Check Rust formatting on pull requests (#451)

* Check Rust formatting on pull requests

This should help us maintain consistent style, without having unrelated style changes in pull requests from running `rustfmt`.

* cargo fmt non-generated files

* Ignore generated files
This commit is contained in:
Violet
2025-07-30 14:55:09 -07:00
committed by GitHub
parent 98b601d15a
commit 21ef5f60a3
40 changed files with 11463 additions and 11399 deletions

View File

@ -11,6 +11,15 @@ env:
ROCM_VERSION: "6.3.1"
jobs:
formatting:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: actions-rust-lang/setup-rust-toolchain@v1
with:
components: rustfmt
- name: Check Rust formatting
uses: actions-rust-lang/rustfmt@v1
build_linux:
name: Build (Linux)
runs-on: ubuntu-22.04

View File

@ -184,7 +184,8 @@ pub fn compile_bitcode(
let bitcode_data_set = DataSet::new(comgr)?;
let main_bitcode_data = Data::new(comgr, DataKind::Bc, c"zluda.bc", main_buffer)?;
bitcode_data_set.add(&main_bitcode_data)?;
let attributes_bitcode_data = Data::new(comgr, DataKind::Bc, c"attributes.bc", attributes_buffer)?;
let attributes_bitcode_data =
Data::new(comgr, DataKind::Bc, c"attributes.bc", attributes_buffer)?;
bitcode_data_set.add(&attributes_bitcode_data)?;
let stdlib_bitcode_data = Data::new(comgr, DataKind::Bc, c"ptx_impl.bc", ptx_impl)?;
bitcode_data_set.add(&stdlib_bitcode_data)?;

View File

@ -0,0 +1 @@
disable_all_formatting = true

1
cuda_types/.rustfmt.toml Normal file
View File

@ -0,0 +1 @@
disable_all_formatting = true

View File

@ -77,21 +77,22 @@ impl<'a> Fatbin<'a> {
pub fn get_submodules(&self) -> Result<FatbinIter<'a>, FatbinError> {
match self.wrapper.version {
FatbincWrapper::VERSION_V2 =>
Ok(FatbinIter::V2(FatbinSubmoduleIterator {
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
_phantom: std::marker::PhantomData,
})),
FatbincWrapper::VERSION_V2 => Ok(FatbinIter::V2(FatbinSubmoduleIterator {
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
_phantom: std::marker::PhantomData,
})),
FatbincWrapper::VERSION_V1 => {
let header = parse_fatbin_header(&self.wrapper.data)
.map_err(FatbinError::ParseFailure)?;
let header =
parse_fatbin_header(&self.wrapper.data).map_err(FatbinError::ParseFailure)?;
Ok(FatbinIter::V1(Some(FatbinSubmodule::new(header))))
}
version => Err(FatbinError::ParseFailure(ParseError::UnexpectedBinaryField{
field_name: "FATBINC_VERSION",
observed: version,
expected: [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2].into(),
})),
version => Err(FatbinError::ParseFailure(
ParseError::UnexpectedBinaryField {
field_name: "FATBINC_VERSION",
observed: version,
expected: [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2].into(),
},
)),
}
}
}
@ -176,7 +177,6 @@ impl<'a> FatbinFile<'a> {
unsafe { self.get_payload().to_vec() }
};
while payload.last() == Some(&0) {
// remove trailing zeros
payload.pop();

1
ext/hip_runtime-sys/.rustfmt.toml vendored Normal file
View File

@ -0,0 +1 @@
disable_all_formatting = true

1
ext/rocblas-sys/.rustfmt.toml vendored Normal file
View File

@ -0,0 +1 @@
disable_all_formatting = true

1
format/.rustfmt.toml Normal file
View File

@ -0,0 +1 @@
disable_all_formatting = true

View File

@ -428,7 +428,8 @@ impl CudaDisplay for CUmemcpy3DOperand_st {
CudaDisplay::write(unsafe { &self.op.ptr }, fn_name, index, writer)?;
}
_ => {
const CU_MEMCPY_3D_OP_SIZE: usize = mem::size_of::<CUmemcpy3DOperand_st__bindgen_ty_1>();
const CU_MEMCPY_3D_OP_SIZE: usize =
mem::size_of::<CUmemcpy3DOperand_st__bindgen_ty_1>();
CudaDisplay::write(
&unsafe { mem::transmute::<_, [u8; CU_MEMCPY_3D_OP_SIZE]>(self.op) },
fn_name,

View File

@ -4,4 +4,3 @@ mod test;
pub use pass::to_llvm_module;
pub use pass::Attributes;

View File

@ -1,10 +1,13 @@
use std::ffi::CStr;
use super::*;
use super::super::*;
use llvm_zluda::{core::*};
use super::*;
use llvm_zluda::core::*;
pub(crate) fn run(context: &Context, attributes: Attributes) -> Result<llvm::Module, TranslateError> {
pub(crate) fn run(
context: &Context,
attributes: Attributes,
) -> Result<llvm::Module, TranslateError> {
let module = llvm::Module::new(context, LLVM_UNNAMED);
emit_attribute(context, &module, "clock_rate", attributes.clock_rate)?;
@ -16,7 +19,12 @@ pub(crate) fn run(context: &Context, attributes: Attributes) -> Result<llvm::Mod
Ok(module)
}
fn emit_attribute(context: &Context, module: &llvm::Module, name: &str, attribute: u32) -> Result<(), TranslateError> {
fn emit_attribute(
context: &Context,
module: &llvm::Module,
name: &str,
attribute: u32,
) -> Result<(), TranslateError> {
let name = format!("{}attribute_{}\0", ZLUDA_PTX_PREFIX, name).to_ascii_uppercase();
let name = unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) };
let attribute_type = get_scalar_type(context.get(), ast::ScalarType::U32);

View File

@ -88,7 +88,11 @@ struct ModuleEmitContext<'a, 'input> {
}
impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new(context: &Context, module: &llvm::Module, id_defs: &'a GlobalStringIdentResolver2<'input>) -> Self {
fn new(
context: &Context,
module: &llvm::Module,
id_defs: &'a GlobalStringIdentResolver2<'input>,
) -> Self {
ModuleEmitContext {
context: context.get(),
module: module.get(),
@ -516,7 +520,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")),
ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments),
ast::Instruction::CpAsync { data, arguments } => self.emit_cp_async(data, arguments),
ast::Instruction::CpAsyncCommitGroup { } => Ok(()), // nop
ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop
ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop
ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop
// replaced by a function call
@ -764,7 +768,9 @@ impl<'a> MethodEmitContext<'a> {
todo!()
}
let store = unsafe { LLVMBuildStore(self.builder, value, ptr) };
unsafe { LLVMSetAlignment(store, data.typ.layout().align() as u32); }
unsafe {
LLVMSetAlignment(store, data.typ.layout().align() as u32);
}
Ok(())
}
@ -2587,7 +2593,6 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn flush_denormals(
&mut self,
type_: ptx_parser::ScalarType,

View File

@ -1,5 +1,5 @@
pub(super) mod emit;
pub(super) mod attributes;
pub(super) mod emit;
use std::ffi::CStr;
use std::ops::Deref;
@ -44,9 +44,7 @@ pub struct Module(LLVMModuleRef);
impl Module {
fn new(ctx: &Context, name: &CStr) -> Self {
Self(
unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) },
)
Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) })
}
fn get(&self) -> LLVMModuleRef {

View File

@ -52,7 +52,10 @@ pub struct Attributes {
pub clock_rate: u32,
}
pub fn to_llvm_module<'input>(ast: ast::Module<'input>, attributes: Attributes) -> Result<Module, TranslateError> {
pub fn to_llvm_module<'input>(
ast: ast::Module<'input>,
attributes: Attributes,
) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;

View File

@ -21,7 +21,9 @@ pub(crate) fn run(
for directive in directives.iter_mut() {
let (body_ref, is_kernel) = match directive {
Directive2::Method(Function2 {
body: Some(body), is_kernel, ..
body: Some(body),
is_kernel,
..
}) => (body, *is_kernel),
_ => continue,
};

View File

@ -9,7 +9,9 @@ fn parse_and_assert(ptx_text: &str) {
fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> {
let ast = ast::parse_module_checked(ptx_text).unwrap();
let attributes = pass::Attributes { clock_rate: 2124000 };
let attributes = pass::Attributes {
clock_rate: 2124000,
};
crate::to_llvm_module(ast, attributes)?;
Ok(())
}

View File

@ -46,7 +46,7 @@ macro_rules! test_ptx_llvm {
test_llvm_assert(stringify!($fn_name), &ptx, ll.trim())
}
}
}
};
}
macro_rules! test_ptx {
@ -309,48 +309,77 @@ test_ptx!(assertfail);
test_ptx!(lanemask_lt);
test_ptx!(extern_func);
test_ptx_warp!(tid, [
0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, 15u8,
16u8, 17u8, 18u8, 19u8, 20u8, 21u8, 22u8, 23u8, 24u8, 25u8, 26u8, 27u8, 28u8, 29u8, 30u8, 31u8,
32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8,
48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8,
]);
test_ptx_warp!(bar_red_and_pred, [
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
]);
test_ptx_warp!(shfl_sync_up_b32_pred, [
1000u32, 1001u32, 1002u32, 0u32, 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32,
13u32, 14u32, 15u32, 16u32, 17u32, 18u32, 19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32,
1032u32, 1033u32, 1034u32, 32u32, 33u32, 34u32, 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32,
45u32, 46u32, 47u32, 48u32, 49u32, 50u32, 51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32,
]);
test_ptx_warp!(shfl_sync_down_b32_pred, [
3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, 13u32, 14u32, 15u32, 16u32, 17u32, 18u32,
19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, 29u32, 30u32, 31u32, 1029u32, 1030u32, 1031u32,
35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32, 45u32, 46u32, 47u32, 48u32, 49u32, 50u32,
51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32, 61u32, 62u32, 63u32, 1061u32, 1062u32, 1063u32,
]);
test_ptx_warp!(shfl_sync_bfly_b32_pred, [
3u32, 2u32, 1u32, 0u32, 7u32, 6u32, 5u32, 4u32, 11u32, 10u32, 9u32, 8u32, 15u32, 14u32, 13u32, 12u32,
19u32, 18u32, 17u32, 16u32, 23u32, 22u32, 21u32, 20u32, 27u32, 26u32, 25u32, 24u32, 31u32, 30u32, 29u32, 28u32,
35u32, 34u32, 33u32, 32u32, 39u32, 38u32, 37u32, 36u32, 43u32, 42u32, 41u32, 40u32, 47u32, 46u32, 45u32, 44u32,
51u32, 50u32, 49u32, 48u32, 55u32, 54u32, 53u32, 52u32, 59u32, 58u32, 57u32, 56u32, 63u32, 62u32, 61u32, 60u32,
]);
test_ptx_warp!(shfl_sync_idx_b32_pred, [
12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32,
12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32,
44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32,
44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32,
]);
test_ptx_warp!(shfl_sync_mode_b32, [
9u32, 7u32, 8u32, 9u32, 21u32, 19u32, 20u32, 21u32, 33u32, 31u32, 32u32, 33u32, 45u32, 43u32, 44u32, 45u32,
73u32, 71u32, 72u32, 73u32, 85u32, 83u32, 84u32, 85u32, 97u32, 95u32, 96u32, 97u32, 109u32, 107u32, 108u32, 109u32,
137u32, 135u32, 136u32, 137u32, 149u32, 147u32, 148u32, 149u32, 161u32, 159u32, 160u32, 161u32, 173u32, 171u32, 172u32, 173u32,
201u32, 199u32, 200u32, 201u32, 213u32, 211u32, 212u32, 213u32, 225u32, 223u32, 224u32, 225u32, 237u32, 235u32, 236u32, 237u32,
]);
test_ptx_warp!(
tid,
[
0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, 15u8, 16u8,
17u8, 18u8, 19u8, 20u8, 21u8, 22u8, 23u8, 24u8, 25u8, 26u8, 27u8, 28u8, 29u8, 30u8, 31u8,
32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8,
47u8, 48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8,
62u8, 63u8,
]
);
test_ptx_warp!(
bar_red_and_pred,
[
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32,
]
);
test_ptx_warp!(
shfl_sync_up_b32_pred,
[
1000u32, 1001u32, 1002u32, 0u32, 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32,
10u32, 11u32, 12u32, 13u32, 14u32, 15u32, 16u32, 17u32, 18u32, 19u32, 20u32, 21u32, 22u32,
23u32, 24u32, 25u32, 26u32, 27u32, 28u32, 1032u32, 1033u32, 1034u32, 32u32, 33u32, 34u32,
35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32, 45u32, 46u32, 47u32,
48u32, 49u32, 50u32, 51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32,
]
);
test_ptx_warp!(
shfl_sync_down_b32_pred,
[
3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, 13u32, 14u32, 15u32, 16u32,
17u32, 18u32, 19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, 29u32,
30u32, 31u32, 1029u32, 1030u32, 1031u32, 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32,
42u32, 43u32, 44u32, 45u32, 46u32, 47u32, 48u32, 49u32, 50u32, 51u32, 52u32, 53u32, 54u32,
55u32, 56u32, 57u32, 58u32, 59u32, 60u32, 61u32, 62u32, 63u32, 1061u32, 1062u32, 1063u32,
]
);
test_ptx_warp!(
shfl_sync_bfly_b32_pred,
[
3u32, 2u32, 1u32, 0u32, 7u32, 6u32, 5u32, 4u32, 11u32, 10u32, 9u32, 8u32, 15u32, 14u32,
13u32, 12u32, 19u32, 18u32, 17u32, 16u32, 23u32, 22u32, 21u32, 20u32, 27u32, 26u32, 25u32,
24u32, 31u32, 30u32, 29u32, 28u32, 35u32, 34u32, 33u32, 32u32, 39u32, 38u32, 37u32, 36u32,
43u32, 42u32, 41u32, 40u32, 47u32, 46u32, 45u32, 44u32, 51u32, 50u32, 49u32, 48u32, 55u32,
54u32, 53u32, 52u32, 59u32, 58u32, 57u32, 56u32, 63u32, 62u32, 61u32, 60u32,
]
);
test_ptx_warp!(
shfl_sync_idx_b32_pred,
[
12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32,
12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32,
12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32,
44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32,
44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32,
]
);
test_ptx_warp!(
shfl_sync_mode_b32,
[
9u32, 7u32, 8u32, 9u32, 21u32, 19u32, 20u32, 21u32, 33u32, 31u32, 32u32, 33u32, 45u32,
43u32, 44u32, 45u32, 73u32, 71u32, 72u32, 73u32, 85u32, 83u32, 84u32, 85u32, 97u32, 95u32,
96u32, 97u32, 109u32, 107u32, 108u32, 109u32, 137u32, 135u32, 136u32, 137u32, 149u32,
147u32, 148u32, 149u32, 161u32, 159u32, 160u32, 161u32, 173u32, 171u32, 172u32, 173u32,
201u32, 199u32, 200u32, 201u32, 213u32, 211u32, 212u32, 213u32, 225u32, 223u32, 224u32,
225u32, 237u32, 235u32, 236u32, 237u32,
]
);
struct DisplayError<T: Debug> {
err: T,
@ -381,10 +410,16 @@ fn test_hip_assert<
block_dim_x: u32,
) -> Result<(), Box<dyn error::Error>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast, pass::Attributes { clock_rate: 2124000 }).unwrap();
let llvm_ir = pass::to_llvm_module(
ast,
pass::Attributes {
clock_rate: 2124000,
},
)
.unwrap();
let name = CString::new(name)?;
let result =
run_hip(name.as_c_str(), llvm_ir, input, output, block_dim_x).map_err(|err| DisplayError { err })?;
let result = run_hip(name.as_c_str(), llvm_ir, input, output, block_dim_x)
.map_err(|err| DisplayError { err })?;
assert_eq!(result.as_slice(), output);
Ok(())
}
@ -395,7 +430,13 @@ fn test_llvm_assert(
expected_ll: &str,
) -> Result<(), Box<dyn error::Error>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast, pass::Attributes { clock_rate: 2124000 }).unwrap();
let llvm_ir = pass::to_llvm_module(
ast,
pass::Attributes {
clock_rate: 2124000,
},
)
.unwrap();
let actual_ll = llvm_ir.llvm_ir.print_module_to_string();
let actual_ll = actual_ll.to_str();
compare_llvm(name, actual_ll, expected_ll);

