Use FromCuda in zluda_blas (#455)

This commit is contained in:
Violet
2025-07-31 09:52:10 -07:00
committed by GitHub
parent 49aabffdcc
commit 99c36092be
7 changed files with 32 additions and 9 deletions

1
Cargo.lock generated
View File

@ -1816,6 +1816,7 @@ version = "0.0.0"
dependencies = [ dependencies = [
"cuda_macros", "cuda_macros",
"cuda_types", "cuda_types",
"zluda_common",
] ]
[[package]] [[package]]

View File

@ -284,7 +284,9 @@ pub struct cublasComputeType_t(pub ::core::ffi::c_uint);
pub struct cublasContext { pub struct cublasContext {
_unused: [u8; 0], _unused: [u8; 0],
} }
pub type cublasHandle_t = *mut cublasContext; #[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasHandle_t(pub *mut cublasContext);
pub type cublasLogCallback = ::core::option::Option< pub type cublasLogCallback = ::core::option::Option<
unsafe extern "C" fn(msg: *const ::core::ffi::c_char), unsafe extern "C" fn(msg: *const ::core::ffi::c_char),
>; >;

View File

@ -348,11 +348,7 @@ impl crate::CudaDisplay for cuda_types::cublas::cublasHandle_t {
_index: usize, _index: usize,
writer: &mut (impl std::io::Write + ?Sized), writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> { ) -> std::io::Result<()> {
if self.is_null() { write!(writer, "{:p}", self.0)
writer.write_all(b"NULL")
} else {
write!(writer, "{:p}", *self)
}
} }
} }
pub fn write_cublasCreate_v2( pub fn write_cublasCreate_v2(

View File

@ -658,6 +658,7 @@ fn generate_cublas(crate_root: &PathBuf) {
.allowlist_var("^CUBLAS_.*") .allowlist_var("^CUBLAS_.*")
.must_use_type("cublasStatus_t") .must_use_type("cublasStatus_t")
.constified_enum("cublasStatus_t") .constified_enum("cublasStatus_t")
.new_type_alias(r"^cublasHandle_t$")
.allowlist_recursively(false) .allowlist_recursively(false)
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"]) .clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
.generate() .generate()

View File

@ -10,6 +10,7 @@ name = "cublas"
[dependencies] [dependencies]
cuda_macros = { path = "../cuda_macros" } cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }
zluda_common = { path = "../zluda_common" }
[package.metadata.zluda] [package.metadata.zluda]
linux_symlinks = [ linux_symlinks = [

View File

@ -14,6 +14,20 @@ macro_rules! unimplemented {
} }
macro_rules! implemented { macro_rules! implemented {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
$(
#[cfg_attr(not(test), no_mangle)]
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
cuda_macros::cublas_normalize_fn!( crate::r#impl::$fn_name ) ($(zluda_common::FromCuda::<_, cublasError_t>::from_cuda(&$arg_id)?),*)?;
Ok(())
}
)*
};
}
macro_rules! implemented_and_always_succeeds {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => { ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
$( $(
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -28,7 +42,8 @@ macro_rules! implemented {
cuda_macros::cublas_function_declarations!( cuda_macros::cublas_function_declarations!(
unimplemented, unimplemented,
implemented implemented <= [],
implemented_and_always_succeeds
<= [ <= [
cublasGetStatusName, cublasGetStatusName,
cublasGetStatusString, cublasGetStatusString,

View File

@ -1,4 +1,4 @@
use cuda_types::cuda::*; use cuda_types::{cublas::*, cuda::*};
use hip_runtime_sys::*; use hip_runtime_sys::*;
use std::{ use std::{
ffi::CStr, ffi::CStr,
@ -16,6 +16,11 @@ impl CudaErrorType for CUerror {
const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED; const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED;
} }
impl CudaErrorType for cublasError_t {
const INVALID_VALUE: Self = Self::INVALID_VALUE;
const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED;
}
/// Used to try to convert CUDA API values into our internal representation. /// Used to try to convert CUDA API values into our internal representation.
/// ///
/// Similar to [`TryFrom`], but we can implement this for primitive types. We also provide conversions from pointers to references. /// Similar to [`TryFrom`], but we can implement this for primitive types. We also provide conversions from pointers to references.
@ -123,7 +128,9 @@ from_cuda_nop!(
CUuuid, CUuuid,
CUlibrary, CUlibrary,
CUmodule, CUmodule,
CUcontext CUcontext,
cublasHandle_t,
cublasStatus_t
); );
from_cuda_transmute!( from_cuda_transmute!(
CUuuid => hipUUID, CUuuid => hipUUID,