Always use Unix line endings (#453)

This commit is contained in:
Violet
2025-07-30 15:09:47 -07:00
committed by GitHub
parent 21ef5f60a3
commit b8bcbec295
28 changed files with 5208 additions and 5207 deletions

1
.rustfmt.toml Normal file
View File

@ -0,0 +1 @@
newline_style = "Unix"

View File

@ -1,147 +1,147 @@
use cmake::Config;
use std::io;
use std::path::PathBuf;
use std::process::Command;
const COMPONENTS: &[&'static str] = &[
"LLVMCore",
"LLVMBitWriter",
#[cfg(debug_assertions)]
"LLVMAnalysis", // for module verify
#[cfg(debug_assertions)]
"LLVMBitReader",
];
fn main() {
let mut cmake = Config::new(r"../ext/llvm-project/llvm");
try_use_sccache(&mut cmake);
try_use_ninja(&mut cmake);
cmake
// It's not like we can do anything about the warnings
.define("LLVM_ENABLE_WARNINGS", "OFF")
// For some reason Rust always links to release CRT
.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded")
.define("LLVM_ENABLE_TERMINFO", "OFF")
.define("LLVM_ENABLE_LIBXML2", "OFF")
.define("LLVM_ENABLE_LIBEDIT", "OFF")
.define("LLVM_ENABLE_LIBPFM", "OFF")
.define("LLVM_ENABLE_ZLIB", "OFF")
.define("LLVM_ENABLE_ZSTD", "OFF")
.define("LLVM_INCLUDE_BENCHMARKS", "OFF")
.define("LLVM_INCLUDE_EXAMPLES", "OFF")
.define("LLVM_INCLUDE_TESTS", "OFF")
.define("LLVM_BUILD_TOOLS", "OFF")
.define("LLVM_TARGETS_TO_BUILD", "")
.define("LLVM_ENABLE_PROJECTS", "");
cmake.build_target("llvm-config");
let llvm_dir = cmake.build();
for c in COMPONENTS {
cmake.build_target(c);
cmake.build();
}
let cmake_profile = cmake.get_profile();
let (cxxflags, ldflags, libdir, lib_names, system_libs) =
llvm_config(&llvm_dir, &["build", "bin", "llvm-config"])
.or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"]))
.unwrap();
println!("cargo:rustc-link-arg={ldflags}");
println!("cargo:rustc-link-search=native={libdir}");
for lib in system_libs.split_ascii_whitespace() {
println!("cargo:rustc-link-arg={lib}");
}
link_llvm_components(lib_names);
compile_cxx_lib(cxxflags);
}
// https://github.com/mozilla/sccache/blob/main/README.md#usage
fn try_use_sccache(cmake: &mut Config) {
if let Ok(sccache) = std::env::var("SCCACHE_PATH") {
cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache);
cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache);
match std::env::var_os("CARGO_CFG_TARGET_OS") {
Some(os) if os == "windows" => {
cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded");
cmake.define("CMAKE_POLICY_CMP0141", "NEW");
}
_ => {}
}
}
}
fn try_use_ninja(cmake: &mut Config) {
let mut cmd = Command::new("ninja");
cmd.arg("--version");
if let Ok(status) = cmd.status() {
if status.success() {
cmake.generator("Ninja");
}
}
}
fn llvm_config(
llvm_build_dir: &PathBuf,
path_to_llvm_config: &[&str],
) -> io::Result<(String, String, String, String, String)> {
let mut llvm_build_path = llvm_build_dir.clone();
llvm_build_path.extend(path_to_llvm_config);
let mut cmd = Command::new(llvm_build_path);
cmd.args([
"--link-static",
"--cxxflags",
"--ldflags",
"--libdir",
"--libnames",
"--system-libs",
]);
for c in COMPONENTS {
cmd.arg(c[4..].to_lowercase());
}
let output = cmd.output()?;
if !output.status.success() {
return Err(io::Error::from(io::ErrorKind::Other));
}
let output = unsafe { String::from_utf8_unchecked(output.stdout) };
let mut lines = output.lines();
let cxxflags = lines.next().unwrap();
let ldflags = lines.next().unwrap();
let libdir = lines.next().unwrap();
let lib_names = lines.next().unwrap();
let system_libs = lines.next().unwrap();
Ok((
cxxflags.to_string(),
ldflags.to_string(),
libdir.to_string(),
lib_names.to_string(),
system_libs.to_string(),
))
}
fn compile_cxx_lib(cxxflags: String) {
let mut cc = cc::Build::new();
for flag in cxxflags.split_whitespace() {
cc.flag(flag);
}
cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp");
println!("cargo:rerun-if-changed=src/lib.cpp");
println!("cargo:rerun-if-changed=src/lib.rs");
}
fn link_llvm_components(components: String) {
for component in components.split_whitespace() {
let component = if let Some(component) = component
.strip_prefix("lib")
.and_then(|component| component.strip_suffix(".a"))
{
// Unix (Linux/Mac)
// libLLVMfoo.a
component
} else if let Some(component) = component.strip_suffix(".lib") {
// Windows
// LLVMfoo.lib
component
} else {
panic!("'{}' does not look like a static library name", component)
};
println!("cargo:rustc-link-lib={component}");
}
}
use cmake::Config;
use std::io;
use std::path::PathBuf;
use std::process::Command;
const COMPONENTS: &[&'static str] = &[
"LLVMCore",
"LLVMBitWriter",
#[cfg(debug_assertions)]
"LLVMAnalysis", // for module verify
#[cfg(debug_assertions)]
"LLVMBitReader",
];
fn main() {
let mut cmake = Config::new(r"../ext/llvm-project/llvm");
try_use_sccache(&mut cmake);
try_use_ninja(&mut cmake);
cmake
// It's not like we can do anything about the warnings
.define("LLVM_ENABLE_WARNINGS", "OFF")
// For some reason Rust always links to release CRT
.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded")
.define("LLVM_ENABLE_TERMINFO", "OFF")
.define("LLVM_ENABLE_LIBXML2", "OFF")
.define("LLVM_ENABLE_LIBEDIT", "OFF")
.define("LLVM_ENABLE_LIBPFM", "OFF")
.define("LLVM_ENABLE_ZLIB", "OFF")
.define("LLVM_ENABLE_ZSTD", "OFF")
.define("LLVM_INCLUDE_BENCHMARKS", "OFF")
.define("LLVM_INCLUDE_EXAMPLES", "OFF")
.define("LLVM_INCLUDE_TESTS", "OFF")
.define("LLVM_BUILD_TOOLS", "OFF")
.define("LLVM_TARGETS_TO_BUILD", "")
.define("LLVM_ENABLE_PROJECTS", "");
cmake.build_target("llvm-config");
let llvm_dir = cmake.build();
for c in COMPONENTS {
cmake.build_target(c);
cmake.build();
}
let cmake_profile = cmake.get_profile();
let (cxxflags, ldflags, libdir, lib_names, system_libs) =
llvm_config(&llvm_dir, &["build", "bin", "llvm-config"])
.or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"]))
.unwrap();
println!("cargo:rustc-link-arg={ldflags}");
println!("cargo:rustc-link-search=native={libdir}");
for lib in system_libs.split_ascii_whitespace() {
println!("cargo:rustc-link-arg={lib}");
}
link_llvm_components(lib_names);
compile_cxx_lib(cxxflags);
}
// https://github.com/mozilla/sccache/blob/main/README.md#usage
fn try_use_sccache(cmake: &mut Config) {
if let Ok(sccache) = std::env::var("SCCACHE_PATH") {
cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache);
cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache);
match std::env::var_os("CARGO_CFG_TARGET_OS") {
Some(os) if os == "windows" => {
cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded");
cmake.define("CMAKE_POLICY_CMP0141", "NEW");
}
_ => {}
}
}
}
fn try_use_ninja(cmake: &mut Config) {
let mut cmd = Command::new("ninja");
cmd.arg("--version");
if let Ok(status) = cmd.status() {
if status.success() {
cmake.generator("Ninja");
}
}
}
fn llvm_config(
llvm_build_dir: &PathBuf,
path_to_llvm_config: &[&str],
) -> io::Result<(String, String, String, String, String)> {
let mut llvm_build_path = llvm_build_dir.clone();
llvm_build_path.extend(path_to_llvm_config);
let mut cmd = Command::new(llvm_build_path);
cmd.args([
"--link-static",
"--cxxflags",
"--ldflags",
"--libdir",
"--libnames",
"--system-libs",
]);
for c in COMPONENTS {
cmd.arg(c[4..].to_lowercase());
}
let output = cmd.output()?;
if !output.status.success() {
return Err(io::Error::from(io::ErrorKind::Other));
}
let output = unsafe { String::from_utf8_unchecked(output.stdout) };
let mut lines = output.lines();
let cxxflags = lines.next().unwrap();
let ldflags = lines.next().unwrap();
let libdir = lines.next().unwrap();
let lib_names = lines.next().unwrap();
let system_libs = lines.next().unwrap();
Ok((
cxxflags.to_string(),
ldflags.to_string(),
libdir.to_string(),
lib_names.to_string(),
system_libs.to_string(),
))
}
fn compile_cxx_lib(cxxflags: String) {
let mut cc = cc::Build::new();
for flag in cxxflags.split_whitespace() {
cc.flag(flag);
}
cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp");
println!("cargo:rerun-if-changed=src/lib.cpp");
println!("cargo:rerun-if-changed=src/lib.rs");
}
fn link_llvm_components(components: String) {
for component in components.split_whitespace() {
let component = if let Some(component) = component
.strip_prefix("lib")
.and_then(|component| component.strip_suffix(".a"))
{
// Unix (Linux/Mac)
// libLLVMfoo.a
component
} else if let Some(component) = component.strip_suffix(".lib") {
// Windows
// LLVMfoo.lib
component
} else {
panic!("'{}' does not look like a static library name", component)
};
println!("cargo:rustc-link-lib={component}");
}
}

View File

@ -1,81 +1,81 @@
#![allow(non_upper_case_globals)]
use llvm_sys::prelude::*;
pub use llvm_sys::*;
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum LLVMZludaAtomicRMWBinOp {
LLVMZludaAtomicRMWBinOpXchg = 0,
LLVMZludaAtomicRMWBinOpAdd = 1,
LLVMZludaAtomicRMWBinOpSub = 2,
LLVMZludaAtomicRMWBinOpAnd = 3,
LLVMZludaAtomicRMWBinOpNand = 4,
LLVMZludaAtomicRMWBinOpOr = 5,
LLVMZludaAtomicRMWBinOpXor = 6,
LLVMZludaAtomicRMWBinOpMax = 7,
LLVMZludaAtomicRMWBinOpMin = 8,
LLVMZludaAtomicRMWBinOpUMax = 9,
LLVMZludaAtomicRMWBinOpUMin = 10,
LLVMZludaAtomicRMWBinOpFAdd = 11,
LLVMZludaAtomicRMWBinOpFSub = 12,
LLVMZludaAtomicRMWBinOpFMax = 13,
LLVMZludaAtomicRMWBinOpFMin = 14,
LLVMZludaAtomicRMWBinOpUIncWrap = 15,
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
}
// Backport from LLVM 19
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
| LLVMZludaFastMathNoNaNs
| LLVMZludaFastMathNoInfs
| LLVMZludaFastMathNoSignedZeros
| LLVMZludaFastMathAllowReciprocal
| LLVMZludaFastMathAllowContract
| LLVMZludaFastMathApproxFunc;
pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
extern "C" {
pub fn LLVMZludaBuildAlloca(
B: LLVMBuilderRef,
Ty: LLVMTypeRef,
AddrSpace: u32,
Name: *const i8,
) -> LLVMValueRef;
pub fn LLVMZludaBuildAtomicRMW(
B: LLVMBuilderRef,
op: LLVMZludaAtomicRMWBinOp,
PTR: LLVMValueRef,
Val: LLVMValueRef,
scope: *const i8,
ordering: LLVMAtomicOrdering,
) -> LLVMValueRef;
pub fn LLVMZludaBuildAtomicCmpXchg(
B: LLVMBuilderRef,
Ptr: LLVMValueRef,
Cmp: LLVMValueRef,
New: LLVMValueRef,
scope: *const i8,
SuccessOrdering: LLVMAtomicOrdering,
FailureOrdering: LLVMAtomicOrdering,
) -> LLVMValueRef;
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
pub fn LLVMZludaBuildFence(
B: LLVMBuilderRef,
ordering: LLVMAtomicOrdering,
scope: *const i8,
Name: *const i8,
) -> LLVMValueRef;
}
#![allow(non_upper_case_globals)]
use llvm_sys::prelude::*;
pub use llvm_sys::*;
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum LLVMZludaAtomicRMWBinOp {
LLVMZludaAtomicRMWBinOpXchg = 0,
LLVMZludaAtomicRMWBinOpAdd = 1,
LLVMZludaAtomicRMWBinOpSub = 2,
LLVMZludaAtomicRMWBinOpAnd = 3,
LLVMZludaAtomicRMWBinOpNand = 4,
LLVMZludaAtomicRMWBinOpOr = 5,
LLVMZludaAtomicRMWBinOpXor = 6,
LLVMZludaAtomicRMWBinOpMax = 7,
LLVMZludaAtomicRMWBinOpMin = 8,
LLVMZludaAtomicRMWBinOpUMax = 9,
LLVMZludaAtomicRMWBinOpUMin = 10,
LLVMZludaAtomicRMWBinOpFAdd = 11,
LLVMZludaAtomicRMWBinOpFSub = 12,
LLVMZludaAtomicRMWBinOpFMax = 13,
LLVMZludaAtomicRMWBinOpFMin = 14,
LLVMZludaAtomicRMWBinOpUIncWrap = 15,
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
}
// Backport from LLVM 19
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
| LLVMZludaFastMathNoNaNs
| LLVMZludaFastMathNoInfs
| LLVMZludaFastMathNoSignedZeros
| LLVMZludaFastMathAllowReciprocal
| LLVMZludaFastMathAllowContract
| LLVMZludaFastMathApproxFunc;
pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
extern "C" {
pub fn LLVMZludaBuildAlloca(
B: LLVMBuilderRef,
Ty: LLVMTypeRef,
AddrSpace: u32,
Name: *const i8,
) -> LLVMValueRef;
pub fn LLVMZludaBuildAtomicRMW(
B: LLVMBuilderRef,
op: LLVMZludaAtomicRMWBinOp,
PTR: LLVMValueRef,
Val: LLVMValueRef,
scope: *const i8,
ordering: LLVMAtomicOrdering,
) -> LLVMValueRef;
pub fn LLVMZludaBuildAtomicCmpXchg(
B: LLVMBuilderRef,
Ptr: LLVMValueRef,
Cmp: LLVMValueRef,
New: LLVMValueRef,
scope: *const i8,
SuccessOrdering: LLVMAtomicOrdering,
FailureOrdering: LLVMAtomicOrdering,
) -> LLVMValueRef;
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
pub fn LLVMZludaBuildFence(
B: LLVMBuilderRef,
ordering: LLVMAtomicOrdering,
scope: *const i8,
Name: *const i8,
) -> LLVMValueRef;
}

View File

@ -1,191 +1,191 @@
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_declaration = method.body.is_none();
let mut body = Vec::new();
let mut remap_returns = Vec::new();
if !method.is_kernel {
for arg in method.return_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name =
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
for arg in method.input_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name =
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
body.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: arg.v_type.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: arg.name,
},
}));
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
}
let body = method
.body
.map(|statements| {
for statement in statements {
run_statement(resolver, &remap_returns, &mut body, statement)?;
}
Ok::<_, TranslateError>(body)
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(ast::Instruction::Call {
mut data,
mut arguments,
}) => {
let mut post_st = Vec::new();
for ((type_, space), ident) in data
.input_arguments
.iter_mut()
.zip(arguments.input_arguments.iter_mut())
{
if *space == ptx_parser::StateSpace::Param {
*space = ptx_parser::StateSpace::Reg;
let old_name = *ident;
*ident = resolver
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: *ident,
src: old_name,
},
}));
}
}
for ((type_, space), ident) in data
.return_arguments
.iter_mut()
.zip(arguments.return_arguments.iter_mut())
{
if *space == ptx_parser::StateSpace::Param {
*space = ptx_parser::StateSpace::Reg;
let old_name = *ident;
*ident = resolver
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
post_st.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: *ident,
},
}));
}
}
result.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
result.extend(post_st.into_iter());
}
Statement::Instruction(ast::Instruction::Ret { data }) => {
for (old_name, new_name, type_) in remap_returns.iter() {
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: *new_name,
src: *old_name,
},
}));
}
result.push(Statement::Instruction(ast::Instruction::Ret { data }));
}
statement => {
result.push(statement);
}
}
Ok(())
}
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_declaration = method.body.is_none();
let mut body = Vec::new();
let mut remap_returns = Vec::new();
if !method.is_kernel {
for arg in method.return_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name =
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
for arg in method.input_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name =
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
body.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: arg.v_type.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: arg.name,
},
}));
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
}
let body = method
.body
.map(|statements| {
for statement in statements {
run_statement(resolver, &remap_returns, &mut body, statement)?;
}
Ok::<_, TranslateError>(body)
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(ast::Instruction::Call {
mut data,
mut arguments,
}) => {
let mut post_st = Vec::new();
for ((type_, space), ident) in data
.input_arguments
.iter_mut()
.zip(arguments.input_arguments.iter_mut())
{
if *space == ptx_parser::StateSpace::Param {
*space = ptx_parser::StateSpace::Reg;
let old_name = *ident;
*ident = resolver
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: *ident,
src: old_name,
},
}));
}
}
for ((type_, space), ident) in data
.return_arguments
.iter_mut()
.zip(arguments.return_arguments.iter_mut())
{
if *space == ptx_parser::StateSpace::Param {
*space = ptx_parser::StateSpace::Reg;
let old_name = *ident;
*ident = resolver
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
post_st.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: *ident,
},
}));
}
}
result.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
result.extend(post_st.into_iter());
}
Statement::Instruction(ast::Instruction::Ret { data }) => {
for (old_name, new_name, type_) in remap_returns.iter() {
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: *new_name,
src: *old_name,
},
}));
}
result.push(Statement::Instruction(ast::Instruction::Ret { data }));
}
statement => {
result.push(statement);
}
}
Ok(())
}

View File

