From 2a374ad8806ad3d8a396db9de480c0cf7be8be69 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 16 Jun 2025 19:14:16 -0700 Subject: [PATCH] Add fp saturation, fix various bugs in cvt instruction exposed by ptx_tests (#379) --- comgr/src/lib.rs | 8 +- ptx/src/pass/emit_llvm.rs | 382 +++++++++++++----- ptx/src/pass/insert_post_saturation.rs | 296 ++++++++++++++ .../instruction_mode_to_global_mode/mod.rs | 101 +++-- ptx/src/pass/mod.rs | 24 +- ptx/src/test/ll/add_s32_sat.ll | 51 +++ ptx/src/test/ll/cvt_rni_u16_f32.ll | 38 ++ ptx/src/test/spirv_run/add_s32_sat.ptx | 24 ++ ptx/src/test/spirv_run/cvt_rni_u16_f32.ptx | 22 + ptx/src/test/spirv_run/mod.rs | 2 + ptx_parser/src/ast.rs | 91 +++-- ptx_parser/src/lib.rs | 20 +- 12 files changed, 875 insertions(+), 184 deletions(-) create mode 100644 ptx/src/pass/insert_post_saturation.rs create mode 100644 ptx/src/test/ll/add_s32_sat.ll create mode 100644 ptx/src/test/ll/cvt_rni_u16_f32.ll create mode 100644 ptx/src/test/spirv_run/add_s32_sat.ptx create mode 100644 ptx/src/test/spirv_run/cvt_rni_u16_f32.ptx diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 7588bde..33067ee 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -307,14 +307,14 @@ impl From for Error { } impl From for Error { - fn from(_: comgr2::amd_comgr_status_s) -> Self { - todo!() + fn from(status: comgr2::amd_comgr_status_s) -> Self { + Error(status.0) } } impl From for Error { - fn from(_: comgr3::amd_comgr_status_s) -> Self { - todo!() + fn from(status: comgr3::amd_comgr_status_s) -> Self { + Error(status.0) } } diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 0f432ca..c2562e0 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -518,6 +518,7 @@ impl<'a> MethodEmitContext<'a> { Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?, + Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?, }) } @@ -590,7 +591,7 @@ impl<'a> MethodEmitContext<'a> { inst: ast::Instruction, ) -> Result<(), TranslateError> { match inst { - ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments), + ast::Instruction::Mov { data: _, arguments } => self.emit_mov(arguments), ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), @@ -836,7 +837,13 @@ impl<'a> MethodEmitContext<'a> { let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; let fn_ = match data { - ast::ArithDetails::Integer(..) => LLVMBuildAdd, + ast::ArithDetails::Integer(ast::ArithInteger { + saturate: true, + type_, + }) => return self.emit_add_sat(type_, arguments), + ast::ArithDetails::Integer(ast::ArithInteger { + saturate: false, .. + }) => LLVMBuildAdd, ast::ArithDetails::Float(..) => LLVMBuildFAdd, }; self.resolver.with_result(arguments.dst, |dst| unsafe { @@ -917,11 +924,7 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } - fn emit_mov( - &mut self, - _data: ast::MovDetails, - arguments: ast::MovArgs, - ) -> Result<(), TranslateError> { + fn emit_mov(&mut self, arguments: ast::MovArgs) -> Result<(), TranslateError> { self.resolver .register(arguments.dst, self.resolver.value(arguments.src)?); Ok(()) @@ -1612,32 +1615,40 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::SignExtend => LLVMBuildSExt, ptx_parser::CvtMode::Truncate => LLVMBuildTrunc, ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, - ptx_parser::CvtMode::SaturateUnsignedToSigned => { + ptx_parser::CvtMode::IntSaturateToSigned => { return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments) } - ptx_parser::CvtMode::SaturateSignedToUnsigned => { + ptx_parser::CvtMode::IntSaturateToUnsigned => { return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments) } ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt, ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc, ptx_parser::CvtMode::FPRound { - integer_rounding, .. + integer_rounding: None, + flush_to_zero: None | Some(false), + .. } => { - return self.emit_cvt_float_to_int( - data.from, - data.to, - integer_rounding, - arguments, - Some(LLVMBuildFPToSI), - ) + return self.emit_mov(ast::MovArgs { + dst: arguments.dst, + src: arguments.src, + }) } + ptx_parser::CvtMode::FPRound { + integer_rounding: None, + flush_to_zero: Some(true), + .. + } => return self.flush_denormals(data.to, arguments.src, arguments.dst), + ptx_parser::CvtMode::FPRound { + integer_rounding: Some(rounding), + .. + } => return self.emit_cvt_float_to_int(data.from, data.to, rounding, arguments, None), ptx_parser::CvtMode::SignedFromFP { rounding, .. } => { return self.emit_cvt_float_to_int( data.from, data.to, rounding, arguments, - Some(LLVMBuildFPToSI), + Some(true), ) } ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => { @@ -1646,13 +1657,13 @@ impl<'a> MethodEmitContext<'a> { data.to, rounding, arguments, - Some(LLVMBuildFPToUI), + Some(false), ) } - ptx_parser::CvtMode::FPFromSigned(_) => { + ptx_parser::CvtMode::FPFromSigned { .. } => { return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildSIToFP) } - ptx_parser::CvtMode::FPFromUnsigned(_) => { + ptx_parser::CvtMode::FPFromUnsigned { .. } => { return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildUIToFP) } }; @@ -1669,27 +1680,7 @@ impl<'a> MethodEmitContext<'a> { to: ptx_parser::ScalarType, arguments: ptx_parser::CvtArgs, ) -> Result<(), TranslateError> { - // This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1, - // so if it's downcast to a smaller type, it will be the maximum value - // of the smaller type - let max_value = match to { - ptx_parser::ScalarType::S8 => i8::MAX as u64, - ptx_parser::ScalarType::S16 => i16::MAX as u64, - ptx_parser::ScalarType::S32 => i32::MAX as u64, - ptx_parser::ScalarType::S64 => i64::MAX as u64, - _ => return Err(error_unreachable()), - }; - let from_llvm = get_scalar_type(self.context, from); - let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; - let clamped = self.emit_intrinsic( - c"llvm.umin", - None, - Some(&from.into()), - vec![ - (self.resolver.value(arguments.src)?, from_llvm), - (max, from_llvm), - ], - )?; + let clamped = self.emit_saturate_integer(from, to, &arguments)?; let resize_fn = if to.layout().size() >= from.layout().size() { LLVMBuildSExtOrBitCast } else { @@ -1702,40 +1693,92 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_saturate_integer( + &mut self, + from: ptx_parser::ScalarType, + to: ptx_parser::ScalarType, + arguments: &ptx_parser::CvtArgs, + ) -> Result { + let from_llvm = get_scalar_type(self.context, from); + match from.kind() { + ptx_parser::ScalarKind::Unsigned => { + let max_value = match to { + ptx_parser::ScalarType::U8 => u8::MAX as u64, + ptx_parser::ScalarType::S8 => i8::MAX as u64, + ptx_parser::ScalarType::U16 => u16::MAX as u64, + ptx_parser::ScalarType::S16 => i16::MAX as u64, + ptx_parser::ScalarType::U32 => u32::MAX as u64, + ptx_parser::ScalarType::S32 => i32::MAX as u64, + ptx_parser::ScalarType::U64 => u64::MAX as u64, + ptx_parser::ScalarType::S64 => i64::MAX as u64, + _ => return Err(error_unreachable()), + }; + let intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from)); + let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; + let clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + None, + Some(&from.into()), + vec![ + (self.resolver.value(arguments.src)?, from_llvm), + (max, from_llvm), + ], + )?; + Ok(clamped) + } + ptx_parser::ScalarKind::Signed => { + let (min_value_from, max_value_from) = match from { + ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128), + ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128), + ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128), + ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128), + _ => return Err(error_unreachable()), + }; + let (min_value_to, max_value_to) = match to { + ptx_parser::ScalarType::U8 => (u8::MIN as i128, u8::MAX as i128), + ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128), + ptx_parser::ScalarType::U16 => (u16::MIN as i128, u16::MAX as i128), + ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128), + ptx_parser::ScalarType::U32 => (u32::MIN as i128, u32::MAX as i128), + ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128), + ptx_parser::ScalarType::U64 => (u64::MIN as i128, u64::MAX as i128), + ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128), + _ => return Err(error_unreachable()), + }; + let min_value = min_value_from.max(min_value_to); + let max_value = max_value_from.min(max_value_to); + let max_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from)); + let min = unsafe { LLVMConstInt(from_llvm, min_value as u64, 1) }; + let min_intrinsic = format!("llvm.smin.{}\0", LLVMTypeDisplay(from)); + let max = unsafe { LLVMConstInt(from_llvm, max_value as u64, 1) }; + let clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(max_intrinsic.as_bytes()) }, + None, + Some(&from.into()), + vec![ + (self.resolver.value(arguments.src)?, from_llvm), + (min, from_llvm), + ], + )?; + let clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(min_intrinsic.as_bytes()) }, + None, + Some(&from.into()), + vec![(clamped, from_llvm), (max, from_llvm)], + )?; + Ok(clamped) + } + _ => return Err(error_unreachable()), + } + } + fn emit_cvt_signed_to_unsigned_sat( &mut self, from: ptx_parser::ScalarType, to: ptx_parser::ScalarType, arguments: ptx_parser::CvtArgs, ) -> Result<(), TranslateError> { - let from_llvm = get_scalar_type(self.context, from); - let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) }; - let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from)); - let zero_clamped = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) }, - None, - Some(&from.into()), - vec![ - (self.resolver.value(arguments.src)?, from_llvm), - (zero, from_llvm), - ], - )?; - // zero_clamped is now unsigned - let max_value = match to { - ptx_parser::ScalarType::U8 => u8::MAX as u64, - ptx_parser::ScalarType::U16 => u16::MAX as u64, - ptx_parser::ScalarType::U32 => u32::MAX as u64, - ptx_parser::ScalarType::U64 => u64::MAX as u64, - _ => return Err(error_unreachable()), - }; - let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; - let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from)); - let fully_clamped = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) }, - None, - Some(&from.into()), - vec![(zero_clamped, from_llvm), (max, from_llvm)], - )?; + let clamped = self.emit_saturate_integer(from, to, &arguments)?; let resize_fn = if to.layout().size() >= from.layout().size() { LLVMBuildZExtOrBitCast } else { @@ -1743,7 +1786,7 @@ impl<'a> MethodEmitContext<'a> { }; let to_llvm = get_scalar_type(self.context, to); self.resolver.with_result(arguments.dst, |dst| unsafe { - resize_fn(self.builder, fully_clamped, to_llvm, dst) + resize_fn(self.builder, clamped, to_llvm, dst) }); Ok(()) } @@ -1754,18 +1797,89 @@ impl<'a> MethodEmitContext<'a> { to: ast::ScalarType, rounding: ast::RoundingMode, arguments: ptx_parser::CvtArgs, - llvm_cast: Option< - unsafe extern "C" fn( - arg1: LLVMBuilderRef, - Val: LLVMValueRef, - DestTy: LLVMTypeRef, - Name: *const i8, - ) -> LLVMValueRef, - >, + signed_cast: Option, ) -> Result<(), TranslateError> { + let dst_int_rounded = + self.emit_fp_int_rounding(from, rounding, &arguments, signed_cast.is_some())?; + // In PTX all the int-from-float casts are saturating casts. On the other hand, in LLVM, + // out-of-range fptoui and fptosi have undefined behavior. + // We could handle this all with llvm.fptosi.sat and llvm.fptoui.sat intrinsics, but + // the problem is that, when using *.sat variants AMDGPU target _always_ emits saturation + // checks. Often they are unnecessary because v_cvt_* instructions saturates anyway. + // For that reason, all from-to combinations that we know have a direct corresponding + // v_cvt_* instruction get special treatment + let is_saturating_cast = match (to, from) { + (ast::ScalarType::S16, ast::ScalarType::F16) + | (ast::ScalarType::S32, ast::ScalarType::F32) + | (ast::ScalarType::S32, ast::ScalarType::F64) + | (ast::ScalarType::U16, ast::ScalarType::F16) + | (ast::ScalarType::U32, ast::ScalarType::F32) + | (ast::ScalarType::U32, ast::ScalarType::F64) => true, + _ => false, + }; + let signed_cast = match signed_cast { + Some(s) => s, + None => { + self.resolver.register( + arguments.dst, + dst_int_rounded.ok_or_else(error_unreachable)?, + ); + return Ok(()); + } + }; + if is_saturating_cast { + let to = get_scalar_type(self.context, to); + let src = + dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap()); + let llvm_cast = if signed_cast { + LLVMBuildFPToSI + } else { + LLVMBuildFPToUI + }; + let poisoned_dst = unsafe { llvm_cast(self.builder, src, to, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFreeze(self.builder, poisoned_dst, dst) + }); + } else { + let cvt_op = if to.kind() == ptx_parser::ScalarKind::Unsigned { + "fptoui" + } else { + "fptosi" + }; + let cast_intrinsic = format!( + "llvm.{cvt_op}.sat.{}.{}\0", + LLVMTypeDisplay(to), + LLVMTypeDisplay(from) + ); + let src = + dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap()); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&to.into()), + vec![(src, get_scalar_type(self.context, from))], + )?; + } + Ok(()) + } + + fn emit_fp_int_rounding( + &mut self, + from: ptx_parser::ScalarType, + rounding: ptx_parser::RoundingMode, + arguments: &ptx_parser::CvtArgs, + will_saturate_with_cvt: bool, + ) -> Result, TranslateError> { let prefix = match rounding { ptx_parser::RoundingMode::NearestEven => "llvm.roundeven", - ptx_parser::RoundingMode::Zero => "llvm.trunc", + ptx_parser::RoundingMode::Zero => { + // cvt has round-to-zero semantics + if will_saturate_with_cvt { + return Ok(None); + } else { + "llvm.trunc" + } + } ptx_parser::RoundingMode::NegativeInf => "llvm.floor", ptx_parser::RoundingMode::PositiveInf => "llvm.ceil", }; @@ -1779,34 +1893,7 @@ impl<'a> MethodEmitContext<'a> { get_scalar_type(self.context, from), )], )?; - if let Some(llvm_cast) = llvm_cast { - let to = get_scalar_type(self.context, to); - let poisoned_dst = - unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) }; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildFreeze(self.builder, poisoned_dst, dst) - }); - } else { - self.resolver.register(arguments.dst, rounded_float); - } - // Using explicit saturation gives us worse codegen: it explicitly checks for out of bound - // values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt__ which - // saturates by default and we don't care about NaNs anyway - /* - let cast_intrinsic = format!( - "{}.{}.{}\0", - llvm_cast, - LLVMTypeDisplay(to), - LLVMTypeDisplay(from) - ); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) }, - Some(arguments.dst), - &to.into(), - vec![(rounded_float, get_scalar_type(self.context, from))], - )?; - */ - Ok(()) + Ok(Some(rounded_float)) } fn emit_cvt_int_to_float( @@ -2289,7 +2376,11 @@ impl<'a> MethodEmitContext<'a> { }; let res_lo = self.emit_intrinsic( name_lo, - if data.control == Mul24Control::Lo { Some(arguments.dst) } else { None }, + if data.control == Mul24Control::Lo { + Some(arguments.dst) + } else { + None + }, Some(&ast::Type::Scalar(data.type_)), vec![ (src1, get_scalar_type(self.context, data.type_)), @@ -2316,9 +2407,8 @@ impl<'a> MethodEmitContext<'a> { ], )?; let shift_number = unsafe { LLVMConstInt(LLVMInt32TypeInContext(self.context), 16, 0) }; - let res_lo_shr = unsafe { - LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr()) - }; + let res_lo_shr = + unsafe { LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr()) }; let res_hi_shl = unsafe { LLVMBuildShl(self.builder, res_hi, shift_number, LLVM_UNNAMED.as_ptr()) }; @@ -2381,6 +2471,74 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_fp_saturate( + &mut self, + type_: ast::ScalarType, + dst: SpirvWord, + src: SpirvWord, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); + let zero = unsafe { LLVMConstReal(llvm_type, 0.0) }; + let one = unsafe { LLVMConstReal(llvm_type, 1.0) }; + let maxnum_intrinsic = format!("llvm.maxnum.{}\0", LLVMTypeDisplay(type_)); + let minnum_intrinsic = format!("llvm.minnum.{}\0", LLVMTypeDisplay(type_)); + let src = self.resolver.value(src)?; + let maxnum = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(maxnum_intrinsic.as_bytes()) }, + None, + Some(&type_.into()), + vec![(src, llvm_type), (zero, llvm_type)], + )?; + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(minnum_intrinsic.as_bytes()) }, + Some(dst), + Some(&type_.into()), + vec![(maxnum, llvm_type), (one, llvm_type)], + )?; + Ok(()) + } + + fn emit_add_sat( + &mut self, + type_: ast::ScalarType, + arguments: ast::AddArgs, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let op = if type_.kind() == ast::ScalarKind::Signed { + "sadd" + } else { + "uadd" + }; + let intrinsic = format!("llvm.{}.sat.{}\0", op, LLVMTypeDisplay(type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&type_.into()), + vec![(src1, llvm_type), (src2, llvm_type)], + )?; + Ok(()) + } + + fn flush_denormals( + &mut self, + type_: ptx_parser::ScalarType, + src: SpirvWord, + dst: SpirvWord, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); + let src = self.resolver.value(src)?; + let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(dst), + Some(&type_.into()), + vec![(src, llvm_type)], + )?; + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs new file mode 100644 index 0000000..cc9afa7 --- /dev/null +++ b/ptx/src/pass/insert_post_saturation.rs @@ -0,0 +1,296 @@ +use super::*; +use ptx_parser as ast; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2, + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2, + method: Function2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + let mut new_statements = Vec::new(); + let body = method + .body + .map(|statements| { + for statement in statements { + run_statement(resolver, &mut new_statements, statement)?; + } + Ok::<_, TranslateError>(new_statements) + }) + .transpose()?; + Ok(Function2 { body, ..method }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, SpirvWord>>, + statement: Statement, SpirvWord>, +) -> Result<(), TranslateError> { + match statement { + Statement::Instruction(inst) => run_instruction(resolver, result, inst)?, + statement => { + result.push(statement); + } + } + Ok(()) +} + +fn run_instruction<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, SpirvWord>>, + mut instruction: ast::Instruction, +) -> Result<(), TranslateError> { + match instruction { + ast::Instruction::Abs { .. } + | ast::Instruction::Activemask { .. } + | ast::Instruction::Add { + data: + ast::ArithDetails::Float(ast::ArithFloat { + saturate: false, .. + }), + .. + } + | ast::Instruction::Add { + data: ast::ArithDetails::Integer(..), + .. + } + | ast::Instruction::And { .. } + | ast::Instruction::Atom { .. } + | ast::Instruction::AtomCas { .. } + | ast::Instruction::Bar { .. } + | ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Bra { .. } + | ast::Instruction::Brev { .. } + | ast::Instruction::Call { .. } + | ast::Instruction::Clz { .. } + | ast::Instruction::Cos { .. } + | ast::Instruction::Cvt { + data: + ast::CvtDetails { + mode: + ast::CvtMode::ZeroExtend + | ast::CvtMode::SignExtend + | ast::CvtMode::Truncate + | ast::CvtMode::Bitcast + | ast::CvtMode::IntSaturateToSigned + | ast::CvtMode::IntSaturateToUnsigned + | ast::CvtMode::SignedFromFP { .. } + | ast::CvtMode::UnsignedFromFP { .. } + | ast::CvtMode::FPFromSigned { + saturate: false, .. + } + | ast::CvtMode::FPFromUnsigned { + saturate: false, .. + } + | ast::CvtMode::FPExtend { + saturate: false, .. + } + | ast::CvtMode::FPTruncate { + saturate: false, .. + } + | ast::CvtMode::FPRound { + saturate: false, .. + }, + .. + }, + .. + } + | ast::Instruction::Cvta { .. } + | ast::Instruction::Div { .. } + | ast::Instruction::Ex2 { .. } + | ast::Instruction::Fma { + data: ast::ArithFloat { + saturate: false, .. + }, + .. + } + | ast::Instruction::Ld { .. } + | ast::Instruction::Lg2 { .. } + | ast::Instruction::Mad { + data: + ast::MadDetails::Float(ast::ArithFloat { + saturate: false, .. + }), + .. + } + | ast::Instruction::Mad { + data: ast::MadDetails::Integer { .. }, + .. + } + | ast::Instruction::Max { .. } + | ast::Instruction::Membar { .. } + | ast::Instruction::Min { .. } + | ast::Instruction::Mov { .. } + | ast::Instruction::Mul { + data: + ast::MulDetails::Float(ast::ArithFloat { + saturate: false, .. + }), + .. + } + | ast::Instruction::Mul { + data: ast::MulDetails::Integer { .. }, + .. + } + | ast::Instruction::Mul24 { .. } + | ast::Instruction::Neg { .. } + | ast::Instruction::Not { .. } + | ast::Instruction::Or { .. } + | ast::Instruction::Popc { .. } + | ast::Instruction::Prmt { .. } + | ast::Instruction::PrmtSlow { .. } + | ast::Instruction::Rcp { .. } + | ast::Instruction::Rem { .. } + | ast::Instruction::Ret { .. } + | ast::Instruction::Rsqrt { .. } + | ast::Instruction::Selp { .. } + | ast::Instruction::Setp { .. } + | ast::Instruction::SetpBool { .. } + | ast::Instruction::Shl { .. } + | ast::Instruction::Shr { .. } + | ast::Instruction::Sin { .. } + | ast::Instruction::Sqrt { .. } + | ast::Instruction::St { .. } + | ast::Instruction::Sub { + data: + ast::ArithDetails::Float(ast::ArithFloat { + saturate: false, .. + }), + .. + } + | ast::Instruction::Sub { + data: ast::ArithDetails::Integer(..), + .. + } + | ast::Instruction::Trap {} + | ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)), + ast::Instruction::Add { + data: + ast::ArithDetails::Float(ast::ArithFloat { + saturate: true, + type_, + .. + }), + arguments: ast::AddArgs { ref mut dst, .. }, + } + | ast::Instruction::Fma { + data: + ast::ArithFloat { + saturate: true, + type_, + .. + }, + arguments: ast::FmaArgs { ref mut dst, .. }, + } + | ast::Instruction::Mad { + data: + ast::MadDetails::Float(ast::ArithFloat { + saturate: true, + type_, + .. + }), + arguments: ast::MadArgs { ref mut dst, .. }, + } + | ast::Instruction::Mul { + data: + ast::MulDetails::Float(ast::ArithFloat { + saturate: true, + type_, + .. + }), + arguments: ast::MulArgs { ref mut dst, .. }, + } + | ast::Instruction::Sub { + data: + ast::ArithDetails::Float(ast::ArithFloat { + saturate: true, + type_, + .. + }), + arguments: ast::SubArgs { ref mut dst, .. }, + } + | ast::Instruction::Cvt { + data: + ast::CvtDetails { + to: type_, + mode: ast::CvtMode::FPExtend { saturate: true, .. }, + .. + }, + arguments: ast::CvtArgs { ref mut dst, .. }, + } + | ast::Instruction::Cvt { + data: + ast::CvtDetails { + to: type_, + mode: ast::CvtMode::FPTruncate { saturate: true, .. }, + .. + }, + arguments: ast::CvtArgs { ref mut dst, .. }, + } + | ast::Instruction::Cvt { + data: + ast::CvtDetails { + to: type_, + mode: ast::CvtMode::FPRound { saturate: true, .. }, + .. + }, + arguments: ast::CvtArgs { ref mut dst, .. }, + } + | ast::Instruction::Cvt { + data: + ast::CvtDetails { + to: type_, + mode: ast::CvtMode::FPFromSigned { saturate: true, .. }, + .. + }, + arguments: ast::CvtArgs { ref mut dst, .. }, + } + | ast::Instruction::Cvt { + data: + ast::CvtDetails { + to: type_, + mode: ast::CvtMode::FPFromUnsigned { saturate: true, .. }, + .. + }, + arguments: ast::CvtArgs { ref mut dst, .. }, + } => { + let sat = get_post_saturation(resolver, type_, dst)?; + result.push(Statement::Instruction(instruction)); + result.push(sat); + } + } + Ok(()) +} + +fn get_post_saturation<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + type_: ast::ScalarType, + old_dst: &mut SpirvWord, +) -> Result, SpirvWord>, TranslateError> { + let post_sat = resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))); + let dst = *old_dst; + *old_dst = post_sat; + Ok(Statement::FpSaturate { + dst, + src: post_sat, + type_, + }) +} diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index c2b9672..fdaafd1 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -167,26 +167,42 @@ impl InstructionModes { } } - fn mixed_ftz_f32( - type_: ast::ScalarType, - denormal: Option, - rounding: Option, + fn from_typed_denormal_rounding( + from_type: ast::ScalarType, + to_type: ast::ScalarType, + denormal: DenormalMode, + rounding: RoundingMode, ) -> Self { - if type_ != ast::ScalarType::F32 { - Self { - denormal_f16f64: denormal, - rounding_f32: rounding, - ..Self::none() - } - } else { - Self { - denormal_f32: denormal, - rounding_f32: rounding, - ..Self::none() - } + Self { + rounding_f32: Some(rounding), + rounding_f16f64: Some(rounding), + ..Self::from_typed_denormal(from_type, to_type, denormal) } } + // This function accepts DenormalMode and not Option because + // the semantics are slightly different. + // * In instructions `None` means: flush-to-zero has not been explicitly requested + // * In this pass `None` means: neither flush-to-zero, nor preserve is applicable + fn from_typed_denormal( + from_type: ast::ScalarType, + to_type: ast::ScalarType, + denormal: DenormalMode, + ) -> Self { + let mut result = Self::none(); + if from_type == ast::ScalarType::F32 || to_type == ast::ScalarType::F32 { + result.denormal_f32 = if denormal == DenormalMode::FlushToZero { + Some(DenormalMode::FlushToZero) + } else { + Some(DenormalMode::Preserve) + }; + } + if !(from_type == ast::ScalarType::F32 && to_type == ast::ScalarType::F32) { + result.denormal_f16f64 = Some(DenormalMode::Preserve); + } + result + } + fn from_arith_float(arith: &ast::ArithFloat) -> InstructionModes { let denormal = arith.flush_to_zero.map(DenormalMode::from_ftz); let rounding = Some(RoundingMode::from_ast(arith.rounding)); @@ -220,31 +236,52 @@ impl InstructionModes { | ast::CvtMode::SignExtend | ast::CvtMode::Truncate | ast::CvtMode::Bitcast - | ast::CvtMode::SaturateUnsignedToSigned - | ast::CvtMode::SaturateSignedToUnsigned => Self::none(), - ast::CvtMode::FPExtend { flush_to_zero } => { - Self::from_ftz(ast::ScalarType::F32, flush_to_zero) - } + | ast::CvtMode::IntSaturateToSigned + | ast::CvtMode::IntSaturateToUnsigned => Self::none(), + ast::CvtMode::FPExtend { flush_to_zero, .. } => Self::from_typed_denormal( + cvt.from, + cvt.to, + flush_to_zero + .map(DenormalMode::from_ftz) + .unwrap_or(DenormalMode::Preserve), + ), ast::CvtMode::FPTruncate { rounding, flush_to_zero, + is_integer_rounding, + .. + } => { + let denormal_mode = match (is_integer_rounding, flush_to_zero) { + (true, Some(true)) => DenormalMode::FlushToZero, + _ => DenormalMode::Preserve, + }; + Self::from_typed_denormal_rounding( + cvt.from, + cvt.to, + denormal_mode, + RoundingMode::from_ast(rounding), + ) } - | ast::CvtMode::FPRound { - integer_rounding: rounding, - flush_to_zero, - } => Self::mixed_ftz_f32( + ast::CvtMode::FPRound { flush_to_zero, .. } => Self::from_typed_denormal( + cvt.from, cvt.to, - flush_to_zero.map(DenormalMode::from_ftz), - Some(RoundingMode::from_ast(rounding)), + flush_to_zero + .map(DenormalMode::from_ftz) + .unwrap_or(DenormalMode::Preserve), ), // float to int contains rounding field, but it's not a rounding // mode but rather round-to-int operation that will be applied ast::CvtMode::SignedFromFP { flush_to_zero, .. } - | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => { - Self::new(cvt.from, flush_to_zero.map(DenormalMode::from_ftz), None) - } - ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => { - Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd))) + | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => Self::from_typed_denormal( + cvt.from, + cvt.from, + flush_to_zero + .map(DenormalMode::from_ftz) + .unwrap_or(DenormalMode::Preserve), + ), + ast::CvtMode::FPFromSigned { rounding, .. } + | ast::CvtMode::FPFromUnsigned { rounding, .. } => { + Self::new(cvt.to, None, Some(RoundingMode::from_ast(rounding))) } } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 77d7e60..6b2042e 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -17,8 +17,9 @@ mod expand_operands; mod fix_special_registers2; mod hoist_globals; mod insert_explicit_load_store; -mod instruction_mode_to_global_mode; mod insert_implicit_conversions2; +mod insert_post_saturation; +mod instruction_mode_to_global_mode; mod normalize_basic_blocks; mod normalize_identifiers2; mod normalize_predicates2; @@ -51,6 +52,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result { VectorRead(VectorRead), VectorWrite(VectorWrite), SetMode(ModeRegister), + FpSaturate { + dst: SpirvWord, + src: SpirvWord, + type_: ast::ScalarType, + }, } #[derive(Eq, PartialEq, Clone, Copy)] @@ -488,6 +495,21 @@ impl> Statement, T> { Statement::FunctionPointer(FunctionPointerDetails { dst, src }) } Statement::SetMode(mode_register) => Statement::SetMode(mode_register), + Statement::FpSaturate { dst, src, type_ } => { + let dst = visitor.visit_ident( + dst, + Some((&type_.into(), ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + src, + Some((&type_.into(), ast::StateSpace::Reg)), + false, + false, + )?; + Statement::FpSaturate { dst, src, type_ } + } }) } } diff --git a/ptx/src/test/ll/add_s32_sat.ll b/ptx/src/test/ll/add_s32_sat.ll new file mode 100644 index 0000000..d50ae8d --- /dev/null +++ b/ptx/src/test/ll/add_s32_sat.ll @@ -0,0 +1,51 @@ +define amdgpu_kernel void @add_s32_sat(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"45" = load i64, ptr addrspace(4) %"37", align 4 + store i64 %"45", ptr addrspace(5) %"39", align 4 + %"46" = load i64, ptr addrspace(4) %"38", align 4 + store i64 %"46", ptr addrspace(5) %"40", align 4 + %"48" = load i64, ptr addrspace(5) %"39", align 4 + %"61" = inttoptr i64 %"48" to ptr + %"47" = load i32, ptr %"61", align 4 + store i32 %"47", ptr addrspace(5) %"41", align 4 + %"49" = load i64, ptr addrspace(5) %"39", align 4 + %"62" = inttoptr i64 %"49" to ptr + %"33" = getelementptr inbounds i8, ptr %"62", i64 4 + %"50" = load i32, ptr %"33", align 4 + store i32 %"50", ptr addrspace(5) %"42", align 4 + %"52" = load i32, ptr addrspace(5) %"41", align 4 + %"53" = load i32, ptr addrspace(5) %"42", align 4 + %"51" = call i32 @llvm.sadd.sat.i32(i32 %"52", i32 %"53") + store i32 %"51", ptr addrspace(5) %"43", align 4 + %"55" = load i32, ptr addrspace(5) %"41", align 4 + %"56" = load i32, ptr addrspace(5) %"42", align 4 + %"54" = add i32 %"55", %"56" + store i32 %"54", ptr addrspace(5) %"44", align 4 + %"57" = load i64, ptr addrspace(5) %"40", align 4 + %"58" = load i32, ptr addrspace(5) %"43", align 4 + %"63" = inttoptr i64 %"57" to ptr + store i32 %"58", ptr %"63", align 4 + %"59" = load i64, ptr addrspace(5) %"40", align 4 + %"64" = inttoptr i64 %"59" to ptr + %"35" = getelementptr inbounds i8, ptr %"64", i64 4 + %"60" = load i32, ptr addrspace(5) %"44", align 4 + store i32 %"60", ptr %"35", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.sadd.sat.i32(i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/ll/cvt_rni_u16_f32.ll b/ptx/src/test/ll/cvt_rni_u16_f32.ll new file mode 100644 index 0000000..7b66751 --- /dev/null +++ b/ptx/src/test/ll/cvt_rni_u16_f32.ll @@ -0,0 +1,38 @@ +define amdgpu_kernel void @cvt_rni_u16_f32(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #0 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca float, align 4, addrspace(5) + %"36" = alloca i16, align 2, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"37" = load i64, ptr addrspace(4) %"31", align 4 + store i64 %"37", ptr addrspace(5) %"33", align 4 + %"38" = load i64, ptr addrspace(4) %"32", align 4 + store i64 %"38", ptr addrspace(5) %"34", align 4 + %"40" = load i64, ptr addrspace(5) %"33", align 4 + %"45" = inttoptr i64 %"40" to ptr addrspace(1) + %"39" = load float, ptr addrspace(1) %"45", align 4 + store float %"39", ptr addrspace(5) %"35", align 4 + %"42" = load float, ptr addrspace(5) %"35", align 4 + %2 = call float @llvm.roundeven.f32(float %"42") + %"41" = call i16 @llvm.fptoui.sat.i16.f32(float %2) + store i16 %"41", ptr addrspace(5) %"36", align 2 + %"43" = load i64, ptr addrspace(5) %"34", align 4 + %"44" = load i16, ptr addrspace(5) %"36", align 2 + %"46" = inttoptr i64 %"43" to ptr + store i16 %"44", ptr %"46", align 2 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare float @llvm.roundeven.f32(float) #1 + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i16 @llvm.fptoui.sat.i16.f32(float) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/add_s32_sat.ptx b/ptx/src/test/spirv_run/add_s32_sat.ptx new file mode 100644 index 0000000..8ed9e05 --- /dev/null +++ b/ptx/src/test/spirv_run/add_s32_sat.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry add_s32_sat( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp<4>; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp0, [in_addr]; + ld.s32 temp1, [in_addr+4]; + add.sat.s32 temp2, temp0, temp1; + add.s32 temp3, temp0, temp1; + st.s32 [out_addr], temp2; + st.s32 [out_addr+4], temp3; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_rni_u16_f32.ptx b/ptx/src/test/spirv_run/cvt_rni_u16_f32.ptx new file mode 100644 index 0000000..baf7bca --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rni_u16_f32.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry cvt_rni_u16_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp_f32; + .reg .u16 temp_u16; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.f32 temp_f32, [in_addr]; + cvt.rni.u16.f32 temp_u16, temp_f32; + st.u16 [out_addr], temp_u16; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f2c8ffa..424a1b8 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -147,6 +147,7 @@ test_ptx!(ex2, [10f32], [1024f32]); test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]); test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]); test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]); +test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]); test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); test_ptx!( @@ -226,6 +227,7 @@ test_ptx!( [f32::from_bits(0x800000), f32::from_bits(0x007FFFFF)], [0x800000u32, 0xFFFFFF] ); +test_ptx!(add_s32_sat, [i32::MIN, -1], [i32::MIN, i32::MAX]); test_ptx!(malformed_label, [2u64], [3u64]); test_ptx!( call_rnd, diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 4e2502d..ca7b9df 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -2,7 +2,7 @@ use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; -use crate::{PtxError, PtxParserState, Mul24Control}; +use crate::{Mul24Control, PtxError, PtxParserState}; use bitflags::bitflags; use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; @@ -1197,7 +1197,6 @@ pub enum MulIntControl { Wide, } - #[derive(Copy, Clone)] pub struct Mul24Details { pub type_: ScalarType, @@ -1473,20 +1472,24 @@ pub enum CvtMode { SignExtend, Truncate, Bitcast, - SaturateUnsignedToSigned, - SaturateSignedToUnsigned, + IntSaturateToSigned, + IntSaturateToUnsigned, // float from float FPExtend { flush_to_zero: Option, + saturate: bool, }, FPTruncate { // float rounding rounding: RoundingMode, + is_integer_rounding: bool, flush_to_zero: Option, + saturate: bool, }, FPRound { - integer_rounding: RoundingMode, + integer_rounding: Option, flush_to_zero: Option, + saturate: bool, }, // int from float SignedFromFP { @@ -1498,8 +1501,14 @@ pub enum CvtMode { flush_to_zero: Option, }, // integer rounding // float from int, ftz is allowed in the grammar, but clearly nonsensical - FPFromSigned(RoundingMode), // float rounding - FPFromUnsigned(RoundingMode), // float rounding + FPFromSigned { + rounding: RoundingMode, + saturate: bool, + }, // float rounding + FPFromUnsigned { + rounding: RoundingMode, + saturate: bool, + }, // float rounding } impl CvtDetails { @@ -1511,9 +1520,6 @@ impl CvtDetails { dst: ScalarType, src: ScalarType, ) -> Self { - if saturate && dst.kind() == ScalarKind::Float { - errors.push(PtxError::SyntaxError); - } // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. let flush_to_zero = match (dst, src) { (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz), @@ -1524,55 +1530,81 @@ impl CvtDetails { None } }; - let rounding = rnd.map(Into::into); + let rounding = rnd.map(RawRoundingMode::normalize); let mut unwrap_rounding = || match rounding { - Some(rnd) => rnd, + Some((rnd, is_integer)) => (rnd, is_integer), None => { errors.push(PtxError::SyntaxError); - RoundingMode::NearestEven + (RoundingMode::NearestEven, false) } }; let mode = match (dst.kind(), src.kind()) { (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { - Ordering::Less => CvtMode::FPTruncate { - rounding: unwrap_rounding(), - flush_to_zero, - }, + Ordering::Less => { + let (rounding, is_integer_rounding) = unwrap_rounding(); + CvtMode::FPTruncate { + rounding, + is_integer_rounding, + flush_to_zero, + saturate, + } + } Ordering::Equal => CvtMode::FPRound { - integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven), + integer_rounding: rounding.map(|(rnd, _)| rnd), flush_to_zero, + saturate, }, Ordering::Greater => { if rounding.is_some() { errors.push(PtxError::SyntaxError); } - CvtMode::FPExtend { flush_to_zero } + CvtMode::FPExtend { + flush_to_zero, + saturate, + } } }, (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP { - rounding: unwrap_rounding(), + rounding: unwrap_rounding().0, flush_to_zero, }, (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP { - rounding: unwrap_rounding(), + rounding: unwrap_rounding().0, flush_to_zero, }, - (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), - (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), - (ScalarKind::Signed, ScalarKind::Unsigned) if saturate => { - CvtMode::SaturateUnsignedToSigned - } - (ScalarKind::Unsigned, ScalarKind::Signed) if saturate => { - CvtMode::SaturateSignedToUnsigned + (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned { + rounding: unwrap_rounding().0, + saturate, + }, + (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned { + rounding: unwrap_rounding().0, + saturate, + }, + (ScalarKind::Signed, ScalarKind::Unsigned) + | (ScalarKind::Signed, ScalarKind::Signed) + if saturate => + { + CvtMode::IntSaturateToSigned } (ScalarKind::Unsigned, ScalarKind::Signed) + | (ScalarKind::Unsigned, ScalarKind::Unsigned) + if saturate => + { + CvtMode::IntSaturateToUnsigned + } + (ScalarKind::Unsigned, ScalarKind::Unsigned) + | (ScalarKind::Signed, ScalarKind::Signed) + | (ScalarKind::Unsigned, ScalarKind::Signed) | (ScalarKind::Signed, ScalarKind::Unsigned) if dst.size_of() == src.size_of() => { CvtMode::Bitcast } (ScalarKind::Unsigned, ScalarKind::Unsigned) - | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) { + | (ScalarKind::Signed, ScalarKind::Signed) + | (ScalarKind::Unsigned, ScalarKind::Signed) + | (ScalarKind::Signed, ScalarKind::Unsigned) => match dst.size_of().cmp(&src.size_of()) + { Ordering::Less => CvtMode::Truncate, Ordering::Equal => CvtMode::Bitcast, Ordering::Greater => { @@ -1583,7 +1615,6 @@ impl CvtDetails { } } }, - (ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned, (_, _) => { errors.push(PtxError::SyntaxError); CvtMode::Bitcast diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index e4c07cc..13764e7 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -64,11 +64,21 @@ impl From for ast::LdStQualifier { impl From for ast::RoundingMode { fn from(value: RawRoundingMode) -> Self { - match value { - RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven, - RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero, - RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf, - RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf, + value.normalize().0 + } +} + +impl RawRoundingMode { + fn normalize(self) -> (ast::RoundingMode, bool) { + match self { + RawRoundingMode::Rn => (ast::RoundingMode::NearestEven, false), + RawRoundingMode::Rz => (ast::RoundingMode::Zero, false), + RawRoundingMode::Rm => (ast::RoundingMode::NegativeInf, false), + RawRoundingMode::Rp => (ast::RoundingMode::PositiveInf, false), + RawRoundingMode::Rni => (ast::RoundingMode::NearestEven, true), + RawRoundingMode::Rzi => (ast::RoundingMode::Zero, true), + RawRoundingMode::Rmi => (ast::RoundingMode::NegativeInf, true), + RawRoundingMode::Rpi => (ast::RoundingMode::PositiveInf, true), } } }