Start converting to OpenCL

This commit is contained in:
Andrzej Janik
2021-07-21 01:46:50 +02:00
parent 58fb8a234c
commit 3d2024bf62
9 changed files with 220 additions and 129 deletions

View File

@ -1,2 +1,5 @@
[build]
rustflags = ["-C", "target-cpu=haswell"]
[target."x86_64-pc-windows-gnu"] [target."x86_64-pc-windows-gnu"]
rustflags = ["-C", "link-self-contained=y"] rustflags = ["-C", "link-self-contained=y", "-C", "target-cpu=haswell"]

View File

@ -15,6 +15,10 @@ lazy_static = "1.4"
num_enum = "0.4" num_enum = "0.4"
lz4-sys = "1.9" lz4-sys = "1.9"
[dependencies.ocl-core]
version = "0.11"
features = ["opencl_version_1_2", "opencl_version_2_0", "opencl_version_2_1"]
[target.'cfg(windows)'.dependencies] [target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["heapapi", "std"] } winapi = { version = "0.3", features = ["heapapi", "std"] }

View File

@ -137,11 +137,11 @@ pub fn create_v2(
let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| { let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| {
let dev_ptr = dev as *mut _; let dev_ptr = dev as *mut _;
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
&dev.l0_context, &dev.ocl_context,
dev.base, dev.base,
flags, flags,
false, false,
dev.host_event_pool.get(dev.base, &dev.l0_context)?, dev.host_event_pool.get(dev.base, &dev.ocl_context)?,
dev_ptr as *mut _, dev_ptr as *mut _,
)?)); )?));
ctx_box.late_init(); ctx_box.late_init();

View File