@ -1,301 +1,301 @@
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: Function2<
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(resolver, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
body,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
rounding_mode_f32: method.rounding_mode_f32,
rounding_mode_f16f64: method.rounding_mode_f16f64,
})
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: UnconditionalStatement,
) -> Result<(), TranslateError> {
let mut visitor = FlattenArguments::new(resolver, result);
let new_statement = statement.visit_map(&mut visitor)?;
visitor.result.push(new_statement);
Ok(())
}
struct FlattenArguments<'a, 'input> {
result: &'a mut Vec<ExpandedStatement>,
resolver: &'a mut GlobalStringIdentResolver2<'input>,
post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'input> FlattenArguments<'a, 'input> {
fn new(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
result: &'a mut Vec<ExpandedStatement>,
) -> Self {
FlattenArguments {
result,
resolver,
post_stmts: Vec::new(),
}
}
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
Ok(name)
}
fn reg_offset(
&mut self,
reg: SpirvWord,
offset: i32,
type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
(type_, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
if state_space == ast::StateSpace::Reg {
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
if *reg_space != ast::StateSpace::Reg {
return Err(error_mismatched_type());
}
let reg_scalar_type = match reg_type {
ast::Type::Scalar(underlying_type) => *underlying_type,
_ => return Err(error_mismatched_type()),
};
let reg_type = reg_type.clone();
let id_constant_stmt = self
.resolver
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: reg_scalar_type,
value: ast::ImmediateValue::S64(offset as i64),
}));
let arith_details = match reg_scalar_type.kind() {
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
}),
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
})
}
_ => return Err(error_unreachable()),
};
let id_add_result = self
.resolver
.register_unnamed(Some((reg_type, state_space)));
self.result
.push(Statement::Instruction(ast::Instruction::Add {
data: arith_details,
arguments: ast::AddArgs {
dst: id_add_result,
src1: reg,
src2: id_constant_stmt,
},
}));
Ok(id_add_result)
} else {
let id_constant_stmt = self.resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
let dst = self
.resolver
.register_unnamed(Some((type_.clone(), state_space)));
self.result.push(Statement::PtrAccess(PtrAccess {
underlying_type: type_.clone(),
state_space: state_space,
dst,
ptr_src: reg,
offset_src: id_constant_stmt,
}));
Ok(dst)
}
}
fn immediate(
&mut self,
value: ast::ImmediateValue,
type_space: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) =
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
(*scalar, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
let id = self
.resolver
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
value,
}));
Ok(id)
}
fn vec_member(
&mut self,
vector_ident: SpirvWord,
member: u8,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
(ast::Type::Vector(vector_width, scalar_t), space) => {
(*vector_width, *scalar_t, *space)
}
_ => return Err(error_mismatched_type()),
};
let temporary = self
.resolver
.register_unnamed(Some((scalar_type.into(), space)));
if is_dst {
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
scalar_type,
vector_width,
vector_dst: vector_ident,
vector_src: vector_ident,
scalar_src: temporary,
member,
}));
} else {
self.result.push(Statement::VectorRead(VectorRead {
scalar_type,
vector_width,
scalar_dst: temporary,
vector_src: vector_ident,
member,
}));
}
Ok(temporary)
}
fn vec_pack(
&mut self,
vector_elements: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
let (width, scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
_ => return Err(error_mismatched_type()),
};
let temporary_vector = self
.resolver
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temporary_vector,
unpacked: vector_elements,
relaxed_type_check,
});
if is_dst {
self.post_stmts.push(statement);
} else {
self.result.push(statement);
}
Ok(temporary_vector)
}
}
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
for FlattenArguments<'a, 'b>
{
fn visit(
&mut self,
args: ast::ParsedOperand<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
match args {
ast::ParsedOperand::Reg(r) => self.reg(r),
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
ast::ParsedOperand::RegOffset(reg, offset) => {
self.reg_offset(reg, offset, type_space, is_dst)
}
ast::ParsedOperand::VecMember(vec, member) => {
self.vec_member(vec, member, type_space, is_dst)
}
ast::ParsedOperand::VecPack(vecs) => {
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
}
}
}
fn visit_ident(
&mut self,
name: SpirvWord,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
_relaxed_type_check: bool,
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
self.reg(name)
}
}
impl Drop for FlattenArguments<'_, '_> {
fn drop(&mut self) {
self.result.extend(self.post_stmts.drain(..));
}
}
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: Function2<
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(resolver, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
body,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
rounding_mode_f32: method.rounding_mode_f32,
rounding_mode_f16f64: method.rounding_mode_f16f64,
})
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: UnconditionalStatement,
) -> Result<(), TranslateError> {
let mut visitor = FlattenArguments::new(resolver, result);
let new_statement = statement.visit_map(&mut visitor)?;
visitor.result.push(new_statement);
Ok(())
}
struct FlattenArguments<'a, 'input> {
result: &'a mut Vec<ExpandedStatement>,
resolver: &'a mut GlobalStringIdentResolver2<'input>,
post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'input> FlattenArguments<'a, 'input> {
fn new(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
result: &'a mut Vec<ExpandedStatement>,
) -> Self {
FlattenArguments {
result,
resolver,
post_stmts: Vec::new(),
}
}
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
Ok(name)
}
fn reg_offset(
&mut self,
reg: SpirvWord,
offset: i32,
type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
(type_, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
if state_space == ast::StateSpace::Reg {
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
if *reg_space != ast::StateSpace::Reg {
return Err(error_mismatched_type());
}
let reg_scalar_type = match reg_type {
ast::Type::Scalar(underlying_type) => *underlying_type,
_ => return Err(error_mismatched_type()),
};
let reg_type = reg_type.clone();
let id_constant_stmt = self
.resolver
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: reg_scalar_type,
value: ast::ImmediateValue::S64(offset as i64),
}));
let arith_details = match reg_scalar_type.kind() {
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
}),
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
})
}
_ => return Err(error_unreachable()),
};
let id_add_result = self
.resolver
.register_unnamed(Some((reg_type, state_space)));
self.result
.push(Statement::Instruction(ast::Instruction::Add {
data: arith_details,
arguments: ast::AddArgs {
dst: id_add_result,
src1: reg,
src2: id_constant_stmt,
},
}));
Ok(id_add_result)
} else {
let id_constant_stmt = self.resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
let dst = self
.resolver
.register_unnamed(Some((type_.clone(), state_space)));
self.result.push(Statement::PtrAccess(PtrAccess {
underlying_type: type_.clone(),
state_space: state_space,
dst,
ptr_src: reg,
offset_src: id_constant_stmt,
}));
Ok(dst)
}
}
fn immediate(
&mut self,
value: ast::ImmediateValue,
type_space: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) =
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
(*scalar, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
let id = self
.resolver
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
value,
}));
Ok(id)
}
fn vec_member(
&mut self,
vector_ident: SpirvWord,
member: u8,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
(ast::Type::Vector(vector_width, scalar_t), space) => {
(*vector_width, *scalar_t, *space)
}
_ => return Err(error_mismatched_type()),
};
let temporary = self
.resolver
.register_unnamed(Some((scalar_type.into(), space)));
if is_dst {
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
scalar_type,
vector_width,
vector_dst: vector_ident,
vector_src: vector_ident,
scalar_src: temporary,
member,
}));
} else {
self.result.push(Statement::VectorRead(VectorRead {
scalar_type,
vector_width,
scalar_dst: temporary,
vector_src: vector_ident,
member,
}));
}
Ok(temporary)
}
fn vec_pack(
&mut self,
vector_elements: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
let (width, scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
_ => return Err(error_mismatched_type()),
};
let temporary_vector = self
.resolver
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temporary_vector,
unpacked: vector_elements,
relaxed_type_check,
});
if is_dst {
self.post_stmts.push(statement);
} else {
self.result.push(statement);
}
Ok(temporary_vector)
}
}
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
for FlattenArguments<'a, 'b>
{
fn visit(
&mut self,
args: ast::ParsedOperand<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
match args {
ast::ParsedOperand::Reg(r) => self.reg(r),
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
ast::ParsedOperand::RegOffset(reg, offset) => {
self.reg_offset(reg, offset, type_space, is_dst)
}
ast::ParsedOperand::VecMember(vec, member) => {
self.vec_member(vec, member, type_space, is_dst)
}
ast::ParsedOperand::VecPack(vecs) => {
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
}
}
}
fn visit_ident(
&mut self,
name: SpirvWord,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
_relaxed_type_check: bool,
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
self.reg(name)
}
}
impl Drop for FlattenArguments<'_, '_> {
fn drop(&mut self) {
self.result.extend(self.post_stmts.drain(..));
}
}

View File

@ -1,208 +1,208 @@
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
let mut sreg_to_function =
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
SpecialRegistersMap2::foreach_declaration(
resolver,
|sreg, (return_arguments, name, input_arguments)| {
result.push(UnconditionalDirective::Method(UnconditionalFunction {
return_arguments,
name,
input_arguments,
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}));
sreg_to_function.insert(sreg, name);
},
);
let mut visitor = SpecialRegisterResolver {
resolver,
special_registers,
sreg_to_function,
result: Vec::new(),
};
for directive in directives.into_iter() {
result.push(run_directive(&mut visitor, directive)?);
}
Ok(result)
}
fn run_directive<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
directive: UnconditionalDirective,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
})
}
fn run_method<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
method: UnconditionalFunction,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(visitor, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
result: &mut Vec<UnconditionalStatement>,
statement: UnconditionalStatement,
) -> Result<(), TranslateError> {
let converted_statement = statement.visit_map(visitor)?;
result.extend(visitor.result.drain(..));
result.push(converted_statement);
Ok(())
}
struct SpecialRegisterResolver<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
result: Vec<UnconditionalStatement>,
}
impl<'a, 'b, 'input>
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
for SpecialRegisterResolver<'a, 'input>
{
fn visit(
&mut self,
operand: ast::ParsedOperand<SpirvWord>,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
map_operand(operand, &mut |ident, vector_index| {
self.replace_sreg(ident, vector_index, is_dst)
})
}
fn visit_ident(
&mut self,
args: SpirvWord,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
}
}
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
fn replace_sreg(
&mut self,
name: SpirvWord,
vector_index: Option<u8>,
is_dst: bool,
) -> Result<Option<SpirvWord>, TranslateError> {
if let Some(sreg) = self.special_registers.get(name) {
if is_dst {
return Err(error_mismatched_type());
}
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => {
if inp_type != ast::ScalarType::U8 {
return Err(TranslateError::Unreachable);
}
let constant = self.resolver.register_unnamed(Some((
ast::Type::Scalar(inp_type),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: constant,
typ: inp_type,
value: ast::ImmediateValue::U64(idx as u64),
}));
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
}
(None, None) => Vec::new(),
_ => return Err(error_mismatched_type()),
};
let return_type = sreg.get_function_return_type();
let fn_result = self
.resolver
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
let return_arguments = vec![(
fn_result,
ast::Type::Scalar(return_type),
ast::StateSpace::Reg,
)];
let data = ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
input_arguments: input_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
};
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
func: self.sreg_to_function[&sreg],
input_arguments: input_arguments
.iter()
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
.collect(),
};
self.result
.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
Ok(Some(fn_result))
} else {
Ok(None)
}
}
}
pub fn map_operand<T: Copy, Err>(
this: ast::ParsedOperand<T>,
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
) -> Result<ast::ParsedOperand<T>, Err> {
Ok(match this {
ast::ParsedOperand::Reg(ident) => {
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
}
ast::ParsedOperand::RegOffset(ident, offset) => {
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
}
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
Some(ident) => ast::ParsedOperand::Reg(ident),
None => ast::ParsedOperand::VecMember(ident, member),
},
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
idents
.into_iter()
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
.collect::<Result<Vec<_>, _>>()?,
),
})
}
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
let mut sreg_to_function =
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
SpecialRegistersMap2::foreach_declaration(
resolver,
|sreg, (return_arguments, name, input_arguments)| {
result.push(UnconditionalDirective::Method(UnconditionalFunction {
return_arguments,
name,
input_arguments,
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}));
sreg_to_function.insert(sreg, name);
},
);
let mut visitor = SpecialRegisterResolver {
resolver,
special_registers,
sreg_to_function,
result: Vec::new(),
};
for directive in directives.into_iter() {
result.push(run_directive(&mut visitor, directive)?);
}
Ok(result)
}
fn run_directive<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
directive: UnconditionalDirective,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
})
}
fn run_method<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
method: UnconditionalFunction,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(visitor, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
result: &mut Vec<UnconditionalStatement>,
statement: UnconditionalStatement,
) -> Result<(), TranslateError> {
let converted_statement = statement.visit_map(visitor)?;
result.extend(visitor.result.drain(..));
result.push(converted_statement);
Ok(())
}
struct SpecialRegisterResolver<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
result: Vec<UnconditionalStatement>,
}
impl<'a, 'b, 'input>
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
for SpecialRegisterResolver<'a, 'input>
{
fn visit(
&mut self,
operand: ast::ParsedOperand<SpirvWord>,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
map_operand(operand, &mut |ident, vector_index| {
self.replace_sreg(ident, vector_index, is_dst)
})
}
fn visit_ident(
&mut self,
args: SpirvWord,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
}
}
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
fn replace_sreg(
&mut self,
name: SpirvWord,
vector_index: Option<u8>,
is_dst: bool,
) -> Result<Option<SpirvWord>, TranslateError> {
if let Some(sreg) = self.special_registers.get(name) {
if is_dst {
return Err(error_mismatched_type());
}
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => {
if inp_type != ast::ScalarType::U8 {
return Err(TranslateError::Unreachable);
}
let constant = self.resolver.register_unnamed(Some((
ast::Type::Scalar(inp_type),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: constant,
typ: inp_type,
value: ast::ImmediateValue::U64(idx as u64),
}));
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
}
(None, None) => Vec::new(),
_ => return Err(error_mismatched_type()),
};
let return_type = sreg.get_function_return_type();
let fn_result = self
.resolver
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
let return_arguments = vec![(
fn_result,
ast::Type::Scalar(return_type),
ast::StateSpace::Reg,
)];
let data = ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
input_arguments: input_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
};
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
func: self.sreg_to_function[&sreg],
input_arguments: input_arguments
.iter()
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
.collect(),
};
self.result
.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
Ok(Some(fn_result))
} else {
Ok(None)
}
}
}
pub fn map_operand<T: Copy, Err>(
this: ast::ParsedOperand<T>,
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
) -> Result<ast::ParsedOperand<T>, Err> {
Ok(match this {
ast::ParsedOperand::Reg(ident) => {
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
}
ast::ParsedOperand::RegOffset(ident, offset) => {
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
}
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
Some(ident) => ast::ParsedOperand::Reg(ident),
None => ast::ParsedOperand::VecMember(ident, member),
},
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
idents
.into_iter()
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
.collect::<Result<Vec<_>, _>>()?,
),
})
}

View File

@ -1,45 +1,45 @@
use super::*;
pub(super) fn run<'input>(
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() {
run_directive(&mut result, &mut directive)?;
result.push(directive);
}
Ok(result)
}
fn run_directive<'input>(
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match directive {
Directive2::Variable(..) => {}
Directive2::Method(function2) => run_function(result, function2),
}
Ok(())
}
fn run_function<'input>(
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) {
function.body = function.body.take().map(|statements| {
statements
.into_iter()
.filter_map(|statement| match statement {
Statement::Variable(var @ ast::Variable {
state_space:
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
..
}) => {
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
None
}
s => Some(s),
})
.collect()
});
}
use super::*;
pub(super) fn run<'input>(
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() {
run_directive(&mut result, &mut directive)?;
result.push(directive);
}
Ok(result)
}
fn run_directive<'input>(
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match directive {
Directive2::Variable(..) => {}
Directive2::Method(function2) => run_function(result, function2),
}
Ok(())
}
fn run_function<'input>(
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) {
function.body = function.body.take().map(|statements| {
statements
.into_iter()
.filter_map(|statement| match statement {
Statement::Variable(var @ ast::Variable {
state_space:
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
..
}) => {
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
None
}
s => Some(s),
})
.collect()
});
}

View File

@ -1,404 +1,404 @@
use super::*;
// This pass:
// * Turns all .local, .param and .reg in-body variables into .local variables
// (if _not_ an input method argument)
// * Inserts explicit `ld`/`st` for newly converted .reg variables
// * Fixup state space of all existing `ld`/`st` instructions into newly
// converted variables
// * Turns `.entry` input arguments into param::entry and all related `.param`
// loads into `param::entry` loads
// * All `.func` input arguments are turned into `.reg` arguments by another
// pass, so we do nothing there
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
let visitor = InsertMemSSAVisitor::new(resolver);
Directive2::Method(run_method(visitor, method)?)
}
})
}
fn run_method<'a, 'input>(
mut visitor: InsertMemSSAVisitor<'a, 'input>,
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_kernel = method.is_kernel;
if is_kernel {
for arg in method.input_arguments.iter_mut() {
let old_name = arg.name;
let old_space = arg.state_space;
let new_space = ast::StateSpace::ParamEntry;
let new_name = visitor
.resolver
.register_unnamed(Some((arg.v_type.clone(), new_space)));
visitor.input_argument(old_name, new_name, old_space)?;
arg.name = new_name;
arg.state_space = new_space;
}
};
for arg in method.return_arguments.iter_mut() {
visitor.visit_variable(arg)?;
}
let return_arguments = &method.return_arguments[..];
let body = method
.body
.map(move |statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(&mut visitor, return_arguments, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'a, 'input>(
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
return_arguments: &[ast::Variable<SpirvWord>],
result: &mut Vec<ExpandedStatement>,
statement: ExpandedStatement,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(ast::Instruction::Ret { data }) => {
let statement = if return_arguments.is_empty() {
Statement::Instruction(ast::Instruction::Ret { data })
} else {
Statement::RetValue(
data,
return_arguments
.iter()
.map(|arg| {
if arg.state_space != ast::StateSpace::Local {
return Err(error_unreachable());
}
Ok((arg.name, arg.v_type.clone()))
})
.collect::<Result<Vec<_>, _>>()?,
)
};
let new_statement = statement.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::Variable(mut var) => {
visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
}
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
let instruction = visitor.visit_ld(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
let instruction = visitor.visit_st(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::PtrAccess(ptr_access) => {
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
let statement = statement.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
s => {
let new_statement = s.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
}
Ok(())
}
struct InsertMemSSAVisitor<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
variables: FxHashMap<SpirvWord, RemapAction>,
pre: Vec<ast::Instruction<SpirvWord>>,
post: Vec<ast::Instruction<SpirvWord>>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
resolver,
variables: FxHashMap::default(),
pre: Vec::new(),
post: Vec::new(),
}
}
fn input_argument(
&mut self,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
if old_space != ast::StateSpace::Param {
return Err(error_unreachable());
}
self.variables.insert(
old_name,
RemapAction::LDStSpaceChange {
name: new_name,
old_space,
new_space: ast::StateSpace::ParamEntry,
},
);
Ok(())
}
fn variable(
&mut self,
type_: &ast::Type,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<bool, TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables.insert(
old_name,
RemapAction::PreLdPostSt {
name: new_name,
type_: type_.clone(),
},
);
true
}
ast::StateSpace::Param => {
self.variables.insert(
old_name,
RemapAction::LDStSpaceChange {
old_space,
new_space: ast::StateSpace::Local,
name: new_name,
},
);
true
}
// Good as-is
ast::StateSpace::Local
| ast::StateSpace::Generic
| ast::StateSpace::SharedCluster
| ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::SharedCta
| ast::StateSpace::Shared
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => return Err(error_unreachable()),
})
}
fn visit_st(
&self,
mut data: ast::StData,
mut arguments: ast::StArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src1) {
match remap {
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
} => {
if data.state_space != *old_space {
return Err(error_mismatched_type());
}
data.state_space = *new_space;
arguments.src1 = *name;
}
}
}
Ok(ast::Instruction::St { data, arguments })
}
fn visit_ld(
&self,
mut data: ast::LdDetails,
mut arguments: ast::LdArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
} => {
if data.state_space != *old_space {
return Err(error_mismatched_type());
}
data.state_space = *new_space;
arguments.src = *name;
}
}
}
Ok(ast::Instruction::Ld { data, arguments })
}
fn visit_ptr_access(
&mut self,
ptr_access: PtrAccess<SpirvWord>,
) -> Result<PtrAccess<SpirvWord>, TranslateError> {
let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
Some(RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
}) => (*old_space, *new_space, *name),
Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
};
if ptr_access.state_space != old_space {
return Err(error_mismatched_type());
}
// Propagate space changes in dst
let new_dst = self
.resolver
.register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
self.variables.insert(
ptr_access.dst,
RemapAction::LDStSpaceChange {
old_space,
new_space,
name: new_dst,
},
);
Ok(PtrAccess {
ptr_src: name,
dst: new_dst,
state_space: new_space,
..ptr_access
})
}
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
// Do nothing
ptx_parser::StateSpace::Local => return Ok(()),
// Handled by another pass
ptx_parser::StateSpace::Generic
| ptx_parser::StateSpace::SharedCluster
| ptx_parser::StateSpace::ParamEntry
| ptx_parser::StateSpace::Global
| ptx_parser::StateSpace::SharedCta
| ptx_parser::StateSpace::Const
| ptx_parser::StateSpace::Shared
| ptx_parser::StateSpace::ParamFunc => return Ok(()),
};
let old_name = var.name;
let new_space = ast::StateSpace::Local;
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name;
var.state_space = new_space;
Ok(())
}
}
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
for InsertMemSSAVisitor<'a, 'input>
{
fn visit(
&mut self,
ident: SpirvWord,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
if let Some(remap) = self.variables.get(&ident) {
match remap {
RemapAction::PreLdPostSt { name, type_ } => {
if is_dst {
let temp = self
.resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.post.push(ast::Instruction::St {
data: ast::StData {
state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak,
caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(),
},
arguments: ast::StArgs {
src1: *name,
src2: temp,
},
});
Ok(temp)
} else {
let temp = self
.resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.pre.push(ast::Instruction::Ld {
data: ast::LdDetails {
state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: temp,
src: *name,
},
});
Ok(temp)
}
}
RemapAction::LDStSpaceChange { .. } => {
return Err(error_mismatched_type());
}
}
} else {
Ok(ident)
}
}
fn visit_ident(
&mut self,
args: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.visit(args, type_space, is_dst, relaxed_type_check)
}
}
#[derive(Clone)]
enum RemapAction {
PreLdPostSt {
name: SpirvWord,
type_: ast::Type,
},
LDStSpaceChange {
old_space: ast::StateSpace,
new_space: ast::StateSpace,
name: SpirvWord,
},
}
use super::*;
// This pass:
// * Turns all .local, .param and .reg in-body variables into .local variables
// (if _not_ an input method argument)
// * Inserts explicit `ld`/`st` for newly converted .reg variables
// * Fixup state space of all existing `ld`/`st` instructions into newly
// converted variables
// * Turns `.entry` input arguments into param::entry and all related `.param`
// loads into `param::entry` loads
// * All `.func` input arguments are turned into `.reg` arguments by another
// pass, so we do nothing there
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
let visitor = InsertMemSSAVisitor::new(resolver);
Directive2::Method(run_method(visitor, method)?)
}
})
}
fn run_method<'a, 'input>(
mut visitor: InsertMemSSAVisitor<'a, 'input>,
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_kernel = method.is_kernel;
if is_kernel {
for arg in method.input_arguments.iter_mut() {
let old_name = arg.name;
let old_space = arg.state_space;
let new_space = ast::StateSpace::ParamEntry;
let new_name = visitor
.resolver
.register_unnamed(Some((arg.v_type.clone(), new_space)));
visitor.input_argument(old_name, new_name, old_space)?;
arg.name = new_name;
arg.state_space = new_space;
}
};
for arg in method.return_arguments.iter_mut() {
visitor.visit_variable(arg)?;
}
let return_arguments = &method.return_arguments[..];
let body = method
.body
.map(move |statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(&mut visitor, return_arguments, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'a, 'input>(
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
return_arguments: &[ast::Variable<SpirvWord>],
result: &mut Vec<ExpandedStatement>,
statement: ExpandedStatement,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(ast::Instruction::Ret { data }) => {
let statement = if return_arguments.is_empty() {
Statement::Instruction(ast::Instruction::Ret { data })
} else {
Statement::RetValue(
data,
return_arguments
.iter()
.map(|arg| {
if arg.state_space != ast::StateSpace::Local {
return Err(error_unreachable());
}
Ok((arg.name, arg.v_type.clone()))
})
.collect::<Result<Vec<_>, _>>()?,
)
};
let new_statement = statement.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::Variable(mut var) => {
visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
}
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
let instruction = visitor.visit_ld(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
let instruction = visitor.visit_st(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::PtrAccess(ptr_access) => {
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
let statement = statement.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
s => {
let new_statement = s.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
}
Ok(())
}
struct InsertMemSSAVisitor<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
variables: FxHashMap<SpirvWord, RemapAction>,
pre: Vec<ast::Instruction<SpirvWord>>,
post: Vec<ast::Instruction<SpirvWord>>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
resolver,
variables: FxHashMap::default(),
pre: Vec::new(),
post: Vec::new(),
}
}
fn input_argument(
&mut self,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
if old_space != ast::StateSpace::Param {
return Err(error_unreachable());
}
self.variables.insert(
old_name,
RemapAction::LDStSpaceChange {
name: new_name,
old_space,
new_space: ast::StateSpace::ParamEntry,
},
);
Ok(())
}
fn variable(
&mut self,
type_: &ast::Type,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<bool, TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables.insert(
old_name,
RemapAction::PreLdPostSt {
name: new_name,
type_: type_.clone(),
},
);
true
}
ast::StateSpace::Param => {
self.variables.insert(
old_name,
RemapAction::LDStSpaceChange {
old_space,
new_space: ast::StateSpace::Local,
name: new_name,
},
);
true
}
// Good as-is
ast::StateSpace::Local
| ast::StateSpace::Generic
| ast::StateSpace::SharedCluster
| ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::SharedCta
| ast::StateSpace::Shared
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => return Err(error_unreachable()),
})
}
fn visit_st(
&self,
mut data: ast::StData,
mut arguments: ast::StArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src1) {
match remap {
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
} => {
if data.state_space != *old_space {
return Err(error_mismatched_type());
}
data.state_space = *new_space;
arguments.src1 = *name;
}
}
}
Ok(ast::Instruction::St { data, arguments })
}
fn visit_ld(
&self,
mut data: ast::LdDetails,
mut arguments: ast::LdArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
} => {
if data.state_space != *old_space {
return Err(error_mismatched_type());
}
data.state_space = *new_space;
arguments.src = *name;
}
}
}
Ok(ast::Instruction::Ld { data, arguments })
}
fn visit_ptr_access(
&mut self,
ptr_access: PtrAccess<SpirvWord>,
) -> Result<PtrAccess<SpirvWord>, TranslateError> {
let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
Some(RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
}) => (*old_space, *new_space, *name),
Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
};
if ptr_access.state_space != old_space {
return Err(error_mismatched_type());
}
// Propagate space changes in dst
let new_dst = self
.resolver
.register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
self.variables.insert(
ptr_access.dst,
RemapAction::LDStSpaceChange {
old_space,
new_space,
name: new_dst,
},
);
Ok(PtrAccess {
ptr_src: name,
dst: new_dst,
state_space: new_space,
..ptr_access
})
}
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
// Do nothing
ptx_parser::StateSpace::Local => return Ok(()),
// Handled by another pass
ptx_parser::StateSpace::Generic
| ptx_parser::StateSpace::SharedCluster
| ptx_parser::StateSpace::ParamEntry
| ptx_parser::StateSpace::Global
| ptx_parser::StateSpace::SharedCta
| ptx_parser::StateSpace::Const
| ptx_parser::StateSpace::Shared
| ptx_parser::StateSpace::ParamFunc => return Ok(()),
};
let old_name = var.name;
let new_space = ast::StateSpace::Local;
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name;
var.state_space = new_space;
Ok(())
}
}
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
for InsertMemSSAVisitor<'a, 'input>
{
fn visit(
&mut self,
ident: SpirvWord,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
if let Some(remap) = self.variables.get(&ident) {
match remap {
RemapAction::PreLdPostSt { name, type_ } => {
if is_dst {
let temp = self
.resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.post.push(ast::Instruction::St {
data: ast::StData {
state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak,
caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(),
},
arguments: ast::StArgs {
src1: *name,
src2: temp,
},
});
Ok(temp)
} else {
let temp = self
.resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.pre.push(ast::Instruction::Ld {
data: ast::LdDetails {
state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: temp,
src: *name,
},
});
Ok(temp)
}
}
RemapAction::LDStSpaceChange { .. } => {
return Err(error_mismatched_type());
}
}
} else {
Ok(ident)
}
}
fn visit_ident(
&mut self,
args: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.visit(args, type_space, is_dst, relaxed_type_check)
}
}
#[derive(Clone)]
enum RemapAction {
PreLdPostSt {
name: SpirvWord,
type_: ast::Type,
},
LDStSpaceChange {
old_space: ast::StateSpace,
new_space: ast::StateSpace,
name: SpirvWord,
},
}

