|
|
|
@ -1,74 +1,207 @@
|
|
|
|
|
// We use Raw LLVM-C bindings here because using inkwell is just not worth it.
|
|
|
|
|
// Specifically the issue is with builder functions. We maintain the mapping
|
|
|
|
|
// between ZLUDA identifiers and LLVM values. When using inkwell, LLVM values
|
|
|
|
|
// are kept as instances `AnyValueEnum`. Now look at the signature of
|
|
|
|
|
// `Builder::build_int_add(...)`:
|
|
|
|
|
// pub fn build_int_add<T: IntMathValue<'ctx>>(&self, lhs: T, rhs: T, name: &str, ) -> Result<T, BuilderError>
|
|
|
|
|
// At this point both lhs and rhs are `AnyValueEnum`. To call
|
|
|
|
|
// `build_int_add(...)` we would have to do something like this:
|
|
|
|
|
// if let (Ok(lhs), Ok(rhs)) = (lhs.as_int(), rhs.as_int()) {
|
|
|
|
|
// builder.build_int_add(lhs, rhs, dst)?;
|
|
|
|
|
// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_pointer(), rhs.as_pointer()) {
|
|
|
|
|
// builder.build_int_add(lhs, rhs, dst)?;
|
|
|
|
|
// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_vector(), rhs.as_vector()) {
|
|
|
|
|
// builder.build_int_add(lhs, rhs, dst)?;
|
|
|
|
|
// } else {
|
|
|
|
|
// return Err(error_unrachable());
|
|
|
|
|
// }
|
|
|
|
|
// while with plain LLVM-C it's just:
|
|
|
|
|
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
|
|
|
|
|
|
|
|
|
|
use std::convert::{TryFrom, TryInto};
|
|
|
|
|
use std::ffi::CStr;
|
|
|
|
|
use std::ops::Deref;
|
|
|
|
|
use std::ptr;
|
|
|
|
|
|
|
|
|
|
use super::*;
|
|
|
|
|
use llvm_zluda::inkwell::builder::{Builder, BuilderError};
|
|
|
|
|
use llvm_zluda::inkwell::context::{AsContextRef, Context};
|
|
|
|
|
use llvm_zluda::inkwell::memory_buffer::MemoryBuffer;
|
|
|
|
|
use llvm_zluda::inkwell::types::{
|
|
|
|
|
ArrayType, AsTypeRef, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType,
|
|
|
|
|
IntType, PointerType, VectorType, VoidType,
|
|
|
|
|
};
|
|
|
|
|
use llvm_zluda::inkwell::values::{
|
|
|
|
|
AnyValue, AnyValueEnum, ArrayValue, BasicValueEnum, FloatMathValue, FloatValue, FunctionValue,
|
|
|
|
|
InstructionValue, IntMathValue, IntValue, PhiValue, PointerValue, StructValue, VectorValue,
|
|
|
|
|
};
|
|
|
|
|
use llvm_zluda::inkwell::{self, module, AddressSpace};
|
|
|
|
|
use llvm_zluda::llvm::core::{
|
|
|
|
|
LLVMArrayType2, LLVMBFloatType, LLVMBFloatTypeInContext, LLVMVectorType,
|
|
|
|
|
};
|
|
|
|
|
use llvm_zluda::llvm::prelude::*;
|
|
|
|
|
use llvm_zluda::llvm::{LLVMCallConv, LLVMZludaBuildAlloca};
|
|
|
|
|
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
|
|
|
|
|
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
|
|
|
|
|
use llvm_zluda::core::*;
|
|
|
|
|
use llvm_zluda::prelude::*;
|
|
|
|
|
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
|
|
|
|
|
|
|
|
|
|
const LLVM_UNNAMED: &str = "\0";
|
|
|
|
|
const LLVM_UNNAMED: &CStr = c"";
|
|
|
|
|
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
|
|
|
|
|
const GENERIC_ADDRESS_SPACE: u16 = 0;
|
|
|
|
|
const GLOBAL_ADDRESS_SPACE: u16 = 1;
|
|
|
|
|
const SHARED_ADDRESS_SPACE: u16 = 3;
|
|
|
|
|
const CONSTANT_ADDRESS_SPACE: u16 = 4;
|
|
|
|
|
const PRIVATE_ADDRESS_SPACE: u16 = 5;
|
|
|
|
|
const GENERIC_ADDRESS_SPACE: u32 = 0;
|
|
|
|
|
const GLOBAL_ADDRESS_SPACE: u32 = 1;
|
|
|
|
|
const SHARED_ADDRESS_SPACE: u32 = 3;
|
|
|
|
|
const CONSTANT_ADDRESS_SPACE: u32 = 4;
|
|
|
|
|
const PRIVATE_ADDRESS_SPACE: u32 = 5;
|
|
|
|
|
|
|
|
|
|
struct Context(LLVMContextRef);
|
|
|
|
|
|
|
|
|
|
impl Context {
|
|
|
|
|
fn new() -> Self {
|
|
|
|
|
Self(unsafe { LLVMContextCreate() })
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get(&self) -> LLVMContextRef {
|
|
|
|
|
self.0
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Drop for Context {
|
|
|
|
|
fn drop(&mut self) {
|
|
|
|
|
unsafe {
|
|
|
|
|
LLVMContextDispose(self.0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct Module(LLVMModuleRef);
|
|
|
|
|
|
|
|
|
|
impl Module {
|
|
|
|
|
fn new(ctx: &Context, name: &CStr) -> Self {
|
|
|
|
|
Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) })
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get(&self) -> LLVMModuleRef {
|
|
|
|
|
self.0
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn verify(&self) -> Result<(), Message> {
|
|
|
|
|
let mut err = ptr::null_mut();
|
|
|
|
|
let error = unsafe {
|
|
|
|
|
LLVMVerifyModule(
|
|
|
|
|
self.get(),
|
|
|
|
|
LLVMVerifierFailureAction::LLVMReturnStatusAction,
|
|
|
|
|
&mut err,
|
|
|
|
|
)
|
|
|
|
|
};
|
|
|
|
|
if error == 1 && err != ptr::null_mut() {
|
|
|
|
|
Err(Message(unsafe { CStr::from_ptr(err) }))
|
|
|
|
|
} else {
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn write_bitcode_to_memory(&self) -> MemoryBuffer {
|
|
|
|
|
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
|
|
|
|
|
MemoryBuffer(memory_buffer)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn write_to_stderr(&self) {
|
|
|
|
|
unsafe { LLVMDumpModule(self.get()) };
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Drop for Module {
|
|
|
|
|
fn drop(&mut self) {
|
|
|
|
|
unsafe {
|
|
|
|
|
LLVMDisposeModule(self.0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct Builder(LLVMBuilderRef);
|
|
|
|
|
|
|
|
|
|
impl Builder {
|
|
|
|
|
fn new(ctx: &Context) -> Self {
|
|
|
|
|
Self::new_raw(ctx.get())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn new_raw(ctx: LLVMContextRef) -> Self {
|
|
|
|
|
Self(unsafe { LLVMCreateBuilderInContext(ctx) })
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get(&self) -> LLVMBuilderRef {
|
|
|
|
|
self.0
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Drop for Builder {
|
|
|
|
|
fn drop(&mut self) {
|
|
|
|
|
unsafe {
|
|
|
|
|
LLVMDisposeBuilder(self.0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct Message(&'static CStr);
|
|
|
|
|
|
|
|
|
|
impl Drop for Message {
|
|
|
|
|
fn drop(&mut self) {
|
|
|
|
|
unsafe {
|
|
|
|
|
LLVMDisposeMessage(self.0.as_ptr().cast_mut());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl std::fmt::Debug for Message {
|
|
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
|
|
std::fmt::Debug::fmt(&self.0, f)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct MemoryBuffer(LLVMMemoryBufferRef);
|
|
|
|
|
|
|
|
|
|
impl Drop for MemoryBuffer {
|
|
|
|
|
fn drop(&mut self) {
|
|
|
|
|
unsafe {
|
|
|
|
|
LLVMDisposeMemoryBuffer(self.0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Deref for MemoryBuffer {
|
|
|
|
|
type Target = [u8];
|
|
|
|
|
|
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
|
|
|
let data = unsafe { LLVMGetBufferStart(self.0) };
|
|
|
|
|
let len = unsafe { LLVMGetBufferSize(self.0) };
|
|
|
|
|
unsafe { std::slice::from_raw_parts(data.cast(), len) }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(super) fn run<'input>(
|
|
|
|
|
id_defs: &GlobalStringIdResolver<'input>,
|
|
|
|
|
call_map: MethodsCallMap<'input>,
|
|
|
|
|
directives: Vec<Directive<'input>>,
|
|
|
|
|
) -> Result<MemoryBuffer, TranslateError> {
|
|
|
|
|
let context = inkwell::context::Context::create();
|
|
|
|
|
let module = context.create_module(LLVM_UNNAMED);
|
|
|
|
|
let builder = context.create_builder();
|
|
|
|
|
let mut emit_ctx = ModuleEmitContext::new(&context, module, builder, id_defs);
|
|
|
|
|
let context = Context::new();
|
|
|
|
|
let module = Module::new(&context, LLVM_UNNAMED);
|
|
|
|
|
let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
|
|
|
|
|
for directive in directives {
|
|
|
|
|
match directive {
|
|
|
|
|
Directive::Variable(..) => todo!(),
|
|
|
|
|
Directive::Method(method) => emit_ctx.emit_method(method)?,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if let Err(err) = emit_ctx.module.verify() {
|
|
|
|
|
emit_ctx.module.print_to_stderr();
|
|
|
|
|
panic!("{}", err);
|
|
|
|
|
module.write_to_stderr();
|
|
|
|
|
if let Err(err) = module.verify() {
|
|
|
|
|
panic!("{:?}", err);
|
|
|
|
|
}
|
|
|
|
|
Ok(emit_ctx.module.write_bitcode_to_memory())
|
|
|
|
|
Ok(module.write_bitcode_to_memory())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct ModuleEmitContext<'ctx, 'input> {
|
|
|
|
|
context: &'ctx Context,
|
|
|
|
|
module: module::Module<'ctx>,
|
|
|
|
|
builder: Builder<'ctx>,
|
|
|
|
|
id_defs: &'ctx GlobalStringIdResolver<'input>,
|
|
|
|
|
resolver: ResolveIdent<'ctx>,
|
|
|
|
|
struct ModuleEmitContext<'a, 'input> {
|
|
|
|
|
context: LLVMContextRef,
|
|
|
|
|
module: LLVMModuleRef,
|
|
|
|
|
builder: Builder,
|
|
|
|
|
id_defs: &'a GlobalStringIdResolver<'input>,
|
|
|
|
|
resolver: ResolveIdent,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
|
|
|
|
|
impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|
|
|
|
fn new(
|
|
|
|
|
context: &'ctx Context,
|
|
|
|
|
module: module::Module<'ctx>,
|
|
|
|
|
builder: Builder<'ctx>,
|
|
|
|
|
id_defs: &'ctx GlobalStringIdResolver<'input>,
|
|
|
|
|
context: &Context,
|
|
|
|
|
module: &Module,
|
|
|
|
|
id_defs: &'a GlobalStringIdResolver<'input>,
|
|
|
|
|
) -> Self {
|
|
|
|
|
ModuleEmitContext {
|
|
|
|
|
context: &context,
|
|
|
|
|
module,
|
|
|
|
|
builder,
|
|
|
|
|
context: context.get(),
|
|
|
|
|
module: module.get(),
|
|
|
|
|
builder: Builder::new(context),
|
|
|
|
|
id_defs,
|
|
|
|
|
resolver: ResolveIdent::new(&id_defs),
|
|
|
|
|
}
|
|
|
|
@ -84,85 +217,86 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
|
|
|
|
|
|
|
|
|
|
fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
|
|
|
|
|
let func_decl = method.func_decl.borrow();
|
|
|
|
|
let fn_ = self.module.add_function(
|
|
|
|
|
method
|
|
|
|
|
let name = method
|
|
|
|
|
.import_as
|
|
|
|
|
.as_deref()
|
|
|
|
|
.unwrap_or_else(|| match func_decl.name {
|
|
|
|
|
ast::MethodName::Kernel(name) => name,
|
|
|
|
|
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
|
|
|
|
|
}),
|
|
|
|
|
self.function_type(
|
|
|
|
|
});
|
|
|
|
|
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
|
|
|
|
let fn_type = self.function_type(
|
|
|
|
|
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
|
|
|
|
func_decl.input_arguments.iter().map(|v| &v.v_type),
|
|
|
|
|
),
|
|
|
|
|
None,
|
|
|
|
|
);
|
|
|
|
|
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
|
|
|
|
for (i, param) in func_decl.input_arguments.iter().enumerate() {
|
|
|
|
|
let value = fn_
|
|
|
|
|
.get_nth_param(i as u32)
|
|
|
|
|
.ok_or_else(|| error_unreachable())?;
|
|
|
|
|
value.set_name(self.resolver.get_or_add(param.name));
|
|
|
|
|
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
|
|
|
|
let name = self.resolver.get_or_add(param.name);
|
|
|
|
|
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
|
|
|
|
self.resolver.register(param.name, value);
|
|
|
|
|
}
|
|
|
|
|
fn_.set_call_conventions(if func_decl.name.is_kernel() {
|
|
|
|
|
let call_conv = if func_decl.name.is_kernel() {
|
|
|
|
|
Self::kernel_call_convention()
|
|
|
|
|
} else {
|
|
|
|
|
Self::func_call_convention()
|
|
|
|
|
});
|
|
|
|
|
};
|
|
|
|
|
unsafe { LLVMSetFunctionCallConv(fn_, call_conv) };
|
|
|
|
|
if let Some(statements) = method.body {
|
|
|
|
|
let variables_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED);
|
|
|
|
|
let variables_builder = self.context.create_builder();
|
|
|
|
|
variables_builder.position_at_end(variables_bb);
|
|
|
|
|
let real_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED);
|
|
|
|
|
self.builder.position_at_end(real_bb);
|
|
|
|
|
let variables_bb =
|
|
|
|
|
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
|
|
|
|
|
let variables_builder = Builder::new_raw(self.context);
|
|
|
|
|
unsafe { LLVMPositionBuilderAtEnd(variables_builder.get(), variables_bb) };
|
|
|
|
|
let real_bb =
|
|
|
|
|
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
|
|
|
|
|
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
|
|
|
|
|
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
|
|
|
|
|
for statement in statements {
|
|
|
|
|
method_emitter.emit_statement(statement)?;
|
|
|
|
|
}
|
|
|
|
|
method_emitter.variables_builder.build_unconditional_branch(real_bb);
|
|
|
|
|
unsafe { LLVMBuildBr(method_emitter.variables_builder.get(), real_bb) };
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn function_type<'a>(
|
|
|
|
|
fn function_type(
|
|
|
|
|
&self,
|
|
|
|
|
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
|
|
|
|
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
|
|
|
|
) -> FunctionType<'ctx> {
|
|
|
|
|
) -> LLVMTypeRef {
|
|
|
|
|
if return_args.len() == 0 {
|
|
|
|
|
let input_args = input_args
|
|
|
|
|
let mut input_args = input_args
|
|
|
|
|
.map(|type_| match type_ {
|
|
|
|
|
ast::Type::Scalar(scalar) => match scalar {
|
|
|
|
|
ast::ScalarType::Pred => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.bool_type())
|
|
|
|
|
unsafe { LLVMInt1TypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.i8_type())
|
|
|
|
|
unsafe { LLVMInt8TypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.i16_type())
|
|
|
|
|
unsafe { LLVMInt16TypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.i32_type())
|
|
|
|
|
unsafe { LLVMInt32TypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.i64_type())
|
|
|
|
|
unsafe { LLVMInt64TypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::B128 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.i128_type())
|
|
|
|
|
unsafe { LLVMInt128TypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::F16 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.f16_type())
|
|
|
|
|
unsafe { LLVMHalfTypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::F32 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.f32_type())
|
|
|
|
|
unsafe { LLVMFloatTypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::F64 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(self.context.f64_type())
|
|
|
|
|
unsafe { LLVMDoubleTypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::BF16 => {
|
|
|
|
|
BasicMetadataTypeEnum::from(unsafe { FloatType::new(LLVMBFloatType()) })
|
|
|
|
|
unsafe { LLVMBFloatTypeInContext(self.context) }
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::U16x2 => todo!(),
|
|
|
|
|
ast::ScalarType::S16x2 => todo!(),
|
|
|
|
@ -174,41 +308,39 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> {
|
|
|
|
|
ast::Type::Pointer(_, _) => todo!(),
|
|
|
|
|
})
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
return self.context.void_type().fn_type(&*input_args, false);
|
|
|
|
|
return unsafe {
|
|
|
|
|
LLVMFunctionType(
|
|
|
|
|
LLVMVoidTypeInContext(self.context),
|
|
|
|
|
input_args.as_mut_ptr(),
|
|
|
|
|
input_args.len() as u32,
|
|
|
|
|
0,
|
|
|
|
|
)
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
todo!()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get_type(&self, type_: &ast::Type) -> FunctionType<'ctx> {
|
|
|
|
|
match type_ {
|
|
|
|
|
ast::Type::Scalar(_) => todo!(),
|
|
|
|
|
ast::Type::Vector(_, _) => todo!(),
|
|
|
|
|
ast::Type::Array(_, _, _) => todo!(),
|
|
|
|
|
ast::Type::Pointer(_, _) => todo!(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
context: &'ctx Context,
|
|
|
|
|
module: &'a module::Module<'ctx>,
|
|
|
|
|
method: FunctionValue<'ctx>,
|
|
|
|
|
builder: &'a Builder<'ctx>,
|
|
|
|
|
struct MethodEmitContext<'a, 'input> {
|
|
|
|
|
context: LLVMContextRef,
|
|
|
|
|
module: LLVMModuleRef,
|
|
|
|
|
method: LLVMValueRef,
|
|
|
|
|
builder: LLVMBuilderRef,
|
|
|
|
|
id_defs: &'a GlobalStringIdResolver<'input>,
|
|
|
|
|
variables_builder: Builder<'ctx>,
|
|
|
|
|
resolver: &'a mut ResolveIdent<'ctx>,
|
|
|
|
|
variables_builder: Builder,
|
|
|
|
|
resolver: &'a mut ResolveIdent,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
fn new(
|
|
|
|
|
parent: &'a mut ModuleEmitContext<'ctx, 'input>,
|
|
|
|
|
method: FunctionValue<'ctx>,
|
|
|
|
|
variables_builder: Builder<'ctx>,
|
|
|
|
|
) -> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|
|
|
|
fn new<'x>(
|
|
|
|
|
parent: &'a mut ModuleEmitContext<'x, 'input>,
|
|
|
|
|
method: LLVMValueRef,
|
|
|
|
|
variables_builder: Builder,
|
|
|
|
|
) -> MethodEmitContext<'a, 'input> {
|
|
|
|
|
MethodEmitContext {
|
|
|
|
|
context: &parent.context,
|
|
|
|
|
module: &parent.module,
|
|
|
|
|
builder: &parent.builder,
|
|
|
|
|
context: parent.context,
|
|
|
|
|
module: parent.module,
|
|
|
|
|
builder: parent.builder.get(),
|
|
|
|
|
id_defs: parent.id_defs,
|
|
|
|
|
variables_builder,
|
|
|
|
|
resolver: &mut parent.resolver,
|
|
|
|
@ -238,19 +370,16 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
|
|
|
|
|
fn emit_variable(&mut self, var: ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
|
|
|
|
let alloca = unsafe {
|
|
|
|
|
PointerValue::new(LLVMZludaBuildAlloca(
|
|
|
|
|
self.variables_builder.as_mut_ptr(),
|
|
|
|
|
get_type::<BasicTypeEnum>(&self.context, &var.v_type)?.as_type_ref(),
|
|
|
|
|
get_state_space(var.state_space)? as u32,
|
|
|
|
|
LLVMZludaBuildAlloca(
|
|
|
|
|
self.variables_builder.get(),
|
|
|
|
|
get_type(self.context, &var.v_type)?,
|
|
|
|
|
get_state_space(var.state_space)?,
|
|
|
|
|
self.resolver.get_or_add_raw(var.name),
|
|
|
|
|
))
|
|
|
|
|
)
|
|
|
|
|
};
|
|
|
|
|
self.resolver.register(var.name, alloca);
|
|
|
|
|
if let Some(align) = var.align {
|
|
|
|
|
let alloca = alloca.as_instruction().ok_or_else(|| error_unreachable())?;
|
|
|
|
|
alloca
|
|
|
|
|
.set_alignment(align)
|
|
|
|
|
.map_err(|_| error_unreachable())?;
|
|
|
|
|
unsafe { LLVMSetAlignment(alloca, align) };
|
|
|
|
|
}
|
|
|
|
|
if !var.array_init.is_empty() {
|
|
|
|
|
todo!()
|
|
|
|
@ -259,27 +388,24 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn emit_label(&mut self, label: SpirvWord) {
|
|
|
|
|
let block = self
|
|
|
|
|
.context
|
|
|
|
|
.append_basic_block(self.method, self.resolver.get_or_add(label));
|
|
|
|
|
if self
|
|
|
|
|
.builder
|
|
|
|
|
.get_insert_block()
|
|
|
|
|
.unwrap()
|
|
|
|
|
.get_terminator()
|
|
|
|
|
.is_none()
|
|
|
|
|
{
|
|
|
|
|
self.builder.build_unconditional_branch(block);
|
|
|
|
|
let block = unsafe {
|
|
|
|
|
LLVMAppendBasicBlockInContext(
|
|
|
|
|
self.context,
|
|
|
|
|
self.method,
|
|
|
|
|
self.resolver.get_or_add_raw(label),
|
|
|
|
|
)
|
|
|
|
|
};
|
|
|
|
|
let last_block = unsafe { LLVMGetInsertBlock(self.builder) };
|
|
|
|
|
if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() {
|
|
|
|
|
unsafe { LLVMBuildBr(self.builder, block) };
|
|
|
|
|
}
|
|
|
|
|
self.builder.position_at_end(block);
|
|
|
|
|
unsafe { LLVMPositionBuilderAtEnd(self.builder, block) };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn emit_store_var(&mut self, store: StoreVarDetails) -> Result<(), TranslateError> {
|
|
|
|
|
let src1 = self.resolver.value(store.arg.src1)?;
|
|
|
|
|
let src2 = self.resolver.value(store.arg.src2)?;
|
|
|
|
|
self.builder
|
|
|
|
|
.build_store(src1.as_pointer()?, src2.as_basic()?)
|
|
|
|
|
.map_err(|_| error_unreachable())?;
|
|
|
|
|
let ptr = self.resolver.value(store.arg.src1)?;
|
|
|
|
|
let value = self.resolver.value(store.arg.src2)?;
|
|
|
|
|
unsafe { LLVMBuildStore(self.builder, value, ptr) };
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -303,7 +429,7 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
ast::Instruction::Cvt { data, arguments } => todo!(),
|
|
|
|
|
ast::Instruction::Shr { data, arguments } => todo!(),
|
|
|
|
|
ast::Instruction::Shl { data, arguments } => todo!(),
|
|
|
|
|
ast::Instruction::Ret { data } => self.emit_ret(data),
|
|
|
|
|
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
|
|
|
|
|
ast::Instruction::Cvta { data, arguments } => todo!(),
|
|
|
|
|
ast::Instruction::Abs { data, arguments } => todo!(),
|
|
|
|
|
ast::Instruction::Mad { data, arguments } => todo!(),
|
|
|
|
@ -351,10 +477,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
todo!()
|
|
|
|
|
}
|
|
|
|
|
let builder = self.builder;
|
|
|
|
|
let type_ = get_type::<BasicTypeEnum>(&self.context, &data.typ)?;
|
|
|
|
|
let ptr = self.resolver.value(arguments.src)?.as_pointer()?;
|
|
|
|
|
self.resolver
|
|
|
|
|
.with_result(arguments.dst, |dst| builder.build_load(type_, ptr, dst))
|
|
|
|
|
let type_ = get_type(self.context, &data.typ)?;
|
|
|
|
|
let ptr = self.resolver.value(arguments.src)?;
|
|
|
|
|
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
|
|
|
|
LLVMBuildLoad2(builder, type_, ptr, dst)
|
|
|
|
|
});
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn emit_load_variable(&mut self, var: LoadVarDetails) -> Result<(), TranslateError> {
|
|
|
|
@ -362,10 +490,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
todo!()
|
|
|
|
|
}
|
|
|
|
|
let builder = self.builder;
|
|
|
|
|
let type_ = get_type::<BasicTypeEnum>(&self.context, &var.typ)?;
|
|
|
|
|
let ptr = self.resolver.value(var.arg.src)?.as_pointer()?;
|
|
|
|
|
self.resolver
|
|
|
|
|
.with_result(var.arg.dst, |dst| builder.build_load(type_, ptr, dst))
|
|
|
|
|
let type_ = get_type(self.context, &var.typ)?;
|
|
|
|
|
let ptr = self.resolver.value(var.arg.src)?;
|
|
|
|
|
self.resolver.with_result(var.arg.dst, |dst| unsafe {
|
|
|
|
|
LLVMBuildLoad2(builder, type_, ptr, dst)
|
|
|
|
|
});
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
|
|
|
|
@ -374,11 +504,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
ConversionKind::Default => todo!(),
|
|
|
|
|
ConversionKind::SignExtend => todo!(),
|
|
|
|
|
ConversionKind::BitToPtr => {
|
|
|
|
|
let src = self.resolver.value(conversion.src)?.as_int()?;
|
|
|
|
|
let src = self.resolver.value(conversion.src)?;
|
|
|
|
|
let type_ = get_pointer_type(self.context, conversion.to_space)?;
|
|
|
|
|
self.resolver.with_result(conversion.dst, |dst| {
|
|
|
|
|
builder.build_int_to_ptr(src, type_, dst)
|
|
|
|
|
})
|
|
|
|
|
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
|
|
|
|
LLVMBuildIntToPtr(builder, src, type_, dst)
|
|
|
|
|
});
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
ConversionKind::PtrToPtr => todo!(),
|
|
|
|
|
ConversionKind::AddressOf => todo!(),
|
|
|
|
@ -386,21 +517,12 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
|
|
|
|
|
let type_ = get_scalar_type::<BasicTypeEnum>(&self.context, constant.typ);
|
|
|
|
|
let value: AnyValueEnum = match (type_, constant.value) {
|
|
|
|
|
(BasicTypeEnum::IntType(type_), ast::ImmediateValue::U64(x)) => {
|
|
|
|
|
type_.const_int(x, false).into()
|
|
|
|
|
}
|
|
|
|
|
(BasicTypeEnum::IntType(type_), ast::ImmediateValue::S64(x)) => {
|
|
|
|
|
type_.const_int(x as u64, false).into()
|
|
|
|
|
}
|
|
|
|
|
(BasicTypeEnum::FloatType(type_), ast::ImmediateValue::F32(x)) => {
|
|
|
|
|
type_.const_float(x as f64).into()
|
|
|
|
|
}
|
|
|
|
|
(BasicTypeEnum::FloatType(type_), ast::ImmediateValue::F64(x)) => {
|
|
|
|
|
type_.const_float(x).into()
|
|
|
|
|
}
|
|
|
|
|
_ => return Err(error_unreachable()),
|
|
|
|
|
let type_ = get_scalar_type(self.context, constant.typ);
|
|
|
|
|
let value = match constant.value {
|
|
|
|
|
ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) },
|
|
|
|
|
ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) },
|
|
|
|
|
ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) },
|
|
|
|
|
ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) },
|
|
|
|
|
};
|
|
|
|
|
self.resolver.register(constant.dst, value);
|
|
|
|
|
Ok(())
|
|
|
|
@ -412,14 +534,16 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
arguments: ast::AddArgs<SpirvWord>,
|
|
|
|
|
) -> Result<(), TranslateError> {
|
|
|
|
|
let builder = self.builder;
|
|
|
|
|
let src1 = self.resolver.value(arguments.src1)?.as_int()?;
|
|
|
|
|
let src2 = self.resolver.value(arguments.src2)?.as_int()?;
|
|
|
|
|
let src1 = self.resolver.value(arguments.src1)?;
|
|
|
|
|
let src2 = self.resolver.value(arguments.src2)?;
|
|
|
|
|
let fn_ = match data {
|
|
|
|
|
ast::ArithDetails::Integer(integer) => Builder::build_int_add,
|
|
|
|
|
ast::ArithDetails::Float(float) => todo!(),
|
|
|
|
|
ast::ArithDetails::Integer(integer) => LLVMBuildAdd,
|
|
|
|
|
ast::ArithDetails::Float(float) => LLVMBuildFAdd,
|
|
|
|
|
};
|
|
|
|
|
self.resolver
|
|
|
|
|
.with_result(arguments.dst, |dst| fn_(builder, src1, src2, dst))
|
|
|
|
|
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
|
|
|
|
fn_(builder, src1, src2, dst)
|
|
|
|
|
});
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn emit_st(
|
|
|
|
@ -427,129 +551,80 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> {
|
|
|
|
|
data: ptx_parser::StData,
|
|
|
|
|
arguments: ptx_parser::StArgs<SpirvWord>,
|
|
|
|
|
) -> Result<(), TranslateError> {
|
|
|
|
|
let builder = self.builder;
|
|
|
|
|
let src1 = self.resolver.value(arguments.src1)?.as_pointer()?;
|
|
|
|
|
let src2 = self.resolver.value(arguments.src2)?.as_basic()?;
|
|
|
|
|
let ptr = self.resolver.value(arguments.src1)?;
|
|
|
|
|
let value = self.resolver.value(arguments.src2)?;
|
|
|
|
|
if data.qualifier != ast::LdStQualifier::Weak {
|
|
|
|
|
todo!()
|
|
|
|
|
}
|
|
|
|
|
self.builder
|
|
|
|
|
.build_store(src1, src2)
|
|
|
|
|
.map_err(|_| error_unreachable())?;
|
|
|
|
|
unsafe { LLVMBuildStore(self.builder, value, ptr) };
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn emit_ret(&self, _data: ptx_parser::RetData) -> Result<(), TranslateError> {
|
|
|
|
|
self.builder
|
|
|
|
|
.build_return(None)
|
|
|
|
|
.map_err(|_| error_unreachable())?;
|
|
|
|
|
Ok(())
|
|
|
|
|
fn emit_ret(&self, _data: ptx_parser::RetData) {
|
|
|
|
|
unsafe { LLVMBuildRetVoid(self.builder) };
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get_pointer_type<'ctx>(
|
|
|
|
|
context: &'ctx Context,
|
|
|
|
|
context: LLVMContextRef,
|
|
|
|
|
to_space: ast::StateSpace,
|
|
|
|
|
) -> Result<PointerType<'ctx>, TranslateError> {
|
|
|
|
|
Ok(context.ptr_type(AddressSpace::from(get_state_space(to_space)?)))
|
|
|
|
|
) -> Result<LLVMTypeRef, TranslateError> {
|
|
|
|
|
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get_type<
|
|
|
|
|
'ctx,
|
|
|
|
|
T: From<IntType<'ctx>>
|
|
|
|
|
+ From<FloatType<'ctx>>
|
|
|
|
|
+ From<VectorType<'ctx>>
|
|
|
|
|
+ From<PointerType<'ctx>>
|
|
|
|
|
+ From<ArrayType<'ctx>>,
|
|
|
|
|
>(
|
|
|
|
|
context: &'ctx Context,
|
|
|
|
|
type_: &ast::Type,
|
|
|
|
|
) -> Result<T, TranslateError> {
|
|
|
|
|
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
|
|
|
|
|
Ok(match type_ {
|
|
|
|
|
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
|
|
|
|
|
ast::Type::Vector(size, scalar) => {
|
|
|
|
|
let base_type = get_scalar_type::<BasicTypeEnum>(context, *scalar);
|
|
|
|
|
let base_type = match base_type {
|
|
|
|
|
BasicTypeEnum::FloatType(t) => t.as_type_ref(),
|
|
|
|
|
BasicTypeEnum::IntType(t) => t.as_type_ref(),
|
|
|
|
|
_ => return Err(error_unreachable()),
|
|
|
|
|
};
|
|
|
|
|
T::from(unsafe { VectorType::new(LLVMVectorType(base_type, *size as u32)) })
|
|
|
|
|
let base_type = get_scalar_type(context, *scalar);
|
|
|
|
|
unsafe { LLVMVectorType(base_type, *size as u32) }
|
|
|
|
|
}
|
|
|
|
|
ast::Type::Array(vec, scalar, dimensions) => {
|
|
|
|
|
let mut underlying_type = get_scalar_type::<BasicTypeEnum>(context, *scalar);
|
|
|
|
|
let mut underlying_type = get_scalar_type(context, *scalar);
|
|
|
|
|
if let Some(size) = vec {
|
|
|
|
|
underlying_type = BasicTypeEnum::VectorType(unsafe {
|
|
|
|
|
VectorType::new(LLVMVectorType(
|
|
|
|
|
match underlying_type {
|
|
|
|
|
BasicTypeEnum::FloatType(t) => t.as_type_ref(),
|
|
|
|
|
BasicTypeEnum::IntType(t) => t.as_type_ref(),
|
|
|
|
|
_ => return Err(error_unreachable()),
|
|
|
|
|
},
|
|
|
|
|
size.get() as u32,
|
|
|
|
|
))
|
|
|
|
|
});
|
|
|
|
|
underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) };
|
|
|
|
|
}
|
|
|
|
|
if dimensions.is_empty() {
|
|
|
|
|
return Ok(T::from(underlying_type.array_type(0)));
|
|
|
|
|
return Ok(unsafe { LLVMArrayType2(underlying_type, 0) });
|
|
|
|
|
}
|
|
|
|
|
let llvm_type = dimensions
|
|
|
|
|
dimensions
|
|
|
|
|
.iter()
|
|
|
|
|
.rfold(underlying_type.as_type_ref(), |result, dimension| unsafe {
|
|
|
|
|
.rfold(underlying_type, |result, dimension| unsafe {
|
|
|
|
|
LLVMArrayType2(result, *dimension as u64)
|
|
|
|
|
});
|
|
|
|
|
T::from(unsafe { ArrayType::new(llvm_type) })
|
|
|
|
|
}
|
|
|
|
|
ast::Type::Pointer(_, space) => {
|
|
|
|
|
T::from(context.ptr_type(AddressSpace::from(get_state_space(*space)?)))
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
ast::Type::Pointer(_, space) => get_pointer_type(context, *space)?,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get_scalar_type<
|
|
|
|
|
'ctx,
|
|
|
|
|
T: From<IntType<'ctx>> + From<FloatType<'ctx>> + From<VectorType<'ctx>>,
|
|
|
|
|
>(
|
|
|
|
|
context: &'ctx Context,
|
|
|
|
|
type_: ast::ScalarType,
|
|
|
|
|
) -> T {
|
|
|
|
|
fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef {
|
|
|
|
|
match type_ {
|
|
|
|
|
ast::ScalarType::Pred => T::from(context.bool_type()),
|
|
|
|
|
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
|
|
|
|
|
T::from(context.i8_type())
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
|
|
|
|
|
T::from(context.i16_type())
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
|
|
|
|
|
T::from(context.i32_type())
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
|
|
|
|
|
T::from(context.i64_type())
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::B128 => T::from(context.i128_type()),
|
|
|
|
|
ast::ScalarType::F16 => T::from(context.f16_type()),
|
|
|
|
|
ast::ScalarType::F32 => T::from(context.f32_type()),
|
|
|
|
|
ast::ScalarType::F64 => T::from(context.f64_type()),
|
|
|
|
|
ast::ScalarType::BF16 => {
|
|
|
|
|
T::from(unsafe { FloatType::new(LLVMBFloatTypeInContext(context.as_ctx_ref())) })
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => {
|
|
|
|
|
T::from(unsafe { VectorType::new(LLVMVectorType(context.i16_type().as_type_ref(), 2)) })
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::F16x2 => {
|
|
|
|
|
T::from(unsafe { VectorType::new(LLVMVectorType(context.f16_type().as_type_ref(), 2)) })
|
|
|
|
|
}
|
|
|
|
|
ast::ScalarType::BF16x2 => T::from(unsafe {
|
|
|
|
|
VectorType::new(LLVMVectorType(
|
|
|
|
|
LLVMBFloatTypeInContext(context.as_ctx_ref()),
|
|
|
|
|
2,
|
|
|
|
|
))
|
|
|
|
|
}),
|
|
|
|
|
ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) },
|
|
|
|
|
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
|
|
|
|
|
LLVMInt8TypeInContext(context)
|
|
|
|
|
},
|
|
|
|
|
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
|
|
|
|
|
LLVMInt16TypeInContext(context)
|
|
|
|
|
},
|
|
|
|
|
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
|
|
|
|
|
LLVMInt32TypeInContext(context)
|
|
|
|
|
},
|
|
|
|
|
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe {
|
|
|
|
|
LLVMInt64TypeInContext(context)
|
|
|
|
|
},
|
|
|
|
|
ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) },
|
|
|
|
|
ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) },
|
|
|
|
|
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
|
|
|
|
|
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
|
|
|
|
|
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
|
|
|
|
|
ast::ScalarType::U16x2 => todo!(),
|
|
|
|
|
ast::ScalarType::S16x2 => todo!(),
|
|
|
|
|
ast::ScalarType::F16x2 => todo!(),
|
|
|
|
|
ast::ScalarType::BF16x2 => todo!(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get_state_space(space: ast::StateSpace) -> Result<u16, TranslateError> {
|
|
|
|
|
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
|
|
|
|
|
match space {
|
|
|
|
|
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
|
|
|
|
|
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
|
|
|
|
@ -566,12 +641,12 @@ fn get_state_space(space: ast::StateSpace) -> Result<u16, TranslateError> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct ResolveIdent<'ctx> {
|
|
|
|
|
struct ResolveIdent {
|
|
|
|
|
words: HashMap<SpirvWord, String>,
|
|
|
|
|
values: HashMap<SpirvWord, AnyValueEnum<'ctx>>,
|
|
|
|
|
values: HashMap<SpirvWord, LLVMValueRef>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<'ctx> ResolveIdent<'ctx> {
|
|
|
|
|
impl ResolveIdent {
|
|
|
|
|
fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
|
|
|
|
|
ResolveIdent {
|
|
|
|
|
words: HashMap::new(),
|
|
|
|
@ -580,14 +655,15 @@ impl<'ctx> ResolveIdent<'ctx> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T {
|
|
|
|
|
match self.words.entry(word) {
|
|
|
|
|
hash_map::Entry::Occupied(entry) => fn_(entry.into_mut()),
|
|
|
|
|
let str = match self.words.entry(word) {
|
|
|
|
|
hash_map::Entry::Occupied(entry) => entry.into_mut(),
|
|
|
|
|
hash_map::Entry::Vacant(entry) => {
|
|
|
|
|
let mut text = word.0.to_string();
|
|
|
|
|
text.push('\0');
|
|
|
|
|
fn_(entry.insert(text))
|
|
|
|
|
}
|
|
|
|
|
entry.insert(text)
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
fn_(&str[..str.len() - 1])
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get_or_add(&mut self, word: SpirvWord) -> &str {
|
|
|
|
@ -598,153 +674,19 @@ impl<'ctx> ResolveIdent<'ctx> {
|
|
|
|
|
self.get_or_add(word).as_ptr().cast()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn register(&mut self, word: SpirvWord, t: impl AnyValue<'ctx>) {
|
|
|
|
|
self.values.insert(word, t.as_any_value_enum());
|
|
|
|
|
fn register(&mut self, word: SpirvWord, v: LLVMValueRef) {
|
|
|
|
|
self.values.insert(word, v);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn value(&self, word: SpirvWord) -> Result<AnyValueEnum<'ctx>, TranslateError> {
|
|
|
|
|
fn value(&self, word: SpirvWord) -> Result<LLVMValueRef, TranslateError> {
|
|
|
|
|
self.values
|
|
|
|
|
.get(&word)
|
|
|
|
|
.copied()
|
|
|
|
|
.ok_or_else(|| error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn with_result<T: AnyValue<'ctx>>(
|
|
|
|
|
&mut self,
|
|
|
|
|
word: SpirvWord,
|
|
|
|
|
fn_: impl FnOnce(&str) -> Result<T, BuilderError>,
|
|
|
|
|
) -> Result<(), TranslateError> {
|
|
|
|
|
let t = self
|
|
|
|
|
.get_or_ad_impl(word, fn_)
|
|
|
|
|
.map_err(|_| error_unreachable())?;
|
|
|
|
|
fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) {
|
|
|
|
|
let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast()));
|
|
|
|
|
self.register(word, t);
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_int_math(
|
|
|
|
|
&mut self,
|
|
|
|
|
builder: &Builder<'ctx>,
|
|
|
|
|
dst: SpirvWord,
|
|
|
|
|
src1: SpirvWord,
|
|
|
|
|
src2: SpirvWord,
|
|
|
|
|
fn_: impl IntMathOp<'ctx>,
|
|
|
|
|
) -> Result<(), TranslateError> {
|
|
|
|
|
let src1 = self.value(src1)?;
|
|
|
|
|
let src2 = self.value(src2)?;
|
|
|
|
|
self.with_result(dst, |dst| {
|
|
|
|
|
Ok(match (src1, src2) {
|
|
|
|
|
(AnyValueEnum::IntValue(src1), AnyValueEnum::IntValue(src2)) => {
|
|
|
|
|
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
|
|
|
|
|
}
|
|
|
|
|
(AnyValueEnum::PointerValue(src1), AnyValueEnum::PointerValue(src2)) => {
|
|
|
|
|
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
|
|
|
|
|
}
|
|
|
|
|
(AnyValueEnum::VectorValue(src1), AnyValueEnum::VectorValue(src2)) => {
|
|
|
|
|
AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?)
|
|
|
|
|
}
|
|
|
|
|
_ => return todo!(),
|
|
|
|
|
})
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
trait IntMathOp<'ctx> {
|
|
|
|
|
fn call<T: IntMathValue<'ctx>>(
|
|
|
|
|
self,
|
|
|
|
|
builder: &Builder<'ctx>,
|
|
|
|
|
src1: T,
|
|
|
|
|
src2: T,
|
|
|
|
|
dst: &str,
|
|
|
|
|
) -> Result<T, BuilderError>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
trait AnyValueEnumExt<'ctx> {
|
|
|
|
|
fn as_array(self) -> Result<ArrayValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_int(self) -> Result<IntValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_float(self) -> Result<FloatValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_phi(self) -> Result<PhiValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_function(self) -> Result<FunctionValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_pointer(self) -> Result<PointerValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_struct(self) -> Result<StructValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_vector(self) -> Result<VectorValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_instruction(self) -> Result<InstructionValue<'ctx>, TranslateError>;
|
|
|
|
|
fn as_basic(self) -> Result<BasicValueEnum<'ctx>, TranslateError>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<'ctx> AnyValueEnumExt<'ctx> for AnyValueEnum<'ctx> {
|
|
|
|
|
fn as_array(self) -> Result<ArrayValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::ArrayValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_int(self) -> Result<IntValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::IntValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_float(self) -> Result<FloatValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::FloatValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_phi(self) -> Result<PhiValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::PhiValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_function(self) -> Result<FunctionValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::FunctionValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_pointer(self) -> Result<PointerValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::PointerValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_struct(self) -> Result<StructValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::StructValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_vector(self) -> Result<VectorValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::VectorValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_instruction(self) -> Result<InstructionValue<'ctx>, TranslateError> {
|
|
|
|
|
if let AnyValueEnum::InstructionValue(x) = self {
|
|
|
|
|
Ok(x)
|
|
|
|
|
} else {
|
|
|
|
|
Err(error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn as_basic(self) -> Result<BasicValueEnum<'ctx>, TranslateError> {
|
|
|
|
|
BasicValueEnum::try_from(self).map_err(|_| error_unreachable())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|