Always use Unix line endings (#453)

This commit is contained in:
Violet
2025-07-30 15:09:47 -07:00
committed by GitHub
parent 21ef5f60a3
commit b8bcbec295
28 changed files with 5208 additions and 5207 deletions

1
.rustfmt.toml Normal file
View File

@ -0,0 +1 @@
newline_style = "Unix"

View File

@ -1,147 +1,147 @@
use cmake::Config; use cmake::Config;
use std::io; use std::io;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Command; use std::process::Command;
const COMPONENTS: &[&'static str] = &[ const COMPONENTS: &[&'static str] = &[
"LLVMCore", "LLVMCore",
"LLVMBitWriter", "LLVMBitWriter",
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
"LLVMAnalysis", // for module verify "LLVMAnalysis", // for module verify
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
"LLVMBitReader", "LLVMBitReader",
]; ];
fn main() { fn main() {
let mut cmake = Config::new(r"../ext/llvm-project/llvm"); let mut cmake = Config::new(r"../ext/llvm-project/llvm");
try_use_sccache(&mut cmake); try_use_sccache(&mut cmake);
try_use_ninja(&mut cmake); try_use_ninja(&mut cmake);
cmake cmake
// It's not like we can do anything about the warnings // It's not like we can do anything about the warnings
.define("LLVM_ENABLE_WARNINGS", "OFF") .define("LLVM_ENABLE_WARNINGS", "OFF")
// For some reason Rust always links to release CRT // For some reason Rust always links to release CRT
.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded") .define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded")
.define("LLVM_ENABLE_TERMINFO", "OFF") .define("LLVM_ENABLE_TERMINFO", "OFF")
.define("LLVM_ENABLE_LIBXML2", "OFF") .define("LLVM_ENABLE_LIBXML2", "OFF")
.define("LLVM_ENABLE_LIBEDIT", "OFF") .define("LLVM_ENABLE_LIBEDIT", "OFF")
.define("LLVM_ENABLE_LIBPFM", "OFF") .define("LLVM_ENABLE_LIBPFM", "OFF")
.define("LLVM_ENABLE_ZLIB", "OFF") .define("LLVM_ENABLE_ZLIB", "OFF")
.define("LLVM_ENABLE_ZSTD", "OFF") .define("LLVM_ENABLE_ZSTD", "OFF")
.define("LLVM_INCLUDE_BENCHMARKS", "OFF") .define("LLVM_INCLUDE_BENCHMARKS", "OFF")
.define("LLVM_INCLUDE_EXAMPLES", "OFF") .define("LLVM_INCLUDE_EXAMPLES", "OFF")
.define("LLVM_INCLUDE_TESTS", "OFF") .define("LLVM_INCLUDE_TESTS", "OFF")
.define("LLVM_BUILD_TOOLS", "OFF") .define("LLVM_BUILD_TOOLS", "OFF")
.define("LLVM_TARGETS_TO_BUILD", "") .define("LLVM_TARGETS_TO_BUILD", "")
.define("LLVM_ENABLE_PROJECTS", ""); .define("LLVM_ENABLE_PROJECTS", "");
cmake.build_target("llvm-config"); cmake.build_target("llvm-config");
let llvm_dir = cmake.build(); let llvm_dir = cmake.build();
for c in COMPONENTS { for c in COMPONENTS {
cmake.build_target(c); cmake.build_target(c);
cmake.build(); cmake.build();
} }
let cmake_profile = cmake.get_profile(); let cmake_profile = cmake.get_profile();
let (cxxflags, ldflags, libdir, lib_names, system_libs) = let (cxxflags, ldflags, libdir, lib_names, system_libs) =
llvm_config(&llvm_dir, &["build", "bin", "llvm-config"]) llvm_config(&llvm_dir, &["build", "bin", "llvm-config"])
.or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"])) .or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"]))
.unwrap(); .unwrap();
println!("cargo:rustc-link-arg={ldflags}"); println!("cargo:rustc-link-arg={ldflags}");
println!("cargo:rustc-link-search=native={libdir}"); println!("cargo:rustc-link-search=native={libdir}");
for lib in system_libs.split_ascii_whitespace() { for lib in system_libs.split_ascii_whitespace() {
println!("cargo:rustc-link-arg={lib}"); println!("cargo:rustc-link-arg={lib}");
} }
link_llvm_components(lib_names); link_llvm_components(lib_names);
compile_cxx_lib(cxxflags); compile_cxx_lib(cxxflags);
} }
// https://github.com/mozilla/sccache/blob/main/README.md#usage // https://github.com/mozilla/sccache/blob/main/README.md#usage
fn try_use_sccache(cmake: &mut Config) { fn try_use_sccache(cmake: &mut Config) {
if let Ok(sccache) = std::env::var("SCCACHE_PATH") { if let Ok(sccache) = std::env::var("SCCACHE_PATH") {
cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache); cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache);
cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache); cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache);
match std::env::var_os("CARGO_CFG_TARGET_OS") { match std::env::var_os("CARGO_CFG_TARGET_OS") {
Some(os) if os == "windows" => { Some(os) if os == "windows" => {
cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded"); cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded");
cmake.define("CMAKE_POLICY_CMP0141", "NEW"); cmake.define("CMAKE_POLICY_CMP0141", "NEW");
} }
_ => {} _ => {}
} }
} }
} }
fn try_use_ninja(cmake: &mut Config) { fn try_use_ninja(cmake: &mut Config) {
let mut cmd = Command::new("ninja"); let mut cmd = Command::new("ninja");
cmd.arg("--version"); cmd.arg("--version");
if let Ok(status) = cmd.status() { if let Ok(status) = cmd.status() {
if status.success() { if status.success() {
cmake.generator("Ninja"); cmake.generator("Ninja");
} }
} }
} }
fn llvm_config( fn llvm_config(
llvm_build_dir: &PathBuf, llvm_build_dir: &PathBuf,
path_to_llvm_config: &[&str], path_to_llvm_config: &[&str],
) -> io::Result<(String, String, String, String, String)> { ) -> io::Result<(String, String, String, String, String)> {
let mut llvm_build_path = llvm_build_dir.clone(); let mut llvm_build_path = llvm_build_dir.clone();
llvm_build_path.extend(path_to_llvm_config); llvm_build_path.extend(path_to_llvm_config);
let mut cmd = Command::new(llvm_build_path); let mut cmd = Command::new(llvm_build_path);
cmd.args([ cmd.args([
"--link-static", "--link-static",
"--cxxflags", "--cxxflags",
"--ldflags", "--ldflags",
"--libdir", "--libdir",
"--libnames", "--libnames",
"--system-libs", "--system-libs",
]); ]);
for c in COMPONENTS { for c in COMPONENTS {
cmd.arg(c[4..].to_lowercase()); cmd.arg(c[4..].to_lowercase());
} }
let output = cmd.output()?; let output = cmd.output()?;
if !output.status.success() { if !output.status.success() {
return Err(io::Error::from(io::ErrorKind::Other)); return Err(io::Error::from(io::ErrorKind::Other));
} }
let output = unsafe { String::from_utf8_unchecked(output.stdout) }; let output = unsafe { String::from_utf8_unchecked(output.stdout) };
let mut lines = output.lines(); let mut lines = output.lines();
let cxxflags = lines.next().unwrap(); let cxxflags = lines.next().unwrap();
let ldflags = lines.next().unwrap(); let ldflags = lines.next().unwrap();
let libdir = lines.next().unwrap(); let libdir = lines.next().unwrap();
let lib_names = lines.next().unwrap(); let lib_names = lines.next().unwrap();
let system_libs = lines.next().unwrap(); let system_libs = lines.next().unwrap();
Ok(( Ok((
cxxflags.to_string(), cxxflags.to_string(),
ldflags.to_string(), ldflags.to_string(),
libdir.to_string(), libdir.to_string(),
lib_names.to_string(), lib_names.to_string(),
system_libs.to_string(), system_libs.to_string(),
)) ))
} }
fn compile_cxx_lib(cxxflags: String) { fn compile_cxx_lib(cxxflags: String) {
let mut cc = cc::Build::new(); let mut cc = cc::Build::new();
for flag in cxxflags.split_whitespace() { for flag in cxxflags.split_whitespace() {
cc.flag(flag); cc.flag(flag);
} }
cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp"); cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp");
println!("cargo:rerun-if-changed=src/lib.cpp"); println!("cargo:rerun-if-changed=src/lib.cpp");
println!("cargo:rerun-if-changed=src/lib.rs"); println!("cargo:rerun-if-changed=src/lib.rs");
} }
fn link_llvm_components(components: String) { fn link_llvm_components(components: String) {
for component in components.split_whitespace() { for component in components.split_whitespace() {
let component = if let Some(component) = component let component = if let Some(component) = component
.strip_prefix("lib") .strip_prefix("lib")
.and_then(|component| component.strip_suffix(".a")) .and_then(|component| component.strip_suffix(".a"))
{ {
// Unix (Linux/Mac) // Unix (Linux/Mac)
// libLLVMfoo.a // libLLVMfoo.a
component component
} else if let Some(component) = component.strip_suffix(".lib") { } else if let Some(component) = component.strip_suffix(".lib") {
// Windows // Windows
// LLVMfoo.lib // LLVMfoo.lib
component component
} else { } else {
panic!("'{}' does not look like a static library name", component) panic!("'{}' does not look like a static library name", component)
}; };
println!("cargo:rustc-link-lib={component}"); println!("cargo:rustc-link-lib={component}");
} }
} }

View File

@ -1,81 +1,81 @@
#![allow(non_upper_case_globals)] #![allow(non_upper_case_globals)]
use llvm_sys::prelude::*; use llvm_sys::prelude::*;
pub use llvm_sys::*; pub use llvm_sys::*;
#[repr(C)] #[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
pub enum LLVMZludaAtomicRMWBinOp { pub enum LLVMZludaAtomicRMWBinOp {
LLVMZludaAtomicRMWBinOpXchg = 0, LLVMZludaAtomicRMWBinOpXchg = 0,
LLVMZludaAtomicRMWBinOpAdd = 1, LLVMZludaAtomicRMWBinOpAdd = 1,
LLVMZludaAtomicRMWBinOpSub = 2, LLVMZludaAtomicRMWBinOpSub = 2,
LLVMZludaAtomicRMWBinOpAnd = 3, LLVMZludaAtomicRMWBinOpAnd = 3,
LLVMZludaAtomicRMWBinOpNand = 4, LLVMZludaAtomicRMWBinOpNand = 4,
LLVMZludaAtomicRMWBinOpOr = 5, LLVMZludaAtomicRMWBinOpOr = 5,
LLVMZludaAtomicRMWBinOpXor = 6, LLVMZludaAtomicRMWBinOpXor = 6,
LLVMZludaAtomicRMWBinOpMax = 7, LLVMZludaAtomicRMWBinOpMax = 7,
LLVMZludaAtomicRMWBinOpMin = 8, LLVMZludaAtomicRMWBinOpMin = 8,
LLVMZludaAtomicRMWBinOpUMax = 9, LLVMZludaAtomicRMWBinOpUMax = 9,
LLVMZludaAtomicRMWBinOpUMin = 10, LLVMZludaAtomicRMWBinOpUMin = 10,
LLVMZludaAtomicRMWBinOpFAdd = 11, LLVMZludaAtomicRMWBinOpFAdd = 11,
LLVMZludaAtomicRMWBinOpFSub = 12, LLVMZludaAtomicRMWBinOpFSub = 12,
LLVMZludaAtomicRMWBinOpFMax = 13, LLVMZludaAtomicRMWBinOpFMax = 13,
LLVMZludaAtomicRMWBinOpFMin = 14, LLVMZludaAtomicRMWBinOpFMin = 14,
LLVMZludaAtomicRMWBinOpUIncWrap = 15, LLVMZludaAtomicRMWBinOpUIncWrap = 15,
LLVMZludaAtomicRMWBinOpUDecWrap = 16, LLVMZludaAtomicRMWBinOpUDecWrap = 16,
} }
// Backport from LLVM 19 // Backport from LLVM 19
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0; pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1; pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2; pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3; pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4; pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5; pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6; pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0; pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
| LLVMZludaFastMathNoNaNs | LLVMZludaFastMathNoNaNs
| LLVMZludaFastMathNoInfs | LLVMZludaFastMathNoInfs
| LLVMZludaFastMathNoSignedZeros | LLVMZludaFastMathNoSignedZeros
| LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathAllowReciprocal
| LLVMZludaFastMathAllowContract | LLVMZludaFastMathAllowContract
| LLVMZludaFastMathApproxFunc; | LLVMZludaFastMathApproxFunc;
pub type LLVMZludaFastMathFlags = std::ffi::c_uint; pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
extern "C" { extern "C" {
pub fn LLVMZludaBuildAlloca( pub fn LLVMZludaBuildAlloca(
B: LLVMBuilderRef, B: LLVMBuilderRef,
Ty: LLVMTypeRef, Ty: LLVMTypeRef,
AddrSpace: u32, AddrSpace: u32,
Name: *const i8, Name: *const i8,
) -> LLVMValueRef; ) -> LLVMValueRef;
pub fn LLVMZludaBuildAtomicRMW( pub fn LLVMZludaBuildAtomicRMW(
B: LLVMBuilderRef, B: LLVMBuilderRef,
op: LLVMZludaAtomicRMWBinOp, op: LLVMZludaAtomicRMWBinOp,
PTR: LLVMValueRef, PTR: LLVMValueRef,
Val: LLVMValueRef, Val: LLVMValueRef,
scope: *const i8, scope: *const i8,
ordering: LLVMAtomicOrdering, ordering: LLVMAtomicOrdering,
) -> LLVMValueRef; ) -> LLVMValueRef;
pub fn LLVMZludaBuildAtomicCmpXchg( pub fn LLVMZludaBuildAtomicCmpXchg(
B: LLVMBuilderRef, B: LLVMBuilderRef,
Ptr: LLVMValueRef, Ptr: LLVMValueRef,
Cmp: LLVMValueRef, Cmp: LLVMValueRef,
New: LLVMValueRef, New: LLVMValueRef,
scope: *const i8, scope: *const i8,
SuccessOrdering: LLVMAtomicOrdering, SuccessOrdering: LLVMAtomicOrdering,
FailureOrdering: LLVMAtomicOrdering, FailureOrdering: LLVMAtomicOrdering,
) -> LLVMValueRef; ) -> LLVMValueRef;
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags); pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
pub fn LLVMZludaBuildFence( pub fn LLVMZludaBuildFence(
B: LLVMBuilderRef, B: LLVMBuilderRef,
ordering: LLVMAtomicOrdering, ordering: LLVMAtomicOrdering,
scope: *const i8, scope: *const i8,
Name: *const i8, Name: *const i8,
) -> LLVMValueRef; ) -> LLVMValueRef;
} }

View File

@ -1,191 +1,191 @@
use super::*; use super::*;
pub(super) fn run<'a, 'input>( pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
} }
fn run_directive<'input>( fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2, resolver: &mut GlobalStringIdentResolver2,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>, directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
}) })
} }
fn run_method<'input>( fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2, resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>, mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_declaration = method.body.is_none(); let is_declaration = method.body.is_none();
let mut body = Vec::new(); let mut body = Vec::new();
let mut remap_returns = Vec::new(); let mut remap_returns = Vec::new();
if !method.is_kernel { if !method.is_kernel {
for arg in method.return_arguments.iter_mut() { for arg in method.return_arguments.iter_mut() {
match arg.state_space { match arg.state_space {
ptx_parser::StateSpace::Param => { ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg; arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name; let old_name = arg.name;
arg.name = arg.name =
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration { if is_declaration {
continue; continue;
} }
remap_returns.push((old_name, arg.name, arg.v_type.clone())); remap_returns.push((old_name, arg.name, arg.v_type.clone()));
body.push(Statement::Variable(ast::Variable { body.push(Statement::Variable(ast::Variable {
align: None, align: None,
name: old_name, name: old_name,
v_type: arg.v_type.clone(), v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param, state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(), array_init: Vec::new(),
})); }));
} }
ptx_parser::StateSpace::Reg => {} ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()), _ => return Err(error_unreachable()),
} }
} }
for arg in method.input_arguments.iter_mut() { for arg in method.input_arguments.iter_mut() {
match arg.state_space { match arg.state_space {
ptx_parser::StateSpace::Param => { ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg; arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name; let old_name = arg.name;
arg.name = arg.name =
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration { if is_declaration {
continue; continue;
} }
body.push(Statement::Variable(ast::Variable { body.push(Statement::Variable(ast::Variable {
align: None, align: None,
name: old_name, name: old_name,
v_type: arg.v_type.clone(), v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param, state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(), array_init: Vec::new(),
})); }));
body.push(Statement::Instruction(ast::Instruction::St { body.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData { data: ast::StData {
qualifier: ast::LdStQualifier::Weak, qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param, state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough, caching: ast::StCacheOperator::Writethrough,
typ: arg.v_type.clone(), typ: arg.v_type.clone(),
}, },
arguments: ast::StArgs { arguments: ast::StArgs {
src1: old_name, src1: old_name,
src2: arg.name, src2: arg.name,
}, },
})); }));
} }
ptx_parser::StateSpace::Reg => {} ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()), _ => return Err(error_unreachable()),
} }
} }
} }
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
for statement in statements { for statement in statements {
run_statement(resolver, &remap_returns, &mut body, statement)?; run_statement(resolver, &remap_returns, &mut body, statement)?;
} }
Ok::<_, TranslateError>(body) Ok::<_, TranslateError>(body)
}) })
.transpose()?; .transpose()?;
Ok(Function2 { body, ..method }) Ok(Function2 { body, ..method })
} }
fn run_statement<'input>( fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>, statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
match statement { match statement {
Statement::Instruction(ast::Instruction::Call { Statement::Instruction(ast::Instruction::Call {
mut data, mut data,
mut arguments, mut arguments,
}) => { }) => {
let mut post_st = Vec::new(); let mut post_st = Vec::new();
for ((type_, space), ident) in data for ((type_, space), ident) in data
.input_arguments .input_arguments
.iter_mut() .iter_mut()
.zip(arguments.input_arguments.iter_mut()) .zip(arguments.input_arguments.iter_mut())
{ {
if *space == ptx_parser::StateSpace::Param { if *space == ptx_parser::StateSpace::Param {
*space = ptx_parser::StateSpace::Reg; *space = ptx_parser::StateSpace::Reg;
let old_name = *ident; let old_name = *ident;
*ident = resolver *ident = resolver
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
result.push(Statement::Instruction(ast::Instruction::Ld { result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails { data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak, qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param, state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached, caching: ast::LdCacheOperator::Cached,
typ: type_.clone(), typ: type_.clone(),
non_coherent: false, non_coherent: false,
}, },
arguments: ast::LdArgs { arguments: ast::LdArgs {
dst: *ident, dst: *ident,
src: old_name, src: old_name,
}, },
})); }));
} }
} }
for ((type_, space), ident) in data for ((type_, space), ident) in data
.return_arguments .return_arguments
.iter_mut() .iter_mut()
.zip(arguments.return_arguments.iter_mut()) .zip(arguments.return_arguments.iter_mut())
{ {
if *space == ptx_parser::StateSpace::Param { if *space == ptx_parser::StateSpace::Param {
*space = ptx_parser::StateSpace::Reg; *space = ptx_parser::StateSpace::Reg;
let old_name = *ident; let old_name = *ident;
*ident = resolver *ident = resolver
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
post_st.push(Statement::Instruction(ast::Instruction::St { post_st.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData { data: ast::StData {
qualifier: ast::LdStQualifier::Weak, qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param, state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough, caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(), typ: type_.clone(),
}, },
arguments: ast::StArgs { arguments: ast::StArgs {
src1: old_name, src1: old_name,
src2: *ident, src2: *ident,
}, },
})); }));
} }
} }
result.push(Statement::Instruction(ast::Instruction::Call { result.push(Statement::Instruction(ast::Instruction::Call {
data, data,
arguments, arguments,
})); }));
result.extend(post_st.into_iter()); result.extend(post_st.into_iter());
} }
Statement::Instruction(ast::Instruction::Ret { data }) => { Statement::Instruction(ast::Instruction::Ret { data }) => {
for (old_name, new_name, type_) in remap_returns.iter() { for (old_name, new_name, type_) in remap_returns.iter() {
result.push(Statement::Instruction(ast::Instruction::Ld { result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails { data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak, qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param, state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached, caching: ast::LdCacheOperator::Cached,
typ: type_.clone(), typ: type_.clone(),
non_coherent: false, non_coherent: false,
}, },
arguments: ast::LdArgs { arguments: ast::LdArgs {
dst: *new_name, dst: *new_name,
src: *old_name, src: *old_name,
}, },
})); }));
} }
result.push(Statement::Instruction(ast::Instruction::Ret { data })); result.push(Statement::Instruction(ast::Instruction::Ret { data }));
} }
statement => { statement => {
result.push(statement); result.push(statement);
} }
} }
Ok(()) Ok(())
} }

View File

