Add more tests

This commit is contained in:
Andrzej Janik
2020-09-20 15:44:52 +02:00
parent 17f2d09cc7
commit dcaea507ba
7 changed files with 175 additions and 139 deletions

View File

@ -13,8 +13,10 @@
%25 = OpTypeFunction %void %ulong %ulong %25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong
%uchar = OpTypeInt 8 0 %uchar = OpTypeInt 8 0
%_arr_uchar_8 = OpTypeArray %uchar %8 %uint = OpTypeInt 32 0
%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8 %uint_8 = OpConstant %uint 8
%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8
%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8
%_ptr_Generic_ulong = OpTypePointer Generic %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong
%1 = OpFunction %void None %25 %1 = OpFunction %void None %25
%8 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong
@ -22,7 +24,7 @@
%20 = OpLabel %20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function %2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup %4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function
%5 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function %7 = OpVariable %_ptr_Function_ulong Function

View File

@ -8,7 +8,7 @@ use spirv_headers::Word;
use spirv_tools_sys::{ use spirv_tools_sys::{
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env, spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
}; };
use std::collections::hash_map::Entry; use std::{collections::hash_map::Entry, cmp};
use std::error; use std::error;
use std::ffi::{c_void, CStr, CString}; use std::ffi::{c_void, CStr, CString};
use std::fmt; use std::fmt;
@ -59,8 +59,9 @@ test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]); test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
//test_ptx!(ntid, [3u32], [4u32]); test_ptx!(ntid, [3u32], [4u32]);
//test_ptx!(reg_slm, [12u64], [12u64]); test_ptx!(reg_local, [12u64], [12u64]);
test_ptx!(mov_address, [0xDEADu64], [0u64]);
struct DisplayError<T: Debug> { struct DisplayError<T: Debug> {
err: T, err: T,
@ -123,8 +124,8 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
kernel.set_indirect_access( kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE, ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
)?; )?;
let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, input.len())?; let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(input.len(),1))?;
let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, output.len())?; let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
let inp_b_ptr_mut: ze::BufferPtrMut<T> = (&mut inp_b).into(); let inp_b_ptr_mut: ze::BufferPtrMut<T> = (&mut inp_b).into();
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?; let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
let ev0 = ze::Event::new(&event_pool, 0)?; let ev0 = ze::Event::new(&event_pool, 0)?;

View File

@ -0,0 +1,15 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry mov_address(
.param .u64 input,
.param .u64 output
)
{
.local .b8 __local_depot0[8];
.reg .u64 temp;
mov.u64 temp, __local_depot0;
ret;
}

View File

@ -2,12 +2,12 @@
.target sm_30 .target sm_30
.address_size 64 .address_size 64
.visible .entry reg_slm( .visible .entry reg_local(
.param .u64 input, .param .u64 input,
.param .u64 output .param .u64 output
) )
{ {
.local .align 8 .b8 slm[8]; .local .align 8 .b8 local_x[8];
.reg .u64 in_addr; .reg .u64 in_addr;
.reg .u64 out_addr; .reg .u64 out_addr;
.reg .b64 temp; .reg .b64 temp;
@ -16,11 +16,9 @@
ld.param.u64 in_addr, [input]; ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output]; ld.param.u64 out_addr, [output];
mov.s64 unused, slm;
ld.global.u64 temp, [in_addr]; ld.global.u64 temp, [in_addr];
st.u64 [slm], temp; st.u64 [local_x], temp;
ld.u64 temp, [slm]; ld.u64 temp, [local_x];
st.global.u64 [out_addr], temp; st.global.u64 [out_addr], temp;
ret; ret;
} }

View File

@ -0,0 +1,46 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "add"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_1 = OpConstant %ulong 1
%1 = OpFunction %void None %28
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%23 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %21
OpStore %6 %14
%17 = OpLoad %ulong %6
%16 = OpIAdd %ulong %17 %ulong_1
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
OpStore %22 %19
OpReturn
OpFunctionEnd

