mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-23 18:08:57 +03:00
Start converting to OpenCL
This commit is contained in:
@ -1,2 +1,5 @@
|
||||
[build]
|
||||
rustflags = ["-C", "target-cpu=haswell"]
|
||||
|
||||
[target."x86_64-pc-windows-gnu"]
|
||||
rustflags = ["-C", "link-self-contained=y"]
|
||||
rustflags = ["-C", "link-self-contained=y", "-C", "target-cpu=haswell"]
|
||||
|
@ -15,6 +15,10 @@ lazy_static = "1.4"
|
||||
num_enum = "0.4"
|
||||
lz4-sys = "1.9"
|
||||
|
||||
[dependencies.ocl-core]
|
||||
version = "0.11"
|
||||
features = ["opencl_version_1_2", "opencl_version_2_0", "opencl_version_2_1"]
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = { version = "0.3", features = ["heapapi", "std"] }
|
||||
|
||||
|
@ -137,11 +137,11 @@ pub fn create_v2(
|
||||
let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| {
|
||||
let dev_ptr = dev as *mut _;
|
||||
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
|
||||
&dev.l0_context,
|
||||
&dev.ocl_context,
|
||||
dev.base,
|
||||
flags,
|
||||
false,
|
||||
dev.host_event_pool.get(dev.base, &dev.l0_context)?,
|
||||
dev.host_event_pool.get(dev.base, &dev.ocl_context)?,
|
||||
dev_ptr as *mut _,
|
||||
)?));
|
||||
ctx_box.late_init();
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::{context, transmute_lifetime, transmute_lifetime_mut, CUresult, GlobalState};
|
||||
use crate::cuda;
|
||||
use cuda::{CUdevice_attribute, CUuuid_st};
|
||||
use ocl_core::DeviceType;
|
||||
use std::{
|
||||
cmp, mem,
|
||||
os::raw::{c_char, c_int, c_uint},
|
||||
@ -18,11 +19,10 @@ pub struct Index(pub c_int);
|
||||
pub struct Device {
|
||||
pub index: Index,
|
||||
pub base: l0::Device,
|
||||
pub default_queue: l0::CommandQueue<'static>,
|
||||
pub l0_context: l0::Context,
|
||||
pub ocl_base: ocl_core::DeviceId,
|
||||
pub default_queue: ocl_core::CommandQueue,
|
||||
pub ocl_context: ocl_core::Context,
|
||||
pub primary_context: context::Context,
|
||||
pub device_event_pool: DynamicEventPool,
|
||||
pub host_event_pool: DynamicEventPool,
|
||||
properties: Option<Box<l0::sys::ze_device_properties_t>>,
|
||||
image_properties: Option<Box<l0::sys::ze_device_image_properties_t>>,
|
||||
memory_properties: Option<Vec<l0::sys::ze_device_memory_properties_t>>,
|
||||
@ -32,41 +32,22 @@ pub struct Device {
|
||||
unsafe impl Send for Device {}
|
||||
|
||||
impl Device {
|
||||
// Unsafe because it does not fully initalize primary_context
|
||||
// and we transmute lifetimes left and right
|
||||
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
|
||||
let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?;
|
||||
let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?;
|
||||
let mut host_event_pool = DynamicEventPool::new(
|
||||
l0_dev,
|
||||
transmute_lifetime(&ctx),
|
||||
l0::sys::ze_event_pool_flags_t::ZE_EVENT_POOL_FLAG_HOST_VISIBLE,
|
||||
l0::sys::ze_event_scope_flags_t::ZE_EVENT_SCOPE_FLAG_HOST,
|
||||
)?;
|
||||
let host_event =
|
||||
transmute_lifetime_mut(&mut host_event_pool).get(l0_dev, transmute_lifetime(&ctx))?;
|
||||
let primary_context = context::Context::new(context::ContextData::new(
|
||||
transmute_lifetime(&ctx),
|
||||
l0_dev,
|
||||
0,
|
||||
true,
|
||||
host_event,
|
||||
ptr::null_mut(),
|
||||
)?);
|
||||
let device_event_pool = DynamicEventPool::new(
|
||||
l0_dev,
|
||||
transmute_lifetime(&ctx),
|
||||
l0::sys::ze_event_pool_flags_t(0),
|
||||
l0::sys::ze_event_scope_flags_t(0),
|
||||
)?;
|
||||
pub fn new(
|
||||
drv: &l0::Driver,
|
||||
l0_dev: l0::Device,
|
||||
ocl_dev: ocl_core::DeviceId,
|
||||
idx: usize,
|
||||
) -> Result<Self, CUresult> {
|
||||
let ctx = ocl_core::create_context(None, &[ocl_dev], None, None)?;
|
||||
let queue = ocl_core::create_command_queue(&ctx, ocl_dev, None)?;
|
||||
let primary_context = context::Context::new(context::ContextData::new());
|
||||
Ok(Self {
|
||||
index: Index(idx as c_int),
|
||||
base: l0_dev,
|
||||
ocl_base: ocl_dev,
|
||||
default_queue: queue,
|
||||
l0_context: ctx,
|
||||
primary_context: primary_context,
|
||||
device_event_pool,
|
||||
host_event_pool,
|
||||
ocl_context: ctx,
|
||||
primary_context,
|
||||
properties: None,
|
||||
image_properties: None,
|
||||
memory_properties: None,
|
||||
@ -111,10 +92,6 @@ impl Device {
|
||||
Ok(self.compute_properties.get_or_insert(Box::new(props)))
|
||||
}
|
||||
|
||||
pub fn late_init(&mut self) {
|
||||
self.primary_context.as_option_mut().unwrap().device = self as *mut _;
|
||||
}
|
||||
|
||||
fn get_max_simd(&mut self) -> l0::Result<u32> {
|
||||
let props = self.get_compute_properties()?;
|
||||
Ok(*props.subGroupSizes[0..props.numSubGroupSizes as usize]
|
||||
@ -124,20 +101,6 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(driver: &l0::Driver) -> Result<Vec<Device>, CUresult> {
|
||||
let ze_devices = driver.devices()?;
|
||||
let mut devices = ze_devices
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, d)| unsafe { Device::new(driver, d, idx) })
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
for dev in devices.iter_mut() {
|
||||
dev.late_init();
|
||||
dev.primary_context.late_init();
|
||||
}
|
||||
Ok(devices)
|
||||
}
|
||||
|
||||
pub fn get_count(count: *mut c_int) -> Result<(), CUresult> {
|
||||
let len = GlobalState::lock(|state| state.devices.len())?;
|
||||
unsafe { *count = len as c_int };
|
||||
@ -215,8 +178,6 @@ impl CUdevice_attribute {
|
||||
match self {
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP => Some(1),
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT => Some(1),
|
||||
// TODO: fix this for DG1
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED => Some(1),
|
||||
// TODO: go back to this once we have more funcitonality implemented
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR => Some(8),
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR => Some(0),
|
||||
@ -239,6 +200,19 @@ pub fn get_attribute(
|
||||
return Ok(());
|
||||
}
|
||||
let value = match attrib {
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED => {
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_properties()?;
|
||||
if (props.flags
|
||||
& l0::sys::ze_device_property_flags_t::ZE_DEVICE_PROPERTY_FLAG_INTEGRATED)
|
||||
== l0::sys::ze_device_property_flags_t::ZE_DEVICE_PROPERTY_FLAG_INTEGRATED
|
||||
{
|
||||
Ok(1)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => {
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_properties()?;
|
||||
|
@ -1,3 +1,5 @@
|
||||
use ocl_core::DeviceId;
|
||||
|
||||
use super::{stream::Stream, CUresult, GlobalState, HasLivenessCookie, LiveCheck};
|
||||
use crate::cuda::CUfunction_attribute;
|
||||
use ::std::os::raw::{c_uint, c_void};
|
||||
@ -24,10 +26,9 @@ impl HasLivenessCookie for FunctionData {
|
||||
}
|
||||
|
||||
pub struct FunctionData {
|
||||
pub base: l0::Kernel<'static>,
|
||||
pub base: ocl_core::Kernel,
|
||||
pub arg_size: Vec<usize>,
|
||||
pub use_shared_mem: bool,
|
||||
pub properties: Option<Box<l0::sys::ze_kernel_properties_t>>,
|
||||
pub legacy_args: LegacyArguments,
|
||||
}
|
||||
|
||||
@ -50,18 +51,6 @@ impl LegacyArguments {
|
||||
}
|
||||
}
|
||||
|
||||
impl FunctionData {
|
||||
fn get_properties(&mut self) -> Result<&l0::sys::ze_kernel_properties_t, l0::sys::ze_result_t> {
|
||||
if let None = self.properties {
|
||||
self.properties = Some(self.base.get_properties()?)
|
||||
}
|
||||
match self.properties {
|
||||
Some(ref props) => Ok(props.as_ref()),
|
||||
None => unsafe { hint::unreachable_unchecked() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn launch_kernel(
|
||||
f: *mut Function,
|
||||
grid_dim_x: c_uint,
|
||||
@ -81,13 +70,16 @@ pub fn launch_kernel(
|
||||
{
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
GlobalState::lock_enqueue(hstream, |cmd_list, signal, wait| {
|
||||
GlobalState::lock_enqueue(hstream, |queue| {
|
||||
let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?;
|
||||
if kernel_params != ptr::null_mut() {
|
||||
for (i, arg_size) in func.arg_size.iter().enumerate() {
|
||||
unsafe {
|
||||
func.base
|
||||
.set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))?
|
||||
ocl_core::set_kernel_arg(
|
||||
&func.base,
|
||||
i as u32,
|
||||
ocl_core::ArgVal::from_raw(*arg_size, *kernel_params.add(i), false),
|
||||
)?;
|
||||
};
|
||||
}
|
||||
} else {
|
||||
@ -120,11 +112,15 @@ pub fn launch_kernel(
|
||||
for (i, arg_size) in func.arg_size.iter().enumerate() {
|
||||
let buffer_offset = round_up_to_multiple(offset, *arg_size);
|
||||
unsafe {
|
||||
func.base.set_arg_raw(
|
||||
ocl_core::set_kernel_arg(
|
||||
&func.base,
|
||||
i as u32,
|
||||
ocl_core::ArgVal::from_raw(
|
||||
*arg_size,
|
||||
buffer_ptr.add(buffer_offset) as *const _,
|
||||
)?
|
||||
false,
|
||||
),
|
||||
)?;
|
||||
};
|
||||
offset = buffer_offset + *arg_size;
|
||||
}
|
||||
@ -134,24 +130,34 @@ pub fn launch_kernel(
|
||||
}
|
||||
if func.use_shared_mem {
|
||||
unsafe {
|
||||
func.base.set_arg_raw(
|
||||
ocl_core::set_kernel_arg(
|
||||
&func.base,
|
||||
func.arg_size.len() as u32,
|
||||
shared_mem_bytes as usize,
|
||||
ptr::null(),
|
||||
)?
|
||||
ocl_core::ArgVal::from_raw(shared_mem_bytes as usize, ptr::null(), false),
|
||||
)?;
|
||||
};
|
||||
}
|
||||
func.base
|
||||
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
|
||||
func.legacy_args.reset();
|
||||
let global_dims = [
|
||||
(block_dim_x * grid_dim_x) as usize,
|
||||
(block_dim_y * grid_dim_y) as usize,
|
||||
(block_dim_z * grid_dim_z) as usize,
|
||||
];
|
||||
unsafe {
|
||||
cmd_list.append_launch_kernel(
|
||||
&mut func.base,
|
||||
&[grid_dim_x, grid_dim_y, grid_dim_z],
|
||||
Some(signal),
|
||||
wait,
|
||||
)?;
|
||||
}
|
||||
ocl_core::enqueue_kernel::<&mut ocl_core::Event, ocl_core::Event>(
|
||||
queue,
|
||||
&func.base,
|
||||
3,
|
||||
None,
|
||||
&global_dims,
|
||||
Some([
|
||||
block_dim_x as usize,
|
||||
block_dim_y as usize,
|
||||
block_dim_z as usize,
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
)?
|
||||
};
|
||||
Ok::<_, CUresult>(())
|
||||
})
|
||||
}
|
||||
@ -171,8 +177,17 @@ pub(crate) fn get_attribute(
|
||||
match attrib {
|
||||
CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
|
||||
let max_threads = GlobalState::lock_function(func, |func| {
|
||||
let props = func.get_properties()?;
|
||||
Ok::<_, CUresult>(props.maxSubgroupSize * props.maxNumSubgroups)
|
||||
if let ocl_core::KernelWorkGroupInfoResult::WorkGroupSize(size) =
|
||||
ocl_core::get_kernel_work_group_info::<ocl_core::DeviceId>(
|
||||
&func.base,
|
||||
unsafe { ocl_core::DeviceId::null() },
|
||||
ocl_core::KernelWorkGroupInfo::WorkGroupSize,
|
||||
)?
|
||||
{
|
||||
Ok(size)
|
||||
} else {
|
||||
Err(CUresult::CUDA_ERROR_UNKNOWN)
|
||||
}
|
||||
})??;
|
||||
unsafe { *pi = max_threads as i32 };
|
||||
Ok(())
|
||||
|
@ -4,7 +4,7 @@ use std::{ffi::c_void, mem};
|
||||
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
|
||||
let ptr = GlobalState::lock_current_context(|ctx| {
|
||||
let dev = unsafe { &mut *ctx.device };
|
||||
Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?)
|
||||
Ok::<_, CUresult>(dev.ocl_context.mem_alloc_device(bytesize, 0, dev.base)?)
|
||||
})??;
|
||||
unsafe { *dptr = ptr };
|
||||
Ok(())
|
||||
@ -20,7 +20,7 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<
|
||||
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
|
||||
GlobalState::lock_current_context(|ctx| {
|
||||
let dev = unsafe { &mut *ctx.device };
|
||||
Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?)
|
||||
Ok::<_, CUresult>(dev.ocl_context.mem_free(ptr)?)
|
||||
})
|
||||
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
|
||||
}
|
||||
|
@ -164,6 +164,14 @@ impl<T> From<TryLockError<T>> for CUresult {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ocl_core::Error> for CUresult {
|
||||
fn from(result: ocl_core::Error) -> Self {
|
||||
match result {
|
||||
_ => CUresult::CUDA_ERROR_UNKNOWN,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Encuda {
|
||||
type To: Sized;
|
||||
fn encuda(self: Self) -> Self::To;
|
||||
@ -207,6 +215,7 @@ lazy_static! {
|
||||
struct GlobalState {
|
||||
devices: Vec<Device>,
|
||||
global_heap: *mut c_void,
|
||||
platform: ocl_core::PlatformId,
|
||||
}
|
||||
|
||||
unsafe impl Send for GlobalState {}
|
||||
@ -275,15 +284,11 @@ impl GlobalState {
|
||||
|
||||
fn lock_enqueue(
|
||||
stream: *mut stream::Stream,
|
||||
f: impl FnOnce(
|
||||
&l0::CommandList,
|
||||
&l0::Event<'static>,
|
||||
&[&l0::Event<'static>],
|
||||
) -> Result<(), CUresult>,
|
||||
f: impl FnOnce(&ocl_core::CommandQueue) -> Result<(), CUresult>,
|
||||
) -> Result<(), CUresult> {
|
||||
Self::lock_stream(stream, |stream_data| {
|
||||
let l0_dev = unsafe { (*(*stream_data.context).device).base };
|
||||
let l0_ctx = unsafe { &mut (*(*stream_data.context).device).l0_context };
|
||||
let l0_ctx = unsafe { &mut (*(*stream_data.context).device).ocl_context };
|
||||
let cmd_list = unsafe { transmute_lifetime(&stream_data.cmd_list) };
|
||||
// TODO: make new_marker drop-safe
|
||||
let (new_event, new_marker) = stream_data.get_event(l0_dev, l0_ctx)?;
|
||||
@ -325,10 +330,34 @@ pub fn init() -> Result<(), CUresult> {
|
||||
return Ok(());
|
||||
}
|
||||
l0::init()?;
|
||||
let platforms = ocl_core::get_platform_ids()?;
|
||||
let (platform, device) = platforms
|
||||
.iter()
|
||||
.find_map(|plat| {
|
||||
let devices =
|
||||
ocl_core::get_device_ids(plat, Some(ocl_core::DeviceType::GPU), None).ok()?;
|
||||
for dev in devices {
|
||||
let vendor = ocl_core::get_device_info(dev, ocl_core::DeviceInfo::VendorId).ok()?;
|
||||
if let ocl_core::DeviceInfoResult::VendorId(0x8086) = vendor {
|
||||
let dev_type =
|
||||
ocl_core::get_device_info(dev, ocl_core::DeviceInfo::Type).ok()?;
|
||||
if let ocl_core::DeviceInfoResult::Type(ocl_core::DeviceType::GPU) = dev_type {
|
||||
return Some((plat.clone(), dev));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
.ok_or(CUresult::CUDA_ERROR_UNKNOWN)?;
|
||||
let drivers = l0::Driver::get()?;
|
||||
let devices = match drivers.into_iter().find(is_intel_gpu_driver) {
|
||||
None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
|
||||
Some(driver) => device::init(&driver)?,
|
||||
Some(driver) => driver
|
||||
.devices()?
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, l0_dev)| device::Device::new(&driver, l0_dev, device, idx).unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
};
|
||||
let global_heap = unsafe { os::heap_create() };
|
||||
if global_heap == ptr::null_mut() {
|
||||
@ -337,6 +366,7 @@ pub fn init() -> Result<(), CUresult> {
|
||||
*global_state = Some(GlobalState {
|
||||
devices,
|
||||
global_heap,
|
||||
platform,
|
||||
});
|
||||
drop(global_state);
|
||||
Ok(())
|
||||
|
@ -1,8 +1,18 @@
|
||||
use std::{
|
||||
collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem,
|
||||
os::raw::c_char, ptr, slice,
|
||||
collections::hash_map,
|
||||
collections::HashMap,
|
||||
ffi::c_void,
|
||||
ffi::CStr,
|
||||
ffi::CString,
|
||||
mem,
|
||||
os::raw::{c_char, c_int, c_uint},
|
||||
ptr, slice,
|
||||
};
|
||||
|
||||
const CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL: u32 = 0x4200;
|
||||
const CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL: u32 = 0x4201;
|
||||
const CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL: u32 = 0x4202;
|
||||
|
||||
use super::{
|
||||
device,
|
||||
function::Function,
|
||||
@ -41,7 +51,7 @@ pub struct SpirvModule {
|
||||
}
|
||||
|
||||
pub struct CompiledModule {
|
||||
pub base: l0::Module<'static>,
|
||||
pub base: ocl_core::Program,
|
||||
pub kernels: HashMap<CString, Box<Function>>,
|
||||
}
|
||||
|
||||
@ -80,28 +90,57 @@ impl SpirvModule {
|
||||
|
||||
pub fn compile<'a>(
|
||||
&self,
|
||||
ctx: &'a l0::Context,
|
||||
dev: l0::Device,
|
||||
) -> Result<l0::Module<'a>, CUresult> {
|
||||
ctx: &ocl_core::Context,
|
||||
dev: &ocl_core::DeviceId,
|
||||
) -> Result<ocl_core::Program, CUresult> {
|
||||
let byte_il = unsafe {
|
||||
slice::from_raw_parts(
|
||||
self.binaries.as_ptr() as *const u8,
|
||||
self.binaries.len() * mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let l0_module = match self.should_link_ptx_impl {
|
||||
None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())),
|
||||
let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
|
||||
match self.should_link_ptx_impl {
|
||||
None => {
|
||||
ocl_core::compile_program(
|
||||
&main_module,
|
||||
Some(&[dev]),
|
||||
&self.build_options,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)?;
|
||||
}
|
||||
Some(ptx_impl) => {
|
||||
l0::Module::build_link_spirv(
|
||||
let ptx_impl_prog = ocl_core::create_program_with_il(ctx, ptx_impl, None)?;
|
||||
ocl_core::build_program(
|
||||
&main_module,
|
||||
Some(&[dev]),
|
||||
&self.build_options,
|
||||
None,
|
||||
None,
|
||||
)?;
|
||||
ocl_core::build_program(
|
||||
&ptx_impl_prog,
|
||||
Some(&[dev]),
|
||||
&self.build_options,
|
||||
None,
|
||||
None,
|
||||
)?;
|
||||
ocl_core::link_program(
|
||||
ctx,
|
||||
dev,
|
||||
&[ptx_impl, byte_il],
|
||||
Some(self.build_options.as_c_str()),
|
||||
)
|
||||
.0
|
||||
Some(&[dev]),
|
||||
&self.build_options,
|
||||
&[&main_module, &ptx_impl_prog],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)?;
|
||||
}
|
||||
};
|
||||
Ok(l0_module?)
|
||||
Ok(main_module)
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,7 +160,9 @@ pub fn get_function(
|
||||
hash_map::Entry::Occupied(entry) => entry.into_mut(),
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let new_module = CompiledModule {
|
||||
base: module.spirv.compile(&mut device.l0_context, device.base)?,
|
||||
base: module
|
||||
.spirv
|
||||
.compile(&device.ocl_context, &device.ocl_base)?,
|
||||
kernels: HashMap::new(),
|
||||
};
|
||||
entry.insert(new_module)
|
||||
@ -137,18 +178,42 @@ pub fn get_function(
|
||||
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
|
||||
})
|
||||
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
|
||||
let kernel =
|
||||
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
|
||||
kernel.set_indirect_access(
|
||||
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
|
||||
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
|
||||
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED
|
||||
let kernel = ocl_core::create_kernel(
|
||||
&compiled_module.base,
|
||||
&entry.key().as_c_str().to_string_lossy(),
|
||||
)?;
|
||||
let true_b: ocl_core::ffi::cl_bool = 1;
|
||||
let err = unsafe {
|
||||
ocl_core::ffi::clSetKernelExecInfo(
|
||||
kernel.as_ptr(),
|
||||
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
|
||||
mem::size_of::<ocl_core::ffi::cl_bool>(),
|
||||
&true_b as *const _ as *const _,
|
||||
)
|
||||
};
|
||||
assert_eq!(err, 0);
|
||||
let err = unsafe {
|
||||
ocl_core::ffi::clSetKernelExecInfo(
|
||||
kernel.as_ptr(),
|
||||
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
|
||||
mem::size_of::<ocl_core::ffi::cl_bool>(),
|
||||
&true_b as *const _ as *const _,
|
||||
)
|
||||
};
|
||||
assert_eq!(err, 0);
|
||||
let err = unsafe {
|
||||
ocl_core::ffi::clSetKernelExecInfo(
|
||||
kernel.as_ptr(),
|
||||
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
|
||||
mem::size_of::<ocl_core::ffi::cl_bool>(),
|
||||
&true_b as *const _ as *const _,
|
||||
)
|
||||
};
|
||||
assert_eq!(err, 0);
|
||||
entry.insert(Box::new(Function::new(FunctionData {
|
||||
base: kernel,
|
||||
arg_size: kernel_info.arguments_sizes.clone(),
|
||||
use_shared_mem: kernel_info.uses_shared_mem,
|
||||
properties: None,
|
||||
legacy_args: LegacyArguments::new(),
|
||||
})))
|
||||
}
|
||||
@ -167,7 +232,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
|
||||
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
|
||||
let module = GlobalState::lock_current_context(|ctx| {
|
||||
let device = unsafe { &mut *ctx.device };
|
||||
let l0_module = spirv_data.compile(&device.l0_context, device.base)?;
|
||||
let l0_module = spirv_data.compile(&device.ocl_context, &device.ocl_base)?;
|
||||
let mut device_binaries = HashMap::new();
|
||||
let compiled_module = CompiledModule {
|
||||
base: l0_module,
|
||||
|
@ -56,7 +56,7 @@ impl StreamData {
|
||||
})
|
||||
}
|
||||
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
|
||||
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
|
||||
let l0_ctx = &mut unsafe { &mut *ctx.device }.ocl_context;
|
||||
let device = unsafe { &*ctx.device }.base;
|
||||
let synchronization_event = unsafe { &mut *ctx.device }
|
||||
.host_event_pool
|
||||
|
Reference in New Issue
Block a user