diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index d5816e8..91f3dc2 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -35,7 +35,7 @@ fn main() { generate_cublaslt(&crate_root); generate_cufft(&crate_root); generate_cusparse(&crate_root); - generate_cudnn(&crate_root); + // generate_cudnn(&crate_root); } fn generate_process_address_table(crate_root: &PathBuf, mut cuda_fns: Vec) { @@ -151,6 +151,7 @@ fn generate_cufft(crate_root: &PathBuf) { .allowlist_function("^cufft.*") .allowlist_var("^CUFFT_.*") .must_use_type("cufftResult_t") + .constified_enum("cufftResult_t") .allowlist_recursively(false) .clang_args(["-I/usr/local/cuda/include"]) .generate() @@ -163,7 +164,15 @@ fn generate_cufft(crate_root: &PathBuf) { &["..", "cuda_macros", "src", "cufft.rs"], &module, ); + let convert_options = ConvertIntoRustResultOptions { + type_: "cufftResult", + underlying_type: "cufftResult_t", + new_error_type: "cufftError_t", + error_prefix: ("CUFFT_", "ERROR_"), + success: ("CUFFT_SUCCESS", "SUCCESS"), + }; generate_types_library( + Some(convert_options), Some(LibraryOverride::CuFft), &crate_root, &["..", "cuda_types", "src", "cufft.rs"], @@ -206,6 +215,7 @@ fn generate_cusparse(crate_root: &PathBuf) { .allowlist_function("^cusparse.*") .allowlist_var("^CUSPARSE_.*") .must_use_type("cusparseStatus_t") + .constified_enum("cusparseStatus_t") .allowlist_recursively(false) .clang_args(["-I/usr/local/cuda/include"]) .generate() @@ -218,7 +228,15 @@ fn generate_cusparse(crate_root: &PathBuf) { &["..", "cuda_macros", "src", "cusparse.rs"], &module, ); + let convert_options = ConvertIntoRustResultOptions { + type_: "cusparseStatus_t", + underlying_type: "cusparseStatus_t", + new_error_type: "cusparseError_t", + error_prefix: ("CUSPARSE_STATUS_", "ERROR_"), + success: ("CUSPARSE_STATUS_SUCCESS", "SUCCESS"), + }; generate_types_library( + Some(convert_options), None, &crate_root, &["..", "cuda_types", "src", "cusparse.rs"], @@ -239,13 +257,21 @@ fn generate_cudnn(crate_root: &PathBuf) { .allowlist_function("^cudnn.*") .allowlist_var("^CUDNN_.*") .must_use_type("cudnnStatus_t") + .constified_enum("cudnnStatus_t") .allowlist_recursively(false) .clang_args(["-I/usr/local/cuda/include"]) .generate() .unwrap() .to_string(); + let convert_options = ConvertIntoRustResultOptions { + type_: "cudnnStatus_t", + underlying_type: "cudnnStatus_t", + new_error_type: "cudnnError_", + error_prefix: ("CUDNN_STATUS_", "ERROR_"), + success: ("CUDNN_STATUS_SUCCESS", "SUCCESS"), + }; let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap(); - let cudnn9_types = generate_types_library_impl(&cudnn9_module); + let cudnn9_types = generate_types_library_impl(Some(convert_options.clone()), &cudnn9_module); let mut current_dir = PathBuf::from(file!()); current_dir.pop(); let cudnn8 = new_builder() @@ -263,7 +289,7 @@ fn generate_cudnn(crate_root: &PathBuf) { .unwrap() .to_string(); let cudnn8_module: syn::File = syn::parse_str(&cudnn8).unwrap(); - let cudnn8_types = generate_types_library_impl(&cudnn8_module); + let cudnn8_types = generate_types_library_impl(Some(convert_options), &cudnn8_module); merge_types( &crate_root, &["..", "cuda_types", "src", "cudnn.rs"], @@ -624,6 +650,7 @@ fn generate_cublas(crate_root: &PathBuf) { .allowlist_function("^cublas.*") .allowlist_var("^CUBLAS_.*") .must_use_type("cublasStatus_t") + .constified_enum("cublasStatus_t") .allowlist_recursively(false) .clang_args(["-I/usr/local/cuda/include", "-x", "c++"]) .generate() @@ -636,7 +663,15 @@ fn generate_cublas(crate_root: &PathBuf) { &["..", "cuda_macros", "src", "cublas.rs"], &module, ); + let convert_options = ConvertIntoRustResultOptions { + type_: "cublasStatus_t", + underlying_type: "cublasStatus_t", + new_error_type: "cublasError_t", + error_prefix: ("CUBLAS_STATUS_", "ERROR_"), + success: ("CUBLAS_STATUS_SUCCESS", "SUCCESS"), + }; generate_types_library( + Some(convert_options), None, &crate_root, &["..", "cuda_types", "src", "cublas.rs"], @@ -708,6 +743,7 @@ fn generate_cublaslt(crate_root: &PathBuf) { &module_blas, ); generate_types_library( + None, Some(LibraryOverride::CuBlasLt), &crate_root, &["..", "cuda_types", "src", "cublaslt.rs"], @@ -782,32 +818,22 @@ fn generate_ml(crate_root: &PathBuf) { .generate() .unwrap() .to_string(); - let mut module: syn::File = syn::parse_str(&ml_header).unwrap(); - let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions { - type_: "nvmlReturn_t", - underlying_type: "nvmlReturn_enum", - new_error_type: "nvmlError_t", - error_prefix: ("NVML_ERROR_", "ERROR_"), - success: ("NVML_SUCCESS", "SUCCESS"), - }); - module.items = module - .items - .into_iter() - .filter_map(|item| match item { - Item::Const(const_) => converter.get_const(const_).map(Item::Const), - Item::Use(use_) => converter.get_use(use_).map(Item::Use), - Item::Type(type_) => converter.get_type(type_).map(Item::Type), - item => Some(item), - }) - .collect::>(); - converter.flush(&mut module.items); + let module: syn::File = syn::parse_str(&ml_header).unwrap(); generate_functions( &crate_root, "nvml", &["..", "cuda_macros", "src", "nvml.rs"], &module, ); + let convert_options = ConvertIntoRustResultOptions { + type_: "nvmlReturn_t", + underlying_type: "nvmlReturn_enum", + new_error_type: "nvmlError_t", + error_prefix: ("NVML_ERROR_", "ERROR_"), + success: ("NVML_SUCCESS", "SUCCESS"), + }; generate_types_library( + Some(convert_options), None, &crate_root, &["..", "cuda_types", "src", "nvml.rs"], @@ -816,12 +842,13 @@ fn generate_ml(crate_root: &PathBuf) { } fn generate_types_library( + convert_options: Option, override_: Option, crate_root: &PathBuf, path: &[&str], module: &syn::File, ) { - let module = generate_types_library_impl(module); + let module = generate_types_library_impl(convert_options, module); let mut output = crate_root.clone(); output.extend(path); let mut text = @@ -846,7 +873,10 @@ enum LibraryOverride { CuFft, } -fn generate_types_library_impl(module: &syn::File) -> syn::File { +fn generate_types_library_impl( + convert_options: Option, + module: &syn::File, +) -> syn::File { let known_reexports: Punctuated = parse_quote! { pub type __half = u16; pub type __nv_bfloat16 = u16; @@ -860,11 +890,28 @@ fn generate_types_library_impl(module: &syn::File) -> syn::File { pub type cudaAsyncNotificationType = super::cuda::CUasyncNotificationType_enum; pub type cudaGraph_t = super::cuda::CUgraph; }; - let non_fn = module.items.iter().filter_map(|item| match item { + let remove_functions = |item| match item { Item::ForeignMod(_) => None, _ => Some(item), - }); - let items = known_reexports.iter().chain(non_fn); + }; + let non_fn = if let Some(options) = convert_options { + let mut converter = ConvertIntoRustResult::new(options); + let mut non_fn = converter + .convert(module.items.clone()) + .filter_map(remove_functions) + .collect::>(); + converter.flush(&mut non_fn); + non_fn + } else { + let non_fn = module + .items + .clone() + .into_iter() + .filter_map(remove_functions) + .collect::>(); + non_fn + }; + let items = known_reexports.into_iter().chain(non_fn); parse_quote! { #(#items)* } @@ -894,16 +941,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { error_prefix: ("hipError", "Error"), success: ("hipSuccess", "Success"), }); - module.items = module - .items - .into_iter() - .filter_map(|item| match item { - Item::Const(const_) => converter.get_const(const_).map(Item::Const), - Item::Use(use_) => converter.get_use(use_).map(Item::Use), - Item::Type(type_) => converter.get_type(type_).map(Item::Type), - item => Some(item), - }) - .collect::>(); + module.items = converter.convert(module.items).collect::>(); converter.flush(&mut module.items); add_send_sync( &mut module.items, @@ -1004,14 +1042,10 @@ fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) { error_prefix: ("CUDA_ERROR_", "ERROR_"), success: ("CUDA_SUCCESS", "SUCCESS"), }); - module.items = module - .items - .into_iter() + module.items = converter + .convert(module.items) .filter_map(|item| match item { Item::ForeignMod(_) => None, - Item::Const(const_) => converter.get_const(const_).map(Item::Const), - Item::Use(use_) => converter.get_use(use_).map(Item::Use), - Item::Type(type_) => converter.get_type(type_).map(Item::Type), Item::Struct(mut struct_) => { let ident_string = struct_.ident.to_string(); match &*ident_string { @@ -1064,6 +1098,7 @@ fn write_rust_to_file(path: impl AsRef, content: &str) { file.write(content.as_bytes()).unwrap(); } +#[derive(Clone)] struct ConvertIntoRustResultOptions { type_: &'static str, underlying_type: &'static str, @@ -1165,6 +1200,15 @@ impl ConvertIntoRustResult { Some(type_) } } + + fn convert(&mut self, items: Vec) -> impl Iterator + use<'_> { + items.into_iter().filter_map(|item| match item { + Item::Const(const_) => self.get_const(const_).map(Item::Const), + Item::Use(use_) => self.get_use(use_).map(Item::Use), + Item::Type(type_) => self.get_type(type_).map(Item::Type), + item => Some(item), + }) + } } struct FixAbi; @@ -1324,6 +1368,7 @@ fn generate_display_perflib( .iter() .filter_map(|i| cuda_derive_display_trait_for_item(types_crate, &mut derive_state, i)) .collect::>(); + // TODO: derive result display trait (look at curesult_display_trait) let mut output = output.clone(); output.extend(path); write_rust_to_file( diff --git a/zluda_fft/src/impl.rs b/zluda_fft/src/impl.rs index ece814e..94cc02b 100644 --- a/zluda_fft/src/impl.rs +++ b/zluda_fft/src/impl.rs @@ -1,11 +1,11 @@ -use cuda_types::cufft::cufftResult_t; +use cuda_types::cufft::*; #[cfg(debug_assertions)] -pub(crate) fn unimplemented() -> cufftResult_t { +pub(crate) fn unimplemented() -> cufftResult { unimplemented!() } #[cfg(not(debug_assertions))] -pub(crate) fn unimplemented() -> cufftResult_t { - cufftResult_t::CUFFT_NOT_SUPPORTED +pub(crate) fn unimplemented() -> cufftResult { + cufftResult::ERROR_NOT_SUPPORTED }