Remove inkwell

This commit is contained in:
Andrzej Janik
2024-09-12 04:37:31 +02:00
parent fb68c67adb
commit 631417b405
7 changed files with 366 additions and 436 deletions

View File

@ -1,2 +0,0 @@
[patch.crates-io]
inkwell = { git = "https://github.com/vosen/inkwell.git", rev = "46027c2afb7e98976438cdcc41a2949dedb60b2e" }

View File

@ -15,8 +15,3 @@ features = [ "disable-alltargets-init", "no-llvm-linking" ]
[build-dependencies]
cmake = "0.1"
cc = "1.0.69"
[dependencies.inkwell]
version = "0.5"
default-features = false # default features contain all LLVM targets (x86, mips, riscv, etc.)
features = [ "llvm17-0-no-llvm-linking", "no-libffi-linking" ]

View File

@ -1,15 +1,10 @@
pub mod inkwell {
pub use inkwell::*;
}
pub mod llvm {
use llvm_sys::prelude::*;
pub use llvm_sys::*;
extern "C" {
use llvm_sys::prelude::*;
pub use llvm_sys::*;
extern "C" {
pub fn LLVMZludaBuildAlloca(
B: LLVMBuilderRef,
Ty: LLVMTypeRef,
AddrSpace: u32,
Name: *const i8,
) -> LLVMValueRef;
}
}

View File

@ -2,7 +2,7 @@
name = "ptx"
version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2018"
edition = "2021"
[lib]

View File

@ -1,74 +1,207 @@
// We use Raw LLVM-C bindings here because using inkwell is just not worth it.
// Specifically the issue is with builder functions. We maintain the mapping
// between ZLUDA identifiers and LLVM values. When using inkwell, LLVM values
// are kept as instances `AnyValueEnum`. Now look at the signature of
// `Builder::build_int_add(...)`:
// pub fn build_int_add<T: IntMathValue<'ctx>>(&self, lhs: T, rhs: T, name: &str, ) -> Result<T, BuilderError>
// At this point both lhs and rhs are `AnyValueEnum`. To call
// `build_int_add(...)` we would have to do something like this:
// if let (Ok(lhs), Ok(rhs)) = (lhs.as_int(), rhs.as_int()) {
// builder.build_int_add(lhs, rhs, dst)?;
// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_pointer(), rhs.as_pointer()) {
// builder.build_int_add(lhs, rhs, dst)?;
// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_vector(), rhs.as_vector()) {
// builder.build_int_add(lhs, rhs, dst)?;
// } else {
// return Err(error_unrachable());
// }
// while with plain LLVM-C it's just:
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
use std::convert::{TryFrom, TryInto};
use std::ffi::CStr;
use std::ops::Deref;
use std::ptr;
use super::*;
use llvm_zluda::inkwell::builder::{Builder, BuilderError};
use llvm_zluda::inkwell::context::{AsContextRef, Context};
use llvm_zluda::inkwell::memory_buffer::MemoryBuffer;
use llvm_zluda::inkwell::types::{
ArrayType, AsTypeRef, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType,
IntType, PointerType, VectorType, VoidType,
};
use llvm_zluda::inkwell::values::{
AnyValue, AnyValueEnum, ArrayValue, BasicValueEnum, FloatMathValue, FloatValue, FunctionValue,
InstructionValue, IntMathValue, IntValue, PhiValue, PointerValue, StructValue, VectorValue,
};
use llvm_zluda::inkwell::{self, module, AddressSpace};
use llvm_zluda::llvm::core::{
LLVMArrayType2, LLVMBFloatType, LLVMBFloatTypeInContext, LLVMVectorType,
};
use llvm_zluda::llvm::prelude::*;
use llvm_zluda::llvm::{LLVMCallConv, LLVMZludaBuildAlloca};
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::core::*;
use llvm_zluda::prelude::*;
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
const LLVM_UNNAMED: &str = "\0";
const LLVM_UNNAMED: &CStr = c"";
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
const GENERIC_ADDRESS_SPACE: u16 = 0;
const GLOBAL_ADDRESS_SPACE: u16 = 1;
const SHARED_ADDRESS_SPACE: u16 = 3;
const CONSTANT_ADDRESS_SPACE: u16 = 4;
const PRIVATE_ADDRESS_SPACE: u16 = 5;
const GENERIC_ADDRESS_SPACE: u32 = 0;
const GLOBAL_ADDRESS_SPACE: u32 = 1;
const SHARED_ADDRESS_SPACE: u32 = 3;
const CONSTANT_ADDRESS_SPACE: u32 = 4;
const PRIVATE_ADDRESS_SPACE: u32 = 5;
struct Context(LLVMContextRef);
impl Context {
fn new() -> Self {
Self(unsafe { LLVMContextCreate() })
}
fn get(&self) -> LLVMContextRef {
self.0
}
}
impl Drop for Context {
fn drop(&mut self) {
unsafe {
LLVMContextDispose(self.0);
}
}
}
struct Module(LLVMModuleRef);
impl Module {
fn new(ctx: &Context, name: &CStr) -> Self {
Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) })
}
fn get(&self) -> LLVMModuleRef {
self.0
}
fn verify(&self) -> Result<(), Message> {
let mut err = ptr::null_mut();
let error = unsafe {
LLVMVerifyModule(
self.get(),
LLVMVerifierFailureAction::LLVMReturnStatusAction,
&mut err,
)
};
if error == 1 && err != ptr::null_mut() {
Err(Message(unsafe { CStr::from_ptr(err) }))
} else {
Ok(())
}
}
fn write_bitcode_to_memory(&self) -> MemoryBuffer {
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
MemoryBuffer(memory_buffer)
}
fn write_to_stderr(&self) {
unsafe { LLVMDumpModule(self.get()) };
}
}
impl Drop for Module {
fn drop(&mut self) {
unsafe {
LLVMDisposeModule(self.0);
}
}
}
struct Builder(LLVMBuilderRef);
impl Builder {
fn new(ctx: &Context) -> Self {
Self::new_raw(ctx.get())
}
fn new_raw(ctx: LLVMContextRef) -> Self {
Self(unsafe { LLVMCreateBuilderInContext(ctx) })
}
fn get(&self) -> LLVMBuilderRef {
self.0
}
}
impl Drop for Builder {
fn drop(&mut self) {
unsafe {
LLVMDisposeBuilder(self.0);
}
}
}
struct Message(&'static CStr);
impl Drop for Message {
fn drop(&mut self) {
unsafe {
LLVMDisposeMessage(self.0.as_ptr().cast_mut());
}
}
}
impl std::fmt::Debug for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.0, f)
}
}
pub struct MemoryBuffer(LLVMMemoryBufferRef);
impl Drop for MemoryBuffer {
fn drop(&mut self) {
unsafe {
LLVMDisposeMemoryBuffer(self.0);
}
}
}
impl Deref for MemoryBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
let data = unsafe { LLVMGetBufferStart(self.0) };
let len = unsafe { LLVMGetBufferSize(self.0) };
unsafe { std::slice::from_raw_parts(data.cast(), len) }
}
}
pub(super) fn run<'input>(
id_defs: &GlobalStringIdResolver<'input>,
call_map: MethodsCallMap<'input>,
directives: Vec<Directive<'input>>,
) -> Result<MemoryBuffer, TranslateError> {
let context = inkwell::context::Context::create();
let module = context.create_module(LLVM_UNNAMED);
let builder = context.create_builder();
let mut emit_ctx = ModuleEmitContext::new(&context, module, builder, id_defs);
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
for directive in directives {
match directive {
Directive::Variable(..) => todo!(),
Directive::Method(method) => emit_ctx.emit_method(method)?,
}
}
if let Err(err) = emit_ctx.module.verify() {
emit_ctx.module.print_to_stderr();
panic!("{}", err);
module.write_to_stderr();
if let Err(err) = module.verify() {
panic!("{:?}", err);
}
Ok(emit_ctx.module.write_bitcode_to_memory())
Ok(module.write_bitcode_to_memory())
}
struct ModuleEmitContext<'ctx, 'input> {
context: &'ctx Context,
module: module::Module<'ctx>,
builder: Builder<'ctx>,
id_defs: &'ctx GlobalStringIdResolver<'input>,
resolver: ResolveIdent<'ctx>,
struct ModuleEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
builder: Builder,
id_defs: &'a GlobalStringIdResolver<'input>,
resolver: ResolveIdent,
}
impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new(
context: &'ctx Context,
module: module::Module<'ctx>,
builder: Builder<'ctx>,
id_defs: &'ctx GlobalStringIdResolver<'input>,
context: &Context,
module: &Module,
id_defs: &'a GlobalStringIdResolver<'input>,
) -> Self {
ModuleEmitContext {
context: &context,
module,
builder,
context: context.get(),
module: module.get(),
builder: Builder::new(context),
id_defs,
resolver: ResolveIdent::new(&id_defs),
}
@ -84,85 +217,86 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
let func_decl = method.func_decl.borrow();
let fn_ = self.module.add_function(
method
let name = method
.import_as
.as_deref()
.unwrap_or_else(|| match func_decl.name {
ast::MethodName::Kernel(name) => name,
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
}),
self.function_type(
});
let name = CString::new(name).map_err(|_| error_unreachable())?;
let fn_type = self.function_type(
func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl.input_arguments.iter().map(|v| &v.v_type),
),
None,
);
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
for (i, param) in func_decl.input_arguments.iter().enumerate() {
let value = fn_
.get_nth_param(i as u32)
.ok_or_else(|| error_unreachable())?;
value.set_name(self.resolver.get_or_add(param.name));
let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
}
fn_.set_call_conventions(if func_decl.name.is_kernel() {
let call_conv = if func_decl.name.is_kernel() {
Self::kernel_call_convention()
} else {
Self::func_call_convention()
});
};
unsafe { LLVMSetFunctionCallConv(fn_, call_conv) };
if let Some(statements) = method.body {
let variables_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED);
let variables_builder = self.context.create_builder();
variables_builder.position_at_end(variables_bb);
let real_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED);
self.builder.position_at_end(real_bb);
let variables_bb =
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
let variables_builder = Builder::new_raw(self.context);
unsafe { LLVMPositionBuilderAtEnd(variables_builder.get(), variables_bb) };
let real_bb =
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
for statement in statements {
method_emitter.emit_statement(statement)?;
}
method_emitter.variables_builder.build_unconditional_branch(real_bb);
unsafe { LLVMBuildBr(method_emitter.variables_builder.get(), real_bb) };
}
Ok(())
}
fn function_type<'a>(
fn function_type(
&self,
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
) -> FunctionType<'ctx> {
) -> LLVMTypeRef {
if return_args.len() == 0 {
let input_args = input_args
let mut input_args = input_args
.map(|type_| match type_ {
ast::Type::Scalar(scalar) => match scalar {
ast::ScalarType::Pred => {
BasicMetadataTypeEnum::from(self.context.bool_type())
unsafe { LLVMInt1TypeInContext(self.context) }
}
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
BasicMetadataTypeEnum::from(self.context.i8_type())
unsafe { LLVMInt8TypeInContext(self.context) }
}
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
BasicMetadataTypeEnum::from(self.context.i16_type())
unsafe { LLVMInt16TypeInContext(self.context) }
}
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
BasicMetadataTypeEnum::from(self.context.i32_type())
unsafe { LLVMInt32TypeInContext(self.context) }
}
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
BasicMetadataTypeEnum::from(self.context.i64_type())
unsafe { LLVMInt64TypeInContext(self.context) }
}
ast::ScalarType::B128 => {
BasicMetadataTypeEnum::from(self.context.i128_type())
unsafe { LLVMInt128TypeInContext(self.context) }
}
ast::ScalarType::F16 => {
BasicMetadataTypeEnum::from(self.context.f16_type())
unsafe { LLVMHalfTypeInContext(self.context) }
}
ast::ScalarType::F32 => {
BasicMetadataTypeEnum::from(self.context.f32_type())
unsafe { LLVMFloatTypeInContext(self.context) }
}
ast::ScalarType::F64 => {
BasicMetadataTypeEnum::from(self.context.f64_type())
unsafe { LLVMDoubleTypeInContext(self.context) }
}
ast::ScalarType::BF16 => {
BasicMetadataTypeEnum::from(unsafe { FloatType::new(LLVMBFloatType()) })
unsafe { LLVMBFloatTypeInContext(self.context) }
}
ast::ScalarType::U16x2 => todo!(),
ast::ScalarType::S16x2 => todo!(),
@ -174,41 +308,39 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
ast::Type::Pointer(_, _) => todo!(),
})
.collect::<Vec<_>>();
return self.context.void_type().fn_type(&*input_args, false);
return unsafe {
LLVMFunctionType(
LLVMVoidTypeInContext(self.context),
input_args.as_mut_ptr(),
input_args.len() as u32,
0,
)
};
}
todo!()
}
fn get_type(&self, type_: &ast::Type) -> FunctionType<'ctx> {
match type_ {
ast::Type::Scalar(_) => todo!(),
ast::Type::Vector(_, _) => todo!(),
ast::Type::Array(_, _, _) => todo!(),
ast::Type::Pointer(_, _) => todo!(),
}
}
}
struct MethodEmitContext<'a, 'ctx, 'input> {
context: &'ctx Context,
module: &'a module::Module<'ctx>,
method: FunctionValue<'ctx>,
builder: &'a Builder<'ctx>,
struct MethodEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
id_defs: &'a GlobalStringIdResolver<'input>,
variables_builder: Builder<'ctx>,
resolver: &'a mut ResolveIdent<'ctx>,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
fn new(
parent: &'a mut ModuleEmitContext<'ctx, 'input>,
method: FunctionValue<'ctx>,
variables_builder: Builder<'ctx>,
) -> MethodEmitContext<'a, 'ctx, 'input> {
impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn new<'x>(
parent: &'a mut ModuleEmitContext<'x, 'input>,
method: LLVMValueRef,
variables_builder: Builder,
) -> MethodEmitContext<'a, 'input> {
MethodEmitContext {
context: &parent.context,
module: &parent.module,
builder: &parent.builder,
context: parent.context,
module: parent.module,
builder: parent.builder.get(),
id_defs: parent.id_defs,
variables_builder,
resolver: &mut parent.resolver,
@ -238,19 +370,16 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
fn emit_variable(&mut self, var: ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let alloca = unsafe {
PointerValue::new(LLVMZludaBuildAlloca(
self.variables_builder.as_mut_ptr(),
get_type::<BasicTypeEnum>(&self.context, &var.v_type)?.as_type_ref(),
get_state_space(var.state_space)? as u32,
LLVMZludaBuildAlloca(
self.variables_builder.get(),
get_type(self.context, &var.v_type)?,
get_state_space(var.state_space)?,
self.resolver.get_or_add_raw(var.name),
))
)
};
self.resolver.register(var.name, alloca);
if let Some(align) = var.align {
let alloca = alloca.as_instruction().ok_or_else(|| error_unreachable())?;
alloca
.set_alignment(align)
.map_err(|_| error_unreachable())?;
unsafe { LLVMSetAlignment(alloca, align) };
}
if !var.array_init.is_empty() {
todo!()
@ -259,27 +388,24 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
}
fn emit_label(&mut self, label: SpirvWord) {
let block = self
.context
.append_basic_block(self.method, self.resolver.get_or_add(label));
if self
.builder
.get_insert_block()
.unwrap()
.get_terminator()
.is_none()
{
self.builder.build_unconditional_branch(block);
let block = unsafe {
LLVMAppendBasicBlockInContext(
self.context,
self.method,
self.resolver.get_or_add_raw(label),
)
};
let last_block = unsafe { LLVMGetInsertBlock(self.builder) };
if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() {
unsafe { LLVMBuildBr(self.builder, block) };
}
self.builder.position_at_end(block);
unsafe { LLVMPositionBuilderAtEnd(self.builder, block) };
}
fn emit_store_var(&mut self, store: StoreVarDetails) -> Result<(), TranslateError> {
let src1 = self.resolver.value(store.arg.src1)?;
let src2 = self.resolver.value(store.arg.src2)?;
self.builder
.build_store(src1.as_pointer()?, src2.as_basic()?)
.map_err(|_| error_unreachable())?;
let ptr = self.resolver.value(store.arg.src1)?;
let value = self.resolver.value(store.arg.src2)?;
unsafe { LLVMBuildStore(self.builder, value, ptr) };
Ok(())
}
@ -303,7 +429,7 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
ast::Instruction::Cvt { data, arguments } => todo!(),
ast::Instruction::Shr { data, arguments } => todo!(),
ast::Instruction::Shl { data, arguments } => todo!(),
ast::Instruction::Ret { data } => self.emit_ret(data),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { data, arguments } => todo!(),
ast::Instruction::Abs { data, arguments } => todo!(),
ast::Instruction::Mad { data, arguments } => todo!(),
@ -351,10 +477,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
todo!()
}
let builder = self.builder;
let type_ = get_type::<BasicTypeEnum>(&self.context, &data.typ)?;
let ptr = self.resolver.value(arguments.src)?.as_pointer()?;
self.resolver
.with_result(arguments.dst, |dst| builder.build_load(type_, ptr, dst))
let type_ = get_type(self.context, &data.typ)?;
let ptr = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildLoad2(builder, type_, ptr, dst)
});
Ok(())
}
fn emit_load_variable(&mut self, var: LoadVarDetails) -> Result<(), TranslateError> {
@ -362,10 +490,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
todo!()
}
let builder = self.builder;
let type_ = get_type::<BasicTypeEnum>(&self.context, &var.typ)?;
let ptr = self.resolver.value(var.arg.src)?.as_pointer()?;
self.resolver
.with_result(var.arg.dst, |dst| builder.build_load(type_, ptr, dst))
let type_ = get_type(self.context, &var.typ)?;
let ptr = self.resolver.value(var.arg.src)?;
self.resolver.with_result(var.arg.dst, |dst| unsafe {
LLVMBuildLoad2(builder, type_, ptr, dst)
});
Ok(())
}
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
@ -374,11 +504,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
ConversionKind::Default => todo!(),
ConversionKind::SignExtend => todo!(),
ConversionKind::BitToPtr => {
let src = self.resolver.value(conversion.src)?.as_int()?;
let src = self.resolver.value(conversion.src)?;
let type_ = get_pointer_type(self.context, conversion.to_space)?;
self.resolver.with_result(conversion.dst, |dst| {
builder.build_int_to_ptr(src, type_, dst)
})
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildIntToPtr(builder, src, type_, dst)
});
Ok(())
}
ConversionKind::PtrToPtr => todo!(),
ConversionKind::AddressOf => todo!(),
@ -386,21 +517,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
}
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
let type_ = get_scalar_type::<BasicTypeEnum>(&self.context, constant.typ);
let value: AnyValueEnum = match (type_, constant.value) {
(BasicTypeEnum::IntType(type_), ast::ImmediateValue::U64(x)) => {
type_.const_int(x, false).into()
}
(BasicTypeEnum::IntType(type_), ast::ImmediateValue::S64(x)) => {
type_.const_int(x as u64, false).into()
}
(BasicTypeEnum::FloatType(type_), ast::ImmediateValue::F32(x)) => {
type_.const_float(x as f64).into()
}
(BasicTypeEnum::FloatType(type_), ast::ImmediateValue::F64(x)) => {
type_.const_float(x).into()
}
_ => return Err(error_unreachable()),
let type_ = get_scalar_type(self.context, constant.typ);
let value = match constant.value {
ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) },
ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) },
ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) },
ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) },
};
self.resolver.register(constant.dst, value);
Ok(())
@ -412,14 +534,16 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
arguments: ast::AddArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?.as_int()?;
let src2 = self.resolver.value(arguments.src2)?.as_int()?;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let fn_ = match data {
ast::ArithDetails::Integer(integer) => Builder::build_int_add,
ast::ArithDetails::Float(float) => todo!(),
ast::ArithDetails::Integer(integer) => LLVMBuildAdd,
ast::ArithDetails::Float(float) => LLVMBuildFAdd,
};
self.resolver
.with_result(arguments.dst, |dst| fn_(builder, src1, src2, dst))
self.resolver.with_result(arguments.dst, |dst| unsafe {
fn_(builder, src1, src2, dst)
});
Ok(())
}
fn emit_st(
@ -427,129 +551,80 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
data: ptx_parser::StData,
arguments: ptx_parser::StArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?.as_pointer()?;
let src2 = self.resolver.value(arguments.src2)?.as_basic()?;
let ptr = self.resolver.value(arguments.src1)?;
let value = self.resolver.value(arguments.src2)?;
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
self.builder
.build_store(src1, src2)
.map_err(|_| error_unreachable())?;
unsafe { LLVMBuildStore(self.builder, value, ptr) };
Ok(())
}
fn emit_ret(&self, _data: ptx_parser::RetData) -> Result<(), TranslateError> {
self.builder
.build_return(None)
.map_err(|_| error_unreachable())?;
Ok(())
fn emit_ret(&self, _data: ptx_parser::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
}
fn get_pointer_type<'ctx>(
context: &'ctx Context,
context: LLVMContextRef,
to_space: ast::StateSpace,
) -> Result<PointerType<'ctx>, TranslateError> {
Ok(context.ptr_type(AddressSpace::from(get_state_space(to_space)?)))
) -> Result<LLVMTypeRef, TranslateError> {
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
}
fn get_type<
'ctx,
T: From<IntType<'ctx>>
+ From<FloatType<'ctx>>
+ From<VectorType<'ctx>>
+ From<PointerType<'ctx>>
+ From<ArrayType<'ctx>>,
>(
context: &'ctx Context,
type_: &ast::Type,
) -> Result<T, TranslateError> {
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
ast::Type::Vector(size, scalar) => {
let base_type = get_scalar_type::<BasicTypeEnum>(context, *scalar);
let base_type = match base_type {
BasicTypeEnum::FloatType(t) => t.as_type_ref(),
BasicTypeEnum::IntType(t) => t.as_type_ref(),
_ => return Err(error_unreachable()),
};
T::from(unsafe { VectorType::new(LLVMVectorType(base_type, *size as u32)) })
let base_type = get_scalar_type(context, *scalar);
unsafe { LLVMVectorType(base_type, *size as u32) }
}
ast::Type::Array(vec, scalar, dimensions) => {
let mut underlying_type = get_scalar_type::<BasicTypeEnum>(context, *scalar);
let mut underlying_type = get_scalar_type(context, *scalar);
if let Some(size) = vec {
underlying_type = BasicTypeEnum::VectorType(unsafe {
VectorType::new(LLVMVectorType(
match underlying_type {
BasicTypeEnum::FloatType(t) => t.as_type_ref(),
BasicTypeEnum::IntType(t) => t.as_type_ref(),
_ => return Err(error_unreachable()),
},
size.get() as u32,
))
});
underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) };
}
if dimensions.is_empty() {
return Ok(T::from(underlying_type.array_type(0)));
return Ok(unsafe { LLVMArrayType2(underlying_type, 0) });
}
let llvm_type = dimensions
dimensions
.iter()
.rfold(underlying_type.as_type_ref(), |result, dimension| unsafe {
.rfold(underlying_type, |result, dimension| unsafe {
LLVMArrayType2(result, *dimension as u64)
});
T::from(unsafe { ArrayType::new(llvm_type) })
}
ast::Type::Pointer(_, space) => {
T::from(context.ptr_type(AddressSpace::from(get_state_space(*space)?)))
})
}
ast::Type::Pointer(_, space) => get_pointer_type(context, *space)?,
})
}
fn get_scalar_type<
'ctx,
T: From<IntType<'ctx>> + From<FloatType<'ctx>> + From<VectorType<'ctx>>,
>(
context: &'ctx Context,
type_: ast::ScalarType,
) -> T {
fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef {
match type_ {
ast::ScalarType::Pred => T::from(context.bool_type()),
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
T::from(context.i8_type())
}
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
T::from(context.i16_type())
}
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
T::from(context.i32_type())
}
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
T::from(context.i64_type())
}
ast::ScalarType::B128 => T::from(context.i128_type()),
ast::ScalarType::F16 => T::from(context.f16_type()),
ast::ScalarType::F32 => T::from(context.f32_type()),
ast::ScalarType::F64 => T::from(context.f64_type()),
ast::ScalarType::BF16 => {
T::from(unsafe { FloatType::new(LLVMBFloatTypeInContext(context.as_ctx_ref())) })
}
ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => {
T::from(unsafe { VectorType::new(LLVMVectorType(context.i16_type().as_type_ref(), 2)) })
}
ast::ScalarType::F16x2 => {
T::from(unsafe { VectorType::new(LLVMVectorType(context.f16_type().as_type_ref(), 2)) })
}
ast::ScalarType::BF16x2 => T::from(unsafe {
VectorType::new(LLVMVectorType(
LLVMBFloatTypeInContext(context.as_ctx_ref()),
2,
))
}),
ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) },
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
LLVMInt8TypeInContext(context)
},
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
LLVMInt16TypeInContext(context)
},
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
LLVMInt32TypeInContext(context)
},
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe {
LLVMInt64TypeInContext(context)
},
ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) },
ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) },
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
ast::ScalarType::U16x2 => todo!(),
ast::ScalarType::S16x2 => todo!(),
ast::ScalarType::F16x2 => todo!(),
ast::ScalarType::BF16x2 => todo!(),
}
}
fn get_state_space(space: ast::StateSpace) -> Result<u16, TranslateError> {
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
@ -566,12 +641,12 @@ fn get_state_space(space: ast::StateSpace) -> Result<u16, TranslateError> {
}
}
struct ResolveIdent<'ctx> {
struct ResolveIdent {
words: HashMap<SpirvWord, String>,
values: HashMap<SpirvWord, AnyValueEnum<'ctx>>,
values: HashMap<SpirvWord, LLVMValueRef>,
}
impl<'ctx> ResolveIdent<'ctx> {
impl ResolveIdent {
fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
ResolveIdent {
words: HashMap::new(),
@ -580,14 +655,15 @@ impl<'ctx> ResolveIdent<'ctx> {
}
fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T {
match self.words.entry(word) {
hash_map::Entry::Occupied(entry) => fn_(entry.into_mut()),
let str = match self.words.entry(word) {
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let mut text = word.0.to_string();
text.push('\0');
fn_(entry.insert(text))
}
entry.insert(text)
}
};
fn_(&str[..str.len() - 1])
}
fn get_or_add(&mut self, word: SpirvWord) -> &str {
@ -598,153 +674,19 @@ impl<'ctx> ResolveIdent<'ctx> {
self.get_or_add(word).as_ptr().cast()
}
fn register(&mut self, word: SpirvWord, t: impl AnyValue<'ctx>) {
self.values.insert(word, t.as_any_value_enum());
fn register(&mut self, word: SpirvWord, v: LLVMValueRef) {
self.values.insert(word, v);
}
fn value(&self, word: SpirvWord) -> Result<AnyValueEnum<'ctx>, TranslateError> {
fn value(&self, word: SpirvWord) -> Result<LLVMValueRef, TranslateError> {
self.values
.get(&word)
.copied()
.ok_or_else(|| error_unreachable())
}
fn with_result<T: AnyValue<'ctx>>(
&mut self,
word: SpirvWord,
fn_: impl FnOnce(&str) -> Result<T, BuilderError>,
) -> Result<(), TranslateError> {
let t = self
.get_or_ad_impl(word, fn_)
.map_err(|_| error_unreachable())?;
fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) {
let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast()));
self.register(word, t);
Ok(())
}
fn build_int_math(
&mut self,
builder: &Builder<'ctx>,
dst: SpirvWord,
src1: SpirvWord,
src2: SpirvWord,
fn_: impl IntMathOp<'ctx>,
) -> Result<(), TranslateError> {
let src1 = self.value(src1)?;
let src2 = self.value(src2)?;
self.with_result(dst, |dst| {
Ok(match (src1, src2) {
(AnyValueEnum::IntValue(src1), AnyValueEnum::IntValue(src2)) => {
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
}
(AnyValueEnum::PointerValue(src1), AnyValueEnum::PointerValue(src2)) => {
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
}
(AnyValueEnum::VectorValue(src1), AnyValueEnum::VectorValue(src2)) => {
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
}
_ => return todo!(),
})
})
}
}
trait IntMathOp<'ctx> {
fn call<T: IntMathValue<'ctx>>(
self,
builder: &Builder<'ctx>,
src1: T,
src2: T,
dst: &str,
) -> Result<T, BuilderError>;
}
trait AnyValueEnumExt<'ctx> {
fn as_array(self) -> Result<ArrayValue<'ctx>, TranslateError>;
fn as_int(self) -> Result<IntValue<'ctx>, TranslateError>;
fn as_float(self) -> Result<FloatValue<'ctx>, TranslateError>;
fn as_phi(self) -> Result<PhiValue<'ctx>, TranslateError>;
fn as_function(self) -> Result<FunctionValue<'ctx>, TranslateError>;
fn as_pointer(self) -> Result<PointerValue<'ctx>, TranslateError>;
fn as_struct(self) -> Result<StructValue<'ctx>, TranslateError>;
fn as_vector(self) -> Result<VectorValue<'ctx>, TranslateError>;
fn as_instruction(self) -> Result<InstructionValue<'ctx>, TranslateError>;
fn as_basic(self) -> Result<BasicValueEnum<'ctx>, TranslateError>;
}
impl<'ctx> AnyValueEnumExt<'ctx> for AnyValueEnum<'ctx> {
fn as_array(self) -> Result<ArrayValue<'ctx>, TranslateError> {
if let AnyValueEnum::ArrayValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_int(self) -> Result<IntValue<'ctx>, TranslateError> {
if let AnyValueEnum::IntValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_float(self) -> Result<FloatValue<'ctx>, TranslateError> {
if let AnyValueEnum::FloatValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_phi(self) -> Result<PhiValue<'ctx>, TranslateError> {
if let AnyValueEnum::PhiValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_function(self) -> Result<FunctionValue<'ctx>, TranslateError> {
if let AnyValueEnum::FunctionValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_pointer(self) -> Result<PointerValue<'ctx>, TranslateError> {
if let AnyValueEnum::PointerValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_struct(self) -> Result<StructValue<'ctx>, TranslateError> {
if let AnyValueEnum::StructValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_vector(self) -> Result<VectorValue<'ctx>, TranslateError> {
if let AnyValueEnum::VectorValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_instruction(self) -> Result<InstructionValue<'ctx>, TranslateError> {
if let AnyValueEnum::InstructionValue(x) = self {
Ok(x)
} else {
Err(error_unreachable())
}
}
fn as_basic(self) -> Result<BasicValueEnum<'ctx>, TranslateError> {
BasicValueEnum::try_from(self).map_err(|_| error_unreachable())
}
}

