mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-19 18:26:26 +03:00
Add prmt, membar, fix some of cvt
This commit is contained in:
@ -183,4 +183,13 @@ void LLVMZludaSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF)
|
|||||||
cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
|
cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LLVMZludaBuildFence(LLVMBuilderRef B, LLVMAtomicOrdering Ordering,
|
||||||
|
char *scope, const char *Name)
|
||||||
|
{
|
||||||
|
auto builder = llvm::unwrap(B);
|
||||||
|
LLVMContext &context = builder->getContext();
|
||||||
|
builder->CreateFence(mapFromLLVMOrdering(Ordering),
|
||||||
|
context.getOrInsertSyncScopeID(scope));
|
||||||
|
}
|
||||||
|
|
||||||
LLVM_C_EXTERN_C_END
|
LLVM_C_EXTERN_C_END
|
@ -71,4 +71,11 @@ extern "C" {
|
|||||||
) -> LLVMValueRef;
|
) -> LLVMValueRef;
|
||||||
|
|
||||||
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
|
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
|
||||||
|
|
||||||
|
pub fn LLVMZludaBuildFence(
|
||||||
|
B: LLVMBuilderRef,
|
||||||
|
ordering: LLVMAtomicOrdering,
|
||||||
|
scope: *const i8,
|
||||||
|
Name: *const i8,
|
||||||
|
) -> LLVMValueRef;
|
||||||
}
|
}
|
||||||
|
@ -385,20 +385,22 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||||||
| ptx_parser::ScalarType::U16 => unsafe {
|
| ptx_parser::ScalarType::U16 => unsafe {
|
||||||
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
|
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||||
},
|
},
|
||||||
|
ptx_parser::ScalarType::S32
|
||||||
|
| ptx_parser::ScalarType::B32
|
||||||
|
| ptx_parser::ScalarType::U32 => unsafe {
|
||||||
|
LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||||
|
},
|
||||||
ptx_parser::ScalarType::F16 => todo!(),
|
ptx_parser::ScalarType::F16 => todo!(),
|
||||||
ptx_parser::ScalarType::BF16 => todo!(),
|
ptx_parser::ScalarType::BF16 => todo!(),
|
||||||
ptx_parser::ScalarType::S32 => todo!(),
|
|
||||||
ptx_parser::ScalarType::U64 => todo!(),
|
ptx_parser::ScalarType::U64 => todo!(),
|
||||||
ptx_parser::ScalarType::S64 => todo!(),
|
ptx_parser::ScalarType::S64 => todo!(),
|
||||||
ptx_parser::ScalarType::S16x2 => todo!(),
|
ptx_parser::ScalarType::S16x2 => todo!(),
|
||||||
ptx_parser::ScalarType::B32 => todo!(),
|
|
||||||
ptx_parser::ScalarType::F32 => todo!(),
|
ptx_parser::ScalarType::F32 => todo!(),
|
||||||
ptx_parser::ScalarType::B64 => todo!(),
|
ptx_parser::ScalarType::B64 => todo!(),
|
||||||
ptx_parser::ScalarType::F64 => todo!(),
|
ptx_parser::ScalarType::F64 => todo!(),
|
||||||
ptx_parser::ScalarType::B128 => todo!(),
|
ptx_parser::ScalarType::B128 => todo!(),
|
||||||
ptx_parser::ScalarType::U16x2 => todo!(),
|
ptx_parser::ScalarType::U16x2 => todo!(),
|
||||||
ptx_parser::ScalarType::F16x2 => todo!(),
|
ptx_parser::ScalarType::F16x2 => todo!(),
|
||||||
ptx_parser::ScalarType::U32 => todo!(),
|
|
||||||
ptx_parser::ScalarType::BF16x2 => todo!(),
|
ptx_parser::ScalarType::BF16x2 => todo!(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -552,8 +554,8 @@ impl<'a> MethodEmitContext<'a> {
|
|||||||
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
|
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
|
||||||
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
|
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
|
||||||
ast::Instruction::PrmtSlow { .. } => todo!(),
|
ast::Instruction::PrmtSlow { .. } => todo!(),
|
||||||
ast::Instruction::Prmt { .. } => todo!(),
|
ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
|
||||||
ast::Instruction::Membar { .. } => todo!(),
|
ast::Instruction::Membar { data } => self.emit_membar(data),
|
||||||
ast::Instruction::Trap {} => todo!(),
|
ast::Instruction::Trap {} => todo!(),
|
||||||
// replaced by a function call
|
// replaced by a function call
|
||||||
ast::Instruction::Bfe { .. }
|
ast::Instruction::Bfe { .. }
|
||||||
@ -582,88 +584,14 @@ impl<'a> MethodEmitContext<'a> {
|
|||||||
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
|
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
|
||||||
let builder = self.builder;
|
let builder = self.builder;
|
||||||
match conversion.kind {
|
match conversion.kind {
|
||||||
ConversionKind::Default => {
|
ConversionKind::Default => self.emit_conversion_default(
|
||||||
match (&conversion.from_type, &conversion.to_type) {
|
self.resolver.value(conversion.src)?,
|
||||||
(ast::Type::Scalar(from_type), ast::Type::Scalar(to_type)) => {
|
conversion.dst,
|
||||||
let from_layout = conversion.from_type.layout();
|
&conversion.from_type,
|
||||||
let to_layout = conversion.to_type.layout();
|
conversion.from_space,
|
||||||
if from_layout.size() == to_layout.size() {
|
&conversion.to_type,
|
||||||
let dst_type = get_type(self.context, &conversion.to_type)?;
|
conversion.to_space,
|
||||||
if from_type.kind() != ast::ScalarKind::Float
|
),
|
||||||
&& to_type.kind() != ast::ScalarKind::Float
|
|
||||||
{
|
|
||||||
// It is noop, but another instruction expects result of this conversion
|
|
||||||
self.resolver
|
|
||||||
.register(conversion.dst, self.resolver.value(conversion.src)?);
|
|
||||||
} else {
|
|
||||||
let src = self.resolver.value(conversion.src)?;
|
|
||||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
|
||||||
LLVMBuildBitCast(builder, src, dst_type, dst)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
let src = self.resolver.value(conversion.src)?;
|
|
||||||
// This block is safe because it's illegal to implictly convert between floating point values
|
|
||||||
let same_width_bit_type = unsafe {
|
|
||||||
LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
|
|
||||||
};
|
|
||||||
let same_width_bit_value = unsafe {
|
|
||||||
LLVMBuildBitCast(
|
|
||||||
builder,
|
|
||||||
src,
|
|
||||||
same_width_bit_type,
|
|
||||||
LLVM_UNNAMED.as_ptr(),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let wide_bit_type = unsafe {
|
|
||||||
LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
|
|
||||||
};
|
|
||||||
if to_type.kind() == ast::ScalarKind::Unsigned
|
|
||||||
|| to_type.kind() == ast::ScalarKind::Bit
|
|
||||||
{
|
|
||||||
let llvm_fn = if to_type.size_of() >= from_type.size_of() {
|
|
||||||
LLVMBuildZExtOrBitCast
|
|
||||||
} else {
|
|
||||||
LLVMBuildTrunc
|
|
||||||
};
|
|
||||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
|
||||||
llvm_fn(builder, same_width_bit_value, wide_bit_type, dst)
|
|
||||||
});
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
let _conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
|
|
||||||
&& to_type.kind() == ast::ScalarKind::Signed
|
|
||||||
{
|
|
||||||
if to_type.size_of() >= from_type.size_of() {
|
|
||||||
LLVMBuildSExtOrBitCast
|
|
||||||
} else {
|
|
||||||
LLVMBuildTrunc
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if to_type.size_of() >= from_type.size_of() {
|
|
||||||
LLVMBuildZExtOrBitCast
|
|
||||||
} else {
|
|
||||||
LLVMBuildTrunc
|
|
||||||
}
|
|
||||||
};
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(ast::Type::Vector(..), ast::Type::Scalar(..))
|
|
||||||
| (ast::Type::Scalar(..), ast::Type::Array(..))
|
|
||||||
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
|
|
||||||
let src = self.resolver.value(conversion.src)?;
|
|
||||||
let dst_type = get_type(self.context, &conversion.to_type)?;
|
|
||||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
|
||||||
LLVMBuildBitCast(builder, src, dst_type, dst)
|
|
||||||
});
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
_ => todo!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ConversionKind::SignExtend => {
|
ConversionKind::SignExtend => {
|
||||||
let src = self.resolver.value(conversion.src)?;
|
let src = self.resolver.value(conversion.src)?;
|
||||||
let type_ = get_type(self.context, &conversion.to_type)?;
|
let type_ = get_type(self.context, &conversion.to_type)?;
|
||||||
@ -699,6 +627,115 @@ impl<'a> MethodEmitContext<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_conversion_default(
|
||||||
|
&mut self,
|
||||||
|
src: LLVMValueRef,
|
||||||
|
dst: SpirvWord,
|
||||||
|
from_type: &ast::Type,
|
||||||
|
from_space: ast::StateSpace,
|
||||||
|
to_type: &ast::Type,
|
||||||
|
to_space: ast::StateSpace,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
match (from_type, to_type) {
|
||||||
|
(ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => {
|
||||||
|
let from_layout = from_type.layout();
|
||||||
|
let to_layout = to_type.layout();
|
||||||
|
if from_layout.size() == to_layout.size() {
|
||||||
|
let dst_type = get_type(self.context, &to_type)?;
|
||||||
|
if from_type.kind() != ast::ScalarKind::Float
|
||||||
|
&& to_type_scalar.kind() != ast::ScalarKind::Float
|
||||||
|
{
|
||||||
|
// It is noop, but another instruction expects result of this conversion
|
||||||
|
self.resolver.register(dst, src);
|
||||||
|
} else {
|
||||||
|
self.resolver.with_result(dst, |dst| unsafe {
|
||||||
|
LLVMBuildBitCast(self.builder, src, dst_type, dst)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
// This block is safe because it's illegal to implictly convert between floating point values
|
||||||
|
let same_width_bit_type = unsafe {
|
||||||
|
LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
|
||||||
|
};
|
||||||
|
let same_width_bit_value = unsafe {
|
||||||
|
LLVMBuildBitCast(
|
||||||
|
self.builder,
|
||||||
|
src,
|
||||||
|
same_width_bit_type,
|
||||||
|
LLVM_UNNAMED.as_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let wide_bit_type = match to_type_scalar.layout().size() {
|
||||||
|
1 => ast::ScalarType::B8,
|
||||||
|
2 => ast::ScalarType::B16,
|
||||||
|
4 => ast::ScalarType::B32,
|
||||||
|
8 => ast::ScalarType::B64,
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
let wide_bit_type_llvm = unsafe {
|
||||||
|
LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
|
||||||
|
};
|
||||||
|
if to_type_scalar.kind() == ast::ScalarKind::Unsigned
|
||||||
|
|| to_type_scalar.kind() == ast::ScalarKind::Bit
|
||||||
|
{
|
||||||
|
let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() {
|
||||||
|
LLVMBuildZExtOrBitCast
|
||||||
|
} else {
|
||||||
|
LLVMBuildTrunc
|
||||||
|
};
|
||||||
|
self.resolver.with_result(dst, |dst| unsafe {
|
||||||
|
llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
|
||||||
|
&& to_type_scalar.kind() == ast::ScalarKind::Signed
|
||||||
|
{
|
||||||
|
if to_type_scalar.size_of() >= from_type.size_of() {
|
||||||
|
LLVMBuildSExtOrBitCast
|
||||||
|
} else {
|
||||||
|
LLVMBuildTrunc
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if to_type_scalar.size_of() >= from_type.size_of() {
|
||||||
|
LLVMBuildZExtOrBitCast
|
||||||
|
} else {
|
||||||
|
LLVMBuildTrunc
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let wide_bit_value = unsafe {
|
||||||
|
conversion_fn(
|
||||||
|
self.builder,
|
||||||
|
same_width_bit_value,
|
||||||
|
wide_bit_type_llvm,
|
||||||
|
LLVM_UNNAMED.as_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
self.emit_conversion_default(
|
||||||
|
wide_bit_value,
|
||||||
|
dst,
|
||||||
|
&wide_bit_type.into(),
|
||||||
|
from_space,
|
||||||
|
to_type,
|
||||||
|
to_space,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(ast::Type::Vector(..), ast::Type::Scalar(..))
|
||||||
|
| (ast::Type::Scalar(..), ast::Type::Array(..))
|
||||||
|
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
|
||||||
|
let dst_type = get_type(self.context, to_type)?;
|
||||||
|
self.resolver.with_result(dst, |dst| unsafe {
|
||||||
|
LLVMBuildBitCast(self.builder, src, dst_type, dst)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
_ => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
|
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
|
||||||
let type_ = get_scalar_type(self.context, constant.typ);
|
let type_ = get_scalar_type(self.context, constant.typ);
|
||||||
let value = match constant.value {
|
let value = match constant.value {
|
||||||
@ -1879,6 +1916,60 @@ impl<'a> MethodEmitContext<'a> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> {
|
||||||
|
unsafe {
|
||||||
|
LLVMZludaBuildFence(
|
||||||
|
self.builder,
|
||||||
|
LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent,
|
||||||
|
get_scope_membar(data)?,
|
||||||
|
LLVM_UNNAMED.as_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_prmt(
|
||||||
|
&mut self,
|
||||||
|
control: u16,
|
||||||
|
arguments: ptx_parser::PrmtArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let components = [
|
||||||
|
(control >> 0) & 0b1111,
|
||||||
|
(control >> 4) & 0b1111,
|
||||||
|
(control >> 8) & 0b1111,
|
||||||
|
(control >> 12) & 0b1111,
|
||||||
|
];
|
||||||
|
if components.iter().any(|&c| c > 7) {
|
||||||
|
return Err(TranslateError::Todo);
|
||||||
|
}
|
||||||
|
let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
|
||||||
|
let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;
|
||||||
|
let mut components = [
|
||||||
|
unsafe { LLVMConstInt(u32_type, components[0] as _, 0) },
|
||||||
|
unsafe { LLVMConstInt(u32_type, components[1] as _, 0) },
|
||||||
|
unsafe { LLVMConstInt(u32_type, components[2] as _, 0) },
|
||||||
|
unsafe { LLVMConstInt(u32_type, components[3] as _, 0) },
|
||||||
|
];
|
||||||
|
let components_indices =
|
||||||
|
unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) };
|
||||||
|
let src1 = self.resolver.value(arguments.src1)?;
|
||||||
|
let src1_vector =
|
||||||
|
unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) };
|
||||||
|
let src2 = self.resolver.value(arguments.src2)?;
|
||||||
|
let src2_vector =
|
||||||
|
unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) };
|
||||||
|
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||||
|
LLVMBuildShuffleVector(
|
||||||
|
self.builder,
|
||||||
|
src1_vector,
|
||||||
|
src2_vector,
|
||||||
|
components_indices,
|
||||||
|
dst,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||||
// Should be available in LLVM 19
|
// Should be available in LLVM 19
|
||||||
@ -1964,6 +2055,16 @@ fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
|
|||||||
.as_ptr())
|
.as_ptr())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
|
||||||
|
Ok(match scope {
|
||||||
|
ast::MemScope::Cta => c"workgroup",
|
||||||
|
ast::MemScope::Gpu => c"agent",
|
||||||
|
ast::MemScope::Sys => c"",
|
||||||
|
ast::MemScope::Cluster => todo!(),
|
||||||
|
}
|
||||||
|
.as_ptr())
|
||||||
|
}
|
||||||
|
|
||||||
fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
|
fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
|
||||||
match semantics {
|
match semantics {
|
||||||
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
|
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
|
||||||
|
Reference in New Issue
Block a user