@ -1,6 +1,7 @@
use super::{context, transmute_lifetime, transmute_lifetime_mut, CUresult, GlobalState}; use super::{context, transmute_lifetime, transmute_lifetime_mut, CUresult, GlobalState};
use crate::cuda; use crate::cuda;
use cuda::{CUdevice_attribute, CUuuid_st}; use cuda::{CUdevice_attribute, CUuuid_st};
use ocl_core::DeviceType;
use std::{ use std::{
cmp, mem, cmp, mem,
os::raw::{c_char, c_int, c_uint}, os::raw::{c_char, c_int, c_uint},
@ -18,11 +19,10 @@ pub struct Index(pub c_int);
pub struct Device { pub struct Device {
pub index: Index, pub index: Index,
pub base: l0::Device, pub base: l0::Device,
pub default_queue: l0::CommandQueue<'static>, pub ocl_base: ocl_core::DeviceId,
pub l0_context: l0::Context, pub default_queue: ocl_core::CommandQueue,
pub ocl_context: ocl_core::Context,
pub primary_context: context::Context, pub primary_context: context::Context,
pub device_event_pool: DynamicEventPool,
pub host_event_pool: DynamicEventPool,
properties: Option<Box<l0::sys::ze_device_properties_t>>, properties: Option<Box<l0::sys::ze_device_properties_t>>,
image_properties: Option<Box<l0::sys::ze_device_image_properties_t>>, image_properties: Option<Box<l0::sys::ze_device_image_properties_t>>,
memory_properties: Option<Vec<l0::sys::ze_device_memory_properties_t>>, memory_properties: Option<Vec<l0::sys::ze_device_memory_properties_t>>,
@ -32,41 +32,22 @@ pub struct Device {
unsafe impl Send for Device {} unsafe impl Send for Device {}
impl Device { impl Device {
// Unsafe because it does not fully initalize primary_context pub fn new(
// and we transmute lifetimes left and right drv: &l0::Driver,
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> { l0_dev: l0::Device,
let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?; ocl_dev: ocl_core::DeviceId,
let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?; idx: usize,
let mut host_event_pool = DynamicEventPool::new( ) -> Result<Self, CUresult> {
l0_dev, let ctx = ocl_core::create_context(None, &[ocl_dev], None, None)?;
transmute_lifetime(&ctx), let queue = ocl_core::create_command_queue(&ctx, ocl_dev, None)?;
l0::sys::ze_event_pool_flags_t::ZE_EVENT_POOL_FLAG_HOST_VISIBLE, let primary_context = context::Context::new(context::ContextData::new());
l0::sys::ze_event_scope_flags_t::ZE_EVENT_SCOPE_FLAG_HOST,
)?;
let host_event =
transmute_lifetime_mut(&mut host_event_pool).get(l0_dev, transmute_lifetime(&ctx))?;
let primary_context = context::Context::new(context::ContextData::new(
transmute_lifetime(&ctx),
l0_dev,
0,
true,
host_event,
ptr::null_mut(),
)?);
let device_event_pool = DynamicEventPool::new(
l0_dev,
transmute_lifetime(&ctx),
l0::sys::ze_event_pool_flags_t(0),
l0::sys::ze_event_scope_flags_t(0),
)?;
Ok(Self { Ok(Self {
index: Index(idx as c_int), index: Index(idx as c_int),
base: l0_dev, base: l0_dev,
ocl_base: ocl_dev,
default_queue: queue, default_queue: queue,
l0_context: ctx, ocl_context: ctx,
primary_context: primary_context, primary_context,
device_event_pool,
host_event_pool,
properties: None, properties: None,
image_properties: None, image_properties: None,
memory_properties: None, memory_properties: None,
@ -111,10 +92,6 @@ impl Device {
Ok(self.compute_properties.get_or_insert(Box::new(props))) Ok(self.compute_properties.get_or_insert(Box::new(props)))
} }
pub fn late_init(&mut self) {
self.primary_context.as_option_mut().unwrap().device = self as *mut _;
}
fn get_max_simd(&mut self) -> l0::Result<u32> { fn get_max_simd(&mut self) -> l0::Result<u32> {
let props = self.get_compute_properties()?; let props = self.get_compute_properties()?;
Ok(*props.subGroupSizes[0..props.numSubGroupSizes as usize] Ok(*props.subGroupSizes[0..props.numSubGroupSizes as usize]
@ -124,20 +101,6 @@ impl Device {
} }
} }
pub fn init(driver: &l0::Driver) -> Result<Vec<Device>, CUresult> {
let ze_devices = driver.devices()?;
let mut devices = ze_devices
.into_iter()
.enumerate()
.map(|(idx, d)| unsafe { Device::new(driver, d, idx) })
.collect::<Result<Vec<_>, _>>()?;
for dev in devices.iter_mut() {
dev.late_init();
dev.primary_context.late_init();
}
Ok(devices)
}
pub fn get_count(count: *mut c_int) -> Result<(), CUresult> { pub fn get_count(count: *mut c_int) -> Result<(), CUresult> {
let len = GlobalState::lock(|state| state.devices.len())?; let len = GlobalState::lock(|state| state.devices.len())?;
unsafe { *count = len as c_int }; unsafe { *count = len as c_int };
@ -215,8 +178,6 @@ impl CUdevice_attribute {
match self { match self {
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP => Some(1), CUdevice_attribute::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP => Some(1),
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT => Some(1), CUdevice_attribute::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT => Some(1),
// TODO: fix this for DG1
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED => Some(1),
// TODO: go back to this once we have more funcitonality implemented // TODO: go back to this once we have more funcitonality implemented
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR => Some(8), CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR => Some(8),
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR => Some(0), CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR => Some(0),
@ -239,6 +200,19 @@ pub fn get_attribute(
return Ok(()); return Ok(());
} }
let value = match attrib { let value = match attrib {
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED => {
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_properties()?;
if (props.flags
& l0::sys::ze_device_property_flags_t::ZE_DEVICE_PROPERTY_FLAG_INTEGRATED)
== l0::sys::ze_device_property_flags_t::ZE_DEVICE_PROPERTY_FLAG_INTEGRATED
{
Ok(1)
} else {
Ok(0)
}
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => {
GlobalState::lock_device(dev_idx, |dev| { GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_properties()?; let props = dev.get_properties()?;

View File

@ -1,3 +1,5 @@
use ocl_core::DeviceId;
use super::{stream::Stream, CUresult, GlobalState, HasLivenessCookie, LiveCheck}; use super::{stream::Stream, CUresult, GlobalState, HasLivenessCookie, LiveCheck};
use crate::cuda::CUfunction_attribute; use crate::cuda::CUfunction_attribute;
use ::std::os::raw::{c_uint, c_void}; use ::std::os::raw::{c_uint, c_void};
@ -24,10 +26,9 @@ impl HasLivenessCookie for FunctionData {
} }
pub struct FunctionData { pub struct FunctionData {
pub base: l0::Kernel<'static>, pub base: ocl_core::Kernel,
pub arg_size: Vec<usize>, pub arg_size: Vec<usize>,
pub use_shared_mem: bool, pub use_shared_mem: bool,
pub properties: Option<Box<l0::sys::ze_kernel_properties_t>>,
pub legacy_args: LegacyArguments, pub legacy_args: LegacyArguments,
} }
@ -50,18 +51,6 @@ impl LegacyArguments {
} }
} }
impl FunctionData {
fn get_properties(&mut self) -> Result<&l0::sys::ze_kernel_properties_t, l0::sys::ze_result_t> {
if let None = self.properties {
self.properties = Some(self.base.get_properties()?)
}
match self.properties {
Some(ref props) => Ok(props.as_ref()),
None => unsafe { hint::unreachable_unchecked() },
}
}
}
pub fn launch_kernel( pub fn launch_kernel(
f: *mut Function, f: *mut Function,
grid_dim_x: c_uint, grid_dim_x: c_uint,
@ -81,13 +70,16 @@ pub fn launch_kernel(
{ {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
GlobalState::lock_enqueue(hstream, |cmd_list, signal, wait| { GlobalState::lock_enqueue(hstream, |queue| {
let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?; let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?;
if kernel_params != ptr::null_mut() { if kernel_params != ptr::null_mut() {
for (i, arg_size) in func.arg_size.iter().enumerate() { for (i, arg_size) in func.arg_size.iter().enumerate() {
unsafe { unsafe {
func.base ocl_core::set_kernel_arg(
.set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))? &func.base,
i as u32,
ocl_core::ArgVal::from_raw(*arg_size, *kernel_params.add(i), false),
)?;
}; };
} }
} else { } else {
@ -120,11 +112,15 @@ pub fn launch_kernel(
for (i, arg_size) in func.arg_size.iter().enumerate() { for (i, arg_size) in func.arg_size.iter().enumerate() {
let buffer_offset = round_up_to_multiple(offset, *arg_size); let buffer_offset = round_up_to_multiple(offset, *arg_size);
unsafe { unsafe {
func.base.set_arg_raw( ocl_core::set_kernel_arg(
&func.base,
i as u32, i as u32,
ocl_core::ArgVal::from_raw(
*arg_size, *arg_size,
buffer_ptr.add(buffer_offset) as *const _, buffer_ptr.add(buffer_offset) as *const _,
)? false,
),
)?;
}; };
offset = buffer_offset + *arg_size; offset = buffer_offset + *arg_size;
} }
@ -134,24 +130,34 @@ pub fn launch_kernel(
} }
if func.use_shared_mem { if func.use_shared_mem {
unsafe { unsafe {
func.base.set_arg_raw( ocl_core::set_kernel_arg(
&func.base,
func.arg_size.len() as u32, func.arg_size.len() as u32,
shared_mem_bytes as usize, ocl_core::ArgVal::from_raw(shared_mem_bytes as usize, ptr::null(), false),
ptr::null(), )?;
)?
}; };
} }
func.base let global_dims = [
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?; (block_dim_x * grid_dim_x) as usize,
func.legacy_args.reset(); (block_dim_y * grid_dim_y) as usize,
(block_dim_z * grid_dim_z) as usize,
];
unsafe { unsafe {
cmd_list.append_launch_kernel( ocl_core::enqueue_kernel::<&mut ocl_core::Event, ocl_core::Event>(
&mut func.base, queue,
&[grid_dim_x, grid_dim_y, grid_dim_z], &func.base,
Some(signal), 3,
wait, None,
)?; &global_dims,
} Some([
block_dim_x as usize,
block_dim_y as usize,
block_dim_z as usize,
]),
None,
None,
)?
};
Ok::<_, CUresult>(()) Ok::<_, CUresult>(())
}) })
} }
@ -171,8 +177,17 @@ pub(crate) fn get_attribute(
match attrib { match attrib {
CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
let max_threads = GlobalState::lock_function(func, |func| { let max_threads = GlobalState::lock_function(func, |func| {
let props = func.get_properties()?; if let ocl_core::KernelWorkGroupInfoResult::WorkGroupSize(size) =
Ok::<_, CUresult>(props.maxSubgroupSize * props.maxNumSubgroups) ocl_core::get_kernel_work_group_info::<ocl_core::DeviceId>(
&func.base,
unsafe { ocl_core::DeviceId::null() },
ocl_core::KernelWorkGroupInfo::WorkGroupSize,
)?
{
Ok(size)
} else {
Err(CUresult::CUDA_ERROR_UNKNOWN)
}
})??; })??;
unsafe { *pi = max_threads as i32 }; unsafe { *pi = max_threads as i32 };
Ok(()) Ok(())

View File

@ -4,7 +4,7 @@ use std::{ffi::c_void, mem};
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> { pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
let ptr = GlobalState::lock_current_context(|ctx| { let ptr = GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device }; let dev = unsafe { &mut *ctx.device };
Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?) Ok::<_, CUresult>(dev.ocl_context.mem_alloc_device(bytesize, 0, dev.base)?)
})??; })??;
unsafe { *dptr = ptr }; unsafe { *dptr = ptr };
Ok(()) Ok(())
@ -20,7 +20,7 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> { pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
GlobalState::lock_current_context(|ctx| { GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device }; let dev = unsafe { &mut *ctx.device };
Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?) Ok::<_, CUresult>(dev.ocl_context.mem_free(ptr)?)
}) })
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)? .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
} }

View File

@ -164,6 +164,14 @@ impl<T> From<TryLockError<T>> for CUresult {
} }
} }
impl From<ocl_core::Error> for CUresult {
fn from(result: ocl_core::Error) -> Self {
match result {
_ => CUresult::CUDA_ERROR_UNKNOWN,
}
}
}
pub trait Encuda { pub trait Encuda {
type To: Sized; type To: Sized;
fn encuda(self: Self) -> Self::To; fn encuda(self: Self) -> Self::To;
@ -207,6 +215,7 @@ lazy_static! {
struct GlobalState { struct GlobalState {
devices: Vec<Device>, devices: Vec<Device>,
global_heap: *mut c_void, global_heap: *mut c_void,
platform: ocl_core::PlatformId,
} }
unsafe impl Send for GlobalState {} unsafe impl Send for GlobalState {}
@ -275,15 +284,11 @@ impl GlobalState {
fn lock_enqueue( fn lock_enqueue(
stream: *mut stream::Stream, stream: *mut stream::Stream,
f: impl FnOnce( f: impl FnOnce(&ocl_core::CommandQueue) -> Result<(), CUresult>,
&l0::CommandList,
&l0::Event<'static>,
&[&l0::Event<'static>],
) -> Result<(), CUresult>,
) -> Result<(), CUresult> { ) -> Result<(), CUresult> {
Self::lock_stream(stream, |stream_data| { Self::lock_stream(stream, |stream_data| {
let l0_dev = unsafe { (*(*stream_data.context).device).base }; let l0_dev = unsafe { (*(*stream_data.context).device).base };
let l0_ctx = unsafe { &mut (*(*stream_data.context).device).l0_context }; let l0_ctx = unsafe { &mut (*(*stream_data.context).device).ocl_context };
let cmd_list = unsafe { transmute_lifetime(&stream_data.cmd_list) }; let cmd_list = unsafe { transmute_lifetime(&stream_data.cmd_list) };
// TODO: make new_marker drop-safe // TODO: make new_marker drop-safe
let (new_event, new_marker) = stream_data.get_event(l0_dev, l0_ctx)?; let (new_event, new_marker) = stream_data.get_event(l0_dev, l0_ctx)?;
@ -325,10 +330,34 @@ pub fn init() -> Result<(), CUresult> {
return Ok(()); return Ok(());
} }
l0::init()?; l0::init()?;
let platforms = ocl_core::get_platform_ids()?;
let (platform, device) = platforms
.iter()
.find_map(|plat| {
let devices =
ocl_core::get_device_ids(plat, Some(ocl_core::DeviceType::GPU), None).ok()?;
for dev in devices {
let vendor = ocl_core::get_device_info(dev, ocl_core::DeviceInfo::VendorId).ok()?;
if let ocl_core::DeviceInfoResult::VendorId(0x8086) = vendor {
let dev_type =
ocl_core::get_device_info(dev, ocl_core::DeviceInfo::Type).ok()?;
if let ocl_core::DeviceInfoResult::Type(ocl_core::DeviceType::GPU) = dev_type {
return Some((plat.clone(), dev));
}
}
}
None
})
.ok_or(CUresult::CUDA_ERROR_UNKNOWN)?;
let drivers = l0::Driver::get()?; let drivers = l0::Driver::get()?;
let devices = match drivers.into_iter().find(is_intel_gpu_driver) { let devices = match drivers.into_iter().find(is_intel_gpu_driver) {
None => return Err(CUresult::CUDA_ERROR_UNKNOWN), None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
Some(driver) => device::init(&driver)?, Some(driver) => driver
.devices()?
.into_iter()
.enumerate()
.map(|(idx, l0_dev)| device::Device::new(&driver, l0_dev, device, idx).unwrap())
.collect::<Vec<_>>(),
}; };
let global_heap = unsafe { os::heap_create() }; let global_heap = unsafe { os::heap_create() };
if global_heap == ptr::null_mut() { if global_heap == ptr::null_mut() {
@ -337,6 +366,7 @@ pub fn init() -> Result<(), CUresult> {
*global_state = Some(GlobalState { *global_state = Some(GlobalState {
devices, devices,
global_heap, global_heap,
platform,
}); });
drop(global_state); drop(global_state);
Ok(()) Ok(())

View File

@ -1,8 +1,18 @@
use std::{ use std::{
collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, collections::hash_map,
os::raw::c_char, ptr, slice, collections::HashMap,
ffi::c_void,
ffi::CStr,
ffi::CString,
mem,
os::raw::{c_char, c_int, c_uint},
ptr, slice,
}; };
const CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL: u32 = 0x4200;
const CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL: u32 = 0x4201;
const CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL: u32 = 0x4202;
use super::{ use super::{
device, device,
function::Function, function::Function,
@ -41,7 +51,7 @@ pub struct SpirvModule {
} }
pub struct CompiledModule { pub struct CompiledModule {
pub base: l0::Module<'static>, pub base: ocl_core::Program,
pub kernels: HashMap<CString, Box<Function>>, pub kernels: HashMap<CString, Box<Function>>,
} }
@ -80,28 +90,57 @@ impl SpirvModule {
pub fn compile<'a>( pub fn compile<'a>(
&self, &self,
ctx: &'a l0::Context, ctx: &ocl_core::Context,
dev: l0::Device, dev: &ocl_core::DeviceId,
) -> Result<l0::Module<'a>, CUresult> { ) -> Result<ocl_core::Program, CUresult> {
let byte_il = unsafe { let byte_il = unsafe {
slice::from_raw_parts( slice::from_raw_parts(
self.binaries.as_ptr() as *const u8, self.binaries.as_ptr() as *const u8,
self.binaries.len() * mem::size_of::<u32>(), self.binaries.len() * mem::size_of::<u32>(),
) )
}; };
let l0_module = match self.should_link_ptx_impl { let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())), match self.should_link_ptx_impl {
None => {
ocl_core::compile_program(
&main_module,
Some(&[dev]),
&self.build_options,
&[],
&[],
None,
None,
None,
)?;
}
Some(ptx_impl) => { Some(ptx_impl) => {
l0::Module::build_link_spirv( let ptx_impl_prog = ocl_core::create_program_with_il(ctx, ptx_impl, None)?;
ocl_core::build_program(
&main_module,
Some(&[dev]),
&self.build_options,
None,
None,
)?;
ocl_core::build_program(
&ptx_impl_prog,
Some(&[dev]),
&self.build_options,
None,
None,
)?;
ocl_core::link_program(
ctx, ctx,
dev, Some(&[dev]),
&[ptx_impl, byte_il], &self.build_options,
Some(self.build_options.as_c_str()), &[&main_module, &ptx_impl_prog],
) None,
.0 None,
None,
)?;
} }
}; };
Ok(l0_module?) Ok(main_module)
} }
} }
@ -121,7 +160,9 @@ pub fn get_function(
hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => { hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule { let new_module = CompiledModule {
base: module.spirv.compile(&mut device.l0_context, device.base)?, base: module
.spirv
.compile(&device.ocl_context, &device.ocl_base)?,
kernels: HashMap::new(), kernels: HashMap::new(),
}; };
entry.insert(new_module) entry.insert(new_module)
@ -137,18 +178,42 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes()) std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
}) })
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?; .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
let kernel = let kernel = ocl_core::create_kernel(
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?; &compiled_module.base,
kernel.set_indirect_access( &entry.key().as_c_str().to_string_lossy(),
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED
)?; )?;
let true_b: ocl_core::ffi::cl_bool = 1;
let err = unsafe {
ocl_core::ffi::clSetKernelExecInfo(
kernel.as_ptr(),
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
mem::size_of::<ocl_core::ffi::cl_bool>(),
&true_b as *const _ as *const _,
)
};
assert_eq!(err, 0);
let err = unsafe {
ocl_core::ffi::clSetKernelExecInfo(
kernel.as_ptr(),
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
mem::size_of::<ocl_core::ffi::cl_bool>(),
&true_b as *const _ as *const _,
)
};
assert_eq!(err, 0);
let err = unsafe {
ocl_core::ffi::clSetKernelExecInfo(
kernel.as_ptr(),
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
mem::size_of::<ocl_core::ffi::cl_bool>(),
&true_b as *const _ as *const _,
)
};
assert_eq!(err, 0);
entry.insert(Box::new(Function::new(FunctionData { entry.insert(Box::new(Function::new(FunctionData {
base: kernel, base: kernel,
arg_size: kernel_info.arguments_sizes.clone(), arg_size: kernel_info.arguments_sizes.clone(),
use_shared_mem: kernel_info.uses_shared_mem, use_shared_mem: kernel_info.uses_shared_mem,
properties: None,
legacy_args: LegacyArguments::new(), legacy_args: LegacyArguments::new(),
}))) })))
} }
@ -167,7 +232,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| { let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device }; let device = unsafe { &mut *ctx.device };
let l0_module = spirv_data.compile(&device.l0_context, device.base)?; let l0_module = spirv_data.compile(&device.ocl_context, &device.ocl_base)?;
let mut device_binaries = HashMap::new(); let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule { let compiled_module = CompiledModule {
base: l0_module, base: l0_module,

View File

@ -56,7 +56,7 @@ impl StreamData {
}) })
} }
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> { pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; let l0_ctx = &mut unsafe { &mut *ctx.device }.ocl_context;
let device = unsafe { &*ctx.device }.base; let device = unsafe { &*ctx.device }.base;
let synchronization_event = unsafe { &mut *ctx.device } let synchronization_event = unsafe { &mut *ctx.device }
.host_event_pool .host_event_pool