mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:57:43 +03:00
Add tanh
This commit is contained in:
@ -640,6 +640,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
|
||||
ast::Instruction::Membar { data } => self.emit_membar(data),
|
||||
ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")),
|
||||
ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments),
|
||||
// replaced by a function call
|
||||
ast::Instruction::Bfe { .. }
|
||||
| ast::Instruction::Bar { .. }
|
||||
@ -2787,6 +2788,26 @@ impl<'a> MethodEmitContext<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TOOD: revisit this on gfx1250 which has native tanh support
|
||||
fn emit_tanh(
|
||||
&mut self,
|
||||
data: ast::ScalarType,
|
||||
arguments: ast::TanhArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
let llvm_type = get_scalar_type(self.context, data);
|
||||
let name = format!("__ocml_tanh_{}\0", LLVMTypeDisplay(data));
|
||||
let tanh = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
Some(&data.into()),
|
||||
vec![(src, llvm_type)],
|
||||
)?;
|
||||
// Not sure if it ultimately does anything
|
||||
unsafe { LLVMZludaSetFastMathFlags(tanh, LLVMZludaFastMathApproxFunc) }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||
// Should be available in LLVM 19
|
||||
|
@ -184,6 +184,7 @@ fn run_instruction<'input>(
|
||||
data: ast::ArithDetails::Integer(..),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Tanh { .. }
|
||||
| ast::Instruction::Trap {}
|
||||
| ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)),
|
||||
ast::Instruction::Add {
|
||||
|
@ -1980,6 +1980,7 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||
InstructionModes::from_rtz_special(data)
|
||||
},
|
||||
ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data),
|
||||
ast::Instruction::Tanh { data, .. } => InstructionModes::from_ftz(*data, Some(false)),
|
||||
}
|
||||
}
|
||||
|
||||
|
31
ptx/src/test/ll/tanh.ll
Normal file
31
ptx/src/test/ll/tanh.ll
Normal file
@ -0,0 +1,31 @@
|
||||
define amdgpu_kernel void @tanh(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #0 {
|
||||
%"32" = alloca i64, align 8, addrspace(5)
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca float, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"29"
|
||||
|
||||
"29": ; preds = %1
|
||||
%"35" = load i64, ptr addrspace(4) %"30", align 4
|
||||
store i64 %"35", ptr addrspace(5) %"32", align 4
|
||||
%"36" = load i64, ptr addrspace(4) %"31", align 4
|
||||
store i64 %"36", ptr addrspace(5) %"33", align 4
|
||||
%"38" = load i64, ptr addrspace(5) %"32", align 4
|
||||
%"43" = inttoptr i64 %"38" to ptr
|
||||
%"37" = load float, ptr %"43", align 4
|
||||
store float %"37", ptr addrspace(5) %"34", align 4
|
||||
%"40" = load float, ptr addrspace(5) %"34", align 4
|
||||
%"39" = call afn float @__ocml_tanh_f32(float %"40")
|
||||
store float %"39", ptr addrspace(5) %"34", align 4
|
||||
%"41" = load i64, ptr addrspace(5) %"33", align 4
|
||||
%"42" = load float, ptr addrspace(5) %"34", align 4
|
||||
%"44" = inttoptr i64 %"41" to ptr
|
||||
store float %"42", ptr %"44", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
declare float @__ocml_tanh_f32(float)
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
@ -298,6 +298,7 @@ test_ptx!(
|
||||
);
|
||||
test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
|
||||
test_ptx!(warp_sz, [0u8], [32u8]);
|
||||
test_ptx!(tanh, [f32::INFINITY], [1.0f32]);
|
||||
|
||||
test_ptx!(assertfail);
|
||||
// TODO: not yet supported
|
||||
|
21
ptx/src/test/spirv_run/tanh.ptx
Normal file
21
ptx/src/test/spirv_run/tanh.ptx
Normal file
@ -0,0 +1,21 @@
|
||||
.version 7.0
|
||||
.target sm_75
|
||||
.address_size 64
|
||||
|
||||
.visible .entry tanh(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp1;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 temp1, [in_addr];
|
||||
tanh.approx.f32 temp1, temp1;
|
||||
st.f32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
@ -594,6 +594,14 @@ ptx_parser_macros::generate_instruction_type!(
|
||||
src2: T
|
||||
}
|
||||
},
|
||||
Tanh {
|
||||
type: Type::Scalar(data.clone()),
|
||||
data: ScalarType,
|
||||
arguments<T>: {
|
||||
dst: T,
|
||||
src: T
|
||||
}
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -3539,6 +3539,16 @@ derive_parser!(
|
||||
}
|
||||
}
|
||||
.mode: ShuffleMode = { .up, .down, .bfly, .idx };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-tanh
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-tanh
|
||||
tanh.approx.type d, a => {
|
||||
Instruction::Tanh {
|
||||
data: type_,
|
||||
arguments: TanhArgs { dst: d, src: a }
|
||||
}
|
||||
}
|
||||
.type: ScalarType = { .f32, .f16, .f16x2, .bf16, .bf16x2 };
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
|
Reference in New Issue
Block a user