View File

@ -1,4 +1,3 @@
use llvm_zluda::inkwell::memory_buffer::MemoryBuffer;
use ptx_parser as ast;
use rspirv::{binary::Assemble, dr};
use std::hash::Hash;
@ -17,7 +16,7 @@ use std::{
mod convert_dynamic_shared_memory_usage;
mod convert_to_stateful_memory_access;
mod convert_to_typed;
mod emit_llvm;
pub(crate) mod emit_llvm;
mod emit_spirv;
mod expand_arguments;
mod extract_globals;
@ -182,7 +181,7 @@ fn to_ssa<'input, 'b>(
}
pub struct Module {
pub llvm_ir: MemoryBuffer,
pub llvm_ir: emit_llvm::MemoryBuffer,
pub kernel_info: HashMap<String, KernelInfo>,
}
@ -598,6 +597,7 @@ fn error_unreachable() -> TranslateError {
TranslateError::Unreachable
}
#[cfg(debug_assertions)]
fn error_unknown_symbol() -> TranslateError {
panic!()
}
@ -607,6 +607,7 @@ fn error_unknown_symbol() -> TranslateError {
TranslateError::UnknownSymbol
}
#[cfg(debug_assertions)]
fn error_mismatched_type() -> TranslateError {
panic!()
}

View File

@ -2,7 +2,6 @@ use crate::pass;
use crate::ptx;
use crate::translate;
use hip_runtime_sys::hipError_t;
use llvm_zluda::inkwell::memory_buffer::MemoryBuffer;
use rspirv::{
binary::{Assemble, Disassemble},
dr::{Block, Function, Instruction, Loader, Operand},
@ -379,21 +378,21 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
Ok(result)
}
unsafe fn compile_amd(buffer: &MemoryBuffer) -> Vec<u8> {
unsafe fn compile_amd(buffer: &pass::emit_llvm::MemoryBuffer) -> Vec<u8> {
use amd_comgr_sys::*;
let mut data_set = mem::zeroed();
amd_comgr_create_data_set(&mut data_set).unwrap();
let mut data = mem::zeroed();
amd_comgr_create_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, &mut data).unwrap();
let buffer = buffer.as_slice();
let buffer = &**buffer;
amd_comgr_set_data(data, buffer.len(), buffer.as_ptr().cast()).unwrap();
amd_comgr_set_data_name(data, "zluda.bc\0".as_ptr().cast()).unwrap();
amd_comgr_set_data_name(data, c"zluda.bc".as_ptr()).unwrap();
amd_comgr_data_set_add(data_set, data).unwrap();
let mut reloc_data = mem::zeroed();
amd_comgr_create_data_set(&mut reloc_data).unwrap();
let mut action_info = mem::zeroed();
amd_comgr_create_action_info(&mut action_info).unwrap();
amd_comgr_action_info_set_isa_name(action_info, "amdgcn-amd-amdhsa--gfx1030\0".as_ptr().cast())
amd_comgr_action_info_set_isa_name(action_info, c"amdgcn-amd-amdhsa--gfx1030".as_ptr())
.unwrap();
amd_comgr_do_action(
amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE,