diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs index 9a6d72c..562af37 100644 --- a/notcuda/src/impl/export_table.rs +++ b/notcuda/src/impl/export_table.rs @@ -323,6 +323,9 @@ fn context_local_storage_ctor_impl( if cu_ctx == ptr::null_mut() { context::get_current(&mut cu_ctx)?; } + if cu_ctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } unsafe { &*cu_ctx } .as_ref() .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) @@ -354,9 +357,12 @@ unsafe extern "C" fn context_local_storage_get_state( fn context_local_storage_get_state_impl( ctx_state: *mut *mut cuda_impl::rt::ContextState, - cu_ctx: *mut context::Context, + mut cu_ctx: *mut context::Context, _: *mut cuda_impl::rt::ContextStateManager, ) -> Result<(), CUresult> { + if cu_ctx == ptr::null_mut() { + context::get_current(&mut cu_ctx)?; + } if cu_ctx == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } @@ -369,6 +375,10 @@ fn context_local_storage_get_state_impl( .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) .map(|mutable| mutable.cuda_state) })?; - unsafe { *ctx_state = cuda_state }; - Ok(()) + if cuda_state == ptr::null_mut() { + Err(CUresult::CUDA_ERROR_INVALID_VALUE) + } else { + unsafe { *ctx_state = cuda_state }; + Ok(()) + } }