Add platform initialization

This commit is contained in:
Andrzej Janik
2020-02-16 15:58:15 +01:00
parent 35caa53c3f
commit 6d748a3959
15 changed files with 7477 additions and 31 deletions

View File

@ -1,6 +1,7 @@
[workspace]
members = [
"level_zero-sys",
"notcuda",
"notcuda_inject",
"notcuda_redirect",

View File

@ -0,0 +1,8 @@
[package]
name = "level_zero-sys"
version = "0.4.1"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2018"
links = "ze_loader"
[lib]

1
level_zero-sys/README Normal file
View File

@ -0,0 +1 @@
bindgen --default-enum-style=rust --whitelist-function ze.* /usr/include/level_zero/zex_api.h -o zex_api.rs -- -x c++ && sed -i 's/pub enum _ze_result_t/#[must_use]\npub enum _ze_result_t/g' zex_api.rs

5
level_zero-sys/build.rs Normal file
View File

@ -0,0 +1,5 @@
fn main() {
println!("cargo:rustc-link-lib=dylib=ze_loader");
println!("cargo:rerun-if-changed=build.rs");
}

View File

@ -0,0 +1,3 @@
#![allow(warnings)]
pub mod zex_api;
pub use zex_api::*;

File diff suppressed because it is too large Load Diff

View File

@ -7,3 +7,7 @@ edition = "2018"
[lib]
name = "nvcuda"
crate-type = ["cdylib"]
[dependencies]
level_zero-sys = { path = "../level_zero-sys" }
lazy_static = "1.4"

View File

@ -79,7 +79,24 @@ pub enum Result {
ERROR_UNKNOWN = 999,
}
#[repr(C)]
pub struct Uuid {
x: [std::os::raw::c_char; 16]
impl Result {
pub fn from_l0(result: l0::ze_result_t) -> Result {
match result {
l0::ze_result_t::ZE_RESULT_SUCCESS => Result::SUCCESS,
l0::ze_result_t::ZE_RESULT_ERROR_UNINITIALIZED => Result::ERROR_NOT_INITIALIZED,
l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ENUMERATION => Result::ERROR_INVALID_VALUE,
l0::ze_result_t::ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY => Result::ERROR_OUT_OF_MEMORY,
_ => Result::ERROR_UNKNOWN
}
}
}
#[repr(C)]
#[derive(PartialEq, Eq)]
pub struct Uuid {
pub x: [std::os::raw::c_uchar; 16]
}
pub struct Device {
base: level_zero_sys::ze_driver_handle_t
}

View File

@ -0,0 +1,58 @@
use super::cu;
use std::mem;
use std::ptr;
#[no_mangle]
pub unsafe extern "stdcall" fn cuGetExportTable(
table: *mut *const std::os::raw::c_void,
id: *const cu::Uuid,
) -> cu::Result {
if *id == GUID0 {
*table = TABLE0.as_ptr() as *const _;
}
return cu::Result::SUCCESS;
}
const GUID0: cu::Uuid = cu::Uuid {
x: [
0xa0, 0x94, 0x79, 0x8c, 0x2e, 0x74, 0x2e, 0x74, 0x93, 0xf2, 0x08, 0x00, 0x20, 0x0c, 0x0a,
0x66,
],
};
#[repr(C)]
union Table0Member {
count: usize,
ptr: *const (),
}
unsafe impl Sync for Table0Member {}
const TABLE0_LEN: usize = 7;
static TABLE0: [Table0Member; TABLE0_LEN] = [
Table0Member {
count: mem::size_of::<[Table0Member; TABLE0_LEN]>(),
},
Table0Member { ptr: ptr::null() },
Table0Member {
ptr: table0_fn1 as *const (),
},
Table0Member { ptr: ptr::null() },
Table0Member { ptr: ptr::null() },
Table0Member { ptr: ptr::null() },
Table0Member {
ptr: table0_fn5 as *const (),
},
];
static mut TABLE0_FN1_SPACE: [u8; 512] = [0; 512];
static mut TABLE0_FN5_SPACE: [u8; 2] = [0; 2];
unsafe extern "stdcall" fn table0_fn1(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
*ptr = TABLE0_FN1_SPACE.as_mut_ptr();
*size = TABLE0_FN1_SPACE.len();
return TABLE0_FN1_SPACE.as_mut_ptr();
}
unsafe extern "stdcall" fn table0_fn5(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
*ptr = TABLE0_FN5_SPACE.as_mut_ptr();
*size = TABLE0_FN5_SPACE.len();
return TABLE0_FN5_SPACE.as_mut_ptr();
}

View File

@ -1,23 +1,76 @@
extern crate level_zero_sys as l0;
#[macro_use]
extern crate lazy_static;
use std::sync::Mutex;
use std::ptr;
mod cu;
mod export_table;
macro_rules! l0_check {
($exp:expr) => {
{
let result = unsafe{ $exp };
if result != l0::ze_result_t::ZE_RESULT_SUCCESS {
return Err(result)
}
}
};
}
lazy_static! {
pub static ref GLOBAL_STATE: Mutex<Option<Driver>> = Mutex::new(None);
}
pub struct Driver {
base: l0::ze_driver_handle_t
}
unsafe impl Send for Driver {}
unsafe impl Sync for Driver {}
impl Driver {
fn new() -> Result<Driver, l0::ze_result_t> {
let mut driver_count = 1;
let mut handle = ptr::null_mut();
l0_check!{ l0::zeDriverGet(&mut driver_count, &mut handle) };
Ok(Driver{ base: handle })
}
}
#[no_mangle]
pub extern "stdcall" fn cuDriverGetVersion(version: &mut std::os::raw::c_int) -> cu::Result {
*version = 0;
*version = i32::max_value();
return cu::Result::SUCCESS;
}
#[no_mangle]
pub extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Result {
return cu::Result::SUCCESS;
}
#[no_mangle]
pub extern "stdcall" fn cuGetExportTable(_: *const *const std::os::raw::c_void, _: cu::Uuid) -> cu::Result {
return cu::Result::ERROR_NOT_SUPPORTED;
pub unsafe extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Result {
let l0_init = l0::zeInit(l0::ze_init_flag_t::ZE_INIT_FLAG_GPU_ONLY);
if l0_init != l0::ze_result_t::ZE_RESULT_SUCCESS {
return cu::Result::from_l0(l0_init);
};
let mut lock = GLOBAL_STATE.try_lock();
if let Ok(ref mut mutex) = lock {
if let None = **mutex {
match Driver::new() {
Ok(state) => **mutex = Some(state),
Err(err) => return cu::Result::from_l0(err)
}
}
} else {
return cu::Result::ERROR_UNKNOWN;
}
cu::Result::SUCCESS
}
#[no_mangle]
pub extern "stdcall" fn cuDeviceGetCount(count: &mut std::os::raw::c_int) -> cu::Result {
*count = 1;
return cu::Result::SUCCESS;
}
#[no_mangle]
pub extern "stdcall" fn cuDeviceGet(device: *mut cu::Device, ordinal: ::std::os::raw::c_int) -> cu::Result {
unimplemented!()
}

View File

@ -6,12 +6,11 @@ edition = "2018"
[[bin]]
name = "notcuda"
path = "src/bin.rs"
[dependencies]
notcuda_redirect = { path = "../notcuda_redirect" }
detours-sys = "0.1"
clap = "2.33"
path = "src/main.rs"
[target.'cfg(windows)'.dependencies]
notcuda_redirect = { path = "../notcuda_redirect" }
winapi = { version = "0.3", features = ["processthreadsapi", "std", "synchapi"] }
detours-sys = "0.1"
clap = "2.33"
guid = "0.1"

View File

@ -1,4 +1,6 @@
extern crate clap;
#[macro_use]
extern crate guid;
extern crate detours_sys;
extern crate winapi;
@ -60,7 +62,7 @@ fn main() -> Result<(), Box<dyn Error>> {
ptr::null_mut(),
ptr::null_mut(),
0,
0x10,
0,
ptr::null_mut(),
ptr::null(),
&mut startup_info as *mut _,
@ -70,10 +72,28 @@ fn main() -> Result<(), Box<dyn Error>> {
),
|x| x != 0
);
let mut exe_path = std::env::current_dir()?
.as_os_str()
.encode_wide()
.collect::<Vec<_>>();
let guid = guid! {"C225FC0C-00D7-40B8-935A-7E342A9344C1"};
os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
mem::transmute(&guid),
exe_path.as_mut_ptr() as *mut _,
(exe_path.len() * mem::size_of::<u16>()) as u32
),
|x| x != 0
);
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x != WAIT_FAILED);
let mut child_exit_code : u32 = 0;
os_call!(GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _), |x| x != 0);
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
!= WAIT_FAILED);
let mut child_exit_code: u32 = 0;
os_call!(
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|x| x != 0
);
std::process::exit(child_exit_code as i32)
}
@ -82,3 +102,5 @@ fn copy_to(from: &OsStr, to: &mut Vec<u16>) {
to.push(x);
}
}
//

