diff --git a/notcuda/src/cu.rs b/notcuda/src/cu.rs index a3a515b..df07099 100644 --- a/notcuda/src/cu.rs +++ b/notcuda/src/cu.rs @@ -85,6 +85,7 @@ impl Result { l0::ze_result_t::ZE_RESULT_SUCCESS => Result::SUCCESS, l0::ze_result_t::ZE_RESULT_ERROR_UNINITIALIZED => Result::ERROR_NOT_INITIALIZED, l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ENUMERATION => Result::ERROR_INVALID_VALUE, + l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT => Result::ERROR_INVALID_VALUE, l0::ze_result_t::ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY => Result::ERROR_OUT_OF_MEMORY, _ => Result::ERROR_UNKNOWN } @@ -95,8 +96,4 @@ impl Result { #[derive(PartialEq, Eq)] pub struct Uuid { pub x: [std::os::raw::c_uchar; 16] -} - -pub struct Device { - base: level_zero_sys::ze_driver_handle_t } \ No newline at end of file diff --git a/notcuda/src/lib.rs b/notcuda/src/lib.rs index c491dbb..29c5a70 100644 --- a/notcuda/src/lib.rs +++ b/notcuda/src/lib.rs @@ -37,6 +37,39 @@ impl Driver { l0_check!{ l0::zeDriverGet(&mut driver_count, &mut handle) }; Ok(Driver{ base: handle }) } + + fn call l0::ze_result_t>(f: F) -> cu::Result { + let mut lock = GLOBAL_STATE.try_lock(); + if let Ok(ref mut mutex) = lock { + match **mutex { + None => return cu::Result::ERROR_NOT_INITIALIZED, + Some(ref mut driver) => { + return cu::Result::from_l0(f(driver)); + } + } + } else { + return cu::Result::ERROR_UNKNOWN; + } + } + + fn device_get_count(&self, count: *mut i32) -> l0::ze_result_t { + unsafe { l0::zeDeviceGet(self.base, count as *mut _ as *mut _, ptr::null_mut()) } + } + + fn device_get(&self, device: *mut l0::ze_device_handle_t, ordinal: ::std::os::raw::c_int) -> l0::ze_result_t { + let count = (ordinal as u32) + 1; + let mut devices_found = count; + let mut handles = vec![ptr::null_mut(); count as usize]; + let result = unsafe { l0::zeDeviceGet(self.base, &mut devices_found, handles.as_mut_ptr()) }; + if result != l0::ze_result_t::ZE_RESULT_SUCCESS { + return result; + } + if devices_found < count { + return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT; + } + unsafe { *device = handles[(count as usize) - 1] }; + l0::ze_result_t::ZE_RESULT_SUCCESS + } } #[no_mangle] @@ -66,11 +99,11 @@ pub unsafe extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Res } #[no_mangle] -pub extern "stdcall" fn cuDeviceGetCount(count: &mut std::os::raw::c_int) -> cu::Result { - return cu::Result::SUCCESS; +pub extern "stdcall" fn cuDeviceGetCount(count: *mut std::os::raw::c_int) -> cu::Result { + Driver::call(|driver| driver.device_get_count(count)) } #[no_mangle] -pub extern "stdcall" fn cuDeviceGet(device: *mut cu::Device, ordinal: ::std::os::raw::c_int) -> cu::Result { - unimplemented!() +pub extern "stdcall" fn cuDeviceGet(device: *mut l0::ze_device_handle_t, ordinal: ::std::os::raw::c_int) -> cu::Result { + Driver::call(|driver| driver.device_get(device, ordinal)) } \ No newline at end of file