Assorted instruction fixes (#423)

This fixes transcendentals and some other buggy instructions exposed by `ptx_tests` (abs, neg). Add (slow - hardware limitation) tanh.
Only two remaining incorrect instructions are div and sqrt with non-default rounding, but this commit is already bloated enough
This commit is contained in:
Andrzej Janik
2025-07-24 00:50:35 +02:00
committed by GitHub
parent 119b635b9d
commit 3746079b1a
16 changed files with 753 additions and 108 deletions

Binary file not shown.

View File

@ -1,10 +1,11 @@
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
// /opt/rocm/llvm/bin/clang -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
#include <cstddef>
#include <cstdint>
#include <bit>
#include <cmath>
#include <hip/amd_detail/amd_device_functions.h>
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
@ -163,7 +164,7 @@ extern "C"
int32_t __ockl_wgred_and_i32(int32_t) __device__;
int32_t __ockl_wgred_or_i32(int32_t) __device__;
#define BAR_RED_IMPL(reducer) \
#define BAR_RED_IMPL(reducer) \
bool FUNC(bar_red_##reducer##_pred)(uint32_t barrier __attribute__((unused)), bool predicate, bool invert_predicate) \
{ \
/* TODO: handle barrier */ \
@ -173,7 +174,8 @@ extern "C"
BAR_RED_IMPL(and);
BAR_RED_IMPL(or);
struct ShflSyncResult {
struct ShflSyncResult
{
uint32_t output;
bool in_bounds;
};
@ -191,7 +193,7 @@ extern "C"
// intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the
// bounds check.
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \
{ \
int32_t section_mask = (opts >> 8) & 0b11111; \
@ -201,10 +203,11 @@ extern "C"
int32_t subsection_end = subsection | (~section_mask & warp_end); \
int32_t idx = calculate_index; \
bool out_of_bounds = idx CMP subsection_end; \
if (out_of_bounds) { \
if (out_of_bounds) \
{ \
idx = self; \
} \
int32_t output = __builtin_amdgcn_ds_bpermute(idx<<2, (int32_t)input); \
int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \
return {(uint32_t)output, !out_of_bounds}; \
} \
\
@ -212,15 +215,15 @@ extern "C"
{ \
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \
}
// We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync
// intrinsics, as those only add an assertion checking that the membermask is used correctly.
// They do not return the result of the range check, so we must replicate that logic here.
SHFL_SYNC_IMPL(up, self - delta, <);
SHFL_SYNC_IMPL(down, self + delta, >);
SHFL_SYNC_IMPL(bfly, self ^ delta, >);
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
SHFL_SYNC_IMPL(up, self - delta, <);
SHFL_SYNC_IMPL(down, self + delta, >);
SHFL_SYNC_IMPL(bfly, self ^ delta, >);
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
DECLARE_ATTR(uint32_t, CLOCK_RATE);
void FUNC(nanosleep_u32)(uint32_t nanoseconds) {
@ -254,4 +257,82 @@ extern "C"
(void)char_size;
__assert_fail((const char *)message, (const char *)file, line, (const char *)function);
}
// * Smallest denormal is 1.4 × 10^-45
// * Smallest normal is ~1.175494351 × 10^(-38)
// * Now, 1.175494351×10^-38 / 1.4 × 10^-45 = 8396388 + 31/140
// * Next power of 2 is 16777216
const float DENORMAL_TO_NORMAL_FACTOR_F32 = 16777216.0f;
// * Largest subnormal is ~1.175494210692441e × 10^(-38)
// * Then any value equal or larger than following will produce subnormals: 8.50706018714406320806444272332455743547934627837873057975602739772164... × 10^37
const float RCP_DENORMAL_OUTPUT = 8.50706018714406320806444272332455743547934627837873057975602739772164e37f;
const float REVERSE_DENORMAL_TO_NORMAL_FACTOR_F32 = 0.029387360490963111877208252592662410455594571842846914442095471744599661631813495980086003637902577995683214210345151992265999035207077609582844f;
float FUNC(sqrt_approx_f32)(float x)
{
bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL);
float input = x;
if (is_subnormal)
input = x * DENORMAL_TO_NORMAL_FACTOR_F32;
float value = __builtin_amdgcn_sqrtf(input);
if (is_subnormal)
return value * 0.000244140625f;
else
return value;
}
float FUNC(rsqrt_approx_f32)(float x)
{
bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL);
float input = x;
if (is_subnormal)
input = x * DENORMAL_TO_NORMAL_FACTOR_F32;
float value = __builtin_amdgcn_rsqf(input);
if (is_subnormal)
return value * 4096.0f;
else
return value;
}
float FUNC(rcp_approx_f32)(float x)
{
float factor = 1.0f;
if (__builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL))
{
factor = DENORMAL_TO_NORMAL_FACTOR_F32;
}
if (std::fabs(x) >= RCP_DENORMAL_OUTPUT)
{
factor = REVERSE_DENORMAL_TO_NORMAL_FACTOR_F32;
}
return __builtin_amdgcn_rcpf(x * factor) * factor;
}
// When x = -126, exp2(x) = 2^(-126) ≈ 1.175494351 × 10^(-38),
// which is the smallest normalized number in FP32
float FUNC(ex2_approx_f32)(float x)
{
bool special_handling = x < -126.0f;
float input = x;
if (special_handling)
input *= 0.5f;
float result = __builtin_amdgcn_exp2f(input);
if (special_handling)
return result * result;
else
return result;
}
float FUNC(lg2_approx_f32)(float x)
{
bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL);
float input = x;
if (is_subnormal)
input = x * DENORMAL_TO_NORMAL_FACTOR_F32;
float value = __builtin_amdgcn_logf(input);
if (is_subnormal)
return value - 24.0f;
else
return value;
}
}