View File

@ -1,401 +1,401 @@
use std::mem;
use super::*;
use ptx_parser as ast;
/*
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- st.param [x] y (used as function return arguments) same rule as above applies
- generic/global ld: for instruction `ld x, [y]`, y must be of type
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
documented special ld/st/cvt conversion rules are applied to dst
- generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer
*/
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {
method.body = method
.body
.map(|statements| run_statements(resolver, statements))
.transpose()?;
Directive2::Method(method)
}
})
}
fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
func: Vec<ExpandedStatement>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
insert_implicit_conversions_impl(resolver, &mut result, s)?;
}
Ok(result)
}
fn insert_implicit_conversions_impl<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
func: &mut Vec<ExpandedStatement>,
stmt: ExpandedStatement,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
&mut |operand,
type_state: Option<(&ast::Type, ast::StateSpace)>,
is_dst,
relaxed_type_check| {
let (instr_type, instruction_space) = match type_state {
None => return Ok(operand),
Some(t) => t,
};
let (operand_type, operand_space) = resolver.get_typed(operand)?;
let conversion_fn = if relaxed_type_check {
if is_dst {
should_convert_relaxed_dst_wrapper
} else {
should_convert_relaxed_src_wrapper
}
} else {
default_implicit_conversion
};
match conversion_fn(
(*operand_space, &operand_type),
(instruction_space, instr_type),
)? {
Some(conv_kind) => {
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
let mut from_type = instr_type.clone();
let mut from_space = instruction_space;
let mut to_type = operand_type.clone();
let mut to_space = *operand_space;
let mut src =
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
let mut dst = operand;
let result = Ok::<_, TranslateError>(src);
if !is_dst {
mem::swap(&mut src, &mut dst);
mem::swap(&mut from_type, &mut to_type);
mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
from_type,
from_space,
to_type,
to_space,
kind: conv_kind,
}));
result
}
None => Ok(operand),
}
},
)?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
pub(crate) fn default_implicit_conversion(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if instruction_space == ast::StateSpace::Reg {
if operand_space == ast::StateSpace::Reg {
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
(operand_type, instruction_type)
{
if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
}
} else if is_addressable(operand_space) {
return Ok(Some(ConversionKind::AddressOf));
}
}
if instruction_space != operand_space {
default_implicit_conversion_space((operand_space, operand_type), instruction_space)
} else if instruction_type != operand_type {
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
} else {
Ok(None)
}
}
fn is_addressable(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Const
| ast::StateSpace::Generic
| ast::StateSpace::Global
| ast::StateSpace::Local
| ast::StateSpace::Shared => true,
ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => todo!(),
}
}
// Space is different
fn default_implicit_conversion_space(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
instruction_space: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{
Ok(Some(ConversionKind::PtrToPtr))
} else if operand_space == ast::StateSpace::Reg {
match operand_type {
// TODO: 32 bit
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
ast::StateSpace::Global
| ast::StateSpace::Generic
| ast::StateSpace::Const
| ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(error_mismatched_type()),
},
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr))
}
_ => Err(error_mismatched_type()),
},
_ => Err(error_mismatched_type()),
}
} else {
Err(error_mismatched_type())
}
}
// Space is same, but type is different
fn default_implicit_conversion_type(
space: ast::StateSpace,
operand_type: &ast::Type,
instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
if space == ast::StateSpace::Reg {
if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default))
} else {
Err(TranslateError::MismatchedType)
}
} else {
Ok(Some(ConversionKind::PtrToPtr))
}
}
fn coerces_to_generic(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta
| ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true,
ast::StateSpace::Reg
| ast::StateSpace::Param
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
| ast::StateSpace::Generic => false,
}
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ast::ScalarKind::Signed => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned
}
ast::ScalarKind::Unsigned => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed
}
ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
pub(crate) fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space {
return Err(TranslateError::MismatchedType);
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed => {
if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
}
} else {
None
}
}
ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_dst(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
pub(crate) fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space {
return Err(error_mismatched_type());
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(error_mismatched_type()),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_src(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
use std::mem;
use super::*;
use ptx_parser as ast;
/*
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- st.param [x] y (used as function return arguments) same rule as above applies
- generic/global ld: for instruction `ld x, [y]`, y must be of type
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
documented special ld/st/cvt conversion rules are applied to dst
- generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer
*/
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {
method.body = method
.body
.map(|statements| run_statements(resolver, statements))
.transpose()?;
Directive2::Method(method)
}
})
}
fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
func: Vec<ExpandedStatement>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
insert_implicit_conversions_impl(resolver, &mut result, s)?;
}
Ok(result)
}
fn insert_implicit_conversions_impl<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
func: &mut Vec<ExpandedStatement>,
stmt: ExpandedStatement,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
&mut |operand,
type_state: Option<(&ast::Type, ast::StateSpace)>,
is_dst,
relaxed_type_check| {
let (instr_type, instruction_space) = match type_state {
None => return Ok(operand),
Some(t) => t,
};
let (operand_type, operand_space) = resolver.get_typed(operand)?;
let conversion_fn = if relaxed_type_check {
if is_dst {
should_convert_relaxed_dst_wrapper
} else {
should_convert_relaxed_src_wrapper
}
} else {
default_implicit_conversion
};
match conversion_fn(
(*operand_space, &operand_type),
(instruction_space, instr_type),
)? {
Some(conv_kind) => {
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
let mut from_type = instr_type.clone();
let mut from_space = instruction_space;
let mut to_type = operand_type.clone();
let mut to_space = *operand_space;
let mut src =
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
let mut dst = operand;
let result = Ok::<_, TranslateError>(src);
if !is_dst {
mem::swap(&mut src, &mut dst);
mem::swap(&mut from_type, &mut to_type);
mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
from_type,
from_space,
to_type,
to_space,
kind: conv_kind,
}));
result
}
None => Ok(operand),
}
},
)?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
pub(crate) fn default_implicit_conversion(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if instruction_space == ast::StateSpace::Reg {
if operand_space == ast::StateSpace::Reg {
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
(operand_type, instruction_type)
{
if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
}
} else if is_addressable(operand_space) {
return Ok(Some(ConversionKind::AddressOf));
}
}
if instruction_space != operand_space {
default_implicit_conversion_space((operand_space, operand_type), instruction_space)
} else if instruction_type != operand_type {
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
} else {
Ok(None)
}
}
fn is_addressable(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Const
| ast::StateSpace::Generic
| ast::StateSpace::Global
| ast::StateSpace::Local
| ast::StateSpace::Shared => true,
ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => todo!(),
}
}
// Space is different
fn default_implicit_conversion_space(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
instruction_space: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{
Ok(Some(ConversionKind::PtrToPtr))
} else if operand_space == ast::StateSpace::Reg {
match operand_type {
// TODO: 32 bit
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
ast::StateSpace::Global
| ast::StateSpace::Generic
| ast::StateSpace::Const
| ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(error_mismatched_type()),
},
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr))
}
_ => Err(error_mismatched_type()),
},
_ => Err(error_mismatched_type()),
}
} else {
Err(error_mismatched_type())
}
}
// Space is same, but type is different
fn default_implicit_conversion_type(
space: ast::StateSpace,
operand_type: &ast::Type,
instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
if space == ast::StateSpace::Reg {
if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default))
} else {
Err(TranslateError::MismatchedType)
}
} else {
Ok(Some(ConversionKind::PtrToPtr))
}
}
fn coerces_to_generic(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta
| ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true,
ast::StateSpace::Reg
| ast::StateSpace::Param
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
| ast::StateSpace::Generic => false,
}
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ast::ScalarKind::Signed => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned
}
ast::ScalarKind::Unsigned => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed
}
ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
pub(crate) fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space {
return Err(TranslateError::MismatchedType);
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed => {
if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
}
} else {
None
}
}
ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_dst(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
pub(crate) fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space {
return Err(error_mismatched_type());
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(error_mismatched_type()),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_src(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}

View File

@ -1,194 +1,194 @@
use super::*;
use ptx_parser as ast;
pub(crate) fn run<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
resolver.start_scope();
let result = directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
fn run_directive<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2, TranslateError> {
Ok(match directive {
ast::Directive::Variable(linking, var) => {
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
}
ast::Directive::Method(linking, directive) => {
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
}
})
}
fn run_method<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2, TranslateError> {
let is_kernel = method.func_directive.name.is_kernel();
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
resolver.start_scope();
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
run_statements(resolver, &mut result, statements)?;
Ok::<_, TranslateError>(result)
})
.transpose()?;
resolver.end_scope();
Ok(Function2 {
return_arguments,
name,
input_arguments,
body,
import_as: None,
linkage,
is_kernel,
tuning: method.tuning,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})
}
fn run_function_decl<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
assert!(func_directive.shared_mem.is_none());
let return_arguments = func_directive
.return_arguments
.into_iter()
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
let input_arguments = func_directive
.input_arguments
.into_iter()
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
Ok((return_arguments, input_arguments))
}
fn run_variable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable {
name: resolver.add(
Cow::Borrowed(variable.name),
Some((variable.v_type.clone(), variable.state_space)),
)?,
align: variable.align,
v_type: variable.v_type,
state_space: variable.state_space,
array_init: variable.array_init,
})
}
fn run_statements<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<(), TranslateError> {
for statement in statements.iter() {
match statement {
ast::Statement::Label(label) => {
resolver.add(Cow::Borrowed(*label), None)?;
}
_ => {}
}
}
for statement in statements {
match statement {
ast::Statement::Label(label) => {
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
}
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
ast::Statement::Instruction(predicate, instruction) => {
result.push(Statement::Instruction((
predicate
.map(|pred| {
Ok::<_, TranslateError>(ast::PredAt {
not: pred.not,
label: resolver.get(pred.label)?,
})
})
.transpose()?,
run_instruction(resolver, instruction)?,
)))
}
ast::Statement::Block(block) => {
resolver.start_scope();
run_statements(resolver, result, block)?;
resolver.end_scope();
}
}
}
Ok(())
}
fn run_instruction<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(instruction, &mut |name: &'input str,
_: Option<(
&ast::Type,
ast::StateSpace,
)>,
_,
_| {
resolver.get(&name)
})
}
fn run_multivariable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> {
match variable.count {
Some(count) => {
for i in 0..count {
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
let ident = resolver.add(
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable {
align: variable.var.align,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident,
array_init: variable.var.array_init.clone(),
}));
}
}
None => {
let name = Cow::Borrowed(variable.var.name);
let ident = resolver.add(
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable {
align: variable.var.align,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident,
array_init: variable.var.array_init.clone(),
}));
}
}
Ok(())
}
use super::*;
use ptx_parser as ast;
pub(crate) fn run<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
resolver.start_scope();
let result = directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
fn run_directive<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2, TranslateError> {
Ok(match directive {
ast::Directive::Variable(linking, var) => {
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
}
ast::Directive::Method(linking, directive) => {
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
}
})
}
fn run_method<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2, TranslateError> {
let is_kernel = method.func_directive.name.is_kernel();
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
resolver.start_scope();
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
run_statements(resolver, &mut result, statements)?;
Ok::<_, TranslateError>(result)
})
.transpose()?;
resolver.end_scope();
Ok(Function2 {
return_arguments,
name,
input_arguments,
body,
import_as: None,
linkage,
is_kernel,
tuning: method.tuning,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})
}
fn run_function_decl<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
assert!(func_directive.shared_mem.is_none());
let return_arguments = func_directive
.return_arguments
.into_iter()
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
let input_arguments = func_directive
.input_arguments
.into_iter()
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
Ok((return_arguments, input_arguments))
}
fn run_variable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable {
name: resolver.add(
Cow::Borrowed(variable.name),
Some((variable.v_type.clone(), variable.state_space)),
)?,
align: variable.align,
v_type: variable.v_type,
state_space: variable.state_space,
array_init: variable.array_init,
})
}
fn run_statements<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<(), TranslateError> {
for statement in statements.iter() {
match statement {
ast::Statement::Label(label) => {
resolver.add(Cow::Borrowed(*label), None)?;
}
_ => {}
}
}
for statement in statements {
match statement {
ast::Statement::Label(label) => {
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
}
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
ast::Statement::Instruction(predicate, instruction) => {
result.push(Statement::Instruction((
predicate
.map(|pred| {
Ok::<_, TranslateError>(ast::PredAt {
not: pred.not,
label: resolver.get(pred.label)?,
})
})
.transpose()?,
run_instruction(resolver, instruction)?,
)))
}
ast::Statement::Block(block) => {
resolver.start_scope();
run_statements(resolver, result, block)?;
resolver.end_scope();
}
}
}
Ok(())
}
fn run_instruction<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(instruction, &mut |name: &'input str,
_: Option<(
&ast::Type,
ast::StateSpace,
)>,
_,
_| {
resolver.get(&name)
})
}
fn run_multivariable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> {
match variable.count {
Some(count) => {
for i in 0..count {
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
let ident = resolver.add(
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable {
align: variable.var.align,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident,
array_init: variable.var.array_init.clone(),
}));
}
}
None => {
let name = Cow::Borrowed(variable.var.name);
let ident = resolver.add(
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable {
align: variable.var.align,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident,
array_init: variable.var.array_init.clone(),
}));
}
}
Ok(())
}

View File

@ -1,90 +1,90 @@
use super::*;
use ptx_parser as ast;
pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<NormalizedDirective2>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: NormalizedDirective2,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: NormalizedFunction2,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(resolver, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
body,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
rounding_mode_f32: method.rounding_mode_f32,
rounding_mode_f16f64: method.rounding_mode_f16f64,
})
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<UnconditionalStatement>,
statement: NormalizedStatement,
) -> Result<(), TranslateError> {
Ok(match statement {
Statement::Label(label) => result.push(Statement::Label(label)),
Statement::Variable(var) => result.push(Statement::Variable(var)),
Statement::Instruction((predicate, instruction)) => {
if let Some(pred) = predicate {
let if_true = resolver.register_unnamed(None);
let if_false = resolver.register_unnamed(None);
let folded_bra = match &instruction {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None,
};
let mut branch = BrachCondition {
predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
}
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
result.push(Statement::Instruction(instruction));
}
result.push(Statement::Label(if_false));
} else {
result.push(Statement::Instruction(instruction));
}
}
_ => return Err(error_unreachable()),
})
}
use super::*;
use ptx_parser as ast;
pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<NormalizedDirective2>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: NormalizedDirective2,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: NormalizedFunction2,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(resolver, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
body,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
rounding_mode_f32: method.rounding_mode_f32,
rounding_mode_f16f64: method.rounding_mode_f16f64,
})
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<UnconditionalStatement>,
statement: NormalizedStatement,
) -> Result<(), TranslateError> {
Ok(match statement {
Statement::Label(label) => result.push(Statement::Label(label)),
Statement::Variable(var) => result.push(Statement::Variable(var)),
Statement::Instruction((predicate, instruction)) => {
if let Some(pred) = predicate {
let if_true = resolver.register_unnamed(None);
let if_false = resolver.register_unnamed(None);
let folded_bra = match &instruction {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None,
};
let mut branch = BrachCondition {
predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
}
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
result.push(Statement::Instruction(instruction));
}
result.push(Statement::Label(if_false));
} else {
result.push(Statement::Instruction(instruction));
}
}
_ => return Err(error_unreachable()),
})
}

View File

