mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-23 01:48:56 +03:00
Add platform initialization
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
[workspace]
|
||||
|
||||
members = [
|
||||
"level_zero-sys",
|
||||
"notcuda",
|
||||
"notcuda_inject",
|
||||
"notcuda_redirect",
|
||||
|
8
level_zero-sys/Cargo.toml
Normal file
8
level_zero-sys/Cargo.toml
Normal file
@ -0,0 +1,8 @@
|
||||
[package]
|
||||
name = "level_zero-sys"
|
||||
version = "0.4.1"
|
||||
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||
edition = "2018"
|
||||
links = "ze_loader"
|
||||
|
||||
[lib]
|
1
level_zero-sys/README
Normal file
1
level_zero-sys/README
Normal file
@ -0,0 +1 @@
|
||||
bindgen --default-enum-style=rust --whitelist-function ze.* /usr/include/level_zero/zex_api.h -o zex_api.rs -- -x c++ && sed -i 's/pub enum _ze_result_t/#[must_use]\npub enum _ze_result_t/g' zex_api.rs
|
5
level_zero-sys/build.rs
Normal file
5
level_zero-sys/build.rs
Normal file
@ -0,0 +1,5 @@
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rustc-link-lib=dylib=ze_loader");
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
}
|
3
level_zero-sys/src/lib.rs
Normal file
3
level_zero-sys/src/lib.rs
Normal file
@ -0,0 +1,3 @@
|
||||
#![allow(warnings)]
|
||||
pub mod zex_api;
|
||||
pub use zex_api::*;
|
7227
level_zero-sys/src/zex_api.rs
Normal file
7227
level_zero-sys/src/zex_api.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -7,3 +7,7 @@ edition = "2018"
|
||||
[lib]
|
||||
name = "nvcuda"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
level_zero-sys = { path = "../level_zero-sys" }
|
||||
lazy_static = "1.4"
|
@ -79,7 +79,24 @@ pub enum Result {
|
||||
ERROR_UNKNOWN = 999,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Uuid {
|
||||
x: [std::os::raw::c_char; 16]
|
||||
impl Result {
|
||||
pub fn from_l0(result: l0::ze_result_t) -> Result {
|
||||
match result {
|
||||
l0::ze_result_t::ZE_RESULT_SUCCESS => Result::SUCCESS,
|
||||
l0::ze_result_t::ZE_RESULT_ERROR_UNINITIALIZED => Result::ERROR_NOT_INITIALIZED,
|
||||
l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ENUMERATION => Result::ERROR_INVALID_VALUE,
|
||||
l0::ze_result_t::ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY => Result::ERROR_OUT_OF_MEMORY,
|
||||
_ => Result::ERROR_UNKNOWN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct Uuid {
|
||||
pub x: [std::os::raw::c_uchar; 16]
|
||||
}
|
||||
|
||||
pub struct Device {
|
||||
base: level_zero_sys::ze_driver_handle_t
|
||||
}
|
58
notcuda/src/export_table.rs
Normal file
58
notcuda/src/export_table.rs
Normal file
@ -0,0 +1,58 @@
|
||||
use super::cu;
|
||||
|
||||
use std::mem;
|
||||
use std::ptr;
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "stdcall" fn cuGetExportTable(
|
||||
table: *mut *const std::os::raw::c_void,
|
||||
id: *const cu::Uuid,
|
||||
) -> cu::Result {
|
||||
if *id == GUID0 {
|
||||
*table = TABLE0.as_ptr() as *const _;
|
||||
}
|
||||
return cu::Result::SUCCESS;
|
||||
}
|
||||
|
||||
const GUID0: cu::Uuid = cu::Uuid {
|
||||
x: [
|
||||
0xa0, 0x94, 0x79, 0x8c, 0x2e, 0x74, 0x2e, 0x74, 0x93, 0xf2, 0x08, 0x00, 0x20, 0x0c, 0x0a,
|
||||
0x66,
|
||||
],
|
||||
};
|
||||
#[repr(C)]
|
||||
union Table0Member {
|
||||
count: usize,
|
||||
ptr: *const (),
|
||||
}
|
||||
unsafe impl Sync for Table0Member {}
|
||||
const TABLE0_LEN: usize = 7;
|
||||
static TABLE0: [Table0Member; TABLE0_LEN] = [
|
||||
Table0Member {
|
||||
count: mem::size_of::<[Table0Member; TABLE0_LEN]>(),
|
||||
},
|
||||
Table0Member { ptr: ptr::null() },
|
||||
Table0Member {
|
||||
ptr: table0_fn1 as *const (),
|
||||
},
|
||||
Table0Member { ptr: ptr::null() },
|
||||
Table0Member { ptr: ptr::null() },
|
||||
Table0Member { ptr: ptr::null() },
|
||||
Table0Member {
|
||||
ptr: table0_fn5 as *const (),
|
||||
},
|
||||
];
|
||||
static mut TABLE0_FN1_SPACE: [u8; 512] = [0; 512];
|
||||
static mut TABLE0_FN5_SPACE: [u8; 2] = [0; 2];
|
||||
|
||||
unsafe extern "stdcall" fn table0_fn1(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
|
||||
*ptr = TABLE0_FN1_SPACE.as_mut_ptr();
|
||||
*size = TABLE0_FN1_SPACE.len();
|
||||
return TABLE0_FN1_SPACE.as_mut_ptr();
|
||||
}
|
||||
|
||||
unsafe extern "stdcall" fn table0_fn5(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
|
||||
*ptr = TABLE0_FN5_SPACE.as_mut_ptr();
|
||||
*size = TABLE0_FN5_SPACE.len();
|
||||
return TABLE0_FN5_SPACE.as_mut_ptr();
|
||||
}
|
@ -1,23 +1,76 @@
|
||||
extern crate level_zero_sys as l0;
|
||||
#[macro_use]
|
||||
extern crate lazy_static;
|
||||
|
||||
use std::sync::Mutex;
|
||||
use std::ptr;
|
||||
|
||||
mod cu;
|
||||
mod export_table;
|
||||
|
||||
macro_rules! l0_check {
|
||||
($exp:expr) => {
|
||||
{
|
||||
let result = unsafe{ $exp };
|
||||
if result != l0::ze_result_t::ZE_RESULT_SUCCESS {
|
||||
return Err(result)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
pub static ref GLOBAL_STATE: Mutex<Option<Driver>> = Mutex::new(None);
|
||||
}
|
||||
|
||||
pub struct Driver {
|
||||
base: l0::ze_driver_handle_t
|
||||
}
|
||||
|
||||
unsafe impl Send for Driver {}
|
||||
unsafe impl Sync for Driver {}
|
||||
|
||||
impl Driver {
|
||||
fn new() -> Result<Driver, l0::ze_result_t> {
|
||||
let mut driver_count = 1;
|
||||
let mut handle = ptr::null_mut();
|
||||
l0_check!{ l0::zeDriverGet(&mut driver_count, &mut handle) };
|
||||
Ok(Driver{ base: handle })
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "stdcall" fn cuDriverGetVersion(version: &mut std::os::raw::c_int) -> cu::Result {
|
||||
*version = 0;
|
||||
*version = i32::max_value();
|
||||
return cu::Result::SUCCESS;
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Result {
|
||||
return cu::Result::SUCCESS;
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "stdcall" fn cuGetExportTable(_: *const *const std::os::raw::c_void, _: cu::Uuid) -> cu::Result {
|
||||
return cu::Result::ERROR_NOT_SUPPORTED;
|
||||
pub unsafe extern "stdcall" fn cuInit(_: *const std::os::raw::c_uint) -> cu::Result {
|
||||
let l0_init = l0::zeInit(l0::ze_init_flag_t::ZE_INIT_FLAG_GPU_ONLY);
|
||||
if l0_init != l0::ze_result_t::ZE_RESULT_SUCCESS {
|
||||
return cu::Result::from_l0(l0_init);
|
||||
};
|
||||
let mut lock = GLOBAL_STATE.try_lock();
|
||||
if let Ok(ref mut mutex) = lock {
|
||||
if let None = **mutex {
|
||||
match Driver::new() {
|
||||
Ok(state) => **mutex = Some(state),
|
||||
Err(err) => return cu::Result::from_l0(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return cu::Result::ERROR_UNKNOWN;
|
||||
}
|
||||
cu::Result::SUCCESS
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "stdcall" fn cuDeviceGetCount(count: &mut std::os::raw::c_int) -> cu::Result {
|
||||
*count = 1;
|
||||
return cu::Result::SUCCESS;
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "stdcall" fn cuDeviceGet(device: *mut cu::Device, ordinal: ::std::os::raw::c_int) -> cu::Result {
|
||||
unimplemented!()
|
||||
}
|
@ -6,12 +6,11 @@ edition = "2018"
|
||||
|
||||
[[bin]]
|
||||
name = "notcuda"
|
||||
path = "src/bin.rs"
|
||||
|
||||
[dependencies]
|
||||
notcuda_redirect = { path = "../notcuda_redirect" }
|
||||
detours-sys = "0.1"
|
||||
clap = "2.33"
|
||||
path = "src/main.rs"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
notcuda_redirect = { path = "../notcuda_redirect" }
|
||||
winapi = { version = "0.3", features = ["processthreadsapi", "std", "synchapi"] }
|
||||
detours-sys = "0.1"
|
||||
clap = "2.33"
|
||||
guid = "0.1"
|
@ -1,4 +1,6 @@
|
||||
extern crate clap;
|
||||
#[macro_use]
|
||||
extern crate guid;
|
||||
extern crate detours_sys;
|
||||
extern crate winapi;
|
||||
|
||||
@ -60,7 +62,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
ptr::null_mut(),
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
0x10,
|
||||
0,
|
||||
ptr::null_mut(),
|
||||
ptr::null(),
|
||||
&mut startup_info as *mut _,
|
||||
@ -70,10 +72,28 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
let mut exe_path = std::env::current_dir()?
|
||||
.as_os_str()
|
||||
.encode_wide()
|
||||
.collect::<Vec<_>>();
|
||||
let guid = guid! {"C225FC0C-00D7-40B8-935A-7E342A9344C1"};
|
||||
os_call!(
|
||||
detours_sys::DetourCopyPayloadToProcess(
|
||||
proc_info.hProcess,
|
||||
mem::transmute(&guid),
|
||||
exe_path.as_mut_ptr() as *mut _,
|
||||
(exe_path.len() * mem::size_of::<u16>()) as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
|
||||
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x != WAIT_FAILED);
|
||||
let mut child_exit_code : u32 = 0;
|
||||
os_call!(GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _), |x| x != 0);
|
||||
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
|
||||
!= WAIT_FAILED);
|
||||
let mut child_exit_code: u32 = 0;
|
||||
os_call!(
|
||||
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|
||||
|x| x != 0
|
||||
);
|
||||
std::process::exit(child_exit_code as i32)
|
||||
}
|
||||
|
||||
@ -82,3 +102,5 @@ fn copy_to(from: &OsStr, to: &mut Vec<u16>) {
|
||||
to.push(x);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
5
notcuda_inject/src/main.rs
Normal file
5
notcuda_inject/src/main.rs
Normal file
@ -0,0 +1,5 @@
|
||||
#[cfg(target_os = "windows")]
|
||||
mod bin;
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn main() {}
|
@ -7,9 +7,8 @@ edition = "2018"
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
detours-sys = "0.1"
|
||||
wchar = "0.6"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
guid = "0.1"
|
||||
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "std"] }
|
@ -1,19 +1,26 @@
|
||||
#![cfg(windows)]
|
||||
|
||||
extern crate detours_sys;
|
||||
#[macro_use]
|
||||
extern crate guid;
|
||||
extern crate winapi;
|
||||
|
||||
use std::mem;
|
||||
|
||||
use detours_sys::{
|
||||
DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionBegin,
|
||||
DetourTransactionCommit, DetourUpdateThread,
|
||||
};
|
||||
use wchar::wch_c;
|
||||
use winapi::shared::minwindef::{DWORD, HMODULE, TRUE};
|
||||
use wchar::{wch, wch_c};
|
||||
use winapi::shared::minwindef::{DWORD, FALSE, HMODULE, TRUE};
|
||||
use winapi::um::libloaderapi::LoadLibraryExW;
|
||||
use winapi::um::processthreadsapi::GetCurrentThread;
|
||||
use winapi::um::winbase::lstrcmpiW;
|
||||
use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR};
|
||||
|
||||
const NVCUDA_LONG_PATH: &[u16] = wch_c!(r"C:\WINDOWS\system32\nvcuda.dll");
|
||||
const NVCUDA_SHORT_PATH: &[u16] = wch_c!("nvcuda.dll");
|
||||
const NVCUDA_PATH: &[u16] = wch_c!(r"C:\WINDOWS\system32\nvcuda.dll");
|
||||
const NOTCUDA_DLL: &[u16] = wch!(r"nvcuda.dll");
|
||||
static mut NOTCUDA_PATH: Option<Vec<u16>> = None;
|
||||
|
||||
static mut LOAD_LIBRARY_EX: unsafe extern "system" fn(
|
||||
lpLibFileName: LPCWSTR,
|
||||
@ -28,8 +35,8 @@ unsafe extern "system" fn NotCudaLoadLibraryExW(
|
||||
hFile: HANDLE,
|
||||
dwFlags: DWORD,
|
||||
) -> HMODULE {
|
||||
let nvcuda_file_name = if lstrcmpiW(lpLibFileName, NVCUDA_LONG_PATH.as_ptr()) == 0 {
|
||||
NVCUDA_SHORT_PATH.as_ptr()
|
||||
let nvcuda_file_name = if lstrcmpiW(lpLibFileName, NVCUDA_PATH.as_ptr()) == 0 {
|
||||
NOTCUDA_PATH.as_ref().unwrap().as_ptr()
|
||||
} else {
|
||||
lpLibFileName
|
||||
};
|
||||
@ -41,6 +48,10 @@ unsafe extern "system" fn NotCudaLoadLibraryExW(
|
||||
unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) -> i32 {
|
||||
if dwReason == DLL_PROCESS_ATTACH {
|
||||
DetourRestoreAfterWith();
|
||||
match get_notcuda_dll_path() {
|
||||
Some((path, len)) => set_notcuda_dll_path(path, len),
|
||||
None => return FALSE,
|
||||
}
|
||||
DetourTransactionBegin();
|
||||
DetourUpdateThread(GetCurrentThread());
|
||||
DetourAttach(
|
||||
@ -59,3 +70,36 @@ unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) ->
|
||||
}
|
||||
TRUE
|
||||
}
|
||||
|
||||
fn get_notcuda_dll_path() -> Option<(*const u16, usize)> {
|
||||
let guid = guid! {"C225FC0C-00D7-40B8-935A-7E342A9344C1"};
|
||||
let mut module = std::ptr::null_mut();
|
||||
loop {
|
||||
module = unsafe { detours_sys::DetourEnumerateModules(module) };
|
||||
if module == std::ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let mut size = 0;
|
||||
let payload = unsafe {
|
||||
detours_sys::DetourFindPayload(module, std::mem::transmute(&guid), &mut size)
|
||||
};
|
||||
if payload != std::ptr::null_mut() {
|
||||
return Some((payload as *const _, (size as usize) / mem::size_of::<u16>()));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
unsafe fn set_notcuda_dll_path(path: *const u16, len: usize) {
|
||||
let len = len as usize;
|
||||
let mut result = Vec::<u16>::with_capacity(len + NOTCUDA_DLL.len() + 2);
|
||||
for i in 0..len {
|
||||
result.push(*path.add(i));
|
||||
}
|
||||
result.push(0x5c); // \
|
||||
for c in NOTCUDA_DLL.iter().copied() {
|
||||
result.push(c);
|
||||
}
|
||||
result.push(0);
|
||||
NOTCUDA_PATH = Some(result);
|
||||
}
|
||||
|
Reference in New Issue
Block a user