@ -1,301 +1,301 @@
use super::*; use super::*;
pub(super) fn run<'a, 'input>( pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<UnconditionalDirective>, directives: Vec<UnconditionalDirective>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
} }
fn run_directive<'input>( fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2< directive: Directive2<
ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>,
>, >,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var), Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
}) })
} }
fn run_method<'input>( fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
method: Function2< method: Function2<
ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>,
>, >,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
let mut result = Vec::with_capacity(statements.len()); let mut result = Vec::with_capacity(statements.len());
for statement in statements { for statement in statements {
run_statement(resolver, &mut result, statement)?; run_statement(resolver, &mut result, statement)?;
} }
Ok::<_, TranslateError>(result) Ok::<_, TranslateError>(result)
}) })
.transpose()?; .transpose()?;
Ok(Function2 { Ok(Function2 {
body, body,
return_arguments: method.return_arguments, return_arguments: method.return_arguments,
name: method.name, name: method.name,
input_arguments: method.input_arguments, input_arguments: method.input_arguments,
import_as: method.import_as, import_as: method.import_as,
tuning: method.tuning, tuning: method.tuning,
linkage: method.linkage, linkage: method.linkage,
is_kernel: method.is_kernel, is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32, flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64, flush_to_zero_f16f64: method.flush_to_zero_f16f64,
rounding_mode_f32: method.rounding_mode_f32, rounding_mode_f32: method.rounding_mode_f32,
rounding_mode_f16f64: method.rounding_mode_f16f64, rounding_mode_f16f64: method.rounding_mode_f16f64,
}) })
} }
fn run_statement<'input>( fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: UnconditionalStatement, statement: UnconditionalStatement,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let mut visitor = FlattenArguments::new(resolver, result); let mut visitor = FlattenArguments::new(resolver, result);
let new_statement = statement.visit_map(&mut visitor)?; let new_statement = statement.visit_map(&mut visitor)?;
visitor.result.push(new_statement); visitor.result.push(new_statement);
Ok(()) Ok(())
} }
struct FlattenArguments<'a, 'input> { struct FlattenArguments<'a, 'input> {
result: &'a mut Vec<ExpandedStatement>, result: &'a mut Vec<ExpandedStatement>,
resolver: &'a mut GlobalStringIdentResolver2<'input>, resolver: &'a mut GlobalStringIdentResolver2<'input>,
post_stmts: Vec<ExpandedStatement>, post_stmts: Vec<ExpandedStatement>,
} }
impl<'a, 'input> FlattenArguments<'a, 'input> { impl<'a, 'input> FlattenArguments<'a, 'input> {
fn new( fn new(
resolver: &'a mut GlobalStringIdentResolver2<'input>, resolver: &'a mut GlobalStringIdentResolver2<'input>,
result: &'a mut Vec<ExpandedStatement>, result: &'a mut Vec<ExpandedStatement>,
) -> Self { ) -> Self {
FlattenArguments { FlattenArguments {
result, result,
resolver, resolver,
post_stmts: Vec::new(), post_stmts: Vec::new(),
} }
} }
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> { fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
Ok(name) Ok(name)
} }
fn reg_offset( fn reg_offset(
&mut self, &mut self,
reg: SpirvWord, reg: SpirvWord,
offset: i32, offset: i32,
type_space: Option<(&ast::Type, ast::StateSpace)>, type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool, _is_dst: bool,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
let (type_, state_space) = if let Some((type_, state_space)) = type_space { let (type_, state_space) = if let Some((type_, state_space)) = type_space {
(type_, state_space) (type_, state_space)
} else { } else {
return Err(TranslateError::UntypedSymbol); return Err(TranslateError::UntypedSymbol);
}; };
if state_space == ast::StateSpace::Reg { if state_space == ast::StateSpace::Reg {
let (reg_type, reg_space) = self.resolver.get_typed(reg)?; let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
if *reg_space != ast::StateSpace::Reg { if *reg_space != ast::StateSpace::Reg {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
let reg_scalar_type = match reg_type { let reg_scalar_type = match reg_type {
ast::Type::Scalar(underlying_type) => *underlying_type, ast::Type::Scalar(underlying_type) => *underlying_type,
_ => return Err(error_mismatched_type()), _ => return Err(error_mismatched_type()),
}; };
let reg_type = reg_type.clone(); let reg_type = reg_type.clone();
let id_constant_stmt = self let id_constant_stmt = self
.resolver .resolver
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg))); .register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
self.result.push(Statement::Constant(ConstantDefinition { self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt, dst: id_constant_stmt,
typ: reg_scalar_type, typ: reg_scalar_type,
value: ast::ImmediateValue::S64(offset as i64), value: ast::ImmediateValue::S64(offset as i64),
})); }));
let arith_details = match reg_scalar_type.kind() { let arith_details = match reg_scalar_type.kind() {
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type, type_: reg_scalar_type,
saturate: false, saturate: false,
}), }),
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
ast::ArithDetails::Integer(ast::ArithInteger { ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type, type_: reg_scalar_type,
saturate: false, saturate: false,
}) })
} }
_ => return Err(error_unreachable()), _ => return Err(error_unreachable()),
}; };
let id_add_result = self let id_add_result = self
.resolver .resolver
.register_unnamed(Some((reg_type, state_space))); .register_unnamed(Some((reg_type, state_space)));
self.result self.result
.push(Statement::Instruction(ast::Instruction::Add { .push(Statement::Instruction(ast::Instruction::Add {
data: arith_details, data: arith_details,
arguments: ast::AddArgs { arguments: ast::AddArgs {
dst: id_add_result, dst: id_add_result,
src1: reg, src1: reg,
src2: id_constant_stmt, src2: id_constant_stmt,
}, },
})); }));
Ok(id_add_result) Ok(id_add_result)
} else { } else {
let id_constant_stmt = self.resolver.register_unnamed(Some(( let id_constant_stmt = self.resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::S64), ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))); )));
self.result.push(Statement::Constant(ConstantDefinition { self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt, dst: id_constant_stmt,
typ: ast::ScalarType::S64, typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64), value: ast::ImmediateValue::S64(offset as i64),
})); }));
let dst = self let dst = self
.resolver .resolver
.register_unnamed(Some((type_.clone(), state_space))); .register_unnamed(Some((type_.clone(), state_space)));
self.result.push(Statement::PtrAccess(PtrAccess { self.result.push(Statement::PtrAccess(PtrAccess {
underlying_type: type_.clone(), underlying_type: type_.clone(),
state_space: state_space, state_space: state_space,
dst, dst,
ptr_src: reg, ptr_src: reg,
offset_src: id_constant_stmt, offset_src: id_constant_stmt,
})); }));
Ok(dst) Ok(dst)
} }
} }
fn immediate( fn immediate(
&mut self, &mut self,
value: ast::ImmediateValue, value: ast::ImmediateValue,
type_space: Option<(&ast::Type, ast::StateSpace)>, type_space: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) = let (scalar_t, state_space) =
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
(*scalar, state_space) (*scalar, state_space)
} else { } else {
return Err(TranslateError::UntypedSymbol); return Err(TranslateError::UntypedSymbol);
}; };
let id = self let id = self
.resolver .resolver
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space))); .register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
self.result.push(Statement::Constant(ConstantDefinition { self.result.push(Statement::Constant(ConstantDefinition {
dst: id, dst: id,
typ: scalar_t, typ: scalar_t,
value, value,
})); }));
Ok(id) Ok(id)
} }
fn vec_member( fn vec_member(
&mut self, &mut self,
vector_ident: SpirvWord, vector_ident: SpirvWord,
member: u8, member: u8,
_type_space: Option<(&ast::Type, ast::StateSpace)>, _type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool, is_dst: bool,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? { let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
(ast::Type::Vector(vector_width, scalar_t), space) => { (ast::Type::Vector(vector_width, scalar_t), space) => {
(*vector_width, *scalar_t, *space) (*vector_width, *scalar_t, *space)
} }
_ => return Err(error_mismatched_type()), _ => return Err(error_mismatched_type()),
}; };
let temporary = self let temporary = self
.resolver .resolver
.register_unnamed(Some((scalar_type.into(), space))); .register_unnamed(Some((scalar_type.into(), space)));
if is_dst { if is_dst {
self.post_stmts.push(Statement::VectorWrite(VectorWrite { self.post_stmts.push(Statement::VectorWrite(VectorWrite {
scalar_type, scalar_type,
vector_width, vector_width,
vector_dst: vector_ident, vector_dst: vector_ident,
vector_src: vector_ident, vector_src: vector_ident,
scalar_src: temporary, scalar_src: temporary,
member, member,
})); }));
} else { } else {
self.result.push(Statement::VectorRead(VectorRead { self.result.push(Statement::VectorRead(VectorRead {
scalar_type, scalar_type,
vector_width, vector_width,
scalar_dst: temporary, scalar_dst: temporary,
vector_src: vector_ident, vector_src: vector_ident,
member, member,
})); }));
} }
Ok(temporary) Ok(temporary)
} }
fn vec_pack( fn vec_pack(
&mut self, &mut self,
vector_elements: Vec<SpirvWord>, vector_elements: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>, type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool, relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
let (width, scalar_t, state_space) = match type_space { let (width, scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space), Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
_ => return Err(error_mismatched_type()), _ => return Err(error_mismatched_type()),
}; };
let temporary_vector = self let temporary_vector = self
.resolver .resolver
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space))); .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails { let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst, is_extract: is_dst,
typ: scalar_t, typ: scalar_t,
packed: temporary_vector, packed: temporary_vector,
unpacked: vector_elements, unpacked: vector_elements,
relaxed_type_check, relaxed_type_check,
}); });
if is_dst { if is_dst {
self.post_stmts.push(statement); self.post_stmts.push(statement);
} else { } else {
self.result.push(statement); self.result.push(statement);
} }
Ok(temporary_vector) Ok(temporary_vector)
} }
} }
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError> impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
for FlattenArguments<'a, 'b> for FlattenArguments<'a, 'b>
{ {
fn visit( fn visit(
&mut self, &mut self,
args: ast::ParsedOperand<SpirvWord>, args: ast::ParsedOperand<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>, type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool, relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
match args { match args {
ast::ParsedOperand::Reg(r) => self.reg(r), ast::ParsedOperand::Reg(r) => self.reg(r),
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space), ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
ast::ParsedOperand::RegOffset(reg, offset) => { ast::ParsedOperand::RegOffset(reg, offset) => {
self.reg_offset(reg, offset, type_space, is_dst) self.reg_offset(reg, offset, type_space, is_dst)
} }
ast::ParsedOperand::VecMember(vec, member) => { ast::ParsedOperand::VecMember(vec, member) => {
self.vec_member(vec, member, type_space, is_dst) self.vec_member(vec, member, type_space, is_dst)
} }
ast::ParsedOperand::VecPack(vecs) => { ast::ParsedOperand::VecPack(vecs) => {
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check) self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
} }
} }
} }
fn visit_ident( fn visit_ident(
&mut self, &mut self,
name: SpirvWord, name: SpirvWord,
_type_space: Option<(&ast::Type, ast::StateSpace)>, _type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool, _is_dst: bool,
_relaxed_type_check: bool, _relaxed_type_check: bool,
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> { ) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
self.reg(name) self.reg(name)
} }
} }
impl Drop for FlattenArguments<'_, '_> { impl Drop for FlattenArguments<'_, '_> {
fn drop(&mut self) { fn drop(&mut self) {
self.result.extend(self.post_stmts.drain(..)); self.result.extend(self.post_stmts.drain(..));
} }
} }

View File

@ -1,208 +1,208 @@
use super::*; use super::*;
pub(super) fn run<'a, 'input>( pub(super) fn run<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>, resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2, special_registers: &'a SpecialRegistersMap2,
directives: Vec<UnconditionalDirective>, directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> { ) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len()); let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
let mut sreg_to_function = let mut sreg_to_function =
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default()); FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
SpecialRegistersMap2::foreach_declaration( SpecialRegistersMap2::foreach_declaration(
resolver, resolver,
|sreg, (return_arguments, name, input_arguments)| { |sreg, (return_arguments, name, input_arguments)| {
result.push(UnconditionalDirective::Method(UnconditionalFunction { result.push(UnconditionalDirective::Method(UnconditionalFunction {
return_arguments, return_arguments,
name, name,
input_arguments, input_arguments,
body: None, body: None,
import_as: None, import_as: None,
tuning: Vec::new(), tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN, linkage: ast::LinkingDirective::EXTERN,
is_kernel: false, is_kernel: false,
flush_to_zero_f32: false, flush_to_zero_f32: false,
flush_to_zero_f16f64: false, flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})); }));
sreg_to_function.insert(sreg, name); sreg_to_function.insert(sreg, name);
}, },
); );
let mut visitor = SpecialRegisterResolver { let mut visitor = SpecialRegisterResolver {
resolver, resolver,
special_registers, special_registers,
sreg_to_function, sreg_to_function,
result: Vec::new(), result: Vec::new(),
}; };
for directive in directives.into_iter() { for directive in directives.into_iter() {
result.push(run_directive(&mut visitor, directive)?); result.push(run_directive(&mut visitor, directive)?);
} }
Ok(result) Ok(result)
} }
fn run_directive<'a, 'input>( fn run_directive<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>, visitor: &mut SpecialRegisterResolver<'a, 'input>,
directive: UnconditionalDirective, directive: UnconditionalDirective,
) -> Result<UnconditionalDirective, TranslateError> { ) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?), Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
}) })
} }
fn run_method<'a, 'input>( fn run_method<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>, visitor: &mut SpecialRegisterResolver<'a, 'input>,
method: UnconditionalFunction, method: UnconditionalFunction,
) -> Result<UnconditionalFunction, TranslateError> { ) -> Result<UnconditionalFunction, TranslateError> {
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
let mut result = Vec::with_capacity(statements.len()); let mut result = Vec::with_capacity(statements.len());
for statement in statements { for statement in statements {
run_statement(visitor, &mut result, statement)?; run_statement(visitor, &mut result, statement)?;
} }
Ok::<_, TranslateError>(result) Ok::<_, TranslateError>(result)
}) })
.transpose()?; .transpose()?;
Ok(Function2 { body, ..method }) Ok(Function2 { body, ..method })
} }
fn run_statement<'a, 'input>( fn run_statement<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>, visitor: &mut SpecialRegisterResolver<'a, 'input>,
result: &mut Vec<UnconditionalStatement>, result: &mut Vec<UnconditionalStatement>,
statement: UnconditionalStatement, statement: UnconditionalStatement,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let converted_statement = statement.visit_map(visitor)?; let converted_statement = statement.visit_map(visitor)?;
result.extend(visitor.result.drain(..)); result.extend(visitor.result.drain(..));
result.push(converted_statement); result.push(converted_statement);
Ok(()) Ok(())
} }
struct SpecialRegisterResolver<'a, 'input> { struct SpecialRegisterResolver<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>, resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2, special_registers: &'a SpecialRegistersMap2,
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>, sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
result: Vec<UnconditionalStatement>, result: Vec<UnconditionalStatement>,
} }
impl<'a, 'b, 'input> impl<'a, 'b, 'input>
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
for SpecialRegisterResolver<'a, 'input> for SpecialRegisterResolver<'a, 'input>
{ {
fn visit( fn visit(
&mut self, &mut self,
operand: ast::ParsedOperand<SpirvWord>, operand: ast::ParsedOperand<SpirvWord>,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool, is_dst: bool,
_relaxed_type_check: bool, _relaxed_type_check: bool,
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> { ) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
map_operand(operand, &mut |ident, vector_index| { map_operand(operand, &mut |ident, vector_index| {
self.replace_sreg(ident, vector_index, is_dst) self.replace_sreg(ident, vector_index, is_dst)
}) })
} }
fn visit_ident( fn visit_ident(
&mut self, &mut self,
args: SpirvWord, args: SpirvWord,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool, is_dst: bool,
_relaxed_type_check: bool, _relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args)) Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
} }
} }
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
fn replace_sreg( fn replace_sreg(
&mut self, &mut self,
name: SpirvWord, name: SpirvWord,
vector_index: Option<u8>, vector_index: Option<u8>,
is_dst: bool, is_dst: bool,
) -> Result<Option<SpirvWord>, TranslateError> { ) -> Result<Option<SpirvWord>, TranslateError> {
if let Some(sreg) = self.special_registers.get(name) { if let Some(sreg) = self.special_registers.get(name) {
if is_dst { if is_dst {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
let input_arguments = match (vector_index, sreg.get_function_input_type()) { let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => { (Some(idx), Some(inp_type)) => {
if inp_type != ast::ScalarType::U8 { if inp_type != ast::ScalarType::U8 {
return Err(TranslateError::Unreachable); return Err(TranslateError::Unreachable);
} }
let constant = self.resolver.register_unnamed(Some(( let constant = self.resolver.register_unnamed(Some((
ast::Type::Scalar(inp_type), ast::Type::Scalar(inp_type),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))); )));
self.result.push(Statement::Constant(ConstantDefinition { self.result.push(Statement::Constant(ConstantDefinition {
dst: constant, dst: constant,
typ: inp_type, typ: inp_type,
value: ast::ImmediateValue::U64(idx as u64), value: ast::ImmediateValue::U64(idx as u64),
})); }));
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)] vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
} }
(None, None) => Vec::new(), (None, None) => Vec::new(),
_ => return Err(error_mismatched_type()), _ => return Err(error_mismatched_type()),
}; };
let return_type = sreg.get_function_return_type(); let return_type = sreg.get_function_return_type();
let fn_result = self let fn_result = self
.resolver .resolver
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg))); .register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
let return_arguments = vec![( let return_arguments = vec![(
fn_result, fn_result,
ast::Type::Scalar(return_type), ast::Type::Scalar(return_type),
ast::StateSpace::Reg, ast::StateSpace::Reg,
)]; )];
let data = ast::CallDetails { let data = ast::CallDetails {
uniform: false, uniform: false,
return_arguments: return_arguments return_arguments: return_arguments
.iter() .iter()
.map(|(_, typ, space)| (typ.clone(), *space)) .map(|(_, typ, space)| (typ.clone(), *space))
.collect(), .collect(),
input_arguments: input_arguments input_arguments: input_arguments
.iter() .iter()
.map(|(_, typ, space)| (typ.clone(), *space)) .map(|(_, typ, space)| (typ.clone(), *space))
.collect(), .collect(),
}; };
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> { let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
func: self.sreg_to_function[&sreg], func: self.sreg_to_function[&sreg],
input_arguments: input_arguments input_arguments: input_arguments
.iter() .iter()
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name)) .map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
.collect(), .collect(),
}; };
self.result self.result
.push(Statement::Instruction(ast::Instruction::Call { .push(Statement::Instruction(ast::Instruction::Call {
data, data,
arguments, arguments,
})); }));
Ok(Some(fn_result)) Ok(Some(fn_result))
} else { } else {
Ok(None) Ok(None)
} }
} }
} }
pub fn map_operand<T: Copy, Err>( pub fn map_operand<T: Copy, Err>(
this: ast::ParsedOperand<T>, this: ast::ParsedOperand<T>,
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>, fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
) -> Result<ast::ParsedOperand<T>, Err> { ) -> Result<ast::ParsedOperand<T>, Err> {
Ok(match this { Ok(match this {
ast::ParsedOperand::Reg(ident) => { ast::ParsedOperand::Reg(ident) => {
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident)) ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
} }
ast::ParsedOperand::RegOffset(ident, offset) => { ast::ParsedOperand::RegOffset(ident, offset) => {
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset) ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
} }
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm), ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? { ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
Some(ident) => ast::ParsedOperand::Reg(ident), Some(ident) => ast::ParsedOperand::Reg(ident),
None => ast::ParsedOperand::VecMember(ident, member), None => ast::ParsedOperand::VecMember(ident, member),
}, },
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
idents idents
.into_iter() .into_iter()
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident))) .map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
.collect::<Result<Vec<_>, _>>()?, .collect::<Result<Vec<_>, _>>()?,
), ),
}) })
} }

View File

@ -1,45 +1,45 @@
use super::*; use super::*;
pub(super) fn run<'input>( pub(super) fn run<'input>(
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len()); let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() { for mut directive in directives.into_iter() {
run_directive(&mut result, &mut directive)?; run_directive(&mut result, &mut directive)?;
result.push(directive); result.push(directive);
} }
Ok(result) Ok(result)
} }
fn run_directive<'input>( fn run_directive<'input>(
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>, directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
match directive { match directive {
Directive2::Variable(..) => {} Directive2::Variable(..) => {}
Directive2::Method(function2) => run_function(result, function2), Directive2::Method(function2) => run_function(result, function2),
} }
Ok(()) Ok(())
} }
fn run_function<'input>( fn run_function<'input>(
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>, function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) { ) {
function.body = function.body.take().map(|statements| { function.body = function.body.take().map(|statements| {
statements statements
.into_iter() .into_iter()
.filter_map(|statement| match statement { .filter_map(|statement| match statement {
Statement::Variable(var @ ast::Variable { Statement::Variable(var @ ast::Variable {
state_space: state_space:
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared, ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
.. ..
}) => { }) => {
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var)); result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
None None
} }
s => Some(s), s => Some(s),
}) })
.collect() .collect()
}); });
} }

View File