View File

@ -164,6 +164,8 @@ fn run_instruction<'input>(
| ast::Instruction::Ret { .. }
| ast::Instruction::Rsqrt { .. }
| ast::Instruction::Selp { .. }
| ast::Instruction::Set { .. }
| ast::Instruction::SetBool { .. }
| ast::Instruction::Setp { .. }
| ast::Instruction::SetpBool { .. }
| ast::Instruction::ShflSync { .. }
@ -183,6 +185,7 @@ fn run_instruction<'input>(
data: ast::ArithDetails::Integer(..),
..
}
| ast::Instruction::Tanh { .. }
| ast::Instruction::Trap {}
| ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add {

View File

@ -221,12 +221,25 @@ impl InstructionModes {
)
}
fn from_rcp(data: ast::RcpData) -> InstructionModes {
fn from_rtz_special(data: ast::RcpData) -> InstructionModes {
let rounding = match data.kind {
ast::RcpKind::Approx => None,
ast::RcpKind::Compliant(rnd) => Some(RoundingMode::from_ast(rnd)),
};
let denormal = data.flush_to_zero.map(DenormalMode::from_ftz);
let denormal = match (
data.kind == ast::RcpKind::Approx,
data.flush_to_zero == Some(true),
) {
// If we are approximate and flushing then we compile to V_RSQ_F32
// or V_SQRT_F32 which ignores prevailing denormal mode and flushes anyway
(true, true) => None,
// If we are approximate and flushing the V_RSQ_F32/V_SQRT_F32
// ignores ftz mode, but we implement instruction in terms of fmuls
// which must preserve denormals
(true, false) => Some(DenormalMode::Preserve),
// For full precision we set denormal mode accordingly
(false, ftz) => Some(DenormalMode::from_ftz(ftz)),
};
InstructionModes::new(data.type_, denormal, rounding)
}
@ -1780,6 +1793,11 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
match inst {
// TODO: review it when implementing virtual calls
ast::Instruction::Call { .. }
// abs and neg have special handling in AMD GPU ISA. They get compiled
// down to instruction argument modifiers, floating point flags have no
// effect on it. We handle it during LLVM bitcode emission
| ast::Instruction::Abs { .. }
| ast::Instruction::Neg {.. }
| ast::Instruction::Mov { .. }
| ast::Instruction::Ld { .. }
| ast::Instruction::St { .. }
@ -1856,7 +1874,31 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
data: ast::ArithDetails::Float(data),
..
} => InstructionModes::from_arith_float(data),
ast::Instruction::Setp {
ast::Instruction::Set {
data: ast::SetData{
base: ast::SetpData {
type_,
flush_to_zero,
..
},
..
},
..
}
| ast::Instruction::SetBool {
data: ast::SetBoolData {
base: ast::SetpBoolData {
base: ast::SetpData {
type_,
flush_to_zero,
..
},
..
},
..
}, ..
}
| ast::Instruction::Setp {
data:
ast::SetpData {
type_,
@ -1878,34 +1920,6 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
},
..
}
| ast::Instruction::Neg {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Ex2 {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Rsqrt {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Abs {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Min {
data:
ast::MinMaxDetails::Float(ast::MinMaxFloat {
@ -1945,12 +1959,29 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
)
}
ast::Instruction::Sin { data, .. }
| ast::Instruction::Cos { data, .. }
| ast::Instruction::Lg2 { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero),
| ast::Instruction::Cos { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero),
ast::Instruction::Rcp { data, .. } | ast::Instruction::Sqrt { data, .. } => {
InstructionModes::from_rcp(*data)
InstructionModes::from_rtz_special(*data)
}
ast::Instruction::Rsqrt { data, .. }
| ast::Instruction::Ex2 { data, .. } => {
let data = ast::RcpData {
type_: data.type_,
flush_to_zero: data.flush_to_zero,
kind: ast::RcpKind::Approx,
};
InstructionModes::from_rtz_special(data)
},
ast::Instruction::Lg2 { data, .. } => {
let data = ast::RcpData {
type_: ast::ScalarType::F32,
flush_to_zero: Some(data.flush_to_zero),
kind: ast::RcpKind::Approx,
};
InstructionModes::from_rtz_special(data)
},
ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data),
ast::Instruction::Tanh { data, .. } => InstructionModes::from_ftz(*data, Some(false)),
}
}

View File

@ -27,7 +27,7 @@
use std::array::TryFromSliceError;
use std::convert::TryInto;
use std::ffi::{CStr, NulError};
use std::{i8, ptr};
use std::{i8, ptr, u64};
use super::*;
use crate::pass::*;
@ -471,8 +471,10 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
ast::Instruction::Mul24 { data, arguments } => self.emit_mul24(data, arguments),
ast::Instruction::Set { data, arguments } => self.emit_set(data, arguments),
ast::Instruction::SetBool { data, arguments } => self.emit_set_bool(data, arguments),
ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
ast::Instruction::SetpBool { .. } => todo!(),
ast::Instruction::SetpBool { data, arguments } => self.emit_setp_bool(data, arguments),
ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments),
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
@ -506,10 +508,13 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments),
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
ast::Instruction::PrmtSlow { .. } => todo!(),
ast::Instruction::PrmtSlow { .. } => {
Err(error_todo_msg("PrmtSlow is not implemented yet"))
}
ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
ast::Instruction::Membar { data } => self.emit_membar(data),
ast::Instruction::Trap {} => todo!(),
ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")),
ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments),
// replaced by a function call
ast::Instruction::Bfe { .. }
| ast::Instruction::Bar { .. }
@ -713,18 +718,31 @@ impl<'a> MethodEmitContext<'a> {
arguments: ast::AddArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let fn_ = match data {
ast::ArithDetails::Integer(ast::ArithInteger {
saturate: true,
type_,
}) => return self.emit_add_sat(type_, arguments),
}) => {
let op = if type_.kind() == ast::ScalarKind::Signed {
"sadd"
} else {
"uadd"
};
return self.emit_intrinsic_saturate(
op,
type_,
arguments.dst,
arguments.src1,
arguments.src2,
);
}
ast::ArithDetails::Integer(ast::ArithInteger {
saturate: false, ..
}) => LLVMBuildAdd,
ast::ArithDetails::Float(..) => LLVMBuildFAdd,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
fn_(builder, src1, src2, dst)
});
@ -1353,7 +1371,18 @@ impl<'a> MethodEmitContext<'a> {
arguments: ptx_parser::SubArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arith_integer.saturate {
todo!()
let op = if arith_integer.type_.kind() == ast::ScalarKind::Signed {
"ssub"
} else {
"usub"
};
return self.emit_intrinsic_saturate(
op,
arith_integer.type_,
arguments.dst,
arguments.src1,
arguments.src2,
);
}
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
@ -1365,12 +1394,9 @@ impl<'a> MethodEmitContext<'a> {
fn emit_sub_float(
&mut self,
arith_float: ptx_parser::ArithFloat,
_arith_float: ptx_parser::ArithFloat,
arguments: ptx_parser::SubArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arith_float.saturate {
todo!()
}
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
@ -1430,25 +1456,39 @@ impl<'a> MethodEmitContext<'a> {
arguments: ptx_parser::NegArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float {
let is_floating_point = data.type_.kind() == ptx_parser::ScalarKind::Float;
let llvm_fn = if is_floating_point {
LLVMBuildFNeg
} else {
LLVMBuildNeg
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_fn(self.builder, src, dst)
});
if is_floating_point && data.flush_to_zero == Some(true) {
let negated = unsafe { llvm_fn(self.builder, src, LLVM_UNNAMED.as_ptr()) };
let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
Some(&data.type_.into()),
vec![(negated, get_scalar_type(self.context, data.type_))],
)?;
} else {
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_fn(self.builder, src, dst)
});
}
Ok(())
}
fn emit_not(
&mut self,
_data: ptx_parser::ScalarType,
type_: ptx_parser::ScalarType,
arguments: ptx_parser::NotArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
let type_ = get_scalar_type(self.context, type_);
let constant = unsafe { LLVMConstInt(type_, u64::MAX, 0) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildNot(self.builder, src, dst)
LLVMBuildXor(self.builder, src, constant, dst)
});
Ok(())
}
@ -1458,15 +1498,29 @@ impl<'a> MethodEmitContext<'a> {
data: ptx_parser::SetpData,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arguments.dst2.is_some() {
todo!()
let dst = self.emit_setp_impl(data, arguments.dst2, arguments.src1, arguments.src2)?;
self.resolver.register(arguments.dst1, dst);
Ok(())
}
fn emit_setp_impl(
&mut self,
data: ptx_parser::SetpData,
dst2: Option<SpirvWord>,
src1: SpirvWord,
src2: SpirvWord,
) -> Result<LLVMValueRef, TranslateError> {
if dst2.is_some() {
return Err(error_todo_msg(
"setp with two dst arguments not yet supported",
));
}
match data.cmp_op {
ptx_parser::SetpCompareOp::Integer(setp_compare_int) => {
self.emit_setp_int(setp_compare_int, arguments)
self.emit_setp_int(setp_compare_int, src1, src2)
}
ptx_parser::SetpCompareOp::Float(setp_compare_float) => {
self.emit_setp_float(setp_compare_float, arguments)
self.emit_setp_float(setp_compare_float, src1, src2)
}
}
}
@ -1474,8 +1528,9 @@ impl<'a> MethodEmitContext<'a> {
fn emit_setp_int(
&mut self,
setp: ptx_parser::SetpCompareInt,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
src1: SpirvWord,
src2: SpirvWord,
) -> Result<LLVMValueRef, TranslateError> {
let op = match setp {
ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ,
ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE,
@ -1488,19 +1543,17 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT,
ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
LLVMBuildICmp(self.builder, op, src1, src2, dst1)
});
Ok(())
let src1 = self.resolver.value(src1)?;
let src2 = self.resolver.value(src2)?;
Ok(unsafe { LLVMBuildICmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) })
}
fn emit_setp_float(
&mut self,
setp: ptx_parser::SetpCompareFloat,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
src1: SpirvWord,
src2: SpirvWord,
) -> Result<LLVMValueRef, TranslateError> {
let op = match setp {
ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ,
ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE,
@ -1517,12 +1570,9 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD,
ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
LLVMBuildFCmp(self.builder, op, src1, src2, dst1)
});
Ok(())
let src1 = self.resolver.value(src1)?;
let src2 = self.resolver.value(src2)?;
Ok(unsafe { LLVMBuildFCmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) })
}
fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> {
@ -1935,15 +1985,34 @@ impl<'a> MethodEmitContext<'a> {
data: ptx_parser::ShrData,
arguments: ptx_parser::ShrArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let shift_fn = match data.kind {
ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr,
ptx_parser::RightShiftKind::Logical => LLVMBuildLShr,
let llvm_type = get_scalar_type(self.context, data.type_);
let (out_of_range, shift_fn): (
*mut LLVMValue,
unsafe extern "C" fn(
LLVMBuilderRef,
LLVMValueRef,
LLVMValueRef,
*const i8,
) -> LLVMValueRef,
) = match data.kind {
ptx_parser::RightShiftKind::Logical => {
(unsafe { LLVMConstNull(llvm_type) }, LLVMBuildLShr)
}
ptx_parser::RightShiftKind::Arithmetic => {
let src1 = self.resolver.value(arguments.src1)?;
let shift_size =
unsafe { LLVMConstInt(llvm_type, (data.type_.size_of() as u64 * 8) - 1, 0) };
let out_of_range =
unsafe { LLVMBuildAShr(self.builder, src1, shift_size, LLVM_UNNAMED.as_ptr()) };
(out_of_range, LLVMBuildAShr)
}
};
self.emit_shift(
data.type_,
arguments.dst,
arguments.src1,
arguments.src2,
out_of_range,
shift_fn,
)
}
@ -1953,11 +2022,13 @@ impl<'a> MethodEmitContext<'a> {
type_: ptx_parser::ScalarType,
arguments: ptx_parser::ShlArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_type = get_scalar_type(self.context, type_);
self.emit_shift(
type_,
arguments.dst,
arguments.src1,
arguments.src2,
unsafe { LLVMConstNull(llvm_type) },
LLVMBuildShl,
)
}
@ -1968,6 +2039,7 @@ impl<'a> MethodEmitContext<'a> {
dst: SpirvWord,
src1: SpirvWord,
src2: SpirvWord,
out_of_range_value: LLVMValueRef,
llvm_fn: unsafe extern "C" fn(
LLVMBuilderRef,
LLVMValueRef,
@ -1995,7 +2067,6 @@ impl<'a> MethodEmitContext<'a> {
)
};
let llvm_type = get_scalar_type(self.context, type_);
let zero = unsafe { LLVMConstNull(llvm_type) };
let normalized_shift_size = if type_.layout().size() >= 4 {
unsafe {
LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr())
@ -2012,7 +2083,7 @@ impl<'a> MethodEmitContext<'a> {
)
};
self.resolver.with_result(dst, |dst| unsafe {
LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst)
LLVMBuildSelect(self.builder, should_clamp, out_of_range_value, shifted, dst)
});
Ok(())
}
@ -2207,7 +2278,19 @@ impl<'a> MethodEmitContext<'a> {
},
)
}
ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()),
ptx_parser::MadDetails::Integer {
saturate: true,
control: ast::MulIntControl::High,
type_: ast::ScalarType::S32,
} => {
return self.emit_mad_hi_sat_s32(
arguments.dst,
(arguments.src1, arguments.src2, arguments.src3),
);
}
ptx_parser::MadDetails::Integer { saturate: true, .. } => {
return Err(error_unreachable())
}
ptx_parser::MadDetails::Integer { type_, control, .. } => {
ast::MulDetails::Integer { control, type_ }
}
@ -2281,7 +2364,8 @@ impl<'a> MethodEmitContext<'a> {
) -> Result<(), TranslateError> {
let llvm_type = get_scalar_type(self.context, data.type_);
let src = self.resolver.value(arguments.src)?;
let (prefix, intrinsic_arguments) = if data.type_.kind() == ast::ScalarKind::Float {
let is_floating_point = data.type_.kind() == ast::ScalarKind::Float;
let (prefix, intrinsic_arguments) = if is_floating_point {
("llvm.fabs", vec![(src, llvm_type)])
} else {
let pred = get_scalar_type(self.context, ast::ScalarType::Pred);
@ -2289,12 +2373,23 @@ impl<'a> MethodEmitContext<'a> {
("llvm.abs", vec![(src, llvm_type), (zero, pred)])
};
let llvm_intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(data.type_));
self.emit_intrinsic(
let abs_result = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
Some(arguments.dst),
None,
Some(&data.type_.into()),
intrinsic_arguments,
)?;
if is_floating_point && data.flush_to_zero == Some(true) {
let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
Some(&data.type_.into()),
vec![(abs_result, llvm_type)],
)?;
} else {
self.resolver.register(arguments.dst, abs_result);
}
Ok(())
}
@ -2434,23 +2529,21 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_add_sat(
fn emit_intrinsic_saturate(
&mut self,
op: &str,
type_: ast::ScalarType,
arguments: ast::AddArgs<SpirvWord>,
dst: SpirvWord,
src1: SpirvWord,
src2: SpirvWord,
) -> Result<(), TranslateError> {
let llvm_type = get_scalar_type(self.context, type_);
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let op = if type_.kind() == ast::ScalarKind::Signed {
"sadd"
} else {
"uadd"
};
let src1 = self.resolver.value(src1)?;
let src2 = self.resolver.value(src2)?;
let intrinsic = format!("llvm.{}.sat.{}\0", op, LLVMTypeDisplay(type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
Some(dst),
Some(&type_.into()),
vec![(src1, llvm_type), (src2, llvm_type)],
)?;
@ -2475,6 +2568,130 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_mad_hi_sat_s32(
&mut self,
dst: SpirvWord,
(src1, src2, src3): (SpirvWord, SpirvWord, SpirvWord),
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(src1)?;
let src2 = self.resolver.value(src2)?;
let src3 = self.resolver.value(src3)?;
let llvm_type_s32 = get_scalar_type(self.context, ast::ScalarType::S32);
let llvm_type_s64 = get_scalar_type(self.context, ast::ScalarType::S64);
let src1_wide =
unsafe { LLVMBuildSExt(self.builder, src1, llvm_type_s64, LLVM_UNNAMED.as_ptr()) };
let src2_wide =
unsafe { LLVMBuildSExt(self.builder, src2, llvm_type_s64, LLVM_UNNAMED.as_ptr()) };
let mul_wide =
unsafe { LLVMBuildMul(self.builder, src1_wide, src2_wide, LLVM_UNNAMED.as_ptr()) };
let const_32 = unsafe { LLVMConstInt(llvm_type_s64, 32, 0) };
let mul_wide =
unsafe { LLVMBuildLShr(self.builder, mul_wide, const_32, LLVM_UNNAMED.as_ptr()) };
let mul_narrow =
unsafe { LLVMBuildTrunc(self.builder, mul_wide, llvm_type_s32, LLVM_UNNAMED.as_ptr()) };
self.emit_intrinsic(
c"llvm.sadd.sat.i32",
Some(dst),
Some(&ast::ScalarType::S32.into()),
vec![(mul_narrow, llvm_type_s32), (src3, llvm_type_s32)],
)?;
Ok(())
}
fn emit_set(
&mut self,
data: ptx_parser::SetData,
arguments: ptx_parser::SetArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let setp_result = self.emit_setp_impl(data.base, None, arguments.src1, arguments.src2)?;
self.setp_to_set(arguments.dst, data.dtype, setp_result)?;
Ok(())
}
fn emit_set_bool(
&mut self,
data: ptx_parser::SetBoolData,
arguments: ptx_parser::SetBoolArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let result =
self.emit_setp_bool_impl(data.base, arguments.src1, arguments.src2, arguments.src3)?;
self.setp_to_set(arguments.dst, data.dtype, result)?;
Ok(())
}
fn emit_setp_bool(
&mut self,
data: ast::SetpBoolData,
args: ast::SetpBoolArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let dst = self.emit_setp_bool_impl(data, args.src1, args.src2, args.src3)?;
self.resolver.register(args.dst1, dst);
Ok(())
}
fn emit_setp_bool_impl(
&mut self,
data: ptx_parser::SetpBoolData,
src1: SpirvWord,
src2: SpirvWord,
src3: SpirvWord,
) -> Result<LLVMValueRef, TranslateError> {
let bool_result = self.emit_setp_impl(data.base, None, src1, src2)?;
let bool_result = if data.negate_src3 {
let constant =
unsafe { LLVMConstInt(LLVMIntTypeInContext(self.context, 1), u64::MAX, 0) };
unsafe { LLVMBuildXor(self.builder, bool_result, constant, LLVM_UNNAMED.as_ptr()) }
} else {
bool_result
};
let post_op = match data.bool_op {
ptx_parser::SetpBoolPostOp::Xor => LLVMBuildXor,
ptx_parser::SetpBoolPostOp::And => LLVMBuildAnd,
ptx_parser::SetpBoolPostOp::Or => LLVMBuildOr,
};
let src3 = self.resolver.value(src3)?;
Ok(unsafe { post_op(self.builder, bool_result, src3, LLVM_UNNAMED.as_ptr()) })
}
fn setp_to_set(
&mut self,
dst: SpirvWord,
dtype: ast::ScalarType,
setp_result: LLVMValueRef,
) -> Result<(), TranslateError> {
let llvm_dtype = get_scalar_type(self.context, dtype);
let zero = unsafe { LLVMConstNull(llvm_dtype) };
let one = if dtype.kind() == ast::ScalarKind::Float {
unsafe { LLVMConstReal(llvm_dtype, 1.0) }
} else {
unsafe { LLVMConstInt(llvm_dtype, u64::MAX, 0) }
};
self.resolver.with_result(dst, |dst| unsafe {
LLVMBuildSelect(self.builder, setp_result, one, zero, dst)
});
Ok(())
}
// TODO: revisit this on gfx1250 which has native tanh support
fn emit_tanh(
&mut self,
data: ast::ScalarType,
arguments: ast::TanhArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
let llvm_type = get_scalar_type(self.context, data);
let name = format!("__ocml_tanh_{}\0", LLVMTypeDisplay(data));
let tanh = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) },
Some(arguments.dst),
Some(&data.into()),
vec![(src, llvm_type)],
)?;
// Not sure if it ultimately does anything
unsafe { LLVMZludaSetFastMathFlags(tanh, LLVMZludaFastMathApproxFunc) }
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
@ -2651,7 +2868,7 @@ fn get_function_type<'a>(
_ => {
check_multiple_return_types(return_args)?;
get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?
},
}
};
Ok(unsafe {

View File

@ -94,6 +94,46 @@ fn run_instruction<'input>(
instruction: ptx_parser::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
Ok(match instruction {
i @ ptx_parser::Instruction::Sqrt {
data:
ast::RcpData {
kind: ast::RcpKind::Approx,
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Rsqrt {
data:
ast::TypeFtz {
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Rcp {
data:
ast::RcpData {
kind: ast::RcpKind::Approx,
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Ex2 {
data:
ast::TypeFtz {
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Lg2 {
data: ast::FlushToZero {
flush_to_zero: false,
},
..
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Activemask { .. } => {
to_call(resolver, fn_declarations, "activemask".into(), i)?
}
@ -116,7 +156,12 @@ fn run_instruction<'input>(
ptx_parser::Reduction::And => "bar_red_and_pred",
ptx_parser::Reduction::Or => "bar_red_or_pred",
};
to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })?
to_call(
resolver,
fn_declarations,
name.into(),
ptx_parser::Instruction::BarRed { data, arguments },
)?
}
ptx_parser::Instruction::ShflSync { data, arguments } => {
let mode = match data.mode {

34
ptx/src/test/ll/abs.ll Normal file
View File

@ -0,0 +1,34 @@
define amdgpu_kernel void @abs(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 {
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i32, align 4, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"30"
"30": ; preds = %1
%"37" = load i64, ptr addrspace(4) %"31", align 4
store i64 %"37", ptr addrspace(5) %"33", align 4
%"38" = load i64, ptr addrspace(4) %"32", align 4
store i64 %"38", ptr addrspace(5) %"34", align 4
%"40" = load i64, ptr addrspace(5) %"33", align 4
%"45" = inttoptr i64 %"40" to ptr
%"39" = load i32, ptr %"45", align 4
store i32 %"39", ptr addrspace(5) %"35", align 4
%"42" = load i32, ptr addrspace(5) %"35", align 4
%"41" = call i32 @llvm.abs.i32(i32 %"42", i1 false)
store i32 %"41", ptr addrspace(5) %"36", align 4
%"43" = load i64, ptr addrspace(5) %"34", align 4
%"44" = load i32, ptr addrspace(5) %"36", align 4
%"46" = inttoptr i64 %"43" to ptr
store i32 %"44", ptr %"46", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.abs.i32(i32, i1 immarg) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View File

@ -17,8 +17,9 @@ define amdgpu_kernel void @shr(ptr addrspace(4) byref(i64) %"31", ptr addrspace(
%"38" = load i32, ptr %"44", align 4
store i32 %"38", ptr addrspace(5) %"35", align 4
%"41" = load i32, ptr addrspace(5) %"35", align 4
%2 = ashr i32 %"41", 1
%"40" = select i1 false, i32 0, i32 %2
%2 = ashr i32 %"41", 31
%3 = ashr i32 %"41", 1
%"40" = select i1 false, i32 %2, i32 %3
store i32 %"40", ptr addrspace(5) %"35", align 4
%"42" = load i64, ptr addrspace(5) %"34", align 8
%"43" = load i32, ptr addrspace(5) %"35", align 4

View File

@ -0,0 +1,31 @@
define amdgpu_kernel void @shr_oob(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 {
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i16, align 2, addrspace(5)
br label %1
1: ; preds = %0
br label %"30"
"30": ; preds = %1
%"36" = load i64, ptr addrspace(4) %"31", align 4
store i64 %"36", ptr addrspace(5) %"33", align 4
%"37" = load i64, ptr addrspace(4) %"32", align 4
store i64 %"37", ptr addrspace(5) %"34", align 4
%"39" = load i64, ptr addrspace(5) %"33", align 4
%"44" = inttoptr i64 %"39" to ptr
%"38" = load i16, ptr %"44", align 2
store i16 %"38", ptr addrspace(5) %"35", align 2
%"41" = load i16, ptr addrspace(5) %"35", align 2
%2 = ashr i16 %"41", 15
%3 = ashr i16 %"41", 16
%"40" = select i1 true, i16 %2, i16 %3
store i16 %"40", ptr addrspace(5) %"35", align 2
%"42" = load i64, ptr addrspace(5) %"34", align 4
%"43" = load i16, ptr addrspace(5) %"35", align 2
%"45" = inttoptr i64 %"42" to ptr
store i16 %"43", ptr %"45", align 2
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

31
ptx/src/test/ll/tanh.ll Normal file
View File

@ -0,0 +1,31 @@
define amdgpu_kernel void @tanh(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #0 {
%"32" = alloca i64, align 8, addrspace(5)
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca float, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"29"
"29": ; preds = %1
%"35" = load i64, ptr addrspace(4) %"30", align 4
store i64 %"35", ptr addrspace(5) %"32", align 4
%"36" = load i64, ptr addrspace(4) %"31", align 4
store i64 %"36", ptr addrspace(5) %"33", align 4
%"38" = load i64, ptr addrspace(5) %"32", align 4
%"43" = inttoptr i64 %"38" to ptr
%"37" = load float, ptr %"43", align 4
store float %"37", ptr addrspace(5) %"34", align 4
%"40" = load float, ptr addrspace(5) %"34", align 4
%"39" = call afn float @__ocml_tanh_f32(float %"40")
store float %"39", ptr addrspace(5) %"34", align 4
%"41" = load i64, ptr addrspace(5) %"33", align 4
%"42" = load float, ptr addrspace(5) %"34", align 4
%"44" = inttoptr i64 %"41" to ptr
store float %"42", ptr %"44", align 4
ret void
}
declare float @__ocml_tanh_f32(float)
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry abs(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp1;
.reg .s32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp1, [in_addr];
abs.s32 temp2, temp1;
st.s32 [out_addr], temp2;
ret;
}

View File

@ -159,6 +159,7 @@ test_ptx!(
);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(shr_oob, [-32768i16], [-1i16]);
test_ptx!(or, [1u64, 2u64], [3u64]);
test_ptx!(sub, [2u64], [1u64]);
test_ptx!(min, [555i32, 444i32], [444i32]);
@ -180,6 +181,7 @@ test_ptx!(
[0b1_00000000_01000000000000000000000u32]
);
test_ptx!(constant_f32, [10f32], [5f32]);
test_ptx!(abs, [i32::MIN], [i32::MIN]);
test_ptx!(constant_negative, [-101i32], [101i32]);
test_ptx!(and, [6u32, 3u32], [2u32]);
test_ptx!(selp, [100u16, 200u16], [200u16]);
@ -296,6 +298,7 @@ test_ptx!(
);
test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
test_ptx!(warp_sz, [0u8], [32u8]);
test_ptx!(tanh, [f32::INFINITY], [1.0f32]);
test_ptx!(nanosleep, [0u64], [0u64]);

View File

@ -0,0 +1,21 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry shr_oob(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s16 temp;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s16 temp, [in_addr];
shr.s16 temp, temp, 16;
st.s16 [out_addr], temp;
ret;
}

View File

@ -0,0 +1,21 @@
.version 7.0
.target sm_75
.address_size 64
.visible .entry tanh(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp1;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 temp1, [in_addr];
tanh.approx.f32 temp1, temp1;
st.f32 [out_addr], temp1;
ret;
}

View File

@ -428,6 +428,44 @@ ptx_parser_macros::generate_instruction_type!(
},
}
},
Set {
data: SetData,
arguments<T>: {
dst: {
repr: T,
type: Type::from(data.dtype)
},
src1: {
repr: T,
type: Type::from(data.base.type_),
},
src2: {
repr: T,
type: Type::from(data.base.type_),
}
}
},
SetBool {
data: SetBoolData,
arguments<T>: {
dst: {
repr: T,
type: Type::from(data.dtype)
},
src1: {
repr: T,
type: Type::from(data.base.base.type_),
},
src2: {
repr: T,
type: Type::from(data.base.base.type_),
},
src3: {
repr: T,
type: Type::from(ScalarType::Pred)
}
}
},
Setp {
data: SetpData,
arguments<T>: {
@ -562,6 +600,14 @@ ptx_parser_macros::generate_instruction_type!(
src2: T
}
},
Tanh {
type: Type::Scalar(data.clone()),
data: ScalarType,
arguments<T>: {
dst: T,
src: T
}
},
}
);
@ -1247,6 +1293,11 @@ pub struct Mul24Details {
pub control: Mul24Control,
}
pub struct SetData {
pub dtype: ScalarType,
pub base: SetpData,
}
pub struct SetpData {
pub type_: ScalarType,
pub flush_to_zero: Option<bool>,
@ -1288,6 +1339,12 @@ impl SetpData {
}
}
pub struct SetBoolData {
pub dtype: ScalarType,
pub base: SetpBoolData,
}
pub struct SetpBoolData {
pub base: SetpData,
pub bool_op: SetpBoolPostOp,

View File

@ -442,7 +442,7 @@ fn directive<'a, 'input>(
)),
(
any,
take_till(1.., |(token, _)| match token {
take_till(1.., |(token, _)| match token {
// visibility
Token::DotExtern | Token::DotVisible | Token::DotWeak
// methods
@ -2201,6 +2201,43 @@ derive_parser!(
.rnd: RawRoundingMode = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#comparison-and-selection-instructions-set
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-comparison-instructions-set
set.CmpOp{.ftz}.dtype.stype d, a, b => {
let base = ast::SetpData::try_parse(state, cmpop, ftz, stype);
let data = ast::SetData {
base, dtype
};
ast::Instruction::Set {
data,
arguments: SetArgs { dst: d, src1: a, src2: b }
}
}
set.CmpOp.BoolOp{.ftz}.dtype.stype d, a, b, {!}c => {
let (negate_src3, c) = c;
let base = ast::SetpData::try_parse(state, cmpop, ftz, stype);
let base = ast::SetpBoolData {
base,
bool_op: boolop,
negate_src3
};
let data = ast::SetBoolData {
base, dtype
};
ast::Instruction::SetBool {
data,
arguments: SetBoolArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge,
.lo, .ls, .hi, .hs, // signed
.equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only
.BoolOp: SetpBoolPostOp = { .and, .or, .xor };
.dtype: ScalarType = { .u32, .s32, .f32 };
.stype: ScalarType = { .b16, .b32, .b64, .u16, .u32, .u64, .s16, .s32, .s64, .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
setp.CmpOp{.ftz}.type p[|q], a, b => {
@ -3509,6 +3546,16 @@ derive_parser!(
arguments: NanosleepArgs { src: t }
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-tanh
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-tanh
tanh.approx.type d, a => {
Instruction::Tanh {
data: type_,
arguments: TanhArgs { dst: d, src: a }
}
}
.type: ScalarType = { .f32, .f16, .f16x2, .bf16, .bf16x2 };
);
#[cfg(test)]