mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-21 00:48:49 +03:00
Overhaul DLL injection
This commit is contained in:
@ -18,7 +18,6 @@ const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedi
|
|||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() };
|
static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() };
|
||||||
}
|
}
|
||||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
struct PlatformLibrary {
|
struct PlatformLibrary {
|
||||||
|
@ -10,6 +10,8 @@ path = "src/main.rs"
|
|||||||
|
|
||||||
[target.'cfg(windows)'.dependencies]
|
[target.'cfg(windows)'.dependencies]
|
||||||
winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] }
|
winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] }
|
||||||
|
tempfile = "3"
|
||||||
|
argh = "0.1"
|
||||||
detours-sys = { path = "../detours-sys" }
|
detours-sys = { path = "../detours-sys" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
@ -43,6 +43,8 @@ fn main() -> Result<(), VarError> {
|
|||||||
.arg("-ldylib=nvcuda")
|
.arg("-ldylib=nvcuda")
|
||||||
.arg("-C")
|
.arg("-C")
|
||||||
.arg(format!("opt-level={}", opt_level))
|
.arg(format!("opt-level={}", opt_level))
|
||||||
|
.arg("-L")
|
||||||
|
.arg(format!("{}", out_dir))
|
||||||
.arg("--out-dir")
|
.arg("--out-dir")
|
||||||
.arg(format!("{}", out_dir))
|
.arg(format!("{}", out_dir))
|
||||||
.arg("--target")
|
.arg("--target")
|
||||||
@ -52,11 +54,11 @@ fn main() -> Result<(), VarError> {
|
|||||||
}
|
}
|
||||||
std::fs::copy(
|
std::fs::copy(
|
||||||
format!(
|
format!(
|
||||||
"{}{}do_cuinit_main_clr.exe",
|
"{}{}do_cuinit_late_clr.exe",
|
||||||
helpers_dir_as_string,
|
helpers_dir_as_string,
|
||||||
path::MAIN_SEPARATOR
|
path::MAIN_SEPARATOR
|
||||||
),
|
),
|
||||||
format!("{}{}do_cuinit_main_clr.exe", out_dir, path::MAIN_SEPARATOR),
|
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
|
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::os::windows;
|
||||||
use std::os::windows::ffi::OsStrExt;
|
use std::os::windows::ffi::OsStrExt;
|
||||||
use std::path::Path;
|
|
||||||
use std::ptr;
|
|
||||||
use std::{env, ops::Deref};
|
|
||||||
use std::{error::Error, process};
|
use std::{error::Error, process};
|
||||||
|
use std::{fs, io, ptr};
|
||||||
use std::{mem, path::PathBuf};
|
use std::{mem, path::PathBuf};
|
||||||
|
|
||||||
|
use argh::FromArgs;
|
||||||
use mem::size_of_val;
|
use mem::size_of_val;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
use winapi::um::processenv::SearchPathW;
|
||||||
use winapi::um::{
|
use winapi::um::{
|
||||||
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
|
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
|
||||||
processthreadsapi::{GetExitCodeProcess, ResumeThread},
|
processthreadsapi::{GetExitCodeProcess, ResumeThread},
|
||||||
@ -20,28 +23,46 @@ use winapi::um::{
|
|||||||
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
||||||
|
|
||||||
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
||||||
static ZLUDA_DLL: &'static str = "nvcuda.dll";
|
static NVCUDA_DLL: &'static str = "nvcuda.dll";
|
||||||
static ZLUDA_ML_DLL: &'static str = "nvml.dll";
|
static NVML_DLL: &'static str = "nvml.dll";
|
||||||
|
|
||||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||||
|
|
||||||
|
#[derive(FromArgs)]
|
||||||
|
/// Launch application with custom CUDA libraries
|
||||||
|
struct ProgramArguments {
|
||||||
|
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
|
||||||
|
#[argh(option)]
|
||||||
|
nvcuda: Option<PathBuf>,
|
||||||
|
|
||||||
|
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
|
||||||
|
#[argh(option)]
|
||||||
|
nvml: Option<PathBuf>,
|
||||||
|
|
||||||
|
/// executable to be injected with custom CUDA libraries
|
||||||
|
#[argh(positional)]
|
||||||
|
exe: String,
|
||||||
|
|
||||||
|
/// arguments to the executable
|
||||||
|
#[argh(positional)]
|
||||||
|
args: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
||||||
let args = env::args().collect::<Vec<_>>();
|
let raw_args = argh::from_env::<ProgramArguments>();
|
||||||
if args.len() <= 1 {
|
let normalized_args = NormalizedArguments::new(raw_args)?;
|
||||||
print_help_and_exit();
|
let mut environment = Environment::setup(normalized_args)?;
|
||||||
}
|
|
||||||
let injector_path = env::current_exe()?;
|
|
||||||
let injector_dir = injector_path.parent().unwrap();
|
|
||||||
let redirect_path = create_redirect_path(injector_dir);
|
|
||||||
let (mut inject_nvcuda_path, mut inject_nvml_path, cmd) =
|
|
||||||
create_inject_path(&args[1..], injector_dir)?;
|
|
||||||
let mut cmd_line = construct_command_line(cmd);
|
|
||||||
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
||||||
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
||||||
|
let mut dlls_to_inject = [
|
||||||
|
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
|
||||||
|
environment.nvcuda_path_zero_terminated.as_ptr() as _,
|
||||||
|
environment.redirect_path_zero_terminated.as_ptr() as _,
|
||||||
|
];
|
||||||
os_call!(
|
os_call!(
|
||||||
detours_sys::DetourCreateProcessWithDllExW(
|
detours_sys::DetourCreateProcessWithDllsW(
|
||||||
ptr::null(),
|
ptr::null(),
|
||||||
cmd_line.as_mut_ptr(),
|
environment.winapi_command_line_zero_terminated.as_mut_ptr(),
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
0,
|
0,
|
||||||
@ -50,7 +71,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||||||
ptr::null(),
|
ptr::null(),
|
||||||
&mut startup_info as *mut _,
|
&mut startup_info as *mut _,
|
||||||
&mut proc_info as *mut _,
|
&mut proc_info as *mut _,
|
||||||
redirect_path.as_ptr() as *const i8,
|
dlls_to_inject.len() as u32,
|
||||||
|
dlls_to_inject.as_mut_ptr(),
|
||||||
Option::None
|
Option::None
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
@ -60,8 +82,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||||||
detours_sys::DetourCopyPayloadToProcess(
|
detours_sys::DetourCopyPayloadToProcess(
|
||||||
proc_info.hProcess,
|
proc_info.hProcess,
|
||||||
&PAYLOAD_NVCUDA_GUID,
|
&PAYLOAD_NVCUDA_GUID,
|
||||||
inject_nvcuda_path.as_mut_ptr() as *mut _,
|
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
|
||||||
(inject_nvcuda_path.len() * mem::size_of::<u16>()) as u32
|
environment.nvcuda_path_zero_terminated.len() as u32
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
);
|
);
|
||||||
@ -69,8 +91,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||||||
detours_sys::DetourCopyPayloadToProcess(
|
detours_sys::DetourCopyPayloadToProcess(
|
||||||
proc_info.hProcess,
|
proc_info.hProcess,
|
||||||
&PAYLOAD_NVML_GUID,
|
&PAYLOAD_NVML_GUID,
|
||||||
inject_nvml_path.as_mut_ptr() as *mut _,
|
environment.nvml_path_zero_terminated.as_ptr() as *mut _,
|
||||||
(inject_nvml_path.len() * mem::size_of::<u16>()) as u32
|
environment.nvml_path_zero_terminated.len() as u32
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
);
|
);
|
||||||
@ -85,6 +107,135 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||||||
process::exit(child_exit_code as i32)
|
process::exit(child_exit_code as i32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct NormalizedArguments {
|
||||||
|
nvml_path: PathBuf,
|
||||||
|
nvcuda_path: PathBuf,
|
||||||
|
redirect_path: PathBuf,
|
||||||
|
winapi_command_line_zero_terminated: Vec<u16>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NormalizedArguments {
|
||||||
|
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
|
||||||
|
let current_exe = env::current_exe()?;
|
||||||
|
let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?;
|
||||||
|
let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?;
|
||||||
|
let winapi_command_line_zero_terminated =
|
||||||
|
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
|
||||||
|
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
|
||||||
|
redirect_path.push(REDIRECT_DLL);
|
||||||
|
Ok(Self {
|
||||||
|
nvml_path,
|
||||||
|
nvcuda_path,
|
||||||
|
redirect_path,
|
||||||
|
winapi_command_line_zero_terminated,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const WIN_MAX_PATH: usize = 260;
|
||||||
|
|
||||||
|
fn get_absolute_path(
|
||||||
|
current_exe: &PathBuf,
|
||||||
|
dll: Option<PathBuf>,
|
||||||
|
default: &str,
|
||||||
|
) -> Result<PathBuf, Box<dyn Error>> {
|
||||||
|
Ok(if let Some(dll) = dll {
|
||||||
|
if dll.is_absolute() {
|
||||||
|
dll
|
||||||
|
} else {
|
||||||
|
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
|
||||||
|
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
|
||||||
|
dll_utf16.push(0);
|
||||||
|
loop {
|
||||||
|
let copied_len = os_call!(
|
||||||
|
SearchPathW(
|
||||||
|
ptr::null_mut(),
|
||||||
|
dll_utf16.as_ptr(),
|
||||||
|
ptr::null(),
|
||||||
|
full_dll_path.len() as u32,
|
||||||
|
full_dll_path.as_mut_ptr(),
|
||||||
|
ptr::null_mut()
|
||||||
|
),
|
||||||
|
|x| x != 0
|
||||||
|
) as usize;
|
||||||
|
if copied_len > full_dll_path.len() {
|
||||||
|
full_dll_path.resize(copied_len + 1, 0);
|
||||||
|
} else {
|
||||||
|
full_dll_path.truncate(copied_len);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PathBuf::from(String::from_utf16_lossy(&full_dll_path))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let mut dll_path = current_exe.parent().unwrap().to_path_buf();
|
||||||
|
dll_path.push(default);
|
||||||
|
dll_path
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Environment {
|
||||||
|
nvml_path_zero_terminated: String,
|
||||||
|
nvcuda_path_zero_terminated: String,
|
||||||
|
redirect_path_zero_terminated: String,
|
||||||
|
winapi_command_line_zero_terminated: Vec<u16>,
|
||||||
|
_temp_dir: TempDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
// This structs represents "enviroment". By environment we mean all paths
|
||||||
|
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
|
||||||
|
// directory which contains nvcuda.dll
|
||||||
|
impl Environment {
|
||||||
|
fn setup(args: NormalizedArguments) -> io::Result<Self> {
|
||||||
|
let _temp_dir = TempDir::new()?;
|
||||||
|
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||||
|
args.nvml_path,
|
||||||
|
&_temp_dir,
|
||||||
|
NVML_DLL,
|
||||||
|
)?);
|
||||||
|
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||||
|
args.nvcuda_path,
|
||||||
|
&_temp_dir,
|
||||||
|
NVCUDA_DLL,
|
||||||
|
)?);
|
||||||
|
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
|
||||||
|
Ok(Self {
|
||||||
|
nvml_path_zero_terminated,
|
||||||
|
nvcuda_path_zero_terminated,
|
||||||
|
redirect_path_zero_terminated,
|
||||||
|
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
|
||||||
|
_temp_dir,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_to_correct_name(
|
||||||
|
path_buf: PathBuf,
|
||||||
|
temp_dir: &TempDir,
|
||||||
|
correct_name: &str,
|
||||||
|
) -> io::Result<PathBuf> {
|
||||||
|
let file_name = path_buf.file_name().unwrap();
|
||||||
|
if file_name == correct_name {
|
||||||
|
Ok(path_buf)
|
||||||
|
} else {
|
||||||
|
let mut temp_file_path = temp_dir.path().to_path_buf();
|
||||||
|
temp_file_path.push(correct_name);
|
||||||
|
match windows::fs::symlink_file(&path_buf, &temp_file_path) {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(_) => {
|
||||||
|
fs::copy(&path_buf, &temp_file_path)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(temp_file_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zero_terminate(p: PathBuf) -> String {
|
||||||
|
let mut s = p.to_string_lossy().to_string();
|
||||||
|
s.push('\0');
|
||||||
|
s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
||||||
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
|
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
|
||||||
!= ptr::null_mut());
|
!= ptr::null_mut());
|
||||||
@ -103,29 +254,11 @@ fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn print_help_and_exit() -> ! {
|
|
||||||
let current_exe = env::current_exe().unwrap();
|
|
||||||
let exe_name = current_exe.file_name().unwrap().to_string_lossy();
|
|
||||||
println!(
|
|
||||||
"USAGE:
|
|
||||||
{0} -- <EXE> [ARGS]...
|
|
||||||
{0} <DLL> -- <EXE> [ARGS]...
|
|
||||||
ARGS:
|
|
||||||
<DLL> DLL to be injected instead of system nvcuda.dll, if not provided
|
|
||||||
will use nvcuda.dll from the directory where {0} is located
|
|
||||||
<EXE> Path to the executable to be injected with <DLL>
|
|
||||||
<ARGS>... Arguments that will be passed to <EXE>
|
|
||||||
",
|
|
||||||
exe_name
|
|
||||||
);
|
|
||||||
process::exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
|
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
|
||||||
fn construct_command_line(args: &[String]) -> Vec<u16> {
|
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
|
||||||
let mut cmd_line = Vec::new();
|
let mut cmd_line = Vec::new();
|
||||||
let args_len = args.len();
|
let args_len = args.size_hint().0;
|
||||||
for (idx, arg) in args.iter().enumerate() {
|
for (idx, arg) in args.enumerate() {
|
||||||
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
|
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
|
||||||
cmd_line.extend(arg.encode_utf16());
|
cmd_line.extend(arg.encode_utf16());
|
||||||
} else {
|
} else {
|
||||||
@ -176,55 +309,3 @@ fn construct_command_line(args: &[String]) -> Vec<u16> {
|
|||||||
cmd_line.push(0);
|
cmd_line.push(0);
|
||||||
cmd_line
|
cmd_line
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_redirect_path(injector_dir: &Path) -> Vec<u8> {
|
|
||||||
let mut injector_dir = injector_dir.to_path_buf();
|
|
||||||
injector_dir.push(REDIRECT_DLL);
|
|
||||||
let mut result = injector_dir.to_string_lossy().into_owned().into_bytes();
|
|
||||||
result.push(0);
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_inject_path<'a>(
|
|
||||||
args: &'a [String],
|
|
||||||
injector_dir: &Path,
|
|
||||||
) -> std::io::Result<(Vec<u16>, Vec<u16>, &'a [String])> {
|
|
||||||
let injector_dir = injector_dir.to_path_buf();
|
|
||||||
let (nvcuda_path, unparsed_args) = if args.get(0).map(Deref::deref) == Some("--") {
|
|
||||||
(
|
|
||||||
encode_file_in_directory_raw(injector_dir.clone(), ZLUDA_DLL),
|
|
||||||
&args[1..],
|
|
||||||
)
|
|
||||||
} else if args.get(1).map(Deref::deref) == Some("--") {
|
|
||||||
let dll_path = make_absolute_and_encode(&args[0])?;
|
|
||||||
(dll_path, &args[2..])
|
|
||||||
} else {
|
|
||||||
print_help_and_exit()
|
|
||||||
};
|
|
||||||
let nvml_path = encode_file_in_directory_raw(injector_dir, ZLUDA_ML_DLL);
|
|
||||||
Ok((nvcuda_path, nvml_path, unparsed_args))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn encode_file_in_directory_raw(mut dir: PathBuf, file: &'static str) -> Vec<u16> {
|
|
||||||
dir.push(file);
|
|
||||||
let mut result = dir
|
|
||||||
.to_string_lossy()
|
|
||||||
.as_ref()
|
|
||||||
.encode_utf16()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
result.push(0);
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
fn make_absolute_and_encode(maybe_path: &str) -> std::io::Result<Vec<u16>> {
|
|
||||||
let path = Path::new(maybe_path);
|
|
||||||
let mut encoded_path = if path.is_relative() {
|
|
||||||
let mut current_dir = env::current_dir()?;
|
|
||||||
current_dir.push(path);
|
|
||||||
current_dir.as_os_str().encode_wide().collect::<Vec<_>>()
|
|
||||||
} else {
|
|
||||||
maybe_path.encode_utf16().collect::<Vec<_>>()
|
|
||||||
};
|
|
||||||
encoded_path.push(0);
|
|
||||||
Ok(encoded_path)
|
|
||||||
}
|
|
||||||
|
10
zluda_inject/tests/helpers/do_cuinit_early.rs
Normal file
10
zluda_inject/tests/helpers/do_cuinit_early.rs
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#![crate_type = "bin"]
|
||||||
|
|
||||||
|
#[link(name = "do_cuinit")]
|
||||||
|
extern "system" {
|
||||||
|
fn do_cuinit(flags: u32) -> u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
unsafe { do_cuinit(0) };
|
||||||
|
}
|
10
zluda_inject/tests/helpers/subprocess.rs
Normal file
10
zluda_inject/tests/helpers/subprocess.rs
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#![crate_type = "bin"]
|
||||||
|
|
||||||
|
use std::io;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
fn main() -> io::Result<()> {
|
||||||
|
let status = Command::new("direct_cuinit.exe").status()?;
|
||||||
|
assert!(status.success());
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -5,19 +5,29 @@ fn direct_cuinit() -> io::Result<()> {
|
|||||||
run_process_and_check_for_zluda_dump("direct_cuinit")
|
run_process_and_check_for_zluda_dump("direct_cuinit")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn do_cuinit_early() -> io::Result<()> {
|
||||||
|
run_process_and_check_for_zluda_dump("do_cuinit_early")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn do_cuinit_late() -> io::Result<()> {
|
||||||
|
run_process_and_check_for_zluda_dump("do_cuinit_late")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn do_cuinit_late_clr() -> io::Result<()> {
|
||||||
|
run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn indirect_cuinit() -> io::Result<()> {
|
fn indirect_cuinit() -> io::Result<()> {
|
||||||
run_process_and_check_for_zluda_dump("indirect_cuinit")
|
run_process_and_check_for_zluda_dump("indirect_cuinit")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn do_cuinit() -> io::Result<()> {
|
fn subprocess() -> io::Result<()> {
|
||||||
run_process_and_check_for_zluda_dump("do_cuinit_main")
|
run_process_and_check_for_zluda_dump("subprocess")
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn do_cuinit_clr() -> io::Result<()> {
|
|
||||||
run_process_and_check_for_zluda_dump("do_cuinit_main_clr")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
||||||
@ -27,7 +37,11 @@ fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
|||||||
let helpers_dir = env!("HELPERS_OUT_DIR");
|
let helpers_dir = env!("HELPERS_OUT_DIR");
|
||||||
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
|
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
|
||||||
let mut test_cmd = Command::new(&zluda_with_exe);
|
let mut test_cmd = Command::new(&zluda_with_exe);
|
||||||
let test_cmd = test_cmd.arg(&zluda_dump_dll).arg("--").arg(&exe_under_test);
|
let test_cmd = test_cmd
|
||||||
|
.arg("--nvcuda")
|
||||||
|
.arg(&zluda_dump_dll)
|
||||||
|
.arg("--")
|
||||||
|
.arg(&exe_under_test);
|
||||||
let test_output = test_cmd.output()?;
|
let test_output = test_cmd.output()?;
|
||||||
assert!(test_output.status.success());
|
assert!(test_output.status.success());
|
||||||
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
|
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
use std::io::Write;
|
|
||||||
use std::slice;
|
|
||||||
use std::{
|
use std::{
|
||||||
os::raw::{c_char, c_uint},
|
os::raw::{c_char, c_uint},
|
||||||
ptr,
|
ptr,
|
||||||
|
@ -11,6 +11,3 @@ crate-type = ["cdylib"]
|
|||||||
detours-sys = { path = "../detours-sys" }
|
detours-sys = { path = "../detours-sys" }
|
||||||
wchar = "0.6"
|
wchar = "0.6"
|
||||||
winapi = { version = "0.3", features = [ "sysinfoapi", "memoryapi", "processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
|
winapi = { version = "0.3", features = [ "sysinfoapi", "memoryapi", "processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
|
||||||
tempfile = "3"
|
|
||||||
goblin = { version = "0.4", default-features = false, features = ["pe64"] }
|
|
||||||
memoffset = "0.6"
|
|
@ -6,32 +6,35 @@ extern crate winapi;
|
|||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
ffi::{c_void, CStr},
|
ffi::{c_void, CStr},
|
||||||
io, mem,
|
mem, ptr, slice, usize,
|
||||||
os::raw::c_uint,
|
|
||||||
ptr, slice, usize,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use detours_sys::{
|
use detours_sys::{
|
||||||
DetourAllocateRegionWithinJumpBounds, DetourAttach, DetourEnumerateExports,
|
DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
|
||||||
DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
|
|
||||||
DetourTransactionCommit, DetourUpdateProcessWithDll, DetourUpdateThread,
|
DetourTransactionCommit, DetourUpdateProcessWithDll, DetourUpdateThread,
|
||||||
};
|
};
|
||||||
use goblin::pe::{
|
|
||||||
self,
|
|
||||||
header::{CoffHeader, DOS_MAGIC, PE_MAGIC, PE_POINTER_OFFSET},
|
|
||||||
optional_header::StandardFields64,
|
|
||||||
};
|
|
||||||
use memoffset::offset_of;
|
|
||||||
use tempfile::TempDir;
|
|
||||||
use wchar::wch;
|
use wchar::wch;
|
||||||
|
use winapi::{
|
||||||
|
shared::minwindef::{BOOL, LPVOID},
|
||||||
|
um::{
|
||||||
|
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
|
||||||
|
minwinbase::LPSECURITY_ATTRIBUTES,
|
||||||
|
processthreadsapi::{
|
||||||
|
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
|
||||||
|
SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
|
||||||
|
},
|
||||||
|
tlhelp32::{
|
||||||
|
CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
|
||||||
|
},
|
||||||
|
winbase::CREATE_SUSPENDED,
|
||||||
|
winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
|
||||||
|
},
|
||||||
|
};
|
||||||
use winapi::{
|
use winapi::{
|
||||||
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
|
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
|
||||||
um::{
|
um::{
|
||||||
libloaderapi::{GetModuleHandleA, LoadLibraryExA},
|
libloaderapi::{GetModuleHandleA, LoadLibraryExA},
|
||||||
memoryapi::VirtualProtect,
|
winnt::LPCSTR,
|
||||||
processthreadsapi::{FlushInstructionCache, GetCurrentProcess},
|
|
||||||
sysinfoapi::GetSystemInfo,
|
|
||||||
winnt::{LPCSTR, PAGE_READWRITE},
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use winapi::{
|
use winapi::{
|
||||||
@ -47,26 +50,6 @@ use winapi::{
|
|||||||
shared::winerror::NO_ERROR,
|
shared::winerror::NO_ERROR,
|
||||||
um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
|
um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
|
||||||
};
|
};
|
||||||
use winapi::{
|
|
||||||
shared::{
|
|
||||||
minwindef::{BOOL, LPVOID},
|
|
||||||
winerror::E_UNEXPECTED,
|
|
||||||
},
|
|
||||||
um::{
|
|
||||||
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
|
|
||||||
libloaderapi::GetModuleHandleW,
|
|
||||||
minwinbase::LPSECURITY_ATTRIBUTES,
|
|
||||||
processthreadsapi::{
|
|
||||||
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
|
|
||||||
SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
|
|
||||||
},
|
|
||||||
tlhelp32::{
|
|
||||||
CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
|
|
||||||
},
|
|
||||||
winbase::{CopyFileW, CreateSymbolicLinkW, CREATE_SUSPENDED},
|
|
||||||
winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
include!("payload_guid.rs");
|
include!("payload_guid.rs");
|
||||||
|
|
||||||
@ -74,13 +57,12 @@ const NVCUDA_UTF8: &'static str = "NVCUDA.DLL";
|
|||||||
const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL");
|
const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL");
|
||||||
const NVML_UTF8: &'static str = "NVML.DLL";
|
const NVML_UTF8: &'static str = "NVML.DLL";
|
||||||
const NVML_UTF16: &[u16] = wch!("NVML.DLL");
|
const NVML_UTF16: &[u16] = wch!("NVML.DLL");
|
||||||
static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new();
|
static mut ZLUDA_PATH_UTF8: Option<&'static [u8]> = None;
|
||||||
static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
|
static mut ZLUDA_PATH_UTF16: Vec<u16> = Vec::new();
|
||||||
static mut ZLUDA_ML_PATH_UTF8: Vec<u8> = Vec::new();
|
static mut ZLUDA_ML_PATH_UTF8: Option<&'static [u8]> = None;
|
||||||
static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None;
|
static mut ZLUDA_ML_PATH_UTF16: Vec<u16> = Vec::new();
|
||||||
static mut CURRENT_MODULE_FILENAME: Vec<u8> = Vec::new();
|
static mut CURRENT_MODULE_FILENAME: Vec<u8> = Vec::new();
|
||||||
static mut DETOUR_STATE: Option<DetourDetachGuard> = None;
|
static mut DETOUR_STATE: Option<DetourDetachGuard> = None;
|
||||||
const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801;
|
|
||||||
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
#[used]
|
#[used]
|
||||||
@ -197,9 +179,9 @@ unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -
|
|||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
|
unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
|
||||||
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
||||||
ZLUDA_PATH_UTF8.as_ptr() as *const _
|
ZLUDA_PATH_UTF8.unwrap().as_ptr() as *const _
|
||||||
} else if is_nvml_dll_utf8(lpLibFileName as *const _) {
|
} else if is_nvml_dll_utf8(lpLibFileName as *const _) {
|
||||||
ZLUDA_ML_PATH_UTF8.as_ptr() as *const _
|
ZLUDA_ML_PATH_UTF8.unwrap().as_ptr() as *const _
|
||||||
} else {
|
} else {
|
||||||
lpLibFileName
|
lpLibFileName
|
||||||
};
|
};
|
||||||
@ -209,9 +191,9 @@ unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
|
|||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE {
|
unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE {
|
||||||
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
|
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
|
||||||
ZLUDA_PATH_UTF16.unwrap().as_ptr()
|
ZLUDA_PATH_UTF16.as_ptr()
|
||||||
} else if is_nvml_dll_utf16(lpLibFileName as *const _) {
|
} else if is_nvml_dll_utf16(lpLibFileName as *const _) {
|
||||||
ZLUDA_ML_PATH_UTF16.unwrap().as_ptr()
|
ZLUDA_ML_PATH_UTF16.as_ptr()
|
||||||
} else {
|
} else {
|
||||||
lpLibFileName
|
lpLibFileName
|
||||||
};
|
};
|
||||||
@ -225,9 +207,9 @@ unsafe extern "system" fn ZludaLoadLibraryExA(
|
|||||||
dwFlags: DWORD,
|
dwFlags: DWORD,
|
||||||
) -> HMODULE {
|
) -> HMODULE {
|
||||||
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
||||||
ZLUDA_PATH_UTF8.as_ptr() as *const _
|
ZLUDA_PATH_UTF8.unwrap().as_ptr() as *const _
|
||||||
} else if is_nvml_dll_utf8(lpLibFileName as *const _) {
|
} else if is_nvml_dll_utf8(lpLibFileName as *const _) {
|
||||||
ZLUDA_ML_PATH_UTF8.as_ptr() as *const _
|
ZLUDA_ML_PATH_UTF8.unwrap().as_ptr() as *const _
|
||||||
} else {
|
} else {
|
||||||
lpLibFileName
|
lpLibFileName
|
||||||
};
|
};
|
||||||
@ -241,9 +223,9 @@ unsafe extern "system" fn ZludaLoadLibraryExW(
|
|||||||
dwFlags: DWORD,
|
dwFlags: DWORD,
|
||||||
) -> HMODULE {
|
) -> HMODULE {
|
||||||
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
|
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
|
||||||
ZLUDA_PATH_UTF16.unwrap().as_ptr()
|
ZLUDA_PATH_UTF16.as_ptr()
|
||||||
} else if is_nvml_dll_utf16(lpLibFileName as *const _) {
|
} else if is_nvml_dll_utf16(lpLibFileName as *const _) {
|
||||||
ZLUDA_ML_PATH_UTF16.unwrap().as_ptr()
|
ZLUDA_ML_PATH_UTF16.as_ptr()
|
||||||
} else {
|
} else {
|
||||||
lpLibFileName
|
lpLibFileName
|
||||||
};
|
};
|
||||||
@ -392,67 +374,6 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW(
|
|||||||
continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
|
continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
|
||||||
}
|
}
|
||||||
|
|
||||||
static mut MAIN: unsafe extern "system" fn() -> DWORD = zluda_main;
|
|
||||||
static mut COR_EXE_MAIN: unsafe extern "system" fn() -> DWORD = zluda_main_clr;
|
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications
|
|
||||||
// "If a DLL with the same module name is already loaded in memory, the system
|
|
||||||
// uses the loaded DLL, no matter which directory it is in. The system does not
|
|
||||||
// search for the DLL."
|
|
||||||
unsafe extern "system" fn zluda_main() -> DWORD {
|
|
||||||
zluda_main_impl(MAIN)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "system" fn zluda_main_clr() -> DWORD {
|
|
||||||
zluda_main_impl(COR_EXE_MAIN)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn zluda_main_impl(original: unsafe extern "system" fn() -> DWORD) -> DWORD {
|
|
||||||
let temp_dir = match do_zluda_preload() {
|
|
||||||
Ok(f) => f,
|
|
||||||
Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as DWORD,
|
|
||||||
};
|
|
||||||
let result = original();
|
|
||||||
drop(temp_dir);
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn do_zluda_preload() -> std::io::Result<TempDir> {
|
|
||||||
let temp_dir = tempfile::tempdir()?;
|
|
||||||
do_single_zluda_preload(&temp_dir, ZLUDA_PATH_UTF16.unwrap().as_ptr(), NVCUDA_UTF8)?;
|
|
||||||
do_single_zluda_preload(&temp_dir, ZLUDA_ML_PATH_UTF16.unwrap().as_ptr(), NVML_UTF8)?;
|
|
||||||
Ok(temp_dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn do_single_zluda_preload(
|
|
||||||
temp_dir: &TempDir,
|
|
||||||
full_path: *const u16,
|
|
||||||
file_name: &'static str,
|
|
||||||
) -> io::Result<()> {
|
|
||||||
let mut temp_file_path = temp_dir.path().to_path_buf();
|
|
||||||
temp_file_path.push(file_name);
|
|
||||||
let mut temp_file_path_utf16 = temp_file_path
|
|
||||||
.into_os_string()
|
|
||||||
.to_string_lossy()
|
|
||||||
.encode_utf16()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
temp_file_path_utf16.push(0);
|
|
||||||
// Probably we are not in developer mode, do a copty then
|
|
||||||
if 0 == CreateSymbolicLinkW(
|
|
||||||
temp_file_path_utf16.as_ptr(),
|
|
||||||
full_path,
|
|
||||||
0x2, //SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE
|
|
||||||
) {
|
|
||||||
if 0 == CopyFileW(full_path, temp_file_path_utf16.as_ptr(), 1) {
|
|
||||||
return Err(io::Error::last_os_error());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ptr::null_mut() == ZludaLoadLibraryW_NoRedirect(temp_file_path_utf16.as_ptr()) {
|
|
||||||
return Err(io::Error::last_os_error());
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// This type encapsulates typical calling sequence of detours and cleanup.
|
// This type encapsulates typical calling sequence of detours and cleanup.
|
||||||
// We have two ways we do detours:
|
// We have two ways we do detours:
|
||||||
// * If we are loaded before nvcuda.dll, we hook LoadLibrary*
|
// * If we are loaded before nvcuda.dll, we hook LoadLibrary*
|
||||||
@ -633,21 +554,31 @@ unsafe fn continue_create_process_hook(
|
|||||||
// continues uninjected than to break the parent
|
// continues uninjected than to break the parent
|
||||||
if DetourUpdateProcessWithDll(
|
if DetourUpdateProcessWithDll(
|
||||||
(*process_information).hProcess,
|
(*process_information).hProcess,
|
||||||
&mut CURRENT_MODULE_FILENAME.as_ptr() as *mut _ as *mut _,
|
&mut ZLUDA_ML_PATH_UTF8.unwrap().as_ptr() as *mut _ as *mut _,
|
||||||
1,
|
1,
|
||||||
) != FALSE
|
) != FALSE
|
||||||
|
&& DetourUpdateProcessWithDll(
|
||||||
|
(*process_information).hProcess,
|
||||||
|
&mut ZLUDA_PATH_UTF8.unwrap().as_ptr() as *mut _ as *mut _,
|
||||||
|
1,
|
||||||
|
) != FALSE
|
||||||
|
&& DetourUpdateProcessWithDll(
|
||||||
|
(*process_information).hProcess,
|
||||||
|
&mut CURRENT_MODULE_FILENAME.as_ptr() as *mut _ as *mut _,
|
||||||
|
1,
|
||||||
|
) != FALSE
|
||||||
{
|
{
|
||||||
detours_sys::DetourCopyPayloadToProcess(
|
detours_sys::DetourCopyPayloadToProcess(
|
||||||
(*process_information).hProcess,
|
(*process_information).hProcess,
|
||||||
&PAYLOAD_NVML_GUID,
|
&PAYLOAD_NVML_GUID,
|
||||||
ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() as *mut _,
|
ZLUDA_ML_PATH_UTF16.as_ptr() as *mut _,
|
||||||
(ZLUDA_ML_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32,
|
(ZLUDA_ML_PATH_UTF16.len() * mem::size_of::<u16>()) as u32,
|
||||||
);
|
);
|
||||||
detours_sys::DetourCopyPayloadToProcess(
|
detours_sys::DetourCopyPayloadToProcess(
|
||||||
(*process_information).hProcess,
|
(*process_information).hProcess,
|
||||||
&PAYLOAD_NVCUDA_GUID,
|
&PAYLOAD_NVCUDA_GUID,
|
||||||
ZLUDA_PATH_UTF16.unwrap().as_ptr() as *mut _,
|
ZLUDA_PATH_UTF16.as_ptr() as *mut _,
|
||||||
(ZLUDA_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32,
|
(ZLUDA_PATH_UTF16.len() * mem::size_of::<u16>()) as u32,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if original_creation_flags & CREATE_SUSPENDED == 0 {
|
if original_creation_flags & CREATE_SUSPENDED == 0 {
|
||||||
@ -733,23 +664,18 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
|
|||||||
}
|
}
|
||||||
match get_zluda_dlls_paths() {
|
match get_zluda_dlls_paths() {
|
||||||
Some((nvcuda_path, nvml_path)) => {
|
Some((nvcuda_path, nvml_path)) => {
|
||||||
ZLUDA_PATH_UTF16 = Some(nvcuda_path);
|
ZLUDA_PATH_UTF8 = Some(nvcuda_path);
|
||||||
ZLUDA_ML_PATH_UTF16 = Some(nvml_path);
|
ZLUDA_ML_PATH_UTF8 = Some(nvml_path);
|
||||||
// from_utf16_lossy(...) handles terminating NULL correctly
|
ZLUDA_PATH_UTF16 = std::str::from_utf8_unchecked(nvcuda_path)
|
||||||
ZLUDA_PATH_UTF8 = String::from_utf16_lossy(nvcuda_path).into_bytes();
|
.encode_utf16()
|
||||||
ZLUDA_ML_PATH_UTF8 = String::from_utf16_lossy(nvml_path).into_bytes();
|
.collect::<Vec<_>>();
|
||||||
|
ZLUDA_ML_PATH_UTF16 = std::str::from_utf8_unchecked(nvml_path)
|
||||||
|
.encode_utf16()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
}
|
}
|
||||||
None => return FALSE,
|
None => return FALSE,
|
||||||
}
|
}
|
||||||
// If the application (directly or not) links to nvcuda.dll, nvcuda.dll
|
match detour_already_loaded_nvcuda() {
|
||||||
// will get loaded before we can act. In this case, instead of
|
|
||||||
// redirecting LoadLibrary* to load ZLUDA, we override already loaded
|
|
||||||
// functions
|
|
||||||
let detach_guard = match get_cuinit() {
|
|
||||||
Some((nvcuda_mod, _)) => detour_already_loaded_nvcuda(nvcuda_mod),
|
|
||||||
None => detour_main(),
|
|
||||||
};
|
|
||||||
match detach_guard {
|
|
||||||
Some(g) => {
|
Some(g) => {
|
||||||
DETOUR_STATE = Some(g);
|
DETOUR_STATE = Some(g);
|
||||||
TRUE
|
TRUE
|
||||||
@ -787,42 +713,9 @@ unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> {
|
|
||||||
let nvcuda = GetModuleHandleA(b"nvcuda\0".as_ptr() as _);
|
|
||||||
if nvcuda == ptr::null_mut() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let cuinit_addr = GetProcAddress(nvcuda, b"cuInit\0".as_ptr() as _);
|
|
||||||
if cuinit_addr == ptr::null_mut() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
Some((nvcuda as *mut _, cuinit_addr))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
|
unsafe fn detour_already_loaded_nvcuda() -> Option<DetourDetachGuard> {
|
||||||
let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr());
|
let nvcuda_mod = GetModuleHandleA(b"nvcuda\0".as_ptr() as _);
|
||||||
if zluda_module == ptr::null_mut() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let original_functions = gather_imports(nvcuda_mod);
|
|
||||||
let override_functions = gather_imports(zluda_module);
|
|
||||||
let mut override_fn_pairs = HashMap::with_capacity(original_functions.len());
|
|
||||||
// TODO: optimize
|
|
||||||
for (original_fn_name, original_fn_address) in original_functions {
|
|
||||||
let override_fn_address =
|
|
||||||
match override_functions.binary_search_by_key(&original_fn_name, |(name, _)| *name) {
|
|
||||||
Ok(x) => override_functions[x].1,
|
|
||||||
Err(_) => {
|
|
||||||
// TODO: print a warning in debug
|
|
||||||
cuda_unsupported as _
|
|
||||||
}
|
|
||||||
};
|
|
||||||
override_fn_pairs.insert(
|
|
||||||
original_fn_name,
|
|
||||||
(original_fn_address as _, override_fn_address),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let detour_functions = vec![
|
let detour_functions = vec![
|
||||||
(
|
(
|
||||||
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
||||||
@ -838,146 +731,10 @@ unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option<DetourDeta
|
|||||||
ZludaLoadLibraryExW as _,
|
ZludaLoadLibraryExW as _,
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, override_fn_pairs)
|
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, HashMap::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn cuda_unsupported() -> c_uint {
|
fn get_zluda_dlls_paths() -> Option<(&'static [u8], &'static [u8])> {
|
||||||
CUDA_ERROR_NOT_SUPPORTED
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn gather_imports(module: HINSTANCE) -> Vec<(&'static CStr, *mut c_void)> {
|
|
||||||
let mut result = Vec::new();
|
|
||||||
DetourEnumerateExports(
|
|
||||||
module as _,
|
|
||||||
&mut result as *mut _ as *mut _,
|
|
||||||
Some(gather_imports_impl),
|
|
||||||
);
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "stdcall" fn gather_imports_impl(
|
|
||||||
context: *mut c_void,
|
|
||||||
_: u32,
|
|
||||||
name: LPCSTR,
|
|
||||||
code: *mut c_void,
|
|
||||||
) -> i32 {
|
|
||||||
let result: &mut Vec<(&'static CStr, *mut c_void)> = &mut *(context as *mut Vec<_>);
|
|
||||||
result.push((CStr::from_ptr(name), code));
|
|
||||||
TRUE
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
unsafe fn detour_main() -> Option<DetourDetachGuard> {
|
|
||||||
if !override_entry_point() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let mut detour_functions = vec![
|
|
||||||
(
|
|
||||||
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
|
||||||
ZludaLoadLibraryA as *mut c_void,
|
|
||||||
),
|
|
||||||
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
|
|
||||||
(
|
|
||||||
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
|
|
||||||
ZludaLoadLibraryExA as _,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
|
|
||||||
ZludaLoadLibraryExW as _,
|
|
||||||
),
|
|
||||||
];
|
|
||||||
detour_functions.extend(get_clr_entry_point());
|
|
||||||
DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new())
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn override_entry_point() -> bool {
|
|
||||||
let exe_handle = GetModuleHandleW(ptr::null());
|
|
||||||
let dos_signature = exe_handle as *mut u16;
|
|
||||||
if *dos_signature != DOS_MAGIC {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
let pe_offset = *((exe_handle as *mut u8).add(PE_POINTER_OFFSET as usize) as *mut u32);
|
|
||||||
let pe_sig = (exe_handle as *mut u8).add(pe_offset as usize) as *mut u32;
|
|
||||||
if (*pe_sig) != PE_MAGIC {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
let coff_header = pe_sig.add(1) as *mut CoffHeader;
|
|
||||||
let standard_coff_fields = coff_header.add(1) as *mut StandardFields64;
|
|
||||||
if (*standard_coff_fields).magic != pe::optional_header::MAGIC_64 {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
let entry_point = mem::transmute::<_, unsafe extern "system" fn() -> DWORD>(
|
|
||||||
(exe_handle as *mut u8).add((*standard_coff_fields).address_of_entry_point as usize),
|
|
||||||
);
|
|
||||||
let mut allocated_size = 0;
|
|
||||||
let exe_region = DetourAllocateRegionWithinJumpBounds(exe_handle as _, &mut allocated_size);
|
|
||||||
if (allocated_size as usize) < mem::size_of::<JmpThunk64>() || exe_region == ptr::null_mut() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MAIN = entry_point;
|
|
||||||
*(exe_region as *mut JmpThunk64) = JmpThunk64::new(zluda_main);
|
|
||||||
FlushInstructionCache(
|
|
||||||
GetCurrentProcess(),
|
|
||||||
exe_region,
|
|
||||||
mem::size_of::<JmpThunk64>(),
|
|
||||||
);
|
|
||||||
let new_address_of_entry_point = (exe_region as *mut u8).offset_from(exe_handle as *mut u8);
|
|
||||||
let entry_point_offset = offset_of!(StandardFields64, address_of_entry_point);
|
|
||||||
let mut system_info = mem::zeroed();
|
|
||||||
GetSystemInfo(&mut system_info);
|
|
||||||
let pointer_to_address_of_entry_point =
|
|
||||||
(standard_coff_fields as *mut u8).add(entry_point_offset) as *mut i32;
|
|
||||||
let page_size = system_info.dwPageSize as usize;
|
|
||||||
let page_start = (((pointer_to_address_of_entry_point as usize) / page_size) * page_size) as _;
|
|
||||||
let mut old_protect = 0;
|
|
||||||
if VirtualProtect(page_start, page_size, PAGE_READWRITE, &mut old_protect) == 0 {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*pointer_to_address_of_entry_point = new_address_of_entry_point as i32;
|
|
||||||
if VirtualProtect(page_start, page_size, old_protect, &mut old_protect) == 0 {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
// mov rax, $address;
|
|
||||||
// jmp rax;
|
|
||||||
// int 3;
|
|
||||||
#[repr(packed)]
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[cfg(target_pointer_width = "64")]
|
|
||||||
struct JmpThunk64 {
|
|
||||||
mov_rax: [u8; 2],
|
|
||||||
address: u64,
|
|
||||||
jmp_rax: [u8; 2],
|
|
||||||
int3: u8,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl JmpThunk64 {
|
|
||||||
fn new<T: Sized>(target: unsafe extern "system" fn() -> T) -> Self {
|
|
||||||
JmpThunk64 {
|
|
||||||
mov_rax: [0x48, 0xB8],
|
|
||||||
address: target as u64,
|
|
||||||
jmp_rax: [0xFF, 0xE0],
|
|
||||||
int3: 0xcc,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn get_clr_entry_point() -> Option<(*mut *mut c_void, *mut c_void)> {
|
|
||||||
let mscoree = GetModuleHandleA(b"mscoree\0".as_ptr() as _);
|
|
||||||
if mscoree == ptr::null_mut() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let proc = GetProcAddress(mscoree, b"_CorExeMain\0".as_ptr() as _);
|
|
||||||
if proc == ptr::null_mut() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
COR_EXE_MAIN = mem::transmute(proc);
|
|
||||||
Some((&mut COR_EXE_MAIN as *mut _ as _, zluda_main_clr as _))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> {
|
|
||||||
match get_payload(&PAYLOAD_NVCUDA_GUID) {
|
match get_payload(&PAYLOAD_NVCUDA_GUID) {
|
||||||
None => None,
|
None => None,
|
||||||
Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
|
Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
|
||||||
@ -987,22 +744,17 @@ fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_payload(guid: &detours_sys::GUID) -> Option<&'static [u16]> {
|
fn get_payload(guid: &detours_sys::GUID) -> Option<&'static [u8]> {
|
||||||
let mut module = ptr::null_mut();
|
let mut size = 0;
|
||||||
loop {
|
let payload_ptr = unsafe { detours_sys::DetourFindPayloadEx(guid, &mut size) };
|
||||||
module = unsafe { detours_sys::DetourEnumerateModules(module) };
|
if payload_ptr != ptr::null_mut() {
|
||||||
if module == ptr::null_mut() {
|
Some(unsafe {
|
||||||
return None;
|
slice::from_raw_parts(
|
||||||
}
|
payload_ptr as *const _,
|
||||||
let mut size = 0;
|
(size as usize) / mem::size_of::<u16>(),
|
||||||
let payload_ptr = unsafe { detours_sys::DetourFindPayload(module, guid, &mut size) };
|
)
|
||||||
if payload_ptr != ptr::null_mut() {
|
})
|
||||||
return Some(unsafe {
|
} else {
|
||||||
slice::from_raw_parts(
|
None
|
||||||
payload_ptr as *const _,
|
|
||||||
(size as usize) / mem::size_of::<u16>(),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user