diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt index e96a964..b8ffac0 100644 --- a/ptx/src/test/spirv_run/mul_wide.spvtxt +++ b/ptx/src/test/spirv_run/mul_wide.spvtxt @@ -10,6 +10,8 @@ %30 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "mul_wide" + OpExecutionMode %1 ContractionOff + OpDecorate %17 NoSignedWrap %void = OpTypeVoid %ulong = OpTypeInt 64 0 %33 = OpTypeFunction %void %ulong %ulong @@ -20,8 +22,6 @@ %ulong_4 = OpConstant %ulong 4 %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %_struct_42 = OpTypeStruct %uint %uint - %v2uint = OpTypeVector %uint 2 %_ptr_Generic_ulong = OpTypePointer Generic %ulong %1 = OpFunction %void None %33 %9 = OpFunctionParameter %ulong @@ -53,11 +53,9 @@ OpStore %7 %15 %18 = OpLoad %uint %6 %19 = OpLoad %uint %7 - %43 = OpSMulExtended %_struct_42 %18 %19 - %44 = OpCompositeExtract %uint %43 0 - %45 = OpCompositeExtract %uint %43 1 - %47 = OpCompositeConstruct %v2uint %44 %45 - %17 = OpBitcast %ulong %47 + %42 = OpSConvert %ulong %18 + %43 = OpSConvert %ulong %19 + %17 = OpIMul %ulong %42 %43 OpStore %8 %17 %20 = OpLoad %ulong %5 %21 = OpLoad %ulong %8 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 91e4237..5fea075 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -3798,8 +3798,8 @@ fn emit_mul_sint( desc: &ast::MulSInt, arg: &ast::Arg3, ) -> Result<(), dr::Error> { - let instruction_type = ast::ScalarType::from(desc.typ); - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); + let instruction_type = desc.typ; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.typ)); match desc.control { ast::MulIntControl::Low => { builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; @@ -3816,25 +3816,14 @@ fn emit_mul_sint( )?; } ast::MulIntControl::Wide => { - let mul_ext_type = SpirvType::Struct(vec![ - SpirvScalarKey::from(instruction_type), - SpirvScalarKey::from(instruction_type), - ]); - let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); - let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?; let instr_width = instruction_type.size_of(); let instr_kind = instruction_type.kind(); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type_id = map.get_or_add_scalar(builder, dst_type); - struct2_bitcast_to_wide( - builder, - map, - SpirvScalarKey::from(instruction_type), - inst_type, - arg.dst, - dst_type_id, - mul, - )?; + let src1 = builder.s_convert(dst_type_id, None, arg.src1)?; + let src2 = builder.s_convert(dst_type_id, None, arg.src2)?; + builder.i_mul(dst_type_id, Some(arg.dst), src1, src2)?; + builder.decorate(arg.dst, spirv::Decoration::NoSignedWrap, []); } } Ok(()) @@ -3865,25 +3854,14 @@ fn emit_mul_uint( )?; } ast::MulIntControl::Wide => { - let mul_ext_type = SpirvType::Struct(vec![ - SpirvScalarKey::from(instruction_type), - SpirvScalarKey::from(instruction_type), - ]); - let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); - let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?; let instr_width = instruction_type.size_of(); let instr_kind = instruction_type.kind(); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type_id = map.get_or_add_scalar(builder, dst_type); - struct2_bitcast_to_wide( - builder, - map, - SpirvScalarKey::from(instruction_type), - inst_type, - arg.dst, - dst_type_id, - mul, - )?; + let src1 = builder.u_convert(dst_type_id, None, arg.src1)?; + let src2 = builder.u_convert(dst_type_id, None, arg.src2)?; + builder.i_mul(dst_type_id, Some(arg.dst), src1, src2)?; + builder.decorate(arg.dst, spirv::Decoration::NoUnsignedWrap, []); } } Ok(())