diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 154e4e6..40b11a4 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -2063,45 +2063,50 @@ impl<'a> MethodEmitContext<'a> { *const i8, ) -> LLVMValueRef, ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); let src1 = self.resolver.value(src1)?; let shift_size = self.resolver.value(src2)?; - let integer_bits = type_.layout().size() * 8; - let integer_bits_constant = unsafe { - LLVMConstInt( - get_scalar_type(self.context, ast::ScalarType::U32), - integer_bits as u64, - 0, - ) - }; - let should_clamp = unsafe { - LLVMBuildICmp( - self.builder, - LLVMIntPredicate::LLVMIntUGE, - shift_size, - integer_bits_constant, - LLVM_UNNAMED.as_ptr(), - ) - }; - let llvm_type = get_scalar_type(self.context, type_); - let zero = unsafe { LLVMConstNull(llvm_type) }; - let normalized_shift_size = if type_.layout().size() >= 4 { + let typed_shift_size = if type_.layout().size() >= 4 { unsafe { LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } } else { unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } }; - let shifted = unsafe { - llvm_fn( + let integer_bits = type_.layout().size() * 8; + let integer_bits_constant = unsafe { LLVMConstInt(llvm_type, integer_bits as u64, 0) }; + let should_clamp = unsafe { + LLVMBuildICmp( self.builder, - src1, - normalized_shift_size, + LLVMIntPredicate::LLVMIntUGE, + typed_shift_size, + integer_bits_constant, LLVM_UNNAMED.as_ptr(), ) }; - self.resolver.with_result(dst, |dst| unsafe { - LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst) - }); + if type_.kind() == ast::ScalarKind::Signed { + let integer_bits_constant_minus_one = + unsafe { LLVMConstInt(llvm_type, integer_bits as u64 - 1, 0) }; + let clamped_shift_size = unsafe { + LLVMBuildSelect( + self.builder, + should_clamp, + integer_bits_constant_minus_one, + typed_shift_size, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.resolver.with_result(dst, |dst| unsafe { + llvm_fn(self.builder, src1, clamped_shift_size, dst) + }); + } else { + let shifted_value = + unsafe { llvm_fn(self.builder, src1, typed_shift_size, LLVM_UNNAMED.as_ptr()) }; + let zero = unsafe { LLVMConstNull(llvm_type) }; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildSelect(self.builder, should_clamp, zero, shifted_value, dst) + }); + } Ok(()) }