From b8bcbec295b0d989c95bf3f8b66260dd95ff5ed1 Mon Sep 17 00:00:00 2001 From: Violet Date: Wed, 30 Jul 2025 15:09:47 -0700 Subject: [PATCH] Always use Unix line endings (#453) --- .rustfmt.toml | 1 + llvm_zluda/build.rs | 294 ++-- llvm_zluda/src/lib.rs | 162 +- ptx/src/pass/deparamize_functions.rs | 382 ++--- ptx/src/pass/expand_operands.rs | 602 ++++---- ptx/src/pass/fix_special_registers2.rs | 416 ++--- ptx/src/pass/hoist_globals.rs | 90 +- ptx/src/pass/insert_explicit_load_store.rs | 808 +++++----- ptx/src/pass/insert_implicit_conversions2.rs | 802 +++++----- ptx/src/pass/normalize_identifiers2.rs | 388 ++--- ptx/src/pass/normalize_predicates2.rs | 180 +-- ...eplace_instructions_with_function_calls.rs | 536 +++---- ptx/src/pass/replace_known_functions.rs | 66 +- ptx/src/pass/resolve_function_pointers.rs | 138 +- xtask/src/main.rs | 654 ++++---- zluda/src/impl/driver.rs | 852 +++++------ zluda/src/impl/os_unix.rs | 18 +- zluda/src/impl/os_win.rs | 18 +- zluda_dump/src/dark_api.rs | 248 +-- zluda_dump/src/log.rs | 1336 ++++++++--------- zluda_dump/src/os_unix.rs | 162 +- zluda_dump/src/os_win.rs | 380 ++--- zluda_dump/src/trace.rs | 668 ++++----- zluda_inject/build.rs | 162 +- zluda_inject/src/bin.rs | 622 ++++---- zluda_inject/src/main.rs | 26 +- zluda_inject/src/win.rs | 302 ++-- zluda_inject/tests/inject.rs | 102 +- 28 files changed, 5208 insertions(+), 5207 deletions(-) create mode 100644 .rustfmt.toml diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..43d4840 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1 @@ +newline_style = "Unix" diff --git a/llvm_zluda/build.rs b/llvm_zluda/build.rs index 9505a86..ac9da3f 100644 --- a/llvm_zluda/build.rs +++ b/llvm_zluda/build.rs @@ -1,147 +1,147 @@ -use cmake::Config; -use std::io; -use std::path::PathBuf; -use std::process::Command; - -const COMPONENTS: &[&'static str] = &[ - "LLVMCore", - "LLVMBitWriter", - #[cfg(debug_assertions)] - "LLVMAnalysis", // for module verify - #[cfg(debug_assertions)] - "LLVMBitReader", -]; - -fn main() { - let mut cmake = Config::new(r"../ext/llvm-project/llvm"); - try_use_sccache(&mut cmake); - try_use_ninja(&mut cmake); - cmake - // It's not like we can do anything about the warnings - .define("LLVM_ENABLE_WARNINGS", "OFF") - // For some reason Rust always links to release CRT - .define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded") - .define("LLVM_ENABLE_TERMINFO", "OFF") - .define("LLVM_ENABLE_LIBXML2", "OFF") - .define("LLVM_ENABLE_LIBEDIT", "OFF") - .define("LLVM_ENABLE_LIBPFM", "OFF") - .define("LLVM_ENABLE_ZLIB", "OFF") - .define("LLVM_ENABLE_ZSTD", "OFF") - .define("LLVM_INCLUDE_BENCHMARKS", "OFF") - .define("LLVM_INCLUDE_EXAMPLES", "OFF") - .define("LLVM_INCLUDE_TESTS", "OFF") - .define("LLVM_BUILD_TOOLS", "OFF") - .define("LLVM_TARGETS_TO_BUILD", "") - .define("LLVM_ENABLE_PROJECTS", ""); - cmake.build_target("llvm-config"); - let llvm_dir = cmake.build(); - for c in COMPONENTS { - cmake.build_target(c); - cmake.build(); - } - let cmake_profile = cmake.get_profile(); - let (cxxflags, ldflags, libdir, lib_names, system_libs) = - llvm_config(&llvm_dir, &["build", "bin", "llvm-config"]) - .or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"])) - .unwrap(); - println!("cargo:rustc-link-arg={ldflags}"); - println!("cargo:rustc-link-search=native={libdir}"); - for lib in system_libs.split_ascii_whitespace() { - println!("cargo:rustc-link-arg={lib}"); - } - link_llvm_components(lib_names); - compile_cxx_lib(cxxflags); -} - -// https://github.com/mozilla/sccache/blob/main/README.md#usage -fn try_use_sccache(cmake: &mut Config) { - if let Ok(sccache) = std::env::var("SCCACHE_PATH") { - cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache); - cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache); - match std::env::var_os("CARGO_CFG_TARGET_OS") { - Some(os) if os == "windows" => { - cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded"); - cmake.define("CMAKE_POLICY_CMP0141", "NEW"); - } - _ => {} - } - } -} - -fn try_use_ninja(cmake: &mut Config) { - let mut cmd = Command::new("ninja"); - cmd.arg("--version"); - if let Ok(status) = cmd.status() { - if status.success() { - cmake.generator("Ninja"); - } - } -} - -fn llvm_config( - llvm_build_dir: &PathBuf, - path_to_llvm_config: &[&str], -) -> io::Result<(String, String, String, String, String)> { - let mut llvm_build_path = llvm_build_dir.clone(); - llvm_build_path.extend(path_to_llvm_config); - let mut cmd = Command::new(llvm_build_path); - cmd.args([ - "--link-static", - "--cxxflags", - "--ldflags", - "--libdir", - "--libnames", - "--system-libs", - ]); - for c in COMPONENTS { - cmd.arg(c[4..].to_lowercase()); - } - let output = cmd.output()?; - if !output.status.success() { - return Err(io::Error::from(io::ErrorKind::Other)); - } - let output = unsafe { String::from_utf8_unchecked(output.stdout) }; - let mut lines = output.lines(); - let cxxflags = lines.next().unwrap(); - let ldflags = lines.next().unwrap(); - let libdir = lines.next().unwrap(); - let lib_names = lines.next().unwrap(); - let system_libs = lines.next().unwrap(); - Ok(( - cxxflags.to_string(), - ldflags.to_string(), - libdir.to_string(), - lib_names.to_string(), - system_libs.to_string(), - )) -} - -fn compile_cxx_lib(cxxflags: String) { - let mut cc = cc::Build::new(); - for flag in cxxflags.split_whitespace() { - cc.flag(flag); - } - cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp"); - println!("cargo:rerun-if-changed=src/lib.cpp"); - println!("cargo:rerun-if-changed=src/lib.rs"); -} - -fn link_llvm_components(components: String) { - for component in components.split_whitespace() { - let component = if let Some(component) = component - .strip_prefix("lib") - .and_then(|component| component.strip_suffix(".a")) - { - // Unix (Linux/Mac) - // libLLVMfoo.a - component - } else if let Some(component) = component.strip_suffix(".lib") { - // Windows - // LLVMfoo.lib - component - } else { - panic!("'{}' does not look like a static library name", component) - }; - println!("cargo:rustc-link-lib={component}"); - } -} +use cmake::Config; +use std::io; +use std::path::PathBuf; +use std::process::Command; + +const COMPONENTS: &[&'static str] = &[ + "LLVMCore", + "LLVMBitWriter", + #[cfg(debug_assertions)] + "LLVMAnalysis", // for module verify + #[cfg(debug_assertions)] + "LLVMBitReader", +]; + +fn main() { + let mut cmake = Config::new(r"../ext/llvm-project/llvm"); + try_use_sccache(&mut cmake); + try_use_ninja(&mut cmake); + cmake + // It's not like we can do anything about the warnings + .define("LLVM_ENABLE_WARNINGS", "OFF") + // For some reason Rust always links to release CRT + .define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded") + .define("LLVM_ENABLE_TERMINFO", "OFF") + .define("LLVM_ENABLE_LIBXML2", "OFF") + .define("LLVM_ENABLE_LIBEDIT", "OFF") + .define("LLVM_ENABLE_LIBPFM", "OFF") + .define("LLVM_ENABLE_ZLIB", "OFF") + .define("LLVM_ENABLE_ZSTD", "OFF") + .define("LLVM_INCLUDE_BENCHMARKS", "OFF") + .define("LLVM_INCLUDE_EXAMPLES", "OFF") + .define("LLVM_INCLUDE_TESTS", "OFF") + .define("LLVM_BUILD_TOOLS", "OFF") + .define("LLVM_TARGETS_TO_BUILD", "") + .define("LLVM_ENABLE_PROJECTS", ""); + cmake.build_target("llvm-config"); + let llvm_dir = cmake.build(); + for c in COMPONENTS { + cmake.build_target(c); + cmake.build(); + } + let cmake_profile = cmake.get_profile(); + let (cxxflags, ldflags, libdir, lib_names, system_libs) = + llvm_config(&llvm_dir, &["build", "bin", "llvm-config"]) + .or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"])) + .unwrap(); + println!("cargo:rustc-link-arg={ldflags}"); + println!("cargo:rustc-link-search=native={libdir}"); + for lib in system_libs.split_ascii_whitespace() { + println!("cargo:rustc-link-arg={lib}"); + } + link_llvm_components(lib_names); + compile_cxx_lib(cxxflags); +} + +// https://github.com/mozilla/sccache/blob/main/README.md#usage +fn try_use_sccache(cmake: &mut Config) { + if let Ok(sccache) = std::env::var("SCCACHE_PATH") { + cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache); + cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache); + match std::env::var_os("CARGO_CFG_TARGET_OS") { + Some(os) if os == "windows" => { + cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded"); + cmake.define("CMAKE_POLICY_CMP0141", "NEW"); + } + _ => {} + } + } +} + +fn try_use_ninja(cmake: &mut Config) { + let mut cmd = Command::new("ninja"); + cmd.arg("--version"); + if let Ok(status) = cmd.status() { + if status.success() { + cmake.generator("Ninja"); + } + } +} + +fn llvm_config( + llvm_build_dir: &PathBuf, + path_to_llvm_config: &[&str], +) -> io::Result<(String, String, String, String, String)> { + let mut llvm_build_path = llvm_build_dir.clone(); + llvm_build_path.extend(path_to_llvm_config); + let mut cmd = Command::new(llvm_build_path); + cmd.args([ + "--link-static", + "--cxxflags", + "--ldflags", + "--libdir", + "--libnames", + "--system-libs", + ]); + for c in COMPONENTS { + cmd.arg(c[4..].to_lowercase()); + } + let output = cmd.output()?; + if !output.status.success() { + return Err(io::Error::from(io::ErrorKind::Other)); + } + let output = unsafe { String::from_utf8_unchecked(output.stdout) }; + let mut lines = output.lines(); + let cxxflags = lines.next().unwrap(); + let ldflags = lines.next().unwrap(); + let libdir = lines.next().unwrap(); + let lib_names = lines.next().unwrap(); + let system_libs = lines.next().unwrap(); + Ok(( + cxxflags.to_string(), + ldflags.to_string(), + libdir.to_string(), + lib_names.to_string(), + system_libs.to_string(), + )) +} + +fn compile_cxx_lib(cxxflags: String) { + let mut cc = cc::Build::new(); + for flag in cxxflags.split_whitespace() { + cc.flag(flag); + } + cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp"); + println!("cargo:rerun-if-changed=src/lib.cpp"); + println!("cargo:rerun-if-changed=src/lib.rs"); +} + +fn link_llvm_components(components: String) { + for component in components.split_whitespace() { + let component = if let Some(component) = component + .strip_prefix("lib") + .and_then(|component| component.strip_suffix(".a")) + { + // Unix (Linux/Mac) + // libLLVMfoo.a + component + } else if let Some(component) = component.strip_suffix(".lib") { + // Windows + // LLVMfoo.lib + component + } else { + panic!("'{}' does not look like a static library name", component) + }; + println!("cargo:rustc-link-lib={component}"); + } +} diff --git a/llvm_zluda/src/lib.rs b/llvm_zluda/src/lib.rs index fb5cc47..18046a5 100644 --- a/llvm_zluda/src/lib.rs +++ b/llvm_zluda/src/lib.rs @@ -1,81 +1,81 @@ -#![allow(non_upper_case_globals)] -use llvm_sys::prelude::*; -pub use llvm_sys::*; - -#[repr(C)] -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum LLVMZludaAtomicRMWBinOp { - LLVMZludaAtomicRMWBinOpXchg = 0, - LLVMZludaAtomicRMWBinOpAdd = 1, - LLVMZludaAtomicRMWBinOpSub = 2, - LLVMZludaAtomicRMWBinOpAnd = 3, - LLVMZludaAtomicRMWBinOpNand = 4, - LLVMZludaAtomicRMWBinOpOr = 5, - LLVMZludaAtomicRMWBinOpXor = 6, - LLVMZludaAtomicRMWBinOpMax = 7, - LLVMZludaAtomicRMWBinOpMin = 8, - LLVMZludaAtomicRMWBinOpUMax = 9, - LLVMZludaAtomicRMWBinOpUMin = 10, - LLVMZludaAtomicRMWBinOpFAdd = 11, - LLVMZludaAtomicRMWBinOpFSub = 12, - LLVMZludaAtomicRMWBinOpFMax = 13, - LLVMZludaAtomicRMWBinOpFMin = 14, - LLVMZludaAtomicRMWBinOpUIncWrap = 15, - LLVMZludaAtomicRMWBinOpUDecWrap = 16, -} - -// Backport from LLVM 19 -pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0; -pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1; -pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2; -pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3; -pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4; -pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5; -pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6; -pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0; -pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc - | LLVMZludaFastMathNoNaNs - | LLVMZludaFastMathNoInfs - | LLVMZludaFastMathNoSignedZeros - | LLVMZludaFastMathAllowReciprocal - | LLVMZludaFastMathAllowContract - | LLVMZludaFastMathApproxFunc; - -pub type LLVMZludaFastMathFlags = std::ffi::c_uint; - -extern "C" { - pub fn LLVMZludaBuildAlloca( - B: LLVMBuilderRef, - Ty: LLVMTypeRef, - AddrSpace: u32, - Name: *const i8, - ) -> LLVMValueRef; - - pub fn LLVMZludaBuildAtomicRMW( - B: LLVMBuilderRef, - op: LLVMZludaAtomicRMWBinOp, - PTR: LLVMValueRef, - Val: LLVMValueRef, - scope: *const i8, - ordering: LLVMAtomicOrdering, - ) -> LLVMValueRef; - - pub fn LLVMZludaBuildAtomicCmpXchg( - B: LLVMBuilderRef, - Ptr: LLVMValueRef, - Cmp: LLVMValueRef, - New: LLVMValueRef, - scope: *const i8, - SuccessOrdering: LLVMAtomicOrdering, - FailureOrdering: LLVMAtomicOrdering, - ) -> LLVMValueRef; - - pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags); - - pub fn LLVMZludaBuildFence( - B: LLVMBuilderRef, - ordering: LLVMAtomicOrdering, - scope: *const i8, - Name: *const i8, - ) -> LLVMValueRef; -} +#![allow(non_upper_case_globals)] +use llvm_sys::prelude::*; +pub use llvm_sys::*; + +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum LLVMZludaAtomicRMWBinOp { + LLVMZludaAtomicRMWBinOpXchg = 0, + LLVMZludaAtomicRMWBinOpAdd = 1, + LLVMZludaAtomicRMWBinOpSub = 2, + LLVMZludaAtomicRMWBinOpAnd = 3, + LLVMZludaAtomicRMWBinOpNand = 4, + LLVMZludaAtomicRMWBinOpOr = 5, + LLVMZludaAtomicRMWBinOpXor = 6, + LLVMZludaAtomicRMWBinOpMax = 7, + LLVMZludaAtomicRMWBinOpMin = 8, + LLVMZludaAtomicRMWBinOpUMax = 9, + LLVMZludaAtomicRMWBinOpUMin = 10, + LLVMZludaAtomicRMWBinOpFAdd = 11, + LLVMZludaAtomicRMWBinOpFSub = 12, + LLVMZludaAtomicRMWBinOpFMax = 13, + LLVMZludaAtomicRMWBinOpFMin = 14, + LLVMZludaAtomicRMWBinOpUIncWrap = 15, + LLVMZludaAtomicRMWBinOpUDecWrap = 16, +} + +// Backport from LLVM 19 +pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0; +pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1; +pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2; +pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3; +pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4; +pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5; +pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6; +pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0; +pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc + | LLVMZludaFastMathNoNaNs + | LLVMZludaFastMathNoInfs + | LLVMZludaFastMathNoSignedZeros + | LLVMZludaFastMathAllowReciprocal + | LLVMZludaFastMathAllowContract + | LLVMZludaFastMathApproxFunc; + +pub type LLVMZludaFastMathFlags = std::ffi::c_uint; + +extern "C" { + pub fn LLVMZludaBuildAlloca( + B: LLVMBuilderRef, + Ty: LLVMTypeRef, + AddrSpace: u32, + Name: *const i8, + ) -> LLVMValueRef; + + pub fn LLVMZludaBuildAtomicRMW( + B: LLVMBuilderRef, + op: LLVMZludaAtomicRMWBinOp, + PTR: LLVMValueRef, + Val: LLVMValueRef, + scope: *const i8, + ordering: LLVMAtomicOrdering, + ) -> LLVMValueRef; + + pub fn LLVMZludaBuildAtomicCmpXchg( + B: LLVMBuilderRef, + Ptr: LLVMValueRef, + Cmp: LLVMValueRef, + New: LLVMValueRef, + scope: *const i8, + SuccessOrdering: LLVMAtomicOrdering, + FailureOrdering: LLVMAtomicOrdering, + ) -> LLVMValueRef; + + pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags); + + pub fn LLVMZludaBuildFence( + B: LLVMBuilderRef, + ordering: LLVMAtomicOrdering, + scope: *const i8, + Name: *const i8, + ) -> LLVMValueRef; +} diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index e203394..e80f6a3 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -1,191 +1,191 @@ -use super::*; - -pub(super) fn run<'a, 'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { - directives - .into_iter() - .map(|directive| run_directive(resolver, directive)) - .collect::, _>>() -} - -fn run_directive<'input>( - resolver: &mut GlobalStringIdentResolver2, - directive: Directive2, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { - Ok(match directive { - var @ Directive2::Variable(..) => var, - Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), - }) -} - -fn run_method<'input>( - resolver: &mut GlobalStringIdentResolver2, - mut method: Function2, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { - let is_declaration = method.body.is_none(); - let mut body = Vec::new(); - let mut remap_returns = Vec::new(); - if !method.is_kernel { - for arg in method.return_arguments.iter_mut() { - match arg.state_space { - ptx_parser::StateSpace::Param => { - arg.state_space = ptx_parser::StateSpace::Reg; - let old_name = arg.name; - arg.name = - resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); - if is_declaration { - continue; - } - remap_returns.push((old_name, arg.name, arg.v_type.clone())); - body.push(Statement::Variable(ast::Variable { - align: None, - name: old_name, - v_type: arg.v_type.clone(), - state_space: ptx_parser::StateSpace::Param, - array_init: Vec::new(), - })); - } - ptx_parser::StateSpace::Reg => {} - _ => return Err(error_unreachable()), - } - } - for arg in method.input_arguments.iter_mut() { - match arg.state_space { - ptx_parser::StateSpace::Param => { - arg.state_space = ptx_parser::StateSpace::Reg; - let old_name = arg.name; - arg.name = - resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); - if is_declaration { - continue; - } - body.push(Statement::Variable(ast::Variable { - align: None, - name: old_name, - v_type: arg.v_type.clone(), - state_space: ptx_parser::StateSpace::Param, - array_init: Vec::new(), - })); - body.push(Statement::Instruction(ast::Instruction::St { - data: ast::StData { - qualifier: ast::LdStQualifier::Weak, - state_space: ast::StateSpace::Param, - caching: ast::StCacheOperator::Writethrough, - typ: arg.v_type.clone(), - }, - arguments: ast::StArgs { - src1: old_name, - src2: arg.name, - }, - })); - } - ptx_parser::StateSpace::Reg => {} - _ => return Err(error_unreachable()), - } - } - } - let body = method - .body - .map(|statements| { - for statement in statements { - run_statement(resolver, &remap_returns, &mut body, statement)?; - } - Ok::<_, TranslateError>(body) - }) - .transpose()?; - Ok(Function2 { body, ..method }) -} - -fn run_statement<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, - result: &mut Vec, SpirvWord>>, - statement: Statement, SpirvWord>, -) -> Result<(), TranslateError> { - match statement { - Statement::Instruction(ast::Instruction::Call { - mut data, - mut arguments, - }) => { - let mut post_st = Vec::new(); - for ((type_, space), ident) in data - .input_arguments - .iter_mut() - .zip(arguments.input_arguments.iter_mut()) - { - if *space == ptx_parser::StateSpace::Param { - *space = ptx_parser::StateSpace::Reg; - let old_name = *ident; - *ident = resolver - .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); - result.push(Statement::Instruction(ast::Instruction::Ld { - data: ast::LdDetails { - qualifier: ast::LdStQualifier::Weak, - state_space: ast::StateSpace::Param, - caching: ast::LdCacheOperator::Cached, - typ: type_.clone(), - non_coherent: false, - }, - arguments: ast::LdArgs { - dst: *ident, - src: old_name, - }, - })); - } - } - for ((type_, space), ident) in data - .return_arguments - .iter_mut() - .zip(arguments.return_arguments.iter_mut()) - { - if *space == ptx_parser::StateSpace::Param { - *space = ptx_parser::StateSpace::Reg; - let old_name = *ident; - *ident = resolver - .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); - post_st.push(Statement::Instruction(ast::Instruction::St { - data: ast::StData { - qualifier: ast::LdStQualifier::Weak, - state_space: ast::StateSpace::Param, - caching: ast::StCacheOperator::Writethrough, - typ: type_.clone(), - }, - arguments: ast::StArgs { - src1: old_name, - src2: *ident, - }, - })); - } - } - result.push(Statement::Instruction(ast::Instruction::Call { - data, - arguments, - })); - result.extend(post_st.into_iter()); - } - Statement::Instruction(ast::Instruction::Ret { data }) => { - for (old_name, new_name, type_) in remap_returns.iter() { - result.push(Statement::Instruction(ast::Instruction::Ld { - data: ast::LdDetails { - qualifier: ast::LdStQualifier::Weak, - state_space: ast::StateSpace::Param, - caching: ast::LdCacheOperator::Cached, - typ: type_.clone(), - non_coherent: false, - }, - arguments: ast::LdArgs { - dst: *new_name, - src: *old_name, - }, - })); - } - result.push(Statement::Instruction(ast::Instruction::Ret { data })); - } - statement => { - result.push(statement); - } - } - Ok(()) -} +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2, + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2, + mut method: Function2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + let is_declaration = method.body.is_none(); + let mut body = Vec::new(); + let mut remap_returns = Vec::new(); + if !method.is_kernel { + for arg in method.return_arguments.iter_mut() { + match arg.state_space { + ptx_parser::StateSpace::Param => { + arg.state_space = ptx_parser::StateSpace::Reg; + let old_name = arg.name; + arg.name = + resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + if is_declaration { + continue; + } + remap_returns.push((old_name, arg.name, arg.v_type.clone())); + body.push(Statement::Variable(ast::Variable { + align: None, + name: old_name, + v_type: arg.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + })); + } + ptx_parser::StateSpace::Reg => {} + _ => return Err(error_unreachable()), + } + } + for arg in method.input_arguments.iter_mut() { + match arg.state_space { + ptx_parser::StateSpace::Param => { + arg.state_space = ptx_parser::StateSpace::Reg; + let old_name = arg.name; + arg.name = + resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + if is_declaration { + continue; + } + body.push(Statement::Variable(ast::Variable { + align: None, + name: old_name, + v_type: arg.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + })); + body.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: arg.v_type.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: arg.name, + }, + })); + } + ptx_parser::StateSpace::Reg => {} + _ => return Err(error_unreachable()), + } + } + } + let body = method + .body + .map(|statements| { + for statement in statements { + run_statement(resolver, &remap_returns, &mut body, statement)?; + } + Ok::<_, TranslateError>(body) + }) + .transpose()?; + Ok(Function2 { body, ..method }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, + result: &mut Vec, SpirvWord>>, + statement: Statement, SpirvWord>, +) -> Result<(), TranslateError> { + match statement { + Statement::Instruction(ast::Instruction::Call { + mut data, + mut arguments, + }) => { + let mut post_st = Vec::new(); + for ((type_, space), ident) in data + .input_arguments + .iter_mut() + .zip(arguments.input_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + result.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: *ident, + src: old_name, + }, + })); + } + } + for ((type_, space), ident) in data + .return_arguments + .iter_mut() + .zip(arguments.return_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + post_st.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: type_.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: *ident, + }, + })); + } + } + result.push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + result.extend(post_st.into_iter()); + } + Statement::Instruction(ast::Instruction::Ret { data }) => { + for (old_name, new_name, type_) in remap_returns.iter() { + result.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: *new_name, + src: *old_name, + }, + })); + } + result.push(Statement::Instruction(ast::Instruction::Ret { data })); + } + statement => { + result.push(statement); + } + } + Ok(()) +} diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index a9ede33..b21c343 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -1,301 +1,301 @@ -use super::*; - -pub(super) fn run<'a, 'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, -) -> Result, SpirvWord>>, TranslateError> { - directives - .into_iter() - .map(|directive| run_directive(resolver, directive)) - .collect::, _>>() -} - -fn run_directive<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directive: Directive2< - ast::Instruction>, - ast::ParsedOperand, - >, -) -> Result, SpirvWord>, TranslateError> { - Ok(match directive { - Directive2::Variable(linking, var) => Directive2::Variable(linking, var), - Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), - }) -} - -fn run_method<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - method: Function2< - ast::Instruction>, - ast::ParsedOperand, - >, -) -> Result, SpirvWord>, TranslateError> { - let body = method - .body - .map(|statements| { - let mut result = Vec::with_capacity(statements.len()); - for statement in statements { - run_statement(resolver, &mut result, statement)?; - } - Ok::<_, TranslateError>(result) - }) - .transpose()?; - Ok(Function2 { - body, - return_arguments: method.return_arguments, - name: method.name, - input_arguments: method.input_arguments, - import_as: method.import_as, - tuning: method.tuning, - linkage: method.linkage, - is_kernel: method.is_kernel, - flush_to_zero_f32: method.flush_to_zero_f32, - flush_to_zero_f16f64: method.flush_to_zero_f16f64, - rounding_mode_f32: method.rounding_mode_f32, - rounding_mode_f16f64: method.rounding_mode_f16f64, - }) -} - -fn run_statement<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - result: &mut Vec, SpirvWord>>, - statement: UnconditionalStatement, -) -> Result<(), TranslateError> { - let mut visitor = FlattenArguments::new(resolver, result); - let new_statement = statement.visit_map(&mut visitor)?; - visitor.result.push(new_statement); - Ok(()) -} - -struct FlattenArguments<'a, 'input> { - result: &'a mut Vec, - resolver: &'a mut GlobalStringIdentResolver2<'input>, - post_stmts: Vec, -} - -impl<'a, 'input> FlattenArguments<'a, 'input> { - fn new( - resolver: &'a mut GlobalStringIdentResolver2<'input>, - result: &'a mut Vec, - ) -> Self { - FlattenArguments { - result, - resolver, - post_stmts: Vec::new(), - } - } - - fn reg(&mut self, name: SpirvWord) -> Result { - Ok(name) - } - - fn reg_offset( - &mut self, - reg: SpirvWord, - offset: i32, - type_space: Option<(&ast::Type, ast::StateSpace)>, - _is_dst: bool, - ) -> Result { - let (type_, state_space) = if let Some((type_, state_space)) = type_space { - (type_, state_space) - } else { - return Err(TranslateError::UntypedSymbol); - }; - if state_space == ast::StateSpace::Reg { - let (reg_type, reg_space) = self.resolver.get_typed(reg)?; - if *reg_space != ast::StateSpace::Reg { - return Err(error_mismatched_type()); - } - let reg_scalar_type = match reg_type { - ast::Type::Scalar(underlying_type) => *underlying_type, - _ => return Err(error_mismatched_type()), - }; - let reg_type = reg_type.clone(); - let id_constant_stmt = self - .resolver - .register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg))); - self.result.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: reg_scalar_type, - value: ast::ImmediateValue::S64(offset as i64), - })); - let arith_details = match reg_scalar_type.kind() { - ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { - type_: reg_scalar_type, - saturate: false, - }), - ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { - ast::ArithDetails::Integer(ast::ArithInteger { - type_: reg_scalar_type, - saturate: false, - }) - } - _ => return Err(error_unreachable()), - }; - let id_add_result = self - .resolver - .register_unnamed(Some((reg_type, state_space))); - self.result - .push(Statement::Instruction(ast::Instruction::Add { - data: arith_details, - arguments: ast::AddArgs { - dst: id_add_result, - src1: reg, - src2: id_constant_stmt, - }, - })); - Ok(id_add_result) - } else { - let id_constant_stmt = self.resolver.register_unnamed(Some(( - ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - ))); - self.result.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::S64, - value: ast::ImmediateValue::S64(offset as i64), - })); - let dst = self - .resolver - .register_unnamed(Some((type_.clone(), state_space))); - self.result.push(Statement::PtrAccess(PtrAccess { - underlying_type: type_.clone(), - state_space: state_space, - dst, - ptr_src: reg, - offset_src: id_constant_stmt, - })); - Ok(dst) - } - } - - fn immediate( - &mut self, - value: ast::ImmediateValue, - type_space: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - let (scalar_t, state_space) = - if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { - (*scalar, state_space) - } else { - return Err(TranslateError::UntypedSymbol); - }; - let id = self - .resolver - .register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space))); - self.result.push(Statement::Constant(ConstantDefinition { - dst: id, - typ: scalar_t, - value, - })); - Ok(id) - } - - fn vec_member( - &mut self, - vector_ident: SpirvWord, - member: u8, - _type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - ) -> Result { - let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? { - (ast::Type::Vector(vector_width, scalar_t), space) => { - (*vector_width, *scalar_t, *space) - } - _ => return Err(error_mismatched_type()), - }; - let temporary = self - .resolver - .register_unnamed(Some((scalar_type.into(), space))); - if is_dst { - self.post_stmts.push(Statement::VectorWrite(VectorWrite { - scalar_type, - vector_width, - vector_dst: vector_ident, - vector_src: vector_ident, - scalar_src: temporary, - member, - })); - } else { - self.result.push(Statement::VectorRead(VectorRead { - scalar_type, - vector_width, - scalar_dst: temporary, - vector_src: vector_ident, - member, - })); - } - Ok(temporary) - } - - fn vec_pack( - &mut self, - vector_elements: Vec, - type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - let (width, scalar_t, state_space) = match type_space { - Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space), - _ => return Err(error_mismatched_type()), - }; - let temporary_vector = self - .resolver - .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space))); - let statement = Statement::RepackVector(RepackVectorDetails { - is_extract: is_dst, - typ: scalar_t, - packed: temporary_vector, - unpacked: vector_elements, - relaxed_type_check, - }); - if is_dst { - self.post_stmts.push(statement); - } else { - self.result.push(statement); - } - Ok(temporary_vector) - } -} - -impl<'a, 'b> ast::VisitorMap, SpirvWord, TranslateError> - for FlattenArguments<'a, 'b> -{ - fn visit( - &mut self, - args: ast::ParsedOperand, - type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - match args { - ast::ParsedOperand::Reg(r) => self.reg(r), - ast::ParsedOperand::Imm(x) => self.immediate(x, type_space), - ast::ParsedOperand::RegOffset(reg, offset) => { - self.reg_offset(reg, offset, type_space, is_dst) - } - ast::ParsedOperand::VecMember(vec, member) => { - self.vec_member(vec, member, type_space, is_dst) - } - ast::ParsedOperand::VecPack(vecs) => { - self.vec_pack(vecs, type_space, is_dst, relaxed_type_check) - } - } - } - - fn visit_ident( - &mut self, - name: SpirvWord, - _type_space: Option<(&ast::Type, ast::StateSpace)>, - _is_dst: bool, - _relaxed_type_check: bool, - ) -> Result<::Ident, TranslateError> { - self.reg(name) - } -} - -impl Drop for FlattenArguments<'_, '_> { - fn drop(&mut self) { - self.result.extend(self.post_stmts.drain(..)); - } -} +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2< + ast::Instruction>, + ast::ParsedOperand, + >, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + Directive2::Variable(linking, var) => Directive2::Variable(linking, var), + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + method: Function2< + ast::Instruction>, + ast::ParsedOperand, + >, +) -> Result, SpirvWord>, TranslateError> { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(resolver, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + body, + return_arguments: method.return_arguments, + name: method.name, + input_arguments: method.input_arguments, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + is_kernel: method.is_kernel, + flush_to_zero_f32: method.flush_to_zero_f32, + flush_to_zero_f16f64: method.flush_to_zero_f16f64, + rounding_mode_f32: method.rounding_mode_f32, + rounding_mode_f16f64: method.rounding_mode_f16f64, + }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, SpirvWord>>, + statement: UnconditionalStatement, +) -> Result<(), TranslateError> { + let mut visitor = FlattenArguments::new(resolver, result); + let new_statement = statement.visit_map(&mut visitor)?; + visitor.result.push(new_statement); + Ok(()) +} + +struct FlattenArguments<'a, 'input> { + result: &'a mut Vec, + resolver: &'a mut GlobalStringIdentResolver2<'input>, + post_stmts: Vec, +} + +impl<'a, 'input> FlattenArguments<'a, 'input> { + fn new( + resolver: &'a mut GlobalStringIdentResolver2<'input>, + result: &'a mut Vec, + ) -> Self { + FlattenArguments { + result, + resolver, + post_stmts: Vec::new(), + } + } + + fn reg(&mut self, name: SpirvWord) -> Result { + Ok(name) + } + + fn reg_offset( + &mut self, + reg: SpirvWord, + offset: i32, + type_space: Option<(&ast::Type, ast::StateSpace)>, + _is_dst: bool, + ) -> Result { + let (type_, state_space) = if let Some((type_, state_space)) = type_space { + (type_, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + if state_space == ast::StateSpace::Reg { + let (reg_type, reg_space) = self.resolver.get_typed(reg)?; + if *reg_space != ast::StateSpace::Reg { + return Err(error_mismatched_type()); + } + let reg_scalar_type = match reg_type { + ast::Type::Scalar(underlying_type) => *underlying_type, + _ => return Err(error_mismatched_type()), + }; + let reg_type = reg_type.clone(); + let id_constant_stmt = self + .resolver + .register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: reg_scalar_type, + value: ast::ImmediateValue::S64(offset as i64), + })); + let arith_details = match reg_scalar_type.kind() { + ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }), + ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { + ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }) + } + _ => return Err(error_unreachable()), + }; + let id_add_result = self + .resolver + .register_unnamed(Some((reg_type, state_space))); + self.result + .push(Statement::Instruction(ast::Instruction::Add { + data: arith_details, + arguments: ast::AddArgs { + dst: id_add_result, + src1: reg, + src2: id_constant_stmt, + }, + })); + Ok(id_add_result) + } else { + let id_constant_stmt = self.resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self + .resolver + .register_unnamed(Some((type_.clone(), state_space))); + self.result.push(Statement::PtrAccess(PtrAccess { + underlying_type: type_.clone(), + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) + } + } + + fn immediate( + &mut self, + value: ast::ImmediateValue, + type_space: Option<(&ast::Type, ast::StateSpace)>, + ) -> Result { + let (scalar_t, state_space) = + if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { + (*scalar, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + let id = self + .resolver + .register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value, + })); + Ok(id) + } + + fn vec_member( + &mut self, + vector_ident: SpirvWord, + member: u8, + _type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + ) -> Result { + let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? { + (ast::Type::Vector(vector_width, scalar_t), space) => { + (*vector_width, *scalar_t, *space) + } + _ => return Err(error_mismatched_type()), + }; + let temporary = self + .resolver + .register_unnamed(Some((scalar_type.into(), space))); + if is_dst { + self.post_stmts.push(Statement::VectorWrite(VectorWrite { + scalar_type, + vector_width, + vector_dst: vector_ident, + vector_src: vector_ident, + scalar_src: temporary, + member, + })); + } else { + self.result.push(Statement::VectorRead(VectorRead { + scalar_type, + vector_width, + scalar_dst: temporary, + vector_src: vector_ident, + member, + })); + } + Ok(temporary) + } + + fn vec_pack( + &mut self, + vector_elements: Vec, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + let (width, scalar_t, state_space) = match type_space { + Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space), + _ => return Err(error_mismatched_type()), + }; + let temporary_vector = self + .resolver + .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space))); + let statement = Statement::RepackVector(RepackVectorDetails { + is_extract: is_dst, + typ: scalar_t, + packed: temporary_vector, + unpacked: vector_elements, + relaxed_type_check, + }); + if is_dst { + self.post_stmts.push(statement); + } else { + self.result.push(statement); + } + Ok(temporary_vector) + } +} + +impl<'a, 'b> ast::VisitorMap, SpirvWord, TranslateError> + for FlattenArguments<'a, 'b> +{ + fn visit( + &mut self, + args: ast::ParsedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + match args { + ast::ParsedOperand::Reg(r) => self.reg(r), + ast::ParsedOperand::Imm(x) => self.immediate(x, type_space), + ast::ParsedOperand::RegOffset(reg, offset) => { + self.reg_offset(reg, offset, type_space, is_dst) + } + ast::ParsedOperand::VecMember(vec, member) => { + self.vec_member(vec, member, type_space, is_dst) + } + ast::ParsedOperand::VecPack(vecs) => { + self.vec_pack(vecs, type_space, is_dst, relaxed_type_check) + } + } + } + + fn visit_ident( + &mut self, + name: SpirvWord, + _type_space: Option<(&ast::Type, ast::StateSpace)>, + _is_dst: bool, + _relaxed_type_check: bool, + ) -> Result<::Ident, TranslateError> { + self.reg(name) + } +} + +impl Drop for FlattenArguments<'_, '_> { + fn drop(&mut self) { + self.result.extend(self.post_stmts.drain(..)); + } +} diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 78e66c9..fc9e028 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -1,208 +1,208 @@ -use super::*; - -pub(super) fn run<'a, 'input>( - resolver: &'a mut GlobalStringIdentResolver2<'input>, - special_registers: &'a SpecialRegistersMap2, - directives: Vec, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len()); - let mut sreg_to_function = - FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default()); - SpecialRegistersMap2::foreach_declaration( - resolver, - |sreg, (return_arguments, name, input_arguments)| { - result.push(UnconditionalDirective::Method(UnconditionalFunction { - return_arguments, - name, - input_arguments, - body: None, - import_as: None, - tuning: Vec::new(), - linkage: ast::LinkingDirective::EXTERN, - is_kernel: false, - flush_to_zero_f32: false, - flush_to_zero_f16f64: false, - rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, - rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, - })); - sreg_to_function.insert(sreg, name); - }, - ); - let mut visitor = SpecialRegisterResolver { - resolver, - special_registers, - sreg_to_function, - result: Vec::new(), - }; - for directive in directives.into_iter() { - result.push(run_directive(&mut visitor, directive)?); - } - Ok(result) -} - -fn run_directive<'a, 'input>( - visitor: &mut SpecialRegisterResolver<'a, 'input>, - directive: UnconditionalDirective, -) -> Result { - Ok(match directive { - var @ Directive2::Variable(..) => var, - Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?), - }) -} - -fn run_method<'a, 'input>( - visitor: &mut SpecialRegisterResolver<'a, 'input>, - method: UnconditionalFunction, -) -> Result { - let body = method - .body - .map(|statements| { - let mut result = Vec::with_capacity(statements.len()); - for statement in statements { - run_statement(visitor, &mut result, statement)?; - } - Ok::<_, TranslateError>(result) - }) - .transpose()?; - Ok(Function2 { body, ..method }) -} - -fn run_statement<'a, 'input>( - visitor: &mut SpecialRegisterResolver<'a, 'input>, - result: &mut Vec, - statement: UnconditionalStatement, -) -> Result<(), TranslateError> { - let converted_statement = statement.visit_map(visitor)?; - result.extend(visitor.result.drain(..)); - result.push(converted_statement); - Ok(()) -} - -struct SpecialRegisterResolver<'a, 'input> { - resolver: &'a mut GlobalStringIdentResolver2<'input>, - special_registers: &'a SpecialRegistersMap2, - sreg_to_function: FxHashMap, - result: Vec, -} - -impl<'a, 'b, 'input> - ast::VisitorMap, ast::ParsedOperand, TranslateError> - for SpecialRegisterResolver<'a, 'input> -{ - fn visit( - &mut self, - operand: ast::ParsedOperand, - _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - _relaxed_type_check: bool, - ) -> Result, TranslateError> { - map_operand(operand, &mut |ident, vector_index| { - self.replace_sreg(ident, vector_index, is_dst) - }) - } - - fn visit_ident( - &mut self, - args: SpirvWord, - _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - _relaxed_type_check: bool, - ) -> Result { - Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args)) - } -} - -impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { - fn replace_sreg( - &mut self, - name: SpirvWord, - vector_index: Option, - is_dst: bool, - ) -> Result, TranslateError> { - if let Some(sreg) = self.special_registers.get(name) { - if is_dst { - return Err(error_mismatched_type()); - } - let input_arguments = match (vector_index, sreg.get_function_input_type()) { - (Some(idx), Some(inp_type)) => { - if inp_type != ast::ScalarType::U8 { - return Err(TranslateError::Unreachable); - } - let constant = self.resolver.register_unnamed(Some(( - ast::Type::Scalar(inp_type), - ast::StateSpace::Reg, - ))); - self.result.push(Statement::Constant(ConstantDefinition { - dst: constant, - typ: inp_type, - value: ast::ImmediateValue::U64(idx as u64), - })); - vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)] - } - (None, None) => Vec::new(), - _ => return Err(error_mismatched_type()), - }; - let return_type = sreg.get_function_return_type(); - let fn_result = self - .resolver - .register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg))); - let return_arguments = vec![( - fn_result, - ast::Type::Scalar(return_type), - ast::StateSpace::Reg, - )]; - let data = ast::CallDetails { - uniform: false, - return_arguments: return_arguments - .iter() - .map(|(_, typ, space)| (typ.clone(), *space)) - .collect(), - input_arguments: input_arguments - .iter() - .map(|(_, typ, space)| (typ.clone(), *space)) - .collect(), - }; - let arguments = ast::CallArgs::> { - return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), - func: self.sreg_to_function[&sreg], - input_arguments: input_arguments - .iter() - .map(|(name, _, _)| ast::ParsedOperand::Reg(*name)) - .collect(), - }; - self.result - .push(Statement::Instruction(ast::Instruction::Call { - data, - arguments, - })); - Ok(Some(fn_result)) - } else { - Ok(None) - } - } -} - -pub fn map_operand( - this: ast::ParsedOperand, - fn_: &mut impl FnMut(T, Option) -> Result, Err>, -) -> Result, Err> { - Ok(match this { - ast::ParsedOperand::Reg(ident) => { - ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident)) - } - ast::ParsedOperand::RegOffset(ident, offset) => { - ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset) - } - ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm), - ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? { - Some(ident) => ast::ParsedOperand::Reg(ident), - None => ast::ParsedOperand::VecMember(ident, member), - }, - ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( - idents - .into_iter() - .map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident))) - .collect::, _>>()?, - ), - }) -} +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &'a mut GlobalStringIdentResolver2<'input>, + special_registers: &'a SpecialRegistersMap2, + directives: Vec, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len()); + let mut sreg_to_function = + FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default()); + SpecialRegistersMap2::foreach_declaration( + resolver, + |sreg, (return_arguments, name, input_arguments)| { + result.push(UnconditionalDirective::Method(UnconditionalFunction { + return_arguments, + name, + input_arguments, + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + is_kernel: false, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + })); + sreg_to_function.insert(sreg, name); + }, + ); + let mut visitor = SpecialRegisterResolver { + resolver, + special_registers, + sreg_to_function, + result: Vec::new(), + }; + for directive in directives.into_iter() { + result.push(run_directive(&mut visitor, directive)?); + } + Ok(result) +} + +fn run_directive<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + directive: UnconditionalDirective, +) -> Result { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?), + }) +} + +fn run_method<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + method: UnconditionalFunction, +) -> Result { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(visitor, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { body, ..method }) +} + +fn run_statement<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + result: &mut Vec, + statement: UnconditionalStatement, +) -> Result<(), TranslateError> { + let converted_statement = statement.visit_map(visitor)?; + result.extend(visitor.result.drain(..)); + result.push(converted_statement); + Ok(()) +} + +struct SpecialRegisterResolver<'a, 'input> { + resolver: &'a mut GlobalStringIdentResolver2<'input>, + special_registers: &'a SpecialRegistersMap2, + sreg_to_function: FxHashMap, + result: Vec, +} + +impl<'a, 'b, 'input> + ast::VisitorMap, ast::ParsedOperand, TranslateError> + for SpecialRegisterResolver<'a, 'input> +{ + fn visit( + &mut self, + operand: ast::ParsedOperand, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result, TranslateError> { + map_operand(operand, &mut |ident, vector_index| { + self.replace_sreg(ident, vector_index, is_dst) + }) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args)) + } +} + +impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { + fn replace_sreg( + &mut self, + name: SpirvWord, + vector_index: Option, + is_dst: bool, + ) -> Result, TranslateError> { + if let Some(sreg) = self.special_registers.get(name) { + if is_dst { + return Err(error_mismatched_type()); + } + let input_arguments = match (vector_index, sreg.get_function_input_type()) { + (Some(idx), Some(inp_type)) => { + if inp_type != ast::ScalarType::U8 { + return Err(TranslateError::Unreachable); + } + let constant = self.resolver.register_unnamed(Some(( + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: constant, + typ: inp_type, + value: ast::ImmediateValue::U64(idx as u64), + })); + vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)] + } + (None, None) => Vec::new(), + _ => return Err(error_mismatched_type()), + }; + let return_type = sreg.get_function_return_type(); + let fn_result = self + .resolver + .register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg))); + let return_arguments = vec![( + fn_result, + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + )]; + let data = ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + input_arguments: input_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + }; + let arguments = ast::CallArgs::> { + return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), + func: self.sreg_to_function[&sreg], + input_arguments: input_arguments + .iter() + .map(|(name, _, _)| ast::ParsedOperand::Reg(*name)) + .collect(), + }; + self.result + .push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + Ok(Some(fn_result)) + } else { + Ok(None) + } + } +} + +pub fn map_operand( + this: ast::ParsedOperand, + fn_: &mut impl FnMut(T, Option) -> Result, Err>, +) -> Result, Err> { + Ok(match this { + ast::ParsedOperand::Reg(ident) => { + ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident)) + } + ast::ParsedOperand::RegOffset(ident, offset) => { + ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset) + } + ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm), + ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? { + Some(ident) => ast::ParsedOperand::Reg(ident), + None => ast::ParsedOperand::VecMember(ident, member), + }, + ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( + idents + .into_iter() + .map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident))) + .collect::, _>>()?, + ), + }) +} diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index 654a7e9..dfc88c2 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -1,45 +1,45 @@ -use super::*; - -pub(super) fn run<'input>( - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { - let mut result = Vec::with_capacity(directives.len()); - for mut directive in directives.into_iter() { - run_directive(&mut result, &mut directive)?; - result.push(directive); - } - Ok(result) -} - -fn run_directive<'input>( - result: &mut Vec, SpirvWord>>, - directive: &mut Directive2, SpirvWord>, -) -> Result<(), TranslateError> { - match directive { - Directive2::Variable(..) => {} - Directive2::Method(function2) => run_function(result, function2), - } - Ok(()) -} - -fn run_function<'input>( - result: &mut Vec, SpirvWord>>, - function: &mut Function2, SpirvWord>, -) { - function.body = function.body.take().map(|statements| { - statements - .into_iter() - .filter_map(|statement| match statement { - Statement::Variable(var @ ast::Variable { - state_space: - ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared, - .. - }) => { - result.push(Directive2::Variable(ast::LinkingDirective::NONE, var)); - None - } - s => Some(s), - }) - .collect() - }); -} +use super::*; + +pub(super) fn run<'input>( + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let mut result = Vec::with_capacity(directives.len()); + for mut directive in directives.into_iter() { + run_directive(&mut result, &mut directive)?; + result.push(directive); + } + Ok(result) +} + +fn run_directive<'input>( + result: &mut Vec, SpirvWord>>, + directive: &mut Directive2, SpirvWord>, +) -> Result<(), TranslateError> { + match directive { + Directive2::Variable(..) => {} + Directive2::Method(function2) => run_function(result, function2), + } + Ok(()) +} + +fn run_function<'input>( + result: &mut Vec, SpirvWord>>, + function: &mut Function2, SpirvWord>, +) { + function.body = function.body.take().map(|statements| { + statements + .into_iter() + .filter_map(|statement| match statement { + Statement::Variable(var @ ast::Variable { + state_space: + ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared, + .. + }) => { + result.push(Directive2::Variable(ast::LinkingDirective::NONE, var)); + None + } + s => Some(s), + }) + .collect() + }); +} diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 935e78d..32597c5 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -1,404 +1,404 @@ -use super::*; -// This pass: -// * Turns all .local, .param and .reg in-body variables into .local variables -// (if _not_ an input method argument) -// * Inserts explicit `ld`/`st` for newly converted .reg variables -// * Fixup state space of all existing `ld`/`st` instructions into newly -// converted variables -// * Turns `.entry` input arguments into param::entry and all related `.param` -// loads into `param::entry` loads -// * All `.func` input arguments are turned into `.reg` arguments by another -// pass, so we do nothing there -pub(super) fn run<'a, 'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { - directives - .into_iter() - .map(|directive| run_directive(resolver, directive)) - .collect::, _>>() -} - -fn run_directive<'a, 'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directive: Directive2, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { - Ok(match directive { - var @ Directive2::Variable(..) => var, - Directive2::Method(method) => { - let visitor = InsertMemSSAVisitor::new(resolver); - Directive2::Method(run_method(visitor, method)?) - } - }) -} - -fn run_method<'a, 'input>( - mut visitor: InsertMemSSAVisitor<'a, 'input>, - mut method: Function2, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { - let is_kernel = method.is_kernel; - if is_kernel { - for arg in method.input_arguments.iter_mut() { - let old_name = arg.name; - let old_space = arg.state_space; - let new_space = ast::StateSpace::ParamEntry; - let new_name = visitor - .resolver - .register_unnamed(Some((arg.v_type.clone(), new_space))); - visitor.input_argument(old_name, new_name, old_space)?; - arg.name = new_name; - arg.state_space = new_space; - } - }; - for arg in method.return_arguments.iter_mut() { - visitor.visit_variable(arg)?; - } - let return_arguments = &method.return_arguments[..]; - let body = method - .body - .map(move |statements| { - let mut result = Vec::with_capacity(statements.len()); - for statement in statements { - run_statement(&mut visitor, return_arguments, &mut result, statement)?; - } - Ok::<_, TranslateError>(result) - }) - .transpose()?; - Ok(Function2 { body, ..method }) -} - -fn run_statement<'a, 'input>( - visitor: &mut InsertMemSSAVisitor<'a, 'input>, - return_arguments: &[ast::Variable], - result: &mut Vec, - statement: ExpandedStatement, -) -> Result<(), TranslateError> { - match statement { - Statement::Instruction(ast::Instruction::Ret { data }) => { - let statement = if return_arguments.is_empty() { - Statement::Instruction(ast::Instruction::Ret { data }) - } else { - Statement::RetValue( - data, - return_arguments - .iter() - .map(|arg| { - if arg.state_space != ast::StateSpace::Local { - return Err(error_unreachable()); - } - Ok((arg.name, arg.v_type.clone())) - }) - .collect::, _>>()?, - ) - }; - let new_statement = statement.visit_map(visitor)?; - result.extend(visitor.pre.drain(..).map(Statement::Instruction)); - result.push(new_statement); - result.extend(visitor.post.drain(..).map(Statement::Instruction)); - } - Statement::Variable(mut var) => { - visitor.visit_variable(&mut var)?; - result.push(Statement::Variable(var)); - } - Statement::Instruction(ast::Instruction::Ld { data, arguments }) => { - let instruction = visitor.visit_ld(data, arguments)?; - let instruction = ast::visit_map(instruction, visitor)?; - result.extend(visitor.pre.drain(..).map(Statement::Instruction)); - result.push(Statement::Instruction(instruction)); - result.extend(visitor.post.drain(..).map(Statement::Instruction)); - } - Statement::Instruction(ast::Instruction::St { data, arguments }) => { - let instruction = visitor.visit_st(data, arguments)?; - let instruction = ast::visit_map(instruction, visitor)?; - result.extend(visitor.pre.drain(..).map(Statement::Instruction)); - result.push(Statement::Instruction(instruction)); - result.extend(visitor.post.drain(..).map(Statement::Instruction)); - } - Statement::PtrAccess(ptr_access) => { - let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); - let statement = statement.visit_map(visitor)?; - result.extend(visitor.pre.drain(..).map(Statement::Instruction)); - result.push(statement); - result.extend(visitor.post.drain(..).map(Statement::Instruction)); - } - s => { - let new_statement = s.visit_map(visitor)?; - result.extend(visitor.pre.drain(..).map(Statement::Instruction)); - result.push(new_statement); - result.extend(visitor.post.drain(..).map(Statement::Instruction)); - } - } - Ok(()) -} - -struct InsertMemSSAVisitor<'a, 'input> { - resolver: &'a mut GlobalStringIdentResolver2<'input>, - variables: FxHashMap, - pre: Vec>, - post: Vec>, -} - -impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { - fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self { - Self { - resolver, - variables: FxHashMap::default(), - pre: Vec::new(), - post: Vec::new(), - } - } - - fn input_argument( - &mut self, - old_name: SpirvWord, - new_name: SpirvWord, - old_space: ast::StateSpace, - ) -> Result<(), TranslateError> { - if old_space != ast::StateSpace::Param { - return Err(error_unreachable()); - } - self.variables.insert( - old_name, - RemapAction::LDStSpaceChange { - name: new_name, - old_space, - new_space: ast::StateSpace::ParamEntry, - }, - ); - Ok(()) - } - - fn variable( - &mut self, - type_: &ast::Type, - old_name: SpirvWord, - new_name: SpirvWord, - old_space: ast::StateSpace, - ) -> Result { - Ok(match old_space { - ast::StateSpace::Reg => { - self.variables.insert( - old_name, - RemapAction::PreLdPostSt { - name: new_name, - type_: type_.clone(), - }, - ); - true - } - ast::StateSpace::Param => { - self.variables.insert( - old_name, - RemapAction::LDStSpaceChange { - old_space, - new_space: ast::StateSpace::Local, - name: new_name, - }, - ); - true - } - // Good as-is - ast::StateSpace::Local - | ast::StateSpace::Generic - | ast::StateSpace::SharedCluster - | ast::StateSpace::Global - | ast::StateSpace::Const - | ast::StateSpace::SharedCta - | ast::StateSpace::Shared - | ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc => return Err(error_unreachable()), - }) - } - - fn visit_st( - &self, - mut data: ast::StData, - mut arguments: ast::StArgs, - ) -> Result, TranslateError> { - if let Some(remap) = self.variables.get(&arguments.src1) { - match remap { - RemapAction::PreLdPostSt { .. } => {} - RemapAction::LDStSpaceChange { - old_space, - new_space, - name, - } => { - if data.state_space != *old_space { - return Err(error_mismatched_type()); - } - data.state_space = *new_space; - arguments.src1 = *name; - } - } - } - Ok(ast::Instruction::St { data, arguments }) - } - - fn visit_ld( - &self, - mut data: ast::LdDetails, - mut arguments: ast::LdArgs, - ) -> Result, TranslateError> { - if let Some(remap) = self.variables.get(&arguments.src) { - match remap { - RemapAction::PreLdPostSt { .. } => {} - RemapAction::LDStSpaceChange { - old_space, - new_space, - name, - } => { - if data.state_space != *old_space { - return Err(error_mismatched_type()); - } - data.state_space = *new_space; - arguments.src = *name; - } - } - } - Ok(ast::Instruction::Ld { data, arguments }) - } - - fn visit_ptr_access( - &mut self, - ptr_access: PtrAccess, - ) -> Result, TranslateError> { - let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) { - Some(RemapAction::LDStSpaceChange { - old_space, - new_space, - name, - }) => (*old_space, *new_space, *name), - Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access), - }; - if ptr_access.state_space != old_space { - return Err(error_mismatched_type()); - } - // Propagate space changes in dst - let new_dst = self - .resolver - .register_unnamed(Some((ptr_access.underlying_type.clone(), new_space))); - self.variables.insert( - ptr_access.dst, - RemapAction::LDStSpaceChange { - old_space, - new_space, - name: new_dst, - }, - ); - Ok(PtrAccess { - ptr_src: name, - dst: new_dst, - state_space: new_space, - ..ptr_access - }) - } - - fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { - let old_space = match var.state_space { - space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, - // Do nothing - ptx_parser::StateSpace::Local => return Ok(()), - // Handled by another pass - ptx_parser::StateSpace::Generic - | ptx_parser::StateSpace::SharedCluster - | ptx_parser::StateSpace::ParamEntry - | ptx_parser::StateSpace::Global - | ptx_parser::StateSpace::SharedCta - | ptx_parser::StateSpace::Const - | ptx_parser::StateSpace::Shared - | ptx_parser::StateSpace::ParamFunc => return Ok(()), - }; - let old_name = var.name; - let new_space = ast::StateSpace::Local; - let new_name = self - .resolver - .register_unnamed(Some((var.v_type.clone(), new_space))); - self.variable(&var.v_type, old_name, new_name, old_space)?; - var.name = new_name; - var.state_space = new_space; - Ok(()) - } -} - -impl<'a, 'input> ast::VisitorMap - for InsertMemSSAVisitor<'a, 'input> -{ - fn visit( - &mut self, - ident: SpirvWord, - _type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - _relaxed_type_check: bool, - ) -> Result { - if let Some(remap) = self.variables.get(&ident) { - match remap { - RemapAction::PreLdPostSt { name, type_ } => { - if is_dst { - let temp = self - .resolver - .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); - self.post.push(ast::Instruction::St { - data: ast::StData { - state_space: ast::StateSpace::Local, - qualifier: ast::LdStQualifier::Weak, - caching: ast::StCacheOperator::Writethrough, - typ: type_.clone(), - }, - arguments: ast::StArgs { - src1: *name, - src2: temp, - }, - }); - Ok(temp) - } else { - let temp = self - .resolver - .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); - self.pre.push(ast::Instruction::Ld { - data: ast::LdDetails { - state_space: ast::StateSpace::Local, - qualifier: ast::LdStQualifier::Weak, - caching: ast::LdCacheOperator::Cached, - typ: type_.clone(), - non_coherent: false, - }, - arguments: ast::LdArgs { - dst: temp, - src: *name, - }, - }); - Ok(temp) - } - } - RemapAction::LDStSpaceChange { .. } => { - return Err(error_mismatched_type()); - } - } - } else { - Ok(ident) - } - } - - fn visit_ident( - &mut self, - args: SpirvWord, - type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - self.visit(args, type_space, is_dst, relaxed_type_check) - } -} - -#[derive(Clone)] -enum RemapAction { - PreLdPostSt { - name: SpirvWord, - type_: ast::Type, - }, - LDStSpaceChange { - old_space: ast::StateSpace, - new_space: ast::StateSpace, - name: SpirvWord, - }, -} +use super::*; +// This pass: +// * Turns all .local, .param and .reg in-body variables into .local variables +// (if _not_ an input method argument) +// * Inserts explicit `ld`/`st` for newly converted .reg variables +// * Fixup state space of all existing `ld`/`st` instructions into newly +// converted variables +// * Turns `.entry` input arguments into param::entry and all related `.param` +// loads into `param::entry` loads +// * All `.func` input arguments are turned into `.reg` arguments by another +// pass, so we do nothing there +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => { + let visitor = InsertMemSSAVisitor::new(resolver); + Directive2::Method(run_method(visitor, method)?) + } + }) +} + +fn run_method<'a, 'input>( + mut visitor: InsertMemSSAVisitor<'a, 'input>, + mut method: Function2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + let is_kernel = method.is_kernel; + if is_kernel { + for arg in method.input_arguments.iter_mut() { + let old_name = arg.name; + let old_space = arg.state_space; + let new_space = ast::StateSpace::ParamEntry; + let new_name = visitor + .resolver + .register_unnamed(Some((arg.v_type.clone(), new_space))); + visitor.input_argument(old_name, new_name, old_space)?; + arg.name = new_name; + arg.state_space = new_space; + } + }; + for arg in method.return_arguments.iter_mut() { + visitor.visit_variable(arg)?; + } + let return_arguments = &method.return_arguments[..]; + let body = method + .body + .map(move |statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(&mut visitor, return_arguments, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { body, ..method }) +} + +fn run_statement<'a, 'input>( + visitor: &mut InsertMemSSAVisitor<'a, 'input>, + return_arguments: &[ast::Variable], + result: &mut Vec, + statement: ExpandedStatement, +) -> Result<(), TranslateError> { + match statement { + Statement::Instruction(ast::Instruction::Ret { data }) => { + let statement = if return_arguments.is_empty() { + Statement::Instruction(ast::Instruction::Ret { data }) + } else { + Statement::RetValue( + data, + return_arguments + .iter() + .map(|arg| { + if arg.state_space != ast::StateSpace::Local { + return Err(error_unreachable()); + } + Ok((arg.name, arg.v_type.clone())) + }) + .collect::, _>>()?, + ) + }; + let new_statement = statement.visit_map(visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(new_statement); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + Statement::Variable(mut var) => { + visitor.visit_variable(&mut var)?; + result.push(Statement::Variable(var)); + } + Statement::Instruction(ast::Instruction::Ld { data, arguments }) => { + let instruction = visitor.visit_ld(data, arguments)?; + let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + Statement::Instruction(ast::Instruction::St { data, arguments }) => { + let instruction = visitor.visit_st(data, arguments)?; + let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + Statement::PtrAccess(ptr_access) => { + let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); + let statement = statement.visit_map(visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(statement); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + s => { + let new_statement = s.visit_map(visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(new_statement); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + } + Ok(()) +} + +struct InsertMemSSAVisitor<'a, 'input> { + resolver: &'a mut GlobalStringIdentResolver2<'input>, + variables: FxHashMap, + pre: Vec>, + post: Vec>, +} + +impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { + fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self { + Self { + resolver, + variables: FxHashMap::default(), + pre: Vec::new(), + post: Vec::new(), + } + } + + fn input_argument( + &mut self, + old_name: SpirvWord, + new_name: SpirvWord, + old_space: ast::StateSpace, + ) -> Result<(), TranslateError> { + if old_space != ast::StateSpace::Param { + return Err(error_unreachable()); + } + self.variables.insert( + old_name, + RemapAction::LDStSpaceChange { + name: new_name, + old_space, + new_space: ast::StateSpace::ParamEntry, + }, + ); + Ok(()) + } + + fn variable( + &mut self, + type_: &ast::Type, + old_name: SpirvWord, + new_name: SpirvWord, + old_space: ast::StateSpace, + ) -> Result { + Ok(match old_space { + ast::StateSpace::Reg => { + self.variables.insert( + old_name, + RemapAction::PreLdPostSt { + name: new_name, + type_: type_.clone(), + }, + ); + true + } + ast::StateSpace::Param => { + self.variables.insert( + old_name, + RemapAction::LDStSpaceChange { + old_space, + new_space: ast::StateSpace::Local, + name: new_name, + }, + ); + true + } + // Good as-is + ast::StateSpace::Local + | ast::StateSpace::Generic + | ast::StateSpace::SharedCluster + | ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::SharedCta + | ast::StateSpace::Shared + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => return Err(error_unreachable()), + }) + } + + fn visit_st( + &self, + mut data: ast::StData, + mut arguments: ast::StArgs, + ) -> Result, TranslateError> { + if let Some(remap) = self.variables.get(&arguments.src1) { + match remap { + RemapAction::PreLdPostSt { .. } => {} + RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + } => { + if data.state_space != *old_space { + return Err(error_mismatched_type()); + } + data.state_space = *new_space; + arguments.src1 = *name; + } + } + } + Ok(ast::Instruction::St { data, arguments }) + } + + fn visit_ld( + &self, + mut data: ast::LdDetails, + mut arguments: ast::LdArgs, + ) -> Result, TranslateError> { + if let Some(remap) = self.variables.get(&arguments.src) { + match remap { + RemapAction::PreLdPostSt { .. } => {} + RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + } => { + if data.state_space != *old_space { + return Err(error_mismatched_type()); + } + data.state_space = *new_space; + arguments.src = *name; + } + } + } + Ok(ast::Instruction::Ld { data, arguments }) + } + + fn visit_ptr_access( + &mut self, + ptr_access: PtrAccess, + ) -> Result, TranslateError> { + let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) { + Some(RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + }) => (*old_space, *new_space, *name), + Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access), + }; + if ptr_access.state_space != old_space { + return Err(error_mismatched_type()); + } + // Propagate space changes in dst + let new_dst = self + .resolver + .register_unnamed(Some((ptr_access.underlying_type.clone(), new_space))); + self.variables.insert( + ptr_access.dst, + RemapAction::LDStSpaceChange { + old_space, + new_space, + name: new_dst, + }, + ); + Ok(PtrAccess { + ptr_src: name, + dst: new_dst, + state_space: new_space, + ..ptr_access + }) + } + + fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { + let old_space = match var.state_space { + space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, + // Do nothing + ptx_parser::StateSpace::Local => return Ok(()), + // Handled by another pass + ptx_parser::StateSpace::Generic + | ptx_parser::StateSpace::SharedCluster + | ptx_parser::StateSpace::ParamEntry + | ptx_parser::StateSpace::Global + | ptx_parser::StateSpace::SharedCta + | ptx_parser::StateSpace::Const + | ptx_parser::StateSpace::Shared + | ptx_parser::StateSpace::ParamFunc => return Ok(()), + }; + let old_name = var.name; + let new_space = ast::StateSpace::Local; + let new_name = self + .resolver + .register_unnamed(Some((var.v_type.clone(), new_space))); + self.variable(&var.v_type, old_name, new_name, old_space)?; + var.name = new_name; + var.state_space = new_space; + Ok(()) + } +} + +impl<'a, 'input> ast::VisitorMap + for InsertMemSSAVisitor<'a, 'input> +{ + fn visit( + &mut self, + ident: SpirvWord, + _type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + if let Some(remap) = self.variables.get(&ident) { + match remap { + RemapAction::PreLdPostSt { name, type_ } => { + if is_dst { + let temp = self + .resolver + .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); + self.post.push(ast::Instruction::St { + data: ast::StData { + state_space: ast::StateSpace::Local, + qualifier: ast::LdStQualifier::Weak, + caching: ast::StCacheOperator::Writethrough, + typ: type_.clone(), + }, + arguments: ast::StArgs { + src1: *name, + src2: temp, + }, + }); + Ok(temp) + } else { + let temp = self + .resolver + .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); + self.pre.push(ast::Instruction::Ld { + data: ast::LdDetails { + state_space: ast::StateSpace::Local, + qualifier: ast::LdStQualifier::Weak, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: temp, + src: *name, + }, + }); + Ok(temp) + } + } + RemapAction::LDStSpaceChange { .. } => { + return Err(error_mismatched_type()); + } + } + } else { + Ok(ident) + } + } + + fn visit_ident( + &mut self, + args: SpirvWord, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + self.visit(args, type_space, is_dst, relaxed_type_check) + } +} + +#[derive(Clone)] +enum RemapAction { + PreLdPostSt { + name: SpirvWord, + type_: ast::Type, + }, + LDStSpaceChange { + old_space: ast::StateSpace, + new_space: ast::StateSpace, + name: SpirvWord, + }, +} diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs index 5189664..5b0fd3b 100644 --- a/ptx/src/pass/insert_implicit_conversions2.rs +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -1,401 +1,401 @@ -use std::mem; - -use super::*; -use ptx_parser as ast; - -/* - There are several kinds of implicit conversions in PTX: - * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands - * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size - - ld.param: not documented, but for instruction `ld.param. x, [y]`, - semantics are to first zext/chop/bitcast `y` as needed and then do - documented special ld/st/cvt conversion rules for destination operands - - st.param [x] y (used as function return arguments) same rule as above applies - - generic/global ld: for instruction `ld x, [y]`, y must be of type - b64/u64/s64, which is bitcast to a pointer, dereferenced and then - documented special ld/st/cvt conversion rules are applied to dst - - generic/global st: for instruction `st [x], y`, x must be of type - b64/u64/s64, which is bitcast to a pointer -*/ -pub(super) fn run<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { - directives - .into_iter() - .map(|directive| run_directive(resolver, directive)) - .collect::, _>>() -} - -fn run_directive<'a, 'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directive: Directive2, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { - Ok(match directive { - var @ Directive2::Variable(..) => var, - Directive2::Method(mut method) => { - method.body = method - .body - .map(|statements| run_statements(resolver, statements)) - .transpose()?; - Directive2::Method(method) - } - }) -} - -fn run_statements<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - func: Vec, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for s in func.into_iter() { - insert_implicit_conversions_impl(resolver, &mut result, s)?; - } - Ok(result) -} - -fn insert_implicit_conversions_impl<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - func: &mut Vec, - stmt: ExpandedStatement, -) -> Result<(), TranslateError> { - let mut post_conv = Vec::new(); - let statement = stmt.visit_map::( - &mut |operand, - type_state: Option<(&ast::Type, ast::StateSpace)>, - is_dst, - relaxed_type_check| { - let (instr_type, instruction_space) = match type_state { - None => return Ok(operand), - Some(t) => t, - }; - let (operand_type, operand_space) = resolver.get_typed(operand)?; - let conversion_fn = if relaxed_type_check { - if is_dst { - should_convert_relaxed_dst_wrapper - } else { - should_convert_relaxed_src_wrapper - } - } else { - default_implicit_conversion - }; - match conversion_fn( - (*operand_space, &operand_type), - (instruction_space, instr_type), - )? { - Some(conv_kind) => { - let conv_output = if is_dst { &mut post_conv } else { &mut *func }; - let mut from_type = instr_type.clone(); - let mut from_space = instruction_space; - let mut to_type = operand_type.clone(); - let mut to_space = *operand_space; - let mut src = - resolver.register_unnamed(Some((instr_type.clone(), instruction_space))); - let mut dst = operand; - let result = Ok::<_, TranslateError>(src); - if !is_dst { - mem::swap(&mut src, &mut dst); - mem::swap(&mut from_type, &mut to_type); - mem::swap(&mut from_space, &mut to_space); - } - conv_output.push(Statement::Conversion(ImplicitConversion { - src, - dst, - from_type, - from_space, - to_type, - to_space, - kind: conv_kind, - })); - result - } - None => Ok(operand), - } - }, - )?; - func.push(statement); - func.append(&mut post_conv); - Ok(()) -} - -pub(crate) fn default_implicit_conversion( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if instruction_space == ast::StateSpace::Reg { - if operand_space == ast::StateSpace::Reg { - if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = - (operand_type, instruction_type) - { - if scalar.kind() == ast::ScalarKind::Bit - && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) - { - return Ok(Some(ConversionKind::Default)); - } - } - } else if is_addressable(operand_space) { - return Ok(Some(ConversionKind::AddressOf)); - } - } - if instruction_space != operand_space { - default_implicit_conversion_space((operand_space, operand_type), instruction_space) - } else if instruction_type != operand_type { - default_implicit_conversion_type(instruction_space, operand_type, instruction_type) - } else { - Ok(None) - } -} - -fn is_addressable(this: ast::StateSpace) -> bool { - match this { - ast::StateSpace::Const - | ast::StateSpace::Generic - | ast::StateSpace::Global - | ast::StateSpace::Local - | ast::StateSpace::Shared => true, - ast::StateSpace::Param | ast::StateSpace::Reg => false, - ast::StateSpace::SharedCluster - | ast::StateSpace::SharedCta - | ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc => todo!(), - } -} - -// Space is different -fn default_implicit_conversion_space( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - instruction_space: ast::StateSpace, -) -> Result, TranslateError> { - if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) - || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) - { - Ok(Some(ConversionKind::PtrToPtr)) - } else if operand_space == ast::StateSpace::Reg { - match operand_type { - // TODO: 32 bit - ast::Type::Scalar(ast::ScalarType::B64) - | ast::Type::Scalar(ast::ScalarType::U64) - | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { - ast::StateSpace::Global - | ast::StateSpace::Generic - | ast::StateSpace::Const - | ast::StateSpace::Local - | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), - _ => Err(error_mismatched_type()), - }, - ast::Type::Scalar(ast::ScalarType::B32) - | ast::Type::Scalar(ast::ScalarType::U32) - | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { - ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { - Ok(Some(ConversionKind::BitToPtr)) - } - _ => Err(error_mismatched_type()), - }, - _ => Err(error_mismatched_type()), - } - } else { - Err(error_mismatched_type()) - } -} - -// Space is same, but type is different -fn default_implicit_conversion_type( - space: ast::StateSpace, - operand_type: &ast::Type, - instruction_type: &ast::Type, -) -> Result, TranslateError> { - if space == ast::StateSpace::Reg { - if should_bitcast(instruction_type, operand_type) { - Ok(Some(ConversionKind::Default)) - } else { - Err(TranslateError::MismatchedType) - } - } else { - Ok(Some(ConversionKind::PtrToPtr)) - } -} - -fn coerces_to_generic(this: ast::StateSpace) -> bool { - match this { - ast::StateSpace::Global - | ast::StateSpace::Const - | ast::StateSpace::Local - | ptx_parser::StateSpace::SharedCta - | ast::StateSpace::SharedCluster - | ast::StateSpace::Shared => true, - ast::StateSpace::Reg - | ast::StateSpace::Param - | ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc - | ast::StateSpace::Generic => false, - } -} - -fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { - match (instr, operand) { - (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { - if inst.size_of() != operand.size_of() { - return false; - } - match inst.kind() { - ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, - ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, - ast::ScalarKind::Signed => { - operand.kind() == ast::ScalarKind::Bit - || operand.kind() == ast::ScalarKind::Unsigned - } - ast::ScalarKind::Unsigned => { - operand.kind() == ast::ScalarKind::Bit - || operand.kind() == ast::ScalarKind::Signed - } - ast::ScalarKind::Pred => false, - } - } - (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) - | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { - should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) - } - _ => false, - } -} - -pub(crate) fn should_convert_relaxed_dst_wrapper( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if operand_space != instruction_space { - return Err(TranslateError::MismatchedType); - } - if operand_type == instruction_type { - return Ok(None); - } - match should_convert_relaxed_dst(operand_type, instruction_type) { - conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands -fn should_convert_relaxed_dst( - dst_type: &ast::Type, - instr_type: &ast::Type, -) -> Option { - if dst_type == instr_type { - return None; - } - match (dst_type, instr_type) { - (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ast::ScalarKind::Bit => { - if instr_type.size_of() <= dst_type.size_of() { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Signed => { - if dst_type.kind() != ast::ScalarKind::Float { - if instr_type.size_of() == dst_type.size_of() { - Some(ConversionKind::Default) - } else if instr_type.size_of() < dst_type.size_of() { - Some(ConversionKind::SignExtend) - } else { - None - } - } else { - None - } - } - ast::ScalarKind::Unsigned => { - if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() != ast::ScalarKind::Float - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Float => { - if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() == ast::ScalarKind::Bit - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Pred => None, - }, - (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) - | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { - should_convert_relaxed_dst( - &ast::Type::Scalar(*dst_type), - &ast::Type::Scalar(*instr_type), - ) - } - _ => None, - } -} - -pub(crate) fn should_convert_relaxed_src_wrapper( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if operand_space != instruction_space { - return Err(error_mismatched_type()); - } - if operand_type == instruction_type { - return Ok(None); - } - match should_convert_relaxed_src(operand_type, instruction_type) { - conv @ Some(_) => Ok(conv), - None => Err(error_mismatched_type()), - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands -fn should_convert_relaxed_src( - src_type: &ast::Type, - instr_type: &ast::Type, -) -> Option { - if src_type == instr_type { - return None; - } - match (src_type, instr_type) { - (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ast::ScalarKind::Bit => { - if instr_type.size_of() <= src_type.size_of() { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { - if instr_type.size_of() <= src_type.size_of() - && src_type.kind() != ast::ScalarKind::Float - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Float => { - if instr_type.size_of() <= src_type.size_of() - && src_type.kind() == ast::ScalarKind::Bit - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Pred => None, - }, - (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) - | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { - should_convert_relaxed_src( - &ast::Type::Scalar(*dst_type), - &ast::Type::Scalar(*instr_type), - ) - } - _ => None, - } -} +use std::mem; + +use super::*; +use ptx_parser as ast; + +/* + There are several kinds of implicit conversions in PTX: + * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands + * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size + - ld.param: not documented, but for instruction `ld.param. x, [y]`, + semantics are to first zext/chop/bitcast `y` as needed and then do + documented special ld/st/cvt conversion rules for destination operands + - st.param [x] y (used as function return arguments) same rule as above applies + - generic/global ld: for instruction `ld x, [y]`, y must be of type + b64/u64/s64, which is bitcast to a pointer, dereferenced and then + documented special ld/st/cvt conversion rules are applied to dst + - generic/global st: for instruction `st [x], y`, x must be of type + b64/u64/s64, which is bitcast to a pointer +*/ +pub(super) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(mut method) => { + method.body = method + .body + .map(|statements| run_statements(resolver, statements)) + .transpose()?; + Directive2::Method(method) + } + }) +} + +fn run_statements<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + func: Vec, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func.into_iter() { + insert_implicit_conversions_impl(resolver, &mut result, s)?; + } + Ok(result) +} + +fn insert_implicit_conversions_impl<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + func: &mut Vec, + stmt: ExpandedStatement, +) -> Result<(), TranslateError> { + let mut post_conv = Vec::new(); + let statement = stmt.visit_map::( + &mut |operand, + type_state: Option<(&ast::Type, ast::StateSpace)>, + is_dst, + relaxed_type_check| { + let (instr_type, instruction_space) = match type_state { + None => return Ok(operand), + Some(t) => t, + }; + let (operand_type, operand_space) = resolver.get_typed(operand)?; + let conversion_fn = if relaxed_type_check { + if is_dst { + should_convert_relaxed_dst_wrapper + } else { + should_convert_relaxed_src_wrapper + } + } else { + default_implicit_conversion + }; + match conversion_fn( + (*operand_space, &operand_type), + (instruction_space, instr_type), + )? { + Some(conv_kind) => { + let conv_output = if is_dst { &mut post_conv } else { &mut *func }; + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type.clone(); + let mut to_space = *operand_space; + let mut src = + resolver.register_unnamed(Some((instr_type.clone(), instruction_space))); + let mut dst = operand; + let result = Ok::<_, TranslateError>(src); + if !is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + from_space, + to_type, + to_space, + kind: conv_kind, + })); + result + } + None => Ok(operand), + } + }, + )?; + func.push(statement); + func.append(&mut post_conv); + Ok(()) +} + +pub(crate) fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if instruction_space == ast::StateSpace::Reg { + if operand_space == ast::StateSpace::Reg { + if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) + { + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } + } + } else if is_addressable(operand_space) { + return Ok(Some(ConversionKind::AddressOf)); + } + } + if instruction_space != operand_space { + default_implicit_conversion_space((operand_space, operand_type), instruction_space) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) + } +} + +fn is_addressable(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg => false, + ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => todo!(), + } +} + +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + instruction_space: ast::StateSpace, +) -> Result, TranslateError> { + if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) + || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if operand_space == ast::StateSpace::Reg { + match operand_type { + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(error_mismatched_type()), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } + _ => Err(error_mismatched_type()), + }, + _ => Err(error_mismatched_type()), + } + } else { + Err(error_mismatched_type()) + } +} + +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, +) -> Result, TranslateError> { + if space == ast::StateSpace::Reg { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) + } + } else { + Ok(Some(ConversionKind::PtrToPtr)) + } +} + +fn coerces_to_generic(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCluster + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::Generic => false, + } +} + +fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { + match (instr, operand) { + (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { + if inst.size_of() != operand.size_of() { + return false; + } + match inst.kind() { + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned + } + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed + } + ast::ScalarKind::Pred => false, + } + } + (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) + | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { + should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) + } + _ => false, + } +} + +pub(crate) fn should_convert_relaxed_dst_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if operand_space != instruction_space { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_dst(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands +fn should_convert_relaxed_dst( + dst_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if dst_type == instr_type { + return None; + } + match (dst_type, instr_type) { + (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= dst_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { + if instr_type.size_of() == dst_type.size_of() { + Some(ConversionKind::Default) + } else if instr_type.size_of() < dst_type.size_of() { + Some(ConversionKind::SignExtend) + } else { + None + } + } else { + None + } + } + ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { + should_convert_relaxed_dst( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} + +pub(crate) fn should_convert_relaxed_src_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if operand_space != instruction_space { + return Err(error_mismatched_type()); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_src(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(error_mismatched_type()), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands +fn should_convert_relaxed_src( + src_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if src_type == instr_type { + return None; + } + match (src_type, instr_type) { + (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= src_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { + should_convert_relaxed_src( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index 810ef3e..05045b7 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -1,194 +1,194 @@ -use super::*; -use ptx_parser as ast; - -pub(crate) fn run<'input, 'b>( - resolver: &mut ScopedResolver<'input, 'b>, - directives: Vec>>, -) -> Result, TranslateError> { - resolver.start_scope(); - let result = directives - .into_iter() - .map(|directive| run_directive(resolver, directive)) - .collect::, _>>()?; - resolver.end_scope(); - Ok(result) -} - -fn run_directive<'input, 'b>( - resolver: &mut ScopedResolver<'input, 'b>, - directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, -) -> Result { - Ok(match directive { - ast::Directive::Variable(linking, var) => { - NormalizedDirective2::Variable(linking, run_variable(resolver, var)?) - } - ast::Directive::Method(linking, directive) => { - NormalizedDirective2::Method(run_method(resolver, linking, directive)?) - } - }) -} - -fn run_method<'input, 'b>( - resolver: &mut ScopedResolver<'input, 'b>, - linkage: ast::LinkingDirective, - method: ast::Function<'input, &'input str, ast::Statement>>, -) -> Result { - let is_kernel = method.func_directive.name.is_kernel(); - let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?; - resolver.start_scope(); - let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?; - let body = method - .body - .map(|statements| { - let mut result = Vec::with_capacity(statements.len()); - run_statements(resolver, &mut result, statements)?; - Ok::<_, TranslateError>(result) - }) - .transpose()?; - resolver.end_scope(); - Ok(Function2 { - return_arguments, - name, - input_arguments, - body, - import_as: None, - linkage, - is_kernel, - tuning: method.tuning, - flush_to_zero_f32: false, - flush_to_zero_f16f64: false, - rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, - rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, - }) -} - -fn run_function_decl<'input, 'b>( - resolver: &mut ScopedResolver<'input, 'b>, - func_directive: ast::MethodDeclaration<'input, &'input str>, -) -> Result<(Vec>, Vec>), TranslateError> { - assert!(func_directive.shared_mem.is_none()); - let return_arguments = func_directive - .return_arguments - .into_iter() - .map(|var| run_variable(resolver, var)) - .collect::, _>>()?; - let input_arguments = func_directive - .input_arguments - .into_iter() - .map(|var| run_variable(resolver, var)) - .collect::, _>>()?; - Ok((return_arguments, input_arguments)) -} - -fn run_variable<'input, 'b>( - resolver: &mut ScopedResolver<'input, 'b>, - variable: ast::Variable<&'input str>, -) -> Result, TranslateError> { - Ok(ast::Variable { - name: resolver.add( - Cow::Borrowed(variable.name), - Some((variable.v_type.clone(), variable.state_space)), - )?, - align: variable.align, - v_type: variable.v_type, - state_space: variable.state_space, - array_init: variable.array_init, - }) -} - -fn run_statements<'input, 'b>( - resolver: &mut ScopedResolver<'input, 'b>, - result: &mut Vec, - statements: Vec>>, -) -> Result<(), TranslateError> { - for statement in statements.iter() { - match statement { - ast::Statement::Label(label) => { - resolver.add(Cow::Borrowed(*label), None)?; - } - _ => {} - } - } - for statement in statements { - match statement { - ast::Statement::Label(label) => { - result.push(Statement::Label(resolver.get_in_current_scope(label)?)) - } - ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?, - ast::Statement::Instruction(predicate, instruction) => { - result.push(Statement::Instruction(( - predicate - .map(|pred| { - Ok::<_, TranslateError>(ast::PredAt { - not: pred.not, - label: resolver.get(pred.label)?, - }) - }) - .transpose()?, - run_instruction(resolver, instruction)?, - ))) - } - ast::Statement::Block(block) => { - resolver.start_scope(); - run_statements(resolver, result, block)?; - resolver.end_scope(); - } - } - } - Ok(()) -} - -fn run_instruction<'input, 'b>( - resolver: &mut ScopedResolver<'input, 'b>, - instruction: ast::Instruction>, -) -> Result>, TranslateError> { - ast::visit_map(instruction, &mut |name: &'input str, - _: Option<( - &ast::Type, - ast::StateSpace, - )>, - _, - _| { - resolver.get(&name) - }) -} - -fn run_multivariable<'input, 'b>( - resolver: &mut ScopedResolver<'input, 'b>, - result: &mut Vec, - variable: ast::MultiVariable<&'input str>, -) -> Result<(), TranslateError> { - match variable.count { - Some(count) => { - for i in 0..count { - let name = Cow::Owned(format!("{}{}", variable.var.name, i)); - let ident = resolver.add( - name, - Some((variable.var.v_type.clone(), variable.var.state_space)), - )?; - result.push(Statement::Variable(ast::Variable { - align: variable.var.align, - v_type: variable.var.v_type.clone(), - state_space: variable.var.state_space, - name: ident, - array_init: variable.var.array_init.clone(), - })); - } - } - None => { - let name = Cow::Borrowed(variable.var.name); - let ident = resolver.add( - name, - Some((variable.var.v_type.clone(), variable.var.state_space)), - )?; - result.push(Statement::Variable(ast::Variable { - align: variable.var.align, - v_type: variable.var.v_type.clone(), - state_space: variable.var.state_space, - name: ident, - array_init: variable.var.array_init.clone(), - })); - } - } - Ok(()) -} +use super::*; +use ptx_parser as ast; + +pub(crate) fn run<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + directives: Vec>>, +) -> Result, TranslateError> { + resolver.start_scope(); + let result = directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>()?; + resolver.end_scope(); + Ok(result) +} + +fn run_directive<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, +) -> Result { + Ok(match directive { + ast::Directive::Variable(linking, var) => { + NormalizedDirective2::Variable(linking, run_variable(resolver, var)?) + } + ast::Directive::Method(linking, directive) => { + NormalizedDirective2::Method(run_method(resolver, linking, directive)?) + } + }) +} + +fn run_method<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + linkage: ast::LinkingDirective, + method: ast::Function<'input, &'input str, ast::Statement>>, +) -> Result { + let is_kernel = method.func_directive.name.is_kernel(); + let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?; + resolver.start_scope(); + let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?; + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + run_statements(resolver, &mut result, statements)?; + Ok::<_, TranslateError>(result) + }) + .transpose()?; + resolver.end_scope(); + Ok(Function2 { + return_arguments, + name, + input_arguments, + body, + import_as: None, + linkage, + is_kernel, + tuning: method.tuning, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + }) +} + +fn run_function_decl<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + func_directive: ast::MethodDeclaration<'input, &'input str>, +) -> Result<(Vec>, Vec>), TranslateError> { + assert!(func_directive.shared_mem.is_none()); + let return_arguments = func_directive + .return_arguments + .into_iter() + .map(|var| run_variable(resolver, var)) + .collect::, _>>()?; + let input_arguments = func_directive + .input_arguments + .into_iter() + .map(|var| run_variable(resolver, var)) + .collect::, _>>()?; + Ok((return_arguments, input_arguments)) +} + +fn run_variable<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + variable: ast::Variable<&'input str>, +) -> Result, TranslateError> { + Ok(ast::Variable { + name: resolver.add( + Cow::Borrowed(variable.name), + Some((variable.v_type.clone(), variable.state_space)), + )?, + align: variable.align, + v_type: variable.v_type, + state_space: variable.state_space, + array_init: variable.array_init, + }) +} + +fn run_statements<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + result: &mut Vec, + statements: Vec>>, +) -> Result<(), TranslateError> { + for statement in statements.iter() { + match statement { + ast::Statement::Label(label) => { + resolver.add(Cow::Borrowed(*label), None)?; + } + _ => {} + } + } + for statement in statements { + match statement { + ast::Statement::Label(label) => { + result.push(Statement::Label(resolver.get_in_current_scope(label)?)) + } + ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?, + ast::Statement::Instruction(predicate, instruction) => { + result.push(Statement::Instruction(( + predicate + .map(|pred| { + Ok::<_, TranslateError>(ast::PredAt { + not: pred.not, + label: resolver.get(pred.label)?, + }) + }) + .transpose()?, + run_instruction(resolver, instruction)?, + ))) + } + ast::Statement::Block(block) => { + resolver.start_scope(); + run_statements(resolver, result, block)?; + resolver.end_scope(); + } + } + } + Ok(()) +} + +fn run_instruction<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + instruction: ast::Instruction>, +) -> Result>, TranslateError> { + ast::visit_map(instruction, &mut |name: &'input str, + _: Option<( + &ast::Type, + ast::StateSpace, + )>, + _, + _| { + resolver.get(&name) + }) +} + +fn run_multivariable<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + result: &mut Vec, + variable: ast::MultiVariable<&'input str>, +) -> Result<(), TranslateError> { + match variable.count { + Some(count) => { + for i in 0..count { + let name = Cow::Owned(format!("{}{}", variable.var.name, i)); + let ident = resolver.add( + name, + Some((variable.var.v_type.clone(), variable.var.state_space)), + )?; + result.push(Statement::Variable(ast::Variable { + align: variable.var.align, + v_type: variable.var.v_type.clone(), + state_space: variable.var.state_space, + name: ident, + array_init: variable.var.array_init.clone(), + })); + } + } + None => { + let name = Cow::Borrowed(variable.var.name); + let ident = resolver.add( + name, + Some((variable.var.v_type.clone(), variable.var.state_space)), + )?; + result.push(Statement::Variable(ast::Variable { + align: variable.var.align, + v_type: variable.var.v_type.clone(), + state_space: variable.var.state_space, + name: ident, + array_init: variable.var.array_init.clone(), + })); + } + } + Ok(()) +} diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs index ae41021..a053d06 100644 --- a/ptx/src/pass/normalize_predicates2.rs +++ b/ptx/src/pass/normalize_predicates2.rs @@ -1,90 +1,90 @@ -use super::*; -use ptx_parser as ast; - -pub(crate) fn run<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, -) -> Result, TranslateError> { - directives - .into_iter() - .map(|directive| run_directive(resolver, directive)) - .collect::, _>>() -} - -fn run_directive<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directive: NormalizedDirective2, -) -> Result { - Ok(match directive { - Directive2::Variable(linking, var) => Directive2::Variable(linking, var), - Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), - }) -} - -fn run_method<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - method: NormalizedFunction2, -) -> Result { - let body = method - .body - .map(|statements| { - let mut result = Vec::with_capacity(statements.len()); - for statement in statements { - run_statement(resolver, &mut result, statement)?; - } - Ok::<_, TranslateError>(result) - }) - .transpose()?; - Ok(Function2 { - body, - return_arguments: method.return_arguments, - name: method.name, - input_arguments: method.input_arguments, - import_as: method.import_as, - tuning: method.tuning, - linkage: method.linkage, - is_kernel: method.is_kernel, - flush_to_zero_f32: method.flush_to_zero_f32, - flush_to_zero_f16f64: method.flush_to_zero_f16f64, - rounding_mode_f32: method.rounding_mode_f32, - rounding_mode_f16f64: method.rounding_mode_f16f64, - }) -} - -fn run_statement<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - result: &mut Vec, - statement: NormalizedStatement, -) -> Result<(), TranslateError> { - Ok(match statement { - Statement::Label(label) => result.push(Statement::Label(label)), - Statement::Variable(var) => result.push(Statement::Variable(var)), - Statement::Instruction((predicate, instruction)) => { - if let Some(pred) = predicate { - let if_true = resolver.register_unnamed(None); - let if_false = resolver.register_unnamed(None); - let folded_bra = match &instruction { - ast::Instruction::Bra { arguments, .. } => Some(arguments.src), - _ => None, - }; - let mut branch = BrachCondition { - predicate: pred.label, - if_true: folded_bra.unwrap_or(if_true), - if_false, - }; - if pred.not { - std::mem::swap(&mut branch.if_true, &mut branch.if_false); - } - result.push(Statement::Conditional(branch)); - if folded_bra.is_none() { - result.push(Statement::Label(if_true)); - result.push(Statement::Instruction(instruction)); - } - result.push(Statement::Label(if_false)); - } else { - result.push(Statement::Instruction(instruction)); - } - } - _ => return Err(error_unreachable()), - }) -} +use super::*; +use ptx_parser as ast; + +pub(crate) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, +) -> Result, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: NormalizedDirective2, +) -> Result { + Ok(match directive { + Directive2::Variable(linking, var) => Directive2::Variable(linking, var), + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + method: NormalizedFunction2, +) -> Result { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(resolver, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + body, + return_arguments: method.return_arguments, + name: method.name, + input_arguments: method.input_arguments, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + is_kernel: method.is_kernel, + flush_to_zero_f32: method.flush_to_zero_f32, + flush_to_zero_f16f64: method.flush_to_zero_f16f64, + rounding_mode_f32: method.rounding_mode_f32, + rounding_mode_f16f64: method.rounding_mode_f16f64, + }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, + statement: NormalizedStatement, +) -> Result<(), TranslateError> { + Ok(match statement { + Statement::Label(label) => result.push(Statement::Label(label)), + Statement::Variable(var) => result.push(Statement::Variable(var)), + Statement::Instruction((predicate, instruction)) => { + if let Some(pred) = predicate { + let if_true = resolver.register_unnamed(None); + let if_false = resolver.register_unnamed(None); + let folded_bra = match &instruction { + ast::Instruction::Bra { arguments, .. } => Some(arguments.src), + _ => None, + }; + let mut branch = BrachCondition { + predicate: pred.label, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + if pred.not { + std::mem::swap(&mut branch.if_true, &mut branch.if_false); + } + result.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + result.push(Statement::Label(if_true)); + result.push(Statement::Instruction(instruction)); + } + result.push(Statement::Label(if_false)); + } else { + result.push(Statement::Instruction(instruction)); + } + } + _ => return Err(error_unreachable()), + }) +} diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 84bb442..8123e41 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -1,268 +1,268 @@ -use super::*; - -pub(super) fn run<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { - let mut fn_declarations = FxHashMap::default(); - let remapped_directives = directives - .into_iter() - .map(|directive| run_directive(resolver, &mut fn_declarations, directive)) - .collect::, _>>()?; - let mut result = fn_declarations - .into_iter() - .map(|(_, (return_arguments, name, input_arguments))| { - Directive2::Method(Function2 { - return_arguments, - name: name, - input_arguments, - body: None, - import_as: None, - tuning: Vec::new(), - linkage: ast::LinkingDirective::EXTERN, - is_kernel: false, - flush_to_zero_f32: false, - flush_to_zero_f16f64: false, - rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, - rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, - }) - }) - .collect::>(); - result.extend(remapped_directives); - Ok(result) -} - -fn run_directive<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - fn_declarations: &mut FxHashMap< - Cow<'input, str>, - ( - Vec>, - SpirvWord, - Vec>, - ), - >, - directive: Directive2, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { - Ok(match directive { - var @ Directive2::Variable(..) => var, - Directive2::Method(mut method) => { - method.body = method - .body - .map(|statements| run_statements(resolver, fn_declarations, statements)) - .transpose()?; - Directive2::Method(method) - } - }) -} - -fn run_statements<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - fn_declarations: &mut FxHashMap< - Cow<'input, str>, - ( - Vec>, - SpirvWord, - Vec>, - ), - >, - statements: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { - statements - .into_iter() - .map(|statement| { - Ok(match statement { - Statement::Instruction(instruction) => { - Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?) - } - s => s, - }) - }) - .collect::, _>>() -} - -fn run_instruction<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - fn_declarations: &mut FxHashMap< - Cow<'input, str>, - ( - Vec>, - SpirvWord, - Vec>, - ), - >, - instruction: ptx_parser::Instruction, -) -> Result, TranslateError> { - Ok(match instruction { - i @ ptx_parser::Instruction::Sqrt { - data: - ast::RcpData { - kind: ast::RcpKind::Approx, - type_: ast::ScalarType::F32, - flush_to_zero: None | Some(false), - }, - .. - } => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?, - i @ ptx_parser::Instruction::Rsqrt { - data: - ast::TypeFtz { - type_: ast::ScalarType::F32, - flush_to_zero: None | Some(false), - }, - .. - } => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?, - i @ ptx_parser::Instruction::Rcp { - data: - ast::RcpData { - kind: ast::RcpKind::Approx, - type_: ast::ScalarType::F32, - flush_to_zero: None | Some(false), - }, - .. - } => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?, - i @ ptx_parser::Instruction::Ex2 { - data: - ast::TypeFtz { - type_: ast::ScalarType::F32, - flush_to_zero: None | Some(false), - }, - .. - } => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?, - i @ ptx_parser::Instruction::Lg2 { - data: ast::FlushToZero { - flush_to_zero: false, - }, - .. - } => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?, - i @ ptx_parser::Instruction::Activemask { .. } => { - to_call(resolver, fn_declarations, "activemask".into(), i)? - } - i @ ptx_parser::Instruction::Bfe { data, .. } => { - let name = ["bfe_", scalar_to_ptx_name(data)].concat(); - to_call(resolver, fn_declarations, name.into(), i)? - } - i @ ptx_parser::Instruction::Bfi { data, .. } => { - let name = ["bfi_", scalar_to_ptx_name(data)].concat(); - to_call(resolver, fn_declarations, name.into(), i)? - } - i @ ptx_parser::Instruction::Bar { .. } => { - to_call(resolver, fn_declarations, "bar_sync".into(), i)? - } - ptx_parser::Instruction::BarRed { data, arguments } => { - if arguments.src_threadcount.is_some() { - return Err(error_todo()); - } - let name = match data.pred_reduction { - ptx_parser::Reduction::And => "bar_red_and_pred", - ptx_parser::Reduction::Or => "bar_red_or_pred", - }; - to_call( - resolver, - fn_declarations, - name.into(), - ptx_parser::Instruction::BarRed { data, arguments }, - )? - } - ptx_parser::Instruction::ShflSync { data, arguments } => { - let mode = match data.mode { - ptx_parser::ShuffleMode::Up => "up", - ptx_parser::ShuffleMode::Down => "down", - ptx_parser::ShuffleMode::BFly => "bfly", - ptx_parser::ShuffleMode::Idx => "idx", - }; - let pred = if arguments.dst_pred.is_some() { - "_pred" - } else { - "" - }; - to_call( - resolver, - fn_declarations, - format!("shfl_sync_{}_b32{}", mode, pred).into(), - ptx_parser::Instruction::ShflSync { data, arguments }, - )? - } - i @ ptx_parser::Instruction::Nanosleep { .. } => { - to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)? - } - i => i, - }) -} - -fn to_call<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - fn_declarations: &mut FxHashMap< - Cow<'input, str>, - ( - Vec>, - SpirvWord, - Vec>, - ), - >, - name: Cow<'input, str>, - i: ast::Instruction, -) -> Result, TranslateError> { - let mut data_return = Vec::new(); - let mut data_input = Vec::new(); - let mut arguments_return = Vec::new(); - let mut arguments_input = Vec::new(); - ast::visit(&i, &mut |name: &SpirvWord, - type_space: Option<( - &ptx_parser::Type, - ptx_parser::StateSpace, - )>, - is_dst: bool, - _: bool| { - let (type_, space) = type_space.ok_or_else(error_mismatched_type)?; - if is_dst { - data_return.push((type_.clone(), space)); - arguments_return.push(*name); - } else { - data_input.push((type_.clone(), space)); - arguments_input.push(*name); - }; - Ok::<_, TranslateError>(()) - })?; - let fn_name = match fn_declarations.entry(name) { - hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, - hash_map::Entry::Vacant(vacant_entry) => { - let name = vacant_entry.key().clone(); - let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); - let name = resolver.register_named(Cow::Owned(full_name.clone()), None); - vacant_entry.insert(( - to_variables(resolver, &data_return), - name, - to_variables(resolver, &data_input), - )); - name - } - }; - Ok(ast::Instruction::Call { - data: ptx_parser::CallDetails { - uniform: false, - return_arguments: data_return, - input_arguments: data_input, - }, - arguments: ptx_parser::CallArgs { - return_arguments: arguments_return, - func: fn_name, - input_arguments: arguments_input, - }, - }) -} - -fn to_variables<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, -) -> Vec> { - arguments - .iter() - .map(|(type_, space)| ast::Variable { - align: None, - v_type: type_.clone(), - state_space: *space, - name: resolver.register_unnamed(Some((type_.clone(), *space))), - array_init: Vec::new(), - }) - .collect::>() -} +use super::*; + +pub(super) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let mut fn_declarations = FxHashMap::default(); + let remapped_directives = directives + .into_iter() + .map(|directive| run_directive(resolver, &mut fn_declarations, directive)) + .collect::, _>>()?; + let mut result = fn_declarations + .into_iter() + .map(|(_, (return_arguments, name, input_arguments))| { + Directive2::Method(Function2 { + return_arguments, + name: name, + input_arguments, + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + is_kernel: false, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + }) + }) + .collect::>(); + result.extend(remapped_directives); + Ok(result) +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(mut method) => { + method.body = method + .body + .map(|statements| run_statements(resolver, fn_declarations, statements)) + .transpose()?; + Directive2::Method(method) + } + }) +} + +fn run_statements<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + statements: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + statements + .into_iter() + .map(|statement| { + Ok(match statement { + Statement::Instruction(instruction) => { + Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?) + } + s => s, + }) + }) + .collect::, _>>() +} + +fn run_instruction<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + instruction: ptx_parser::Instruction, +) -> Result, TranslateError> { + Ok(match instruction { + i @ ptx_parser::Instruction::Sqrt { + data: + ast::RcpData { + kind: ast::RcpKind::Approx, + type_: ast::ScalarType::F32, + flush_to_zero: None | Some(false), + }, + .. + } => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Rsqrt { + data: + ast::TypeFtz { + type_: ast::ScalarType::F32, + flush_to_zero: None | Some(false), + }, + .. + } => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Rcp { + data: + ast::RcpData { + kind: ast::RcpKind::Approx, + type_: ast::ScalarType::F32, + flush_to_zero: None | Some(false), + }, + .. + } => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Ex2 { + data: + ast::TypeFtz { + type_: ast::ScalarType::F32, + flush_to_zero: None | Some(false), + }, + .. + } => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Lg2 { + data: ast::FlushToZero { + flush_to_zero: false, + }, + .. + } => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?, + i @ ptx_parser::Instruction::Activemask { .. } => { + to_call(resolver, fn_declarations, "activemask".into(), i)? + } + i @ ptx_parser::Instruction::Bfe { data, .. } => { + let name = ["bfe_", scalar_to_ptx_name(data)].concat(); + to_call(resolver, fn_declarations, name.into(), i)? + } + i @ ptx_parser::Instruction::Bfi { data, .. } => { + let name = ["bfi_", scalar_to_ptx_name(data)].concat(); + to_call(resolver, fn_declarations, name.into(), i)? + } + i @ ptx_parser::Instruction::Bar { .. } => { + to_call(resolver, fn_declarations, "bar_sync".into(), i)? + } + ptx_parser::Instruction::BarRed { data, arguments } => { + if arguments.src_threadcount.is_some() { + return Err(error_todo()); + } + let name = match data.pred_reduction { + ptx_parser::Reduction::And => "bar_red_and_pred", + ptx_parser::Reduction::Or => "bar_red_or_pred", + }; + to_call( + resolver, + fn_declarations, + name.into(), + ptx_parser::Instruction::BarRed { data, arguments }, + )? + } + ptx_parser::Instruction::ShflSync { data, arguments } => { + let mode = match data.mode { + ptx_parser::ShuffleMode::Up => "up", + ptx_parser::ShuffleMode::Down => "down", + ptx_parser::ShuffleMode::BFly => "bfly", + ptx_parser::ShuffleMode::Idx => "idx", + }; + let pred = if arguments.dst_pred.is_some() { + "_pred" + } else { + "" + }; + to_call( + resolver, + fn_declarations, + format!("shfl_sync_{}_b32{}", mode, pred).into(), + ptx_parser::Instruction::ShflSync { data, arguments }, + )? + } + i @ ptx_parser::Instruction::Nanosleep { .. } => { + to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)? + } + i => i, + }) +} + +fn to_call<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + name: Cow<'input, str>, + i: ast::Instruction, +) -> Result, TranslateError> { + let mut data_return = Vec::new(); + let mut data_input = Vec::new(); + let mut arguments_return = Vec::new(); + let mut arguments_input = Vec::new(); + ast::visit(&i, &mut |name: &SpirvWord, + type_space: Option<( + &ptx_parser::Type, + ptx_parser::StateSpace, + )>, + is_dst: bool, + _: bool| { + let (type_, space) = type_space.ok_or_else(error_mismatched_type)?; + if is_dst { + data_return.push((type_.clone(), space)); + arguments_return.push(*name); + } else { + data_input.push((type_.clone(), space)); + arguments_input.push(*name); + }; + Ok::<_, TranslateError>(()) + })?; + let fn_name = match fn_declarations.entry(name) { + hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, + hash_map::Entry::Vacant(vacant_entry) => { + let name = vacant_entry.key().clone(); + let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); + let name = resolver.register_named(Cow::Owned(full_name.clone()), None); + vacant_entry.insert(( + to_variables(resolver, &data_return), + name, + to_variables(resolver, &data_input), + )); + name + } + }; + Ok(ast::Instruction::Call { + data: ptx_parser::CallDetails { + uniform: false, + return_arguments: data_return, + input_arguments: data_input, + }, + arguments: ptx_parser::CallArgs { + return_arguments: arguments_return, + func: fn_name, + input_arguments: arguments_input, + }, + }) +} + +fn to_variables<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, +) -> Vec> { + arguments + .iter() + .map(|(type_, space)| ast::Variable { + align: None, + v_type: type_.clone(), + state_space: *space, + name: resolver.register_unnamed(Some((type_.clone(), *space))), + array_init: Vec::new(), + }) + .collect::>() +} diff --git a/ptx/src/pass/replace_known_functions.rs b/ptx/src/pass/replace_known_functions.rs index 48f2b45..99509fe 100644 --- a/ptx/src/pass/replace_known_functions.rs +++ b/ptx/src/pass/replace_known_functions.rs @@ -1,33 +1,33 @@ -use std::borrow::Cow; - -use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord}; - -pub(crate) fn run<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - mut directives: Vec, -) -> Vec { - for directive in directives.iter_mut() { - match directive { - NormalizedDirective2::Method(func) => { - replace_with_ptx_impl(resolver, func.name); - } - _ => {} - } - } - directives -} - -fn replace_with_ptx_impl<'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, - fn_name: SpirvWord, -) { - let known_names = ["__assertfail"]; - if let Some(super::IdentEntry { - name: Some(name), .. - }) = resolver.ident_map.get_mut(&fn_name) - { - if known_names.contains(&&**name) { - *name = Cow::Owned(format!("__zluda_ptx_impl_{}", name)); - } - } -} +use std::borrow::Cow; + +use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord}; + +pub(crate) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + mut directives: Vec, +) -> Vec { + for directive in directives.iter_mut() { + match directive { + NormalizedDirective2::Method(func) => { + replace_with_ptx_impl(resolver, func.name); + } + _ => {} + } + } + directives +} + +fn replace_with_ptx_impl<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_name: SpirvWord, +) { + let known_names = ["__assertfail"]; + if let Some(super::IdentEntry { + name: Some(name), .. + }) = resolver.ident_map.get_mut(&fn_name) + { + if known_names.contains(&&**name) { + *name = Cow::Owned(format!("__zluda_ptx_impl_{}", name)); + } + } +} diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs index 81b9f0a..a9448c7 100644 --- a/ptx/src/pass/resolve_function_pointers.rs +++ b/ptx/src/pass/resolve_function_pointers.rs @@ -1,69 +1,69 @@ -use super::*; -use ptx_parser as ast; -use rustc_hash::FxHashSet; - -pub(crate) fn run<'input>( - directives: Vec, -) -> Result, TranslateError> { - let mut functions = FxHashSet::default(); - directives - .into_iter() - .map(|directive| run_directive(&mut functions, directive)) - .collect::, _>>() -} - -fn run_directive<'input>( - functions: &mut FxHashSet, - directive: UnconditionalDirective, -) -> Result { - Ok(match directive { - var @ Directive2::Variable(..) => var, - Directive2::Method(method) => { - if !method.is_kernel { - functions.insert(method.name); - } - Directive2::Method(run_method(functions, method)?) - } - }) -} - -fn run_method<'input>( - functions: &mut FxHashSet, - method: UnconditionalFunction, -) -> Result { - let body = method - .body - .map(|statements| { - statements - .into_iter() - .map(|statement| run_statement(functions, statement)) - .collect::, _>>() - }) - .transpose()?; - Ok(Function2 { body, ..method }) -} - -fn run_statement<'input>( - functions: &mut FxHashSet, - statement: UnconditionalStatement, -) -> Result { - Ok(match statement { - Statement::Instruction(ast::Instruction::Mov { - data, - arguments: - ast::MovArgs { - dst: ast::ParsedOperand::Reg(dst_reg), - src: ast::ParsedOperand::Reg(src_reg), - }, - }) if functions.contains(&src_reg) => { - if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { - return Err(error_mismatched_type()); - } - UnconditionalStatement::FunctionPointer(FunctionPointerDetails { - dst: dst_reg, - src: src_reg, - }) - } - s => s, - }) -} +use super::*; +use ptx_parser as ast; +use rustc_hash::FxHashSet; + +pub(crate) fn run<'input>( + directives: Vec, +) -> Result, TranslateError> { + let mut functions = FxHashSet::default(); + directives + .into_iter() + .map(|directive| run_directive(&mut functions, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + functions: &mut FxHashSet, + directive: UnconditionalDirective, +) -> Result { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => { + if !method.is_kernel { + functions.insert(method.name); + } + Directive2::Method(run_method(functions, method)?) + } + }) +} + +fn run_method<'input>( + functions: &mut FxHashSet, + method: UnconditionalFunction, +) -> Result { + let body = method + .body + .map(|statements| { + statements + .into_iter() + .map(|statement| run_statement(functions, statement)) + .collect::, _>>() + }) + .transpose()?; + Ok(Function2 { body, ..method }) +} + +fn run_statement<'input>( + functions: &mut FxHashSet, + statement: UnconditionalStatement, +) -> Result { + Ok(match statement { + Statement::Instruction(ast::Instruction::Mov { + data, + arguments: + ast::MovArgs { + dst: ast::ParsedOperand::Reg(dst_reg), + src: ast::ParsedOperand::Reg(src_reg), + }, + }) if functions.contains(&src_reg) => { + if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { + return Err(error_mismatched_type()); + } + UnconditionalStatement::FunctionPointer(FunctionPointerDetails { + dst: dst_reg, + src: src_reg, + }) + } + s => s, + }) +} diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 50a4124..8d5e0be 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -1,327 +1,327 @@ -use bpaf::{Args, Bpaf, Parser}; -use cargo_metadata::{MetadataCommand, Package}; -use serde::Deserialize; -use std::{env, ffi::OsString, path::PathBuf, process::Command}; - -#[derive(Debug, Clone, Bpaf)] -#[bpaf(options)] -enum Options { - #[bpaf(command)] - /// Compile ZLUDA (default command) - Build(#[bpaf(external(build))] Build), - #[bpaf(command)] - /// Compile ZLUDA and build a package - Zip(#[bpaf(external(build))] Build), -} - -#[derive(Debug, Clone, Bpaf)] -struct Build { - #[bpaf(any("CARGO", not_help), many)] - /// Arguments to pass to cargo, e.g. `--release` for release build - cargo_arguments: Vec, -} - -fn not_help(s: OsString) -> Option { - if s == "-h" || s == "--help" { - None - } else { - Some(s) - } -} - -// We need to sniff out some args passed to cargo to understand how to create -// symlinks (should they go into `target/debug`, `target/release` or custom) -#[derive(Debug, Clone, Bpaf)] -struct Cargo { - #[bpaf(switch, long, short)] - release: Option, - #[bpaf(long)] - profile: Option, - #[bpaf(any("", Some), many)] - _unused: Vec, -} - -struct Project { - name: String, - target_name: String, - target_kind: ProjectTarget, - meta: ZludaMetadata, -} - -impl Project { - fn try_new(p: Package) -> Option { - let name = p.name; - serde_json::from_value::>(p.metadata) - .unwrap() - .map(|m| { - let (target_name, target_kind) = p - .targets - .into_iter() - .find_map(|target| { - if target.is_cdylib() { - Some((target.name, ProjectTarget::Cdylib)) - } else if target.is_bin() { - Some((target.name, ProjectTarget::Bin)) - } else { - None - } - }) - .unwrap(); - Self { - name, - target_name, - target_kind, - meta: m.zluda, - } - }) - } - - #[cfg(unix)] - fn prefix(&self) -> &'static str { - match self.target_kind { - ProjectTarget::Bin => "", - ProjectTarget::Cdylib => "lib", - } - } - - #[cfg(not(unix))] - fn prefix(&self) -> &'static str { - "" - } - - #[cfg(unix)] - fn suffix(&self) -> &'static str { - match self.target_kind { - ProjectTarget::Bin => "", - ProjectTarget::Cdylib => ".so", - } - } - - #[cfg(not(unix))] - fn suffix(&self) -> &'static str { - match self.target_kind { - ProjectTarget::Bin => ".exe", - ProjectTarget::Cdylib => ".dll", - } - } - - // Returns tuple: - // * symlink file path (relative to the root of build dir) - // * symlink absolute file path - // * target actual file (relative to symlink file) - #[cfg_attr(not(unix), allow(unused))] - fn symlinks<'a>( - &'a self, - target_dir: &'a PathBuf, - profile: &'a str, - libname: &'a str, - ) -> impl Iterator + 'a { - self.meta.linux_symlinks.iter().map(move |source| { - let mut link = target_dir.clone(); - link.extend([profile, source]); - let relative_link = PathBuf::from(source); - let ancestors = relative_link.as_path().ancestors().count(); - let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold( - PathBuf::new(), - |mut buff, segment| { - buff.push(segment); - buff - }, - ); - target.push(libname); - (&**source, link, target) - }) - } - - fn file_name(&self) -> String { - let target_name = &self.target_name; - let prefix = self.prefix(); - let suffix = self.suffix(); - format!("{prefix}{target_name}{suffix}") - } -} - -#[derive(Clone, Copy)] -enum ProjectTarget { - Cdylib, - Bin, -} - -#[derive(Deserialize)] -struct Metadata { - zluda: ZludaMetadata, -} - -#[derive(Deserialize)] -#[serde(deny_unknown_fields)] -struct ZludaMetadata { - #[serde(default)] - windows_only: bool, - #[serde(default)] - debug_only: bool, - #[cfg_attr(not(unix), allow(unused))] - #[serde(default)] - linux_symlinks: Vec, -} - -fn main() { - let options = match options().run_inner(Args::current_args()) { - Ok(b) => b, - Err(err) => match build().to_options().run_inner(Args::current_args()) { - Ok(b) => Options::Build(b), - Err(_) => { - err.print_message(100); - std::process::exit(err.exit_code()); - } - }, - }; - match options { - Options::Build(b) => { - compile(b); - } - Options::Zip(b) => zip(b), - } -} - -fn compile(b: Build) -> (PathBuf, String, Vec) { - let profile = sniff_out_profile_name(&b.cargo_arguments); - let meta = MetadataCommand::new().no_deps().exec().unwrap(); - let target_directory = meta.target_directory.into_std_path_buf(); - let projects = meta - .packages - .into_iter() - .filter_map(Project::try_new) - .filter(|project| { - if project.meta.windows_only && cfg!(not(windows)) { - return false; - } - if project.meta.debug_only && profile != "debug" { - return false; - } - true - }) - .collect::>(); - let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string()); - let mut command = Command::new(&cargo); - command.arg("build"); - command.arg("--locked"); - for project in projects.iter() { - command.arg("--package"); - command.arg(&project.name); - } - command.args(b.cargo_arguments); - assert!(command.status().unwrap().success()); - os::make_symlinks(&target_directory, &*projects, &*profile); - (target_directory, profile, projects) -} - -fn sniff_out_profile_name(b: &[OsString]) -> String { - let parsed_cargo_arguments = cargo().to_options().run_inner(b); - match parsed_cargo_arguments { - Ok(Cargo { - release: Some(true), - .. - }) => "release".to_string(), - Ok(Cargo { - profile: Some(profile), - .. - }) => profile, - _ => "debug".to_string(), - } -} - -fn zip(zip: Build) { - let (target_dir, profile, projects) = compile(zip); - os::zip(target_dir, profile, projects) -} - -#[cfg(unix)] -mod os { - use flate2::write::GzEncoder; - use flate2::Compression; - use std::{ - fs::{self, File}, - path::PathBuf, - }; - use tar::Header; - - pub fn make_symlinks( - target_directory: &std::path::PathBuf, - projects: &[super::Project], - profile: &str, - ) { - use std::os::unix::fs as unix_fs; - for project in projects.iter() { - let libname = project.file_name(); - for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) { - let mut dir = full_path.clone(); - assert!(dir.pop()); - fs::create_dir_all(dir).unwrap(); - fs::remove_file(&full_path).ok(); - unix_fs::symlink(&target, full_path).unwrap(); - } - } - } - - pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec) { - let tar_gz = - File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap(); - let enc = GzEncoder::new(tar_gz, Compression::default()); - let mut tar = tar::Builder::new(enc); - for project in projects.iter() { - let file_name = project.file_name(); - let mut file = - File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap(); - tar.append_file(format!("zluda/{file_name}"), &mut file) - .unwrap(); - for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) { - let mut header = Header::new_gnu(); - let meta = fs::symlink_metadata(&full_path).unwrap(); - header.set_metadata(&meta); - tar.append_link(&mut header, format!("zluda/{source}"), target) - .unwrap(); - } - } - tar.finish().unwrap(); - } -} - -#[cfg(not(unix))] -mod os { - use std::{fs::File, io, path::PathBuf}; - use zip::{write::SimpleFileOptions, ZipWriter}; - - pub fn make_symlinks( - _target_directory: &std::path::PathBuf, - _projects: &[super::Project], - _profile: &str, - ) { - } - - pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec) { - let zip_file = - File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap(); - let mut zip = ZipWriter::new(zip_file); - zip.add_directory("zluda", SimpleFileOptions::default()) - .unwrap(); - for project in projects.iter() { - let file_name = project.file_name(); - let mut file = - File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap(); - let file_options = file_options_from_time(&file).unwrap_or_default(); - zip.start_file(format!("zluda/{file_name}"), file_options) - .unwrap(); - io::copy(&mut file, &mut zip).unwrap(); - } - zip.finish().unwrap(); - } - - fn file_options_from_time(from: &File) -> io::Result { - let metadata = from.metadata()?; - let modified = metadata.modified()?; - let modified = time::OffsetDateTime::from(modified); - Ok(SimpleFileOptions::default().last_modified_time( - zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?, - )) - } -} +use bpaf::{Args, Bpaf, Parser}; +use cargo_metadata::{MetadataCommand, Package}; +use serde::Deserialize; +use std::{env, ffi::OsString, path::PathBuf, process::Command}; + +#[derive(Debug, Clone, Bpaf)] +#[bpaf(options)] +enum Options { + #[bpaf(command)] + /// Compile ZLUDA (default command) + Build(#[bpaf(external(build))] Build), + #[bpaf(command)] + /// Compile ZLUDA and build a package + Zip(#[bpaf(external(build))] Build), +} + +#[derive(Debug, Clone, Bpaf)] +struct Build { + #[bpaf(any("CARGO", not_help), many)] + /// Arguments to pass to cargo, e.g. `--release` for release build + cargo_arguments: Vec, +} + +fn not_help(s: OsString) -> Option { + if s == "-h" || s == "--help" { + None + } else { + Some(s) + } +} + +// We need to sniff out some args passed to cargo to understand how to create +// symlinks (should they go into `target/debug`, `target/release` or custom) +#[derive(Debug, Clone, Bpaf)] +struct Cargo { + #[bpaf(switch, long, short)] + release: Option, + #[bpaf(long)] + profile: Option, + #[bpaf(any("", Some), many)] + _unused: Vec, +} + +struct Project { + name: String, + target_name: String, + target_kind: ProjectTarget, + meta: ZludaMetadata, +} + +impl Project { + fn try_new(p: Package) -> Option { + let name = p.name; + serde_json::from_value::>(p.metadata) + .unwrap() + .map(|m| { + let (target_name, target_kind) = p + .targets + .into_iter() + .find_map(|target| { + if target.is_cdylib() { + Some((target.name, ProjectTarget::Cdylib)) + } else if target.is_bin() { + Some((target.name, ProjectTarget::Bin)) + } else { + None + } + }) + .unwrap(); + Self { + name, + target_name, + target_kind, + meta: m.zluda, + } + }) + } + + #[cfg(unix)] + fn prefix(&self) -> &'static str { + match self.target_kind { + ProjectTarget::Bin => "", + ProjectTarget::Cdylib => "lib", + } + } + + #[cfg(not(unix))] + fn prefix(&self) -> &'static str { + "" + } + + #[cfg(unix)] + fn suffix(&self) -> &'static str { + match self.target_kind { + ProjectTarget::Bin => "", + ProjectTarget::Cdylib => ".so", + } + } + + #[cfg(not(unix))] + fn suffix(&self) -> &'static str { + match self.target_kind { + ProjectTarget::Bin => ".exe", + ProjectTarget::Cdylib => ".dll", + } + } + + // Returns tuple: + // * symlink file path (relative to the root of build dir) + // * symlink absolute file path + // * target actual file (relative to symlink file) + #[cfg_attr(not(unix), allow(unused))] + fn symlinks<'a>( + &'a self, + target_dir: &'a PathBuf, + profile: &'a str, + libname: &'a str, + ) -> impl Iterator + 'a { + self.meta.linux_symlinks.iter().map(move |source| { + let mut link = target_dir.clone(); + link.extend([profile, source]); + let relative_link = PathBuf::from(source); + let ancestors = relative_link.as_path().ancestors().count(); + let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold( + PathBuf::new(), + |mut buff, segment| { + buff.push(segment); + buff + }, + ); + target.push(libname); + (&**source, link, target) + }) + } + + fn file_name(&self) -> String { + let target_name = &self.target_name; + let prefix = self.prefix(); + let suffix = self.suffix(); + format!("{prefix}{target_name}{suffix}") + } +} + +#[derive(Clone, Copy)] +enum ProjectTarget { + Cdylib, + Bin, +} + +#[derive(Deserialize)] +struct Metadata { + zluda: ZludaMetadata, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ZludaMetadata { + #[serde(default)] + windows_only: bool, + #[serde(default)] + debug_only: bool, + #[cfg_attr(not(unix), allow(unused))] + #[serde(default)] + linux_symlinks: Vec, +} + +fn main() { + let options = match options().run_inner(Args::current_args()) { + Ok(b) => b, + Err(err) => match build().to_options().run_inner(Args::current_args()) { + Ok(b) => Options::Build(b), + Err(_) => { + err.print_message(100); + std::process::exit(err.exit_code()); + } + }, + }; + match options { + Options::Build(b) => { + compile(b); + } + Options::Zip(b) => zip(b), + } +} + +fn compile(b: Build) -> (PathBuf, String, Vec) { + let profile = sniff_out_profile_name(&b.cargo_arguments); + let meta = MetadataCommand::new().no_deps().exec().unwrap(); + let target_directory = meta.target_directory.into_std_path_buf(); + let projects = meta + .packages + .into_iter() + .filter_map(Project::try_new) + .filter(|project| { + if project.meta.windows_only && cfg!(not(windows)) { + return false; + } + if project.meta.debug_only && profile != "debug" { + return false; + } + true + }) + .collect::>(); + let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string()); + let mut command = Command::new(&cargo); + command.arg("build"); + command.arg("--locked"); + for project in projects.iter() { + command.arg("--package"); + command.arg(&project.name); + } + command.args(b.cargo_arguments); + assert!(command.status().unwrap().success()); + os::make_symlinks(&target_directory, &*projects, &*profile); + (target_directory, profile, projects) +} + +fn sniff_out_profile_name(b: &[OsString]) -> String { + let parsed_cargo_arguments = cargo().to_options().run_inner(b); + match parsed_cargo_arguments { + Ok(Cargo { + release: Some(true), + .. + }) => "release".to_string(), + Ok(Cargo { + profile: Some(profile), + .. + }) => profile, + _ => "debug".to_string(), + } +} + +fn zip(zip: Build) { + let (target_dir, profile, projects) = compile(zip); + os::zip(target_dir, profile, projects) +} + +#[cfg(unix)] +mod os { + use flate2::write::GzEncoder; + use flate2::Compression; + use std::{ + fs::{self, File}, + path::PathBuf, + }; + use tar::Header; + + pub fn make_symlinks( + target_directory: &std::path::PathBuf, + projects: &[super::Project], + profile: &str, + ) { + use std::os::unix::fs as unix_fs; + for project in projects.iter() { + let libname = project.file_name(); + for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) { + let mut dir = full_path.clone(); + assert!(dir.pop()); + fs::create_dir_all(dir).unwrap(); + fs::remove_file(&full_path).ok(); + unix_fs::symlink(&target, full_path).unwrap(); + } + } + } + + pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec) { + let tar_gz = + File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap(); + let enc = GzEncoder::new(tar_gz, Compression::default()); + let mut tar = tar::Builder::new(enc); + for project in projects.iter() { + let file_name = project.file_name(); + let mut file = + File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap(); + tar.append_file(format!("zluda/{file_name}"), &mut file) + .unwrap(); + for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) { + let mut header = Header::new_gnu(); + let meta = fs::symlink_metadata(&full_path).unwrap(); + header.set_metadata(&meta); + tar.append_link(&mut header, format!("zluda/{source}"), target) + .unwrap(); + } + } + tar.finish().unwrap(); + } +} + +#[cfg(not(unix))] +mod os { + use std::{fs::File, io, path::PathBuf}; + use zip::{write::SimpleFileOptions, ZipWriter}; + + pub fn make_symlinks( + _target_directory: &std::path::PathBuf, + _projects: &[super::Project], + _profile: &str, + ) { + } + + pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec) { + let zip_file = + File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap(); + let mut zip = ZipWriter::new(zip_file); + zip.add_directory("zluda", SimpleFileOptions::default()) + .unwrap(); + for project in projects.iter() { + let file_name = project.file_name(); + let mut file = + File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap(); + let file_options = file_options_from_time(&file).unwrap_or_default(); + zip.start_file(format!("zluda/{file_name}"), file_options) + .unwrap(); + io::copy(&mut file, &mut zip).unwrap(); + } + zip.finish().unwrap(); + } + + fn file_options_from_time(from: &File) -> io::Result { + let metadata = from.metadata()?; + let modified = metadata.modified()?; + let modified = time::OffsetDateTime::from(modified); + Ok(SimpleFileOptions::default().last_modified_time( + zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?, + )) + } +} diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index 43e3d42..7346d62 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -1,426 +1,426 @@ -use super::{FromCuda, LiveCheck}; -use crate::r#impl::{context, device}; -use comgr::Comgr; -use cuda_types::cuda::*; -use hip_runtime_sys::*; -use std::{ - ffi::{c_void, CStr, CString}, - mem, ptr, slice, - sync::OnceLock, - usize, -}; - -#[cfg_attr(windows, path = "os_win.rs")] -#[cfg_attr(not(windows), path = "os_unix.rs")] -mod os; - -pub(crate) struct GlobalState { - pub devices: Vec, - pub comgr: Comgr, -} - -pub(crate) struct Device { - pub(crate) _comgr_isa: CString, - primary_context: LiveCheck, -} - -impl Device { - pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) { - unsafe { - ( - self.primary_context.data.assume_init_ref(), - self.primary_context.as_handle(), - ) - } - } -} - -pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> { - global_state()? - .devices - .get(dev as usize) - .ok_or(CUerror::INVALID_DEVICE) -} - -pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> { - static GLOBAL_STATE: OnceLock> = OnceLock::new(); - fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] { - unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) } - } - GLOBAL_STATE - .get_or_init(|| { - let mut device_count = 0; - unsafe { hipGetDeviceCount(&mut device_count) }?; - let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?; - Ok(GlobalState { - comgr, - devices: (0..device_count) - .map(|i| { - let mut props = unsafe { mem::zeroed() }; - unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?; - Ok::<_, CUerror>(Device { - _comgr_isa: CStr::from_bytes_until_nul(cast_slice( - &props.gcnArchName[..], - )) - .map_err(|_| CUerror::UNKNOWN)? - .to_owned(), - primary_context: LiveCheck::new(context::Context::new(i)), - }) - }) - .collect::, _>>()?, - }) - }) - .as_ref() - .map_err(|e| *e) -} - -pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult { - unsafe { hipInit(flags) }?; - global_state()?; - Ok(()) -} - -struct UnknownBuffer { - buffer: std::cell::UnsafeCell<[u32; S]>, -} - -impl UnknownBuffer { - const fn new() -> Self { - UnknownBuffer { - buffer: std::cell::UnsafeCell::new([0; S]), - } - } - const fn len(&self) -> usize { - S - } -} - -unsafe impl Sync for UnknownBuffer {} - -static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new(); -static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new(); - -struct DarkApi {} - -impl ::dark_api::cuda::CudaDarkApi for DarkApi { - unsafe extern "system" fn get_module_from_cubin( - _module: *mut cuda_types::cuda::CUmodule, - _fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper, - ) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn cudart_interface_fn2( - pctx: *mut cuda_types::cuda::CUcontext, - hip_dev: hipDevice_t, - ) -> cuda_types::cuda::CUresult { - let pctx = match pctx.as_mut() { - Some(p) => p, - None => return CUresult::ERROR_INVALID_VALUE, - }; - - device::primary_context_retain(pctx, hip_dev) - } - - unsafe extern "system" fn get_module_from_cubin_ext1( - _result: *mut cuda_types::cuda::CUmodule, - _fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper, - _arg3: *mut std::ffi::c_void, - _arg4: *mut std::ffi::c_void, - _arg5: u32, - ) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn get_module_from_cubin_ext2( - _fatbin_header: *const cuda_types::dark_api::FatbinHeader, - _result: *mut cuda_types::cuda::CUmodule, - _arg3: *mut std::ffi::c_void, - _arg4: *mut std::ffi::c_void, - _arg5: u32, - ) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn get_unknown_buffer1( - ptr: *mut *mut std::ffi::c_void, - size: *mut usize, - ) -> () { - *ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void; - *size = UNKNOWN_BUFFER1.len(); - } - - unsafe extern "system" fn get_unknown_buffer2( - ptr: *mut *mut std::ffi::c_void, - size: *mut usize, - ) -> () { - *ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void; - *size = UNKNOWN_BUFFER2.len(); - } - - unsafe extern "system" fn context_local_storage_put( - cu_ctx: CUcontext, - key: *mut c_void, - value: *mut c_void, - dtor_cb: Option, - ) -> CUresult { - let _ctx = if cu_ctx.0 != ptr::null_mut() { - cu_ctx - } else { - let mut current_ctx: CUcontext = CUcontext(ptr::null_mut()); - context::get_current(&mut current_ctx)?; - current_ctx - }; - let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?; - ctx_obj.with_state_mut(|state: &mut context::ContextState| { - state.storage.insert( - key as usize, - context::StorageData { - value: value as usize, - reset_cb: dtor_cb, - handle: _ctx, - }, - ); - Ok(()) - })?; - Ok(()) - } - - unsafe extern "system" fn context_local_storage_delete( - cu_ctx: CUcontext, - key: *mut c_void, - ) -> CUresult { - let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?; - ctx_obj.with_state_mut(|state: &mut context::ContextState| { - state.storage.remove(&(key as usize)); - Ok(()) - })?; - Ok(()) - } - - unsafe extern "system" fn context_local_storage_get( - value: *mut *mut c_void, - cu_ctx: CUcontext, - key: *mut c_void, - ) -> CUresult { - let mut _ctx: CUcontext; - if cu_ctx.0 == ptr::null_mut() { - _ctx = context::get_current_context()?; - } else { - _ctx = cu_ctx - }; - let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?; - ctx_obj.with_state(|state: &context::ContextState| { - match state.storage.get(&(key as usize)) { - Some(data) => *value = data.value as *mut c_void, - None => return CUresult::ERROR_INVALID_HANDLE, - } - Ok(()) - })?; - Ok(()) - } - - unsafe extern "system" fn ctx_create_v2_bypass( - _pctx: *mut cuda_types::cuda::CUcontext, - _flags: ::std::os::raw::c_uint, - _dev: cuda_types::cuda::CUdevice, - ) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn heap_alloc( - _heap_alloc_record_ptr: *mut *const std::ffi::c_void, - _arg2: usize, - _arg3: usize, - ) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn heap_free( - _heap_alloc_record_ptr: *const std::ffi::c_void, - _arg2: *mut usize, - ) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn device_get_attribute_ext( - _dev: cuda_types::cuda::CUdevice, - _attribute: std::ffi::c_uint, - _unknown: std::ffi::c_int, - _result: *mut [usize; 2], - ) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn device_get_something( - _result: *mut std::ffi::c_uchar, - _dev: cuda_types::cuda::CUdevice, - ) -> cuda_types::cuda::CUresult { - todo!() - } - - unsafe extern "system" fn integrity_check( - version: u32, - unix_seconds: u64, - result: *mut [u64; 2], - ) -> cuda_types::cuda::CUresult { - let current_process = std::process::id(); - let current_thread = os::current_thread(); - - let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast(); - let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast(); - let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1]; - - let devices = get_device_hash_info()?; - let device_count = devices.len() as u32; - let get_device = |dev| devices[dev as usize]; - - let hash = ::dark_api::integrity_check( - version, - unix_seconds, - cuda_types::cuda::CUDA_VERSION, - current_process, - current_thread, - integrity_check_table, - cudart_table, - fn_address, - device_count, - get_device, - ); - *result = hash; - Ok(()) - } - - unsafe extern "system" fn context_check( - _ctx_in: cuda_types::cuda::CUcontext, - result1: *mut u32, - _result2: *mut *const std::ffi::c_void, - ) -> cuda_types::cuda::CUresult { - *result1 = 0; - CUresult::SUCCESS - } - - unsafe extern "system" fn check_fn3() -> u32 { - 0 - } -} - -fn get_device_hash_info() -> Result, CUerror> { - let mut device_count = 0; - device::get_count(&mut device_count)?; - - (0..device_count) - .map(|dev| { - let mut guid = CUuuid_st { bytes: [0; 16] }; - unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? }; - - let mut pci_domain = 0; - device::get_attribute( - &mut pci_domain, - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, - dev, - )?; - - let mut pci_bus = 0; - device::get_attribute( - &mut pci_bus, - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, - dev, - )?; - - let mut pci_device = 0; - device::get_attribute( - &mut pci_device, - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, - dev, - )?; - - Ok(::dark_api::DeviceHashinfo { - guid, - pci_domain, - pci_bus, - pci_device, - }) - }) - .collect() -} - -static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable = - ::dark_api::cuda::CudaDarkApiGlobalTable::new::(); - -pub(crate) fn get_export_table( - pp_export_table: &mut *const ::core::ffi::c_void, - p_export_table_id: &CUuuid, -) -> CUresult { - if let Some(table) = EXPORT_TABLE.get(p_export_table_id) { - *pp_export_table = table.start(); - cuda_types::cuda::CUresult::SUCCESS - } else { - cuda_types::cuda::CUresult::ERROR_INVALID_VALUE - } -} - -pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult { - *version = cuda_types::cuda::CUDA_VERSION as i32; - Ok(()) -} - -pub(crate) unsafe fn get_proc_address( - symbol: &CStr, - pfn: &mut *mut ::core::ffi::c_void, - cuda_version: ::core::ffi::c_int, - flags: cuda_types::cuda::cuuint64_t, -) -> CUresult { - get_proc_address_v2(symbol, pfn, cuda_version, flags, None) -} - -pub(crate) unsafe fn get_proc_address_v2( - symbol: &CStr, - pfn: &mut *mut ::core::ffi::c_void, - cuda_version: ::core::ffi::c_int, - flags: cuda_types::cuda::cuuint64_t, - symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>, -) -> CUresult { - // This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point. - fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void { - use crate::*; - include!("../../../zluda_bindgen/src/process_table.rs") - } - let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version); - match fn_ptr as usize { - 0 => { - if let Some(symbol_status) = symbol_status { - *symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND; - } - *pfn = ptr::null_mut(); - CUresult::ERROR_NOT_FOUND - } - usize::MAX => { - if let Some(symbol_status) = symbol_status { - *symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT; - } - *pfn = ptr::null_mut(); - CUresult::ERROR_NOT_FOUND - } - _ => { - if let Some(symbol_status) = symbol_status { - *symbol_status = - cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS; - } - *pfn = fn_ptr; - Ok(()) - } - } -} - -pub(crate) fn profiler_start() -> CUresult { - Ok(()) -} - -pub(crate) fn profiler_stop() -> CUresult { - Ok(()) -} +use super::{FromCuda, LiveCheck}; +use crate::r#impl::{context, device}; +use comgr::Comgr; +use cuda_types::cuda::*; +use hip_runtime_sys::*; +use std::{ + ffi::{c_void, CStr, CString}, + mem, ptr, slice, + sync::OnceLock, + usize, +}; + +#[cfg_attr(windows, path = "os_win.rs")] +#[cfg_attr(not(windows), path = "os_unix.rs")] +mod os; + +pub(crate) struct GlobalState { + pub devices: Vec, + pub comgr: Comgr, +} + +pub(crate) struct Device { + pub(crate) _comgr_isa: CString, + primary_context: LiveCheck, +} + +impl Device { + pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) { + unsafe { + ( + self.primary_context.data.assume_init_ref(), + self.primary_context.as_handle(), + ) + } + } +} + +pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> { + global_state()? + .devices + .get(dev as usize) + .ok_or(CUerror::INVALID_DEVICE) +} + +pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> { + static GLOBAL_STATE: OnceLock> = OnceLock::new(); + fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] { + unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) } + } + GLOBAL_STATE + .get_or_init(|| { + let mut device_count = 0; + unsafe { hipGetDeviceCount(&mut device_count) }?; + let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?; + Ok(GlobalState { + comgr, + devices: (0..device_count) + .map(|i| { + let mut props = unsafe { mem::zeroed() }; + unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?; + Ok::<_, CUerror>(Device { + _comgr_isa: CStr::from_bytes_until_nul(cast_slice( + &props.gcnArchName[..], + )) + .map_err(|_| CUerror::UNKNOWN)? + .to_owned(), + primary_context: LiveCheck::new(context::Context::new(i)), + }) + }) + .collect::, _>>()?, + }) + }) + .as_ref() + .map_err(|e| *e) +} + +pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult { + unsafe { hipInit(flags) }?; + global_state()?; + Ok(()) +} + +struct UnknownBuffer { + buffer: std::cell::UnsafeCell<[u32; S]>, +} + +impl UnknownBuffer { + const fn new() -> Self { + UnknownBuffer { + buffer: std::cell::UnsafeCell::new([0; S]), + } + } + const fn len(&self) -> usize { + S + } +} + +unsafe impl Sync for UnknownBuffer {} + +static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new(); +static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new(); + +struct DarkApi {} + +impl ::dark_api::cuda::CudaDarkApi for DarkApi { + unsafe extern "system" fn get_module_from_cubin( + _module: *mut cuda_types::cuda::CUmodule, + _fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper, + ) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn cudart_interface_fn2( + pctx: *mut cuda_types::cuda::CUcontext, + hip_dev: hipDevice_t, + ) -> cuda_types::cuda::CUresult { + let pctx = match pctx.as_mut() { + Some(p) => p, + None => return CUresult::ERROR_INVALID_VALUE, + }; + + device::primary_context_retain(pctx, hip_dev) + } + + unsafe extern "system" fn get_module_from_cubin_ext1( + _result: *mut cuda_types::cuda::CUmodule, + _fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper, + _arg3: *mut std::ffi::c_void, + _arg4: *mut std::ffi::c_void, + _arg5: u32, + ) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn get_module_from_cubin_ext2( + _fatbin_header: *const cuda_types::dark_api::FatbinHeader, + _result: *mut cuda_types::cuda::CUmodule, + _arg3: *mut std::ffi::c_void, + _arg4: *mut std::ffi::c_void, + _arg5: u32, + ) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn get_unknown_buffer1( + ptr: *mut *mut std::ffi::c_void, + size: *mut usize, + ) -> () { + *ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void; + *size = UNKNOWN_BUFFER1.len(); + } + + unsafe extern "system" fn get_unknown_buffer2( + ptr: *mut *mut std::ffi::c_void, + size: *mut usize, + ) -> () { + *ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void; + *size = UNKNOWN_BUFFER2.len(); + } + + unsafe extern "system" fn context_local_storage_put( + cu_ctx: CUcontext, + key: *mut c_void, + value: *mut c_void, + dtor_cb: Option, + ) -> CUresult { + let _ctx = if cu_ctx.0 != ptr::null_mut() { + cu_ctx + } else { + let mut current_ctx: CUcontext = CUcontext(ptr::null_mut()); + context::get_current(&mut current_ctx)?; + current_ctx + }; + let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?; + ctx_obj.with_state_mut(|state: &mut context::ContextState| { + state.storage.insert( + key as usize, + context::StorageData { + value: value as usize, + reset_cb: dtor_cb, + handle: _ctx, + }, + ); + Ok(()) + })?; + Ok(()) + } + + unsafe extern "system" fn context_local_storage_delete( + cu_ctx: CUcontext, + key: *mut c_void, + ) -> CUresult { + let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?; + ctx_obj.with_state_mut(|state: &mut context::ContextState| { + state.storage.remove(&(key as usize)); + Ok(()) + })?; + Ok(()) + } + + unsafe extern "system" fn context_local_storage_get( + value: *mut *mut c_void, + cu_ctx: CUcontext, + key: *mut c_void, + ) -> CUresult { + let mut _ctx: CUcontext; + if cu_ctx.0 == ptr::null_mut() { + _ctx = context::get_current_context()?; + } else { + _ctx = cu_ctx + }; + let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?; + ctx_obj.with_state(|state: &context::ContextState| { + match state.storage.get(&(key as usize)) { + Some(data) => *value = data.value as *mut c_void, + None => return CUresult::ERROR_INVALID_HANDLE, + } + Ok(()) + })?; + Ok(()) + } + + unsafe extern "system" fn ctx_create_v2_bypass( + _pctx: *mut cuda_types::cuda::CUcontext, + _flags: ::std::os::raw::c_uint, + _dev: cuda_types::cuda::CUdevice, + ) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn heap_alloc( + _heap_alloc_record_ptr: *mut *const std::ffi::c_void, + _arg2: usize, + _arg3: usize, + ) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn heap_free( + _heap_alloc_record_ptr: *const std::ffi::c_void, + _arg2: *mut usize, + ) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn device_get_attribute_ext( + _dev: cuda_types::cuda::CUdevice, + _attribute: std::ffi::c_uint, + _unknown: std::ffi::c_int, + _result: *mut [usize; 2], + ) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn device_get_something( + _result: *mut std::ffi::c_uchar, + _dev: cuda_types::cuda::CUdevice, + ) -> cuda_types::cuda::CUresult { + todo!() + } + + unsafe extern "system" fn integrity_check( + version: u32, + unix_seconds: u64, + result: *mut [u64; 2], + ) -> cuda_types::cuda::CUresult { + let current_process = std::process::id(); + let current_thread = os::current_thread(); + + let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast(); + let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast(); + let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1]; + + let devices = get_device_hash_info()?; + let device_count = devices.len() as u32; + let get_device = |dev| devices[dev as usize]; + + let hash = ::dark_api::integrity_check( + version, + unix_seconds, + cuda_types::cuda::CUDA_VERSION, + current_process, + current_thread, + integrity_check_table, + cudart_table, + fn_address, + device_count, + get_device, + ); + *result = hash; + Ok(()) + } + + unsafe extern "system" fn context_check( + _ctx_in: cuda_types::cuda::CUcontext, + result1: *mut u32, + _result2: *mut *const std::ffi::c_void, + ) -> cuda_types::cuda::CUresult { + *result1 = 0; + CUresult::SUCCESS + } + + unsafe extern "system" fn check_fn3() -> u32 { + 0 + } +} + +fn get_device_hash_info() -> Result, CUerror> { + let mut device_count = 0; + device::get_count(&mut device_count)?; + + (0..device_count) + .map(|dev| { + let mut guid = CUuuid_st { bytes: [0; 16] }; + unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? }; + + let mut pci_domain = 0; + device::get_attribute( + &mut pci_domain, + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, + dev, + )?; + + let mut pci_bus = 0; + device::get_attribute( + &mut pci_bus, + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, + dev, + )?; + + let mut pci_device = 0; + device::get_attribute( + &mut pci_device, + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, + dev, + )?; + + Ok(::dark_api::DeviceHashinfo { + guid, + pci_domain, + pci_bus, + pci_device, + }) + }) + .collect() +} + +static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable = + ::dark_api::cuda::CudaDarkApiGlobalTable::new::(); + +pub(crate) fn get_export_table( + pp_export_table: &mut *const ::core::ffi::c_void, + p_export_table_id: &CUuuid, +) -> CUresult { + if let Some(table) = EXPORT_TABLE.get(p_export_table_id) { + *pp_export_table = table.start(); + cuda_types::cuda::CUresult::SUCCESS + } else { + cuda_types::cuda::CUresult::ERROR_INVALID_VALUE + } +} + +pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult { + *version = cuda_types::cuda::CUDA_VERSION as i32; + Ok(()) +} + +pub(crate) unsafe fn get_proc_address( + symbol: &CStr, + pfn: &mut *mut ::core::ffi::c_void, + cuda_version: ::core::ffi::c_int, + flags: cuda_types::cuda::cuuint64_t, +) -> CUresult { + get_proc_address_v2(symbol, pfn, cuda_version, flags, None) +} + +pub(crate) unsafe fn get_proc_address_v2( + symbol: &CStr, + pfn: &mut *mut ::core::ffi::c_void, + cuda_version: ::core::ffi::c_int, + flags: cuda_types::cuda::cuuint64_t, + symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>, +) -> CUresult { + // This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point. + fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void { + use crate::*; + include!("../../../zluda_bindgen/src/process_table.rs") + } + let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version); + match fn_ptr as usize { + 0 => { + if let Some(symbol_status) = symbol_status { + *symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND; + } + *pfn = ptr::null_mut(); + CUresult::ERROR_NOT_FOUND + } + usize::MAX => { + if let Some(symbol_status) = symbol_status { + *symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT; + } + *pfn = ptr::null_mut(); + CUresult::ERROR_NOT_FOUND + } + _ => { + if let Some(symbol_status) = symbol_status { + *symbol_status = + cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS; + } + *pfn = fn_ptr; + Ok(()) + } + } +} + +pub(crate) fn profiler_start() -> CUresult { + Ok(()) +} + +pub(crate) fn profiler_stop() -> CUresult { + Ok(()) +} diff --git a/zluda/src/impl/os_unix.rs b/zluda/src/impl/os_unix.rs index f02a9d7..0149b76 100644 --- a/zluda/src/impl/os_unix.rs +++ b/zluda/src/impl/os_unix.rs @@ -1,9 +1,9 @@ -// TODO: remove duplication with zluda_dump -#[link(name = "pthread")] -unsafe extern "C" { - fn pthread_self() -> std::os::unix::thread::RawPthread; -} - -pub(crate) fn current_thread() -> u32 { - (unsafe { pthread_self() }) as u32 -} +// TODO: remove duplication with zluda_dump +#[link(name = "pthread")] +unsafe extern "C" { + fn pthread_self() -> std::os::unix::thread::RawPthread; +} + +pub(crate) fn current_thread() -> u32 { + (unsafe { pthread_self() }) as u32 +} diff --git a/zluda/src/impl/os_win.rs b/zluda/src/impl/os_win.rs index 5c7459f..b9691b0 100644 --- a/zluda/src/impl/os_win.rs +++ b/zluda/src/impl/os_win.rs @@ -1,9 +1,9 @@ -// TODO: remove duplication with zluda_dump -#[link(name = "kernel32")] -unsafe extern "system" { - fn GetCurrentThreadId() -> u32; -} - -pub(crate) fn current_thread() -> u32 { - unsafe { GetCurrentThreadId() } -} +// TODO: remove duplication with zluda_dump +#[link(name = "kernel32")] +unsafe extern "system" { + fn GetCurrentThreadId() -> u32; +} + +pub(crate) fn current_thread() -> u32 { + unsafe { GetCurrentThreadId() } +} diff --git a/zluda_dump/src/dark_api.rs b/zluda_dump/src/dark_api.rs index 2a929f5..ea52149 100644 --- a/zluda_dump/src/dark_api.rs +++ b/zluda_dump/src/dark_api.rs @@ -1,124 +1,124 @@ -use crate::os; -use crate::{CudaFunctionName, ErrorEntry}; -use cuda_types::cuda::*; -use rustc_hash::FxHashMap; -use std::cell::RefMut; -use std::hash::Hash; -use std::{collections::hash_map, ffi::c_void, mem}; - -pub(crate) struct DarkApiState2 { - // Key is Box, (*const *const c_void, Vec<*const c_void>)>, -} - -unsafe impl Send for DarkApiState2 {} -unsafe impl Sync for DarkApiState2 {} - -impl DarkApiState2 { - pub(crate) fn new() -> Self { - DarkApiState2 { - overrides: FxHashMap::default(), - } - } - - pub(crate) fn override_export_table( - &mut self, - known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable, - original_export_table: *const *const c_void, - guid: &CUuuid_st, - ) -> (*const *const c_void, Option) { - let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) { - hash_map::Entry::Occupied(entry) => { - let (_, override_table) = entry.get(); - return (override_table.as_ptr(), None); - } - hash_map::Entry::Vacant(entry) => entry, - }; - let mut error = None; - let byte_size: usize = unsafe { *(original_export_table.cast::()) }; - // Some export tables don't start with a byte count, but directly with a - // pointer, and are instead terminated by 0 or MAX - let export_functions_start_idx; - let export_functions_size; - if byte_size > 0x10000 { - export_functions_start_idx = 0; - let mut i = 0; - loop { - let current_ptr = unsafe { original_export_table.add(i) }; - let current_ptr_numeric = unsafe { *current_ptr } as usize; - if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX { - export_functions_size = i; - break; - } - i += 1; - } - } else { - export_functions_start_idx = 1; - export_functions_size = byte_size / mem::size_of::(); - } - let our_functions = known_exports.get(guid); - if let Some(ref our_functions) = our_functions { - if our_functions.len() != export_functions_size { - error = Some(ErrorEntry::UnexpectedExportTableSize { - expected: our_functions.len(), - computed: export_functions_size, - }); - } - } - let mut override_table = - unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) } - .to_vec(); - for i in export_functions_start_idx..export_functions_size { - let current_fn = (|| { - if let Some(ref our_functions) = our_functions { - if let Some(fn_) = our_functions.get_fn(i) { - return fn_; - } - } - os::get_thunk( - override_table[i], - Self::report_unknown_export_table_call, - std::ptr::from_ref(entry.key().as_ref()).cast(), - i, - ) - })(); - override_table[i] = current_fn; - } - ( - entry - .insert((original_export_table, override_table)) - .1 - .as_ptr(), - error, - ) - } - - unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) { - let global_state = crate::GLOBAL_STATE2.lock(); - let global_state_ref_cell = &*global_state; - let mut global_state_ref_mut = global_state_ref_cell.borrow_mut(); - let global_state = &mut *global_state_ref_mut; - let log_guard = crate::OuterCallGuard { - writer: &mut global_state.log_writer, - log_root: &global_state.log_stack, - }; - { - let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| { - log_stack.enter() - }); - logger.name = CudaFunctionName::Dark { guid: *guid, index }; - }; - drop(log_guard); - } -} - -#[derive(Eq, PartialEq)] -#[repr(transparent)] -pub(crate) struct CUuuidWrapper(pub CUuuid); - -impl Hash for CUuuidWrapper { - fn hash(&self, state: &mut H) { - self.0.bytes.hash(state); - } -} +use crate::os; +use crate::{CudaFunctionName, ErrorEntry}; +use cuda_types::cuda::*; +use rustc_hash::FxHashMap; +use std::cell::RefMut; +use std::hash::Hash; +use std::{collections::hash_map, ffi::c_void, mem}; + +pub(crate) struct DarkApiState2 { + // Key is Box, (*const *const c_void, Vec<*const c_void>)>, +} + +unsafe impl Send for DarkApiState2 {} +unsafe impl Sync for DarkApiState2 {} + +impl DarkApiState2 { + pub(crate) fn new() -> Self { + DarkApiState2 { + overrides: FxHashMap::default(), + } + } + + pub(crate) fn override_export_table( + &mut self, + known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable, + original_export_table: *const *const c_void, + guid: &CUuuid_st, + ) -> (*const *const c_void, Option) { + let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) { + hash_map::Entry::Occupied(entry) => { + let (_, override_table) = entry.get(); + return (override_table.as_ptr(), None); + } + hash_map::Entry::Vacant(entry) => entry, + }; + let mut error = None; + let byte_size: usize = unsafe { *(original_export_table.cast::()) }; + // Some export tables don't start with a byte count, but directly with a + // pointer, and are instead terminated by 0 or MAX + let export_functions_start_idx; + let export_functions_size; + if byte_size > 0x10000 { + export_functions_start_idx = 0; + let mut i = 0; + loop { + let current_ptr = unsafe { original_export_table.add(i) }; + let current_ptr_numeric = unsafe { *current_ptr } as usize; + if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX { + export_functions_size = i; + break; + } + i += 1; + } + } else { + export_functions_start_idx = 1; + export_functions_size = byte_size / mem::size_of::(); + } + let our_functions = known_exports.get(guid); + if let Some(ref our_functions) = our_functions { + if our_functions.len() != export_functions_size { + error = Some(ErrorEntry::UnexpectedExportTableSize { + expected: our_functions.len(), + computed: export_functions_size, + }); + } + } + let mut override_table = + unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) } + .to_vec(); + for i in export_functions_start_idx..export_functions_size { + let current_fn = (|| { + if let Some(ref our_functions) = our_functions { + if let Some(fn_) = our_functions.get_fn(i) { + return fn_; + } + } + os::get_thunk( + override_table[i], + Self::report_unknown_export_table_call, + std::ptr::from_ref(entry.key().as_ref()).cast(), + i, + ) + })(); + override_table[i] = current_fn; + } + ( + entry + .insert((original_export_table, override_table)) + .1 + .as_ptr(), + error, + ) + } + + unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) { + let global_state = crate::GLOBAL_STATE2.lock(); + let global_state_ref_cell = &*global_state; + let mut global_state_ref_mut = global_state_ref_cell.borrow_mut(); + let global_state = &mut *global_state_ref_mut; + let log_guard = crate::OuterCallGuard { + writer: &mut global_state.log_writer, + log_root: &global_state.log_stack, + }; + { + let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| { + log_stack.enter() + }); + logger.name = CudaFunctionName::Dark { guid: *guid, index }; + }; + drop(log_guard); + } +} + +#[derive(Eq, PartialEq)] +#[repr(transparent)] +pub(crate) struct CUuuidWrapper(pub CUuuid); + +impl Hash for CUuuidWrapper { + fn hash(&self, state: &mut H) { + self.0.bytes.hash(state); + } +} diff --git a/zluda_dump/src/log.rs b/zluda_dump/src/log.rs index 0bc5113..91e6b2c 100644 --- a/zluda_dump/src/log.rs +++ b/zluda_dump/src/log.rs @@ -1,668 +1,668 @@ -use super::Settings; -use crate::FnCallLog; -use crate::LogEntry; -use cuda_types::cuda::*; -use format::CudaDisplay; -use std::error::Error; -use std::ffi::c_void; -use std::ffi::NulError; -use std::fmt::Display; -use std::fs::File; -use std::io; -use std::io::Stderr; -use std::io::Write; -use std::path::PathBuf; -use std::str::Utf8Error; - -const LOG_PREFIX: &[u8] = b"[ZLUDA_DUMP] "; - -pub(crate) struct Writer { - // Fallible emitter is an optional emitter to the file system, we might lack - // file permissions or be out of disk space - fallible_emitter: Option>, - // This is emitter that "always works" (and if it does not, then we don't - // care). In addition of normal logs it emits errors from fallible emitter - infallible_emitter: Box, - // This object could be recreated every time, but it's slightly better for performance to - // reuse the allocations by keeping the object in globals - write_buffer: WriteBuffer, -} - -impl Writer { - pub(crate) fn new() -> Self { - let debug_emitter = os::new_debug_logger(); - Self { - infallible_emitter: debug_emitter, - fallible_emitter: None, - write_buffer: WriteBuffer::new(), - } - } - - pub(crate) fn late_init(&mut self, settings: &Settings) -> Result<(), ErrorEntry> { - self.fallible_emitter = settings - .dump_dir - .as_ref() - .map(|path| { - Ok::<_, std::io::Error>(Box::new(File::create(path.to_path_buf().join("log.txt"))?) - as Box) - }) - .transpose() - .map_err(ErrorEntry::IoError)?; - self.write_buffer - .init(&self.fallible_emitter, &self.infallible_emitter); - Ok(()) - } - - pub(crate) fn write_and_flush(&mut self, log_root: &mut FnCallLog) { - self.write_all_from_depth(0, log_root); - self.write_buffer.finish(); - let error_from_writing_to_fallible_emitter = match self.fallible_emitter { - Some(ref mut emitter) => self.write_buffer.send_to_and_flush(emitter), - None => Ok(()), - }; - if let Err(e) = error_from_writing_to_fallible_emitter { - self.hack_squeeze_in_additional_error(ErrorEntry::IoError(e)) - } - self.write_buffer - .send_to_and_flush(&mut self.infallible_emitter) - .ok(); - self.write_buffer.reset(); - log_root.reset(); - } - - fn write_all_from_depth(&mut self, depth: usize, fn_call: &FnCallLog) { - self.write_call(depth, fn_call); - for sub in fn_call.subcalls.iter() { - match sub { - LogEntry::FnCall(fn_call) => self.write_all_from_depth(depth + 1, fn_call), - LogEntry::Error(err) => self.write_error(depth + 1, err), - } - } - } - - fn write_call(&mut self, depth: usize, call: &FnCallLog) { - self.write_buffer.start_line(depth); - write!(self.write_buffer, "{}", call.name).ok(); - match call.args { - Some(ref args) => { - self.write_buffer.write_all(args).ok(); - } - None => { - self.write_buffer.write_all(b"(...)").ok(); - } - } - self.write_buffer.write_all(b" -> ").ok(); - if let Some(ref result) = call.output { - self.write_buffer.write_all(result).ok(); - } else { - self.write_buffer.write_all(b"UNKNOWN").ok(); - }; - self.write_buffer.end_line(); - } - - fn write_error(&mut self, depth: usize, error: &ErrorEntry) { - self.write_buffer.start_line(depth); - write!(self.write_buffer, "{}", error).ok(); - self.write_buffer.end_line(); - } - - fn hack_squeeze_in_additional_error(&mut self, entry: ErrorEntry) { - self.write_buffer.undo_finish(); - write!(self.write_buffer, " {}", entry).ok(); - self.write_buffer.end_line(); - self.write_buffer.finish(); - } -} - -// When writing out to the emitter (file, WinAPI, whatever else) instead of -// writing piece-by-piece it's better to first concatenate everything in memory -// then write out from memory to the slow emitter only once. -// Additionally we might have an unprefixed and prefixed buffer, this struct -// handles this detail -struct WriteBuffer { - prefixed_buffer: Option>, - unprefixed_buffer: Option>, -} - -impl WriteBuffer { - fn new() -> Self { - WriteBuffer { - prefixed_buffer: None, - unprefixed_buffer: None, - } - } - - fn init( - &mut self, - fallible_emitter: &Option>, - infallible_emitter: &Box, - ) { - if infallible_emitter.should_prefix() { - self.prefixed_buffer = Some(Vec::new()); - } else { - self.unprefixed_buffer = Some(Vec::new()); - } - if let Some(emitter) = fallible_emitter { - if emitter.should_prefix() { - self.prefixed_buffer = Some(Vec::new()); - } else { - self.unprefixed_buffer = Some(Vec::new()); - } - } - } - - fn all_buffers(&mut self) -> impl Iterator> { - self.prefixed_buffer - .as_mut() - .into_iter() - .chain(self.unprefixed_buffer.as_mut().into_iter()) - } - - fn start_line(&mut self, depth: usize) { - if let Some(buffer) = &mut self.prefixed_buffer { - buffer.extend_from_slice(LOG_PREFIX); - } - if depth == 0 { - return; - } - for buffer in self.all_buffers() { - buffer.extend(std::iter::repeat_n(b' ', depth * 4)); - } - } - - fn end_line(&mut self) { - for buffer in self.all_buffers() { - buffer.push(b'\n'); - } - } - - fn finish(&mut self) { - for buffer in self.all_buffers() { - buffer.push(b'\0'); - } - } - - fn undo_finish(&mut self) { - for buffer in self.all_buffers() { - buffer.truncate(buffer.len() - 1); - } - } - - fn send_to_and_flush( - &self, - log_emitter: &mut Box, - ) -> Result<(), io::Error> { - if log_emitter.should_prefix() { - log_emitter.write_zero_aware( - &*self - .prefixed_buffer - .as_ref() - .unwrap_or_else(|| unreachable!()), - )?; - } else { - log_emitter.write_zero_aware( - &*self - .unprefixed_buffer - .as_ref() - .unwrap_or_else(|| unreachable!()), - )?; - } - log_emitter.flush() - } - - fn reset(&mut self) { - for buffer in self.all_buffers() { - unsafe { buffer.set_len(0) }; - } - } -} - -impl Write for WriteBuffer { - fn write(&mut self, buf: &[u8]) -> io::Result { - if let Some(buffer) = &mut self.prefixed_buffer { - buffer.extend_from_slice(buf); - } - if let Some(buffer) = &mut self.unprefixed_buffer { - buffer.extend_from_slice(buf); - } - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -#[derive(Clone)] -pub(crate) enum CudaFunctionName { - Normal(&'static str), - Dark { guid: CUuuid, index: usize }, -} - -impl Display for CudaFunctionName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CudaFunctionName::Normal(fn_) => f.write_str(fn_), - CudaFunctionName::Dark { guid, index } => { - match ::dark_api::cuda::guid_to_name(guid, *index) { - Some((name, fn_)) => match fn_ { - Some(fn_) => write!(f, "{{{name}}}::{fn_}"), - None => write!(f, "{{{name}}}::{index}"), - }, - None => { - let mut temp = Vec::new(); - format::CudaDisplay::write(guid, "", 0, &mut temp) - .map_err(|_| std::fmt::Error::default())?; - let temp = String::from_utf8_lossy(&*temp); - write!(f, "{temp}::{index}") - } - } - } - } - } -} - -pub(crate) enum ErrorEntry { - IoError(io::Error), - CreatedDumpDirectory(PathBuf), - ErrorBox(Box), - UnsupportedModule { - module: CUmodule, - raw_image: *const c_void, - kind: &'static str, - }, - FunctionNotFound(CudaFunctionName), - MalformedModulePath(Utf8Error), - NonUtf8ModuleText(Utf8Error), - NulInsideModuleText(NulError), - ModuleParsingError(String), - Lz4DecompressionFailure, - ZstdDecompressionFailure(usize), - UnexpectedArgument { - arg_name: &'static str, - expected: Vec, - observed: UInt, - }, - UnexpectedBinaryField { - field_name: &'static str, - expected: Vec, - observed: UInt, - }, - InvalidEnvVar { - var: &'static str, - pattern: &'static str, - value: String, - }, - UnexpectedExportTableSize { - expected: usize, - computed: usize, - }, - IntegrityCheck { - original: [u64; 2], - overriden: [u64; 2], - }, - NullPointer(&'static str), - UnknownLibrary(CUlibrary), -} - -unsafe impl Send for ErrorEntry {} -unsafe impl Sync for ErrorEntry {} - -impl From for ErrorEntry { - fn from(e: dark_api::fatbin::ParseError) -> Self { - match e { - dark_api::fatbin::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s), - dark_api::fatbin::ParseError::UnexpectedBinaryField { - field_name, - observed, - expected, - } => ErrorEntry::UnexpectedBinaryField { - field_name, - observed: UInt::from(observed), - expected: expected.into_iter().map(UInt::from).collect(), - }, - } - } -} - -impl From for ErrorEntry { - fn from(e: dark_api::fatbin::FatbinError) -> Self { - match e { - dark_api::fatbin::FatbinError::ParseFailure(parse_error) => parse_error.into(), - dark_api::fatbin::FatbinError::Lz4DecompressionFailure => { - ErrorEntry::Lz4DecompressionFailure - } - dark_api::fatbin::FatbinError::ZstdDecompressionFailure(c) => { - ErrorEntry::ZstdDecompressionFailure(c) - } - } - } -} - -impl Display for ErrorEntry { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ErrorEntry::IoError(e) => e.fmt(f), - ErrorEntry::CreatedDumpDirectory(dir) => { - write!( - f, - "Created dump directory {} ", - dir.as_os_str().to_string_lossy() - ) - } - ErrorEntry::ErrorBox(e) => e.fmt(f), - ErrorEntry::UnsupportedModule { - module, - raw_image, - kind, - } => { - write!( - f, - "Unsupported {} module {:?} loaded from module image {:?}", - kind, module, raw_image - ) - } - ErrorEntry::MalformedModulePath(e) => e.fmt(f), - ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f), - ErrorEntry::ModuleParsingError(file_name) => { - write!( - f, - "Error parsing module, log has been written to {}", - file_name - ) - } - ErrorEntry::NulInsideModuleText(e) => e.fmt(f), - ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), - ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)), - ErrorEntry::UnexpectedBinaryField { - field_name, - expected, - observed, - } => write!( - f, - "Unexpected field {}. Expected one of: [{}], observed: {}", - field_name, - expected - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - observed - ), - ErrorEntry::UnexpectedArgument { - arg_name, - expected, - observed, - } => write!( - f, - "Unexpected argument {}. Expected one of: {{{}}}, observed: {}", - arg_name, - expected - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - observed - ), - ErrorEntry::InvalidEnvVar { - var, - pattern, - value, - } => write!( - f, - "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" - ), - ErrorEntry::FunctionNotFound(cuda_function_name) => write!( - f, - "No function {cuda_function_name} in the underlying library" - ), - ErrorEntry::UnexpectedExportTableSize { expected, computed } => { - write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") - } - ErrorEntry::IntegrityCheck { original, overriden } => { - write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") - } - ErrorEntry::NullPointer(type_) => { - write!(f, "Null pointer of type {type_} encountered") - } - ErrorEntry::UnknownLibrary(culibrary) => { - write!(f, "Unknown library: ")?; - let mut temp_buffer = Vec::new(); - CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); - f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) - } - } - } -} - -#[derive(Clone, Copy)] -pub(crate) enum UInt { - U16(u16), - U32(u32), - USize(usize), -} - -impl From for UInt { - fn from(value: u16) -> Self { - UInt::U16(value) - } -} - -impl From for UInt { - fn from(value: u32) -> Self { - UInt::U32(value) - } -} - -impl From for UInt { - fn from(value: usize) -> Self { - UInt::USize(value) - } -} - -impl Display for UInt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - UInt::U16(x) => write!(f, "{:#x}", x), - UInt::U32(x) => write!(f, "{:#x}", x), - UInt::USize(x) => write!(f, "{:#x}", x), - } - } -} - -// Some of our writers want to have trailing zero (WinAPI debug logger) and some -// don't (everything else), this trait encapsulates that logic -pub(crate) trait WriteTrailingZeroAware { - fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()>; - fn flush(&mut self) -> std::io::Result<()>; - fn should_prefix(&self) -> bool; -} - -impl WriteTrailingZeroAware for File { - fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()> { - ::write_all(self, buf.split_last().unwrap().1) - } - - fn flush(&mut self) -> std::io::Result<()> { - ::flush(self) - } - - fn should_prefix(&self) -> bool { - false - } -} - -impl WriteTrailingZeroAware for Stderr { - fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()> { - ::write_all(self, buf.split_last().unwrap().1) - } - - fn flush(&mut self) -> std::io::Result<()> { - ::flush(self) - } - - fn should_prefix(&self) -> bool { - true - } -} - -#[cfg(windows)] -mod os { - use super::WriteTrailingZeroAware; - use std::{os::windows::prelude::AsRawHandle, ptr}; - use winapi::um::debugapi::OutputDebugStringA; - - struct OutputDebugString {} - - impl WriteTrailingZeroAware for OutputDebugString { - fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()> { - unsafe { OutputDebugStringA(buf.as_ptr() as *const _) }; - Ok(()) - } - - fn flush(&mut self) -> std::io::Result<()> { - Ok(()) - } - - fn should_prefix(&self) -> bool { - true - } - } - - pub(crate) fn new_debug_logger() -> Box { - let stderr = std::io::stderr(); - let log_to_stderr = stderr.as_raw_handle() != ptr::null_mut(); - if log_to_stderr { - Box::new(stderr) - } else { - Box::new(OutputDebugString {}) - } - } -} - -#[cfg(not(windows))] -mod os { - use super::WriteTrailingZeroAware; - - pub(crate) fn new_debug_logger() -> Box { - Box::new(std::io::stderr()) - } -} - -#[cfg(test)] -mod tests { - use super::{ErrorEntry, FnCallLog, WriteTrailingZeroAware}; - use crate::{ - log::{CudaFunctionName, WriteBuffer}, - FnCallLogStack, OuterCallGuard, - }; - use std::{ - cell::RefCell, - io, str, - sync::{Arc, Mutex}, - }; - - struct FailOnNthWrite { - fail_on: usize, - counter: usize, - } - - impl WriteTrailingZeroAware for FailOnNthWrite { - fn write_zero_aware(&mut self, _: &[u8]) -> std::io::Result<()> { - self.counter += 1; - if self.counter >= self.fail_on { - Err(io::Error::from_raw_os_error(4)) - } else { - Ok(()) - } - } - - fn flush(&mut self) -> std::io::Result<()> { - panic!() - } - - fn should_prefix(&self) -> bool { - false - } - } - - // Custom type to not trigger trait coherence rules - #[derive(Clone)] - struct ArcVec(Arc>>); - - impl WriteTrailingZeroAware for ArcVec { - fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()> { - let mut vec = self.0.lock().unwrap(); - vec.extend_from_slice(buf.split_last().unwrap().1); - Ok(()) - } - - fn flush(&mut self) -> std::io::Result<()> { - Ok(()) - } - - fn should_prefix(&self) -> bool { - false - } - } - - #[test] - // TODO: fix this, it should use drop guard for testing. - // Previously FnCallLog would implement Drop and write to the log - fn error_in_fallible_emitter_is_handled_gracefully() { - let result = ArcVec(Arc::new(Mutex::new(Vec::::new()))); - let infallible_emitter = Box::new(result.clone()) as Box; - let fallible_emitter = Some(Box::new(FailOnNthWrite { - fail_on: 1, - counter: 0, - }) as Box); - let mut write_buffer = WriteBuffer::new(); - write_buffer.unprefixed_buffer = Some(Vec::new()); - let mut writer = super::Writer { - fallible_emitter, - infallible_emitter, - write_buffer, - }; - let func_logger = FnCallLog { - name: CudaFunctionName::Normal("cuInit"), - args: None, - output: None, - subcalls: Vec::new(), - }; - let log_root = FnCallLogStack { - depth: 1, - log_root: func_logger, - }; - let log_root = RefCell::new(log_root); - let drop_guard = OuterCallGuard { - writer: &mut writer, - log_root: &log_root, - }; - - { - log_root - .borrow_mut() - .log_root - .log(ErrorEntry::IoError(io::Error::from_raw_os_error(1))); - log_root - .borrow_mut() - .log_root - .log(ErrorEntry::IoError(io::Error::from_raw_os_error(2))); - log_root - .borrow_mut() - .log_root - .log(ErrorEntry::IoError(io::Error::from_raw_os_error(3))); - } - drop(drop_guard); - - let result = result.0.lock().unwrap(); - let result_str = str::from_utf8(&*result).unwrap(); - let result_lines = result_str.lines().collect::>(); - assert_eq!(result_lines.len(), 5); - assert_eq!(result_lines[0], "cuInit(...) -> UNKNOWN"); - assert!(result_lines[1].starts_with(" ")); - assert!(result_lines[2].starts_with(" ")); - assert!(result_lines[3].starts_with(" ")); - assert!(result_lines[4].starts_with(" ")); - } -} +use super::Settings; +use crate::FnCallLog; +use crate::LogEntry; +use cuda_types::cuda::*; +use format::CudaDisplay; +use std::error::Error; +use std::ffi::c_void; +use std::ffi::NulError; +use std::fmt::Display; +use std::fs::File; +use std::io; +use std::io::Stderr; +use std::io::Write; +use std::path::PathBuf; +use std::str::Utf8Error; + +const LOG_PREFIX: &[u8] = b"[ZLUDA_DUMP] "; + +pub(crate) struct Writer { + // Fallible emitter is an optional emitter to the file system, we might lack + // file permissions or be out of disk space + fallible_emitter: Option>, + // This is emitter that "always works" (and if it does not, then we don't + // care). In addition of normal logs it emits errors from fallible emitter + infallible_emitter: Box, + // This object could be recreated every time, but it's slightly better for performance to + // reuse the allocations by keeping the object in globals + write_buffer: WriteBuffer, +} + +impl Writer { + pub(crate) fn new() -> Self { + let debug_emitter = os::new_debug_logger(); + Self { + infallible_emitter: debug_emitter, + fallible_emitter: None, + write_buffer: WriteBuffer::new(), + } + } + + pub(crate) fn late_init(&mut self, settings: &Settings) -> Result<(), ErrorEntry> { + self.fallible_emitter = settings + .dump_dir + .as_ref() + .map(|path| { + Ok::<_, std::io::Error>(Box::new(File::create(path.to_path_buf().join("log.txt"))?) + as Box) + }) + .transpose() + .map_err(ErrorEntry::IoError)?; + self.write_buffer + .init(&self.fallible_emitter, &self.infallible_emitter); + Ok(()) + } + + pub(crate) fn write_and_flush(&mut self, log_root: &mut FnCallLog) { + self.write_all_from_depth(0, log_root); + self.write_buffer.finish(); + let error_from_writing_to_fallible_emitter = match self.fallible_emitter { + Some(ref mut emitter) => self.write_buffer.send_to_and_flush(emitter), + None => Ok(()), + }; + if let Err(e) = error_from_writing_to_fallible_emitter { + self.hack_squeeze_in_additional_error(ErrorEntry::IoError(e)) + } + self.write_buffer + .send_to_and_flush(&mut self.infallible_emitter) + .ok(); + self.write_buffer.reset(); + log_root.reset(); + } + + fn write_all_from_depth(&mut self, depth: usize, fn_call: &FnCallLog) { + self.write_call(depth, fn_call); + for sub in fn_call.subcalls.iter() { + match sub { + LogEntry::FnCall(fn_call) => self.write_all_from_depth(depth + 1, fn_call), + LogEntry::Error(err) => self.write_error(depth + 1, err), + } + } + } + + fn write_call(&mut self, depth: usize, call: &FnCallLog) { + self.write_buffer.start_line(depth); + write!(self.write_buffer, "{}", call.name).ok(); + match call.args { + Some(ref args) => { + self.write_buffer.write_all(args).ok(); + } + None => { + self.write_buffer.write_all(b"(...)").ok(); + } + } + self.write_buffer.write_all(b" -> ").ok(); + if let Some(ref result) = call.output { + self.write_buffer.write_all(result).ok(); + } else { + self.write_buffer.write_all(b"UNKNOWN").ok(); + }; + self.write_buffer.end_line(); + } + + fn write_error(&mut self, depth: usize, error: &ErrorEntry) { + self.write_buffer.start_line(depth); + write!(self.write_buffer, "{}", error).ok(); + self.write_buffer.end_line(); + } + + fn hack_squeeze_in_additional_error(&mut self, entry: ErrorEntry) { + self.write_buffer.undo_finish(); + write!(self.write_buffer, " {}", entry).ok(); + self.write_buffer.end_line(); + self.write_buffer.finish(); + } +} + +// When writing out to the emitter (file, WinAPI, whatever else) instead of +// writing piece-by-piece it's better to first concatenate everything in memory +// then write out from memory to the slow emitter only once. +// Additionally we might have an unprefixed and prefixed buffer, this struct +// handles this detail +struct WriteBuffer { + prefixed_buffer: Option>, + unprefixed_buffer: Option>, +} + +impl WriteBuffer { + fn new() -> Self { + WriteBuffer { + prefixed_buffer: None, + unprefixed_buffer: None, + } + } + + fn init( + &mut self, + fallible_emitter: &Option>, + infallible_emitter: &Box, + ) { + if infallible_emitter.should_prefix() { + self.prefixed_buffer = Some(Vec::new()); + } else { + self.unprefixed_buffer = Some(Vec::new()); + } + if let Some(emitter) = fallible_emitter { + if emitter.should_prefix() { + self.prefixed_buffer = Some(Vec::new()); + } else { + self.unprefixed_buffer = Some(Vec::new()); + } + } + } + + fn all_buffers(&mut self) -> impl Iterator> { + self.prefixed_buffer + .as_mut() + .into_iter() + .chain(self.unprefixed_buffer.as_mut().into_iter()) + } + + fn start_line(&mut self, depth: usize) { + if let Some(buffer) = &mut self.prefixed_buffer { + buffer.extend_from_slice(LOG_PREFIX); + } + if depth == 0 { + return; + } + for buffer in self.all_buffers() { + buffer.extend(std::iter::repeat_n(b' ', depth * 4)); + } + } + + fn end_line(&mut self) { + for buffer in self.all_buffers() { + buffer.push(b'\n'); + } + } + + fn finish(&mut self) { + for buffer in self.all_buffers() { + buffer.push(b'\0'); + } + } + + fn undo_finish(&mut self) { + for buffer in self.all_buffers() { + buffer.truncate(buffer.len() - 1); + } + } + + fn send_to_and_flush( + &self, + log_emitter: &mut Box, + ) -> Result<(), io::Error> { + if log_emitter.should_prefix() { + log_emitter.write_zero_aware( + &*self + .prefixed_buffer + .as_ref() + .unwrap_or_else(|| unreachable!()), + )?; + } else { + log_emitter.write_zero_aware( + &*self + .unprefixed_buffer + .as_ref() + .unwrap_or_else(|| unreachable!()), + )?; + } + log_emitter.flush() + } + + fn reset(&mut self) { + for buffer in self.all_buffers() { + unsafe { buffer.set_len(0) }; + } + } +} + +impl Write for WriteBuffer { + fn write(&mut self, buf: &[u8]) -> io::Result { + if let Some(buffer) = &mut self.prefixed_buffer { + buffer.extend_from_slice(buf); + } + if let Some(buffer) = &mut self.unprefixed_buffer { + buffer.extend_from_slice(buf); + } + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +#[derive(Clone)] +pub(crate) enum CudaFunctionName { + Normal(&'static str), + Dark { guid: CUuuid, index: usize }, +} + +impl Display for CudaFunctionName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CudaFunctionName::Normal(fn_) => f.write_str(fn_), + CudaFunctionName::Dark { guid, index } => { + match ::dark_api::cuda::guid_to_name(guid, *index) { + Some((name, fn_)) => match fn_ { + Some(fn_) => write!(f, "{{{name}}}::{fn_}"), + None => write!(f, "{{{name}}}::{index}"), + }, + None => { + let mut temp = Vec::new(); + format::CudaDisplay::write(guid, "", 0, &mut temp) + .map_err(|_| std::fmt::Error::default())?; + let temp = String::from_utf8_lossy(&*temp); + write!(f, "{temp}::{index}") + } + } + } + } + } +} + +pub(crate) enum ErrorEntry { + IoError(io::Error), + CreatedDumpDirectory(PathBuf), + ErrorBox(Box), + UnsupportedModule { + module: CUmodule, + raw_image: *const c_void, + kind: &'static str, + }, + FunctionNotFound(CudaFunctionName), + MalformedModulePath(Utf8Error), + NonUtf8ModuleText(Utf8Error), + NulInsideModuleText(NulError), + ModuleParsingError(String), + Lz4DecompressionFailure, + ZstdDecompressionFailure(usize), + UnexpectedArgument { + arg_name: &'static str, + expected: Vec, + observed: UInt, + }, + UnexpectedBinaryField { + field_name: &'static str, + expected: Vec, + observed: UInt, + }, + InvalidEnvVar { + var: &'static str, + pattern: &'static str, + value: String, + }, + UnexpectedExportTableSize { + expected: usize, + computed: usize, + }, + IntegrityCheck { + original: [u64; 2], + overriden: [u64; 2], + }, + NullPointer(&'static str), + UnknownLibrary(CUlibrary), +} + +unsafe impl Send for ErrorEntry {} +unsafe impl Sync for ErrorEntry {} + +impl From for ErrorEntry { + fn from(e: dark_api::fatbin::ParseError) -> Self { + match e { + dark_api::fatbin::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s), + dark_api::fatbin::ParseError::UnexpectedBinaryField { + field_name, + observed, + expected, + } => ErrorEntry::UnexpectedBinaryField { + field_name, + observed: UInt::from(observed), + expected: expected.into_iter().map(UInt::from).collect(), + }, + } + } +} + +impl From for ErrorEntry { + fn from(e: dark_api::fatbin::FatbinError) -> Self { + match e { + dark_api::fatbin::FatbinError::ParseFailure(parse_error) => parse_error.into(), + dark_api::fatbin::FatbinError::Lz4DecompressionFailure => { + ErrorEntry::Lz4DecompressionFailure + } + dark_api::fatbin::FatbinError::ZstdDecompressionFailure(c) => { + ErrorEntry::ZstdDecompressionFailure(c) + } + } + } +} + +impl Display for ErrorEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorEntry::IoError(e) => e.fmt(f), + ErrorEntry::CreatedDumpDirectory(dir) => { + write!( + f, + "Created dump directory {} ", + dir.as_os_str().to_string_lossy() + ) + } + ErrorEntry::ErrorBox(e) => e.fmt(f), + ErrorEntry::UnsupportedModule { + module, + raw_image, + kind, + } => { + write!( + f, + "Unsupported {} module {:?} loaded from module image {:?}", + kind, module, raw_image + ) + } + ErrorEntry::MalformedModulePath(e) => e.fmt(f), + ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f), + ErrorEntry::ModuleParsingError(file_name) => { + write!( + f, + "Error parsing module, log has been written to {}", + file_name + ) + } + ErrorEntry::NulInsideModuleText(e) => e.fmt(f), + ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), + ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)), + ErrorEntry::UnexpectedBinaryField { + field_name, + expected, + observed, + } => write!( + f, + "Unexpected field {}. Expected one of: [{}], observed: {}", + field_name, + expected + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + observed + ), + ErrorEntry::UnexpectedArgument { + arg_name, + expected, + observed, + } => write!( + f, + "Unexpected argument {}. Expected one of: {{{}}}, observed: {}", + arg_name, + expected + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + observed + ), + ErrorEntry::InvalidEnvVar { + var, + pattern, + value, + } => write!( + f, + "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" + ), + ErrorEntry::FunctionNotFound(cuda_function_name) => write!( + f, + "No function {cuda_function_name} in the underlying library" + ), + ErrorEntry::UnexpectedExportTableSize { expected, computed } => { + write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") + } + ErrorEntry::IntegrityCheck { original, overriden } => { + write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") + } + ErrorEntry::NullPointer(type_) => { + write!(f, "Null pointer of type {type_} encountered") + } + ErrorEntry::UnknownLibrary(culibrary) => { + write!(f, "Unknown library: ")?; + let mut temp_buffer = Vec::new(); + CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); + f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) + } + } + } +} + +#[derive(Clone, Copy)] +pub(crate) enum UInt { + U16(u16), + U32(u32), + USize(usize), +} + +impl From for UInt { + fn from(value: u16) -> Self { + UInt::U16(value) + } +} + +impl From for UInt { + fn from(value: u32) -> Self { + UInt::U32(value) + } +} + +impl From for UInt { + fn from(value: usize) -> Self { + UInt::USize(value) + } +} + +impl Display for UInt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UInt::U16(x) => write!(f, "{:#x}", x), + UInt::U32(x) => write!(f, "{:#x}", x), + UInt::USize(x) => write!(f, "{:#x}", x), + } + } +} + +// Some of our writers want to have trailing zero (WinAPI debug logger) and some +// don't (everything else), this trait encapsulates that logic +pub(crate) trait WriteTrailingZeroAware { + fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()>; + fn flush(&mut self) -> std::io::Result<()>; + fn should_prefix(&self) -> bool; +} + +impl WriteTrailingZeroAware for File { + fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()> { + ::write_all(self, buf.split_last().unwrap().1) + } + + fn flush(&mut self) -> std::io::Result<()> { + ::flush(self) + } + + fn should_prefix(&self) -> bool { + false + } +} + +impl WriteTrailingZeroAware for Stderr { + fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()> { + ::write_all(self, buf.split_last().unwrap().1) + } + + fn flush(&mut self) -> std::io::Result<()> { + ::flush(self) + } + + fn should_prefix(&self) -> bool { + true + } +} + +#[cfg(windows)] +mod os { + use super::WriteTrailingZeroAware; + use std::{os::windows::prelude::AsRawHandle, ptr}; + use winapi::um::debugapi::OutputDebugStringA; + + struct OutputDebugString {} + + impl WriteTrailingZeroAware for OutputDebugString { + fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()> { + unsafe { OutputDebugStringA(buf.as_ptr() as *const _) }; + Ok(()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + + fn should_prefix(&self) -> bool { + true + } + } + + pub(crate) fn new_debug_logger() -> Box { + let stderr = std::io::stderr(); + let log_to_stderr = stderr.as_raw_handle() != ptr::null_mut(); + if log_to_stderr { + Box::new(stderr) + } else { + Box::new(OutputDebugString {}) + } + } +} + +#[cfg(not(windows))] +mod os { + use super::WriteTrailingZeroAware; + + pub(crate) fn new_debug_logger() -> Box { + Box::new(std::io::stderr()) + } +} + +#[cfg(test)] +mod tests { + use super::{ErrorEntry, FnCallLog, WriteTrailingZeroAware}; + use crate::{ + log::{CudaFunctionName, WriteBuffer}, + FnCallLogStack, OuterCallGuard, + }; + use std::{ + cell::RefCell, + io, str, + sync::{Arc, Mutex}, + }; + + struct FailOnNthWrite { + fail_on: usize, + counter: usize, + } + + impl WriteTrailingZeroAware for FailOnNthWrite { + fn write_zero_aware(&mut self, _: &[u8]) -> std::io::Result<()> { + self.counter += 1; + if self.counter >= self.fail_on { + Err(io::Error::from_raw_os_error(4)) + } else { + Ok(()) + } + } + + fn flush(&mut self) -> std::io::Result<()> { + panic!() + } + + fn should_prefix(&self) -> bool { + false + } + } + + // Custom type to not trigger trait coherence rules + #[derive(Clone)] + struct ArcVec(Arc>>); + + impl WriteTrailingZeroAware for ArcVec { + fn write_zero_aware(&mut self, buf: &[u8]) -> std::io::Result<()> { + let mut vec = self.0.lock().unwrap(); + vec.extend_from_slice(buf.split_last().unwrap().1); + Ok(()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + + fn should_prefix(&self) -> bool { + false + } + } + + #[test] + // TODO: fix this, it should use drop guard for testing. + // Previously FnCallLog would implement Drop and write to the log + fn error_in_fallible_emitter_is_handled_gracefully() { + let result = ArcVec(Arc::new(Mutex::new(Vec::::new()))); + let infallible_emitter = Box::new(result.clone()) as Box; + let fallible_emitter = Some(Box::new(FailOnNthWrite { + fail_on: 1, + counter: 0, + }) as Box); + let mut write_buffer = WriteBuffer::new(); + write_buffer.unprefixed_buffer = Some(Vec::new()); + let mut writer = super::Writer { + fallible_emitter, + infallible_emitter, + write_buffer, + }; + let func_logger = FnCallLog { + name: CudaFunctionName::Normal("cuInit"), + args: None, + output: None, + subcalls: Vec::new(), + }; + let log_root = FnCallLogStack { + depth: 1, + log_root: func_logger, + }; + let log_root = RefCell::new(log_root); + let drop_guard = OuterCallGuard { + writer: &mut writer, + log_root: &log_root, + }; + + { + log_root + .borrow_mut() + .log_root + .log(ErrorEntry::IoError(io::Error::from_raw_os_error(1))); + log_root + .borrow_mut() + .log_root + .log(ErrorEntry::IoError(io::Error::from_raw_os_error(2))); + log_root + .borrow_mut() + .log_root + .log(ErrorEntry::IoError(io::Error::from_raw_os_error(3))); + } + drop(drop_guard); + + let result = result.0.lock().unwrap(); + let result_str = str::from_utf8(&*result).unwrap(); + let result_lines = result_str.lines().collect::>(); + assert_eq!(result_lines.len(), 5); + assert_eq!(result_lines[0], "cuInit(...) -> UNKNOWN"); + assert!(result_lines[1].starts_with(" ")); + assert!(result_lines[2].starts_with(" ")); + assert!(result_lines[3].starts_with(" ")); + assert!(result_lines[4].starts_with(" ")); + } +} diff --git a/zluda_dump/src/os_unix.rs b/zluda_dump/src/os_unix.rs index 7d915b1..27e695c 100644 --- a/zluda_dump/src/os_unix.rs +++ b/zluda_dump/src/os_unix.rs @@ -1,81 +1,81 @@ -use cuda_types::cuda::CUuuid; -use std::ffi::{c_void, CStr, CString}; -use std::mem; - -pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1"; - -pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void { - let libcuda_path = CString::new(libcuda_path).unwrap(); - libc::dlopen( - libcuda_path.as_ptr() as *const _, - libc::RTLD_LOCAL | libc::RTLD_NOW, - ) -} - -pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void { - libc::dlsym(handle, func.as_ptr() as *const _) -} - -#[macro_export] -macro_rules! os_log { - ($format:tt) => { - { - eprintln!("[ZLUDA_DUMP] {}", format!($format)); - } - }; - ($format:tt, $($obj: expr),+) => { - { - eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+)); - } - }; -} - -//RDI, RSI, RDX, RCX, R8, R9 -#[cfg(target_arch = "x86_64")] -pub fn get_thunk( - original_fn: *const c_void, - report_fn: unsafe extern "system" fn(&CUuuid, usize), - guid: *const CUuuid, - idx: usize, -) -> *const c_void { - use dynasmrt::{dynasm, DynasmApi}; - let mut ops = dynasmrt::x64::Assembler::new().unwrap(); - let start = ops.offset(); - dynasm!(ops - // stack alignment - ; sub rsp, 8 - ; push rdi - ; push rsi - ; push rdx - ; push rcx - ; push r8 - ; push r9 - ; mov rdi, QWORD guid as i64 - ; mov rsi, QWORD idx as i64 - ; mov rax, QWORD report_fn as i64 - ; call rax - ; pop r9 - ; pop r8 - ; pop rcx - ; pop rdx - ; pop rsi - ; pop rdi - ; add rsp, 8 - ; mov rax, QWORD original_fn as i64 - ; jmp rax - ; int 3 - ); - let exe_buf = ops.finalize().unwrap(); - let result_fn = exe_buf.ptr(start); - mem::forget(exe_buf); - result_fn as *const _ -} - -#[link(name = "pthread")] -unsafe extern "C" { - fn pthread_self() -> std::os::unix::thread::RawPthread; -} - -pub(crate) fn current_thread() -> u32 { - (unsafe { pthread_self() }) as u32 -} +use cuda_types::cuda::CUuuid; +use std::ffi::{c_void, CStr, CString}; +use std::mem; + +pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1"; + +pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void { + let libcuda_path = CString::new(libcuda_path).unwrap(); + libc::dlopen( + libcuda_path.as_ptr() as *const _, + libc::RTLD_LOCAL | libc::RTLD_NOW, + ) +} + +pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void { + libc::dlsym(handle, func.as_ptr() as *const _) +} + +#[macro_export] +macro_rules! os_log { + ($format:tt) => { + { + eprintln!("[ZLUDA_DUMP] {}", format!($format)); + } + }; + ($format:tt, $($obj: expr),+) => { + { + eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+)); + } + }; +} + +//RDI, RSI, RDX, RCX, R8, R9 +#[cfg(target_arch = "x86_64")] +pub fn get_thunk( + original_fn: *const c_void, + report_fn: unsafe extern "system" fn(&CUuuid, usize), + guid: *const CUuuid, + idx: usize, +) -> *const c_void { + use dynasmrt::{dynasm, DynasmApi}; + let mut ops = dynasmrt::x64::Assembler::new().unwrap(); + let start = ops.offset(); + dynasm!(ops + // stack alignment + ; sub rsp, 8 + ; push rdi + ; push rsi + ; push rdx + ; push rcx + ; push r8 + ; push r9 + ; mov rdi, QWORD guid as i64 + ; mov rsi, QWORD idx as i64 + ; mov rax, QWORD report_fn as i64 + ; call rax + ; pop r9 + ; pop r8 + ; pop rcx + ; pop rdx + ; pop rsi + ; pop rdi + ; add rsp, 8 + ; mov rax, QWORD original_fn as i64 + ; jmp rax + ; int 3 + ); + let exe_buf = ops.finalize().unwrap(); + let result_fn = exe_buf.ptr(start); + mem::forget(exe_buf); + result_fn as *const _ +} + +#[link(name = "pthread")] +unsafe extern "C" { + fn pthread_self() -> std::os::unix::thread::RawPthread; +} + +pub(crate) fn current_thread() -> u32 { + (unsafe { pthread_self() }) as u32 +} diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index fd2ea36..5d6ce6f 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -1,190 +1,190 @@ -use std::{ - ffi::{c_void, CStr}, - mem, ptr, - sync::LazyLock, -}; - -use std::os::windows::io::AsRawHandle; -use winapi::{ - shared::minwindef::{FARPROC, HMODULE}, - um::debugapi::OutputDebugStringA, - um::libloaderapi::{GetProcAddress, LoadLibraryW}, -}; - -use cuda_types::cuda::CUuuid; - -pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll"; -const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0"; -const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0"; - -static PLATFORM_LIBRARY: LazyLock = - LazyLock::new(|| unsafe { PlatformLibrary::new() }); - -#[allow(non_snake_case)] -struct PlatformLibrary { - LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE, - GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC, -} - -impl PlatformLibrary { - #[allow(non_snake_case)] - unsafe fn new() -> Self { - let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() { - None => ( - LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE, - mem::transmute( - GetProcAddress - as unsafe extern "system" fn( - hModule: HMODULE, - lpProcName: *const i8, - ) -> FARPROC, - ), - ), - Some(zluda_with) => ( - mem::transmute(GetProcAddress( - zluda_with, - LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _, - )), - mem::transmute(GetProcAddress( - zluda_with, - GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _, - )), - ), - }; - PlatformLibrary { - LoadLibraryW, - GetProcAddress, - } - } - - unsafe fn get_detourer_module() -> Option { - let mut module = ptr::null_mut(); - loop { - module = detours_sys::DetourEnumerateModules(module); - if module == ptr::null_mut() { - break; - } - let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _); - if payload != ptr::null_mut() { - return Some(module as _); - } - } - None - } -} - -pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void { - let libcuda_path_uf16 = libcuda_path - .encode_utf16() - .chain(std::iter::once(0)) - .collect::>(); - (PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _ -} - -pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void { - (PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _ -} - -#[macro_export] -macro_rules! os_log { - ($format:tt) => { - { - use crate::os::__log_impl; - __log_impl(format!($format)); - } - }; - ($format:tt, $($obj: expr),+) => { - { - use crate::os::__log_impl; - __log_impl(format!($format, $($obj,)+)); - } - }; -} - -pub fn __log_impl(s: String) { - let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut(); - if log_to_stderr { - eprintln!("[ZLUDA_DUMP] {}", s); - } else { - let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2); - win_str.push_str("[ZLUDA_DUMP] "); - win_str.push_str(&s); - win_str.push_str("\n\0"); - unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) }; - } -} - -#[cfg(target_arch = "x86")] -pub fn get_thunk( - original_fn: *const c_void, - report_fn: unsafe extern "system" fn(&CUuuid, usize), - guid: *const CUuuid, - idx: usize, -) -> *const c_void { - use dynasmrt::{dynasm, DynasmApi}; - let mut ops = dynasmrt::x86::Assembler::new().unwrap(); - let start = ops.offset(); - dynasm!(ops - ; .arch x86 - ; push idx as i32 - ; push guid as i32 - ; mov eax, report_fn as i32 - ; call eax - ; mov eax, original_fn as i32 - ; jmp eax - ; int 3 - ); - let exe_buf = ops.finalize().unwrap(); - let result_fn = exe_buf.ptr(start); - mem::forget(exe_buf); - result_fn as *const _ -} - -//RCX, RDX, R8, R9 -#[cfg(target_arch = "x86_64")] -pub fn get_thunk( - original_fn: *const c_void, - report_fn: unsafe extern "system" fn(&CUuuid, usize), - guid: *const CUuuid, - idx: usize, -) -> *const c_void { - use dynasmrt::{dynasm, DynasmApi}; - let mut ops = dynasmrt::x86::Assembler::new().unwrap(); - let start = ops.offset(); - // Let's hope there's never more than 4 arguments - dynasm!(ops - ; .arch x64 - ; push rbp - ; mov rbp, rsp - ; push rcx - ; push rdx - ; push r8 - ; push r9 - ; mov rcx, QWORD guid as i64 - ; mov rdx, QWORD idx as i64 - ; mov rax, QWORD report_fn as i64 - ; call rax - ; pop r9 - ; pop r8 - ; pop rdx - ; pop rcx - ; mov rax, QWORD original_fn as i64 - ; call rax - ; pop rbp - ; ret - ; int 3 - ); - let exe_buf = ops.finalize().unwrap(); - let result_fn = exe_buf.ptr(start); - mem::forget(exe_buf); - result_fn as *const _ -} - -#[link(name = "kernel32")] -unsafe extern "system" { - fn GetCurrentThreadId() -> u32; -} - -pub(crate) fn current_thread() -> u32 { - unsafe { GetCurrentThreadId() } -} +use std::{ + ffi::{c_void, CStr}, + mem, ptr, + sync::LazyLock, +}; + +use std::os::windows::io::AsRawHandle; +use winapi::{ + shared::minwindef::{FARPROC, HMODULE}, + um::debugapi::OutputDebugStringA, + um::libloaderapi::{GetProcAddress, LoadLibraryW}, +}; + +use cuda_types::cuda::CUuuid; + +pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll"; +const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0"; +const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0"; + +static PLATFORM_LIBRARY: LazyLock = + LazyLock::new(|| unsafe { PlatformLibrary::new() }); + +#[allow(non_snake_case)] +struct PlatformLibrary { + LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE, + GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC, +} + +impl PlatformLibrary { + #[allow(non_snake_case)] + unsafe fn new() -> Self { + let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() { + None => ( + LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE, + mem::transmute( + GetProcAddress + as unsafe extern "system" fn( + hModule: HMODULE, + lpProcName: *const i8, + ) -> FARPROC, + ), + ), + Some(zluda_with) => ( + mem::transmute(GetProcAddress( + zluda_with, + LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _, + )), + mem::transmute(GetProcAddress( + zluda_with, + GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _, + )), + ), + }; + PlatformLibrary { + LoadLibraryW, + GetProcAddress, + } + } + + unsafe fn get_detourer_module() -> Option { + let mut module = ptr::null_mut(); + loop { + module = detours_sys::DetourEnumerateModules(module); + if module == ptr::null_mut() { + break; + } + let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _); + if payload != ptr::null_mut() { + return Some(module as _); + } + } + None + } +} + +pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void { + let libcuda_path_uf16 = libcuda_path + .encode_utf16() + .chain(std::iter::once(0)) + .collect::>(); + (PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _ +} + +pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void { + (PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _ +} + +#[macro_export] +macro_rules! os_log { + ($format:tt) => { + { + use crate::os::__log_impl; + __log_impl(format!($format)); + } + }; + ($format:tt, $($obj: expr),+) => { + { + use crate::os::__log_impl; + __log_impl(format!($format, $($obj,)+)); + } + }; +} + +pub fn __log_impl(s: String) { + let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut(); + if log_to_stderr { + eprintln!("[ZLUDA_DUMP] {}", s); + } else { + let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2); + win_str.push_str("[ZLUDA_DUMP] "); + win_str.push_str(&s); + win_str.push_str("\n\0"); + unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) }; + } +} + +#[cfg(target_arch = "x86")] +pub fn get_thunk( + original_fn: *const c_void, + report_fn: unsafe extern "system" fn(&CUuuid, usize), + guid: *const CUuuid, + idx: usize, +) -> *const c_void { + use dynasmrt::{dynasm, DynasmApi}; + let mut ops = dynasmrt::x86::Assembler::new().unwrap(); + let start = ops.offset(); + dynasm!(ops + ; .arch x86 + ; push idx as i32 + ; push guid as i32 + ; mov eax, report_fn as i32 + ; call eax + ; mov eax, original_fn as i32 + ; jmp eax + ; int 3 + ); + let exe_buf = ops.finalize().unwrap(); + let result_fn = exe_buf.ptr(start); + mem::forget(exe_buf); + result_fn as *const _ +} + +//RCX, RDX, R8, R9 +#[cfg(target_arch = "x86_64")] +pub fn get_thunk( + original_fn: *const c_void, + report_fn: unsafe extern "system" fn(&CUuuid, usize), + guid: *const CUuuid, + idx: usize, +) -> *const c_void { + use dynasmrt::{dynasm, DynasmApi}; + let mut ops = dynasmrt::x86::Assembler::new().unwrap(); + let start = ops.offset(); + // Let's hope there's never more than 4 arguments + dynasm!(ops + ; .arch x64 + ; push rbp + ; mov rbp, rsp + ; push rcx + ; push rdx + ; push r8 + ; push r9 + ; mov rcx, QWORD guid as i64 + ; mov rdx, QWORD idx as i64 + ; mov rax, QWORD report_fn as i64 + ; call rax + ; pop r9 + ; pop r8 + ; pop rdx + ; pop rcx + ; mov rax, QWORD original_fn as i64 + ; call rax + ; pop rbp + ; ret + ; int 3 + ); + let exe_buf = ops.finalize().unwrap(); + let result_fn = exe_buf.ptr(start); + mem::forget(exe_buf); + result_fn as *const _ +} + +#[link(name = "kernel32")] +unsafe extern "system" { + fn GetCurrentThreadId() -> u32; +} + +pub(crate) fn current_thread() -> u32 { + unsafe { GetCurrentThreadId() } +} diff --git a/zluda_dump/src/trace.rs b/zluda_dump/src/trace.rs index 13e2f4a..b840171 100644 --- a/zluda_dump/src/trace.rs +++ b/zluda_dump/src/trace.rs @@ -1,334 +1,334 @@ -use crate::{ - log::{self, UInt}, - trace, ErrorEntry, FnCallLog, Settings, -}; -use cuda_types::{ - cuda::*, - dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper}, -}; -use dark_api::fatbin::{ - decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule, -}; -use rustc_hash::{FxHashMap, FxHashSet}; -use std::{ - borrow::Cow, - ffi::{c_void, CStr, CString}, - fs::{self, File}, - io::{self, Read, Write}, - path::PathBuf, -}; -use unwrap_or::unwrap_some_or; - -// This struct is the heart of CUDA state tracking, it: -// * receives calls from the probes about changes to CUDA state -// * records updates to the state change -// * writes out relevant state change and details to disk and log -pub(crate) struct StateTracker { - writer: DumpWriter, - pub(crate) libraries: FxHashMap, - saved_modules: FxHashSet, - module_counter: usize, - submodule_counter: usize, - pub(crate) override_cc: Option<(u32, u32)>, -} - -#[derive(Clone, Copy)] -pub(crate) struct CodePointer(pub *const c_void); - -unsafe impl Send for CodePointer {} -unsafe impl Sync for CodePointer {} - -impl StateTracker { - pub(crate) fn new(settings: &Settings) -> Self { - StateTracker { - writer: DumpWriter::new(settings.dump_dir.clone()), - libraries: FxHashMap::default(), - saved_modules: FxHashSet::default(), - module_counter: 0, - submodule_counter: 0, - override_cc: settings.override_cc, - } - } - - pub(crate) fn record_new_module_file( - &mut self, - module: CUmodule, - file_name: *const i8, - fn_logger: &mut FnCallLog, - ) { - let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() { - Ok(f) => f, - Err(err) => { - fn_logger.log(log::ErrorEntry::MalformedModulePath(err)); - return; - } - }; - let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name); - fn_logger.log_io_error(maybe_io_error) - } - - fn try_record_new_module_file( - &mut self, - module: CUmodule, - fn_logger: &mut FnCallLog, - file_name: &str, - ) -> io::Result<()> { - let mut module_file = fs::File::open(file_name)?; - let mut read_buff = Vec::new(); - module_file.read_to_end(&mut read_buff)?; - self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger); - Ok(()) - } - - pub(crate) fn record_new_submodule( - &mut self, - module: CUmodule, - submodule: &[u8], - fn_logger: &mut FnCallLog, - type_: &'static str, - ) { - if self.saved_modules.insert(module) { - self.module_counter += 1; - self.submodule_counter = 0; - } - self.submodule_counter += 1; - fn_logger.log_io_error(self.writer.save_module( - self.module_counter, - Some(self.submodule_counter), - submodule, - type_, - )); - if type_ == "ptx" { - match CString::new(submodule) { - Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)), - Ok(submodule_cstring) => match submodule_cstring.to_str() { - Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)), - Ok(submodule_text) => self.try_parse_and_record_kernels( - fn_logger, - self.module_counter, - Some(self.submodule_counter), - submodule_text, - ), - }, - } - } - } - - pub(crate) fn record_new_module( - &mut self, - module: CUmodule, - raw_image: *const c_void, - fn_logger: &mut FnCallLog, - ) { - self.module_counter += 1; - if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG { - self.saved_modules.insert(module); - // TODO: Parse ELF and write it to disk - fn_logger.log(log::ErrorEntry::UnsupportedModule { - module, - raw_image, - kind: "ELF", - }) - } else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC { - self.saved_modules.insert(module); - // TODO: Figure out how to get size of archive module and write it to disk - fn_logger.log(log::ErrorEntry::UnsupportedModule { - module, - raw_image, - kind: "archive", - }) - } else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC { - unsafe { - fn_logger.try_(|fn_logger| { - trace::record_submodules_from_wrapped_fatbin( - module, - raw_image as *const FatbincWrapper, - fn_logger, - self, - ) - }); - } - } else { - self.record_module_ptx(module, raw_image, fn_logger) - } - } - - fn record_module_ptx( - &mut self, - module: CUmodule, - raw_image: *const c_void, - fn_logger: &mut FnCallLog, - ) { - self.saved_modules.insert(module); - let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str(); - let module_text = match module_text { - Ok(m) => m, - Err(utf8_err) => { - fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err)); - return; - } - }; - fn_logger.log_io_error(self.writer.save_module( - self.module_counter, - None, - module_text.as_bytes(), - "ptx", - )); - self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text); - } - - fn try_parse_and_record_kernels( - &mut self, - fn_logger: &mut FnCallLog, - module_index: usize, - submodule_index: Option, - module_text: &str, - ) { - let errors = ptx_parser::parse_for_errors(module_text); - if !errors.is_empty() { - fn_logger.log(log::ErrorEntry::ModuleParsingError( - DumpWriter::get_file_name(module_index, submodule_index, "log"), - )); - fn_logger.log_io_error(self.writer.save_module_error_log( - module_index, - submodule_index, - &*errors, - )); - } - } -} - -// This structs writes out information about CUDA execution to the dump dir -struct DumpWriter { - dump_dir: Option, -} - -impl DumpWriter { - fn new(dump_dir: Option) -> Self { - Self { dump_dir } - } - - fn save_module( - &self, - module_index: usize, - submodule_index: Option, - buffer: &[u8], - kind: &'static str, - ) -> io::Result<()> { - let mut dump_file = match &self.dump_dir { - None => return Ok(()), - Some(d) => d.clone(), - }; - dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); - let mut file = File::create(dump_file)?; - file.write_all(buffer)?; - Ok(()) - } - - fn save_module_error_log<'input>( - &self, - module_index: usize, - submodule_index: Option, - errors: &[ptx_parser::PtxError<'input>], - ) -> io::Result<()> { - let mut log_file = match &self.dump_dir { - None => return Ok(()), - Some(d) => d.clone(), - }; - log_file.push(Self::get_file_name(module_index, submodule_index, "log")); - let mut file = File::create(log_file)?; - for error in errors { - writeln!(file, "{}", error)?; - } - Ok(()) - } - - fn get_file_name(module_index: usize, submodule_index: Option, kind: &str) -> String { - match submodule_index { - None => { - format!("module_{:04}.{:02}", module_index, kind) - } - Some(submodule_index) => { - format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind) - } - } - } -} - -pub(crate) unsafe fn record_submodules_from_wrapped_fatbin( - module: CUmodule, - fatbinc_wrapper: *const FatbincWrapper, - fn_logger: &mut FnCallLog, - state: &mut StateTracker, -) -> Result<(), ErrorEntry> { - let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?; - let mut submodules = fatbin.get_submodules()?; - while let Some(current) = submodules.next()? { - record_submodules_from_fatbin(module, current, fn_logger, state)?; - } - Ok(()) -} - -pub(crate) unsafe fn record_submodules_from_fatbin( - module: CUmodule, - submodule: FatbinSubmodule, - logger: &mut FnCallLog, - state: &mut StateTracker, -) -> Result<(), ErrorEntry> { - record_submodules(module, logger, state, submodule.get_files())?; - Ok(()) -} - -pub(crate) unsafe fn record_submodules( - module: CUmodule, - fn_logger: &mut FnCallLog, - state: &mut StateTracker, - mut files: FatbinFileIterator, -) -> Result<(), ErrorEntry> { - while let Some(file) = files.next()? { - let mut payload = if file - .header - .flags - .contains(FatbinFileHeaderFlags::CompressedLz4) - { - Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())), - continue - )) - } else if file - .header - .flags - .contains(FatbinFileHeaderFlags::CompressedZstd) - { - Cow::Owned(unwrap_some_or!( - fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())), - continue - )) - } else { - Cow::Borrowed(file.get_payload()) - }; - match file.header.kind { - FatbinFileHeader::HEADER_KIND_PTX => { - while payload.last() == Some(&0) { - // remove trailing zeros - payload.to_mut().pop(); - } - state.record_new_submodule(module, &*payload, fn_logger, "ptx") - } - FatbinFileHeader::HEADER_KIND_ELF => { - state.record_new_submodule(module, &*payload, fn_logger, "elf") - } - _ => { - fn_logger.log(log::ErrorEntry::UnexpectedBinaryField { - field_name: "FATBIN_FILE_HEADER_KIND", - expected: vec![ - UInt::U16(FatbinFileHeader::HEADER_KIND_PTX), - UInt::U16(FatbinFileHeader::HEADER_KIND_ELF), - ], - observed: UInt::U16(file.header.kind), - }); - } - } - } - Ok(()) -} +use crate::{ + log::{self, UInt}, + trace, ErrorEntry, FnCallLog, Settings, +}; +use cuda_types::{ + cuda::*, + dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper}, +}; +use dark_api::fatbin::{ + decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule, +}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{ + borrow::Cow, + ffi::{c_void, CStr, CString}, + fs::{self, File}, + io::{self, Read, Write}, + path::PathBuf, +}; +use unwrap_or::unwrap_some_or; + +// This struct is the heart of CUDA state tracking, it: +// * receives calls from the probes about changes to CUDA state +// * records updates to the state change +// * writes out relevant state change and details to disk and log +pub(crate) struct StateTracker { + writer: DumpWriter, + pub(crate) libraries: FxHashMap, + saved_modules: FxHashSet, + module_counter: usize, + submodule_counter: usize, + pub(crate) override_cc: Option<(u32, u32)>, +} + +#[derive(Clone, Copy)] +pub(crate) struct CodePointer(pub *const c_void); + +unsafe impl Send for CodePointer {} +unsafe impl Sync for CodePointer {} + +impl StateTracker { + pub(crate) fn new(settings: &Settings) -> Self { + StateTracker { + writer: DumpWriter::new(settings.dump_dir.clone()), + libraries: FxHashMap::default(), + saved_modules: FxHashSet::default(), + module_counter: 0, + submodule_counter: 0, + override_cc: settings.override_cc, + } + } + + pub(crate) fn record_new_module_file( + &mut self, + module: CUmodule, + file_name: *const i8, + fn_logger: &mut FnCallLog, + ) { + let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() { + Ok(f) => f, + Err(err) => { + fn_logger.log(log::ErrorEntry::MalformedModulePath(err)); + return; + } + }; + let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name); + fn_logger.log_io_error(maybe_io_error) + } + + fn try_record_new_module_file( + &mut self, + module: CUmodule, + fn_logger: &mut FnCallLog, + file_name: &str, + ) -> io::Result<()> { + let mut module_file = fs::File::open(file_name)?; + let mut read_buff = Vec::new(); + module_file.read_to_end(&mut read_buff)?; + self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger); + Ok(()) + } + + pub(crate) fn record_new_submodule( + &mut self, + module: CUmodule, + submodule: &[u8], + fn_logger: &mut FnCallLog, + type_: &'static str, + ) { + if self.saved_modules.insert(module) { + self.module_counter += 1; + self.submodule_counter = 0; + } + self.submodule_counter += 1; + fn_logger.log_io_error(self.writer.save_module( + self.module_counter, + Some(self.submodule_counter), + submodule, + type_, + )); + if type_ == "ptx" { + match CString::new(submodule) { + Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)), + Ok(submodule_cstring) => match submodule_cstring.to_str() { + Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)), + Ok(submodule_text) => self.try_parse_and_record_kernels( + fn_logger, + self.module_counter, + Some(self.submodule_counter), + submodule_text, + ), + }, + } + } + } + + pub(crate) fn record_new_module( + &mut self, + module: CUmodule, + raw_image: *const c_void, + fn_logger: &mut FnCallLog, + ) { + self.module_counter += 1; + if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG { + self.saved_modules.insert(module); + // TODO: Parse ELF and write it to disk + fn_logger.log(log::ErrorEntry::UnsupportedModule { + module, + raw_image, + kind: "ELF", + }) + } else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC { + self.saved_modules.insert(module); + // TODO: Figure out how to get size of archive module and write it to disk + fn_logger.log(log::ErrorEntry::UnsupportedModule { + module, + raw_image, + kind: "archive", + }) + } else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC { + unsafe { + fn_logger.try_(|fn_logger| { + trace::record_submodules_from_wrapped_fatbin( + module, + raw_image as *const FatbincWrapper, + fn_logger, + self, + ) + }); + } + } else { + self.record_module_ptx(module, raw_image, fn_logger) + } + } + + fn record_module_ptx( + &mut self, + module: CUmodule, + raw_image: *const c_void, + fn_logger: &mut FnCallLog, + ) { + self.saved_modules.insert(module); + let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str(); + let module_text = match module_text { + Ok(m) => m, + Err(utf8_err) => { + fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err)); + return; + } + }; + fn_logger.log_io_error(self.writer.save_module( + self.module_counter, + None, + module_text.as_bytes(), + "ptx", + )); + self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text); + } + + fn try_parse_and_record_kernels( + &mut self, + fn_logger: &mut FnCallLog, + module_index: usize, + submodule_index: Option, + module_text: &str, + ) { + let errors = ptx_parser::parse_for_errors(module_text); + if !errors.is_empty() { + fn_logger.log(log::ErrorEntry::ModuleParsingError( + DumpWriter::get_file_name(module_index, submodule_index, "log"), + )); + fn_logger.log_io_error(self.writer.save_module_error_log( + module_index, + submodule_index, + &*errors, + )); + } + } +} + +// This structs writes out information about CUDA execution to the dump dir +struct DumpWriter { + dump_dir: Option, +} + +impl DumpWriter { + fn new(dump_dir: Option) -> Self { + Self { dump_dir } + } + + fn save_module( + &self, + module_index: usize, + submodule_index: Option, + buffer: &[u8], + kind: &'static str, + ) -> io::Result<()> { + let mut dump_file = match &self.dump_dir { + None => return Ok(()), + Some(d) => d.clone(), + }; + dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); + let mut file = File::create(dump_file)?; + file.write_all(buffer)?; + Ok(()) + } + + fn save_module_error_log<'input>( + &self, + module_index: usize, + submodule_index: Option, + errors: &[ptx_parser::PtxError<'input>], + ) -> io::Result<()> { + let mut log_file = match &self.dump_dir { + None => return Ok(()), + Some(d) => d.clone(), + }; + log_file.push(Self::get_file_name(module_index, submodule_index, "log")); + let mut file = File::create(log_file)?; + for error in errors { + writeln!(file, "{}", error)?; + } + Ok(()) + } + + fn get_file_name(module_index: usize, submodule_index: Option, kind: &str) -> String { + match submodule_index { + None => { + format!("module_{:04}.{:02}", module_index, kind) + } + Some(submodule_index) => { + format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind) + } + } + } +} + +pub(crate) unsafe fn record_submodules_from_wrapped_fatbin( + module: CUmodule, + fatbinc_wrapper: *const FatbincWrapper, + fn_logger: &mut FnCallLog, + state: &mut StateTracker, +) -> Result<(), ErrorEntry> { + let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?; + let mut submodules = fatbin.get_submodules()?; + while let Some(current) = submodules.next()? { + record_submodules_from_fatbin(module, current, fn_logger, state)?; + } + Ok(()) +} + +pub(crate) unsafe fn record_submodules_from_fatbin( + module: CUmodule, + submodule: FatbinSubmodule, + logger: &mut FnCallLog, + state: &mut StateTracker, +) -> Result<(), ErrorEntry> { + record_submodules(module, logger, state, submodule.get_files())?; + Ok(()) +} + +pub(crate) unsafe fn record_submodules( + module: CUmodule, + fn_logger: &mut FnCallLog, + state: &mut StateTracker, + mut files: FatbinFileIterator, +) -> Result<(), ErrorEntry> { + while let Some(file) = files.next()? { + let mut payload = if file + .header + .flags + .contains(FatbinFileHeaderFlags::CompressedLz4) + { + Cow::Owned(unwrap_some_or!( + fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())), + continue + )) + } else if file + .header + .flags + .contains(FatbinFileHeaderFlags::CompressedZstd) + { + Cow::Owned(unwrap_some_or!( + fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())), + continue + )) + } else { + Cow::Borrowed(file.get_payload()) + }; + match file.header.kind { + FatbinFileHeader::HEADER_KIND_PTX => { + while payload.last() == Some(&0) { + // remove trailing zeros + payload.to_mut().pop(); + } + state.record_new_submodule(module, &*payload, fn_logger, "ptx") + } + FatbinFileHeader::HEADER_KIND_ELF => { + state.record_new_submodule(module, &*payload, fn_logger, "elf") + } + _ => { + fn_logger.log(log::ErrorEntry::UnexpectedBinaryField { + field_name: "FATBIN_FILE_HEADER_KIND", + expected: vec![ + UInt::U16(FatbinFileHeader::HEADER_KIND_PTX), + UInt::U16(FatbinFileHeader::HEADER_KIND_ELF), + ], + observed: UInt::U16(file.header.kind), + }); + } + } + } + Ok(()) +} diff --git a/zluda_inject/build.rs b/zluda_inject/build.rs index b971a65..7a66420 100644 --- a/zluda_inject/build.rs +++ b/zluda_inject/build.rs @@ -1,81 +1,81 @@ -use std::{ - env::{self, VarError}, - fs::{self, DirEntry}, - io, - path::{self, PathBuf}, - process::Command, -}; - -fn main() -> Result<(), VarError> { - if std::env::var_os("CARGO_CFG_WINDOWS").is_none() { - return Ok(()); - } - println!("cargo:rerun-if-changed=build.rs"); - if env::var("PROFILE")? != "debug" { - return Ok(()); - } - let rustc_exe = env::var("RUSTC")?; - let out_dir = env::var("OUT_DIR")?; - let target = env::var("TARGET")?; - let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc"; - let opt_level = env::var("OPT_LEVEL")?; - let debug = str::parse::(env::var("DEBUG")?.as_str()).unwrap(); - let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); - helpers_dir.push("tests"); - helpers_dir.push("helpers"); - let helpers_dir_as_string = helpers_dir.to_string_lossy(); - println!("cargo:rerun-if-changed={}", helpers_dir_as_string); - for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) { - let full_file_path = format!( - "{}{}{}", - helpers_dir_as_string, - path::MAIN_SEPARATOR, - rust_file - ); - let mut rustc_cmd = Command::new(&*rustc_exe); - if debug { - rustc_cmd.arg("-g"); - } - rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string)); - if !is_msvc { - // HACK ALERT - // I have no idea why the extra library below have to be linked - rustc_cmd.arg(r"-lucrt"); - } - rustc_cmd - .arg("-C") - .arg(format!("opt-level={}", opt_level)) - .arg("-L") - .arg(format!("{}", out_dir)) - .arg("--out-dir") - .arg(format!("{}", out_dir)) - .arg("--target") - .arg(format!("{}", target)) - .arg(full_file_path); - assert!(rustc_cmd.status().unwrap().success()); - } - std::fs::copy( - format!( - "{}{}do_cuinit_late_clr.exe", - helpers_dir_as_string, - path::MAIN_SEPARATOR - ), - format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR), - ) - .unwrap(); - println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir); - Ok(()) -} - -fn rust_file(entry: io::Result) -> Option { - entry.ok().and_then(|e| { - let os_file_name = e.file_name(); - let file_name = os_file_name.to_string_lossy(); - let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false); - if is_file && file_name.ends_with(".rs") { - Some(file_name.to_string()) - } else { - None - } - }) -} +use std::{ + env::{self, VarError}, + fs::{self, DirEntry}, + io, + path::{self, PathBuf}, + process::Command, +}; + +fn main() -> Result<(), VarError> { + if std::env::var_os("CARGO_CFG_WINDOWS").is_none() { + return Ok(()); + } + println!("cargo:rerun-if-changed=build.rs"); + if env::var("PROFILE")? != "debug" { + return Ok(()); + } + let rustc_exe = env::var("RUSTC")?; + let out_dir = env::var("OUT_DIR")?; + let target = env::var("TARGET")?; + let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc"; + let opt_level = env::var("OPT_LEVEL")?; + let debug = str::parse::(env::var("DEBUG")?.as_str()).unwrap(); + let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); + helpers_dir.push("tests"); + helpers_dir.push("helpers"); + let helpers_dir_as_string = helpers_dir.to_string_lossy(); + println!("cargo:rerun-if-changed={}", helpers_dir_as_string); + for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) { + let full_file_path = format!( + "{}{}{}", + helpers_dir_as_string, + path::MAIN_SEPARATOR, + rust_file + ); + let mut rustc_cmd = Command::new(&*rustc_exe); + if debug { + rustc_cmd.arg("-g"); + } + rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string)); + if !is_msvc { + // HACK ALERT + // I have no idea why the extra library below have to be linked + rustc_cmd.arg(r"-lucrt"); + } + rustc_cmd + .arg("-C") + .arg(format!("opt-level={}", opt_level)) + .arg("-L") + .arg(format!("{}", out_dir)) + .arg("--out-dir") + .arg(format!("{}", out_dir)) + .arg("--target") + .arg(format!("{}", target)) + .arg(full_file_path); + assert!(rustc_cmd.status().unwrap().success()); + } + std::fs::copy( + format!( + "{}{}do_cuinit_late_clr.exe", + helpers_dir_as_string, + path::MAIN_SEPARATOR + ), + format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR), + ) + .unwrap(); + println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir); + Ok(()) +} + +fn rust_file(entry: io::Result) -> Option { + entry.ok().and_then(|e| { + let os_file_name = e.file_name(); + let file_name = os_file_name.to_string_lossy(); + let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false); + if is_file && file_name.ends_with(".rs") { + Some(file_name.to_string()) + } else { + None + } + }) +} diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs index 408f8ab..a642871 100644 --- a/zluda_inject/src/bin.rs +++ b/zluda_inject/src/bin.rs @@ -1,311 +1,311 @@ -use std::env; -use std::os::windows; -use std::os::windows::ffi::OsStrExt; -use std::{error::Error, process}; -use std::{fs, io, ptr}; -use std::{mem, path::PathBuf}; - -use argh::FromArgs; -use mem::size_of_val; -use tempfile::TempDir; -use winapi::um::processenv::SearchPathW; -use winapi::um::{ - jobapi2::{AssignProcessToJobObject, SetInformationJobObject}, - processthreadsapi::{GetExitCodeProcess, ResumeThread}, - synchapi::WaitForSingleObject, - winbase::CreateJobObjectA, - winnt::{ - JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION, - JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, - }, -}; - -use winapi::um::winbase::{INFINITE, WAIT_FAILED}; - -static REDIRECT_DLL: &'static str = "zluda_redirect.dll"; -static NVCUDA_DLL: &'static str = "nvcuda.dll"; -static NVML_DLL: &'static str = "nvml.dll"; - -include!("../../zluda_redirect/src/payload_guid.rs"); - -#[derive(FromArgs)] -/// Launch application with custom CUDA libraries -struct ProgramArguments { - /// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory - #[argh(option)] - nvcuda: Option, - - /// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory - #[argh(option)] - nvml: Option, - - /// executable to be injected with custom CUDA libraries - #[argh(positional)] - exe: String, - - /// arguments to the executable - #[argh(positional)] - args: Vec, -} - -pub fn main_impl() -> Result<(), Box> { - let raw_args = argh::from_env::(); - let normalized_args = NormalizedArguments::new(raw_args)?; - let mut environment = Environment::setup(normalized_args)?; - let mut startup_info = unsafe { mem::zeroed::() }; - let mut proc_info = unsafe { mem::zeroed::() }; - let mut dlls_to_inject = [ - environment.nvml_path_zero_terminated.as_ptr() as *const i8, - environment.nvcuda_path_zero_terminated.as_ptr() as _, - environment.redirect_path_zero_terminated.as_ptr() as _, - ]; - os_call!( - detours_sys::DetourCreateProcessWithDllsW( - ptr::null(), - environment.winapi_command_line_zero_terminated.as_mut_ptr(), - ptr::null_mut(), - ptr::null_mut(), - 0, - 0, - ptr::null_mut(), - ptr::null(), - &mut startup_info as *mut _, - &mut proc_info as *mut _, - dlls_to_inject.len() as u32, - dlls_to_inject.as_mut_ptr(), - Option::None - ), - |x| x != 0 - ); - kill_child_on_process_exit(proc_info.hProcess)?; - os_call!( - detours_sys::DetourCopyPayloadToProcess( - proc_info.hProcess, - &PAYLOAD_NVCUDA_GUID, - environment.nvcuda_path_zero_terminated.as_ptr() as *mut _, - environment.nvcuda_path_zero_terminated.len() as u32 - ), - |x| x != 0 - ); - os_call!( - detours_sys::DetourCopyPayloadToProcess( - proc_info.hProcess, - &PAYLOAD_NVML_GUID, - environment.nvml_path_zero_terminated.as_ptr() as *mut _, - environment.nvml_path_zero_terminated.len() as u32 - ), - |x| x != 0 - ); - os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1); - os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x - != WAIT_FAILED); - let mut child_exit_code: u32 = 0; - os_call!( - GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _), - |x| x != 0 - ); - process::exit(child_exit_code as i32) -} - -struct NormalizedArguments { - nvml_path: PathBuf, - nvcuda_path: PathBuf, - redirect_path: PathBuf, - winapi_command_line_zero_terminated: Vec, -} - -impl NormalizedArguments { - fn new(prog_args: ProgramArguments) -> Result> { - let current_exe = env::current_exe()?; - let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?; - let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?; - let winapi_command_line_zero_terminated = - construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args)); - let mut redirect_path = current_exe.parent().unwrap().to_path_buf(); - redirect_path.push(REDIRECT_DLL); - Ok(Self { - nvml_path, - nvcuda_path, - redirect_path, - winapi_command_line_zero_terminated, - }) - } - - const WIN_MAX_PATH: usize = 260; - - fn get_absolute_path( - current_exe: &PathBuf, - dll: Option, - default: &str, - ) -> Result> { - Ok(if let Some(dll) = dll { - if dll.is_absolute() { - dll - } else { - let mut full_dll_path = vec![0; Self::WIN_MAX_PATH]; - let mut dll_utf16 = dll.as_os_str().encode_wide().collect::>(); - dll_utf16.push(0); - loop { - let copied_len = os_call!( - SearchPathW( - ptr::null_mut(), - dll_utf16.as_ptr(), - ptr::null(), - full_dll_path.len() as u32, - full_dll_path.as_mut_ptr(), - ptr::null_mut() - ), - |x| x != 0 - ) as usize; - if copied_len > full_dll_path.len() { - full_dll_path.resize(copied_len + 1, 0); - } else { - full_dll_path.truncate(copied_len); - break; - } - } - PathBuf::from(String::from_utf16_lossy(&full_dll_path)) - } - } else { - let mut dll_path = current_exe.parent().unwrap().to_path_buf(); - dll_path.push(default); - dll_path - }) - } -} - -struct Environment { - nvml_path_zero_terminated: String, - nvcuda_path_zero_terminated: String, - redirect_path_zero_terminated: String, - winapi_command_line_zero_terminated: Vec, - _temp_dir: TempDir, -} - -// This structs represents "enviroment". By environment we mean all paths -// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary -// directory which contains nvcuda.dll -impl Environment { - fn setup(args: NormalizedArguments) -> io::Result { - let _temp_dir = TempDir::new()?; - let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( - args.nvml_path, - &_temp_dir, - NVML_DLL, - )?); - let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( - args.nvcuda_path, - &_temp_dir, - NVCUDA_DLL, - )?); - let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path); - Ok(Self { - nvml_path_zero_terminated, - nvcuda_path_zero_terminated, - redirect_path_zero_terminated, - winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated, - _temp_dir, - }) - } - - fn copy_to_correct_name( - path_buf: PathBuf, - temp_dir: &TempDir, - correct_name: &str, - ) -> io::Result { - let file_name = path_buf.file_name().unwrap(); - if file_name == correct_name { - Ok(path_buf) - } else { - let mut temp_file_path = temp_dir.path().to_path_buf(); - temp_file_path.push(correct_name); - match windows::fs::symlink_file(&path_buf, &temp_file_path) { - Ok(()) => {} - Err(_) => { - fs::copy(&path_buf, &temp_file_path)?; - } - } - Ok(temp_file_path) - } - } - - fn zero_terminate(p: PathBuf) -> String { - let mut s = p.to_string_lossy().to_string(); - s.push('\0'); - s - } -} - -fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box> { - let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x - != ptr::null_mut()); - let mut info = unsafe { mem::zeroed::() }; - info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; - os_call!( - SetInformationJobObject( - job_handle, - JobObjectExtendedLimitInformation, - &mut info as *mut _ as *mut _, - size_of_val(&info) as u32 - ), - |x| x != 0 - ); - os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0); - Ok(()) -} - -// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way -fn construct_command_line(args: impl Iterator) -> Vec { - let mut cmd_line = Vec::new(); - let args_len = args.size_hint().0; - for (idx, arg) in args.enumerate() { - if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) { - cmd_line.extend(arg.encode_utf16()); - } else { - cmd_line.push('"' as u16); // " - let mut char_iter = arg.chars().peekable(); - loop { - let mut current = char_iter.next(); - let mut backslashes = 0; - match current { - Some('\\') => { - backslashes = 1; - while let Some('\\') = char_iter.peek() { - backslashes += 1; - char_iter.next(); - } - current = char_iter.next(); - } - _ => {} - } - match current { - None => { - for _ in 0..(backslashes * 2) { - cmd_line.push('\\' as u16); - } - break; - } - Some('"') => { - for _ in 0..(backslashes * 2 + 1) { - cmd_line.push('\\' as u16); - } - cmd_line.push('"' as u16); - } - Some(c) => { - for _ in 0..backslashes { - cmd_line.push('\\' as u16); - } - let mut temp = [0u16; 2]; - cmd_line.extend(&*c.encode_utf16(&mut temp)); - } - } - } - cmd_line.push('"' as u16); - } - if idx < args_len - 1 { - cmd_line.push(' ' as u16); - } - } - cmd_line.push(0); - cmd_line -} +use std::env; +use std::os::windows; +use std::os::windows::ffi::OsStrExt; +use std::{error::Error, process}; +use std::{fs, io, ptr}; +use std::{mem, path::PathBuf}; + +use argh::FromArgs; +use mem::size_of_val; +use tempfile::TempDir; +use winapi::um::processenv::SearchPathW; +use winapi::um::{ + jobapi2::{AssignProcessToJobObject, SetInformationJobObject}, + processthreadsapi::{GetExitCodeProcess, ResumeThread}, + synchapi::WaitForSingleObject, + winbase::CreateJobObjectA, + winnt::{ + JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION, + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, +}; + +use winapi::um::winbase::{INFINITE, WAIT_FAILED}; + +static REDIRECT_DLL: &'static str = "zluda_redirect.dll"; +static NVCUDA_DLL: &'static str = "nvcuda.dll"; +static NVML_DLL: &'static str = "nvml.dll"; + +include!("../../zluda_redirect/src/payload_guid.rs"); + +#[derive(FromArgs)] +/// Launch application with custom CUDA libraries +struct ProgramArguments { + /// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory + #[argh(option)] + nvcuda: Option, + + /// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory + #[argh(option)] + nvml: Option, + + /// executable to be injected with custom CUDA libraries + #[argh(positional)] + exe: String, + + /// arguments to the executable + #[argh(positional)] + args: Vec, +} + +pub fn main_impl() -> Result<(), Box> { + let raw_args = argh::from_env::(); + let normalized_args = NormalizedArguments::new(raw_args)?; + let mut environment = Environment::setup(normalized_args)?; + let mut startup_info = unsafe { mem::zeroed::() }; + let mut proc_info = unsafe { mem::zeroed::() }; + let mut dlls_to_inject = [ + environment.nvml_path_zero_terminated.as_ptr() as *const i8, + environment.nvcuda_path_zero_terminated.as_ptr() as _, + environment.redirect_path_zero_terminated.as_ptr() as _, + ]; + os_call!( + detours_sys::DetourCreateProcessWithDllsW( + ptr::null(), + environment.winapi_command_line_zero_terminated.as_mut_ptr(), + ptr::null_mut(), + ptr::null_mut(), + 0, + 0, + ptr::null_mut(), + ptr::null(), + &mut startup_info as *mut _, + &mut proc_info as *mut _, + dlls_to_inject.len() as u32, + dlls_to_inject.as_mut_ptr(), + Option::None + ), + |x| x != 0 + ); + kill_child_on_process_exit(proc_info.hProcess)?; + os_call!( + detours_sys::DetourCopyPayloadToProcess( + proc_info.hProcess, + &PAYLOAD_NVCUDA_GUID, + environment.nvcuda_path_zero_terminated.as_ptr() as *mut _, + environment.nvcuda_path_zero_terminated.len() as u32 + ), + |x| x != 0 + ); + os_call!( + detours_sys::DetourCopyPayloadToProcess( + proc_info.hProcess, + &PAYLOAD_NVML_GUID, + environment.nvml_path_zero_terminated.as_ptr() as *mut _, + environment.nvml_path_zero_terminated.len() as u32 + ), + |x| x != 0 + ); + os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1); + os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x + != WAIT_FAILED); + let mut child_exit_code: u32 = 0; + os_call!( + GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _), + |x| x != 0 + ); + process::exit(child_exit_code as i32) +} + +struct NormalizedArguments { + nvml_path: PathBuf, + nvcuda_path: PathBuf, + redirect_path: PathBuf, + winapi_command_line_zero_terminated: Vec, +} + +impl NormalizedArguments { + fn new(prog_args: ProgramArguments) -> Result> { + let current_exe = env::current_exe()?; + let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?; + let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?; + let winapi_command_line_zero_terminated = + construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args)); + let mut redirect_path = current_exe.parent().unwrap().to_path_buf(); + redirect_path.push(REDIRECT_DLL); + Ok(Self { + nvml_path, + nvcuda_path, + redirect_path, + winapi_command_line_zero_terminated, + }) + } + + const WIN_MAX_PATH: usize = 260; + + fn get_absolute_path( + current_exe: &PathBuf, + dll: Option, + default: &str, + ) -> Result> { + Ok(if let Some(dll) = dll { + if dll.is_absolute() { + dll + } else { + let mut full_dll_path = vec![0; Self::WIN_MAX_PATH]; + let mut dll_utf16 = dll.as_os_str().encode_wide().collect::>(); + dll_utf16.push(0); + loop { + let copied_len = os_call!( + SearchPathW( + ptr::null_mut(), + dll_utf16.as_ptr(), + ptr::null(), + full_dll_path.len() as u32, + full_dll_path.as_mut_ptr(), + ptr::null_mut() + ), + |x| x != 0 + ) as usize; + if copied_len > full_dll_path.len() { + full_dll_path.resize(copied_len + 1, 0); + } else { + full_dll_path.truncate(copied_len); + break; + } + } + PathBuf::from(String::from_utf16_lossy(&full_dll_path)) + } + } else { + let mut dll_path = current_exe.parent().unwrap().to_path_buf(); + dll_path.push(default); + dll_path + }) + } +} + +struct Environment { + nvml_path_zero_terminated: String, + nvcuda_path_zero_terminated: String, + redirect_path_zero_terminated: String, + winapi_command_line_zero_terminated: Vec, + _temp_dir: TempDir, +} + +// This structs represents "enviroment". By environment we mean all paths +// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary +// directory which contains nvcuda.dll +impl Environment { + fn setup(args: NormalizedArguments) -> io::Result { + let _temp_dir = TempDir::new()?; + let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.nvml_path, + &_temp_dir, + NVML_DLL, + )?); + let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.nvcuda_path, + &_temp_dir, + NVCUDA_DLL, + )?); + let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path); + Ok(Self { + nvml_path_zero_terminated, + nvcuda_path_zero_terminated, + redirect_path_zero_terminated, + winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated, + _temp_dir, + }) + } + + fn copy_to_correct_name( + path_buf: PathBuf, + temp_dir: &TempDir, + correct_name: &str, + ) -> io::Result { + let file_name = path_buf.file_name().unwrap(); + if file_name == correct_name { + Ok(path_buf) + } else { + let mut temp_file_path = temp_dir.path().to_path_buf(); + temp_file_path.push(correct_name); + match windows::fs::symlink_file(&path_buf, &temp_file_path) { + Ok(()) => {} + Err(_) => { + fs::copy(&path_buf, &temp_file_path)?; + } + } + Ok(temp_file_path) + } + } + + fn zero_terminate(p: PathBuf) -> String { + let mut s = p.to_string_lossy().to_string(); + s.push('\0'); + s + } +} + +fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box> { + let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x + != ptr::null_mut()); + let mut info = unsafe { mem::zeroed::() }; + info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + os_call!( + SetInformationJobObject( + job_handle, + JobObjectExtendedLimitInformation, + &mut info as *mut _ as *mut _, + size_of_val(&info) as u32 + ), + |x| x != 0 + ); + os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0); + Ok(()) +} + +// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way +fn construct_command_line(args: impl Iterator) -> Vec { + let mut cmd_line = Vec::new(); + let args_len = args.size_hint().0; + for (idx, arg) in args.enumerate() { + if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) { + cmd_line.extend(arg.encode_utf16()); + } else { + cmd_line.push('"' as u16); // " + let mut char_iter = arg.chars().peekable(); + loop { + let mut current = char_iter.next(); + let mut backslashes = 0; + match current { + Some('\\') => { + backslashes = 1; + while let Some('\\') = char_iter.peek() { + backslashes += 1; + char_iter.next(); + } + current = char_iter.next(); + } + _ => {} + } + match current { + None => { + for _ in 0..(backslashes * 2) { + cmd_line.push('\\' as u16); + } + break; + } + Some('"') => { + for _ in 0..(backslashes * 2 + 1) { + cmd_line.push('\\' as u16); + } + cmd_line.push('"' as u16); + } + Some(c) => { + for _ in 0..backslashes { + cmd_line.push('\\' as u16); + } + let mut temp = [0u16; 2]; + cmd_line.extend(&*c.encode_utf16(&mut temp)); + } + } + } + cmd_line.push('"' as u16); + } + if idx < args_len - 1 { + cmd_line.push(' ' as u16); + } + } + cmd_line.push(0); + cmd_line +} diff --git a/zluda_inject/src/main.rs b/zluda_inject/src/main.rs index 201802b..fb3bfcf 100644 --- a/zluda_inject/src/main.rs +++ b/zluda_inject/src/main.rs @@ -1,13 +1,13 @@ -#[macro_use] -#[cfg(target_os = "windows")] -mod win; -#[cfg(target_os = "windows")] -mod bin; - -#[cfg(target_os = "windows")] -fn main() -> Result<(), Box> { - bin::main_impl() -} - -#[cfg(not(target_os = "windows"))] -fn main() {} +#[macro_use] +#[cfg(target_os = "windows")] +mod win; +#[cfg(target_os = "windows")] +mod bin; + +#[cfg(target_os = "windows")] +fn main() -> Result<(), Box> { + bin::main_impl() +} + +#[cfg(not(target_os = "windows"))] +fn main() {} diff --git a/zluda_inject/src/win.rs b/zluda_inject/src/win.rs index 4d7fcdd..5c66ba0 100644 --- a/zluda_inject/src/win.rs +++ b/zluda_inject/src/win.rs @@ -1,151 +1,151 @@ -#![allow(non_snake_case)] - -use std::error; -use std::fmt; -use std::ptr; - -mod c { - use std::ffi::c_void; - use std::os::raw::c_ulong; - - pub type DWORD = c_ulong; - pub type HANDLE = LPVOID; - pub type LPVOID = *mut c_void; - pub type HINSTANCE = HANDLE; - pub type HMODULE = HINSTANCE; - pub type WCHAR = u16; - pub type LPCWSTR = *const WCHAR; - pub type LPWSTR = *mut WCHAR; - - pub const FACILITY_NT_BIT: DWORD = 0x1000_0000; - pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800; - pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000; - pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200; - - extern "system" { - pub fn GetLastError() -> DWORD; - pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE; - pub fn FormatMessageW( - flags: DWORD, - lpSrc: LPVOID, - msgId: DWORD, - langId: DWORD, - buf: LPWSTR, - nsize: DWORD, - args: *const c_void, - ) -> DWORD; - } -} - -macro_rules! last_ident { - ($i:ident) => { - stringify!($i) - }; - ($start:ident, $($cont:ident),+) => { - last_ident!($($cont),+) - }; -} - -macro_rules! os_call { - ($($path:ident)::+ ($($args:expr),*), $success:expr) => { - { - let result = unsafe{ $($path)::+ ($($args),*) }; - if !($success)(result) { - let name = last_ident!($($path),+); - let err_code = $crate::win::errno(); - Err($crate::win::OsError{ - function: name, - error_code: err_code as u32, - message: $crate::win::error_string(err_code) - })?; - } - result - } - }; -} - -#[derive(Debug)] -pub struct OsError { - pub function: &'static str, - pub error_code: u32, - pub message: String, -} - -impl fmt::Display for OsError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", self) - } -} - -impl error::Error for OsError { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - None - } -} - -pub fn errno() -> i32 { - unsafe { c::GetLastError() as i32 } -} - -/// Gets a detailed string description for the given error number. -pub fn error_string(mut errnum: i32) -> String { - // This value is calculated from the macro - // MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT) - let langId = 0x0800 as c::DWORD; - - let mut buf = [0 as c::WCHAR; 2048]; - - unsafe { - let mut module = ptr::null_mut(); - let mut flags = 0; - - // NTSTATUS errors may be encoded as HRESULT, which may returned from - // GetLastError. For more information about Windows error codes, see - // `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx - if (errnum & c::FACILITY_NT_BIT as i32) != 0 { - // format according to https://support.microsoft.com/en-us/help/259693 - const NTDLL_DLL: &[u16] = &[ - 'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _, - 'L' as _, 0, - ]; - module = c::GetModuleHandleW(NTDLL_DLL.as_ptr()); - - if module != ptr::null_mut() { - errnum ^= c::FACILITY_NT_BIT as i32; - flags = c::FORMAT_MESSAGE_FROM_HMODULE; - } - } - - let res = c::FormatMessageW( - flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS, - module, - errnum as c::DWORD, - langId, - buf.as_mut_ptr(), - buf.len() as c::DWORD, - ptr::null(), - ) as usize; - if res == 0 { - // Sometimes FormatMessageW can fail e.g., system doesn't like langId, - let fm_err = errno(); - return format!( - "OS Error {} (FormatMessageW() returned error {})", - errnum, fm_err - ); - } - - match String::from_utf16(&buf[..res]) { - Ok(mut msg) => { - // Trim trailing CRLF inserted by FormatMessageW - let len = msg.trim_end().len(); - msg.truncate(len); - msg - } - Err(..) => format!( - "OS Error {} (FormatMessageW() returned \ - invalid UTF-16)", - errnum - ), - } - } -} +#![allow(non_snake_case)] + +use std::error; +use std::fmt; +use std::ptr; + +mod c { + use std::ffi::c_void; + use std::os::raw::c_ulong; + + pub type DWORD = c_ulong; + pub type HANDLE = LPVOID; + pub type LPVOID = *mut c_void; + pub type HINSTANCE = HANDLE; + pub type HMODULE = HINSTANCE; + pub type WCHAR = u16; + pub type LPCWSTR = *const WCHAR; + pub type LPWSTR = *mut WCHAR; + + pub const FACILITY_NT_BIT: DWORD = 0x1000_0000; + pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800; + pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000; + pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200; + + extern "system" { + pub fn GetLastError() -> DWORD; + pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE; + pub fn FormatMessageW( + flags: DWORD, + lpSrc: LPVOID, + msgId: DWORD, + langId: DWORD, + buf: LPWSTR, + nsize: DWORD, + args: *const c_void, + ) -> DWORD; + } +} + +macro_rules! last_ident { + ($i:ident) => { + stringify!($i) + }; + ($start:ident, $($cont:ident),+) => { + last_ident!($($cont),+) + }; +} + +macro_rules! os_call { + ($($path:ident)::+ ($($args:expr),*), $success:expr) => { + { + let result = unsafe{ $($path)::+ ($($args),*) }; + if !($success)(result) { + let name = last_ident!($($path),+); + let err_code = $crate::win::errno(); + Err($crate::win::OsError{ + function: name, + error_code: err_code as u32, + message: $crate::win::error_string(err_code) + })?; + } + result + } + }; +} + +#[derive(Debug)] +pub struct OsError { + pub function: &'static str, + pub error_code: u32, + pub message: String, +} + +impl fmt::Display for OsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl error::Error for OsError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + None + } +} + +pub fn errno() -> i32 { + unsafe { c::GetLastError() as i32 } +} + +/// Gets a detailed string description for the given error number. +pub fn error_string(mut errnum: i32) -> String { + // This value is calculated from the macro + // MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT) + let langId = 0x0800 as c::DWORD; + + let mut buf = [0 as c::WCHAR; 2048]; + + unsafe { + let mut module = ptr::null_mut(); + let mut flags = 0; + + // NTSTATUS errors may be encoded as HRESULT, which may returned from + // GetLastError. For more information about Windows error codes, see + // `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx + if (errnum & c::FACILITY_NT_BIT as i32) != 0 { + // format according to https://support.microsoft.com/en-us/help/259693 + const NTDLL_DLL: &[u16] = &[ + 'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _, + 'L' as _, 0, + ]; + module = c::GetModuleHandleW(NTDLL_DLL.as_ptr()); + + if module != ptr::null_mut() { + errnum ^= c::FACILITY_NT_BIT as i32; + flags = c::FORMAT_MESSAGE_FROM_HMODULE; + } + } + + let res = c::FormatMessageW( + flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS, + module, + errnum as c::DWORD, + langId, + buf.as_mut_ptr(), + buf.len() as c::DWORD, + ptr::null(), + ) as usize; + if res == 0 { + // Sometimes FormatMessageW can fail e.g., system doesn't like langId, + let fm_err = errno(); + return format!( + "OS Error {} (FormatMessageW() returned error {})", + errnum, fm_err + ); + } + + match String::from_utf16(&buf[..res]) { + Ok(mut msg) => { + // Trim trailing CRLF inserted by FormatMessageW + let len = msg.trim_end().len(); + msg.truncate(len); + msg + } + Err(..) => format!( + "OS Error {} (FormatMessageW() returned \ + invalid UTF-16)", + errnum + ), + } + } +} diff --git a/zluda_inject/tests/inject.rs b/zluda_inject/tests/inject.rs index f897f9c..bf7bcdf 100644 --- a/zluda_inject/tests/inject.rs +++ b/zluda_inject/tests/inject.rs @@ -1,51 +1,51 @@ -#![cfg(windows)] -use std::{env, io, path::PathBuf, process::Command}; - -#[test] -fn direct_cuinit() -> io::Result<()> { - run_process_and_check_for_zluda_dump("direct_cuinit") -} - -#[test] -fn do_cuinit_early() -> io::Result<()> { - run_process_and_check_for_zluda_dump("do_cuinit_early") -} - -#[test] -fn do_cuinit_late() -> io::Result<()> { - run_process_and_check_for_zluda_dump("do_cuinit_late") -} - -#[test] -fn do_cuinit_late_clr() -> io::Result<()> { - run_process_and_check_for_zluda_dump("do_cuinit_late_clr") -} - -#[test] -fn indirect_cuinit() -> io::Result<()> { - run_process_and_check_for_zluda_dump("indirect_cuinit") -} - -#[test] -fn subprocess() -> io::Result<()> { - run_process_and_check_for_zluda_dump("subprocess") -} - -fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> { - let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with")); - let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf(); - zluda_dump_dll.push("zluda_dump.dll"); - let helpers_dir = env!("HELPERS_OUT_DIR"); - let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name); - let mut test_cmd = Command::new(&zluda_with_exe); - let test_cmd = test_cmd - .arg("--nvcuda") - .arg(&zluda_dump_dll) - .arg("--") - .arg(&exe_under_test); - let test_output = test_cmd.output()?; - assert!(test_output.status.success()); - let stderr_text = String::from_utf8(test_output.stderr).unwrap(); - assert!(stderr_text.contains("ZLUDA_DUMP")); - Ok(()) -} +#![cfg(windows)] +use std::{env, io, path::PathBuf, process::Command}; + +#[test] +fn direct_cuinit() -> io::Result<()> { + run_process_and_check_for_zluda_dump("direct_cuinit") +} + +#[test] +fn do_cuinit_early() -> io::Result<()> { + run_process_and_check_for_zluda_dump("do_cuinit_early") +} + +#[test] +fn do_cuinit_late() -> io::Result<()> { + run_process_and_check_for_zluda_dump("do_cuinit_late") +} + +#[test] +fn do_cuinit_late_clr() -> io::Result<()> { + run_process_and_check_for_zluda_dump("do_cuinit_late_clr") +} + +#[test] +fn indirect_cuinit() -> io::Result<()> { + run_process_and_check_for_zluda_dump("indirect_cuinit") +} + +#[test] +fn subprocess() -> io::Result<()> { + run_process_and_check_for_zluda_dump("subprocess") +} + +fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> { + let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with")); + let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf(); + zluda_dump_dll.push("zluda_dump.dll"); + let helpers_dir = env!("HELPERS_OUT_DIR"); + let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name); + let mut test_cmd = Command::new(&zluda_with_exe); + let test_cmd = test_cmd + .arg("--nvcuda") + .arg(&zluda_dump_dll) + .arg("--") + .arg(&exe_under_test); + let test_output = test_cmd.output()?; + assert!(test_output.status.success()); + let stderr_text = String::from_utf8(test_output.stderr).unwrap(); + assert!(stderr_text.contains("ZLUDA_DUMP")); + Ok(()) +}