mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:57:43 +03:00
Implement From for result structs
This commit is contained in:
@ -170,6 +170,7 @@ fn generate_cufft(crate_root: &PathBuf) {
|
||||
new_error_type: "cufftError_t",
|
||||
error_prefix: ("CUFFT_", "ERROR_"),
|
||||
success: ("CUFFT_SUCCESS", "SUCCESS"),
|
||||
hip_type: None,
|
||||
};
|
||||
generate_types_library(
|
||||
Some(&result_options),
|
||||
@ -235,6 +236,7 @@ fn generate_cusparse(crate_root: &PathBuf) {
|
||||
new_error_type: "cusparseError_t",
|
||||
error_prefix: ("CUSPARSE_STATUS_", "ERROR_"),
|
||||
success: ("CUSPARSE_STATUS_SUCCESS", "SUCCESS"),
|
||||
hip_type: None,
|
||||
};
|
||||
generate_types_library(
|
||||
Some(&result_options),
|
||||
@ -271,6 +273,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
|
||||
new_error_type: "cudnnError_",
|
||||
error_prefix: ("CUDNN_STATUS_", "ERROR_"),
|
||||
success: ("CUDNN_STATUS_SUCCESS", "SUCCESS"),
|
||||
hip_type: None,
|
||||
};
|
||||
let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap();
|
||||
let cudnn9_types = generate_types_library_impl(Some(&result_options), &cudnn9_module);
|
||||
@ -672,6 +675,7 @@ fn generate_cublas(crate_root: &PathBuf) {
|
||||
new_error_type: "cublasError_t",
|
||||
error_prefix: ("CUBLAS_STATUS_", "ERROR_"),
|
||||
success: ("CUBLAS_STATUS_SUCCESS", "SUCCESS"),
|
||||
hip_type: Some(syn::parse_str("rocblas_sys::rocblas_error").unwrap()),
|
||||
};
|
||||
generate_types_library(
|
||||
Some(&result_options),
|
||||
@ -804,6 +808,7 @@ fn generate_cuda(crate_root: &PathBuf) -> Vec<Ident> {
|
||||
new_error_type: "CUerror",
|
||||
error_prefix: ("CUDA_ERROR_", "ERROR_"),
|
||||
success: ("CUDA_SUCCESS", "SUCCESS"),
|
||||
hip_type: Some(syn::parse_str("hip_runtime_sys::hipErrorCode_t").unwrap()),
|
||||
};
|
||||
generate_types_cuda(
|
||||
&result_options,
|
||||
@ -846,6 +851,7 @@ fn generate_ml(crate_root: &PathBuf) {
|
||||
new_error_type: "nvmlError_t",
|
||||
error_prefix: ("NVML_ERROR_", "ERROR_"),
|
||||
success: ("NVML_SUCCESS", "SUCCESS"),
|
||||
hip_type: None,
|
||||
};
|
||||
generate_types_library(
|
||||
Some(&result_options),
|
||||
@ -955,6 +961,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
|
||||
new_error_type: "hipErrorCode_t",
|
||||
error_prefix: ("hipError", "Error"),
|
||||
success: ("hipSuccess", "Success"),
|
||||
hip_type: None,
|
||||
});
|
||||
module.items = converter.convert(module.items).collect::<Vec<Item>>();
|
||||
converter.flush(&mut module.items);
|
||||
@ -1081,13 +1088,6 @@ fn generate_types_cuda(
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
converter.flush(&mut module.items);
|
||||
module.items.push(parse_quote! {
|
||||
impl From<hip_runtime_sys::hipErrorCode_t> for CUerror {
|
||||
fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
|
||||
Self(error.0)
|
||||
}
|
||||
}
|
||||
});
|
||||
add_send_sync(
|
||||
&mut module.items,
|
||||
&[
|
||||
@ -1119,6 +1119,8 @@ struct ConvertIntoRustResultOptions {
|
||||
new_error_type: &'static str,
|
||||
error_prefix: (&'static str, &'static str),
|
||||
success: (&'static str, &'static str),
|
||||
// TODO: this should no longer be an Option once all hip perf libraries are present
|
||||
hip_type: Option<Path>,
|
||||
}
|
||||
|
||||
struct ConvertIntoRustResult {
|
||||
@ -1205,6 +1207,13 @@ impl ConvertIntoRustResult {
|
||||
};
|
||||
};
|
||||
items.extend(extra_items);
|
||||
if let Some(hip_error_path) = self.options.hip_type {
|
||||
items.push(parse_quote! {impl From<#hip_error_path> for #new_error_type {
|
||||
fn from(error: #hip_error_path) -> Self {
|
||||
Self(error.0)
|
||||
}
|
||||
}});
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> {
|
||||
|
Reference in New Issue
Block a user