@ -1,404 +1,404 @@
use super::*; use super::*;
// This pass: // This pass:
// * Turns all .local, .param and .reg in-body variables into .local variables // * Turns all .local, .param and .reg in-body variables into .local variables
// (if _not_ an input method argument) // (if _not_ an input method argument)
// * Inserts explicit `ld`/`st` for newly converted .reg variables // * Inserts explicit `ld`/`st` for newly converted .reg variables
// * Fixup state space of all existing `ld`/`st` instructions into newly // * Fixup state space of all existing `ld`/`st` instructions into newly
// converted variables // converted variables
// * Turns `.entry` input arguments into param::entry and all related `.param` // * Turns `.entry` input arguments into param::entry and all related `.param`
// loads into `param::entry` loads // loads into `param::entry` loads
// * All `.func` input arguments are turned into `.reg` arguments by another // * All `.func` input arguments are turned into `.reg` arguments by another
// pass, so we do nothing there // pass, so we do nothing there
pub(super) fn run<'a, 'input>( pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
} }
fn run_directive<'a, 'input>( fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>, directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(method) => { Directive2::Method(method) => {
let visitor = InsertMemSSAVisitor::new(resolver); let visitor = InsertMemSSAVisitor::new(resolver);
Directive2::Method(run_method(visitor, method)?) Directive2::Method(run_method(visitor, method)?)
} }
}) })
} }
fn run_method<'a, 'input>( fn run_method<'a, 'input>(
mut visitor: InsertMemSSAVisitor<'a, 'input>, mut visitor: InsertMemSSAVisitor<'a, 'input>,
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>, mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_kernel = method.is_kernel; let is_kernel = method.is_kernel;
if is_kernel { if is_kernel {
for arg in method.input_arguments.iter_mut() { for arg in method.input_arguments.iter_mut() {
let old_name = arg.name; let old_name = arg.name;
let old_space = arg.state_space; let old_space = arg.state_space;
let new_space = ast::StateSpace::ParamEntry; let new_space = ast::StateSpace::ParamEntry;
let new_name = visitor let new_name = visitor
.resolver .resolver
.register_unnamed(Some((arg.v_type.clone(), new_space))); .register_unnamed(Some((arg.v_type.clone(), new_space)));
visitor.input_argument(old_name, new_name, old_space)?; visitor.input_argument(old_name, new_name, old_space)?;
arg.name = new_name; arg.name = new_name;
arg.state_space = new_space; arg.state_space = new_space;
} }
}; };
for arg in method.return_arguments.iter_mut() { for arg in method.return_arguments.iter_mut() {
visitor.visit_variable(arg)?; visitor.visit_variable(arg)?;
} }
let return_arguments = &method.return_arguments[..]; let return_arguments = &method.return_arguments[..];
let body = method let body = method
.body .body
.map(move |statements| { .map(move |statements| {
let mut result = Vec::with_capacity(statements.len()); let mut result = Vec::with_capacity(statements.len());
for statement in statements { for statement in statements {
run_statement(&mut visitor, return_arguments, &mut result, statement)?; run_statement(&mut visitor, return_arguments, &mut result, statement)?;
} }
Ok::<_, TranslateError>(result) Ok::<_, TranslateError>(result)
}) })
.transpose()?; .transpose()?;
Ok(Function2 { body, ..method }) Ok(Function2 { body, ..method })
} }
fn run_statement<'a, 'input>( fn run_statement<'a, 'input>(
visitor: &mut InsertMemSSAVisitor<'a, 'input>, visitor: &mut InsertMemSSAVisitor<'a, 'input>,
return_arguments: &[ast::Variable<SpirvWord>], return_arguments: &[ast::Variable<SpirvWord>],
result: &mut Vec<ExpandedStatement>, result: &mut Vec<ExpandedStatement>,
statement: ExpandedStatement, statement: ExpandedStatement,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
match statement { match statement {
Statement::Instruction(ast::Instruction::Ret { data }) => { Statement::Instruction(ast::Instruction::Ret { data }) => {
let statement = if return_arguments.is_empty() { let statement = if return_arguments.is_empty() {
Statement::Instruction(ast::Instruction::Ret { data }) Statement::Instruction(ast::Instruction::Ret { data })
} else { } else {
Statement::RetValue( Statement::RetValue(
data, data,
return_arguments return_arguments
.iter() .iter()
.map(|arg| { .map(|arg| {
if arg.state_space != ast::StateSpace::Local { if arg.state_space != ast::StateSpace::Local {
return Err(error_unreachable()); return Err(error_unreachable());
} }
Ok((arg.name, arg.v_type.clone())) Ok((arg.name, arg.v_type.clone()))
}) })
.collect::<Result<Vec<_>, _>>()?, .collect::<Result<Vec<_>, _>>()?,
) )
}; };
let new_statement = statement.visit_map(visitor)?; let new_statement = statement.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction)); result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement); result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction));
} }
Statement::Variable(mut var) => { Statement::Variable(mut var) => {
visitor.visit_variable(&mut var)?; visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var)); result.push(Statement::Variable(var));
} }
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => { Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
let instruction = visitor.visit_ld(data, arguments)?; let instruction = visitor.visit_ld(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?; let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction)); result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction)); result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction));
} }
Statement::Instruction(ast::Instruction::St { data, arguments }) => { Statement::Instruction(ast::Instruction::St { data, arguments }) => {
let instruction = visitor.visit_st(data, arguments)?; let instruction = visitor.visit_st(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?; let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction)); result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction)); result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction));
} }
Statement::PtrAccess(ptr_access) => { Statement::PtrAccess(ptr_access) => {
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
let statement = statement.visit_map(visitor)?; let statement = statement.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction)); result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(statement); result.push(statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction));
} }
s => { s => {
let new_statement = s.visit_map(visitor)?; let new_statement = s.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction)); result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement); result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction));
} }
} }
Ok(()) Ok(())
} }
struct InsertMemSSAVisitor<'a, 'input> { struct InsertMemSSAVisitor<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>, resolver: &'a mut GlobalStringIdentResolver2<'input>,
variables: FxHashMap<SpirvWord, RemapAction>, variables: FxHashMap<SpirvWord, RemapAction>,
pre: Vec<ast::Instruction<SpirvWord>>, pre: Vec<ast::Instruction<SpirvWord>>,
post: Vec<ast::Instruction<SpirvWord>>, post: Vec<ast::Instruction<SpirvWord>>,
} }
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self { fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
Self { Self {
resolver, resolver,
variables: FxHashMap::default(), variables: FxHashMap::default(),
pre: Vec::new(), pre: Vec::new(),
post: Vec::new(), post: Vec::new(),
} }
} }
fn input_argument( fn input_argument(
&mut self, &mut self,
old_name: SpirvWord, old_name: SpirvWord,
new_name: SpirvWord, new_name: SpirvWord,
old_space: ast::StateSpace, old_space: ast::StateSpace,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
if old_space != ast::StateSpace::Param { if old_space != ast::StateSpace::Param {
return Err(error_unreachable()); return Err(error_unreachable());
} }
self.variables.insert( self.variables.insert(
old_name, old_name,
RemapAction::LDStSpaceChange { RemapAction::LDStSpaceChange {
name: new_name, name: new_name,
old_space, old_space,
new_space: ast::StateSpace::ParamEntry, new_space: ast::StateSpace::ParamEntry,
}, },
); );
Ok(()) Ok(())
} }
fn variable( fn variable(
&mut self, &mut self,
type_: &ast::Type, type_: &ast::Type,
old_name: SpirvWord, old_name: SpirvWord,
new_name: SpirvWord, new_name: SpirvWord,
old_space: ast::StateSpace, old_space: ast::StateSpace,
) -> Result<bool, TranslateError> { ) -> Result<bool, TranslateError> {
Ok(match old_space { Ok(match old_space {
ast::StateSpace::Reg => { ast::StateSpace::Reg => {
self.variables.insert( self.variables.insert(
old_name, old_name,
RemapAction::PreLdPostSt { RemapAction::PreLdPostSt {
name: new_name, name: new_name,
type_: type_.clone(), type_: type_.clone(),
}, },
); );
true true
} }
ast::StateSpace::Param => { ast::StateSpace::Param => {
self.variables.insert( self.variables.insert(
old_name, old_name,
RemapAction::LDStSpaceChange { RemapAction::LDStSpaceChange {
old_space, old_space,
new_space: ast::StateSpace::Local, new_space: ast::StateSpace::Local,
name: new_name, name: new_name,
}, },
); );
true true
} }
// Good as-is // Good as-is
ast::StateSpace::Local ast::StateSpace::Local
| ast::StateSpace::Generic | ast::StateSpace::Generic
| ast::StateSpace::SharedCluster | ast::StateSpace::SharedCluster
| ast::StateSpace::Global | ast::StateSpace::Global
| ast::StateSpace::Const | ast::StateSpace::Const
| ast::StateSpace::SharedCta | ast::StateSpace::SharedCta
| ast::StateSpace::Shared | ast::StateSpace::Shared
| ast::StateSpace::ParamEntry | ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => return Err(error_unreachable()), | ast::StateSpace::ParamFunc => return Err(error_unreachable()),
}) })
} }
fn visit_st( fn visit_st(
&self, &self,
mut data: ast::StData, mut data: ast::StData,
mut arguments: ast::StArgs<SpirvWord>, mut arguments: ast::StArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> { ) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src1) { if let Some(remap) = self.variables.get(&arguments.src1) {
match remap { match remap {
RemapAction::PreLdPostSt { .. } => {} RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange { RemapAction::LDStSpaceChange {
old_space, old_space,
new_space, new_space,
name, name,
} => { } => {
if data.state_space != *old_space { if data.state_space != *old_space {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
data.state_space = *new_space; data.state_space = *new_space;
arguments.src1 = *name; arguments.src1 = *name;
} }
} }
} }
Ok(ast::Instruction::St { data, arguments }) Ok(ast::Instruction::St { data, arguments })
} }
fn visit_ld( fn visit_ld(
&self, &self,
mut data: ast::LdDetails, mut data: ast::LdDetails,
mut arguments: ast::LdArgs<SpirvWord>, mut arguments: ast::LdArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> { ) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src) { if let Some(remap) = self.variables.get(&arguments.src) {
match remap { match remap {
RemapAction::PreLdPostSt { .. } => {} RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange { RemapAction::LDStSpaceChange {
old_space, old_space,
new_space, new_space,
name, name,
} => { } => {
if data.state_space != *old_space { if data.state_space != *old_space {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
data.state_space = *new_space; data.state_space = *new_space;
arguments.src = *name; arguments.src = *name;
} }
} }
} }
Ok(ast::Instruction::Ld { data, arguments }) Ok(ast::Instruction::Ld { data, arguments })
} }
fn visit_ptr_access( fn visit_ptr_access(
&mut self, &mut self,
ptr_access: PtrAccess<SpirvWord>, ptr_access: PtrAccess<SpirvWord>,
) -> Result<PtrAccess<SpirvWord>, TranslateError> { ) -> Result<PtrAccess<SpirvWord>, TranslateError> {
let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) { let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
Some(RemapAction::LDStSpaceChange { Some(RemapAction::LDStSpaceChange {
old_space, old_space,
new_space, new_space,
name, name,
}) => (*old_space, *new_space, *name), }) => (*old_space, *new_space, *name),
Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access), Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
}; };
if ptr_access.state_space != old_space { if ptr_access.state_space != old_space {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
// Propagate space changes in dst // Propagate space changes in dst
let new_dst = self let new_dst = self
.resolver .resolver
.register_unnamed(Some((ptr_access.underlying_type.clone(), new_space))); .register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
self.variables.insert( self.variables.insert(
ptr_access.dst, ptr_access.dst,
RemapAction::LDStSpaceChange { RemapAction::LDStSpaceChange {
old_space, old_space,
new_space, new_space,
name: new_dst, name: new_dst,
}, },
); );
Ok(PtrAccess { Ok(PtrAccess {
ptr_src: name, ptr_src: name,
dst: new_dst, dst: new_dst,
state_space: new_space, state_space: new_space,
..ptr_access ..ptr_access
}) })
} }
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> { fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space { let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
// Do nothing // Do nothing
ptx_parser::StateSpace::Local => return Ok(()), ptx_parser::StateSpace::Local => return Ok(()),
// Handled by another pass // Handled by another pass
ptx_parser::StateSpace::Generic ptx_parser::StateSpace::Generic
| ptx_parser::StateSpace::SharedCluster | ptx_parser::StateSpace::SharedCluster
| ptx_parser::StateSpace::ParamEntry | ptx_parser::StateSpace::ParamEntry
| ptx_parser::StateSpace::Global | ptx_parser::StateSpace::Global
| ptx_parser::StateSpace::SharedCta | ptx_parser::StateSpace::SharedCta
| ptx_parser::StateSpace::Const | ptx_parser::StateSpace::Const
| ptx_parser::StateSpace::Shared | ptx_parser::StateSpace::Shared
| ptx_parser::StateSpace::ParamFunc => return Ok(()), | ptx_parser::StateSpace::ParamFunc => return Ok(()),
}; };
let old_name = var.name; let old_name = var.name;
let new_space = ast::StateSpace::Local; let new_space = ast::StateSpace::Local;
let new_name = self let new_name = self
.resolver .resolver
.register_unnamed(Some((var.v_type.clone(), new_space))); .register_unnamed(Some((var.v_type.clone(), new_space)));
self.variable(&var.v_type, old_name, new_name, old_space)?; self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name; var.name = new_name;
var.state_space = new_space; var.state_space = new_space;
Ok(()) Ok(())
} }
} }
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError> impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
for InsertMemSSAVisitor<'a, 'input> for InsertMemSSAVisitor<'a, 'input>
{ {
fn visit( fn visit(
&mut self, &mut self,
ident: SpirvWord, ident: SpirvWord,
_type_space: Option<(&ast::Type, ast::StateSpace)>, _type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool, is_dst: bool,
_relaxed_type_check: bool, _relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
if let Some(remap) = self.variables.get(&ident) { if let Some(remap) = self.variables.get(&ident) {
match remap { match remap {
RemapAction::PreLdPostSt { name, type_ } => { RemapAction::PreLdPostSt { name, type_ } => {
if is_dst { if is_dst {
let temp = self let temp = self
.resolver .resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.post.push(ast::Instruction::St { self.post.push(ast::Instruction::St {
data: ast::StData { data: ast::StData {
state_space: ast::StateSpace::Local, state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak, qualifier: ast::LdStQualifier::Weak,
caching: ast::StCacheOperator::Writethrough, caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(), typ: type_.clone(),
}, },
arguments: ast::StArgs { arguments: ast::StArgs {
src1: *name, src1: *name,
src2: temp, src2: temp,
}, },
}); });
Ok(temp) Ok(temp)
} else { } else {
let temp = self let temp = self
.resolver .resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.pre.push(ast::Instruction::Ld { self.pre.push(ast::Instruction::Ld {
data: ast::LdDetails { data: ast::LdDetails {
state_space: ast::StateSpace::Local, state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak, qualifier: ast::LdStQualifier::Weak,
caching: ast::LdCacheOperator::Cached, caching: ast::LdCacheOperator::Cached,
typ: type_.clone(), typ: type_.clone(),
non_coherent: false, non_coherent: false,
}, },
arguments: ast::LdArgs { arguments: ast::LdArgs {
dst: temp, dst: temp,
src: *name, src: *name,
}, },
}); });
Ok(temp) Ok(temp)
} }
} }
RemapAction::LDStSpaceChange { .. } => { RemapAction::LDStSpaceChange { .. } => {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
} }
} else { } else {
Ok(ident) Ok(ident)
} }
} }
fn visit_ident( fn visit_ident(
&mut self, &mut self,
args: SpirvWord, args: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>, type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool, relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
self.visit(args, type_space, is_dst, relaxed_type_check) self.visit(args, type_space, is_dst, relaxed_type_check)
} }
} }
#[derive(Clone)] #[derive(Clone)]
enum RemapAction { enum RemapAction {
PreLdPostSt { PreLdPostSt {
name: SpirvWord, name: SpirvWord,
type_: ast::Type, type_: ast::Type,
}, },
LDStSpaceChange { LDStSpaceChange {
old_space: ast::StateSpace, old_space: ast::StateSpace,
new_space: ast::StateSpace, new_space: ast::StateSpace,
name: SpirvWord, name: SpirvWord,
}, },
} }

View File

@ -1,401 +1,401 @@
use std::mem; use std::mem;
use super::*; use super::*;
use ptx_parser as ast; use ptx_parser as ast;
/* /*
There are several kinds of implicit conversions in PTX: There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`, - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands documented special ld/st/cvt conversion rules for destination operands
- st.param [x] y (used as function return arguments) same rule as above applies - st.param [x] y (used as function return arguments) same rule as above applies
- generic/global ld: for instruction `ld x, [y]`, y must be of type - generic/global ld: for instruction `ld x, [y]`, y must be of type
b64/u64/s64, which is bitcast to a pointer, dereferenced and then b64/u64/s64, which is bitcast to a pointer, dereferenced and then
documented special ld/st/cvt conversion rules are applied to dst documented special ld/st/cvt conversion rules are applied to dst
- generic/global st: for instruction `st [x], y`, x must be of type - generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer b64/u64/s64, which is bitcast to a pointer
*/ */
pub(super) fn run<'input>( pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
} }
fn run_directive<'a, 'input>( fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>, directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => { Directive2::Method(mut method) => {
method.body = method method.body = method
.body .body
.map(|statements| run_statements(resolver, statements)) .map(|statements| run_statements(resolver, statements))
.transpose()?; .transpose()?;
Directive2::Method(method) Directive2::Method(method)
} }
}) })
} }
fn run_statements<'input>( fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
func: Vec<ExpandedStatement>, func: Vec<ExpandedStatement>,
) -> Result<Vec<ExpandedStatement>, TranslateError> { ) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len()); let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() { for s in func.into_iter() {
insert_implicit_conversions_impl(resolver, &mut result, s)?; insert_implicit_conversions_impl(resolver, &mut result, s)?;
} }
Ok(result) Ok(result)
} }
fn insert_implicit_conversions_impl<'input>( fn insert_implicit_conversions_impl<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
stmt: ExpandedStatement, stmt: ExpandedStatement,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let mut post_conv = Vec::new(); let mut post_conv = Vec::new();
let statement = stmt.visit_map::<SpirvWord, TranslateError>( let statement = stmt.visit_map::<SpirvWord, TranslateError>(
&mut |operand, &mut |operand,
type_state: Option<(&ast::Type, ast::StateSpace)>, type_state: Option<(&ast::Type, ast::StateSpace)>,
is_dst, is_dst,
relaxed_type_check| { relaxed_type_check| {
let (instr_type, instruction_space) = match type_state { let (instr_type, instruction_space) = match type_state {
None => return Ok(operand), None => return Ok(operand),
Some(t) => t, Some(t) => t,
}; };
let (operand_type, operand_space) = resolver.get_typed(operand)?; let (operand_type, operand_space) = resolver.get_typed(operand)?;
let conversion_fn = if relaxed_type_check { let conversion_fn = if relaxed_type_check {
if is_dst { if is_dst {
should_convert_relaxed_dst_wrapper should_convert_relaxed_dst_wrapper
} else { } else {
should_convert_relaxed_src_wrapper should_convert_relaxed_src_wrapper
} }
} else { } else {
default_implicit_conversion default_implicit_conversion
}; };
match conversion_fn( match conversion_fn(
(*operand_space, &operand_type), (*operand_space, &operand_type),
(instruction_space, instr_type), (instruction_space, instr_type),
)? { )? {
Some(conv_kind) => { Some(conv_kind) => {
let conv_output = if is_dst { &mut post_conv } else { &mut *func }; let conv_output = if is_dst { &mut post_conv } else { &mut *func };
let mut from_type = instr_type.clone(); let mut from_type = instr_type.clone();
let mut from_space = instruction_space; let mut from_space = instruction_space;
let mut to_type = operand_type.clone(); let mut to_type = operand_type.clone();
let mut to_space = *operand_space; let mut to_space = *operand_space;
let mut src = let mut src =
resolver.register_unnamed(Some((instr_type.clone(), instruction_space))); resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
let mut dst = operand; let mut dst = operand;
let result = Ok::<_, TranslateError>(src); let result = Ok::<_, TranslateError>(src);
if !is_dst { if !is_dst {
mem::swap(&mut src, &mut dst); mem::swap(&mut src, &mut dst);
mem::swap(&mut from_type, &mut to_type); mem::swap(&mut from_type, &mut to_type);
mem::swap(&mut from_space, &mut to_space); mem::swap(&mut from_space, &mut to_space);
} }
conv_output.push(Statement::Conversion(ImplicitConversion { conv_output.push(Statement::Conversion(ImplicitConversion {
src, src,
dst, dst,
from_type, from_type,
from_space, from_space,
to_type, to_type,
to_space, to_space,
kind: conv_kind, kind: conv_kind,
})); }));
result result
} }
None => Ok(operand), None => Ok(operand),
} }
}, },
)?; )?;
func.push(statement); func.push(statement);
func.append(&mut post_conv); func.append(&mut post_conv);
Ok(()) Ok(())
} }
pub(crate) fn default_implicit_conversion( pub(crate) fn default_implicit_conversion(
(operand_space, operand_type): (ast::StateSpace, &ast::Type), (operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if instruction_space == ast::StateSpace::Reg { if instruction_space == ast::StateSpace::Reg {
if operand_space == ast::StateSpace::Reg { if operand_space == ast::StateSpace::Reg {
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
(operand_type, instruction_type) (operand_type, instruction_type)
{ {
if scalar.kind() == ast::ScalarKind::Bit if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len) && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{ {
return Ok(Some(ConversionKind::Default)); return Ok(Some(ConversionKind::Default));
} }
} }
} else if is_addressable(operand_space) { } else if is_addressable(operand_space) {
return Ok(Some(ConversionKind::AddressOf)); return Ok(Some(ConversionKind::AddressOf));
} }
} }
if instruction_space != operand_space { if instruction_space != operand_space {
default_implicit_conversion_space((operand_space, operand_type), instruction_space) default_implicit_conversion_space((operand_space, operand_type), instruction_space)
} else if instruction_type != operand_type { } else if instruction_type != operand_type {
default_implicit_conversion_type(instruction_space, operand_type, instruction_type) default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
} else { } else {
Ok(None) Ok(None)
} }
} }
fn is_addressable(this: ast::StateSpace) -> bool { fn is_addressable(this: ast::StateSpace) -> bool {
match this { match this {
ast::StateSpace::Const ast::StateSpace::Const
| ast::StateSpace::Generic | ast::StateSpace::Generic
| ast::StateSpace::Global | ast::StateSpace::Global
| ast::StateSpace::Local | ast::StateSpace::Local
| ast::StateSpace::Shared => true, | ast::StateSpace::Shared => true,
ast::StateSpace::Param | ast::StateSpace::Reg => false, ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta | ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry | ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => todo!(), | ast::StateSpace::ParamFunc => todo!(),
} }
} }
// Space is different // Space is different
fn default_implicit_conversion_space( fn default_implicit_conversion_space(
(operand_space, operand_type): (ast::StateSpace, &ast::Type), (operand_space, operand_type): (ast::StateSpace, &ast::Type),
instruction_space: ast::StateSpace, instruction_space: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{ {
Ok(Some(ConversionKind::PtrToPtr)) Ok(Some(ConversionKind::PtrToPtr))
} else if operand_space == ast::StateSpace::Reg { } else if operand_space == ast::StateSpace::Reg {
match operand_type { match operand_type {
// TODO: 32 bit // TODO: 32 bit
ast::Type::Scalar(ast::ScalarType::B64) ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64) | ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
ast::StateSpace::Global ast::StateSpace::Global
| ast::StateSpace::Generic | ast::StateSpace::Generic
| ast::StateSpace::Const | ast::StateSpace::Const
| ast::StateSpace::Local | ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(error_mismatched_type()), _ => Err(error_mismatched_type()),
}, },
ast::Type::Scalar(ast::ScalarType::B32) ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32) | ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr)) Ok(Some(ConversionKind::BitToPtr))
} }
_ => Err(error_mismatched_type()), _ => Err(error_mismatched_type()),
}, },
_ => Err(error_mismatched_type()), _ => Err(error_mismatched_type()),
} }
} else { } else {
Err(error_mismatched_type()) Err(error_mismatched_type())
} }
} }
// Space is same, but type is different // Space is same, but type is different
fn default_implicit_conversion_type( fn default_implicit_conversion_type(
space: ast::StateSpace, space: ast::StateSpace,
operand_type: &ast::Type, operand_type: &ast::Type,
instruction_type: &ast::Type, instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if space == ast::StateSpace::Reg { if space == ast::StateSpace::Reg {
if should_bitcast(instruction_type, operand_type) { if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default)) Ok(Some(ConversionKind::Default))
} else { } else {
Err(TranslateError::MismatchedType) Err(TranslateError::MismatchedType)
} }
} else { } else {
Ok(Some(ConversionKind::PtrToPtr)) Ok(Some(ConversionKind::PtrToPtr))
} }
} }
fn coerces_to_generic(this: ast::StateSpace) -> bool { fn coerces_to_generic(this: ast::StateSpace) -> bool {
match this { match this {
ast::StateSpace::Global ast::StateSpace::Global
| ast::StateSpace::Const | ast::StateSpace::Const
| ast::StateSpace::Local | ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta | ptx_parser::StateSpace::SharedCta
| ast::StateSpace::SharedCluster | ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true, | ast::StateSpace::Shared => true,
ast::StateSpace::Reg ast::StateSpace::Reg
| ast::StateSpace::Param | ast::StateSpace::Param
| ast::StateSpace::ParamEntry | ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc | ast::StateSpace::ParamFunc
| ast::StateSpace::Generic => false, | ast::StateSpace::Generic => false,
} }
} }
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) { match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.size_of() != operand.size_of() { if inst.size_of() != operand.size_of() {
return false; return false;
} }
match inst.kind() { match inst.kind() {
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ast::ScalarKind::Signed => { ast::ScalarKind::Signed => {
operand.kind() == ast::ScalarKind::Bit operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned || operand.kind() == ast::ScalarKind::Unsigned
} }
ast::ScalarKind::Unsigned => { ast::ScalarKind::Unsigned => {
operand.kind() == ast::ScalarKind::Bit operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed || operand.kind() == ast::ScalarKind::Signed
} }
ast::ScalarKind::Pred => false, ast::ScalarKind::Pred => false,
} }
} }
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
} }
_ => false, _ => false,
} }
} }
pub(crate) fn should_convert_relaxed_dst_wrapper( pub(crate) fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type), (operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space { if operand_space != instruction_space {
return Err(TranslateError::MismatchedType); return Err(TranslateError::MismatchedType);
} }
if operand_type == instruction_type { if operand_type == instruction_type {
return Ok(None); return Ok(None);
} }
match should_convert_relaxed_dst(operand_type, instruction_type) { match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv), conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType), None => Err(TranslateError::MismatchedType),
} }
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst( fn should_convert_relaxed_dst(
dst_type: &ast::Type, dst_type: &ast::Type,
instr_type: &ast::Type, instr_type: &ast::Type,
) -> Option<ConversionKind> { ) -> Option<ConversionKind> {
if dst_type == instr_type { if dst_type == instr_type {
return None; return None;
} }
match (dst_type, instr_type) { match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => { ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() { if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ast::ScalarKind::Signed => { ast::ScalarKind::Signed => {
if dst_type.kind() != ast::ScalarKind::Float { if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() { if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() { } else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend) Some(ConversionKind::SignExtend)
} else { } else {
None None
} }
} else { } else {
None None
} }
} }
ast::ScalarKind::Unsigned => { ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of() if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ast::ScalarKind::Float && dst_type.kind() != ast::ScalarKind::Float
{ {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ast::ScalarKind::Float => { ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of() if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit && dst_type.kind() == ast::ScalarKind::Bit
{ {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ast::ScalarKind::Pred => None, ast::ScalarKind::Pred => None,
}, },
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_dst( should_convert_relaxed_dst(
&ast::Type::Scalar(*dst_type), &ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type), &ast::Type::Scalar(*instr_type),
) )
} }
_ => None, _ => None,
} }
} }
pub(crate) fn should_convert_relaxed_src_wrapper( pub(crate) fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type), (operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space { if operand_space != instruction_space {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
if operand_type == instruction_type { if operand_type == instruction_type {
return Ok(None); return Ok(None);
} }
match should_convert_relaxed_src(operand_type, instruction_type) { match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv), conv @ Some(_) => Ok(conv),
None => Err(error_mismatched_type()), None => Err(error_mismatched_type()),
} }
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src( fn should_convert_relaxed_src(
src_type: &ast::Type, src_type: &ast::Type,
instr_type: &ast::Type, instr_type: &ast::Type,
) -> Option<ConversionKind> { ) -> Option<ConversionKind> {
if src_type == instr_type { if src_type == instr_type {
return None; return None;
} }
match (src_type, instr_type) { match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => { ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() { if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of() if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ast::ScalarKind::Float && src_type.kind() != ast::ScalarKind::Float
{ {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ast::ScalarKind::Float => { ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of() if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit && src_type.kind() == ast::ScalarKind::Bit
{ {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ast::ScalarKind::Pred => None, ast::ScalarKind::Pred => None,
}, },
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_src( should_convert_relaxed_src(
&ast::Type::Scalar(*dst_type), &ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type), &ast::Type::Scalar(*instr_type),
) )
} }
_ => None, _ => None,
} }
} }

View File

@ -1,194 +1,194 @@
use super::*; use super::*;
use ptx_parser as ast; use ptx_parser as ast;
pub(crate) fn run<'input, 'b>( pub(crate) fn run<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>, directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2>, TranslateError> { ) -> Result<Vec<NormalizedDirective2>, TranslateError> {
resolver.start_scope(); resolver.start_scope();
let result = directives let result = directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
resolver.end_scope(); resolver.end_scope();
Ok(result) Ok(result)
} }
fn run_directive<'input, 'b>( fn run_directive<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2, TranslateError> { ) -> Result<NormalizedDirective2, TranslateError> {
Ok(match directive { Ok(match directive {
ast::Directive::Variable(linking, var) => { ast::Directive::Variable(linking, var) => {
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?) NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
} }
ast::Directive::Method(linking, directive) => { ast::Directive::Method(linking, directive) => {
NormalizedDirective2::Method(run_method(resolver, linking, directive)?) NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
} }
}) })
} }
fn run_method<'input, 'b>( fn run_method<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective, linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>, method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2, TranslateError> { ) -> Result<NormalizedFunction2, TranslateError> {
let is_kernel = method.func_directive.name.is_kernel(); let is_kernel = method.func_directive.name.is_kernel();
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?; let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
resolver.start_scope(); resolver.start_scope();
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?; let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
let mut result = Vec::with_capacity(statements.len()); let mut result = Vec::with_capacity(statements.len());
run_statements(resolver, &mut result, statements)?; run_statements(resolver, &mut result, statements)?;
Ok::<_, TranslateError>(result) Ok::<_, TranslateError>(result)
}) })
.transpose()?; .transpose()?;
resolver.end_scope(); resolver.end_scope();
Ok(Function2 { Ok(Function2 {
return_arguments, return_arguments,
name, name,
input_arguments, input_arguments,
body, body,
import_as: None, import_as: None,
linkage, linkage,
is_kernel, is_kernel,
tuning: method.tuning, tuning: method.tuning,
flush_to_zero_f32: false, flush_to_zero_f32: false,
flush_to_zero_f16f64: false, flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}) })
} }
fn run_function_decl<'input, 'b>( fn run_function_decl<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>, func_directive: ast::MethodDeclaration<'input, &'input str>,
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> { ) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
assert!(func_directive.shared_mem.is_none()); assert!(func_directive.shared_mem.is_none());
let return_arguments = func_directive let return_arguments = func_directive
.return_arguments .return_arguments
.into_iter() .into_iter()
.map(|var| run_variable(resolver, var)) .map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let input_arguments = func_directive let input_arguments = func_directive
.input_arguments .input_arguments
.into_iter() .into_iter()
.map(|var| run_variable(resolver, var)) .map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok((return_arguments, input_arguments)) Ok((return_arguments, input_arguments))
} }
fn run_variable<'input, 'b>( fn run_variable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
variable: ast::Variable<&'input str>, variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> { ) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable { Ok(ast::Variable {
name: resolver.add( name: resolver.add(
Cow::Borrowed(variable.name), Cow::Borrowed(variable.name),
Some((variable.v_type.clone(), variable.state_space)), Some((variable.v_type.clone(), variable.state_space)),
)?, )?,
align: variable.align, align: variable.align,
v_type: variable.v_type, v_type: variable.v_type,
state_space: variable.state_space, state_space: variable.state_space,
array_init: variable.array_init, array_init: variable.array_init,
}) })
} }
fn run_statements<'input, 'b>( fn run_statements<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>, result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>, statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
for statement in statements.iter() { for statement in statements.iter() {
match statement { match statement {
ast::Statement::Label(label) => { ast::Statement::Label(label) => {
resolver.add(Cow::Borrowed(*label), None)?; resolver.add(Cow::Borrowed(*label), None)?;
} }
_ => {} _ => {}
} }
} }
for statement in statements { for statement in statements {
match statement { match statement {
ast::Statement::Label(label) => { ast::Statement::Label(label) => {
result.push(Statement::Label(resolver.get_in_current_scope(label)?)) result.push(Statement::Label(resolver.get_in_current_scope(label)?))
} }
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?, ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
ast::Statement::Instruction(predicate, instruction) => { ast::Statement::Instruction(predicate, instruction) => {
result.push(Statement::Instruction(( result.push(Statement::Instruction((
predicate predicate
.map(|pred| { .map(|pred| {
Ok::<_, TranslateError>(ast::PredAt { Ok::<_, TranslateError>(ast::PredAt {
not: pred.not, not: pred.not,
label: resolver.get(pred.label)?, label: resolver.get(pred.label)?,
}) })
}) })
.transpose()?, .transpose()?,
run_instruction(resolver, instruction)?, run_instruction(resolver, instruction)?,
))) )))
} }
ast::Statement::Block(block) => { ast::Statement::Block(block) => {
resolver.start_scope(); resolver.start_scope();
run_statements(resolver, result, block)?; run_statements(resolver, result, block)?;
resolver.end_scope(); resolver.end_scope();
} }
} }
} }
Ok(()) Ok(())
} }
fn run_instruction<'input, 'b>( fn run_instruction<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>, instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> { ) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(instruction, &mut |name: &'input str, ast::visit_map(instruction, &mut |name: &'input str,
_: Option<( _: Option<(
&ast::Type, &ast::Type,
ast::StateSpace, ast::StateSpace,
)>, )>,
_, _,
_| { _| {
resolver.get(&name) resolver.get(&name)
}) })
} }
fn run_multivariable<'input, 'b>( fn run_multivariable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>, result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>, variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
match variable.count { match variable.count {
Some(count) => { Some(count) => {
for i in 0..count { for i in 0..count {
let name = Cow::Owned(format!("{}{}", variable.var.name, i)); let name = Cow::Owned(format!("{}{}", variable.var.name, i));
let ident = resolver.add( let ident = resolver.add(
name, name,
Some((variable.var.v_type.clone(), variable.var.state_space)), Some((variable.var.v_type.clone(), variable.var.state_space)),
)?; )?;
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: variable.var.align, align: variable.var.align,
v_type: variable.var.v_type.clone(), v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space, state_space: variable.var.state_space,
name: ident, name: ident,
array_init: variable.var.array_init.clone(), array_init: variable.var.array_init.clone(),
})); }));
} }
} }
None => { None => {
let name = Cow::Borrowed(variable.var.name); let name = Cow::Borrowed(variable.var.name);
let ident = resolver.add( let ident = resolver.add(
name, name,
Some((variable.var.v_type.clone(), variable.var.state_space)), Some((variable.var.v_type.clone(), variable.var.state_space)),
)?; )?;
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: variable.var.align, align: variable.var.align,
v_type: variable.var.v_type.clone(), v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space, state_space: variable.var.state_space,
name: ident, name: ident,
array_init: variable.var.array_init.clone(), array_init: variable.var.array_init.clone(),
})); }));
} }
} }
Ok(()) Ok(())
} }

