From db9ae4b430043bfea1df310d6a268173cdd46724 Mon Sep 17 00:00:00 2001 From: Violet Date: Mon, 28 Jul 2025 01:24:05 +0000 Subject: [PATCH] Slight refactor of ConvertIntoRustResult --- zluda_bindgen/src/main.rs | 54 ++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index cb35c14..d5816e8 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -783,14 +783,13 @@ fn generate_ml(crate_root: &PathBuf) { .unwrap() .to_string(); let mut module: syn::File = syn::parse_str(&ml_header).unwrap(); - let mut converter = ConvertIntoRustResult { + let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions { type_: "nvmlReturn_t", underlying_type: "nvmlReturn_enum", new_error_type: "nvmlError_t", error_prefix: ("NVML_ERROR_", "ERROR_"), success: ("NVML_SUCCESS", "SUCCESS"), - constants: Vec::new(), - }; + }); module.items = module .items .into_iter() @@ -888,14 +887,13 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { .unwrap() .to_string(); let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap(); - let mut converter = ConvertIntoRustResult { + let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions { type_: "hipError_t", underlying_type: "hipError_t", new_error_type: "hipErrorCode_t", error_prefix: ("hipError", "Error"), success: ("hipSuccess", "Success"), - constants: Vec::new(), - }; + }); module.items = module .items .into_iter() @@ -999,14 +997,13 @@ fn generate_functions( fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) { let mut module = module.clone(); - let mut converter = ConvertIntoRustResult { + let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions { type_: "CUresult", underlying_type: "cudaError_enum", new_error_type: "CUerror", error_prefix: ("CUDA_ERROR_", "ERROR_"), success: ("CUDA_SUCCESS", "SUCCESS"), - constants: Vec::new(), - }; + }); module.items = module .items .into_iter() @@ -1067,19 +1064,30 @@ fn write_rust_to_file(path: impl AsRef, content: &str) { file.write(content.as_bytes()).unwrap(); } -struct ConvertIntoRustResult { +struct ConvertIntoRustResultOptions { type_: &'static str, underlying_type: &'static str, new_error_type: &'static str, error_prefix: (&'static str, &'static str), success: (&'static str, &'static str), +} + +struct ConvertIntoRustResult { + options: ConvertIntoRustResultOptions, constants: Vec, } impl ConvertIntoRustResult { + fn new(options: ConvertIntoRustResultOptions) -> Self { + Self { + options, + constants: vec![], + } + } + fn get_const(&mut self, const_: syn::ItemConst) -> Option { let name = const_.ident.to_string(); - if name.starts_with(self.underlying_type) { + if name.starts_with(self.options.underlying_type) { self.constants.push(const_); None } else { @@ -1090,7 +1098,7 @@ impl ConvertIntoRustResult { fn get_use(&mut self, use_: ItemUse) -> Option { if let UseTree::Path(ref path) = use_.tree { if let UseTree::Rename(ref rename) = &*path.tree { - if rename.rename == self.type_ { + if rename.rename == self.options.type_ { return None; } } @@ -1099,22 +1107,26 @@ impl ConvertIntoRustResult { } fn flush(self, items: &mut Vec) { - let type_ = format_ident!("{}", self.type_); - let type_trait = format_ident!("{}Consts", self.type_); - let new_error_type = format_ident!("{}", self.new_error_type); - let success = format_ident!("{}", self.success.1); + let type_ = format_ident!("{}", self.options.type_); + let type_trait = format_ident!("{}Consts", self.options.type_); + let new_error_type = format_ident!("{}", self.options.new_error_type); + let success = format_ident!("{}", self.options.success.1); let mut result_variants = Vec::new(); let mut error_variants = Vec::new(); for const_ in self.constants.iter() { let ident = const_.ident.to_string(); - if ident.ends_with(self.success.0) { + if ident.ends_with(self.options.success.0) { result_variants.push(quote! { const #success: #type_ = #type_::Ok(()); }); } else { - let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len(); - let variant_ident = - format_ident!("{}{}", self.error_prefix.1, &ident[old_prefix_len..]); + let old_prefix_len = + self.options.underlying_type.len() + 1 + self.options.error_prefix.0.len(); + let variant_ident = format_ident!( + "{}{}", + self.options.error_prefix.1, + &ident[old_prefix_len..] + ); let error_ident = format_ident!("{}", &ident[old_prefix_len..]); let expr = &const_.expr; result_variants.push(quote! { @@ -1147,7 +1159,7 @@ impl ConvertIntoRustResult { } fn get_type(&self, type_: syn::ItemType) -> Option { - if type_.ident.to_string() == self.type_ { + if type_.ident.to_string() == self.options.type_ { None } else { Some(type_)