mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:19:20 +03:00
Inject our own NVML
This commit is contained in:
@ -36,7 +36,7 @@ unsafe fn is_detoured() -> bool {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let mut size = 0;
|
let mut size = 0;
|
||||||
let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size);
|
let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_NVCUDA_GUID, &mut size);
|
||||||
if payload != ptr::null_mut() {
|
if payload != ptr::null_mut() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
use std::mem;
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
use std::{env, ops::Deref};
|
use std::{env, ops::Deref};
|
||||||
use std::{error::Error, process};
|
use std::{error::Error, process};
|
||||||
|
use std::{mem, path::PathBuf};
|
||||||
|
|
||||||
use mem::size_of_val;
|
use mem::size_of_val;
|
||||||
use winapi::um::{
|
use winapi::um::{
|
||||||
@ -20,6 +20,7 @@ use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
|||||||
|
|
||||||
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
||||||
static ZLUDA_DLL: &'static str = "nvcuda.dll";
|
static ZLUDA_DLL: &'static str = "nvcuda.dll";
|
||||||
|
static ZLUDA_ML_DLL: &'static str = "nvml.dll";
|
||||||
|
|
||||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||||
|
|
||||||
@ -31,7 +32,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||||||
let injector_path = env::current_exe()?;
|
let injector_path = env::current_exe()?;
|
||||||
let injector_dir = injector_path.parent().unwrap();
|
let injector_dir = injector_path.parent().unwrap();
|
||||||
let redirect_path = create_redirect_path(injector_dir);
|
let redirect_path = create_redirect_path(injector_dir);
|
||||||
let (mut inject_path, cmd) = create_inject_path(&args[1..], injector_dir);
|
let (mut inject_nvcuda_path, mut inject_nvml_path, cmd) =
|
||||||
|
create_inject_path(&args[1..], injector_dir);
|
||||||
let mut cmd_line = construct_command_line(cmd);
|
let mut cmd_line = construct_command_line(cmd);
|
||||||
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
||||||
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
||||||
@ -56,9 +58,18 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||||||
os_call!(
|
os_call!(
|
||||||
detours_sys::DetourCopyPayloadToProcess(
|
detours_sys::DetourCopyPayloadToProcess(
|
||||||
proc_info.hProcess,
|
proc_info.hProcess,
|
||||||
&PAYLOAD_GUID,
|
&PAYLOAD_NVCUDA_GUID,
|
||||||
inject_path.as_mut_ptr() as *mut _,
|
inject_nvcuda_path.as_mut_ptr() as *mut _,
|
||||||
(inject_path.len() * mem::size_of::<u16>()) as u32
|
(inject_nvcuda_path.len() * mem::size_of::<u16>()) as u32
|
||||||
|
),
|
||||||
|
|x| x != 0
|
||||||
|
);
|
||||||
|
os_call!(
|
||||||
|
detours_sys::DetourCopyPayloadToProcess(
|
||||||
|
proc_info.hProcess,
|
||||||
|
&PAYLOAD_NVML_GUID,
|
||||||
|
inject_nvml_path.as_mut_ptr() as *mut _,
|
||||||
|
(inject_nvml_path.len() * mem::size_of::<u16>()) as u32
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
);
|
);
|
||||||
@ -173,22 +184,34 @@ fn create_redirect_path(injector_dir: &Path) -> Vec<u8> {
|
|||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_inject_path<'a>(args: &'a [String], injector_dir: &Path) -> (Vec<u16>, &'a [String]) {
|
fn create_inject_path<'a>(
|
||||||
if args.get(0).map(Deref::deref) == Some("--") {
|
args: &'a [String],
|
||||||
let mut injector_dir = injector_dir.to_path_buf();
|
injector_dir: &Path,
|
||||||
injector_dir.push(ZLUDA_DLL);
|
) -> (Vec<u16>, Vec<u16>, &'a [String]) {
|
||||||
let mut result = injector_dir
|
let injector_dir = injector_dir.to_path_buf();
|
||||||
.to_string_lossy()
|
let (nvcuda_path, unparsed_args) = if args.get(0).map(Deref::deref) == Some("--") {
|
||||||
.as_ref()
|
(
|
||||||
.encode_utf16()
|
encode_file_in_directory_raw(injector_dir.clone(), ZLUDA_DLL),
|
||||||
.collect::<Vec<_>>();
|
&args[1..],
|
||||||
result.push(0);
|
)
|
||||||
(result, &args[1..])
|
|
||||||
} else if args.get(1).map(Deref::deref) == Some("--") {
|
} else if args.get(1).map(Deref::deref) == Some("--") {
|
||||||
let mut dll_path = args[0].encode_utf16().collect::<Vec<_>>();
|
let mut dll_path = args[0].encode_utf16().collect::<Vec<_>>();
|
||||||
dll_path.push(0);
|
dll_path.push(0);
|
||||||
(dll_path, &args[2..])
|
(dll_path, &args[2..])
|
||||||
} else {
|
} else {
|
||||||
print_help_and_exit()
|
print_help_and_exit()
|
||||||
}
|
};
|
||||||
|
let nvml_path = encode_file_in_directory_raw(injector_dir, ZLUDA_ML_DLL);
|
||||||
|
(nvcuda_path, nvml_path, unparsed_args)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_file_in_directory_raw(mut dir: PathBuf, file: &'static str) -> Vec<u16> {
|
||||||
|
dir.push(file);
|
||||||
|
let mut result = dir
|
||||||
|
.to_string_lossy()
|
||||||
|
.as_ref()
|
||||||
|
.encode_utf16()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
result.push(0);
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
bindgen "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\include\nvml.h" --whitelist-function="^nvml.*" --size_t-is-usize --default-enum-style=newtype --no-layout-tests --no-doc-comments --no-derive-debug -o src/nvml.rs
|
bindgen "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\include\nvml.h" --whitelist-function="^nvml.*" --size_t-is-usize --default-enum-style=newtype --no-layout-tests --no-doc-comments --no-derive-debug -o src/nvml.rs
|
||||||
sed -i -e 's/extern "C" {//g' -e 's/-> nvmlReturn_t;/-> nvmlReturn_t { impl_::unsupported()/g' -e 's/pub fn /#[no_mangle] pub extern "C" fn /g' src/nvml.rs
|
sed -i -e 's/extern "C" {//g' -e 's/-> nvmlReturn_t;/-> nvmlReturn_t { crate::r#impl::unimplemented()/g' -e 's/pub fn /#[no_mangle] pub extern "C" fn /g' src/nvml.rs
|
||||||
rustfmt src/nvml.rs
|
rustfmt src/nvml.rs
|
@ -1,5 +1,5 @@
|
|||||||
use level_zero as l0;
|
use level_zero as l0;
|
||||||
use std::{io::Write, ops::Add};
|
use std::io::Write;
|
||||||
use std::{
|
use std::{
|
||||||
os::raw::{c_char, c_uint},
|
os::raw::{c_char, c_uint},
|
||||||
ptr,
|
ptr,
|
||||||
|
@ -55,8 +55,12 @@ include!("payload_guid.rs");
|
|||||||
|
|
||||||
const NVCUDA_UTF8: &'static str = "NVCUDA.DLL";
|
const NVCUDA_UTF8: &'static str = "NVCUDA.DLL";
|
||||||
const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL");
|
const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL");
|
||||||
|
const NVML_UTF8: &'static str = "NVML.DLL";
|
||||||
|
const NVML_UTF16: &[u16] = wch!("NVML.DLL");
|
||||||
static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new();
|
static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new();
|
||||||
static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
|
static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
|
||||||
|
static mut ZLUDA_ML_PATH_UTF8: Vec<u8> = Vec::new();
|
||||||
|
static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None;
|
||||||
static mut DETACH_LOAD_LIBRARY: bool = false;
|
static mut DETACH_LOAD_LIBRARY: bool = false;
|
||||||
static mut NVCUDA_ORIGINAL_MODULE: HMODULE = ptr::null_mut();
|
static mut NVCUDA_ORIGINAL_MODULE: HMODULE = ptr::null_mut();
|
||||||
static mut CUINIT_ORIGINAL_FN: FARPROC = ptr::null_mut();
|
static mut CUINIT_ORIGINAL_FN: FARPROC = ptr::null_mut();
|
||||||
@ -158,6 +162,8 @@ unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -
|
|||||||
unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
|
unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
|
||||||
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
||||||
ZLUDA_PATH_UTF8.as_ptr() as *const _
|
ZLUDA_PATH_UTF8.as_ptr() as *const _
|
||||||
|
} else if is_nvml_dll_utf8(lpLibFileName as *const _) {
|
||||||
|
ZLUDA_ML_PATH_UTF8.as_ptr() as *const _
|
||||||
} else {
|
} else {
|
||||||
lpLibFileName
|
lpLibFileName
|
||||||
};
|
};
|
||||||
@ -168,6 +174,8 @@ unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
|
|||||||
unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE {
|
unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE {
|
||||||
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
|
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
|
||||||
ZLUDA_PATH_UTF16.unwrap().as_ptr()
|
ZLUDA_PATH_UTF16.unwrap().as_ptr()
|
||||||
|
} else if is_nvml_dll_utf16(lpLibFileName as *const _) {
|
||||||
|
ZLUDA_ML_PATH_UTF16.unwrap().as_ptr()
|
||||||
} else {
|
} else {
|
||||||
lpLibFileName
|
lpLibFileName
|
||||||
};
|
};
|
||||||
@ -182,6 +190,8 @@ unsafe extern "system" fn ZludaLoadLibraryExA(
|
|||||||
) -> HMODULE {
|
) -> HMODULE {
|
||||||
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
||||||
ZLUDA_PATH_UTF8.as_ptr() as *const _
|
ZLUDA_PATH_UTF8.as_ptr() as *const _
|
||||||
|
} else if is_nvml_dll_utf8(lpLibFileName as *const _) {
|
||||||
|
ZLUDA_ML_PATH_UTF8.as_ptr() as *const _
|
||||||
} else {
|
} else {
|
||||||
lpLibFileName
|
lpLibFileName
|
||||||
};
|
};
|
||||||
@ -196,6 +206,8 @@ unsafe extern "system" fn ZludaLoadLibraryExW(
|
|||||||
) -> HMODULE {
|
) -> HMODULE {
|
||||||
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
|
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
|
||||||
ZLUDA_PATH_UTF16.unwrap().as_ptr()
|
ZLUDA_PATH_UTF16.unwrap().as_ptr()
|
||||||
|
} else if is_nvml_dll_utf16(lpLibFileName as *const _) {
|
||||||
|
ZLUDA_ML_PATH_UTF16.unwrap().as_ptr()
|
||||||
} else {
|
} else {
|
||||||
lpLibFileName
|
lpLibFileName
|
||||||
};
|
};
|
||||||
@ -363,7 +375,7 @@ unsafe fn continue_create_process_hook(
|
|||||||
}
|
}
|
||||||
if detours_sys::DetourCopyPayloadToProcess(
|
if detours_sys::DetourCopyPayloadToProcess(
|
||||||
(*process_information).hProcess,
|
(*process_information).hProcess,
|
||||||
&PAYLOAD_GUID,
|
&PAYLOAD_NVCUDA_GUID,
|
||||||
ZLUDA_PATH_UTF16.unwrap().as_ptr() as *mut _,
|
ZLUDA_PATH_UTF16.unwrap().as_ptr() as *mut _,
|
||||||
(ZLUDA_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32,
|
(ZLUDA_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32,
|
||||||
) == FALSE
|
) == FALSE
|
||||||
@ -372,6 +384,16 @@ unsafe fn continue_create_process_hook(
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if detours_sys::DetourCopyPayloadToProcess(
|
||||||
|
(*process_information).hProcess,
|
||||||
|
&PAYLOAD_NVML_GUID,
|
||||||
|
ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() as *mut _,
|
||||||
|
(ZLUDA_ML_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32,
|
||||||
|
) == FALSE
|
||||||
|
{
|
||||||
|
TerminateProcess((*process_information).hProcess, 1);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
if creation_flags & CREATE_SUSPENDED == 0 {
|
if creation_flags & CREATE_SUSPENDED == 0 {
|
||||||
if ResumeThread((*process_information).hThread) == -1i32 as u32 {
|
if ResumeThread((*process_information).hThread) == -1i32 as u32 {
|
||||||
TerminateProcess((*process_information).hProcess, 1);
|
TerminateProcess((*process_information).hProcess, 1);
|
||||||
@ -490,7 +512,23 @@ unsafe extern "C" fn unsupported_cuda_fn() -> c_uint {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_nvcuda_dll_utf8(lib: *const u8) -> bool {
|
fn is_nvcuda_dll_utf8(lib: *const u8) -> bool {
|
||||||
is_nvcuda_dll(lib, 0, NVCUDA_UTF8.as_bytes(), |c| {
|
is_dll_utf8(lib, NVCUDA_UTF8.as_bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_nvcuda_dll_utf16(lib: *const u16) -> bool {
|
||||||
|
is_dll_utf16(lib, NVCUDA_UTF16)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_nvml_dll_utf8(lib: *const u8) -> bool {
|
||||||
|
is_dll_utf8(lib, NVML_UTF8.as_bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_nvml_dll_utf16(lib: *const u16) -> bool {
|
||||||
|
is_dll_utf16(lib, NVML_UTF16)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_dll_utf8(lib: *const u8, name: &[u8]) -> bool {
|
||||||
|
is_dll_impl(lib, 0, name, |c| {
|
||||||
if c >= 'a' as u8 && c <= 'z' as u8 {
|
if c >= 'a' as u8 && c <= 'z' as u8 {
|
||||||
c - 32
|
c - 32
|
||||||
} else {
|
} else {
|
||||||
@ -498,8 +536,9 @@ fn is_nvcuda_dll_utf8(lib: *const u8) -> bool {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
fn is_nvcuda_dll_utf16(lib: *const u16) -> bool {
|
|
||||||
is_nvcuda_dll(lib, 0u16, NVCUDA_UTF16, |c| {
|
fn is_dll_utf16(lib: *const u16, name: &[u16]) -> bool {
|
||||||
|
is_dll_impl(lib, 0u16, name, |c| {
|
||||||
if c >= 'a' as u16 && c <= 'z' as u16 {
|
if c >= 'a' as u16 && c <= 'z' as u16 {
|
||||||
c - 32
|
c - 32
|
||||||
} else {
|
} else {
|
||||||
@ -508,7 +547,7 @@ fn is_nvcuda_dll_utf16(lib: *const u16) -> bool {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_nvcuda_dll<T: Copy + PartialEq>(
|
fn is_dll_impl<T: Copy + PartialEq>(
|
||||||
lib: *const T,
|
lib: *const T,
|
||||||
zero: T,
|
zero: T,
|
||||||
dll_name: &[T],
|
dll_name: &[T],
|
||||||
@ -544,11 +583,13 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
|
|||||||
if !initialize_current_module_name(instDLL) {
|
if !initialize_current_module_name(instDLL) {
|
||||||
return FALSE;
|
return FALSE;
|
||||||
}
|
}
|
||||||
match get_zluda_dll_path() {
|
match get_zluda_dlls_paths() {
|
||||||
Some(path) => {
|
Some((nvcuda_path, nvml_path)) => {
|
||||||
ZLUDA_PATH_UTF16 = Some(path);
|
ZLUDA_PATH_UTF16 = Some(nvcuda_path);
|
||||||
|
ZLUDA_ML_PATH_UTF16 = Some(nvml_path);
|
||||||
// from_utf16_lossy(...) handles terminating NULL correctly
|
// from_utf16_lossy(...) handles terminating NULL correctly
|
||||||
ZLUDA_PATH_UTF8 = String::from_utf16_lossy(path).into_bytes();
|
ZLUDA_PATH_UTF8 = String::from_utf16_lossy(nvcuda_path).into_bytes();
|
||||||
|
ZLUDA_ML_PATH_UTF8 = String::from_utf16_lossy(nvml_path).into_bytes();
|
||||||
}
|
}
|
||||||
None => return FALSE,
|
None => return FALSE,
|
||||||
}
|
}
|
||||||
@ -740,25 +781,34 @@ unsafe fn detach_load_library() -> i32 {
|
|||||||
TRUE
|
TRUE
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_zluda_dll_path() -> Option<&'static [u16]> {
|
fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> {
|
||||||
|
match get_payload(&PAYLOAD_NVCUDA_GUID) {
|
||||||
|
None => None,
|
||||||
|
Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
|
||||||
|
None => return None,
|
||||||
|
Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_payload(guid: &detours_sys::GUID) -> Option<&'static [u16]> {
|
||||||
let mut module = ptr::null_mut();
|
let mut module = ptr::null_mut();
|
||||||
loop {
|
loop {
|
||||||
module = unsafe { detours_sys::DetourEnumerateModules(module) };
|
module = unsafe { detours_sys::DetourEnumerateModules(module) };
|
||||||
if module == ptr::null_mut() {
|
if module == ptr::null_mut() {
|
||||||
break;
|
return None;
|
||||||
}
|
}
|
||||||
let mut size = 0;
|
let mut size = 0;
|
||||||
let payload = unsafe { detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size) };
|
let payload_ptr = unsafe { detours_sys::DetourFindPayload(module, guid, &mut size) };
|
||||||
if payload != ptr::null_mut() {
|
if payload_ptr != ptr::null_mut() {
|
||||||
return unsafe {
|
return Some(unsafe {
|
||||||
Some(slice::from_raw_parts(
|
slice::from_raw_parts(
|
||||||
payload as *const _,
|
payload_ptr as *const _,
|
||||||
(size as usize) / mem::size_of::<u16>(),
|
(size as usize) / mem::size_of::<u16>(),
|
||||||
))
|
)
|
||||||
};
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
|
@ -1,6 +1,14 @@
|
|||||||
const PAYLOAD_GUID: detours_sys::GUID = detours_sys::GUID {
|
const PAYLOAD_NVCUDA_GUID: detours_sys::GUID = detours_sys::GUID {
|
||||||
Data1: 0xC225FC0C,
|
Data1: 0xC225FC0C,
|
||||||
Data2: 0x00D7,
|
Data2: 0x00D7,
|
||||||
Data3: 0x40B8,
|
Data3: 0x40B8,
|
||||||
Data4: [0x93, 0x5A, 0x7E, 0x34, 0x2A, 0x93, 0x44, 0xC1],
|
Data4: [0x93, 0x5A, 0x7E, 0x34, 0x2A, 0x93, 0x44, 0xC1],
|
||||||
|
};
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
const PAYLOAD_NVML_GUID: detours_sys::GUID = detours_sys::GUID {
|
||||||
|
Data1: 0x75B54759,
|
||||||
|
Data2: 0xB6F1,
|
||||||
|
Data3: 0x49C2,
|
||||||
|
Data4: [0xA2, 0x09, 0x68, 0x54, 0x96, 0xBD, 0x70, 0xC0],
|
||||||
};
|
};
|
Reference in New Issue
Block a user