mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-22 03:36:27 +03:00
Implement rcp instruction
This commit is contained in:
@ -510,6 +510,7 @@ pub enum Instruction<P: ArgParams> {
|
|||||||
Sub(ArithDetails, Arg3<P>),
|
Sub(ArithDetails, Arg3<P>),
|
||||||
Min(MinMaxDetails, Arg3<P>),
|
Min(MinMaxDetails, Arg3<P>),
|
||||||
Max(MinMaxDetails, Arg3<P>),
|
Max(MinMaxDetails, Arg3<P>),
|
||||||
|
Rcp(RcpDetails, Arg2<P>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
@ -520,6 +521,12 @@ pub struct AbsDetails {
|
|||||||
pub flush_to_zero: bool,
|
pub flush_to_zero: bool,
|
||||||
pub typ: ScalarType,
|
pub typ: ScalarType,
|
||||||
}
|
}
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct RcpDetails {
|
||||||
|
pub rounding: Option<RoundingMode>,
|
||||||
|
pub flush_to_zero: bool,
|
||||||
|
pub is_f64: bool,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct CallInst<P: ArgParams> {
|
pub struct CallInst<P: ArgParams> {
|
||||||
pub uniform: bool,
|
pub uniform: bool,
|
||||||
|
@ -35,6 +35,7 @@ match {
|
|||||||
".address_size",
|
".address_size",
|
||||||
".align",
|
".align",
|
||||||
".and",
|
".and",
|
||||||
|
".approx",
|
||||||
".b16",
|
".b16",
|
||||||
".b32",
|
".b32",
|
||||||
".b64",
|
".b64",
|
||||||
@ -134,6 +135,7 @@ match {
|
|||||||
"mul",
|
"mul",
|
||||||
"not",
|
"not",
|
||||||
"or",
|
"or",
|
||||||
|
"rcp",
|
||||||
"ret",
|
"ret",
|
||||||
"setp",
|
"setp",
|
||||||
"shl",
|
"shl",
|
||||||
@ -166,6 +168,7 @@ ExtendedID : &'input str = {
|
|||||||
"mul",
|
"mul",
|
||||||
"not",
|
"not",
|
||||||
"or",
|
"or",
|
||||||
|
"rcp",
|
||||||
"ret",
|
"ret",
|
||||||
"setp",
|
"setp",
|
||||||
"shl",
|
"shl",
|
||||||
@ -542,6 +545,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
InstSub,
|
InstSub,
|
||||||
InstMin,
|
InstMin,
|
||||||
InstMax,
|
InstMax,
|
||||||
|
InstRcp
|
||||||
};
|
};
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||||
@ -1119,6 +1123,31 @@ OrType: ast::OrType = {
|
|||||||
".b64" => ast::OrType::B64,
|
".b64" => ast::OrType::B64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp
|
||||||
|
InstRcp: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
|
"rcp" <rounding:RcpRoundingMode> <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
||||||
|
let details = ast::RcpDetails {
|
||||||
|
rounding,
|
||||||
|
flush_to_zero: ftz.is_some(),
|
||||||
|
is_f64: false,
|
||||||
|
};
|
||||||
|
ast::Instruction::Rcp(details, a)
|
||||||
|
},
|
||||||
|
"rcp" <rn:RoundingModeFloat> ".f64" <a:Arg2> => {
|
||||||
|
let details = ast::RcpDetails {
|
||||||
|
rounding: Some(rn),
|
||||||
|
flush_to_zero: false,
|
||||||
|
is_f64: true,
|
||||||
|
};
|
||||||
|
ast::Instruction::Rcp(details, a)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
RcpRoundingMode: Option<ast::RoundingMode> = {
|
||||||
|
".approx" => None,
|
||||||
|
<r:RoundingModeFloat> => Some(r)
|
||||||
|
};
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
|
||||||
|
@ -80,6 +80,7 @@ test_ptx!(max, [555i32, 444i32], [555i32]);
|
|||||||
test_ptx!(global_array, [0xDEADu32], [1u32]);
|
test_ptx!(global_array, [0xDEADu32], [1u32]);
|
||||||
test_ptx!(extern_shared, [127u64], [127u64]);
|
test_ptx!(extern_shared, [127u64], [127u64]);
|
||||||
test_ptx!(extern_shared_call, [121u64], [123u64]);
|
test_ptx!(extern_shared_call, [121u64], [123u64]);
|
||||||
|
test_ptx!(rcp, [2f32], [0.5f32]);
|
||||||
|
|
||||||
struct DisplayError<T: Debug> {
|
struct DisplayError<T: Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
21
ptx/src/test/spirv_run/rcp.ptx
Normal file
21
ptx/src/test/spirv_run/rcp.ptx
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry rcp(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .f32 temp;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.f32 temp, [in_addr];
|
||||||
|
rcp.approx.f32 temp, temp;
|
||||||
|
st.f32 [out_addr], temp;
|
||||||
|
ret;
|
||||||
|
}
|
51
ptx/src/test/spirv_run/rcp.spvtxt
Normal file
51
ptx/src/test/spirv_run/rcp.spvtxt
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int8
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Float64
|
||||||
|
%23 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "rcp"
|
||||||
|
OpDecorate %15 FPFastMathMode AllowRecip
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%26 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%float = OpTypeFloat 32
|
||||||
|
%_ptr_Function_float = OpTypePointer Function %float
|
||||||
|
%_ptr_Generic_float = OpTypePointer Generic %float
|
||||||
|
%float_1 = OpConstant %float 1
|
||||||
|
%1 = OpFunction %void None %26
|
||||||
|
%7 = OpFunctionParameter %ulong
|
||||||
|
%8 = OpFunctionParameter %ulong
|
||||||
|
%21 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_float Function
|
||||||
|
OpStore %2 %7
|
||||||
|
OpStore %3 %8
|
||||||
|
%10 = OpLoad %ulong %2
|
||||||
|
%9 = OpCopyObject %ulong %10
|
||||||
|
OpStore %4 %9
|
||||||
|
%12 = OpLoad %ulong %3
|
||||||
|
%11 = OpCopyObject %ulong %12
|
||||||
|
OpStore %5 %11
|
||||||
|
%14 = OpLoad %ulong %4
|
||||||
|
%19 = OpConvertUToPtr %_ptr_Generic_float %14
|
||||||
|
%13 = OpLoad %float %19
|
||||||
|
OpStore %6 %13
|
||||||
|
%16 = OpLoad %float %6
|
||||||
|
%15 = OpFDiv %float %float_1 %16
|
||||||
|
OpStore %6 %15
|
||||||
|
%17 = OpLoad %ulong %5
|
||||||
|
%18 = OpLoad %float %6
|
||||||
|
%20 = OpConvertUToPtr %_ptr_Generic_float %17
|
||||||
|
OpStore %20 %18
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
@ -1144,6 +1144,9 @@ fn convert_to_typed_statements(
|
|||||||
ast::Instruction::Max(d, a) => {
|
ast::Instruction::Max(d, a) => {
|
||||||
result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
|
result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
|
||||||
}
|
}
|
||||||
|
ast::Instruction::Rcp(d, a) => {
|
||||||
|
result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast())))
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Statement::Label(i) => result.push(Statement::Label(i)),
|
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||||
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
||||||
@ -2179,6 +2182,9 @@ fn emit_function_body_ops(
|
|||||||
ast::Instruction::Max(d, a) => {
|
ast::Instruction::Max(d, a) => {
|
||||||
emit_max(builder, map, opencl, d, a)?;
|
emit_max(builder, map, opencl, d, a)?;
|
||||||
}
|
}
|
||||||
|
ast::Instruction::Rcp(d, a) => {
|
||||||
|
emit_rcp(builder, map, d, a)?;
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Statement::LoadVar(arg, typ) => {
|
Statement::LoadVar(arg, typ) => {
|
||||||
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
|
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
|
||||||
@ -2209,6 +2215,40 @@ fn emit_function_body_ops(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_rcp(
|
||||||
|
builder: &mut dr::Builder,
|
||||||
|
map: &mut TypeWordMap,
|
||||||
|
desc: &ast::RcpDetails,
|
||||||
|
a: &ast::Arg2<ExpandedArgParams>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
if desc.flush_to_zero {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
let (instr_type, constant) = if desc.is_f64 {
|
||||||
|
(ast::ScalarType::F64, vec_repr(1.0f64))
|
||||||
|
} else {
|
||||||
|
(ast::ScalarType::F32, vec_repr(1.0f32))
|
||||||
|
};
|
||||||
|
let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
|
||||||
|
let result_type = map.get_or_add_scalar(builder, instr_type);
|
||||||
|
builder.f_div(result_type, Some(a.dst), one, a.src)?;
|
||||||
|
emit_rounding_decoration(builder, a.dst, desc.rounding);
|
||||||
|
builder.decorate(
|
||||||
|
a.dst,
|
||||||
|
spirv::Decoration::FPFastMathMode,
|
||||||
|
&[dr::Operand::FPFastMathMode(
|
||||||
|
spirv::FPFastMathMode::ALLOW_RECIP,
|
||||||
|
)],
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
|
||||||
|
let mut result = vec![0; mem::size_of::<T>()];
|
||||||
|
unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) };
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_variable(
|
fn emit_variable(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
@ -3735,7 +3775,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||||||
) -> Result<ast::Instruction<U>, TranslateError> {
|
) -> Result<ast::Instruction<U>, TranslateError> {
|
||||||
Ok(match self {
|
Ok(match self {
|
||||||
ast::Instruction::Abs(d, arg) => {
|
ast::Instruction::Abs(d, arg) => {
|
||||||
ast::Instruction::Abs(d, arg.map(visitor, false, &ast::Type::Scalar(d.typ))?)
|
ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?)
|
||||||
}
|
}
|
||||||
// Call instruction is converted to a call statement early on
|
// Call instruction is converted to a call statement early on
|
||||||
ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
|
ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
|
||||||
@ -3766,9 +3806,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||||||
let inst_type = d.typ;
|
let inst_type = d.typ;
|
||||||
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
|
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
|
||||||
}
|
}
|
||||||
ast::Instruction::Not(t, a) => {
|
ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
|
||||||
ast::Instruction::Not(t, a.map(visitor, false, &t.to_type())?)
|
|
||||||
}
|
|
||||||
ast::Instruction::Cvt(d, a) => {
|
ast::Instruction::Cvt(d, a) => {
|
||||||
let (dst_t, src_t) = match &d {
|
let (dst_t, src_t) = match &d {
|
||||||
ast::CvtDetails::FloatFromFloat(desc) => (
|
ast::CvtDetails::FloatFromFloat(desc) => (
|
||||||
@ -3806,7 +3844,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||||||
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
|
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
|
||||||
ast::Instruction::Cvta(d, a) => {
|
ast::Instruction::Cvta(d, a) => {
|
||||||
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
|
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
|
||||||
ast::Instruction::Cvta(d, a.map(visitor, false, &inst_type)?)
|
ast::Instruction::Cvta(d, a.map(visitor, &inst_type)?)
|
||||||
}
|
}
|
||||||
ast::Instruction::Mad(d, a) => {
|
ast::Instruction::Mad(d, a) => {
|
||||||
let inst_type = d.get_type();
|
let inst_type = d.get_type();
|
||||||
@ -3829,6 +3867,14 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||||||
let typ = d.get_type();
|
let typ = d.get_type();
|
||||||
ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?)
|
ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?)
|
||||||
}
|
}
|
||||||
|
ast::Instruction::Rcp(d, a) => {
|
||||||
|
let typ = ast::Type::Scalar(if d.is_f64 {
|
||||||
|
ast::ScalarType::F64
|
||||||
|
} else {
|
||||||
|
ast::ScalarType::F32
|
||||||
|
});
|
||||||
|
ast::Instruction::Rcp(d, a.map(visitor, &typ)?)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -4072,6 +4118,7 @@ impl ast::Instruction<ExpandedArgParams> {
|
|||||||
| ast::Instruction::Sub(_, _)
|
| ast::Instruction::Sub(_, _)
|
||||||
| ast::Instruction::Min(_, _)
|
| ast::Instruction::Min(_, _)
|
||||||
| ast::Instruction::Max(_, _)
|
| ast::Instruction::Max(_, _)
|
||||||
|
| ast::Instruction::Rcp(_, _)
|
||||||
| ast::Instruction::Mad(_, _) => None,
|
| ast::Instruction::Mad(_, _) => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -4289,7 +4336,6 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
|
|||||||
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||||
self,
|
self,
|
||||||
visitor: &mut V,
|
visitor: &mut V,
|
||||||
src_is_addr: bool,
|
|
||||||
t: &ast::Type,
|
t: &ast::Type,
|
||||||
) -> Result<ast::Arg2<U>, TranslateError> {
|
) -> Result<ast::Arg2<U>, TranslateError> {
|
||||||
let new_dst = visitor.id(
|
let new_dst = visitor.id(
|
||||||
@ -4304,11 +4350,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
|
|||||||
ArgumentDescriptor {
|
ArgumentDescriptor {
|
||||||
op: self.src,
|
op: self.src,
|
||||||
is_dst: false,
|
is_dst: false,
|
||||||
sema: if src_is_addr {
|
sema: ArgumentSemantics::Default,
|
||||||
ArgumentSemantics::Address
|
|
||||||
} else {
|
|
||||||
ArgumentSemantics::Default
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
t,
|
t,
|
||||||
)?;
|
)?;
|
||||||
|
Reference in New Issue
Block a user