Use normalize_fn for performance libraries (#449)

The goal here is to make the performance library implementations work more like zluda.
This commit is contained in:
Violet
2025-07-30 14:02:01 -07:00
committed by GitHub
parent c07d7678cd
commit 98b601d15a
11 changed files with 90 additions and 71 deletions

View File

@ -203,8 +203,11 @@ const MODULES: &[&str] = &[
"stream", "stream",
]; ];
#[proc_macro] fn normalize_fn_impl(
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { prefix: &str,
default_module: Option<&str>,
tokens: TokenStream,
) -> TokenStream {
let mut path = parse_macro_input!(tokens as syn::Path); let mut path = parse_macro_input!(tokens as syn::Path);
let fn_ = path let fn_ = path
.segments .segments
@ -215,14 +218,44 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.ident .ident
.to_string(); .to_string();
let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string()); let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
let segments: Vec<String> = split(&fn_[2..]); // skip "cu" let segments: Vec<String> = split(&fn_[prefix.len()..]); // skip "cu"
let fn_path = join(segments, !already_has_module); let fn_path = join(segments, default_module.filter(|_| !already_has_module));
quote! { quote! {
#path #fn_path #path #fn_path
} }
.into() .into()
} }
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cu", Some("driver"), tokens)
}
#[proc_macro]
pub fn cublas_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cublas", None, tokens)
}
#[proc_macro]
pub fn cublaslt_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cublasLt", None, tokens)
}
#[proc_macro]
pub fn cudnn_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cudnn", None, tokens)
}
#[proc_macro]
pub fn cusparse_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cusparse", None, tokens)
}
#[proc_macro]
pub fn nvml_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("nvml", None, tokens)
}
fn split(fn_: &str) -> Vec<String> { fn split(fn_: &str) -> Vec<String> {
let mut result = Vec::new(); let mut result = Vec::new();
for c in fn_.chars() { for c in fn_.chars() {
@ -235,7 +268,10 @@ fn split(fn_: &str) -> Vec<String> {
result result
} }
fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> { fn join(
fn_: Vec<String>,
default_module: Option<&str>,
) -> Punctuated<Ident, Token![::]> {
fn full_form(segment: &str) -> Option<&[&str]> { fn full_form(segment: &str) -> Option<&[&str]> {
Some(match segment { Some(match segment {
"ctx" => &["context"], "ctx" => &["context"],
@ -253,13 +289,9 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
None => normalized.push(&*segment), None => normalized.push(&*segment),
} }
} }
if !find_module { if let Some(default_module) = default_module {
return [Ident::new(&normalized.join("_"), Span::call_site())] if !MODULES.contains(&normalized[0]) {
.into_iter() let mut globalized = vec![default_module];
.collect();
}
if !MODULES.contains(&normalized[0]) {
let mut globalized = vec!["driver"];
globalized.extend(normalized); globalized.extend(normalized);
normalized = globalized; normalized = globalized;
} }
@ -269,4 +301,10 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
.into_iter() .into_iter()
.map(|s| Ident::new(s, Span::call_site())) .map(|s| Ident::new(s, Span::call_site()))
.collect() .collect()
} else {
return [Ident::new(&normalized.join("_"), Span::call_site())]
.into_iter()
.collect();
}
} }

View File

@ -10,23 +10,22 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
cublasStatus_t::ERROR_NOT_SUPPORTED cublasStatus_t::ERROR_NOT_SUPPORTED
} }
#[allow(non_snake_case)] pub(crate) fn get_status_name(
pub fn cublasGetStatusName(_status: cuda_types::cublas::cublasStatus_t) -> *const ::core::ffi::c_char { _status: cublasStatus_t,
) -> *const ::core::ffi::c_char {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_status_string(
pub fn cublasGetStatusString(_status: cuda_types::cublas::cublasStatus_t) -> *const ::core::ffi::c_char { _status: cublasStatus_t,
) -> *const ::core::ffi::c_char {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn xerbla(_sr_name: *const ::core::ffi::c_char, _info: ::core::ffi::c_int) -> () {
pub fn cublasXerbla(_srName: *const ::core::ffi::c_char, _info: ::core::ffi::c_int) -> () {
todo!() todo!()
} }
pub(crate) fn get_cudart_version() -> usize {
#[allow(non_snake_case)]
pub fn cublasGetCudartVersion() -> usize {
todo!() todo!()
} }

View File

@ -20,7 +20,7 @@ macro_rules! implemented {
#[allow(improper_ctypes)] #[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)] #[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::$fn_name( $( $arg_id ),* ) cuda_macros::cublas_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
} }
)* )*
}; };
@ -32,6 +32,6 @@ cuda_macros::cublas_function_declarations!(
cublasGetStatusName, cublasGetStatusName,
cublasGetStatusString, cublasGetStatusString,
cublasXerbla, cublasXerbla,
cublasGetCudartVersion cublasGetCudartVersion,
] ]
); );

