Add test for injecting app that directly uses nvcuda

This commit is contained in:
Andrzej Janik
2021-12-01 23:08:07 +01:00
parent fd1c13560f
commit 400feaf015
5 changed files with 139 additions and 77 deletions

View File

@ -1,14 +1,13 @@
use std::{
ffi::{c_void, CStr},
ffi::{c_void, CStr, CString, OsString},
mem,
os::raw::c_ushort,
ptr,
};
use std::os::windows::io::AsRawHandle;
use wchar::wch_c;
use winapi::{
shared::minwindef::HMODULE,
shared::minwindef::{FARPROC, HMODULE},
um::debugapi::OutputDebugStringA,
um::libloaderapi::{GetProcAddress, LoadLibraryW},
};
@ -17,62 +16,76 @@ use crate::cuda::CUuuid;
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
lazy_static! {
static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() };
}
include!("../../zluda_redirect/src/payload_guid.rs");
pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
let load_lib = if is_detoured() {
match get_non_detoured_load_library() {
Some(load_lib) => load_lib,
None => return ptr::null_mut(),
#[allow(non_snake_case)]
struct PlatformLibrary {
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
}
impl PlatformLibrary {
#[allow(non_snake_case)]
unsafe fn new() -> Self {
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
None => (
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
mem::transmute(
GetProcAddress
as unsafe extern "system" fn(
hModule: HMODULE,
lpProcName: *const i8,
) -> FARPROC,
),
),
Some(zluda_with) => (
mem::transmute(GetProcAddress(
zluda_with,
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
)),
mem::transmute(GetProcAddress(
zluda_with,
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
)),
),
};
PlatformLibrary {
LoadLibraryW,
GetProcAddress,
}
} else {
LoadLibraryW
};
}
unsafe fn get_detourer_module() -> Option<HMODULE> {
let mut module = ptr::null_mut();
loop {
module = detours_sys::DetourEnumerateModules(module);
if module == ptr::null_mut() {
break;
}
let mut size = 0;
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
if payload != ptr::null_mut() {
return Some(module as _);
}
}
None
}
}
pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path_uf16 = libcuda_path
.encode_utf16()
.chain(std::iter::once(0))
.collect::<Vec<_>>();
load_lib(libcuda_path_uf16.as_ptr()) as *mut _
}
unsafe fn is_detoured() -> bool {
let mut module = ptr::null_mut();
loop {
module = detours_sys::DetourEnumerateModules(module);
if module == ptr::null_mut() {
break;
}
let mut size = 0;
let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_NVCUDA_GUID, &mut size);
if payload != ptr::null_mut() {
return true;
}
}
false
}
unsafe fn get_non_detoured_load_library(
) -> Option<unsafe extern "system" fn(*const c_ushort) -> HMODULE> {
let mut module = ptr::null_mut();
loop {
module = detours_sys::DetourEnumerateModules(module);
if module == ptr::null_mut() {
break;
}
let result = GetProcAddress(
module as *mut _,
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as *mut _,
);
if result != ptr::null_mut() {
return Some(mem::transmute(result));
}
}
None
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
GetProcAddress(handle as *mut _, func.as_ptr()) as *mut _
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
}
#[macro_export]

View File

@ -11,3 +11,9 @@ path = "src/main.rs"
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] }
detours-sys = { path = "../detours-sys" }
[dev-dependencies]
# dependency for integration tests
zluda_redirect = { path = "../zluda_redirect" }
# dependency for integration tests
zluda_dump = { path = "../zluda_dump" }

View File

