mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-22 03:36:27 +03:00
Implement div
This commit is contained in:
@ -18,6 +18,12 @@
|
||||
// while with plain LLVM-C it's just:
|
||||
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
|
||||
|
||||
// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete.
|
||||
// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with
|
||||
// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all"
|
||||
// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel",
|
||||
// but it will too fail similarly, but with "unable to legalize instruction"
|
||||
|
||||
use std::array::TryFromSliceError;
|
||||
use std::convert::TryInto;
|
||||
use std::ffi::{CStr, NulError};
|
||||
@ -534,7 +540,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
ast::Instruction::Bar { .. } => todo!(),
|
||||
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
|
||||
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
|
||||
ast::Instruction::Div { .. } => todo!(),
|
||||
ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
|
||||
ast::Instruction::Neg { .. } => todo!(),
|
||||
ast::Instruction::Sin { .. } => todo!(),
|
||||
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
|
||||
@ -626,7 +632,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
|
||||
let _conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
|
||||
&& to_type.kind() == ast::ScalarKind::Signed
|
||||
{
|
||||
if to_type.size_of() >= from_type.size_of() {
|
||||
@ -1086,6 +1092,147 @@ impl<'a> MethodEmitContext<'a> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_div(
|
||||
&mut self,
|
||||
data: ptx_parser::DivDetails,
|
||||
arguments: ptx_parser::DivArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let integer_div = match data {
|
||||
ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv,
|
||||
ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv,
|
||||
ptx_parser::DivDetails::Float(float_div) => {
|
||||
return self.emit_div_float(float_div, arguments)
|
||||
}
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
integer_div(self.builder, src1, src2, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_div_float(
|
||||
&mut self,
|
||||
float_div: ptx_parser::DivFloatDetails,
|
||||
arguments: ptx_parser::DivArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let builder = self.builder;
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let _rnd = match float_div.kind {
|
||||
ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven,
|
||||
ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven,
|
||||
ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode,
|
||||
};
|
||||
let approx = match float_div.kind {
|
||||
ptx_parser::DivFloatKind::Approx => {
|
||||
LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc
|
||||
}
|
||||
ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone,
|
||||
ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone,
|
||||
};
|
||||
let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildFDiv(builder, src1, src2, dst)
|
||||
});
|
||||
unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) };
|
||||
if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind {
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div:
|
||||
// div.full.f32 implements a relatively fast, full-range approximation that scales
|
||||
// operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not
|
||||
// support rounding modifiers. The maximum ulp error is 2 across the full range of
|
||||
// inputs.
|
||||
// https://llvm.org/docs/LangRef.html#fpmath-metadata
|
||||
let fpmath_value =
|
||||
unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) };
|
||||
let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) };
|
||||
let mut md_node_content = [fpmath_value];
|
||||
let md_node = unsafe {
|
||||
LLVMMDNodeInContext2(
|
||||
self.context,
|
||||
md_node_content.as_mut_ptr(),
|
||||
md_node_content.len(),
|
||||
)
|
||||
};
|
||||
let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) };
|
||||
let kind = unsafe {
|
||||
LLVMGetMDKindIDInContext(
|
||||
self.context,
|
||||
"fpmath".as_ptr().cast(),
|
||||
"fpmath".len() as u32,
|
||||
)
|
||||
};
|
||||
unsafe { LLVMSetMetadata(fdiv, kind, md_node) };
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||
// Should be available in LLVM 19
|
||||
fn with_rounding<T>(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T {
|
||||
let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
|
||||
let void_type = unsafe { LLVMVoidTypeInContext(self.context) };
|
||||
let get_rounding = c"llvm.get.rounding";
|
||||
let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) };
|
||||
let mut get_rounding_fn =
|
||||
unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) };
|
||||
if get_rounding_fn == ptr::null_mut() {
|
||||
get_rounding_fn = unsafe {
|
||||
LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type)
|
||||
};
|
||||
}
|
||||
let set_rounding = c"llvm.set.rounding";
|
||||
let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) };
|
||||
let mut set_rounding_fn =
|
||||
unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) };
|
||||
if set_rounding_fn == ptr::null_mut() {
|
||||
set_rounding_fn = unsafe {
|
||||
LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type)
|
||||
};
|
||||
}
|
||||
let mut preserved_rounding_mode = unsafe {
|
||||
LLVMBuildCall2(
|
||||
self.builder,
|
||||
get_rounding_fn_type,
|
||||
get_rounding_fn,
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
let mut requested_rounding = unsafe {
|
||||
LLVMConstInt(
|
||||
get_scalar_type(self.context, ast::ScalarType::B32),
|
||||
rounding_to_llvm(rnd) as u64,
|
||||
0,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
LLVMBuildCall2(
|
||||
self.builder,
|
||||
set_rounding_fn_type,
|
||||
set_rounding_fn,
|
||||
&mut requested_rounding,
|
||||
1,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
let result = fn_(self);
|
||||
unsafe {
|
||||
LLVMBuildCall2(
|
||||
self.builder,
|
||||
set_rounding_fn_type,
|
||||
set_rounding_fn,
|
||||
&mut preserved_rounding_mode,
|
||||
1,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
result
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
fn get_pointer_type<'ctx>(
|
||||
@ -1279,3 +1426,36 @@ impl ResolveIdent {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
struct ScalarTypeInLLVM(ast::ScalarType);
|
||||
|
||||
impl std::fmt::Display for ScalarTypeInLLVM {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self.0 {
|
||||
ast::ScalarType::Pred => write!(f, "i1"),
|
||||
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
|
||||
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
|
||||
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
|
||||
ptx_parser::ScalarType::B128 => write!(f, "i128"),
|
||||
ast::ScalarType::F16 => write!(f, "f16"),
|
||||
ptx_parser::ScalarType::BF16 => write!(f, "bfloat"),
|
||||
ast::ScalarType::F32 => write!(f, "f32"),
|
||||
ast::ScalarType::F64 => write!(f, "f64"),
|
||||
ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"),
|
||||
ast::ScalarType::F16x2 => write!(f, "v2f16"),
|
||||
ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
|
||||
match this {
|
||||
ptx_parser::RoundingMode::Zero => 0,
|
||||
ptx_parser::RoundingMode::NearestEven => 1,
|
||||
ptx_parser::RoundingMode::PositiveInf => 2,
|
||||
ptx_parser::RoundingMode::NegativeInf => 3,
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
Reference in New Issue
Block a user