From 8efd0d2c9f20b2e43e9782b6835ed9041fb7740d Mon Sep 17 00:00:00 2001 From: Violet Date: Thu, 31 Jul 2025 17:49:20 +0000 Subject: [PATCH] Handle cublasSetMathMode --- zluda_blas/src/impl.rs | 4 ++-- zluda_common/src/lib.rs | 13 +++++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/zluda_blas/src/impl.rs b/zluda_blas/src/impl.rs index 64b32d5..3a6faa3 100644 --- a/zluda_blas/src/impl.rs +++ b/zluda_blas/src/impl.rs @@ -63,8 +63,8 @@ pub(crate) fn get_cudart_version() -> usize { todo!() } -pub(crate) fn set_math_mode(_handle: &Handle, _mode: cublasMath_t) -> cublasStatus_t { - // TODO: hipblas implements this but rocblas does not +pub(crate) fn set_math_mode(handle: &Handle, mode: rocblas_math_mode) -> cublasStatus_t { + unsafe { rocblas_set_math_mode(handle.handle, mode) }?; Ok(()) } diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index c4edc0e..dcff5ad 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -134,8 +134,7 @@ from_cuda_nop!( CUmodule, CUcontext, cublasHandle_t, - cublasStatus_t, - cublasMath_t + cublasStatus_t ); from_cuda_transmute!( CUuuid => hipUUID, @@ -189,6 +188,16 @@ impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for rocblas_operat } } +impl<'a, E: CudaErrorType> FromCuda<'a, cublasMath_t, E> for rocblas_math_mode { + fn from_cuda(mode: &'a cublasMath_t) -> Result { + Ok(match *mode { + cublasMath_t::CUBLAS_DEFAULT_MATH => rocblas_math_mode_::rocblas_default_math, + cublasMath_t::CUBLAS_TF32_TENSOR_OP_MATH => rocblas_math_mode::rocblas_xf32_xdl_math_op, + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + /// Represents an object that can be sent across the API boundary. /// /// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a