Add ptx_impl bitcode module

This commit is contained in:
Andrzej Janik
2024-09-25 02:46:08 +02:00
parent c92abba2bb
commit 81baecf2c8
8 changed files with 152 additions and 40 deletions

View File

@ -79,6 +79,10 @@ impl ActionInfo {
unsafe { amd_comgr_action_info_set_isa_name(self.get(), full_isa.as_ptr().cast()) } unsafe { amd_comgr_action_info_set_isa_name(self.get(), full_isa.as_ptr().cast()) }
} }
fn set_language(&self, language: amd_comgr_language_t) -> Result<(), amd_comgr_status_s> {
unsafe { amd_comgr_action_info_set_language(self.get(), language) }
}
fn get(&self) -> amd_comgr_action_info_t { fn get(&self) -> amd_comgr_action_info_t {
self.0 self.0
} }
@ -90,36 +94,56 @@ impl Drop for ActionInfo {
} }
} }
pub fn compile_bitcode(gcn_arch: &CStr, buffer: &[u8]) -> Result<Vec<u8>, amd_comgr_status_s> { pub fn compile_bitcode(
gcn_arch: &CStr,
main_buffer: &[u8],
ptx_impl: &[u8],
) -> Result<Vec<u8>, amd_comgr_status_s> {
use amd_comgr_sys::*; use amd_comgr_sys::*;
let bitcode_data_set = DataSet::new()?; let bitcode_data_set = DataSet::new()?;
let bitcode_data = Data::new( let main_bitcode_data = Data::new(
amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC,
c"zluda.bc", c"zluda.bc",
buffer, main_buffer,
)?;
bitcode_data_set.add(&main_bitcode_data)?;
let stdlib_bitcode_data = Data::new(
amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC,
c"ptx_impl.bc",
ptx_impl,
)?;
bitcode_data_set.add(&stdlib_bitcode_data)?;
let lang_action_info = ActionInfo::new()?;
lang_action_info.set_isa_name(gcn_arch)?;
lang_action_info.set_language(amd_comgr_language_t::AMD_COMGR_LANGUAGE_LLVM_IR)?;
let linked_data_set = do_action(
&bitcode_data_set,
&lang_action_info,
amd_comgr_action_kind_t::AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC,
)?; )?;
bitcode_data_set.add(&bitcode_data)?;
let reloc_data_set = DataSet::new()?;
let action_info = ActionInfo::new()?; let action_info = ActionInfo::new()?;
action_info.set_isa_name(gcn_arch)?; action_info.set_isa_name(gcn_arch)?;
unsafe { let reloc_data_set = do_action(
amd_comgr_do_action( &linked_data_set,
amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, &action_info,
action_info.get(), amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE,
bitcode_data_set.get(), )?;
reloc_data_set.get(), let exec_data_set = do_action(
) &reloc_data_set,
}?; &action_info,
let exec_data_set = DataSet::new()?; amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE,
unsafe { )?;
amd_comgr_do_action(
amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE,
action_info.get(),
reloc_data_set.get(),
exec_data_set.get(),
)
}?;
let executable = let executable =
exec_data_set.get_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_EXECUTABLE, 0)?; exec_data_set.get_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_EXECUTABLE, 0)?;
executable.copy_content() executable.copy_content()
} }
fn do_action(
data_set: &DataSet,
action: &ActionInfo,
kind: amd_comgr_action_kind_t,
) -> Result<DataSet, amd_comgr_status_s> {
let result = DataSet::new()?;
unsafe { amd_comgr_do_action(kind, action.get(), data_set.get(), result.get()) }?;
Ok(result)
}

Binary file not shown.

View File

@ -0,0 +1,18 @@
// Every time this file changes it must te rebuilt, you need llvm-17:
// /opt/rocm/llvm/bin/clang -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && llvm-dis-17 zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | llvm-as-17 - -o zluda_ptx_impl.bc && llvm-dis-17 zluda_ptx_impl.bc
#include <cstddef>
#include <cstdint>
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_ ## NAME
extern "C" {
uint32_t FUNC(activemask)() {
return __builtin_amdgcn_read_exec_lo();
}
size_t __ockl_get_local_size(uint32_t) __device__;
uint32_t FUNC(sreg_ntid)(uint8_t member) {
return (uint32_t)__ockl_get_local_size(member);
}
}

View File

