Implement From for result structs

This commit is contained in:
Violet
2025-07-28 05:49:13 +00:00
parent 7c129806a0
commit 91934ea522

View File

@ -170,6 +170,7 @@ fn generate_cufft(crate_root: &PathBuf) {
new_error_type: "cufftError_t",
error_prefix: ("CUFFT_", "ERROR_"),
success: ("CUFFT_SUCCESS", "SUCCESS"),
hip_type: None,
};
generate_types_library(
Some(&result_options),
@ -235,6 +236,7 @@ fn generate_cusparse(crate_root: &PathBuf) {
new_error_type: "cusparseError_t",
error_prefix: ("CUSPARSE_STATUS_", "ERROR_"),
success: ("CUSPARSE_STATUS_SUCCESS", "SUCCESS"),
hip_type: None,
};
generate_types_library(
Some(&result_options),
@ -271,6 +273,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
new_error_type: "cudnnError_",
error_prefix: ("CUDNN_STATUS_", "ERROR_"),
success: ("CUDNN_STATUS_SUCCESS", "SUCCESS"),
hip_type: None,
};
let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap();
let cudnn9_types = generate_types_library_impl(Some(&result_options), &cudnn9_module);
@ -672,6 +675,7 @@ fn generate_cublas(crate_root: &PathBuf) {
new_error_type: "cublasError_t",
error_prefix: ("CUBLAS_STATUS_", "ERROR_"),
success: ("CUBLAS_STATUS_SUCCESS", "SUCCESS"),
hip_type: Some(syn::parse_str("rocblas_sys::rocblas_error").unwrap()),
};
generate_types_library(
Some(&result_options),
@ -804,6 +808,7 @@ fn generate_cuda(crate_root: &PathBuf) -> Vec<Ident> {
new_error_type: "CUerror",
error_prefix: ("CUDA_ERROR_", "ERROR_"),
success: ("CUDA_SUCCESS", "SUCCESS"),
hip_type: Some(syn::parse_str("hip_runtime_sys::hipErrorCode_t").unwrap()),
};
generate_types_cuda(
&result_options,
@ -846,6 +851,7 @@ fn generate_ml(crate_root: &PathBuf) {
new_error_type: "nvmlError_t",
error_prefix: ("NVML_ERROR_", "ERROR_"),
success: ("NVML_SUCCESS", "SUCCESS"),
hip_type: None,
};
generate_types_library(
Some(&result_options),
@ -955,6 +961,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
new_error_type: "hipErrorCode_t",
error_prefix: ("hipError", "Error"),
success: ("hipSuccess", "Success"),
hip_type: None,
});
module.items = converter.convert(module.items).collect::<Vec<Item>>();
converter.flush(&mut module.items);
@ -1081,13 +1088,6 @@ fn generate_types_cuda(
})
.collect::<Vec<_>>();
converter.flush(&mut module.items);
module.items.push(parse_quote! {
impl From<hip_runtime_sys::hipErrorCode_t> for CUerror {
fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
Self(error.0)
}
}
});
add_send_sync(
&mut module.items,
&[
@ -1119,6 +1119,8 @@ struct ConvertIntoRustResultOptions {
new_error_type: &'static str,
error_prefix: (&'static str, &'static str),
success: (&'static str, &'static str),
// TODO: this should no longer be an Option once all hip perf libraries are present
hip_type: Option<Path>,
}
struct ConvertIntoRustResult {
@ -1205,6 +1207,13 @@ impl ConvertIntoRustResult {
};
};
items.extend(extra_items);
if let Some(hip_error_path) = self.options.hip_type {
items.push(parse_quote! {impl From<#hip_error_path> for #new_error_type {
fn from(error: #hip_error_path) -> Self {
Self(error.0)
}
}});
}
}
fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> {