@ -1,268 +1,268 @@
use super::*;
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut fn_declarations = FxHashMap::default();
let remapped_directives = directives
.into_iter()
.map(|directive| run_directive(resolver, &mut fn_declarations, directive))
.collect::<Result<Vec<_>, _>>()?;
let mut result = fn_declarations
.into_iter()
.map(|(_, (return_arguments, name, input_arguments))| {
Directive2::Method(Function2 {
return_arguments,
name: name,
input_arguments,
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})
})
.collect::<Vec<_>>();
result.extend(remapped_directives);
Ok(result)
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {
method.body = method
.body
.map(|statements| run_statements(resolver, fn_declarations, statements))
.transpose()?;
Directive2::Method(method)
}
})
}
fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
statements
.into_iter()
.map(|statement| {
Ok(match statement {
Statement::Instruction(instruction) => {
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
}
s => s,
})
})
.collect::<Result<Vec<_>, _>>()
}
fn run_instruction<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
instruction: ptx_parser::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
Ok(match instruction {
i @ ptx_parser::Instruction::Sqrt {
data:
ast::RcpData {
kind: ast::RcpKind::Approx,
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Rsqrt {
data:
ast::TypeFtz {
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Rcp {
data:
ast::RcpData {
kind: ast::RcpKind::Approx,
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Ex2 {
data:
ast::TypeFtz {
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Lg2 {
data: ast::FlushToZero {
flush_to_zero: false,
},
..
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Activemask { .. } => {
to_call(resolver, fn_declarations, "activemask".into(), i)?
}
i @ ptx_parser::Instruction::Bfe { data, .. } => {
let name = ["bfe_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Bfi { data, .. } => {
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Bar { .. } => {
to_call(resolver, fn_declarations, "bar_sync".into(), i)?
}
ptx_parser::Instruction::BarRed { data, arguments } => {
if arguments.src_threadcount.is_some() {
return Err(error_todo());
}
let name = match data.pred_reduction {
ptx_parser::Reduction::And => "bar_red_and_pred",
ptx_parser::Reduction::Or => "bar_red_or_pred",
};
to_call(
resolver,
fn_declarations,
name.into(),
ptx_parser::Instruction::BarRed { data, arguments },
)?
}
ptx_parser::Instruction::ShflSync { data, arguments } => {
let mode = match data.mode {
ptx_parser::ShuffleMode::Up => "up",
ptx_parser::ShuffleMode::Down => "down",
ptx_parser::ShuffleMode::BFly => "bfly",
ptx_parser::ShuffleMode::Idx => "idx",
};
let pred = if arguments.dst_pred.is_some() {
"_pred"
} else {
""
};
to_call(
resolver,
fn_declarations,
format!("shfl_sync_{}_b32{}", mode, pred).into(),
ptx_parser::Instruction::ShflSync { data, arguments },
)?
}
i @ ptx_parser::Instruction::Nanosleep { .. } => {
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
}
i => i,
})
}
fn to_call<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
name: Cow<'input, str>,
i: ast::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
let mut data_return = Vec::new();
let mut data_input = Vec::new();
let mut arguments_return = Vec::new();
let mut arguments_input = Vec::new();
ast::visit(&i, &mut |name: &SpirvWord,
type_space: Option<(
&ptx_parser::Type,
ptx_parser::StateSpace,
)>,
is_dst: bool,
_: bool| {
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
if is_dst {
data_return.push((type_.clone(), space));
arguments_return.push(*name);
} else {
data_input.push((type_.clone(), space));
arguments_input.push(*name);
};
Ok::<_, TranslateError>(())
})?;
let fn_name = match fn_declarations.entry(name) {
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
hash_map::Entry::Vacant(vacant_entry) => {
let name = vacant_entry.key().clone();
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
vacant_entry.insert((
to_variables(resolver, &data_return),
name,
to_variables(resolver, &data_input),
));
name
}
};
Ok(ast::Instruction::Call {
data: ptx_parser::CallDetails {
uniform: false,
return_arguments: data_return,
input_arguments: data_input,
},
arguments: ptx_parser::CallArgs {
return_arguments: arguments_return,
func: fn_name,
input_arguments: arguments_input,
},
})
}
fn to_variables<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
) -> Vec<ptx_parser::Variable<SpirvWord>> {
arguments
.iter()
.map(|(type_, space)| ast::Variable {
align: None,
v_type: type_.clone(),
state_space: *space,
name: resolver.register_unnamed(Some((type_.clone(), *space))),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}
use super::*;
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut fn_declarations = FxHashMap::default();
let remapped_directives = directives
.into_iter()
.map(|directive| run_directive(resolver, &mut fn_declarations, directive))
.collect::<Result<Vec<_>, _>>()?;
let mut result = fn_declarations
.into_iter()
.map(|(_, (return_arguments, name, input_arguments))| {
Directive2::Method(Function2 {
return_arguments,
name: name,
input_arguments,
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})
})
.collect::<Vec<_>>();
result.extend(remapped_directives);
Ok(result)
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {
method.body = method
.body
.map(|statements| run_statements(resolver, fn_declarations, statements))
.transpose()?;
Directive2::Method(method)
}
})
}
fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
statements
.into_iter()
.map(|statement| {
Ok(match statement {
Statement::Instruction(instruction) => {
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
}
s => s,
})
})
.collect::<Result<Vec<_>, _>>()
}
fn run_instruction<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
instruction: ptx_parser::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
Ok(match instruction {
i @ ptx_parser::Instruction::Sqrt {
data:
ast::RcpData {
kind: ast::RcpKind::Approx,
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Rsqrt {
data:
ast::TypeFtz {
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Rcp {
data:
ast::RcpData {
kind: ast::RcpKind::Approx,
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Ex2 {
data:
ast::TypeFtz {
type_: ast::ScalarType::F32,
flush_to_zero: None | Some(false),
},
..
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Lg2 {
data: ast::FlushToZero {
flush_to_zero: false,
},
..
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
i @ ptx_parser::Instruction::Activemask { .. } => {
to_call(resolver, fn_declarations, "activemask".into(), i)?
}
i @ ptx_parser::Instruction::Bfe { data, .. } => {
let name = ["bfe_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Bfi { data, .. } => {
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Bar { .. } => {
to_call(resolver, fn_declarations, "bar_sync".into(), i)?
}
ptx_parser::Instruction::BarRed { data, arguments } => {
if arguments.src_threadcount.is_some() {
return Err(error_todo());
}
let name = match data.pred_reduction {
ptx_parser::Reduction::And => "bar_red_and_pred",
ptx_parser::Reduction::Or => "bar_red_or_pred",
};
to_call(
resolver,
fn_declarations,
name.into(),
ptx_parser::Instruction::BarRed { data, arguments },
)?
}
ptx_parser::Instruction::ShflSync { data, arguments } => {
let mode = match data.mode {
ptx_parser::ShuffleMode::Up => "up",
ptx_parser::ShuffleMode::Down => "down",
ptx_parser::ShuffleMode::BFly => "bfly",
ptx_parser::ShuffleMode::Idx => "idx",
};
let pred = if arguments.dst_pred.is_some() {
"_pred"
} else {
""
};
to_call(
resolver,
fn_declarations,
format!("shfl_sync_{}_b32{}", mode, pred).into(),
ptx_parser::Instruction::ShflSync { data, arguments },
)?
}
i @ ptx_parser::Instruction::Nanosleep { .. } => {
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
}
i => i,
})
}
fn to_call<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
name: Cow<'input, str>,
i: ast::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
let mut data_return = Vec::new();
let mut data_input = Vec::new();
let mut arguments_return = Vec::new();
let mut arguments_input = Vec::new();
ast::visit(&i, &mut |name: &SpirvWord,
type_space: Option<(
&ptx_parser::Type,
ptx_parser::StateSpace,
)>,
is_dst: bool,
_: bool| {
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
if is_dst {
data_return.push((type_.clone(), space));
arguments_return.push(*name);
} else {
data_input.push((type_.clone(), space));
arguments_input.push(*name);
};
Ok::<_, TranslateError>(())
})?;
let fn_name = match fn_declarations.entry(name) {
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
hash_map::Entry::Vacant(vacant_entry) => {
let name = vacant_entry.key().clone();
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
vacant_entry.insert((
to_variables(resolver, &data_return),
name,
to_variables(resolver, &data_input),
));
name
}
};
Ok(ast::Instruction::Call {
data: ptx_parser::CallDetails {
uniform: false,
return_arguments: data_return,
input_arguments: data_input,
},
arguments: ptx_parser::CallArgs {
return_arguments: arguments_return,
func: fn_name,
input_arguments: arguments_input,
},
})
}
fn to_variables<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
) -> Vec<ptx_parser::Variable<SpirvWord>> {
arguments
.iter()
.map(|(type_, space)| ast::Variable {
align: None,
v_type: type_.clone(),
state_space: *space,
name: resolver.register_unnamed(Some((type_.clone(), *space))),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}

View File

@ -1,33 +1,33 @@
use std::borrow::Cow;
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
mut directives: Vec<NormalizedDirective2>,
) -> Vec<NormalizedDirective2> {
for directive in directives.iter_mut() {
match directive {
NormalizedDirective2::Method(func) => {
replace_with_ptx_impl(resolver, func.name);
}
_ => {}
}
}
directives
}
fn replace_with_ptx_impl<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_name: SpirvWord,
) {
let known_names = ["__assertfail"];
if let Some(super::IdentEntry {
name: Some(name), ..
}) = resolver.ident_map.get_mut(&fn_name)
{
if known_names.contains(&&**name) {
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
}
}
}
use std::borrow::Cow;
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
mut directives: Vec<NormalizedDirective2>,
) -> Vec<NormalizedDirective2> {
for directive in directives.iter_mut() {
match directive {
NormalizedDirective2::Method(func) => {
replace_with_ptx_impl(resolver, func.name);
}
_ => {}
}
}
directives
}
fn replace_with_ptx_impl<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_name: SpirvWord,
) {
let known_names = ["__assertfail"];
if let Some(super::IdentEntry {
name: Some(name), ..
}) = resolver.ident_map.get_mut(&fn_name)
{
if known_names.contains(&&**name) {
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
}
}
}

View File

@ -1,69 +1,69 @@
use super::*;
use ptx_parser as ast;
use rustc_hash::FxHashSet;
pub(crate) fn run<'input>(
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut functions = FxHashSet::default();
directives
.into_iter()
.map(|directive| run_directive(&mut functions, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
functions: &mut FxHashSet<SpirvWord>,
directive: UnconditionalDirective,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
if !method.is_kernel {
functions.insert(method.name);
}
Directive2::Method(run_method(functions, method)?)
}
})
}
fn run_method<'input>(
functions: &mut FxHashSet<SpirvWord>,
method: UnconditionalFunction,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
statements
.into_iter()
.map(|statement| run_statement(functions, statement))
.collect::<Result<Vec<_>, _>>()
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'input>(
functions: &mut FxHashSet<SpirvWord>,
statement: UnconditionalStatement,
) -> Result<UnconditionalStatement, TranslateError> {
Ok(match statement {
Statement::Instruction(ast::Instruction::Mov {
data,
arguments:
ast::MovArgs {
dst: ast::ParsedOperand::Reg(dst_reg),
src: ast::ParsedOperand::Reg(src_reg),
},
}) if functions.contains(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(error_mismatched_type());
}
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg,
src: src_reg,
})
}
s => s,
})
}
use super::*;
use ptx_parser as ast;
use rustc_hash::FxHashSet;
pub(crate) fn run<'input>(
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut functions = FxHashSet::default();
directives
.into_iter()
.map(|directive| run_directive(&mut functions, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
functions: &mut FxHashSet<SpirvWord>,
directive: UnconditionalDirective,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
if !method.is_kernel {
functions.insert(method.name);
}
Directive2::Method(run_method(functions, method)?)
}
})
}
fn run_method<'input>(
functions: &mut FxHashSet<SpirvWord>,
method: UnconditionalFunction,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
statements
.into_iter()
.map(|statement| run_statement(functions, statement))
.collect::<Result<Vec<_>, _>>()
})
.transpose()?;
Ok(Function2 { body, ..method })
}
fn run_statement<'input>(
functions: &mut FxHashSet<SpirvWord>,
statement: UnconditionalStatement,
) -> Result<UnconditionalStatement, TranslateError> {
Ok(match statement {
Statement::Instruction(ast::Instruction::Mov {
data,
arguments:
ast::MovArgs {
dst: ast::ParsedOperand::Reg(dst_reg),
src: ast::ParsedOperand::Reg(src_reg),
},
}) if functions.contains(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(error_mismatched_type());
}
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg,
src: src_reg,
})
}
s => s,
})
}

View File

@ -1,327 +1,327 @@
use bpaf::{Args, Bpaf, Parser};
use cargo_metadata::{MetadataCommand, Package};
use serde::Deserialize;
use std::{env, ffi::OsString, path::PathBuf, process::Command};
#[derive(Debug, Clone, Bpaf)]
#[bpaf(options)]
enum Options {
#[bpaf(command)]
/// Compile ZLUDA (default command)
Build(#[bpaf(external(build))] Build),
#[bpaf(command)]
/// Compile ZLUDA and build a package
Zip(#[bpaf(external(build))] Build),
}
#[derive(Debug, Clone, Bpaf)]
struct Build {
#[bpaf(any("CARGO", not_help), many)]
/// Arguments to pass to cargo, e.g. `--release` for release build
cargo_arguments: Vec<OsString>,
}
fn not_help(s: OsString) -> Option<OsString> {
if s == "-h" || s == "--help" {
None
} else {
Some(s)
}
}
// We need to sniff out some args passed to cargo to understand how to create
// symlinks (should they go into `target/debug`, `target/release` or custom)
#[derive(Debug, Clone, Bpaf)]
struct Cargo {
#[bpaf(switch, long, short)]
release: Option<bool>,
#[bpaf(long)]
profile: Option<String>,
#[bpaf(any("", Some), many)]
_unused: Vec<OsString>,
}
struct Project {
name: String,
target_name: String,
target_kind: ProjectTarget,
meta: ZludaMetadata,
}
impl Project {
fn try_new(p: Package) -> Option<Project> {
let name = p.name;
serde_json::from_value::<Option<Metadata>>(p.metadata)
.unwrap()
.map(|m| {
let (target_name, target_kind) = p
.targets
.into_iter()
.find_map(|target| {
if target.is_cdylib() {
Some((target.name, ProjectTarget::Cdylib))
} else if target.is_bin() {
Some((target.name, ProjectTarget::Bin))
} else {
None
}
})
.unwrap();
Self {
name,
target_name,
target_kind,
meta: m.zluda,
}
})
}
#[cfg(unix)]
fn prefix(&self) -> &'static str {
match self.target_kind {
ProjectTarget::Bin => "",
ProjectTarget::Cdylib => "lib",
}
}
#[cfg(not(unix))]
fn prefix(&self) -> &'static str {
""
}
#[cfg(unix)]
fn suffix(&self) -> &'static str {
match self.target_kind {
ProjectTarget::Bin => "",
ProjectTarget::Cdylib => ".so",
}
}
#[cfg(not(unix))]
fn suffix(&self) -> &'static str {
match self.target_kind {
ProjectTarget::Bin => ".exe",
ProjectTarget::Cdylib => ".dll",
}
}
// Returns tuple:
// * symlink file path (relative to the root of build dir)
// * symlink absolute file path
// * target actual file (relative to symlink file)
#[cfg_attr(not(unix), allow(unused))]
fn symlinks<'a>(
&'a self,
target_dir: &'a PathBuf,
profile: &'a str,
libname: &'a str,
) -> impl Iterator<Item = (&'a str, PathBuf, PathBuf)> + 'a {
self.meta.linux_symlinks.iter().map(move |source| {
let mut link = target_dir.clone();
link.extend([profile, source]);
let relative_link = PathBuf::from(source);
let ancestors = relative_link.as_path().ancestors().count();
let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold(
PathBuf::new(),
|mut buff, segment| {
buff.push(segment);
buff
},
);
target.push(libname);
(&**source, link, target)
})
}
fn file_name(&self) -> String {
let target_name = &self.target_name;
let prefix = self.prefix();
let suffix = self.suffix();
format!("{prefix}{target_name}{suffix}")
}
}
#[derive(Clone, Copy)]
enum ProjectTarget {
Cdylib,
Bin,
}
#[derive(Deserialize)]
struct Metadata {
zluda: ZludaMetadata,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct ZludaMetadata {
#[serde(default)]
windows_only: bool,
#[serde(default)]
debug_only: bool,
#[cfg_attr(not(unix), allow(unused))]
#[serde(default)]
linux_symlinks: Vec<String>,
}
fn main() {
let options = match options().run_inner(Args::current_args()) {
Ok(b) => b,
Err(err) => match build().to_options().run_inner(Args::current_args()) {
Ok(b) => Options::Build(b),
Err(_) => {
err.print_message(100);
std::process::exit(err.exit_code());
}
},
};
match options {
Options::Build(b) => {
compile(b);
}
Options::Zip(b) => zip(b),
}
}
fn compile(b: Build) -> (PathBuf, String, Vec<Project>) {
let profile = sniff_out_profile_name(&b.cargo_arguments);
let meta = MetadataCommand::new().no_deps().exec().unwrap();
let target_directory = meta.target_directory.into_std_path_buf();
let projects = meta
.packages
.into_iter()
.filter_map(Project::try_new)
.filter(|project| {
if project.meta.windows_only && cfg!(not(windows)) {
return false;
}
if project.meta.debug_only && profile != "debug" {
return false;
}
true
})
.collect::<Vec<_>>();
let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
let mut command = Command::new(&cargo);
command.arg("build");
command.arg("--locked");
for project in projects.iter() {
command.arg("--package");
command.arg(&project.name);
}
command.args(b.cargo_arguments);
assert!(command.status().unwrap().success());
os::make_symlinks(&target_directory, &*projects, &*profile);
(target_directory, profile, projects)
}
fn sniff_out_profile_name(b: &[OsString]) -> String {
let parsed_cargo_arguments = cargo().to_options().run_inner(b);
match parsed_cargo_arguments {
Ok(Cargo {
release: Some(true),
..
}) => "release".to_string(),
Ok(Cargo {
profile: Some(profile),
..
}) => profile,
_ => "debug".to_string(),
}
}
fn zip(zip: Build) {
let (target_dir, profile, projects) = compile(zip);
os::zip(target_dir, profile, projects)
}
#[cfg(unix)]
mod os {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::{
fs::{self, File},
path::PathBuf,
};
use tar::Header;
pub fn make_symlinks(
target_directory: &std::path::PathBuf,
projects: &[super::Project],
profile: &str,
) {
use std::os::unix::fs as unix_fs;
for project in projects.iter() {
let libname = project.file_name();
for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) {
let mut dir = full_path.clone();
assert!(dir.pop());
fs::create_dir_all(dir).unwrap();
fs::remove_file(&full_path).ok();
unix_fs::symlink(&target, full_path).unwrap();
}
}
}
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
let tar_gz =
File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap();
let enc = GzEncoder::new(tar_gz, Compression::default());
let mut tar = tar::Builder::new(enc);
for project in projects.iter() {
let file_name = project.file_name();
let mut file =
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
tar.append_file(format!("zluda/{file_name}"), &mut file)
.unwrap();
for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) {
let mut header = Header::new_gnu();
let meta = fs::symlink_metadata(&full_path).unwrap();
header.set_metadata(&meta);
tar.append_link(&mut header, format!("zluda/{source}"), target)
.unwrap();
}
}
tar.finish().unwrap();
}
}
#[cfg(not(unix))]
mod os {
use std::{fs::File, io, path::PathBuf};
use zip::{write::SimpleFileOptions, ZipWriter};
pub fn make_symlinks(
_target_directory: &std::path::PathBuf,
_projects: &[super::Project],
_profile: &str,
) {
}
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
let zip_file =
File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap();
let mut zip = ZipWriter::new(zip_file);
zip.add_directory("zluda", SimpleFileOptions::default())
.unwrap();
for project in projects.iter() {
let file_name = project.file_name();
let mut file =
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
let file_options = file_options_from_time(&file).unwrap_or_default();
zip.start_file(format!("zluda/{file_name}"), file_options)
.unwrap();
io::copy(&mut file, &mut zip).unwrap();
}
zip.finish().unwrap();
}
fn file_options_from_time(from: &File) -> io::Result<SimpleFileOptions> {
let metadata = from.metadata()?;
let modified = metadata.modified()?;
let modified = time::OffsetDateTime::from(modified);
Ok(SimpleFileOptions::default().last_modified_time(
zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?,
))
}
}
use bpaf::{Args, Bpaf, Parser};
use cargo_metadata::{MetadataCommand, Package};
use serde::Deserialize;
use std::{env, ffi::OsString, path::PathBuf, process::Command};
#[derive(Debug, Clone, Bpaf)]
#[bpaf(options)]
enum Options {
#[bpaf(command)]
/// Compile ZLUDA (default command)
Build(#[bpaf(external(build))] Build),
#[bpaf(command)]
/// Compile ZLUDA and build a package
Zip(#[bpaf(external(build))] Build),
}
#[derive(Debug, Clone, Bpaf)]
struct Build {
#[bpaf(any("CARGO", not_help), many)]
/// Arguments to pass to cargo, e.g. `--release` for release build
cargo_arguments: Vec<OsString>,
}
fn not_help(s: OsString) -> Option<OsString> {
if s == "-h" || s == "--help" {
None
} else {
Some(s)
}
}
// We need to sniff out some args passed to cargo to understand how to create
// symlinks (should they go into `target/debug`, `target/release` or custom)
#[derive(Debug, Clone, Bpaf)]
struct Cargo {
#[bpaf(switch, long, short)]
release: Option<bool>,
#[bpaf(long)]
profile: Option<String>,
#[bpaf(any("", Some), many)]
_unused: Vec<OsString>,
}
struct Project {
name: String,
target_name: String,
target_kind: ProjectTarget,
meta: ZludaMetadata,
}
impl Project {
fn try_new(p: Package) -> Option<Project> {
let name = p.name;
serde_json::from_value::<Option<Metadata>>(p.metadata)
.unwrap()
.map(|m| {
let (target_name, target_kind) = p
.targets
.into_iter()
.find_map(|target| {
if target.is_cdylib() {
Some((target.name, ProjectTarget::Cdylib))
} else if target.is_bin() {
Some((target.name, ProjectTarget::Bin))
} else {
None
}
})
.unwrap();
Self {
name,
target_name,
target_kind,
meta: m.zluda,
}
})
}
#[cfg(unix)]
fn prefix(&self) -> &'static str {
match self.target_kind {
ProjectTarget::Bin => "",
ProjectTarget::Cdylib => "lib",
}
}
#[cfg(not(unix))]
fn prefix(&self) -> &'static str {
""
}
#[cfg(unix)]
fn suffix(&self) -> &'static str {
match self.target_kind {
ProjectTarget::Bin => "",
ProjectTarget::Cdylib => ".so",
}
}
#[cfg(not(unix))]
fn suffix(&self) -> &'static str {
match self.target_kind {
ProjectTarget::Bin => ".exe",
ProjectTarget::Cdylib => ".dll",
}
}
// Returns tuple:
// * symlink file path (relative to the root of build dir)
// * symlink absolute file path
// * target actual file (relative to symlink file)
#[cfg_attr(not(unix), allow(unused))]
fn symlinks<'a>(
&'a self,
target_dir: &'a PathBuf,
profile: &'a str,
libname: &'a str,
) -> impl Iterator<Item = (&'a str, PathBuf, PathBuf)> + 'a {
self.meta.linux_symlinks.iter().map(move |source| {
let mut link = target_dir.clone();
link.extend([profile, source]);
let relative_link = PathBuf::from(source);
let ancestors = relative_link.as_path().ancestors().count();
let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold(
PathBuf::new(),
|mut buff, segment| {
buff.push(segment);
buff
},
);
target.push(libname);
(&**source, link, target)
})
}
fn file_name(&self) -> String {
let target_name = &self.target_name;
let prefix = self.prefix();
let suffix = self.suffix();
format!("{prefix}{target_name}{suffix}")
}
}
#[derive(Clone, Copy)]
enum ProjectTarget {
Cdylib,
Bin,
}
#[derive(Deserialize)]
struct Metadata {
zluda: ZludaMetadata,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct ZludaMetadata {
#[serde(default)]
windows_only: bool,
#[serde(default)]
debug_only: bool,
#[cfg_attr(not(unix), allow(unused))]
#[serde(default)]
linux_symlinks: Vec<String>,
}
fn main() {
let options = match options().run_inner(Args::current_args()) {
Ok(b) => b,
Err(err) => match build().to_options().run_inner(Args::current_args()) {
Ok(b) => Options::Build(b),
Err(_) => {
err.print_message(100);
std::process::exit(err.exit_code());
}
},
};
match options {
Options::Build(b) => {
compile(b);
}
Options::Zip(b) => zip(b),
}
}
fn compile(b: Build) -> (PathBuf, String, Vec<Project>) {
let profile = sniff_out_profile_name(&b.cargo_arguments);
let meta = MetadataCommand::new().no_deps().exec().unwrap();
let target_directory = meta.target_directory.into_std_path_buf();
let projects = meta
.packages
.into_iter()
.filter_map(Project::try_new)
.filter(|project| {
if project.meta.windows_only && cfg!(not(windows)) {
return false;
}
if project.meta.debug_only && profile != "debug" {
return false;
}
true
})
.collect::<Vec<_>>();
let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
let mut command = Command::new(&cargo);
command.arg("build");
command.arg("--locked");
for project in projects.iter() {
command.arg("--package");
command.arg(&project.name);
}
command.args(b.cargo_arguments);
assert!(command.status().unwrap().success());
os::make_symlinks(&target_directory, &*projects, &*profile);
(target_directory, profile, projects)
}
fn sniff_out_profile_name(b: &[OsString]) -> String {
let parsed_cargo_arguments = cargo().to_options().run_inner(b);
match parsed_cargo_arguments {
Ok(Cargo {
release: Some(true),
..
}) => "release".to_string(),
Ok(Cargo {
profile: Some(profile),
..
}) => profile,
_ => "debug".to_string(),
}
}
fn zip(zip: Build) {
let (target_dir, profile, projects) = compile(zip);
os::zip(target_dir, profile, projects)
}
#[cfg(unix)]
mod os {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::{
fs::{self, File},
path::PathBuf,
};
use tar::Header;
pub fn make_symlinks(
target_directory: &std::path::PathBuf,
projects: &[super::Project],
profile: &str,
) {
use std::os::unix::fs as unix_fs;
for project in projects.iter() {
let libname = project.file_name();
for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) {
let mut dir = full_path.clone();
assert!(dir.pop());
fs::create_dir_all(dir).unwrap();
fs::remove_file(&full_path).ok();
unix_fs::symlink(&target, full_path).unwrap();
}
}
}
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
let tar_gz =
File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap();
let enc = GzEncoder::new(tar_gz, Compression::default());
let mut tar = tar::Builder::new(enc);
for project in projects.iter() {
let file_name = project.file_name();
let mut file =
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
tar.append_file(format!("zluda/{file_name}"), &mut file)
.unwrap();
for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) {
let mut header = Header::new_gnu();
let meta = fs::symlink_metadata(&full_path).unwrap();
header.set_metadata(&meta);
tar.append_link(&mut header, format!("zluda/{source}"), target)
.unwrap();
}
}
tar.finish().unwrap();
}
}
#[cfg(not(unix))]
mod os {
use std::{fs::File, io, path::PathBuf};
use zip::{write::SimpleFileOptions, ZipWriter};
pub fn make_symlinks(
_target_directory: &std::path::PathBuf,
_projects: &[super::Project],
_profile: &str,
) {
}
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
let zip_file =
File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap();
let mut zip = ZipWriter::new(zip_file);
zip.add_directory("zluda", SimpleFileOptions::default())
.unwrap();
for project in projects.iter() {
let file_name = project.file_name();
let mut file =
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
let file_options = file_options_from_time(&file).unwrap_or_default();
zip.start_file(format!("zluda/{file_name}"), file_options)
.unwrap();
io::copy(&mut file, &mut zip).unwrap();
}
zip.finish().unwrap();
}
fn file_options_from_time(from: &File) -> io::Result<SimpleFileOptions> {
let metadata = from.metadata()?;
let modified = metadata.modified()?;
let modified = time::OffsetDateTime::from(modified);
Ok(SimpleFileOptions::default().last_modified_time(
zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?,
))
}
}

