Implement nanosleep.u32 (#421)

This commit is contained in:
Violet
2025-07-21 17:42:04 -07:00
committed by GitHub
parent 72e2fe5b9a
commit 27cfd50ddd
19 changed files with 330 additions and 187 deletions

View File

@ -178,11 +178,14 @@ pub fn compile_bitcode(
comgr: &Comgr,
gcn_arch: &CStr,
main_buffer: &[u8],
attributes_buffer: &[u8],
ptx_impl: &[u8],
) -> Result<Vec<u8>, Error> {
let bitcode_data_set = DataSet::new(comgr)?;
let main_bitcode_data = Data::new(comgr, DataKind::Bc, c"zluda.bc", main_buffer)?;
bitcode_data_set.add(&main_bitcode_data)?;
let attributes_bitcode_data = Data::new(comgr, DataKind::Bc, c"attributes.bc", attributes_buffer)?;
bitcode_data_set.add(&attributes_bitcode_data)?;
let stdlib_bitcode_data = Data::new(comgr, DataKind::Bc, c"ptx_impl.bc", ptx_impl)?;
bitcode_data_set.add(&stdlib_bitcode_data)?;
let linking_info = ActionInfo::new(comgr)?;

Binary file not shown.

View File

@ -8,6 +8,8 @@
#include <hip/amd_detail/amd_device_functions.h>
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
#define DECLARE_ATTR(TYPE, NAME) extern const TYPE ATTR(NAME) __device__
extern "C"
{
@ -220,6 +222,29 @@ extern "C"
SHFL_SYNC_IMPL(bfly, self ^ delta, >);
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
DECLARE_ATTR(uint32_t, CLOCK_RATE);
void FUNC(nanosleep_u32)(uint32_t nanoseconds) {
// clock_rate is in kHz
uint64_t cycles_per_ns = ATTR(CLOCK_RATE) / 1000000;
uint64_t cycles = nanoseconds * cycles_per_ns;
// Avoid small sleep values resulting in s_sleep 0
cycles += 63;
// s_sleep N sleeps for 64 * N cycles
uint64_t sleep_amount = cycles / 64;
// The argument to s_sleep must be a constant
for (size_t i = 0; i < sleep_amount >> 4; i++)
__builtin_amdgcn_s_sleep(16);
if (sleep_amount & 8U)
__builtin_amdgcn_s_sleep(8);
if (sleep_amount & 4U)
__builtin_amdgcn_s_sleep(4);
if (sleep_amount & 2U)
__builtin_amdgcn_s_sleep(2);
if (sleep_amount & 1U)
__builtin_amdgcn_s_sleep(1);
}
void FUNC(__assertfail)(uint64_t message,
uint64_t file,
uint32_t line,

View File

@ -3,4 +3,5 @@ pub(crate) mod pass;
mod test;
pub use pass::to_llvm_module;
pub use pass::Attributes;

View File

@ -152,6 +152,7 @@ fn run_instruction<'input>(
..
}
| ast::Instruction::Mul24 { .. }
| ast::Instruction::Nanosleep { .. }
| ast::Instruction::Neg { .. }
| ast::Instruction::Not { .. }
| ast::Instruction::Or { .. }

View File

@ -1809,6 +1809,7 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::Cvta { .. }
| ast::Instruction::Atom { .. }
| ast::Instruction::Mul24 { .. }
| ast::Instruction::Nanosleep { .. }
| ast::Instruction::AtomCas { .. } => InstructionModes::none(),
ast::Instruction::Add {
data: ast::ArithDetails::Integer(_),

View File

@ -0,0 +1,34 @@
use std::ffi::CStr;
use super::*;
use super::super::*;
use llvm_zluda::{core::*};
pub(crate) fn run(context: &Context, attributes: Attributes) -> Result<llvm::Module, TranslateError> {
let module = llvm::Module::new(context, LLVM_UNNAMED);
emit_attribute(context, &module, "clock_rate", attributes.clock_rate)?;
if let Err(err) = module.verify() {
panic!("{:?}", err);
}
Ok(module)
}
fn emit_attribute(context: &Context, module: &llvm::Module, name: &str, attribute: u32) -> Result<(), TranslateError> {
let name = format!("{}attribute_{}\0", ZLUDA_PTX_PREFIX, name).to_ascii_uppercase();
let name = unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) };
let attribute_type = get_scalar_type(context.get(), ast::ScalarType::U32);
let global = unsafe {
LLVMAddGlobalInAddressSpace(
module.get(),
attribute_type,
name.as_ptr(),
get_state_space(ast::StateSpace::Global)?,
)
};
unsafe { LLVMSetInitializer(global, LLVMConstInt(attribute_type, attribute as u64, 0)) };
unsafe { LLVMSetGlobalConstant(global, 1) };
Ok(())
}

View File

@ -27,98 +27,15 @@
use std::array::TryFromSliceError;
use std::convert::TryInto;
use std::ffi::{CStr, NulError};
use std::ops::Deref;
use std::{i8, ptr};
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use crate::pass::*;
use llvm_zluda::{core::*, *};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
use ptx_parser::Mul24Control;
const LLVM_UNNAMED: &CStr = c"";
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
const GENERIC_ADDRESS_SPACE: u32 = 0;
const GLOBAL_ADDRESS_SPACE: u32 = 1;
const SHARED_ADDRESS_SPACE: u32 = 3;
const CONSTANT_ADDRESS_SPACE: u32 = 4;
const PRIVATE_ADDRESS_SPACE: u32 = 5;
struct Context(LLVMContextRef);
impl Context {
fn new() -> Self {
Self(unsafe { LLVMContextCreate() })
}
fn get(&self) -> LLVMContextRef {
self.0
}
}
impl Drop for Context {
fn drop(&mut self) {
unsafe {
LLVMContextDispose(self.0);
}
}
}
pub struct Module(LLVMModuleRef, Context);
impl Module {
fn new(ctx: Context, name: &CStr) -> Self {
Self(
unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) },
ctx,
)
}
fn get(&self) -> LLVMModuleRef {
self.0
}
fn context(&self) -> &Context {
&self.1
}
fn verify(&self) -> Result<(), Message> {
let mut err = ptr::null_mut();
let error = unsafe {
LLVMVerifyModule(
self.get(),
LLVMVerifierFailureAction::LLVMReturnStatusAction,
&mut err,
)
};
if error == 1 && err != ptr::null_mut() {
Err(Message(unsafe { CStr::from_ptr(err) }))
} else {
Ok(())
}
}
pub fn write_bitcode_to_memory(&self) -> MemoryBuffer {
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
MemoryBuffer(memory_buffer)
}
pub fn print_module_to_string(&self) -> Message {
let asm = unsafe { LLVMPrintModuleToString(self.get()) };
Message(unsafe { CStr::from_ptr(asm) })
}
}
impl Drop for Module {
fn drop(&mut self) {
unsafe {
LLVMDisposeModule(self.0);
}
}
}
struct Builder(LLVMBuilderRef);
impl Builder {
@ -143,55 +60,13 @@ impl Drop for Builder {
}
}
pub struct Message(&'static CStr);
impl Drop for Message {
fn drop(&mut self) {
unsafe {
LLVMDisposeMessage(self.0.as_ptr().cast_mut());
}
}
}
impl std::fmt::Debug for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.0, f)
}
}
impl Message {
pub fn to_str(&self) -> &str {
self.0.to_str().unwrap().trim()
}
}
pub struct MemoryBuffer(LLVMMemoryBufferRef);
impl Drop for MemoryBuffer {
fn drop(&mut self) {
unsafe {
LLVMDisposeMemoryBuffer(self.0);
}
}
}
impl Deref for MemoryBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
let data = unsafe { LLVMGetBufferStart(self.0) };
let len = unsafe { LLVMGetBufferSize(self.0) };
unsafe { std::slice::from_raw_parts(data.cast(), len) }
}
}
pub(super) fn run<'input>(
pub(crate) fn run<'input>(
context: &Context,
id_defs: GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Module, TranslateError> {
let context = Context::new();
let module = Module::new(context, LLVM_UNNAMED);
let mut emit_ctx = ModuleEmitContext::new(&module, &id_defs);
) -> Result<llvm::Module, TranslateError> {
let module = llvm::Module::new(context, LLVM_UNNAMED);
let mut emit_ctx = ModuleEmitContext::new(context, &module, &id_defs);
for directive in directives {
match directive {
Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?,
@ -213,8 +88,7 @@ struct ModuleEmitContext<'a, 'input> {
}
impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new(module: &Module, id_defs: &'a GlobalStringIdentResolver2<'input>) -> Self {
let context = module.context();
fn new(context: &Context, module: &llvm::Module, id_defs: &'a GlobalStringIdentResolver2<'input>) -> Self {
ModuleEmitContext {
context: context.get(),
module: module.get(),
@ -642,7 +516,8 @@ impl<'a> MethodEmitContext<'a> {
| ast::Instruction::BarRed { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Activemask { .. }
| ast::Instruction::ShflSync { .. } => return Err(error_unreachable()),
| ast::Instruction::ShflSync { .. }
| ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()),
}
}
@ -2729,33 +2604,6 @@ fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, T
})
}
fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef {
match type_ {
ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) },
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
LLVMInt8TypeInContext(context)
},
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
LLVMInt16TypeInContext(context)
},
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
LLVMInt32TypeInContext(context)
},
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe {
LLVMInt64TypeInContext(context)
},
ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) },
ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) },
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
ast::ScalarType::U16x2 => todo!(),
ast::ScalarType::S16x2 => todo!(),
ast::ScalarType::F16x2 => todo!(),
ast::ScalarType::BF16x2 => todo!(),
}
}
fn get_array_type<'a>(
context: LLVMContextRef,
elem_type: &'a ast::Type,
@ -2808,22 +2656,6 @@ fn get_function_type<'a>(
})
}
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())),
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE),
ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())),
ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())),
}
}
struct ResolveIdent {
words: HashMap<SpirvWord, String>,
values: HashMap<SpirvWord, LLVMValueRef>,

173
ptx/src/pass/llvm/mod.rs Normal file
View File

@ -0,0 +1,173 @@
pub(super) mod emit;
pub(super) mod attributes;
use std::ffi::CStr;
use std::ops::Deref;
use std::ptr;
use crate::pass::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::core::*;
use llvm_zluda::prelude::*;
const LLVM_UNNAMED: &CStr = c"";
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
const GENERIC_ADDRESS_SPACE: u32 = 0;
const GLOBAL_ADDRESS_SPACE: u32 = 1;
const SHARED_ADDRESS_SPACE: u32 = 3;
const CONSTANT_ADDRESS_SPACE: u32 = 4;
const PRIVATE_ADDRESS_SPACE: u32 = 5;
pub(super) struct Context(LLVMContextRef);
impl Context {
pub fn new() -> Self {
Self(unsafe { LLVMContextCreate() })
}
fn get(&self) -> LLVMContextRef {
self.0
}
}
impl Drop for Context {
fn drop(&mut self) {
unsafe {
LLVMContextDispose(self.0);
}
}
}
pub struct Module(LLVMModuleRef);
impl Module {
fn new(ctx: &Context, name: &CStr) -> Self {
Self(
unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) },
)
}
fn get(&self) -> LLVMModuleRef {
self.0
}
fn verify(&self) -> Result<(), Message> {
let mut err = ptr::null_mut();
let error = unsafe {
LLVMVerifyModule(
self.get(),
LLVMVerifierFailureAction::LLVMReturnStatusAction,
&mut err,
)
};
if error == 1 && err != ptr::null_mut() {
Err(Message(unsafe { CStr::from_ptr(err) }))
} else {
Ok(())
}
}
pub fn write_bitcode_to_memory(&self) -> MemoryBuffer {
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
MemoryBuffer(memory_buffer)
}
pub fn print_module_to_string(&self) -> Message {
let asm = unsafe { LLVMPrintModuleToString(self.get()) };
Message(unsafe { CStr::from_ptr(asm) })
}
}
impl Drop for Module {
fn drop(&mut self) {
unsafe {
LLVMDisposeModule(self.0);
}
}
}
pub struct Message(&'static CStr);
impl Drop for Message {
fn drop(&mut self) {
unsafe {
LLVMDisposeMessage(self.0.as_ptr().cast_mut());
}
}
}
impl std::fmt::Debug for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.0, f)
}
}
impl Message {
pub fn to_str(&self) -> &str {
self.0.to_str().unwrap().trim()
}
}
pub struct MemoryBuffer(LLVMMemoryBufferRef);
impl Drop for MemoryBuffer {
fn drop(&mut self) {
unsafe {
LLVMDisposeMemoryBuffer(self.0);
}
}
}
impl Deref for MemoryBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
let data = unsafe { LLVMGetBufferStart(self.0) };
let len = unsafe { LLVMGetBufferSize(self.0) };
unsafe { std::slice::from_raw_parts(data.cast(), len) }
}
}
fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef {
match type_ {
ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) },
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
LLVMInt8TypeInContext(context)
},
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
LLVMInt16TypeInContext(context)
},
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
LLVMInt32TypeInContext(context)
},
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe {
LLVMInt64TypeInContext(context)
},
ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) },
ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) },
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
ast::ScalarType::U16x2 => todo!(),
ast::ScalarType::S16x2 => todo!(),
ast::ScalarType::F16x2 => todo!(),
ast::ScalarType::BF16x2 => todo!(),
}
}
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())),
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE),
ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())),
ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())),
}
}

