diff --git a/Cargo.lock b/Cargo.lock index e65e2ea..0787625 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1793,6 +1793,7 @@ dependencies = [ "rustc-hash 1.1.0", "tempfile", "winapi", + "zluda_common", ] [[package]] @@ -1825,6 +1826,14 @@ dependencies = [ "cuda_types", ] +[[package]] +name = "zluda_common" +version = "0.1.0" +dependencies = [ + "cuda_types", + "hip_runtime-sys", +] + [[package]] name = "zluda_dnn" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 5022510..07cdda2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ members = [ "zluda_bindgen", "zluda_blas", "zluda_blaslt", + "zluda_common", "zluda_dnn", "zluda_dump", "zluda_dump_blas", diff --git a/zluda/Cargo.toml b/zluda/Cargo.toml index 61de140..664401c 100644 --- a/zluda/Cargo.toml +++ b/zluda/Cargo.toml @@ -22,6 +22,8 @@ lz4-sys = "1.9" tempfile = "3" paste = "1.0" rustc-hash = "1.1" +dtor = "0.0.6" +zluda_common = { path = "../zluda_common" } [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["heapapi", "std"] } @@ -34,4 +36,4 @@ dtor = "0.0.6" linux_symlinks = [ "libcuda.so", "libcuda.so.1", -] \ No newline at end of file +] diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 9d7bcc0..577a454 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -1,8 +1,9 @@ -use super::{module, FromCuda, ZludaObject}; +use super::module; use cuda_types::cuda::*; use hip_runtime_sys::*; use rustc_hash::{FxHashMap, FxHashSet}; use std::{cell::RefCell, ffi::c_void, ptr, sync::Mutex}; +use zluda_common::{FromCuda, ZludaObject}; thread_local! { pub(crate) static STACK: RefCell> = RefCell::new(Vec::new()); @@ -48,7 +49,7 @@ impl ContextState { self.flags = 0; // drop all modules and return first error if any let result = self.modules.drain().fold(Ok(()), |res: CUresult, hmod| { - match (res, super::drop_checked::(hmod)) { + match (res, zluda_common::drop_checked::(hmod)) { (Err(e), _) => Err(e), (_, Err(e)) => Err(e), _ => Ok(()), @@ -88,6 +89,7 @@ impl Context { impl ZludaObject for Context { const COOKIE: usize = 0x5f867c6d9cb73315; + type Error = CUerror; type CudaHandle = CUcontext; fn drop_checked(&mut self) -> CUresult { @@ -128,7 +130,7 @@ pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult { None }) } else { - let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?; + let ctx: &Context = FromCuda::<_, CUerror>::from_cuda(&raw_ctx)?; let device = ctx.device; STACK.with(move |stack| { let mut stack = stack.borrow_mut(); @@ -157,7 +159,7 @@ pub(crate) fn get_current(pctx: &mut CUcontext) -> CUresult { pub(crate) fn get_device(dev: &mut hipDevice_t) -> CUresult { let cu_ctx = get_current_context()?; - let ctx: &Context = FromCuda::from_cuda(&cu_ctx)?; + let ctx: &Context = FromCuda::<_, CUerror>::from_cuda(&cu_ctx)?; *dev = ctx.device; Ok(()) } @@ -195,7 +197,7 @@ pub(crate) unsafe fn create_v2( } pub(crate) unsafe fn destroy_v2(ctx: CUcontext) -> CUresult { - super::drop_checked::(ctx) + zluda_common::drop_checked::(ctx) } pub(crate) unsafe fn pop_current_v2(ctx: &mut CUcontext) -> CUresult { diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index 7346d62..d7fadab 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -1,4 +1,3 @@ -use super::{FromCuda, LiveCheck}; use crate::r#impl::{context, device}; use comgr::Comgr; use cuda_types::cuda::*; @@ -9,6 +8,7 @@ use std::{ sync::OnceLock, usize, }; +use zluda_common::{FromCuda, LiveCheck}; #[cfg_attr(windows, path = "os_win.rs")] #[cfg_attr(not(windows), path = "os_unix.rs")] @@ -175,7 +175,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { context::get_current(&mut current_ctx)?; current_ctx }; - let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?; + let ctx_obj: &context::Context = FromCuda::<_, CUerror>::from_cuda(&_ctx)?; ctx_obj.with_state_mut(|state: &mut context::ContextState| { state.storage.insert( key as usize, @@ -194,7 +194,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { cu_ctx: CUcontext, key: *mut c_void, ) -> CUresult { - let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?; + let ctx_obj: &context::Context = FromCuda::<_, CUerror>::from_cuda(&cu_ctx)?; ctx_obj.with_state_mut(|state: &mut context::ContextState| { state.storage.remove(&(key as usize)); Ok(()) @@ -213,7 +213,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { } else { _ctx = cu_ctx }; - let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?; + let ctx_obj: &context::Context = FromCuda::<_, CUerror>::from_cuda(&_ctx)?; ctx_obj.with_state(|state: &context::ContextState| { match state.storage.get(&(key as usize)) { Some(data) => *value = data.value as *mut c_void, diff --git a/zluda/src/impl/library.rs b/zluda/src/impl/library.rs index c5f60b1..4a166e9 100644 --- a/zluda/src/impl/library.rs +++ b/zluda/src/impl/library.rs @@ -1,6 +1,7 @@ -use super::{module, ZludaObject}; +use super::module; use cuda_types::cuda::*; use hip_runtime_sys::*; +use zluda_common::ZludaObject; pub(crate) struct Library { base: hipModule_t, @@ -9,6 +10,7 @@ pub(crate) struct Library { impl ZludaObject for Library { const COOKIE: usize = 0xb328a916cc234d7c; + type Error = CUerror; type CudaHandle = CUlibrary; fn drop_checked(&mut self) -> CUresult { @@ -35,7 +37,7 @@ pub(crate) fn load_data( } pub(crate) unsafe fn unload(library: CUlibrary) -> CUresult { - super::drop_checked::(library) + zluda_common::drop_checked::(library) } pub(crate) unsafe fn get_module(out: &mut CUmodule, library: &Library) -> CUresult { diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 02a81bd..21d315d 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -1,10 +1,5 @@ use cuda_types::cuda::*; -use hip_runtime_sys::*; -use std::{ - ffi::CStr, - mem::{self, ManuallyDrop, MaybeUninit}, - ptr, -}; +use zluda_common::from_cuda_object; pub(super) mod context; pub(super) mod device; @@ -26,231 +21,4 @@ pub(crate) fn unimplemented() -> CUresult { CUresult::ERROR_NOT_SUPPORTED } -pub(crate) trait FromCuda<'a, T>: Sized { - fn from_cuda(t: &'a T) -> Result; -} - -macro_rules! from_cuda_nop { - ($($type_:ty),*) => { - $( - impl<'a> FromCuda<'a, $type_> for $type_ { - fn from_cuda(x: &'a $type_) -> Result { - Ok(*x) - } - } - - impl<'a> FromCuda<'a, *mut $type_> for &'a mut $type_ { - fn from_cuda(x: &'a *mut $type_) -> Result { - match unsafe { x.as_mut() } { - Some(x) => Ok(x), - None => Err(CUerror::INVALID_VALUE), - } - } - } - - impl<'a> FromCuda<'a, *const $type_> for &'a $type_ { - fn from_cuda(x: &'a *const $type_) -> Result { - match unsafe { x.as_ref() } { - Some(x) => Ok(x), - None => Err(CUerror::INVALID_VALUE), - } - } - } - - impl<'a> FromCuda<'a, *mut $type_> for Option<&'a mut $type_> { - fn from_cuda(x: &'a *mut $type_) -> Result { - Ok(unsafe { x.as_mut() }) - } - } - )* - }; -} - -macro_rules! from_cuda_transmute { - ($($from:ty => $to:ty),*) => { - $( - impl<'a> FromCuda<'a, $from> for $to { - fn from_cuda(x: &'a $from) -> Result { - Ok(unsafe { std::mem::transmute(*x) }) - } - } - - impl<'a> FromCuda<'a, *mut $from> for &'a mut $to { - fn from_cuda(x: &'a *mut $from) -> Result { - match unsafe { x.cast::<$to>().as_mut() } { - Some(x) => Ok(x), - None => Err(CUerror::INVALID_VALUE), - } - } - } - - impl<'a> FromCuda<'a, *mut $from> for * mut $to { - fn from_cuda(x: &'a *mut $from) -> Result { - Ok(x.cast::<$to>()) - } - } - )* - }; -} - -macro_rules! from_cuda_object { - ($($type_:ty),*) => { - $( - impl<'a> FromCuda<'a, <$type_ as ZludaObject>::CudaHandle> for <$type_ as ZludaObject>::CudaHandle { - fn from_cuda(handle: &'a <$type_ as ZludaObject>::CudaHandle) -> Result<<$type_ as ZludaObject>::CudaHandle, CUerror> { - Ok(*handle) - } - } - - impl<'a> FromCuda<'a, *mut <$type_ as ZludaObject>::CudaHandle> for &'a mut <$type_ as ZludaObject>::CudaHandle { - fn from_cuda(handle: &'a *mut <$type_ as ZludaObject>::CudaHandle) -> Result<&'a mut <$type_ as ZludaObject>::CudaHandle, CUerror> { - match unsafe { handle.as_mut() } { - Some(x) => Ok(x), - None => Err(CUerror::INVALID_VALUE), - } - } - } - - impl<'a> FromCuda<'a, <$type_ as ZludaObject>::CudaHandle> for &'a $type_ { - fn from_cuda(handle: &'a <$type_ as ZludaObject>::CudaHandle) -> Result<&'a $type_, CUerror> { - Ok(as_ref(handle).as_result()?) - } - } - )* - }; -} - -from_cuda_nop!( - *mut i8, - *mut i32, - *mut usize, - *const ::core::ffi::c_void, - *const ::core::ffi::c_char, - *mut ::core::ffi::c_void, - *mut *mut ::core::ffi::c_void, - u8, - i32, - u32, - u64, - usize, - cuda_types::cuda::CUdevprop, - CUdevice_attribute, - CUdriverProcAddressQueryResult, - CUjit_option, - CUlibraryOption, - CUmoduleLoadingMode, - CUuuid -); -from_cuda_transmute!( - CUuuid => hipUUID, - CUfunction => hipFunction_t, - CUfunction_attribute => hipFunction_attribute, - CUstream => hipStream_t, - CUpointer_attribute => hipPointer_attribute, - CUdeviceptr_v2 => hipDeviceptr_t -); from_cuda_object!(module::Module, context::Context, library::Library); - -impl<'a> FromCuda<'a, CUlimit> for hipLimit_t { - fn from_cuda(limit: &'a CUlimit) -> Result { - Ok(match *limit { - CUlimit::CU_LIMIT_STACK_SIZE => hipLimit_t::hipLimitStackSize, - CUlimit::CU_LIMIT_PRINTF_FIFO_SIZE => hipLimit_t::hipLimitPrintfFifoSize, - CUlimit::CU_LIMIT_MALLOC_HEAP_SIZE => hipLimit_t::hipLimitMallocHeapSize, - _ => return Err(CUerror::NOT_SUPPORTED), - }) - } -} - -impl<'a> FromCuda<'a, *const ::core::ffi::c_char> for &CStr { - fn from_cuda(s: &'a *const ::core::ffi::c_char) -> Result { - if *s != ptr::null() { - Ok(unsafe { CStr::from_ptr(*s) }) - } else { - Err(CUerror::INVALID_VALUE) - } - } -} - -impl<'a> FromCuda<'a, *const ::core::ffi::c_void> for &'a ::core::ffi::c_void { - fn from_cuda(x: &'a *const ::core::ffi::c_void) -> Result { - match unsafe { x.as_ref() } { - Some(x) => Ok(x), - None => Err(CUerror::INVALID_VALUE), - } - } -} - -pub(crate) trait ZludaObject: Sized + Send + Sync { - const COOKIE: usize; - const LIVENESS_FAIL: CUerror = cuda_types::cuda::CUerror::INVALID_VALUE; - - type CudaHandle: Sized; - - fn drop_checked(&mut self) -> CUresult; - - fn wrap(self) -> Self::CudaHandle { - unsafe { mem::transmute_copy(&LiveCheck::wrap(self)) } - } -} - -#[repr(C)] -pub(crate) struct LiveCheck { - cookie: usize, - data: MaybeUninit, -} - -impl LiveCheck { - fn new(data: T) -> Self { - LiveCheck { - cookie: T::COOKIE, - data: MaybeUninit::new(data), - } - } - - fn as_handle(&self) -> T::CudaHandle { - unsafe { mem::transmute_copy(&self) } - } - - fn wrap(data: T) -> *mut Self { - Box::into_raw(Box::new(Self::new(data))) - } - - fn as_result(&self) -> Result<&T, CUerror> { - if self.cookie == T::COOKIE { - Ok(unsafe { self.data.assume_init_ref() }) - } else { - Err(T::LIVENESS_FAIL) - } - } - - // This looks like nonsense, but it's not. There are two cases: - // Err(CUerror) -> meaning that the object is invalid, this pointer does not point into valid memory - // Ok(maybe_error) -> meaning that the object is valid, we dropped everything, but there *might* - // an error in the underlying runtime that we want to propagate - #[must_use] - fn drop_checked(&mut self) -> Result, CUerror> { - if self.cookie == T::COOKIE { - self.cookie = 0; - let result = unsafe { self.data.assume_init_mut().drop_checked() }; - unsafe { MaybeUninit::assume_init_drop(&mut self.data) }; - Ok(result) - } else { - Err(T::LIVENESS_FAIL) - } - } -} - -pub fn as_ref<'a, T: ZludaObject>( - handle: &'a T::CudaHandle, -) -> &'a ManuallyDrop>> { - unsafe { mem::transmute(handle) } -} - -pub fn drop_checked(handle: T::CudaHandle) -> Result<(), CUerror> { - let mut wrapped_object: ManuallyDrop>> = - unsafe { mem::transmute_copy(&handle) }; - let underlying_error = LiveCheck::drop_checked(&mut wrapped_object)?; - unsafe { ManuallyDrop::drop(&mut wrapped_object) }; - underlying_error -} diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 797bec0..981ab06 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -1,4 +1,4 @@ -use super::{driver, ZludaObject}; +use super::driver; use cuda_types::{ cuda::*, dark_api::{FatbinFileHeader, FatbincWrapper}, @@ -6,6 +6,7 @@ use cuda_types::{ use dark_api::fatbin::Fatbin; use hip_runtime_sys::*; use std::{ffi::CStr, mem}; +use zluda_common::ZludaObject; pub(crate) struct Module { pub(crate) base: hipModule_t, @@ -14,6 +15,7 @@ pub(crate) struct Module { impl ZludaObject for Module { const COOKIE: usize = 0xe9138bd040487d4a; + type Error = CUerror; type CudaHandle = CUmodule; fn drop_checked(&mut self) -> CUresult { @@ -92,7 +94,7 @@ pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUre } pub(crate) fn unload(hmod: CUmodule) -> CUresult { - super::drop_checked::(hmod) + zluda_common::drop_checked::(hmod) } pub(crate) fn get_function( diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 0de04b1..1fdc1ea 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -39,7 +39,7 @@ macro_rules! implemented { if !initialized() { return Err(CUerror::DEINITIALIZED); } - cuda_macros::cuda_normalize_fn!( crate::r#impl::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda(&$arg_id)?),*)?; + cuda_macros::cuda_normalize_fn!( crate::r#impl::$fn_name ) ($(zluda_common::FromCuda::<_, CUerror>::from_cuda(&$arg_id)?),*)?; Ok(()) } )* @@ -56,7 +56,7 @@ macro_rules! implemented_in_function { if !initialized() { return Err(CUerror::DEINITIALIZED); } - cuda_macros::cuda_normalize_fn!( crate::r#impl::function::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda(&$arg_id)?),*)?; + cuda_macros::cuda_normalize_fn!( crate::r#impl::function::$fn_name ) ($(zluda_common::FromCuda::<_, CUerror>::from_cuda(&$arg_id)?),*)?; Ok(()) } )* diff --git a/zluda_common/Cargo.toml b/zluda_common/Cargo.toml new file mode 100644 index 0000000..8258e83 --- /dev/null +++ b/zluda_common/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "zluda_common" +version = "0.1.0" +authors = ["Violet "] +edition = "2021" + +[dependencies] +cuda_types = { path = "../cuda_types" } +hip_runtime-sys = { path = "../ext/hip_runtime-sys" } diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs new file mode 100644 index 0000000..33a465f --- /dev/null +++ b/zluda_common/src/lib.rs @@ -0,0 +1,278 @@ +use cuda_types::cuda::*; +use hip_runtime_sys::*; +use std::{ + ffi::CStr, + mem::{self, ManuallyDrop, MaybeUninit}, + ptr, +}; + +pub trait CudaErrorType { + const INVALID_VALUE: Self; + const NOT_SUPPORTED: Self; +} + +impl CudaErrorType for CUerror { + const INVALID_VALUE: Self = Self::INVALID_VALUE; + const NOT_SUPPORTED: Self = Self::NOT_SUPPORTED; +} + +/// Used to try to convert CUDA API values into our internal representation. +/// +/// Similar to [`TryFrom`], but we can implement this for primitive types. We also provide conversions from pointers to references. +pub trait FromCuda<'a, T, E: CudaErrorType>: Sized { + /// Tries to convert to this type. + fn from_cuda(t: &'a T) -> Result; +} + +macro_rules! from_cuda_nop { + ($($type_:ty),*) => { + $( + impl<'a, E: CudaErrorType> FromCuda<'a, $type_, E> for $type_ { + fn from_cuda(x: &'a $type_) -> Result { + Ok(*x) + } + } + + impl<'a, E: CudaErrorType> FromCuda<'a, *mut $type_, E> for &'a mut $type_ { + fn from_cuda(x: &'a *mut $type_) -> Result { + match unsafe { x.as_mut() } { + Some(x) => Ok(x), + None => Err(E::INVALID_VALUE), + } + } + } + + impl<'a, E: CudaErrorType> FromCuda<'a, *const $type_, E> for &'a $type_ { + fn from_cuda(x: &'a *const $type_) -> Result { + match unsafe { x.as_ref() } { + Some(x) => Ok(x), + None => Err(E::INVALID_VALUE), + } + } + } + + impl<'a, E: CudaErrorType> FromCuda<'a, *mut $type_, E> for Option<&'a mut $type_> { + fn from_cuda(x: &'a *mut $type_) -> Result { + Ok(unsafe { x.as_mut() }) + } + } + )* + }; +} + +macro_rules! from_cuda_transmute { + ($($from:ty => $to:ty),*) => { + $( + impl<'a, E: CudaErrorType> FromCuda<'a, $from, E> for $to { + fn from_cuda(x: &'a $from) -> Result { + Ok(unsafe { std::mem::transmute(*x) }) + } + } + + impl<'a, E: CudaErrorType> FromCuda<'a, *mut $from, E> for &'a mut $to { + fn from_cuda(x: &'a *mut $from) -> Result { + match unsafe { x.cast::<$to>().as_mut() } { + Some(x) => Ok(x), + None => Err(E::INVALID_VALUE), + } + } + } + + impl<'a, E: CudaErrorType> FromCuda<'a, *mut $from, E> for * mut $to { + fn from_cuda(x: &'a *mut $from) -> Result { + Ok(x.cast::<$to>()) + } + } + )* + }; +} + +/// Implement the [`FromCuda`] trait for a [`ZludaObject`]. +#[macro_export] +macro_rules! from_cuda_object { + ($($type_:ty),*) => { + $( + impl<'a> zluda_common::FromCuda<'a, <$type_ as zluda_common::ZludaObject>::CudaHandle, <$type_ as zluda_common::ZludaObject>::Error> for &'a $type_ { + fn from_cuda(handle: &'a <$type_ as zluda_common::ZludaObject>::CudaHandle) -> Result<&'a $type_, <$type_ as zluda_common::ZludaObject>::Error> { + Ok(zluda_common::as_ref(handle).as_result()?) + } + } + )* + }; +} + +from_cuda_nop!( + *mut i8, + *mut i32, + *mut usize, + *const ::core::ffi::c_void, + *const ::core::ffi::c_char, + *mut ::core::ffi::c_void, + *mut *mut ::core::ffi::c_void, + u8, + i32, + u32, + u64, + usize, + cuda_types::cuda::CUdevprop, + CUdevice_attribute, + CUdriverProcAddressQueryResult, + CUjit_option, + CUlibraryOption, + CUmoduleLoadingMode, + CUuuid, + CUlibrary, + CUmodule, + CUcontext +); +from_cuda_transmute!( + CUuuid => hipUUID, + CUfunction => hipFunction_t, + CUfunction_attribute => hipFunction_attribute, + CUstream => hipStream_t, + CUpointer_attribute => hipPointer_attribute, + CUdeviceptr_v2 => hipDeviceptr_t +); + +impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t { + fn from_cuda(limit: &'a CUlimit) -> Result { + Ok(match *limit { + CUlimit::CU_LIMIT_STACK_SIZE => hipLimit_t::hipLimitStackSize, + CUlimit::CU_LIMIT_PRINTF_FIFO_SIZE => hipLimit_t::hipLimitPrintfFifoSize, + CUlimit::CU_LIMIT_MALLOC_HEAP_SIZE => hipLimit_t::hipLimitMallocHeapSize, + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, *const ::core::ffi::c_char, E> for &CStr { + fn from_cuda(s: &'a *const ::core::ffi::c_char) -> Result { + if *s != ptr::null() { + Ok(unsafe { CStr::from_ptr(*s) }) + } else { + Err(E::INVALID_VALUE) + } + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, *const ::core::ffi::c_void, E> for &'a ::core::ffi::c_void { + fn from_cuda(x: &'a *const ::core::ffi::c_void) -> Result { + match unsafe { x.as_ref() } { + Some(x) => Ok(x), + None => Err(E::INVALID_VALUE), + } + } +} + +/// Represents an object that can be sent across the API boundary. +/// +/// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a +/// module's data and set the `module` output argument to a new `CUmodule`. Then, other functions +/// like `cuModuleGetFunction` can take that `CUmodule` as an argument. +pub trait ZludaObject: Sized + Send + Sync { + /// This is a unique identifier used by [`LiveCheck`] for runtime type and lifetime checking. + /// + /// You can generate a new one with Python: + /// + /// ```python + /// import random + /// hex(random.getrandbits(64)) + /// ``` + const COOKIE: usize; + + /// The value of [`Self::Error`] used to represent a type check failure or use after free. + const LIVENESS_FAIL: Self::Error = Self::Error::INVALID_VALUE; + + /// The error type that should be used. This is generally specific to the library this trait + /// is being implemented in – for example, a [`ZludaObject`] in `zluda` should use the + /// [`CUerror`] type, and a [`ZludaObject`] in `zluda_blas` should use the [`cublasStatus_t`] + /// type. + type Error: CudaErrorType; + /// The handle type that an object of this trait should should be wrapped as. + type CudaHandle: Sized; + + /// Executes the destructor for this type. + fn drop_checked(&mut self) -> Result<(), Self::Error>; + + /// Wraps an object of this trait in a [`LiveCheck`] that can be used for runtime type and + /// lifetime checking, and returns it as an opaque [`Self::CudaHandle`] that can be passed to + /// the API caller. + fn wrap(self) -> Self::CudaHandle { + unsafe { mem::transmute_copy(&LiveCheck::wrap(self)) } + } +} + +/// Wraps a [`ZludaObject`] and provides runtime type and lifetime checking. +/// +/// Arbitrary memory can be cast to a value of this type, and then [`LiveCheck::as_result`] can be +/// used to get the wrapped [`ZludaObject`] value, if it is valid. +#[repr(C)] +pub struct LiveCheck { + cookie: usize, + /// The wrapped [`ZludaObject`]. + pub data: MaybeUninit, +} + +impl LiveCheck { + /// Wraps `data` as a valid, initialized `LiveCheck`. + pub fn new(data: T) -> Self { + LiveCheck { + cookie: T::COOKIE, + data: MaybeUninit::new(data), + } + } + + /// Returns this value as an opaque `T::CudaHandle`. + pub fn as_handle(&self) -> T::CudaHandle { + unsafe { mem::transmute_copy(&self) } + } + + fn wrap(data: T) -> *mut Self { + Box::into_raw(Box::new(Self::new(data))) + } + + /// Checks if this value represents a valid and initialized value of `T` and returns it. + /// Returns `T::LIVENESS_FAIL` if it does not. + pub fn as_result(&self) -> Result<&T, T::Error> { + if self.cookie == T::COOKIE { + Ok(unsafe { self.data.assume_init_ref() }) + } else { + Err(T::LIVENESS_FAIL) + } + } + + // This looks like nonsense, but it's not. There are two cases: + // Err(CUerror) -> meaning that the object is invalid, this pointer does not point into valid memory + // Ok(maybe_error) -> meaning that the object is valid, we dropped everything, but there *might* + // an error in the underlying runtime that we want to propagate + #[must_use] + fn drop_checked(&mut self) -> Result, T::Error> { + if self.cookie == T::COOKIE { + self.cookie = 0; + let result = unsafe { self.data.assume_init_mut().drop_checked() }; + unsafe { MaybeUninit::assume_init_drop(&mut self.data) }; + Ok(result) + } else { + Err(T::LIVENESS_FAIL) + } + } +} + +/// Cast a `T::CudaHandle` reference to a [`LiveCheck`] reference, preserving the lifetime. +pub fn as_ref<'a, T: ZludaObject>( + handle: &'a T::CudaHandle, +) -> &'a ManuallyDrop>> { + unsafe { mem::transmute(handle) } +} + +/// Try to drop `handle`. +/// +/// Returns an error if `handle` is not initialized, not a value of `T`, or if `T::drop_checked` +/// returns an error. +pub fn drop_checked(handle: T::CudaHandle) -> Result<(), T::Error> { + let mut wrapped_object: ManuallyDrop>> = + unsafe { mem::transmute_copy(&handle) }; + let underlying_error = LiveCheck::drop_checked(&mut wrapped_object)?; + unsafe { ManuallyDrop::drop(&mut wrapped_object) }; + underlying_error +}