View File

@ -1,426 +1,426 @@
use super::{FromCuda, LiveCheck};
use crate::r#impl::{context, device};
use comgr::Comgr;
use cuda_types::cuda::*;
use hip_runtime_sys::*;
use std::{
ffi::{c_void, CStr, CString},
mem, ptr, slice,
sync::OnceLock,
usize,
};
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
mod os;
pub(crate) struct GlobalState {
pub devices: Vec<Device>,
pub comgr: Comgr,
}
pub(crate) struct Device {
pub(crate) _comgr_isa: CString,
primary_context: LiveCheck<context::Context>,
}
impl Device {
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
unsafe {
(
self.primary_context.data.assume_init_ref(),
self.primary_context.as_handle(),
)
}
}
}
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
global_state()?
.devices
.get(dev as usize)
.ok_or(CUerror::INVALID_DEVICE)
}
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
}
GLOBAL_STATE
.get_or_init(|| {
let mut device_count = 0;
unsafe { hipGetDeviceCount(&mut device_count) }?;
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
Ok(GlobalState {
comgr,
devices: (0..device_count)
.map(|i| {
let mut props = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
Ok::<_, CUerror>(Device {
_comgr_isa: CStr::from_bytes_until_nul(cast_slice(
&props.gcnArchName[..],
))
.map_err(|_| CUerror::UNKNOWN)?
.to_owned(),
primary_context: LiveCheck::new(context::Context::new(i)),
})
})
.collect::<Result<Vec<_>, _>>()?,
})
})
.as_ref()
.map_err(|e| *e)
}
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
unsafe { hipInit(flags) }?;
global_state()?;
Ok(())
}
struct UnknownBuffer<const S: usize> {
buffer: std::cell::UnsafeCell<[u32; S]>,
}
impl<const S: usize> UnknownBuffer<S> {
const fn new() -> Self {
UnknownBuffer {
buffer: std::cell::UnsafeCell::new([0; S]),
}
}
const fn len(&self) -> usize {
S
}
}
unsafe impl<const S: usize> Sync for UnknownBuffer<S> {}
static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new();
static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new();
struct DarkApi {}
impl ::dark_api::cuda::CudaDarkApi for DarkApi {
unsafe extern "system" fn get_module_from_cubin(
_module: *mut cuda_types::cuda::CUmodule,
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn cudart_interface_fn2(
pctx: *mut cuda_types::cuda::CUcontext,
hip_dev: hipDevice_t,
) -> cuda_types::cuda::CUresult {
let pctx = match pctx.as_mut() {
Some(p) => p,
None => return CUresult::ERROR_INVALID_VALUE,
};
device::primary_context_retain(pctx, hip_dev)
}
unsafe extern "system" fn get_module_from_cubin_ext1(
_result: *mut cuda_types::cuda::CUmodule,
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
_arg3: *mut std::ffi::c_void,
_arg4: *mut std::ffi::c_void,
_arg5: u32,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn get_module_from_cubin_ext2(
_fatbin_header: *const cuda_types::dark_api::FatbinHeader,
_result: *mut cuda_types::cuda::CUmodule,
_arg3: *mut std::ffi::c_void,
_arg4: *mut std::ffi::c_void,
_arg5: u32,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn get_unknown_buffer1(
ptr: *mut *mut std::ffi::c_void,
size: *mut usize,
) -> () {
*ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void;
*size = UNKNOWN_BUFFER1.len();
}
unsafe extern "system" fn get_unknown_buffer2(
ptr: *mut *mut std::ffi::c_void,
size: *mut usize,
) -> () {
*ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void;
*size = UNKNOWN_BUFFER2.len();
}
unsafe extern "system" fn context_local_storage_put(
cu_ctx: CUcontext,
key: *mut c_void,
value: *mut c_void,
dtor_cb: Option<extern "system" fn(CUcontext, *mut c_void, *mut c_void)>,
) -> CUresult {
let _ctx = if cu_ctx.0 != ptr::null_mut() {
cu_ctx
} else {
let mut current_ctx: CUcontext = CUcontext(ptr::null_mut());
context::get_current(&mut current_ctx)?;
current_ctx
};
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
state.storage.insert(
key as usize,
context::StorageData {
value: value as usize,
reset_cb: dtor_cb,
handle: _ctx,
},
);
Ok(())
})?;
Ok(())
}
unsafe extern "system" fn context_local_storage_delete(
cu_ctx: CUcontext,
key: *mut c_void,
) -> CUresult {
let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?;
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
state.storage.remove(&(key as usize));
Ok(())
})?;
Ok(())
}
unsafe extern "system" fn context_local_storage_get(
value: *mut *mut c_void,
cu_ctx: CUcontext,
key: *mut c_void,
) -> CUresult {
let mut _ctx: CUcontext;
if cu_ctx.0 == ptr::null_mut() {
_ctx = context::get_current_context()?;
} else {
_ctx = cu_ctx
};
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
ctx_obj.with_state(|state: &context::ContextState| {
match state.storage.get(&(key as usize)) {
Some(data) => *value = data.value as *mut c_void,
None => return CUresult::ERROR_INVALID_HANDLE,
}
Ok(())
})?;
Ok(())
}
unsafe extern "system" fn ctx_create_v2_bypass(
_pctx: *mut cuda_types::cuda::CUcontext,
_flags: ::std::os::raw::c_uint,
_dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn heap_alloc(
_heap_alloc_record_ptr: *mut *const std::ffi::c_void,
_arg2: usize,
_arg3: usize,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn heap_free(
_heap_alloc_record_ptr: *const std::ffi::c_void,
_arg2: *mut usize,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn device_get_attribute_ext(
_dev: cuda_types::cuda::CUdevice,
_attribute: std::ffi::c_uint,
_unknown: std::ffi::c_int,
_result: *mut [usize; 2],
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn device_get_something(
_result: *mut std::ffi::c_uchar,
_dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn integrity_check(
version: u32,
unix_seconds: u64,
result: *mut [u64; 2],
) -> cuda_types::cuda::CUresult {
let current_process = std::process::id();
let current_thread = os::current_thread();
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
let devices = get_device_hash_info()?;
let device_count = devices.len() as u32;
let get_device = |dev| devices[dev as usize];
let hash = ::dark_api::integrity_check(
version,
unix_seconds,
cuda_types::cuda::CUDA_VERSION,
current_process,
current_thread,
integrity_check_table,
cudart_table,
fn_address,
device_count,
get_device,
);
*result = hash;
Ok(())
}
unsafe extern "system" fn context_check(
_ctx_in: cuda_types::cuda::CUcontext,
result1: *mut u32,
_result2: *mut *const std::ffi::c_void,
) -> cuda_types::cuda::CUresult {
*result1 = 0;
CUresult::SUCCESS
}
unsafe extern "system" fn check_fn3() -> u32 {
0
}
}
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
let mut device_count = 0;
device::get_count(&mut device_count)?;
(0..device_count)
.map(|dev| {
let mut guid = CUuuid_st { bytes: [0; 16] };
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
let mut pci_domain = 0;
device::get_attribute(
&mut pci_domain,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
dev,
)?;
let mut pci_bus = 0;
device::get_attribute(
&mut pci_bus,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
dev,
)?;
let mut pci_device = 0;
device::get_attribute(
&mut pci_device,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
dev,
)?;
Ok(::dark_api::DeviceHashinfo {
guid,
pci_domain,
pci_bus,
pci_device,
})
})
.collect()
}
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();
pub(crate) fn get_export_table(
pp_export_table: &mut *const ::core::ffi::c_void,
p_export_table_id: &CUuuid,
) -> CUresult {
if let Some(table) = EXPORT_TABLE.get(p_export_table_id) {
*pp_export_table = table.start();
cuda_types::cuda::CUresult::SUCCESS
} else {
cuda_types::cuda::CUresult::ERROR_INVALID_VALUE
}
}
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
*version = cuda_types::cuda::CUDA_VERSION as i32;
Ok(())
}
pub(crate) unsafe fn get_proc_address(
symbol: &CStr,
pfn: &mut *mut ::core::ffi::c_void,
cuda_version: ::core::ffi::c_int,
flags: cuda_types::cuda::cuuint64_t,
) -> CUresult {
get_proc_address_v2(symbol, pfn, cuda_version, flags, None)
}
pub(crate) unsafe fn get_proc_address_v2(
symbol: &CStr,
pfn: &mut *mut ::core::ffi::c_void,
cuda_version: ::core::ffi::c_int,
flags: cuda_types::cuda::cuuint64_t,
symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>,
) -> CUresult {
// This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point.
fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void {
use crate::*;
include!("../../../zluda_bindgen/src/process_table.rs")
}
let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version);
match fn_ptr as usize {
0 => {
if let Some(symbol_status) = symbol_status {
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND;
}
*pfn = ptr::null_mut();
CUresult::ERROR_NOT_FOUND
}
usize::MAX => {
if let Some(symbol_status) = symbol_status {
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT;
}
*pfn = ptr::null_mut();
CUresult::ERROR_NOT_FOUND
}
_ => {
if let Some(symbol_status) = symbol_status {
*symbol_status =
cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS;
}
*pfn = fn_ptr;
Ok(())
}
}
}
pub(crate) fn profiler_start() -> CUresult {
Ok(())
}
pub(crate) fn profiler_stop() -> CUresult {
Ok(())
}
use super::{FromCuda, LiveCheck};
use crate::r#impl::{context, device};
use comgr::Comgr;
use cuda_types::cuda::*;
use hip_runtime_sys::*;
use std::{
ffi::{c_void, CStr, CString},
mem, ptr, slice,
sync::OnceLock,
usize,
};
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
mod os;
pub(crate) struct GlobalState {
pub devices: Vec<Device>,
pub comgr: Comgr,
}
pub(crate) struct Device {
pub(crate) _comgr_isa: CString,
primary_context: LiveCheck<context::Context>,
}
impl Device {
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
unsafe {
(
self.primary_context.data.assume_init_ref(),
self.primary_context.as_handle(),
)
}
}
}
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
global_state()?
.devices
.get(dev as usize)
.ok_or(CUerror::INVALID_DEVICE)
}
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
}
GLOBAL_STATE
.get_or_init(|| {
let mut device_count = 0;
unsafe { hipGetDeviceCount(&mut device_count) }?;
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
Ok(GlobalState {
comgr,
devices: (0..device_count)
.map(|i| {
let mut props = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
Ok::<_, CUerror>(Device {
_comgr_isa: CStr::from_bytes_until_nul(cast_slice(
&props.gcnArchName[..],
))
.map_err(|_| CUerror::UNKNOWN)?
.to_owned(),
primary_context: LiveCheck::new(context::Context::new(i)),
})
})
.collect::<Result<Vec<_>, _>>()?,
})
})
.as_ref()
.map_err(|e| *e)
}
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
unsafe { hipInit(flags) }?;
global_state()?;
Ok(())
}
struct UnknownBuffer<const S: usize> {
buffer: std::cell::UnsafeCell<[u32; S]>,
}
impl<const S: usize> UnknownBuffer<S> {
const fn new() -> Self {
UnknownBuffer {
buffer: std::cell::UnsafeCell::new([0; S]),
}
}
const fn len(&self) -> usize {
S
}
}
unsafe impl<const S: usize> Sync for UnknownBuffer<S> {}
static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new();
static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new();
struct DarkApi {}
impl ::dark_api::cuda::CudaDarkApi for DarkApi {
unsafe extern "system" fn get_module_from_cubin(
_module: *mut cuda_types::cuda::CUmodule,
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn cudart_interface_fn2(
pctx: *mut cuda_types::cuda::CUcontext,
hip_dev: hipDevice_t,
) -> cuda_types::cuda::CUresult {
let pctx = match pctx.as_mut() {
Some(p) => p,
None => return CUresult::ERROR_INVALID_VALUE,
};
device::primary_context_retain(pctx, hip_dev)
}
unsafe extern "system" fn get_module_from_cubin_ext1(
_result: *mut cuda_types::cuda::CUmodule,
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
_arg3: *mut std::ffi::c_void,
_arg4: *mut std::ffi::c_void,
_arg5: u32,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn get_module_from_cubin_ext2(
_fatbin_header: *const cuda_types::dark_api::FatbinHeader,
_result: *mut cuda_types::cuda::CUmodule,
_arg3: *mut std::ffi::c_void,
_arg4: *mut std::ffi::c_void,
_arg5: u32,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn get_unknown_buffer1(
ptr: *mut *mut std::ffi::c_void,
size: *mut usize,
) -> () {
*ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void;
*size = UNKNOWN_BUFFER1.len();
}
unsafe extern "system" fn get_unknown_buffer2(
ptr: *mut *mut std::ffi::c_void,
size: *mut usize,
) -> () {
*ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void;
*size = UNKNOWN_BUFFER2.len();
}
unsafe extern "system" fn context_local_storage_put(
cu_ctx: CUcontext,
key: *mut c_void,
value: *mut c_void,
dtor_cb: Option<extern "system" fn(CUcontext, *mut c_void, *mut c_void)>,
) -> CUresult {
let _ctx = if cu_ctx.0 != ptr::null_mut() {
cu_ctx
} else {
let mut current_ctx: CUcontext = CUcontext(ptr::null_mut());
context::get_current(&mut current_ctx)?;
current_ctx
};
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
state.storage.insert(
key as usize,
context::StorageData {
value: value as usize,
reset_cb: dtor_cb,
handle: _ctx,
},
);
Ok(())
})?;
Ok(())
}
unsafe extern "system" fn context_local_storage_delete(
cu_ctx: CUcontext,
key: *mut c_void,
) -> CUresult {
let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?;
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
state.storage.remove(&(key as usize));
Ok(())
})?;
Ok(())
}
unsafe extern "system" fn context_local_storage_get(
value: *mut *mut c_void,
cu_ctx: CUcontext,
key: *mut c_void,
) -> CUresult {
let mut _ctx: CUcontext;
if cu_ctx.0 == ptr::null_mut() {
_ctx = context::get_current_context()?;
} else {
_ctx = cu_ctx
};
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
ctx_obj.with_state(|state: &context::ContextState| {
match state.storage.get(&(key as usize)) {
Some(data) => *value = data.value as *mut c_void,
None => return CUresult::ERROR_INVALID_HANDLE,
}
Ok(())
})?;
Ok(())
}
unsafe extern "system" fn ctx_create_v2_bypass(
_pctx: *mut cuda_types::cuda::CUcontext,
_flags: ::std::os::raw::c_uint,
_dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn heap_alloc(
_heap_alloc_record_ptr: *mut *const std::ffi::c_void,
_arg2: usize,
_arg3: usize,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn heap_free(
_heap_alloc_record_ptr: *const std::ffi::c_void,
_arg2: *mut usize,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn device_get_attribute_ext(
_dev: cuda_types::cuda::CUdevice,
_attribute: std::ffi::c_uint,
_unknown: std::ffi::c_int,
_result: *mut [usize; 2],
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn device_get_something(
_result: *mut std::ffi::c_uchar,
_dev: cuda_types::cuda::CUdevice,
) -> cuda_types::cuda::CUresult {
todo!()
}
unsafe extern "system" fn integrity_check(
version: u32,
unix_seconds: u64,
result: *mut [u64; 2],
) -> cuda_types::cuda::CUresult {
let current_process = std::process::id();
let current_thread = os::current_thread();
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
let devices = get_device_hash_info()?;
let device_count = devices.len() as u32;
let get_device = |dev| devices[dev as usize];
let hash = ::dark_api::integrity_check(
version,
unix_seconds,
cuda_types::cuda::CUDA_VERSION,
current_process,
current_thread,
integrity_check_table,
cudart_table,
fn_address,
device_count,
get_device,
);
*result = hash;
Ok(())
}
unsafe extern "system" fn context_check(
_ctx_in: cuda_types::cuda::CUcontext,
result1: *mut u32,
_result2: *mut *const std::ffi::c_void,
) -> cuda_types::cuda::CUresult {
*result1 = 0;
CUresult::SUCCESS
}
unsafe extern "system" fn check_fn3() -> u32 {
0
}
}
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
let mut device_count = 0;
device::get_count(&mut device_count)?;
(0..device_count)
.map(|dev| {
let mut guid = CUuuid_st { bytes: [0; 16] };
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
let mut pci_domain = 0;
device::get_attribute(
&mut pci_domain,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
dev,
)?;
let mut pci_bus = 0;
device::get_attribute(
&mut pci_bus,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
dev,
)?;
let mut pci_device = 0;
device::get_attribute(
&mut pci_device,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
dev,
)?;
Ok(::dark_api::DeviceHashinfo {
guid,
pci_domain,
pci_bus,
pci_device,
})
})
.collect()
}
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();
pub(crate) fn get_export_table(
pp_export_table: &mut *const ::core::ffi::c_void,
p_export_table_id: &CUuuid,
) -> CUresult {
if let Some(table) = EXPORT_TABLE.get(p_export_table_id) {
*pp_export_table = table.start();
cuda_types::cuda::CUresult::SUCCESS
} else {
cuda_types::cuda::CUresult::ERROR_INVALID_VALUE
}
}
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
*version = cuda_types::cuda::CUDA_VERSION as i32;
Ok(())
}
pub(crate) unsafe fn get_proc_address(
symbol: &CStr,
pfn: &mut *mut ::core::ffi::c_void,
cuda_version: ::core::ffi::c_int,
flags: cuda_types::cuda::cuuint64_t,
) -> CUresult {
get_proc_address_v2(symbol, pfn, cuda_version, flags, None)
}
pub(crate) unsafe fn get_proc_address_v2(
symbol: &CStr,
pfn: &mut *mut ::core::ffi::c_void,
cuda_version: ::core::ffi::c_int,
flags: cuda_types::cuda::cuuint64_t,
symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>,
) -> CUresult {
// This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point.
fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void {
use crate::*;
include!("../../../zluda_bindgen/src/process_table.rs")
}
let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version);
match fn_ptr as usize {
0 => {
if let Some(symbol_status) = symbol_status {
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND;
}
*pfn = ptr::null_mut();
CUresult::ERROR_NOT_FOUND
}
usize::MAX => {
if let Some(symbol_status) = symbol_status {
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT;
}
*pfn = ptr::null_mut();
CUresult::ERROR_NOT_FOUND
}
_ => {
if let Some(symbol_status) = symbol_status {
*symbol_status =
cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS;
}
*pfn = fn_ptr;
Ok(())
}
}
}
pub(crate) fn profiler_start() -> CUresult {
Ok(())
}
pub(crate) fn profiler_stop() -> CUresult {
Ok(())
}

