extern crate level_zero_sys as l0; #[macro_use] extern crate lazy_static; use std::sync::Mutex; use std::ptr; use std::os::raw::{c_char, c_int, c_uint}; mod cu; mod export_table; mod ze; macro_rules! l0_check_err { ($exp:expr) => { { let result = unsafe{ $exp }; if result != l0::ze_result_t::ZE_RESULT_SUCCESS { return Err(result); } } }; } lazy_static! { pub static ref GLOBAL_STATE: Mutex> = Mutex::new(None); } pub struct Driver { base: l0::ze_driver_handle_t, devices: Vec:: } unsafe impl Send for Driver {} unsafe impl Sync for Driver {} impl Driver { fn new() -> Result { let mut driver_count = 1; let mut handle = ptr::null_mut(); l0_check_err!{ l0::zeDriverGet(&mut driver_count, &mut handle) }; let mut count = 0; l0_check_err! { l0::zeDeviceGet(handle, &mut count, ptr::null_mut()) } let mut devices = vec![ptr::null_mut(); count as usize]; l0_check_err! { l0::zeDeviceGet(handle, &mut count, devices.as_mut_ptr()) } if (count as usize) < devices.len() { devices.truncate(count as usize); } Ok(Driver{ base: handle, devices: ze::Device::new_vec(devices) }) } fn call l0::ze_result_t>(f: F) -> cu::Result { let mut lock = GLOBAL_STATE.try_lock(); if let Ok(ref mut mutex) = lock { match **mutex { None => return cu::Result::ERROR_NOT_INITIALIZED, Some(ref mut driver) => { return cu::Result::from_l0(f(driver)); } } } else { return cu::Result::ERROR_UNKNOWN; } } fn call_device l0::ze_result_t>(cu::Device(dev): cu::Device, f: F) -> cu::Result { if dev < 0 { return cu::Result::ERROR_INVALID_VALUE; } let dev = dev as usize; Driver::call(|driver| { if dev >= driver.devices.len() { return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT; } f(&mut driver.devices[dev]) }) } fn device_get_count(&self, count: *mut i32) -> l0::ze_result_t { unsafe { *count = self.devices.len() as i32 }; l0::ze_result_t::ZE_RESULT_SUCCESS } fn device_get(&self, device: *mut cu::Device, ordinal: c_int) -> l0::ze_result_t { if (ordinal as usize) >= self.devices.len() { return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT; } unsafe { *device = cu::Device(ordinal) }; l0::ze_result_t::ZE_RESULT_SUCCESS } } #[no_mangle] pub unsafe extern "C" fn cuDriverGetVersion(version: *mut c_int) -> cu::Result { if version == ptr::null_mut() { return cu::Result::ERROR_INVALID_VALUE; } *version = i32::max_value(); return cu::Result::SUCCESS; } #[no_mangle] pub unsafe extern "C" fn cuInit(_: c_uint) -> cu::Result { let l0_init = l0::zeInit(l0::ze_init_flag_t::ZE_INIT_FLAG_GPU_ONLY); if l0_init != l0::ze_result_t::ZE_RESULT_SUCCESS { return cu::Result::from_l0(l0_init); } let mut lock = GLOBAL_STATE.try_lock(); if let Ok(ref mut mutex) = lock { if let None = **mutex { match Driver::new() { Ok(state) => **mutex = Some(state), Err(err) => return cu::Result::from_l0(err) } } } else { return cu::Result::ERROR_UNKNOWN; } cu::Result::SUCCESS } #[no_mangle] pub extern "C" fn cuDeviceGetCount(count: *mut c_int) -> cu::Result { if count == ptr::null_mut() { return cu::Result::ERROR_INVALID_VALUE; } Driver::call(|driver| driver.device_get_count(count)) } #[no_mangle] pub extern "C" fn cuDeviceGet(device: *mut cu::Device, ordinal: c_int) -> cu::Result { if ordinal < 0 || device == ptr::null_mut() { return cu::Result::ERROR_INVALID_VALUE; } Driver::call(|driver| driver.device_get(device, ordinal)) } #[no_mangle] pub extern "C" fn cuDeviceGetName(name: *mut c_char, len: c_int, dev_idx: cu::Device) -> cu::Result { if name == ptr::null_mut() || len <= 0 { return cu::Result::ERROR_INVALID_VALUE; } Driver::call_device(dev_idx, |dev| dev.get_name(name, len)) } #[no_mangle] pub extern "C" fn cuDeviceTotalMem_v2(bytes: *mut usize, dev_idx: cu::Device) -> cu::Result { if bytes == ptr::null_mut() { return cu::Result::ERROR_INVALID_VALUE; } Driver::call_device(dev_idx, |dev| dev.total_mem(bytes)) } /* #[no_mangle] pub extern "C" fn cuDeviceGetAttribute(pi: *mut c_int, attrib: cu::DeviceAttribute, dev: cu::Device) -> cu::Result { let cu::Device(dev_idx) = dev; if pi == ptr::null_mut() || dev_idx < 0 { return cu::Result::ERROR_INVALID_VALUE; } Driver::call(|driver| driver.device_get_attribute(bytes, dev)) } */