View File

@ -1,90 +1,90 @@
use super::*; use super::*;
use ptx_parser as ast; use ptx_parser as ast;
pub(crate) fn run<'input>( pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<NormalizedDirective2>, directives: Vec<NormalizedDirective2>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> { ) -> Result<Vec<UnconditionalDirective>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
} }
fn run_directive<'input>( fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directive: NormalizedDirective2, directive: NormalizedDirective2,
) -> Result<UnconditionalDirective, TranslateError> { ) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive { Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var), Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
}) })
} }
fn run_method<'input>( fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
method: NormalizedFunction2, method: NormalizedFunction2,
) -> Result<UnconditionalFunction, TranslateError> { ) -> Result<UnconditionalFunction, TranslateError> {
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
let mut result = Vec::with_capacity(statements.len()); let mut result = Vec::with_capacity(statements.len());
for statement in statements { for statement in statements {
run_statement(resolver, &mut result, statement)?; run_statement(resolver, &mut result, statement)?;
} }
Ok::<_, TranslateError>(result) Ok::<_, TranslateError>(result)
}) })
.transpose()?; .transpose()?;
Ok(Function2 { Ok(Function2 {
body, body,
return_arguments: method.return_arguments, return_arguments: method.return_arguments,
name: method.name, name: method.name,
input_arguments: method.input_arguments, input_arguments: method.input_arguments,
import_as: method.import_as, import_as: method.import_as,
tuning: method.tuning, tuning: method.tuning,
linkage: method.linkage, linkage: method.linkage,
is_kernel: method.is_kernel, is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32, flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64, flush_to_zero_f16f64: method.flush_to_zero_f16f64,
rounding_mode_f32: method.rounding_mode_f32, rounding_mode_f32: method.rounding_mode_f32,
rounding_mode_f16f64: method.rounding_mode_f16f64, rounding_mode_f16f64: method.rounding_mode_f16f64,
}) })
} }
fn run_statement<'input>( fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<UnconditionalStatement>, result: &mut Vec<UnconditionalStatement>,
statement: NormalizedStatement, statement: NormalizedStatement,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
Ok(match statement { Ok(match statement {
Statement::Label(label) => result.push(Statement::Label(label)), Statement::Label(label) => result.push(Statement::Label(label)),
Statement::Variable(var) => result.push(Statement::Variable(var)), Statement::Variable(var) => result.push(Statement::Variable(var)),
Statement::Instruction((predicate, instruction)) => { Statement::Instruction((predicate, instruction)) => {
if let Some(pred) = predicate { if let Some(pred) = predicate {
let if_true = resolver.register_unnamed(None); let if_true = resolver.register_unnamed(None);
let if_false = resolver.register_unnamed(None); let if_false = resolver.register_unnamed(None);
let folded_bra = match &instruction { let folded_bra = match &instruction {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src), ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None, _ => None,
}; };
let mut branch = BrachCondition { let mut branch = BrachCondition {
predicate: pred.label, predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true), if_true: folded_bra.unwrap_or(if_true),
if_false, if_false,
}; };
if pred.not { if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false); std::mem::swap(&mut branch.if_true, &mut branch.if_false);
} }
result.push(Statement::Conditional(branch)); result.push(Statement::Conditional(branch));
if folded_bra.is_none() { if folded_bra.is_none() {
result.push(Statement::Label(if_true)); result.push(Statement::Label(if_true));
result.push(Statement::Instruction(instruction)); result.push(Statement::Instruction(instruction));
} }
result.push(Statement::Label(if_false)); result.push(Statement::Label(if_false));
} else { } else {
result.push(Statement::Instruction(instruction)); result.push(Statement::Instruction(instruction));
} }
} }
_ => return Err(error_unreachable()), _ => return Err(error_unreachable()),
}) })
} }

View File

@ -1,268 +1,268 @@
use super::*; use super::*;
pub(super) fn run<'input>( pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut fn_declarations = FxHashMap::default(); let mut fn_declarations = FxHashMap::default();
let remapped_directives = directives let remapped_directives = directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, &mut fn_declarations, directive)) .map(|directive| run_directive(resolver, &mut fn_declarations, directive))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let mut result = fn_declarations let mut result = fn_declarations
.into_iter() .into_iter()
.map(|(_, (return_arguments, name, input_arguments))| { .map(|(_, (return_arguments, name, input_arguments))| {
Directive2::Method(Function2 { Directive2::Method(Function2 {
return_arguments, return_arguments,
name: name, name: name,
input_arguments, input_arguments,
body: None, body: None,
import_as: None, import_as: None,
tuning: Vec::new(), tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN, linkage: ast::LinkingDirective::EXTERN,
is_kernel: false, is_kernel: false,
flush_to_zero_f32: false, flush_to_zero_f32: false,
flush_to_zero_f16f64: false, flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}) })
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
result.extend(remapped_directives); result.extend(remapped_directives);
Ok(result) Ok(result)
} }
fn run_directive<'input>( fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap< fn_declarations: &mut FxHashMap<
Cow<'input, str>, Cow<'input, str>,
( (
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
SpirvWord, SpirvWord,
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
), ),
>, >,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>, directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => { Directive2::Method(mut method) => {
method.body = method method.body = method
.body .body
.map(|statements| run_statements(resolver, fn_declarations, statements)) .map(|statements| run_statements(resolver, fn_declarations, statements))
.transpose()?; .transpose()?;
Directive2::Method(method) Directive2::Method(method)
} }
}) })
} }
fn run_statements<'input>( fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap< fn_declarations: &mut FxHashMap<
Cow<'input, str>, Cow<'input, str>,
( (
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
SpirvWord, SpirvWord,
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
), ),
>, >,
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
statements statements
.into_iter() .into_iter()
.map(|statement| { .map(|statement| {
Ok(match statement { Ok(match statement {
Statement::Instruction(instruction) => { Statement::Instruction(instruction) => {
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?) Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
} }
s => s, s => s,
}) })
}) })
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
} }
fn run_instruction<'input>( fn run_instruction<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap< fn_declarations: &mut FxHashMap<
Cow<'input, str>, Cow<'input, str>,
( (
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
SpirvWord, SpirvWord,
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
), ),
>, >,
instruction: ptx_parser::Instruction<SpirvWord>, instruction: ptx_parser::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> { ) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
Ok(match instruction { Ok(match instruction {
i @ ptx_parser::Instruction::Sqrt { i @ ptx_parser::Instruction::Sqrt {
data: data:
ast::RcpData { ast::RcpData {
kind: ast::RcpKind::Approx, kind: ast::RcpKind::Approx,
type_: ast::ScalarType::F32, type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false), flush_to_zero: None | Some(false),
}, },
.. ..
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?, } => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Rsqrt { i @ ptx_parser::Instruction::Rsqrt {
data: data:
ast::TypeFtz { ast::TypeFtz {
type_: ast::ScalarType::F32, type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false), flush_to_zero: None | Some(false),
}, },
.. ..
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?, } => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Rcp { i @ ptx_parser::Instruction::Rcp {
data: data:
ast::RcpData { ast::RcpData {
kind: ast::RcpKind::Approx, kind: ast::RcpKind::Approx,
type_: ast::ScalarType::F32, type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false), flush_to_zero: None | Some(false),
}, },
.. ..
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?, } => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Ex2 { i @ ptx_parser::Instruction::Ex2 {
data: data:
ast::TypeFtz { ast::TypeFtz {
type_: ast::ScalarType::F32, type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false), flush_to_zero: None | Some(false),
}, },
.. ..
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?, } => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Lg2 { i @ ptx_parser::Instruction::Lg2 {
data: ast::FlushToZero { data: ast::FlushToZero {
flush_to_zero: false, flush_to_zero: false,
}, },
.. ..
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?, } => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Activemask { .. } => { i @ ptx_parser::Instruction::Activemask { .. } => {
to_call(resolver, fn_declarations, "activemask".into(), i)? to_call(resolver, fn_declarations, "activemask".into(), i)?
} }
i @ ptx_parser::Instruction::Bfe { data, .. } => { i @ ptx_parser::Instruction::Bfe { data, .. } => {
let name = ["bfe_", scalar_to_ptx_name(data)].concat(); let name = ["bfe_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)? to_call(resolver, fn_declarations, name.into(), i)?
} }
i @ ptx_parser::Instruction::Bfi { data, .. } => { i @ ptx_parser::Instruction::Bfi { data, .. } => {
let name = ["bfi_", scalar_to_ptx_name(data)].concat(); let name = ["bfi_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)? to_call(resolver, fn_declarations, name.into(), i)?
} }
i @ ptx_parser::Instruction::Bar { .. } => { i @ ptx_parser::Instruction::Bar { .. } => {
to_call(resolver, fn_declarations, "bar_sync".into(), i)? to_call(resolver, fn_declarations, "bar_sync".into(), i)?
} }
ptx_parser::Instruction::BarRed { data, arguments } => { ptx_parser::Instruction::BarRed { data, arguments } => {
if arguments.src_threadcount.is_some() { if arguments.src_threadcount.is_some() {
return Err(error_todo()); return Err(error_todo());
} }
let name = match data.pred_reduction { let name = match data.pred_reduction {
ptx_parser::Reduction::And => "bar_red_and_pred", ptx_parser::Reduction::And => "bar_red_and_pred",
ptx_parser::Reduction::Or => "bar_red_or_pred", ptx_parser::Reduction::Or => "bar_red_or_pred",
}; };
to_call( to_call(
resolver, resolver,
fn_declarations, fn_declarations,
name.into(), name.into(),
ptx_parser::Instruction::BarRed { data, arguments }, ptx_parser::Instruction::BarRed { data, arguments },
)? )?
} }
ptx_parser::Instruction::ShflSync { data, arguments } => { ptx_parser::Instruction::ShflSync { data, arguments } => {
let mode = match data.mode { let mode = match data.mode {
ptx_parser::ShuffleMode::Up => "up", ptx_parser::ShuffleMode::Up => "up",
ptx_parser::ShuffleMode::Down => "down", ptx_parser::ShuffleMode::Down => "down",
ptx_parser::ShuffleMode::BFly => "bfly", ptx_parser::ShuffleMode::BFly => "bfly",
ptx_parser::ShuffleMode::Idx => "idx", ptx_parser::ShuffleMode::Idx => "idx",
}; };
let pred = if arguments.dst_pred.is_some() { let pred = if arguments.dst_pred.is_some() {
"_pred" "_pred"
} else { } else {
"" ""
}; };
to_call( to_call(
resolver, resolver,
fn_declarations, fn_declarations,
format!("shfl_sync_{}_b32{}", mode, pred).into(), format!("shfl_sync_{}_b32{}", mode, pred).into(),
ptx_parser::Instruction::ShflSync { data, arguments }, ptx_parser::Instruction::ShflSync { data, arguments },
)? )?
} }
i @ ptx_parser::Instruction::Nanosleep { .. } => { i @ ptx_parser::Instruction::Nanosleep { .. } => {
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)? to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
} }
i => i, i => i,
}) })
} }
fn to_call<'input>( fn to_call<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap< fn_declarations: &mut FxHashMap<
Cow<'input, str>, Cow<'input, str>,
( (
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
SpirvWord, SpirvWord,
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
), ),
>, >,
name: Cow<'input, str>, name: Cow<'input, str>,
i: ast::Instruction<SpirvWord>, i: ast::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> { ) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
let mut data_return = Vec::new(); let mut data_return = Vec::new();
let mut data_input = Vec::new(); let mut data_input = Vec::new();
let mut arguments_return = Vec::new(); let mut arguments_return = Vec::new();
let mut arguments_input = Vec::new(); let mut arguments_input = Vec::new();
ast::visit(&i, &mut |name: &SpirvWord, ast::visit(&i, &mut |name: &SpirvWord,
type_space: Option<( type_space: Option<(
&ptx_parser::Type, &ptx_parser::Type,
ptx_parser::StateSpace, ptx_parser::StateSpace,
)>, )>,
is_dst: bool, is_dst: bool,
_: bool| { _: bool| {
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?; let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
if is_dst { if is_dst {
data_return.push((type_.clone(), space)); data_return.push((type_.clone(), space));
arguments_return.push(*name); arguments_return.push(*name);
} else { } else {
data_input.push((type_.clone(), space)); data_input.push((type_.clone(), space));
arguments_input.push(*name); arguments_input.push(*name);
}; };
Ok::<_, TranslateError>(()) Ok::<_, TranslateError>(())
})?; })?;
let fn_name = match fn_declarations.entry(name) { let fn_name = match fn_declarations.entry(name) {
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
hash_map::Entry::Vacant(vacant_entry) => { hash_map::Entry::Vacant(vacant_entry) => {
let name = vacant_entry.key().clone(); let name = vacant_entry.key().clone();
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
let name = resolver.register_named(Cow::Owned(full_name.clone()), None); let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
vacant_entry.insert(( vacant_entry.insert((
to_variables(resolver, &data_return), to_variables(resolver, &data_return),
name, name,
to_variables(resolver, &data_input), to_variables(resolver, &data_input),
)); ));
name name
} }
}; };
Ok(ast::Instruction::Call { Ok(ast::Instruction::Call {
data: ptx_parser::CallDetails { data: ptx_parser::CallDetails {
uniform: false, uniform: false,
return_arguments: data_return, return_arguments: data_return,
input_arguments: data_input, input_arguments: data_input,
}, },
arguments: ptx_parser::CallArgs { arguments: ptx_parser::CallArgs {
return_arguments: arguments_return, return_arguments: arguments_return,
func: fn_name, func: fn_name,
input_arguments: arguments_input, input_arguments: arguments_input,
}, },
}) })
} }
fn to_variables<'input>( fn to_variables<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
) -> Vec<ptx_parser::Variable<SpirvWord>> { ) -> Vec<ptx_parser::Variable<SpirvWord>> {
arguments arguments
.iter() .iter()
.map(|(type_, space)| ast::Variable { .map(|(type_, space)| ast::Variable {
align: None, align: None,
v_type: type_.clone(), v_type: type_.clone(),
state_space: *space, state_space: *space,
name: resolver.register_unnamed(Some((type_.clone(), *space))), name: resolver.register_unnamed(Some((type_.clone(), *space))),
array_init: Vec::new(), array_init: Vec::new(),
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }

View File

@ -1,33 +1,33 @@
use std::borrow::Cow; use std::borrow::Cow;
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord}; use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
pub(crate) fn run<'input>( pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
mut directives: Vec<NormalizedDirective2>, mut directives: Vec<NormalizedDirective2>,
) -> Vec<NormalizedDirective2> { ) -> Vec<NormalizedDirective2> {
for directive in directives.iter_mut() { for directive in directives.iter_mut() {
match directive { match directive {
NormalizedDirective2::Method(func) => { NormalizedDirective2::Method(func) => {
replace_with_ptx_impl(resolver, func.name); replace_with_ptx_impl(resolver, func.name);
} }
_ => {} _ => {}
} }
} }
directives directives
} }
fn replace_with_ptx_impl<'input>( fn replace_with_ptx_impl<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
fn_name: SpirvWord, fn_name: SpirvWord,
) { ) {
let known_names = ["__assertfail"]; let known_names = ["__assertfail"];
if let Some(super::IdentEntry { if let Some(super::IdentEntry {
name: Some(name), .. name: Some(name), ..
}) = resolver.ident_map.get_mut(&fn_name) }) = resolver.ident_map.get_mut(&fn_name)
{ {
if known_names.contains(&&**name) { if known_names.contains(&&**name) {
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name)); *name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
} }
} }
} }

View File