View File

@ -12,7 +12,6 @@ use strum::IntoEnumIterator;
use strum_macros::EnumIter;
mod deparamize_functions;
pub(crate) mod emit_llvm;
mod expand_operands;
mod fix_special_registers2;
mod hoist_globals;
@ -20,6 +19,7 @@ mod insert_explicit_load_store;
mod insert_implicit_conversions2;
mod insert_post_saturation;
mod instruction_mode_to_global_mode;
mod llvm;
mod normalize_basic_blocks;
mod normalize_identifiers2;
mod normalize_predicates2;
@ -46,7 +46,13 @@ quick_error! {
}
}
pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
/// GPU attributes needed at compile time.
pub struct Attributes {
/// Clock frequency in kHz.
pub clock_rate: u32,
}
pub fn to_llvm_module<'input>(ast: ast::Module<'input>, attributes: Attributes) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
@ -65,16 +71,23 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
let context = llvm::Context::new();
let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?;
let attributes_ir = llvm::attributes::run(&context, attributes)?;
Ok(Module {
llvm_ir,
attributes_ir,
kernel_info: HashMap::new(),
_context: context,
})
}
pub struct Module {
pub llvm_ir: emit_llvm::Module,
pub llvm_ir: llvm::Module,
pub attributes_ir: llvm::Module,
pub kernel_info: HashMap<String, KernelInfo>,
_context: llvm::Context,
}
impl Module {

View File

@ -137,6 +137,9 @@ fn run_instruction<'input>(
ptx_parser::Instruction::ShflSync { data, arguments },
)?
}
i @ ptx_parser::Instruction::Nanosleep { .. } => {
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
}
i => i,
})
}

