Add cuCtxCreate_v2 and cuCtxDestroy_v2 (#430)

This commit is contained in:
Andrzej Janik
2025-07-24 02:33:59 +02:00
committed by GitHub
parent 2b90fdb56c
commit 5deada8426
2 changed files with 27 additions and 17 deletions

View File

@ -1,9 +1,8 @@
use super::{FromCuda, ZludaObject, module};
use super::{module, FromCuda, ZludaObject};
use cuda_types::cuda::*;
use hip_runtime_sys::*;
use rustc_hash::{FxHashSet, FxHashMap};
use std::{cell::RefCell, ptr, sync::Mutex, ffi::c_void};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{cell::RefCell, ffi::c_void, ptr, sync::Mutex};
thread_local! {
pub(crate) static STACK: RefCell<Vec<(CUcontext, hipDevice_t)>> = RefCell::new(Vec::new());
@ -48,12 +47,11 @@ impl ContextState {
self.ref_count = 0;
self.flags = 0;
// drop all modules and return first error if any
let result = self.modules.drain().fold(
Ok(()), |res: CUresult, hmod| {
match (res, super::drop_checked::<module::Module>(hmod)) {
(Err(e), _) => Err(e),
(_, Err(e)) => Err(e),
_ => Ok(()),
let result = self.modules.drain().fold(Ok(()), |res: CUresult, hmod| {
match (res, super::drop_checked::<module::Module>(hmod)) {
(Err(e), _) => Err(e),
(_, Err(e)) => Err(e),
_ => Ok(()),
}
});
self.storage.clear();
@ -69,12 +67,9 @@ impl Context {
}
}
pub(crate) fn with_state(
&self,
fn_: impl FnOnce(&ContextState) -> CUresult,
) -> CUresult {
pub(crate) fn with_state(&self, fn_: impl FnOnce(&ContextState) -> CUresult) -> CUresult {
match self.state.lock() {
Ok(guard) => fn_(& *guard),
Ok(guard) => fn_(&*guard),
Err(_) => CUresult::ERROR_UNKNOWN,
}
}
@ -167,7 +162,6 @@ pub(crate) fn get_device(dev: &mut hipDevice_t) -> CUresult {
Ok(())
}
pub(crate) unsafe fn push_current(ctx: CUcontext) -> CUresult {
if ctx == CUcontext(ptr::null_mut()) {
return CUresult::ERROR_INVALID_VALUE;
@ -188,6 +182,20 @@ pub(crate) unsafe fn pop_current(ctx: &mut CUcontext) -> CUresult {
Ok(())
}
pub(crate) unsafe fn create_v2(
ctx: &mut CUcontext,
_flags: ::core::ffi::c_uint,
dev: cuda_types::cuda::CUdevice,
) -> CUresult {
let handle = Context::wrap(Context::new(dev));
*ctx = handle;
Ok(())
}
pub(crate) unsafe fn destroy_v2(ctx: CUcontext) -> CUresult {
super::drop_checked::<Context>(ctx)
}
pub(crate) unsafe fn pop_current_v2(ctx: &mut CUcontext) -> CUresult {
pop_current(ctx)
}
}

View File

@ -61,6 +61,8 @@ macro_rules! implemented_in_function {
cuda_base::cuda_function_declarations!(
unimplemented,
implemented <= [
cuCtxCreate_v2,
cuCtxDestroy_v2,
cuCtxGetLimit,
cuCtxSetCurrent,
cuCtxGetCurrent,