This commit is contained in:
Andrzej Janik
2025-07-23 22:15:23 +00:00
parent eb2d1f81fb
commit 63f02c4158
8 changed files with 94 additions and 0 deletions

View File

@ -640,6 +640,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
ast::Instruction::Membar { data } => self.emit_membar(data),
ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")),
ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments),
// replaced by a function call
ast::Instruction::Bfe { .. }
| ast::Instruction::Bar { .. }
@ -2787,6 +2788,26 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
// TOOD: revisit this on gfx1250 which has native tanh support
fn emit_tanh(
&mut self,
data: ast::ScalarType,
arguments: ast::TanhArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
let llvm_type = get_scalar_type(self.context, data);
let name = format!("__ocml_tanh_{}\0", LLVMTypeDisplay(data));
let tanh = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) },
Some(arguments.dst),
Some(&data.into()),
vec![(src, llvm_type)],
)?;
// Not sure if it ultimately does anything
unsafe { LLVMZludaSetFastMathFlags(tanh, LLVMZludaFastMathApproxFunc) }
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19

View File

@ -184,6 +184,7 @@ fn run_instruction<'input>(
data: ast::ArithDetails::Integer(..),
..
}
| ast::Instruction::Tanh { .. }
| ast::Instruction::Trap {}
| ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add {

View File

@ -1980,6 +1980,7 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
InstructionModes::from_rtz_special(data)
},
ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data),
ast::Instruction::Tanh { data, .. } => InstructionModes::from_ftz(*data, Some(false)),
}
}

31
ptx/src/test/ll/tanh.ll Normal file
View File

@ -0,0 +1,31 @@
define amdgpu_kernel void @tanh(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #0 {
%"32" = alloca i64, align 8, addrspace(5)
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca float, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"29"
"29": ; preds = %1
%"35" = load i64, ptr addrspace(4) %"30", align 4
store i64 %"35", ptr addrspace(5) %"32", align 4
%"36" = load i64, ptr addrspace(4) %"31", align 4
store i64 %"36", ptr addrspace(5) %"33", align 4
%"38" = load i64, ptr addrspace(5) %"32", align 4
%"43" = inttoptr i64 %"38" to ptr
%"37" = load float, ptr %"43", align 4
store float %"37", ptr addrspace(5) %"34", align 4
%"40" = load float, ptr addrspace(5) %"34", align 4
%"39" = call afn float @__ocml_tanh_f32(float %"40")
store float %"39", ptr addrspace(5) %"34", align 4
%"41" = load i64, ptr addrspace(5) %"33", align 4
%"42" = load float, ptr addrspace(5) %"34", align 4
%"44" = inttoptr i64 %"41" to ptr
store float %"42", ptr %"44", align 4
ret void
}
declare float @__ocml_tanh_f32(float)
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -298,6 +298,7 @@ test_ptx!(
);
test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
test_ptx!(warp_sz, [0u8], [32u8]);
test_ptx!(tanh, [f32::INFINITY], [1.0f32]);
test_ptx!(assertfail);
// TODO: not yet supported

View File

@ -0,0 +1,21 @@
.version 7.0
.target sm_75
.address_size 64
.visible .entry tanh(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp1;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 temp1, [in_addr];
tanh.approx.f32 temp1, temp1;
st.f32 [out_addr], temp1;
ret;
}

View File

@ -594,6 +594,14 @@ ptx_parser_macros::generate_instruction_type!(
src2: T
}
},
Tanh {
type: Type::Scalar(data.clone()),
data: ScalarType,
arguments<T>: {
dst: T,
src: T
}
},
}
);

View File

@ -3539,6 +3539,16 @@ derive_parser!(
}
}
.mode: ShuffleMode = { .up, .down, .bfly, .idx };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-tanh
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-tanh
tanh.approx.type d, a => {
Instruction::Tanh {
data: type_,
arguments: TanhArgs { dst: d, src: a }
}
}
.type: ScalarType = { .f32, .f16, .f16x2, .bf16, .bf16x2 };
);
#[cfg(test)]