mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:57:43 +03:00
Add fp saturation, fix various bugs in cvt instruction exposed by ptx_tests (#379)
This commit is contained in:
@ -307,14 +307,14 @@ impl From<libloading::Error> for Error {
|
||||
}
|
||||
|
||||
impl From<comgr2::amd_comgr_status_s> for Error {
|
||||
fn from(_: comgr2::amd_comgr_status_s) -> Self {
|
||||
todo!()
|
||||
fn from(status: comgr2::amd_comgr_status_s) -> Self {
|
||||
Error(status.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<comgr3::amd_comgr_status_s> for Error {
|
||||
fn from(_: comgr3::amd_comgr_status_s) -> Self {
|
||||
todo!()
|
||||
fn from(status: comgr3::amd_comgr_status_s) -> Self {
|
||||
Error(status.0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -518,6 +518,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
|
||||
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
|
||||
Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?,
|
||||
Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?,
|
||||
})
|
||||
}
|
||||
|
||||
@ -590,7 +591,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
inst: ast::Instruction<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match inst {
|
||||
ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
|
||||
ast::Instruction::Mov { data: _, arguments } => self.emit_mov(arguments),
|
||||
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
|
||||
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
|
||||
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
|
||||
@ -836,7 +837,13 @@ impl<'a> MethodEmitContext<'a> {
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let fn_ = match data {
|
||||
ast::ArithDetails::Integer(..) => LLVMBuildAdd,
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
saturate: true,
|
||||
type_,
|
||||
}) => return self.emit_add_sat(type_, arguments),
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
saturate: false, ..
|
||||
}) => LLVMBuildAdd,
|
||||
ast::ArithDetails::Float(..) => LLVMBuildFAdd,
|
||||
};
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
@ -917,11 +924,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mov(
|
||||
&mut self,
|
||||
_data: ast::MovDetails,
|
||||
arguments: ast::MovArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
fn emit_mov(&mut self, arguments: ast::MovArgs<SpirvWord>) -> Result<(), TranslateError> {
|
||||
self.resolver
|
||||
.register(arguments.dst, self.resolver.value(arguments.src)?);
|
||||
Ok(())
|
||||
@ -1612,32 +1615,40 @@ impl<'a> MethodEmitContext<'a> {
|
||||
ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
|
||||
ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
|
||||
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
|
||||
ptx_parser::CvtMode::SaturateUnsignedToSigned => {
|
||||
ptx_parser::CvtMode::IntSaturateToSigned => {
|
||||
return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments)
|
||||
}
|
||||
ptx_parser::CvtMode::SaturateSignedToUnsigned => {
|
||||
ptx_parser::CvtMode::IntSaturateToUnsigned => {
|
||||
return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments)
|
||||
}
|
||||
ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt,
|
||||
ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc,
|
||||
ptx_parser::CvtMode::FPRound {
|
||||
integer_rounding, ..
|
||||
integer_rounding: None,
|
||||
flush_to_zero: None | Some(false),
|
||||
..
|
||||
} => {
|
||||
return self.emit_cvt_float_to_int(
|
||||
data.from,
|
||||
data.to,
|
||||
integer_rounding,
|
||||
arguments,
|
||||
Some(LLVMBuildFPToSI),
|
||||
)
|
||||
return self.emit_mov(ast::MovArgs {
|
||||
dst: arguments.dst,
|
||||
src: arguments.src,
|
||||
})
|
||||
}
|
||||
ptx_parser::CvtMode::FPRound {
|
||||
integer_rounding: None,
|
||||
flush_to_zero: Some(true),
|
||||
..
|
||||
} => return self.flush_denormals(data.to, arguments.src, arguments.dst),
|
||||
ptx_parser::CvtMode::FPRound {
|
||||
integer_rounding: Some(rounding),
|
||||
..
|
||||
} => return self.emit_cvt_float_to_int(data.from, data.to, rounding, arguments, None),
|
||||
ptx_parser::CvtMode::SignedFromFP { rounding, .. } => {
|
||||
return self.emit_cvt_float_to_int(
|
||||
data.from,
|
||||
data.to,
|
||||
rounding,
|
||||
arguments,
|
||||
Some(LLVMBuildFPToSI),
|
||||
Some(true),
|
||||
)
|
||||
}
|
||||
ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => {
|
||||
@ -1646,13 +1657,13 @@ impl<'a> MethodEmitContext<'a> {
|
||||
data.to,
|
||||
rounding,
|
||||
arguments,
|
||||
Some(LLVMBuildFPToUI),
|
||||
Some(false),
|
||||
)
|
||||
}
|
||||
ptx_parser::CvtMode::FPFromSigned(_) => {
|
||||
ptx_parser::CvtMode::FPFromSigned { .. } => {
|
||||
return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildSIToFP)
|
||||
}
|
||||
ptx_parser::CvtMode::FPFromUnsigned(_) => {
|
||||
ptx_parser::CvtMode::FPFromUnsigned { .. } => {
|
||||
return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildUIToFP)
|
||||
}
|
||||
};
|
||||
@ -1669,27 +1680,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
to: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
// This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1,
|
||||
// so if it's downcast to a smaller type, it will be the maximum value
|
||||
// of the smaller type
|
||||
let max_value = match to {
|
||||
ptx_parser::ScalarType::S8 => i8::MAX as u64,
|
||||
ptx_parser::ScalarType::S16 => i16::MAX as u64,
|
||||
ptx_parser::ScalarType::S32 => i32::MAX as u64,
|
||||
ptx_parser::ScalarType::S64 => i64::MAX as u64,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let from_llvm = get_scalar_type(self.context, from);
|
||||
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
|
||||
let clamped = self.emit_intrinsic(
|
||||
c"llvm.umin",
|
||||
None,
|
||||
Some(&from.into()),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src)?, from_llvm),
|
||||
(max, from_llvm),
|
||||
],
|
||||
)?;
|
||||
let clamped = self.emit_saturate_integer(from, to, &arguments)?;
|
||||
let resize_fn = if to.layout().size() >= from.layout().size() {
|
||||
LLVMBuildSExtOrBitCast
|
||||
} else {
|
||||
@ -1702,40 +1693,92 @@ impl<'a> MethodEmitContext<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_saturate_integer(
|
||||
&mut self,
|
||||
from: ptx_parser::ScalarType,
|
||||
to: ptx_parser::ScalarType,
|
||||
arguments: &ptx_parser::CvtArgs<SpirvWord>,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let from_llvm = get_scalar_type(self.context, from);
|
||||
match from.kind() {
|
||||
ptx_parser::ScalarKind::Unsigned => {
|
||||
let max_value = match to {
|
||||
ptx_parser::ScalarType::U8 => u8::MAX as u64,
|
||||
ptx_parser::ScalarType::S8 => i8::MAX as u64,
|
||||
ptx_parser::ScalarType::U16 => u16::MAX as u64,
|
||||
ptx_parser::ScalarType::S16 => i16::MAX as u64,
|
||||
ptx_parser::ScalarType::U32 => u32::MAX as u64,
|
||||
ptx_parser::ScalarType::S32 => i32::MAX as u64,
|
||||
ptx_parser::ScalarType::U64 => u64::MAX as u64,
|
||||
ptx_parser::ScalarType::S64 => i64::MAX as u64,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
|
||||
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
|
||||
let clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
None,
|
||||
Some(&from.into()),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src)?, from_llvm),
|
||||
(max, from_llvm),
|
||||
],
|
||||
)?;
|
||||
Ok(clamped)
|
||||
}
|
||||
ptx_parser::ScalarKind::Signed => {
|
||||
let (min_value_from, max_value_from) = match from {
|
||||
ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128),
|
||||
ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128),
|
||||
ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128),
|
||||
ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128),
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let (min_value_to, max_value_to) = match to {
|
||||
ptx_parser::ScalarType::U8 => (u8::MIN as i128, u8::MAX as i128),
|
||||
ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128),
|
||||
ptx_parser::ScalarType::U16 => (u16::MIN as i128, u16::MAX as i128),
|
||||
ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128),
|
||||
ptx_parser::ScalarType::U32 => (u32::MIN as i128, u32::MAX as i128),
|
||||
ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128),
|
||||
ptx_parser::ScalarType::U64 => (u64::MIN as i128, u64::MAX as i128),
|
||||
ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128),
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let min_value = min_value_from.max(min_value_to);
|
||||
let max_value = max_value_from.min(max_value_to);
|
||||
let max_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
|
||||
let min = unsafe { LLVMConstInt(from_llvm, min_value as u64, 1) };
|
||||
let min_intrinsic = format!("llvm.smin.{}\0", LLVMTypeDisplay(from));
|
||||
let max = unsafe { LLVMConstInt(from_llvm, max_value as u64, 1) };
|
||||
let clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(max_intrinsic.as_bytes()) },
|
||||
None,
|
||||
Some(&from.into()),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src)?, from_llvm),
|
||||
(min, from_llvm),
|
||||
],
|
||||
)?;
|
||||
let clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(min_intrinsic.as_bytes()) },
|
||||
None,
|
||||
Some(&from.into()),
|
||||
vec![(clamped, from_llvm), (max, from_llvm)],
|
||||
)?;
|
||||
Ok(clamped)
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_cvt_signed_to_unsigned_sat(
|
||||
&mut self,
|
||||
from: ptx_parser::ScalarType,
|
||||
to: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let from_llvm = get_scalar_type(self.context, from);
|
||||
let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) };
|
||||
let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
|
||||
let zero_clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
|
||||
None,
|
||||
Some(&from.into()),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src)?, from_llvm),
|
||||
(zero, from_llvm),
|
||||
],
|
||||
)?;
|
||||
// zero_clamped is now unsigned
|
||||
let max_value = match to {
|
||||
ptx_parser::ScalarType::U8 => u8::MAX as u64,
|
||||
ptx_parser::ScalarType::U16 => u16::MAX as u64,
|
||||
ptx_parser::ScalarType::U32 => u32::MAX as u64,
|
||||
ptx_parser::ScalarType::U64 => u64::MAX as u64,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
|
||||
let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
|
||||
let fully_clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
|
||||
None,
|
||||
Some(&from.into()),
|
||||
vec![(zero_clamped, from_llvm), (max, from_llvm)],
|
||||
)?;
|
||||
let clamped = self.emit_saturate_integer(from, to, &arguments)?;
|
||||
let resize_fn = if to.layout().size() >= from.layout().size() {
|
||||
LLVMBuildZExtOrBitCast
|
||||
} else {
|
||||
@ -1743,7 +1786,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
};
|
||||
let to_llvm = get_scalar_type(self.context, to);
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
resize_fn(self.builder, fully_clamped, to_llvm, dst)
|
||||
resize_fn(self.builder, clamped, to_llvm, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
@ -1754,18 +1797,89 @@ impl<'a> MethodEmitContext<'a> {
|
||||
to: ast::ScalarType,
|
||||
rounding: ast::RoundingMode,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
llvm_cast: Option<
|
||||
unsafe extern "C" fn(
|
||||
arg1: LLVMBuilderRef,
|
||||
Val: LLVMValueRef,
|
||||
DestTy: LLVMTypeRef,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef,
|
||||
>,
|
||||
signed_cast: Option<bool>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let dst_int_rounded =
|
||||
self.emit_fp_int_rounding(from, rounding, &arguments, signed_cast.is_some())?;
|
||||
// In PTX all the int-from-float casts are saturating casts. On the other hand, in LLVM,
|
||||
// out-of-range fptoui and fptosi have undefined behavior.
|
||||
// We could handle this all with llvm.fptosi.sat and llvm.fptoui.sat intrinsics, but
|
||||
// the problem is that, when using *.sat variants AMDGPU target _always_ emits saturation
|
||||
// checks. Often they are unnecessary because v_cvt_* instructions saturates anyway.
|
||||
// For that reason, all from-to combinations that we know have a direct corresponding
|
||||
// v_cvt_* instruction get special treatment
|
||||
let is_saturating_cast = match (to, from) {
|
||||
(ast::ScalarType::S16, ast::ScalarType::F16)
|
||||
| (ast::ScalarType::S32, ast::ScalarType::F32)
|
||||
| (ast::ScalarType::S32, ast::ScalarType::F64)
|
||||
| (ast::ScalarType::U16, ast::ScalarType::F16)
|
||||
| (ast::ScalarType::U32, ast::ScalarType::F32)
|
||||
| (ast::ScalarType::U32, ast::ScalarType::F64) => true,
|
||||
_ => false,
|
||||
};
|
||||
let signed_cast = match signed_cast {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
self.resolver.register(
|
||||
arguments.dst,
|
||||
dst_int_rounded.ok_or_else(error_unreachable)?,
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
if is_saturating_cast {
|
||||
let to = get_scalar_type(self.context, to);
|
||||
let src =
|
||||
dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap());
|
||||
let llvm_cast = if signed_cast {
|
||||
LLVMBuildFPToSI
|
||||
} else {
|
||||
LLVMBuildFPToUI
|
||||
};
|
||||
let poisoned_dst = unsafe { llvm_cast(self.builder, src, to, LLVM_UNNAMED.as_ptr()) };
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildFreeze(self.builder, poisoned_dst, dst)
|
||||
});
|
||||
} else {
|
||||
let cvt_op = if to.kind() == ptx_parser::ScalarKind::Unsigned {
|
||||
"fptoui"
|
||||
} else {
|
||||
"fptosi"
|
||||
};
|
||||
let cast_intrinsic = format!(
|
||||
"llvm.{cvt_op}.sat.{}.{}\0",
|
||||
LLVMTypeDisplay(to),
|
||||
LLVMTypeDisplay(from)
|
||||
);
|
||||
let src =
|
||||
dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap());
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
Some(&to.into()),
|
||||
vec![(src, get_scalar_type(self.context, from))],
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_fp_int_rounding(
|
||||
&mut self,
|
||||
from: ptx_parser::ScalarType,
|
||||
rounding: ptx_parser::RoundingMode,
|
||||
arguments: &ptx_parser::CvtArgs<SpirvWord>,
|
||||
will_saturate_with_cvt: bool,
|
||||
) -> Result<Option<LLVMValueRef>, TranslateError> {
|
||||
let prefix = match rounding {
|
||||
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
|
||||
ptx_parser::RoundingMode::Zero => "llvm.trunc",
|
||||
ptx_parser::RoundingMode::Zero => {
|
||||
// cvt has round-to-zero semantics
|
||||
if will_saturate_with_cvt {
|
||||
return Ok(None);
|
||||
} else {
|
||||
"llvm.trunc"
|
||||
}
|
||||
}
|
||||
ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
|
||||
ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
|
||||
};
|
||||
@ -1779,34 +1893,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
get_scalar_type(self.context, from),
|
||||
)],
|
||||
)?;
|
||||
if let Some(llvm_cast) = llvm_cast {
|
||||
let to = get_scalar_type(self.context, to);
|
||||
let poisoned_dst =
|
||||
unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) };
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildFreeze(self.builder, poisoned_dst, dst)
|
||||
});
|
||||
} else {
|
||||
self.resolver.register(arguments.dst, rounded_float);
|
||||
}
|
||||
// Using explicit saturation gives us worse codegen: it explicitly checks for out of bound
|
||||
// values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt_<TO>_<FROM> which
|
||||
// saturates by default and we don't care about NaNs anyway
|
||||
/*
|
||||
let cast_intrinsic = format!(
|
||||
"{}.{}.{}\0",
|
||||
llvm_cast,
|
||||
LLVMTypeDisplay(to),
|
||||
LLVMTypeDisplay(from)
|
||||
);
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&to.into(),
|
||||
vec![(rounded_float, get_scalar_type(self.context, from))],
|
||||
)?;
|
||||
*/
|
||||
Ok(())
|
||||
Ok(Some(rounded_float))
|
||||
}
|
||||
|
||||
fn emit_cvt_int_to_float(
|
||||
@ -2289,7 +2376,11 @@ impl<'a> MethodEmitContext<'a> {
|
||||
};
|
||||
let res_lo = self.emit_intrinsic(
|
||||
name_lo,
|
||||
if data.control == Mul24Control::Lo { Some(arguments.dst) } else { None },
|
||||
if data.control == Mul24Control::Lo {
|
||||
Some(arguments.dst)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
Some(&ast::Type::Scalar(data.type_)),
|
||||
vec![
|
||||
(src1, get_scalar_type(self.context, data.type_)),
|
||||
@ -2316,9 +2407,8 @@ impl<'a> MethodEmitContext<'a> {
|
||||
],
|
||||
)?;
|
||||
let shift_number = unsafe { LLVMConstInt(LLVMInt32TypeInContext(self.context), 16, 0) };
|
||||
let res_lo_shr = unsafe {
|
||||
LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr())
|
||||
};
|
||||
let res_lo_shr =
|
||||
unsafe { LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr()) };
|
||||
let res_hi_shl =
|
||||
unsafe { LLVMBuildShl(self.builder, res_hi, shift_number, LLVM_UNNAMED.as_ptr()) };
|
||||
|
||||
@ -2381,6 +2471,74 @@ impl<'a> MethodEmitContext<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_fp_saturate(
|
||||
&mut self,
|
||||
type_: ast::ScalarType,
|
||||
dst: SpirvWord,
|
||||
src: SpirvWord,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_type = get_scalar_type(self.context, type_);
|
||||
let zero = unsafe { LLVMConstReal(llvm_type, 0.0) };
|
||||
let one = unsafe { LLVMConstReal(llvm_type, 1.0) };
|
||||
let maxnum_intrinsic = format!("llvm.maxnum.{}\0", LLVMTypeDisplay(type_));
|
||||
let minnum_intrinsic = format!("llvm.minnum.{}\0", LLVMTypeDisplay(type_));
|
||||
let src = self.resolver.value(src)?;
|
||||
let maxnum = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(maxnum_intrinsic.as_bytes()) },
|
||||
None,
|
||||
Some(&type_.into()),
|
||||
vec![(src, llvm_type), (zero, llvm_type)],
|
||||
)?;
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(minnum_intrinsic.as_bytes()) },
|
||||
Some(dst),
|
||||
Some(&type_.into()),
|
||||
vec![(maxnum, llvm_type), (one, llvm_type)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_add_sat(
|
||||
&mut self,
|
||||
type_: ast::ScalarType,
|
||||
arguments: ast::AddArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_type = get_scalar_type(self.context, type_);
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let op = if type_.kind() == ast::ScalarKind::Signed {
|
||||
"sadd"
|
||||
} else {
|
||||
"uadd"
|
||||
};
|
||||
let intrinsic = format!("llvm.{}.sat.{}\0", op, LLVMTypeDisplay(type_));
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
Some(&type_.into()),
|
||||
vec![(src1, llvm_type), (src2, llvm_type)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush_denormals(
|
||||
&mut self,
|
||||
type_: ptx_parser::ScalarType,
|
||||
src: SpirvWord,
|
||||
dst: SpirvWord,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_type = get_scalar_type(self.context, type_);
|
||||
let src = self.resolver.value(src)?;
|
||||
let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(type_));
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(dst),
|
||||
Some(&type_.into()),
|
||||
vec![(src, llvm_type)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||
// Should be available in LLVM 19
|
||||
|
296
ptx/src/pass/insert_post_saturation.rs
Normal file
296
ptx/src/pass/insert_post_saturation.rs
Normal file
@ -0,0 +1,296 @@
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let mut new_statements = Vec::new();
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut new_statements, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(new_statements)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(inst) => run_instruction(resolver, result, inst)?,
|
||||
statement => {
|
||||
result.push(statement);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_instruction<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
mut instruction: ast::Instruction<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match instruction {
|
||||
ast::Instruction::Abs { .. }
|
||||
| ast::Instruction::Activemask { .. }
|
||||
| ast::Instruction::Add {
|
||||
data:
|
||||
ast::ArithDetails::Float(ast::ArithFloat {
|
||||
saturate: false, ..
|
||||
}),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Add {
|
||||
data: ast::ArithDetails::Integer(..),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::And { .. }
|
||||
| ast::Instruction::Atom { .. }
|
||||
| ast::Instruction::AtomCas { .. }
|
||||
| ast::Instruction::Bar { .. }
|
||||
| ast::Instruction::Bfe { .. }
|
||||
| ast::Instruction::Bfi { .. }
|
||||
| ast::Instruction::Bra { .. }
|
||||
| ast::Instruction::Brev { .. }
|
||||
| ast::Instruction::Call { .. }
|
||||
| ast::Instruction::Clz { .. }
|
||||
| ast::Instruction::Cos { .. }
|
||||
| ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
mode:
|
||||
ast::CvtMode::ZeroExtend
|
||||
| ast::CvtMode::SignExtend
|
||||
| ast::CvtMode::Truncate
|
||||
| ast::CvtMode::Bitcast
|
||||
| ast::CvtMode::IntSaturateToSigned
|
||||
| ast::CvtMode::IntSaturateToUnsigned
|
||||
| ast::CvtMode::SignedFromFP { .. }
|
||||
| ast::CvtMode::UnsignedFromFP { .. }
|
||||
| ast::CvtMode::FPFromSigned {
|
||||
saturate: false, ..
|
||||
}
|
||||
| ast::CvtMode::FPFromUnsigned {
|
||||
saturate: false, ..
|
||||
}
|
||||
| ast::CvtMode::FPExtend {
|
||||
saturate: false, ..
|
||||
}
|
||||
| ast::CvtMode::FPTruncate {
|
||||
saturate: false, ..
|
||||
}
|
||||
| ast::CvtMode::FPRound {
|
||||
saturate: false, ..
|
||||
},
|
||||
..
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Cvta { .. }
|
||||
| ast::Instruction::Div { .. }
|
||||
| ast::Instruction::Ex2 { .. }
|
||||
| ast::Instruction::Fma {
|
||||
data: ast::ArithFloat {
|
||||
saturate: false, ..
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Ld { .. }
|
||||
| ast::Instruction::Lg2 { .. }
|
||||
| ast::Instruction::Mad {
|
||||
data:
|
||||
ast::MadDetails::Float(ast::ArithFloat {
|
||||
saturate: false, ..
|
||||
}),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mad {
|
||||
data: ast::MadDetails::Integer { .. },
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Max { .. }
|
||||
| ast::Instruction::Membar { .. }
|
||||
| ast::Instruction::Min { .. }
|
||||
| ast::Instruction::Mov { .. }
|
||||
| ast::Instruction::Mul {
|
||||
data:
|
||||
ast::MulDetails::Float(ast::ArithFloat {
|
||||
saturate: false, ..
|
||||
}),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Integer { .. },
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mul24 { .. }
|
||||
| ast::Instruction::Neg { .. }
|
||||
| ast::Instruction::Not { .. }
|
||||
| ast::Instruction::Or { .. }
|
||||
| ast::Instruction::Popc { .. }
|
||||
| ast::Instruction::Prmt { .. }
|
||||
| ast::Instruction::PrmtSlow { .. }
|
||||
| ast::Instruction::Rcp { .. }
|
||||
| ast::Instruction::Rem { .. }
|
||||
| ast::Instruction::Ret { .. }
|
||||
| ast::Instruction::Rsqrt { .. }
|
||||
| ast::Instruction::Selp { .. }
|
||||
| ast::Instruction::Setp { .. }
|
||||
| ast::Instruction::SetpBool { .. }
|
||||
| ast::Instruction::Shl { .. }
|
||||
| ast::Instruction::Shr { .. }
|
||||
| ast::Instruction::Sin { .. }
|
||||
| ast::Instruction::Sqrt { .. }
|
||||
| ast::Instruction::St { .. }
|
||||
| ast::Instruction::Sub {
|
||||
data:
|
||||
ast::ArithDetails::Float(ast::ArithFloat {
|
||||
saturate: false, ..
|
||||
}),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Sub {
|
||||
data: ast::ArithDetails::Integer(..),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Trap {}
|
||||
| ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)),
|
||||
ast::Instruction::Add {
|
||||
data:
|
||||
ast::ArithDetails::Float(ast::ArithFloat {
|
||||
saturate: true,
|
||||
type_,
|
||||
..
|
||||
}),
|
||||
arguments: ast::AddArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Fma {
|
||||
data:
|
||||
ast::ArithFloat {
|
||||
saturate: true,
|
||||
type_,
|
||||
..
|
||||
},
|
||||
arguments: ast::FmaArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Mad {
|
||||
data:
|
||||
ast::MadDetails::Float(ast::ArithFloat {
|
||||
saturate: true,
|
||||
type_,
|
||||
..
|
||||
}),
|
||||
arguments: ast::MadArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Mul {
|
||||
data:
|
||||
ast::MulDetails::Float(ast::ArithFloat {
|
||||
saturate: true,
|
||||
type_,
|
||||
..
|
||||
}),
|
||||
arguments: ast::MulArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Sub {
|
||||
data:
|
||||
ast::ArithDetails::Float(ast::ArithFloat {
|
||||
saturate: true,
|
||||
type_,
|
||||
..
|
||||
}),
|
||||
arguments: ast::SubArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
to: type_,
|
||||
mode: ast::CvtMode::FPExtend { saturate: true, .. },
|
||||
..
|
||||
},
|
||||
arguments: ast::CvtArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
to: type_,
|
||||
mode: ast::CvtMode::FPTruncate { saturate: true, .. },
|
||||
..
|
||||
},
|
||||
arguments: ast::CvtArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
to: type_,
|
||||
mode: ast::CvtMode::FPRound { saturate: true, .. },
|
||||
..
|
||||
},
|
||||
arguments: ast::CvtArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
to: type_,
|
||||
mode: ast::CvtMode::FPFromSigned { saturate: true, .. },
|
||||
..
|
||||
},
|
||||
arguments: ast::CvtArgs { ref mut dst, .. },
|
||||
}
|
||||
| ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
to: type_,
|
||||
mode: ast::CvtMode::FPFromUnsigned { saturate: true, .. },
|
||||
..
|
||||
},
|
||||
arguments: ast::CvtArgs { ref mut dst, .. },
|
||||
} => {
|
||||
let sat = get_post_saturation(resolver, type_, dst)?;
|
||||
result.push(Statement::Instruction(instruction));
|
||||
result.push(sat);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_post_saturation<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
type_: ast::ScalarType,
|
||||
old_dst: &mut SpirvWord,
|
||||
) -> Result<Statement<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let post_sat = resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg)));
|
||||
let dst = *old_dst;
|
||||
*old_dst = post_sat;
|
||||
Ok(Statement::FpSaturate {
|
||||
dst,
|
||||
src: post_sat,
|
||||
type_,
|
||||
})
|
||||
}
|
@ -167,24 +167,40 @@ impl InstructionModes {
|
||||
}
|
||||
}
|
||||
|
||||
fn mixed_ftz_f32(
|
||||
type_: ast::ScalarType,
|
||||
denormal: Option<DenormalMode>,
|
||||
rounding: Option<RoundingMode>,
|
||||
fn from_typed_denormal_rounding(
|
||||
from_type: ast::ScalarType,
|
||||
to_type: ast::ScalarType,
|
||||
denormal: DenormalMode,
|
||||
rounding: RoundingMode,
|
||||
) -> Self {
|
||||
if type_ != ast::ScalarType::F32 {
|
||||
Self {
|
||||
denormal_f16f64: denormal,
|
||||
rounding_f32: rounding,
|
||||
..Self::none()
|
||||
rounding_f32: Some(rounding),
|
||||
rounding_f16f64: Some(rounding),
|
||||
..Self::from_typed_denormal(from_type, to_type, denormal)
|
||||
}
|
||||
}
|
||||
|
||||
// This function accepts DenormalMode and not Option<DenormalMode> because
|
||||
// the semantics are slightly different.
|
||||
// * In instructions `None` means: flush-to-zero has not been explicitly requested
|
||||
// * In this pass `None` means: neither flush-to-zero, nor preserve is applicable
|
||||
fn from_typed_denormal(
|
||||
from_type: ast::ScalarType,
|
||||
to_type: ast::ScalarType,
|
||||
denormal: DenormalMode,
|
||||
) -> Self {
|
||||
let mut result = Self::none();
|
||||
if from_type == ast::ScalarType::F32 || to_type == ast::ScalarType::F32 {
|
||||
result.denormal_f32 = if denormal == DenormalMode::FlushToZero {
|
||||
Some(DenormalMode::FlushToZero)
|
||||
} else {
|
||||
Self {
|
||||
denormal_f32: denormal,
|
||||
rounding_f32: rounding,
|
||||
..Self::none()
|
||||
Some(DenormalMode::Preserve)
|
||||
};
|
||||
}
|
||||
if !(from_type == ast::ScalarType::F32 && to_type == ast::ScalarType::F32) {
|
||||
result.denormal_f16f64 = Some(DenormalMode::Preserve);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn from_arith_float(arith: &ast::ArithFloat) -> InstructionModes {
|
||||
@ -220,31 +236,52 @@ impl InstructionModes {
|
||||
| ast::CvtMode::SignExtend
|
||||
| ast::CvtMode::Truncate
|
||||
| ast::CvtMode::Bitcast
|
||||
| ast::CvtMode::SaturateUnsignedToSigned
|
||||
| ast::CvtMode::SaturateSignedToUnsigned => Self::none(),
|
||||
ast::CvtMode::FPExtend { flush_to_zero } => {
|
||||
Self::from_ftz(ast::ScalarType::F32, flush_to_zero)
|
||||
}
|
||||
| ast::CvtMode::IntSaturateToSigned
|
||||
| ast::CvtMode::IntSaturateToUnsigned => Self::none(),
|
||||
ast::CvtMode::FPExtend { flush_to_zero, .. } => Self::from_typed_denormal(
|
||||
cvt.from,
|
||||
cvt.to,
|
||||
flush_to_zero
|
||||
.map(DenormalMode::from_ftz)
|
||||
.unwrap_or(DenormalMode::Preserve),
|
||||
),
|
||||
ast::CvtMode::FPTruncate {
|
||||
rounding,
|
||||
flush_to_zero,
|
||||
}
|
||||
| ast::CvtMode::FPRound {
|
||||
integer_rounding: rounding,
|
||||
flush_to_zero,
|
||||
} => Self::mixed_ftz_f32(
|
||||
is_integer_rounding,
|
||||
..
|
||||
} => {
|
||||
let denormal_mode = match (is_integer_rounding, flush_to_zero) {
|
||||
(true, Some(true)) => DenormalMode::FlushToZero,
|
||||
_ => DenormalMode::Preserve,
|
||||
};
|
||||
Self::from_typed_denormal_rounding(
|
||||
cvt.from,
|
||||
cvt.to,
|
||||
flush_to_zero.map(DenormalMode::from_ftz),
|
||||
Some(RoundingMode::from_ast(rounding)),
|
||||
denormal_mode,
|
||||
RoundingMode::from_ast(rounding),
|
||||
)
|
||||
}
|
||||
ast::CvtMode::FPRound { flush_to_zero, .. } => Self::from_typed_denormal(
|
||||
cvt.from,
|
||||
cvt.to,
|
||||
flush_to_zero
|
||||
.map(DenormalMode::from_ftz)
|
||||
.unwrap_or(DenormalMode::Preserve),
|
||||
),
|
||||
// float to int contains rounding field, but it's not a rounding
|
||||
// mode but rather round-to-int operation that will be applied
|
||||
ast::CvtMode::SignedFromFP { flush_to_zero, .. }
|
||||
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => {
|
||||
Self::new(cvt.from, flush_to_zero.map(DenormalMode::from_ftz), None)
|
||||
}
|
||||
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
|
||||
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
|
||||
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => Self::from_typed_denormal(
|
||||
cvt.from,
|
||||
cvt.from,
|
||||
flush_to_zero
|
||||
.map(DenormalMode::from_ftz)
|
||||
.unwrap_or(DenormalMode::Preserve),
|
||||
),
|
||||
ast::CvtMode::FPFromSigned { rounding, .. }
|
||||
| ast::CvtMode::FPFromUnsigned { rounding, .. } => {
|
||||
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rounding)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -17,8 +17,9 @@ mod expand_operands;
|
||||
mod fix_special_registers2;
|
||||
mod hoist_globals;
|
||||
mod insert_explicit_load_store;
|
||||
mod instruction_mode_to_global_mode;
|
||||
mod insert_implicit_conversions2;
|
||||
mod insert_post_saturation;
|
||||
mod instruction_mode_to_global_mode;
|
||||
mod normalize_basic_blocks;
|
||||
mod normalize_identifiers2;
|
||||
mod normalize_predicates2;
|
||||
@ -51,6 +52,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
||||
let directives = resolve_function_pointers::run(directives)?;
|
||||
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||
let directives = expand_operands::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
|
||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
|
||||
let directives = remove_unreachable_basic_blocks::run(directives)?;
|
||||
@ -202,6 +204,11 @@ enum Statement<I, P: ast::Operand> {
|
||||
VectorRead(VectorRead),
|
||||
VectorWrite(VectorWrite),
|
||||
SetMode(ModeRegister),
|
||||
FpSaturate {
|
||||
dst: SpirvWord,
|
||||
src: SpirvWord,
|
||||
type_: ast::ScalarType,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Clone, Copy)]
|
||||
@ -488,6 +495,21 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||
Statement::FunctionPointer(FunctionPointerDetails { dst, src })
|
||||
}
|
||||
Statement::SetMode(mode_register) => Statement::SetMode(mode_register),
|
||||
Statement::FpSaturate { dst, src, type_ } => {
|
||||
let dst = visitor.visit_ident(
|
||||
dst,
|
||||
Some((&type_.into(), ast::StateSpace::Reg)),
|
||||
true,
|
||||
false,
|
||||
)?;
|
||||
let src = visitor.visit_ident(
|
||||
src,
|
||||
Some((&type_.into(), ast::StateSpace::Reg)),
|
||||
false,
|
||||
false,
|
||||
)?;
|
||||
Statement::FpSaturate { dst, src, type_ }
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
51
ptx/src/test/ll/add_s32_sat.ll
Normal file
51
ptx/src/test/ll/add_s32_sat.ll
Normal file
@ -0,0 +1,51 @@
|
||||
define amdgpu_kernel void @add_s32_sat(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 {
|
||||
%"39" = alloca i64, align 8, addrspace(5)
|
||||
%"40" = alloca i64, align 8, addrspace(5)
|
||||
%"41" = alloca i32, align 4, addrspace(5)
|
||||
%"42" = alloca i32, align 4, addrspace(5)
|
||||
%"43" = alloca i32, align 4, addrspace(5)
|
||||
%"44" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"36"
|
||||
|
||||
"36": ; preds = %1
|
||||
%"45" = load i64, ptr addrspace(4) %"37", align 4
|
||||
store i64 %"45", ptr addrspace(5) %"39", align 4
|
||||
%"46" = load i64, ptr addrspace(4) %"38", align 4
|
||||
store i64 %"46", ptr addrspace(5) %"40", align 4
|
||||
%"48" = load i64, ptr addrspace(5) %"39", align 4
|
||||
%"61" = inttoptr i64 %"48" to ptr
|
||||
%"47" = load i32, ptr %"61", align 4
|
||||
store i32 %"47", ptr addrspace(5) %"41", align 4
|
||||
%"49" = load i64, ptr addrspace(5) %"39", align 4
|
||||
%"62" = inttoptr i64 %"49" to ptr
|
||||
%"33" = getelementptr inbounds i8, ptr %"62", i64 4
|
||||
%"50" = load i32, ptr %"33", align 4
|
||||
store i32 %"50", ptr addrspace(5) %"42", align 4
|
||||
%"52" = load i32, ptr addrspace(5) %"41", align 4
|
||||
%"53" = load i32, ptr addrspace(5) %"42", align 4
|
||||
%"51" = call i32 @llvm.sadd.sat.i32(i32 %"52", i32 %"53")
|
||||
store i32 %"51", ptr addrspace(5) %"43", align 4
|
||||
%"55" = load i32, ptr addrspace(5) %"41", align 4
|
||||
%"56" = load i32, ptr addrspace(5) %"42", align 4
|
||||
%"54" = add i32 %"55", %"56"
|
||||
store i32 %"54", ptr addrspace(5) %"44", align 4
|
||||
%"57" = load i64, ptr addrspace(5) %"40", align 4
|
||||
%"58" = load i32, ptr addrspace(5) %"43", align 4
|
||||
%"63" = inttoptr i64 %"57" to ptr
|
||||
store i32 %"58", ptr %"63", align 4
|
||||
%"59" = load i64, ptr addrspace(5) %"40", align 4
|
||||
%"64" = inttoptr i64 %"59" to ptr
|
||||
%"35" = getelementptr inbounds i8, ptr %"64", i64 4
|
||||
%"60" = load i32, ptr addrspace(5) %"44", align 4
|
||||
store i32 %"60", ptr %"35", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
|
||||
declare i32 @llvm.sadd.sat.i32(i32, i32) #1
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
|
38
ptx/src/test/ll/cvt_rni_u16_f32.ll
Normal file
38
ptx/src/test/ll/cvt_rni_u16_f32.ll
Normal file
@ -0,0 +1,38 @@
|
||||
define amdgpu_kernel void @cvt_rni_u16_f32(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 {
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca i64, align 8, addrspace(5)
|
||||
%"35" = alloca float, align 4, addrspace(5)
|
||||
%"36" = alloca i16, align 2, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"30"
|
||||
|
||||
"30": ; preds = %1
|
||||
%"37" = load i64, ptr addrspace(4) %"31", align 4
|
||||
store i64 %"37", ptr addrspace(5) %"33", align 4
|
||||
%"38" = load i64, ptr addrspace(4) %"32", align 4
|
||||
store i64 %"38", ptr addrspace(5) %"34", align 4
|
||||
%"40" = load i64, ptr addrspace(5) %"33", align 4
|
||||
%"45" = inttoptr i64 %"40" to ptr addrspace(1)
|
||||
%"39" = load float, ptr addrspace(1) %"45", align 4
|
||||
store float %"39", ptr addrspace(5) %"35", align 4
|
||||
%"42" = load float, ptr addrspace(5) %"35", align 4
|
||||
%2 = call float @llvm.roundeven.f32(float %"42")
|
||||
%"41" = call i16 @llvm.fptoui.sat.i16.f32(float %2)
|
||||
store i16 %"41", ptr addrspace(5) %"36", align 2
|
||||
%"43" = load i64, ptr addrspace(5) %"34", align 4
|
||||
%"44" = load i16, ptr addrspace(5) %"36", align 2
|
||||
%"46" = inttoptr i64 %"43" to ptr
|
||||
store i16 %"44", ptr %"46", align 2
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
|
||||
declare float @llvm.roundeven.f32(float) #1
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
|
||||
declare i16 @llvm.fptoui.sat.i16.f32(float) #1
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
|
24
ptx/src/test/spirv_run/add_s32_sat.ptx
Normal file
24
ptx/src/test/spirv_run/add_s32_sat.ptx
Normal file
@ -0,0 +1,24 @@
|
||||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry add_s32_sat(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .s32 temp<4>;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.s32 temp0, [in_addr];
|
||||
ld.s32 temp1, [in_addr+4];
|
||||
add.sat.s32 temp2, temp0, temp1;
|
||||
add.s32 temp3, temp0, temp1;
|
||||
st.s32 [out_addr], temp2;
|
||||
st.s32 [out_addr+4], temp3;
|
||||
ret;
|
||||
}
|
22
ptx/src/test/spirv_run/cvt_rni_u16_f32.ptx
Normal file
22
ptx/src/test/spirv_run/cvt_rni_u16_f32.ptx
Normal file
@ -0,0 +1,22 @@
|
||||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry cvt_rni_u16_f32(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp_f32;
|
||||
.reg .u16 temp_u16;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.global.f32 temp_f32, [in_addr];
|
||||
cvt.rni.u16.f32 temp_u16, temp_f32;
|
||||
st.u16 [out_addr], temp_u16;
|
||||
ret;
|
||||
}
|
@ -147,6 +147,7 @@ test_ptx!(ex2, [10f32], [1024f32]);
|
||||
test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
|
||||
test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]);
|
||||
test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]);
|
||||
test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]);
|
||||
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
|
||||
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
|
||||
test_ptx!(
|
||||
@ -226,6 +227,7 @@ test_ptx!(
|
||||
[f32::from_bits(0x800000), f32::from_bits(0x007FFFFF)],
|
||||
[0x800000u32, 0xFFFFFF]
|
||||
);
|
||||
test_ptx!(add_s32_sat, [i32::MIN, -1], [i32::MIN, i32::MAX]);
|
||||
test_ptx!(malformed_label, [2u64], [3u64]);
|
||||
test_ptx!(
|
||||
call_rnd,
|
||||
|
@ -2,7 +2,7 @@ use super::{
|
||||
AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
|
||||
StateSpace, VectorPrefix,
|
||||
};
|
||||
use crate::{PtxError, PtxParserState, Mul24Control};
|
||||
use crate::{Mul24Control, PtxError, PtxParserState};
|
||||
use bitflags::bitflags;
|
||||
use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
|
||||
|
||||
@ -1197,7 +1197,6 @@ pub enum MulIntControl {
|
||||
Wide,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct Mul24Details {
|
||||
pub type_: ScalarType,
|
||||
@ -1473,20 +1472,24 @@ pub enum CvtMode {
|
||||
SignExtend,
|
||||
Truncate,
|
||||
Bitcast,
|
||||
SaturateUnsignedToSigned,
|
||||
SaturateSignedToUnsigned,
|
||||
IntSaturateToSigned,
|
||||
IntSaturateToUnsigned,
|
||||
// float from float
|
||||
FPExtend {
|
||||
flush_to_zero: Option<bool>,
|
||||
saturate: bool,
|
||||
},
|
||||
FPTruncate {
|
||||
// float rounding
|
||||
rounding: RoundingMode,
|
||||
is_integer_rounding: bool,
|
||||
flush_to_zero: Option<bool>,
|
||||
saturate: bool,
|
||||
},
|
||||
FPRound {
|
||||
integer_rounding: RoundingMode,
|
||||
integer_rounding: Option<RoundingMode>,
|
||||
flush_to_zero: Option<bool>,
|
||||
saturate: bool,
|
||||
},
|
||||
// int from float
|
||||
SignedFromFP {
|
||||
@ -1498,8 +1501,14 @@ pub enum CvtMode {
|
||||
flush_to_zero: Option<bool>,
|
||||
}, // integer rounding
|
||||
// float from int, ftz is allowed in the grammar, but clearly nonsensical
|
||||
FPFromSigned(RoundingMode), // float rounding
|
||||
FPFromUnsigned(RoundingMode), // float rounding
|
||||
FPFromSigned {
|
||||
rounding: RoundingMode,
|
||||
saturate: bool,
|
||||
}, // float rounding
|
||||
FPFromUnsigned {
|
||||
rounding: RoundingMode,
|
||||
saturate: bool,
|
||||
}, // float rounding
|
||||
}
|
||||
|
||||
impl CvtDetails {
|
||||
@ -1511,9 +1520,6 @@ impl CvtDetails {
|
||||
dst: ScalarType,
|
||||
src: ScalarType,
|
||||
) -> Self {
|
||||
if saturate && dst.kind() == ScalarKind::Float {
|
||||
errors.push(PtxError::SyntaxError);
|
||||
}
|
||||
// Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results.
|
||||
let flush_to_zero = match (dst, src) {
|
||||
(ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz),
|
||||
@ -1524,55 +1530,81 @@ impl CvtDetails {
|
||||
None
|
||||
}
|
||||
};
|
||||
let rounding = rnd.map(Into::into);
|
||||
let rounding = rnd.map(RawRoundingMode::normalize);
|
||||
let mut unwrap_rounding = || match rounding {
|
||||
Some(rnd) => rnd,
|
||||
Some((rnd, is_integer)) => (rnd, is_integer),
|
||||
None => {
|
||||
errors.push(PtxError::SyntaxError);
|
||||
RoundingMode::NearestEven
|
||||
(RoundingMode::NearestEven, false)
|
||||
}
|
||||
};
|
||||
let mode = match (dst.kind(), src.kind()) {
|
||||
(ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
|
||||
Ordering::Less => CvtMode::FPTruncate {
|
||||
rounding: unwrap_rounding(),
|
||||
Ordering::Less => {
|
||||
let (rounding, is_integer_rounding) = unwrap_rounding();
|
||||
CvtMode::FPTruncate {
|
||||
rounding,
|
||||
is_integer_rounding,
|
||||
flush_to_zero,
|
||||
},
|
||||
saturate,
|
||||
}
|
||||
}
|
||||
Ordering::Equal => CvtMode::FPRound {
|
||||
integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven),
|
||||
integer_rounding: rounding.map(|(rnd, _)| rnd),
|
||||
flush_to_zero,
|
||||
saturate,
|
||||
},
|
||||
Ordering::Greater => {
|
||||
if rounding.is_some() {
|
||||
errors.push(PtxError::SyntaxError);
|
||||
}
|
||||
CvtMode::FPExtend { flush_to_zero }
|
||||
CvtMode::FPExtend {
|
||||
flush_to_zero,
|
||||
saturate,
|
||||
}
|
||||
}
|
||||
},
|
||||
(ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP {
|
||||
rounding: unwrap_rounding(),
|
||||
rounding: unwrap_rounding().0,
|
||||
flush_to_zero,
|
||||
},
|
||||
(ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP {
|
||||
rounding: unwrap_rounding(),
|
||||
rounding: unwrap_rounding().0,
|
||||
flush_to_zero,
|
||||
},
|
||||
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
|
||||
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
|
||||
(ScalarKind::Signed, ScalarKind::Unsigned) if saturate => {
|
||||
CvtMode::SaturateUnsignedToSigned
|
||||
}
|
||||
(ScalarKind::Unsigned, ScalarKind::Signed) if saturate => {
|
||||
CvtMode::SaturateSignedToUnsigned
|
||||
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned {
|
||||
rounding: unwrap_rounding().0,
|
||||
saturate,
|
||||
},
|
||||
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned {
|
||||
rounding: unwrap_rounding().0,
|
||||
saturate,
|
||||
},
|
||||
(ScalarKind::Signed, ScalarKind::Unsigned)
|
||||
| (ScalarKind::Signed, ScalarKind::Signed)
|
||||
if saturate =>
|
||||
{
|
||||
CvtMode::IntSaturateToSigned
|
||||
}
|
||||
(ScalarKind::Unsigned, ScalarKind::Signed)
|
||||
| (ScalarKind::Unsigned, ScalarKind::Unsigned)
|
||||
if saturate =>
|
||||
{
|
||||
CvtMode::IntSaturateToUnsigned
|
||||
}
|
||||
(ScalarKind::Unsigned, ScalarKind::Unsigned)
|
||||
| (ScalarKind::Signed, ScalarKind::Signed)
|
||||
| (ScalarKind::Unsigned, ScalarKind::Signed)
|
||||
| (ScalarKind::Signed, ScalarKind::Unsigned)
|
||||
if dst.size_of() == src.size_of() =>
|
||||
{
|
||||
CvtMode::Bitcast
|
||||
}
|
||||
(ScalarKind::Unsigned, ScalarKind::Unsigned)
|
||||
| (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) {
|
||||
| (ScalarKind::Signed, ScalarKind::Signed)
|
||||
| (ScalarKind::Unsigned, ScalarKind::Signed)
|
||||
| (ScalarKind::Signed, ScalarKind::Unsigned) => match dst.size_of().cmp(&src.size_of())
|
||||
{
|
||||
Ordering::Less => CvtMode::Truncate,
|
||||
Ordering::Equal => CvtMode::Bitcast,
|
||||
Ordering::Greater => {
|
||||
@ -1583,7 +1615,6 @@ impl CvtDetails {
|
||||
}
|
||||
}
|
||||
},
|
||||
(ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned,
|
||||
(_, _) => {
|
||||
errors.push(PtxError::SyntaxError);
|
||||
CvtMode::Bitcast
|
||||
|
@ -64,11 +64,21 @@ impl From<RawLdStQualifier> for ast::LdStQualifier {
|
||||
|
||||
impl From<RawRoundingMode> for ast::RoundingMode {
|
||||
fn from(value: RawRoundingMode) -> Self {
|
||||
match value {
|
||||
RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven,
|
||||
RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero,
|
||||
RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf,
|
||||
RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf,
|
||||
value.normalize().0
|
||||
}
|
||||
}
|
||||
|
||||
impl RawRoundingMode {
|
||||
fn normalize(self) -> (ast::RoundingMode, bool) {
|
||||
match self {
|
||||
RawRoundingMode::Rn => (ast::RoundingMode::NearestEven, false),
|
||||
RawRoundingMode::Rz => (ast::RoundingMode::Zero, false),
|
||||
RawRoundingMode::Rm => (ast::RoundingMode::NegativeInf, false),
|
||||
RawRoundingMode::Rp => (ast::RoundingMode::PositiveInf, false),
|
||||
RawRoundingMode::Rni => (ast::RoundingMode::NearestEven, true),
|
||||
RawRoundingMode::Rzi => (ast::RoundingMode::Zero, true),
|
||||
RawRoundingMode::Rmi => (ast::RoundingMode::NegativeInf, true),
|
||||
RawRoundingMode::Rpi => (ast::RoundingMode::PositiveInf, true),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user