Add fp saturation, fix various bugs in cvt instruction exposed by ptx_tests (#379)

This commit is contained in:
Andrzej Janik
2025-06-16 19:14:16 -07:00
committed by GitHub
parent 4d4053194a
commit 2a374ad880
12 changed files with 875 additions and 184 deletions

View File

@ -307,14 +307,14 @@ impl From<libloading::Error> for Error {
}
impl From<comgr2::amd_comgr_status_s> for Error {
fn from(_: comgr2::amd_comgr_status_s) -> Self {
todo!()
fn from(status: comgr2::amd_comgr_status_s) -> Self {
Error(status.0)
}
}
impl From<comgr3::amd_comgr_status_s> for Error {
fn from(_: comgr3::amd_comgr_status_s) -> Self {
todo!()
fn from(status: comgr3::amd_comgr_status_s) -> Self {
Error(status.0)
}
}

View File

@ -518,6 +518,7 @@ impl<'a> MethodEmitContext<'a> {
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?,
Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?,
})
}
@ -590,7 +591,7 @@ impl<'a> MethodEmitContext<'a> {
inst: ast::Instruction<SpirvWord>,
) -> Result<(), TranslateError> {
match inst {
ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
ast::Instruction::Mov { data: _, arguments } => self.emit_mov(arguments),
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
@ -836,7 +837,13 @@ impl<'a> MethodEmitContext<'a> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let fn_ = match data {
ast::ArithDetails::Integer(..) => LLVMBuildAdd,
ast::ArithDetails::Integer(ast::ArithInteger {
saturate: true,
type_,
}) => return self.emit_add_sat(type_, arguments),
ast::ArithDetails::Integer(ast::ArithInteger {
saturate: false, ..
}) => LLVMBuildAdd,
ast::ArithDetails::Float(..) => LLVMBuildFAdd,
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
@ -917,11 +924,7 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_mov(
&mut self,
_data: ast::MovDetails,
arguments: ast::MovArgs<SpirvWord>,
) -> Result<(), TranslateError> {
fn emit_mov(&mut self, arguments: ast::MovArgs<SpirvWord>) -> Result<(), TranslateError> {
self.resolver
.register(arguments.dst, self.resolver.value(arguments.src)?);
Ok(())
@ -1612,32 +1615,40 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
ptx_parser::CvtMode::SaturateUnsignedToSigned => {
ptx_parser::CvtMode::IntSaturateToSigned => {
return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments)
}
ptx_parser::CvtMode::SaturateSignedToUnsigned => {
ptx_parser::CvtMode::IntSaturateToUnsigned => {
return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments)
}
ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt,
ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc,
ptx_parser::CvtMode::FPRound {
integer_rounding, ..
integer_rounding: None,
flush_to_zero: None | Some(false),
..
} => {
return self.emit_cvt_float_to_int(
data.from,
data.to,
integer_rounding,
arguments,
Some(LLVMBuildFPToSI),
)
return self.emit_mov(ast::MovArgs {
dst: arguments.dst,
src: arguments.src,
})
}
ptx_parser::CvtMode::FPRound {
integer_rounding: None,
flush_to_zero: Some(true),
..
} => return self.flush_denormals(data.to, arguments.src, arguments.dst),
ptx_parser::CvtMode::FPRound {
integer_rounding: Some(rounding),
..
} => return self.emit_cvt_float_to_int(data.from, data.to, rounding, arguments, None),
ptx_parser::CvtMode::SignedFromFP { rounding, .. } => {
return self.emit_cvt_float_to_int(
data.from,
data.to,
rounding,
arguments,
Some(LLVMBuildFPToSI),
Some(true),
)
}
ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => {
@ -1646,13 +1657,13 @@ impl<'a> MethodEmitContext<'a> {
data.to,
rounding,
arguments,
Some(LLVMBuildFPToUI),
Some(false),
)
}
ptx_parser::CvtMode::FPFromSigned(_) => {
ptx_parser::CvtMode::FPFromSigned { .. } => {
return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildSIToFP)
}
ptx_parser::CvtMode::FPFromUnsigned(_) => {
ptx_parser::CvtMode::FPFromUnsigned { .. } => {
return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildUIToFP)
}
};
@ -1669,27 +1680,7 @@ impl<'a> MethodEmitContext<'a> {
to: ptx_parser::ScalarType,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
// This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1,
// so if it's downcast to a smaller type, it will be the maximum value
// of the smaller type
let max_value = match to {
ptx_parser::ScalarType::S8 => i8::MAX as u64,
ptx_parser::ScalarType::S16 => i16::MAX as u64,
ptx_parser::ScalarType::S32 => i32::MAX as u64,
ptx_parser::ScalarType::S64 => i64::MAX as u64,
_ => return Err(error_unreachable()),
};
let from_llvm = get_scalar_type(self.context, from);
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
let clamped = self.emit_intrinsic(
c"llvm.umin",
None,
Some(&from.into()),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(max, from_llvm),
],
)?;
let clamped = self.emit_saturate_integer(from, to, &arguments)?;
let resize_fn = if to.layout().size() >= from.layout().size() {
LLVMBuildSExtOrBitCast
} else {
@ -1702,40 +1693,92 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_saturate_integer(
&mut self,
from: ptx_parser::ScalarType,
to: ptx_parser::ScalarType,
arguments: &ptx_parser::CvtArgs<SpirvWord>,
) -> Result<LLVMValueRef, TranslateError> {
let from_llvm = get_scalar_type(self.context, from);
match from.kind() {
ptx_parser::ScalarKind::Unsigned => {
let max_value = match to {
ptx_parser::ScalarType::U8 => u8::MAX as u64,
ptx_parser::ScalarType::S8 => i8::MAX as u64,
ptx_parser::ScalarType::U16 => u16::MAX as u64,
ptx_parser::ScalarType::S16 => i16::MAX as u64,
ptx_parser::ScalarType::U32 => u32::MAX as u64,
ptx_parser::ScalarType::S32 => i32::MAX as u64,
ptx_parser::ScalarType::U64 => u64::MAX as u64,
ptx_parser::ScalarType::S64 => i64::MAX as u64,
_ => return Err(error_unreachable()),
};
let intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
let clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
None,
Some(&from.into()),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(max, from_llvm),
],
)?;
Ok(clamped)
}
ptx_parser::ScalarKind::Signed => {
let (min_value_from, max_value_from) = match from {
ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128),
ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128),
ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128),
ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128),
_ => return Err(error_unreachable()),
};
let (min_value_to, max_value_to) = match to {
ptx_parser::ScalarType::U8 => (u8::MIN as i128, u8::MAX as i128),
ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128),
ptx_parser::ScalarType::U16 => (u16::MIN as i128, u16::MAX as i128),
ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128),
ptx_parser::ScalarType::U32 => (u32::MIN as i128, u32::MAX as i128),
ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128),
ptx_parser::ScalarType::U64 => (u64::MIN as i128, u64::MAX as i128),
ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128),
_ => return Err(error_unreachable()),
};
let min_value = min_value_from.max(min_value_to);
let max_value = max_value_from.min(max_value_to);
let max_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
let min = unsafe { LLVMConstInt(from_llvm, min_value as u64, 1) };
let min_intrinsic = format!("llvm.smin.{}\0", LLVMTypeDisplay(from));
let max = unsafe { LLVMConstInt(from_llvm, max_value as u64, 1) };
let clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(max_intrinsic.as_bytes()) },
None,
Some(&from.into()),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(min, from_llvm),
],
)?;
let clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(min_intrinsic.as_bytes()) },
None,
Some(&from.into()),
vec![(clamped, from_llvm), (max, from_llvm)],
)?;
Ok(clamped)
}
_ => return Err(error_unreachable()),
}
}
fn emit_cvt_signed_to_unsigned_sat(
&mut self,
from: ptx_parser::ScalarType,
to: ptx_parser::ScalarType,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let from_llvm = get_scalar_type(self.context, from);
let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) };
let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
let zero_clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
None,
Some(&from.into()),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(zero, from_llvm),
],
)?;
// zero_clamped is now unsigned
let max_value = match to {
ptx_parser::ScalarType::U8 => u8::MAX as u64,
ptx_parser::ScalarType::U16 => u16::MAX as u64,
ptx_parser::ScalarType::U32 => u32::MAX as u64,
ptx_parser::ScalarType::U64 => u64::MAX as u64,
_ => return Err(error_unreachable()),
};
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
let fully_clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
None,
Some(&from.into()),
vec![(zero_clamped, from_llvm), (max, from_llvm)],
)?;
let clamped = self.emit_saturate_integer(from, to, &arguments)?;
let resize_fn = if to.layout().size() >= from.layout().size() {
LLVMBuildZExtOrBitCast
} else {
@ -1743,7 +1786,7 @@ impl<'a> MethodEmitContext<'a> {
};
let to_llvm = get_scalar_type(self.context, to);
self.resolver.with_result(arguments.dst, |dst| unsafe {
resize_fn(self.builder, fully_clamped, to_llvm, dst)
resize_fn(self.builder, clamped, to_llvm, dst)
});
Ok(())
}
@ -1754,18 +1797,89 @@ impl<'a> MethodEmitContext<'a> {
to: ast::ScalarType,
rounding: ast::RoundingMode,
arguments: ptx_parser::CvtArgs<SpirvWord>,
llvm_cast: Option<
unsafe extern "C" fn(
arg1: LLVMBuilderRef,
Val: LLVMValueRef,
DestTy: LLVMTypeRef,
Name: *const i8,
) -> LLVMValueRef,
>,
signed_cast: Option<bool>,
) -> Result<(), TranslateError> {
let dst_int_rounded =
self.emit_fp_int_rounding(from, rounding, &arguments, signed_cast.is_some())?;
// In PTX all the int-from-float casts are saturating casts. On the other hand, in LLVM,
// out-of-range fptoui and fptosi have undefined behavior.
// We could handle this all with llvm.fptosi.sat and llvm.fptoui.sat intrinsics, but
// the problem is that, when using *.sat variants AMDGPU target _always_ emits saturation
// checks. Often they are unnecessary because v_cvt_* instructions saturates anyway.
// For that reason, all from-to combinations that we know have a direct corresponding
// v_cvt_* instruction get special treatment
let is_saturating_cast = match (to, from) {
(ast::ScalarType::S16, ast::ScalarType::F16)
| (ast::ScalarType::S32, ast::ScalarType::F32)
| (ast::ScalarType::S32, ast::ScalarType::F64)
| (ast::ScalarType::U16, ast::ScalarType::F16)
| (ast::ScalarType::U32, ast::ScalarType::F32)
| (ast::ScalarType::U32, ast::ScalarType::F64) => true,
_ => false,
};
let signed_cast = match signed_cast {
Some(s) => s,
None => {
self.resolver.register(
arguments.dst,
dst_int_rounded.ok_or_else(error_unreachable)?,
);
return Ok(());
}
};
if is_saturating_cast {
let to = get_scalar_type(self.context, to);
let src =
dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap());
let llvm_cast = if signed_cast {
LLVMBuildFPToSI
} else {
LLVMBuildFPToUI
};
let poisoned_dst = unsafe { llvm_cast(self.builder, src, to, LLVM_UNNAMED.as_ptr()) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFreeze(self.builder, poisoned_dst, dst)
});
} else {
let cvt_op = if to.kind() == ptx_parser::ScalarKind::Unsigned {
"fptoui"
} else {
"fptosi"
};
let cast_intrinsic = format!(
"llvm.{cvt_op}.sat.{}.{}\0",
LLVMTypeDisplay(to),
LLVMTypeDisplay(from)
);
let src =
dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
Some(arguments.dst),
Some(&to.into()),
vec![(src, get_scalar_type(self.context, from))],
)?;
}
Ok(())
}
fn emit_fp_int_rounding(
&mut self,
from: ptx_parser::ScalarType,
rounding: ptx_parser::RoundingMode,
arguments: &ptx_parser::CvtArgs<SpirvWord>,
will_saturate_with_cvt: bool,
) -> Result<Option<LLVMValueRef>, TranslateError> {
let prefix = match rounding {
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
ptx_parser::RoundingMode::Zero => "llvm.trunc",
ptx_parser::RoundingMode::Zero => {
// cvt has round-to-zero semantics
if will_saturate_with_cvt {
return Ok(None);
} else {
"llvm.trunc"
}
}
ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
};
@ -1779,34 +1893,7 @@ impl<'a> MethodEmitContext<'a> {
get_scalar_type(self.context, from),
)],
)?;
if let Some(llvm_cast) = llvm_cast {
let to = get_scalar_type(self.context, to);
let poisoned_dst =
unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFreeze(self.builder, poisoned_dst, dst)
});
} else {
self.resolver.register(arguments.dst, rounded_float);
}
// Using explicit saturation gives us worse codegen: it explicitly checks for out of bound
// values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt_<TO>_<FROM> which
// saturates by default and we don't care about NaNs anyway
/*
let cast_intrinsic = format!(
"{}.{}.{}\0",
llvm_cast,
LLVMTypeDisplay(to),
LLVMTypeDisplay(from)
);
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
Some(arguments.dst),
&to.into(),
vec![(rounded_float, get_scalar_type(self.context, from))],
)?;
*/
Ok(())
Ok(Some(rounded_float))
}
fn emit_cvt_int_to_float(
@ -2289,7 +2376,11 @@ impl<'a> MethodEmitContext<'a> {
};
let res_lo = self.emit_intrinsic(
name_lo,
if data.control == Mul24Control::Lo { Some(arguments.dst) } else { None },
if data.control == Mul24Control::Lo {
Some(arguments.dst)
} else {
None
},
Some(&ast::Type::Scalar(data.type_)),
vec![
(src1, get_scalar_type(self.context, data.type_)),
@ -2316,9 +2407,8 @@ impl<'a> MethodEmitContext<'a> {
],
)?;
let shift_number = unsafe { LLVMConstInt(LLVMInt32TypeInContext(self.context), 16, 0) };
let res_lo_shr = unsafe {
LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr())
};
let res_lo_shr =
unsafe { LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr()) };
let res_hi_shl =
unsafe { LLVMBuildShl(self.builder, res_hi, shift_number, LLVM_UNNAMED.as_ptr()) };
@ -2381,6 +2471,74 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_fp_saturate(
&mut self,
type_: ast::ScalarType,
dst: SpirvWord,
src: SpirvWord,
) -> Result<(), TranslateError> {
let llvm_type = get_scalar_type(self.context, type_);
let zero = unsafe { LLVMConstReal(llvm_type, 0.0) };
let one = unsafe { LLVMConstReal(llvm_type, 1.0) };
let maxnum_intrinsic = format!("llvm.maxnum.{}\0", LLVMTypeDisplay(type_));
let minnum_intrinsic = format!("llvm.minnum.{}\0", LLVMTypeDisplay(type_));
let src = self.resolver.value(src)?;
let maxnum = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(maxnum_intrinsic.as_bytes()) },
None,
Some(&type_.into()),
vec![(src, llvm_type), (zero, llvm_type)],
)?;
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(minnum_intrinsic.as_bytes()) },
Some(dst),
Some(&type_.into()),
vec![(maxnum, llvm_type), (one, llvm_type)],
)?;
Ok(())
}
fn emit_add_sat(
&mut self,
type_: ast::ScalarType,
arguments: ast::AddArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_type = get_scalar_type(self.context, type_);
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let op = if type_.kind() == ast::ScalarKind::Signed {
"sadd"
} else {
"uadd"
};
let intrinsic = format!("llvm.{}.sat.{}\0", op, LLVMTypeDisplay(type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
Some(&type_.into()),
vec![(src1, llvm_type), (src2, llvm_type)],
)?;
Ok(())
}
fn flush_denormals(
&mut self,
type_: ptx_parser::ScalarType,
src: SpirvWord,
dst: SpirvWord,
) -> Result<(), TranslateError> {
let llvm_type = get_scalar_type(self.context, type_);
let src = self.resolver.value(src)?;
let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(dst),
Some(&type_.into()),
vec![(src, llvm_type)],
)?;
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19

View File

@ -0,0 +1,296 @@
use super::*;
use ptx_parser as ast;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2,
method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut new_statements = Vec::new();
let body = method
.body
.map(|statements| {
for statement in statements {
run_statement(resolver, &mut new_statements, statement)?;
}
Ok::<_, TranslateError>(new_statements)
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(inst) => run_instruction(resolver, result, inst)?,
statement => {
result.push(statement);
}
}
Ok(())
}
fn run_instruction<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
mut instruction: ast::Instruction<SpirvWord>,
) -> Result<(), TranslateError> {
match instruction {
ast::Instruction::Abs { .. }
| ast::Instruction::Activemask { .. }
| ast::Instruction::Add {
data:
ast::ArithDetails::Float(ast::ArithFloat {
saturate: false, ..
}),
..
}
| ast::Instruction::Add {
data: ast::ArithDetails::Integer(..),
..
}
| ast::Instruction::And { .. }
| ast::Instruction::Atom { .. }
| ast::Instruction::AtomCas { .. }
| ast::Instruction::Bar { .. }
| ast::Instruction::Bfe { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Bra { .. }
| ast::Instruction::Brev { .. }
| ast::Instruction::Call { .. }
| ast::Instruction::Clz { .. }
| ast::Instruction::Cos { .. }
| ast::Instruction::Cvt {
data:
ast::CvtDetails {
mode:
ast::CvtMode::ZeroExtend
| ast::CvtMode::SignExtend
| ast::CvtMode::Truncate
| ast::CvtMode::Bitcast
| ast::CvtMode::IntSaturateToSigned
| ast::CvtMode::IntSaturateToUnsigned
| ast::CvtMode::SignedFromFP { .. }
| ast::CvtMode::UnsignedFromFP { .. }
| ast::CvtMode::FPFromSigned {
saturate: false, ..
}
| ast::CvtMode::FPFromUnsigned {
saturate: false, ..
}
| ast::CvtMode::FPExtend {
saturate: false, ..
}
| ast::CvtMode::FPTruncate {
saturate: false, ..
}
| ast::CvtMode::FPRound {
saturate: false, ..
},
..
},
..
}
| ast::Instruction::Cvta { .. }
| ast::Instruction::Div { .. }
| ast::Instruction::Ex2 { .. }
| ast::Instruction::Fma {
data: ast::ArithFloat {
saturate: false, ..
},
..
}
| ast::Instruction::Ld { .. }
| ast::Instruction::Lg2 { .. }
| ast::Instruction::Mad {
data:
ast::MadDetails::Float(ast::ArithFloat {
saturate: false, ..
}),
..
}
| ast::Instruction::Mad {
data: ast::MadDetails::Integer { .. },
..
}
| ast::Instruction::Max { .. }
| ast::Instruction::Membar { .. }
| ast::Instruction::Min { .. }
| ast::Instruction::Mov { .. }
| ast::Instruction::Mul {
data:
ast::MulDetails::Float(ast::ArithFloat {
saturate: false, ..
}),
..
}
| ast::Instruction::Mul {
data: ast::MulDetails::Integer { .. },
..
}
| ast::Instruction::Mul24 { .. }
| ast::Instruction::Neg { .. }
| ast::Instruction::Not { .. }
| ast::Instruction::Or { .. }
| ast::Instruction::Popc { .. }
| ast::Instruction::Prmt { .. }
| ast::Instruction::PrmtSlow { .. }
| ast::Instruction::Rcp { .. }
| ast::Instruction::Rem { .. }
| ast::Instruction::Ret { .. }
| ast::Instruction::Rsqrt { .. }
| ast::Instruction::Selp { .. }
| ast::Instruction::Setp { .. }
| ast::Instruction::SetpBool { .. }
| ast::Instruction::Shl { .. }
| ast::Instruction::Shr { .. }
| ast::Instruction::Sin { .. }
| ast::Instruction::Sqrt { .. }
| ast::Instruction::St { .. }
| ast::Instruction::Sub {
data:
ast::ArithDetails::Float(ast::ArithFloat {
saturate: false, ..
}),
..
}
| ast::Instruction::Sub {
data: ast::ArithDetails::Integer(..),
..
}
| ast::Instruction::Trap {}
| ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add {
data:
ast::ArithDetails::Float(ast::ArithFloat {
saturate: true,
type_,
..
}),
arguments: ast::AddArgs { ref mut dst, .. },
}
| ast::Instruction::Fma {
data:
ast::ArithFloat {
saturate: true,
type_,
..
},
arguments: ast::FmaArgs { ref mut dst, .. },
}
| ast::Instruction::Mad {
data:
ast::MadDetails::Float(ast::ArithFloat {
saturate: true,
type_,
..
}),
arguments: ast::MadArgs { ref mut dst, .. },
}
| ast::Instruction::Mul {
data:
ast::MulDetails::Float(ast::ArithFloat {
saturate: true,
type_,
..
}),
arguments: ast::MulArgs { ref mut dst, .. },
}
| ast::Instruction::Sub {
data:
ast::ArithDetails::Float(ast::ArithFloat {
saturate: true,
type_,
..
}),
arguments: ast::SubArgs { ref mut dst, .. },
}
| ast::Instruction::Cvt {
data:
ast::CvtDetails {
to: type_,
mode: ast::CvtMode::FPExtend { saturate: true, .. },
..
},
arguments: ast::CvtArgs { ref mut dst, .. },
}
| ast::Instruction::Cvt {
data:
ast::CvtDetails {
to: type_,
mode: ast::CvtMode::FPTruncate { saturate: true, .. },
..
},
arguments: ast::CvtArgs { ref mut dst, .. },
}
| ast::Instruction::Cvt {
data:
ast::CvtDetails {
to: type_,
mode: ast::CvtMode::FPRound { saturate: true, .. },
..
},
arguments: ast::CvtArgs { ref mut dst, .. },
}
| ast::Instruction::Cvt {
data:
ast::CvtDetails {
to: type_,
mode: ast::CvtMode::FPFromSigned { saturate: true, .. },
..
},
arguments: ast::CvtArgs { ref mut dst, .. },
}
| ast::Instruction::Cvt {
data:
ast::CvtDetails {
to: type_,
mode: ast::CvtMode::FPFromUnsigned { saturate: true, .. },
..
},
arguments: ast::CvtArgs { ref mut dst, .. },
} => {
let sat = get_post_saturation(resolver, type_, dst)?;
result.push(Statement::Instruction(instruction));
result.push(sat);
}
}
Ok(())
}
fn get_post_saturation<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
type_: ast::ScalarType,
old_dst: &mut SpirvWord,
) -> Result<Statement<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let post_sat = resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg)));
let dst = *old_dst;
*old_dst = post_sat;
Ok(Statement::FpSaturate {
dst,
src: post_sat,
type_,
})
}

