Use normalize_fn for performance libraries (#449)

The goal here is to make the performance library implementations work more like zluda.
This commit is contained in:
Violet
2025-07-30 14:02:01 -07:00
committed by GitHub
parent c07d7678cd
commit 98b601d15a
11 changed files with 90 additions and 71 deletions

View File

@ -203,8 +203,11 @@ const MODULES: &[&str] = &[
"stream",
];
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
fn normalize_fn_impl(
prefix: &str,
default_module: Option<&str>,
tokens: TokenStream,
) -> TokenStream {
let mut path = parse_macro_input!(tokens as syn::Path);
let fn_ = path
.segments
@ -215,14 +218,44 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.ident
.to_string();
let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
let fn_path = join(segments, !already_has_module);
let segments: Vec<String> = split(&fn_[prefix.len()..]); // skip "cu"
let fn_path = join(segments, default_module.filter(|_| !already_has_module));
quote! {
#path #fn_path
}
.into()
}
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cu", Some("driver"), tokens)
}
#[proc_macro]
pub fn cublas_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cublas", None, tokens)
}
#[proc_macro]
pub fn cublaslt_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cublasLt", None, tokens)
}
#[proc_macro]
pub fn cudnn_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cudnn", None, tokens)
}
#[proc_macro]
pub fn cusparse_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cusparse", None, tokens)
}
#[proc_macro]
pub fn nvml_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("nvml", None, tokens)
}
fn split(fn_: &str) -> Vec<String> {
let mut result = Vec::new();
for c in fn_.chars() {
@ -235,7 +268,10 @@ fn split(fn_: &str) -> Vec<String> {
result
}
fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
fn join(
fn_: Vec<String>,
default_module: Option<&str>,
) -> Punctuated<Ident, Token![::]> {
fn full_form(segment: &str) -> Option<&[&str]> {
Some(match segment {
"ctx" => &["context"],
@ -253,13 +289,9 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
None => normalized.push(&*segment),
}
}
if !find_module {
return [Ident::new(&normalized.join("_"), Span::call_site())]
.into_iter()
.collect();
}
if let Some(default_module) = default_module {
if !MODULES.contains(&normalized[0]) {
let mut globalized = vec!["driver"];
let mut globalized = vec![default_module];
globalized.extend(normalized);
normalized = globalized;
}
@ -269,4 +301,10 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
.into_iter()
.map(|s| Ident::new(s, Span::call_site()))
.collect()
} else {
return [Ident::new(&normalized.join("_"), Span::call_site())]
.into_iter()
.collect();
}
}

View File

@ -10,23 +10,22 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
cublasStatus_t::ERROR_NOT_SUPPORTED
}
#[allow(non_snake_case)]
pub fn cublasGetStatusName(_status: cuda_types::cublas::cublasStatus_t) -> *const ::core::ffi::c_char {
pub(crate) fn get_status_name(
_status: cublasStatus_t,
) -> *const ::core::ffi::c_char {
todo!()
}
#[allow(non_snake_case)]
pub fn cublasGetStatusString(_status: cuda_types::cublas::cublasStatus_t) -> *const ::core::ffi::c_char {
pub(crate) fn get_status_string(
_status: cublasStatus_t,
) -> *const ::core::ffi::c_char {
todo!()
}
#[allow(non_snake_case)]
pub fn cublasXerbla(_srName: *const ::core::ffi::c_char, _info: ::core::ffi::c_int) -> () {
pub(crate) fn xerbla(_sr_name: *const ::core::ffi::c_char, _info: ::core::ffi::c_int) -> () {
todo!()
}
#[allow(non_snake_case)]
pub fn cublasGetCudartVersion() -> usize {
pub(crate) fn get_cudart_version() -> usize {
todo!()
}

View File

@ -20,7 +20,7 @@ macro_rules! implemented {
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::$fn_name( $( $arg_id ),* )
cuda_macros::cublas_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
}
)*
};
@ -32,6 +32,6 @@ cuda_macros::cublas_function_declarations!(
cublasGetStatusName,
cublasGetStatusString,
cublasXerbla,
cublasGetCudartVersion
cublasGetCudartVersion,
]
);

View File

@ -10,32 +10,28 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
cublasStatus_t::ERROR_NOT_SUPPORTED
}
#[allow(non_snake_case)]
pub(crate) fn cublasLtGetStatusName(
pub(crate) fn get_status_name(
_status: cuda_types::cublas::cublasStatus_t,
) -> *const ::core::ffi::c_char {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cublasLtGetStatusString(
pub(crate) fn get_status_string(
_status: cuda_types::cublas::cublasStatus_t,
) -> *const ::core::ffi::c_char {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cublasLtGetVersion() -> usize {
pub(crate) fn get_version() -> usize {
todo!()
}
pub(crate) fn get_cudart_version() -> usize {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cublasLtGetCudartVersion() -> usize {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cublasLtDisableCpuInstructionsSetMask(
pub(crate) fn disable_cpu_instructions_set_mask(
_mask: ::core::ffi::c_uint,
) -> ::core::ffi::c_uint {
todo!()

View File

@ -20,7 +20,7 @@ macro_rules! implemented {
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::$fn_name( $( $arg_id ),* )
cuda_macros::cublaslt_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
}
)*
};

View File

@ -10,25 +10,20 @@ pub(crate) fn unimplemented() -> cudnnStatus_t {
cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED
}
#[allow(non_snake_case)]
pub(crate) fn cudnnGetVersion() -> usize {
pub(crate) fn get_version() -> usize {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cudnnGetMaxDeviceVersion() -> usize {
pub(crate) fn get_max_device_version() -> usize {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cudnnGetCudartVersion() -> usize {
pub(crate) fn get_cudart_version() -> usize {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cudnnGetErrorString(
pub(crate) fn get_error_string(
_status: cuda_types::cudnn9::cudnnStatus_t,
) -> *const ::core::ffi::c_char {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cudnnGetLastErrorString(_message: *mut ::core::ffi::c_char, _max_size: usize) -> () {
pub(crate) fn get_last_error_string(_message: *mut ::core::ffi::c_char, _max_size: usize) -> () {
todo!()
}

View File

@ -20,7 +20,7 @@ macro_rules! implemented {
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::$fn_name( $( $arg_id ),* )
cuda_macros::cudnn_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
}
)*
};

View File

@ -11,22 +11,19 @@ pub(crate) fn unimplemented() -> nvmlReturn_t {
nvmlReturn_t::ERROR_NOT_SUPPORTED
}
#[allow(non_snake_case)]
pub(crate) fn nvmlErrorString(
pub(crate) fn error_string(
_result: cuda_types::nvml::nvmlReturn_t,
) -> *const ::core::ffi::c_char {
c"".as_ptr()
}
#[allow(non_snake_case)]
pub(crate) fn nvmlInit_v2() -> cuda_types::nvml::nvmlReturn_t {
pub(crate) fn init_v2() -> cuda_types::nvml::nvmlReturn_t {
nvmlReturn_t::SUCCESS
}
const VERSION: &'static CStr = c"550.77";
#[allow(non_snake_case)]
pub(crate) fn nvmlSystemGetDriverVersion(
pub(crate) fn system_get_driver_version(
result: *mut ::core::ffi::c_char,
length: ::core::ffi::c_uint,
) -> cuda_types::nvml::nvmlReturn_t {

View File

@ -18,7 +18,7 @@ macro_rules! implemented_fn {
#[no_mangle]
#[allow(improper_ctypes_definitions)]
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
r#impl::$fn_name($($arg_id),*)
cuda_macros::nvml_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
}
)*
};

View File

@ -10,44 +10,38 @@ pub(crate) fn unimplemented() -> cusparseStatus_t {
cusparseStatus_t::ERROR_NOT_SUPPORTED
}
#[allow(non_snake_case)]
pub(crate) fn cusparseGetErrorName(
pub(crate) fn get_error_name(
_status: cuda_types::cusparse::cusparseStatus_t,
) -> *const ::core::ffi::c_char {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cusparseGetErrorString(
pub(crate) fn get_error_string(
_status: cuda_types::cusparse::cusparseStatus_t,
) -> *const ::core::ffi::c_char {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cusparseGetMatType(
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
pub(crate) fn get_mat_type(
_descr_a: cuda_types::cusparse::cusparseMatDescr_t,
) -> cuda_types::cusparse::cusparseMatrixType_t {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cusparseGetMatFillMode(
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
pub(crate) fn get_mat_fill_mode(
_descr_a: cuda_types::cusparse::cusparseMatDescr_t,
) -> cuda_types::cusparse::cusparseFillMode_t {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cusparseGetMatDiagType(
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
pub(crate) fn get_mat_diag_type(
_descr_a: cuda_types::cusparse::cusparseMatDescr_t,
) -> cuda_types::cusparse::cusparseDiagType_t {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn cusparseGetMatIndexBase(
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
pub(crate) fn get_mat_index_base(
_descr_a: cuda_types::cusparse::cusparseMatDescr_t,
) -> cuda_types::cusparse::cusparseIndexBase_t {
todo!()
}

View File

@ -20,7 +20,7 @@ macro_rules! implemented {
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::$fn_name( $( $arg_id ),* )
cuda_macros::cusparse_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
}
)*
};