mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:57:43 +03:00
Rustify performance library return types
This commit is contained in:
@ -35,7 +35,7 @@ fn main() {
|
|||||||
generate_cublaslt(&crate_root);
|
generate_cublaslt(&crate_root);
|
||||||
generate_cufft(&crate_root);
|
generate_cufft(&crate_root);
|
||||||
generate_cusparse(&crate_root);
|
generate_cusparse(&crate_root);
|
||||||
generate_cudnn(&crate_root);
|
// generate_cudnn(&crate_root);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_process_address_table(crate_root: &PathBuf, mut cuda_fns: Vec<Ident>) {
|
fn generate_process_address_table(crate_root: &PathBuf, mut cuda_fns: Vec<Ident>) {
|
||||||
@ -151,6 +151,7 @@ fn generate_cufft(crate_root: &PathBuf) {
|
|||||||
.allowlist_function("^cufft.*")
|
.allowlist_function("^cufft.*")
|
||||||
.allowlist_var("^CUFFT_.*")
|
.allowlist_var("^CUFFT_.*")
|
||||||
.must_use_type("cufftResult_t")
|
.must_use_type("cufftResult_t")
|
||||||
|
.constified_enum("cufftResult_t")
|
||||||
.allowlist_recursively(false)
|
.allowlist_recursively(false)
|
||||||
.clang_args(["-I/usr/local/cuda/include"])
|
.clang_args(["-I/usr/local/cuda/include"])
|
||||||
.generate()
|
.generate()
|
||||||
@ -163,7 +164,15 @@ fn generate_cufft(crate_root: &PathBuf) {
|
|||||||
&["..", "cuda_macros", "src", "cufft.rs"],
|
&["..", "cuda_macros", "src", "cufft.rs"],
|
||||||
&module,
|
&module,
|
||||||
);
|
);
|
||||||
|
let convert_options = ConvertIntoRustResultOptions {
|
||||||
|
type_: "cufftResult",
|
||||||
|
underlying_type: "cufftResult_t",
|
||||||
|
new_error_type: "cufftError_t",
|
||||||
|
error_prefix: ("CUFFT_", "ERROR_"),
|
||||||
|
success: ("CUFFT_SUCCESS", "SUCCESS"),
|
||||||
|
};
|
||||||
generate_types_library(
|
generate_types_library(
|
||||||
|
Some(convert_options),
|
||||||
Some(LibraryOverride::CuFft),
|
Some(LibraryOverride::CuFft),
|
||||||
&crate_root,
|
&crate_root,
|
||||||
&["..", "cuda_types", "src", "cufft.rs"],
|
&["..", "cuda_types", "src", "cufft.rs"],
|
||||||
@ -206,6 +215,7 @@ fn generate_cusparse(crate_root: &PathBuf) {
|
|||||||
.allowlist_function("^cusparse.*")
|
.allowlist_function("^cusparse.*")
|
||||||
.allowlist_var("^CUSPARSE_.*")
|
.allowlist_var("^CUSPARSE_.*")
|
||||||
.must_use_type("cusparseStatus_t")
|
.must_use_type("cusparseStatus_t")
|
||||||
|
.constified_enum("cusparseStatus_t")
|
||||||
.allowlist_recursively(false)
|
.allowlist_recursively(false)
|
||||||
.clang_args(["-I/usr/local/cuda/include"])
|
.clang_args(["-I/usr/local/cuda/include"])
|
||||||
.generate()
|
.generate()
|
||||||
@ -218,7 +228,15 @@ fn generate_cusparse(crate_root: &PathBuf) {
|
|||||||
&["..", "cuda_macros", "src", "cusparse.rs"],
|
&["..", "cuda_macros", "src", "cusparse.rs"],
|
||||||
&module,
|
&module,
|
||||||
);
|
);
|
||||||
|
let convert_options = ConvertIntoRustResultOptions {
|
||||||
|
type_: "cusparseStatus_t",
|
||||||
|
underlying_type: "cusparseStatus_t",
|
||||||
|
new_error_type: "cusparseError_t",
|
||||||
|
error_prefix: ("CUSPARSE_STATUS_", "ERROR_"),
|
||||||
|
success: ("CUSPARSE_STATUS_SUCCESS", "SUCCESS"),
|
||||||
|
};
|
||||||
generate_types_library(
|
generate_types_library(
|
||||||
|
Some(convert_options),
|
||||||
None,
|
None,
|
||||||
&crate_root,
|
&crate_root,
|
||||||
&["..", "cuda_types", "src", "cusparse.rs"],
|
&["..", "cuda_types", "src", "cusparse.rs"],
|
||||||
@ -239,13 +257,21 @@ fn generate_cudnn(crate_root: &PathBuf) {
|
|||||||
.allowlist_function("^cudnn.*")
|
.allowlist_function("^cudnn.*")
|
||||||
.allowlist_var("^CUDNN_.*")
|
.allowlist_var("^CUDNN_.*")
|
||||||
.must_use_type("cudnnStatus_t")
|
.must_use_type("cudnnStatus_t")
|
||||||
|
.constified_enum("cudnnStatus_t")
|
||||||
.allowlist_recursively(false)
|
.allowlist_recursively(false)
|
||||||
.clang_args(["-I/usr/local/cuda/include"])
|
.clang_args(["-I/usr/local/cuda/include"])
|
||||||
.generate()
|
.generate()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_string();
|
.to_string();
|
||||||
|
let convert_options = ConvertIntoRustResultOptions {
|
||||||
|
type_: "cudnnStatus_t",
|
||||||
|
underlying_type: "cudnnStatus_t",
|
||||||
|
new_error_type: "cudnnError_",
|
||||||
|
error_prefix: ("CUDNN_STATUS_", "ERROR_"),
|
||||||
|
success: ("CUDNN_STATUS_SUCCESS", "SUCCESS"),
|
||||||
|
};
|
||||||
let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap();
|
let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap();
|
||||||
let cudnn9_types = generate_types_library_impl(&cudnn9_module);
|
let cudnn9_types = generate_types_library_impl(Some(convert_options.clone()), &cudnn9_module);
|
||||||
let mut current_dir = PathBuf::from(file!());
|
let mut current_dir = PathBuf::from(file!());
|
||||||
current_dir.pop();
|
current_dir.pop();
|
||||||
let cudnn8 = new_builder()
|
let cudnn8 = new_builder()
|
||||||
@ -263,7 +289,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.to_string();
|
.to_string();
|
||||||
let cudnn8_module: syn::File = syn::parse_str(&cudnn8).unwrap();
|
let cudnn8_module: syn::File = syn::parse_str(&cudnn8).unwrap();
|
||||||
let cudnn8_types = generate_types_library_impl(&cudnn8_module);
|
let cudnn8_types = generate_types_library_impl(Some(convert_options), &cudnn8_module);
|
||||||
merge_types(
|
merge_types(
|
||||||
&crate_root,
|
&crate_root,
|
||||||
&["..", "cuda_types", "src", "cudnn.rs"],
|
&["..", "cuda_types", "src", "cudnn.rs"],
|
||||||
@ -624,6 +650,7 @@ fn generate_cublas(crate_root: &PathBuf) {
|
|||||||
.allowlist_function("^cublas.*")
|
.allowlist_function("^cublas.*")
|
||||||
.allowlist_var("^CUBLAS_.*")
|
.allowlist_var("^CUBLAS_.*")
|
||||||
.must_use_type("cublasStatus_t")
|
.must_use_type("cublasStatus_t")
|
||||||
|
.constified_enum("cublasStatus_t")
|
||||||
.allowlist_recursively(false)
|
.allowlist_recursively(false)
|
||||||
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
|
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
|
||||||
.generate()
|
.generate()
|
||||||
@ -636,7 +663,15 @@ fn generate_cublas(crate_root: &PathBuf) {
|
|||||||
&["..", "cuda_macros", "src", "cublas.rs"],
|
&["..", "cuda_macros", "src", "cublas.rs"],
|
||||||
&module,
|
&module,
|
||||||
);
|
);
|
||||||
|
let convert_options = ConvertIntoRustResultOptions {
|
||||||
|
type_: "cublasStatus_t",
|
||||||
|
underlying_type: "cublasStatus_t",
|
||||||
|
new_error_type: "cublasError_t",
|
||||||
|
error_prefix: ("CUBLAS_STATUS_", "ERROR_"),
|
||||||
|
success: ("CUBLAS_STATUS_SUCCESS", "SUCCESS"),
|
||||||
|
};
|
||||||
generate_types_library(
|
generate_types_library(
|
||||||
|
Some(convert_options),
|
||||||
None,
|
None,
|
||||||
&crate_root,
|
&crate_root,
|
||||||
&["..", "cuda_types", "src", "cublas.rs"],
|
&["..", "cuda_types", "src", "cublas.rs"],
|
||||||
@ -708,6 +743,7 @@ fn generate_cublaslt(crate_root: &PathBuf) {
|
|||||||
&module_blas,
|
&module_blas,
|
||||||
);
|
);
|
||||||
generate_types_library(
|
generate_types_library(
|
||||||
|
None,
|
||||||
Some(LibraryOverride::CuBlasLt),
|
Some(LibraryOverride::CuBlasLt),
|
||||||
&crate_root,
|
&crate_root,
|
||||||
&["..", "cuda_types", "src", "cublaslt.rs"],
|
&["..", "cuda_types", "src", "cublaslt.rs"],
|
||||||
@ -782,32 +818,22 @@ fn generate_ml(crate_root: &PathBuf) {
|
|||||||
.generate()
|
.generate()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_string();
|
.to_string();
|
||||||
let mut module: syn::File = syn::parse_str(&ml_header).unwrap();
|
let module: syn::File = syn::parse_str(&ml_header).unwrap();
|
||||||
let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
|
|
||||||
type_: "nvmlReturn_t",
|
|
||||||
underlying_type: "nvmlReturn_enum",
|
|
||||||
new_error_type: "nvmlError_t",
|
|
||||||
error_prefix: ("NVML_ERROR_", "ERROR_"),
|
|
||||||
success: ("NVML_SUCCESS", "SUCCESS"),
|
|
||||||
});
|
|
||||||
module.items = module
|
|
||||||
.items
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|item| match item {
|
|
||||||
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
|
|
||||||
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
|
|
||||||
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
|
|
||||||
item => Some(item),
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
converter.flush(&mut module.items);
|
|
||||||
generate_functions(
|
generate_functions(
|
||||||
&crate_root,
|
&crate_root,
|
||||||
"nvml",
|
"nvml",
|
||||||
&["..", "cuda_macros", "src", "nvml.rs"],
|
&["..", "cuda_macros", "src", "nvml.rs"],
|
||||||
&module,
|
&module,
|
||||||
);
|
);
|
||||||
|
let convert_options = ConvertIntoRustResultOptions {
|
||||||
|
type_: "nvmlReturn_t",
|
||||||
|
underlying_type: "nvmlReturn_enum",
|
||||||
|
new_error_type: "nvmlError_t",
|
||||||
|
error_prefix: ("NVML_ERROR_", "ERROR_"),
|
||||||
|
success: ("NVML_SUCCESS", "SUCCESS"),
|
||||||
|
};
|
||||||
generate_types_library(
|
generate_types_library(
|
||||||
|
Some(convert_options),
|
||||||
None,
|
None,
|
||||||
&crate_root,
|
&crate_root,
|
||||||
&["..", "cuda_types", "src", "nvml.rs"],
|
&["..", "cuda_types", "src", "nvml.rs"],
|
||||||
@ -816,12 +842,13 @@ fn generate_ml(crate_root: &PathBuf) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn generate_types_library(
|
fn generate_types_library(
|
||||||
|
convert_options: Option<ConvertIntoRustResultOptions>,
|
||||||
override_: Option<LibraryOverride>,
|
override_: Option<LibraryOverride>,
|
||||||
crate_root: &PathBuf,
|
crate_root: &PathBuf,
|
||||||
path: &[&str],
|
path: &[&str],
|
||||||
module: &syn::File,
|
module: &syn::File,
|
||||||
) {
|
) {
|
||||||
let module = generate_types_library_impl(module);
|
let module = generate_types_library_impl(convert_options, module);
|
||||||
let mut output = crate_root.clone();
|
let mut output = crate_root.clone();
|
||||||
output.extend(path);
|
output.extend(path);
|
||||||
let mut text =
|
let mut text =
|
||||||
@ -846,7 +873,10 @@ enum LibraryOverride {
|
|||||||
CuFft,
|
CuFft,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_types_library_impl(module: &syn::File) -> syn::File {
|
fn generate_types_library_impl(
|
||||||
|
convert_options: Option<ConvertIntoRustResultOptions>,
|
||||||
|
module: &syn::File,
|
||||||
|
) -> syn::File {
|
||||||
let known_reexports: Punctuated<syn::Item, syn::parse::Nothing> = parse_quote! {
|
let known_reexports: Punctuated<syn::Item, syn::parse::Nothing> = parse_quote! {
|
||||||
pub type __half = u16;
|
pub type __half = u16;
|
||||||
pub type __nv_bfloat16 = u16;
|
pub type __nv_bfloat16 = u16;
|
||||||
@ -860,11 +890,28 @@ fn generate_types_library_impl(module: &syn::File) -> syn::File {
|
|||||||
pub type cudaAsyncNotificationType = super::cuda::CUasyncNotificationType_enum;
|
pub type cudaAsyncNotificationType = super::cuda::CUasyncNotificationType_enum;
|
||||||
pub type cudaGraph_t = super::cuda::CUgraph;
|
pub type cudaGraph_t = super::cuda::CUgraph;
|
||||||
};
|
};
|
||||||
let non_fn = module.items.iter().filter_map(|item| match item {
|
let remove_functions = |item| match item {
|
||||||
Item::ForeignMod(_) => None,
|
Item::ForeignMod(_) => None,
|
||||||
_ => Some(item),
|
_ => Some(item),
|
||||||
});
|
};
|
||||||
let items = known_reexports.iter().chain(non_fn);
|
let non_fn = if let Some(options) = convert_options {
|
||||||
|
let mut converter = ConvertIntoRustResult::new(options);
|
||||||
|
let mut non_fn = converter
|
||||||
|
.convert(module.items.clone())
|
||||||
|
.filter_map(remove_functions)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
converter.flush(&mut non_fn);
|
||||||
|
non_fn
|
||||||
|
} else {
|
||||||
|
let non_fn = module
|
||||||
|
.items
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(remove_functions)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
non_fn
|
||||||
|
};
|
||||||
|
let items = known_reexports.into_iter().chain(non_fn);
|
||||||
parse_quote! {
|
parse_quote! {
|
||||||
#(#items)*
|
#(#items)*
|
||||||
}
|
}
|
||||||
@ -894,16 +941,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
|
|||||||
error_prefix: ("hipError", "Error"),
|
error_prefix: ("hipError", "Error"),
|
||||||
success: ("hipSuccess", "Success"),
|
success: ("hipSuccess", "Success"),
|
||||||
});
|
});
|
||||||
module.items = module
|
module.items = converter.convert(module.items).collect::<Vec<Item>>();
|
||||||
.items
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|item| match item {
|
|
||||||
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
|
|
||||||
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
|
|
||||||
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
|
|
||||||
item => Some(item),
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
converter.flush(&mut module.items);
|
converter.flush(&mut module.items);
|
||||||
add_send_sync(
|
add_send_sync(
|
||||||
&mut module.items,
|
&mut module.items,
|
||||||
@ -1004,14 +1042,10 @@ fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) {
|
|||||||
error_prefix: ("CUDA_ERROR_", "ERROR_"),
|
error_prefix: ("CUDA_ERROR_", "ERROR_"),
|
||||||
success: ("CUDA_SUCCESS", "SUCCESS"),
|
success: ("CUDA_SUCCESS", "SUCCESS"),
|
||||||
});
|
});
|
||||||
module.items = module
|
module.items = converter
|
||||||
.items
|
.convert(module.items)
|
||||||
.into_iter()
|
|
||||||
.filter_map(|item| match item {
|
.filter_map(|item| match item {
|
||||||
Item::ForeignMod(_) => None,
|
Item::ForeignMod(_) => None,
|
||||||
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
|
|
||||||
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
|
|
||||||
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
|
|
||||||
Item::Struct(mut struct_) => {
|
Item::Struct(mut struct_) => {
|
||||||
let ident_string = struct_.ident.to_string();
|
let ident_string = struct_.ident.to_string();
|
||||||
match &*ident_string {
|
match &*ident_string {
|
||||||
@ -1064,6 +1098,7 @@ fn write_rust_to_file(path: impl AsRef<std::path::Path>, content: &str) {
|
|||||||
file.write(content.as_bytes()).unwrap();
|
file.write(content.as_bytes()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
struct ConvertIntoRustResultOptions {
|
struct ConvertIntoRustResultOptions {
|
||||||
type_: &'static str,
|
type_: &'static str,
|
||||||
underlying_type: &'static str,
|
underlying_type: &'static str,
|
||||||
@ -1165,6 +1200,15 @@ impl ConvertIntoRustResult {
|
|||||||
Some(type_)
|
Some(type_)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert(&mut self, items: Vec<Item>) -> impl Iterator<Item = Item> + use<'_> {
|
||||||
|
items.into_iter().filter_map(|item| match item {
|
||||||
|
Item::Const(const_) => self.get_const(const_).map(Item::Const),
|
||||||
|
Item::Use(use_) => self.get_use(use_).map(Item::Use),
|
||||||
|
Item::Type(type_) => self.get_type(type_).map(Item::Type),
|
||||||
|
item => Some(item),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FixAbi;
|
struct FixAbi;
|
||||||
@ -1324,6 +1368,7 @@ fn generate_display_perflib(
|
|||||||
.iter()
|
.iter()
|
||||||
.filter_map(|i| cuda_derive_display_trait_for_item(types_crate, &mut derive_state, i))
|
.filter_map(|i| cuda_derive_display_trait_for_item(types_crate, &mut derive_state, i))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
// TODO: derive result display trait (look at curesult_display_trait)
|
||||||
let mut output = output.clone();
|
let mut output = output.clone();
|
||||||
output.extend(path);
|
output.extend(path);
|
||||||
write_rust_to_file(
|
write_rust_to_file(
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
use cuda_types::cufft::cufftResult_t;
|
use cuda_types::cufft::*;
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
pub(crate) fn unimplemented() -> cufftResult_t {
|
pub(crate) fn unimplemented() -> cufftResult {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
pub(crate) fn unimplemented() -> cufftResult_t {
|
pub(crate) fn unimplemented() -> cufftResult {
|
||||||
cufftResult_t::CUFFT_NOT_SUPPORTED
|
cufftResult::ERROR_NOT_SUPPORTED
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user