View File

@ -0,0 +1,5 @@
#[cfg(target_os = "windows")]
mod bin;
#[cfg(not(target_os = "windows"))]
fn main() {}

View File

@ -7,9 +7,8 @@ edition = "2018"
[lib]
crate-type = ["cdylib"]
[dependencies]
[target.'cfg(windows)'.dependencies]
detours-sys = "0.1"
wchar = "0.6"
[target.'cfg(windows)'.dependencies]
guid = "0.1"
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "std"] }

View File

@ -1,19 +1,26 @@
#![cfg(windows)]
extern crate detours_sys;
#[macro_use]
extern crate guid;
extern crate winapi;
use std::mem;
use detours_sys::{
DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionBegin,
DetourTransactionCommit, DetourUpdateThread,
};
use wchar::wch_c;
use winapi::shared::minwindef::{DWORD, HMODULE, TRUE};
use wchar::{wch, wch_c};
use winapi::shared::minwindef::{DWORD, FALSE, HMODULE, TRUE};
use winapi::um::libloaderapi::LoadLibraryExW;
use winapi::um::processthreadsapi::GetCurrentThread;
use winapi::um::winbase::lstrcmpiW;
use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR};
const NVCUDA_LONG_PATH: &[u16] = wch_c!(r"C:\WINDOWS\system32\nvcuda.dll");
const NVCUDA_SHORT_PATH: &[u16] = wch_c!("nvcuda.dll");
const NVCUDA_PATH: &[u16] = wch_c!(r"C:\WINDOWS\system32\nvcuda.dll");
const NOTCUDA_DLL: &[u16] = wch!(r"nvcuda.dll");
static mut NOTCUDA_PATH: Option<Vec<u16>> = None;
static mut LOAD_LIBRARY_EX: unsafe extern "system" fn(
lpLibFileName: LPCWSTR,
@ -28,8 +35,8 @@ unsafe extern "system" fn NotCudaLoadLibraryExW(
hFile: HANDLE,
dwFlags: DWORD,
) -> HMODULE {
let nvcuda_file_name = if lstrcmpiW(lpLibFileName, NVCUDA_LONG_PATH.as_ptr()) == 0 {
NVCUDA_SHORT_PATH.as_ptr()
let nvcuda_file_name = if lstrcmpiW(lpLibFileName, NVCUDA_PATH.as_ptr()) == 0 {
NOTCUDA_PATH.as_ref().unwrap().as_ptr()
} else {
lpLibFileName
};
@ -41,6 +48,10 @@ unsafe extern "system" fn NotCudaLoadLibraryExW(
unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) -> i32 {
if dwReason == DLL_PROCESS_ATTACH {
DetourRestoreAfterWith();
match get_notcuda_dll_path() {
Some((path, len)) => set_notcuda_dll_path(path, len),
None => return FALSE,
}
DetourTransactionBegin();
DetourUpdateThread(GetCurrentThread());
DetourAttach(
@ -59,3 +70,36 @@ unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) ->
}
TRUE
}
fn get_notcuda_dll_path() -> Option<(*const u16, usize)> {
let guid = guid! {"C225FC0C-00D7-40B8-935A-7E342A9344C1"};
let mut module = std::ptr::null_mut();
loop {
module = unsafe { detours_sys::DetourEnumerateModules(module) };
if module == std::ptr::null_mut() {
break;
}
let mut size = 0;
let payload = unsafe {
detours_sys::DetourFindPayload(module, std::mem::transmute(&guid), &mut size)
};
if payload != std::ptr::null_mut() {
return Some((payload as *const _, (size as usize) / mem::size_of::<u16>()));
}
}
None
}
unsafe fn set_notcuda_dll_path(path: *const u16, len: usize) {
let len = len as usize;
let mut result = Vec::<u16>::with_capacity(len + NOTCUDA_DLL.len() + 2);
for i in 0..len {
result.push(*path.add(i));
}
result.push(0x5c); // \
for c in NOTCUDA_DLL.iter().copied() {
result.push(c);
}
result.push(0);
NOTCUDA_PATH = Some(result);
}