@ -3,8 +3,8 @@ use std::{env, io, path::PathBuf, process::Command};
#[test]
fn direct_cuinit() -> io::Result<()> {
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
let mut zluda_redirect_dll = zluda_with_exe.parent().unwrap().to_path_buf();
zluda_redirect_dll.push("zluda_redirect.dll");
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
zluda_dump_dll.push("zluda_dump.dll");
let helpers_dir = env!("HELPERS_OUT_DIR");
let exe_under_test = format!(
"{}{}direct_cuinit.exe",
@ -12,11 +12,9 @@ fn direct_cuinit() -> io::Result<()> {
std::path::MAIN_SEPARATOR
);
let mut test_cmd = Command::new(&zluda_with_exe);
test_cmd
.arg(&zluda_redirect_dll)
.arg("--")
.arg(&exe_under_test);
let test_cmd = test_cmd.arg(&zluda_dump_dll).arg("--").arg(&exe_under_test);
let test_output = test_cmd.output()?;
assert!(test_output.status.success());
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
assert!(stderr_text.contains("ZLUDA_DUMP"));
Ok(())

View File

@ -4,6 +4,7 @@ extern crate detours_sys;
extern crate winapi;
use std::{
collections::HashMap,
ffi::{c_void, CStr},
mem,
os::raw::c_uint,
@ -61,9 +62,13 @@ static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
static mut ZLUDA_ML_PATH_UTF8: Vec<u8> = Vec::new();
static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None;
static mut CURRENT_MODULE_FILENAME: Vec<u8> = Vec::new();
static mut DETOUR_DETACH: Option<DetourDetachGuard> = None;
static mut DETOUR_STATE: Option<DetourDetachGuard> = None;
const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801;
#[no_mangle]
#[used]
pub static ZLUDA_REDIRECT: () = ();
static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE =
LoadLibraryA;
@ -148,6 +153,24 @@ static mut CREATE_PROCESS_WITH_LOGON_W: unsafe extern "system" fn(
lpProcessInformation: LPPROCESS_INFORMATION,
) -> BOOL = CreateProcessWithLogonW;
#[no_mangle]
#[allow(non_snake_case)]
unsafe extern "system" fn ZludaGetProcAddress_NoRedirect(
hModule: HMODULE,
lpProcName: LPCSTR,
) -> FARPROC {
if let Some(detour_guard) = &DETOUR_STATE {
if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule {
let proc_name = CStr::from_ptr(lpProcName);
return match detour_guard.overriden_cuda_fns.get(proc_name) {
Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn),
None => ptr::null_mut(),
};
}
}
GetProcAddress(hModule, lpProcName)
}
#[no_mangle]
#[allow(non_snake_case)]
unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -> HMODULE {
@ -361,7 +384,9 @@ struct DetourDetachGuard {
state: DetourUndoState,
suspended_threads: Vec<*mut c_void>,
// First element is the original fn, second is the new fn
overriden_functions: Vec<(*mut c_void, *mut c_void)>,
overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
nvcuda_module: HMODULE,
overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
}
impl DetourDetachGuard {
@ -371,12 +396,16 @@ impl DetourDetachGuard {
// also get overriden, so for example ZludaLoadLibraryExW ends calling
// itself recursively until stack overflow exception occurs
unsafe fn detour_functions<'a>(
override_fn_pairs: Vec<(*mut c_void, *mut c_void)>,
nvcuda_module: HMODULE,
non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
) -> Option<Self> {
let mut result = DetourDetachGuard {
state: DetourUndoState::DoNothing,
suspended_threads: Vec::new(),
overriden_functions: override_fn_pairs,
overriden_non_cuda_fns: non_cuda_fns,
nvcuda_module,
overriden_cuda_fns: cuda_fns,
};
if DetourTransactionBegin() != NO_ERROR as i32 {
return None;
@ -390,24 +419,35 @@ impl DetourDetachGuard {
return None;
}
}
result.overriden_functions.extend_from_slice(&[
(CREATE_PROCESS_A as _, ZludaCreateProcessA as _),
(CREATE_PROCESS_W as _, ZludaCreateProcessW as _),
result.overriden_non_cuda_fns.extend_from_slice(&[
(
CREATE_PROCESS_AS_USER_W as _,
&mut CREATE_PROCESS_A as *mut _ as _,
ZludaCreateProcessA as _,
),
(
&mut CREATE_PROCESS_W as *mut _ as _,
ZludaCreateProcessW as _,
),
(
&mut CREATE_PROCESS_AS_USER_W as *mut _ as _,
ZludaCreateProcessAsUserW as _,
),
(
CREATE_PROCESS_WITH_LOGON_W as _,
&mut CREATE_PROCESS_WITH_LOGON_W as *mut _ as _,
ZludaCreateProcessWithLogonW as _,
),
(
CREATE_PROCESS_WITH_TOKEN_W as _,
&mut CREATE_PROCESS_WITH_TOKEN_W as *mut _ as _,
ZludaCreateProcessWithTokenW as _,
),
]);
for (original_fn, new_fn) in result.overriden_functions.iter_mut() {
if DetourAttach(original_fn as *mut _, *new_fn) != NO_ERROR as i32 {
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain(
result
.overriden_cuda_fns
.values_mut()
.map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)),
) {
if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
return None;
}
}
@ -633,13 +673,13 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
};
match detach_guard {
Some(g) => {
DETOUR_DETACH = Some(g);
DETOUR_STATE = Some(g);
TRUE
}
None => FALSE,
}
} else if dwReason == DLL_PROCESS_DETACH {
match DETOUR_DETACH.take() {
match DETOUR_STATE.take() {
Some(_) => TRUE,
None => FALSE,
}
@ -691,9 +731,9 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
}
let original_functions = gather_imports(nvcuda_mod);
let override_functions = gather_imports(zluda_module);
let mut override_fn_pairs = Vec::with_capacity(original_functions.len());
let mut override_fn_pairs = HashMap::with_capacity(original_functions.len());
// TODO: optimize
for (original_fn_name, mut original_fn_address) in original_functions {
for (original_fn_name, original_fn_address) in original_functions {
let override_fn_address =
match override_functions.binary_search_by_key(&original_fn_name, |(name, _)| *name) {
Ok(x) => override_functions[x].1,
@ -702,9 +742,12 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
cuda_unsupported as _
}
};
override_fn_pairs.push((original_fn_address as _, override_fn_address));
override_fn_pairs.insert(
original_fn_name,
(original_fn_address as _, override_fn_address),
);
}
DetourDetachGuard::detour_functions(override_fn_pairs)
DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs)
}
unsafe extern "system" fn cuda_unsupported() -> c_uint {
@ -735,7 +778,10 @@ unsafe extern "stdcall" fn gather_imports_impl(
#[must_use]
unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
let detour_functions = vec![
(&mut LOAD_LIBRARY_A as *mut _ as _, ZludaLoadLibraryA as _),
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
),
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
(
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
@ -746,9 +792,7 @@ unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
ZludaLoadLibraryExW as _,
),
];
let result = DetourDetachGuard::detour_functions(detour_functions);
result
DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new())
}
fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> {

View File

@ -1,3 +1,4 @@
#[allow(dead_code)]
const PAYLOAD_NVCUDA_GUID: detours_sys::GUID = detours_sys::GUID {
Data1: 0xC225FC0C,
Data2: 0x00D7,