From 96ae27e9e10317855fff526dcee5b44a1d930d88 Mon Sep 17 00:00:00 2001 From: Violet Date: Thu, 31 Jul 2025 11:08:53 -0700 Subject: [PATCH] Implement cublas functions needed for llm.c (#457) --- Cargo.lock | 2 + zluda_blas/Cargo.toml | 1 + zluda_blas/src/impl.rs | 128 ++++++++++++++++++++++++++++++++++++++++ zluda_blas/src/lib.rs | 11 +++- zluda_common/Cargo.toml | 1 + zluda_common/src/lib.rs | 27 +++++++++ 6 files changed, 169 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 574e9af..8c663b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1816,6 +1816,7 @@ version = "0.0.0" dependencies = [ "cuda_macros", "cuda_types", + "rocblas-sys", "zluda_common", ] @@ -1833,6 +1834,7 @@ version = "0.1.0" dependencies = [ "cuda_types", "hip_runtime-sys", + "rocblas-sys", ] [[package]] diff --git a/zluda_blas/Cargo.toml b/zluda_blas/Cargo.toml index f0f1046..6f7960e 100644 --- a/zluda_blas/Cargo.toml +++ b/zluda_blas/Cargo.toml @@ -11,6 +11,7 @@ name = "cublas" cuda_macros = { path = "../cuda_macros" } cuda_types = { path = "../cuda_types" } zluda_common = { path = "../zluda_common" } +rocblas-sys = { path = "../ext/rocblas-sys" } [package.metadata.zluda] linux_symlinks = [ diff --git a/zluda_blas/src/impl.rs b/zluda_blas/src/impl.rs index 55b1edc..3a6faa3 100644 --- a/zluda_blas/src/impl.rs +++ b/zluda_blas/src/impl.rs @@ -1,4 +1,34 @@ +use std::mem; + use cuda_types::cublas::*; +use zluda_common::{from_cuda_object, ZludaObject}; + +use rocblas_sys::*; + +pub struct Handle { + handle: rocblas_handle, +} + +impl Handle { + fn new() -> Self { + Self { + handle: unsafe { mem::zeroed() }, + } + } +} + +impl ZludaObject for Handle { + const COOKIE: usize = 0x57c3fdb0fd72b08e; + + type Error = cublasError_t; + type CudaHandle = cublasHandle_t; + + fn drop_checked(&mut self) -> cublasStatus_t { + Ok(()) + } +} + +from_cuda_object!(Handle); #[cfg(debug_assertions)] pub(crate) fn unimplemented() -> cublasStatus_t { @@ -10,6 +40,13 @@ pub(crate) fn unimplemented() -> cublasStatus_t { cublasStatus_t::ERROR_NOT_SUPPORTED } +pub(crate) fn create_v2(handle: &mut cublasHandle_t) -> cublasStatus_t { + let mut zluda_blas_handle = Handle::new(); + unsafe { rocblas_create_handle(&mut zluda_blas_handle.handle) }?; + *handle = Handle::wrap(zluda_blas_handle); + Ok(()) +} + pub(crate) fn get_status_name(_status: cublasStatus_t) -> *const ::core::ffi::c_char { todo!() } @@ -25,3 +62,94 @@ pub(crate) fn xerbla(_sr_name: *const ::core::ffi::c_char, _info: ::core::ffi::c pub(crate) fn get_cudart_version() -> usize { todo!() } + +pub(crate) fn set_math_mode(handle: &Handle, mode: rocblas_math_mode) -> cublasStatus_t { + unsafe { rocblas_set_math_mode(handle.handle, mode) }?; + Ok(()) +} + +pub(crate) fn sgemm_strided_batched( + handle: &Handle, + transa: rocblas_operation, + transb: rocblas_operation, + m: ::core::ffi::c_int, + n: ::core::ffi::c_int, + k: ::core::ffi::c_int, + alpha: *const f32, + a: *const f32, + lda: ::core::ffi::c_int, + stride_a: ::core::ffi::c_longlong, + b: *const f32, + ldb: ::core::ffi::c_int, + stride_b: ::core::ffi::c_longlong, + beta: *const f32, + c: *mut f32, + ldc: ::core::ffi::c_int, + stride_c: ::core::ffi::c_longlong, + batch_count: ::core::ffi::c_int, +) -> cublasStatus_t { + unsafe { + rocblas_sgemm_strided_batched( + handle.handle, + transa, + transb, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_count, + ) + }?; + Ok(()) +} + +pub(crate) fn sgemm_v2( + handle: &Handle, + transa: rocblas_operation, + transb: rocblas_operation, + m: ::core::ffi::c_int, + n: ::core::ffi::c_int, + k: ::core::ffi::c_int, + alpha: *const f32, + a: *const f32, + lda: ::core::ffi::c_int, + b: *const f32, + ldb: ::core::ffi::c_int, + beta: *const f32, + c: *mut f32, + ldc: ::core::ffi::c_int, +) -> cublasStatus_t { + unsafe { + rocblas_sgemm( + handle.handle, + transa, + transb, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + ) + }?; + Ok(()) +} + +pub(crate) fn destroy_v2(handle: cublasHandle_t) -> cublasStatus_t { + zluda_common::drop_checked::(handle) +} diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index c650759..712183c 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -1,5 +1,7 @@ mod r#impl; +use cuda_types::cublas::cublasError_t; + macro_rules! unimplemented { ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => { $( @@ -42,7 +44,14 @@ macro_rules! implemented_and_always_succeeds { cuda_macros::cublas_function_declarations!( unimplemented, - implemented <= [], + implemented + <= [ + cublasCreate_v2, + cublasSetMathMode, + cublasSgemmStridedBatched, + cublasSgemm_v2, + cublasDestroy_v2 + ], implemented_and_always_succeeds <= [ cublasGetStatusName, diff --git a/zluda_common/Cargo.toml b/zluda_common/Cargo.toml index 8258e83..4c528e5 100644 --- a/zluda_common/Cargo.toml +++ b/zluda_common/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" [dependencies] cuda_types = { path = "../cuda_types" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" } +rocblas-sys = { path = "../ext/rocblas-sys" } diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index 4615a48..dcff5ad 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -1,5 +1,6 @@ use cuda_types::{cublas::*, cuda::*}; use hip_runtime_sys::*; +use rocblas_sys::*; use std::{ ffi::CStr, mem::{self, ManuallyDrop, MaybeUninit}, @@ -110,6 +111,8 @@ from_cuda_nop!( *mut i8, *mut i32, *mut usize, + *const f32, + *mut f32, *const ::core::ffi::c_void, *const ::core::ffi::c_char, *mut ::core::ffi::c_void, @@ -118,6 +121,7 @@ from_cuda_nop!( i32, u32, u64, + i64, usize, cuda_types::cuda::CUdevprop, CUdevice_attribute, @@ -171,6 +175,29 @@ impl<'a, E: CudaErrorType> FromCuda<'a, *const ::core::ffi::c_void, E> for &'a : } } +impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for rocblas_operation { + fn from_cuda(t: &'a cublasOperation_t) -> Result { + Ok(match *t { + cublasOperation_t::CUBLAS_OP_N => rocblas_operation::rocblas_operation_none, + cublasOperation_t::CUBLAS_OP_T => rocblas_operation::rocblas_operation_transpose, + cublasOperation_t::CUBLAS_OP_C => { + rocblas_operation::rocblas_operation_conjugate_transpose + } + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, cublasMath_t, E> for rocblas_math_mode { + fn from_cuda(mode: &'a cublasMath_t) -> Result { + Ok(match *mode { + cublasMath_t::CUBLAS_DEFAULT_MATH => rocblas_math_mode_::rocblas_default_math, + cublasMath_t::CUBLAS_TF32_TENSOR_OP_MATH => rocblas_math_mode::rocblas_xf32_xdl_math_op, + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + /// Represents an object that can be sent across the API boundary. /// /// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a