From e5a53ed5d30fad3d8ebae6d72ead1564d2b97275 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 1 Nov 2020 14:58:44 +0100 Subject: [PATCH] Implement neg instruction --- ptx/src/ast.rs | 7 +++++ ptx/src/ptx.lalrpop | 36 +++++++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/neg.ptx | 21 ++++++++++++++ ptx/src/test/spirv_run/neg.spvtxt | 47 +++++++++++++++++++++++++++++++ ptx/src/translate.rs | 20 ++++++++++++- 6 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 ptx/src/test/spirv_run/neg.ptx create mode 100644 ptx/src/test/spirv_run/neg.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index f00ddce..7f2fc9a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -542,6 +542,7 @@ pub enum Instruction { Div(DivDetails, Arg3

), Sqrt(SqrtDetails, Arg2

), Rsqrt(RsqrtDetails, Arg2

), + Neg(NegDetails, Arg2

), } #[derive(Copy, Clone)] @@ -1183,6 +1184,12 @@ pub struct RsqrtDetails { pub flush_to_zero: bool, } +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct NegDetails { + pub typ: ScalarType, + pub flush_to_zero: Option, +} + impl<'a> NumsOrArrays<'a> { pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result, PtxError> { self.normalize_dimensions(dimensions)?; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 4cf4255..9d2adec 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -156,6 +156,7 @@ match { "min", "mov", "mul", + "neg", "not", "or", "rcp", @@ -198,6 +199,7 @@ ExtendedID : &'input str = { "min", "mov", "mul", + "neg", "not", "or", "rcp", @@ -684,6 +686,7 @@ Instruction: ast::Instruction> = { InstDiv, InstSqrt, InstRsqrt, + InstNeg, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -1577,6 +1580,39 @@ InstRsqrt: ast::Instruction> = { }, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg +InstNeg: ast::Instruction> = { + "neg" => { + let details = ast::NegDetails { + typ, + flush_to_zero: Some(ftz.is_some()), + }; + ast::Instruction::Neg(details, a) + }, + "neg" => { + let details = ast::NegDetails { + typ, + flush_to_zero: None, + }; + ast::Instruction::Neg(details, a) + }, +} + +NegTypeFtz: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, +} + +NegTypeNonFtz: ast::ScalarType = { + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f64" => ast::ScalarType::F64 +} + ArithDetails: ast::ArithDetails = { => ast::ArithDetails::Unsigned(t), => ast::ArithDetails::Signed(ast::ArithSInt { diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 4e9d39f..7ba3c4d 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -104,6 +104,7 @@ test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]); test_ptx!(div_approx, [1f32, 2f32], [0.5f32]); test_ptx!(sqrt, [0.25f32], [0.5f32]); test_ptx!(rsqrt, [0.25f64], [2f64]); +test_ptx!(neg, [181i32], [-181i32]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/neg.ptx b/ptx/src/test/spirv_run/neg.ptx new file mode 100644 index 0000000..60fe162 --- /dev/null +++ b/ptx/src/test/spirv_run/neg.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry neg( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp1; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp1, [in_addr]; + neg.s32 temp1, temp1; + st.s32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/neg.spvtxt b/ptx/src/test/spirv_run/neg.spvtxt new file mode 100644 index 0000000..b358858 --- /dev/null +++ b/ptx/src/test/spirv_run/neg.spvtxt @@ -0,0 +1,47 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %26 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "not" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %29 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %1 = OpFunction %void None %29 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %24 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %20 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %20 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %22 = OpCopyObject %ulong %17 + %21 = OpNot %ulong %22 + %16 = OpCopyObject %ulong %21 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %23 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %23 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c351ccd..36e15f9 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1511,6 +1511,9 @@ fn convert_to_typed_statements( ast::Instruction::Rsqrt(d, a) => { result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast()))) } + ast::Instruction::Neg(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast()))) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -2805,6 +2808,15 @@ fn emit_function_body_ops( &[a.src], )?; } + ast::Instruction::Neg(details, arg) => { + let result_type = map.get_or_add_scalar(builder, details.typ); + let negate_func = if details.typ.kind() == ScalarKind::Float { + dr::Builder::f_negate + } else { + dr::Builder::s_negate + }; + negate_func(builder, result_type, Some(arg.dst), arg.src)?; + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); @@ -3406,7 +3418,7 @@ fn emit_setp( (ast::SetpCompareOp::NanGreaterOrEq, _) => { builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) } - _ => todo!() + _ => todo!(), }?; Ok(()) } @@ -4678,6 +4690,9 @@ impl ast::Instruction { ast::Instruction::Rsqrt(d, a) => { ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?) } + ast::Instruction::Neg(d, a) => { + ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?) + } }) } } @@ -4984,6 +4999,9 @@ impl ast::Instruction { details.flush_to_zero, ast::ScalarType::from(details.typ).size_of(), )), + ast::Instruction::Neg(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, details.typ.size_of())), } } }