Implement cublas functions needed for llm.c (#457)

This commit is contained in:
Violet
2025-07-31 11:08:53 -07:00
committed by GitHub
parent 99c36092be
commit 96ae27e9e1
6 changed files with 169 additions and 1 deletions

2
Cargo.lock generated
View File

@ -1816,6 +1816,7 @@ version = "0.0.0"
dependencies = [ dependencies = [
"cuda_macros", "cuda_macros",
"cuda_types", "cuda_types",
"rocblas-sys",
"zluda_common", "zluda_common",
] ]
@ -1833,6 +1834,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"cuda_types", "cuda_types",
"hip_runtime-sys", "hip_runtime-sys",
"rocblas-sys",
] ]
[[package]] [[package]]

View File

@ -11,6 +11,7 @@ name = "cublas"
cuda_macros = { path = "../cuda_macros" } cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }
zluda_common = { path = "../zluda_common" } zluda_common = { path = "../zluda_common" }
rocblas-sys = { path = "../ext/rocblas-sys" }
[package.metadata.zluda] [package.metadata.zluda]
linux_symlinks = [ linux_symlinks = [

View File

@ -1,4 +1,34 @@
use std::mem;
use cuda_types::cublas::*; use cuda_types::cublas::*;
use zluda_common::{from_cuda_object, ZludaObject};
use rocblas_sys::*;
pub struct Handle {
handle: rocblas_handle,
}
impl Handle {
fn new() -> Self {
Self {
handle: unsafe { mem::zeroed() },
}
}
}
impl ZludaObject for Handle {
const COOKIE: usize = 0x57c3fdb0fd72b08e;
type Error = cublasError_t;
type CudaHandle = cublasHandle_t;
fn drop_checked(&mut self) -> cublasStatus_t {
Ok(())
}
}
from_cuda_object!(Handle);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
pub(crate) fn unimplemented() -> cublasStatus_t { pub(crate) fn unimplemented() -> cublasStatus_t {
@ -10,6 +40,13 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
cublasStatus_t::ERROR_NOT_SUPPORTED cublasStatus_t::ERROR_NOT_SUPPORTED
} }
pub(crate) fn create_v2(handle: &mut cublasHandle_t) -> cublasStatus_t {
let mut zluda_blas_handle = Handle::new();
unsafe { rocblas_create_handle(&mut zluda_blas_handle.handle) }?;
*handle = Handle::wrap(zluda_blas_handle);
Ok(())
}
pub(crate) fn get_status_name(_status: cublasStatus_t) -> *const ::core::ffi::c_char { pub(crate) fn get_status_name(_status: cublasStatus_t) -> *const ::core::ffi::c_char {
todo!() todo!()
} }
@ -25,3 +62,94 @@ pub(crate) fn xerbla(_sr_name: *const ::core::ffi::c_char, _info: ::core::ffi::c
pub(crate) fn get_cudart_version() -> usize { pub(crate) fn get_cudart_version() -> usize {
todo!() todo!()
} }
pub(crate) fn set_math_mode(handle: &Handle, mode: rocblas_math_mode) -> cublasStatus_t {
unsafe { rocblas_set_math_mode(handle.handle, mode) }?;
Ok(())
}
pub(crate) fn sgemm_strided_batched(
handle: &Handle,
transa: rocblas_operation,
transb: rocblas_operation,
m: ::core::ffi::c_int,
n: ::core::ffi::c_int,
k: ::core::ffi::c_int,
alpha: *const f32,
a: *const f32,
lda: ::core::ffi::c_int,
stride_a: ::core::ffi::c_longlong,
b: *const f32,
ldb: ::core::ffi::c_int,
stride_b: ::core::ffi::c_longlong,
beta: *const f32,
c: *mut f32,
ldc: ::core::ffi::c_int,
stride_c: ::core::ffi::c_longlong,
batch_count: ::core::ffi::c_int,
) -> cublasStatus_t {
unsafe {
rocblas_sgemm_strided_batched(
handle.handle,
transa,
transb,
m,
n,
k,
alpha,
a,
lda,
stride_a,
b,
ldb,
stride_b,
beta,
c,
ldc,
stride_c,
batch_count,
)
}?;
Ok(())
}
pub(crate) fn sgemm_v2(
handle: &Handle,
transa: rocblas_operation,
transb: rocblas_operation,
m: ::core::ffi::c_int,
n: ::core::ffi::c_int,
k: ::core::ffi::c_int,
alpha: *const f32,
a: *const f32,
lda: ::core::ffi::c_int,
b: *const f32,
ldb: ::core::ffi::c_int,
beta: *const f32,
c: *mut f32,
ldc: ::core::ffi::c_int,
) -> cublasStatus_t {
unsafe {
rocblas_sgemm(
handle.handle,
transa,
transb,
m,
n,
k,
alpha,
a,
lda,
b,
ldb,
beta,
c,
ldc,
)
}?;
Ok(())
}
pub(crate) fn destroy_v2(handle: cublasHandle_t) -> cublasStatus_t {
zluda_common::drop_checked::<Handle>(handle)
}

View File

@ -1,5 +1,7 @@
mod r#impl; mod r#impl;
use cuda_types::cublas::cublasError_t;
macro_rules! unimplemented { macro_rules! unimplemented {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => { ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
$( $(
@ -42,7 +44,14 @@ macro_rules! implemented_and_always_succeeds {
cuda_macros::cublas_function_declarations!( cuda_macros::cublas_function_declarations!(
unimplemented, unimplemented,
implemented <= [], implemented
<= [
cublasCreate_v2,
cublasSetMathMode,
cublasSgemmStridedBatched,
cublasSgemm_v2,
cublasDestroy_v2
],
implemented_and_always_succeeds implemented_and_always_succeeds
<= [ <= [
cublasGetStatusName, cublasGetStatusName,

View File

@ -7,3 +7,4 @@ edition = "2021"
[dependencies] [dependencies]
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
rocblas-sys = { path = "../ext/rocblas-sys" }

View File

@ -1,5 +1,6 @@
use cuda_types::{cublas::*, cuda::*}; use cuda_types::{cublas::*, cuda::*};
use hip_runtime_sys::*; use hip_runtime_sys::*;
use rocblas_sys::*;
use std::{ use std::{
ffi::CStr, ffi::CStr,
mem::{self, ManuallyDrop, MaybeUninit}, mem::{self, ManuallyDrop, MaybeUninit},
@ -110,6 +111,8 @@ from_cuda_nop!(
*mut i8, *mut i8,
*mut i32, *mut i32,
*mut usize, *mut usize,
*const f32,
*mut f32,
*const ::core::ffi::c_void, *const ::core::ffi::c_void,
*const ::core::ffi::c_char, *const ::core::ffi::c_char,
*mut ::core::ffi::c_void, *mut ::core::ffi::c_void,
@ -118,6 +121,7 @@ from_cuda_nop!(
i32, i32,
u32, u32,
u64, u64,
i64,
usize, usize,
cuda_types::cuda::CUdevprop, cuda_types::cuda::CUdevprop,
CUdevice_attribute, CUdevice_attribute,
@ -171,6 +175,29 @@ impl<'a, E: CudaErrorType> FromCuda<'a, *const ::core::ffi::c_void, E> for &'a :
} }
} }
impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for rocblas_operation {
fn from_cuda(t: &'a cublasOperation_t) -> Result<Self, E> {
Ok(match *t {
cublasOperation_t::CUBLAS_OP_N => rocblas_operation::rocblas_operation_none,
cublasOperation_t::CUBLAS_OP_T => rocblas_operation::rocblas_operation_transpose,
cublasOperation_t::CUBLAS_OP_C => {
rocblas_operation::rocblas_operation_conjugate_transpose
}
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, cublasMath_t, E> for rocblas_math_mode {
fn from_cuda(mode: &'a cublasMath_t) -> Result<Self, E> {
Ok(match *mode {
cublasMath_t::CUBLAS_DEFAULT_MATH => rocblas_math_mode_::rocblas_default_math,
cublasMath_t::CUBLAS_TF32_TENSOR_OP_MATH => rocblas_math_mode::rocblas_xf32_xdl_math_op,
_ => return Err(E::NOT_SUPPORTED),
})
}
}
/// Represents an object that can be sent across the API boundary. /// Represents an object that can be sent across the API boundary.
/// ///
/// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a /// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a