mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 06:47:46 +03:00
Implement cublas functions needed for llm.c (#457)
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1816,6 +1816,7 @@ version = "0.0.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"cuda_macros",
|
"cuda_macros",
|
||||||
"cuda_types",
|
"cuda_types",
|
||||||
|
"rocblas-sys",
|
||||||
"zluda_common",
|
"zluda_common",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1833,6 +1834,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"cuda_types",
|
"cuda_types",
|
||||||
"hip_runtime-sys",
|
"hip_runtime-sys",
|
||||||
|
"rocblas-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -11,6 +11,7 @@ name = "cublas"
|
|||||||
cuda_macros = { path = "../cuda_macros" }
|
cuda_macros = { path = "../cuda_macros" }
|
||||||
cuda_types = { path = "../cuda_types" }
|
cuda_types = { path = "../cuda_types" }
|
||||||
zluda_common = { path = "../zluda_common" }
|
zluda_common = { path = "../zluda_common" }
|
||||||
|
rocblas-sys = { path = "../ext/rocblas-sys" }
|
||||||
|
|
||||||
[package.metadata.zluda]
|
[package.metadata.zluda]
|
||||||
linux_symlinks = [
|
linux_symlinks = [
|
||||||
|
@ -1,4 +1,34 @@
|
|||||||
|
use std::mem;
|
||||||
|
|
||||||
use cuda_types::cublas::*;
|
use cuda_types::cublas::*;
|
||||||
|
use zluda_common::{from_cuda_object, ZludaObject};
|
||||||
|
|
||||||
|
use rocblas_sys::*;
|
||||||
|
|
||||||
|
pub struct Handle {
|
||||||
|
handle: rocblas_handle,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Handle {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
handle: unsafe { mem::zeroed() },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ZludaObject for Handle {
|
||||||
|
const COOKIE: usize = 0x57c3fdb0fd72b08e;
|
||||||
|
|
||||||
|
type Error = cublasError_t;
|
||||||
|
type CudaHandle = cublasHandle_t;
|
||||||
|
|
||||||
|
fn drop_checked(&mut self) -> cublasStatus_t {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
from_cuda_object!(Handle);
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
pub(crate) fn unimplemented() -> cublasStatus_t {
|
pub(crate) fn unimplemented() -> cublasStatus_t {
|
||||||
@ -10,6 +40,13 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
|
|||||||
cublasStatus_t::ERROR_NOT_SUPPORTED
|
cublasStatus_t::ERROR_NOT_SUPPORTED
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn create_v2(handle: &mut cublasHandle_t) -> cublasStatus_t {
|
||||||
|
let mut zluda_blas_handle = Handle::new();
|
||||||
|
unsafe { rocblas_create_handle(&mut zluda_blas_handle.handle) }?;
|
||||||
|
*handle = Handle::wrap(zluda_blas_handle);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn get_status_name(_status: cublasStatus_t) -> *const ::core::ffi::c_char {
|
pub(crate) fn get_status_name(_status: cublasStatus_t) -> *const ::core::ffi::c_char {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
@ -25,3 +62,94 @@ pub(crate) fn xerbla(_sr_name: *const ::core::ffi::c_char, _info: ::core::ffi::c
|
|||||||
pub(crate) fn get_cudart_version() -> usize {
|
pub(crate) fn get_cudart_version() -> usize {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_math_mode(handle: &Handle, mode: rocblas_math_mode) -> cublasStatus_t {
|
||||||
|
unsafe { rocblas_set_math_mode(handle.handle, mode) }?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sgemm_strided_batched(
|
||||||
|
handle: &Handle,
|
||||||
|
transa: rocblas_operation,
|
||||||
|
transb: rocblas_operation,
|
||||||
|
m: ::core::ffi::c_int,
|
||||||
|
n: ::core::ffi::c_int,
|
||||||
|
k: ::core::ffi::c_int,
|
||||||
|
alpha: *const f32,
|
||||||
|
a: *const f32,
|
||||||
|
lda: ::core::ffi::c_int,
|
||||||
|
stride_a: ::core::ffi::c_longlong,
|
||||||
|
b: *const f32,
|
||||||
|
ldb: ::core::ffi::c_int,
|
||||||
|
stride_b: ::core::ffi::c_longlong,
|
||||||
|
beta: *const f32,
|
||||||
|
c: *mut f32,
|
||||||
|
ldc: ::core::ffi::c_int,
|
||||||
|
stride_c: ::core::ffi::c_longlong,
|
||||||
|
batch_count: ::core::ffi::c_int,
|
||||||
|
) -> cublasStatus_t {
|
||||||
|
unsafe {
|
||||||
|
rocblas_sgemm_strided_batched(
|
||||||
|
handle.handle,
|
||||||
|
transa,
|
||||||
|
transb,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
|
a,
|
||||||
|
lda,
|
||||||
|
stride_a,
|
||||||
|
b,
|
||||||
|
ldb,
|
||||||
|
stride_b,
|
||||||
|
beta,
|
||||||
|
c,
|
||||||
|
ldc,
|
||||||
|
stride_c,
|
||||||
|
batch_count,
|
||||||
|
)
|
||||||
|
}?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sgemm_v2(
|
||||||
|
handle: &Handle,
|
||||||
|
transa: rocblas_operation,
|
||||||
|
transb: rocblas_operation,
|
||||||
|
m: ::core::ffi::c_int,
|
||||||
|
n: ::core::ffi::c_int,
|
||||||
|
k: ::core::ffi::c_int,
|
||||||
|
alpha: *const f32,
|
||||||
|
a: *const f32,
|
||||||
|
lda: ::core::ffi::c_int,
|
||||||
|
b: *const f32,
|
||||||
|
ldb: ::core::ffi::c_int,
|
||||||
|
beta: *const f32,
|
||||||
|
c: *mut f32,
|
||||||
|
ldc: ::core::ffi::c_int,
|
||||||
|
) -> cublasStatus_t {
|
||||||
|
unsafe {
|
||||||
|
rocblas_sgemm(
|
||||||
|
handle.handle,
|
||||||
|
transa,
|
||||||
|
transb,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
|
a,
|
||||||
|
lda,
|
||||||
|
b,
|
||||||
|
ldb,
|
||||||
|
beta,
|
||||||
|
c,
|
||||||
|
ldc,
|
||||||
|
)
|
||||||
|
}?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn destroy_v2(handle: cublasHandle_t) -> cublasStatus_t {
|
||||||
|
zluda_common::drop_checked::<Handle>(handle)
|
||||||
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
mod r#impl;
|
mod r#impl;
|
||||||
|
|
||||||
|
use cuda_types::cublas::cublasError_t;
|
||||||
|
|
||||||
macro_rules! unimplemented {
|
macro_rules! unimplemented {
|
||||||
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
|
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
|
||||||
$(
|
$(
|
||||||
@ -42,7 +44,14 @@ macro_rules! implemented_and_always_succeeds {
|
|||||||
|
|
||||||
cuda_macros::cublas_function_declarations!(
|
cuda_macros::cublas_function_declarations!(
|
||||||
unimplemented,
|
unimplemented,
|
||||||
implemented <= [],
|
implemented
|
||||||
|
<= [
|
||||||
|
cublasCreate_v2,
|
||||||
|
cublasSetMathMode,
|
||||||
|
cublasSgemmStridedBatched,
|
||||||
|
cublasSgemm_v2,
|
||||||
|
cublasDestroy_v2
|
||||||
|
],
|
||||||
implemented_and_always_succeeds
|
implemented_and_always_succeeds
|
||||||
<= [
|
<= [
|
||||||
cublasGetStatusName,
|
cublasGetStatusName,
|
||||||
|
@ -7,3 +7,4 @@ edition = "2021"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
cuda_types = { path = "../cuda_types" }
|
cuda_types = { path = "../cuda_types" }
|
||||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||||
|
rocblas-sys = { path = "../ext/rocblas-sys" }
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use cuda_types::{cublas::*, cuda::*};
|
use cuda_types::{cublas::*, cuda::*};
|
||||||
use hip_runtime_sys::*;
|
use hip_runtime_sys::*;
|
||||||
|
use rocblas_sys::*;
|
||||||
use std::{
|
use std::{
|
||||||
ffi::CStr,
|
ffi::CStr,
|
||||||
mem::{self, ManuallyDrop, MaybeUninit},
|
mem::{self, ManuallyDrop, MaybeUninit},
|
||||||
@ -110,6 +111,8 @@ from_cuda_nop!(
|
|||||||
*mut i8,
|
*mut i8,
|
||||||
*mut i32,
|
*mut i32,
|
||||||
*mut usize,
|
*mut usize,
|
||||||
|
*const f32,
|
||||||
|
*mut f32,
|
||||||
*const ::core::ffi::c_void,
|
*const ::core::ffi::c_void,
|
||||||
*const ::core::ffi::c_char,
|
*const ::core::ffi::c_char,
|
||||||
*mut ::core::ffi::c_void,
|
*mut ::core::ffi::c_void,
|
||||||
@ -118,6 +121,7 @@ from_cuda_nop!(
|
|||||||
i32,
|
i32,
|
||||||
u32,
|
u32,
|
||||||
u64,
|
u64,
|
||||||
|
i64,
|
||||||
usize,
|
usize,
|
||||||
cuda_types::cuda::CUdevprop,
|
cuda_types::cuda::CUdevprop,
|
||||||
CUdevice_attribute,
|
CUdevice_attribute,
|
||||||
@ -171,6 +175,29 @@ impl<'a, E: CudaErrorType> FromCuda<'a, *const ::core::ffi::c_void, E> for &'a :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for rocblas_operation {
|
||||||
|
fn from_cuda(t: &'a cublasOperation_t) -> Result<Self, E> {
|
||||||
|
Ok(match *t {
|
||||||
|
cublasOperation_t::CUBLAS_OP_N => rocblas_operation::rocblas_operation_none,
|
||||||
|
cublasOperation_t::CUBLAS_OP_T => rocblas_operation::rocblas_operation_transpose,
|
||||||
|
cublasOperation_t::CUBLAS_OP_C => {
|
||||||
|
rocblas_operation::rocblas_operation_conjugate_transpose
|
||||||
|
}
|
||||||
|
_ => return Err(E::NOT_SUPPORTED),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, E: CudaErrorType> FromCuda<'a, cublasMath_t, E> for rocblas_math_mode {
|
||||||
|
fn from_cuda(mode: &'a cublasMath_t) -> Result<Self, E> {
|
||||||
|
Ok(match *mode {
|
||||||
|
cublasMath_t::CUBLAS_DEFAULT_MATH => rocblas_math_mode_::rocblas_default_math,
|
||||||
|
cublasMath_t::CUBLAS_TF32_TENSOR_OP_MATH => rocblas_math_mode::rocblas_xf32_xdl_math_op,
|
||||||
|
_ => return Err(E::NOT_SUPPORTED),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Represents an object that can be sent across the API boundary.
|
/// Represents an object that can be sent across the API boundary.
|
||||||
///
|
///
|
||||||
/// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a
|
/// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a
|
||||||
|
Reference in New Issue
Block a user