Rustify performance library return types

This commit is contained in:
Violet
2025-07-26 23:12:56 +00:00
parent 01552367dc
commit 4e383406b9
2 changed files with 92 additions and 47 deletions

View File

@ -35,7 +35,7 @@ fn main() {
generate_cublaslt(&crate_root); generate_cublaslt(&crate_root);
generate_cufft(&crate_root); generate_cufft(&crate_root);
generate_cusparse(&crate_root); generate_cusparse(&crate_root);
generate_cudnn(&crate_root); // generate_cudnn(&crate_root);
} }
fn generate_process_address_table(crate_root: &PathBuf, mut cuda_fns: Vec<Ident>) { fn generate_process_address_table(crate_root: &PathBuf, mut cuda_fns: Vec<Ident>) {
@ -151,6 +151,7 @@ fn generate_cufft(crate_root: &PathBuf) {
.allowlist_function("^cufft.*") .allowlist_function("^cufft.*")
.allowlist_var("^CUFFT_.*") .allowlist_var("^CUFFT_.*")
.must_use_type("cufftResult_t") .must_use_type("cufftResult_t")
.constified_enum("cufftResult_t")
.allowlist_recursively(false) .allowlist_recursively(false)
.clang_args(["-I/usr/local/cuda/include"]) .clang_args(["-I/usr/local/cuda/include"])
.generate() .generate()
@ -163,7 +164,15 @@ fn generate_cufft(crate_root: &PathBuf) {
&["..", "cuda_macros", "src", "cufft.rs"], &["..", "cuda_macros", "src", "cufft.rs"],
&module, &module,
); );
let convert_options = ConvertIntoRustResultOptions {
type_: "cufftResult",
underlying_type: "cufftResult_t",
new_error_type: "cufftError_t",
error_prefix: ("CUFFT_", "ERROR_"),
success: ("CUFFT_SUCCESS", "SUCCESS"),
};
generate_types_library( generate_types_library(
Some(convert_options),
Some(LibraryOverride::CuFft), Some(LibraryOverride::CuFft),
&crate_root, &crate_root,
&["..", "cuda_types", "src", "cufft.rs"], &["..", "cuda_types", "src", "cufft.rs"],
@ -206,6 +215,7 @@ fn generate_cusparse(crate_root: &PathBuf) {
.allowlist_function("^cusparse.*") .allowlist_function("^cusparse.*")
.allowlist_var("^CUSPARSE_.*") .allowlist_var("^CUSPARSE_.*")
.must_use_type("cusparseStatus_t") .must_use_type("cusparseStatus_t")
.constified_enum("cusparseStatus_t")
.allowlist_recursively(false) .allowlist_recursively(false)
.clang_args(["-I/usr/local/cuda/include"]) .clang_args(["-I/usr/local/cuda/include"])
.generate() .generate()
@ -218,7 +228,15 @@ fn generate_cusparse(crate_root: &PathBuf) {
&["..", "cuda_macros", "src", "cusparse.rs"], &["..", "cuda_macros", "src", "cusparse.rs"],
&module, &module,
); );
let convert_options = ConvertIntoRustResultOptions {
type_: "cusparseStatus_t",
underlying_type: "cusparseStatus_t",
new_error_type: "cusparseError_t",
error_prefix: ("CUSPARSE_STATUS_", "ERROR_"),
success: ("CUSPARSE_STATUS_SUCCESS", "SUCCESS"),
};
generate_types_library( generate_types_library(
Some(convert_options),
None, None,
&crate_root, &crate_root,
&["..", "cuda_types", "src", "cusparse.rs"], &["..", "cuda_types", "src", "cusparse.rs"],
@ -239,13 +257,21 @@ fn generate_cudnn(crate_root: &PathBuf) {
.allowlist_function("^cudnn.*") .allowlist_function("^cudnn.*")
.allowlist_var("^CUDNN_.*") .allowlist_var("^CUDNN_.*")
.must_use_type("cudnnStatus_t") .must_use_type("cudnnStatus_t")
.constified_enum("cudnnStatus_t")
.allowlist_recursively(false) .allowlist_recursively(false)
.clang_args(["-I/usr/local/cuda/include"]) .clang_args(["-I/usr/local/cuda/include"])
.generate() .generate()
.unwrap() .unwrap()
.to_string(); .to_string();
let convert_options = ConvertIntoRustResultOptions {
type_: "cudnnStatus_t",
underlying_type: "cudnnStatus_t",
new_error_type: "cudnnError_",
error_prefix: ("CUDNN_STATUS_", "ERROR_"),
success: ("CUDNN_STATUS_SUCCESS", "SUCCESS"),
};
let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap(); let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap();
let cudnn9_types = generate_types_library_impl(&cudnn9_module); let cudnn9_types = generate_types_library_impl(Some(convert_options.clone()), &cudnn9_module);
let mut current_dir = PathBuf::from(file!()); let mut current_dir = PathBuf::from(file!());
current_dir.pop(); current_dir.pop();
let cudnn8 = new_builder() let cudnn8 = new_builder()
@ -263,7 +289,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
.unwrap() .unwrap()
.to_string(); .to_string();
let cudnn8_module: syn::File = syn::parse_str(&cudnn8).unwrap(); let cudnn8_module: syn::File = syn::parse_str(&cudnn8).unwrap();
let cudnn8_types = generate_types_library_impl(&cudnn8_module); let cudnn8_types = generate_types_library_impl(Some(convert_options), &cudnn8_module);
merge_types( merge_types(
&crate_root, &crate_root,
&["..", "cuda_types", "src", "cudnn.rs"], &["..", "cuda_types", "src", "cudnn.rs"],
@ -624,6 +650,7 @@ fn generate_cublas(crate_root: &PathBuf) {
.allowlist_function("^cublas.*") .allowlist_function("^cublas.*")
.allowlist_var("^CUBLAS_.*") .allowlist_var("^CUBLAS_.*")
.must_use_type("cublasStatus_t") .must_use_type("cublasStatus_t")
.constified_enum("cublasStatus_t")
.allowlist_recursively(false) .allowlist_recursively(false)
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"]) .clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
.generate() .generate()
@ -636,7 +663,15 @@ fn generate_cublas(crate_root: &PathBuf) {
&["..", "cuda_macros", "src", "cublas.rs"], &["..", "cuda_macros", "src", "cublas.rs"],
&module, &module,
); );
let convert_options = ConvertIntoRustResultOptions {
type_: "cublasStatus_t",
underlying_type: "cublasStatus_t",
new_error_type: "cublasError_t",
error_prefix: ("CUBLAS_STATUS_", "ERROR_"),
success: ("CUBLAS_STATUS_SUCCESS", "SUCCESS"),
};
generate_types_library( generate_types_library(
Some(convert_options),
None, None,
&crate_root, &crate_root,
&["..", "cuda_types", "src", "cublas.rs"], &["..", "cuda_types", "src", "cublas.rs"],
@ -708,6 +743,7 @@ fn generate_cublaslt(crate_root: &PathBuf) {
&module_blas, &module_blas,
); );
generate_types_library( generate_types_library(
None,
Some(LibraryOverride::CuBlasLt), Some(LibraryOverride::CuBlasLt),
&crate_root, &crate_root,
&["..", "cuda_types", "src", "cublaslt.rs"], &["..", "cuda_types", "src", "cublaslt.rs"],
@ -782,32 +818,22 @@ fn generate_ml(crate_root: &PathBuf) {
.generate() .generate()
.unwrap() .unwrap()
.to_string(); .to_string();
let mut module: syn::File = syn::parse_str(&ml_header).unwrap(); let module: syn::File = syn::parse_str(&ml_header).unwrap();
let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
type_: "nvmlReturn_t",
underlying_type: "nvmlReturn_enum",
new_error_type: "nvmlError_t",
error_prefix: ("NVML_ERROR_", "ERROR_"),
success: ("NVML_SUCCESS", "SUCCESS"),
});
module.items = module
.items
.into_iter()
.filter_map(|item| match item {
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
item => Some(item),
})
.collect::<Vec<_>>();
converter.flush(&mut module.items);
generate_functions( generate_functions(
&crate_root, &crate_root,
"nvml", "nvml",
&["..", "cuda_macros", "src", "nvml.rs"], &["..", "cuda_macros", "src", "nvml.rs"],
&module, &module,
); );
let convert_options = ConvertIntoRustResultOptions {
type_: "nvmlReturn_t",
underlying_type: "nvmlReturn_enum",
new_error_type: "nvmlError_t",
error_prefix: ("NVML_ERROR_", "ERROR_"),
success: ("NVML_SUCCESS", "SUCCESS"),
};
generate_types_library( generate_types_library(
Some(convert_options),
None, None,
&crate_root, &crate_root,
&["..", "cuda_types", "src", "nvml.rs"], &["..", "cuda_types", "src", "nvml.rs"],
@ -816,12 +842,13 @@ fn generate_ml(crate_root: &PathBuf) {
} }
fn generate_types_library( fn generate_types_library(
convert_options: Option<ConvertIntoRustResultOptions>,
override_: Option<LibraryOverride>, override_: Option<LibraryOverride>,
crate_root: &PathBuf, crate_root: &PathBuf,
path: &[&str], path: &[&str],
module: &syn::File, module: &syn::File,
) { ) {
let module = generate_types_library_impl(module); let module = generate_types_library_impl(convert_options, module);
let mut output = crate_root.clone(); let mut output = crate_root.clone();
output.extend(path); output.extend(path);
let mut text = let mut text =
@ -846,7 +873,10 @@ enum LibraryOverride {
CuFft, CuFft,
} }
fn generate_types_library_impl(module: &syn::File) -> syn::File { fn generate_types_library_impl(
convert_options: Option<ConvertIntoRustResultOptions>,
module: &syn::File,
) -> syn::File {
let known_reexports: Punctuated<syn::Item, syn::parse::Nothing> = parse_quote! { let known_reexports: Punctuated<syn::Item, syn::parse::Nothing> = parse_quote! {
pub type __half = u16; pub type __half = u16;
pub type __nv_bfloat16 = u16; pub type __nv_bfloat16 = u16;
@ -860,11 +890,28 @@ fn generate_types_library_impl(module: &syn::File) -> syn::File {
pub type cudaAsyncNotificationType = super::cuda::CUasyncNotificationType_enum; pub type cudaAsyncNotificationType = super::cuda::CUasyncNotificationType_enum;
pub type cudaGraph_t = super::cuda::CUgraph; pub type cudaGraph_t = super::cuda::CUgraph;
}; };
let non_fn = module.items.iter().filter_map(|item| match item { let remove_functions = |item| match item {
Item::ForeignMod(_) => None, Item::ForeignMod(_) => None,
_ => Some(item), _ => Some(item),
}); };
let items = known_reexports.iter().chain(non_fn); let non_fn = if let Some(options) = convert_options {
let mut converter = ConvertIntoRustResult::new(options);
let mut non_fn = converter
.convert(module.items.clone())
.filter_map(remove_functions)
.collect::<Vec<_>>();
converter.flush(&mut non_fn);
non_fn
} else {
let non_fn = module
.items
.clone()
.into_iter()
.filter_map(remove_functions)
.collect::<Vec<_>>();
non_fn
};
let items = known_reexports.into_iter().chain(non_fn);
parse_quote! { parse_quote! {
#(#items)* #(#items)*
} }
@ -894,16 +941,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
error_prefix: ("hipError", "Error"), error_prefix: ("hipError", "Error"),
success: ("hipSuccess", "Success"), success: ("hipSuccess", "Success"),
}); });
module.items = module module.items = converter.convert(module.items).collect::<Vec<Item>>();
.items
.into_iter()
.filter_map(|item| match item {
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
item => Some(item),
})
.collect::<Vec<_>>();
converter.flush(&mut module.items); converter.flush(&mut module.items);
add_send_sync( add_send_sync(
&mut module.items, &mut module.items,
@ -1004,14 +1042,10 @@ fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) {
error_prefix: ("CUDA_ERROR_", "ERROR_"), error_prefix: ("CUDA_ERROR_", "ERROR_"),
success: ("CUDA_SUCCESS", "SUCCESS"), success: ("CUDA_SUCCESS", "SUCCESS"),
}); });
module.items = module module.items = converter
.items .convert(module.items)
.into_iter()
.filter_map(|item| match item { .filter_map(|item| match item {
Item::ForeignMod(_) => None, Item::ForeignMod(_) => None,
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
Item::Struct(mut struct_) => { Item::Struct(mut struct_) => {
let ident_string = struct_.ident.to_string(); let ident_string = struct_.ident.to_string();
match &*ident_string { match &*ident_string {
@ -1064,6 +1098,7 @@ fn write_rust_to_file(path: impl AsRef<std::path::Path>, content: &str) {
file.write(content.as_bytes()).unwrap(); file.write(content.as_bytes()).unwrap();
} }
#[derive(Clone)]
struct ConvertIntoRustResultOptions { struct ConvertIntoRustResultOptions {
type_: &'static str, type_: &'static str,
underlying_type: &'static str, underlying_type: &'static str,
@ -1165,6 +1200,15 @@ impl ConvertIntoRustResult {
Some(type_) Some(type_)
} }
} }
fn convert(&mut self, items: Vec<Item>) -> impl Iterator<Item = Item> + use<'_> {
items.into_iter().filter_map(|item| match item {
Item::Const(const_) => self.get_const(const_).map(Item::Const),
Item::Use(use_) => self.get_use(use_).map(Item::Use),
Item::Type(type_) => self.get_type(type_).map(Item::Type),
item => Some(item),
})
}
} }
struct FixAbi; struct FixAbi;
@ -1324,6 +1368,7 @@ fn generate_display_perflib(
.iter() .iter()
.filter_map(|i| cuda_derive_display_trait_for_item(types_crate, &mut derive_state, i)) .filter_map(|i| cuda_derive_display_trait_for_item(types_crate, &mut derive_state, i))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// TODO: derive result display trait (look at curesult_display_trait)
let mut output = output.clone(); let mut output = output.clone();
output.extend(path); output.extend(path);
write_rust_to_file( write_rust_to_file(

View File

@ -1,11 +1,11 @@
use cuda_types::cufft::cufftResult_t; use cuda_types::cufft::*;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
pub(crate) fn unimplemented() -> cufftResult_t { pub(crate) fn unimplemented() -> cufftResult {
unimplemented!() unimplemented!()
} }
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
pub(crate) fn unimplemented() -> cufftResult_t { pub(crate) fn unimplemented() -> cufftResult {
cufftResult_t::CUFFT_NOT_SUPPORTED cufftResult::ERROR_NOT_SUPPORTED
} }