mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 06:47:46 +03:00
Use FromCuda
in zluda_blas
(#455)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -1816,6 +1816,7 @@ version = "0.0.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"cuda_macros",
|
"cuda_macros",
|
||||||
"cuda_types",
|
"cuda_types",
|
||||||
|
"zluda_common",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -284,7 +284,9 @@ pub struct cublasComputeType_t(pub ::core::ffi::c_uint);
|
|||||||
pub struct cublasContext {
|
pub struct cublasContext {
|
||||||
_unused: [u8; 0],
|
_unused: [u8; 0],
|
||||||
}
|
}
|
||||||
pub type cublasHandle_t = *mut cublasContext;
|
#[repr(transparent)]
|
||||||
|
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
|
||||||
|
pub struct cublasHandle_t(pub *mut cublasContext);
|
||||||
pub type cublasLogCallback = ::core::option::Option<
|
pub type cublasLogCallback = ::core::option::Option<
|
||||||
unsafe extern "C" fn(msg: *const ::core::ffi::c_char),
|
unsafe extern "C" fn(msg: *const ::core::ffi::c_char),
|
||||||
>;
|
>;
|
||||||
|
@ -348,11 +348,7 @@ impl crate::CudaDisplay for cuda_types::cublas::cublasHandle_t {
|
|||||||
_index: usize,
|
_index: usize,
|
||||||
writer: &mut (impl std::io::Write + ?Sized),
|
writer: &mut (impl std::io::Write + ?Sized),
|
||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
if self.is_null() {
|
write!(writer, "{:p}", self.0)
|
||||||
writer.write_all(b"NULL")
|
|
||||||
} else {
|
|
||||||
write!(writer, "{:p}", *self)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn write_cublasCreate_v2(
|
pub fn write_cublasCreate_v2(
|
||||||
|
@ -658,6 +658,7 @@ fn generate_cublas(crate_root: &PathBuf) {
|
|||||||
.allowlist_var("^CUBLAS_.*")
|
.allowlist_var("^CUBLAS_.*")
|
||||||
.must_use_type("cublasStatus_t")
|
.must_use_type("cublasStatus_t")
|
||||||
.constified_enum("cublasStatus_t")
|
.constified_enum("cublasStatus_t")
|
||||||
|
.new_type_alias(r"^cublasHandle_t$")
|
||||||
.allowlist_recursively(false)
|
.allowlist_recursively(false)
|
||||||
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
|
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
|
||||||
.generate()
|
.generate()
|
||||||
|
@ -10,6 +10,7 @@ name = "cublas"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
cuda_macros = { path = "../cuda_macros" }
|
cuda_macros = { path = "../cuda_macros" }
|
||||||
cuda_types = { path = "../cuda_types" }
|
cuda_types = { path = "../cuda_types" }
|
||||||
|
zluda_common = { path = "../zluda_common" }
|
||||||
|
|
||||||
[package.metadata.zluda]
|
[package.metadata.zluda]
|
||||||
linux_symlinks = [
|
linux_symlinks = [
|
||||||
|
@ -14,6 +14,20 @@ macro_rules! unimplemented {
|
|||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! implemented {
|
macro_rules! implemented {
|
||||||
|
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
|
||||||
|
$(
|
||||||
|
#[cfg_attr(not(test), no_mangle)]
|
||||||
|
#[allow(improper_ctypes)]
|
||||||
|
#[allow(improper_ctypes_definitions)]
|
||||||
|
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||||
|
cuda_macros::cublas_normalize_fn!( crate::r#impl::$fn_name ) ($(zluda_common::FromCuda::<_, cublasError_t>::from_cuda(&$arg_id)?),*)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! implemented_and_always_succeeds {
|
||||||
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
|
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
|
||||||
$(
|
$(
|
||||||
#[cfg_attr(not(test), no_mangle)]
|
#[cfg_attr(not(test), no_mangle)]
|
||||||
@ -28,7 +42,8 @@ macro_rules! implemented {
|
|||||||
|
|
||||||
cuda_macros::cublas_function_declarations!(
|
cuda_macros::cublas_function_declarations!(
|
||||||
unimplemented,
|
unimplemented,
|
||||||
implemented
|
implemented <= [],
|
||||||
|
implemented_and_always_succeeds
|
||||||
<= [
|
<= [
|
||||||
cublasGetStatusName,
|
cublasGetStatusName,
|
||||||
cublasGetStatusString,
|
cublasGetStatusString,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use cuda_types::cuda::*;
|
use cuda_types::{cublas::*, cuda::*};
|
||||||
use hip_runtime_sys::*;
|
use hip_runtime_sys::*;
|
||||||
use std::{
|
use std::{
|
||||||
ffi::CStr,
|
ffi::CStr,
|
||||||
@ -16,6 +16,11 @@ impl CudaErrorType for CUerror {
|
|||||||
const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED;
|
const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CudaErrorType for cublasError_t {
|
||||||
|
const INVALID_VALUE: Self = Self::INVALID_VALUE;
|
||||||
|
const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED;
|
||||||
|
}
|
||||||
|
|
||||||
/// Used to try to convert CUDA API values into our internal representation.
|
/// Used to try to convert CUDA API values into our internal representation.
|
||||||
///
|
///
|
||||||
/// Similar to [`TryFrom`], but we can implement this for primitive types. We also provide conversions from pointers to references.
|
/// Similar to [`TryFrom`], but we can implement this for primitive types. We also provide conversions from pointers to references.
|
||||||
@ -123,7 +128,9 @@ from_cuda_nop!(
|
|||||||
CUuuid,
|
CUuuid,
|
||||||
CUlibrary,
|
CUlibrary,
|
||||||
CUmodule,
|
CUmodule,
|
||||||
CUcontext
|
CUcontext,
|
||||||
|
cublasHandle_t,
|
||||||
|
cublasStatus_t
|
||||||
);
|
);
|
||||||
from_cuda_transmute!(
|
from_cuda_transmute!(
|
||||||
CUuuid => hipUUID,
|
CUuuid => hipUUID,
|
||||||
|
Reference in New Issue
Block a user