View File

@ -167,26 +167,42 @@ impl InstructionModes {
}
}
fn mixed_ftz_f32(
type_: ast::ScalarType,
denormal: Option<DenormalMode>,
rounding: Option<RoundingMode>,
fn from_typed_denormal_rounding(
from_type: ast::ScalarType,
to_type: ast::ScalarType,
denormal: DenormalMode,
rounding: RoundingMode,
) -> Self {
if type_ != ast::ScalarType::F32 {
Self {
denormal_f16f64: denormal,
rounding_f32: rounding,
..Self::none()
}
} else {
Self {
denormal_f32: denormal,
rounding_f32: rounding,
..Self::none()
}
Self {
rounding_f32: Some(rounding),
rounding_f16f64: Some(rounding),
..Self::from_typed_denormal(from_type, to_type, denormal)
}
}
// This function accepts DenormalMode and not Option<DenormalMode> because
// the semantics are slightly different.
// * In instructions `None` means: flush-to-zero has not been explicitly requested
// * In this pass `None` means: neither flush-to-zero, nor preserve is applicable
fn from_typed_denormal(
from_type: ast::ScalarType,
to_type: ast::ScalarType,
denormal: DenormalMode,
) -> Self {
let mut result = Self::none();
if from_type == ast::ScalarType::F32 || to_type == ast::ScalarType::F32 {
result.denormal_f32 = if denormal == DenormalMode::FlushToZero {
Some(DenormalMode::FlushToZero)
} else {
Some(DenormalMode::Preserve)
};
}
if !(from_type == ast::ScalarType::F32 && to_type == ast::ScalarType::F32) {
result.denormal_f16f64 = Some(DenormalMode::Preserve);
}
result
}
fn from_arith_float(arith: &ast::ArithFloat) -> InstructionModes {
let denormal = arith.flush_to_zero.map(DenormalMode::from_ftz);
let rounding = Some(RoundingMode::from_ast(arith.rounding));
@ -220,31 +236,52 @@ impl InstructionModes {
| ast::CvtMode::SignExtend
| ast::CvtMode::Truncate
| ast::CvtMode::Bitcast
| ast::CvtMode::SaturateUnsignedToSigned
| ast::CvtMode::SaturateSignedToUnsigned => Self::none(),
ast::CvtMode::FPExtend { flush_to_zero } => {
Self::from_ftz(ast::ScalarType::F32, flush_to_zero)
}
| ast::CvtMode::IntSaturateToSigned
| ast::CvtMode::IntSaturateToUnsigned => Self::none(),
ast::CvtMode::FPExtend { flush_to_zero, .. } => Self::from_typed_denormal(
cvt.from,
cvt.to,
flush_to_zero
.map(DenormalMode::from_ftz)
.unwrap_or(DenormalMode::Preserve),
),
ast::CvtMode::FPTruncate {
rounding,
flush_to_zero,
is_integer_rounding,
..
} => {
let denormal_mode = match (is_integer_rounding, flush_to_zero) {
(true, Some(true)) => DenormalMode::FlushToZero,
_ => DenormalMode::Preserve,
};
Self::from_typed_denormal_rounding(
cvt.from,
cvt.to,
denormal_mode,
RoundingMode::from_ast(rounding),
)
}
| ast::CvtMode::FPRound {
integer_rounding: rounding,
flush_to_zero,
} => Self::mixed_ftz_f32(
ast::CvtMode::FPRound { flush_to_zero, .. } => Self::from_typed_denormal(
cvt.from,
cvt.to,
flush_to_zero.map(DenormalMode::from_ftz),
Some(RoundingMode::from_ast(rounding)),
flush_to_zero
.map(DenormalMode::from_ftz)
.unwrap_or(DenormalMode::Preserve),
),
// float to int contains rounding field, but it's not a rounding
// mode but rather round-to-int operation that will be applied
ast::CvtMode::SignedFromFP { flush_to_zero, .. }
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => {
Self::new(cvt.from, flush_to_zero.map(DenormalMode::from_ftz), None)
}
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => Self::from_typed_denormal(
cvt.from,
cvt.from,
flush_to_zero
.map(DenormalMode::from_ftz)
.unwrap_or(DenormalMode::Preserve),
),
ast::CvtMode::FPFromSigned { rounding, .. }
| ast::CvtMode::FPFromUnsigned { rounding, .. } => {
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rounding)))
}
}
}

