mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-12 10:48:53 +03:00
Add test for injecting app that directly uses nvcuda
This commit is contained in:
@ -1,14 +1,13 @@
|
||||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
ffi::{c_void, CStr, CString, OsString},
|
||||
mem,
|
||||
os::raw::c_ushort,
|
||||
ptr,
|
||||
};
|
||||
|
||||
use std::os::windows::io::AsRawHandle;
|
||||
use wchar::wch_c;
|
||||
use winapi::{
|
||||
shared::minwindef::HMODULE,
|
||||
shared::minwindef::{FARPROC, HMODULE},
|
||||
um::debugapi::OutputDebugStringA,
|
||||
um::libloaderapi::{GetProcAddress, LoadLibraryW},
|
||||
};
|
||||
@ -17,62 +16,76 @@ use crate::cuda::CUuuid;
|
||||
|
||||
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
|
||||
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
|
||||
|
||||
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
|
||||
lazy_static! {
|
||||
static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() };
|
||||
}
|
||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||
|
||||
pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
|
||||
let load_lib = if is_detoured() {
|
||||
match get_non_detoured_load_library() {
|
||||
Some(load_lib) => load_lib,
|
||||
None => return ptr::null_mut(),
|
||||
#[allow(non_snake_case)]
|
||||
struct PlatformLibrary {
|
||||
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
|
||||
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
|
||||
}
|
||||
|
||||
impl PlatformLibrary {
|
||||
#[allow(non_snake_case)]
|
||||
unsafe fn new() -> Self {
|
||||
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
|
||||
None => (
|
||||
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
|
||||
mem::transmute(
|
||||
GetProcAddress
|
||||
as unsafe extern "system" fn(
|
||||
hModule: HMODULE,
|
||||
lpProcName: *const i8,
|
||||
) -> FARPROC,
|
||||
),
|
||||
),
|
||||
Some(zluda_with) => (
|
||||
mem::transmute(GetProcAddress(
|
||||
zluda_with,
|
||||
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
|
||||
)),
|
||||
mem::transmute(GetProcAddress(
|
||||
zluda_with,
|
||||
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
|
||||
)),
|
||||
),
|
||||
};
|
||||
PlatformLibrary {
|
||||
LoadLibraryW,
|
||||
GetProcAddress,
|
||||
}
|
||||
} else {
|
||||
LoadLibraryW
|
||||
};
|
||||
}
|
||||
|
||||
unsafe fn get_detourer_module() -> Option<HMODULE> {
|
||||
let mut module = ptr::null_mut();
|
||||
loop {
|
||||
module = detours_sys::DetourEnumerateModules(module);
|
||||
if module == ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let mut size = 0;
|
||||
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
|
||||
if payload != ptr::null_mut() {
|
||||
return Some(module as _);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
|
||||
let libcuda_path_uf16 = libcuda_path
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect::<Vec<_>>();
|
||||
load_lib(libcuda_path_uf16.as_ptr()) as *mut _
|
||||
}
|
||||
|
||||
unsafe fn is_detoured() -> bool {
|
||||
let mut module = ptr::null_mut();
|
||||
loop {
|
||||
module = detours_sys::DetourEnumerateModules(module);
|
||||
if module == ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let mut size = 0;
|
||||
let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_NVCUDA_GUID, &mut size);
|
||||
if payload != ptr::null_mut() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
unsafe fn get_non_detoured_load_library(
|
||||
) -> Option<unsafe extern "system" fn(*const c_ushort) -> HMODULE> {
|
||||
let mut module = ptr::null_mut();
|
||||
loop {
|
||||
module = detours_sys::DetourEnumerateModules(module);
|
||||
if module == ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let result = GetProcAddress(
|
||||
module as *mut _,
|
||||
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as *mut _,
|
||||
);
|
||||
if result != ptr::null_mut() {
|
||||
return Some(mem::transmute(result));
|
||||
}
|
||||
}
|
||||
None
|
||||
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
|
||||
}
|
||||
|
||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||
GetProcAddress(handle as *mut _, func.as_ptr()) as *mut _
|
||||
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
|
@ -11,3 +11,9 @@ path = "src/main.rs"
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] }
|
||||
detours-sys = { path = "../detours-sys" }
|
||||
|
||||
[dev-dependencies]
|
||||
# dependency for integration tests
|
||||
zluda_redirect = { path = "../zluda_redirect" }
|
||||
# dependency for integration tests
|
||||
zluda_dump = { path = "../zluda_dump" }
|
||||
|
@ -3,8 +3,8 @@ use std::{env, io, path::PathBuf, process::Command};
|
||||
#[test]
|
||||
fn direct_cuinit() -> io::Result<()> {
|
||||
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
|
||||
let mut zluda_redirect_dll = zluda_with_exe.parent().unwrap().to_path_buf();
|
||||
zluda_redirect_dll.push("zluda_redirect.dll");
|
||||
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
|
||||
zluda_dump_dll.push("zluda_dump.dll");
|
||||
let helpers_dir = env!("HELPERS_OUT_DIR");
|
||||
let exe_under_test = format!(
|
||||
"{}{}direct_cuinit.exe",
|
||||
@ -12,11 +12,9 @@ fn direct_cuinit() -> io::Result<()> {
|
||||
std::path::MAIN_SEPARATOR
|
||||
);
|
||||
let mut test_cmd = Command::new(&zluda_with_exe);
|
||||
test_cmd
|
||||
.arg(&zluda_redirect_dll)
|
||||
.arg("--")
|
||||
.arg(&exe_under_test);
|
||||
let test_cmd = test_cmd.arg(&zluda_dump_dll).arg("--").arg(&exe_under_test);
|
||||
let test_output = test_cmd.output()?;
|
||||
assert!(test_output.status.success());
|
||||
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
|
||||
assert!(stderr_text.contains("ZLUDA_DUMP"));
|
||||
Ok(())
|
||||
|
@ -4,6 +4,7 @@ extern crate detours_sys;
|
||||
extern crate winapi;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
ffi::{c_void, CStr},
|
||||
mem,
|
||||
os::raw::c_uint,
|
||||
@ -61,9 +62,13 @@ static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
|
||||
static mut ZLUDA_ML_PATH_UTF8: Vec<u8> = Vec::new();
|
||||
static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None;
|
||||
static mut CURRENT_MODULE_FILENAME: Vec<u8> = Vec::new();
|
||||
static mut DETOUR_DETACH: Option<DetourDetachGuard> = None;
|
||||
static mut DETOUR_STATE: Option<DetourDetachGuard> = None;
|
||||
const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801;
|
||||
|
||||
#[no_mangle]
|
||||
#[used]
|
||||
pub static ZLUDA_REDIRECT: () = ();
|
||||
|
||||
static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE =
|
||||
LoadLibraryA;
|
||||
|
||||
@ -148,6 +153,24 @@ static mut CREATE_PROCESS_WITH_LOGON_W: unsafe extern "system" fn(
|
||||
lpProcessInformation: LPPROCESS_INFORMATION,
|
||||
) -> BOOL = CreateProcessWithLogonW;
|
||||
|
||||
#[no_mangle]
|
||||
#[allow(non_snake_case)]
|
||||
unsafe extern "system" fn ZludaGetProcAddress_NoRedirect(
|
||||
hModule: HMODULE,
|
||||
lpProcName: LPCSTR,
|
||||
) -> FARPROC {
|
||||
if let Some(detour_guard) = &DETOUR_STATE {
|
||||
if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule {
|
||||
let proc_name = CStr::from_ptr(lpProcName);
|
||||
return match detour_guard.overriden_cuda_fns.get(proc_name) {
|
||||
Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn),
|
||||
None => ptr::null_mut(),
|
||||
};
|
||||
}
|
||||
}
|
||||
GetProcAddress(hModule, lpProcName)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
#[allow(non_snake_case)]
|
||||
unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -> HMODULE {
|
||||
@ -361,7 +384,9 @@ struct DetourDetachGuard {
|
||||
state: DetourUndoState,
|
||||
suspended_threads: Vec<*mut c_void>,
|
||||
// First element is the original fn, second is the new fn
|
||||
overriden_functions: Vec<(*mut c_void, *mut c_void)>,
|
||||
overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
|
||||
nvcuda_module: HMODULE,
|
||||
overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
|
||||
}
|
||||
|
||||
impl DetourDetachGuard {
|
||||
@ -371,12 +396,16 @@ impl DetourDetachGuard {
|
||||
// also get overriden, so for example ZludaLoadLibraryExW ends calling
|
||||
// itself recursively until stack overflow exception occurs
|
||||
unsafe fn detour_functions<'a>(
|
||||
override_fn_pairs: Vec<(*mut c_void, *mut c_void)>,
|
||||
nvcuda_module: HMODULE,
|
||||
non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
|
||||
cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
|
||||
) -> Option<Self> {
|
||||
let mut result = DetourDetachGuard {
|
||||
state: DetourUndoState::DoNothing,
|
||||
suspended_threads: Vec::new(),
|
||||
overriden_functions: override_fn_pairs,
|
||||
overriden_non_cuda_fns: non_cuda_fns,
|
||||
nvcuda_module,
|
||||
overriden_cuda_fns: cuda_fns,
|
||||
};
|
||||
if DetourTransactionBegin() != NO_ERROR as i32 {
|
||||
return None;
|
||||
@ -390,24 +419,35 @@ impl DetourDetachGuard {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
result.overriden_functions.extend_from_slice(&[
|
||||
(CREATE_PROCESS_A as _, ZludaCreateProcessA as _),
|
||||
(CREATE_PROCESS_W as _, ZludaCreateProcessW as _),
|
||||
result.overriden_non_cuda_fns.extend_from_slice(&[
|
||||
(
|
||||
CREATE_PROCESS_AS_USER_W as _,
|
||||
&mut CREATE_PROCESS_A as *mut _ as _,
|
||||
ZludaCreateProcessA as _,
|
||||
),
|
||||
(
|
||||
&mut CREATE_PROCESS_W as *mut _ as _,
|
||||
ZludaCreateProcessW as _,
|
||||
),
|
||||
(
|
||||
&mut CREATE_PROCESS_AS_USER_W as *mut _ as _,
|
||||
ZludaCreateProcessAsUserW as _,
|
||||
),
|
||||
(
|
||||
CREATE_PROCESS_WITH_LOGON_W as _,
|
||||
&mut CREATE_PROCESS_WITH_LOGON_W as *mut _ as _,
|
||||
ZludaCreateProcessWithLogonW as _,
|
||||
),
|
||||
(
|
||||
CREATE_PROCESS_WITH_TOKEN_W as _,
|
||||
&mut CREATE_PROCESS_WITH_TOKEN_W as *mut _ as _,
|
||||
ZludaCreateProcessWithTokenW as _,
|
||||
),
|
||||
]);
|
||||
for (original_fn, new_fn) in result.overriden_functions.iter_mut() {
|
||||
if DetourAttach(original_fn as *mut _, *new_fn) != NO_ERROR as i32 {
|
||||
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain(
|
||||
result
|
||||
.overriden_cuda_fns
|
||||
.values_mut()
|
||||
.map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)),
|
||||
) {
|
||||
if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
@ -633,13 +673,13 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
|
||||
};
|
||||
match detach_guard {
|
||||
Some(g) => {
|
||||
DETOUR_DETACH = Some(g);
|
||||
DETOUR_STATE = Some(g);
|
||||
TRUE
|
||||
}
|
||||
None => FALSE,
|
||||
}
|
||||
} else if dwReason == DLL_PROCESS_DETACH {
|
||||
match DETOUR_DETACH.take() {
|
||||
match DETOUR_STATE.take() {
|
||||
Some(_) => TRUE,
|
||||
None => FALSE,
|
||||
}
|
||||
@ -691,9 +731,9 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
|
||||
}
|
||||
let original_functions = gather_imports(nvcuda_mod);
|
||||
let override_functions = gather_imports(zluda_module);
|
||||
let mut override_fn_pairs = Vec::with_capacity(original_functions.len());
|
||||
let mut override_fn_pairs = HashMap::with_capacity(original_functions.len());
|
||||
// TODO: optimize
|
||||
for (original_fn_name, mut original_fn_address) in original_functions {
|
||||
for (original_fn_name, original_fn_address) in original_functions {
|
||||
let override_fn_address =
|
||||
match override_functions.binary_search_by_key(&original_fn_name, |(name, _)| *name) {
|
||||
Ok(x) => override_functions[x].1,
|
||||
@ -702,9 +742,12 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
|
||||
cuda_unsupported as _
|
||||
}
|
||||
};
|
||||
override_fn_pairs.push((original_fn_address as _, override_fn_address));
|
||||
override_fn_pairs.insert(
|
||||
original_fn_name,
|
||||
(original_fn_address as _, override_fn_address),
|
||||
);
|
||||
}
|
||||
DetourDetachGuard::detour_functions(override_fn_pairs)
|
||||
DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs)
|
||||
}
|
||||
|
||||
unsafe extern "system" fn cuda_unsupported() -> c_uint {
|
||||
@ -735,7 +778,10 @@ unsafe extern "stdcall" fn gather_imports_impl(
|
||||
#[must_use]
|
||||
unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
|
||||
let detour_functions = vec![
|
||||
(&mut LOAD_LIBRARY_A as *mut _ as _, ZludaLoadLibraryA as _),
|
||||
(
|
||||
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
||||
ZludaLoadLibraryA as *mut c_void,
|
||||
),
|
||||
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
|
||||
(
|
||||
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
|
||||
@ -746,9 +792,7 @@ unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
|
||||
ZludaLoadLibraryExW as _,
|
||||
),
|
||||
];
|
||||
let result = DetourDetachGuard::detour_functions(detour_functions);
|
||||
|
||||
result
|
||||
DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new())
|
||||
}
|
||||
|
||||
fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> {
|
||||
|
@ -1,3 +1,4 @@
|
||||
#[allow(dead_code)]
|
||||
const PAYLOAD_NVCUDA_GUID: detours_sys::GUID = detours_sys::GUID {
|
||||
Data1: 0xC225FC0C,
|
||||
Data2: 0x00D7,
|
||||
|
Reference in New Issue
Block a user