From 98b601d15af455828fd4c3df1215681816c29cef Mon Sep 17 00:00:00 2001 From: Violet Date: Wed, 30 Jul 2025 14:02:01 -0700 Subject: [PATCH] Use `normalize_fn` for performance libraries (#449) The goal here is to make the performance library implementations work more like zluda. --- cuda_macros/src/lib.rs | 62 ++++++++++++++++++++++++++++++++-------- zluda_blas/src/impl.rs | 17 ++++++----- zluda_blas/src/lib.rs | 4 +-- zluda_blaslt/src/impl.rs | 20 ++++++------- zluda_blaslt/src/lib.rs | 2 +- zluda_dnn/src/impl.rs | 15 ++++------ zluda_dnn/src/lib.rs | 2 +- zluda_ml/src/impl.rs | 9 ++---- zluda_ml/src/lib.rs | 2 +- zluda_sparse/src/impl.rs | 26 +++++++---------- zluda_sparse/src/lib.rs | 2 +- 11 files changed, 90 insertions(+), 71 deletions(-) diff --git a/cuda_macros/src/lib.rs b/cuda_macros/src/lib.rs index 6cef62d..37ae690 100644 --- a/cuda_macros/src/lib.rs +++ b/cuda_macros/src/lib.rs @@ -203,8 +203,11 @@ const MODULES: &[&str] = &[ "stream", ]; -#[proc_macro] -pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { +fn normalize_fn_impl( + prefix: &str, + default_module: Option<&str>, + tokens: TokenStream, +) -> TokenStream { let mut path = parse_macro_input!(tokens as syn::Path); let fn_ = path .segments @@ -215,14 +218,44 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { .ident .to_string(); let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string()); - let segments: Vec = split(&fn_[2..]); // skip "cu" - let fn_path = join(segments, !already_has_module); + let segments: Vec = split(&fn_[prefix.len()..]); // skip "cu" + let fn_path = join(segments, default_module.filter(|_| !already_has_module)); quote! { #path #fn_path } .into() } +#[proc_macro] +pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { + normalize_fn_impl("cu", Some("driver"), tokens) +} + +#[proc_macro] +pub fn cublas_normalize_fn(tokens: TokenStream) -> TokenStream { + normalize_fn_impl("cublas", None, tokens) +} + +#[proc_macro] +pub fn cublaslt_normalize_fn(tokens: TokenStream) -> TokenStream { + normalize_fn_impl("cublasLt", None, tokens) +} + +#[proc_macro] +pub fn cudnn_normalize_fn(tokens: TokenStream) -> TokenStream { + normalize_fn_impl("cudnn", None, tokens) +} + +#[proc_macro] +pub fn cusparse_normalize_fn(tokens: TokenStream) -> TokenStream { + normalize_fn_impl("cusparse", None, tokens) +} + +#[proc_macro] +pub fn nvml_normalize_fn(tokens: TokenStream) -> TokenStream { + normalize_fn_impl("nvml", None, tokens) +} + fn split(fn_: &str) -> Vec { let mut result = Vec::new(); for c in fn_.chars() { @@ -235,7 +268,10 @@ fn split(fn_: &str) -> Vec { result } -fn join(fn_: Vec, find_module: bool) -> Punctuated { +fn join( + fn_: Vec, + default_module: Option<&str>, +) -> Punctuated { fn full_form(segment: &str) -> Option<&[&str]> { Some(match segment { "ctx" => &["context"], @@ -253,13 +289,9 @@ fn join(fn_: Vec, find_module: bool) -> Punctuated { None => normalized.push(&*segment), } } - if !find_module { - return [Ident::new(&normalized.join("_"), Span::call_site())] - .into_iter() - .collect(); - } - if !MODULES.contains(&normalized[0]) { - let mut globalized = vec!["driver"]; + if let Some(default_module) = default_module { + if !MODULES.contains(&normalized[0]) { + let mut globalized = vec![default_module]; globalized.extend(normalized); normalized = globalized; } @@ -269,4 +301,10 @@ fn join(fn_: Vec, find_module: bool) -> Punctuated { .into_iter() .map(|s| Ident::new(s, Span::call_site())) .collect() + } else { + return [Ident::new(&normalized.join("_"), Span::call_site())] + .into_iter() + .collect(); + } + } diff --git a/zluda_blas/src/impl.rs b/zluda_blas/src/impl.rs index 627ecc0..feb95e2 100644 --- a/zluda_blas/src/impl.rs +++ b/zluda_blas/src/impl.rs @@ -10,23 +10,22 @@ pub(crate) fn unimplemented() -> cublasStatus_t { cublasStatus_t::ERROR_NOT_SUPPORTED } -#[allow(non_snake_case)] -pub fn cublasGetStatusName(_status: cuda_types::cublas::cublasStatus_t) -> *const ::core::ffi::c_char { +pub(crate) fn get_status_name( + _status: cublasStatus_t, +) -> *const ::core::ffi::c_char { todo!() } -#[allow(non_snake_case)] -pub fn cublasGetStatusString(_status: cuda_types::cublas::cublasStatus_t) -> *const ::core::ffi::c_char { +pub(crate) fn get_status_string( + _status: cublasStatus_t, +) -> *const ::core::ffi::c_char { todo!() } -#[allow(non_snake_case)] -pub fn cublasXerbla(_srName: *const ::core::ffi::c_char, _info: ::core::ffi::c_int) -> () { +pub(crate) fn xerbla(_sr_name: *const ::core::ffi::c_char, _info: ::core::ffi::c_int) -> () { todo!() } - -#[allow(non_snake_case)] -pub fn cublasGetCudartVersion() -> usize { +pub(crate) fn get_cudart_version() -> usize { todo!() } diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index 2f09536..ed86c01 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -20,7 +20,7 @@ macro_rules! implemented { #[allow(improper_ctypes)] #[allow(improper_ctypes_definitions)] pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - crate::r#impl::$fn_name( $( $arg_id ),* ) + cuda_macros::cublas_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* ) } )* }; @@ -32,6 +32,6 @@ cuda_macros::cublas_function_declarations!( cublasGetStatusName, cublasGetStatusString, cublasXerbla, - cublasGetCudartVersion + cublasGetCudartVersion, ] ); diff --git a/zluda_blaslt/src/impl.rs b/zluda_blaslt/src/impl.rs index d2ec310..8b67915 100644 --- a/zluda_blaslt/src/impl.rs +++ b/zluda_blaslt/src/impl.rs @@ -10,32 +10,28 @@ pub(crate) fn unimplemented() -> cublasStatus_t { cublasStatus_t::ERROR_NOT_SUPPORTED } -#[allow(non_snake_case)] -pub(crate) fn cublasLtGetStatusName( +pub(crate) fn get_status_name( _status: cuda_types::cublas::cublasStatus_t, ) -> *const ::core::ffi::c_char { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cublasLtGetStatusString( +pub(crate) fn get_status_string( _status: cuda_types::cublas::cublasStatus_t, ) -> *const ::core::ffi::c_char { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cublasLtGetVersion() -> usize { +pub(crate) fn get_version() -> usize { + todo!() +} + +pub(crate) fn get_cudart_version() -> usize { todo!() } #[allow(non_snake_case)] -pub(crate) fn cublasLtGetCudartVersion() -> usize { - todo!() -} - -#[allow(non_snake_case)] -pub(crate) fn cublasLtDisableCpuInstructionsSetMask( +pub(crate) fn disable_cpu_instructions_set_mask( _mask: ::core::ffi::c_uint, ) -> ::core::ffi::c_uint { todo!() diff --git a/zluda_blaslt/src/lib.rs b/zluda_blaslt/src/lib.rs index 8cbab95..326ac0a 100644 --- a/zluda_blaslt/src/lib.rs +++ b/zluda_blaslt/src/lib.rs @@ -20,7 +20,7 @@ macro_rules! implemented { #[allow(improper_ctypes)] #[allow(improper_ctypes_definitions)] pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - crate::r#impl::$fn_name( $( $arg_id ),* ) + cuda_macros::cublaslt_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* ) } )* }; diff --git a/zluda_dnn/src/impl.rs b/zluda_dnn/src/impl.rs index 1357224..69da3a4 100644 --- a/zluda_dnn/src/impl.rs +++ b/zluda_dnn/src/impl.rs @@ -10,25 +10,20 @@ pub(crate) fn unimplemented() -> cudnnStatus_t { cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED } -#[allow(non_snake_case)] -pub(crate) fn cudnnGetVersion() -> usize { +pub(crate) fn get_version() -> usize { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cudnnGetMaxDeviceVersion() -> usize { +pub(crate) fn get_max_device_version() -> usize { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cudnnGetCudartVersion() -> usize { +pub(crate) fn get_cudart_version() -> usize { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cudnnGetErrorString( +pub(crate) fn get_error_string( _status: cuda_types::cudnn9::cudnnStatus_t, ) -> *const ::core::ffi::c_char { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cudnnGetLastErrorString(_message: *mut ::core::ffi::c_char, _max_size: usize) -> () { +pub(crate) fn get_last_error_string(_message: *mut ::core::ffi::c_char, _max_size: usize) -> () { todo!() } diff --git a/zluda_dnn/src/lib.rs b/zluda_dnn/src/lib.rs index 2d25559..a744a59 100644 --- a/zluda_dnn/src/lib.rs +++ b/zluda_dnn/src/lib.rs @@ -20,7 +20,7 @@ macro_rules! implemented { #[allow(improper_ctypes)] #[allow(improper_ctypes_definitions)] pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - crate::r#impl::$fn_name( $( $arg_id ),* ) + cuda_macros::cudnn_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* ) } )* }; diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs index ea05633..accb048 100644 --- a/zluda_ml/src/impl.rs +++ b/zluda_ml/src/impl.rs @@ -11,22 +11,19 @@ pub(crate) fn unimplemented() -> nvmlReturn_t { nvmlReturn_t::ERROR_NOT_SUPPORTED } -#[allow(non_snake_case)] -pub(crate) fn nvmlErrorString( +pub(crate) fn error_string( _result: cuda_types::nvml::nvmlReturn_t, ) -> *const ::core::ffi::c_char { c"".as_ptr() } -#[allow(non_snake_case)] -pub(crate) fn nvmlInit_v2() -> cuda_types::nvml::nvmlReturn_t { +pub(crate) fn init_v2() -> cuda_types::nvml::nvmlReturn_t { nvmlReturn_t::SUCCESS } const VERSION: &'static CStr = c"550.77"; -#[allow(non_snake_case)] -pub(crate) fn nvmlSystemGetDriverVersion( +pub(crate) fn system_get_driver_version( result: *mut ::core::ffi::c_char, length: ::core::ffi::c_uint, ) -> cuda_types::nvml::nvmlReturn_t { diff --git a/zluda_ml/src/lib.rs b/zluda_ml/src/lib.rs index cbd1301..d65fae8 100644 --- a/zluda_ml/src/lib.rs +++ b/zluda_ml/src/lib.rs @@ -18,7 +18,7 @@ macro_rules! implemented_fn { #[no_mangle] #[allow(improper_ctypes_definitions)] pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - r#impl::$fn_name($($arg_id),*) + cuda_macros::nvml_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* ) } )* }; diff --git a/zluda_sparse/src/impl.rs b/zluda_sparse/src/impl.rs index 691961c..39c4824 100644 --- a/zluda_sparse/src/impl.rs +++ b/zluda_sparse/src/impl.rs @@ -10,44 +10,38 @@ pub(crate) fn unimplemented() -> cusparseStatus_t { cusparseStatus_t::ERROR_NOT_SUPPORTED } -#[allow(non_snake_case)] -pub(crate) fn cusparseGetErrorName( +pub(crate) fn get_error_name( _status: cuda_types::cusparse::cusparseStatus_t, ) -> *const ::core::ffi::c_char { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cusparseGetErrorString( +pub(crate) fn get_error_string( _status: cuda_types::cusparse::cusparseStatus_t, ) -> *const ::core::ffi::c_char { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cusparseGetMatType( - _descrA: cuda_types::cusparse::cusparseMatDescr_t, +pub(crate) fn get_mat_type( + _descr_a: cuda_types::cusparse::cusparseMatDescr_t, ) -> cuda_types::cusparse::cusparseMatrixType_t { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cusparseGetMatFillMode( - _descrA: cuda_types::cusparse::cusparseMatDescr_t, +pub(crate) fn get_mat_fill_mode( + _descr_a: cuda_types::cusparse::cusparseMatDescr_t, ) -> cuda_types::cusparse::cusparseFillMode_t { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cusparseGetMatDiagType( - _descrA: cuda_types::cusparse::cusparseMatDescr_t, +pub(crate) fn get_mat_diag_type( + _descr_a: cuda_types::cusparse::cusparseMatDescr_t, ) -> cuda_types::cusparse::cusparseDiagType_t { todo!() } -#[allow(non_snake_case)] -pub(crate) fn cusparseGetMatIndexBase( - _descrA: cuda_types::cusparse::cusparseMatDescr_t, +pub(crate) fn get_mat_index_base( + _descr_a: cuda_types::cusparse::cusparseMatDescr_t, ) -> cuda_types::cusparse::cusparseIndexBase_t { todo!() } diff --git a/zluda_sparse/src/lib.rs b/zluda_sparse/src/lib.rs index 1c2ad6d..795f680 100644 --- a/zluda_sparse/src/lib.rs +++ b/zluda_sparse/src/lib.rs @@ -20,7 +20,7 @@ macro_rules! implemented { #[allow(improper_ctypes)] #[allow(improper_ctypes_definitions)] pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - crate::r#impl::$fn_name( $( $arg_id ),* ) + cuda_macros::cusparse_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* ) } )* };