View File

@ -17,8 +17,9 @@ mod expand_operands;
mod fix_special_registers2;
mod hoist_globals;
mod insert_explicit_load_store;
mod instruction_mode_to_global_mode;
mod insert_implicit_conversions2;
mod insert_post_saturation;
mod instruction_mode_to_global_mode;
mod normalize_basic_blocks;
mod normalize_identifiers2;
mod normalize_predicates2;
@ -51,6 +52,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
let directives = remove_unreachable_basic_blocks::run(directives)?;
@ -202,6 +204,11 @@ enum Statement<I, P: ast::Operand> {
VectorRead(VectorRead),
VectorWrite(VectorWrite),
SetMode(ModeRegister),
FpSaturate {
dst: SpirvWord,
src: SpirvWord,
type_: ast::ScalarType,
},
}
#[derive(Eq, PartialEq, Clone, Copy)]
@ -488,6 +495,21 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
Statement::FunctionPointer(FunctionPointerDetails { dst, src })
}
Statement::SetMode(mode_register) => Statement::SetMode(mode_register),
Statement::FpSaturate { dst, src, type_ } => {
let dst = visitor.visit_ident(
dst,
Some((&type_.into(), ast::StateSpace::Reg)),
true,
false,
)?;
let src = visitor.visit_ident(
src,
Some((&type_.into(), ast::StateSpace::Reg)),
false,
false,
)?;
Statement::FpSaturate { dst, src, type_ }
}
})
}
}