@ -1,69 +1,69 @@
use super::*; use super::*;
use ptx_parser as ast; use ptx_parser as ast;
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
pub(crate) fn run<'input>( pub(crate) fn run<'input>(
directives: Vec<UnconditionalDirective>, directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> { ) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut functions = FxHashSet::default(); let mut functions = FxHashSet::default();
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(&mut functions, directive)) .map(|directive| run_directive(&mut functions, directive))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
} }
fn run_directive<'input>( fn run_directive<'input>(
functions: &mut FxHashSet<SpirvWord>, functions: &mut FxHashSet<SpirvWord>,
directive: UnconditionalDirective, directive: UnconditionalDirective,
) -> Result<UnconditionalDirective, TranslateError> { ) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(method) => { Directive2::Method(method) => {
if !method.is_kernel { if !method.is_kernel {
functions.insert(method.name); functions.insert(method.name);
} }
Directive2::Method(run_method(functions, method)?) Directive2::Method(run_method(functions, method)?)
} }
}) })
} }
fn run_method<'input>( fn run_method<'input>(
functions: &mut FxHashSet<SpirvWord>, functions: &mut FxHashSet<SpirvWord>,
method: UnconditionalFunction, method: UnconditionalFunction,
) -> Result<UnconditionalFunction, TranslateError> { ) -> Result<UnconditionalFunction, TranslateError> {
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
statements statements
.into_iter() .into_iter()
.map(|statement| run_statement(functions, statement)) .map(|statement| run_statement(functions, statement))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
}) })
.transpose()?; .transpose()?;
Ok(Function2 { body, ..method }) Ok(Function2 { body, ..method })
} }
fn run_statement<'input>( fn run_statement<'input>(
functions: &mut FxHashSet<SpirvWord>, functions: &mut FxHashSet<SpirvWord>,
statement: UnconditionalStatement, statement: UnconditionalStatement,
) -> Result<UnconditionalStatement, TranslateError> { ) -> Result<UnconditionalStatement, TranslateError> {
Ok(match statement { Ok(match statement {
Statement::Instruction(ast::Instruction::Mov { Statement::Instruction(ast::Instruction::Mov {
data, data,
arguments: arguments:
ast::MovArgs { ast::MovArgs {
dst: ast::ParsedOperand::Reg(dst_reg), dst: ast::ParsedOperand::Reg(dst_reg),
src: ast::ParsedOperand::Reg(src_reg), src: ast::ParsedOperand::Reg(src_reg),
}, },
}) if functions.contains(&src_reg) => { }) if functions.contains(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
UnconditionalStatement::FunctionPointer(FunctionPointerDetails { UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg, dst: dst_reg,
src: src_reg, src: src_reg,
}) })
} }
s => s, s => s,
}) })
} }

View File

@ -1,327 +1,327 @@
use bpaf::{Args, Bpaf, Parser}; use bpaf::{Args, Bpaf, Parser};
use cargo_metadata::{MetadataCommand, Package}; use cargo_metadata::{MetadataCommand, Package};
use serde::Deserialize; use serde::Deserialize;
use std::{env, ffi::OsString, path::PathBuf, process::Command}; use std::{env, ffi::OsString, path::PathBuf, process::Command};
#[derive(Debug, Clone, Bpaf)] #[derive(Debug, Clone, Bpaf)]
#[bpaf(options)] #[bpaf(options)]
enum Options { enum Options {
#[bpaf(command)] #[bpaf(command)]
/// Compile ZLUDA (default command) /// Compile ZLUDA (default command)
Build(#[bpaf(external(build))] Build), Build(#[bpaf(external(build))] Build),
#[bpaf(command)] #[bpaf(command)]
/// Compile ZLUDA and build a package /// Compile ZLUDA and build a package
Zip(#[bpaf(external(build))] Build), Zip(#[bpaf(external(build))] Build),
} }
#[derive(Debug, Clone, Bpaf)] #[derive(Debug, Clone, Bpaf)]
struct Build { struct Build {
#[bpaf(any("CARGO", not_help), many)] #[bpaf(any("CARGO", not_help), many)]
/// Arguments to pass to cargo, e.g. `--release` for release build /// Arguments to pass to cargo, e.g. `--release` for release build
cargo_arguments: Vec<OsString>, cargo_arguments: Vec<OsString>,
} }
fn not_help(s: OsString) -> Option<OsString> { fn not_help(s: OsString) -> Option<OsString> {
if s == "-h" || s == "--help" { if s == "-h" || s == "--help" {
None None
} else { } else {
Some(s) Some(s)
} }
} }
// We need to sniff out some args passed to cargo to understand how to create // We need to sniff out some args passed to cargo to understand how to create
// symlinks (should they go into `target/debug`, `target/release` or custom) // symlinks (should they go into `target/debug`, `target/release` or custom)
#[derive(Debug, Clone, Bpaf)] #[derive(Debug, Clone, Bpaf)]
struct Cargo { struct Cargo {
#[bpaf(switch, long, short)] #[bpaf(switch, long, short)]
release: Option<bool>, release: Option<bool>,
#[bpaf(long)] #[bpaf(long)]
profile: Option<String>, profile: Option<String>,
#[bpaf(any("", Some), many)] #[bpaf(any("", Some), many)]
_unused: Vec<OsString>, _unused: Vec<OsString>,
} }
struct Project { struct Project {
name: String, name: String,
target_name: String, target_name: String,
target_kind: ProjectTarget, target_kind: ProjectTarget,
meta: ZludaMetadata, meta: ZludaMetadata,
} }
impl Project { impl Project {
fn try_new(p: Package) -> Option<Project> { fn try_new(p: Package) -> Option<Project> {
let name = p.name; let name = p.name;
serde_json::from_value::<Option<Metadata>>(p.metadata) serde_json::from_value::<Option<Metadata>>(p.metadata)
.unwrap() .unwrap()
.map(|m| { .map(|m| {
let (target_name, target_kind) = p let (target_name, target_kind) = p
.targets .targets
.into_iter() .into_iter()
.find_map(|target| { .find_map(|target| {
if target.is_cdylib() { if target.is_cdylib() {
Some((target.name, ProjectTarget::Cdylib)) Some((target.name, ProjectTarget::Cdylib))
} else if target.is_bin() { } else if target.is_bin() {
Some((target.name, ProjectTarget::Bin)) Some((target.name, ProjectTarget::Bin))
} else { } else {
None None
} }
}) })
.unwrap(); .unwrap();
Self { Self {
name, name,
target_name, target_name,
target_kind, target_kind,
meta: m.zluda, meta: m.zluda,
} }
}) })
} }
#[cfg(unix)] #[cfg(unix)]
fn prefix(&self) -> &'static str { fn prefix(&self) -> &'static str {
match self.target_kind { match self.target_kind {
ProjectTarget::Bin => "", ProjectTarget::Bin => "",
ProjectTarget::Cdylib => "lib", ProjectTarget::Cdylib => "lib",
} }
} }
#[cfg(not(unix))] #[cfg(not(unix))]
fn prefix(&self) -> &'static str { fn prefix(&self) -> &'static str {
"" ""
} }
#[cfg(unix)] #[cfg(unix)]
fn suffix(&self) -> &'static str { fn suffix(&self) -> &'static str {
match self.target_kind { match self.target_kind {
ProjectTarget::Bin => "", ProjectTarget::Bin => "",
ProjectTarget::Cdylib => ".so", ProjectTarget::Cdylib => ".so",
} }
} }
#[cfg(not(unix))] #[cfg(not(unix))]
fn suffix(&self) -> &'static str { fn suffix(&self) -> &'static str {
match self.target_kind { match self.target_kind {
ProjectTarget::Bin => ".exe", ProjectTarget::Bin => ".exe",
ProjectTarget::Cdylib => ".dll", ProjectTarget::Cdylib => ".dll",
} }
} }
// Returns tuple: // Returns tuple:
// * symlink file path (relative to the root of build dir) // * symlink file path (relative to the root of build dir)
// * symlink absolute file path // * symlink absolute file path
// * target actual file (relative to symlink file) // * target actual file (relative to symlink file)
#[cfg_attr(not(unix), allow(unused))] #[cfg_attr(not(unix), allow(unused))]
fn symlinks<'a>( fn symlinks<'a>(
&'a self, &'a self,
target_dir: &'a PathBuf, target_dir: &'a PathBuf,
profile: &'a str, profile: &'a str,
libname: &'a str, libname: &'a str,
) -> impl Iterator<Item = (&'a str, PathBuf, PathBuf)> + 'a { ) -> impl Iterator<Item = (&'a str, PathBuf, PathBuf)> + 'a {
self.meta.linux_symlinks.iter().map(move |source| { self.meta.linux_symlinks.iter().map(move |source| {
let mut link = target_dir.clone(); let mut link = target_dir.clone();
link.extend([profile, source]); link.extend([profile, source]);
let relative_link = PathBuf::from(source); let relative_link = PathBuf::from(source);
let ancestors = relative_link.as_path().ancestors().count(); let ancestors = relative_link.as_path().ancestors().count();
let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold( let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold(
PathBuf::new(), PathBuf::new(),
|mut buff, segment| { |mut buff, segment| {
buff.push(segment); buff.push(segment);
buff buff
}, },
); );
target.push(libname); target.push(libname);
(&**source, link, target) (&**source, link, target)
}) })
} }
fn file_name(&self) -> String { fn file_name(&self) -> String {
let target_name = &self.target_name; let target_name = &self.target_name;
let prefix = self.prefix(); let prefix = self.prefix();
let suffix = self.suffix(); let suffix = self.suffix();
format!("{prefix}{target_name}{suffix}") format!("{prefix}{target_name}{suffix}")
} }
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
enum ProjectTarget { enum ProjectTarget {
Cdylib, Cdylib,
Bin, Bin,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct Metadata { struct Metadata {
zluda: ZludaMetadata, zluda: ZludaMetadata,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
struct ZludaMetadata { struct ZludaMetadata {
#[serde(default)] #[serde(default)]
windows_only: bool, windows_only: bool,
#[serde(default)] #[serde(default)]
debug_only: bool, debug_only: bool,
#[cfg_attr(not(unix), allow(unused))] #[cfg_attr(not(unix), allow(unused))]
#[serde(default)] #[serde(default)]
linux_symlinks: Vec<String>, linux_symlinks: Vec<String>,
} }
fn main() { fn main() {
let options = match options().run_inner(Args::current_args()) { let options = match options().run_inner(Args::current_args()) {
Ok(b) => b, Ok(b) => b,
Err(err) => match build().to_options().run_inner(Args::current_args()) { Err(err) => match build().to_options().run_inner(Args::current_args()) {
Ok(b) => Options::Build(b), Ok(b) => Options::Build(b),
Err(_) => { Err(_) => {
err.print_message(100); err.print_message(100);
std::process::exit(err.exit_code()); std::process::exit(err.exit_code());
} }
}, },
}; };
match options { match options {
Options::Build(b) => { Options::Build(b) => {
compile(b); compile(b);
} }
Options::Zip(b) => zip(b), Options::Zip(b) => zip(b),
} }
} }
fn compile(b: Build) -> (PathBuf, String, Vec<Project>) { fn compile(b: Build) -> (PathBuf, String, Vec<Project>) {
let profile = sniff_out_profile_name(&b.cargo_arguments); let profile = sniff_out_profile_name(&b.cargo_arguments);
let meta = MetadataCommand::new().no_deps().exec().unwrap(); let meta = MetadataCommand::new().no_deps().exec().unwrap();
let target_directory = meta.target_directory.into_std_path_buf(); let target_directory = meta.target_directory.into_std_path_buf();
let projects = meta let projects = meta
.packages .packages
.into_iter() .into_iter()
.filter_map(Project::try_new) .filter_map(Project::try_new)
.filter(|project| { .filter(|project| {
if project.meta.windows_only && cfg!(not(windows)) { if project.meta.windows_only && cfg!(not(windows)) {
return false; return false;
} }
if project.meta.debug_only && profile != "debug" { if project.meta.debug_only && profile != "debug" {
return false; return false;
} }
true true
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string()); let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
let mut command = Command::new(&cargo); let mut command = Command::new(&cargo);
command.arg("build"); command.arg("build");
command.arg("--locked"); command.arg("--locked");
for project in projects.iter() { for project in projects.iter() {
command.arg("--package"); command.arg("--package");
command.arg(&project.name); command.arg(&project.name);
} }
command.args(b.cargo_arguments); command.args(b.cargo_arguments);
assert!(command.status().unwrap().success()); assert!(command.status().unwrap().success());
os::make_symlinks(&target_directory, &*projects, &*profile); os::make_symlinks(&target_directory, &*projects, &*profile);
(target_directory, profile, projects) (target_directory, profile, projects)
} }
fn sniff_out_profile_name(b: &[OsString]) -> String { fn sniff_out_profile_name(b: &[OsString]) -> String {
let parsed_cargo_arguments = cargo().to_options().run_inner(b); let parsed_cargo_arguments = cargo().to_options().run_inner(b);
match parsed_cargo_arguments { match parsed_cargo_arguments {
Ok(Cargo { Ok(Cargo {
release: Some(true), release: Some(true),
.. ..
}) => "release".to_string(), }) => "release".to_string(),
Ok(Cargo { Ok(Cargo {
profile: Some(profile), profile: Some(profile),
.. ..
}) => profile, }) => profile,
_ => "debug".to_string(), _ => "debug".to_string(),
} }
} }
fn zip(zip: Build) { fn zip(zip: Build) {
let (target_dir, profile, projects) = compile(zip); let (target_dir, profile, projects) = compile(zip);
os::zip(target_dir, profile, projects) os::zip(target_dir, profile, projects)
} }
#[cfg(unix)] #[cfg(unix)]
mod os { mod os {
use flate2::write::GzEncoder; use flate2::write::GzEncoder;
use flate2::Compression; use flate2::Compression;
use std::{ use std::{
fs::{self, File}, fs::{self, File},
path::PathBuf, path::PathBuf,
}; };
use tar::Header; use tar::Header;
pub fn make_symlinks( pub fn make_symlinks(
target_directory: &std::path::PathBuf, target_directory: &std::path::PathBuf,
projects: &[super::Project], projects: &[super::Project],
profile: &str, profile: &str,
) { ) {
use std::os::unix::fs as unix_fs; use std::os::unix::fs as unix_fs;
for project in projects.iter() { for project in projects.iter() {
let libname = project.file_name(); let libname = project.file_name();
for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) { for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) {
let mut dir = full_path.clone(); let mut dir = full_path.clone();
assert!(dir.pop()); assert!(dir.pop());
fs::create_dir_all(dir).unwrap(); fs::create_dir_all(dir).unwrap();
fs::remove_file(&full_path).ok(); fs::remove_file(&full_path).ok();
unix_fs::symlink(&target, full_path).unwrap(); unix_fs::symlink(&target, full_path).unwrap();
} }
} }
} }
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) { pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
let tar_gz = let tar_gz =
File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap(); File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap();
let enc = GzEncoder::new(tar_gz, Compression::default()); let enc = GzEncoder::new(tar_gz, Compression::default());
let mut tar = tar::Builder::new(enc); let mut tar = tar::Builder::new(enc);
for project in projects.iter() { for project in projects.iter() {
let file_name = project.file_name(); let file_name = project.file_name();
let mut file = let mut file =
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap(); File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
tar.append_file(format!("zluda/{file_name}"), &mut file) tar.append_file(format!("zluda/{file_name}"), &mut file)
.unwrap(); .unwrap();
for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) { for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) {
let mut header = Header::new_gnu(); let mut header = Header::new_gnu();
let meta = fs::symlink_metadata(&full_path).unwrap(); let meta = fs::symlink_metadata(&full_path).unwrap();
header.set_metadata(&meta); header.set_metadata(&meta);
tar.append_link(&mut header, format!("zluda/{source}"), target) tar.append_link(&mut header, format!("zluda/{source}"), target)
.unwrap(); .unwrap();
} }
} }
tar.finish().unwrap(); tar.finish().unwrap();
} }
} }
#[cfg(not(unix))] #[cfg(not(unix))]
mod os { mod os {
use std::{fs::File, io, path::PathBuf}; use std::{fs::File, io, path::PathBuf};
use zip::{write::SimpleFileOptions, ZipWriter}; use zip::{write::SimpleFileOptions, ZipWriter};
pub fn make_symlinks( pub fn make_symlinks(
_target_directory: &std::path::PathBuf, _target_directory: &std::path::PathBuf,
_projects: &[super::Project], _projects: &[super::Project],
_profile: &str, _profile: &str,
) { ) {
} }
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) { pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
let zip_file = let zip_file =
File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap(); File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap();
let mut zip = ZipWriter::new(zip_file); let mut zip = ZipWriter::new(zip_file);
zip.add_directory("zluda", SimpleFileOptions::default()) zip.add_directory("zluda", SimpleFileOptions::default())
.unwrap(); .unwrap();
for project in projects.iter() { for project in projects.iter() {
let file_name = project.file_name(); let file_name = project.file_name();
let mut file = let mut file =
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap(); File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
let file_options = file_options_from_time(&file).unwrap_or_default(); let file_options = file_options_from_time(&file).unwrap_or_default();
zip.start_file(format!("zluda/{file_name}"), file_options) zip.start_file(format!("zluda/{file_name}"), file_options)
.unwrap(); .unwrap();
io::copy(&mut file, &mut zip).unwrap(); io::copy(&mut file, &mut zip).unwrap();
} }
zip.finish().unwrap(); zip.finish().unwrap();
} }
fn file_options_from_time(from: &File) -> io::Result<SimpleFileOptions> { fn file_options_from_time(from: &File) -> io::Result<SimpleFileOptions> {
let metadata = from.metadata()?; let metadata = from.metadata()?;
let modified = metadata.modified()?; let modified = metadata.modified()?;
let modified = time::OffsetDateTime::from(modified); let modified = time::OffsetDateTime::from(modified);
Ok(SimpleFileOptions::default().last_modified_time( Ok(SimpleFileOptions::default().last_modified_time(
zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?, zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?,
)) ))
} }
} }

View File

