Handle cublasSetMathMode

This commit is contained in:
Violet
2025-07-31 17:49:20 +00:00
parent ed853980ff
commit 8efd0d2c9f
2 changed files with 13 additions and 4 deletions

View File

@ -63,8 +63,8 @@ pub(crate) fn get_cudart_version() -> usize {
todo!() todo!()
} }
pub(crate) fn set_math_mode(_handle: &Handle, _mode: cublasMath_t) -> cublasStatus_t { pub(crate) fn set_math_mode(handle: &Handle, mode: rocblas_math_mode) -> cublasStatus_t {
// TODO: hipblas implements this but rocblas does not unsafe { rocblas_set_math_mode(handle.handle, mode) }?;
Ok(()) Ok(())
} }

View File

@ -134,8 +134,7 @@ from_cuda_nop!(
CUmodule, CUmodule,
CUcontext, CUcontext,
cublasHandle_t, cublasHandle_t,
cublasStatus_t, cublasStatus_t
cublasMath_t
); );
from_cuda_transmute!( from_cuda_transmute!(
CUuuid => hipUUID, CUuuid => hipUUID,
@ -189,6 +188,16 @@ impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for rocblas_operat
} }
} }
impl<'a, E: CudaErrorType> FromCuda<'a, cublasMath_t, E> for rocblas_math_mode {
fn from_cuda(mode: &'a cublasMath_t) -> Result<Self, E> {
Ok(match *mode {
cublasMath_t::CUBLAS_DEFAULT_MATH => rocblas_math_mode_::rocblas_default_math,
cublasMath_t::CUBLAS_TF32_TENSOR_OP_MATH => rocblas_math_mode::rocblas_xf32_xdl_math_op,
_ => return Err(E::NOT_SUPPORTED),
})
}
}
/// Represents an object that can be sent across the API boundary. /// Represents an object that can be sent across the API boundary.
/// ///
/// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a /// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a