View File

@ -0,0 +1,51 @@
define amdgpu_kernel void @add_s32_sat(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 {
%"39" = alloca i64, align 8, addrspace(5)
%"40" = alloca i64, align 8, addrspace(5)
%"41" = alloca i32, align 4, addrspace(5)
%"42" = alloca i32, align 4, addrspace(5)
%"43" = alloca i32, align 4, addrspace(5)
%"44" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"36"
"36": ; preds = %1
%"45" = load i64, ptr addrspace(4) %"37", align 4
store i64 %"45", ptr addrspace(5) %"39", align 4
%"46" = load i64, ptr addrspace(4) %"38", align 4
store i64 %"46", ptr addrspace(5) %"40", align 4
%"48" = load i64, ptr addrspace(5) %"39", align 4
%"61" = inttoptr i64 %"48" to ptr
%"47" = load i32, ptr %"61", align 4
store i32 %"47", ptr addrspace(5) %"41", align 4
%"49" = load i64, ptr addrspace(5) %"39", align 4
%"62" = inttoptr i64 %"49" to ptr
%"33" = getelementptr inbounds i8, ptr %"62", i64 4
%"50" = load i32, ptr %"33", align 4
store i32 %"50", ptr addrspace(5) %"42", align 4
%"52" = load i32, ptr addrspace(5) %"41", align 4
%"53" = load i32, ptr addrspace(5) %"42", align 4
%"51" = call i32 @llvm.sadd.sat.i32(i32 %"52", i32 %"53")
store i32 %"51", ptr addrspace(5) %"43", align 4
%"55" = load i32, ptr addrspace(5) %"41", align 4
%"56" = load i32, ptr addrspace(5) %"42", align 4
%"54" = add i32 %"55", %"56"
store i32 %"54", ptr addrspace(5) %"44", align 4
%"57" = load i64, ptr addrspace(5) %"40", align 4
%"58" = load i32, ptr addrspace(5) %"43", align 4
%"63" = inttoptr i64 %"57" to ptr
store i32 %"58", ptr %"63", align 4
%"59" = load i64, ptr addrspace(5) %"40", align 4
%"64" = inttoptr i64 %"59" to ptr
%"35" = getelementptr inbounds i8, ptr %"64", i64 4
%"60" = load i32, ptr addrspace(5) %"44", align 4
store i32 %"60", ptr %"35", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.sadd.sat.i32(i32, i32) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View File

@ -0,0 +1,38 @@
define amdgpu_kernel void @cvt_rni_u16_f32(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 {
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca float, align 4, addrspace(5)
%"36" = alloca i16, align 2, addrspace(5)
br label %1
1: ; preds = %0
br label %"30"
"30": ; preds = %1
%"37" = load i64, ptr addrspace(4) %"31", align 4
store i64 %"37", ptr addrspace(5) %"33", align 4
%"38" = load i64, ptr addrspace(4) %"32", align 4
store i64 %"38", ptr addrspace(5) %"34", align 4
%"40" = load i64, ptr addrspace(5) %"33", align 4
%"45" = inttoptr i64 %"40" to ptr addrspace(1)
%"39" = load float, ptr addrspace(1) %"45", align 4
store float %"39", ptr addrspace(5) %"35", align 4
%"42" = load float, ptr addrspace(5) %"35", align 4
%2 = call float @llvm.roundeven.f32(float %"42")
%"41" = call i16 @llvm.fptoui.sat.i16.f32(float %2)
store i16 %"41", ptr addrspace(5) %"36", align 2
%"43" = load i64, ptr addrspace(5) %"34", align 4
%"44" = load i16, ptr addrspace(5) %"36", align 2
%"46" = inttoptr i64 %"43" to ptr
store i16 %"44", ptr %"46", align 2
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare float @llvm.roundeven.f32(float) #1
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i16 @llvm.fptoui.sat.i16.f32(float) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View File

@ -0,0 +1,24 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry add_s32_sat(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp<4>;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp0, [in_addr];
ld.s32 temp1, [in_addr+4];
add.sat.s32 temp2, temp0, temp1;
add.s32 temp3, temp0, temp1;
st.s32 [out_addr], temp2;
st.s32 [out_addr+4], temp3;
ret;
}

View File

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry cvt_rni_u16_f32(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp_f32;
.reg .u16 temp_u16;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.global.f32 temp_f32, [in_addr];
cvt.rni.u16.f32 temp_u16, temp_f32;
st.u16 [out_addr], temp_u16;
ret;
}

View File

@ -147,6 +147,7 @@ test_ptx!(ex2, [10f32], [1024f32]);
test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]);
test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]);
test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]);
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
test_ptx!(
@ -226,6 +227,7 @@ test_ptx!(
[f32::from_bits(0x800000), f32::from_bits(0x007FFFFF)],
[0x800000u32, 0xFFFFFF]
);
test_ptx!(add_s32_sat, [i32::MIN, -1], [i32::MIN, i32::MAX]);
test_ptx!(malformed_label, [2u64], [3u64]);
test_ptx!(
call_rnd,

View File

@ -2,7 +2,7 @@ use super::{
AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
StateSpace, VectorPrefix,
};
use crate::{PtxError, PtxParserState, Mul24Control};
use crate::{Mul24Control, PtxError, PtxParserState};
use bitflags::bitflags;
use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
@ -1197,7 +1197,6 @@ pub enum MulIntControl {
Wide,
}
#[derive(Copy, Clone)]
pub struct Mul24Details {
pub type_: ScalarType,
@ -1473,20 +1472,24 @@ pub enum CvtMode {
SignExtend,
Truncate,
Bitcast,
SaturateUnsignedToSigned,
SaturateSignedToUnsigned,
IntSaturateToSigned,
IntSaturateToUnsigned,
// float from float
FPExtend {
flush_to_zero: Option<bool>,
saturate: bool,
},
FPTruncate {
// float rounding
rounding: RoundingMode,
is_integer_rounding: bool,
flush_to_zero: Option<bool>,
saturate: bool,
},
FPRound {
integer_rounding: RoundingMode,
integer_rounding: Option<RoundingMode>,
flush_to_zero: Option<bool>,
saturate: bool,
},
// int from float
SignedFromFP {
@ -1498,8 +1501,14 @@ pub enum CvtMode {
flush_to_zero: Option<bool>,
}, // integer rounding
// float from int, ftz is allowed in the grammar, but clearly nonsensical
FPFromSigned(RoundingMode), // float rounding
FPFromUnsigned(RoundingMode), // float rounding
FPFromSigned {
rounding: RoundingMode,
saturate: bool,
}, // float rounding
FPFromUnsigned {
rounding: RoundingMode,
saturate: bool,
}, // float rounding
}
impl CvtDetails {
@ -1511,9 +1520,6 @@ impl CvtDetails {
dst: ScalarType,
src: ScalarType,
) -> Self {
if saturate && dst.kind() == ScalarKind::Float {
errors.push(PtxError::SyntaxError);
}
// Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results.
let flush_to_zero = match (dst, src) {
(ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz),
@ -1524,55 +1530,81 @@ impl CvtDetails {
None
}
};
let rounding = rnd.map(Into::into);
let rounding = rnd.map(RawRoundingMode::normalize);
let mut unwrap_rounding = || match rounding {
Some(rnd) => rnd,
Some((rnd, is_integer)) => (rnd, is_integer),
None => {
errors.push(PtxError::SyntaxError);
RoundingMode::NearestEven
(RoundingMode::NearestEven, false)
}
};
let mode = match (dst.kind(), src.kind()) {
(ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
Ordering::Less => CvtMode::FPTruncate {
rounding: unwrap_rounding(),
flush_to_zero,
},
Ordering::Less => {
let (rounding, is_integer_rounding) = unwrap_rounding();
CvtMode::FPTruncate {
rounding,
is_integer_rounding,
flush_to_zero,
saturate,
}
}
Ordering::Equal => CvtMode::FPRound {
integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven),
integer_rounding: rounding.map(|(rnd, _)| rnd),
flush_to_zero,
saturate,
},
Ordering::Greater => {
if rounding.is_some() {
errors.push(PtxError::SyntaxError);
}
CvtMode::FPExtend { flush_to_zero }
CvtMode::FPExtend {
flush_to_zero,
saturate,
}
}
},
(ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP {
rounding: unwrap_rounding(),
rounding: unwrap_rounding().0,
flush_to_zero,
},
(ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP {
rounding: unwrap_rounding(),
rounding: unwrap_rounding().0,
flush_to_zero,
},
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
(ScalarKind::Signed, ScalarKind::Unsigned) if saturate => {
CvtMode::SaturateUnsignedToSigned
}
(ScalarKind::Unsigned, ScalarKind::Signed) if saturate => {
CvtMode::SaturateSignedToUnsigned
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned {
rounding: unwrap_rounding().0,
saturate,
},
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned {
rounding: unwrap_rounding().0,
saturate,
},
(ScalarKind::Signed, ScalarKind::Unsigned)
| (ScalarKind::Signed, ScalarKind::Signed)
if saturate =>
{
CvtMode::IntSaturateToSigned
}
(ScalarKind::Unsigned, ScalarKind::Signed)
| (ScalarKind::Unsigned, ScalarKind::Unsigned)
if saturate =>
{
CvtMode::IntSaturateToUnsigned
}
(ScalarKind::Unsigned, ScalarKind::Unsigned)
| (ScalarKind::Signed, ScalarKind::Signed)
| (ScalarKind::Unsigned, ScalarKind::Signed)
| (ScalarKind::Signed, ScalarKind::Unsigned)
if dst.size_of() == src.size_of() =>
{
CvtMode::Bitcast
}
(ScalarKind::Unsigned, ScalarKind::Unsigned)
| (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) {
| (ScalarKind::Signed, ScalarKind::Signed)
| (ScalarKind::Unsigned, ScalarKind::Signed)
| (ScalarKind::Signed, ScalarKind::Unsigned) => match dst.size_of().cmp(&src.size_of())
{
Ordering::Less => CvtMode::Truncate,
Ordering::Equal => CvtMode::Bitcast,
Ordering::Greater => {
@ -1583,7 +1615,6 @@ impl CvtDetails {
}
}
},
(ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned,
(_, _) => {
errors.push(PtxError::SyntaxError);
CvtMode::Bitcast

View File

@ -64,11 +64,21 @@ impl From<RawLdStQualifier> for ast::LdStQualifier {
impl From<RawRoundingMode> for ast::RoundingMode {
fn from(value: RawRoundingMode) -> Self {
match value {
RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven,
RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero,
RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf,
RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf,
value.normalize().0
}
}
impl RawRoundingMode {
fn normalize(self) -> (ast::RoundingMode, bool) {
match self {
RawRoundingMode::Rn => (ast::RoundingMode::NearestEven, false),
RawRoundingMode::Rz => (ast::RoundingMode::Zero, false),
RawRoundingMode::Rm => (ast::RoundingMode::NegativeInf, false),
RawRoundingMode::Rp => (ast::RoundingMode::PositiveInf, false),
RawRoundingMode::Rni => (ast::RoundingMode::NearestEven, true),
RawRoundingMode::Rzi => (ast::RoundingMode::Zero, true),
RawRoundingMode::Rmi => (ast::RoundingMode::NegativeInf, true),
RawRoundingMode::Rpi => (ast::RoundingMode::PositiveInf, true),
}
}
}