@ -1,426 +1,426 @@
use super::{FromCuda, LiveCheck}; use super::{FromCuda, LiveCheck};
use crate::r#impl::{context, device}; use crate::r#impl::{context, device};
use comgr::Comgr; use comgr::Comgr;
use cuda_types::cuda::*; use cuda_types::cuda::*;
use hip_runtime_sys::*; use hip_runtime_sys::*;
use std::{ use std::{
ffi::{c_void, CStr, CString}, ffi::{c_void, CStr, CString},
mem, ptr, slice, mem, ptr, slice,
sync::OnceLock, sync::OnceLock,
usize, usize,
}; };
#[cfg_attr(windows, path = "os_win.rs")] #[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")] #[cfg_attr(not(windows), path = "os_unix.rs")]
mod os; mod os;
pub(crate) struct GlobalState { pub(crate) struct GlobalState {
pub devices: Vec<Device>, pub devices: Vec<Device>,
pub comgr: Comgr, pub comgr: Comgr,
} }
pub(crate) struct Device { pub(crate) struct Device {
pub(crate) _comgr_isa: CString, pub(crate) _comgr_isa: CString,
primary_context: LiveCheck<context::Context>, primary_context: LiveCheck<context::Context>,
} }
impl Device { impl Device {
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) { pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
unsafe { unsafe {
( (
self.primary_context.data.assume_init_ref(), self.primary_context.data.assume_init_ref(),
self.primary_context.as_handle(), self.primary_context.as_handle(),
) )
} }
} }
} }
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> { pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
global_state()? global_state()?
.devices .devices
.get(dev as usize) .get(dev as usize)
.ok_or(CUerror::INVALID_DEVICE) .ok_or(CUerror::INVALID_DEVICE)
} }
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> { pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new(); static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] { fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) } unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
} }
GLOBAL_STATE GLOBAL_STATE
.get_or_init(|| { .get_or_init(|| {
let mut device_count = 0; let mut device_count = 0;
unsafe { hipGetDeviceCount(&mut device_count) }?; unsafe { hipGetDeviceCount(&mut device_count) }?;
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?; let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
Ok(GlobalState { Ok(GlobalState {
comgr, comgr,
devices: (0..device_count) devices: (0..device_count)
.map(|i| { .map(|i| {
let mut props = unsafe { mem::zeroed() }; let mut props = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?; unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
Ok::<_, CUerror>(Device { Ok::<_, CUerror>(Device {
_comgr_isa: CStr::from_bytes_until_nul(cast_slice( _comgr_isa: CStr::from_bytes_until_nul(cast_slice(
&props.gcnArchName[..], &props.gcnArchName[..],
)) ))
.map_err(|_| CUerror::UNKNOWN)? .map_err(|_| CUerror::UNKNOWN)?
.to_owned(), .to_owned(),
primary_context: LiveCheck::new(context::Context::new(i)), primary_context: LiveCheck::new(context::Context::new(i)),
}) })
}) })
.collect::<Result<Vec<_>, _>>()?, .collect::<Result<Vec<_>, _>>()?,
}) })
}) })
.as_ref() .as_ref()
.map_err(|e| *e) .map_err(|e| *e)
} }
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult { pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
unsafe { hipInit(flags) }?; unsafe { hipInit(flags) }?;
global_state()?; global_state()?;
Ok(()) Ok(())
} }
struct UnknownBuffer<const S: usize> { struct UnknownBuffer<const S: usize> {
buffer: std::cell::UnsafeCell<[u32; S]>, buffer: std::cell::UnsafeCell<[u32; S]>,
} }
impl<const S: usize> UnknownBuffer<S> { impl<const S: usize> UnknownBuffer<S> {
const fn new() -> Self { const fn new() -> Self {
UnknownBuffer { UnknownBuffer {
buffer: std::cell::UnsafeCell::new([0; S]), buffer: std::cell::UnsafeCell::new([0; S]),
} }
} }
const fn len(&self) -> usize { const fn len(&self) -> usize {
S S
} }
} }
unsafe impl<const S: usize> Sync for UnknownBuffer<S> {} unsafe impl<const S: usize> Sync for UnknownBuffer<S> {}
static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new(); static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new();
static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new(); static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new();
struct DarkApi {} struct DarkApi {}
impl ::dark_api::cuda::CudaDarkApi for DarkApi { impl ::dark_api::cuda::CudaDarkApi for DarkApi {
unsafe extern "system" fn get_module_from_cubin( unsafe extern "system" fn get_module_from_cubin(
_module: *mut cuda_types::cuda::CUmodule, _module: *mut cuda_types::cuda::CUmodule,
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper, _fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn cudart_interface_fn2( unsafe extern "system" fn cudart_interface_fn2(
pctx: *mut cuda_types::cuda::CUcontext, pctx: *mut cuda_types::cuda::CUcontext,
hip_dev: hipDevice_t, hip_dev: hipDevice_t,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
let pctx = match pctx.as_mut() { let pctx = match pctx.as_mut() {
Some(p) => p, Some(p) => p,
None => return CUresult::ERROR_INVALID_VALUE, None => return CUresult::ERROR_INVALID_VALUE,
}; };
device::primary_context_retain(pctx, hip_dev) device::primary_context_retain(pctx, hip_dev)
} }
unsafe extern "system" fn get_module_from_cubin_ext1( unsafe extern "system" fn get_module_from_cubin_ext1(
_result: *mut cuda_types::cuda::CUmodule, _result: *mut cuda_types::cuda::CUmodule,
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper, _fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
_arg3: *mut std::ffi::c_void, _arg3: *mut std::ffi::c_void,
_arg4: *mut std::ffi::c_void, _arg4: *mut std::ffi::c_void,
_arg5: u32, _arg5: u32,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult { unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn get_module_from_cubin_ext2( unsafe extern "system" fn get_module_from_cubin_ext2(
_fatbin_header: *const cuda_types::dark_api::FatbinHeader, _fatbin_header: *const cuda_types::dark_api::FatbinHeader,
_result: *mut cuda_types::cuda::CUmodule, _result: *mut cuda_types::cuda::CUmodule,
_arg3: *mut std::ffi::c_void, _arg3: *mut std::ffi::c_void,
_arg4: *mut std::ffi::c_void, _arg4: *mut std::ffi::c_void,
_arg5: u32, _arg5: u32,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn get_unknown_buffer1( unsafe extern "system" fn get_unknown_buffer1(
ptr: *mut *mut std::ffi::c_void, ptr: *mut *mut std::ffi::c_void,
size: *mut usize, size: *mut usize,
) -> () { ) -> () {
*ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void; *ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void;
*size = UNKNOWN_BUFFER1.len(); *size = UNKNOWN_BUFFER1.len();
} }
unsafe extern "system" fn get_unknown_buffer2( unsafe extern "system" fn get_unknown_buffer2(
ptr: *mut *mut std::ffi::c_void, ptr: *mut *mut std::ffi::c_void,
size: *mut usize, size: *mut usize,
) -> () { ) -> () {
*ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void; *ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void;
*size = UNKNOWN_BUFFER2.len(); *size = UNKNOWN_BUFFER2.len();
} }
unsafe extern "system" fn context_local_storage_put( unsafe extern "system" fn context_local_storage_put(
cu_ctx: CUcontext, cu_ctx: CUcontext,
key: *mut c_void, key: *mut c_void,
value: *mut c_void, value: *mut c_void,
dtor_cb: Option<extern "system" fn(CUcontext, *mut c_void, *mut c_void)>, dtor_cb: Option<extern "system" fn(CUcontext, *mut c_void, *mut c_void)>,
) -> CUresult { ) -> CUresult {
let _ctx = if cu_ctx.0 != ptr::null_mut() { let _ctx = if cu_ctx.0 != ptr::null_mut() {
cu_ctx cu_ctx
} else { } else {
let mut current_ctx: CUcontext = CUcontext(ptr::null_mut()); let mut current_ctx: CUcontext = CUcontext(ptr::null_mut());
context::get_current(&mut current_ctx)?; context::get_current(&mut current_ctx)?;
current_ctx current_ctx
}; };
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?; let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
ctx_obj.with_state_mut(|state: &mut context::ContextState| { ctx_obj.with_state_mut(|state: &mut context::ContextState| {
state.storage.insert( state.storage.insert(
key as usize, key as usize,
context::StorageData { context::StorageData {
value: value as usize, value: value as usize,
reset_cb: dtor_cb, reset_cb: dtor_cb,
handle: _ctx, handle: _ctx,
}, },
); );
Ok(()) Ok(())
})?; })?;
Ok(()) Ok(())
} }
unsafe extern "system" fn context_local_storage_delete( unsafe extern "system" fn context_local_storage_delete(
cu_ctx: CUcontext, cu_ctx: CUcontext,
key: *mut c_void, key: *mut c_void,
) -> CUresult { ) -> CUresult {
let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?; let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?;
ctx_obj.with_state_mut(|state: &mut context::ContextState| { ctx_obj.with_state_mut(|state: &mut context::ContextState| {
state.storage.remove(&(key as usize)); state.storage.remove(&(key as usize));
Ok(()) Ok(())
})?; })?;
Ok(()) Ok(())
} }
unsafe extern "system" fn context_local_storage_get( unsafe extern "system" fn context_local_storage_get(
value: *mut *mut c_void, value: *mut *mut c_void,
cu_ctx: CUcontext, cu_ctx: CUcontext,
key: *mut c_void, key: *mut c_void,
) -> CUresult { ) -> CUresult {
let mut _ctx: CUcontext; let mut _ctx: CUcontext;
if cu_ctx.0 == ptr::null_mut() { if cu_ctx.0 == ptr::null_mut() {
_ctx = context::get_current_context()?; _ctx = context::get_current_context()?;
} else { } else {
_ctx = cu_ctx _ctx = cu_ctx
}; };
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?; let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
ctx_obj.with_state(|state: &context::ContextState| { ctx_obj.with_state(|state: &context::ContextState| {
match state.storage.get(&(key as usize)) { match state.storage.get(&(key as usize)) {
Some(data) => *value = data.value as *mut c_void, Some(data) => *value = data.value as *mut c_void,
None => return CUresult::ERROR_INVALID_HANDLE, None => return CUresult::ERROR_INVALID_HANDLE,
} }
Ok(()) Ok(())
})?; })?;
Ok(()) Ok(())
} }
unsafe extern "system" fn ctx_create_v2_bypass( unsafe extern "system" fn ctx_create_v2_bypass(
_pctx: *mut cuda_types::cuda::CUcontext, _pctx: *mut cuda_types::cuda::CUcontext,
_flags: ::std::os::raw::c_uint, _flags: ::std::os::raw::c_uint,
_dev: cuda_types::cuda::CUdevice, _dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn heap_alloc( unsafe extern "system" fn heap_alloc(
_heap_alloc_record_ptr: *mut *const std::ffi::c_void, _heap_alloc_record_ptr: *mut *const std::ffi::c_void,
_arg2: usize, _arg2: usize,
_arg3: usize, _arg3: usize,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn heap_free( unsafe extern "system" fn heap_free(
_heap_alloc_record_ptr: *const std::ffi::c_void, _heap_alloc_record_ptr: *const std::ffi::c_void,
_arg2: *mut usize, _arg2: *mut usize,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn device_get_attribute_ext( unsafe extern "system" fn device_get_attribute_ext(
_dev: cuda_types::cuda::CUdevice, _dev: cuda_types::cuda::CUdevice,
_attribute: std::ffi::c_uint, _attribute: std::ffi::c_uint,
_unknown: std::ffi::c_int, _unknown: std::ffi::c_int,
_result: *mut [usize; 2], _result: *mut [usize; 2],
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn device_get_something( unsafe extern "system" fn device_get_something(
_result: *mut std::ffi::c_uchar, _result: *mut std::ffi::c_uchar,
_dev: cuda_types::cuda::CUdevice, _dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
todo!() todo!()
} }
unsafe extern "system" fn integrity_check( unsafe extern "system" fn integrity_check(
version: u32, version: u32,
unix_seconds: u64, unix_seconds: u64,
result: *mut [u64; 2], result: *mut [u64; 2],
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
let current_process = std::process::id(); let current_process = std::process::id();
let current_thread = os::current_thread(); let current_thread = os::current_thread();
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast(); let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast(); let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1]; let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
let devices = get_device_hash_info()?; let devices = get_device_hash_info()?;
let device_count = devices.len() as u32; let device_count = devices.len() as u32;
let get_device = |dev| devices[dev as usize]; let get_device = |dev| devices[dev as usize];
let hash = ::dark_api::integrity_check( let hash = ::dark_api::integrity_check(
version, version,
unix_seconds, unix_seconds,
cuda_types::cuda::CUDA_VERSION, cuda_types::cuda::CUDA_VERSION,
current_process, current_process,
current_thread, current_thread,
integrity_check_table, integrity_check_table,
cudart_table, cudart_table,
fn_address, fn_address,
device_count, device_count,
get_device, get_device,
); );
*result = hash; *result = hash;
Ok(()) Ok(())
} }
unsafe extern "system" fn context_check( unsafe extern "system" fn context_check(
_ctx_in: cuda_types::cuda::CUcontext, _ctx_in: cuda_types::cuda::CUcontext,
result1: *mut u32, result1: *mut u32,
_result2: *mut *const std::ffi::c_void, _result2: *mut *const std::ffi::c_void,
) -> cuda_types::cuda::CUresult { ) -> cuda_types::cuda::CUresult {
*result1 = 0; *result1 = 0;
CUresult::SUCCESS CUresult::SUCCESS
} }
unsafe extern "system" fn check_fn3() -> u32 { unsafe extern "system" fn check_fn3() -> u32 {
0 0
} }
} }
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> { fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
let mut device_count = 0; let mut device_count = 0;
device::get_count(&mut device_count)?; device::get_count(&mut device_count)?;
(0..device_count) (0..device_count)
.map(|dev| { .map(|dev| {
let mut guid = CUuuid_st { bytes: [0; 16] }; let mut guid = CUuuid_st { bytes: [0; 16] };
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? }; unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
let mut pci_domain = 0; let mut pci_domain = 0;
device::get_attribute( device::get_attribute(
&mut pci_domain, &mut pci_domain,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
dev, dev,
)?; )?;
let mut pci_bus = 0; let mut pci_bus = 0;
device::get_attribute( device::get_attribute(
&mut pci_bus, &mut pci_bus,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
dev, dev,
)?; )?;
let mut pci_device = 0; let mut pci_device = 0;
device::get_attribute( device::get_attribute(
&mut pci_device, &mut pci_device,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
dev, dev,
)?; )?;
Ok(::dark_api::DeviceHashinfo { Ok(::dark_api::DeviceHashinfo {
guid, guid,
pci_domain, pci_domain,
pci_bus, pci_bus,
pci_device, pci_device,
}) })
}) })
.collect() .collect()
} }
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable = static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>(); ::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();
pub(crate) fn get_export_table( pub(crate) fn get_export_table(
pp_export_table: &mut *const ::core::ffi::c_void, pp_export_table: &mut *const ::core::ffi::c_void,
p_export_table_id: &CUuuid, p_export_table_id: &CUuuid,
) -> CUresult { ) -> CUresult {
if let Some(table) = EXPORT_TABLE.get(p_export_table_id) { if let Some(table) = EXPORT_TABLE.get(p_export_table_id) {
*pp_export_table = table.start(); *pp_export_table = table.start();
cuda_types::cuda::CUresult::SUCCESS cuda_types::cuda::CUresult::SUCCESS
} else { } else {
cuda_types::cuda::CUresult::ERROR_INVALID_VALUE cuda_types::cuda::CUresult::ERROR_INVALID_VALUE
} }
} }
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult { pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
*version = cuda_types::cuda::CUDA_VERSION as i32; *version = cuda_types::cuda::CUDA_VERSION as i32;
Ok(()) Ok(())
} }
pub(crate) unsafe fn get_proc_address( pub(crate) unsafe fn get_proc_address(
symbol: &CStr, symbol: &CStr,
pfn: &mut *mut ::core::ffi::c_void, pfn: &mut *mut ::core::ffi::c_void,
cuda_version: ::core::ffi::c_int, cuda_version: ::core::ffi::c_int,
flags: cuda_types::cuda::cuuint64_t, flags: cuda_types::cuda::cuuint64_t,
) -> CUresult { ) -> CUresult {
get_proc_address_v2(symbol, pfn, cuda_version, flags, None) get_proc_address_v2(symbol, pfn, cuda_version, flags, None)
} }
pub(crate) unsafe fn get_proc_address_v2( pub(crate) unsafe fn get_proc_address_v2(
symbol: &CStr, symbol: &CStr,
pfn: &mut *mut ::core::ffi::c_void, pfn: &mut *mut ::core::ffi::c_void,
cuda_version: ::core::ffi::c_int, cuda_version: ::core::ffi::c_int,
flags: cuda_types::cuda::cuuint64_t, flags: cuda_types::cuda::cuuint64_t,
symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>, symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>,
) -> CUresult { ) -> CUresult {
// This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point. // This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point.
fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void { fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void {
use crate::*; use crate::*;
include!("../../../zluda_bindgen/src/process_table.rs") include!("../../../zluda_bindgen/src/process_table.rs")
} }
let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version); let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version);
match fn_ptr as usize { match fn_ptr as usize {
0 => { 0 => {
if let Some(symbol_status) = symbol_status { if let Some(symbol_status) = symbol_status {
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND; *symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND;
} }
*pfn = ptr::null_mut(); *pfn = ptr::null_mut();
CUresult::ERROR_NOT_FOUND CUresult::ERROR_NOT_FOUND
} }
usize::MAX => { usize::MAX => {
if let Some(symbol_status) = symbol_status { if let Some(symbol_status) = symbol_status {
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT; *symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT;
} }
*pfn = ptr::null_mut(); *pfn = ptr::null_mut();
CUresult::ERROR_NOT_FOUND CUresult::ERROR_NOT_FOUND
} }
_ => { _ => {
if let Some(symbol_status) = symbol_status { if let Some(symbol_status) = symbol_status {
*symbol_status = *symbol_status =
cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS; cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS;
} }
*pfn = fn_ptr; *pfn = fn_ptr;
Ok(()) Ok(())
} }
} }
} }
pub(crate) fn profiler_start() -> CUresult { pub(crate) fn profiler_start() -> CUresult {
Ok(()) Ok(())
} }
pub(crate) fn profiler_stop() -> CUresult { pub(crate) fn profiler_stop() -> CUresult {
Ok(()) Ok(())
} }

View File

@ -1,9 +1,9 @@
// TODO: remove duplication with zluda_dump // TODO: remove duplication with zluda_dump
#[link(name = "pthread")] #[link(name = "pthread")]
unsafe extern "C" { unsafe extern "C" {
fn pthread_self() -> std::os::unix::thread::RawPthread; fn pthread_self() -> std::os::unix::thread::RawPthread;
} }
pub(crate) fn current_thread() -> u32 { pub(crate) fn current_thread() -> u32 {
(unsafe { pthread_self() }) as u32 (unsafe { pthread_self() }) as u32
} }

View File

@ -1,9 +1,9 @@
// TODO: remove duplication with zluda_dump // TODO: remove duplication with zluda_dump
#[link(name = "kernel32")] #[link(name = "kernel32")]
unsafe extern "system" { unsafe extern "system" {
fn GetCurrentThreadId() -> u32; fn GetCurrentThreadId() -> u32;
} }
pub(crate) fn current_thread() -> u32 { pub(crate) fn current_thread() -> u32 {
unsafe { GetCurrentThreadId() } unsafe { GetCurrentThreadId() }
} }

View File

@ -1,124 +1,124 @@
use crate::os; use crate::os;
use crate::{CudaFunctionName, ErrorEntry}; use crate::{CudaFunctionName, ErrorEntry};
use cuda_types::cuda::*; use cuda_types::cuda::*;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use std::cell::RefMut; use std::cell::RefMut;
use std::hash::Hash; use std::hash::Hash;
use std::{collections::hash_map, ffi::c_void, mem}; use std::{collections::hash_map, ffi::c_void, mem};
pub(crate) struct DarkApiState2 { pub(crate) struct DarkApiState2 {
// Key is Box<CUuuid, because thunk reporting unknown export table needs a // Key is Box<CUuuid, because thunk reporting unknown export table needs a
// stable memory location for the guid // stable memory location for the guid
pub(crate) overrides: FxHashMap<Box<CUuuidWrapper>, (*const *const c_void, Vec<*const c_void>)>, pub(crate) overrides: FxHashMap<Box<CUuuidWrapper>, (*const *const c_void, Vec<*const c_void>)>,
} }
unsafe impl Send for DarkApiState2 {} unsafe impl Send for DarkApiState2 {}
unsafe impl Sync for DarkApiState2 {} unsafe impl Sync for DarkApiState2 {}
impl DarkApiState2 { impl DarkApiState2 {
pub(crate) fn new() -> Self { pub(crate) fn new() -> Self {
DarkApiState2 { DarkApiState2 {
overrides: FxHashMap::default(), overrides: FxHashMap::default(),
} }
} }
pub(crate) fn override_export_table( pub(crate) fn override_export_table(
&mut self, &mut self,
known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable, known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable,
original_export_table: *const *const c_void, original_export_table: *const *const c_void,
guid: &CUuuid_st, guid: &CUuuid_st,
) -> (*const *const c_void, Option<ErrorEntry>) { ) -> (*const *const c_void, Option<ErrorEntry>) {
let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) { let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) {
hash_map::Entry::Occupied(entry) => { hash_map::Entry::Occupied(entry) => {
let (_, override_table) = entry.get(); let (_, override_table) = entry.get();
return (override_table.as_ptr(), None); return (override_table.as_ptr(), None);
} }
hash_map::Entry::Vacant(entry) => entry, hash_map::Entry::Vacant(entry) => entry,
}; };
let mut error = None; let mut error = None;
let byte_size: usize = unsafe { *(original_export_table.cast::<usize>()) }; let byte_size: usize = unsafe { *(original_export_table.cast::<usize>()) };
// Some export tables don't start with a byte count, but directly with a // Some export tables don't start with a byte count, but directly with a
// pointer, and are instead terminated by 0 or MAX // pointer, and are instead terminated by 0 or MAX
let export_functions_start_idx; let export_functions_start_idx;
let export_functions_size; let export_functions_size;
if byte_size > 0x10000 { if byte_size > 0x10000 {
export_functions_start_idx = 0; export_functions_start_idx = 0;
let mut i = 0; let mut i = 0;
loop { loop {
let current_ptr = unsafe { original_export_table.add(i) }; let current_ptr = unsafe { original_export_table.add(i) };
let current_ptr_numeric = unsafe { *current_ptr } as usize; let current_ptr_numeric = unsafe { *current_ptr } as usize;
if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX { if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX {
export_functions_size = i; export_functions_size = i;
break; break;
} }
i += 1; i += 1;
} }
} else { } else {
export_functions_start_idx = 1; export_functions_start_idx = 1;
export_functions_size = byte_size / mem::size_of::<usize>(); export_functions_size = byte_size / mem::size_of::<usize>();
} }
let our_functions = known_exports.get(guid); let our_functions = known_exports.get(guid);
if let Some(ref our_functions) = our_functions { if let Some(ref our_functions) = our_functions {
if our_functions.len() != export_functions_size { if our_functions.len() != export_functions_size {
error = Some(ErrorEntry::UnexpectedExportTableSize { error = Some(ErrorEntry::UnexpectedExportTableSize {
expected: our_functions.len(), expected: our_functions.len(),
computed: export_functions_size, computed: export_functions_size,
}); });
} }
} }
let mut override_table = let mut override_table =
unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) } unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) }
.to_vec(); .to_vec();
for i in export_functions_start_idx..export_functions_size { for i in export_functions_start_idx..export_functions_size {
let current_fn = (|| { let current_fn = (|| {
if let Some(ref our_functions) = our_functions { if let Some(ref our_functions) = our_functions {
if let Some(fn_) = our_functions.get_fn(i) { if let Some(fn_) = our_functions.get_fn(i) {
return fn_; return fn_;
} }
} }
os::get_thunk( os::get_thunk(
override_table[i], override_table[i],
Self::report_unknown_export_table_call, Self::report_unknown_export_table_call,
std::ptr::from_ref(entry.key().as_ref()).cast(), std::ptr::from_ref(entry.key().as_ref()).cast(),
i, i,
) )
})(); })();
override_table[i] = current_fn; override_table[i] = current_fn;
} }
( (
entry entry
.insert((original_export_table, override_table)) .insert((original_export_table, override_table))
.1 .1
.as_ptr(), .as_ptr(),
error, error,
) )
} }
unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) { unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) {
let global_state = crate::GLOBAL_STATE2.lock(); let global_state = crate::GLOBAL_STATE2.lock();
let global_state_ref_cell = &*global_state; let global_state_ref_cell = &*global_state;
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut(); let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
let global_state = &mut *global_state_ref_mut; let global_state = &mut *global_state_ref_mut;
let log_guard = crate::OuterCallGuard { let log_guard = crate::OuterCallGuard {
writer: &mut global_state.log_writer, writer: &mut global_state.log_writer,
log_root: &global_state.log_stack, log_root: &global_state.log_stack,
}; };
{ {
let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| { let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| {
log_stack.enter() log_stack.enter()
}); });
logger.name = CudaFunctionName::Dark { guid: *guid, index }; logger.name = CudaFunctionName::Dark { guid: *guid, index };
}; };
drop(log_guard); drop(log_guard);
} }
} }
#[derive(Eq, PartialEq)] #[derive(Eq, PartialEq)]
#[repr(transparent)] #[repr(transparent)]
pub(crate) struct CUuuidWrapper(pub CUuuid); pub(crate) struct CUuuidWrapper(pub CUuuid);
impl Hash for CUuuidWrapper { impl Hash for CUuuidWrapper {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.bytes.hash(state); self.0.bytes.hash(state);
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,81 +1,81 @@
use cuda_types::cuda::CUuuid; use cuda_types::cuda::CUuuid;
use std::ffi::{c_void, CStr, CString}; use std::ffi::{c_void, CStr, CString};
use std::mem; use std::mem;
pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1"; pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1";
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void { pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path = CString::new(libcuda_path).unwrap(); let libcuda_path = CString::new(libcuda_path).unwrap();
libc::dlopen( libc::dlopen(
libcuda_path.as_ptr() as *const _, libcuda_path.as_ptr() as *const _,
libc::RTLD_LOCAL | libc::RTLD_NOW, libc::RTLD_LOCAL | libc::RTLD_NOW,
) )
} }
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void { pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
libc::dlsym(handle, func.as_ptr() as *const _) libc::dlsym(handle, func.as_ptr() as *const _)
} }
#[macro_export] #[macro_export]
macro_rules! os_log { macro_rules! os_log {
($format:tt) => { ($format:tt) => {
{ {
eprintln!("[ZLUDA_DUMP] {}", format!($format)); eprintln!("[ZLUDA_DUMP] {}", format!($format));
} }
}; };
($format:tt, $($obj: expr),+) => { ($format:tt, $($obj: expr),+) => {
{ {
eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+)); eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+));
} }
}; };
} }
//RDI, RSI, RDX, RCX, R8, R9 //RDI, RSI, RDX, RCX, R8, R9
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
pub fn get_thunk( pub fn get_thunk(
original_fn: *const c_void, original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize), report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid, guid: *const CUuuid,
idx: usize, idx: usize,
) -> *const c_void { ) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi}; use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x64::Assembler::new().unwrap(); let mut ops = dynasmrt::x64::Assembler::new().unwrap();
let start = ops.offset(); let start = ops.offset();
dynasm!(ops dynasm!(ops
// stack alignment // stack alignment
; sub rsp, 8 ; sub rsp, 8
; push rdi ; push rdi
; push rsi ; push rsi
; push rdx ; push rdx
; push rcx ; push rcx
; push r8 ; push r8
; push r9 ; push r9
; mov rdi, QWORD guid as i64 ; mov rdi, QWORD guid as i64
; mov rsi, QWORD idx as i64 ; mov rsi, QWORD idx as i64
; mov rax, QWORD report_fn as i64 ; mov rax, QWORD report_fn as i64
; call rax ; call rax
; pop r9 ; pop r9
; pop r8 ; pop r8
; pop rcx ; pop rcx
; pop rdx ; pop rdx
; pop rsi ; pop rsi
; pop rdi ; pop rdi
; add rsp, 8 ; add rsp, 8
; mov rax, QWORD original_fn as i64 ; mov rax, QWORD original_fn as i64
; jmp rax ; jmp rax
; int 3 ; int 3
); );
let exe_buf = ops.finalize().unwrap(); let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start); let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf); mem::forget(exe_buf);
result_fn as *const _ result_fn as *const _
} }
#[link(name = "pthread")] #[link(name = "pthread")]
unsafe extern "C" { unsafe extern "C" {
fn pthread_self() -> std::os::unix::thread::RawPthread; fn pthread_self() -> std::os::unix::thread::RawPthread;
} }
pub(crate) fn current_thread() -> u32 { pub(crate) fn current_thread() -> u32 {
(unsafe { pthread_self() }) as u32 (unsafe { pthread_self() }) as u32
} }

View File

