Change codegen for mul.wide

This commit is contained in:
Andrzej Janik
2021-08-01 19:20:08 +02:00
parent 8f68287b18
commit 4a71fefb8a
2 changed files with 15 additions and 39 deletions

View File

@ -10,6 +10,8 @@
%30 = OpExtInstImport "OpenCL.std" %30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "mul_wide" OpEntryPoint Kernel %1 "mul_wide"
OpExecutionMode %1 ContractionOff
OpDecorate %17 NoSignedWrap
%void = OpTypeVoid %void = OpTypeVoid
%ulong = OpTypeInt 64 0 %ulong = OpTypeInt 64 0
%33 = OpTypeFunction %void %ulong %ulong %33 = OpTypeFunction %void %ulong %ulong
@ -20,8 +22,6 @@
%ulong_4 = OpConstant %ulong 4 %ulong_4 = OpConstant %ulong 4
%uchar = OpTypeInt 8 0 %uchar = OpTypeInt 8 0
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
%_struct_42 = OpTypeStruct %uint %uint
%v2uint = OpTypeVector %uint 2
%_ptr_Generic_ulong = OpTypePointer Generic %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong
%1 = OpFunction %void None %33 %1 = OpFunction %void None %33
%9 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong
@ -53,11 +53,9 @@
OpStore %7 %15 OpStore %7 %15
%18 = OpLoad %uint %6 %18 = OpLoad %uint %6
%19 = OpLoad %uint %7 %19 = OpLoad %uint %7
%43 = OpSMulExtended %_struct_42 %18 %19 %42 = OpSConvert %ulong %18
%44 = OpCompositeExtract %uint %43 0 %43 = OpSConvert %ulong %19
%45 = OpCompositeExtract %uint %43 1 %17 = OpIMul %ulong %42 %43
%47 = OpCompositeConstruct %v2uint %44 %45
%17 = OpBitcast %ulong %47
OpStore %8 %17 OpStore %8 %17
%20 = OpLoad %ulong %5 %20 = OpLoad %ulong %5
%21 = OpLoad %ulong %8 %21 = OpLoad %ulong %8

View File

@ -3798,8 +3798,8 @@ fn emit_mul_sint(
desc: &ast::MulSInt, desc: &ast::MulSInt,
arg: &ast::Arg3<ExpandedArgParams>, arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ); let instruction_type = desc.typ;
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); let inst_type = map.get_or_add(builder, SpirvType::from(desc.typ));
match desc.control { match desc.control {
ast::MulIntControl::Low => { ast::MulIntControl::Low => {
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
@ -3816,25 +3816,14 @@ fn emit_mul_sint(
)?; )?;
} }
ast::MulIntControl::Wide => { ast::MulIntControl::Wide => {
let mul_ext_type = SpirvType::Struct(vec![
SpirvScalarKey::from(instruction_type),
SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.size_of(); let instr_width = instruction_type.size_of();
let instr_kind = instruction_type.kind(); let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type); let dst_type_id = map.get_or_add_scalar(builder, dst_type);
struct2_bitcast_to_wide( let src1 = builder.s_convert(dst_type_id, None, arg.src1)?;
builder, let src2 = builder.s_convert(dst_type_id, None, arg.src2)?;
map, builder.i_mul(dst_type_id, Some(arg.dst), src1, src2)?;
SpirvScalarKey::from(instruction_type), builder.decorate(arg.dst, spirv::Decoration::NoSignedWrap, []);
inst_type,
arg.dst,
dst_type_id,
mul,
)?;
} }
} }
Ok(()) Ok(())
@ -3865,25 +3854,14 @@ fn emit_mul_uint(
)?; )?;
} }
ast::MulIntControl::Wide => { ast::MulIntControl::Wide => {
let mul_ext_type = SpirvType::Struct(vec![
SpirvScalarKey::from(instruction_type),
SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.size_of(); let instr_width = instruction_type.size_of();
let instr_kind = instruction_type.kind(); let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type); let dst_type_id = map.get_or_add_scalar(builder, dst_type);
struct2_bitcast_to_wide( let src1 = builder.u_convert(dst_type_id, None, arg.src1)?;
builder, let src2 = builder.u_convert(dst_type_id, None, arg.src2)?;
map, builder.i_mul(dst_type_id, Some(arg.dst), src1, src2)?;
SpirvScalarKey::from(instruction_type), builder.decorate(arg.dst, spirv::Decoration::NoUnsignedWrap, []);
inst_type,
arg.dst,
dst_type_id,
mul,
)?;
} }
} }
Ok(()) Ok(())