@ -94,7 +94,7 @@ fn run_method<'input>(
.body .body
.map(|statements| { .map(|statements| {
for statement in statements { for statement in statements {
run_statement(&remap_returns, &mut body, statement)?; run_statement(resolver, &remap_returns, &mut body, statement)?;
} }
Ok::<_, TranslateError>(body) Ok::<_, TranslateError>(body)
}) })
@ -110,6 +110,7 @@ fn run_method<'input>(
} }
fn run_statement<'input>( fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>, statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
@ -133,6 +134,66 @@ fn run_statement<'input>(
} }
result.push(statement); result.push(statement);
} }
Statement::Instruction(ast::Instruction::Call {
mut data,
mut arguments,
}) => {
let mut post_st = Vec::new();
for ((type_, space), ident) in data
.input_arguments
.iter_mut()
.zip(arguments.input_arguments.iter_mut())
{
if *space == ptx_parser::StateSpace::Param {
*space = ptx_parser::StateSpace::Reg;
let old_name = *ident;
*ident = resolver
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: *ident,
src: old_name,
},
}));
}
}
for ((type_, space), ident) in data
.return_arguments
.iter_mut()
.zip(arguments.return_arguments.iter_mut())
{
if *space == ptx_parser::StateSpace::Param {
*space = ptx_parser::StateSpace::Reg;
let old_name = *ident;
*ident = resolver
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
post_st.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: *ident,
},
}));
}
}
result.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
result.extend(post_st.into_iter());
}
statement => { statement => {
result.push(statement); result.push(statement);
} }

View File

@ -31,10 +31,10 @@ pub(super) fn run<'a, 'input>(
sreg_to_function, sreg_to_function,
result: Vec::new(), result: Vec::new(),
}; };
directives for directive in directives.into_iter() {
.into_iter() result.push(run_directive(&mut visitor, directive)?);
.map(|directive| run_directive(&mut visitor, directive)) }
.collect::<Result<Vec<_>, _>>() Ok(result)
} }
fn run_directive<'a, 'input>( fn run_directive<'a, 'input>(

View File

@ -5,7 +5,7 @@ pub(super) fn run<'input>(
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len()); let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() { for mut directive in directives.into_iter() {
run_directive(&mut result, &mut directive); run_directive(&mut result, &mut directive)?;
result.push(directive); result.push(directive);
} }
Ok(result) Ok(result)

View File

@ -39,9 +39,8 @@ mod normalize_predicates;
mod normalize_predicates2; mod normalize_predicates2;
mod resolve_function_pointers; mod resolve_function_pointers;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_";
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> { pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1)); let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1));
@ -220,6 +219,12 @@ pub struct Module {
pub kernel_info: HashMap<String, KernelInfo>, pub kernel_info: HashMap<String, KernelInfo>,
} }
impl Module {
pub fn linked_bitcode(&self) -> &[u8] {
ZLUDA_PTX_IMPL
}
}
struct GlobalStringIdResolver<'input> { struct GlobalStringIdResolver<'input> {
current_id: SpirvWord, current_id: SpirvWord,
variables: HashMap<Cow<'input, str>, SpirvWord>, variables: HashMap<Cow<'input, str>, SpirvWord>,
@ -1975,7 +1980,7 @@ impl SpecialRegistersMap2 {
let name = let name =
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None)); ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
let return_type = sreg.get_function_return_type(); let return_type = sreg.get_function_return_type();
let input_type = sreg.get_function_return_type(); let input_type = sreg.get_function_input_type();
( (
sreg, sreg,
ast::MethodDeclaration { ast::MethodDeclaration {
@ -1988,14 +1993,17 @@ impl SpecialRegistersMap2 {
array_init: Vec::new(), array_init: Vec::new(),
}], }],
name: name, name: name,
input_arguments: vec![ast::Variable { input_arguments: input_type
align: None, .into_iter()
v_type: input_type.into(), .map(|type_| ast::Variable {
state_space: ast::StateSpace::Reg, align: None,
name: resolver v_type: type_.into(),
.register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))), state_space: ast::StateSpace::Reg,
array_init: Vec::new(), name: resolver
}], .register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
})
.collect::<Vec<_>>(),
shared_mem: None, shared_mem: None,
}, },
) )

View File

@ -326,6 +326,7 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
let elf_module = comgr::compile_bitcode( let elf_module = comgr::compile_bitcode(
unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) },
&*module.llvm_ir, &*module.llvm_ir,
module.linked_bitcode(),
) )
.unwrap(); .unwrap();
let mut module = ptr::null_mut(); let mut module = ptr::null_mut();