mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 06:47:46 +03:00
Use normalize_fn
for performance libraries (#449)
The goal here is to make the performance library implementations work more like zluda.
This commit is contained in:
@ -203,8 +203,11 @@ const MODULES: &[&str] = &[
|
|||||||
"stream",
|
"stream",
|
||||||
];
|
];
|
||||||
|
|
||||||
#[proc_macro]
|
fn normalize_fn_impl(
|
||||||
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
|
prefix: &str,
|
||||||
|
default_module: Option<&str>,
|
||||||
|
tokens: TokenStream,
|
||||||
|
) -> TokenStream {
|
||||||
let mut path = parse_macro_input!(tokens as syn::Path);
|
let mut path = parse_macro_input!(tokens as syn::Path);
|
||||||
let fn_ = path
|
let fn_ = path
|
||||||
.segments
|
.segments
|
||||||
@ -215,14 +218,44 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
|
|||||||
.ident
|
.ident
|
||||||
.to_string();
|
.to_string();
|
||||||
let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
|
let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
|
||||||
let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
|
let segments: Vec<String> = split(&fn_[prefix.len()..]); // skip "cu"
|
||||||
let fn_path = join(segments, !already_has_module);
|
let fn_path = join(segments, default_module.filter(|_| !already_has_module));
|
||||||
quote! {
|
quote! {
|
||||||
#path #fn_path
|
#path #fn_path
|
||||||
}
|
}
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[proc_macro]
|
||||||
|
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||||
|
normalize_fn_impl("cu", Some("driver"), tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[proc_macro]
|
||||||
|
pub fn cublas_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||||
|
normalize_fn_impl("cublas", None, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[proc_macro]
|
||||||
|
pub fn cublaslt_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||||
|
normalize_fn_impl("cublasLt", None, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[proc_macro]
|
||||||
|
pub fn cudnn_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||||
|
normalize_fn_impl("cudnn", None, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[proc_macro]
|
||||||
|
pub fn cusparse_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||||
|
normalize_fn_impl("cusparse", None, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[proc_macro]
|
||||||
|
pub fn nvml_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||||
|
normalize_fn_impl("nvml", None, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
fn split(fn_: &str) -> Vec<String> {
|
fn split(fn_: &str) -> Vec<String> {
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
for c in fn_.chars() {
|
for c in fn_.chars() {
|
||||||
@ -235,7 +268,10 @@ fn split(fn_: &str) -> Vec<String> {
|
|||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
|
fn join(
|
||||||
|
fn_: Vec<String>,
|
||||||
|
default_module: Option<&str>,
|
||||||
|
) -> Punctuated<Ident, Token![::]> {
|
||||||
fn full_form(segment: &str) -> Option<&[&str]> {
|
fn full_form(segment: &str) -> Option<&[&str]> {
|
||||||
Some(match segment {
|
Some(match segment {
|
||||||
"ctx" => &["context"],
|
"ctx" => &["context"],
|
||||||
@ -253,13 +289,9 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
|
|||||||
None => normalized.push(&*segment),
|
None => normalized.push(&*segment),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !find_module {
|
if let Some(default_module) = default_module {
|
||||||
return [Ident::new(&normalized.join("_"), Span::call_site())]
|
if !MODULES.contains(&normalized[0]) {
|
||||||
.into_iter()
|
let mut globalized = vec![default_module];
|
||||||
.collect();
|
|
||||||
}
|
|
||||||
if !MODULES.contains(&normalized[0]) {
|
|
||||||
let mut globalized = vec!["driver"];
|
|
||||||
globalized.extend(normalized);
|
globalized.extend(normalized);
|
||||||
normalized = globalized;
|
normalized = globalized;
|
||||||
}
|
}
|
||||||
@ -269,4 +301,10 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|s| Ident::new(s, Span::call_site()))
|
.map(|s| Ident::new(s, Span::call_site()))
|
||||||
.collect()
|
.collect()
|
||||||
|
} else {
|
||||||
|
return [Ident::new(&normalized.join("_"), Span::call_site())]
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -10,23 +10,22 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
|
|||||||
cublasStatus_t::ERROR_NOT_SUPPORTED
|
cublasStatus_t::ERROR_NOT_SUPPORTED
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_status_name(
|
||||||
pub fn cublasGetStatusName(_status: cuda_types::cublas::cublasStatus_t) -> *const ::core::ffi::c_char {
|
_status: cublasStatus_t,
|
||||||
|
) -> *const ::core::ffi::c_char {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_status_string(
|
||||||
pub fn cublasGetStatusString(_status: cuda_types::cublas::cublasStatus_t) -> *const ::core::ffi::c_char {
|
_status: cublasStatus_t,
|
||||||
|
) -> *const ::core::ffi::c_char {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn xerbla(_sr_name: *const ::core::ffi::c_char, _info: ::core::ffi::c_int) -> () {
|
||||||
pub fn cublasXerbla(_srName: *const ::core::ffi::c_char, _info: ::core::ffi::c_int) -> () {
|
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_cudart_version() -> usize {
|
||||||
#[allow(non_snake_case)]
|
|
||||||
pub fn cublasGetCudartVersion() -> usize {
|
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ macro_rules! implemented {
|
|||||||
#[allow(improper_ctypes)]
|
#[allow(improper_ctypes)]
|
||||||
#[allow(improper_ctypes_definitions)]
|
#[allow(improper_ctypes_definitions)]
|
||||||
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||||
crate::r#impl::$fn_name( $( $arg_id ),* )
|
cuda_macros::cublas_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
};
|
};
|
||||||
@ -32,6 +32,6 @@ cuda_macros::cublas_function_declarations!(
|
|||||||
cublasGetStatusName,
|
cublasGetStatusName,
|
||||||
cublasGetStatusString,
|
cublasGetStatusString,
|
||||||
cublasXerbla,
|
cublasXerbla,
|
||||||
cublasGetCudartVersion
|
cublasGetCudartVersion,
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
@ -10,32 +10,28 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
|
|||||||
cublasStatus_t::ERROR_NOT_SUPPORTED
|
cublasStatus_t::ERROR_NOT_SUPPORTED
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_status_name(
|
||||||
pub(crate) fn cublasLtGetStatusName(
|
|
||||||
_status: cuda_types::cublas::cublasStatus_t,
|
_status: cuda_types::cublas::cublasStatus_t,
|
||||||
) -> *const ::core::ffi::c_char {
|
) -> *const ::core::ffi::c_char {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_status_string(
|
||||||
pub(crate) fn cublasLtGetStatusString(
|
|
||||||
_status: cuda_types::cublas::cublasStatus_t,
|
_status: cuda_types::cublas::cublasStatus_t,
|
||||||
) -> *const ::core::ffi::c_char {
|
) -> *const ::core::ffi::c_char {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_version() -> usize {
|
||||||
pub(crate) fn cublasLtGetVersion() -> usize {
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_cudart_version() -> usize {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
pub(crate) fn cublasLtGetCudartVersion() -> usize {
|
pub(crate) fn disable_cpu_instructions_set_mask(
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
pub(crate) fn cublasLtDisableCpuInstructionsSetMask(
|
|
||||||
_mask: ::core::ffi::c_uint,
|
_mask: ::core::ffi::c_uint,
|
||||||
) -> ::core::ffi::c_uint {
|
) -> ::core::ffi::c_uint {
|
||||||
todo!()
|
todo!()
|
||||||
|
@ -20,7 +20,7 @@ macro_rules! implemented {
|
|||||||
#[allow(improper_ctypes)]
|
#[allow(improper_ctypes)]
|
||||||
#[allow(improper_ctypes_definitions)]
|
#[allow(improper_ctypes_definitions)]
|
||||||
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||||
crate::r#impl::$fn_name( $( $arg_id ),* )
|
cuda_macros::cublaslt_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
};
|
};
|
||||||
|
@ -10,25 +10,20 @@ pub(crate) fn unimplemented() -> cudnnStatus_t {
|
|||||||
cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED
|
cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_version() -> usize {
|
||||||
pub(crate) fn cudnnGetVersion() -> usize {
|
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_max_device_version() -> usize {
|
||||||
pub(crate) fn cudnnGetMaxDeviceVersion() -> usize {
|
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_cudart_version() -> usize {
|
||||||
pub(crate) fn cudnnGetCudartVersion() -> usize {
|
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_error_string(
|
||||||
pub(crate) fn cudnnGetErrorString(
|
|
||||||
_status: cuda_types::cudnn9::cudnnStatus_t,
|
_status: cuda_types::cudnn9::cudnnStatus_t,
|
||||||
) -> *const ::core::ffi::c_char {
|
) -> *const ::core::ffi::c_char {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_last_error_string(_message: *mut ::core::ffi::c_char, _max_size: usize) -> () {
|
||||||
pub(crate) fn cudnnGetLastErrorString(_message: *mut ::core::ffi::c_char, _max_size: usize) -> () {
|
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ macro_rules! implemented {
|
|||||||
#[allow(improper_ctypes)]
|
#[allow(improper_ctypes)]
|
||||||
#[allow(improper_ctypes_definitions)]
|
#[allow(improper_ctypes_definitions)]
|
||||||
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||||
crate::r#impl::$fn_name( $( $arg_id ),* )
|
cuda_macros::cudnn_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
};
|
};
|
||||||
|
@ -11,22 +11,19 @@ pub(crate) fn unimplemented() -> nvmlReturn_t {
|
|||||||
nvmlReturn_t::ERROR_NOT_SUPPORTED
|
nvmlReturn_t::ERROR_NOT_SUPPORTED
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn error_string(
|
||||||
pub(crate) fn nvmlErrorString(
|
|
||||||
_result: cuda_types::nvml::nvmlReturn_t,
|
_result: cuda_types::nvml::nvmlReturn_t,
|
||||||
) -> *const ::core::ffi::c_char {
|
) -> *const ::core::ffi::c_char {
|
||||||
c"".as_ptr()
|
c"".as_ptr()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn init_v2() -> cuda_types::nvml::nvmlReturn_t {
|
||||||
pub(crate) fn nvmlInit_v2() -> cuda_types::nvml::nvmlReturn_t {
|
|
||||||
nvmlReturn_t::SUCCESS
|
nvmlReturn_t::SUCCESS
|
||||||
}
|
}
|
||||||
|
|
||||||
const VERSION: &'static CStr = c"550.77";
|
const VERSION: &'static CStr = c"550.77";
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn system_get_driver_version(
|
||||||
pub(crate) fn nvmlSystemGetDriverVersion(
|
|
||||||
result: *mut ::core::ffi::c_char,
|
result: *mut ::core::ffi::c_char,
|
||||||
length: ::core::ffi::c_uint,
|
length: ::core::ffi::c_uint,
|
||||||
) -> cuda_types::nvml::nvmlReturn_t {
|
) -> cuda_types::nvml::nvmlReturn_t {
|
||||||
|
@ -18,7 +18,7 @@ macro_rules! implemented_fn {
|
|||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
#[allow(improper_ctypes_definitions)]
|
#[allow(improper_ctypes_definitions)]
|
||||||
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||||
r#impl::$fn_name($($arg_id),*)
|
cuda_macros::nvml_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
};
|
};
|
||||||
|
@ -10,44 +10,38 @@ pub(crate) fn unimplemented() -> cusparseStatus_t {
|
|||||||
cusparseStatus_t::ERROR_NOT_SUPPORTED
|
cusparseStatus_t::ERROR_NOT_SUPPORTED
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_error_name(
|
||||||
pub(crate) fn cusparseGetErrorName(
|
|
||||||
_status: cuda_types::cusparse::cusparseStatus_t,
|
_status: cuda_types::cusparse::cusparseStatus_t,
|
||||||
) -> *const ::core::ffi::c_char {
|
) -> *const ::core::ffi::c_char {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_error_string(
|
||||||
pub(crate) fn cusparseGetErrorString(
|
|
||||||
_status: cuda_types::cusparse::cusparseStatus_t,
|
_status: cuda_types::cusparse::cusparseStatus_t,
|
||||||
) -> *const ::core::ffi::c_char {
|
) -> *const ::core::ffi::c_char {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_mat_type(
|
||||||
pub(crate) fn cusparseGetMatType(
|
_descr_a: cuda_types::cusparse::cusparseMatDescr_t,
|
||||||
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
|
|
||||||
) -> cuda_types::cusparse::cusparseMatrixType_t {
|
) -> cuda_types::cusparse::cusparseMatrixType_t {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_mat_fill_mode(
|
||||||
pub(crate) fn cusparseGetMatFillMode(
|
_descr_a: cuda_types::cusparse::cusparseMatDescr_t,
|
||||||
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
|
|
||||||
) -> cuda_types::cusparse::cusparseFillMode_t {
|
) -> cuda_types::cusparse::cusparseFillMode_t {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_mat_diag_type(
|
||||||
pub(crate) fn cusparseGetMatDiagType(
|
_descr_a: cuda_types::cusparse::cusparseMatDescr_t,
|
||||||
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
|
|
||||||
) -> cuda_types::cusparse::cusparseDiagType_t {
|
) -> cuda_types::cusparse::cusparseDiagType_t {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
pub(crate) fn get_mat_index_base(
|
||||||
pub(crate) fn cusparseGetMatIndexBase(
|
_descr_a: cuda_types::cusparse::cusparseMatDescr_t,
|
||||||
_descrA: cuda_types::cusparse::cusparseMatDescr_t,
|
|
||||||
) -> cuda_types::cusparse::cusparseIndexBase_t {
|
) -> cuda_types::cusparse::cusparseIndexBase_t {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ macro_rules! implemented {
|
|||||||
#[allow(improper_ctypes)]
|
#[allow(improper_ctypes)]
|
||||||
#[allow(improper_ctypes_definitions)]
|
#[allow(improper_ctypes_definitions)]
|
||||||
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||||
crate::r#impl::$fn_name( $( $arg_id ),* )
|
cuda_macros::cusparse_normalize_fn!( crate::r#impl::$fn_name ) ( $( $arg_id ),* )
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user