Slight refactor of ConvertIntoRustResult

This commit is contained in:
Violet
2025-07-28 01:24:05 +00:00
parent f192dd317a
commit db9ae4b430

View File

@ -783,14 +783,13 @@ fn generate_ml(crate_root: &PathBuf) {
.unwrap() .unwrap()
.to_string(); .to_string();
let mut module: syn::File = syn::parse_str(&ml_header).unwrap(); let mut module: syn::File = syn::parse_str(&ml_header).unwrap();
let mut converter = ConvertIntoRustResult { let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
type_: "nvmlReturn_t", type_: "nvmlReturn_t",
underlying_type: "nvmlReturn_enum", underlying_type: "nvmlReturn_enum",
new_error_type: "nvmlError_t", new_error_type: "nvmlError_t",
error_prefix: ("NVML_ERROR_", "ERROR_"), error_prefix: ("NVML_ERROR_", "ERROR_"),
success: ("NVML_SUCCESS", "SUCCESS"), success: ("NVML_SUCCESS", "SUCCESS"),
constants: Vec::new(), });
};
module.items = module module.items = module
.items .items
.into_iter() .into_iter()
@ -888,14 +887,13 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
.unwrap() .unwrap()
.to_string(); .to_string();
let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap(); let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap();
let mut converter = ConvertIntoRustResult { let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
type_: "hipError_t", type_: "hipError_t",
underlying_type: "hipError_t", underlying_type: "hipError_t",
new_error_type: "hipErrorCode_t", new_error_type: "hipErrorCode_t",
error_prefix: ("hipError", "Error"), error_prefix: ("hipError", "Error"),
success: ("hipSuccess", "Success"), success: ("hipSuccess", "Success"),
constants: Vec::new(), });
};
module.items = module module.items = module
.items .items
.into_iter() .into_iter()
@ -999,14 +997,13 @@ fn generate_functions(
fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) { fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) {
let mut module = module.clone(); let mut module = module.clone();
let mut converter = ConvertIntoRustResult { let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
type_: "CUresult", type_: "CUresult",
underlying_type: "cudaError_enum", underlying_type: "cudaError_enum",
new_error_type: "CUerror", new_error_type: "CUerror",
error_prefix: ("CUDA_ERROR_", "ERROR_"), error_prefix: ("CUDA_ERROR_", "ERROR_"),
success: ("CUDA_SUCCESS", "SUCCESS"), success: ("CUDA_SUCCESS", "SUCCESS"),
constants: Vec::new(), });
};
module.items = module module.items = module
.items .items
.into_iter() .into_iter()
@ -1067,19 +1064,30 @@ fn write_rust_to_file(path: impl AsRef<std::path::Path>, content: &str) {
file.write(content.as_bytes()).unwrap(); file.write(content.as_bytes()).unwrap();
} }
struct ConvertIntoRustResult { struct ConvertIntoRustResultOptions {
type_: &'static str, type_: &'static str,
underlying_type: &'static str, underlying_type: &'static str,
new_error_type: &'static str, new_error_type: &'static str,
error_prefix: (&'static str, &'static str), error_prefix: (&'static str, &'static str),
success: (&'static str, &'static str), success: (&'static str, &'static str),
}
struct ConvertIntoRustResult {
options: ConvertIntoRustResultOptions,
constants: Vec<syn::ItemConst>, constants: Vec<syn::ItemConst>,
} }
impl ConvertIntoRustResult { impl ConvertIntoRustResult {
fn new(options: ConvertIntoRustResultOptions) -> Self {
Self {
options,
constants: vec![],
}
}
fn get_const(&mut self, const_: syn::ItemConst) -> Option<syn::ItemConst> { fn get_const(&mut self, const_: syn::ItemConst) -> Option<syn::ItemConst> {
let name = const_.ident.to_string(); let name = const_.ident.to_string();
if name.starts_with(self.underlying_type) { if name.starts_with(self.options.underlying_type) {
self.constants.push(const_); self.constants.push(const_);
None None
} else { } else {
@ -1090,7 +1098,7 @@ impl ConvertIntoRustResult {
fn get_use(&mut self, use_: ItemUse) -> Option<ItemUse> { fn get_use(&mut self, use_: ItemUse) -> Option<ItemUse> {
if let UseTree::Path(ref path) = use_.tree { if let UseTree::Path(ref path) = use_.tree {
if let UseTree::Rename(ref rename) = &*path.tree { if let UseTree::Rename(ref rename) = &*path.tree {
if rename.rename == self.type_ { if rename.rename == self.options.type_ {
return None; return None;
} }
} }
@ -1099,22 +1107,26 @@ impl ConvertIntoRustResult {
} }
fn flush(self, items: &mut Vec<Item>) { fn flush(self, items: &mut Vec<Item>) {
let type_ = format_ident!("{}", self.type_); let type_ = format_ident!("{}", self.options.type_);
let type_trait = format_ident!("{}Consts", self.type_); let type_trait = format_ident!("{}Consts", self.options.type_);
let new_error_type = format_ident!("{}", self.new_error_type); let new_error_type = format_ident!("{}", self.options.new_error_type);
let success = format_ident!("{}", self.success.1); let success = format_ident!("{}", self.options.success.1);
let mut result_variants = Vec::new(); let mut result_variants = Vec::new();
let mut error_variants = Vec::new(); let mut error_variants = Vec::new();
for const_ in self.constants.iter() { for const_ in self.constants.iter() {
let ident = const_.ident.to_string(); let ident = const_.ident.to_string();
if ident.ends_with(self.success.0) { if ident.ends_with(self.options.success.0) {
result_variants.push(quote! { result_variants.push(quote! {
const #success: #type_ = #type_::Ok(()); const #success: #type_ = #type_::Ok(());
}); });
} else { } else {
let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len(); let old_prefix_len =
let variant_ident = self.options.underlying_type.len() + 1 + self.options.error_prefix.0.len();
format_ident!("{}{}", self.error_prefix.1, &ident[old_prefix_len..]); let variant_ident = format_ident!(
"{}{}",
self.options.error_prefix.1,
&ident[old_prefix_len..]
);
let error_ident = format_ident!("{}", &ident[old_prefix_len..]); let error_ident = format_ident!("{}", &ident[old_prefix_len..]);
let expr = &const_.expr; let expr = &const_.expr;
result_variants.push(quote! { result_variants.push(quote! {
@ -1147,7 +1159,7 @@ impl ConvertIntoRustResult {
} }
fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> { fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> {
if type_.ident.to_string() == self.type_ { if type_.ident.to_string() == self.options.type_ {
None None
} else { } else {
Some(type_) Some(type_)