View File

@ -2,7 +2,7 @@ use super::{
AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
StateSpace, VectorPrefix,
};
use crate::{Mul24Control, Reduction, PtxError, PtxParserState, ShuffleMode};
use crate::{Mul24Control, PtxError, PtxParserState, Reduction, ShuffleMode};
use bitflags::bitflags;
use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
@ -1418,7 +1418,6 @@ impl SetpData {
}
}
pub struct SetBoolData {
pub dtype: ScalarType,
pub base: SetpBoolData,

View File

@ -1,6 +1,6 @@
use either::Either;
use ptx_parser_macros_impl::parser;
use proc_macro2::{Span, TokenStream};
use ptx_parser_macros_impl::parser;
use quote::{format_ident, quote, ToTokens};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{collections::hash_map, hash::Hash, iter, rc::Rc};
@ -359,7 +359,8 @@ fn gather_rules(
#[proc_macro]
pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let parse_definitions = parse_macro_input!(tokens as ptx_parser_macros_impl::parser::ParseDefinitions);
let parse_definitions =
parse_macro_input!(tokens as ptx_parser_macros_impl::parser::ParseDefinitions);
let mut definitions = FxHashMap::default();
let mut special_definitions = FxHashMap::default();
let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions);
@ -384,7 +385,13 @@ pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream
special_definitions.keys(),
&mut token_enum.variants,
);
let token_impl = emit_parse_function(&token_enum.ident, &definitions, &special_definitions, all_opcode, all_modifier);
let token_impl = emit_parse_function(
&token_enum.ident,
&definitions,
&special_definitions,
all_opcode,
all_modifier,
);
let tokens = quote! {
#enum_types_tokens
@ -846,7 +853,8 @@ fn emit_definition_parser(
(pattern, parser)
});
let arguments_parse = quote! { let #arguments_pattern = ( #arguments_parser ).parse_next(stream)?; };
let arguments_parse =
quote! { let #arguments_pattern = ( #arguments_parser ).parse_next(stream)?; };
let fn_args = definition.function_arguments();
let fn_name = format_ident!("{}_{}", opcode, fn_idx);

View File

@ -764,7 +764,7 @@ impl Parse for ArgumentField {
repr,
type_,
space,
relaxed_type_check
relaxed_type_check,
})
}
}

