Implement div

This commit is contained in:
Andrzej Janik
2024-10-11 16:27:36 +02:00
parent 9035c4a24d
commit c8b88f4483

View File

@ -18,6 +18,12 @@
// while with plain LLVM-C it's just:
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete.
// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with
// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all"
// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel",
// but it will too fail similarly, but with "unable to legalize instruction"
use std::array::TryFromSliceError;
use std::convert::TryInto;
use std::ffi::{CStr, NulError};
@ -534,7 +540,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Bar { .. } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
ast::Instruction::Div { .. } => todo!(),
ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
ast::Instruction::Neg { .. } => todo!(),
ast::Instruction::Sin { .. } => todo!(),
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
@ -626,7 +632,7 @@ impl<'a> MethodEmitContext<'a> {
});
Ok(())
} else {
let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
let _conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
&& to_type.kind() == ast::ScalarKind::Signed
{
if to_type.size_of() >= from_type.size_of() {
@ -1086,6 +1092,147 @@ impl<'a> MethodEmitContext<'a> {
}
Ok(())
}
fn emit_div(
&mut self,
data: ptx_parser::DivDetails,
arguments: ptx_parser::DivArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let integer_div = match data {
ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv,
ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv,
ptx_parser::DivDetails::Float(float_div) => {
return self.emit_div_float(float_div, arguments)
}
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
integer_div(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_div_float(
&mut self,
float_div: ptx_parser::DivFloatDetails,
arguments: ptx_parser::DivArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let _rnd = match float_div.kind {
ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven,
ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven,
ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode,
};
let approx = match float_div.kind {
ptx_parser::DivFloatKind::Approx => {
LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc
}
ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone,
ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone,
};
let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFDiv(builder, src1, src2, dst)
});
unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) };
if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind {
// https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div:
// div.full.f32 implements a relatively fast, full-range approximation that scales
// operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not
// support rounding modifiers. The maximum ulp error is 2 across the full range of
// inputs.
// https://llvm.org/docs/LangRef.html#fpmath-metadata
let fpmath_value =
unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) };
let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) };
let mut md_node_content = [fpmath_value];
let md_node = unsafe {
LLVMMDNodeInContext2(
self.context,
md_node_content.as_mut_ptr(),
md_node_content.len(),
)
};
let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) };
let kind = unsafe {
LLVMGetMDKindIDInContext(
self.context,
"fpmath".as_ptr().cast(),
"fpmath".len() as u32,
)
};
unsafe { LLVMSetMetadata(fdiv, kind, md_node) };
}
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
fn with_rounding<T>(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T {
let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
let void_type = unsafe { LLVMVoidTypeInContext(self.context) };
let get_rounding = c"llvm.get.rounding";
let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) };
let mut get_rounding_fn =
unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) };
if get_rounding_fn == ptr::null_mut() {
get_rounding_fn = unsafe {
LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type)
};
}
let set_rounding = c"llvm.set.rounding";
let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) };
let mut set_rounding_fn =
unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) };
if set_rounding_fn == ptr::null_mut() {
set_rounding_fn = unsafe {
LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type)
};
}
let mut preserved_rounding_mode = unsafe {
LLVMBuildCall2(
self.builder,
get_rounding_fn_type,
get_rounding_fn,
ptr::null_mut(),
0,
LLVM_UNNAMED.as_ptr(),
)
};
let mut requested_rounding = unsafe {
LLVMConstInt(
get_scalar_type(self.context, ast::ScalarType::B32),
rounding_to_llvm(rnd) as u64,
0,
)
};
unsafe {
LLVMBuildCall2(
self.builder,
set_rounding_fn_type,
set_rounding_fn,
&mut requested_rounding,
1,
LLVM_UNNAMED.as_ptr(),
)
};
let result = fn_(self);
unsafe {
LLVMBuildCall2(
self.builder,
set_rounding_fn_type,
set_rounding_fn,
&mut preserved_rounding_mode,
1,
LLVM_UNNAMED.as_ptr(),
)
};
result
}
*/
}
fn get_pointer_type<'ctx>(
@ -1279,3 +1426,36 @@ impl ResolveIdent {
}
}
}
/*
struct ScalarTypeInLLVM(ast::ScalarType);
impl std::fmt::Display for ScalarTypeInLLVM {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
ast::ScalarType::Pred => write!(f, "i1"),
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
ptx_parser::ScalarType::B128 => write!(f, "i128"),
ast::ScalarType::F16 => write!(f, "f16"),
ptx_parser::ScalarType::BF16 => write!(f, "bfloat"),
ast::ScalarType::F32 => write!(f, "f32"),
ast::ScalarType::F64 => write!(f, "f64"),
ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"),
ast::ScalarType::F16x2 => write!(f, "v2f16"),
ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"),
}
}
}
fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
match this {
ptx_parser::RoundingMode::Zero => 0,
ptx_parser::RoundingMode::NearestEven => 1,
ptx_parser::RoundingMode::PositiveInf => 2,
ptx_parser::RoundingMode::NegativeInf => 3,
}
}
*/