mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-18 09:46:21 +03:00
[BROKEN] Start implementing better support for addressable arguments
This commit is contained in:
@ -354,6 +354,7 @@ pub struct CallInst<P: ArgParams> {
|
|||||||
pub trait ArgParams {
|
pub trait ArgParams {
|
||||||
type ID;
|
type ID;
|
||||||
type Operand;
|
type Operand;
|
||||||
|
type MemoryOperand;
|
||||||
type CallOperand;
|
type CallOperand;
|
||||||
type VecOperand;
|
type VecOperand;
|
||||||
}
|
}
|
||||||
@ -365,6 +366,7 @@ pub struct ParsedArgParams<'a> {
|
|||||||
impl<'a> ArgParams for ParsedArgParams<'a> {
|
impl<'a> ArgParams for ParsedArgParams<'a> {
|
||||||
type ID = &'a str;
|
type ID = &'a str;
|
||||||
type Operand = Operand<&'a str>;
|
type Operand = Operand<&'a str>;
|
||||||
|
type MemoryOperand = Operand<&'a str>;
|
||||||
type CallOperand = CallOperand<&'a str>;
|
type CallOperand = CallOperand<&'a str>;
|
||||||
type VecOperand = (&'a str, u8);
|
type VecOperand = (&'a str, u8);
|
||||||
}
|
}
|
||||||
@ -378,8 +380,13 @@ pub struct Arg2<P: ArgParams> {
|
|||||||
pub src: P::Operand,
|
pub src: P::Operand,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct Arg2Ld<P: ArgParams> {
|
||||||
|
pub dst: P::ID,
|
||||||
|
pub src: P::MemoryOperand,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Arg2St<P: ArgParams> {
|
pub struct Arg2St<P: ArgParams> {
|
||||||
pub src1: P::Operand,
|
pub src1: P::MemoryOperand,
|
||||||
pub src2: P::Operand,
|
pub src2: P::Operand,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,13 +423,13 @@ pub struct Arg5<P: ArgParams> {
|
|||||||
pub enum Operand<ID> {
|
pub enum Operand<ID> {
|
||||||
Reg(ID),
|
Reg(ID),
|
||||||
RegOffset(ID, i32),
|
RegOffset(ID, i32),
|
||||||
Imm(i128),
|
Imm(u32),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub enum CallOperand<ID> {
|
pub enum CallOperand<ID> {
|
||||||
Reg(ID),
|
Reg(ID),
|
||||||
Imm(i128),
|
Imm(u32),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum VectorPrefix {
|
pub enum VectorPrefix {
|
||||||
|
@ -446,7 +446,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||||
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
|
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," <src:MemoryOperand> => {
|
||||||
ast::Instruction::Ld(
|
ast::Instruction::Ld(
|
||||||
ast::LdData {
|
ast::LdData {
|
||||||
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
||||||
@ -899,7 +899,7 @@ ShlType: ast::ShlType = {
|
|||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
|
||||||
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
|
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
|
||||||
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> "[" <src1:Operand> "]" "," <src2:Operand> => {
|
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:Operand> => {
|
||||||
ast::Instruction::St(
|
ast::Instruction::St(
|
||||||
ast::StData {
|
ast::StData {
|
||||||
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
||||||
@ -912,6 +912,11 @@ InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#using-addresses-arrays-and-vectors
|
||||||
|
MemoryOperand: ast::Operand<&'input str> = {
|
||||||
|
"[" <o:Operand> "]" => o
|
||||||
|
}
|
||||||
|
|
||||||
StStateSpace: ast::StStateSpace = {
|
StStateSpace: ast::StStateSpace = {
|
||||||
".global" => ast::StStateSpace::Global,
|
".global" => ast::StStateSpace::Global,
|
||||||
".local" => ast::StStateSpace::Local,
|
".local" => ast::StStateSpace::Local,
|
||||||
@ -1006,7 +1011,7 @@ Operand: ast::Operand<&'input str> = {
|
|||||||
// TODO: start parsing whole constants sub-language:
|
// TODO: start parsing whole constants sub-language:
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants
|
||||||
<o:Num> => {
|
<o:Num> => {
|
||||||
let offset = o.parse::<i128>();
|
let offset = o.parse::<u32>();
|
||||||
let offset = offset.unwrap_with(errors);
|
let offset = offset.unwrap_with(errors);
|
||||||
ast::Operand::Imm(offset)
|
ast::Operand::Imm(offset)
|
||||||
}
|
}
|
||||||
@ -1015,7 +1020,7 @@ Operand: ast::Operand<&'input str> = {
|
|||||||
CallOperand: ast::CallOperand<&'input str> = {
|
CallOperand: ast::CallOperand<&'input str> = {
|
||||||
<r:ExtendedID> => ast::CallOperand::Reg(r),
|
<r:ExtendedID> => ast::CallOperand::Reg(r),
|
||||||
<o:Num> => {
|
<o:Num> => {
|
||||||
let offset = o.parse::<i128>();
|
let offset = o.parse::<u32>();
|
||||||
let offset = offset.unwrap_with(errors);
|
let offset = offset.unwrap_with(errors);
|
||||||
ast::CallOperand::Imm(offset)
|
ast::CallOperand::Imm(offset)
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,8 @@ test_ptx!(local_align, [1u64], [1u64]);
|
|||||||
test_ptx!(call, [1u64], [2u64]);
|
test_ptx!(call, [1u64], [2u64]);
|
||||||
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
||||||
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
|
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
|
||||||
|
test_ptx!(ntid, [3u32], [4u32]);
|
||||||
|
test_ptx!(reg_slm, [12u64], [12u64]);
|
||||||
|
|
||||||
struct DisplayError<T: Debug> {
|
struct DisplayError<T: Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
23
ptx/src/test/spirv_run/ntid.ptx
Normal file
23
ptx/src/test/spirv_run/ntid.ptx
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry ntid(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .u32 in_val;
|
||||||
|
.reg .u32 global_count;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.u32 in_val, [in_addr];
|
||||||
|
mov.u32 global_count, %ntid.x;
|
||||||
|
add.u32 in_val, in_val, global_count;
|
||||||
|
st.u32 [out_addr], in_val;
|
||||||
|
ret;
|
||||||
|
}
|
56
ptx/src/test/spirv_run/ntid.spvtxt
Normal file
56
ptx/src/test/spirv_run/ntid.spvtxt
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Int8
|
||||||
|
%29 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "add" %GlobalSize
|
||||||
|
OpDecorate %GlobalSize BuiltIn GlobalSize
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%uint = OpTypeInt 32 0
|
||||||
|
%v3uint = OpTypeVector %uint 3
|
||||||
|
%_ptr_UniformConstant_v3uint = OpTypePointer UniformConstant %v3uint
|
||||||
|
%GlobalSize = OpVariable %_ptr_UniformConstant_v3uint UniformConstant
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%35 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||||
|
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||||
|
%1 = OpFunction %void None %35
|
||||||
|
%9 = OpFunctionParameter %ulong
|
||||||
|
%10 = OpFunctionParameter %ulong
|
||||||
|
%27 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_uint Function
|
||||||
|
%7 = OpVariable %_ptr_Function_uint Function
|
||||||
|
OpStore %2 %9
|
||||||
|
OpStore %3 %10
|
||||||
|
%12 = OpLoad %ulong %2
|
||||||
|
%11 = OpCopyObject %ulong %12
|
||||||
|
OpStore %4 %11
|
||||||
|
%14 = OpLoad %ulong %3
|
||||||
|
%13 = OpCopyObject %ulong %14
|
||||||
|
OpStore %5 %13
|
||||||
|
%16 = OpLoad %ulong %4
|
||||||
|
%25 = OpConvertUToPtr %_ptr_Generic_uint %16
|
||||||
|
%15 = OpLoad %uint %25
|
||||||
|
OpStore %6 %15
|
||||||
|
%18 = OpLoad %v3uint %GlobalSize
|
||||||
|
%24 = OpCompositeExtract %uint %18 0
|
||||||
|
%17 = OpCopyObject %uint %24
|
||||||
|
OpStore %7 %17
|
||||||
|
%20 = OpLoad %uint %6
|
||||||
|
%21 = OpLoad %uint %7
|
||||||
|
%19 = OpIAdd %uint %20 %21
|
||||||
|
OpStore %6 %19
|
||||||
|
%22 = OpLoad %ulong %5
|
||||||
|
%23 = OpLoad %uint %6
|
||||||
|
%26 = OpConvertUToPtr %_ptr_Generic_uint %22
|
||||||
|
OpStore %26 %23
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
26
ptx/src/test/spirv_run/reg_slm.ptx
Normal file
26
ptx/src/test/spirv_run/reg_slm.ptx
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry reg_slm(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.local .align 8 .b8 slm[8];
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .b64 temp;
|
||||||
|
.reg .s64 unused;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
mov.s64 unused, slm;
|
||||||
|
|
||||||
|
ld.global.u64 temp, [in_addr];
|
||||||
|
st.u64 [slm], temp;
|
||||||
|
ld.u64 temp, [slm];
|
||||||
|
st.global.u64 [out_addr], temp;
|
||||||
|
ret;
|
||||||
|
}
|
46
ptx/src/test/spirv_run/reg_slm.spvtxt
Normal file
46
ptx/src/test/spirv_run/reg_slm.spvtxt
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Int8
|
||||||
|
%25 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "add"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%28 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||||
|
%ulong_1 = OpConstant %ulong 1
|
||||||
|
%1 = OpFunction %void None %28
|
||||||
|
%8 = OpFunctionParameter %ulong
|
||||||
|
%9 = OpFunctionParameter %ulong
|
||||||
|
%23 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%7 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %8
|
||||||
|
OpStore %3 %9
|
||||||
|
%11 = OpLoad %ulong %2
|
||||||
|
%10 = OpCopyObject %ulong %11
|
||||||
|
OpStore %4 %10
|
||||||
|
%13 = OpLoad %ulong %3
|
||||||
|
%12 = OpCopyObject %ulong %13
|
||||||
|
OpStore %5 %12
|
||||||
|
%15 = OpLoad %ulong %4
|
||||||
|
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||||
|
%14 = OpLoad %ulong %21
|
||||||
|
OpStore %6 %14
|
||||||
|
%17 = OpLoad %ulong %6
|
||||||
|
%16 = OpIAdd %ulong %17 %ulong_1
|
||||||
|
OpStore %7 %16
|
||||||
|
%18 = OpLoad %ulong %5
|
||||||
|
%19 = OpLoad %ulong %7
|
||||||
|
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
|
||||||
|
OpStore %22 %19
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
@ -286,7 +286,7 @@ fn expand_kernel_params<'a, 'b>(
|
|||||||
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
|
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
|
||||||
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
|
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
|
||||||
args.map(|a| ast::KernelArgument {
|
args.map(|a| ast::KernelArgument {
|
||||||
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
|
name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))),
|
||||||
v_type: a.v_type,
|
v_type: a.v_type,
|
||||||
align: a.align,
|
align: a.align,
|
||||||
})
|
})
|
||||||
@ -297,10 +297,16 @@ fn expand_fn_params<'a, 'b>(
|
|||||||
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
|
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
|
||||||
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
|
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
|
||||||
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
|
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
|
||||||
args.map(|a| ast::FnArgument {
|
args.map(|a| {
|
||||||
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
|
let ss = match a.v_type {
|
||||||
|
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
|
||||||
|
ast::FnArgumentType::Param(_) => StateSpace::Param,
|
||||||
|
};
|
||||||
|
ast::FnArgument {
|
||||||
|
name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type)))),
|
||||||
v_type: a.v_type,
|
v_type: a.v_type,
|
||||||
align: a.align,
|
align: a.align,
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
@ -325,6 +331,8 @@ fn to_ssa<'input, 'b>(
|
|||||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
|
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
|
||||||
let unadorned_statements =
|
let unadorned_statements =
|
||||||
add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs);
|
add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs);
|
||||||
|
todo!()
|
||||||
|
/*
|
||||||
let (f_args, ssa_statements) =
|
let (f_args, ssa_statements) =
|
||||||
insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
|
insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
|
||||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
|
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
|
||||||
@ -336,6 +344,7 @@ fn to_ssa<'input, 'b>(
|
|||||||
func_directive: f_args,
|
func_directive: f_args,
|
||||||
body: Some(sorted_statements),
|
body: Some(sorted_statements),
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
|
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
|
||||||
@ -350,7 +359,7 @@ fn add_types_to_statements(
|
|||||||
func: Vec<UnadornedStatement>,
|
func: Vec<UnadornedStatement>,
|
||||||
fn_defs: &GlobalFnDeclResolver,
|
fn_defs: &GlobalFnDeclResolver,
|
||||||
id_defs: &NumericIdResolver,
|
id_defs: &NumericIdResolver,
|
||||||
) -> Vec<UnadornedStatement> {
|
) -> Vec<TypedStatement> {
|
||||||
func.into_iter()
|
func.into_iter()
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
match s {
|
match s {
|
||||||
@ -359,7 +368,7 @@ fn add_types_to_statements(
|
|||||||
let fn_def = fn_defs.get_fn_decl(call.func);
|
let fn_def = fn_defs.get_fn_decl(call.func);
|
||||||
let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
|
let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
|
||||||
let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params);
|
let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params);
|
||||||
let resolved_call = ResolvedCall {
|
let resolved_call: ResolvedCall<TypedArgParams> = ResolvedCall {
|
||||||
uniform: call.uniform,
|
uniform: call.uniform,
|
||||||
ret_params,
|
ret_params,
|
||||||
func: call.func,
|
func: call.func,
|
||||||
@ -367,18 +376,13 @@ fn add_types_to_statements(
|
|||||||
};
|
};
|
||||||
Statement::Call(resolved_call)
|
Statement::Call(resolved_call)
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
|
Statement::Instruction(ast::Instruction::Ld(d, arg)) => {
|
||||||
// TODO fail on type mismatch
|
todo!()
|
||||||
let new_dets = match id_defs.get_type(*args.dst()) {
|
|
||||||
Some(ast::Type::Vector(_, len)) => ast::MovVectorDetails {
|
|
||||||
length: len,
|
|
||||||
..dets
|
|
||||||
},
|
|
||||||
_ => dets,
|
|
||||||
};
|
|
||||||
Statement::Instruction(ast::Instruction::MovVector(new_dets, args))
|
|
||||||
}
|
}
|
||||||
s => s,
|
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
s => todo!(),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
@ -485,7 +489,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||||||
ast::MethodDecl::Kernel(_, in_params) => {
|
ast::MethodDecl::Kernel(_, in_params) => {
|
||||||
for p in in_params.iter_mut() {
|
for p in in_params.iter_mut() {
|
||||||
let typ = ast::Type::from(p.v_type);
|
let typ = ast::Type::from(p.v_type);
|
||||||
let new_id = id_def.new_id(Some(typ));
|
let new_id = id_def.new_id(Some((StateSpace::Param, typ)));
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: p.align,
|
align: p.align,
|
||||||
v_type: ast::VariableType::Param(p.v_type),
|
v_type: ast::VariableType::Param(p.v_type),
|
||||||
@ -504,8 +508,12 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||||||
}
|
}
|
||||||
ast::MethodDecl::Func(out_params, _, in_params) => {
|
ast::MethodDecl::Func(out_params, _, in_params) => {
|
||||||
for p in in_params.iter_mut() {
|
for p in in_params.iter_mut() {
|
||||||
|
let ss = match p.v_type {
|
||||||
|
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
|
||||||
|
ast::FnArgumentType::Param(_) => StateSpace::Param,
|
||||||
|
};
|
||||||
let typ = ast::Type::from(p.v_type);
|
let typ = ast::Type::from(p.v_type);
|
||||||
let new_id = id_def.new_id(Some(typ));
|
let new_id = id_def.new_id(Some((ss, typ)));
|
||||||
let var_typ = ast::VariableType::from(p.v_type);
|
let var_typ = ast::VariableType::from(p.v_type);
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: p.align,
|
align: p.align,
|
||||||
@ -548,7 +556,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||||||
dst: new_id,
|
dst: new_id,
|
||||||
src: out_param,
|
src: out_param,
|
||||||
},
|
},
|
||||||
typ.unwrap(),
|
typ.unwrap().1,
|
||||||
));
|
));
|
||||||
result.push(Statement::RetValue(d, new_id));
|
result.push(Statement::RetValue(d, new_id));
|
||||||
} else {
|
} else {
|
||||||
@ -558,7 +566,10 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||||||
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
|
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
|
||||||
},
|
},
|
||||||
Statement::Conditional(mut bra) => {
|
Statement::Conditional(mut bra) => {
|
||||||
let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
|
let generated_id = id_def.new_id(Some((
|
||||||
|
StateSpace::Reg,
|
||||||
|
ast::Type::Scalar(ast::ScalarType::Pred),
|
||||||
|
)));
|
||||||
result.push(Statement::LoadVar(
|
result.push(Statement::LoadVar(
|
||||||
Arg2 {
|
Arg2 {
|
||||||
dst: generated_id,
|
dst: generated_id,
|
||||||
@ -607,11 +618,12 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
|
|||||||
let mut post_statements = Vec::new();
|
let mut post_statements = Vec::new();
|
||||||
let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
|
let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
|
||||||
let id_type = match (id_def.get_type(desc.op), desc.sema) {
|
let id_type = match (id_def.get_type(desc.op), desc.sema) {
|
||||||
(Some(t), ArgumentSemantics::ParamPtr) | (Some(t), ArgumentSemantics::Default) => t,
|
(Some((_, t)), ArgumentSemantics::ParamPtr)
|
||||||
(Some(t), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
|
| (Some((_, t)), ArgumentSemantics::Default) => t,
|
||||||
|
(Some((_, t)), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
|
||||||
(None, _) => return desc.op,
|
(None, _) => return desc.op,
|
||||||
};
|
};
|
||||||
let generated_id = id_def.new_id(Some(id_type));
|
let generated_id = id_def.new_id(Some((StateSpace::Reg, id_type)));
|
||||||
if !desc.is_dst {
|
if !desc.is_dst {
|
||||||
result.push(Statement::LoadVar(
|
result.push(Statement::LoadVar(
|
||||||
Arg2 {
|
Arg2 {
|
||||||
@ -716,11 +728,13 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||||||
} else {
|
} else {
|
||||||
todo!()
|
todo!()
|
||||||
};
|
};
|
||||||
let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
|
let id = self
|
||||||
|
.id_def
|
||||||
|
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
|
||||||
self.func.push(Statement::Constant(ConstantDefinition {
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
dst: id,
|
dst: id,
|
||||||
typ: scalar_t,
|
typ: scalar_t,
|
||||||
value: x,
|
value: x as i64,
|
||||||
}));
|
}));
|
||||||
id
|
id
|
||||||
}
|
}
|
||||||
@ -732,13 +746,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||||||
} else {
|
} else {
|
||||||
todo!()
|
todo!()
|
||||||
};
|
};
|
||||||
let id_constant_stmt =
|
let id_constant_stmt = self
|
||||||
self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
|
.id_def
|
||||||
let result_id = self.id_def.new_id(Some(typ));
|
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
|
||||||
|
let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
|
||||||
self.func.push(Statement::Constant(ConstantDefinition {
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
dst: id_constant_stmt,
|
dst: id_constant_stmt,
|
||||||
typ: scalar_t,
|
typ: scalar_t,
|
||||||
value: offset as i128,
|
value: offset as i64,
|
||||||
}));
|
}));
|
||||||
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
|
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
|
||||||
self.func.push(Statement::Instruction(
|
self.func.push(Statement::Instruction(
|
||||||
@ -758,13 +773,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||||||
}
|
}
|
||||||
ArgumentSemantics::Ptr => {
|
ArgumentSemantics::Ptr => {
|
||||||
let scalar_t = ast::ScalarType::U64;
|
let scalar_t = ast::ScalarType::U64;
|
||||||
let id_constant_stmt =
|
let id_constant_stmt = self
|
||||||
self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
|
.id_def
|
||||||
let result_id = self.id_def.new_id(Some(typ));
|
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
|
||||||
|
let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
|
||||||
self.func.push(Statement::Constant(ConstantDefinition {
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
dst: id_constant_stmt,
|
dst: id_constant_stmt,
|
||||||
typ: scalar_t,
|
typ: scalar_t,
|
||||||
value: offset as i128,
|
value: offset as i64,
|
||||||
}));
|
}));
|
||||||
let int_type = ast::IntType::U64;
|
let int_type = ast::IntType::U64;
|
||||||
self.func.push(Statement::Instruction(
|
self.func.push(Statement::Instruction(
|
||||||
@ -810,9 +826,10 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||||||
desc: ArgumentDescriptor<(spirv::Word, u8)>,
|
desc: ArgumentDescriptor<(spirv::Word, u8)>,
|
||||||
(scalar_type, vec_len): (ast::MovVectorType, u8),
|
(scalar_type, vec_len): (ast::MovVectorType, u8),
|
||||||
) -> spirv::Word {
|
) -> spirv::Word {
|
||||||
let new_id = self
|
let new_id = self.id_def.new_id(Some((
|
||||||
.id_def
|
StateSpace::Reg,
|
||||||
.new_id(Some(ast::Type::Vector(scalar_type.into(), vec_len)));
|
ast::Type::Vector(scalar_type.into(), vec_len),
|
||||||
|
)));
|
||||||
self.func.push(Statement::Composite(CompositeRead {
|
self.func.push(Statement::Composite(CompositeRead {
|
||||||
typ: scalar_type,
|
typ: scalar_type,
|
||||||
dst: new_id,
|
dst: new_id,
|
||||||
@ -821,6 +838,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||||||
}));
|
}));
|
||||||
new_id
|
new_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mov_operand(
|
||||||
|
&mut self,
|
||||||
|
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
|
||||||
|
typ: ast::Type,
|
||||||
|
) -> spirv::Word {
|
||||||
|
self.operand(desc, typ)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -911,7 +936,7 @@ fn insert_implicit_conversions(
|
|||||||
let mut did_vector_implicit = false;
|
let mut did_vector_implicit = false;
|
||||||
let mut post_conv = None;
|
let mut post_conv = None;
|
||||||
if inst_typ_is_bit {
|
if inst_typ_is_bit {
|
||||||
let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!());
|
let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()).1;
|
||||||
if let ast::Type::Vector(_, _) = src_type {
|
if let ast::Type::Vector(_, _) = src_type {
|
||||||
arg.src = insert_conversion_src(
|
arg.src = insert_conversion_src(
|
||||||
&mut result,
|
&mut result,
|
||||||
@ -923,7 +948,7 @@ fn insert_implicit_conversions(
|
|||||||
);
|
);
|
||||||
did_vector_implicit = true;
|
did_vector_implicit = true;
|
||||||
}
|
}
|
||||||
let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!());
|
let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()).1;
|
||||||
if let ast::Type::Vector(_, _) = src_type {
|
if let ast::Type::Vector(_, _) = src_type {
|
||||||
post_conv = Some(get_conversion_dst(
|
post_conv = Some(get_conversion_dst(
|
||||||
id_def,
|
id_def,
|
||||||
@ -1615,9 +1640,15 @@ fn expand_map_variables<'a, 'b>(
|
|||||||
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
|
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
|
||||||
i.map_variable(&mut |id| id_defs.get_id(id)),
|
i.map_variable(&mut |id| id_defs.get_id(id)),
|
||||||
))),
|
))),
|
||||||
ast::Statement::Variable(var) => match var.count {
|
ast::Statement::Variable(var) => {
|
||||||
|
let ss = match var.var.v_type {
|
||||||
|
ast::VariableType::Reg(_) => StateSpace::Reg,
|
||||||
|
ast::VariableType::Local(_) => StateSpace::Local,
|
||||||
|
ast::VariableType::Param(_) => StateSpace::ParamReg,
|
||||||
|
};
|
||||||
|
match var.count {
|
||||||
Some(count) => {
|
Some(count) => {
|
||||||
for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) {
|
for new_id in id_defs.add_defs(var.var.name, count, ss, var.var.v_type.into()) {
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: var.var.align,
|
align: var.var.align,
|
||||||
v_type: var.var.v_type,
|
v_type: var.var.v_type,
|
||||||
@ -1626,14 +1657,15 @@ fn expand_map_variables<'a, 'b>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into()));
|
let new_id = id_defs.add_def(var.var.name, Some((ss, var.var.v_type.into())));
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: var.var.align,
|
align: var.var.align,
|
||||||
v_type: var.var.v_type,
|
v_type: var.var.v_type,
|
||||||
name: new_id,
|
name: new_id,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1766,7 +1798,7 @@ struct FnStringIdResolver<'input, 'b> {
|
|||||||
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
|
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
|
||||||
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
|
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
|
||||||
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
|
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
|
||||||
type_check: HashMap<u32, ast::Type>,
|
type_check: HashMap<u32, (StateSpace, ast::Type)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||||
@ -1809,7 +1841,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
|
fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
|
||||||
let numeric_id = *self.current_id;
|
let numeric_id = *self.current_id;
|
||||||
self.variables
|
self.variables
|
||||||
.last_mut()
|
.last_mut()
|
||||||
@ -1827,6 +1859,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
base_id: &'a str,
|
base_id: &'a str,
|
||||||
count: u32,
|
count: u32,
|
||||||
|
ss: StateSpace,
|
||||||
typ: ast::Type,
|
typ: ast::Type,
|
||||||
) -> impl Iterator<Item = spirv::Word> {
|
) -> impl Iterator<Item = spirv::Word> {
|
||||||
let numeric_id = *self.current_id;
|
let numeric_id = *self.current_id;
|
||||||
@ -1835,7 +1868,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||||||
.last_mut()
|
.last_mut()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
|
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
|
||||||
self.type_check.insert(numeric_id + i, typ);
|
self.type_check.insert(numeric_id + i, (ss, typ));
|
||||||
}
|
}
|
||||||
*self.current_id += count;
|
*self.current_id += count;
|
||||||
(0..count).into_iter().map(move |i| i + numeric_id)
|
(0..count).into_iter().map(move |i| i + numeric_id)
|
||||||
@ -1844,15 +1877,15 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||||||
|
|
||||||
struct NumericIdResolver<'b> {
|
struct NumericIdResolver<'b> {
|
||||||
current_id: &'b mut spirv::Word,
|
current_id: &'b mut spirv::Word,
|
||||||
type_check: HashMap<u32, ast::Type>,
|
type_check: HashMap<u32, (StateSpace, ast::Type)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'b> NumericIdResolver<'b> {
|
impl<'b> NumericIdResolver<'b> {
|
||||||
fn get_type(&self, id: spirv::Word) -> Option<ast::Type> {
|
fn get_type(&self, id: spirv::Word) -> Option<(StateSpace, ast::Type)> {
|
||||||
self.type_check.get(&id).map(|x| *x)
|
self.type_check.get(&id).map(|x| *x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
|
fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
|
||||||
let new_id = *self.current_id;
|
let new_id = *self.current_id;
|
||||||
if let Some(typ) = typ {
|
if let Some(typ) = typ {
|
||||||
self.type_check.insert(new_id, typ);
|
self.type_check.insert(new_id, typ);
|
||||||
@ -1982,16 +2015,48 @@ type UnadornedStatement = Statement<ast::Instruction<NormalizedArgParams>, Norma
|
|||||||
impl ast::ArgParams for NormalizedArgParams {
|
impl ast::ArgParams for NormalizedArgParams {
|
||||||
type ID = spirv::Word;
|
type ID = spirv::Word;
|
||||||
type Operand = ast::Operand<spirv::Word>;
|
type Operand = ast::Operand<spirv::Word>;
|
||||||
|
type MemoryOperand = ast::Operand<spirv::Word>;
|
||||||
type CallOperand = ast::CallOperand<spirv::Word>;
|
type CallOperand = ast::CallOperand<spirv::Word>;
|
||||||
type VecOperand = (spirv::Word, u8);
|
type VecOperand = (spirv::Word, u8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum TypedArgParams {}
|
||||||
|
impl ast::ArgParams for TypedArgParams {
|
||||||
|
type ID = spirv::Word;
|
||||||
|
type Operand = ast::Operand<spirv::Word>;
|
||||||
|
type MemoryOperand = MemoryOperand;
|
||||||
|
type CallOperand = ast::CallOperand<spirv::Word>;
|
||||||
|
type VecOperand = (spirv::Word, u8);
|
||||||
|
}
|
||||||
|
type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
|
||||||
|
|
||||||
impl ArgParamsEx for NormalizedArgParams {
|
impl ArgParamsEx for NormalizedArgParams {
|
||||||
fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
|
fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
|
||||||
decl.get_fn_decl(*id)
|
decl.get_fn_decl(*id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub enum StateSpace {
|
||||||
|
Reg,
|
||||||
|
Sreg,
|
||||||
|
Const,
|
||||||
|
Global,
|
||||||
|
Local,
|
||||||
|
Shared,
|
||||||
|
Param,
|
||||||
|
ParamReg,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub enum MemoryOperand {
|
||||||
|
Reg(spirv::Word),
|
||||||
|
Address(spirv::Word),
|
||||||
|
RegOffset(spirv::Word, i32),
|
||||||
|
AddressOffset(spirv::Word, i32),
|
||||||
|
Imm(u32),
|
||||||
|
}
|
||||||
|
|
||||||
enum ExpandedArgParams {}
|
enum ExpandedArgParams {}
|
||||||
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
|
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
|
||||||
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
|
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
|
||||||
@ -1999,6 +2064,7 @@ type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStateme
|
|||||||
impl ast::ArgParams for ExpandedArgParams {
|
impl ast::ArgParams for ExpandedArgParams {
|
||||||
type ID = spirv::Word;
|
type ID = spirv::Word;
|
||||||
type Operand = spirv::Word;
|
type Operand = spirv::Word;
|
||||||
|
type MemoryOperand = spirv::Word;
|
||||||
type CallOperand = spirv::Word;
|
type CallOperand = spirv::Word;
|
||||||
type VecOperand = spirv::Word;
|
type VecOperand = spirv::Word;
|
||||||
}
|
}
|
||||||
@ -2012,6 +2078,11 @@ impl ArgParamsEx for ExpandedArgParams {
|
|||||||
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
|
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
|
||||||
fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
|
fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
|
||||||
fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
|
fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
|
||||||
|
fn mov_operand(
|
||||||
|
&mut self,
|
||||||
|
desc: ArgumentDescriptor<T::MemoryOperand>,
|
||||||
|
typ: ast::Type,
|
||||||
|
) -> U::MemoryOperand;
|
||||||
fn src_call_operand(
|
fn src_call_operand(
|
||||||
&mut self,
|
&mut self,
|
||||||
desc: ArgumentDescriptor<T::CallOperand>,
|
desc: ArgumentDescriptor<T::CallOperand>,
|
||||||
@ -2035,9 +2106,15 @@ where
|
|||||||
) -> spirv::Word {
|
) -> spirv::Word {
|
||||||
self(desc, t)
|
self(desc, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
|
fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
|
||||||
self(desc, Some(t))
|
self(desc, Some(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
|
||||||
|
self(desc, Some(t))
|
||||||
|
}
|
||||||
|
|
||||||
fn src_call_operand(
|
fn src_call_operand(
|
||||||
&mut self,
|
&mut self,
|
||||||
desc: ArgumentDescriptor<spirv::Word>,
|
desc: ArgumentDescriptor<spirv::Word>,
|
||||||
@ -2045,6 +2122,7 @@ where
|
|||||||
) -> spirv::Word {
|
) -> spirv::Word {
|
||||||
self(desc, Some(t))
|
self(desc, Some(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn src_vec_operand(
|
fn src_vec_operand(
|
||||||
&mut self,
|
&mut self,
|
||||||
desc: ArgumentDescriptor<spirv::Word>,
|
desc: ArgumentDescriptor<spirv::Word>,
|
||||||
@ -2095,6 +2173,14 @@ where
|
|||||||
) -> (spirv::Word, u8) {
|
) -> (spirv::Word, u8) {
|
||||||
(self(desc.op.0), desc.op.1)
|
(self(desc.op.0), desc.op.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mov_operand(
|
||||||
|
&mut self,
|
||||||
|
desc: ArgumentDescriptor<ast::Operand<&str>>,
|
||||||
|
typ: ast::Type,
|
||||||
|
) -> ast::Operand<spirv::Word> {
|
||||||
|
self.operand(desc, typ)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ArgumentDescriptor<Op> {
|
struct ArgumentDescriptor<Op> {
|
||||||
@ -2260,6 +2346,16 @@ where
|
|||||||
desc.op.1,
|
desc.op.1,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mov_operand(
|
||||||
|
&mut self,
|
||||||
|
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
|
||||||
|
typ: ast::Type,
|
||||||
|
) -> ast::Operand<spirv::Word> {
|
||||||
|
<Self as ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams>>::operand(
|
||||||
|
self, desc, typ,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ast::Type {
|
impl ast::Type {
|
||||||
@ -2365,7 +2461,7 @@ struct CompositeRead {
|
|||||||
struct ConstantDefinition {
|
struct ConstantDefinition {
|
||||||
pub dst: spirv::Word,
|
pub dst: spirv::Word,
|
||||||
pub typ: ast::ScalarType,
|
pub typ: ast::ScalarType,
|
||||||
pub value: i128,
|
pub value: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct BrachCondition {
|
struct BrachCondition {
|
||||||
@ -2534,7 +2630,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
|
|||||||
is_param: bool,
|
is_param: bool,
|
||||||
) -> ast::Arg2St<U> {
|
) -> ast::Arg2St<U> {
|
||||||
ast::Arg2St {
|
ast::Arg2St {
|
||||||
src1: visitor.operand(
|
src1: visitor.mov_operand(
|
||||||
ArgumentDescriptor {
|
ArgumentDescriptor {
|
||||||
op: self.src1,
|
op: self.src1,
|
||||||
is_dst: is_param,
|
is_dst: is_param,
|
||||||
@ -3012,6 +3108,16 @@ impl From<ast::FnArgumentType> for ast::VariableType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> ast::Operand<T> {
|
||||||
|
fn underlying(&self) -> Option<&T> {
|
||||||
|
match self {
|
||||||
|
ast::Operand::Reg(r) => Some(r),
|
||||||
|
ast::Operand::RegOffset(r, _) => Some(r),
|
||||||
|
ast::Operand::Imm(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
||||||
match (instr, operand) {
|
match (instr, operand) {
|
||||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||||
@ -3053,7 +3159,7 @@ fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<Expan
|
|||||||
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
|
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
|
||||||
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
|
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
|
||||||
if post_conv.len() > 0 {
|
if post_conv.len() > 0 {
|
||||||
let new_id = id_def.new_id(Some(post_conv[0].from));
|
let new_id = id_def.new_id(Some((StateSpace::Reg, post_conv[0].from)));
|
||||||
post_conv[0].src = new_id;
|
post_conv[0].src = new_id;
|
||||||
post_conv.last_mut().unwrap().dst = *dst(&mut instr);
|
post_conv.last_mut().unwrap().dst = *dst(&mut instr);
|
||||||
*dst(&mut instr) = new_id;
|
*dst(&mut instr) = new_id;
|
||||||
@ -3078,7 +3184,7 @@ fn insert_with_conversions_pre_conv<T>(
|
|||||||
conv.src = *original_src;
|
conv.src = *original_src;
|
||||||
}
|
}
|
||||||
if i == pre_conv_len - 1 {
|
if i == pre_conv_len - 1 {
|
||||||
let new_id = id_def.new_id(Some(conv.to));
|
let new_id = id_def.new_id(Some((StateSpace::Reg, conv.to)));
|
||||||
conv.dst = new_id;
|
conv.dst = new_id;
|
||||||
*original_src = new_id;
|
*original_src = new_id;
|
||||||
}
|
}
|
||||||
@ -3095,7 +3201,7 @@ fn get_implicit_conversions_ld_dst<
|
|||||||
should_convert: ShouldConvert,
|
should_convert: ShouldConvert,
|
||||||
in_reverse: bool,
|
in_reverse: bool,
|
||||||
) -> Option<ImplicitConversion> {
|
) -> Option<ImplicitConversion> {
|
||||||
let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!());
|
let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()).1;
|
||||||
if let Some(conv) = should_convert(dst_type, instr_type) {
|
if let Some(conv) = should_convert(dst_type, instr_type) {
|
||||||
Some(ImplicitConversion {
|
Some(ImplicitConversion {
|
||||||
src: u32::max_value(),
|
src: u32::max_value(),
|
||||||
@ -3115,7 +3221,7 @@ fn get_implicit_conversions_ld_src(
|
|||||||
state_space: ast::LdStateSpace,
|
state_space: ast::LdStateSpace,
|
||||||
src: spirv::Word,
|
src: spirv::Word,
|
||||||
) -> Vec<ImplicitConversion> {
|
) -> Vec<ImplicitConversion> {
|
||||||
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!());
|
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
|
||||||
match state_space {
|
match state_space {
|
||||||
ast::LdStateSpace::Param => {
|
ast::LdStateSpace::Param => {
|
||||||
if src_type != instr_type {
|
if src_type != instr_type {
|
||||||
@ -3162,7 +3268,7 @@ fn get_implicit_conversions_ld_src(
|
|||||||
kind: ConversionKind::Ptr(state_space),
|
kind: ConversionKind::Ptr(state_space),
|
||||||
});
|
});
|
||||||
if result.len() == 2 {
|
if result.len() == 2 {
|
||||||
let new_id = id_def.new_id(Some(new_src_type));
|
let new_id = id_def.new_id(Some((StateSpace::Reg, new_src_type)));
|
||||||
result[0].dst = new_id;
|
result[0].dst = new_id;
|
||||||
result[1].src = new_id;
|
result[1].src = new_id;
|
||||||
result[1].from = new_src_type;
|
result[1].from = new_src_type;
|
||||||
@ -3221,9 +3327,9 @@ fn insert_implicit_conversions_ld_src_impl<
|
|||||||
src: spirv::Word,
|
src: spirv::Word,
|
||||||
should_convert: ShouldConvert,
|
should_convert: ShouldConvert,
|
||||||
) -> spirv::Word {
|
) -> spirv::Word {
|
||||||
let src_type = id_def.get_type(src);
|
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
|
||||||
if let Some(conv) = should_convert(src_type.unwrap(), instr_type) {
|
if let Some(conv) = should_convert(src_type, instr_type) {
|
||||||
insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv)
|
insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
|
||||||
} else {
|
} else {
|
||||||
src
|
src
|
||||||
}
|
}
|
||||||
@ -3263,7 +3369,7 @@ fn insert_conversion_src(
|
|||||||
instr_type: ast::Type,
|
instr_type: ast::Type,
|
||||||
conv: ConversionKind,
|
conv: ConversionKind,
|
||||||
) -> spirv::Word {
|
) -> spirv::Word {
|
||||||
let temp_src = id_def.new_id(Some(instr_type));
|
let temp_src = id_def.new_id(Some((StateSpace::Reg, instr_type)));
|
||||||
func.push(Statement::Conversion(ImplicitConversion {
|
func.push(Statement::Conversion(ImplicitConversion {
|
||||||
src: src,
|
src: src,
|
||||||
dst: temp_src,
|
dst: temp_src,
|
||||||
@ -3309,7 +3415,7 @@ fn get_conversion_dst(
|
|||||||
kind: ConversionKind,
|
kind: ConversionKind,
|
||||||
) -> ExpandedStatement {
|
) -> ExpandedStatement {
|
||||||
let original_dst = *dst;
|
let original_dst = *dst;
|
||||||
let temp_dst = id_def.new_id(Some(instr_type));
|
let temp_dst = id_def.new_id(Some((StateSpace::Reg, instr_type)));
|
||||||
*dst = temp_dst;
|
*dst = temp_dst;
|
||||||
Statement::Conversion(ImplicitConversion {
|
Statement::Conversion(ImplicitConversion {
|
||||||
src: temp_dst,
|
src: temp_dst,
|
||||||
@ -3428,8 +3534,8 @@ fn insert_implicit_bitcasts(
|
|||||||
Some(t) => t,
|
Some(t) => t,
|
||||||
None => return desc.op,
|
None => return desc.op,
|
||||||
};
|
};
|
||||||
let id_actual_type = id_def.get_type(desc.op).unwrap();
|
let id_actual_type = id_def.get_type(desc.op).unwrap().1;
|
||||||
if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) {
|
if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap().1) {
|
||||||
if desc.is_dst {
|
if desc.is_dst {
|
||||||
dst_coercion = Some(get_conversion_dst(
|
dst_coercion = Some(get_conversion_dst(
|
||||||
id_def,
|
id_def,
|
||||||
|
Reference in New Issue
Block a user