From 3105674618b790214ab629bd28162bcc27d8827a Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 15 Oct 2024 18:05:32 +0200 Subject: [PATCH] Add prmt, membar, fix some of cvt --- llvm_zluda/src/lib.cpp | 9 ++ llvm_zluda/src/lib.rs | 7 + ptx/src/pass/emit_llvm.rs | 275 ++++++++++++++++++++++++++------------ 3 files changed, 204 insertions(+), 87 deletions(-) diff --git a/llvm_zluda/src/lib.cpp b/llvm_zluda/src/lib.cpp index 073dba7..886aa0d 100644 --- a/llvm_zluda/src/lib.cpp +++ b/llvm_zluda/src/lib.cpp @@ -183,4 +183,13 @@ void LLVMZludaSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF) cast(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF)); } +void LLVMZludaBuildFence(LLVMBuilderRef B, LLVMAtomicOrdering Ordering, + char *scope, const char *Name) +{ + auto builder = llvm::unwrap(B); + LLVMContext &context = builder->getContext(); + builder->CreateFence(mapFromLLVMOrdering(Ordering), + context.getOrInsertSyncScopeID(scope)); +} + LLVM_C_EXTERN_C_END \ No newline at end of file diff --git a/llvm_zluda/src/lib.rs b/llvm_zluda/src/lib.rs index afcfd89..fb5cc47 100644 --- a/llvm_zluda/src/lib.rs +++ b/llvm_zluda/src/lib.rs @@ -71,4 +71,11 @@ extern "C" { ) -> LLVMValueRef; pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags); + + pub fn LLVMZludaBuildFence( + B: LLVMBuilderRef, + ordering: LLVMAtomicOrdering, + scope: *const i8, + Name: *const i8, + ) -> LLVMValueRef; } diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 209840f..54a07aa 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -385,20 +385,22 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { | ptx_parser::ScalarType::U16 => unsafe { LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) }, + ptx_parser::ScalarType::S32 + | ptx_parser::ScalarType::B32 + | ptx_parser::ScalarType::U32 => unsafe { + LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0) + }, ptx_parser::ScalarType::F16 => todo!(), ptx_parser::ScalarType::BF16 => todo!(), - ptx_parser::ScalarType::S32 => todo!(), ptx_parser::ScalarType::U64 => todo!(), ptx_parser::ScalarType::S64 => todo!(), ptx_parser::ScalarType::S16x2 => todo!(), - ptx_parser::ScalarType::B32 => todo!(), ptx_parser::ScalarType::F32 => todo!(), ptx_parser::ScalarType::B64 => todo!(), ptx_parser::ScalarType::F64 => todo!(), ptx_parser::ScalarType::B128 => todo!(), ptx_parser::ScalarType::U16x2 => todo!(), ptx_parser::ScalarType::F16x2 => todo!(), - ptx_parser::ScalarType::U32 => todo!(), ptx_parser::ScalarType::BF16x2 => todo!(), }) } @@ -552,8 +554,8 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments), ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments), ast::Instruction::PrmtSlow { .. } => todo!(), - ast::Instruction::Prmt { .. } => todo!(), - ast::Instruction::Membar { .. } => todo!(), + ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments), + ast::Instruction::Membar { data } => self.emit_membar(data), ast::Instruction::Trap {} => todo!(), // replaced by a function call ast::Instruction::Bfe { .. } @@ -582,88 +584,14 @@ impl<'a> MethodEmitContext<'a> { fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> { let builder = self.builder; match conversion.kind { - ConversionKind::Default => { - match (&conversion.from_type, &conversion.to_type) { - (ast::Type::Scalar(from_type), ast::Type::Scalar(to_type)) => { - let from_layout = conversion.from_type.layout(); - let to_layout = conversion.to_type.layout(); - if from_layout.size() == to_layout.size() { - let dst_type = get_type(self.context, &conversion.to_type)?; - if from_type.kind() != ast::ScalarKind::Float - && to_type.kind() != ast::ScalarKind::Float - { - // It is noop, but another instruction expects result of this conversion - self.resolver - .register(conversion.dst, self.resolver.value(conversion.src)?); - } else { - let src = self.resolver.value(conversion.src)?; - self.resolver.with_result(conversion.dst, |dst| unsafe { - LLVMBuildBitCast(builder, src, dst_type, dst) - }); - } - Ok(()) - } else { - let src = self.resolver.value(conversion.src)?; - // This block is safe because it's illegal to implictly convert between floating point values - let same_width_bit_type = unsafe { - LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32) - }; - let same_width_bit_value = unsafe { - LLVMBuildBitCast( - builder, - src, - same_width_bit_type, - LLVM_UNNAMED.as_ptr(), - ) - }; - let wide_bit_type = unsafe { - LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32) - }; - if to_type.kind() == ast::ScalarKind::Unsigned - || to_type.kind() == ast::ScalarKind::Bit - { - let llvm_fn = if to_type.size_of() >= from_type.size_of() { - LLVMBuildZExtOrBitCast - } else { - LLVMBuildTrunc - }; - self.resolver.with_result(conversion.dst, |dst| unsafe { - llvm_fn(builder, same_width_bit_value, wide_bit_type, dst) - }); - Ok(()) - } else { - let _conversion_fn = if from_type.kind() == ast::ScalarKind::Signed - && to_type.kind() == ast::ScalarKind::Signed - { - if to_type.size_of() >= from_type.size_of() { - LLVMBuildSExtOrBitCast - } else { - LLVMBuildTrunc - } - } else { - if to_type.size_of() >= from_type.size_of() { - LLVMBuildZExtOrBitCast - } else { - LLVMBuildTrunc - } - }; - todo!() - } - } - } - (ast::Type::Vector(..), ast::Type::Scalar(..)) - | (ast::Type::Scalar(..), ast::Type::Array(..)) - | (ast::Type::Array(..), ast::Type::Scalar(..)) => { - let src = self.resolver.value(conversion.src)?; - let dst_type = get_type(self.context, &conversion.to_type)?; - self.resolver.with_result(conversion.dst, |dst| unsafe { - LLVMBuildBitCast(builder, src, dst_type, dst) - }); - Ok(()) - } - _ => todo!(), - } - } + ConversionKind::Default => self.emit_conversion_default( + self.resolver.value(conversion.src)?, + conversion.dst, + &conversion.from_type, + conversion.from_space, + &conversion.to_type, + conversion.to_space, + ), ConversionKind::SignExtend => { let src = self.resolver.value(conversion.src)?; let type_ = get_type(self.context, &conversion.to_type)?; @@ -699,6 +627,115 @@ impl<'a> MethodEmitContext<'a> { } } + fn emit_conversion_default( + &mut self, + src: LLVMValueRef, + dst: SpirvWord, + from_type: &ast::Type, + from_space: ast::StateSpace, + to_type: &ast::Type, + to_space: ast::StateSpace, + ) -> Result<(), TranslateError> { + match (from_type, to_type) { + (ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => { + let from_layout = from_type.layout(); + let to_layout = to_type.layout(); + if from_layout.size() == to_layout.size() { + let dst_type = get_type(self.context, &to_type)?; + if from_type.kind() != ast::ScalarKind::Float + && to_type_scalar.kind() != ast::ScalarKind::Float + { + // It is noop, but another instruction expects result of this conversion + self.resolver.register(dst, src); + } else { + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildBitCast(self.builder, src, dst_type, dst) + }); + } + Ok(()) + } else { + // This block is safe because it's illegal to implictly convert between floating point values + let same_width_bit_type = unsafe { + LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32) + }; + let same_width_bit_value = unsafe { + LLVMBuildBitCast( + self.builder, + src, + same_width_bit_type, + LLVM_UNNAMED.as_ptr(), + ) + }; + let wide_bit_type = match to_type_scalar.layout().size() { + 1 => ast::ScalarType::B8, + 2 => ast::ScalarType::B16, + 4 => ast::ScalarType::B32, + 8 => ast::ScalarType::B64, + _ => return Err(error_unreachable()), + }; + let wide_bit_type_llvm = unsafe { + LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32) + }; + if to_type_scalar.kind() == ast::ScalarKind::Unsigned + || to_type_scalar.kind() == ast::ScalarKind::Bit + { + let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + }; + self.resolver.with_result(dst, |dst| unsafe { + llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst) + }); + Ok(()) + } else { + let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed + && to_type_scalar.kind() == ast::ScalarKind::Signed + { + if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildSExtOrBitCast + } else { + LLVMBuildTrunc + } + } else { + if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + } + }; + let wide_bit_value = unsafe { + conversion_fn( + self.builder, + same_width_bit_value, + wide_bit_type_llvm, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.emit_conversion_default( + wide_bit_value, + dst, + &wide_bit_type.into(), + from_space, + to_type, + to_space, + ) + } + } + } + (ast::Type::Vector(..), ast::Type::Scalar(..)) + | (ast::Type::Scalar(..), ast::Type::Array(..)) + | (ast::Type::Array(..), ast::Type::Scalar(..)) => { + let dst_type = get_type(self.context, to_type)?; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildBitCast(self.builder, src, dst_type, dst) + }); + Ok(()) + } + _ => todo!(), + } + } + fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> { let type_ = get_scalar_type(self.context, constant.typ); let value = match constant.value { @@ -1879,6 +1916,60 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> { + unsafe { + LLVMZludaBuildFence( + self.builder, + LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent, + get_scope_membar(data)?, + LLVM_UNNAMED.as_ptr(), + ) + }; + Ok(()) + } + + fn emit_prmt( + &mut self, + control: u16, + arguments: ptx_parser::PrmtArgs, + ) -> Result<(), TranslateError> { + let components = [ + (control >> 0) & 0b1111, + (control >> 4) & 0b1111, + (control >> 8) & 0b1111, + (control >> 12) & 0b1111, + ]; + if components.iter().any(|&c| c > 7) { + return Err(TranslateError::Todo); + } + let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); + let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; + let mut components = [ + unsafe { LLVMConstInt(u32_type, components[0] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[1] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[2] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[3] as _, 0) }, + ]; + let components_indices = + unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) }; + let src1 = self.resolver.value(arguments.src1)?; + let src1_vector = + unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) }; + let src2 = self.resolver.value(arguments.src2)?; + let src2_vector = + unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildShuffleVector( + self.builder, + src1_vector, + src2_vector, + components_indices, + dst, + ) + }); + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 @@ -1964,6 +2055,16 @@ fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> { .as_ptr()) } +fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> { + Ok(match scope { + ast::MemScope::Cta => c"workgroup", + ast::MemScope::Gpu => c"agent", + ast::MemScope::Sys => c"", + ast::MemScope::Cluster => todo!(), + } + .as_ptr()) +} + fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { match semantics { ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,