From 99c36092bea198c90fd00f6eea83c728e497ff35 Mon Sep 17 00:00:00 2001 From: Violet Date: Thu, 31 Jul 2025 09:52:10 -0700 Subject: [PATCH] Use `FromCuda` in `zluda_blas` (#455) --- Cargo.lock | 1 + cuda_types/src/cublas.rs | 4 +++- format/src/format_generated_blas.rs | 6 +----- zluda_bindgen/src/main.rs | 1 + zluda_blas/Cargo.toml | 1 + zluda_blas/src/lib.rs | 17 ++++++++++++++++- zluda_common/src/lib.rs | 11 +++++++++-- 7 files changed, 32 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1291f17..574e9af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1816,6 +1816,7 @@ version = "0.0.0" dependencies = [ "cuda_macros", "cuda_types", + "zluda_common", ] [[package]] diff --git a/cuda_types/src/cublas.rs b/cuda_types/src/cublas.rs index dc422a3..cb9190d 100644 --- a/cuda_types/src/cublas.rs +++ b/cuda_types/src/cublas.rs @@ -284,7 +284,9 @@ pub struct cublasComputeType_t(pub ::core::ffi::c_uint); pub struct cublasContext { _unused: [u8; 0], } -pub type cublasHandle_t = *mut cublasContext; +#[repr(transparent)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct cublasHandle_t(pub *mut cublasContext); pub type cublasLogCallback = ::core::option::Option< unsafe extern "C" fn(msg: *const ::core::ffi::c_char), >; diff --git a/format/src/format_generated_blas.rs b/format/src/format_generated_blas.rs index 3c01f23..d57d664 100644 --- a/format/src/format_generated_blas.rs +++ b/format/src/format_generated_blas.rs @@ -348,11 +348,7 @@ impl crate::CudaDisplay for cuda_types::cublas::cublasHandle_t { _index: usize, writer: &mut (impl std::io::Write + ?Sized), ) -> std::io::Result<()> { - if self.is_null() { - writer.write_all(b"NULL") - } else { - write!(writer, "{:p}", *self) - } + write!(writer, "{:p}", self.0) } } pub fn write_cublasCreate_v2( diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index 4439683..210fdd9 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -658,6 +658,7 @@ fn generate_cublas(crate_root: &PathBuf) { .allowlist_var("^CUBLAS_.*") .must_use_type("cublasStatus_t") .constified_enum("cublasStatus_t") + .new_type_alias(r"^cublasHandle_t$") .allowlist_recursively(false) .clang_args(["-I/usr/local/cuda/include", "-x", "c++"]) .generate() diff --git a/zluda_blas/Cargo.toml b/zluda_blas/Cargo.toml index 4a2752c..f0f1046 100644 --- a/zluda_blas/Cargo.toml +++ b/zluda_blas/Cargo.toml @@ -10,6 +10,7 @@ name = "cublas" [dependencies] cuda_macros = { path = "../cuda_macros" } cuda_types = { path = "../cuda_types" } +zluda_common = { path = "../zluda_common" } [package.metadata.zluda] linux_symlinks = [ diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index 09bb28f..c650759 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -14,6 +14,20 @@ macro_rules! unimplemented { } macro_rules! implemented { + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => { + $( + #[cfg_attr(not(test), no_mangle)] + #[allow(improper_ctypes)] + #[allow(improper_ctypes_definitions)] + pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { + cuda_macros::cublas_normalize_fn!( crate::r#impl::$fn_name ) ($(zluda_common::FromCuda::<_, cublasError_t>::from_cuda(&$arg_id)?),*)?; + Ok(()) + } + )* + }; +} + +macro_rules! implemented_and_always_succeeds { ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => { $( #[cfg_attr(not(test), no_mangle)] @@ -28,7 +42,8 @@ macro_rules! implemented { cuda_macros::cublas_function_declarations!( unimplemented, - implemented + implemented <= [], + implemented_and_always_succeeds <= [ cublasGetStatusName, cublasGetStatusString, diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index 33a465f..4615a48 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -1,4 +1,4 @@ -use cuda_types::cuda::*; +use cuda_types::{cublas::*, cuda::*}; use hip_runtime_sys::*; use std::{ ffi::CStr, @@ -16,6 +16,11 @@ impl CudaErrorType for CUerror { const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED; } +impl CudaErrorType for cublasError_t { + const INVALID_VALUE: Self = Self::INVALID_VALUE; + const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED; +} + /// Used to try to convert CUDA API values into our internal representation. /// /// Similar to [`TryFrom`], but we can implement this for primitive types. We also provide conversions from pointers to references. @@ -123,7 +128,9 @@ from_cuda_nop!( CUuuid, CUlibrary, CUmodule, - CUcontext + CUcontext, + cublasHandle_t, + cublasStatus_t ); from_cuda_transmute!( CUuuid => hipUUID,