diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 039e4d1..64593d3 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index b391347..0f0f1d3 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -1,10 +1,11 @@ // Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17` // `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining -// /opt/rocm/llvm/bin/clang -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc +// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc #include #include - +#include +#include #include #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME @@ -163,7 +164,7 @@ extern "C" int32_t __ockl_wgred_and_i32(int32_t) __device__; int32_t __ockl_wgred_or_i32(int32_t) __device__; - #define BAR_RED_IMPL(reducer) \ +#define BAR_RED_IMPL(reducer) \ bool FUNC(bar_red_##reducer##_pred)(uint32_t barrier __attribute__((unused)), bool predicate, bool invert_predicate) \ { \ /* TODO: handle barrier */ \ @@ -173,7 +174,8 @@ extern "C" BAR_RED_IMPL(and); BAR_RED_IMPL(or); - struct ShflSyncResult { + struct ShflSyncResult + { uint32_t output; bool in_bounds; }; @@ -191,7 +193,7 @@ extern "C" // intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the // bounds check. - #define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \ +#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \ ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \ { \ int32_t section_mask = (opts >> 8) & 0b11111; \ @@ -201,10 +203,11 @@ extern "C" int32_t subsection_end = subsection | (~section_mask & warp_end); \ int32_t idx = calculate_index; \ bool out_of_bounds = idx CMP subsection_end; \ - if (out_of_bounds) { \ + if (out_of_bounds) \ + { \ idx = self; \ } \ - int32_t output = __builtin_amdgcn_ds_bpermute(idx<<2, (int32_t)input); \ + int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \ return {(uint32_t)output, !out_of_bounds}; \ } \ \ @@ -212,15 +215,15 @@ extern "C" { \ return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \ } - + // We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync // intrinsics, as those only add an assertion checking that the membermask is used correctly. // They do not return the result of the range check, so we must replicate that logic here. - SHFL_SYNC_IMPL(up, self - delta, <); - SHFL_SYNC_IMPL(down, self + delta, >); - SHFL_SYNC_IMPL(bfly, self ^ delta, >); - SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >); + SHFL_SYNC_IMPL(up, self - delta, <); + SHFL_SYNC_IMPL(down, self + delta, >); + SHFL_SYNC_IMPL(bfly, self ^ delta, >); + SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >); DECLARE_ATTR(uint32_t, CLOCK_RATE); void FUNC(nanosleep_u32)(uint32_t nanoseconds) { @@ -254,4 +257,82 @@ extern "C" (void)char_size; __assert_fail((const char *)message, (const char *)file, line, (const char *)function); } + + // * Smallest denormal is 1.4 × 10^-45 + // * Smallest normal is ~1.175494351 × 10^(-38) + // * Now, 1.175494351×10^-38 / 1.4 × 10^-45 = 8396388 + 31/140 + // * Next power of 2 is 16777216 + const float DENORMAL_TO_NORMAL_FACTOR_F32 = 16777216.0f; + // * Largest subnormal is ~1.175494210692441e × 10^(-38) + // * Then any value equal or larger than following will produce subnormals: 8.50706018714406320806444272332455743547934627837873057975602739772164... × 10^37 + const float RCP_DENORMAL_OUTPUT = 8.50706018714406320806444272332455743547934627837873057975602739772164e37f; + const float REVERSE_DENORMAL_TO_NORMAL_FACTOR_F32 = 0.029387360490963111877208252592662410455594571842846914442095471744599661631813495980086003637902577995683214210345151992265999035207077609582844f; + + float FUNC(sqrt_approx_f32)(float x) + { + bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL); + float input = x; + if (is_subnormal) + input = x * DENORMAL_TO_NORMAL_FACTOR_F32; + float value = __builtin_amdgcn_sqrtf(input); + if (is_subnormal) + return value * 0.000244140625f; + else + return value; + } + + float FUNC(rsqrt_approx_f32)(float x) + { + bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL); + float input = x; + if (is_subnormal) + input = x * DENORMAL_TO_NORMAL_FACTOR_F32; + float value = __builtin_amdgcn_rsqf(input); + if (is_subnormal) + return value * 4096.0f; + else + return value; + } + + float FUNC(rcp_approx_f32)(float x) + { + float factor = 1.0f; + if (__builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL)) + { + factor = DENORMAL_TO_NORMAL_FACTOR_F32; + } + if (std::fabs(x) >= RCP_DENORMAL_OUTPUT) + { + factor = REVERSE_DENORMAL_TO_NORMAL_FACTOR_F32; + } + return __builtin_amdgcn_rcpf(x * factor) * factor; + } + + // When x = -126, exp2(x) = 2^(-126) ≈ 1.175494351 × 10^(-38), + // which is the smallest normalized number in FP32 + float FUNC(ex2_approx_f32)(float x) + { + bool special_handling = x < -126.0f; + float input = x; + if (special_handling) + input *= 0.5f; + float result = __builtin_amdgcn_exp2f(input); + if (special_handling) + return result * result; + else + return result; + } + + float FUNC(lg2_approx_f32)(float x) + { + bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL); + float input = x; + if (is_subnormal) + input = x * DENORMAL_TO_NORMAL_FACTOR_F32; + float value = __builtin_amdgcn_logf(input); + if (is_subnormal) + return value - 24.0f; + else + return value; + } } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index f2fead7..8f71fe7 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -164,6 +164,8 @@ fn run_instruction<'input>( | ast::Instruction::Ret { .. } | ast::Instruction::Rsqrt { .. } | ast::Instruction::Selp { .. } + | ast::Instruction::Set { .. } + | ast::Instruction::SetBool { .. } | ast::Instruction::Setp { .. } | ast::Instruction::SetpBool { .. } | ast::Instruction::ShflSync { .. } @@ -183,6 +185,7 @@ fn run_instruction<'input>( data: ast::ArithDetails::Integer(..), .. } + | ast::Instruction::Tanh { .. } | ast::Instruction::Trap {} | ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 5692365..86ea659 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -221,12 +221,25 @@ impl InstructionModes { ) } - fn from_rcp(data: ast::RcpData) -> InstructionModes { + fn from_rtz_special(data: ast::RcpData) -> InstructionModes { let rounding = match data.kind { ast::RcpKind::Approx => None, ast::RcpKind::Compliant(rnd) => Some(RoundingMode::from_ast(rnd)), }; - let denormal = data.flush_to_zero.map(DenormalMode::from_ftz); + let denormal = match ( + data.kind == ast::RcpKind::Approx, + data.flush_to_zero == Some(true), + ) { + // If we are approximate and flushing then we compile to V_RSQ_F32 + // or V_SQRT_F32 which ignores prevailing denormal mode and flushes anyway + (true, true) => None, + // If we are approximate and flushing the V_RSQ_F32/V_SQRT_F32 + // ignores ftz mode, but we implement instruction in terms of fmuls + // which must preserve denormals + (true, false) => Some(DenormalMode::Preserve), + // For full precision we set denormal mode accordingly + (false, ftz) => Some(DenormalMode::from_ftz(ftz)), + }; InstructionModes::new(data.type_, denormal, rounding) } @@ -1780,6 +1793,11 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { match inst { // TODO: review it when implementing virtual calls ast::Instruction::Call { .. } + // abs and neg have special handling in AMD GPU ISA. They get compiled + // down to instruction argument modifiers, floating point flags have no + // effect on it. We handle it during LLVM bitcode emission + | ast::Instruction::Abs { .. } + | ast::Instruction::Neg {.. } | ast::Instruction::Mov { .. } | ast::Instruction::Ld { .. } | ast::Instruction::St { .. } @@ -1856,7 +1874,31 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { data: ast::ArithDetails::Float(data), .. } => InstructionModes::from_arith_float(data), - ast::Instruction::Setp { + ast::Instruction::Set { + data: ast::SetData{ + base: ast::SetpData { + type_, + flush_to_zero, + .. + }, + .. + }, + .. + } + | ast::Instruction::SetBool { + data: ast::SetBoolData { + base: ast::SetpBoolData { + base: ast::SetpData { + type_, + flush_to_zero, + .. + }, + .. + }, + .. + }, .. + } + | ast::Instruction::Setp { data: ast::SetpData { type_, @@ -1878,34 +1920,6 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { }, .. } - | ast::Instruction::Neg { - data: ast::TypeFtz { - type_, - flush_to_zero, - }, - .. - } - | ast::Instruction::Ex2 { - data: ast::TypeFtz { - type_, - flush_to_zero, - }, - .. - } - | ast::Instruction::Rsqrt { - data: ast::TypeFtz { - type_, - flush_to_zero, - }, - .. - } - | ast::Instruction::Abs { - data: ast::TypeFtz { - type_, - flush_to_zero, - }, - .. - } | ast::Instruction::Min { data: ast::MinMaxDetails::Float(ast::MinMaxFloat { @@ -1945,12 +1959,29 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { ) } ast::Instruction::Sin { data, .. } - | ast::Instruction::Cos { data, .. } - | ast::Instruction::Lg2 { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero), + | ast::Instruction::Cos { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero), ast::Instruction::Rcp { data, .. } | ast::Instruction::Sqrt { data, .. } => { - InstructionModes::from_rcp(*data) + InstructionModes::from_rtz_special(*data) } + ast::Instruction::Rsqrt { data, .. } + | ast::Instruction::Ex2 { data, .. } => { + let data = ast::RcpData { + type_: data.type_, + flush_to_zero: data.flush_to_zero, + kind: ast::RcpKind::Approx, + }; + InstructionModes::from_rtz_special(data) + }, + ast::Instruction::Lg2 { data, .. } => { + let data = ast::RcpData { + type_: ast::ScalarType::F32, + flush_to_zero: Some(data.flush_to_zero), + kind: ast::RcpKind::Approx, + }; + InstructionModes::from_rtz_special(data) + }, ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data), + ast::Instruction::Tanh { data, .. } => InstructionModes::from_ftz(*data, Some(false)), } } diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index eb1fbd4..bd0160a 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -27,7 +27,7 @@ use std::array::TryFromSliceError; use std::convert::TryInto; use std::ffi::{CStr, NulError}; -use std::{i8, ptr}; +use std::{i8, ptr, u64}; use super::*; use crate::pass::*; @@ -471,8 +471,10 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments), ast::Instruction::Mul24 { data, arguments } => self.emit_mul24(data, arguments), + ast::Instruction::Set { data, arguments } => self.emit_set(data, arguments), + ast::Instruction::SetBool { data, arguments } => self.emit_set_bool(data, arguments), ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments), - ast::Instruction::SetpBool { .. } => todo!(), + ast::Instruction::SetpBool { data, arguments } => self.emit_setp_bool(data, arguments), ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments), ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments), ast::Instruction::And { arguments, .. } => self.emit_and(arguments), @@ -506,10 +508,13 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments), ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments), ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments), - ast::Instruction::PrmtSlow { .. } => todo!(), + ast::Instruction::PrmtSlow { .. } => { + Err(error_todo_msg("PrmtSlow is not implemented yet")) + } ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments), ast::Instruction::Membar { data } => self.emit_membar(data), - ast::Instruction::Trap {} => todo!(), + ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")), + ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments), // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bar { .. } @@ -713,18 +718,31 @@ impl<'a> MethodEmitContext<'a> { arguments: ast::AddArgs, ) -> Result<(), TranslateError> { let builder = self.builder; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; let fn_ = match data { ast::ArithDetails::Integer(ast::ArithInteger { saturate: true, type_, - }) => return self.emit_add_sat(type_, arguments), + }) => { + let op = if type_.kind() == ast::ScalarKind::Signed { + "sadd" + } else { + "uadd" + }; + return self.emit_intrinsic_saturate( + op, + type_, + arguments.dst, + arguments.src1, + arguments.src2, + ); + } ast::ArithDetails::Integer(ast::ArithInteger { saturate: false, .. }) => LLVMBuildAdd, ast::ArithDetails::Float(..) => LLVMBuildFAdd, }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst| unsafe { fn_(builder, src1, src2, dst) }); @@ -1353,7 +1371,18 @@ impl<'a> MethodEmitContext<'a> { arguments: ptx_parser::SubArgs, ) -> Result<(), TranslateError> { if arith_integer.saturate { - todo!() + let op = if arith_integer.type_.kind() == ast::ScalarKind::Signed { + "ssub" + } else { + "usub" + }; + return self.emit_intrinsic_saturate( + op, + arith_integer.type_, + arguments.dst, + arguments.src1, + arguments.src2, + ); } let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; @@ -1365,12 +1394,9 @@ impl<'a> MethodEmitContext<'a> { fn emit_sub_float( &mut self, - arith_float: ptx_parser::ArithFloat, + _arith_float: ptx_parser::ArithFloat, arguments: ptx_parser::SubArgs, ) -> Result<(), TranslateError> { - if arith_float.saturate { - todo!() - } let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst| unsafe { @@ -1430,25 +1456,39 @@ impl<'a> MethodEmitContext<'a> { arguments: ptx_parser::NegArgs, ) -> Result<(), TranslateError> { let src = self.resolver.value(arguments.src)?; - let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float { + let is_floating_point = data.type_.kind() == ptx_parser::ScalarKind::Float; + let llvm_fn = if is_floating_point { LLVMBuildFNeg } else { LLVMBuildNeg }; - self.resolver.with_result(arguments.dst, |dst| unsafe { - llvm_fn(self.builder, src, dst) - }); + if is_floating_point && data.flush_to_zero == Some(true) { + let negated = unsafe { llvm_fn(self.builder, src, LLVM_UNNAMED.as_ptr()) }; + let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&data.type_.into()), + vec![(negated, get_scalar_type(self.context, data.type_))], + )?; + } else { + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst) + }); + } Ok(()) } fn emit_not( &mut self, - _data: ptx_parser::ScalarType, + type_: ptx_parser::ScalarType, arguments: ptx_parser::NotArgs, ) -> Result<(), TranslateError> { let src = self.resolver.value(arguments.src)?; + let type_ = get_scalar_type(self.context, type_); + let constant = unsafe { LLVMConstInt(type_, u64::MAX, 0) }; self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildNot(self.builder, src, dst) + LLVMBuildXor(self.builder, src, constant, dst) }); Ok(()) } @@ -1458,15 +1498,29 @@ impl<'a> MethodEmitContext<'a> { data: ptx_parser::SetpData, arguments: ptx_parser::SetpArgs, ) -> Result<(), TranslateError> { - if arguments.dst2.is_some() { - todo!() + let dst = self.emit_setp_impl(data, arguments.dst2, arguments.src1, arguments.src2)?; + self.resolver.register(arguments.dst1, dst); + Ok(()) + } + + fn emit_setp_impl( + &mut self, + data: ptx_parser::SetpData, + dst2: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + if dst2.is_some() { + return Err(error_todo_msg( + "setp with two dst arguments not yet supported", + )); } match data.cmp_op { ptx_parser::SetpCompareOp::Integer(setp_compare_int) => { - self.emit_setp_int(setp_compare_int, arguments) + self.emit_setp_int(setp_compare_int, src1, src2) } ptx_parser::SetpCompareOp::Float(setp_compare_float) => { - self.emit_setp_float(setp_compare_float, arguments) + self.emit_setp_float(setp_compare_float, src1, src2) } } } @@ -1474,8 +1528,9 @@ impl<'a> MethodEmitContext<'a> { fn emit_setp_int( &mut self, setp: ptx_parser::SetpCompareInt, - arguments: ptx_parser::SetpArgs, - ) -> Result<(), TranslateError> { + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { let op = match setp { ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ, ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE, @@ -1488,19 +1543,17 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT, ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE, }; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst1, |dst1| unsafe { - LLVMBuildICmp(self.builder, op, src1, src2, dst1) - }); - Ok(()) + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + Ok(unsafe { LLVMBuildICmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) }) } fn emit_setp_float( &mut self, setp: ptx_parser::SetpCompareFloat, - arguments: ptx_parser::SetpArgs, - ) -> Result<(), TranslateError> { + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { let op = match setp { ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ, ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE, @@ -1517,12 +1570,9 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD, ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO, }; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst1, |dst1| unsafe { - LLVMBuildFCmp(self.builder, op, src1, src2, dst1) - }); - Ok(()) + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + Ok(unsafe { LLVMBuildFCmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) }) } fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> { @@ -1935,15 +1985,34 @@ impl<'a> MethodEmitContext<'a> { data: ptx_parser::ShrData, arguments: ptx_parser::ShrArgs, ) -> Result<(), TranslateError> { - let shift_fn = match data.kind { - ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr, - ptx_parser::RightShiftKind::Logical => LLVMBuildLShr, + let llvm_type = get_scalar_type(self.context, data.type_); + let (out_of_range, shift_fn): ( + *mut LLVMValue, + unsafe extern "C" fn( + LLVMBuilderRef, + LLVMValueRef, + LLVMValueRef, + *const i8, + ) -> LLVMValueRef, + ) = match data.kind { + ptx_parser::RightShiftKind::Logical => { + (unsafe { LLVMConstNull(llvm_type) }, LLVMBuildLShr) + } + ptx_parser::RightShiftKind::Arithmetic => { + let src1 = self.resolver.value(arguments.src1)?; + let shift_size = + unsafe { LLVMConstInt(llvm_type, (data.type_.size_of() as u64 * 8) - 1, 0) }; + let out_of_range = + unsafe { LLVMBuildAShr(self.builder, src1, shift_size, LLVM_UNNAMED.as_ptr()) }; + (out_of_range, LLVMBuildAShr) + } }; self.emit_shift( data.type_, arguments.dst, arguments.src1, arguments.src2, + out_of_range, shift_fn, ) } @@ -1953,11 +2022,13 @@ impl<'a> MethodEmitContext<'a> { type_: ptx_parser::ScalarType, arguments: ptx_parser::ShlArgs, ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); self.emit_shift( type_, arguments.dst, arguments.src1, arguments.src2, + unsafe { LLVMConstNull(llvm_type) }, LLVMBuildShl, ) } @@ -1968,6 +2039,7 @@ impl<'a> MethodEmitContext<'a> { dst: SpirvWord, src1: SpirvWord, src2: SpirvWord, + out_of_range_value: LLVMValueRef, llvm_fn: unsafe extern "C" fn( LLVMBuilderRef, LLVMValueRef, @@ -1995,7 +2067,6 @@ impl<'a> MethodEmitContext<'a> { ) }; let llvm_type = get_scalar_type(self.context, type_); - let zero = unsafe { LLVMConstNull(llvm_type) }; let normalized_shift_size = if type_.layout().size() >= 4 { unsafe { LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) @@ -2012,7 +2083,7 @@ impl<'a> MethodEmitContext<'a> { ) }; self.resolver.with_result(dst, |dst| unsafe { - LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst) + LLVMBuildSelect(self.builder, should_clamp, out_of_range_value, shifted, dst) }); Ok(()) } @@ -2207,7 +2278,19 @@ impl<'a> MethodEmitContext<'a> { }, ) } - ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()), + ptx_parser::MadDetails::Integer { + saturate: true, + control: ast::MulIntControl::High, + type_: ast::ScalarType::S32, + } => { + return self.emit_mad_hi_sat_s32( + arguments.dst, + (arguments.src1, arguments.src2, arguments.src3), + ); + } + ptx_parser::MadDetails::Integer { saturate: true, .. } => { + return Err(error_unreachable()) + } ptx_parser::MadDetails::Integer { type_, control, .. } => { ast::MulDetails::Integer { control, type_ } } @@ -2281,7 +2364,8 @@ impl<'a> MethodEmitContext<'a> { ) -> Result<(), TranslateError> { let llvm_type = get_scalar_type(self.context, data.type_); let src = self.resolver.value(arguments.src)?; - let (prefix, intrinsic_arguments) = if data.type_.kind() == ast::ScalarKind::Float { + let is_floating_point = data.type_.kind() == ast::ScalarKind::Float; + let (prefix, intrinsic_arguments) = if is_floating_point { ("llvm.fabs", vec![(src, llvm_type)]) } else { let pred = get_scalar_type(self.context, ast::ScalarType::Pred); @@ -2289,12 +2373,23 @@ impl<'a> MethodEmitContext<'a> { ("llvm.abs", vec![(src, llvm_type), (zero, pred)]) }; let llvm_intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(data.type_)); - self.emit_intrinsic( + let abs_result = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) }, - Some(arguments.dst), + None, Some(&data.type_.into()), intrinsic_arguments, )?; + if is_floating_point && data.flush_to_zero == Some(true) { + let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&data.type_.into()), + vec![(abs_result, llvm_type)], + )?; + } else { + self.resolver.register(arguments.dst, abs_result); + } Ok(()) } @@ -2434,23 +2529,21 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } - fn emit_add_sat( + fn emit_intrinsic_saturate( &mut self, + op: &str, type_: ast::ScalarType, - arguments: ast::AddArgs, + dst: SpirvWord, + src1: SpirvWord, + src2: SpirvWord, ) -> Result<(), TranslateError> { let llvm_type = get_scalar_type(self.context, type_); - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - let op = if type_.kind() == ast::ScalarKind::Signed { - "sadd" - } else { - "uadd" - }; + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; let intrinsic = format!("llvm.{}.sat.{}\0", op, LLVMTypeDisplay(type_)); self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(arguments.dst), + Some(dst), Some(&type_.into()), vec![(src1, llvm_type), (src2, llvm_type)], )?; @@ -2475,6 +2568,130 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_mad_hi_sat_s32( + &mut self, + dst: SpirvWord, + (src1, src2, src3): (SpirvWord, SpirvWord, SpirvWord), + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + let src3 = self.resolver.value(src3)?; + let llvm_type_s32 = get_scalar_type(self.context, ast::ScalarType::S32); + let llvm_type_s64 = get_scalar_type(self.context, ast::ScalarType::S64); + let src1_wide = + unsafe { LLVMBuildSExt(self.builder, src1, llvm_type_s64, LLVM_UNNAMED.as_ptr()) }; + let src2_wide = + unsafe { LLVMBuildSExt(self.builder, src2, llvm_type_s64, LLVM_UNNAMED.as_ptr()) }; + let mul_wide = + unsafe { LLVMBuildMul(self.builder, src1_wide, src2_wide, LLVM_UNNAMED.as_ptr()) }; + let const_32 = unsafe { LLVMConstInt(llvm_type_s64, 32, 0) }; + let mul_wide = + unsafe { LLVMBuildLShr(self.builder, mul_wide, const_32, LLVM_UNNAMED.as_ptr()) }; + let mul_narrow = + unsafe { LLVMBuildTrunc(self.builder, mul_wide, llvm_type_s32, LLVM_UNNAMED.as_ptr()) }; + self.emit_intrinsic( + c"llvm.sadd.sat.i32", + Some(dst), + Some(&ast::ScalarType::S32.into()), + vec![(mul_narrow, llvm_type_s32), (src3, llvm_type_s32)], + )?; + Ok(()) + } + + fn emit_set( + &mut self, + data: ptx_parser::SetData, + arguments: ptx_parser::SetArgs, + ) -> Result<(), TranslateError> { + let setp_result = self.emit_setp_impl(data.base, None, arguments.src1, arguments.src2)?; + self.setp_to_set(arguments.dst, data.dtype, setp_result)?; + Ok(()) + } + + fn emit_set_bool( + &mut self, + data: ptx_parser::SetBoolData, + arguments: ptx_parser::SetBoolArgs, + ) -> Result<(), TranslateError> { + let result = + self.emit_setp_bool_impl(data.base, arguments.src1, arguments.src2, arguments.src3)?; + self.setp_to_set(arguments.dst, data.dtype, result)?; + Ok(()) + } + + fn emit_setp_bool( + &mut self, + data: ast::SetpBoolData, + args: ast::SetpBoolArgs, + ) -> Result<(), TranslateError> { + let dst = self.emit_setp_bool_impl(data, args.src1, args.src2, args.src3)?; + self.resolver.register(args.dst1, dst); + Ok(()) + } + + fn emit_setp_bool_impl( + &mut self, + data: ptx_parser::SetpBoolData, + src1: SpirvWord, + src2: SpirvWord, + src3: SpirvWord, + ) -> Result { + let bool_result = self.emit_setp_impl(data.base, None, src1, src2)?; + let bool_result = if data.negate_src3 { + let constant = + unsafe { LLVMConstInt(LLVMIntTypeInContext(self.context, 1), u64::MAX, 0) }; + unsafe { LLVMBuildXor(self.builder, bool_result, constant, LLVM_UNNAMED.as_ptr()) } + } else { + bool_result + }; + let post_op = match data.bool_op { + ptx_parser::SetpBoolPostOp::Xor => LLVMBuildXor, + ptx_parser::SetpBoolPostOp::And => LLVMBuildAnd, + ptx_parser::SetpBoolPostOp::Or => LLVMBuildOr, + }; + let src3 = self.resolver.value(src3)?; + Ok(unsafe { post_op(self.builder, bool_result, src3, LLVM_UNNAMED.as_ptr()) }) + } + + fn setp_to_set( + &mut self, + dst: SpirvWord, + dtype: ast::ScalarType, + setp_result: LLVMValueRef, + ) -> Result<(), TranslateError> { + let llvm_dtype = get_scalar_type(self.context, dtype); + let zero = unsafe { LLVMConstNull(llvm_dtype) }; + let one = if dtype.kind() == ast::ScalarKind::Float { + unsafe { LLVMConstReal(llvm_dtype, 1.0) } + } else { + unsafe { LLVMConstInt(llvm_dtype, u64::MAX, 0) } + }; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildSelect(self.builder, setp_result, one, zero, dst) + }); + Ok(()) + } + + // TODO: revisit this on gfx1250 which has native tanh support + fn emit_tanh( + &mut self, + data: ast::ScalarType, + arguments: ast::TanhArgs, + ) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let llvm_type = get_scalar_type(self.context, data); + let name = format!("__ocml_tanh_{}\0", LLVMTypeDisplay(data)); + let tanh = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) }, + Some(arguments.dst), + Some(&data.into()), + vec![(src, llvm_type)], + )?; + // Not sure if it ultimately does anything + unsafe { LLVMZludaSetFastMathFlags(tanh, LLVMZludaFastMathApproxFunc) } + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 @@ -2651,7 +2868,7 @@ fn get_function_type<'a>( _ => { check_multiple_return_types(return_args)?; get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)? - }, + } }; Ok(unsafe { diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 6420c79..84bb442 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -94,6 +94,46 @@ fn run_instruction<'input>( instruction: ptx_parser::Instruction, ) -> Result, TranslateError> { Ok(match instruction { + i @ ptx_parser::Instruction::Sqrt { + data: + ast::RcpData { + kind: ast::RcpKind::Approx, + type_: ast::ScalarType::F32, + flush_to_zero: None | Some(false), + }, + .. + } => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Rsqrt { + data: + ast::TypeFtz { + type_: ast::ScalarType::F32, + flush_to_zero: None | Some(false), + }, + .. + } => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Rcp { + data: + ast::RcpData { + kind: ast::RcpKind::Approx, + type_: ast::ScalarType::F32, + flush_to_zero: None | Some(false), + }, + .. + } => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Ex2 { + data: + ast::TypeFtz { + type_: ast::ScalarType::F32, + flush_to_zero: None | Some(false), + }, + .. + } => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Lg2 { + data: ast::FlushToZero { + flush_to_zero: false, + }, + .. + } => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?, i @ ptx_parser::Instruction::Activemask { .. } => { to_call(resolver, fn_declarations, "activemask".into(), i)? } @@ -116,7 +156,12 @@ fn run_instruction<'input>( ptx_parser::Reduction::And => "bar_red_and_pred", ptx_parser::Reduction::Or => "bar_red_or_pred", }; - to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })? + to_call( + resolver, + fn_declarations, + name.into(), + ptx_parser::Instruction::BarRed { data, arguments }, + )? } ptx_parser::Instruction::ShflSync { data, arguments } => { let mode = match data.mode { diff --git a/ptx/src/test/ll/abs.ll b/ptx/src/test/ll/abs.ll new file mode 100644 index 0000000..026c854 --- /dev/null +++ b/ptx/src/test/ll/abs.ll @@ -0,0 +1,34 @@ +define amdgpu_kernel void @abs(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i32, align 4, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"37" = load i64, ptr addrspace(4) %"31", align 4 + store i64 %"37", ptr addrspace(5) %"33", align 4 + %"38" = load i64, ptr addrspace(4) %"32", align 4 + store i64 %"38", ptr addrspace(5) %"34", align 4 + %"40" = load i64, ptr addrspace(5) %"33", align 4 + %"45" = inttoptr i64 %"40" to ptr + %"39" = load i32, ptr %"45", align 4 + store i32 %"39", ptr addrspace(5) %"35", align 4 + %"42" = load i32, ptr addrspace(5) %"35", align 4 + %"41" = call i32 @llvm.abs.i32(i32 %"42", i1 false) + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"43" = load i64, ptr addrspace(5) %"34", align 4 + %"44" = load i32, ptr addrspace(5) %"36", align 4 + %"46" = inttoptr i64 %"43" to ptr + store i32 %"44", ptr %"46", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.abs.i32(i32, i1 immarg) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/ll/shr.ll b/ptx/src/test/ll/shr.ll index bc0acae..da665e4 100644 --- a/ptx/src/test/ll/shr.ll +++ b/ptx/src/test/ll/shr.ll @@ -17,8 +17,9 @@ define amdgpu_kernel void @shr(ptr addrspace(4) byref(i64) %"31", ptr addrspace( %"38" = load i32, ptr %"44", align 4 store i32 %"38", ptr addrspace(5) %"35", align 4 %"41" = load i32, ptr addrspace(5) %"35", align 4 - %2 = ashr i32 %"41", 1 - %"40" = select i1 false, i32 0, i32 %2 + %2 = ashr i32 %"41", 31 + %3 = ashr i32 %"41", 1 + %"40" = select i1 false, i32 %2, i32 %3 store i32 %"40", ptr addrspace(5) %"35", align 4 %"42" = load i64, ptr addrspace(5) %"34", align 8 %"43" = load i32, ptr addrspace(5) %"35", align 4 diff --git a/ptx/src/test/ll/shr_oob.ll b/ptx/src/test/ll/shr_oob.ll new file mode 100644 index 0000000..cfe2532 --- /dev/null +++ b/ptx/src/test/ll/shr_oob.ll @@ -0,0 +1,31 @@ +define amdgpu_kernel void @shr_oob(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i16, align 2, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"36" = load i64, ptr addrspace(4) %"31", align 4 + store i64 %"36", ptr addrspace(5) %"33", align 4 + %"37" = load i64, ptr addrspace(4) %"32", align 4 + store i64 %"37", ptr addrspace(5) %"34", align 4 + %"39" = load i64, ptr addrspace(5) %"33", align 4 + %"44" = inttoptr i64 %"39" to ptr + %"38" = load i16, ptr %"44", align 2 + store i16 %"38", ptr addrspace(5) %"35", align 2 + %"41" = load i16, ptr addrspace(5) %"35", align 2 + %2 = ashr i16 %"41", 15 + %3 = ashr i16 %"41", 16 + %"40" = select i1 true, i16 %2, i16 %3 + store i16 %"40", ptr addrspace(5) %"35", align 2 + %"42" = load i64, ptr addrspace(5) %"34", align 4 + %"43" = load i16, ptr addrspace(5) %"35", align 2 + %"45" = inttoptr i64 %"42" to ptr + store i16 %"43", ptr %"45", align 2 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/ll/tanh.ll b/ptx/src/test/ll/tanh.ll new file mode 100644 index 0000000..71f8af5 --- /dev/null +++ b/ptx/src/test/ll/tanh.ll @@ -0,0 +1,31 @@ +define amdgpu_kernel void @tanh(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #0 { + %"32" = alloca i64, align 8, addrspace(5) + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca float, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"29" + +"29": ; preds = %1 + %"35" = load i64, ptr addrspace(4) %"30", align 4 + store i64 %"35", ptr addrspace(5) %"32", align 4 + %"36" = load i64, ptr addrspace(4) %"31", align 4 + store i64 %"36", ptr addrspace(5) %"33", align 4 + %"38" = load i64, ptr addrspace(5) %"32", align 4 + %"43" = inttoptr i64 %"38" to ptr + %"37" = load float, ptr %"43", align 4 + store float %"37", ptr addrspace(5) %"34", align 4 + %"40" = load float, ptr addrspace(5) %"34", align 4 + %"39" = call afn float @__ocml_tanh_f32(float %"40") + store float %"39", ptr addrspace(5) %"34", align 4 + %"41" = load i64, ptr addrspace(5) %"33", align 4 + %"42" = load float, ptr addrspace(5) %"34", align 4 + %"44" = inttoptr i64 %"41" to ptr + store float %"42", ptr %"44", align 4 + ret void +} + +declare float @__ocml_tanh_f32(float) + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/abs.ptx b/ptx/src/test/spirv_run/abs.ptx new file mode 100644 index 0000000..0650ea5 --- /dev/null +++ b/ptx/src/test/spirv_run/abs.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry abs( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp1; + .reg .s32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp1, [in_addr]; + abs.s32 temp2, temp1; + st.s32 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index df85a24..381a224 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -159,6 +159,7 @@ test_ptx!( ); test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); test_ptx!(shr, [-2i32], [-1i32]); +test_ptx!(shr_oob, [-32768i16], [-1i16]); test_ptx!(or, [1u64, 2u64], [3u64]); test_ptx!(sub, [2u64], [1u64]); test_ptx!(min, [555i32, 444i32], [444i32]); @@ -180,6 +181,7 @@ test_ptx!( [0b1_00000000_01000000000000000000000u32] ); test_ptx!(constant_f32, [10f32], [5f32]); +test_ptx!(abs, [i32::MIN], [i32::MIN]); test_ptx!(constant_negative, [-101i32], [101i32]); test_ptx!(and, [6u32, 3u32], [2u32]); test_ptx!(selp, [100u16, 200u16], [200u16]); @@ -296,6 +298,7 @@ test_ptx!( ); test_ptx!(multiple_return, [5u32], [6u32, 123u32]); test_ptx!(warp_sz, [0u8], [32u8]); +test_ptx!(tanh, [f32::INFINITY], [1.0f32]); test_ptx!(nanosleep, [0u64], [0u64]); diff --git a/ptx/src/test/spirv_run/shr_oob.ptx b/ptx/src/test/spirv_run/shr_oob.ptx new file mode 100644 index 0000000..07d65bf --- /dev/null +++ b/ptx/src/test/spirv_run/shr_oob.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shr_oob( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s16 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s16 temp, [in_addr]; + shr.s16 temp, temp, 16; + st.s16 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/tanh.ptx b/ptx/src/test/spirv_run/tanh.ptx new file mode 100644 index 0000000..4eb7805 --- /dev/null +++ b/ptx/src/test/spirv_run/tanh.ptx @@ -0,0 +1,21 @@ +.version 7.0 +.target sm_75 +.address_size 64 + +.visible .entry tanh( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp1; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp1, [in_addr]; + tanh.approx.f32 temp1, temp1; + st.f32 [out_addr], temp1; + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index dee15ce..7e99d6b 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -428,6 +428,44 @@ ptx_parser_macros::generate_instruction_type!( }, } }, + Set { + data: SetData, + arguments: { + dst: { + repr: T, + type: Type::from(data.dtype) + }, + src1: { + repr: T, + type: Type::from(data.base.type_), + }, + src2: { + repr: T, + type: Type::from(data.base.type_), + } + } + }, + SetBool { + data: SetBoolData, + arguments: { + dst: { + repr: T, + type: Type::from(data.dtype) + }, + src1: { + repr: T, + type: Type::from(data.base.base.type_), + }, + src2: { + repr: T, + type: Type::from(data.base.base.type_), + }, + src3: { + repr: T, + type: Type::from(ScalarType::Pred) + } + } + }, Setp { data: SetpData, arguments: { @@ -562,6 +600,14 @@ ptx_parser_macros::generate_instruction_type!( src2: T } }, + Tanh { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src: T + } + }, } ); @@ -1247,6 +1293,11 @@ pub struct Mul24Details { pub control: Mul24Control, } +pub struct SetData { + pub dtype: ScalarType, + pub base: SetpData, +} + pub struct SetpData { pub type_: ScalarType, pub flush_to_zero: Option, @@ -1288,6 +1339,12 @@ impl SetpData { } } + +pub struct SetBoolData { + pub dtype: ScalarType, + pub base: SetpBoolData, +} + pub struct SetpBoolData { pub base: SetpData, pub bool_op: SetpBoolPostOp, diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index e2c87fc..4572842 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -442,7 +442,7 @@ fn directive<'a, 'input>( )), ( any, - take_till(1.., |(token, _)| match token { + take_till(1.., |(token, _)| match token { // visibility Token::DotExtern | Token::DotVisible | Token::DotWeak // methods @@ -2201,6 +2201,43 @@ derive_parser!( .rnd: RawRoundingMode = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#comparison-and-selection-instructions-set + // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-comparison-instructions-set + set.CmpOp{.ftz}.dtype.stype d, a, b => { + let base = ast::SetpData::try_parse(state, cmpop, ftz, stype); + let data = ast::SetData { + base, dtype + }; + ast::Instruction::Set { + data, + arguments: SetArgs { dst: d, src1: a, src2: b } + } + } + set.CmpOp.BoolOp{.ftz}.dtype.stype d, a, b, {!}c => { + let (negate_src3, c) = c; + let base = ast::SetpData::try_parse(state, cmpop, ftz, stype); + let base = ast::SetpBoolData { + base, + bool_op: boolop, + negate_src3 + }; + let data = ast::SetBoolData { + base, dtype + }; + ast::Instruction::SetBool { + data, + arguments: SetBoolArgs { dst: d, src1: a, src2: b, src3: c } + } + } + + .CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge, + .lo, .ls, .hi, .hs, // signed + .equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only + .BoolOp: SetpBoolPostOp = { .and, .or, .xor }; + .dtype: ScalarType = { .u32, .s32, .f32 }; + .stype: ScalarType = { .b16, .b32, .b64, .u16, .u32, .u64, .s16, .s32, .s64, .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp setp.CmpOp{.ftz}.type p[|q], a, b => { @@ -3509,6 +3546,16 @@ derive_parser!( arguments: NanosleepArgs { src: t } } } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-tanh + // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-tanh + tanh.approx.type d, a => { + Instruction::Tanh { + data: type_, + arguments: TanhArgs { dst: d, src: a } + } + } + .type: ScalarType = { .f32, .f16, .f16x2, .bf16, .bf16x2 }; ); #[cfg(test)]