View File

@ -10,32 +10,28 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
cublasStatus_t::ERROR_NOT_SUPPORTED cublasStatus_t::ERROR_NOT_SUPPORTED
} }
#[allow(non_snake_case)] pub(crate) fn get_status_name(
pub(crate) fn cublasLtGetStatusName(
_status: cuda_types::cublas::cublasStatus_t, _status: cuda_types::cublas::cublasStatus_t,
) -> *const ::core::ffi::c_char { ) -> *const ::core::ffi::c_char {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_status_string(
pub(crate) fn cublasLtGetStatusString(
_status: cuda_types::cublas::cublasStatus_t, _status: cuda_types::cublas::cublasStatus_t,
) -> *const ::core::ffi::c_char { ) -> *const ::core::ffi::c_char {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_version() -> usize {
pub(crate) fn cublasLtGetVersion() -> usize { todo!()
}
pub(crate) fn get_cudart_version() -> usize {
todo!() todo!()
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub(crate) fn cublasLtGetCudartVersion() -> usize { pub(crate) fn disable_cpu_instructions_set_mask(
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cublasLtDisableCpuInstructionsSetMask(
_mask: ::core::ffi::c_uint, _mask: ::core::ffi::c_uint,
) -> ::core::ffi::c_uint { ) -> ::core::ffi::c_uint {
todo!() todo!()

View File

@ -20,7 +20,7 @@ macro_rules! implemented {
#[allow(improper_ctypes)] #[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)] #[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::$fn_name( $( $arg_id ),* ) cuda_macros::cublaslt_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
} }
)* )*
}; };

View File

@ -10,25 +10,20 @@ pub(crate) fn unimplemented() -> cudnnStatus_t {
cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED
} }
#[allow(non_snake_case)] pub(crate) fn get_version() -> usize {
pub(crate) fn cudnnGetVersion() -> usize {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_max_device_version() -> usize {
pub(crate) fn cudnnGetMaxDeviceVersion() -> usize {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_cudart_version() -> usize {
pub(crate) fn cudnnGetCudartVersion() -> usize {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_error_string(
pub(crate) fn cudnnGetErrorString(
_status: cuda_types::cudnn9::cudnnStatus_t, _status: cuda_types::cudnn9::cudnnStatus_t,
) -> *const ::core::ffi::c_char { ) -> *const ::core::ffi::c_char {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_last_error_string(_message: *mut ::core::ffi::c_char, _max_size: usize) -> () {
pub(crate) fn cudnnGetLastErrorString(_message: *mut ::core::ffi::c_char, _max_size: usize) -> () {
todo!() todo!()
} }

View File

@ -20,7 +20,7 @@ macro_rules! implemented {
#[allow(improper_ctypes)] #[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)] #[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::$fn_name( $( $arg_id ),* ) cuda_macros::cudnn_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
} }
)* )*
}; };

View File

@ -11,22 +11,19 @@ pub(crate) fn unimplemented() -> nvmlReturn_t {
nvmlReturn_t::ERROR_NOT_SUPPORTED nvmlReturn_t::ERROR_NOT_SUPPORTED
} }
#[allow(non_snake_case)] pub(crate) fn error_string(
pub(crate) fn nvmlErrorString(
_result: cuda_types::nvml::nvmlReturn_t, _result: cuda_types::nvml::nvmlReturn_t,
) -> *const ::core::ffi::c_char { ) -> *const ::core::ffi::c_char {
c"".as_ptr() c"".as_ptr()
} }
#[allow(non_snake_case)] pub(crate) fn init_v2() -> cuda_types::nvml::nvmlReturn_t {
pub(crate) fn nvmlInit_v2() -> cuda_types::nvml::nvmlReturn_t {
nvmlReturn_t::SUCCESS nvmlReturn_t::SUCCESS
} }
const VERSION: &'static CStr = c"550.77"; const VERSION: &'static CStr = c"550.77";
#[allow(non_snake_case)] pub(crate) fn system_get_driver_version(
pub(crate) fn nvmlSystemGetDriverVersion(
result: *mut ::core::ffi::c_char, result: *mut ::core::ffi::c_char,
length: ::core::ffi::c_uint, length: ::core::ffi::c_uint,
) -> cuda_types::nvml::nvmlReturn_t { ) -> cuda_types::nvml::nvmlReturn_t {

View File

@ -18,7 +18,7 @@ macro_rules! implemented_fn {
#[no_mangle] #[no_mangle]
#[allow(improper_ctypes_definitions)] #[allow(improper_ctypes_definitions)]
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
r#impl::$fn_name($($arg_id),*) cuda_macros::nvml_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
} }
)* )*
}; };

