Fix minor problems with a private CUDA function

This commit is contained in:
Andrzej Janik
2020-09-24 02:20:54 +02:00
parent 3f41f21acb
commit 42bcd999eb

View File

@ -323,6 +323,9 @@ fn context_local_storage_ctor_impl(
if cu_ctx == ptr::null_mut() { if cu_ctx == ptr::null_mut() {
context::get_current(&mut cu_ctx)?; context::get_current(&mut cu_ctx)?;
} }
if cu_ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
unsafe { &*cu_ctx } unsafe { &*cu_ctx }
.as_ref() .as_ref()
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
@ -354,9 +357,12 @@ unsafe extern "C" fn context_local_storage_get_state(
fn context_local_storage_get_state_impl( fn context_local_storage_get_state_impl(
ctx_state: *mut *mut cuda_impl::rt::ContextState, ctx_state: *mut *mut cuda_impl::rt::ContextState,
cu_ctx: *mut context::Context, mut cu_ctx: *mut context::Context,
_: *mut cuda_impl::rt::ContextStateManager, _: *mut cuda_impl::rt::ContextStateManager,
) -> Result<(), CUresult> { ) -> Result<(), CUresult> {
if cu_ctx == ptr::null_mut() {
context::get_current(&mut cu_ctx)?;
}
if cu_ctx == ptr::null_mut() { if cu_ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
@ -369,6 +375,10 @@ fn context_local_storage_get_state_impl(
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
.map(|mutable| mutable.cuda_state) .map(|mutable| mutable.cuda_state)
})?; })?;
unsafe { *ctx_state = cuda_state }; if cuda_state == ptr::null_mut() {
Ok(()) Err(CUresult::CUDA_ERROR_INVALID_VALUE)
} else {
unsafe { *ctx_state = cuda_state };
Ok(())
}
} }