View File

@ -0,0 +1 @@
@__ZLUDA_PTX_IMPL_ATTRIBUTE_CLOCK_RATE = addrspace(1) constant i32 2124000

View File

@ -0,0 +1,15 @@
declare void @__zluda_ptx_impl_nanosleep_u32(i32) #0
define amdgpu_kernel void @nanosleep(ptr addrspace(4) byref(i64) %"28", ptr addrspace(4) byref(i64) %"29") #1 {
br label %1
1: ; preds = %0
br label %"27"
"27": ; preds = %1
call void @__zluda_ptx_impl_nanosleep_u32(i32 1)
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -1,4 +1,4 @@
use crate::pass::TranslateError;
use crate::pass::{self, TranslateError};
use ptx_parser as ast;
mod spirv_run;
@ -9,7 +9,8 @@ fn parse_and_assert(ptx_text: &str) {
fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> {
let ast = ast::parse_module_checked(ptx_text).unwrap();
crate::to_llvm_module(ast)?;
let attributes = pass::Attributes { clock_rate: 2124000 };
crate::to_llvm_module(ast, attributes)?;
Ok(())
}

View File

@ -297,6 +297,8 @@ test_ptx!(
test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
test_ptx!(warp_sz, [0u8], [32u8]);
test_ptx!(nanosleep, [0u64], [0u64]);
test_ptx!(assertfail);
// TODO: not yet supported
//test_ptx!(func_ptr);
@ -375,7 +377,7 @@ fn test_hip_assert<
block_dim_x: u32,
) -> Result<(), Box<dyn error::Error>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap();
let llvm_ir = pass::to_llvm_module(ast, pass::Attributes { clock_rate: 2124000 }).unwrap();
let name = CString::new(name)?;
let result =
run_hip(name.as_c_str(), llvm_ir, input, output, block_dim_x).map_err(|err| DisplayError { err })?;
@ -389,9 +391,19 @@ fn test_llvm_assert(
expected_ll: &str,
) -> Result<(), Box<dyn error::Error>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap();
let llvm_ir = pass::to_llvm_module(ast, pass::Attributes { clock_rate: 2124000 }).unwrap();
let actual_ll = llvm_ir.llvm_ir.print_module_to_string();
let actual_ll = actual_ll.to_str();
compare_llvm(name, actual_ll, expected_ll);
let expected_attributes_ll = read_test_file!(concat!("../ll/_attributes.ll"));
let actual_attributes_ll = llvm_ir.attributes_ir.print_module_to_string();
let actual_attributes_ll = actual_attributes_ll.to_str();
compare_llvm("_attributes", actual_attributes_ll, &expected_attributes_ll);
Ok(())
}
fn compare_llvm(name: &str, actual_ll: &str, expected_ll: &str) {
if actual_ll != expected_ll {
let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR");
if let Ok(output_dir) = output_dir {
@ -404,7 +416,6 @@ fn test_llvm_assert(
let comparison = pretty_assertions::StrComparison::new(&expected_ll, &actual_ll);
panic!("assertion failed: `(left == right)`\n\n{}", comparison);
}
Ok(())
}
fn test_cuda_assert<
@ -567,6 +578,7 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
&comgr,
unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) },
&*module.llvm_ir.write_bitcode_to_memory(),
&*module.attributes_ir.write_bitcode_to_memory(),
module.linked_bitcode(),
)
.unwrap();