View File

@ -1,9 +1,9 @@
// TODO: remove duplication with zluda_dump
#[link(name = "pthread")]
unsafe extern "C" {
fn pthread_self() -> std::os::unix::thread::RawPthread;
}
pub(crate) fn current_thread() -> u32 {
(unsafe { pthread_self() }) as u32
}
// TODO: remove duplication with zluda_dump
#[link(name = "pthread")]
unsafe extern "C" {
fn pthread_self() -> std::os::unix::thread::RawPthread;
}
pub(crate) fn current_thread() -> u32 {
(unsafe { pthread_self() }) as u32
}

View File

@ -1,9 +1,9 @@
// TODO: remove duplication with zluda_dump
#[link(name = "kernel32")]
unsafe extern "system" {
fn GetCurrentThreadId() -> u32;
}
pub(crate) fn current_thread() -> u32 {
unsafe { GetCurrentThreadId() }
}
// TODO: remove duplication with zluda_dump
#[link(name = "kernel32")]
unsafe extern "system" {
fn GetCurrentThreadId() -> u32;
}
pub(crate) fn current_thread() -> u32 {
unsafe { GetCurrentThreadId() }
}

View File

@ -1,124 +1,124 @@
use crate::os;
use crate::{CudaFunctionName, ErrorEntry};
use cuda_types::cuda::*;
use rustc_hash::FxHashMap;
use std::cell::RefMut;
use std::hash::Hash;
use std::{collections::hash_map, ffi::c_void, mem};
pub(crate) struct DarkApiState2 {
// Key is Box<CUuuid, because thunk reporting unknown export table needs a
// stable memory location for the guid
pub(crate) overrides: FxHashMap<Box<CUuuidWrapper>, (*const *const c_void, Vec<*const c_void>)>,
}
unsafe impl Send for DarkApiState2 {}
unsafe impl Sync for DarkApiState2 {}
impl DarkApiState2 {
pub(crate) fn new() -> Self {
DarkApiState2 {
overrides: FxHashMap::default(),
}
}
pub(crate) fn override_export_table(
&mut self,
known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable,
original_export_table: *const *const c_void,
guid: &CUuuid_st,
) -> (*const *const c_void, Option<ErrorEntry>) {
let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) {
hash_map::Entry::Occupied(entry) => {
let (_, override_table) = entry.get();
return (override_table.as_ptr(), None);
}
hash_map::Entry::Vacant(entry) => entry,
};
let mut error = None;
let byte_size: usize = unsafe { *(original_export_table.cast::<usize>()) };
// Some export tables don't start with a byte count, but directly with a
// pointer, and are instead terminated by 0 or MAX
let export_functions_start_idx;
let export_functions_size;
if byte_size > 0x10000 {
export_functions_start_idx = 0;
let mut i = 0;
loop {
let current_ptr = unsafe { original_export_table.add(i) };
let current_ptr_numeric = unsafe { *current_ptr } as usize;
if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX {
export_functions_size = i;
break;
}
i += 1;
}
} else {
export_functions_start_idx = 1;
export_functions_size = byte_size / mem::size_of::<usize>();
}
let our_functions = known_exports.get(guid);
if let Some(ref our_functions) = our_functions {
if our_functions.len() != export_functions_size {
error = Some(ErrorEntry::UnexpectedExportTableSize {
expected: our_functions.len(),
computed: export_functions_size,
});
}
}
let mut override_table =
unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) }
.to_vec();
for i in export_functions_start_idx..export_functions_size {
let current_fn = (|| {
if let Some(ref our_functions) = our_functions {
if let Some(fn_) = our_functions.get_fn(i) {
return fn_;
}
}
os::get_thunk(
override_table[i],
Self::report_unknown_export_table_call,
std::ptr::from_ref(entry.key().as_ref()).cast(),
i,
)
})();
override_table[i] = current_fn;
}
(
entry
.insert((original_export_table, override_table))
.1
.as_ptr(),
error,
)
}
unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) {
let global_state = crate::GLOBAL_STATE2.lock();
let global_state_ref_cell = &*global_state;
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
let global_state = &mut *global_state_ref_mut;
let log_guard = crate::OuterCallGuard {
writer: &mut global_state.log_writer,
log_root: &global_state.log_stack,
};
{
let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| {
log_stack.enter()
});
logger.name = CudaFunctionName::Dark { guid: *guid, index };
};
drop(log_guard);
}
}
#[derive(Eq, PartialEq)]
#[repr(transparent)]
pub(crate) struct CUuuidWrapper(pub CUuuid);
impl Hash for CUuuidWrapper {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.bytes.hash(state);
}
}
use crate::os;
use crate::{CudaFunctionName, ErrorEntry};
use cuda_types::cuda::*;
use rustc_hash::FxHashMap;
use std::cell::RefMut;
use std::hash::Hash;
use std::{collections::hash_map, ffi::c_void, mem};
pub(crate) struct DarkApiState2 {
// Key is Box<CUuuid, because thunk reporting unknown export table needs a
// stable memory location for the guid
pub(crate) overrides: FxHashMap<Box<CUuuidWrapper>, (*const *const c_void, Vec<*const c_void>)>,
}
unsafe impl Send for DarkApiState2 {}
unsafe impl Sync for DarkApiState2 {}
impl DarkApiState2 {
pub(crate) fn new() -> Self {
DarkApiState2 {
overrides: FxHashMap::default(),
}
}
pub(crate) fn override_export_table(
&mut self,
known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable,
original_export_table: *const *const c_void,
guid: &CUuuid_st,
) -> (*const *const c_void, Option<ErrorEntry>) {
let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) {
hash_map::Entry::Occupied(entry) => {
let (_, override_table) = entry.get();
return (override_table.as_ptr(), None);
}
hash_map::Entry::Vacant(entry) => entry,
};
let mut error = None;
let byte_size: usize = unsafe { *(original_export_table.cast::<usize>()) };
// Some export tables don't start with a byte count, but directly with a
// pointer, and are instead terminated by 0 or MAX
let export_functions_start_idx;
let export_functions_size;
if byte_size > 0x10000 {
export_functions_start_idx = 0;
let mut i = 0;
loop {
let current_ptr = unsafe { original_export_table.add(i) };
let current_ptr_numeric = unsafe { *current_ptr } as usize;
if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX {
export_functions_size = i;
break;
}
i += 1;
}
} else {
export_functions_start_idx = 1;
export_functions_size = byte_size / mem::size_of::<usize>();
}
let our_functions = known_exports.get(guid);
if let Some(ref our_functions) = our_functions {
if our_functions.len() != export_functions_size {
error = Some(ErrorEntry::UnexpectedExportTableSize {
expected: our_functions.len(),
computed: export_functions_size,
});
}
}
let mut override_table =
unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) }
.to_vec();
for i in export_functions_start_idx..export_functions_size {
let current_fn = (|| {
if let Some(ref our_functions) = our_functions {
if let Some(fn_) = our_functions.get_fn(i) {
return fn_;
}
}
os::get_thunk(
override_table[i],
Self::report_unknown_export_table_call,
std::ptr::from_ref(entry.key().as_ref()).cast(),
i,
)
})();
override_table[i] = current_fn;
}
(
entry
.insert((original_export_table, override_table))
.1
.as_ptr(),
error,
)
}
unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) {
let global_state = crate::GLOBAL_STATE2.lock();
let global_state_ref_cell = &*global_state;
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
let global_state = &mut *global_state_ref_mut;
let log_guard = crate::OuterCallGuard {
writer: &mut global_state.log_writer,
log_root: &global_state.log_stack,
};
{
let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| {
log_stack.enter()
});
logger.name = CudaFunctionName::Dark { guid: *guid, index };
};
drop(log_guard);
}
}
#[derive(Eq, PartialEq)]
#[repr(transparent)]
pub(crate) struct CUuuidWrapper(pub CUuuid);
impl Hash for CUuuidWrapper {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.bytes.hash(state);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,81 +1,81 @@
use cuda_types::cuda::CUuuid;
use std::ffi::{c_void, CStr, CString};
use std::mem;
pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1";
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path = CString::new(libcuda_path).unwrap();
libc::dlopen(
libcuda_path.as_ptr() as *const _,
libc::RTLD_LOCAL | libc::RTLD_NOW,
)
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
libc::dlsym(handle, func.as_ptr() as *const _)
}
#[macro_export]
macro_rules! os_log {
($format:tt) => {
{
eprintln!("[ZLUDA_DUMP] {}", format!($format));
}
};
($format:tt, $($obj: expr),+) => {
{
eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+));
}
};
}
//RDI, RSI, RDX, RCX, R8, R9
#[cfg(target_arch = "x86_64")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x64::Assembler::new().unwrap();
let start = ops.offset();
dynasm!(ops
// stack alignment
; sub rsp, 8
; push rdi
; push rsi
; push rdx
; push rcx
; push r8
; push r9
; mov rdi, QWORD guid as i64
; mov rsi, QWORD idx as i64
; mov rax, QWORD report_fn as i64
; call rax
; pop r9
; pop r8
; pop rcx
; pop rdx
; pop rsi
; pop rdi
; add rsp, 8
; mov rax, QWORD original_fn as i64
; jmp rax
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}
#[link(name = "pthread")]
unsafe extern "C" {
fn pthread_self() -> std::os::unix::thread::RawPthread;
}
pub(crate) fn current_thread() -> u32 {
(unsafe { pthread_self() }) as u32
}
use cuda_types::cuda::CUuuid;
use std::ffi::{c_void, CStr, CString};
use std::mem;
pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1";
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path = CString::new(libcuda_path).unwrap();
libc::dlopen(
libcuda_path.as_ptr() as *const _,
libc::RTLD_LOCAL | libc::RTLD_NOW,
)
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
libc::dlsym(handle, func.as_ptr() as *const _)
}
#[macro_export]
macro_rules! os_log {
($format:tt) => {
{
eprintln!("[ZLUDA_DUMP] {}", format!($format));
}
};
($format:tt, $($obj: expr),+) => {
{
eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+));
}
};
}
//RDI, RSI, RDX, RCX, R8, R9
#[cfg(target_arch = "x86_64")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x64::Assembler::new().unwrap();
let start = ops.offset();
dynasm!(ops
// stack alignment
; sub rsp, 8
; push rdi
; push rsi
; push rdx
; push rcx
; push r8
; push r9
; mov rdi, QWORD guid as i64
; mov rsi, QWORD idx as i64
; mov rax, QWORD report_fn as i64
; call rax
; pop r9
; pop r8
; pop rcx
; pop rdx
; pop rsi
; pop rdi
; add rsp, 8
; mov rax, QWORD original_fn as i64
; jmp rax
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}
#[link(name = "pthread")]
unsafe extern "C" {
fn pthread_self() -> std::os::unix::thread::RawPthread;
}
pub(crate) fn current_thread() -> u32 {
(unsafe { pthread_self() }) as u32
}

View File

@ -1,190 +1,190 @@
use std::{
ffi::{c_void, CStr},
mem, ptr,
sync::LazyLock,
};
use std::os::windows::io::AsRawHandle;
use winapi::{
shared::minwindef::{FARPROC, HMODULE},
um::debugapi::OutputDebugStringA,
um::libloaderapi::{GetProcAddress, LoadLibraryW},
};
use cuda_types::cuda::CUuuid;
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
static PLATFORM_LIBRARY: LazyLock<PlatformLibrary> =
LazyLock::new(|| unsafe { PlatformLibrary::new() });
#[allow(non_snake_case)]
struct PlatformLibrary {
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
}
impl PlatformLibrary {
#[allow(non_snake_case)]
unsafe fn new() -> Self {
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
None => (
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
mem::transmute(
GetProcAddress
as unsafe extern "system" fn(
hModule: HMODULE,
lpProcName: *const i8,
) -> FARPROC,
),
),
Some(zluda_with) => (
mem::transmute(GetProcAddress(
zluda_with,
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
)),
mem::transmute(GetProcAddress(
zluda_with,
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
)),
),
};
PlatformLibrary {
LoadLibraryW,
GetProcAddress,
}
}
unsafe fn get_detourer_module() -> Option<HMODULE> {
let mut module = ptr::null_mut();
loop {
module = detours_sys::DetourEnumerateModules(module);
if module == ptr::null_mut() {
break;
}
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
if payload != ptr::null_mut() {
return Some(module as _);
}
}
None
}
}
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path_uf16 = libcuda_path
.encode_utf16()
.chain(std::iter::once(0))
.collect::<Vec<_>>();
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
}
#[macro_export]
macro_rules! os_log {
($format:tt) => {
{
use crate::os::__log_impl;
__log_impl(format!($format));
}
};
($format:tt, $($obj: expr),+) => {
{
use crate::os::__log_impl;
__log_impl(format!($format, $($obj,)+));
}
};
}
pub fn __log_impl(s: String) {
let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut();
if log_to_stderr {
eprintln!("[ZLUDA_DUMP] {}", s);
} else {
let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2);
win_str.push_str("[ZLUDA_DUMP] ");
win_str.push_str(&s);
win_str.push_str("\n\0");
unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) };
}
}
#[cfg(target_arch = "x86")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset();
dynasm!(ops
; .arch x86
; push idx as i32
; push guid as i32
; mov eax, report_fn as i32
; call eax
; mov eax, original_fn as i32
; jmp eax
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}
//RCX, RDX, R8, R9
#[cfg(target_arch = "x86_64")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset();
// Let's hope there's never more than 4 arguments
dynasm!(ops
; .arch x64
; push rbp
; mov rbp, rsp
; push rcx
; push rdx
; push r8
; push r9
; mov rcx, QWORD guid as i64
; mov rdx, QWORD idx as i64
; mov rax, QWORD report_fn as i64
; call rax
; pop r9
; pop r8
; pop rdx
; pop rcx
; mov rax, QWORD original_fn as i64
; call rax
; pop rbp
; ret
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}
#[link(name = "kernel32")]
unsafe extern "system" {
fn GetCurrentThreadId() -> u32;
}
pub(crate) fn current_thread() -> u32 {
unsafe { GetCurrentThreadId() }
}
use std::{
ffi::{c_void, CStr},
mem, ptr,
sync::LazyLock,
};
use std::os::windows::io::AsRawHandle;
use winapi::{
shared::minwindef::{FARPROC, HMODULE},
um::debugapi::OutputDebugStringA,
um::libloaderapi::{GetProcAddress, LoadLibraryW},
};
use cuda_types::cuda::CUuuid;
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
static PLATFORM_LIBRARY: LazyLock<PlatformLibrary> =
LazyLock::new(|| unsafe { PlatformLibrary::new() });
#[allow(non_snake_case)]
struct PlatformLibrary {
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
}
impl PlatformLibrary {
#[allow(non_snake_case)]
unsafe fn new() -> Self {
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
None => (
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
mem::transmute(
GetProcAddress
as unsafe extern "system" fn(
hModule: HMODULE,
lpProcName: *const i8,
) -> FARPROC,
),
),
Some(zluda_with) => (
mem::transmute(GetProcAddress(
zluda_with,
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
)),
mem::transmute(GetProcAddress(
zluda_with,
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
)),
),
};
PlatformLibrary {
LoadLibraryW,
GetProcAddress,
}
}
unsafe fn get_detourer_module() -> Option<HMODULE> {
let mut module = ptr::null_mut();
loop {
module = detours_sys::DetourEnumerateModules(module);
if module == ptr::null_mut() {
break;
}
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
if payload != ptr::null_mut() {
return Some(module as _);
}
}
None
}
}
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path_uf16 = libcuda_path
.encode_utf16()
.chain(std::iter::once(0))
.collect::<Vec<_>>();
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
}
#[macro_export]
macro_rules! os_log {
($format:tt) => {
{
use crate::os::__log_impl;
__log_impl(format!($format));
}
};
($format:tt, $($obj: expr),+) => {
{
use crate::os::__log_impl;
__log_impl(format!($format, $($obj,)+));
}
};
}
pub fn __log_impl(s: String) {
let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut();
if log_to_stderr {
eprintln!("[ZLUDA_DUMP] {}", s);
} else {
let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2);
win_str.push_str("[ZLUDA_DUMP] ");
win_str.push_str(&s);
win_str.push_str("\n\0");
unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) };
}
}
#[cfg(target_arch = "x86")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset();
dynasm!(ops
; .arch x86
; push idx as i32
; push guid as i32
; mov eax, report_fn as i32
; call eax
; mov eax, original_fn as i32
; jmp eax
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}
//RCX, RDX, R8, R9
#[cfg(target_arch = "x86_64")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "system" fn(&CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset();
// Let's hope there's never more than 4 arguments
dynasm!(ops
; .arch x64
; push rbp
; mov rbp, rsp
; push rcx
; push rdx
; push r8
; push r9
; mov rcx, QWORD guid as i64
; mov rdx, QWORD idx as i64
; mov rax, QWORD report_fn as i64
; call rax
; pop r9
; pop r8
; pop rdx
; pop rcx
; mov rax, QWORD original_fn as i64
; call rax
; pop rbp
; ret
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}
#[link(name = "kernel32")]
unsafe extern "system" {
fn GetCurrentThreadId() -> u32;
}
pub(crate) fn current_thread() -> u32 {
unsafe { GetCurrentThreadId() }
}

View File