View File

@ -217,11 +217,13 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
let opencl_id = emit_opencl_import(&mut builder); let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder); emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder); let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
for f in ssa_functions { for f in ssa_functions {
let f_body = match f.body { let f_body = match f.body {
Some(f) => f, Some(f) => f,
None => continue, None => continue,
}; };
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?; emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?;
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
builder.end_function()?; builder.end_function()?;
@ -229,6 +231,33 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
Ok(builder.module()) Ok(builder.module())
} }
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver,
) {
for (reg, id) in id_defs.special_registers.iter() {
let result_type = map.get_or_add(
builder,
SpirvType::Pointer(
Box::new(SpirvType::from(reg.get_type())),
spirv::StorageClass::UniformConstant,
),
);
builder.variable(
result_type,
Some(*id),
spirv::StorageClass::UniformConstant,
None,
);
builder.decorate(
*id,
spirv::Decoration::BuiltIn,
&[dr::Operand::BuiltIn(reg.get_builtin())],
);
}
}
fn emit_function_header<'a>( fn emit_function_header<'a>(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
@ -239,7 +268,12 @@ fn emit_function_header<'a>(
let fn_id = match func_directive { let fn_id = match func_directive {
ast::MethodDecl::Kernel(name, _) => { ast::MethodDecl::Kernel(name, _) => {
let fn_id = global.get_id(name)?; let fn_id = global.get_id(name)?;
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]); let interface = global
.special_registers
.iter()
.map(|(_, id)| *id)
.collect::<Vec<_>>();
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface);
fn_id fn_id
} }
ast::MethodDecl::Func(_, name, _) => name, ast::MethodDecl::Func(_, name, _) => name,
@ -293,7 +327,7 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn to_ssa_function<'a>( fn to_ssa_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>, id_defs: &mut GlobalStringIdResolver<'a>,
f: ast::ParsedFunction<'a>, f: ast::ParsedFunction<'a>,
) -> Result<ExpandedFunction<'a>, TranslateError> { ) -> Result<Function<'a>, TranslateError> {
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive); let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive);
to_ssa(str_resolver, fn_resolver, fn_decl, f.body) to_ssa(str_resolver, fn_resolver, fn_decl, f.body)
} }
@ -333,13 +367,14 @@ fn to_ssa<'input, 'b>(
fn_defs: GlobalFnDeclResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>,
f_args: ast::MethodDecl<'input, ExpandedArgParams>, f_args: ast::MethodDecl<'input, ExpandedArgParams>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>, f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
) -> Result<ExpandedFunction<'input>, TranslateError> { ) -> Result<Function<'input>, TranslateError> {
let f_body = match f_body { let f_body = match f_body {
Some(vec) => vec, Some(vec) => vec,
None => { None => {
return Ok(ExpandedFunction { return Ok(Function {
func_directive: f_args, func_directive: f_args,
body: None, body: None,
globals: Vec::new(),
}) })
} }
}; };
@ -357,12 +392,21 @@ fn to_ssa<'input, 'b>(
let mut numeric_id_defs = numeric_id_defs.unmut(); let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let sorted_statements = normalize_variable_decls(labeled_statements); let sorted_statements = normalize_variable_decls(labeled_statements);
Ok(ExpandedFunction { let (f_body, globals) = extract_globals(sorted_statements);
Ok(Function {
func_directive: f_args, func_directive: f_args,
body: Some(sorted_statements), globals: globals,
body: Some(f_body),
}) })
} }
fn extract_globals(
sorted_statements: Vec<ExpandedStatement>,
) -> (Vec<ExpandedStatement>, Vec<ExpandedStatement>) {
// This fn will be used for SLM
(sorted_statements, Vec::new())
}
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> { fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
func[1..].sort_by_key(|s| match s { func[1..].sort_by_key(|s| match s {
Statement::Variable(_) => 0, Statement::Variable(_) => 0,
@ -477,7 +521,9 @@ fn add_types_to_statements(
}, },
_ => dets, _ => dets,
}; };
Ok(Statement::Instruction(ast::Instruction::MovVector(new_dets, args))) Ok(Statement::Instruction(ast::Instruction::MovVector(
new_dets, args,
)))
} }
s => Ok(s), s => Ok(s),
} }
@ -724,7 +770,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
(_, ArgumentSemantics::Address) => return Ok(desc.op), (_, ArgumentSemantics::Address) => return Ok(desc.op),
(t, ArgumentSemantics::RegisterPointer) (t, ArgumentSemantics::RegisterPointer)
| (t, ArgumentSemantics::Default) | (t, ArgumentSemantics::Default)
| (t, ArgumentSemantics::Ptr) => t, | (t, ArgumentSemantics::PhysicalPointer) => t,
}; };
let generated_id = id_def.new_id(id_type); let generated_id = id_def.new_id(id_type);
if !desc.is_dst { if !desc.is_dst {
@ -873,7 +919,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
)); ));
Ok(result_id) Ok(result_id)
} }
ArgumentSemantics::Ptr => { ArgumentSemantics::PhysicalPointer => {
let scalar_t = ast::ScalarType::U64; let scalar_t = ast::ScalarType::U64;
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
let result_id = self.id_def.new_id(typ); let result_id = self.id_def.new_id(typ);
@ -1137,7 +1183,7 @@ fn emit_function_body_ops(
builder.begin_block(Some(*id))?; builder.begin_block(Some(*id))?;
} }
_ => { _ => {
if builder.block.is_none() { if builder.block.is_none() && builder.function.is_some() {
builder.begin_block(None)?; builder.begin_block(None)?;
} }
} }
@ -1166,10 +1212,9 @@ fn emit_function_body_ops(
name, name,
}) => { }) => {
let st_class = match v_type { let st_class = match v_type {
ast::VariableType::Reg(_) | ast::VariableType::Param(_) => { ast::VariableType::Reg(_)
spirv::StorageClass::Function | ast::VariableType::Param(_)
} | ast::VariableType::Local(_) => spirv::StorageClass::Function,
ast::VariableType::Local(_) => spirv::StorageClass::Workgroup,
}; };
let type_id = map.get_or_add( let type_id = map.get_or_add(
builder, builder,
@ -1234,7 +1279,7 @@ fn emit_function_body_ops(
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?; builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
} }
ast::LdStateSpace::Param => { ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
builder.copy_object(result_type, Some(arg.dst), arg.src)?; builder.copy_object(result_type, Some(arg.dst), arg.src)?;
} }
@ -1242,18 +1287,20 @@ fn emit_function_body_ops(
} }
} }
ast::Instruction::St(data, arg) => { ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak if data.qualifier != ast::LdStQualifier::Weak {
|| (data.state_space != ast::StStateSpace::Generic
&& data.state_space != ast::StStateSpace::Param
&& data.state_space != ast::StStateSpace::Global)
{
todo!() todo!()
} }
if data.state_space == ast::StStateSpace::Param { if data.state_space == ast::StStateSpace::Param
|| data.state_space == ast::StStateSpace::Local
{
let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
builder.copy_object(result_type, Some(arg.src1), arg.src2)?; builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
} else { } else if data.state_space == ast::StStateSpace::Generic
|| data.state_space == ast::StStateSpace::Global
{
builder.store(arg.src1, arg.src2, None, &[])?; builder.store(arg.src1, arg.src2, None, &[])?;
} else {
todo!()
} }
} }
// SPIR-V does not support ret as guaranteed-converged // SPIR-V does not support ret as guaranteed-converged
@ -1643,7 +1690,7 @@ fn emit_implicit_conversion(
let from_parts = cv.from.to_parts(); let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts(); let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) { match (from_parts.kind, to_parts.kind, cv.kind) {
(_, _, ConversionKind::Ptr(space)) => { (_, _, ConversionKind::BitToPtr(space)) => {
let dst_type = map.get_or_add( let dst_type = map.get_or_add(
builder, builder,
SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()), SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
@ -1699,14 +1746,11 @@ fn emit_implicit_conversion(
} }
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(), (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) => { | (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to)); let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
builder.bitcast(into_type, Some(cv.dst), cv.src)?; builder.bitcast(into_type, Some(cv.dst), cv.src)?;
} }
(TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
builder.convert_ptr_to_u(into_type, Some(cv.dst), cv.src)?;
}
_ => unreachable!(), _ => unreachable!(),
} }
Ok(()) Ok(())
@ -2181,7 +2225,7 @@ impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
decl.get_fn_decl_str(id) decl.get_fn_decl_str(id)
} }
fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics { fn get_src_semantics(_: &Self::MovOperand) -> ArgumentSemantics {
ArgumentSemantics::Default ArgumentSemantics::Default
} }
} }
@ -2230,7 +2274,12 @@ pub enum StateSpace {
enum ExpandedArgParams {} enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>; type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
struct Function<'input> {
pub func_directive: ast::MethodDecl<'input, ExpandedArgParams>,
pub globals: Vec<ExpandedStatement>,
pub body: Option<Vec<ExpandedStatement>>,
}
impl ast::ArgParams for ExpandedArgParams { impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word; type ID = spirv::Word;
@ -2248,7 +2297,7 @@ impl ArgParamsEx for ExpandedArgParams {
decl.get_fn_decl(*id) decl.get_fn_decl(*id)
} }
fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics { fn get_src_semantics(_: &spirv::Word) -> ArgumentSemantics {
ArgumentSemantics::Default ArgumentSemantics::Default
} }
} }
@ -2398,12 +2447,12 @@ struct ArgumentDescriptor<Op> {
sema: ArgumentSemantics, sema: ArgumentSemantics,
} }
#[derive(Copy, Clone, PartialEq, Eq)] #[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArgumentSemantics { pub enum ArgumentSemantics {
// normal register access // normal register access
Default, Default,
// st/ld global // st/ld global
Ptr, PhysicalPointer,
// st/ld .param, .local // st/ld .param, .local
RegisterPointer, RegisterPointer,
// mov of .local/.global variables // mov of .local/.global variables
@ -2720,7 +2769,8 @@ enum ConversionKind {
Default, Default,
// zero-extend/chop/bitcast depending on types // zero-extend/chop/bitcast depending on types
SignExtend, SignExtend,
Ptr(ast::LdStateSpace), BitToPtr(ast::LdStateSpace),
PtrToBit,
} }
impl<T> ast::PredAt<T> { impl<T> ast::PredAt<T> {
@ -2831,7 +2881,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
sema: if is_param { sema: if is_param {
ArgumentSemantics::RegisterPointer ArgumentSemantics::RegisterPointer
} else { } else {
ArgumentSemantics::Ptr ArgumentSemantics::PhysicalPointer
}, },
}, },
t, t,
@ -2919,7 +2969,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
sema: if is_param { sema: if is_param {
ArgumentSemantics::RegisterPointer ArgumentSemantics::RegisterPointer
} else { } else {
ArgumentSemantics::Ptr ArgumentSemantics::PhysicalPointer
}, },
}, },
t, t,
@ -3518,7 +3568,7 @@ fn get_implicit_conversions_ld_src(
) -> Result<Vec<ImplicitConversion>, TranslateError> { ) -> Result<Vec<ImplicitConversion>, TranslateError> {
let src_type = id_def.get_typed(src)?; let src_type = id_def.get_typed(src)?;
match state_space { match state_space {
ast::LdStateSpace::Param => { ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
if src_type != instr_type { if src_type != instr_type {
Ok(vec![ Ok(vec![
ImplicitConversion { ImplicitConversion {
@ -3560,7 +3610,7 @@ fn get_implicit_conversions_ld_src(
dst: u32::max_value(), dst: u32::max_value(),
from: src_type, from: src_type,
to: instr_type, to: instr_type,
kind: ConversionKind::Ptr(state_space), kind: ConversionKind::BitToPtr(state_space),
}); });
if result.len() == 2 { if result.len() == 2 {
let new_id = id_def.new_id(new_src_type); let new_id = id_def.new_id(new_src_type);
@ -3570,92 +3620,9 @@ fn get_implicit_conversions_ld_src(
} }
Ok(result) Ok(result)
} }
_ => todo!(), _ => Err(TranslateError::Todo),
} }
} }
fn insert_implicit_conversions_ld_src(
func: &mut Vec<ExpandedStatement>,
instr_type: ast::Type,
id_def: &mut MutableNumericIdResolver,
state_space: ast::LdStateSpace,
src: spirv::Word,
) -> Result<spirv::Word, TranslateError> {
match state_space {
ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
func,
id_def,
instr_type,
src,
should_convert_ld_param_src,
),
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
mem::size_of::<usize>() as u8,
ScalarKind::Bit,
));
let new_src = insert_implicit_conversions_ld_src_impl(
func,
id_def,
new_src_type,
src,
should_convert_ld_generic_src_to_bitcast,
)?;
Ok(insert_conversion_src(
func,
id_def,
new_src,
new_src_type,
instr_type,
ConversionKind::Ptr(state_space),
))
}
_ => todo!(),
}
}
fn insert_implicit_conversions_ld_src_impl<
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
instr_type: ast::Type,
src: spirv::Word,
should_convert: ShouldConvert,
) -> Result<spirv::Word, TranslateError> {
let src_type = id_def.get_typed(src)?;
if let Some(conv) = should_convert(src_type, instr_type) {
Ok(insert_conversion_src(
func, id_def, src, src_type, instr_type, conv,
))
} else {
Ok(src)
}
}
fn should_convert_ld_param_src(
src_type: ast::Type,
instr_type: ast::Type,
) -> Option<ConversionKind> {
if src_type != instr_type {
return Some(ConversionKind::Default);
}
None
}
// HACK ALERT
// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
fn should_convert_ld_generic_src_to_bitcast(
src_type: ast::Type,
_instr_type: ast::Type,
) -> Option<ConversionKind> {
if let ast::Type::Scalar(src_type) = src_type {
if src_type.kind() == ScalarKind::Signed {
return Some(ConversionKind::Default);
}
}
None
}
#[must_use] #[must_use]
fn insert_conversion_src( fn insert_conversion_src(
@ -3832,14 +3799,21 @@ fn insert_implicit_bitcasts(
None => return Ok(desc.op), None => return Ok(desc.op),
}; };
let id_actual_type = id_def.get_typed(desc.op)?; let id_actual_type = id_def.get_typed(desc.op)?;
if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) { let conv_kind = if desc.sema == ArgumentSemantics::Address {
Some(ConversionKind::PtrToBit)
} else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
Some(ConversionKind::Default)
} else {
None
};
if let Some(conv_kind) = conv_kind {
if desc.is_dst { if desc.is_dst {
dst_coercion = Some(get_conversion_dst( dst_coercion = Some(get_conversion_dst(
id_def, id_def,
&mut desc.op, &mut desc.op,
id_type_from_instr, id_type_from_instr,
id_actual_type, id_actual_type,
ConversionKind::Default, conv_kind,
)); ));
Ok(desc.op) Ok(desc.op)
} else { } else {
@ -3849,7 +3823,7 @@ fn insert_implicit_bitcasts(
desc.op, desc.op,
id_actual_type, id_actual_type,
id_type_from_instr, id_type_from_instr,
ConversionKind::Default, conv_kind,
)) ))
} }
} else { } else {