Handle cublasSetMathMode

This commit is contained in:
Violet
2025-07-31 17:49:20 +00:00
parent ed853980ff
commit 8efd0d2c9f
2 changed files with 13 additions and 4 deletions

View File

@ -63,8 +63,8 @@ pub(crate) fn get_cudart_version() -> usize {
todo!()
}
pub(crate) fn set_math_mode(_handle: &Handle, _mode: cublasMath_t) -> cublasStatus_t {
// TODO: hipblas implements this but rocblas does not
pub(crate) fn set_math_mode(handle: &Handle, mode: rocblas_math_mode) -> cublasStatus_t {
unsafe { rocblas_set_math_mode(handle.handle, mode) }?;
Ok(())
}

View File

@ -134,8 +134,7 @@ from_cuda_nop!(
CUmodule,
CUcontext,
cublasHandle_t,
cublasStatus_t,
cublasMath_t
cublasStatus_t
);
from_cuda_transmute!(
CUuuid => hipUUID,
@ -189,6 +188,16 @@ impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for rocblas_operat
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, cublasMath_t, E> for rocblas_math_mode {
fn from_cuda(mode: &'a cublasMath_t) -> Result<Self, E> {
Ok(match *mode {
cublasMath_t::CUBLAS_DEFAULT_MATH => rocblas_math_mode_::rocblas_default_math,
cublasMath_t::CUBLAS_TF32_TENSOR_OP_MATH => rocblas_math_mode::rocblas_xf32_xdl_math_op,
_ => return Err(E::NOT_SUPPORTED),
})
}
}
/// Represents an object that can be sent across the API boundary.
///
/// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a