mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:57:43 +03:00
Derive CudaDisplay trait for performance library result types
This commit is contained in:
@ -164,7 +164,7 @@ fn generate_cufft(crate_root: &PathBuf) {
|
||||
&["..", "cuda_macros", "src", "cufft.rs"],
|
||||
&module,
|
||||
);
|
||||
let convert_options = ConvertIntoRustResultOptions {
|
||||
let result_options = ConvertIntoRustResultOptions {
|
||||
type_: "cufftResult",
|
||||
underlying_type: "cufftResult_t",
|
||||
new_error_type: "cufftError_t",
|
||||
@ -172,13 +172,14 @@ fn generate_cufft(crate_root: &PathBuf) {
|
||||
success: ("CUFFT_SUCCESS", "SUCCESS"),
|
||||
};
|
||||
generate_types_library(
|
||||
Some(convert_options),
|
||||
Some(&result_options),
|
||||
Some(LibraryOverride::CuFft),
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "cufft.rs"],
|
||||
&module,
|
||||
);
|
||||
generate_display_perflib(
|
||||
Some(&result_options),
|
||||
&crate_root,
|
||||
&["..", "format", "src", "format_generated_fft.rs"],
|
||||
&["cuda_types", "cufft"],
|
||||
@ -228,7 +229,7 @@ fn generate_cusparse(crate_root: &PathBuf) {
|
||||
&["..", "cuda_macros", "src", "cusparse.rs"],
|
||||
&module,
|
||||
);
|
||||
let convert_options = ConvertIntoRustResultOptions {
|
||||
let result_options = ConvertIntoRustResultOptions {
|
||||
type_: "cusparseStatus_t",
|
||||
underlying_type: "cusparseStatus_t",
|
||||
new_error_type: "cusparseError_t",
|
||||
@ -236,13 +237,14 @@ fn generate_cusparse(crate_root: &PathBuf) {
|
||||
success: ("CUSPARSE_STATUS_SUCCESS", "SUCCESS"),
|
||||
};
|
||||
generate_types_library(
|
||||
Some(convert_options),
|
||||
Some(&result_options),
|
||||
None,
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "cusparse.rs"],
|
||||
&module,
|
||||
);
|
||||
generate_display_perflib(
|
||||
Some(&result_options),
|
||||
&crate_root,
|
||||
&["..", "format", "src", "format_generated_sparse.rs"],
|
||||
&["cuda_types", "cusparse"],
|
||||
@ -263,7 +265,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
|
||||
.generate()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let convert_options = ConvertIntoRustResultOptions {
|
||||
let result_options = ConvertIntoRustResultOptions {
|
||||
type_: "cudnnStatus_t",
|
||||
underlying_type: "cudnnStatus_t",
|
||||
new_error_type: "cudnnError_",
|
||||
@ -271,7 +273,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
|
||||
success: ("CUDNN_STATUS_SUCCESS", "SUCCESS"),
|
||||
};
|
||||
let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap();
|
||||
let cudnn9_types = generate_types_library_impl(Some(convert_options.clone()), &cudnn9_module);
|
||||
let cudnn9_types = generate_types_library_impl(Some(&result_options), &cudnn9_module);
|
||||
let mut current_dir = PathBuf::from(file!());
|
||||
current_dir.pop();
|
||||
let cudnn8 = new_builder()
|
||||
@ -289,7 +291,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let cudnn8_module: syn::File = syn::parse_str(&cudnn8).unwrap();
|
||||
let cudnn8_types = generate_types_library_impl(Some(convert_options), &cudnn8_module);
|
||||
let cudnn8_types = generate_types_library_impl(Some(&result_options), &cudnn8_module);
|
||||
merge_types(
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "cudnn.rs"],
|
||||
@ -311,6 +313,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
|
||||
&cudnn9_module,
|
||||
);
|
||||
generate_display_perflib(
|
||||
Some(&result_options),
|
||||
&crate_root,
|
||||
&["..", "format", "src", "format_generated_dnn9.rs"],
|
||||
&["cuda_types", "cudnn9"],
|
||||
@ -663,7 +666,7 @@ fn generate_cublas(crate_root: &PathBuf) {
|
||||
&["..", "cuda_macros", "src", "cublas.rs"],
|
||||
&module,
|
||||
);
|
||||
let convert_options = ConvertIntoRustResultOptions {
|
||||
let result_options = ConvertIntoRustResultOptions {
|
||||
type_: "cublasStatus_t",
|
||||
underlying_type: "cublasStatus_t",
|
||||
new_error_type: "cublasError_t",
|
||||
@ -671,13 +674,14 @@ fn generate_cublas(crate_root: &PathBuf) {
|
||||
success: ("CUBLAS_STATUS_SUCCESS", "SUCCESS"),
|
||||
};
|
||||
generate_types_library(
|
||||
Some(convert_options),
|
||||
Some(&result_options),
|
||||
None,
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "cublas.rs"],
|
||||
&module,
|
||||
);
|
||||
generate_display_perflib(
|
||||
Some(&result_options),
|
||||
&crate_root,
|
||||
&["..", "format", "src", "format_generated_blas.rs"],
|
||||
&["cuda_types", "cublas"],
|
||||
@ -750,12 +754,14 @@ fn generate_cublaslt(crate_root: &PathBuf) {
|
||||
&module_blas,
|
||||
);
|
||||
generate_display_perflib(
|
||||
None,
|
||||
&crate_root,
|
||||
&["..", "format", "src", "format_generated_blaslt.rs"],
|
||||
&["cuda_types", "cublaslt"],
|
||||
&module_blas,
|
||||
);
|
||||
generate_display_perflib(
|
||||
None,
|
||||
&crate_root,
|
||||
&["..", "format", "src", "format_generated_blaslt_internal.rs"],
|
||||
&["cuda_types", "cublaslt"],
|
||||
@ -792,12 +798,21 @@ fn generate_cuda(crate_root: &PathBuf) -> Vec<Ident> {
|
||||
&["..", "cuda_macros", "src", "cuda.rs"],
|
||||
&module,
|
||||
));
|
||||
let result_options = ConvertIntoRustResultOptions {
|
||||
type_: "CUresult",
|
||||
underlying_type: "cudaError_enum",
|
||||
new_error_type: "CUerror",
|
||||
error_prefix: ("CUDA_ERROR_", "ERROR_"),
|
||||
success: ("CUDA_SUCCESS", "SUCCESS"),
|
||||
};
|
||||
generate_types_cuda(
|
||||
&result_options,
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "cuda.rs"],
|
||||
&module,
|
||||
);
|
||||
generate_display_cuda(
|
||||
&result_options,
|
||||
&crate_root,
|
||||
&["..", "format", "src", "format_generated.rs"],
|
||||
&["cuda_types", "cuda"],
|
||||
@ -825,7 +840,7 @@ fn generate_ml(crate_root: &PathBuf) {
|
||||
&["..", "cuda_macros", "src", "nvml.rs"],
|
||||
&module,
|
||||
);
|
||||
let convert_options = ConvertIntoRustResultOptions {
|
||||
let result_options = ConvertIntoRustResultOptions {
|
||||
type_: "nvmlReturn_t",
|
||||
underlying_type: "nvmlReturn_enum",
|
||||
new_error_type: "nvmlError_t",
|
||||
@ -833,7 +848,7 @@ fn generate_ml(crate_root: &PathBuf) {
|
||||
success: ("NVML_SUCCESS", "SUCCESS"),
|
||||
};
|
||||
generate_types_library(
|
||||
Some(convert_options),
|
||||
Some(&result_options),
|
||||
None,
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "nvml.rs"],
|
||||
@ -842,13 +857,13 @@ fn generate_ml(crate_root: &PathBuf) {
|
||||
}
|
||||
|
||||
fn generate_types_library(
|
||||
convert_options: Option<ConvertIntoRustResultOptions>,
|
||||
result_options: Option<&ConvertIntoRustResultOptions>,
|
||||
override_: Option<LibraryOverride>,
|
||||
crate_root: &PathBuf,
|
||||
path: &[&str],
|
||||
module: &syn::File,
|
||||
) {
|
||||
let module = generate_types_library_impl(convert_options, module);
|
||||
let module = generate_types_library_impl(result_options, module);
|
||||
let mut output = crate_root.clone();
|
||||
output.extend(path);
|
||||
let mut text =
|
||||
@ -874,7 +889,7 @@ enum LibraryOverride {
|
||||
}
|
||||
|
||||
fn generate_types_library_impl(
|
||||
convert_options: Option<ConvertIntoRustResultOptions>,
|
||||
result_options: Option<&ConvertIntoRustResultOptions>,
|
||||
module: &syn::File,
|
||||
) -> syn::File {
|
||||
let known_reexports: Punctuated<syn::Item, syn::parse::Nothing> = parse_quote! {
|
||||
@ -894,8 +909,8 @@ fn generate_types_library_impl(
|
||||
Item::ForeignMod(_) => None,
|
||||
_ => Some(item),
|
||||
};
|
||||
let non_fn = if let Some(options) = convert_options {
|
||||
let mut converter = ConvertIntoRustResult::new(options);
|
||||
let non_fn = if let Some(options) = result_options {
|
||||
let mut converter = ConvertIntoRustResult::new(options.clone());
|
||||
let mut non_fn = converter
|
||||
.convert(module.items.clone())
|
||||
.filter_map(remove_functions)
|
||||
@ -1033,15 +1048,14 @@ fn generate_functions(
|
||||
*/
|
||||
}
|
||||
|
||||
fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) {
|
||||
fn generate_types_cuda(
|
||||
options: &ConvertIntoRustResultOptions,
|
||||
output: &PathBuf,
|
||||
path: &[&str],
|
||||
module: &syn::File,
|
||||
) {
|
||||
let mut module = module.clone();
|
||||
let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
|
||||
type_: "CUresult",
|
||||
underlying_type: "cudaError_enum",
|
||||
new_error_type: "CUerror",
|
||||
error_prefix: ("CUDA_ERROR_", "ERROR_"),
|
||||
success: ("CUDA_SUCCESS", "SUCCESS"),
|
||||
});
|
||||
let mut converter = ConvertIntoRustResult::new(options.clone());
|
||||
module.items = converter
|
||||
.convert(module.items)
|
||||
.filter_map(|item| match item {
|
||||
@ -1265,6 +1279,7 @@ impl VisitMut for ExplicitReturnType {
|
||||
}
|
||||
|
||||
fn generate_display_cuda(
|
||||
result_options: &ConvertIntoRustResultOptions,
|
||||
output: &PathBuf,
|
||||
path: &[&str],
|
||||
types_crate: &[&'static str],
|
||||
@ -1321,9 +1336,16 @@ fn generate_display_cuda(
|
||||
let mut items = module
|
||||
.items
|
||||
.iter()
|
||||
.filter_map(|i| cuda_derive_display_trait_for_item(types_crate, &mut derive_state, i))
|
||||
.filter_map(|i| {
|
||||
cuda_derive_display_trait_for_item(
|
||||
Some(result_options),
|
||||
types_crate,
|
||||
&mut derive_state,
|
||||
i,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
items.push(curesult_display_trait(&derive_state));
|
||||
items.push(result_display_trait(result_options, &derive_state));
|
||||
let mut output = output.clone();
|
||||
output.extend(path);
|
||||
write_rust_to_file(
|
||||
@ -1337,6 +1359,7 @@ fn generate_display_cuda(
|
||||
}
|
||||
|
||||
fn generate_display_perflib(
|
||||
result_options: Option<&ConvertIntoRustResultOptions>,
|
||||
output: &PathBuf,
|
||||
path: &[&str],
|
||||
types_crate: &[&'static str],
|
||||
@ -1363,12 +1386,16 @@ fn generate_display_perflib(
|
||||
&ignore_functions,
|
||||
&count_selectors,
|
||||
);
|
||||
let items = module
|
||||
let mut items = module
|
||||
.items
|
||||
.iter()
|
||||
.filter_map(|i| cuda_derive_display_trait_for_item(types_crate, &mut derive_state, i))
|
||||
.filter_map(|i| {
|
||||
cuda_derive_display_trait_for_item(result_options, types_crate, &mut derive_state, i)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
// TODO: derive result display trait (look at curesult_display_trait)
|
||||
if let Some(result_options) = result_options {
|
||||
items.push(result_display_trait(result_options, &derive_state));
|
||||
}
|
||||
let mut output = output.clone();
|
||||
output.extend(path);
|
||||
write_rust_to_file(
|
||||
@ -1439,6 +1466,7 @@ impl<'a> DeriveDisplayState<'a> {
|
||||
}
|
||||
|
||||
fn cuda_derive_display_trait_for_item<'a>(
|
||||
result_options: Option<&ConvertIntoRustResultOptions>,
|
||||
path: &[&str],
|
||||
state: &mut DeriveDisplayState<'a>,
|
||||
item: &'a Item,
|
||||
@ -1453,8 +1481,10 @@ fn cuda_derive_display_trait_for_item<'a>(
|
||||
};
|
||||
match item {
|
||||
Item::Const(const_) => {
|
||||
if const_.ty.to_token_stream().to_string() == "cudaError_enum" {
|
||||
state.result_variants.push(const_);
|
||||
if let Some(result_options) = result_options {
|
||||
if const_.ty.to_token_stream().to_string() == result_options.underlying_type {
|
||||
state.result_variants.push(const_);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@ -1657,11 +1687,21 @@ fn fn_arg_name(fn_arg: &FnArg) -> &Box<syn::Pat> {
|
||||
name
|
||||
}
|
||||
|
||||
fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item {
|
||||
fn result_display_trait(
|
||||
result_options: &ConvertIntoRustResultOptions,
|
||||
derive_state: &DeriveDisplayState,
|
||||
) -> syn::Item {
|
||||
let path = &derive_state.types_crate;
|
||||
|
||||
let type_ = Ident::new(result_options.type_, Span::call_site());
|
||||
|
||||
let success = result_options.success.0;
|
||||
let success_bstr = syn::LitByteStr::new(success.as_bytes(), Span::call_site());
|
||||
|
||||
let errors = derive_state.result_variants.iter().filter_map(|const_| {
|
||||
let prefix = "cudaError_enum_";
|
||||
let prefix = format!("{}_", result_options.underlying_type);
|
||||
let text = &const_.ident.to_string()[prefix.len()..];
|
||||
if text == "CUDA_SUCCESS" {
|
||||
if text == success {
|
||||
return None;
|
||||
}
|
||||
let expr = &const_.expr;
|
||||
@ -1670,10 +1710,10 @@ fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item {
|
||||
})
|
||||
});
|
||||
parse_quote! {
|
||||
impl crate::CudaDisplay for cuda_types::cuda::CUresult {
|
||||
impl crate::CudaDisplay for #path::#type_ {
|
||||
fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
|
||||
match self {
|
||||
Ok(()) => writer.write_all(b"CUDA_SUCCESS"),
|
||||
Ok(()) => writer.write_all(#success_bstr),
|
||||
Err(err) => {
|
||||
match err.0.get() {
|
||||
#(#errors)*
|
||||
|
Reference in New Issue
Block a user