mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-25 13:16:23 +03:00
Assorted instruction fixes (#423)
This fixes transcendentals and some other buggy instructions exposed by `ptx_tests` (abs, neg). Add (slow - hardware limitation) tanh. Only two remaining incorrect instructions are div and sqrt with non-default rounding, but this commit is already bloated enough
This commit is contained in:
Binary file not shown.
@ -1,10 +1,11 @@
|
||||
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
|
||||
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
|
||||
// /opt/rocm/llvm/bin/clang -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include <bit>
|
||||
#include <cmath>
|
||||
#include <hip/amd_detail/amd_device_functions.h>
|
||||
|
||||
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
|
||||
@ -163,7 +164,7 @@ extern "C"
|
||||
int32_t __ockl_wgred_and_i32(int32_t) __device__;
|
||||
int32_t __ockl_wgred_or_i32(int32_t) __device__;
|
||||
|
||||
#define BAR_RED_IMPL(reducer) \
|
||||
#define BAR_RED_IMPL(reducer) \
|
||||
bool FUNC(bar_red_##reducer##_pred)(uint32_t barrier __attribute__((unused)), bool predicate, bool invert_predicate) \
|
||||
{ \
|
||||
/* TODO: handle barrier */ \
|
||||
@ -173,7 +174,8 @@ extern "C"
|
||||
BAR_RED_IMPL(and);
|
||||
BAR_RED_IMPL(or);
|
||||
|
||||
struct ShflSyncResult {
|
||||
struct ShflSyncResult
|
||||
{
|
||||
uint32_t output;
|
||||
bool in_bounds;
|
||||
};
|
||||
@ -191,7 +193,7 @@ extern "C"
|
||||
// intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the
|
||||
// bounds check.
|
||||
|
||||
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
|
||||
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
|
||||
ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \
|
||||
{ \
|
||||
int32_t section_mask = (opts >> 8) & 0b11111; \
|
||||
@ -201,10 +203,11 @@ extern "C"
|
||||
int32_t subsection_end = subsection | (~section_mask & warp_end); \
|
||||
int32_t idx = calculate_index; \
|
||||
bool out_of_bounds = idx CMP subsection_end; \
|
||||
if (out_of_bounds) { \
|
||||
if (out_of_bounds) \
|
||||
{ \
|
||||
idx = self; \
|
||||
} \
|
||||
int32_t output = __builtin_amdgcn_ds_bpermute(idx<<2, (int32_t)input); \
|
||||
int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \
|
||||
return {(uint32_t)output, !out_of_bounds}; \
|
||||
} \
|
||||
\
|
||||
@ -212,15 +215,15 @@ extern "C"
|
||||
{ \
|
||||
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \
|
||||
}
|
||||
|
||||
|
||||
// We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync
|
||||
// intrinsics, as those only add an assertion checking that the membermask is used correctly.
|
||||
// They do not return the result of the range check, so we must replicate that logic here.
|
||||
|
||||
SHFL_SYNC_IMPL(up, self - delta, <);
|
||||
SHFL_SYNC_IMPL(down, self + delta, >);
|
||||
SHFL_SYNC_IMPL(bfly, self ^ delta, >);
|
||||
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
|
||||
SHFL_SYNC_IMPL(up, self - delta, <);
|
||||
SHFL_SYNC_IMPL(down, self + delta, >);
|
||||
SHFL_SYNC_IMPL(bfly, self ^ delta, >);
|
||||
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
|
||||
|
||||
DECLARE_ATTR(uint32_t, CLOCK_RATE);
|
||||
void FUNC(nanosleep_u32)(uint32_t nanoseconds) {
|
||||
@ -254,4 +257,82 @@ extern "C"
|
||||
(void)char_size;
|
||||
__assert_fail((const char *)message, (const char *)file, line, (const char *)function);
|
||||
}
|
||||
|
||||
// * Smallest denormal is 1.4 × 10^-45
|
||||
// * Smallest normal is ~1.175494351 × 10^(-38)
|
||||
// * Now, 1.175494351×10^-38 / 1.4 × 10^-45 = 8396388 + 31/140
|
||||
// * Next power of 2 is 16777216
|
||||
const float DENORMAL_TO_NORMAL_FACTOR_F32 = 16777216.0f;
|
||||
// * Largest subnormal is ~1.175494210692441e × 10^(-38)
|
||||
// * Then any value equal or larger than following will produce subnormals: 8.50706018714406320806444272332455743547934627837873057975602739772164... × 10^37
|
||||
const float RCP_DENORMAL_OUTPUT = 8.50706018714406320806444272332455743547934627837873057975602739772164e37f;
|
||||
const float REVERSE_DENORMAL_TO_NORMAL_FACTOR_F32 = 0.029387360490963111877208252592662410455594571842846914442095471744599661631813495980086003637902577995683214210345151992265999035207077609582844f;
|
||||
|
||||
float FUNC(sqrt_approx_f32)(float x)
|
||||
{
|
||||
bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL);
|
||||
float input = x;
|
||||
if (is_subnormal)
|
||||
input = x * DENORMAL_TO_NORMAL_FACTOR_F32;
|
||||
float value = __builtin_amdgcn_sqrtf(input);
|
||||
if (is_subnormal)
|
||||
return value * 0.000244140625f;
|
||||
else
|
||||
return value;
|
||||
}
|
||||
|
||||
float FUNC(rsqrt_approx_f32)(float x)
|
||||
{
|
||||
bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL);
|
||||
float input = x;
|
||||
if (is_subnormal)
|
||||
input = x * DENORMAL_TO_NORMAL_FACTOR_F32;
|
||||
float value = __builtin_amdgcn_rsqf(input);
|
||||
if (is_subnormal)
|
||||
return value * 4096.0f;
|
||||
else
|
||||
return value;
|
||||
}
|
||||
|
||||
float FUNC(rcp_approx_f32)(float x)
|
||||
{
|
||||
float factor = 1.0f;
|
||||
if (__builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL))
|
||||
{
|
||||
factor = DENORMAL_TO_NORMAL_FACTOR_F32;
|
||||
}
|
||||
if (std::fabs(x) >= RCP_DENORMAL_OUTPUT)
|
||||
{
|
||||
factor = REVERSE_DENORMAL_TO_NORMAL_FACTOR_F32;
|
||||
}
|
||||
return __builtin_amdgcn_rcpf(x * factor) * factor;
|
||||
}
|
||||
|
||||
// When x = -126, exp2(x) = 2^(-126) ≈ 1.175494351 × 10^(-38),
|
||||
// which is the smallest normalized number in FP32
|
||||
float FUNC(ex2_approx_f32)(float x)
|
||||
{
|
||||
bool special_handling = x < -126.0f;
|
||||
float input = x;
|
||||
if (special_handling)
|
||||
input *= 0.5f;
|
||||
float result = __builtin_amdgcn_exp2f(input);
|
||||
if (special_handling)
|
||||
return result * result;
|
||||
else
|
||||
return result;
|
||||
}
|
||||
|
||||
float FUNC(lg2_approx_f32)(float x)
|
||||
{
|
||||
bool is_subnormal = __builtin_isfpclass(x, __FPCLASS_NEGSUBNORMAL | __FPCLASS_POSSUBNORMAL);
|
||||
float input = x;
|
||||
if (is_subnormal)
|
||||
input = x * DENORMAL_TO_NORMAL_FACTOR_F32;
|
||||
float value = __builtin_amdgcn_logf(input);
|
||||
if (is_subnormal)
|
||||
return value - 24.0f;
|
||||
else
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
@ -164,6 +164,8 @@ fn run_instruction<'input>(
|
||||
| ast::Instruction::Ret { .. }
|
||||
| ast::Instruction::Rsqrt { .. }
|
||||
| ast::Instruction::Selp { .. }
|
||||
| ast::Instruction::Set { .. }
|
||||
| ast::Instruction::SetBool { .. }
|
||||
| ast::Instruction::Setp { .. }
|
||||
| ast::Instruction::SetpBool { .. }
|
||||
| ast::Instruction::ShflSync { .. }
|
||||
@ -183,6 +185,7 @@ fn run_instruction<'input>(
|
||||
data: ast::ArithDetails::Integer(..),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Tanh { .. }
|
||||
| ast::Instruction::Trap {}
|
||||
| ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)),
|
||||
ast::Instruction::Add {
|
||||
|
@ -221,12 +221,25 @@ impl InstructionModes {
|
||||
)
|
||||
}
|
||||
|
||||
fn from_rcp(data: ast::RcpData) -> InstructionModes {
|
||||
fn from_rtz_special(data: ast::RcpData) -> InstructionModes {
|
||||
let rounding = match data.kind {
|
||||
ast::RcpKind::Approx => None,
|
||||
ast::RcpKind::Compliant(rnd) => Some(RoundingMode::from_ast(rnd)),
|
||||
};
|
||||
let denormal = data.flush_to_zero.map(DenormalMode::from_ftz);
|
||||
let denormal = match (
|
||||
data.kind == ast::RcpKind::Approx,
|
||||
data.flush_to_zero == Some(true),
|
||||
) {
|
||||
// If we are approximate and flushing then we compile to V_RSQ_F32
|
||||
// or V_SQRT_F32 which ignores prevailing denormal mode and flushes anyway
|
||||
(true, true) => None,
|
||||
// If we are approximate and flushing the V_RSQ_F32/V_SQRT_F32
|
||||
// ignores ftz mode, but we implement instruction in terms of fmuls
|
||||
// which must preserve denormals
|
||||
(true, false) => Some(DenormalMode::Preserve),
|
||||
// For full precision we set denormal mode accordingly
|
||||
(false, ftz) => Some(DenormalMode::from_ftz(ftz)),
|
||||
};
|
||||
InstructionModes::new(data.type_, denormal, rounding)
|
||||
}
|
||||
|
||||
@ -1780,6 +1793,11 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||
match inst {
|
||||
// TODO: review it when implementing virtual calls
|
||||
ast::Instruction::Call { .. }
|
||||
// abs and neg have special handling in AMD GPU ISA. They get compiled
|
||||
// down to instruction argument modifiers, floating point flags have no
|
||||
// effect on it. We handle it during LLVM bitcode emission
|
||||
| ast::Instruction::Abs { .. }
|
||||
| ast::Instruction::Neg {.. }
|
||||
| ast::Instruction::Mov { .. }
|
||||
| ast::Instruction::Ld { .. }
|
||||
| ast::Instruction::St { .. }
|
||||
@ -1856,7 +1874,31 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||
data: ast::ArithDetails::Float(data),
|
||||
..
|
||||
} => InstructionModes::from_arith_float(data),
|
||||
ast::Instruction::Setp {
|
||||
ast::Instruction::Set {
|
||||
data: ast::SetData{
|
||||
base: ast::SetpData {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
..
|
||||
},
|
||||
..
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::SetBool {
|
||||
data: ast::SetBoolData {
|
||||
base: ast::SetpBoolData {
|
||||
base: ast::SetpData {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
..
|
||||
},
|
||||
..
|
||||
},
|
||||
..
|
||||
}, ..
|
||||
}
|
||||
| ast::Instruction::Setp {
|
||||
data:
|
||||
ast::SetpData {
|
||||
type_,
|
||||
@ -1878,34 +1920,6 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Neg {
|
||||
data: ast::TypeFtz {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Ex2 {
|
||||
data: ast::TypeFtz {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Rsqrt {
|
||||
data: ast::TypeFtz {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Abs {
|
||||
data: ast::TypeFtz {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Min {
|
||||
data:
|
||||
ast::MinMaxDetails::Float(ast::MinMaxFloat {
|
||||
@ -1945,12 +1959,29 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||
)
|
||||
}
|
||||
ast::Instruction::Sin { data, .. }
|
||||
| ast::Instruction::Cos { data, .. }
|
||||
| ast::Instruction::Lg2 { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero),
|
||||
| ast::Instruction::Cos { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero),
|
||||
ast::Instruction::Rcp { data, .. } | ast::Instruction::Sqrt { data, .. } => {
|
||||
InstructionModes::from_rcp(*data)
|
||||
InstructionModes::from_rtz_special(*data)
|
||||
}
|
||||
ast::Instruction::Rsqrt { data, .. }
|
||||
| ast::Instruction::Ex2 { data, .. } => {
|
||||
let data = ast::RcpData {
|
||||
type_: data.type_,
|
||||
flush_to_zero: data.flush_to_zero,
|
||||
kind: ast::RcpKind::Approx,
|
||||
};
|
||||
InstructionModes::from_rtz_special(data)
|
||||
},
|
||||
ast::Instruction::Lg2 { data, .. } => {
|
||||
let data = ast::RcpData {
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: Some(data.flush_to_zero),
|
||||
kind: ast::RcpKind::Approx,
|
||||
};
|
||||
InstructionModes::from_rtz_special(data)
|
||||
},
|
||||
ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data),
|
||||
ast::Instruction::Tanh { data, .. } => InstructionModes::from_ftz(*data, Some(false)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@
|
||||
use std::array::TryFromSliceError;
|
||||
use std::convert::TryInto;
|
||||
use std::ffi::{CStr, NulError};
|
||||
use std::{i8, ptr};
|
||||
use std::{i8, ptr, u64};
|
||||
|
||||
use super::*;
|
||||
use crate::pass::*;
|
||||
@ -471,8 +471,10 @@ impl<'a> MethodEmitContext<'a> {
|
||||
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
|
||||
ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
|
||||
ast::Instruction::Mul24 { data, arguments } => self.emit_mul24(data, arguments),
|
||||
ast::Instruction::Set { data, arguments } => self.emit_set(data, arguments),
|
||||
ast::Instruction::SetBool { data, arguments } => self.emit_set_bool(data, arguments),
|
||||
ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
|
||||
ast::Instruction::SetpBool { .. } => todo!(),
|
||||
ast::Instruction::SetpBool { data, arguments } => self.emit_setp_bool(data, arguments),
|
||||
ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
|
||||
ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments),
|
||||
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
|
||||
@ -506,10 +508,13 @@ impl<'a> MethodEmitContext<'a> {
|
||||
ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments),
|
||||
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
|
||||
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
|
||||
ast::Instruction::PrmtSlow { .. } => todo!(),
|
||||
ast::Instruction::PrmtSlow { .. } => {
|
||||
Err(error_todo_msg("PrmtSlow is not implemented yet"))
|
||||
}
|
||||
ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
|
||||
ast::Instruction::Membar { data } => self.emit_membar(data),
|
||||
ast::Instruction::Trap {} => todo!(),
|
||||
ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")),
|
||||
ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments),
|
||||
// replaced by a function call
|
||||
ast::Instruction::Bfe { .. }
|
||||
| ast::Instruction::Bar { .. }
|
||||
@ -713,18 +718,31 @@ impl<'a> MethodEmitContext<'a> {
|
||||
arguments: ast::AddArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let builder = self.builder;
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let fn_ = match data {
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
saturate: true,
|
||||
type_,
|
||||
}) => return self.emit_add_sat(type_, arguments),
|
||||
}) => {
|
||||
let op = if type_.kind() == ast::ScalarKind::Signed {
|
||||
"sadd"
|
||||
} else {
|
||||
"uadd"
|
||||
};
|
||||
return self.emit_intrinsic_saturate(
|
||||
op,
|
||||
type_,
|
||||
arguments.dst,
|
||||
arguments.src1,
|
||||
arguments.src2,
|
||||
);
|
||||
}
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
saturate: false, ..
|
||||
}) => LLVMBuildAdd,
|
||||
ast::ArithDetails::Float(..) => LLVMBuildFAdd,
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
fn_(builder, src1, src2, dst)
|
||||
});
|
||||
@ -1353,7 +1371,18 @@ impl<'a> MethodEmitContext<'a> {
|
||||
arguments: ptx_parser::SubArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if arith_integer.saturate {
|
||||
todo!()
|
||||
let op = if arith_integer.type_.kind() == ast::ScalarKind::Signed {
|
||||
"ssub"
|
||||
} else {
|
||||
"usub"
|
||||
};
|
||||
return self.emit_intrinsic_saturate(
|
||||
op,
|
||||
arith_integer.type_,
|
||||
arguments.dst,
|
||||
arguments.src1,
|
||||
arguments.src2,
|
||||
);
|
||||
}
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
@ -1365,12 +1394,9 @@ impl<'a> MethodEmitContext<'a> {
|
||||
|
||||
fn emit_sub_float(
|
||||
&mut self,
|
||||
arith_float: ptx_parser::ArithFloat,
|
||||
_arith_float: ptx_parser::ArithFloat,
|
||||
arguments: ptx_parser::SubArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if arith_float.saturate {
|
||||
todo!()
|
||||
}
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
@ -1430,25 +1456,39 @@ impl<'a> MethodEmitContext<'a> {
|
||||
arguments: ptx_parser::NegArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float {
|
||||
let is_floating_point = data.type_.kind() == ptx_parser::ScalarKind::Float;
|
||||
let llvm_fn = if is_floating_point {
|
||||
LLVMBuildFNeg
|
||||
} else {
|
||||
LLVMBuildNeg
|
||||
};
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
llvm_fn(self.builder, src, dst)
|
||||
});
|
||||
if is_floating_point && data.flush_to_zero == Some(true) {
|
||||
let negated = unsafe { llvm_fn(self.builder, src, LLVM_UNNAMED.as_ptr()) };
|
||||
let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_));
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
Some(&data.type_.into()),
|
||||
vec![(negated, get_scalar_type(self.context, data.type_))],
|
||||
)?;
|
||||
} else {
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
llvm_fn(self.builder, src, dst)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_not(
|
||||
&mut self,
|
||||
_data: ptx_parser::ScalarType,
|
||||
type_: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::NotArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
let type_ = get_scalar_type(self.context, type_);
|
||||
let constant = unsafe { LLVMConstInt(type_, u64::MAX, 0) };
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildNot(self.builder, src, dst)
|
||||
LLVMBuildXor(self.builder, src, constant, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
@ -1458,15 +1498,29 @@ impl<'a> MethodEmitContext<'a> {
|
||||
data: ptx_parser::SetpData,
|
||||
arguments: ptx_parser::SetpArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if arguments.dst2.is_some() {
|
||||
todo!()
|
||||
let dst = self.emit_setp_impl(data, arguments.dst2, arguments.src1, arguments.src2)?;
|
||||
self.resolver.register(arguments.dst1, dst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_setp_impl(
|
||||
&mut self,
|
||||
data: ptx_parser::SetpData,
|
||||
dst2: Option<SpirvWord>,
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
if dst2.is_some() {
|
||||
return Err(error_todo_msg(
|
||||
"setp with two dst arguments not yet supported",
|
||||
));
|
||||
}
|
||||
match data.cmp_op {
|
||||
ptx_parser::SetpCompareOp::Integer(setp_compare_int) => {
|
||||
self.emit_setp_int(setp_compare_int, arguments)
|
||||
self.emit_setp_int(setp_compare_int, src1, src2)
|
||||
}
|
||||
ptx_parser::SetpCompareOp::Float(setp_compare_float) => {
|
||||
self.emit_setp_float(setp_compare_float, arguments)
|
||||
self.emit_setp_float(setp_compare_float, src1, src2)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1474,8 +1528,9 @@ impl<'a> MethodEmitContext<'a> {
|
||||
fn emit_setp_int(
|
||||
&mut self,
|
||||
setp: ptx_parser::SetpCompareInt,
|
||||
arguments: ptx_parser::SetpArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let op = match setp {
|
||||
ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ,
|
||||
ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE,
|
||||
@ -1488,19 +1543,17 @@ impl<'a> MethodEmitContext<'a> {
|
||||
ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT,
|
||||
ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE,
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
|
||||
LLVMBuildICmp(self.builder, op, src1, src2, dst1)
|
||||
});
|
||||
Ok(())
|
||||
let src1 = self.resolver.value(src1)?;
|
||||
let src2 = self.resolver.value(src2)?;
|
||||
Ok(unsafe { LLVMBuildICmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) })
|
||||
}
|
||||
|
||||
fn emit_setp_float(
|
||||
&mut self,
|
||||
setp: ptx_parser::SetpCompareFloat,
|
||||
arguments: ptx_parser::SetpArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let op = match setp {
|
||||
ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ,
|
||||
ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE,
|
||||
@ -1517,12 +1570,9 @@ impl<'a> MethodEmitContext<'a> {
|
||||
ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD,
|
||||
ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO,
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
|
||||
LLVMBuildFCmp(self.builder, op, src1, src2, dst1)
|
||||
});
|
||||
Ok(())
|
||||
let src1 = self.resolver.value(src1)?;
|
||||
let src2 = self.resolver.value(src2)?;
|
||||
Ok(unsafe { LLVMBuildFCmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) })
|
||||
}
|
||||
|
||||
fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> {
|
||||
@ -1935,15 +1985,34 @@ impl<'a> MethodEmitContext<'a> {
|
||||
data: ptx_parser::ShrData,
|
||||
arguments: ptx_parser::ShrArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let shift_fn = match data.kind {
|
||||
ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr,
|
||||
ptx_parser::RightShiftKind::Logical => LLVMBuildLShr,
|
||||
let llvm_type = get_scalar_type(self.context, data.type_);
|
||||
let (out_of_range, shift_fn): (
|
||||
*mut LLVMValue,
|
||||
unsafe extern "C" fn(
|
||||
LLVMBuilderRef,
|
||||
LLVMValueRef,
|
||||
LLVMValueRef,
|
||||
*const i8,
|
||||
) -> LLVMValueRef,
|
||||
) = match data.kind {
|
||||
ptx_parser::RightShiftKind::Logical => {
|
||||
(unsafe { LLVMConstNull(llvm_type) }, LLVMBuildLShr)
|
||||
}
|
||||
ptx_parser::RightShiftKind::Arithmetic => {
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let shift_size =
|
||||
unsafe { LLVMConstInt(llvm_type, (data.type_.size_of() as u64 * 8) - 1, 0) };
|
||||
let out_of_range =
|
||||
unsafe { LLVMBuildAShr(self.builder, src1, shift_size, LLVM_UNNAMED.as_ptr()) };
|
||||
(out_of_range, LLVMBuildAShr)
|
||||
}
|
||||
};
|
||||
self.emit_shift(
|
||||
data.type_,
|
||||
arguments.dst,
|
||||
arguments.src1,
|
||||
arguments.src2,
|
||||
out_of_range,
|
||||
shift_fn,
|
||||
)
|
||||
}
|
||||
@ -1953,11 +2022,13 @@ impl<'a> MethodEmitContext<'a> {
|
||||
type_: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::ShlArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_type = get_scalar_type(self.context, type_);
|
||||
self.emit_shift(
|
||||
type_,
|
||||
arguments.dst,
|
||||
arguments.src1,
|
||||
arguments.src2,
|
||||
unsafe { LLVMConstNull(llvm_type) },
|
||||
LLVMBuildShl,
|
||||
)
|
||||
}
|
||||
@ -1968,6 +2039,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
dst: SpirvWord,
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
out_of_range_value: LLVMValueRef,
|
||||
llvm_fn: unsafe extern "C" fn(
|
||||
LLVMBuilderRef,
|
||||
LLVMValueRef,
|
||||
@ -1995,7 +2067,6 @@ impl<'a> MethodEmitContext<'a> {
|
||||
)
|
||||
};
|
||||
let llvm_type = get_scalar_type(self.context, type_);
|
||||
let zero = unsafe { LLVMConstNull(llvm_type) };
|
||||
let normalized_shift_size = if type_.layout().size() >= 4 {
|
||||
unsafe {
|
||||
LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr())
|
||||
@ -2012,7 +2083,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
)
|
||||
};
|
||||
self.resolver.with_result(dst, |dst| unsafe {
|
||||
LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst)
|
||||
LLVMBuildSelect(self.builder, should_clamp, out_of_range_value, shifted, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
@ -2207,7 +2278,19 @@ impl<'a> MethodEmitContext<'a> {
|
||||
},
|
||||
)
|
||||
}
|
||||
ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()),
|
||||
ptx_parser::MadDetails::Integer {
|
||||
saturate: true,
|
||||
control: ast::MulIntControl::High,
|
||||
type_: ast::ScalarType::S32,
|
||||
} => {
|
||||
return self.emit_mad_hi_sat_s32(
|
||||
arguments.dst,
|
||||
(arguments.src1, arguments.src2, arguments.src3),
|
||||
);
|
||||
}
|
||||
ptx_parser::MadDetails::Integer { saturate: true, .. } => {
|
||||
return Err(error_unreachable())
|
||||
}
|
||||
ptx_parser::MadDetails::Integer { type_, control, .. } => {
|
||||
ast::MulDetails::Integer { control, type_ }
|
||||
}
|
||||
@ -2281,7 +2364,8 @@ impl<'a> MethodEmitContext<'a> {
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_type = get_scalar_type(self.context, data.type_);
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
let (prefix, intrinsic_arguments) = if data.type_.kind() == ast::ScalarKind::Float {
|
||||
let is_floating_point = data.type_.kind() == ast::ScalarKind::Float;
|
||||
let (prefix, intrinsic_arguments) = if is_floating_point {
|
||||
("llvm.fabs", vec![(src, llvm_type)])
|
||||
} else {
|
||||
let pred = get_scalar_type(self.context, ast::ScalarType::Pred);
|
||||
@ -2289,12 +2373,23 @@ impl<'a> MethodEmitContext<'a> {
|
||||
("llvm.abs", vec![(src, llvm_type), (zero, pred)])
|
||||
};
|
||||
let llvm_intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(data.type_));
|
||||
self.emit_intrinsic(
|
||||
let abs_result = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
None,
|
||||
Some(&data.type_.into()),
|
||||
intrinsic_arguments,
|
||||
)?;
|
||||
if is_floating_point && data.flush_to_zero == Some(true) {
|
||||
let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_));
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
Some(&data.type_.into()),
|
||||
vec![(abs_result, llvm_type)],
|
||||
)?;
|
||||
} else {
|
||||
self.resolver.register(arguments.dst, abs_result);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -2434,23 +2529,21 @@ impl<'a> MethodEmitContext<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_add_sat(
|
||||
fn emit_intrinsic_saturate(
|
||||
&mut self,
|
||||
op: &str,
|
||||
type_: ast::ScalarType,
|
||||
arguments: ast::AddArgs<SpirvWord>,
|
||||
dst: SpirvWord,
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_type = get_scalar_type(self.context, type_);
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let op = if type_.kind() == ast::ScalarKind::Signed {
|
||||
"sadd"
|
||||
} else {
|
||||
"uadd"
|
||||
};
|
||||
let src1 = self.resolver.value(src1)?;
|
||||
let src2 = self.resolver.value(src2)?;
|
||||
let intrinsic = format!("llvm.{}.sat.{}\0", op, LLVMTypeDisplay(type_));
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
Some(dst),
|
||||
Some(&type_.into()),
|
||||
vec![(src1, llvm_type), (src2, llvm_type)],
|
||||
)?;
|
||||
@ -2475,6 +2568,130 @@ impl<'a> MethodEmitContext<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mad_hi_sat_s32(
|
||||
&mut self,
|
||||
dst: SpirvWord,
|
||||
(src1, src2, src3): (SpirvWord, SpirvWord, SpirvWord),
|
||||
) -> Result<(), TranslateError> {
|
||||
let src1 = self.resolver.value(src1)?;
|
||||
let src2 = self.resolver.value(src2)?;
|
||||
let src3 = self.resolver.value(src3)?;
|
||||
let llvm_type_s32 = get_scalar_type(self.context, ast::ScalarType::S32);
|
||||
let llvm_type_s64 = get_scalar_type(self.context, ast::ScalarType::S64);
|
||||
let src1_wide =
|
||||
unsafe { LLVMBuildSExt(self.builder, src1, llvm_type_s64, LLVM_UNNAMED.as_ptr()) };
|
||||
let src2_wide =
|
||||
unsafe { LLVMBuildSExt(self.builder, src2, llvm_type_s64, LLVM_UNNAMED.as_ptr()) };
|
||||
let mul_wide =
|
||||
unsafe { LLVMBuildMul(self.builder, src1_wide, src2_wide, LLVM_UNNAMED.as_ptr()) };
|
||||
let const_32 = unsafe { LLVMConstInt(llvm_type_s64, 32, 0) };
|
||||
let mul_wide =
|
||||
unsafe { LLVMBuildLShr(self.builder, mul_wide, const_32, LLVM_UNNAMED.as_ptr()) };
|
||||
let mul_narrow =
|
||||
unsafe { LLVMBuildTrunc(self.builder, mul_wide, llvm_type_s32, LLVM_UNNAMED.as_ptr()) };
|
||||
self.emit_intrinsic(
|
||||
c"llvm.sadd.sat.i32",
|
||||
Some(dst),
|
||||
Some(&ast::ScalarType::S32.into()),
|
||||
vec![(mul_narrow, llvm_type_s32), (src3, llvm_type_s32)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_set(
|
||||
&mut self,
|
||||
data: ptx_parser::SetData,
|
||||
arguments: ptx_parser::SetArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let setp_result = self.emit_setp_impl(data.base, None, arguments.src1, arguments.src2)?;
|
||||
self.setp_to_set(arguments.dst, data.dtype, setp_result)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_set_bool(
|
||||
&mut self,
|
||||
data: ptx_parser::SetBoolData,
|
||||
arguments: ptx_parser::SetBoolArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let result =
|
||||
self.emit_setp_bool_impl(data.base, arguments.src1, arguments.src2, arguments.src3)?;
|
||||
self.setp_to_set(arguments.dst, data.dtype, result)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_setp_bool(
|
||||
&mut self,
|
||||
data: ast::SetpBoolData,
|
||||
args: ast::SetpBoolArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let dst = self.emit_setp_bool_impl(data, args.src1, args.src2, args.src3)?;
|
||||
self.resolver.register(args.dst1, dst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_setp_bool_impl(
|
||||
&mut self,
|
||||
data: ptx_parser::SetpBoolData,
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
src3: SpirvWord,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let bool_result = self.emit_setp_impl(data.base, None, src1, src2)?;
|
||||
let bool_result = if data.negate_src3 {
|
||||
let constant =
|
||||
unsafe { LLVMConstInt(LLVMIntTypeInContext(self.context, 1), u64::MAX, 0) };
|
||||
unsafe { LLVMBuildXor(self.builder, bool_result, constant, LLVM_UNNAMED.as_ptr()) }
|
||||
} else {
|
||||
bool_result
|
||||
};
|
||||
let post_op = match data.bool_op {
|
||||
ptx_parser::SetpBoolPostOp::Xor => LLVMBuildXor,
|
||||
ptx_parser::SetpBoolPostOp::And => LLVMBuildAnd,
|
||||
ptx_parser::SetpBoolPostOp::Or => LLVMBuildOr,
|
||||
};
|
||||
let src3 = self.resolver.value(src3)?;
|
||||
Ok(unsafe { post_op(self.builder, bool_result, src3, LLVM_UNNAMED.as_ptr()) })
|
||||
}
|
||||
|
||||
fn setp_to_set(
|
||||
&mut self,
|
||||
dst: SpirvWord,
|
||||
dtype: ast::ScalarType,
|
||||
setp_result: LLVMValueRef,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_dtype = get_scalar_type(self.context, dtype);
|
||||
let zero = unsafe { LLVMConstNull(llvm_dtype) };
|
||||
let one = if dtype.kind() == ast::ScalarKind::Float {
|
||||
unsafe { LLVMConstReal(llvm_dtype, 1.0) }
|
||||
} else {
|
||||
unsafe { LLVMConstInt(llvm_dtype, u64::MAX, 0) }
|
||||
};
|
||||
self.resolver.with_result(dst, |dst| unsafe {
|
||||
LLVMBuildSelect(self.builder, setp_result, one, zero, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: revisit this on gfx1250 which has native tanh support
|
||||
fn emit_tanh(
|
||||
&mut self,
|
||||
data: ast::ScalarType,
|
||||
arguments: ast::TanhArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
let llvm_type = get_scalar_type(self.context, data);
|
||||
let name = format!("__ocml_tanh_{}\0", LLVMTypeDisplay(data));
|
||||
let tanh = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
Some(&data.into()),
|
||||
vec![(src, llvm_type)],
|
||||
)?;
|
||||
// Not sure if it ultimately does anything
|
||||
unsafe { LLVMZludaSetFastMathFlags(tanh, LLVMZludaFastMathApproxFunc) }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||
// Should be available in LLVM 19
|
||||
@ -2651,7 +2868,7 @@ fn get_function_type<'a>(
|
||||
_ => {
|
||||
check_multiple_return_types(return_args)?;
|
||||
get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
Ok(unsafe {
|
||||
|
@ -94,6 +94,46 @@ fn run_instruction<'input>(
|
||||
instruction: ptx_parser::Instruction<SpirvWord>,
|
||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||
Ok(match instruction {
|
||||
i @ ptx_parser::Instruction::Sqrt {
|
||||
data:
|
||||
ast::RcpData {
|
||||
kind: ast::RcpKind::Approx,
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Rsqrt {
|
||||
data:
|
||||
ast::TypeFtz {
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Rcp {
|
||||
data:
|
||||
ast::RcpData {
|
||||
kind: ast::RcpKind::Approx,
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Ex2 {
|
||||
data:
|
||||
ast::TypeFtz {
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Lg2 {
|
||||
data: ast::FlushToZero {
|
||||
flush_to_zero: false,
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Activemask { .. } => {
|
||||
to_call(resolver, fn_declarations, "activemask".into(), i)?
|
||||
}
|
||||
@ -116,7 +156,12 @@ fn run_instruction<'input>(
|
||||
ptx_parser::Reduction::And => "bar_red_and_pred",
|
||||
ptx_parser::Reduction::Or => "bar_red_or_pred",
|
||||
};
|
||||
to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })?
|
||||
to_call(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
name.into(),
|
||||
ptx_parser::Instruction::BarRed { data, arguments },
|
||||
)?
|
||||
}
|
||||
ptx_parser::Instruction::ShflSync { data, arguments } => {
|
||||
let mode = match data.mode {
|
||||
|
34
ptx/src/test/ll/abs.ll
Normal file
34
ptx/src/test/ll/abs.ll
Normal file
@ -0,0 +1,34 @@
|
||||
define amdgpu_kernel void @abs(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 {
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca i64, align 8, addrspace(5)
|
||||
%"35" = alloca i32, align 4, addrspace(5)
|
||||
%"36" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"30"
|
||||
|
||||
"30": ; preds = %1
|
||||
%"37" = load i64, ptr addrspace(4) %"31", align 4
|
||||
store i64 %"37", ptr addrspace(5) %"33", align 4
|
||||
%"38" = load i64, ptr addrspace(4) %"32", align 4
|
||||
store i64 %"38", ptr addrspace(5) %"34", align 4
|
||||
%"40" = load i64, ptr addrspace(5) %"33", align 4
|
||||
%"45" = inttoptr i64 %"40" to ptr
|
||||
%"39" = load i32, ptr %"45", align 4
|
||||
store i32 %"39", ptr addrspace(5) %"35", align 4
|
||||
%"42" = load i32, ptr addrspace(5) %"35", align 4
|
||||
%"41" = call i32 @llvm.abs.i32(i32 %"42", i1 false)
|
||||
store i32 %"41", ptr addrspace(5) %"36", align 4
|
||||
%"43" = load i64, ptr addrspace(5) %"34", align 4
|
||||
%"44" = load i32, ptr addrspace(5) %"36", align 4
|
||||
%"46" = inttoptr i64 %"43" to ptr
|
||||
store i32 %"44", ptr %"46", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
|
||||
declare i32 @llvm.abs.i32(i32, i1 immarg) #1
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
|
@ -17,8 +17,9 @@ define amdgpu_kernel void @shr(ptr addrspace(4) byref(i64) %"31", ptr addrspace(
|
||||
%"38" = load i32, ptr %"44", align 4
|
||||
store i32 %"38", ptr addrspace(5) %"35", align 4
|
||||
%"41" = load i32, ptr addrspace(5) %"35", align 4
|
||||
%2 = ashr i32 %"41", 1
|
||||
%"40" = select i1 false, i32 0, i32 %2
|
||||
%2 = ashr i32 %"41", 31
|
||||
%3 = ashr i32 %"41", 1
|
||||
%"40" = select i1 false, i32 %2, i32 %3
|
||||
store i32 %"40", ptr addrspace(5) %"35", align 4
|
||||
%"42" = load i64, ptr addrspace(5) %"34", align 8
|
||||
%"43" = load i32, ptr addrspace(5) %"35", align 4
|
||||
|
31
ptx/src/test/ll/shr_oob.ll
Normal file
31
ptx/src/test/ll/shr_oob.ll
Normal file
@ -0,0 +1,31 @@
|
||||
define amdgpu_kernel void @shr_oob(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 {
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca i64, align 8, addrspace(5)
|
||||
%"35" = alloca i16, align 2, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"30"
|
||||
|
||||
"30": ; preds = %1
|
||||
%"36" = load i64, ptr addrspace(4) %"31", align 4
|
||||
store i64 %"36", ptr addrspace(5) %"33", align 4
|
||||
%"37" = load i64, ptr addrspace(4) %"32", align 4
|
||||
store i64 %"37", ptr addrspace(5) %"34", align 4
|
||||
%"39" = load i64, ptr addrspace(5) %"33", align 4
|
||||
%"44" = inttoptr i64 %"39" to ptr
|
||||
%"38" = load i16, ptr %"44", align 2
|
||||
store i16 %"38", ptr addrspace(5) %"35", align 2
|
||||
%"41" = load i16, ptr addrspace(5) %"35", align 2
|
||||
%2 = ashr i16 %"41", 15
|
||||
%3 = ashr i16 %"41", 16
|
||||
%"40" = select i1 true, i16 %2, i16 %3
|
||||
store i16 %"40", ptr addrspace(5) %"35", align 2
|
||||
%"42" = load i64, ptr addrspace(5) %"34", align 4
|
||||
%"43" = load i16, ptr addrspace(5) %"35", align 2
|
||||
%"45" = inttoptr i64 %"42" to ptr
|
||||
store i16 %"43", ptr %"45", align 2
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
31
ptx/src/test/ll/tanh.ll
Normal file
31
ptx/src/test/ll/tanh.ll
Normal file
@ -0,0 +1,31 @@
|
||||
define amdgpu_kernel void @tanh(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #0 {
|
||||
%"32" = alloca i64, align 8, addrspace(5)
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca float, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"29"
|
||||
|
||||
"29": ; preds = %1
|
||||
%"35" = load i64, ptr addrspace(4) %"30", align 4
|
||||
store i64 %"35", ptr addrspace(5) %"32", align 4
|
||||
%"36" = load i64, ptr addrspace(4) %"31", align 4
|
||||
store i64 %"36", ptr addrspace(5) %"33", align 4
|
||||
%"38" = load i64, ptr addrspace(5) %"32", align 4
|
||||
%"43" = inttoptr i64 %"38" to ptr
|
||||
%"37" = load float, ptr %"43", align 4
|
||||
store float %"37", ptr addrspace(5) %"34", align 4
|
||||
%"40" = load float, ptr addrspace(5) %"34", align 4
|
||||
%"39" = call afn float @__ocml_tanh_f32(float %"40")
|
||||
store float %"39", ptr addrspace(5) %"34", align 4
|
||||
%"41" = load i64, ptr addrspace(5) %"33", align 4
|
||||
%"42" = load float, ptr addrspace(5) %"34", align 4
|
||||
%"44" = inttoptr i64 %"41" to ptr
|
||||
store float %"42", ptr %"44", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
declare float @__ocml_tanh_f32(float)
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
22
ptx/src/test/spirv_run/abs.ptx
Normal file
22
ptx/src/test/spirv_run/abs.ptx
Normal file
@ -0,0 +1,22 @@
|
||||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry abs(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .s32 temp1;
|
||||
.reg .s32 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.s32 temp1, [in_addr];
|
||||
abs.s32 temp2, temp1;
|
||||
st.s32 [out_addr], temp2;
|
||||
ret;
|
||||
}
|
@ -159,6 +159,7 @@ test_ptx!(
|
||||
);
|
||||
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
|
||||
test_ptx!(shr, [-2i32], [-1i32]);
|
||||
test_ptx!(shr_oob, [-32768i16], [-1i16]);
|
||||
test_ptx!(or, [1u64, 2u64], [3u64]);
|
||||
test_ptx!(sub, [2u64], [1u64]);
|
||||
test_ptx!(min, [555i32, 444i32], [444i32]);
|
||||
@ -180,6 +181,7 @@ test_ptx!(
|
||||
[0b1_00000000_01000000000000000000000u32]
|
||||
);
|
||||
test_ptx!(constant_f32, [10f32], [5f32]);
|
||||
test_ptx!(abs, [i32::MIN], [i32::MIN]);
|
||||
test_ptx!(constant_negative, [-101i32], [101i32]);
|
||||
test_ptx!(and, [6u32, 3u32], [2u32]);
|
||||
test_ptx!(selp, [100u16, 200u16], [200u16]);
|
||||
@ -296,6 +298,7 @@ test_ptx!(
|
||||
);
|
||||
test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
|
||||
test_ptx!(warp_sz, [0u8], [32u8]);
|
||||
test_ptx!(tanh, [f32::INFINITY], [1.0f32]);
|
||||
|
||||
test_ptx!(nanosleep, [0u64], [0u64]);
|
||||
|
||||
|
21
ptx/src/test/spirv_run/shr_oob.ptx
Normal file
21
ptx/src/test/spirv_run/shr_oob.ptx
Normal file
@ -0,0 +1,21 @@
|
||||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry shr_oob(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .s16 temp;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.s16 temp, [in_addr];
|
||||
shr.s16 temp, temp, 16;
|
||||
st.s16 [out_addr], temp;
|
||||
ret;
|
||||
}
|
21
ptx/src/test/spirv_run/tanh.ptx
Normal file
21
ptx/src/test/spirv_run/tanh.ptx
Normal file
@ -0,0 +1,21 @@
|
||||
.version 7.0
|
||||
.target sm_75
|
||||
.address_size 64
|
||||
|
||||
.visible .entry tanh(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp1;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 temp1, [in_addr];
|
||||
tanh.approx.f32 temp1, temp1;
|
||||
st.f32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
@ -428,6 +428,44 @@ ptx_parser_macros::generate_instruction_type!(
|
||||
},
|
||||
}
|
||||
},
|
||||
Set {
|
||||
data: SetData,
|
||||
arguments<T>: {
|
||||
dst: {
|
||||
repr: T,
|
||||
type: Type::from(data.dtype)
|
||||
},
|
||||
src1: {
|
||||
repr: T,
|
||||
type: Type::from(data.base.type_),
|
||||
},
|
||||
src2: {
|
||||
repr: T,
|
||||
type: Type::from(data.base.type_),
|
||||
}
|
||||
}
|
||||
},
|
||||
SetBool {
|
||||
data: SetBoolData,
|
||||
arguments<T>: {
|
||||
dst: {
|
||||
repr: T,
|
||||
type: Type::from(data.dtype)
|
||||
},
|
||||
src1: {
|
||||
repr: T,
|
||||
type: Type::from(data.base.base.type_),
|
||||
},
|
||||
src2: {
|
||||
repr: T,
|
||||
type: Type::from(data.base.base.type_),
|
||||
},
|
||||
src3: {
|
||||
repr: T,
|
||||
type: Type::from(ScalarType::Pred)
|
||||
}
|
||||
}
|
||||
},
|
||||
Setp {
|
||||
data: SetpData,
|
||||
arguments<T>: {
|
||||
@ -562,6 +600,14 @@ ptx_parser_macros::generate_instruction_type!(
|
||||
src2: T
|
||||
}
|
||||
},
|
||||
Tanh {
|
||||
type: Type::Scalar(data.clone()),
|
||||
data: ScalarType,
|
||||
arguments<T>: {
|
||||
dst: T,
|
||||
src: T
|
||||
}
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
@ -1247,6 +1293,11 @@ pub struct Mul24Details {
|
||||
pub control: Mul24Control,
|
||||
}
|
||||
|
||||
pub struct SetData {
|
||||
pub dtype: ScalarType,
|
||||
pub base: SetpData,
|
||||
}
|
||||
|
||||
pub struct SetpData {
|
||||
pub type_: ScalarType,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
@ -1288,6 +1339,12 @@ impl SetpData {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct SetBoolData {
|
||||
pub dtype: ScalarType,
|
||||
pub base: SetpBoolData,
|
||||
}
|
||||
|
||||
pub struct SetpBoolData {
|
||||
pub base: SetpData,
|
||||
pub bool_op: SetpBoolPostOp,
|
||||
|
@ -442,7 +442,7 @@ fn directive<'a, 'input>(
|
||||
)),
|
||||
(
|
||||
any,
|
||||
take_till(1.., |(token, _)| match token {
|
||||
take_till(1.., |(token, _)| match token {
|
||||
// visibility
|
||||
Token::DotExtern | Token::DotVisible | Token::DotWeak
|
||||
// methods
|
||||
@ -2201,6 +2201,43 @@ derive_parser!(
|
||||
.rnd: RawRoundingMode = { .rn };
|
||||
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#comparison-and-selection-instructions-set
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-comparison-instructions-set
|
||||
set.CmpOp{.ftz}.dtype.stype d, a, b => {
|
||||
let base = ast::SetpData::try_parse(state, cmpop, ftz, stype);
|
||||
let data = ast::SetData {
|
||||
base, dtype
|
||||
};
|
||||
ast::Instruction::Set {
|
||||
data,
|
||||
arguments: SetArgs { dst: d, src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
set.CmpOp.BoolOp{.ftz}.dtype.stype d, a, b, {!}c => {
|
||||
let (negate_src3, c) = c;
|
||||
let base = ast::SetpData::try_parse(state, cmpop, ftz, stype);
|
||||
let base = ast::SetpBoolData {
|
||||
base,
|
||||
bool_op: boolop,
|
||||
negate_src3
|
||||
};
|
||||
let data = ast::SetBoolData {
|
||||
base, dtype
|
||||
};
|
||||
ast::Instruction::SetBool {
|
||||
data,
|
||||
arguments: SetBoolArgs { dst: d, src1: a, src2: b, src3: c }
|
||||
}
|
||||
}
|
||||
|
||||
.CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge,
|
||||
.lo, .ls, .hi, .hs, // signed
|
||||
.equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only
|
||||
.BoolOp: SetpBoolPostOp = { .and, .or, .xor };
|
||||
.dtype: ScalarType = { .u32, .s32, .f32 };
|
||||
.stype: ScalarType = { .b16, .b32, .b64, .u16, .u32, .u64, .s16, .s32, .s64, .f32, .f64 };
|
||||
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
|
||||
setp.CmpOp{.ftz}.type p[|q], a, b => {
|
||||
@ -3509,6 +3546,16 @@ derive_parser!(
|
||||
arguments: NanosleepArgs { src: t }
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-tanh
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-tanh
|
||||
tanh.approx.type d, a => {
|
||||
Instruction::Tanh {
|
||||
data: type_,
|
||||
arguments: TanhArgs { dst: d, src: a }
|
||||
}
|
||||
}
|
||||
.type: ScalarType = { .f32, .f16, .f16x2, .bf16, .bf16x2 };
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
|
Reference in New Issue
Block a user