Slight refactor of ConvertIntoRustResult

This commit is contained in:
Violet
2025-07-28 01:24:05 +00:00
parent f192dd317a
commit db9ae4b430

View File

@ -783,14 +783,13 @@ fn generate_ml(crate_root: &PathBuf) {
.unwrap()
.to_string();
let mut module: syn::File = syn::parse_str(&ml_header).unwrap();
let mut converter = ConvertIntoRustResult {
let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
type_: "nvmlReturn_t",
underlying_type: "nvmlReturn_enum",
new_error_type: "nvmlError_t",
error_prefix: ("NVML_ERROR_", "ERROR_"),
success: ("NVML_SUCCESS", "SUCCESS"),
constants: Vec::new(),
};
});
module.items = module
.items
.into_iter()
@ -888,14 +887,13 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
.unwrap()
.to_string();
let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap();
let mut converter = ConvertIntoRustResult {
let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
type_: "hipError_t",
underlying_type: "hipError_t",
new_error_type: "hipErrorCode_t",
error_prefix: ("hipError", "Error"),
success: ("hipSuccess", "Success"),
constants: Vec::new(),
};
});
module.items = module
.items
.into_iter()
@ -999,14 +997,13 @@ fn generate_functions(
fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) {
let mut module = module.clone();
let mut converter = ConvertIntoRustResult {
let mut converter = ConvertIntoRustResult::new(ConvertIntoRustResultOptions {
type_: "CUresult",
underlying_type: "cudaError_enum",
new_error_type: "CUerror",
error_prefix: ("CUDA_ERROR_", "ERROR_"),
success: ("CUDA_SUCCESS", "SUCCESS"),
constants: Vec::new(),
};
});
module.items = module
.items
.into_iter()
@ -1067,19 +1064,30 @@ fn write_rust_to_file(path: impl AsRef<std::path::Path>, content: &str) {
file.write(content.as_bytes()).unwrap();
}
struct ConvertIntoRustResult {
struct ConvertIntoRustResultOptions {
type_: &'static str,
underlying_type: &'static str,
new_error_type: &'static str,
error_prefix: (&'static str, &'static str),
success: (&'static str, &'static str),
}
struct ConvertIntoRustResult {
options: ConvertIntoRustResultOptions,
constants: Vec<syn::ItemConst>,
}
impl ConvertIntoRustResult {
fn new(options: ConvertIntoRustResultOptions) -> Self {
Self {
options,
constants: vec![],
}
}
fn get_const(&mut self, const_: syn::ItemConst) -> Option<syn::ItemConst> {
let name = const_.ident.to_string();
if name.starts_with(self.underlying_type) {
if name.starts_with(self.options.underlying_type) {
self.constants.push(const_);
None
} else {
@ -1090,7 +1098,7 @@ impl ConvertIntoRustResult {
fn get_use(&mut self, use_: ItemUse) -> Option<ItemUse> {
if let UseTree::Path(ref path) = use_.tree {
if let UseTree::Rename(ref rename) = &*path.tree {
if rename.rename == self.type_ {
if rename.rename == self.options.type_ {
return None;
}
}
@ -1099,22 +1107,26 @@ impl ConvertIntoRustResult {
}
fn flush(self, items: &mut Vec<Item>) {
let type_ = format_ident!("{}", self.type_);
let type_trait = format_ident!("{}Consts", self.type_);
let new_error_type = format_ident!("{}", self.new_error_type);
let success = format_ident!("{}", self.success.1);
let type_ = format_ident!("{}", self.options.type_);
let type_trait = format_ident!("{}Consts", self.options.type_);
let new_error_type = format_ident!("{}", self.options.new_error_type);
let success = format_ident!("{}", self.options.success.1);
let mut result_variants = Vec::new();
let mut error_variants = Vec::new();
for const_ in self.constants.iter() {
let ident = const_.ident.to_string();
if ident.ends_with(self.success.0) {
if ident.ends_with(self.options.success.0) {
result_variants.push(quote! {
const #success: #type_ = #type_::Ok(());
});
} else {
let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len();
let variant_ident =
format_ident!("{}{}", self.error_prefix.1, &ident[old_prefix_len..]);
let old_prefix_len =
self.options.underlying_type.len() + 1 + self.options.error_prefix.0.len();
let variant_ident = format_ident!(
"{}{}",
self.options.error_prefix.1,
&ident[old_prefix_len..]
);
let error_ident = format_ident!("{}", &ident[old_prefix_len..]);
let expr = &const_.expr;
result_variants.push(quote! {
@ -1147,7 +1159,7 @@ impl ConvertIntoRustResult {
}
fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> {
if type_.ident.to_string() == self.type_ {
if type_.ident.to_string() == self.options.type_ {
None
} else {
Some(type_)