View File

@ -10,44 +10,38 @@ pub(crate) fn unimplemented() -> cusparseStatus_t {
cusparseStatus_t::ERROR_NOT_SUPPORTED cusparseStatus_t::ERROR_NOT_SUPPORTED
} }
#[allow(non_snake_case)] pub(crate) fn get_error_name(
pub(crate) fn cusparseGetErrorName(
_status: cuda_types::cusparse::cusparseStatus_t, _status: cuda_types::cusparse::cusparseStatus_t,
) -> *const ::core::ffi::c_char { ) -> *const ::core::ffi::c_char {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_error_string(
pub(crate) fn cusparseGetErrorString(
_status: cuda_types::cusparse::cusparseStatus_t, _status: cuda_types::cusparse::cusparseStatus_t,
) -> *const ::core::ffi::c_char { ) -> *const ::core::ffi::c_char {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_mat_type(
pub(crate) fn cusparseGetMatType( _descr_a: cuda_types::cusparse::cusparseMatDescr_t,
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
) -> cuda_types::cusparse::cusparseMatrixType_t { ) -> cuda_types::cusparse::cusparseMatrixType_t {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_mat_fill_mode(
pub(crate) fn cusparseGetMatFillMode( _descr_a: cuda_types::cusparse::cusparseMatDescr_t,
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
) -> cuda_types::cusparse::cusparseFillMode_t { ) -> cuda_types::cusparse::cusparseFillMode_t {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_mat_diag_type(
pub(crate) fn cusparseGetMatDiagType( _descr_a: cuda_types::cusparse::cusparseMatDescr_t,
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
) -> cuda_types::cusparse::cusparseDiagType_t { ) -> cuda_types::cusparse::cusparseDiagType_t {
todo!() todo!()
} }
#[allow(non_snake_case)] pub(crate) fn get_mat_index_base(
pub(crate) fn cusparseGetMatIndexBase( _descr_a: cuda_types::cusparse::cusparseMatDescr_t,
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
) -> cuda_types::cusparse::cusparseIndexBase_t { ) -> cuda_types::cusparse::cusparseIndexBase_t {
todo!() todo!()
} }

View File

@ -20,7 +20,7 @@ macro_rules! implemented {
#[allow(improper_ctypes)] #[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)] #[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::$fn_name( $( $arg_id ),* ) cuda_macros::cusparse_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
} }
)* )*
}; };