mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-22 03:36:27 +03:00
Add support for fma instruction
This commit is contained in:
@ -131,6 +131,7 @@ match {
|
|||||||
"cvt",
|
"cvt",
|
||||||
"cvta",
|
"cvta",
|
||||||
"debug",
|
"debug",
|
||||||
|
"fma",
|
||||||
"ld",
|
"ld",
|
||||||
"mad",
|
"mad",
|
||||||
"map_f64_to_f32",
|
"map_f64_to_f32",
|
||||||
@ -166,6 +167,7 @@ ExtendedID : &'input str = {
|
|||||||
"cvt",
|
"cvt",
|
||||||
"cvta",
|
"cvta",
|
||||||
"debug",
|
"debug",
|
||||||
|
"fma",
|
||||||
"ld",
|
"ld",
|
||||||
"mad",
|
"mad",
|
||||||
"map_f64_to_f32",
|
"map_f64_to_f32",
|
||||||
@ -1185,7 +1187,8 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
|
||||||
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
"mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a),
|
"mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a),
|
||||||
"mad" ".hi" ".sat" ".s32" => todo!()
|
"mad" ".hi" ".sat" ".s32" => todo!(),
|
||||||
|
"fma" <f:ArithFloatMustRound> <a:Arg4> => ast::Instruction::Mad(ast::MulDetails::Float(f), a),
|
||||||
};
|
};
|
||||||
|
|
||||||
SignedIntType: ast::ScalarType = {
|
SignedIntType: ast::ScalarType = {
|
||||||
@ -1333,6 +1336,33 @@ ArithFloat: ast::ArithFloat = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ArithFloatMustRound: ast::ArithFloat = {
|
||||||
|
<rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
|
||||||
|
typ: ast::FloatType::F32,
|
||||||
|
rounding: Some(rn),
|
||||||
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
|
saturate: sat.is_some(),
|
||||||
|
},
|
||||||
|
<rn:RoundingModeFloat> ".f64" => ast::ArithFloat {
|
||||||
|
typ: ast::FloatType::F64,
|
||||||
|
rounding: Some(rn),
|
||||||
|
flush_to_zero: None,
|
||||||
|
saturate: false,
|
||||||
|
},
|
||||||
|
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
|
||||||
|
typ: ast::FloatType::F16,
|
||||||
|
rounding: Some(ast::RoundingMode::NearestEven),
|
||||||
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
|
saturate: sat.is_some(),
|
||||||
|
},
|
||||||
|
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
|
||||||
|
typ: ast::FloatType::F16x2,
|
||||||
|
rounding: Some(ast::RoundingMode::NearestEven),
|
||||||
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
|
saturate: sat.is_some(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
Operand: ast::Operand<&'input str> = {
|
Operand: ast::Operand<&'input str> = {
|
||||||
<r:ExtendedID> => ast::Operand::Reg(r),
|
<r:ExtendedID> => ast::Operand::Reg(r),
|
||||||
<r:ExtendedID> "+" <offset:S32Num> => ast::Operand::RegOffset(r, offset),
|
<r:ExtendedID> "+" <offset:S32Num> => ast::Operand::RegOffset(r, offset),
|
||||||
|
25
ptx/src/test/spirv_run/fma.ptx
Normal file
25
ptx/src/test/spirv_run/fma.ptx
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry fma(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .f32 temp1;
|
||||||
|
.reg .f32 temp2;
|
||||||
|
.reg .f32 temp3;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.f32 temp1, [in_addr];
|
||||||
|
ld.f32 temp2, [in_addr+4];
|
||||||
|
ld.f32 temp3, [in_addr+8];
|
||||||
|
fma.rn.f32 temp1, temp1, temp2, temp3;
|
||||||
|
st.f32 [out_addr], temp1;
|
||||||
|
ret;
|
||||||
|
}
|
72
ptx/src/test/spirv_run/fma.spvtxt
Normal file
72
ptx/src/test/spirv_run/fma.spvtxt
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
; SPIR-V
|
||||||
|
; Version: 1.3
|
||||||
|
; Generator: rspirv
|
||||||
|
; Bound: 45
|
||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int8
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Float64
|
||||||
|
OpCapability FunctionFloatControlINTEL
|
||||||
|
OpExtension "SPV_INTEL_float_controls2"
|
||||||
|
%37 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "fma"
|
||||||
|
OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
|
||||||
|
%38 = OpTypeVoid
|
||||||
|
%39 = OpTypeInt 64 0
|
||||||
|
%40 = OpTypeFunction %38 %39 %39
|
||||||
|
%41 = OpTypePointer Function %39
|
||||||
|
%42 = OpTypeFloat 32
|
||||||
|
%43 = OpTypePointer Function %42
|
||||||
|
%44 = OpTypePointer Generic %42
|
||||||
|
%27 = OpConstant %39 4
|
||||||
|
%29 = OpConstant %39 8
|
||||||
|
%1 = OpFunction %38 None %40
|
||||||
|
%9 = OpFunctionParameter %39
|
||||||
|
%10 = OpFunctionParameter %39
|
||||||
|
%35 = OpLabel
|
||||||
|
%2 = OpVariable %41 Function
|
||||||
|
%3 = OpVariable %41 Function
|
||||||
|
%4 = OpVariable %41 Function
|
||||||
|
%5 = OpVariable %41 Function
|
||||||
|
%6 = OpVariable %43 Function
|
||||||
|
%7 = OpVariable %43 Function
|
||||||
|
%8 = OpVariable %43 Function
|
||||||
|
OpStore %2 %9
|
||||||
|
OpStore %3 %10
|
||||||
|
%12 = OpLoad %39 %2
|
||||||
|
%11 = OpCopyObject %39 %12
|
||||||
|
OpStore %4 %11
|
||||||
|
%14 = OpLoad %39 %3
|
||||||
|
%13 = OpCopyObject %39 %14
|
||||||
|
OpStore %5 %13
|
||||||
|
%16 = OpLoad %39 %4
|
||||||
|
%31 = OpConvertUToPtr %44 %16
|
||||||
|
%15 = OpLoad %42 %31
|
||||||
|
OpStore %6 %15
|
||||||
|
%18 = OpLoad %39 %4
|
||||||
|
%28 = OpIAdd %39 %18 %27
|
||||||
|
%32 = OpConvertUToPtr %44 %28
|
||||||
|
%17 = OpLoad %42 %32
|
||||||
|
OpStore %7 %17
|
||||||
|
%20 = OpLoad %39 %4
|
||||||
|
%30 = OpIAdd %39 %20 %29
|
||||||
|
%33 = OpConvertUToPtr %44 %30
|
||||||
|
%19 = OpLoad %42 %33
|
||||||
|
OpStore %8 %19
|
||||||
|
%22 = OpLoad %42 %6
|
||||||
|
%23 = OpLoad %42 %7
|
||||||
|
%24 = OpLoad %42 %8
|
||||||
|
%21 = OpExtInst %42 %37 mad %22 %23 %24
|
||||||
|
OpStore %6 %21
|
||||||
|
%25 = OpLoad %39 %5
|
||||||
|
%26 = OpLoad %42 %6
|
||||||
|
%34 = OpConvertUToPtr %44 %25
|
||||||
|
OpStore %34 %26
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
@ -91,6 +91,7 @@ test_ptx!(constant_f32, [10f32], [5f32]);
|
|||||||
test_ptx!(constant_negative, [-101i32], [101i32]);
|
test_ptx!(constant_negative, [-101i32], [101i32]);
|
||||||
test_ptx!(and, [6u32, 3u32], [2u32]);
|
test_ptx!(and, [6u32, 3u32], [2u32]);
|
||||||
test_ptx!(selp, [100u16, 200u16], [200u16]);
|
test_ptx!(selp, [100u16, 200u16], [200u16]);
|
||||||
|
test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
|
||||||
|
|
||||||
struct DisplayError<T: Debug> {
|
struct DisplayError<T: Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
@ -2343,7 +2343,9 @@ fn emit_function_body_ops(
|
|||||||
ast::MulDetails::Unsigned(ref desc) => {
|
ast::MulDetails::Unsigned(ref desc) => {
|
||||||
emit_mad_uint(builder, map, opencl, desc, arg)?
|
emit_mad_uint(builder, map, opencl, desc, arg)?
|
||||||
}
|
}
|
||||||
ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
|
ast::MulDetails::Float(desc) => {
|
||||||
|
emit_mad_float(builder, map, opencl, desc, arg)?
|
||||||
|
}
|
||||||
},
|
},
|
||||||
ast::Instruction::Or(t, a) => {
|
ast::Instruction::Or(t, a) => {
|
||||||
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
|
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
|
||||||
@ -2560,10 +2562,19 @@ fn emit_mad_sint(
|
|||||||
fn emit_mad_float(
|
fn emit_mad_float(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
|
opencl: spirv::Word,
|
||||||
desc: &ast::ArithFloat,
|
desc: &ast::ArithFloat,
|
||||||
arg: &ast::Arg4<ExpandedArgParams>,
|
arg: &ast::Arg4<ExpandedArgParams>,
|
||||||
) -> Result<(), dr::Error> {
|
) -> Result<(), dr::Error> {
|
||||||
todo!()
|
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||||
|
builder.ext_inst(
|
||||||
|
inst_type,
|
||||||
|
Some(arg.dst),
|
||||||
|
opencl,
|
||||||
|
spirv::CLOp::mad as spirv::Word,
|
||||||
|
[arg.src1, arg.src2, arg.src3],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn emit_add_float(
|
fn emit_add_float(
|
||||||
|
Reference in New Issue
Block a user