@ -1,190 +1,190 @@
use std::{ use std::{
ffi::{c_void, CStr}, ffi::{c_void, CStr},
mem, ptr, mem, ptr,
sync::LazyLock, sync::LazyLock,
}; };
use std::os::windows::io::AsRawHandle; use std::os::windows::io::AsRawHandle;
use winapi::{ use winapi::{
shared::minwindef::{FARPROC, HMODULE}, shared::minwindef::{FARPROC, HMODULE},
um::debugapi::OutputDebugStringA, um::debugapi::OutputDebugStringA,
um::libloaderapi::{GetProcAddress, LoadLibraryW}, um::libloaderapi::{GetProcAddress, LoadLibraryW},
}; };
use cuda_types::cuda::CUuuid; use cuda_types::cuda::CUuuid;
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll"; pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0"; const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0"; const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
static PLATFORM_LIBRARY: LazyLock<PlatformLibrary> = static PLATFORM_LIBRARY: LazyLock<PlatformLibrary> =
LazyLock::new(|| unsafe { PlatformLibrary::new() }); LazyLock::new(|| unsafe { PlatformLibrary::new() });
#[allow(non_snake_case)] #[allow(non_snake_case)]
struct PlatformLibrary { struct PlatformLibrary {
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE, LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC, GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
} }
impl PlatformLibrary { impl PlatformLibrary {
#[allow(non_snake_case)] #[allow(non_snake_case)]
unsafe fn new() -> Self { unsafe fn new() -> Self {
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() { let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
None => ( None => (
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE, LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
mem::transmute( mem::transmute(
GetProcAddress GetProcAddress
as unsafe extern "system" fn( as unsafe extern "system" fn(
hModule: HMODULE, hModule: HMODULE,
lpProcName: *const i8, lpProcName: *const i8,
) -> FARPROC, ) -> FARPROC,
), ),
), ),
Some(zluda_with) => ( Some(zluda_with) => (
mem::transmute(GetProcAddress( mem::transmute(GetProcAddress(
zluda_with, zluda_with,
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _, LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
)), )),
mem::transmute(GetProcAddress( mem::transmute(GetProcAddress(
zluda_with, zluda_with,
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _, GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
)), )),
), ),
}; };
PlatformLibrary { PlatformLibrary {
LoadLibraryW, LoadLibraryW,
GetProcAddress, GetProcAddress,
} }
} }
unsafe fn get_detourer_module() -> Option<HMODULE> { unsafe fn get_detourer_module() -> Option<HMODULE> {
let mut module = ptr::null_mut(); let mut module = ptr::null_mut();
loop { loop {
module = detours_sys::DetourEnumerateModules(module); module = detours_sys::DetourEnumerateModules(module);
if module == ptr::null_mut() { if module == ptr::null_mut() {
break; break;
} }
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _); let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
if payload != ptr::null_mut() { if payload != ptr::null_mut() {
return Some(module as _); return Some(module as _);
} }
} }
None None
} }
} }
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void { pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path_uf16 = libcuda_path let libcuda_path_uf16 = libcuda_path
.encode_utf16() .encode_utf16()
.chain(std::iter::once(0)) .chain(std::iter::once(0))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _ (PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
} }
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void { pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _ (PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
} }
#[macro_export] #[macro_export]
macro_rules! os_log { macro_rules! os_log {
($format:tt) => { ($format:tt) => {
{ {
use crate::os::__log_impl; use crate::os::__log_impl;
__log_impl(format!($format)); __log_impl(format!($format));
} }
}; };
($format:tt, $($obj: expr),+) => { ($format:tt, $($obj: expr),+) => {
{ {
use crate::os::__log_impl; use crate::os::__log_impl;
__log_impl(format!($format, $($obj,)+)); __log_impl(format!($format, $($obj,)+));
} }
}; };
} }
pub fn __log_impl(s: String) { pub fn __log_impl(s: String) {
let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut(); let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut();
if log_to_stderr { if log_to_stderr {
eprintln!("[ZLUDA_DUMP] {}", s); eprintln!("[ZLUDA_DUMP] {}", s);
} else { } else {
let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2); let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2);
win_str.push_str("[ZLUDA_DUMP] "); win_str.push_str("[ZLUDA_DUMP] ");
win_str.push_str(&s); win_str.push_str(&s);
win_str.push_str("\n\0"); win_str.push_str("\n\0");
unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) }; unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) };
} }
} }
#[cfg(target_arch = "x86")] #[cfg(target_arch = "x86")]
pub fn get_thunk( pub fn get_thunk(
original_fn: *const c_void, original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize), report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid, guid: *const CUuuid,
idx: usize, idx: usize,
) -> *const c_void { ) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi}; use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap(); let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset(); let start = ops.offset();
dynasm!(ops dynasm!(ops
; .arch x86 ; .arch x86
; push idx as i32 ; push idx as i32
; push guid as i32 ; push guid as i32
; mov eax, report_fn as i32 ; mov eax, report_fn as i32
; call eax ; call eax
; mov eax, original_fn as i32 ; mov eax, original_fn as i32
; jmp eax ; jmp eax
; int 3 ; int 3
); );
let exe_buf = ops.finalize().unwrap(); let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start); let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf); mem::forget(exe_buf);
result_fn as *const _ result_fn as *const _
} }
//RCX, RDX, R8, R9 //RCX, RDX, R8, R9
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
pub fn get_thunk( pub fn get_thunk(
original_fn: *const c_void, original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize), report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid, guid: *const CUuuid,
idx: usize, idx: usize,
) -> *const c_void { ) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi}; use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap(); let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset(); let start = ops.offset();
// Let's hope there's never more than 4 arguments // Let's hope there's never more than 4 arguments
dynasm!(ops dynasm!(ops
; .arch x64 ; .arch x64
; push rbp ; push rbp
; mov rbp, rsp ; mov rbp, rsp
; push rcx ; push rcx
; push rdx ; push rdx
; push r8 ; push r8
; push r9 ; push r9
; mov rcx, QWORD guid as i64 ; mov rcx, QWORD guid as i64
; mov rdx, QWORD idx as i64 ; mov rdx, QWORD idx as i64
; mov rax, QWORD report_fn as i64 ; mov rax, QWORD report_fn as i64
; call rax ; call rax
; pop r9 ; pop r9
; pop r8 ; pop r8
; pop rdx ; pop rdx
; pop rcx ; pop rcx
; mov rax, QWORD original_fn as i64 ; mov rax, QWORD original_fn as i64
; call rax ; call rax
; pop rbp ; pop rbp
; ret ; ret
; int 3 ; int 3
); );
let exe_buf = ops.finalize().unwrap(); let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start); let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf); mem::forget(exe_buf);
result_fn as *const _ result_fn as *const _
} }
#[link(name = "kernel32")] #[link(name = "kernel32")]
unsafe extern "system" { unsafe extern "system" {
fn GetCurrentThreadId() -> u32; fn GetCurrentThreadId() -> u32;
} }
pub(crate) fn current_thread() -> u32 { pub(crate) fn current_thread() -> u32 {
unsafe { GetCurrentThreadId() } unsafe { GetCurrentThreadId() }
} }

View File

@ -1,334 +1,334 @@
use crate::{ use crate::{
log::{self, UInt}, log::{self, UInt},
trace, ErrorEntry, FnCallLog, Settings, trace, ErrorEntry, FnCallLog, Settings,
}; };
use cuda_types::{ use cuda_types::{
cuda::*, cuda::*,
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper}, dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
}; };
use dark_api::fatbin::{ use dark_api::fatbin::{
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule, decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
}; };
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use std::{ use std::{
borrow::Cow, borrow::Cow,
ffi::{c_void, CStr, CString}, ffi::{c_void, CStr, CString},
fs::{self, File}, fs::{self, File},
io::{self, Read, Write}, io::{self, Read, Write},
path::PathBuf, path::PathBuf,
}; };
use unwrap_or::unwrap_some_or; use unwrap_or::unwrap_some_or;
// This struct is the heart of CUDA state tracking, it: // This struct is the heart of CUDA state tracking, it:
// * receives calls from the probes about changes to CUDA state // * receives calls from the probes about changes to CUDA state
// * records updates to the state change // * records updates to the state change
// * writes out relevant state change and details to disk and log // * writes out relevant state change and details to disk and log
pub(crate) struct StateTracker { pub(crate) struct StateTracker {
writer: DumpWriter, writer: DumpWriter,
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>, pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
saved_modules: FxHashSet<CUmodule>, saved_modules: FxHashSet<CUmodule>,
module_counter: usize, module_counter: usize,
submodule_counter: usize, submodule_counter: usize,
pub(crate) override_cc: Option<(u32, u32)>, pub(crate) override_cc: Option<(u32, u32)>,
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub(crate) struct CodePointer(pub *const c_void); pub(crate) struct CodePointer(pub *const c_void);
unsafe impl Send for CodePointer {} unsafe impl Send for CodePointer {}
unsafe impl Sync for CodePointer {} unsafe impl Sync for CodePointer {}
impl StateTracker { impl StateTracker {
pub(crate) fn new(settings: &Settings) -> Self { pub(crate) fn new(settings: &Settings) -> Self {
StateTracker { StateTracker {
writer: DumpWriter::new(settings.dump_dir.clone()), writer: DumpWriter::new(settings.dump_dir.clone()),
libraries: FxHashMap::default(), libraries: FxHashMap::default(),
saved_modules: FxHashSet::default(), saved_modules: FxHashSet::default(),
module_counter: 0, module_counter: 0,
submodule_counter: 0, submodule_counter: 0,
override_cc: settings.override_cc, override_cc: settings.override_cc,
} }
} }
pub(crate) fn record_new_module_file( pub(crate) fn record_new_module_file(
&mut self, &mut self,
module: CUmodule, module: CUmodule,
file_name: *const i8, file_name: *const i8,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
) { ) {
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() { let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
Ok(f) => f, Ok(f) => f,
Err(err) => { Err(err) => {
fn_logger.log(log::ErrorEntry::MalformedModulePath(err)); fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
return; return;
} }
}; };
let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name); let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name);
fn_logger.log_io_error(maybe_io_error) fn_logger.log_io_error(maybe_io_error)
} }
fn try_record_new_module_file( fn try_record_new_module_file(
&mut self, &mut self,
module: CUmodule, module: CUmodule,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
file_name: &str, file_name: &str,
) -> io::Result<()> { ) -> io::Result<()> {
let mut module_file = fs::File::open(file_name)?; let mut module_file = fs::File::open(file_name)?;
let mut read_buff = Vec::new(); let mut read_buff = Vec::new();
module_file.read_to_end(&mut read_buff)?; module_file.read_to_end(&mut read_buff)?;
self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger); self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger);
Ok(()) Ok(())
} }
pub(crate) fn record_new_submodule( pub(crate) fn record_new_submodule(
&mut self, &mut self,
module: CUmodule, module: CUmodule,
submodule: &[u8], submodule: &[u8],
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
type_: &'static str, type_: &'static str,
) { ) {
if self.saved_modules.insert(module) { if self.saved_modules.insert(module) {
self.module_counter += 1; self.module_counter += 1;
self.submodule_counter = 0; self.submodule_counter = 0;
} }
self.submodule_counter += 1; self.submodule_counter += 1;
fn_logger.log_io_error(self.writer.save_module( fn_logger.log_io_error(self.writer.save_module(
self.module_counter, self.module_counter,
Some(self.submodule_counter), Some(self.submodule_counter),
submodule, submodule,
type_, type_,
)); ));
if type_ == "ptx" { if type_ == "ptx" {
match CString::new(submodule) { match CString::new(submodule) {
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)), Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
Ok(submodule_cstring) => match submodule_cstring.to_str() { Ok(submodule_cstring) => match submodule_cstring.to_str() {
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)), Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
Ok(submodule_text) => self.try_parse_and_record_kernels( Ok(submodule_text) => self.try_parse_and_record_kernels(
fn_logger, fn_logger,
self.module_counter, self.module_counter,
Some(self.submodule_counter), Some(self.submodule_counter),
submodule_text, submodule_text,
), ),
}, },
} }
} }
} }
pub(crate) fn record_new_module( pub(crate) fn record_new_module(
&mut self, &mut self,
module: CUmodule, module: CUmodule,
raw_image: *const c_void, raw_image: *const c_void,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
) { ) {
self.module_counter += 1; self.module_counter += 1;
if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG { if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG {
self.saved_modules.insert(module); self.saved_modules.insert(module);
// TODO: Parse ELF and write it to disk // TODO: Parse ELF and write it to disk
fn_logger.log(log::ErrorEntry::UnsupportedModule { fn_logger.log(log::ErrorEntry::UnsupportedModule {
module, module,
raw_image, raw_image,
kind: "ELF", kind: "ELF",
}) })
} else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC { } else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC {
self.saved_modules.insert(module); self.saved_modules.insert(module);
// TODO: Figure out how to get size of archive module and write it to disk // TODO: Figure out how to get size of archive module and write it to disk
fn_logger.log(log::ErrorEntry::UnsupportedModule { fn_logger.log(log::ErrorEntry::UnsupportedModule {
module, module,
raw_image, raw_image,
kind: "archive", kind: "archive",
}) })
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC { } else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
unsafe { unsafe {
fn_logger.try_(|fn_logger| { fn_logger.try_(|fn_logger| {
trace::record_submodules_from_wrapped_fatbin( trace::record_submodules_from_wrapped_fatbin(
module, module,
raw_image as *const FatbincWrapper, raw_image as *const FatbincWrapper,
fn_logger, fn_logger,
self, self,
) )
}); });
} }
} else { } else {
self.record_module_ptx(module, raw_image, fn_logger) self.record_module_ptx(module, raw_image, fn_logger)
} }
} }
fn record_module_ptx( fn record_module_ptx(
&mut self, &mut self,
module: CUmodule, module: CUmodule,
raw_image: *const c_void, raw_image: *const c_void,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
) { ) {
self.saved_modules.insert(module); self.saved_modules.insert(module);
let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str(); let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str();
let module_text = match module_text { let module_text = match module_text {
Ok(m) => m, Ok(m) => m,
Err(utf8_err) => { Err(utf8_err) => {
fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err)); fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err));
return; return;
} }
}; };
fn_logger.log_io_error(self.writer.save_module( fn_logger.log_io_error(self.writer.save_module(
self.module_counter, self.module_counter,
None, None,
module_text.as_bytes(), module_text.as_bytes(),
"ptx", "ptx",
)); ));
self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text); self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text);
} }
fn try_parse_and_record_kernels( fn try_parse_and_record_kernels(
&mut self, &mut self,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
module_index: usize, module_index: usize,
submodule_index: Option<usize>, submodule_index: Option<usize>,
module_text: &str, module_text: &str,
) { ) {
let errors = ptx_parser::parse_for_errors(module_text); let errors = ptx_parser::parse_for_errors(module_text);
if !errors.is_empty() { if !errors.is_empty() {
fn_logger.log(log::ErrorEntry::ModuleParsingError( fn_logger.log(log::ErrorEntry::ModuleParsingError(
DumpWriter::get_file_name(module_index, submodule_index, "log"), DumpWriter::get_file_name(module_index, submodule_index, "log"),
)); ));
fn_logger.log_io_error(self.writer.save_module_error_log( fn_logger.log_io_error(self.writer.save_module_error_log(
module_index, module_index,
submodule_index, submodule_index,
&*errors, &*errors,
)); ));
} }
} }
} }
// This structs writes out information about CUDA execution to the dump dir // This structs writes out information about CUDA execution to the dump dir
struct DumpWriter { struct DumpWriter {
dump_dir: Option<PathBuf>, dump_dir: Option<PathBuf>,
} }
impl DumpWriter { impl DumpWriter {
fn new(dump_dir: Option<PathBuf>) -> Self { fn new(dump_dir: Option<PathBuf>) -> Self {
Self { dump_dir } Self { dump_dir }
} }
fn save_module( fn save_module(
&self, &self,
module_index: usize, module_index: usize,
submodule_index: Option<usize>, submodule_index: Option<usize>,
buffer: &[u8], buffer: &[u8],
kind: &'static str, kind: &'static str,
) -> io::Result<()> { ) -> io::Result<()> {
let mut dump_file = match &self.dump_dir { let mut dump_file = match &self.dump_dir {
None => return Ok(()), None => return Ok(()),
Some(d) => d.clone(), Some(d) => d.clone(),
}; };
dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
let mut file = File::create(dump_file)?; let mut file = File::create(dump_file)?;
file.write_all(buffer)?; file.write_all(buffer)?;
Ok(()) Ok(())
} }
fn save_module_error_log<'input>( fn save_module_error_log<'input>(
&self, &self,
module_index: usize, module_index: usize,
submodule_index: Option<usize>, submodule_index: Option<usize>,
errors: &[ptx_parser::PtxError<'input>], errors: &[ptx_parser::PtxError<'input>],
) -> io::Result<()> { ) -> io::Result<()> {
let mut log_file = match &self.dump_dir { let mut log_file = match &self.dump_dir {
None => return Ok(()), None => return Ok(()),
Some(d) => d.clone(), Some(d) => d.clone(),
}; };
log_file.push(Self::get_file_name(module_index, submodule_index, "log")); log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
let mut file = File::create(log_file)?; let mut file = File::create(log_file)?;
for error in errors { for error in errors {
writeln!(file, "{}", error)?; writeln!(file, "{}", error)?;
} }
Ok(()) Ok(())
} }
fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String { fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String {
match submodule_index { match submodule_index {
None => { None => {
format!("module_{:04}.{:02}", module_index, kind) format!("module_{:04}.{:02}", module_index, kind)
} }
Some(submodule_index) => { Some(submodule_index) => {
format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind) format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind)
} }
} }
} }
} }
pub(crate) unsafe fn record_submodules_from_wrapped_fatbin( pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
module: CUmodule, module: CUmodule,
fatbinc_wrapper: *const FatbincWrapper, fatbinc_wrapper: *const FatbincWrapper,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
state: &mut StateTracker, state: &mut StateTracker,
) -> Result<(), ErrorEntry> { ) -> Result<(), ErrorEntry> {
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?; let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
let mut submodules = fatbin.get_submodules()?; let mut submodules = fatbin.get_submodules()?;
while let Some(current) = submodules.next()? { while let Some(current) = submodules.next()? {
record_submodules_from_fatbin(module, current, fn_logger, state)?; record_submodules_from_fatbin(module, current, fn_logger, state)?;
} }
Ok(()) Ok(())
} }
pub(crate) unsafe fn record_submodules_from_fatbin( pub(crate) unsafe fn record_submodules_from_fatbin(
module: CUmodule, module: CUmodule,
submodule: FatbinSubmodule, submodule: FatbinSubmodule,
logger: &mut FnCallLog, logger: &mut FnCallLog,
state: &mut StateTracker, state: &mut StateTracker,
) -> Result<(), ErrorEntry> { ) -> Result<(), ErrorEntry> {
record_submodules(module, logger, state, submodule.get_files())?; record_submodules(module, logger, state, submodule.get_files())?;
Ok(()) Ok(())
} }
pub(crate) unsafe fn record_submodules( pub(crate) unsafe fn record_submodules(
module: CUmodule, module: CUmodule,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
state: &mut StateTracker, state: &mut StateTracker,
mut files: FatbinFileIterator, mut files: FatbinFileIterator,
) -> Result<(), ErrorEntry> { ) -> Result<(), ErrorEntry> {
while let Some(file) = files.next()? { while let Some(file) = files.next()? {
let mut payload = if file let mut payload = if file
.header .header
.flags .flags
.contains(FatbinFileHeaderFlags::CompressedLz4) .contains(FatbinFileHeaderFlags::CompressedLz4)
{ {
Cow::Owned(unwrap_some_or!( Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())), fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
continue continue
)) ))
} else if file } else if file
.header .header
.flags .flags
.contains(FatbinFileHeaderFlags::CompressedZstd) .contains(FatbinFileHeaderFlags::CompressedZstd)
{ {
Cow::Owned(unwrap_some_or!( Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())), fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
continue continue
)) ))
} else { } else {
Cow::Borrowed(file.get_payload()) Cow::Borrowed(file.get_payload())
}; };
match file.header.kind { match file.header.kind {
FatbinFileHeader::HEADER_KIND_PTX => { FatbinFileHeader::HEADER_KIND_PTX => {
while payload.last() == Some(&0) { while payload.last() == Some(&0) {
// remove trailing zeros // remove trailing zeros
payload.to_mut().pop(); payload.to_mut().pop();
} }
state.record_new_submodule(module, &*payload, fn_logger, "ptx") state.record_new_submodule(module, &*payload, fn_logger, "ptx")
} }
FatbinFileHeader::HEADER_KIND_ELF => { FatbinFileHeader::HEADER_KIND_ELF => {
state.record_new_submodule(module, &*payload, fn_logger, "elf") state.record_new_submodule(module, &*payload, fn_logger, "elf")
} }
_ => { _ => {
fn_logger.log(log::ErrorEntry::UnexpectedBinaryField { fn_logger.log(log::ErrorEntry::UnexpectedBinaryField {
field_name: "FATBIN_FILE_HEADER_KIND", field_name: "FATBIN_FILE_HEADER_KIND",
expected: vec![ expected: vec![
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX), UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF), UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
], ],
observed: UInt::U16(file.header.kind), observed: UInt::U16(file.header.kind),
}); });
} }
} }
} }
Ok(()) Ok(())
} }

View File

@ -1,81 +1,81 @@
use std::{ use std::{
env::{self, VarError}, env::{self, VarError},
fs::{self, DirEntry}, fs::{self, DirEntry},
io, io,
path::{self, PathBuf}, path::{self, PathBuf},
process::Command, process::Command,
}; };
fn main() -> Result<(), VarError> { fn main() -> Result<(), VarError> {
if std::env::var_os("CARGO_CFG_WINDOWS").is_none() { if std::env::var_os("CARGO_CFG_WINDOWS").is_none() {
return Ok(()); return Ok(());
} }
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
if env::var("PROFILE")? != "debug" { if env::var("PROFILE")? != "debug" {
return Ok(()); return Ok(());
} }
let rustc_exe = env::var("RUSTC")?; let rustc_exe = env::var("RUSTC")?;
let out_dir = env::var("OUT_DIR")?; let out_dir = env::var("OUT_DIR")?;
let target = env::var("TARGET")?; let target = env::var("TARGET")?;
let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc"; let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc";
let opt_level = env::var("OPT_LEVEL")?; let opt_level = env::var("OPT_LEVEL")?;
let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap(); let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap();
let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
helpers_dir.push("tests"); helpers_dir.push("tests");
helpers_dir.push("helpers"); helpers_dir.push("helpers");
let helpers_dir_as_string = helpers_dir.to_string_lossy(); let helpers_dir_as_string = helpers_dir.to_string_lossy();
println!("cargo:rerun-if-changed={}", helpers_dir_as_string); println!("cargo:rerun-if-changed={}", helpers_dir_as_string);
for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) { for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) {
let full_file_path = format!( let full_file_path = format!(
"{}{}{}", "{}{}{}",
helpers_dir_as_string, helpers_dir_as_string,
path::MAIN_SEPARATOR, path::MAIN_SEPARATOR,
rust_file rust_file
); );
let mut rustc_cmd = Command::new(&*rustc_exe); let mut rustc_cmd = Command::new(&*rustc_exe);
if debug { if debug {
rustc_cmd.arg("-g"); rustc_cmd.arg("-g");
} }
rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string)); rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string));
if !is_msvc { if !is_msvc {
// HACK ALERT // HACK ALERT
// I have no idea why the extra library below have to be linked // I have no idea why the extra library below have to be linked
rustc_cmd.arg(r"-lucrt"); rustc_cmd.arg(r"-lucrt");
} }
rustc_cmd rustc_cmd
.arg("-C") .arg("-C")
.arg(format!("opt-level={}", opt_level)) .arg(format!("opt-level={}", opt_level))
.arg("-L") .arg("-L")
.arg(format!("{}", out_dir)) .arg(format!("{}", out_dir))
.arg("--out-dir") .arg("--out-dir")
.arg(format!("{}", out_dir)) .arg(format!("{}", out_dir))
.arg("--target") .arg("--target")
.arg(format!("{}", target)) .arg(format!("{}", target))
.arg(full_file_path); .arg(full_file_path);
assert!(rustc_cmd.status().unwrap().success()); assert!(rustc_cmd.status().unwrap().success());
} }
std::fs::copy( std::fs::copy(
format!( format!(
"{}{}do_cuinit_late_clr.exe", "{}{}do_cuinit_late_clr.exe",
helpers_dir_as_string, helpers_dir_as_string,
path::MAIN_SEPARATOR path::MAIN_SEPARATOR
), ),
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR), format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
) )
.unwrap(); .unwrap();
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir); println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
Ok(()) Ok(())
} }
fn rust_file(entry: io::Result<DirEntry>) -> Option<String> { fn rust_file(entry: io::Result<DirEntry>) -> Option<String> {
entry.ok().and_then(|e| { entry.ok().and_then(|e| {
let os_file_name = e.file_name(); let os_file_name = e.file_name();
let file_name = os_file_name.to_string_lossy(); let file_name = os_file_name.to_string_lossy();
let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false); let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false);
if is_file && file_name.ends_with(".rs") { if is_file && file_name.ends_with(".rs") {
Some(file_name.to_string()) Some(file_name.to_string())
} else { } else {
None None
} }
}) })
} }

View File