View File

@ -423,7 +423,6 @@ impl std::fmt::Display for HyphenatedIdent {
Ok(())
}
}
impl Parse for HyphenatedIdent {

View File

@ -1,7 +1,7 @@
use super::{context, driver};
use cuda_types::cuda::*;
use hip_runtime_sys::*;
use std::{mem, ptr};
use super::{driver, context};
const PROJECT_SUFFIX: &[u8] = b" [ZLUDA]\0";
pub const COMPUTE_CAPABILITY_MAJOR: i32 = 8;
@ -462,15 +462,14 @@ fn clamp_usize(x: usize) -> i32 {
usize::min(x, i32::MAX as usize) as i32
}
pub(crate) fn get_primary_context(hip_dev: hipDevice_t) -> Result<(&'static context::Context, CUcontext), CUerror> {
pub(crate) fn get_primary_context(
hip_dev: hipDevice_t,
) -> Result<(&'static context::Context, CUcontext), CUerror> {
let dev: &'static driver::Device = driver::device(hip_dev)?;
Ok(dev.primary_context())
}
pub(crate) fn primary_context_retain(
pctx: &mut CUcontext,
hip_dev: hipDevice_t,
) -> CUresult {
pub(crate) fn primary_context_retain(pctx: &mut CUcontext, hip_dev: hipDevice_t) -> CUresult {
let (ctx, cu_ctx) = get_primary_context(hip_dev)?;
ctx.with_state_mut(|state: &mut context::ContextState| {
@ -497,8 +496,6 @@ pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> CUresult {
pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult {
let (ctx, _) = get_primary_context(hip_dev)?;
ctx.with_state_mut(|state| {
state.reset()
})?;
ctx.with_state_mut(|state| state.reset())?;
Ok(())
}

View File

@ -38,10 +38,7 @@ pub(crate) unsafe fn unload(library: CUlibrary) -> CUresult {
super::drop_checked::<Library>(library)
}
pub(crate) unsafe fn get_module(
out: &mut CUmodule,
library: &Library,
) -> CUresult {
*out = module::Module{base: library.base}.wrap();
pub(crate) unsafe fn get_module(out: &mut CUmodule, library: &Library) -> CUresult {
*out = module::Module { base: library.base }.wrap();
Ok(())
}

View File

@ -68,7 +68,9 @@ pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModul
unsafe { hipCtxGetDevice(&mut dev) }?;
let mut props = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut props, dev) }?;
let attributes = ptx::Attributes { clock_rate: props.clockRate as u32 };
let attributes = ptx::Attributes {
clock_rate: props.clockRate as u32,
};
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode(
&global_state.comgr,
@ -91,7 +93,6 @@ pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUre
pub(crate) fn unload(hmod: CUmodule) -> CUresult {
super::drop_checked::<Module>(hmod)
}
pub(crate) fn get_function(

View File

@ -1,11 +1,10 @@
use cuda_types::cuda::CUerror;
use std::sync::atomic::{AtomicBool, Ordering};
pub(crate) mod r#impl;
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
mod os;
pub(crate) mod r#impl;
static INITIALIZED: AtomicBool = AtomicBool::new(true);
pub(crate) fn initialized() -> bool {
@ -66,61 +65,60 @@ macro_rules! implemented_in_function {
cuda_macros::cuda_function_declarations!(
unimplemented,
implemented <= [
cuCtxCreate_v2,
cuCtxDestroy_v2,
cuCtxGetLimit,
cuCtxSetCurrent,
cuCtxGetCurrent,
cuCtxGetDevice,
cuCtxSetLimit,
cuCtxSynchronize,
cuCtxPushCurrent,
cuCtxPushCurrent_v2,
cuCtxPopCurrent,
cuCtxPopCurrent_v2,
cuDeviceComputeCapability,
cuDeviceGet,
cuDeviceGetAttribute,
cuDeviceGetCount,
cuDeviceGetLuid,
cuDeviceGetName,
cuDeviceGetProperties,
cuDeviceGetUuid,
cuDeviceGetUuid_v2,
cuDevicePrimaryCtxRelease,
cuDevicePrimaryCtxRetain,
cuDevicePrimaryCtxReset,
cuDeviceTotalMem_v2,
cuDriverGetVersion,
cuFuncGetAttribute,
cuGetExportTable,
cuGetProcAddress,
cuGetProcAddress_v2,
cuInit,
cuLibraryLoadData,
cuLibraryGetModule,
cuLibraryUnload,
cuMemAlloc_v2,
cuMemFree_v2,
cuMemHostAlloc,
cuMemFreeHost,
cuMemGetAddressRange_v2,
cuMemGetInfo_v2,
cuMemcpyDtoH_v2,
cuMemcpyHtoD_v2,
cuMemsetD32_v2,
cuMemsetD8_v2,
cuModuleGetFunction,
cuModuleGetLoadingMode,
cuModuleLoadData,
cuModuleUnload,
cuPointerGetAttribute,
cuStreamSynchronize,
cuProfilerStart,
cuProfilerStop,
],
implemented_in_function <= [
cuLaunchKernel,
]
implemented
<= [
cuCtxCreate_v2,
cuCtxDestroy_v2,
cuCtxGetLimit,
cuCtxSetCurrent,
cuCtxGetCurrent,
cuCtxGetDevice,
cuCtxSetLimit,
cuCtxSynchronize,
cuCtxPushCurrent,
cuCtxPushCurrent_v2,
cuCtxPopCurrent,
cuCtxPopCurrent_v2,
cuDeviceComputeCapability,
cuDeviceGet,
cuDeviceGetAttribute,
cuDeviceGetCount,
cuDeviceGetLuid,
cuDeviceGetName,
cuDeviceGetProperties,
cuDeviceGetUuid,
cuDeviceGetUuid_v2,
cuDevicePrimaryCtxRelease,
cuDevicePrimaryCtxRetain,
cuDevicePrimaryCtxReset,
cuDeviceTotalMem_v2,
cuDriverGetVersion,
cuFuncGetAttribute,
cuGetExportTable,
cuGetProcAddress,
cuGetProcAddress_v2,
cuInit,
cuLibraryLoadData,
cuLibraryGetModule,
cuLibraryUnload,
cuMemAlloc_v2,
cuMemFree_v2,
cuMemHostAlloc,
cuMemFreeHost,
cuMemGetAddressRange_v2,
cuMemGetInfo_v2,
cuMemcpyDtoH_v2,
cuMemcpyHtoD_v2,
cuMemsetD32_v2,
cuMemsetD8_v2,
cuModuleGetFunction,
cuModuleGetLoadingMode,
cuModuleLoadData,
cuModuleUnload,
cuPointerGetAttribute,
cuStreamSynchronize,
cuProfilerStart,
cuProfilerStop,
],
implemented_in_function <= [cuLaunchKernel,]
);

View File

@ -0,0 +1 @@

View File

@ -10,15 +10,11 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
cublasStatus_t::ERROR_NOT_SUPPORTED
}
pub(crate) fn get_status_name(
_status: cublasStatus_t,
) -> *const ::core::ffi::c_char {
pub(crate) fn get_status_name(_status: cublasStatus_t) -> *const ::core::ffi::c_char {
todo!()
}
pub(crate) fn get_status_string(
_status: cublasStatus_t,
) -> *const ::core::ffi::c_char {
pub(crate) fn get_status_string(_status: cublasStatus_t) -> *const ::core::ffi::c_char {
todo!()
}

View File

@ -28,10 +28,11 @@ macro_rules! implemented {
cuda_macros::cublas_function_declarations!(
unimplemented,
implemented <= [
cublasGetStatusName,
cublasGetStatusString,
cublasXerbla,
cublasGetCudartVersion,
]
implemented
<= [
cublasGetStatusName,
cublasGetStatusString,
cublasXerbla,
cublasGetCudartVersion
]
);

View File

@ -31,8 +31,6 @@ pub(crate) fn get_cudart_version() -> usize {
}
#[allow(non_snake_case)]
pub(crate) fn disable_cpu_instructions_set_mask(
_mask: ::core::ffi::c_uint,
) -> ::core::ffi::c_uint {
pub(crate) fn disable_cpu_instructions_set_mask(_mask: ::core::ffi::c_uint) -> ::core::ffi::c_uint {
todo!()
}

View File

@ -28,11 +28,12 @@ macro_rules! implemented {
cuda_macros::cublaslt_function_declarations!(
unimplemented,
implemented <= [
cublasLtGetStatusName,
cublasLtGetStatusString,
cublasLtDisableCpuInstructionsSetMask,
cublasLtGetVersion,
cublasLtGetCudartVersion
]
implemented
<= [
cublasLtGetStatusName,
cublasLtGetStatusString,
cublasLtDisableCpuInstructionsSetMask,
cublasLtGetVersion,
cublasLtGetCudartVersion
]
);

View File

@ -28,11 +28,12 @@ macro_rules! implemented {
cuda_macros::cudnn9_function_declarations!(
unimplemented,
implemented <= [
cudnnGetVersion,
cudnnGetMaxDeviceVersion,
cudnnGetCudartVersion,
cudnnGetErrorString,
cudnnGetLastErrorString
]
implemented
<= [
cudnnGetVersion,
cudnnGetMaxDeviceVersion,
cudnnGetCudartVersion,
cudnnGetErrorString,
cudnnGetLastErrorString
]
);

View File

@ -13,6 +13,4 @@ macro_rules! unimplemented {
};
}
cuda_macros::cufft_function_declarations!(
unimplemented
);
cuda_macros::cufft_function_declarations!(unimplemented);

View File

@ -11,9 +11,7 @@ pub(crate) fn unimplemented() -> nvmlReturn_t {
nvmlReturn_t::ERROR_NOT_SUPPORTED
}
pub(crate) fn error_string(
_result: cuda_types::nvml::nvmlReturn_t,
) -> *const ::core::ffi::c_char {
pub(crate) fn error_string(_result: cuda_types::nvml::nvmlReturn_t) -> *const ::core::ffi::c_char {
c"".as_ptr()
}

View File

@ -26,9 +26,5 @@ macro_rules! implemented_fn {
cuda_macros::nvml_function_declarations!(
unimplemented_fn,
implemented_fn <= [
nvmlErrorString,
nvmlInit_v2,
nvmlSystemGetDriverVersion
]
implemented_fn <= [nvmlErrorString, nvmlInit_v2, nvmlSystemGetDriverVersion]
);

View File

@ -28,12 +28,13 @@ macro_rules! implemented {
cuda_macros::cusparse_function_declarations!(
unimplemented,
implemented <= [
cusparseGetErrorName,
cusparseGetErrorString,
cusparseGetMatIndexBase,
cusparseGetMatDiagType,
cusparseGetMatFillMode,
cusparseGetMatType
]
implemented
<= [
cusparseGetErrorName,
cusparseGetErrorString,
cusparseGetMatIndexBase,
cusparseGetMatDiagType,
cusparseGetMatFillMode,
cusparseGetMatType
]
);