Refactor FromCuda error type to be generic

This commit is contained in:
Violet
2025-07-28 06:29:04 +00:00
parent c07d7678cd
commit f6301e22aa
6 changed files with 65 additions and 52 deletions

View File

@ -88,6 +88,7 @@ impl Context {
impl ZludaObject for Context {
const COOKIE: usize = 0x5f867c6d9cb73315;
type Error = CUerror;
type CudaHandle = CUcontext;
fn drop_checked(&mut self) -> CUresult {
@ -128,7 +129,7 @@ pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult {
None
})
} else {
let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?;
let ctx: &Context = FromCuda::<_, CUerror>::from_cuda(&raw_ctx)?;
let device = ctx.device;
STACK.with(move |stack| {
let mut stack = stack.borrow_mut();
@ -157,7 +158,7 @@ pub(crate) fn get_current(pctx: &mut CUcontext) -> CUresult {
pub(crate) fn get_device(dev: &mut hipDevice_t) -> CUresult {
let cu_ctx = get_current_context()?;
let ctx: &Context = FromCuda::from_cuda(&cu_ctx)?;
let ctx: &Context = FromCuda::<_, CUerror>::from_cuda(&cu_ctx)?;
*dev = ctx.device;
Ok(())
}

View File

@ -175,7 +175,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
context::get_current(&mut current_ctx)?;
current_ctx
};
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
let ctx_obj: &context::Context = FromCuda::<_, CUerror>::from_cuda(&_ctx)?;
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
state.storage.insert(
key as usize,
@ -194,7 +194,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
cu_ctx: CUcontext,
key: *mut c_void,
) -> CUresult {
let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?;
let ctx_obj: &context::Context = FromCuda::<_, CUerror>::from_cuda(&cu_ctx)?;
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
state.storage.remove(&(key as usize));
Ok(())
@ -213,7 +213,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
} else {
_ctx = cu_ctx
};
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
let ctx_obj: &context::Context = FromCuda::<_, CUerror>::from_cuda(&_ctx)?;
ctx_obj.with_state(|state: &context::ContextState| {
match state.storage.get(&(key as usize)) {
Some(data) => *value = data.value as *mut c_void,

View File

@ -9,6 +9,7 @@ pub(crate) struct Library {
impl ZludaObject for Library {
const COOKIE: usize = 0xb328a916cc234d7c;
type Error = CUerror;
type CudaHandle = CUlibrary;
fn drop_checked(&mut self) -> CUresult {
@ -38,10 +39,7 @@ pub(crate) unsafe fn unload(library: CUlibrary) -> CUresult {
super::drop_checked::<Library>(library)
}
pub(crate) unsafe fn get_module(
out: &mut CUmodule,
library: &Library,
) -> CUresult {
*out = module::Module{base: library.base}.wrap();
pub(crate) unsafe fn get_module(out: &mut CUmodule, library: &Library) -> CUresult {
*out = module::Module { base: library.base }.wrap();
Ok(())
}

View File

@ -26,39 +26,49 @@ pub(crate) fn unimplemented() -> CUresult {
CUresult::ERROR_NOT_SUPPORTED
}
pub(crate) trait FromCuda<'a, T>: Sized {
fn from_cuda(t: &'a T) -> Result<Self, CUerror>;
pub(crate) trait CudaErrorType {
const INVALID_VALUE: Self;
const NOT_SUPPORTED: Self;
}
impl CudaErrorType for CUerror {
const INVALID_VALUE: Self = Self::INVALID_VALUE;
const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED;
}
pub(crate) trait FromCuda<'a, T, E: CudaErrorType>: Sized {
fn from_cuda(t: &'a T) -> Result<Self, E>;
}
macro_rules! from_cuda_nop {
($($type_:ty),*) => {
$(
impl<'a> FromCuda<'a, $type_> for $type_ {
fn from_cuda(x: &'a $type_) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, $type_, E> for $type_ {
fn from_cuda(x: &'a $type_) -> Result<Self, E> {
Ok(*x)
}
}
impl<'a> FromCuda<'a, *mut $type_> for &'a mut $type_ {
fn from_cuda(x: &'a *mut $type_) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, *mut $type_, E> for &'a mut $type_ {
fn from_cuda(x: &'a *mut $type_) -> Result<Self, E> {
match unsafe { x.as_mut() } {
Some(x) => Ok(x),
None => Err(CUerror::INVALID_VALUE),
None => Err(E::INVALID_VALUE),
}
}
}
impl<'a> FromCuda<'a, *const $type_> for &'a $type_ {
fn from_cuda(x: &'a *const $type_) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, *const $type_, E> for &'a $type_ {
fn from_cuda(x: &'a *const $type_) -> Result<Self, E> {
match unsafe { x.as_ref() } {
Some(x) => Ok(x),
None => Err(CUerror::INVALID_VALUE),
None => Err(E::INVALID_VALUE),
}
}
}
impl<'a> FromCuda<'a, *mut $type_> for Option<&'a mut $type_> {
fn from_cuda(x: &'a *mut $type_) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, *mut $type_, E> for Option<&'a mut $type_> {
fn from_cuda(x: &'a *mut $type_) -> Result<Self, E> {
Ok(unsafe { x.as_mut() })
}
}
@ -69,23 +79,23 @@ macro_rules! from_cuda_nop {
macro_rules! from_cuda_transmute {
($($from:ty => $to:ty),*) => {
$(
impl<'a> FromCuda<'a, $from> for $to {
fn from_cuda(x: &'a $from) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, $from, E> for $to {
fn from_cuda(x: &'a $from) -> Result<Self, E> {
Ok(unsafe { std::mem::transmute(*x) })
}
}
impl<'a> FromCuda<'a, *mut $from> for &'a mut $to {
fn from_cuda(x: &'a *mut $from) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, *mut $from, E> for &'a mut $to {
fn from_cuda(x: &'a *mut $from) -> Result<Self, E> {
match unsafe { x.cast::<$to>().as_mut() } {
Some(x) => Ok(x),
None => Err(CUerror::INVALID_VALUE),
None => Err(E::INVALID_VALUE),
}
}
}
impl<'a> FromCuda<'a, *mut $from> for * mut $to {
fn from_cuda(x: &'a *mut $from) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, *mut $from, E> for * mut $to {
fn from_cuda(x: &'a *mut $from) -> Result<Self, E> {
Ok(x.cast::<$to>())
}
}
@ -96,23 +106,23 @@ macro_rules! from_cuda_transmute {
macro_rules! from_cuda_object {
($($type_:ty),*) => {
$(
impl<'a> FromCuda<'a, <$type_ as ZludaObject>::CudaHandle> for <$type_ as ZludaObject>::CudaHandle {
fn from_cuda(handle: &'a <$type_ as ZludaObject>::CudaHandle) -> Result<<$type_ as ZludaObject>::CudaHandle, CUerror> {
impl<'a> FromCuda<'a, <$type_ as ZludaObject>::CudaHandle, <$type_ as ZludaObject>::Error> for <$type_ as ZludaObject>::CudaHandle {
fn from_cuda(handle: &'a <$type_ as ZludaObject>::CudaHandle) -> Result<<$type_ as ZludaObject>::CudaHandle, <$type_ as ZludaObject>::Error> {
Ok(*handle)
}
}
impl<'a> FromCuda<'a, *mut <$type_ as ZludaObject>::CudaHandle> for &'a mut <$type_ as ZludaObject>::CudaHandle {
fn from_cuda(handle: &'a *mut <$type_ as ZludaObject>::CudaHandle) -> Result<&'a mut <$type_ as ZludaObject>::CudaHandle, CUerror> {
impl<'a> FromCuda<'a, *mut <$type_ as ZludaObject>::CudaHandle, <$type_ as ZludaObject>::Error> for &'a mut <$type_ as ZludaObject>::CudaHandle {
fn from_cuda(handle: &'a *mut <$type_ as ZludaObject>::CudaHandle) -> Result<&'a mut <$type_ as ZludaObject>::CudaHandle, <$type_ as ZludaObject>::Error> {
match unsafe { handle.as_mut() } {
Some(x) => Ok(x),
None => Err(CUerror::INVALID_VALUE),
None => Err(<$type_ as ZludaObject>::Error::INVALID_VALUE),
}
}
}
impl<'a> FromCuda<'a, <$type_ as ZludaObject>::CudaHandle> for &'a $type_ {
fn from_cuda(handle: &'a <$type_ as ZludaObject>::CudaHandle) -> Result<&'a $type_, CUerror> {
impl<'a> FromCuda<'a, <$type_ as ZludaObject>::CudaHandle, <$type_ as ZludaObject>::Error> for &'a $type_ {
fn from_cuda(handle: &'a <$type_ as ZludaObject>::CudaHandle) -> Result<&'a $type_, <$type_ as ZludaObject>::Error> {
Ok(as_ref(handle).as_result()?)
}
}
@ -151,43 +161,46 @@ from_cuda_transmute!(
);
from_cuda_object!(module::Module, context::Context, library::Library);
impl<'a> FromCuda<'a, CUlimit> for hipLimit_t {
fn from_cuda(limit: &'a CUlimit) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {
fn from_cuda(limit: &'a CUlimit) -> Result<Self, E> {
Ok(match *limit {
CUlimit::CU_LIMIT_STACK_SIZE => hipLimit_t::hipLimitStackSize,
CUlimit::CU_LIMIT_PRINTF_FIFO_SIZE => hipLimit_t::hipLimitPrintfFifoSize,
CUlimit::CU_LIMIT_MALLOC_HEAP_SIZE => hipLimit_t::hipLimitMallocHeapSize,
_ => return Err(CUerror::NOT_SUPPORTED),
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a> FromCuda<'a, *const ::core::ffi::c_char> for &CStr {
fn from_cuda(s: &'a *const ::core::ffi::c_char) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, *const ::core::ffi::c_char, E> for &CStr {
fn from_cuda(s: &'a *const ::core::ffi::c_char) -> Result<Self, E> {
if *s != ptr::null() {
Ok(unsafe { CStr::from_ptr(*s) })
} else {
Err(CUerror::INVALID_VALUE)
Err(E::INVALID_VALUE)
}
}
}
impl<'a> FromCuda<'a, *const ::core::ffi::c_void> for &'a ::core::ffi::c_void {
fn from_cuda(x: &'a *const ::core::ffi::c_void) -> Result<Self, CUerror> {
impl<'a, E: CudaErrorType> FromCuda<'a, *const ::core::ffi::c_void, E>
for &'a ::core::ffi::c_void
{
fn from_cuda(x: &'a *const ::core::ffi::c_void) -> Result<Self, E> {
match unsafe { x.as_ref() } {
Some(x) => Ok(x),
None => Err(CUerror::INVALID_VALUE),
None => Err(E::INVALID_VALUE),
}
}
}
pub(crate) trait ZludaObject: Sized + Send + Sync {
const COOKIE: usize;
const LIVENESS_FAIL: CUerror = cuda_types::cuda::CUerror::INVALID_VALUE;
const LIVENESS_FAIL: Self::Error = Self::Error::INVALID_VALUE;
type Error: CudaErrorType;
type CudaHandle: Sized;
fn drop_checked(&mut self) -> CUresult;
fn drop_checked(&mut self) -> Result<(), Self::Error>;
fn wrap(self) -> Self::CudaHandle {
unsafe { mem::transmute_copy(&LiveCheck::wrap(self)) }
@ -216,7 +229,7 @@ impl<T: ZludaObject> LiveCheck<T> {
Box::into_raw(Box::new(Self::new(data)))
}
fn as_result(&self) -> Result<&T, CUerror> {
fn as_result(&self) -> Result<&T, T::Error> {
if self.cookie == T::COOKIE {
Ok(unsafe { self.data.assume_init_ref() })
} else {
@ -229,7 +242,7 @@ impl<T: ZludaObject> LiveCheck<T> {
// Ok(maybe_error) -> meaning that the object is valid, we dropped everything, but there *might*
// an error in the underlying runtime that we want to propagate
#[must_use]
fn drop_checked(&mut self) -> Result<Result<(), CUerror>, CUerror> {
fn drop_checked(&mut self) -> Result<Result<(), T::Error>, T::Error> {
if self.cookie == T::COOKIE {
self.cookie = 0;
let result = unsafe { self.data.assume_init_mut().drop_checked() };
@ -247,7 +260,7 @@ pub fn as_ref<'a, T: ZludaObject>(
unsafe { mem::transmute(handle) }
}
pub fn drop_checked<T: ZludaObject>(handle: T::CudaHandle) -> Result<(), CUerror> {
pub fn drop_checked<T: ZludaObject>(handle: T::CudaHandle) -> Result<(), T::Error> {
let mut wrapped_object: ManuallyDrop<Box<LiveCheck<T>>> =
unsafe { mem::transmute_copy(&handle) };
let underlying_error = LiveCheck::drop_checked(&mut wrapped_object)?;

View File

@ -14,6 +14,7 @@ pub(crate) struct Module {
impl ZludaObject for Module {
const COOKIE: usize = 0xe9138bd040487d4a;
type Error = CUerror;
type CudaHandle = CUmodule;
fn drop_checked(&mut self) -> CUresult {

View File

@ -40,7 +40,7 @@ macro_rules! implemented {
if !initialized() {
return Err(CUerror::DEINITIALIZED);
}
cuda_macros::cuda_normalize_fn!( crate::r#impl::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda(&$arg_id)?),*)?;
cuda_macros::cuda_normalize_fn!( crate::r#impl::$fn_name ) ($(crate::r#impl::FromCuda::<_, CUerror>::from_cuda(&$arg_id)?),*)?;
Ok(())
}
)*
@ -57,7 +57,7 @@ macro_rules! implemented_in_function {
if !initialized() {
return Err(CUerror::DEINITIALIZED);
}
cuda_macros::cuda_normalize_fn!( crate::r#impl::function::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda(&$arg_id)?),*)?;
cuda_macros::cuda_normalize_fn!( crate::r#impl::function::$fn_name ) ($(crate::r#impl::FromCuda::<_, CUerror>::from_cuda(&$arg_id)?),*)?;
Ok(())
}
)*