View File

@ -0,0 +1,13 @@
.version 6.5
.target sm_70
.address_size 64
.visible .entry nanosleep(
.param .u64 input,
.param .u64 output
)
{
// TODO: check if there's some way of testing that it actually sleeps
nanosleep.u32 1;
ret;
}

View File

@ -327,6 +327,12 @@ ptx_parser_macros::generate_instruction_type!(
src2: T,
}
},
Nanosleep {
type: Type::Scalar(ScalarType::U32),
arguments<T>: {
src: T
}
},
Neg {
type: Type::Scalar(data.type_),
data: TypeFtz,

View File

@ -3502,6 +3502,13 @@ derive_parser!(
}
}
.mode: ShuffleMode = { .up, .down, .bfly, .idx };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-nanosleep
nanosleep.u32 t => {
Instruction::Nanosleep {
arguments: NanosleepArgs { src: t }
}
}
);
#[cfg(test)]

View File

@ -64,15 +64,17 @@ pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModul
let global_state = driver::global_state()?;
let text = get_ptx(image)?;
let ast = ptx_parser::parse_module_checked(&text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
let llvm_module = ptx::to_llvm_module(ast).map_err(|_| CUerror::UNKNOWN)?;
let mut dev = 0;
unsafe { hipCtxGetDevice(&mut dev) }?;
let mut props = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut props, dev) }?;
let attributes = ptx::Attributes { clock_rate: props.clockRate as u32 };
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode(
&global_state.comgr,
unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) },
&*llvm_module.llvm_ir.write_bitcode_to_memory(),
&*llvm_module.attributes_ir.write_bitcode_to_memory(),
llvm_module.linked_bitcode(),
)
.map_err(|_| CUerror::UNKNOWN)?;