@ -1,334 +1,334 @@
use crate::{
log::{self, UInt},
trace, ErrorEntry, FnCallLog, Settings,
};
use cuda_types::{
cuda::*,
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
};
use dark_api::fatbin::{
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{
borrow::Cow,
ffi::{c_void, CStr, CString},
fs::{self, File},
io::{self, Read, Write},
path::PathBuf,
};
use unwrap_or::unwrap_some_or;
// This struct is the heart of CUDA state tracking, it:
// * receives calls from the probes about changes to CUDA state
// * records updates to the state change
// * writes out relevant state change and details to disk and log
pub(crate) struct StateTracker {
writer: DumpWriter,
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
saved_modules: FxHashSet<CUmodule>,
module_counter: usize,
submodule_counter: usize,
pub(crate) override_cc: Option<(u32, u32)>,
}
#[derive(Clone, Copy)]
pub(crate) struct CodePointer(pub *const c_void);
unsafe impl Send for CodePointer {}
unsafe impl Sync for CodePointer {}
impl StateTracker {
pub(crate) fn new(settings: &Settings) -> Self {
StateTracker {
writer: DumpWriter::new(settings.dump_dir.clone()),
libraries: FxHashMap::default(),
saved_modules: FxHashSet::default(),
module_counter: 0,
submodule_counter: 0,
override_cc: settings.override_cc,
}
}
pub(crate) fn record_new_module_file(
&mut self,
module: CUmodule,
file_name: *const i8,
fn_logger: &mut FnCallLog,
) {
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
Ok(f) => f,
Err(err) => {
fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
return;
}
};
let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name);
fn_logger.log_io_error(maybe_io_error)
}
fn try_record_new_module_file(
&mut self,
module: CUmodule,
fn_logger: &mut FnCallLog,
file_name: &str,
) -> io::Result<()> {
let mut module_file = fs::File::open(file_name)?;
let mut read_buff = Vec::new();
module_file.read_to_end(&mut read_buff)?;
self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger);
Ok(())
}
pub(crate) fn record_new_submodule(
&mut self,
module: CUmodule,
submodule: &[u8],
fn_logger: &mut FnCallLog,
type_: &'static str,
) {
if self.saved_modules.insert(module) {
self.module_counter += 1;
self.submodule_counter = 0;
}
self.submodule_counter += 1;
fn_logger.log_io_error(self.writer.save_module(
self.module_counter,
Some(self.submodule_counter),
submodule,
type_,
));
if type_ == "ptx" {
match CString::new(submodule) {
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
Ok(submodule_cstring) => match submodule_cstring.to_str() {
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
Ok(submodule_text) => self.try_parse_and_record_kernels(
fn_logger,
self.module_counter,
Some(self.submodule_counter),
submodule_text,
),
},
}
}
}
pub(crate) fn record_new_module(
&mut self,
module: CUmodule,
raw_image: *const c_void,
fn_logger: &mut FnCallLog,
) {
self.module_counter += 1;
if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG {
self.saved_modules.insert(module);
// TODO: Parse ELF and write it to disk
fn_logger.log(log::ErrorEntry::UnsupportedModule {
module,
raw_image,
kind: "ELF",
})
} else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC {
self.saved_modules.insert(module);
// TODO: Figure out how to get size of archive module and write it to disk
fn_logger.log(log::ErrorEntry::UnsupportedModule {
module,
raw_image,
kind: "archive",
})
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
unsafe {
fn_logger.try_(|fn_logger| {
trace::record_submodules_from_wrapped_fatbin(
module,
raw_image as *const FatbincWrapper,
fn_logger,
self,
)
});
}
} else {
self.record_module_ptx(module, raw_image, fn_logger)
}
}
fn record_module_ptx(
&mut self,
module: CUmodule,
raw_image: *const c_void,
fn_logger: &mut FnCallLog,
) {
self.saved_modules.insert(module);
let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str();
let module_text = match module_text {
Ok(m) => m,
Err(utf8_err) => {
fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err));
return;
}
};
fn_logger.log_io_error(self.writer.save_module(
self.module_counter,
None,
module_text.as_bytes(),
"ptx",
));
self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text);
}
fn try_parse_and_record_kernels(
&mut self,
fn_logger: &mut FnCallLog,
module_index: usize,
submodule_index: Option<usize>,
module_text: &str,
) {
let errors = ptx_parser::parse_for_errors(module_text);
if !errors.is_empty() {
fn_logger.log(log::ErrorEntry::ModuleParsingError(
DumpWriter::get_file_name(module_index, submodule_index, "log"),
));
fn_logger.log_io_error(self.writer.save_module_error_log(
module_index,
submodule_index,
&*errors,
));
}
}
}
// This structs writes out information about CUDA execution to the dump dir
struct DumpWriter {
dump_dir: Option<PathBuf>,
}
impl DumpWriter {
fn new(dump_dir: Option<PathBuf>) -> Self {
Self { dump_dir }
}
fn save_module(
&self,
module_index: usize,
submodule_index: Option<usize>,
buffer: &[u8],
kind: &'static str,
) -> io::Result<()> {
let mut dump_file = match &self.dump_dir {
None => return Ok(()),
Some(d) => d.clone(),
};
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
let mut file = File::create(dump_file)?;
file.write_all(buffer)?;
Ok(())
}
fn save_module_error_log<'input>(
&self,
module_index: usize,
submodule_index: Option<usize>,
errors: &[ptx_parser::PtxError<'input>],
) -> io::Result<()> {
let mut log_file = match &self.dump_dir {
None => return Ok(()),
Some(d) => d.clone(),
};
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
let mut file = File::create(log_file)?;
for error in errors {
writeln!(file, "{}", error)?;
}
Ok(())
}
fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String {
match submodule_index {
None => {
format!("module_{:04}.{:02}", module_index, kind)
}
Some(submodule_index) => {
format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind)
}
}
}
}
pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
module: CUmodule,
fatbinc_wrapper: *const FatbincWrapper,
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
let mut submodules = fatbin.get_submodules()?;
while let Some(current) = submodules.next()? {
record_submodules_from_fatbin(module, current, fn_logger, state)?;
}
Ok(())
}
pub(crate) unsafe fn record_submodules_from_fatbin(
module: CUmodule,
submodule: FatbinSubmodule,
logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
record_submodules(module, logger, state, submodule.get_files())?;
Ok(())
}
pub(crate) unsafe fn record_submodules(
module: CUmodule,
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
mut files: FatbinFileIterator,
) -> Result<(), ErrorEntry> {
while let Some(file) = files.next()? {
let mut payload = if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedLz4)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
continue
))
} else if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedZstd)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
continue
))
} else {
Cow::Borrowed(file.get_payload())
};
match file.header.kind {
FatbinFileHeader::HEADER_KIND_PTX => {
while payload.last() == Some(&0) {
// remove trailing zeros
payload.to_mut().pop();
}
state.record_new_submodule(module, &*payload, fn_logger, "ptx")
}
FatbinFileHeader::HEADER_KIND_ELF => {
state.record_new_submodule(module, &*payload, fn_logger, "elf")
}
_ => {
fn_logger.log(log::ErrorEntry::UnexpectedBinaryField {
field_name: "FATBIN_FILE_HEADER_KIND",
expected: vec![
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
],
observed: UInt::U16(file.header.kind),
});
}
}
}
Ok(())
}
use crate::{
log::{self, UInt},
trace, ErrorEntry, FnCallLog, Settings,
};
use cuda_types::{
cuda::*,
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
};
use dark_api::fatbin::{
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{
borrow::Cow,
ffi::{c_void, CStr, CString},
fs::{self, File},
io::{self, Read, Write},
path::PathBuf,
};
use unwrap_or::unwrap_some_or;
// This struct is the heart of CUDA state tracking, it:
// * receives calls from the probes about changes to CUDA state
// * records updates to the state change
// * writes out relevant state change and details to disk and log
pub(crate) struct StateTracker {
writer: DumpWriter,
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
saved_modules: FxHashSet<CUmodule>,
module_counter: usize,
submodule_counter: usize,
pub(crate) override_cc: Option<(u32, u32)>,
}
#[derive(Clone, Copy)]
pub(crate) struct CodePointer(pub *const c_void);
unsafe impl Send for CodePointer {}
unsafe impl Sync for CodePointer {}
impl StateTracker {
pub(crate) fn new(settings: &Settings) -> Self {
StateTracker {
writer: DumpWriter::new(settings.dump_dir.clone()),
libraries: FxHashMap::default(),
saved_modules: FxHashSet::default(),
module_counter: 0,
submodule_counter: 0,
override_cc: settings.override_cc,
}
}
pub(crate) fn record_new_module_file(
&mut self,
module: CUmodule,
file_name: *const i8,
fn_logger: &mut FnCallLog,
) {
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
Ok(f) => f,
Err(err) => {
fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
return;
}
};
let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name);
fn_logger.log_io_error(maybe_io_error)
}
fn try_record_new_module_file(
&mut self,
module: CUmodule,
fn_logger: &mut FnCallLog,
file_name: &str,
) -> io::Result<()> {
let mut module_file = fs::File::open(file_name)?;
let mut read_buff = Vec::new();
module_file.read_to_end(&mut read_buff)?;
self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger);
Ok(())
}
pub(crate) fn record_new_submodule(
&mut self,
module: CUmodule,
submodule: &[u8],
fn_logger: &mut FnCallLog,
type_: &'static str,
) {
if self.saved_modules.insert(module) {
self.module_counter += 1;
self.submodule_counter = 0;
}
self.submodule_counter += 1;
fn_logger.log_io_error(self.writer.save_module(
self.module_counter,
Some(self.submodule_counter),
submodule,
type_,
));
if type_ == "ptx" {
match CString::new(submodule) {
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
Ok(submodule_cstring) => match submodule_cstring.to_str() {
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
Ok(submodule_text) => self.try_parse_and_record_kernels(
fn_logger,
self.module_counter,
Some(self.submodule_counter),
submodule_text,
),
},
}
}
}
pub(crate) fn record_new_module(
&mut self,
module: CUmodule,
raw_image: *const c_void,
fn_logger: &mut FnCallLog,
) {
self.module_counter += 1;
if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG {
self.saved_modules.insert(module);
// TODO: Parse ELF and write it to disk
fn_logger.log(log::ErrorEntry::UnsupportedModule {
module,
raw_image,
kind: "ELF",
})
} else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC {
self.saved_modules.insert(module);
// TODO: Figure out how to get size of archive module and write it to disk
fn_logger.log(log::ErrorEntry::UnsupportedModule {
module,
raw_image,
kind: "archive",
})
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
unsafe {
fn_logger.try_(|fn_logger| {
trace::record_submodules_from_wrapped_fatbin(
module,
raw_image as *const FatbincWrapper,
fn_logger,
self,
)
});
}
} else {
self.record_module_ptx(module, raw_image, fn_logger)
}
}
fn record_module_ptx(
&mut self,
module: CUmodule,
raw_image: *const c_void,
fn_logger: &mut FnCallLog,
) {
self.saved_modules.insert(module);
let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str();
let module_text = match module_text {
Ok(m) => m,
Err(utf8_err) => {
fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err));
return;
}
};
fn_logger.log_io_error(self.writer.save_module(
self.module_counter,
None,
module_text.as_bytes(),
"ptx",
));
self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text);
}
fn try_parse_and_record_kernels(
&mut self,
fn_logger: &mut FnCallLog,
module_index: usize,
submodule_index: Option<usize>,
module_text: &str,
) {
let errors = ptx_parser::parse_for_errors(module_text);
if !errors.is_empty() {
fn_logger.log(log::ErrorEntry::ModuleParsingError(
DumpWriter::get_file_name(module_index, submodule_index, "log"),
));
fn_logger.log_io_error(self.writer.save_module_error_log(
module_index,
submodule_index,
&*errors,
));
}
}
}
// This structs writes out information about CUDA execution to the dump dir
struct DumpWriter {
dump_dir: Option<PathBuf>,
}
impl DumpWriter {
fn new(dump_dir: Option<PathBuf>) -> Self {
Self { dump_dir }
}
fn save_module(
&self,
module_index: usize,
submodule_index: Option<usize>,
buffer: &[u8],
kind: &'static str,
) -> io::Result<()> {
let mut dump_file = match &self.dump_dir {
None => return Ok(()),
Some(d) => d.clone(),
};
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
let mut file = File::create(dump_file)?;
file.write_all(buffer)?;
Ok(())
}
fn save_module_error_log<'input>(
&self,
module_index: usize,
submodule_index: Option<usize>,
errors: &[ptx_parser::PtxError<'input>],
) -> io::Result<()> {
let mut log_file = match &self.dump_dir {
None => return Ok(()),
Some(d) => d.clone(),
};
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
let mut file = File::create(log_file)?;
for error in errors {
writeln!(file, "{}", error)?;
}
Ok(())
}
fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String {
match submodule_index {
None => {
format!("module_{:04}.{:02}", module_index, kind)
}
Some(submodule_index) => {
format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind)
}
}
}
}
pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
module: CUmodule,
fatbinc_wrapper: *const FatbincWrapper,
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
let mut submodules = fatbin.get_submodules()?;
while let Some(current) = submodules.next()? {
record_submodules_from_fatbin(module, current, fn_logger, state)?;
}
Ok(())
}
pub(crate) unsafe fn record_submodules_from_fatbin(
module: CUmodule,
submodule: FatbinSubmodule,
logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
record_submodules(module, logger, state, submodule.get_files())?;
Ok(())
}
pub(crate) unsafe fn record_submodules(
module: CUmodule,
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
mut files: FatbinFileIterator,
) -> Result<(), ErrorEntry> {
while let Some(file) = files.next()? {
let mut payload = if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedLz4)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
continue
))
} else if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedZstd)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
continue
))
} else {
Cow::Borrowed(file.get_payload())
};
match file.header.kind {
FatbinFileHeader::HEADER_KIND_PTX => {
while payload.last() == Some(&0) {
// remove trailing zeros
payload.to_mut().pop();
}
state.record_new_submodule(module, &*payload, fn_logger, "ptx")
}
FatbinFileHeader::HEADER_KIND_ELF => {
state.record_new_submodule(module, &*payload, fn_logger, "elf")
}
_ => {
fn_logger.log(log::ErrorEntry::UnexpectedBinaryField {
field_name: "FATBIN_FILE_HEADER_KIND",
expected: vec![
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
],
observed: UInt::U16(file.header.kind),
});
}
}
}
Ok(())
}

View File

@ -1,81 +1,81 @@
use std::{
env::{self, VarError},
fs::{self, DirEntry},
io,
path::{self, PathBuf},
process::Command,
};
fn main() -> Result<(), VarError> {
if std::env::var_os("CARGO_CFG_WINDOWS").is_none() {
return Ok(());
}
println!("cargo:rerun-if-changed=build.rs");
if env::var("PROFILE")? != "debug" {
return Ok(());
}
let rustc_exe = env::var("RUSTC")?;
let out_dir = env::var("OUT_DIR")?;
let target = env::var("TARGET")?;
let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc";
let opt_level = env::var("OPT_LEVEL")?;
let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap();
let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
helpers_dir.push("tests");
helpers_dir.push("helpers");
let helpers_dir_as_string = helpers_dir.to_string_lossy();
println!("cargo:rerun-if-changed={}", helpers_dir_as_string);
for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) {
let full_file_path = format!(
"{}{}{}",
helpers_dir_as_string,
path::MAIN_SEPARATOR,
rust_file
);
let mut rustc_cmd = Command::new(&*rustc_exe);
if debug {
rustc_cmd.arg("-g");
}
rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string));
if !is_msvc {
// HACK ALERT
// I have no idea why the extra library below have to be linked
rustc_cmd.arg(r"-lucrt");
}
rustc_cmd
.arg("-C")
.arg(format!("opt-level={}", opt_level))
.arg("-L")
.arg(format!("{}", out_dir))
.arg("--out-dir")
.arg(format!("{}", out_dir))
.arg("--target")
.arg(format!("{}", target))
.arg(full_file_path);
assert!(rustc_cmd.status().unwrap().success());
}
std::fs::copy(
format!(
"{}{}do_cuinit_late_clr.exe",
helpers_dir_as_string,
path::MAIN_SEPARATOR
),
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
)
.unwrap();
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
Ok(())
}
fn rust_file(entry: io::Result<DirEntry>) -> Option<String> {
entry.ok().and_then(|e| {
let os_file_name = e.file_name();
let file_name = os_file_name.to_string_lossy();
let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false);
if is_file && file_name.ends_with(".rs") {
Some(file_name.to_string())
} else {
None
}
})
}
use std::{
env::{self, VarError},
fs::{self, DirEntry},
io,
path::{self, PathBuf},
process::Command,
};
fn main() -> Result<(), VarError> {
if std::env::var_os("CARGO_CFG_WINDOWS").is_none() {
return Ok(());
}
println!("cargo:rerun-if-changed=build.rs");
if env::var("PROFILE")? != "debug" {
return Ok(());
}
let rustc_exe = env::var("RUSTC")?;
let out_dir = env::var("OUT_DIR")?;
let target = env::var("TARGET")?;
let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc";
let opt_level = env::var("OPT_LEVEL")?;
let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap();
let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
helpers_dir.push("tests");
helpers_dir.push("helpers");
let helpers_dir_as_string = helpers_dir.to_string_lossy();
println!("cargo:rerun-if-changed={}", helpers_dir_as_string);
for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) {
let full_file_path = format!(
"{}{}{}",
helpers_dir_as_string,
path::MAIN_SEPARATOR,
rust_file
);
let mut rustc_cmd = Command::new(&*rustc_exe);
if debug {
rustc_cmd.arg("-g");
}
rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string));
if !is_msvc {
// HACK ALERT
// I have no idea why the extra library below have to be linked
rustc_cmd.arg(r"-lucrt");
}
rustc_cmd
.arg("-C")
.arg(format!("opt-level={}", opt_level))
.arg("-L")
.arg(format!("{}", out_dir))
.arg("--out-dir")
.arg(format!("{}", out_dir))
.arg("--target")
.arg(format!("{}", target))
.arg(full_file_path);
assert!(rustc_cmd.status().unwrap().success());
}
std::fs::copy(
format!(
"{}{}do_cuinit_late_clr.exe",
helpers_dir_as_string,
path::MAIN_SEPARATOR
),
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
)
.unwrap();
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
Ok(())
}
fn rust_file(entry: io::Result<DirEntry>) -> Option<String> {
entry.ok().and_then(|e| {
let os_file_name = e.file_name();
let file_name = os_file_name.to_string_lossy();
let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false);
if is_file && file_name.ends_with(".rs") {
Some(file_name.to_string())
} else {
None
}
})
}

View File