@ -1,311 +1,311 @@
use std::env; use std::env;
use std::os::windows; use std::os::windows;
use std::os::windows::ffi::OsStrExt; use std::os::windows::ffi::OsStrExt;
use std::{error::Error, process}; use std::{error::Error, process};
use std::{fs, io, ptr}; use std::{fs, io, ptr};
use std::{mem, path::PathBuf}; use std::{mem, path::PathBuf};
use argh::FromArgs; use argh::FromArgs;
use mem::size_of_val; use mem::size_of_val;
use tempfile::TempDir; use tempfile::TempDir;
use winapi::um::processenv::SearchPathW; use winapi::um::processenv::SearchPathW;
use winapi::um::{ use winapi::um::{
jobapi2::{AssignProcessToJobObject, SetInformationJobObject}, jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
processthreadsapi::{GetExitCodeProcess, ResumeThread}, processthreadsapi::{GetExitCodeProcess, ResumeThread},
synchapi::WaitForSingleObject, synchapi::WaitForSingleObject,
winbase::CreateJobObjectA, winbase::CreateJobObjectA,
winnt::{ winnt::{
JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
}, },
}; };
use winapi::um::winbase::{INFINITE, WAIT_FAILED}; use winapi::um::winbase::{INFINITE, WAIT_FAILED};
static REDIRECT_DLL: &'static str = "zluda_redirect.dll"; static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
static NVCUDA_DLL: &'static str = "nvcuda.dll"; static NVCUDA_DLL: &'static str = "nvcuda.dll";
static NVML_DLL: &'static str = "nvml.dll"; static NVML_DLL: &'static str = "nvml.dll";
include!("../../zluda_redirect/src/payload_guid.rs"); include!("../../zluda_redirect/src/payload_guid.rs");
#[derive(FromArgs)] #[derive(FromArgs)]
/// Launch application with custom CUDA libraries /// Launch application with custom CUDA libraries
struct ProgramArguments { struct ProgramArguments {
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory /// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
#[argh(option)] #[argh(option)]
nvcuda: Option<PathBuf>, nvcuda: Option<PathBuf>,
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory /// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
#[argh(option)] #[argh(option)]
nvml: Option<PathBuf>, nvml: Option<PathBuf>,
/// executable to be injected with custom CUDA libraries /// executable to be injected with custom CUDA libraries
#[argh(positional)] #[argh(positional)]
exe: String, exe: String,
/// arguments to the executable /// arguments to the executable
#[argh(positional)] #[argh(positional)]
args: Vec<String>, args: Vec<String>,
} }
pub fn main_impl() -> Result<(), Box<dyn Error>> { pub fn main_impl() -> Result<(), Box<dyn Error>> {
let raw_args = argh::from_env::<ProgramArguments>(); let raw_args = argh::from_env::<ProgramArguments>();
let normalized_args = NormalizedArguments::new(raw_args)?; let normalized_args = NormalizedArguments::new(raw_args)?;
let mut environment = Environment::setup(normalized_args)?; let mut environment = Environment::setup(normalized_args)?;
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() }; let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() }; let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
let mut dlls_to_inject = [ let mut dlls_to_inject = [
environment.nvml_path_zero_terminated.as_ptr() as *const i8, environment.nvml_path_zero_terminated.as_ptr() as *const i8,
environment.nvcuda_path_zero_terminated.as_ptr() as _, environment.nvcuda_path_zero_terminated.as_ptr() as _,
environment.redirect_path_zero_terminated.as_ptr() as _, environment.redirect_path_zero_terminated.as_ptr() as _,
]; ];
os_call!( os_call!(
detours_sys::DetourCreateProcessWithDllsW( detours_sys::DetourCreateProcessWithDllsW(
ptr::null(), ptr::null(),
environment.winapi_command_line_zero_terminated.as_mut_ptr(), environment.winapi_command_line_zero_terminated.as_mut_ptr(),
ptr::null_mut(), ptr::null_mut(),
ptr::null_mut(), ptr::null_mut(),
0, 0,
0, 0,
ptr::null_mut(), ptr::null_mut(),
ptr::null(), ptr::null(),
&mut startup_info as *mut _, &mut startup_info as *mut _,
&mut proc_info as *mut _, &mut proc_info as *mut _,
dlls_to_inject.len() as u32, dlls_to_inject.len() as u32,
dlls_to_inject.as_mut_ptr(), dlls_to_inject.as_mut_ptr(),
Option::None Option::None
), ),
|x| x != 0 |x| x != 0
); );
kill_child_on_process_exit(proc_info.hProcess)?; kill_child_on_process_exit(proc_info.hProcess)?;
os_call!( os_call!(
detours_sys::DetourCopyPayloadToProcess( detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess, proc_info.hProcess,
&PAYLOAD_NVCUDA_GUID, &PAYLOAD_NVCUDA_GUID,
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _, environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
environment.nvcuda_path_zero_terminated.len() as u32 environment.nvcuda_path_zero_terminated.len() as u32
), ),
|x| x != 0 |x| x != 0
); );
os_call!( os_call!(
detours_sys::DetourCopyPayloadToProcess( detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess, proc_info.hProcess,
&PAYLOAD_NVML_GUID, &PAYLOAD_NVML_GUID,
environment.nvml_path_zero_terminated.as_ptr() as *mut _, environment.nvml_path_zero_terminated.as_ptr() as *mut _,
environment.nvml_path_zero_terminated.len() as u32 environment.nvml_path_zero_terminated.len() as u32
), ),
|x| x != 0 |x| x != 0
); );
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1); os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
!= WAIT_FAILED); != WAIT_FAILED);
let mut child_exit_code: u32 = 0; let mut child_exit_code: u32 = 0;
os_call!( os_call!(
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _), GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|x| x != 0 |x| x != 0
); );
process::exit(child_exit_code as i32) process::exit(child_exit_code as i32)
} }
struct NormalizedArguments { struct NormalizedArguments {
nvml_path: PathBuf, nvml_path: PathBuf,
nvcuda_path: PathBuf, nvcuda_path: PathBuf,
redirect_path: PathBuf, redirect_path: PathBuf,
winapi_command_line_zero_terminated: Vec<u16>, winapi_command_line_zero_terminated: Vec<u16>,
} }
impl NormalizedArguments { impl NormalizedArguments {
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> { fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
let current_exe = env::current_exe()?; let current_exe = env::current_exe()?;
let nvml_path = Self::get_absolute_path(&current_exe, prog_args.nvml, NVML_DLL)?; let nvml_path = Self::get_absolute_path(&current_exe, prog_args.nvml, NVML_DLL)?;
let nvcuda_path = Self::get_absolute_path(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?; let nvcuda_path = Self::get_absolute_path(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?;
let winapi_command_line_zero_terminated = let winapi_command_line_zero_terminated =
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args)); construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
let mut redirect_path = current_exe.parent().unwrap().to_path_buf(); let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
redirect_path.push(REDIRECT_DLL); redirect_path.push(REDIRECT_DLL);
Ok(Self { Ok(Self {
nvml_path, nvml_path,
nvcuda_path, nvcuda_path,
redirect_path, redirect_path,
winapi_command_line_zero_terminated, winapi_command_line_zero_terminated,
}) })
} }
const WIN_MAX_PATH: usize = 260; const WIN_MAX_PATH: usize = 260;
fn get_absolute_path( fn get_absolute_path(
current_exe: &PathBuf, current_exe: &PathBuf,
dll: Option<PathBuf>, dll: Option<PathBuf>,
default: &str, default: &str,
) -> Result<PathBuf, Box<dyn Error>> { ) -> Result<PathBuf, Box<dyn Error>> {
Ok(if let Some(dll) = dll { Ok(if let Some(dll) = dll {
if dll.is_absolute() { if dll.is_absolute() {
dll dll
} else { } else {
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH]; let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>(); let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
dll_utf16.push(0); dll_utf16.push(0);
loop { loop {
let copied_len = os_call!( let copied_len = os_call!(
SearchPathW( SearchPathW(
ptr::null_mut(), ptr::null_mut(),
dll_utf16.as_ptr(), dll_utf16.as_ptr(),
ptr::null(), ptr::null(),
full_dll_path.len() as u32, full_dll_path.len() as u32,
full_dll_path.as_mut_ptr(), full_dll_path.as_mut_ptr(),
ptr::null_mut() ptr::null_mut()
), ),
|x| x != 0 |x| x != 0
) as usize; ) as usize;
if copied_len > full_dll_path.len() { if copied_len > full_dll_path.len() {
full_dll_path.resize(copied_len + 1, 0); full_dll_path.resize(copied_len + 1, 0);
} else { } else {
full_dll_path.truncate(copied_len); full_dll_path.truncate(copied_len);
break; break;
} }
} }
PathBuf::from(String::from_utf16_lossy(&full_dll_path)) PathBuf::from(String::from_utf16_lossy(&full_dll_path))
} }
} else { } else {
let mut dll_path = current_exe.parent().unwrap().to_path_buf(); let mut dll_path = current_exe.parent().unwrap().to_path_buf();
dll_path.push(default); dll_path.push(default);
dll_path dll_path
}) })
} }
} }
struct Environment { struct Environment {
nvml_path_zero_terminated: String, nvml_path_zero_terminated: String,
nvcuda_path_zero_terminated: String, nvcuda_path_zero_terminated: String,
redirect_path_zero_terminated: String, redirect_path_zero_terminated: String,
winapi_command_line_zero_terminated: Vec<u16>, winapi_command_line_zero_terminated: Vec<u16>,
_temp_dir: TempDir, _temp_dir: TempDir,
} }
// This structs represents "enviroment". By environment we mean all paths // This structs represents "enviroment". By environment we mean all paths
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary // (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
// directory which contains nvcuda.dll // directory which contains nvcuda.dll
impl Environment { impl Environment {
fn setup(args: NormalizedArguments) -> io::Result<Self> { fn setup(args: NormalizedArguments) -> io::Result<Self> {
let _temp_dir = TempDir::new()?; let _temp_dir = TempDir::new()?;
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.nvml_path, args.nvml_path,
&_temp_dir, &_temp_dir,
NVML_DLL, NVML_DLL,
)?); )?);
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.nvcuda_path, args.nvcuda_path,
&_temp_dir, &_temp_dir,
NVCUDA_DLL, NVCUDA_DLL,
)?); )?);
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path); let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
Ok(Self { Ok(Self {
nvml_path_zero_terminated, nvml_path_zero_terminated,
nvcuda_path_zero_terminated, nvcuda_path_zero_terminated,
redirect_path_zero_terminated, redirect_path_zero_terminated,
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated, winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
_temp_dir, _temp_dir,
}) })
} }
fn copy_to_correct_name( fn copy_to_correct_name(
path_buf: PathBuf, path_buf: PathBuf,
temp_dir: &TempDir, temp_dir: &TempDir,
correct_name: &str, correct_name: &str,
) -> io::Result<PathBuf> { ) -> io::Result<PathBuf> {
let file_name = path_buf.file_name().unwrap(); let file_name = path_buf.file_name().unwrap();
if file_name == correct_name { if file_name == correct_name {
Ok(path_buf) Ok(path_buf)
} else { } else {
let mut temp_file_path = temp_dir.path().to_path_buf(); let mut temp_file_path = temp_dir.path().to_path_buf();
temp_file_path.push(correct_name); temp_file_path.push(correct_name);
match windows::fs::symlink_file(&path_buf, &temp_file_path) { match windows::fs::symlink_file(&path_buf, &temp_file_path) {
Ok(()) => {} Ok(()) => {}
Err(_) => { Err(_) => {
fs::copy(&path_buf, &temp_file_path)?; fs::copy(&path_buf, &temp_file_path)?;
} }
} }
Ok(temp_file_path) Ok(temp_file_path)
} }
} }
fn zero_terminate(p: PathBuf) -> String { fn zero_terminate(p: PathBuf) -> String {
let mut s = p.to_string_lossy().to_string(); let mut s = p.to_string_lossy().to_string();
s.push('\0'); s.push('\0');
s s
} }
} }
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> { fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
!= ptr::null_mut()); != ptr::null_mut());
let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() }; let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() };
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
os_call!( os_call!(
SetInformationJobObject( SetInformationJobObject(
job_handle, job_handle,
JobObjectExtendedLimitInformation, JobObjectExtendedLimitInformation,
&mut info as *mut _ as *mut _, &mut info as *mut _ as *mut _,
size_of_val(&info) as u32 size_of_val(&info) as u32
), ),
|x| x != 0 |x| x != 0
); );
os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0); os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0);
Ok(()) Ok(())
} }
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way // Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> { fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
let mut cmd_line = Vec::new(); let mut cmd_line = Vec::new();
let args_len = args.size_hint().0; let args_len = args.size_hint().0;
for (idx, arg) in args.enumerate() { for (idx, arg) in args.enumerate() {
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) { if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
cmd_line.extend(arg.encode_utf16()); cmd_line.extend(arg.encode_utf16());
} else { } else {
cmd_line.push('"' as u16); // " cmd_line.push('"' as u16); // "
let mut char_iter = arg.chars().peekable(); let mut char_iter = arg.chars().peekable();
loop { loop {
let mut current = char_iter.next(); let mut current = char_iter.next();
let mut backslashes = 0; let mut backslashes = 0;
match current { match current {
Some('\\') => { Some('\\') => {
backslashes = 1; backslashes = 1;
while let Some('\\') = char_iter.peek() { while let Some('\\') = char_iter.peek() {
backslashes += 1; backslashes += 1;
char_iter.next(); char_iter.next();
} }
current = char_iter.next(); current = char_iter.next();
} }
_ => {} _ => {}
} }
match current { match current {
None => { None => {
for _ in 0..(backslashes * 2) { for _ in 0..(backslashes * 2) {
cmd_line.push('\\' as u16); cmd_line.push('\\' as u16);
} }
break; break;
} }
Some('"') => { Some('"') => {
for _ in 0..(backslashes * 2 + 1) { for _ in 0..(backslashes * 2 + 1) {
cmd_line.push('\\' as u16); cmd_line.push('\\' as u16);
} }
cmd_line.push('"' as u16); cmd_line.push('"' as u16);
} }
Some(c) => { Some(c) => {
for _ in 0..backslashes { for _ in 0..backslashes {
cmd_line.push('\\' as u16); cmd_line.push('\\' as u16);
} }
let mut temp = [0u16; 2]; let mut temp = [0u16; 2];
cmd_line.extend(&*c.encode_utf16(&mut temp)); cmd_line.extend(&*c.encode_utf16(&mut temp));
} }
} }
} }
cmd_line.push('"' as u16); cmd_line.push('"' as u16);
} }
if idx < args_len - 1 { if idx < args_len - 1 {
cmd_line.push(' ' as u16); cmd_line.push(' ' as u16);
} }
} }
cmd_line.push(0); cmd_line.push(0);
cmd_line cmd_line
} }

View File

@ -1,13 +1,13 @@
#[macro_use] #[macro_use]
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
mod win; mod win;
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
mod bin; mod bin;
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
bin::main_impl() bin::main_impl()
} }
#[cfg(not(target_os = "windows"))] #[cfg(not(target_os = "windows"))]
fn main() {} fn main() {}

View File

@ -1,151 +1,151 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
use std::error; use std::error;
use std::fmt; use std::fmt;
use std::ptr; use std::ptr;
mod c { mod c {
use std::ffi::c_void; use std::ffi::c_void;
use std::os::raw::c_ulong; use std::os::raw::c_ulong;
pub type DWORD = c_ulong; pub type DWORD = c_ulong;
pub type HANDLE = LPVOID; pub type HANDLE = LPVOID;
pub type LPVOID = *mut c_void; pub type LPVOID = *mut c_void;
pub type HINSTANCE = HANDLE; pub type HINSTANCE = HANDLE;
pub type HMODULE = HINSTANCE; pub type HMODULE = HINSTANCE;
pub type WCHAR = u16; pub type WCHAR = u16;
pub type LPCWSTR = *const WCHAR; pub type LPCWSTR = *const WCHAR;
pub type LPWSTR = *mut WCHAR; pub type LPWSTR = *mut WCHAR;
pub const FACILITY_NT_BIT: DWORD = 0x1000_0000; pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800; pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000; pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200; pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;
extern "system" { extern "system" {
pub fn GetLastError() -> DWORD; pub fn GetLastError() -> DWORD;
pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE; pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
pub fn FormatMessageW( pub fn FormatMessageW(
flags: DWORD, flags: DWORD,
lpSrc: LPVOID, lpSrc: LPVOID,
msgId: DWORD, msgId: DWORD,
langId: DWORD, langId: DWORD,
buf: LPWSTR, buf: LPWSTR,
nsize: DWORD, nsize: DWORD,
args: *const c_void, args: *const c_void,
) -> DWORD; ) -> DWORD;
} }
} }
macro_rules! last_ident { macro_rules! last_ident {
($i:ident) => { ($i:ident) => {
stringify!($i) stringify!($i)
}; };
($start:ident, $($cont:ident),+) => { ($start:ident, $($cont:ident),+) => {
last_ident!($($cont),+) last_ident!($($cont),+)
}; };
} }
macro_rules! os_call { macro_rules! os_call {
($($path:ident)::+ ($($args:expr),*), $success:expr) => { ($($path:ident)::+ ($($args:expr),*), $success:expr) => {
{ {
let result = unsafe{ $($path)::+ ($($args),*) }; let result = unsafe{ $($path)::+ ($($args),*) };
if !($success)(result) { if !($success)(result) {
let name = last_ident!($($path),+); let name = last_ident!($($path),+);
let err_code = $crate::win::errno(); let err_code = $crate::win::errno();
Err($crate::win::OsError{ Err($crate::win::OsError{
function: name, function: name,
error_code: err_code as u32, error_code: err_code as u32,
message: $crate::win::error_string(err_code) message: $crate::win::error_string(err_code)
})?; })?;
} }
result result
} }
}; };
} }
#[derive(Debug)] #[derive(Debug)]
pub struct OsError { pub struct OsError {
pub function: &'static str, pub function: &'static str,
pub error_code: u32, pub error_code: u32,
pub message: String, pub message: String,
} }
impl fmt::Display for OsError { impl fmt::Display for OsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self) write!(f, "{:?}", self)
} }
} }
impl error::Error for OsError { impl error::Error for OsError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> { fn source(&self) -> Option<&(dyn error::Error + 'static)> {
None None
} }
} }
pub fn errno() -> i32 { pub fn errno() -> i32 {
unsafe { c::GetLastError() as i32 } unsafe { c::GetLastError() as i32 }
} }
/// Gets a detailed string description for the given error number. /// Gets a detailed string description for the given error number.
pub fn error_string(mut errnum: i32) -> String { pub fn error_string(mut errnum: i32) -> String {
// This value is calculated from the macro // This value is calculated from the macro
// MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT) // MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
let langId = 0x0800 as c::DWORD; let langId = 0x0800 as c::DWORD;
let mut buf = [0 as c::WCHAR; 2048]; let mut buf = [0 as c::WCHAR; 2048];
unsafe { unsafe {
let mut module = ptr::null_mut(); let mut module = ptr::null_mut();
let mut flags = 0; let mut flags = 0;
// NTSTATUS errors may be encoded as HRESULT, which may returned from // NTSTATUS errors may be encoded as HRESULT, which may returned from
// GetLastError. For more information about Windows error codes, see // GetLastError. For more information about Windows error codes, see
// `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx // `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
if (errnum & c::FACILITY_NT_BIT as i32) != 0 { if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
// format according to https://support.microsoft.com/en-us/help/259693 // format according to https://support.microsoft.com/en-us/help/259693
const NTDLL_DLL: &[u16] = &[ const NTDLL_DLL: &[u16] = &[
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _, 'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
'L' as _, 0, 'L' as _, 0,
]; ];
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr()); module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());
if module != ptr::null_mut() { if module != ptr::null_mut() {
errnum ^= c::FACILITY_NT_BIT as i32; errnum ^= c::FACILITY_NT_BIT as i32;
flags = c::FORMAT_MESSAGE_FROM_HMODULE; flags = c::FORMAT_MESSAGE_FROM_HMODULE;
} }
} }
let res = c::FormatMessageW( let res = c::FormatMessageW(
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS, flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
module, module,
errnum as c::DWORD, errnum as c::DWORD,
langId, langId,
buf.as_mut_ptr(), buf.as_mut_ptr(),
buf.len() as c::DWORD, buf.len() as c::DWORD,
ptr::null(), ptr::null(),
) as usize; ) as usize;
if res == 0 { if res == 0 {
// Sometimes FormatMessageW can fail e.g., system doesn't like langId, // Sometimes FormatMessageW can fail e.g., system doesn't like langId,
let fm_err = errno(); let fm_err = errno();
return format!( return format!(
"OS Error {} (FormatMessageW() returned error {})", "OS Error {} (FormatMessageW() returned error {})",
errnum, fm_err errnum, fm_err
); );
} }
match String::from_utf16(&buf[..res]) { match String::from_utf16(&buf[..res]) {
Ok(mut msg) => { Ok(mut msg) => {
// Trim trailing CRLF inserted by FormatMessageW // Trim trailing CRLF inserted by FormatMessageW
let len = msg.trim_end().len(); let len = msg.trim_end().len();
msg.truncate(len); msg.truncate(len);
msg msg
} }
Err(..) => format!( Err(..) => format!(
"OS Error {} (FormatMessageW() returned \ "OS Error {} (FormatMessageW() returned \
invalid UTF-16)", invalid UTF-16)",
errnum errnum
), ),
} }
} }
} }

View File

@ -1,51 +1,51 @@
#![cfg(windows)] #![cfg(windows)]
use std::{env, io, path::PathBuf, process::Command}; use std::{env, io, path::PathBuf, process::Command};
#[test] #[test]
fn direct_cuinit() -> io::Result<()> { fn direct_cuinit() -> io::Result<()> {
run_process_and_check_for_zluda_dump("direct_cuinit") run_process_and_check_for_zluda_dump("direct_cuinit")
} }
#[test] #[test]
fn do_cuinit_early() -> io::Result<()> { fn do_cuinit_early() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_early") run_process_and_check_for_zluda_dump("do_cuinit_early")
} }
#[test] #[test]
fn do_cuinit_late() -> io::Result<()> { fn do_cuinit_late() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_late") run_process_and_check_for_zluda_dump("do_cuinit_late")
} }
#[test] #[test]
fn do_cuinit_late_clr() -> io::Result<()> { fn do_cuinit_late_clr() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_late_clr") run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
} }
#[test] #[test]
fn indirect_cuinit() -> io::Result<()> { fn indirect_cuinit() -> io::Result<()> {
run_process_and_check_for_zluda_dump("indirect_cuinit") run_process_and_check_for_zluda_dump("indirect_cuinit")
} }
#[test] #[test]
fn subprocess() -> io::Result<()> { fn subprocess() -> io::Result<()> {
run_process_and_check_for_zluda_dump("subprocess") run_process_and_check_for_zluda_dump("subprocess")
} }
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> { fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with")); let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf(); let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
zluda_dump_dll.push("zluda_dump.dll"); zluda_dump_dll.push("zluda_dump.dll");
let helpers_dir = env!("HELPERS_OUT_DIR"); let helpers_dir = env!("HELPERS_OUT_DIR");
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name); let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
let mut test_cmd = Command::new(&zluda_with_exe); let mut test_cmd = Command::new(&zluda_with_exe);
let test_cmd = test_cmd let test_cmd = test_cmd
.arg("--nvcuda") .arg("--nvcuda")
.arg(&zluda_dump_dll) .arg(&zluda_dump_dll)
.arg("--") .arg("--")
.arg(&exe_under_test); .arg(&exe_under_test);
let test_output = test_cmd.output()?; let test_output = test_cmd.output()?;
assert!(test_output.status.success()); assert!(test_output.status.success());
let stderr_text = String::from_utf8(test_output.stderr).unwrap(); let stderr_text = String::from_utf8(test_output.stderr).unwrap();
assert!(stderr_text.contains("ZLUDA_DUMP")); assert!(stderr_text.contains("ZLUDA_DUMP"));
Ok(()) Ok(())
} }