@ -1,311 +1,311 @@
use std::env;
use std::os::windows;
use std::os::windows::ffi::OsStrExt;
use std::{error::Error, process};
use std::{fs, io, ptr};
use std::{mem, path::PathBuf};
use argh::FromArgs;
use mem::size_of_val;
use tempfile::TempDir;
use winapi::um::processenv::SearchPathW;
use winapi::um::{
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
processthreadsapi::{GetExitCodeProcess, ResumeThread},
synchapi::WaitForSingleObject,
winbase::CreateJobObjectA,
winnt::{
JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
},
};
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
static NVCUDA_DLL: &'static str = "nvcuda.dll";
static NVML_DLL: &'static str = "nvml.dll";
include!("../../zluda_redirect/src/payload_guid.rs");
#[derive(FromArgs)]
/// Launch application with custom CUDA libraries
struct ProgramArguments {
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
#[argh(option)]
nvcuda: Option<PathBuf>,
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
#[argh(option)]
nvml: Option<PathBuf>,
/// executable to be injected with custom CUDA libraries
#[argh(positional)]
exe: String,
/// arguments to the executable
#[argh(positional)]
args: Vec<String>,
}
pub fn main_impl() -> Result<(), Box<dyn Error>> {
let raw_args = argh::from_env::<ProgramArguments>();
let normalized_args = NormalizedArguments::new(raw_args)?;
let mut environment = Environment::setup(normalized_args)?;
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
let mut dlls_to_inject = [
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
environment.nvcuda_path_zero_terminated.as_ptr() as _,
environment.redirect_path_zero_terminated.as_ptr() as _,
];
os_call!(
detours_sys::DetourCreateProcessWithDllsW(
ptr::null(),
environment.winapi_command_line_zero_terminated.as_mut_ptr(),
ptr::null_mut(),
ptr::null_mut(),
0,
0,
ptr::null_mut(),
ptr::null(),
&mut startup_info as *mut _,
&mut proc_info as *mut _,
dlls_to_inject.len() as u32,
dlls_to_inject.as_mut_ptr(),
Option::None
),
|x| x != 0
);
kill_child_on_process_exit(proc_info.hProcess)?;
os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
&PAYLOAD_NVCUDA_GUID,
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
environment.nvcuda_path_zero_terminated.len() as u32
),
|x| x != 0
);
os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
&PAYLOAD_NVML_GUID,
environment.nvml_path_zero_terminated.as_ptr() as *mut _,
environment.nvml_path_zero_terminated.len() as u32
),
|x| x != 0
);
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
!= WAIT_FAILED);
let mut child_exit_code: u32 = 0;
os_call!(
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|x| x != 0
);
process::exit(child_exit_code as i32)
}
struct NormalizedArguments {
nvml_path: PathBuf,
nvcuda_path: PathBuf,
redirect_path: PathBuf,
winapi_command_line_zero_terminated: Vec<u16>,
}
impl NormalizedArguments {
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
let current_exe = env::current_exe()?;
let nvml_path = Self::get_absolute_path(&current_exe, prog_args.nvml, NVML_DLL)?;
let nvcuda_path = Self::get_absolute_path(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?;
let winapi_command_line_zero_terminated =
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
redirect_path.push(REDIRECT_DLL);
Ok(Self {
nvml_path,
nvcuda_path,
redirect_path,
winapi_command_line_zero_terminated,
})
}
const WIN_MAX_PATH: usize = 260;
fn get_absolute_path(
current_exe: &PathBuf,
dll: Option<PathBuf>,
default: &str,
) -> Result<PathBuf, Box<dyn Error>> {
Ok(if let Some(dll) = dll {
if dll.is_absolute() {
dll
} else {
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
dll_utf16.push(0);
loop {
let copied_len = os_call!(
SearchPathW(
ptr::null_mut(),
dll_utf16.as_ptr(),
ptr::null(),
full_dll_path.len() as u32,
full_dll_path.as_mut_ptr(),
ptr::null_mut()
),
|x| x != 0
) as usize;
if copied_len > full_dll_path.len() {
full_dll_path.resize(copied_len + 1, 0);
} else {
full_dll_path.truncate(copied_len);
break;
}
}
PathBuf::from(String::from_utf16_lossy(&full_dll_path))
}
} else {
let mut dll_path = current_exe.parent().unwrap().to_path_buf();
dll_path.push(default);
dll_path
})
}
}
struct Environment {
nvml_path_zero_terminated: String,
nvcuda_path_zero_terminated: String,
redirect_path_zero_terminated: String,
winapi_command_line_zero_terminated: Vec<u16>,
_temp_dir: TempDir,
}
// This structs represents "enviroment". By environment we mean all paths
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
// directory which contains nvcuda.dll
impl Environment {
fn setup(args: NormalizedArguments) -> io::Result<Self> {
let _temp_dir = TempDir::new()?;
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.nvml_path,
&_temp_dir,
NVML_DLL,
)?);
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.nvcuda_path,
&_temp_dir,
NVCUDA_DLL,
)?);
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
Ok(Self {
nvml_path_zero_terminated,
nvcuda_path_zero_terminated,
redirect_path_zero_terminated,
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
_temp_dir,
})
}
fn copy_to_correct_name(
path_buf: PathBuf,
temp_dir: &TempDir,
correct_name: &str,
) -> io::Result<PathBuf> {
let file_name = path_buf.file_name().unwrap();
if file_name == correct_name {
Ok(path_buf)
} else {
let mut temp_file_path = temp_dir.path().to_path_buf();
temp_file_path.push(correct_name);
match windows::fs::symlink_file(&path_buf, &temp_file_path) {
Ok(()) => {}
Err(_) => {
fs::copy(&path_buf, &temp_file_path)?;
}
}
Ok(temp_file_path)
}
}
fn zero_terminate(p: PathBuf) -> String {
let mut s = p.to_string_lossy().to_string();
s.push('\0');
s
}
}
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
!= ptr::null_mut());
let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() };
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
os_call!(
SetInformationJobObject(
job_handle,
JobObjectExtendedLimitInformation,
&mut info as *mut _ as *mut _,
size_of_val(&info) as u32
),
|x| x != 0
);
os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0);
Ok(())
}
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
let mut cmd_line = Vec::new();
let args_len = args.size_hint().0;
for (idx, arg) in args.enumerate() {
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
cmd_line.extend(arg.encode_utf16());
} else {
cmd_line.push('"' as u16); // "
let mut char_iter = arg.chars().peekable();
loop {
let mut current = char_iter.next();
let mut backslashes = 0;
match current {
Some('\\') => {
backslashes = 1;
while let Some('\\') = char_iter.peek() {
backslashes += 1;
char_iter.next();
}
current = char_iter.next();
}
_ => {}
}
match current {
None => {
for _ in 0..(backslashes * 2) {
cmd_line.push('\\' as u16);
}
break;
}
Some('"') => {
for _ in 0..(backslashes * 2 + 1) {
cmd_line.push('\\' as u16);
}
cmd_line.push('"' as u16);
}
Some(c) => {
for _ in 0..backslashes {
cmd_line.push('\\' as u16);
}
let mut temp = [0u16; 2];
cmd_line.extend(&*c.encode_utf16(&mut temp));
}
}
}
cmd_line.push('"' as u16);
}
if idx < args_len - 1 {
cmd_line.push(' ' as u16);
}
}
cmd_line.push(0);
cmd_line
}
use std::env;
use std::os::windows;
use std::os::windows::ffi::OsStrExt;
use std::{error::Error, process};
use std::{fs, io, ptr};
use std::{mem, path::PathBuf};
use argh::FromArgs;
use mem::size_of_val;
use tempfile::TempDir;
use winapi::um::processenv::SearchPathW;
use winapi::um::{
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
processthreadsapi::{GetExitCodeProcess, ResumeThread},
synchapi::WaitForSingleObject,
winbase::CreateJobObjectA,
winnt::{
JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
},
};
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
static NVCUDA_DLL: &'static str = "nvcuda.dll";
static NVML_DLL: &'static str = "nvml.dll";
include!("../../zluda_redirect/src/payload_guid.rs");
#[derive(FromArgs)]
/// Launch application with custom CUDA libraries
struct ProgramArguments {
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
#[argh(option)]
nvcuda: Option<PathBuf>,
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
#[argh(option)]
nvml: Option<PathBuf>,
/// executable to be injected with custom CUDA libraries
#[argh(positional)]
exe: String,
/// arguments to the executable
#[argh(positional)]
args: Vec<String>,
}
pub fn main_impl() -> Result<(), Box<dyn Error>> {
let raw_args = argh::from_env::<ProgramArguments>();
let normalized_args = NormalizedArguments::new(raw_args)?;
let mut environment = Environment::setup(normalized_args)?;
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
let mut dlls_to_inject = [
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
environment.nvcuda_path_zero_terminated.as_ptr() as _,
environment.redirect_path_zero_terminated.as_ptr() as _,
];
os_call!(
detours_sys::DetourCreateProcessWithDllsW(
ptr::null(),
environment.winapi_command_line_zero_terminated.as_mut_ptr(),
ptr::null_mut(),
ptr::null_mut(),
0,
0,
ptr::null_mut(),
ptr::null(),
&mut startup_info as *mut _,
&mut proc_info as *mut _,
dlls_to_inject.len() as u32,
dlls_to_inject.as_mut_ptr(),
Option::None
),
|x| x != 0
);
kill_child_on_process_exit(proc_info.hProcess)?;
os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
&PAYLOAD_NVCUDA_GUID,
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
environment.nvcuda_path_zero_terminated.len() as u32
),
|x| x != 0
);
os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
&PAYLOAD_NVML_GUID,
environment.nvml_path_zero_terminated.as_ptr() as *mut _,
environment.nvml_path_zero_terminated.len() as u32
),
|x| x != 0
);
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
!= WAIT_FAILED);
let mut child_exit_code: u32 = 0;
os_call!(
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|x| x != 0
);
process::exit(child_exit_code as i32)
}
struct NormalizedArguments {
nvml_path: PathBuf,
nvcuda_path: PathBuf,
redirect_path: PathBuf,
winapi_command_line_zero_terminated: Vec<u16>,
}
impl NormalizedArguments {
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
let current_exe = env::current_exe()?;
let nvml_path = Self::get_absolute_path(&current_exe, prog_args.nvml, NVML_DLL)?;
let nvcuda_path = Self::get_absolute_path(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?;
let winapi_command_line_zero_terminated =
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
redirect_path.push(REDIRECT_DLL);
Ok(Self {
nvml_path,
nvcuda_path,
redirect_path,
winapi_command_line_zero_terminated,
})
}
const WIN_MAX_PATH: usize = 260;
fn get_absolute_path(
current_exe: &PathBuf,
dll: Option<PathBuf>,
default: &str,
) -> Result<PathBuf, Box<dyn Error>> {
Ok(if let Some(dll) = dll {
if dll.is_absolute() {
dll
} else {
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
dll_utf16.push(0);
loop {
let copied_len = os_call!(
SearchPathW(
ptr::null_mut(),
dll_utf16.as_ptr(),
ptr::null(),
full_dll_path.len() as u32,
full_dll_path.as_mut_ptr(),
ptr::null_mut()
),
|x| x != 0
) as usize;
if copied_len > full_dll_path.len() {
full_dll_path.resize(copied_len + 1, 0);
} else {
full_dll_path.truncate(copied_len);
break;
}
}
PathBuf::from(String::from_utf16_lossy(&full_dll_path))
}
} else {
let mut dll_path = current_exe.parent().unwrap().to_path_buf();
dll_path.push(default);
dll_path
})
}
}
struct Environment {
nvml_path_zero_terminated: String,
nvcuda_path_zero_terminated: String,
redirect_path_zero_terminated: String,
winapi_command_line_zero_terminated: Vec<u16>,
_temp_dir: TempDir,
}
// This structs represents "enviroment". By environment we mean all paths
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
// directory which contains nvcuda.dll
impl Environment {
fn setup(args: NormalizedArguments) -> io::Result<Self> {
let _temp_dir = TempDir::new()?;
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.nvml_path,
&_temp_dir,
NVML_DLL,
)?);
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.nvcuda_path,
&_temp_dir,
NVCUDA_DLL,
)?);
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
Ok(Self {
nvml_path_zero_terminated,
nvcuda_path_zero_terminated,
redirect_path_zero_terminated,
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
_temp_dir,
})
}
fn copy_to_correct_name(
path_buf: PathBuf,
temp_dir: &TempDir,
correct_name: &str,
) -> io::Result<PathBuf> {
let file_name = path_buf.file_name().unwrap();
if file_name == correct_name {
Ok(path_buf)
} else {
let mut temp_file_path = temp_dir.path().to_path_buf();
temp_file_path.push(correct_name);
match windows::fs::symlink_file(&path_buf, &temp_file_path) {
Ok(()) => {}
Err(_) => {
fs::copy(&path_buf, &temp_file_path)?;
}
}
Ok(temp_file_path)
}
}
fn zero_terminate(p: PathBuf) -> String {
let mut s = p.to_string_lossy().to_string();
s.push('\0');
s
}
}
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
!= ptr::null_mut());
let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() };
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
os_call!(
SetInformationJobObject(
job_handle,
JobObjectExtendedLimitInformation,
&mut info as *mut _ as *mut _,
size_of_val(&info) as u32
),
|x| x != 0
);
os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0);
Ok(())
}
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
let mut cmd_line = Vec::new();
let args_len = args.size_hint().0;
for (idx, arg) in args.enumerate() {
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
cmd_line.extend(arg.encode_utf16());
} else {
cmd_line.push('"' as u16); // "
let mut char_iter = arg.chars().peekable();
loop {
let mut current = char_iter.next();
let mut backslashes = 0;
match current {
Some('\\') => {
backslashes = 1;
while let Some('\\') = char_iter.peek() {
backslashes += 1;
char_iter.next();
}
current = char_iter.next();
}
_ => {}
}
match current {
None => {
for _ in 0..(backslashes * 2) {
cmd_line.push('\\' as u16);
}
break;
}
Some('"') => {
for _ in 0..(backslashes * 2 + 1) {
cmd_line.push('\\' as u16);
}
cmd_line.push('"' as u16);
}
Some(c) => {
for _ in 0..backslashes {
cmd_line.push('\\' as u16);
}
let mut temp = [0u16; 2];
cmd_line.extend(&*c.encode_utf16(&mut temp));
}
}
}
cmd_line.push('"' as u16);
}
if idx < args_len - 1 {
cmd_line.push(' ' as u16);
}
}
cmd_line.push(0);
cmd_line
}

View File

@ -1,13 +1,13 @@
#[macro_use]
#[cfg(target_os = "windows")]
mod win;
#[cfg(target_os = "windows")]
mod bin;
#[cfg(target_os = "windows")]
fn main() -> Result<(), Box<dyn std::error::Error>> {
bin::main_impl()
}
#[cfg(not(target_os = "windows"))]
fn main() {}
#[macro_use]
#[cfg(target_os = "windows")]
mod win;
#[cfg(target_os = "windows")]
mod bin;
#[cfg(target_os = "windows")]
fn main() -> Result<(), Box<dyn std::error::Error>> {
bin::main_impl()
}
#[cfg(not(target_os = "windows"))]
fn main() {}

View File

@ -1,151 +1,151 @@
#![allow(non_snake_case)]
use std::error;
use std::fmt;
use std::ptr;
mod c {
use std::ffi::c_void;
use std::os::raw::c_ulong;
pub type DWORD = c_ulong;
pub type HANDLE = LPVOID;
pub type LPVOID = *mut c_void;
pub type HINSTANCE = HANDLE;
pub type HMODULE = HINSTANCE;
pub type WCHAR = u16;
pub type LPCWSTR = *const WCHAR;
pub type LPWSTR = *mut WCHAR;
pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;
extern "system" {
pub fn GetLastError() -> DWORD;
pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
pub fn FormatMessageW(
flags: DWORD,
lpSrc: LPVOID,
msgId: DWORD,
langId: DWORD,
buf: LPWSTR,
nsize: DWORD,
args: *const c_void,
) -> DWORD;
}
}
macro_rules! last_ident {
($i:ident) => {
stringify!($i)
};
($start:ident, $($cont:ident),+) => {
last_ident!($($cont),+)
};
}
macro_rules! os_call {
($($path:ident)::+ ($($args:expr),*), $success:expr) => {
{
let result = unsafe{ $($path)::+ ($($args),*) };
if !($success)(result) {
let name = last_ident!($($path),+);
let err_code = $crate::win::errno();
Err($crate::win::OsError{
function: name,
error_code: err_code as u32,
message: $crate::win::error_string(err_code)
})?;
}
result
}
};
}
#[derive(Debug)]
pub struct OsError {
pub function: &'static str,
pub error_code: u32,
pub message: String,
}
impl fmt::Display for OsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl error::Error for OsError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
None
}
}
pub fn errno() -> i32 {
unsafe { c::GetLastError() as i32 }
}
/// Gets a detailed string description for the given error number.
pub fn error_string(mut errnum: i32) -> String {
// This value is calculated from the macro
// MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
let langId = 0x0800 as c::DWORD;
let mut buf = [0 as c::WCHAR; 2048];
unsafe {
let mut module = ptr::null_mut();
let mut flags = 0;
// NTSTATUS errors may be encoded as HRESULT, which may returned from
// GetLastError. For more information about Windows error codes, see
// `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
// format according to https://support.microsoft.com/en-us/help/259693
const NTDLL_DLL: &[u16] = &[
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
'L' as _, 0,
];
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());
if module != ptr::null_mut() {
errnum ^= c::FACILITY_NT_BIT as i32;
flags = c::FORMAT_MESSAGE_FROM_HMODULE;
}
}
let res = c::FormatMessageW(
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
module,
errnum as c::DWORD,
langId,
buf.as_mut_ptr(),
buf.len() as c::DWORD,
ptr::null(),
) as usize;
if res == 0 {
// Sometimes FormatMessageW can fail e.g., system doesn't like langId,
let fm_err = errno();
return format!(
"OS Error {} (FormatMessageW() returned error {})",
errnum, fm_err
);
}
match String::from_utf16(&buf[..res]) {
Ok(mut msg) => {
// Trim trailing CRLF inserted by FormatMessageW
let len = msg.trim_end().len();
msg.truncate(len);
msg
}
Err(..) => format!(
"OS Error {} (FormatMessageW() returned \
invalid UTF-16)",
errnum
),
}
}
}
#![allow(non_snake_case)]
use std::error;
use std::fmt;
use std::ptr;
mod c {
use std::ffi::c_void;
use std::os::raw::c_ulong;
pub type DWORD = c_ulong;
pub type HANDLE = LPVOID;
pub type LPVOID = *mut c_void;
pub type HINSTANCE = HANDLE;
pub type HMODULE = HINSTANCE;
pub type WCHAR = u16;
pub type LPCWSTR = *const WCHAR;
pub type LPWSTR = *mut WCHAR;
pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;
extern "system" {
pub fn GetLastError() -> DWORD;
pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
pub fn FormatMessageW(
flags: DWORD,
lpSrc: LPVOID,
msgId: DWORD,
langId: DWORD,
buf: LPWSTR,
nsize: DWORD,
args: *const c_void,
) -> DWORD;
}
}
macro_rules! last_ident {
($i:ident) => {
stringify!($i)
};
($start:ident, $($cont:ident),+) => {
last_ident!($($cont),+)
};
}
macro_rules! os_call {
($($path:ident)::+ ($($args:expr),*), $success:expr) => {
{
let result = unsafe{ $($path)::+ ($($args),*) };
if !($success)(result) {
let name = last_ident!($($path),+);
let err_code = $crate::win::errno();
Err($crate::win::OsError{
function: name,
error_code: err_code as u32,
message: $crate::win::error_string(err_code)
})?;
}
result
}
};
}
#[derive(Debug)]
pub struct OsError {
pub function: &'static str,
pub error_code: u32,
pub message: String,
}
impl fmt::Display for OsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl error::Error for OsError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
None
}
}
pub fn errno() -> i32 {
unsafe { c::GetLastError() as i32 }
}
/// Gets a detailed string description for the given error number.
pub fn error_string(mut errnum: i32) -> String {
// This value is calculated from the macro
// MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
let langId = 0x0800 as c::DWORD;
let mut buf = [0 as c::WCHAR; 2048];
unsafe {
let mut module = ptr::null_mut();
let mut flags = 0;
// NTSTATUS errors may be encoded as HRESULT, which may returned from
// GetLastError. For more information about Windows error codes, see
// `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
// format according to https://support.microsoft.com/en-us/help/259693
const NTDLL_DLL: &[u16] = &[
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
'L' as _, 0,
];
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());
if module != ptr::null_mut() {
errnum ^= c::FACILITY_NT_BIT as i32;
flags = c::FORMAT_MESSAGE_FROM_HMODULE;
}
}
let res = c::FormatMessageW(
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
module,
errnum as c::DWORD,
langId,
buf.as_mut_ptr(),
buf.len() as c::DWORD,
ptr::null(),
) as usize;
if res == 0 {
// Sometimes FormatMessageW can fail e.g., system doesn't like langId,
let fm_err = errno();
return format!(
"OS Error {} (FormatMessageW() returned error {})",
errnum, fm_err
);
}
match String::from_utf16(&buf[..res]) {
Ok(mut msg) => {
// Trim trailing CRLF inserted by FormatMessageW
let len = msg.trim_end().len();
msg.truncate(len);
msg
}
Err(..) => format!(
"OS Error {} (FormatMessageW() returned \
invalid UTF-16)",
errnum
),
}
}
}

View File

@ -1,51 +1,51 @@
#![cfg(windows)]
use std::{env, io, path::PathBuf, process::Command};
#[test]
fn direct_cuinit() -> io::Result<()> {
run_process_and_check_for_zluda_dump("direct_cuinit")
}
#[test]
fn do_cuinit_early() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_early")
}
#[test]
fn do_cuinit_late() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_late")
}
#[test]
fn do_cuinit_late_clr() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
}
#[test]
fn indirect_cuinit() -> io::Result<()> {
run_process_and_check_for_zluda_dump("indirect_cuinit")
}
#[test]
fn subprocess() -> io::Result<()> {
run_process_and_check_for_zluda_dump("subprocess")
}
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
zluda_dump_dll.push("zluda_dump.dll");
let helpers_dir = env!("HELPERS_OUT_DIR");
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
let mut test_cmd = Command::new(&zluda_with_exe);
let test_cmd = test_cmd
.arg("--nvcuda")
.arg(&zluda_dump_dll)
.arg("--")
.arg(&exe_under_test);
let test_output = test_cmd.output()?;
assert!(test_output.status.success());
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
assert!(stderr_text.contains("ZLUDA_DUMP"));
Ok(())
}
#![cfg(windows)]
use std::{env, io, path::PathBuf, process::Command};
#[test]
fn direct_cuinit() -> io::Result<()> {
run_process_and_check_for_zluda_dump("direct_cuinit")
}
#[test]
fn do_cuinit_early() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_early")
}
#[test]
fn do_cuinit_late() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_late")
}
#[test]
fn do_cuinit_late_clr() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
}
#[test]
fn indirect_cuinit() -> io::Result<()> {
run_process_and_check_for_zluda_dump("indirect_cuinit")
}
#[test]
fn subprocess() -> io::Result<()> {
run_process_and_check_for_zluda_dump("subprocess")
}
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
zluda_dump_dll.push("zluda_dump.dll");
let helpers_dir = env!("HELPERS_OUT_DIR");
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
let mut test_cmd = Command::new(&zluda_with_exe);
let test_cmd = test_cmd
.arg("--nvcuda")
.arg(&zluda_dump_dll)
.arg("--")
.arg(&exe_under_test);
let test_output = test_cmd.output()?;
assert!(test_output.status.success());
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
assert!(stderr_text.contains("ZLUDA_DUMP"));
Ok(())
}