mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-01 22:37:48 +03:00
Always use Unix line endings (#453)
This commit is contained in:
1
.rustfmt.toml
Normal file
1
.rustfmt.toml
Normal file
@ -0,0 +1 @@
|
||||
newline_style = "Unix"
|
@ -1,147 +1,147 @@
|
||||
use cmake::Config;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
|
||||
const COMPONENTS: &[&'static str] = &[
|
||||
"LLVMCore",
|
||||
"LLVMBitWriter",
|
||||
#[cfg(debug_assertions)]
|
||||
"LLVMAnalysis", // for module verify
|
||||
#[cfg(debug_assertions)]
|
||||
"LLVMBitReader",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
let mut cmake = Config::new(r"../ext/llvm-project/llvm");
|
||||
try_use_sccache(&mut cmake);
|
||||
try_use_ninja(&mut cmake);
|
||||
cmake
|
||||
// It's not like we can do anything about the warnings
|
||||
.define("LLVM_ENABLE_WARNINGS", "OFF")
|
||||
// For some reason Rust always links to release CRT
|
||||
.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded")
|
||||
.define("LLVM_ENABLE_TERMINFO", "OFF")
|
||||
.define("LLVM_ENABLE_LIBXML2", "OFF")
|
||||
.define("LLVM_ENABLE_LIBEDIT", "OFF")
|
||||
.define("LLVM_ENABLE_LIBPFM", "OFF")
|
||||
.define("LLVM_ENABLE_ZLIB", "OFF")
|
||||
.define("LLVM_ENABLE_ZSTD", "OFF")
|
||||
.define("LLVM_INCLUDE_BENCHMARKS", "OFF")
|
||||
.define("LLVM_INCLUDE_EXAMPLES", "OFF")
|
||||
.define("LLVM_INCLUDE_TESTS", "OFF")
|
||||
.define("LLVM_BUILD_TOOLS", "OFF")
|
||||
.define("LLVM_TARGETS_TO_BUILD", "")
|
||||
.define("LLVM_ENABLE_PROJECTS", "");
|
||||
cmake.build_target("llvm-config");
|
||||
let llvm_dir = cmake.build();
|
||||
for c in COMPONENTS {
|
||||
cmake.build_target(c);
|
||||
cmake.build();
|
||||
}
|
||||
let cmake_profile = cmake.get_profile();
|
||||
let (cxxflags, ldflags, libdir, lib_names, system_libs) =
|
||||
llvm_config(&llvm_dir, &["build", "bin", "llvm-config"])
|
||||
.or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"]))
|
||||
.unwrap();
|
||||
println!("cargo:rustc-link-arg={ldflags}");
|
||||
println!("cargo:rustc-link-search=native={libdir}");
|
||||
for lib in system_libs.split_ascii_whitespace() {
|
||||
println!("cargo:rustc-link-arg={lib}");
|
||||
}
|
||||
link_llvm_components(lib_names);
|
||||
compile_cxx_lib(cxxflags);
|
||||
}
|
||||
|
||||
// https://github.com/mozilla/sccache/blob/main/README.md#usage
|
||||
fn try_use_sccache(cmake: &mut Config) {
|
||||
if let Ok(sccache) = std::env::var("SCCACHE_PATH") {
|
||||
cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache);
|
||||
cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache);
|
||||
match std::env::var_os("CARGO_CFG_TARGET_OS") {
|
||||
Some(os) if os == "windows" => {
|
||||
cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded");
|
||||
cmake.define("CMAKE_POLICY_CMP0141", "NEW");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_use_ninja(cmake: &mut Config) {
|
||||
let mut cmd = Command::new("ninja");
|
||||
cmd.arg("--version");
|
||||
if let Ok(status) = cmd.status() {
|
||||
if status.success() {
|
||||
cmake.generator("Ninja");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn llvm_config(
|
||||
llvm_build_dir: &PathBuf,
|
||||
path_to_llvm_config: &[&str],
|
||||
) -> io::Result<(String, String, String, String, String)> {
|
||||
let mut llvm_build_path = llvm_build_dir.clone();
|
||||
llvm_build_path.extend(path_to_llvm_config);
|
||||
let mut cmd = Command::new(llvm_build_path);
|
||||
cmd.args([
|
||||
"--link-static",
|
||||
"--cxxflags",
|
||||
"--ldflags",
|
||||
"--libdir",
|
||||
"--libnames",
|
||||
"--system-libs",
|
||||
]);
|
||||
for c in COMPONENTS {
|
||||
cmd.arg(c[4..].to_lowercase());
|
||||
}
|
||||
let output = cmd.output()?;
|
||||
if !output.status.success() {
|
||||
return Err(io::Error::from(io::ErrorKind::Other));
|
||||
}
|
||||
let output = unsafe { String::from_utf8_unchecked(output.stdout) };
|
||||
let mut lines = output.lines();
|
||||
let cxxflags = lines.next().unwrap();
|
||||
let ldflags = lines.next().unwrap();
|
||||
let libdir = lines.next().unwrap();
|
||||
let lib_names = lines.next().unwrap();
|
||||
let system_libs = lines.next().unwrap();
|
||||
Ok((
|
||||
cxxflags.to_string(),
|
||||
ldflags.to_string(),
|
||||
libdir.to_string(),
|
||||
lib_names.to_string(),
|
||||
system_libs.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn compile_cxx_lib(cxxflags: String) {
|
||||
let mut cc = cc::Build::new();
|
||||
for flag in cxxflags.split_whitespace() {
|
||||
cc.flag(flag);
|
||||
}
|
||||
cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp");
|
||||
println!("cargo:rerun-if-changed=src/lib.cpp");
|
||||
println!("cargo:rerun-if-changed=src/lib.rs");
|
||||
}
|
||||
|
||||
fn link_llvm_components(components: String) {
|
||||
for component in components.split_whitespace() {
|
||||
let component = if let Some(component) = component
|
||||
.strip_prefix("lib")
|
||||
.and_then(|component| component.strip_suffix(".a"))
|
||||
{
|
||||
// Unix (Linux/Mac)
|
||||
// libLLVMfoo.a
|
||||
component
|
||||
} else if let Some(component) = component.strip_suffix(".lib") {
|
||||
// Windows
|
||||
// LLVMfoo.lib
|
||||
component
|
||||
} else {
|
||||
panic!("'{}' does not look like a static library name", component)
|
||||
};
|
||||
println!("cargo:rustc-link-lib={component}");
|
||||
}
|
||||
}
|
||||
use cmake::Config;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
|
||||
const COMPONENTS: &[&'static str] = &[
|
||||
"LLVMCore",
|
||||
"LLVMBitWriter",
|
||||
#[cfg(debug_assertions)]
|
||||
"LLVMAnalysis", // for module verify
|
||||
#[cfg(debug_assertions)]
|
||||
"LLVMBitReader",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
let mut cmake = Config::new(r"../ext/llvm-project/llvm");
|
||||
try_use_sccache(&mut cmake);
|
||||
try_use_ninja(&mut cmake);
|
||||
cmake
|
||||
// It's not like we can do anything about the warnings
|
||||
.define("LLVM_ENABLE_WARNINGS", "OFF")
|
||||
// For some reason Rust always links to release CRT
|
||||
.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded")
|
||||
.define("LLVM_ENABLE_TERMINFO", "OFF")
|
||||
.define("LLVM_ENABLE_LIBXML2", "OFF")
|
||||
.define("LLVM_ENABLE_LIBEDIT", "OFF")
|
||||
.define("LLVM_ENABLE_LIBPFM", "OFF")
|
||||
.define("LLVM_ENABLE_ZLIB", "OFF")
|
||||
.define("LLVM_ENABLE_ZSTD", "OFF")
|
||||
.define("LLVM_INCLUDE_BENCHMARKS", "OFF")
|
||||
.define("LLVM_INCLUDE_EXAMPLES", "OFF")
|
||||
.define("LLVM_INCLUDE_TESTS", "OFF")
|
||||
.define("LLVM_BUILD_TOOLS", "OFF")
|
||||
.define("LLVM_TARGETS_TO_BUILD", "")
|
||||
.define("LLVM_ENABLE_PROJECTS", "");
|
||||
cmake.build_target("llvm-config");
|
||||
let llvm_dir = cmake.build();
|
||||
for c in COMPONENTS {
|
||||
cmake.build_target(c);
|
||||
cmake.build();
|
||||
}
|
||||
let cmake_profile = cmake.get_profile();
|
||||
let (cxxflags, ldflags, libdir, lib_names, system_libs) =
|
||||
llvm_config(&llvm_dir, &["build", "bin", "llvm-config"])
|
||||
.or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"]))
|
||||
.unwrap();
|
||||
println!("cargo:rustc-link-arg={ldflags}");
|
||||
println!("cargo:rustc-link-search=native={libdir}");
|
||||
for lib in system_libs.split_ascii_whitespace() {
|
||||
println!("cargo:rustc-link-arg={lib}");
|
||||
}
|
||||
link_llvm_components(lib_names);
|
||||
compile_cxx_lib(cxxflags);
|
||||
}
|
||||
|
||||
// https://github.com/mozilla/sccache/blob/main/README.md#usage
|
||||
fn try_use_sccache(cmake: &mut Config) {
|
||||
if let Ok(sccache) = std::env::var("SCCACHE_PATH") {
|
||||
cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache);
|
||||
cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache);
|
||||
match std::env::var_os("CARGO_CFG_TARGET_OS") {
|
||||
Some(os) if os == "windows" => {
|
||||
cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded");
|
||||
cmake.define("CMAKE_POLICY_CMP0141", "NEW");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_use_ninja(cmake: &mut Config) {
|
||||
let mut cmd = Command::new("ninja");
|
||||
cmd.arg("--version");
|
||||
if let Ok(status) = cmd.status() {
|
||||
if status.success() {
|
||||
cmake.generator("Ninja");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn llvm_config(
|
||||
llvm_build_dir: &PathBuf,
|
||||
path_to_llvm_config: &[&str],
|
||||
) -> io::Result<(String, String, String, String, String)> {
|
||||
let mut llvm_build_path = llvm_build_dir.clone();
|
||||
llvm_build_path.extend(path_to_llvm_config);
|
||||
let mut cmd = Command::new(llvm_build_path);
|
||||
cmd.args([
|
||||
"--link-static",
|
||||
"--cxxflags",
|
||||
"--ldflags",
|
||||
"--libdir",
|
||||
"--libnames",
|
||||
"--system-libs",
|
||||
]);
|
||||
for c in COMPONENTS {
|
||||
cmd.arg(c[4..].to_lowercase());
|
||||
}
|
||||
let output = cmd.output()?;
|
||||
if !output.status.success() {
|
||||
return Err(io::Error::from(io::ErrorKind::Other));
|
||||
}
|
||||
let output = unsafe { String::from_utf8_unchecked(output.stdout) };
|
||||
let mut lines = output.lines();
|
||||
let cxxflags = lines.next().unwrap();
|
||||
let ldflags = lines.next().unwrap();
|
||||
let libdir = lines.next().unwrap();
|
||||
let lib_names = lines.next().unwrap();
|
||||
let system_libs = lines.next().unwrap();
|
||||
Ok((
|
||||
cxxflags.to_string(),
|
||||
ldflags.to_string(),
|
||||
libdir.to_string(),
|
||||
lib_names.to_string(),
|
||||
system_libs.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn compile_cxx_lib(cxxflags: String) {
|
||||
let mut cc = cc::Build::new();
|
||||
for flag in cxxflags.split_whitespace() {
|
||||
cc.flag(flag);
|
||||
}
|
||||
cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp");
|
||||
println!("cargo:rerun-if-changed=src/lib.cpp");
|
||||
println!("cargo:rerun-if-changed=src/lib.rs");
|
||||
}
|
||||
|
||||
fn link_llvm_components(components: String) {
|
||||
for component in components.split_whitespace() {
|
||||
let component = if let Some(component) = component
|
||||
.strip_prefix("lib")
|
||||
.and_then(|component| component.strip_suffix(".a"))
|
||||
{
|
||||
// Unix (Linux/Mac)
|
||||
// libLLVMfoo.a
|
||||
component
|
||||
} else if let Some(component) = component.strip_suffix(".lib") {
|
||||
// Windows
|
||||
// LLVMfoo.lib
|
||||
component
|
||||
} else {
|
||||
panic!("'{}' does not look like a static library name", component)
|
||||
};
|
||||
println!("cargo:rustc-link-lib={component}");
|
||||
}
|
||||
}
|
||||
|
@ -1,81 +1,81 @@
|
||||
#![allow(non_upper_case_globals)]
|
||||
use llvm_sys::prelude::*;
|
||||
pub use llvm_sys::*;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum LLVMZludaAtomicRMWBinOp {
|
||||
LLVMZludaAtomicRMWBinOpXchg = 0,
|
||||
LLVMZludaAtomicRMWBinOpAdd = 1,
|
||||
LLVMZludaAtomicRMWBinOpSub = 2,
|
||||
LLVMZludaAtomicRMWBinOpAnd = 3,
|
||||
LLVMZludaAtomicRMWBinOpNand = 4,
|
||||
LLVMZludaAtomicRMWBinOpOr = 5,
|
||||
LLVMZludaAtomicRMWBinOpXor = 6,
|
||||
LLVMZludaAtomicRMWBinOpMax = 7,
|
||||
LLVMZludaAtomicRMWBinOpMin = 8,
|
||||
LLVMZludaAtomicRMWBinOpUMax = 9,
|
||||
LLVMZludaAtomicRMWBinOpUMin = 10,
|
||||
LLVMZludaAtomicRMWBinOpFAdd = 11,
|
||||
LLVMZludaAtomicRMWBinOpFSub = 12,
|
||||
LLVMZludaAtomicRMWBinOpFMax = 13,
|
||||
LLVMZludaAtomicRMWBinOpFMin = 14,
|
||||
LLVMZludaAtomicRMWBinOpUIncWrap = 15,
|
||||
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
|
||||
}
|
||||
|
||||
// Backport from LLVM 19
|
||||
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
|
||||
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
|
||||
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
|
||||
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
|
||||
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
|
||||
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
|
||||
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
|
||||
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
|
||||
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
|
||||
| LLVMZludaFastMathNoNaNs
|
||||
| LLVMZludaFastMathNoInfs
|
||||
| LLVMZludaFastMathNoSignedZeros
|
||||
| LLVMZludaFastMathAllowReciprocal
|
||||
| LLVMZludaFastMathAllowContract
|
||||
| LLVMZludaFastMathApproxFunc;
|
||||
|
||||
pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
|
||||
|
||||
extern "C" {
|
||||
pub fn LLVMZludaBuildAlloca(
|
||||
B: LLVMBuilderRef,
|
||||
Ty: LLVMTypeRef,
|
||||
AddrSpace: u32,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef;
|
||||
|
||||
pub fn LLVMZludaBuildAtomicRMW(
|
||||
B: LLVMBuilderRef,
|
||||
op: LLVMZludaAtomicRMWBinOp,
|
||||
PTR: LLVMValueRef,
|
||||
Val: LLVMValueRef,
|
||||
scope: *const i8,
|
||||
ordering: LLVMAtomicOrdering,
|
||||
) -> LLVMValueRef;
|
||||
|
||||
pub fn LLVMZludaBuildAtomicCmpXchg(
|
||||
B: LLVMBuilderRef,
|
||||
Ptr: LLVMValueRef,
|
||||
Cmp: LLVMValueRef,
|
||||
New: LLVMValueRef,
|
||||
scope: *const i8,
|
||||
SuccessOrdering: LLVMAtomicOrdering,
|
||||
FailureOrdering: LLVMAtomicOrdering,
|
||||
) -> LLVMValueRef;
|
||||
|
||||
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
|
||||
|
||||
pub fn LLVMZludaBuildFence(
|
||||
B: LLVMBuilderRef,
|
||||
ordering: LLVMAtomicOrdering,
|
||||
scope: *const i8,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef;
|
||||
}
|
||||
#![allow(non_upper_case_globals)]
|
||||
use llvm_sys::prelude::*;
|
||||
pub use llvm_sys::*;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum LLVMZludaAtomicRMWBinOp {
|
||||
LLVMZludaAtomicRMWBinOpXchg = 0,
|
||||
LLVMZludaAtomicRMWBinOpAdd = 1,
|
||||
LLVMZludaAtomicRMWBinOpSub = 2,
|
||||
LLVMZludaAtomicRMWBinOpAnd = 3,
|
||||
LLVMZludaAtomicRMWBinOpNand = 4,
|
||||
LLVMZludaAtomicRMWBinOpOr = 5,
|
||||
LLVMZludaAtomicRMWBinOpXor = 6,
|
||||
LLVMZludaAtomicRMWBinOpMax = 7,
|
||||
LLVMZludaAtomicRMWBinOpMin = 8,
|
||||
LLVMZludaAtomicRMWBinOpUMax = 9,
|
||||
LLVMZludaAtomicRMWBinOpUMin = 10,
|
||||
LLVMZludaAtomicRMWBinOpFAdd = 11,
|
||||
LLVMZludaAtomicRMWBinOpFSub = 12,
|
||||
LLVMZludaAtomicRMWBinOpFMax = 13,
|
||||
LLVMZludaAtomicRMWBinOpFMin = 14,
|
||||
LLVMZludaAtomicRMWBinOpUIncWrap = 15,
|
||||
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
|
||||
}
|
||||
|
||||
// Backport from LLVM 19
|
||||
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
|
||||
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
|
||||
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
|
||||
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
|
||||
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
|
||||
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
|
||||
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
|
||||
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
|
||||
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
|
||||
| LLVMZludaFastMathNoNaNs
|
||||
| LLVMZludaFastMathNoInfs
|
||||
| LLVMZludaFastMathNoSignedZeros
|
||||
| LLVMZludaFastMathAllowReciprocal
|
||||
| LLVMZludaFastMathAllowContract
|
||||
| LLVMZludaFastMathApproxFunc;
|
||||
|
||||
pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
|
||||
|
||||
extern "C" {
|
||||
pub fn LLVMZludaBuildAlloca(
|
||||
B: LLVMBuilderRef,
|
||||
Ty: LLVMTypeRef,
|
||||
AddrSpace: u32,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef;
|
||||
|
||||
pub fn LLVMZludaBuildAtomicRMW(
|
||||
B: LLVMBuilderRef,
|
||||
op: LLVMZludaAtomicRMWBinOp,
|
||||
PTR: LLVMValueRef,
|
||||
Val: LLVMValueRef,
|
||||
scope: *const i8,
|
||||
ordering: LLVMAtomicOrdering,
|
||||
) -> LLVMValueRef;
|
||||
|
||||
pub fn LLVMZludaBuildAtomicCmpXchg(
|
||||
B: LLVMBuilderRef,
|
||||
Ptr: LLVMValueRef,
|
||||
Cmp: LLVMValueRef,
|
||||
New: LLVMValueRef,
|
||||
scope: *const i8,
|
||||
SuccessOrdering: LLVMAtomicOrdering,
|
||||
FailureOrdering: LLVMAtomicOrdering,
|
||||
) -> LLVMValueRef;
|
||||
|
||||
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
|
||||
|
||||
pub fn LLVMZludaBuildFence(
|
||||
B: LLVMBuilderRef,
|
||||
ordering: LLVMAtomicOrdering,
|
||||
scope: *const i8,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef;
|
||||
}
|
||||
|
@ -1,191 +1,191 @@
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let is_declaration = method.body.is_none();
|
||||
let mut body = Vec::new();
|
||||
let mut remap_returns = Vec::new();
|
||||
if !method.is_kernel {
|
||||
for arg in method.return_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name =
|
||||
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
for arg in method.input_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name =
|
||||
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
body.push(Statement::Instruction(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: arg.v_type.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: old_name,
|
||||
src2: arg.name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
}
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
for statement in statements {
|
||||
run_statement(resolver, &remap_returns, &mut body, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(body)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Call {
|
||||
mut data,
|
||||
mut arguments,
|
||||
}) => {
|
||||
let mut post_st = Vec::new();
|
||||
for ((type_, space), ident) in data
|
||||
.input_arguments
|
||||
.iter_mut()
|
||||
.zip(arguments.input_arguments.iter_mut())
|
||||
{
|
||||
if *space == ptx_parser::StateSpace::Param {
|
||||
*space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = *ident;
|
||||
*ident = resolver
|
||||
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
|
||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: *ident,
|
||||
src: old_name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
for ((type_, space), ident) in data
|
||||
.return_arguments
|
||||
.iter_mut()
|
||||
.zip(arguments.return_arguments.iter_mut())
|
||||
{
|
||||
if *space == ptx_parser::StateSpace::Param {
|
||||
*space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = *ident;
|
||||
*ident = resolver
|
||||
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
|
||||
post_st.push(Statement::Instruction(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: type_.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: old_name,
|
||||
src2: *ident,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
result.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}));
|
||||
result.extend(post_st.into_iter());
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
||||
for (old_name, new_name, type_) in remap_returns.iter() {
|
||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: *new_name,
|
||||
src: *old_name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
result.push(Statement::Instruction(ast::Instruction::Ret { data }));
|
||||
}
|
||||
statement => {
|
||||
result.push(statement);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let is_declaration = method.body.is_none();
|
||||
let mut body = Vec::new();
|
||||
let mut remap_returns = Vec::new();
|
||||
if !method.is_kernel {
|
||||
for arg in method.return_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name =
|
||||
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
for arg in method.input_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name =
|
||||
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
body.push(Statement::Instruction(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: arg.v_type.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: old_name,
|
||||
src2: arg.name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
}
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
for statement in statements {
|
||||
run_statement(resolver, &remap_returns, &mut body, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(body)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Call {
|
||||
mut data,
|
||||
mut arguments,
|
||||
}) => {
|
||||
let mut post_st = Vec::new();
|
||||
for ((type_, space), ident) in data
|
||||
.input_arguments
|
||||
.iter_mut()
|
||||
.zip(arguments.input_arguments.iter_mut())
|
||||
{
|
||||
if *space == ptx_parser::StateSpace::Param {
|
||||
*space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = *ident;
|
||||
*ident = resolver
|
||||
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
|
||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: *ident,
|
||||
src: old_name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
for ((type_, space), ident) in data
|
||||
.return_arguments
|
||||
.iter_mut()
|
||||
.zip(arguments.return_arguments.iter_mut())
|
||||
{
|
||||
if *space == ptx_parser::StateSpace::Param {
|
||||
*space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = *ident;
|
||||
*ident = resolver
|
||||
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
|
||||
post_st.push(Statement::Instruction(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: type_.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: old_name,
|
||||
src2: *ident,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
result.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}));
|
||||
result.extend(post_st.into_iter());
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
||||
for (old_name, new_name, type_) in remap_returns.iter() {
|
||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: *new_name,
|
||||
src: *old_name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
result.push(Statement::Instruction(ast::Instruction::Ret { data }));
|
||||
}
|
||||
statement => {
|
||||
result.push(statement);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,301 +1,301 @@
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: Function2<
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
body,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||
rounding_mode_f32: method.rounding_mode_f32,
|
||||
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut visitor = FlattenArguments::new(resolver, result);
|
||||
let new_statement = statement.visit_map(&mut visitor)?;
|
||||
visitor.result.push(new_statement);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct FlattenArguments<'a, 'input> {
|
||||
result: &'a mut Vec<ExpandedStatement>,
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
post_stmts: Vec<ExpandedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> FlattenArguments<'a, 'input> {
|
||||
fn new(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
result: &'a mut Vec<ExpandedStatement>,
|
||||
) -> Self {
|
||||
FlattenArguments {
|
||||
result,
|
||||
resolver,
|
||||
post_stmts: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
fn reg_offset(
|
||||
&mut self,
|
||||
reg: SpirvWord,
|
||||
offset: i32,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||
(type_, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
if state_space == ast::StateSpace::Reg {
|
||||
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
|
||||
if *reg_space != ast::StateSpace::Reg {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let reg_scalar_type = match reg_type {
|
||||
ast::Type::Scalar(underlying_type) => *underlying_type,
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let reg_type = reg_type.clone();
|
||||
let id_constant_stmt = self
|
||||
.resolver
|
||||
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: reg_scalar_type,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let arith_details = match reg_scalar_type.kind() {
|
||||
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
}),
|
||||
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
})
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let id_add_result = self
|
||||
.resolver
|
||||
.register_unnamed(Some((reg_type, state_space)));
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Add {
|
||||
data: arith_details,
|
||||
arguments: ast::AddArgs {
|
||||
dst: id_add_result,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
}));
|
||||
Ok(id_add_result)
|
||||
} else {
|
||||
let id_constant_stmt = self.resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::S64),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::S64,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let dst = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), state_space)));
|
||||
self.result.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: type_.clone(),
|
||||
state_space: state_space,
|
||||
dst,
|
||||
ptr_src: reg,
|
||||
offset_src: id_constant_stmt,
|
||||
}));
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
fn immediate(
|
||||
&mut self,
|
||||
value: ast::ImmediateValue,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (scalar_t, state_space) =
|
||||
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||
(*scalar, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
let id = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id,
|
||||
typ: scalar_t,
|
||||
value,
|
||||
}));
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
fn vec_member(
|
||||
&mut self,
|
||||
vector_ident: SpirvWord,
|
||||
member: u8,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
|
||||
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
||||
(*vector_width, *scalar_t, *space)
|
||||
}
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temporary = self
|
||||
.resolver
|
||||
.register_unnamed(Some((scalar_type.into(), space)));
|
||||
if is_dst {
|
||||
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
vector_dst: vector_ident,
|
||||
vector_src: vector_ident,
|
||||
scalar_src: temporary,
|
||||
member,
|
||||
}));
|
||||
} else {
|
||||
self.result.push(Statement::VectorRead(VectorRead {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
scalar_dst: temporary,
|
||||
vector_src: vector_ident,
|
||||
member,
|
||||
}));
|
||||
}
|
||||
Ok(temporary)
|
||||
}
|
||||
|
||||
fn vec_pack(
|
||||
&mut self,
|
||||
vector_elements: Vec<SpirvWord>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (width, scalar_t, state_space) = match type_space {
|
||||
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temporary_vector = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
|
||||
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract: is_dst,
|
||||
typ: scalar_t,
|
||||
packed: temporary_vector,
|
||||
unpacked: vector_elements,
|
||||
relaxed_type_check,
|
||||
});
|
||||
if is_dst {
|
||||
self.post_stmts.push(statement);
|
||||
} else {
|
||||
self.result.push(statement);
|
||||
}
|
||||
Ok(temporary_vector)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
|
||||
for FlattenArguments<'a, 'b>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: ast::ParsedOperand<SpirvWord>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
match args {
|
||||
ast::ParsedOperand::Reg(r) => self.reg(r),
|
||||
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
|
||||
ast::ParsedOperand::RegOffset(reg, offset) => {
|
||||
self.reg_offset(reg, offset, type_space, is_dst)
|
||||
}
|
||||
ast::ParsedOperand::VecMember(vec, member) => {
|
||||
self.vec_member(vec, member, type_space, is_dst)
|
||||
}
|
||||
ast::ParsedOperand::VecPack(vecs) => {
|
||||
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
name: SpirvWord,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
|
||||
self.reg(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for FlattenArguments<'_, '_> {
|
||||
fn drop(&mut self) {
|
||||
self.result.extend(self.post_stmts.drain(..));
|
||||
}
|
||||
}
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: Function2<
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
body,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||
rounding_mode_f32: method.rounding_mode_f32,
|
||||
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut visitor = FlattenArguments::new(resolver, result);
|
||||
let new_statement = statement.visit_map(&mut visitor)?;
|
||||
visitor.result.push(new_statement);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct FlattenArguments<'a, 'input> {
|
||||
result: &'a mut Vec<ExpandedStatement>,
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
post_stmts: Vec<ExpandedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> FlattenArguments<'a, 'input> {
|
||||
fn new(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
result: &'a mut Vec<ExpandedStatement>,
|
||||
) -> Self {
|
||||
FlattenArguments {
|
||||
result,
|
||||
resolver,
|
||||
post_stmts: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
fn reg_offset(
|
||||
&mut self,
|
||||
reg: SpirvWord,
|
||||
offset: i32,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||
(type_, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
if state_space == ast::StateSpace::Reg {
|
||||
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
|
||||
if *reg_space != ast::StateSpace::Reg {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let reg_scalar_type = match reg_type {
|
||||
ast::Type::Scalar(underlying_type) => *underlying_type,
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let reg_type = reg_type.clone();
|
||||
let id_constant_stmt = self
|
||||
.resolver
|
||||
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: reg_scalar_type,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let arith_details = match reg_scalar_type.kind() {
|
||||
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
}),
|
||||
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
})
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let id_add_result = self
|
||||
.resolver
|
||||
.register_unnamed(Some((reg_type, state_space)));
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Add {
|
||||
data: arith_details,
|
||||
arguments: ast::AddArgs {
|
||||
dst: id_add_result,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
}));
|
||||
Ok(id_add_result)
|
||||
} else {
|
||||
let id_constant_stmt = self.resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::S64),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::S64,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let dst = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), state_space)));
|
||||
self.result.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: type_.clone(),
|
||||
state_space: state_space,
|
||||
dst,
|
||||
ptr_src: reg,
|
||||
offset_src: id_constant_stmt,
|
||||
}));
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
fn immediate(
|
||||
&mut self,
|
||||
value: ast::ImmediateValue,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (scalar_t, state_space) =
|
||||
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||
(*scalar, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
let id = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id,
|
||||
typ: scalar_t,
|
||||
value,
|
||||
}));
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
fn vec_member(
|
||||
&mut self,
|
||||
vector_ident: SpirvWord,
|
||||
member: u8,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
|
||||
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
||||
(*vector_width, *scalar_t, *space)
|
||||
}
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temporary = self
|
||||
.resolver
|
||||
.register_unnamed(Some((scalar_type.into(), space)));
|
||||
if is_dst {
|
||||
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
vector_dst: vector_ident,
|
||||
vector_src: vector_ident,
|
||||
scalar_src: temporary,
|
||||
member,
|
||||
}));
|
||||
} else {
|
||||
self.result.push(Statement::VectorRead(VectorRead {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
scalar_dst: temporary,
|
||||
vector_src: vector_ident,
|
||||
member,
|
||||
}));
|
||||
}
|
||||
Ok(temporary)
|
||||
}
|
||||
|
||||
fn vec_pack(
|
||||
&mut self,
|
||||
vector_elements: Vec<SpirvWord>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (width, scalar_t, state_space) = match type_space {
|
||||
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temporary_vector = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
|
||||
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract: is_dst,
|
||||
typ: scalar_t,
|
||||
packed: temporary_vector,
|
||||
unpacked: vector_elements,
|
||||
relaxed_type_check,
|
||||
});
|
||||
if is_dst {
|
||||
self.post_stmts.push(statement);
|
||||
} else {
|
||||
self.result.push(statement);
|
||||
}
|
||||
Ok(temporary_vector)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
|
||||
for FlattenArguments<'a, 'b>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: ast::ParsedOperand<SpirvWord>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
match args {
|
||||
ast::ParsedOperand::Reg(r) => self.reg(r),
|
||||
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
|
||||
ast::ParsedOperand::RegOffset(reg, offset) => {
|
||||
self.reg_offset(reg, offset, type_space, is_dst)
|
||||
}
|
||||
ast::ParsedOperand::VecMember(vec, member) => {
|
||||
self.vec_member(vec, member, type_space, is_dst)
|
||||
}
|
||||
ast::ParsedOperand::VecPack(vecs) => {
|
||||
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
name: SpirvWord,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
|
||||
self.reg(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for FlattenArguments<'_, '_> {
|
||||
fn drop(&mut self) {
|
||||
self.result.extend(self.post_stmts.drain(..));
|
||||
}
|
||||
}
|
||||
|
@ -1,208 +1,208 @@
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
|
||||
let mut sreg_to_function =
|
||||
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
|
||||
SpecialRegistersMap2::foreach_declaration(
|
||||
resolver,
|
||||
|sreg, (return_arguments, name, input_arguments)| {
|
||||
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
}));
|
||||
sreg_to_function.insert(sreg, name);
|
||||
},
|
||||
);
|
||||
let mut visitor = SpecialRegisterResolver {
|
||||
resolver,
|
||||
special_registers,
|
||||
sreg_to_function,
|
||||
result: Vec::new(),
|
||||
};
|
||||
for directive in directives.into_iter() {
|
||||
result.push(run_directive(&mut visitor, directive)?);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
directive: UnconditionalDirective,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
method: UnconditionalFunction,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(visitor, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
result: &mut Vec<UnconditionalStatement>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let converted_statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.result.drain(..));
|
||||
result.push(converted_statement);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct SpecialRegisterResolver<'a, 'input> {
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||
result: Vec<UnconditionalStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input>
|
||||
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
|
||||
for SpecialRegisterResolver<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
operand: ast::ParsedOperand<SpirvWord>,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
|
||||
map_operand(operand, &mut |ident, vector_index| {
|
||||
self.replace_sreg(ident, vector_index, is_dst)
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
||||
fn replace_sreg(
|
||||
&mut self,
|
||||
name: SpirvWord,
|
||||
vector_index: Option<u8>,
|
||||
is_dst: bool,
|
||||
) -> Result<Option<SpirvWord>, TranslateError> {
|
||||
if let Some(sreg) = self.special_registers.get(name) {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||
(Some(idx), Some(inp_type)) => {
|
||||
if inp_type != ast::ScalarType::U8 {
|
||||
return Err(TranslateError::Unreachable);
|
||||
}
|
||||
let constant = self.resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(inp_type),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: constant,
|
||||
typ: inp_type,
|
||||
value: ast::ImmediateValue::U64(idx as u64),
|
||||
}));
|
||||
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
|
||||
}
|
||||
(None, None) => Vec::new(),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let return_type = sreg.get_function_return_type();
|
||||
let fn_result = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
|
||||
let return_arguments = vec![(
|
||||
fn_result,
|
||||
ast::Type::Scalar(return_type),
|
||||
ast::StateSpace::Reg,
|
||||
)];
|
||||
let data = ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
};
|
||||
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
|
||||
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||
func: self.sreg_to_function[&sreg],
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
|
||||
.collect(),
|
||||
};
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}));
|
||||
Ok(Some(fn_result))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_operand<T: Copy, Err>(
|
||||
this: ast::ParsedOperand<T>,
|
||||
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
|
||||
) -> Result<ast::ParsedOperand<T>, Err> {
|
||||
Ok(match this {
|
||||
ast::ParsedOperand::Reg(ident) => {
|
||||
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
|
||||
}
|
||||
ast::ParsedOperand::RegOffset(ident, offset) => {
|
||||
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
|
||||
}
|
||||
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
||||
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
|
||||
Some(ident) => ast::ParsedOperand::Reg(ident),
|
||||
None => ast::ParsedOperand::VecMember(ident, member),
|
||||
},
|
||||
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||
idents
|
||||
.into_iter()
|
||||
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
})
|
||||
}
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
|
||||
let mut sreg_to_function =
|
||||
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
|
||||
SpecialRegistersMap2::foreach_declaration(
|
||||
resolver,
|
||||
|sreg, (return_arguments, name, input_arguments)| {
|
||||
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
}));
|
||||
sreg_to_function.insert(sreg, name);
|
||||
},
|
||||
);
|
||||
let mut visitor = SpecialRegisterResolver {
|
||||
resolver,
|
||||
special_registers,
|
||||
sreg_to_function,
|
||||
result: Vec::new(),
|
||||
};
|
||||
for directive in directives.into_iter() {
|
||||
result.push(run_directive(&mut visitor, directive)?);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
directive: UnconditionalDirective,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
method: UnconditionalFunction,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(visitor, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
result: &mut Vec<UnconditionalStatement>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let converted_statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.result.drain(..));
|
||||
result.push(converted_statement);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct SpecialRegisterResolver<'a, 'input> {
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||
result: Vec<UnconditionalStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input>
|
||||
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
|
||||
for SpecialRegisterResolver<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
operand: ast::ParsedOperand<SpirvWord>,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
|
||||
map_operand(operand, &mut |ident, vector_index| {
|
||||
self.replace_sreg(ident, vector_index, is_dst)
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
||||
fn replace_sreg(
|
||||
&mut self,
|
||||
name: SpirvWord,
|
||||
vector_index: Option<u8>,
|
||||
is_dst: bool,
|
||||
) -> Result<Option<SpirvWord>, TranslateError> {
|
||||
if let Some(sreg) = self.special_registers.get(name) {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||
(Some(idx), Some(inp_type)) => {
|
||||
if inp_type != ast::ScalarType::U8 {
|
||||
return Err(TranslateError::Unreachable);
|
||||
}
|
||||
let constant = self.resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(inp_type),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: constant,
|
||||
typ: inp_type,
|
||||
value: ast::ImmediateValue::U64(idx as u64),
|
||||
}));
|
||||
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
|
||||
}
|
||||
(None, None) => Vec::new(),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let return_type = sreg.get_function_return_type();
|
||||
let fn_result = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
|
||||
let return_arguments = vec![(
|
||||
fn_result,
|
||||
ast::Type::Scalar(return_type),
|
||||
ast::StateSpace::Reg,
|
||||
)];
|
||||
let data = ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
};
|
||||
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
|
||||
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||
func: self.sreg_to_function[&sreg],
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
|
||||
.collect(),
|
||||
};
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}));
|
||||
Ok(Some(fn_result))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_operand<T: Copy, Err>(
|
||||
this: ast::ParsedOperand<T>,
|
||||
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
|
||||
) -> Result<ast::ParsedOperand<T>, Err> {
|
||||
Ok(match this {
|
||||
ast::ParsedOperand::Reg(ident) => {
|
||||
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
|
||||
}
|
||||
ast::ParsedOperand::RegOffset(ident, offset) => {
|
||||
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
|
||||
}
|
||||
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
||||
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
|
||||
Some(ident) => ast::ParsedOperand::Reg(ident),
|
||||
None => ast::ParsedOperand::VecMember(ident, member),
|
||||
},
|
||||
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||
idents
|
||||
.into_iter()
|
||||
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
@ -1,45 +1,45 @@
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(directives.len());
|
||||
for mut directive in directives.into_iter() {
|
||||
run_directive(&mut result, &mut directive)?;
|
||||
result.push(directive);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match directive {
|
||||
Directive2::Variable(..) => {}
|
||||
Directive2::Method(function2) => run_function(result, function2),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_function<'input>(
|
||||
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) {
|
||||
function.body = function.body.take().map(|statements| {
|
||||
statements
|
||||
.into_iter()
|
||||
.filter_map(|statement| match statement {
|
||||
Statement::Variable(var @ ast::Variable {
|
||||
state_space:
|
||||
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
|
||||
..
|
||||
}) => {
|
||||
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
|
||||
None
|
||||
}
|
||||
s => Some(s),
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
}
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(directives.len());
|
||||
for mut directive in directives.into_iter() {
|
||||
run_directive(&mut result, &mut directive)?;
|
||||
result.push(directive);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match directive {
|
||||
Directive2::Variable(..) => {}
|
||||
Directive2::Method(function2) => run_function(result, function2),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_function<'input>(
|
||||
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) {
|
||||
function.body = function.body.take().map(|statements| {
|
||||
statements
|
||||
.into_iter()
|
||||
.filter_map(|statement| match statement {
|
||||
Statement::Variable(var @ ast::Variable {
|
||||
state_space:
|
||||
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
|
||||
..
|
||||
}) => {
|
||||
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
|
||||
None
|
||||
}
|
||||
s => Some(s),
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
}
|
||||
|
@ -1,404 +1,404 @@
|
||||
use super::*;
|
||||
// This pass:
|
||||
// * Turns all .local, .param and .reg in-body variables into .local variables
|
||||
// (if _not_ an input method argument)
|
||||
// * Inserts explicit `ld`/`st` for newly converted .reg variables
|
||||
// * Fixup state space of all existing `ld`/`st` instructions into newly
|
||||
// converted variables
|
||||
// * Turns `.entry` input arguments into param::entry and all related `.param`
|
||||
// loads into `param::entry` loads
|
||||
// * All `.func` input arguments are turned into `.reg` arguments by another
|
||||
// pass, so we do nothing there
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
let visitor = InsertMemSSAVisitor::new(resolver);
|
||||
Directive2::Method(run_method(visitor, method)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'a, 'input>(
|
||||
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let is_kernel = method.is_kernel;
|
||||
if is_kernel {
|
||||
for arg in method.input_arguments.iter_mut() {
|
||||
let old_name = arg.name;
|
||||
let old_space = arg.state_space;
|
||||
let new_space = ast::StateSpace::ParamEntry;
|
||||
let new_name = visitor
|
||||
.resolver
|
||||
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
||||
visitor.input_argument(old_name, new_name, old_space)?;
|
||||
arg.name = new_name;
|
||||
arg.state_space = new_space;
|
||||
}
|
||||
};
|
||||
for arg in method.return_arguments.iter_mut() {
|
||||
visitor.visit_variable(arg)?;
|
||||
}
|
||||
let return_arguments = &method.return_arguments[..];
|
||||
let body = method
|
||||
.body
|
||||
.map(move |statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(&mut visitor, return_arguments, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
|
||||
return_arguments: &[ast::Variable<SpirvWord>],
|
||||
result: &mut Vec<ExpandedStatement>,
|
||||
statement: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
||||
let statement = if return_arguments.is_empty() {
|
||||
Statement::Instruction(ast::Instruction::Ret { data })
|
||||
} else {
|
||||
Statement::RetValue(
|
||||
data,
|
||||
return_arguments
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
if arg.state_space != ast::StateSpace::Local {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
Ok((arg.name, arg.v_type.clone()))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)
|
||||
};
|
||||
let new_statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(new_statement);
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::Variable(mut var) => {
|
||||
visitor.visit_variable(&mut var)?;
|
||||
result.push(Statement::Variable(var));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
|
||||
let instruction = visitor.visit_ld(data, arguments)?;
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
|
||||
let instruction = visitor.visit_st(data, arguments)?;
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::PtrAccess(ptr_access) => {
|
||||
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
|
||||
let statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(statement);
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
s => {
|
||||
let new_statement = s.visit_map(visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(new_statement);
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct InsertMemSSAVisitor<'a, 'input> {
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
variables: FxHashMap<SpirvWord, RemapAction>,
|
||||
pre: Vec<ast::Instruction<SpirvWord>>,
|
||||
post: Vec<ast::Instruction<SpirvWord>>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||
Self {
|
||||
resolver,
|
||||
variables: FxHashMap::default(),
|
||||
pre: Vec::new(),
|
||||
post: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn input_argument(
|
||||
&mut self,
|
||||
old_name: SpirvWord,
|
||||
new_name: SpirvWord,
|
||||
old_space: ast::StateSpace,
|
||||
) -> Result<(), TranslateError> {
|
||||
if old_space != ast::StateSpace::Param {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::LDStSpaceChange {
|
||||
name: new_name,
|
||||
old_space,
|
||||
new_space: ast::StateSpace::ParamEntry,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn variable(
|
||||
&mut self,
|
||||
type_: &ast::Type,
|
||||
old_name: SpirvWord,
|
||||
new_name: SpirvWord,
|
||||
old_space: ast::StateSpace,
|
||||
) -> Result<bool, TranslateError> {
|
||||
Ok(match old_space {
|
||||
ast::StateSpace::Reg => {
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::PreLdPostSt {
|
||||
name: new_name,
|
||||
type_: type_.clone(),
|
||||
},
|
||||
);
|
||||
true
|
||||
}
|
||||
ast::StateSpace::Param => {
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space: ast::StateSpace::Local,
|
||||
name: new_name,
|
||||
},
|
||||
);
|
||||
true
|
||||
}
|
||||
// Good as-is
|
||||
ast::StateSpace::Local
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::Shared
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc => return Err(error_unreachable()),
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_st(
|
||||
&self,
|
||||
mut data: ast::StData,
|
||||
mut arguments: ast::StArgs<SpirvWord>,
|
||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src1) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { .. } => {}
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
} => {
|
||||
if data.state_space != *old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
data.state_space = *new_space;
|
||||
arguments.src1 = *name;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ast::Instruction::St { data, arguments })
|
||||
}
|
||||
|
||||
fn visit_ld(
|
||||
&self,
|
||||
mut data: ast::LdDetails,
|
||||
mut arguments: ast::LdArgs<SpirvWord>,
|
||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { .. } => {}
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
} => {
|
||||
if data.state_space != *old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
data.state_space = *new_space;
|
||||
arguments.src = *name;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ast::Instruction::Ld { data, arguments })
|
||||
}
|
||||
|
||||
fn visit_ptr_access(
|
||||
&mut self,
|
||||
ptr_access: PtrAccess<SpirvWord>,
|
||||
) -> Result<PtrAccess<SpirvWord>, TranslateError> {
|
||||
let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
|
||||
Some(RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
}) => (*old_space, *new_space, *name),
|
||||
Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
|
||||
};
|
||||
if ptr_access.state_space != old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
// Propagate space changes in dst
|
||||
let new_dst = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
|
||||
self.variables.insert(
|
||||
ptr_access.dst,
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name: new_dst,
|
||||
},
|
||||
);
|
||||
Ok(PtrAccess {
|
||||
ptr_src: name,
|
||||
dst: new_dst,
|
||||
state_space: new_space,
|
||||
..ptr_access
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
||||
let old_space = match var.state_space {
|
||||
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
|
||||
// Do nothing
|
||||
ptx_parser::StateSpace::Local => return Ok(()),
|
||||
// Handled by another pass
|
||||
ptx_parser::StateSpace::Generic
|
||||
| ptx_parser::StateSpace::SharedCluster
|
||||
| ptx_parser::StateSpace::ParamEntry
|
||||
| ptx_parser::StateSpace::Global
|
||||
| ptx_parser::StateSpace::SharedCta
|
||||
| ptx_parser::StateSpace::Const
|
||||
| ptx_parser::StateSpace::Shared
|
||||
| ptx_parser::StateSpace::ParamFunc => return Ok(()),
|
||||
};
|
||||
let old_name = var.name;
|
||||
let new_space = ast::StateSpace::Local;
|
||||
let new_name = self
|
||||
.resolver
|
||||
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
||||
self.variable(&var.v_type, old_name, new_name, old_space)?;
|
||||
var.name = new_name;
|
||||
var.state_space = new_space;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
|
||||
for InsertMemSSAVisitor<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
ident: SpirvWord,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&ident) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { name, type_ } => {
|
||||
if is_dst {
|
||||
let temp = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||
self.post.push(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
state_space: ast::StateSpace::Local,
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: type_.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: *name,
|
||||
src2: temp,
|
||||
},
|
||||
});
|
||||
Ok(temp)
|
||||
} else {
|
||||
let temp = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||
self.pre.push(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
state_space: ast::StateSpace::Local,
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: temp,
|
||||
src: *name,
|
||||
},
|
||||
});
|
||||
Ok(temp)
|
||||
}
|
||||
}
|
||||
RemapAction::LDStSpaceChange { .. } => {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(ident)
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.visit(args, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum RemapAction {
|
||||
PreLdPostSt {
|
||||
name: SpirvWord,
|
||||
type_: ast::Type,
|
||||
},
|
||||
LDStSpaceChange {
|
||||
old_space: ast::StateSpace,
|
||||
new_space: ast::StateSpace,
|
||||
name: SpirvWord,
|
||||
},
|
||||
}
|
||||
use super::*;
|
||||
// This pass:
|
||||
// * Turns all .local, .param and .reg in-body variables into .local variables
|
||||
// (if _not_ an input method argument)
|
||||
// * Inserts explicit `ld`/`st` for newly converted .reg variables
|
||||
// * Fixup state space of all existing `ld`/`st` instructions into newly
|
||||
// converted variables
|
||||
// * Turns `.entry` input arguments into param::entry and all related `.param`
|
||||
// loads into `param::entry` loads
|
||||
// * All `.func` input arguments are turned into `.reg` arguments by another
|
||||
// pass, so we do nothing there
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
let visitor = InsertMemSSAVisitor::new(resolver);
|
||||
Directive2::Method(run_method(visitor, method)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'a, 'input>(
|
||||
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let is_kernel = method.is_kernel;
|
||||
if is_kernel {
|
||||
for arg in method.input_arguments.iter_mut() {
|
||||
let old_name = arg.name;
|
||||
let old_space = arg.state_space;
|
||||
let new_space = ast::StateSpace::ParamEntry;
|
||||
let new_name = visitor
|
||||
.resolver
|
||||
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
||||
visitor.input_argument(old_name, new_name, old_space)?;
|
||||
arg.name = new_name;
|
||||
arg.state_space = new_space;
|
||||
}
|
||||
};
|
||||
for arg in method.return_arguments.iter_mut() {
|
||||
visitor.visit_variable(arg)?;
|
||||
}
|
||||
let return_arguments = &method.return_arguments[..];
|
||||
let body = method
|
||||
.body
|
||||
.map(move |statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(&mut visitor, return_arguments, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
|
||||
return_arguments: &[ast::Variable<SpirvWord>],
|
||||
result: &mut Vec<ExpandedStatement>,
|
||||
statement: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
||||
let statement = if return_arguments.is_empty() {
|
||||
Statement::Instruction(ast::Instruction::Ret { data })
|
||||
} else {
|
||||
Statement::RetValue(
|
||||
data,
|
||||
return_arguments
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
if arg.state_space != ast::StateSpace::Local {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
Ok((arg.name, arg.v_type.clone()))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)
|
||||
};
|
||||
let new_statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(new_statement);
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::Variable(mut var) => {
|
||||
visitor.visit_variable(&mut var)?;
|
||||
result.push(Statement::Variable(var));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
|
||||
let instruction = visitor.visit_ld(data, arguments)?;
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
|
||||
let instruction = visitor.visit_st(data, arguments)?;
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::PtrAccess(ptr_access) => {
|
||||
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
|
||||
let statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(statement);
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
s => {
|
||||
let new_statement = s.visit_map(visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(new_statement);
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct InsertMemSSAVisitor<'a, 'input> {
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
variables: FxHashMap<SpirvWord, RemapAction>,
|
||||
pre: Vec<ast::Instruction<SpirvWord>>,
|
||||
post: Vec<ast::Instruction<SpirvWord>>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||
Self {
|
||||
resolver,
|
||||
variables: FxHashMap::default(),
|
||||
pre: Vec::new(),
|
||||
post: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn input_argument(
|
||||
&mut self,
|
||||
old_name: SpirvWord,
|
||||
new_name: SpirvWord,
|
||||
old_space: ast::StateSpace,
|
||||
) -> Result<(), TranslateError> {
|
||||
if old_space != ast::StateSpace::Param {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::LDStSpaceChange {
|
||||
name: new_name,
|
||||
old_space,
|
||||
new_space: ast::StateSpace::ParamEntry,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn variable(
|
||||
&mut self,
|
||||
type_: &ast::Type,
|
||||
old_name: SpirvWord,
|
||||
new_name: SpirvWord,
|
||||
old_space: ast::StateSpace,
|
||||
) -> Result<bool, TranslateError> {
|
||||
Ok(match old_space {
|
||||
ast::StateSpace::Reg => {
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::PreLdPostSt {
|
||||
name: new_name,
|
||||
type_: type_.clone(),
|
||||
},
|
||||
);
|
||||
true
|
||||
}
|
||||
ast::StateSpace::Param => {
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space: ast::StateSpace::Local,
|
||||
name: new_name,
|
||||
},
|
||||
);
|
||||
true
|
||||
}
|
||||
// Good as-is
|
||||
ast::StateSpace::Local
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::Shared
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc => return Err(error_unreachable()),
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_st(
|
||||
&self,
|
||||
mut data: ast::StData,
|
||||
mut arguments: ast::StArgs<SpirvWord>,
|
||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src1) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { .. } => {}
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
} => {
|
||||
if data.state_space != *old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
data.state_space = *new_space;
|
||||
arguments.src1 = *name;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ast::Instruction::St { data, arguments })
|
||||
}
|
||||
|
||||
fn visit_ld(
|
||||
&self,
|
||||
mut data: ast::LdDetails,
|
||||
mut arguments: ast::LdArgs<SpirvWord>,
|
||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { .. } => {}
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
} => {
|
||||
if data.state_space != *old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
data.state_space = *new_space;
|
||||
arguments.src = *name;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ast::Instruction::Ld { data, arguments })
|
||||
}
|
||||
|
||||
fn visit_ptr_access(
|
||||
&mut self,
|
||||
ptr_access: PtrAccess<SpirvWord>,
|
||||
) -> Result<PtrAccess<SpirvWord>, TranslateError> {
|
||||
let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
|
||||
Some(RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
}) => (*old_space, *new_space, *name),
|
||||
Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
|
||||
};
|
||||
if ptr_access.state_space != old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
// Propagate space changes in dst
|
||||
let new_dst = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
|
||||
self.variables.insert(
|
||||
ptr_access.dst,
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name: new_dst,
|
||||
},
|
||||
);
|
||||
Ok(PtrAccess {
|
||||
ptr_src: name,
|
||||
dst: new_dst,
|
||||
state_space: new_space,
|
||||
..ptr_access
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
||||
let old_space = match var.state_space {
|
||||
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
|
||||
// Do nothing
|
||||
ptx_parser::StateSpace::Local => return Ok(()),
|
||||
// Handled by another pass
|
||||
ptx_parser::StateSpace::Generic
|
||||
| ptx_parser::StateSpace::SharedCluster
|
||||
| ptx_parser::StateSpace::ParamEntry
|
||||
| ptx_parser::StateSpace::Global
|
||||
| ptx_parser::StateSpace::SharedCta
|
||||
| ptx_parser::StateSpace::Const
|
||||
| ptx_parser::StateSpace::Shared
|
||||
| ptx_parser::StateSpace::ParamFunc => return Ok(()),
|
||||
};
|
||||
let old_name = var.name;
|
||||
let new_space = ast::StateSpace::Local;
|
||||
let new_name = self
|
||||
.resolver
|
||||
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
||||
self.variable(&var.v_type, old_name, new_name, old_space)?;
|
||||
var.name = new_name;
|
||||
var.state_space = new_space;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
|
||||
for InsertMemSSAVisitor<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
ident: SpirvWord,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&ident) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { name, type_ } => {
|
||||
if is_dst {
|
||||
let temp = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||
self.post.push(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
state_space: ast::StateSpace::Local,
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: type_.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: *name,
|
||||
src2: temp,
|
||||
},
|
||||
});
|
||||
Ok(temp)
|
||||
} else {
|
||||
let temp = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||
self.pre.push(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
state_space: ast::StateSpace::Local,
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: temp,
|
||||
src: *name,
|
||||
},
|
||||
});
|
||||
Ok(temp)
|
||||
}
|
||||
}
|
||||
RemapAction::LDStSpaceChange { .. } => {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(ident)
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.visit(args, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum RemapAction {
|
||||
PreLdPostSt {
|
||||
name: SpirvWord,
|
||||
type_: ast::Type,
|
||||
},
|
||||
LDStSpaceChange {
|
||||
old_space: ast::StateSpace,
|
||||
new_space: ast::StateSpace,
|
||||
name: SpirvWord,
|
||||
},
|
||||
}
|
||||
|
@ -1,401 +1,401 @@
|
||||
use std::mem;
|
||||
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
/*
|
||||
There are several kinds of implicit conversions in PTX:
|
||||
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||
documented special ld/st/cvt conversion rules for destination operands
|
||||
- st.param [x] y (used as function return arguments) same rule as above applies
|
||||
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
||||
documented special ld/st/cvt conversion rules are applied to dst
|
||||
- generic/global st: for instruction `st [x], y`, x must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer
|
||||
*/
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(mut method) => {
|
||||
method.body = method
|
||||
.body
|
||||
.map(|statements| run_statements(resolver, statements))
|
||||
.transpose()?;
|
||||
Directive2::Method(method)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
func: Vec<ExpandedStatement>,
|
||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func.into_iter() {
|
||||
insert_implicit_conversions_impl(resolver, &mut result, s)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_impl<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
stmt: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut post_conv = Vec::new();
|
||||
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
||||
&mut |operand,
|
||||
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst,
|
||||
relaxed_type_check| {
|
||||
let (instr_type, instruction_space) = match type_state {
|
||||
None => return Ok(operand),
|
||||
Some(t) => t,
|
||||
};
|
||||
let (operand_type, operand_space) = resolver.get_typed(operand)?;
|
||||
let conversion_fn = if relaxed_type_check {
|
||||
if is_dst {
|
||||
should_convert_relaxed_dst_wrapper
|
||||
} else {
|
||||
should_convert_relaxed_src_wrapper
|
||||
}
|
||||
} else {
|
||||
default_implicit_conversion
|
||||
};
|
||||
match conversion_fn(
|
||||
(*operand_space, &operand_type),
|
||||
(instruction_space, instr_type),
|
||||
)? {
|
||||
Some(conv_kind) => {
|
||||
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
||||
let mut from_type = instr_type.clone();
|
||||
let mut from_space = instruction_space;
|
||||
let mut to_type = operand_type.clone();
|
||||
let mut to_space = *operand_space;
|
||||
let mut src =
|
||||
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
|
||||
let mut dst = operand;
|
||||
let result = Ok::<_, TranslateError>(src);
|
||||
if !is_dst {
|
||||
mem::swap(&mut src, &mut dst);
|
||||
mem::swap(&mut from_type, &mut to_type);
|
||||
mem::swap(&mut from_space, &mut to_space);
|
||||
}
|
||||
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||
src,
|
||||
dst,
|
||||
from_type,
|
||||
from_space,
|
||||
to_type,
|
||||
to_space,
|
||||
kind: conv_kind,
|
||||
}));
|
||||
result
|
||||
}
|
||||
None => Ok(operand),
|
||||
}
|
||||
},
|
||||
)?;
|
||||
func.push(statement);
|
||||
func.append(&mut post_conv);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn default_implicit_conversion(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if instruction_space == ast::StateSpace::Reg {
|
||||
if operand_space == ast::StateSpace::Reg {
|
||||
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||
(operand_type, instruction_type)
|
||||
{
|
||||
if scalar.kind() == ast::ScalarKind::Bit
|
||||
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
||||
{
|
||||
return Ok(Some(ConversionKind::Default));
|
||||
}
|
||||
}
|
||||
} else if is_addressable(operand_space) {
|
||||
return Ok(Some(ConversionKind::AddressOf));
|
||||
}
|
||||
}
|
||||
if instruction_space != operand_space {
|
||||
default_implicit_conversion_space((operand_space, operand_type), instruction_space)
|
||||
} else if instruction_type != operand_type {
|
||||
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_addressable(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Const
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
||||
ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
// Space is different
|
||||
fn default_implicit_conversion_space(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
instruction_space: ast::StateSpace,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
||||
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||
{
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else if operand_space == ast::StateSpace::Reg {
|
||||
match operand_type {
|
||||
// TODO: 32 bit
|
||||
ast::Type::Scalar(ast::ScalarType::B64)
|
||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
ast::Type::Scalar(ast::ScalarType::B32)
|
||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||
Ok(Some(ConversionKind::BitToPtr))
|
||||
}
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
_ => Err(error_mismatched_type()),
|
||||
}
|
||||
} else {
|
||||
Err(error_mismatched_type())
|
||||
}
|
||||
}
|
||||
|
||||
// Space is same, but type is different
|
||||
fn default_implicit_conversion_type(
|
||||
space: ast::StateSpace,
|
||||
operand_type: &ast::Type,
|
||||
instruction_type: &ast::Type,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if space == ast::StateSpace::Reg {
|
||||
if should_bitcast(instruction_type, operand_type) {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
} else {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
}
|
||||
}
|
||||
|
||||
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ptx_parser::StateSpace::SharedCta
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Reg
|
||||
| ast::StateSpace::Param
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc
|
||||
| ast::StateSpace::Generic => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
if inst.size_of() != operand.size_of() {
|
||||
return false;
|
||||
}
|
||||
match inst.kind() {
|
||||
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Signed => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Signed
|
||||
}
|
||||
ast::ScalarKind::Pred => false,
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
||||
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
||||
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if operand_space != instruction_space {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||
fn should_convert_relaxed_dst(
|
||||
dst_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if dst_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (dst_type, instr_type) {
|
||||
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed => {
|
||||
if dst_type.kind() != ast::ScalarKind::Float {
|
||||
if instr_type.size_of() == dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else if instr_type.size_of() < dst_type.size_of() {
|
||||
Some(ConversionKind::SignExtend)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||
should_convert_relaxed_dst(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if operand_space != instruction_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(error_mismatched_type()),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||
fn should_convert_relaxed_src(
|
||||
src_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if src_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (src_type, instr_type) {
|
||||
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= src_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||
should_convert_relaxed_src(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
use std::mem;
|
||||
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
/*
|
||||
There are several kinds of implicit conversions in PTX:
|
||||
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||
documented special ld/st/cvt conversion rules for destination operands
|
||||
- st.param [x] y (used as function return arguments) same rule as above applies
|
||||
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
||||
documented special ld/st/cvt conversion rules are applied to dst
|
||||
- generic/global st: for instruction `st [x], y`, x must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer
|
||||
*/
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(mut method) => {
|
||||
method.body = method
|
||||
.body
|
||||
.map(|statements| run_statements(resolver, statements))
|
||||
.transpose()?;
|
||||
Directive2::Method(method)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
func: Vec<ExpandedStatement>,
|
||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func.into_iter() {
|
||||
insert_implicit_conversions_impl(resolver, &mut result, s)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_impl<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
stmt: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut post_conv = Vec::new();
|
||||
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
||||
&mut |operand,
|
||||
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst,
|
||||
relaxed_type_check| {
|
||||
let (instr_type, instruction_space) = match type_state {
|
||||
None => return Ok(operand),
|
||||
Some(t) => t,
|
||||
};
|
||||
let (operand_type, operand_space) = resolver.get_typed(operand)?;
|
||||
let conversion_fn = if relaxed_type_check {
|
||||
if is_dst {
|
||||
should_convert_relaxed_dst_wrapper
|
||||
} else {
|
||||
should_convert_relaxed_src_wrapper
|
||||
}
|
||||
} else {
|
||||
default_implicit_conversion
|
||||
};
|
||||
match conversion_fn(
|
||||
(*operand_space, &operand_type),
|
||||
(instruction_space, instr_type),
|
||||
)? {
|
||||
Some(conv_kind) => {
|
||||
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
||||
let mut from_type = instr_type.clone();
|
||||
let mut from_space = instruction_space;
|
||||
let mut to_type = operand_type.clone();
|
||||
let mut to_space = *operand_space;
|
||||
let mut src =
|
||||
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
|
||||
let mut dst = operand;
|
||||
let result = Ok::<_, TranslateError>(src);
|
||||
if !is_dst {
|
||||
mem::swap(&mut src, &mut dst);
|
||||
mem::swap(&mut from_type, &mut to_type);
|
||||
mem::swap(&mut from_space, &mut to_space);
|
||||
}
|
||||
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||
src,
|
||||
dst,
|
||||
from_type,
|
||||
from_space,
|
||||
to_type,
|
||||
to_space,
|
||||
kind: conv_kind,
|
||||
}));
|
||||
result
|
||||
}
|
||||
None => Ok(operand),
|
||||
}
|
||||
},
|
||||
)?;
|
||||
func.push(statement);
|
||||
func.append(&mut post_conv);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn default_implicit_conversion(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if instruction_space == ast::StateSpace::Reg {
|
||||
if operand_space == ast::StateSpace::Reg {
|
||||
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||
(operand_type, instruction_type)
|
||||
{
|
||||
if scalar.kind() == ast::ScalarKind::Bit
|
||||
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
||||
{
|
||||
return Ok(Some(ConversionKind::Default));
|
||||
}
|
||||
}
|
||||
} else if is_addressable(operand_space) {
|
||||
return Ok(Some(ConversionKind::AddressOf));
|
||||
}
|
||||
}
|
||||
if instruction_space != operand_space {
|
||||
default_implicit_conversion_space((operand_space, operand_type), instruction_space)
|
||||
} else if instruction_type != operand_type {
|
||||
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_addressable(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Const
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
||||
ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
// Space is different
|
||||
fn default_implicit_conversion_space(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
instruction_space: ast::StateSpace,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
||||
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||
{
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else if operand_space == ast::StateSpace::Reg {
|
||||
match operand_type {
|
||||
// TODO: 32 bit
|
||||
ast::Type::Scalar(ast::ScalarType::B64)
|
||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
ast::Type::Scalar(ast::ScalarType::B32)
|
||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||
Ok(Some(ConversionKind::BitToPtr))
|
||||
}
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
_ => Err(error_mismatched_type()),
|
||||
}
|
||||
} else {
|
||||
Err(error_mismatched_type())
|
||||
}
|
||||
}
|
||||
|
||||
// Space is same, but type is different
|
||||
fn default_implicit_conversion_type(
|
||||
space: ast::StateSpace,
|
||||
operand_type: &ast::Type,
|
||||
instruction_type: &ast::Type,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if space == ast::StateSpace::Reg {
|
||||
if should_bitcast(instruction_type, operand_type) {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
} else {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
}
|
||||
}
|
||||
|
||||
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ptx_parser::StateSpace::SharedCta
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Reg
|
||||
| ast::StateSpace::Param
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc
|
||||
| ast::StateSpace::Generic => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
if inst.size_of() != operand.size_of() {
|
||||
return false;
|
||||
}
|
||||
match inst.kind() {
|
||||
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Signed => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Signed
|
||||
}
|
||||
ast::ScalarKind::Pred => false,
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
||||
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
||||
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if operand_space != instruction_space {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||
fn should_convert_relaxed_dst(
|
||||
dst_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if dst_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (dst_type, instr_type) {
|
||||
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed => {
|
||||
if dst_type.kind() != ast::ScalarKind::Float {
|
||||
if instr_type.size_of() == dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else if instr_type.size_of() < dst_type.size_of() {
|
||||
Some(ConversionKind::SignExtend)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||
should_convert_relaxed_dst(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if operand_space != instruction_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(error_mismatched_type()),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||
fn should_convert_relaxed_src(
|
||||
src_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if src_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (src_type, instr_type) {
|
||||
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= src_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||
should_convert_relaxed_src(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
@ -1,194 +1,194 @@
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
|
||||
resolver.start_scope();
|
||||
let result = directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
resolver.end_scope();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<NormalizedDirective2, TranslateError> {
|
||||
Ok(match directive {
|
||||
ast::Directive::Variable(linking, var) => {
|
||||
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||
}
|
||||
ast::Directive::Method(linking, directive) => {
|
||||
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
linkage: ast::LinkingDirective,
|
||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<NormalizedFunction2, TranslateError> {
|
||||
let is_kernel = method.func_directive.name.is_kernel();
|
||||
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
|
||||
resolver.start_scope();
|
||||
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
run_statements(resolver, &mut result, statements)?;
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
resolver.end_scope();
|
||||
Ok(Function2 {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
body,
|
||||
import_as: None,
|
||||
linkage,
|
||||
is_kernel,
|
||||
tuning: method.tuning,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_function_decl<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
||||
assert!(func_directive.shared_mem.is_none());
|
||||
let return_arguments = func_directive
|
||||
.return_arguments
|
||||
.into_iter()
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let input_arguments = func_directive
|
||||
.input_arguments
|
||||
.into_iter()
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok((return_arguments, input_arguments))
|
||||
}
|
||||
|
||||
fn run_variable<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
variable: ast::Variable<&'input str>,
|
||||
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
||||
Ok(ast::Variable {
|
||||
name: resolver.add(
|
||||
Cow::Borrowed(variable.name),
|
||||
Some((variable.v_type.clone(), variable.state_space)),
|
||||
)?,
|
||||
align: variable.align,
|
||||
v_type: variable.v_type,
|
||||
state_space: variable.state_space,
|
||||
array_init: variable.array_init,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<(), TranslateError> {
|
||||
for statement in statements.iter() {
|
||||
match statement {
|
||||
ast::Statement::Label(label) => {
|
||||
resolver.add(Cow::Borrowed(*label), None)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for statement in statements {
|
||||
match statement {
|
||||
ast::Statement::Label(label) => {
|
||||
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
||||
}
|
||||
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
|
||||
ast::Statement::Instruction(predicate, instruction) => {
|
||||
result.push(Statement::Instruction((
|
||||
predicate
|
||||
.map(|pred| {
|
||||
Ok::<_, TranslateError>(ast::PredAt {
|
||||
not: pred.not,
|
||||
label: resolver.get(pred.label)?,
|
||||
})
|
||||
})
|
||||
.transpose()?,
|
||||
run_instruction(resolver, instruction)?,
|
||||
)))
|
||||
}
|
||||
ast::Statement::Block(block) => {
|
||||
resolver.start_scope();
|
||||
run_statements(resolver, result, block)?;
|
||||
resolver.end_scope();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_instruction<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||
ast::visit_map(instruction, &mut |name: &'input str,
|
||||
_: Option<(
|
||||
&ast::Type,
|
||||
ast::StateSpace,
|
||||
)>,
|
||||
_,
|
||||
_| {
|
||||
resolver.get(&name)
|
||||
})
|
||||
}
|
||||
|
||||
fn run_multivariable<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
variable: ast::MultiVariable<&'input str>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match variable.count {
|
||||
Some(count) => {
|
||||
for i in 0..count {
|
||||
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
|
||||
let ident = resolver.add(
|
||||
name,
|
||||
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||
)?;
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: variable.var.align,
|
||||
v_type: variable.var.v_type.clone(),
|
||||
state_space: variable.var.state_space,
|
||||
name: ident,
|
||||
array_init: variable.var.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let name = Cow::Borrowed(variable.var.name);
|
||||
let ident = resolver.add(
|
||||
name,
|
||||
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||
)?;
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: variable.var.align,
|
||||
v_type: variable.var.v_type.clone(),
|
||||
state_space: variable.var.state_space,
|
||||
name: ident,
|
||||
array_init: variable.var.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
|
||||
resolver.start_scope();
|
||||
let result = directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
resolver.end_scope();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<NormalizedDirective2, TranslateError> {
|
||||
Ok(match directive {
|
||||
ast::Directive::Variable(linking, var) => {
|
||||
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||
}
|
||||
ast::Directive::Method(linking, directive) => {
|
||||
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
linkage: ast::LinkingDirective,
|
||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<NormalizedFunction2, TranslateError> {
|
||||
let is_kernel = method.func_directive.name.is_kernel();
|
||||
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
|
||||
resolver.start_scope();
|
||||
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
run_statements(resolver, &mut result, statements)?;
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
resolver.end_scope();
|
||||
Ok(Function2 {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
body,
|
||||
import_as: None,
|
||||
linkage,
|
||||
is_kernel,
|
||||
tuning: method.tuning,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_function_decl<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
||||
assert!(func_directive.shared_mem.is_none());
|
||||
let return_arguments = func_directive
|
||||
.return_arguments
|
||||
.into_iter()
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let input_arguments = func_directive
|
||||
.input_arguments
|
||||
.into_iter()
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok((return_arguments, input_arguments))
|
||||
}
|
||||
|
||||
fn run_variable<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
variable: ast::Variable<&'input str>,
|
||||
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
||||
Ok(ast::Variable {
|
||||
name: resolver.add(
|
||||
Cow::Borrowed(variable.name),
|
||||
Some((variable.v_type.clone(), variable.state_space)),
|
||||
)?,
|
||||
align: variable.align,
|
||||
v_type: variable.v_type,
|
||||
state_space: variable.state_space,
|
||||
array_init: variable.array_init,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<(), TranslateError> {
|
||||
for statement in statements.iter() {
|
||||
match statement {
|
||||
ast::Statement::Label(label) => {
|
||||
resolver.add(Cow::Borrowed(*label), None)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for statement in statements {
|
||||
match statement {
|
||||
ast::Statement::Label(label) => {
|
||||
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
||||
}
|
||||
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
|
||||
ast::Statement::Instruction(predicate, instruction) => {
|
||||
result.push(Statement::Instruction((
|
||||
predicate
|
||||
.map(|pred| {
|
||||
Ok::<_, TranslateError>(ast::PredAt {
|
||||
not: pred.not,
|
||||
label: resolver.get(pred.label)?,
|
||||
})
|
||||
})
|
||||
.transpose()?,
|
||||
run_instruction(resolver, instruction)?,
|
||||
)))
|
||||
}
|
||||
ast::Statement::Block(block) => {
|
||||
resolver.start_scope();
|
||||
run_statements(resolver, result, block)?;
|
||||
resolver.end_scope();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_instruction<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||
ast::visit_map(instruction, &mut |name: &'input str,
|
||||
_: Option<(
|
||||
&ast::Type,
|
||||
ast::StateSpace,
|
||||
)>,
|
||||
_,
|
||||
_| {
|
||||
resolver.get(&name)
|
||||
})
|
||||
}
|
||||
|
||||
fn run_multivariable<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
variable: ast::MultiVariable<&'input str>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match variable.count {
|
||||
Some(count) => {
|
||||
for i in 0..count {
|
||||
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
|
||||
let ident = resolver.add(
|
||||
name,
|
||||
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||
)?;
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: variable.var.align,
|
||||
v_type: variable.var.v_type.clone(),
|
||||
state_space: variable.var.state_space,
|
||||
name: ident,
|
||||
array_init: variable.var.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let name = Cow::Borrowed(variable.var.name);
|
||||
let ident = resolver.add(
|
||||
name,
|
||||
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||
)?;
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: variable.var.align,
|
||||
v_type: variable.var.v_type.clone(),
|
||||
state_space: variable.var.state_space,
|
||||
name: ident,
|
||||
array_init: variable.var.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,90 +1,90 @@
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<NormalizedDirective2>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: NormalizedDirective2,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: NormalizedFunction2,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
body,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||
rounding_mode_f32: method.rounding_mode_f32,
|
||||
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<UnconditionalStatement>,
|
||||
statement: NormalizedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
Ok(match statement {
|
||||
Statement::Label(label) => result.push(Statement::Label(label)),
|
||||
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||
Statement::Instruction((predicate, instruction)) => {
|
||||
if let Some(pred) = predicate {
|
||||
let if_true = resolver.register_unnamed(None);
|
||||
let if_false = resolver.register_unnamed(None);
|
||||
let folded_bra = match &instruction {
|
||||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||
_ => None,
|
||||
};
|
||||
let mut branch = BrachCondition {
|
||||
predicate: pred.label,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
if pred.not {
|
||||
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||
}
|
||||
result.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
result.push(Statement::Label(if_true));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
result.push(Statement::Label(if_false));
|
||||
} else {
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
})
|
||||
}
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<NormalizedDirective2>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: NormalizedDirective2,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: NormalizedFunction2,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
body,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||
rounding_mode_f32: method.rounding_mode_f32,
|
||||
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<UnconditionalStatement>,
|
||||
statement: NormalizedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
Ok(match statement {
|
||||
Statement::Label(label) => result.push(Statement::Label(label)),
|
||||
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||
Statement::Instruction((predicate, instruction)) => {
|
||||
if let Some(pred) = predicate {
|
||||
let if_true = resolver.register_unnamed(None);
|
||||
let if_false = resolver.register_unnamed(None);
|
||||
let folded_bra = match &instruction {
|
||||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||
_ => None,
|
||||
};
|
||||
let mut branch = BrachCondition {
|
||||
predicate: pred.label,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
if pred.not {
|
||||
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||
}
|
||||
result.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
result.push(Statement::Label(if_true));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
result.push(Statement::Label(if_false));
|
||||
} else {
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
})
|
||||
}
|
||||
|
@ -1,268 +1,268 @@
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut fn_declarations = FxHashMap::default();
|
||||
let remapped_directives = directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, &mut fn_declarations, directive))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut result = fn_declarations
|
||||
.into_iter()
|
||||
.map(|(_, (return_arguments, name, input_arguments))| {
|
||||
Directive2::Method(Function2 {
|
||||
return_arguments,
|
||||
name: name,
|
||||
input_arguments,
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
result.extend(remapped_directives);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(mut method) => {
|
||||
method.body = method
|
||||
.body
|
||||
.map(|statements| run_statements(resolver, fn_declarations, statements))
|
||||
.transpose()?;
|
||||
Directive2::Method(method)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|statement| {
|
||||
Ok(match statement {
|
||||
Statement::Instruction(instruction) => {
|
||||
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
|
||||
}
|
||||
s => s,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_instruction<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
instruction: ptx_parser::Instruction<SpirvWord>,
|
||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||
Ok(match instruction {
|
||||
i @ ptx_parser::Instruction::Sqrt {
|
||||
data:
|
||||
ast::RcpData {
|
||||
kind: ast::RcpKind::Approx,
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Rsqrt {
|
||||
data:
|
||||
ast::TypeFtz {
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Rcp {
|
||||
data:
|
||||
ast::RcpData {
|
||||
kind: ast::RcpKind::Approx,
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Ex2 {
|
||||
data:
|
||||
ast::TypeFtz {
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Lg2 {
|
||||
data: ast::FlushToZero {
|
||||
flush_to_zero: false,
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Activemask { .. } => {
|
||||
to_call(resolver, fn_declarations, "activemask".into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Bfe { data, .. } => {
|
||||
let name = ["bfe_", scalar_to_ptx_name(data)].concat();
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Bfi { data, .. } => {
|
||||
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Bar { .. } => {
|
||||
to_call(resolver, fn_declarations, "bar_sync".into(), i)?
|
||||
}
|
||||
ptx_parser::Instruction::BarRed { data, arguments } => {
|
||||
if arguments.src_threadcount.is_some() {
|
||||
return Err(error_todo());
|
||||
}
|
||||
let name = match data.pred_reduction {
|
||||
ptx_parser::Reduction::And => "bar_red_and_pred",
|
||||
ptx_parser::Reduction::Or => "bar_red_or_pred",
|
||||
};
|
||||
to_call(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
name.into(),
|
||||
ptx_parser::Instruction::BarRed { data, arguments },
|
||||
)?
|
||||
}
|
||||
ptx_parser::Instruction::ShflSync { data, arguments } => {
|
||||
let mode = match data.mode {
|
||||
ptx_parser::ShuffleMode::Up => "up",
|
||||
ptx_parser::ShuffleMode::Down => "down",
|
||||
ptx_parser::ShuffleMode::BFly => "bfly",
|
||||
ptx_parser::ShuffleMode::Idx => "idx",
|
||||
};
|
||||
let pred = if arguments.dst_pred.is_some() {
|
||||
"_pred"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
to_call(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
format!("shfl_sync_{}_b32{}", mode, pred).into(),
|
||||
ptx_parser::Instruction::ShflSync { data, arguments },
|
||||
)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Nanosleep { .. } => {
|
||||
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
|
||||
}
|
||||
i => i,
|
||||
})
|
||||
}
|
||||
|
||||
fn to_call<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
name: Cow<'input, str>,
|
||||
i: ast::Instruction<SpirvWord>,
|
||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||
let mut data_return = Vec::new();
|
||||
let mut data_input = Vec::new();
|
||||
let mut arguments_return = Vec::new();
|
||||
let mut arguments_input = Vec::new();
|
||||
ast::visit(&i, &mut |name: &SpirvWord,
|
||||
type_space: Option<(
|
||||
&ptx_parser::Type,
|
||||
ptx_parser::StateSpace,
|
||||
)>,
|
||||
is_dst: bool,
|
||||
_: bool| {
|
||||
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
|
||||
if is_dst {
|
||||
data_return.push((type_.clone(), space));
|
||||
arguments_return.push(*name);
|
||||
} else {
|
||||
data_input.push((type_.clone(), space));
|
||||
arguments_input.push(*name);
|
||||
};
|
||||
Ok::<_, TranslateError>(())
|
||||
})?;
|
||||
let fn_name = match fn_declarations.entry(name) {
|
||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||
hash_map::Entry::Vacant(vacant_entry) => {
|
||||
let name = vacant_entry.key().clone();
|
||||
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
||||
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
||||
vacant_entry.insert((
|
||||
to_variables(resolver, &data_return),
|
||||
name,
|
||||
to_variables(resolver, &data_input),
|
||||
));
|
||||
name
|
||||
}
|
||||
};
|
||||
Ok(ast::Instruction::Call {
|
||||
data: ptx_parser::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: data_return,
|
||||
input_arguments: data_input,
|
||||
},
|
||||
arguments: ptx_parser::CallArgs {
|
||||
return_arguments: arguments_return,
|
||||
func: fn_name,
|
||||
input_arguments: arguments_input,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn to_variables<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
) -> Vec<ptx_parser::Variable<SpirvWord>> {
|
||||
arguments
|
||||
.iter()
|
||||
.map(|(type_, space)| ast::Variable {
|
||||
align: None,
|
||||
v_type: type_.clone(),
|
||||
state_space: *space,
|
||||
name: resolver.register_unnamed(Some((type_.clone(), *space))),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut fn_declarations = FxHashMap::default();
|
||||
let remapped_directives = directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, &mut fn_declarations, directive))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut result = fn_declarations
|
||||
.into_iter()
|
||||
.map(|(_, (return_arguments, name, input_arguments))| {
|
||||
Directive2::Method(Function2 {
|
||||
return_arguments,
|
||||
name: name,
|
||||
input_arguments,
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
result.extend(remapped_directives);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(mut method) => {
|
||||
method.body = method
|
||||
.body
|
||||
.map(|statements| run_statements(resolver, fn_declarations, statements))
|
||||
.transpose()?;
|
||||
Directive2::Method(method)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|statement| {
|
||||
Ok(match statement {
|
||||
Statement::Instruction(instruction) => {
|
||||
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
|
||||
}
|
||||
s => s,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_instruction<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
instruction: ptx_parser::Instruction<SpirvWord>,
|
||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||
Ok(match instruction {
|
||||
i @ ptx_parser::Instruction::Sqrt {
|
||||
data:
|
||||
ast::RcpData {
|
||||
kind: ast::RcpKind::Approx,
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Rsqrt {
|
||||
data:
|
||||
ast::TypeFtz {
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Rcp {
|
||||
data:
|
||||
ast::RcpData {
|
||||
kind: ast::RcpKind::Approx,
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Ex2 {
|
||||
data:
|
||||
ast::TypeFtz {
|
||||
type_: ast::ScalarType::F32,
|
||||
flush_to_zero: None | Some(false),
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Lg2 {
|
||||
data: ast::FlushToZero {
|
||||
flush_to_zero: false,
|
||||
},
|
||||
..
|
||||
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
|
||||
i @ ptx_parser::Instruction::Activemask { .. } => {
|
||||
to_call(resolver, fn_declarations, "activemask".into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Bfe { data, .. } => {
|
||||
let name = ["bfe_", scalar_to_ptx_name(data)].concat();
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Bfi { data, .. } => {
|
||||
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Bar { .. } => {
|
||||
to_call(resolver, fn_declarations, "bar_sync".into(), i)?
|
||||
}
|
||||
ptx_parser::Instruction::BarRed { data, arguments } => {
|
||||
if arguments.src_threadcount.is_some() {
|
||||
return Err(error_todo());
|
||||
}
|
||||
let name = match data.pred_reduction {
|
||||
ptx_parser::Reduction::And => "bar_red_and_pred",
|
||||
ptx_parser::Reduction::Or => "bar_red_or_pred",
|
||||
};
|
||||
to_call(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
name.into(),
|
||||
ptx_parser::Instruction::BarRed { data, arguments },
|
||||
)?
|
||||
}
|
||||
ptx_parser::Instruction::ShflSync { data, arguments } => {
|
||||
let mode = match data.mode {
|
||||
ptx_parser::ShuffleMode::Up => "up",
|
||||
ptx_parser::ShuffleMode::Down => "down",
|
||||
ptx_parser::ShuffleMode::BFly => "bfly",
|
||||
ptx_parser::ShuffleMode::Idx => "idx",
|
||||
};
|
||||
let pred = if arguments.dst_pred.is_some() {
|
||||
"_pred"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
to_call(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
format!("shfl_sync_{}_b32{}", mode, pred).into(),
|
||||
ptx_parser::Instruction::ShflSync { data, arguments },
|
||||
)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Nanosleep { .. } => {
|
||||
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
|
||||
}
|
||||
i => i,
|
||||
})
|
||||
}
|
||||
|
||||
fn to_call<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
name: Cow<'input, str>,
|
||||
i: ast::Instruction<SpirvWord>,
|
||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||
let mut data_return = Vec::new();
|
||||
let mut data_input = Vec::new();
|
||||
let mut arguments_return = Vec::new();
|
||||
let mut arguments_input = Vec::new();
|
||||
ast::visit(&i, &mut |name: &SpirvWord,
|
||||
type_space: Option<(
|
||||
&ptx_parser::Type,
|
||||
ptx_parser::StateSpace,
|
||||
)>,
|
||||
is_dst: bool,
|
||||
_: bool| {
|
||||
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
|
||||
if is_dst {
|
||||
data_return.push((type_.clone(), space));
|
||||
arguments_return.push(*name);
|
||||
} else {
|
||||
data_input.push((type_.clone(), space));
|
||||
arguments_input.push(*name);
|
||||
};
|
||||
Ok::<_, TranslateError>(())
|
||||
})?;
|
||||
let fn_name = match fn_declarations.entry(name) {
|
||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||
hash_map::Entry::Vacant(vacant_entry) => {
|
||||
let name = vacant_entry.key().clone();
|
||||
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
||||
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
||||
vacant_entry.insert((
|
||||
to_variables(resolver, &data_return),
|
||||
name,
|
||||
to_variables(resolver, &data_input),
|
||||
));
|
||||
name
|
||||
}
|
||||
};
|
||||
Ok(ast::Instruction::Call {
|
||||
data: ptx_parser::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: data_return,
|
||||
input_arguments: data_input,
|
||||
},
|
||||
arguments: ptx_parser::CallArgs {
|
||||
return_arguments: arguments_return,
|
||||
func: fn_name,
|
||||
input_arguments: arguments_input,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn to_variables<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
) -> Vec<ptx_parser::Variable<SpirvWord>> {
|
||||
arguments
|
||||
.iter()
|
||||
.map(|(type_, space)| ast::Variable {
|
||||
align: None,
|
||||
v_type: type_.clone(),
|
||||
state_space: *space,
|
||||
name: resolver.register_unnamed(Some((type_.clone(), *space))),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
@ -1,33 +1,33 @@
|
||||
use std::borrow::Cow;
|
||||
|
||||
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
mut directives: Vec<NormalizedDirective2>,
|
||||
) -> Vec<NormalizedDirective2> {
|
||||
for directive in directives.iter_mut() {
|
||||
match directive {
|
||||
NormalizedDirective2::Method(func) => {
|
||||
replace_with_ptx_impl(resolver, func.name);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
directives
|
||||
}
|
||||
|
||||
fn replace_with_ptx_impl<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_name: SpirvWord,
|
||||
) {
|
||||
let known_names = ["__assertfail"];
|
||||
if let Some(super::IdentEntry {
|
||||
name: Some(name), ..
|
||||
}) = resolver.ident_map.get_mut(&fn_name)
|
||||
{
|
||||
if known_names.contains(&&**name) {
|
||||
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
|
||||
}
|
||||
}
|
||||
}
|
||||
use std::borrow::Cow;
|
||||
|
||||
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
mut directives: Vec<NormalizedDirective2>,
|
||||
) -> Vec<NormalizedDirective2> {
|
||||
for directive in directives.iter_mut() {
|
||||
match directive {
|
||||
NormalizedDirective2::Method(func) => {
|
||||
replace_with_ptx_impl(resolver, func.name);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
directives
|
||||
}
|
||||
|
||||
fn replace_with_ptx_impl<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_name: SpirvWord,
|
||||
) {
|
||||
let known_names = ["__assertfail"];
|
||||
if let Some(super::IdentEntry {
|
||||
name: Some(name), ..
|
||||
}) = resolver.ident_map.get_mut(&fn_name)
|
||||
{
|
||||
if known_names.contains(&&**name) {
|
||||
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,69 +1,69 @@
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
let mut functions = FxHashSet::default();
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(&mut functions, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
directive: UnconditionalDirective,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
if !method.is_kernel {
|
||||
functions.insert(method.name);
|
||||
}
|
||||
Directive2::Method(run_method(functions, method)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
method: UnconditionalFunction,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|statement| run_statement(functions, statement))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<UnconditionalStatement, TranslateError> {
|
||||
Ok(match statement {
|
||||
Statement::Instruction(ast::Instruction::Mov {
|
||||
data,
|
||||
arguments:
|
||||
ast::MovArgs {
|
||||
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||
src: ast::ParsedOperand::Reg(src_reg),
|
||||
},
|
||||
}) if functions.contains(&src_reg) => {
|
||||
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
|
||||
dst: dst_reg,
|
||||
src: src_reg,
|
||||
})
|
||||
}
|
||||
s => s,
|
||||
})
|
||||
}
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
let mut functions = FxHashSet::default();
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(&mut functions, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
directive: UnconditionalDirective,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
if !method.is_kernel {
|
||||
functions.insert(method.name);
|
||||
}
|
||||
Directive2::Method(run_method(functions, method)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
method: UnconditionalFunction,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|statement| run_statement(functions, statement))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<UnconditionalStatement, TranslateError> {
|
||||
Ok(match statement {
|
||||
Statement::Instruction(ast::Instruction::Mov {
|
||||
data,
|
||||
arguments:
|
||||
ast::MovArgs {
|
||||
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||
src: ast::ParsedOperand::Reg(src_reg),
|
||||
},
|
||||
}) if functions.contains(&src_reg) => {
|
||||
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
|
||||
dst: dst_reg,
|
||||
src: src_reg,
|
||||
})
|
||||
}
|
||||
s => s,
|
||||
})
|
||||
}
|
||||
|
@ -1,327 +1,327 @@
|
||||
use bpaf::{Args, Bpaf, Parser};
|
||||
use cargo_metadata::{MetadataCommand, Package};
|
||||
use serde::Deserialize;
|
||||
use std::{env, ffi::OsString, path::PathBuf, process::Command};
|
||||
|
||||
#[derive(Debug, Clone, Bpaf)]
|
||||
#[bpaf(options)]
|
||||
enum Options {
|
||||
#[bpaf(command)]
|
||||
/// Compile ZLUDA (default command)
|
||||
Build(#[bpaf(external(build))] Build),
|
||||
#[bpaf(command)]
|
||||
/// Compile ZLUDA and build a package
|
||||
Zip(#[bpaf(external(build))] Build),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Bpaf)]
|
||||
struct Build {
|
||||
#[bpaf(any("CARGO", not_help), many)]
|
||||
/// Arguments to pass to cargo, e.g. `--release` for release build
|
||||
cargo_arguments: Vec<OsString>,
|
||||
}
|
||||
|
||||
fn not_help(s: OsString) -> Option<OsString> {
|
||||
if s == "-h" || s == "--help" {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
}
|
||||
}
|
||||
|
||||
// We need to sniff out some args passed to cargo to understand how to create
|
||||
// symlinks (should they go into `target/debug`, `target/release` or custom)
|
||||
#[derive(Debug, Clone, Bpaf)]
|
||||
struct Cargo {
|
||||
#[bpaf(switch, long, short)]
|
||||
release: Option<bool>,
|
||||
#[bpaf(long)]
|
||||
profile: Option<String>,
|
||||
#[bpaf(any("", Some), many)]
|
||||
_unused: Vec<OsString>,
|
||||
}
|
||||
|
||||
struct Project {
|
||||
name: String,
|
||||
target_name: String,
|
||||
target_kind: ProjectTarget,
|
||||
meta: ZludaMetadata,
|
||||
}
|
||||
|
||||
impl Project {
|
||||
fn try_new(p: Package) -> Option<Project> {
|
||||
let name = p.name;
|
||||
serde_json::from_value::<Option<Metadata>>(p.metadata)
|
||||
.unwrap()
|
||||
.map(|m| {
|
||||
let (target_name, target_kind) = p
|
||||
.targets
|
||||
.into_iter()
|
||||
.find_map(|target| {
|
||||
if target.is_cdylib() {
|
||||
Some((target.name, ProjectTarget::Cdylib))
|
||||
} else if target.is_bin() {
|
||||
Some((target.name, ProjectTarget::Bin))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
Self {
|
||||
name,
|
||||
target_name,
|
||||
target_kind,
|
||||
meta: m.zluda,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn prefix(&self) -> &'static str {
|
||||
match self.target_kind {
|
||||
ProjectTarget::Bin => "",
|
||||
ProjectTarget::Cdylib => "lib",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn prefix(&self) -> &'static str {
|
||||
""
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn suffix(&self) -> &'static str {
|
||||
match self.target_kind {
|
||||
ProjectTarget::Bin => "",
|
||||
ProjectTarget::Cdylib => ".so",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn suffix(&self) -> &'static str {
|
||||
match self.target_kind {
|
||||
ProjectTarget::Bin => ".exe",
|
||||
ProjectTarget::Cdylib => ".dll",
|
||||
}
|
||||
}
|
||||
|
||||
// Returns tuple:
|
||||
// * symlink file path (relative to the root of build dir)
|
||||
// * symlink absolute file path
|
||||
// * target actual file (relative to symlink file)
|
||||
#[cfg_attr(not(unix), allow(unused))]
|
||||
fn symlinks<'a>(
|
||||
&'a self,
|
||||
target_dir: &'a PathBuf,
|
||||
profile: &'a str,
|
||||
libname: &'a str,
|
||||
) -> impl Iterator<Item = (&'a str, PathBuf, PathBuf)> + 'a {
|
||||
self.meta.linux_symlinks.iter().map(move |source| {
|
||||
let mut link = target_dir.clone();
|
||||
link.extend([profile, source]);
|
||||
let relative_link = PathBuf::from(source);
|
||||
let ancestors = relative_link.as_path().ancestors().count();
|
||||
let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold(
|
||||
PathBuf::new(),
|
||||
|mut buff, segment| {
|
||||
buff.push(segment);
|
||||
buff
|
||||
},
|
||||
);
|
||||
target.push(libname);
|
||||
(&**source, link, target)
|
||||
})
|
||||
}
|
||||
|
||||
fn file_name(&self) -> String {
|
||||
let target_name = &self.target_name;
|
||||
let prefix = self.prefix();
|
||||
let suffix = self.suffix();
|
||||
format!("{prefix}{target_name}{suffix}")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ProjectTarget {
|
||||
Cdylib,
|
||||
Bin,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Metadata {
|
||||
zluda: ZludaMetadata,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
struct ZludaMetadata {
|
||||
#[serde(default)]
|
||||
windows_only: bool,
|
||||
#[serde(default)]
|
||||
debug_only: bool,
|
||||
#[cfg_attr(not(unix), allow(unused))]
|
||||
#[serde(default)]
|
||||
linux_symlinks: Vec<String>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let options = match options().run_inner(Args::current_args()) {
|
||||
Ok(b) => b,
|
||||
Err(err) => match build().to_options().run_inner(Args::current_args()) {
|
||||
Ok(b) => Options::Build(b),
|
||||
Err(_) => {
|
||||
err.print_message(100);
|
||||
std::process::exit(err.exit_code());
|
||||
}
|
||||
},
|
||||
};
|
||||
match options {
|
||||
Options::Build(b) => {
|
||||
compile(b);
|
||||
}
|
||||
Options::Zip(b) => zip(b),
|
||||
}
|
||||
}
|
||||
|
||||
fn compile(b: Build) -> (PathBuf, String, Vec<Project>) {
|
||||
let profile = sniff_out_profile_name(&b.cargo_arguments);
|
||||
let meta = MetadataCommand::new().no_deps().exec().unwrap();
|
||||
let target_directory = meta.target_directory.into_std_path_buf();
|
||||
let projects = meta
|
||||
.packages
|
||||
.into_iter()
|
||||
.filter_map(Project::try_new)
|
||||
.filter(|project| {
|
||||
if project.meta.windows_only && cfg!(not(windows)) {
|
||||
return false;
|
||||
}
|
||||
if project.meta.debug_only && profile != "debug" {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
|
||||
let mut command = Command::new(&cargo);
|
||||
command.arg("build");
|
||||
command.arg("--locked");
|
||||
for project in projects.iter() {
|
||||
command.arg("--package");
|
||||
command.arg(&project.name);
|
||||
}
|
||||
command.args(b.cargo_arguments);
|
||||
assert!(command.status().unwrap().success());
|
||||
os::make_symlinks(&target_directory, &*projects, &*profile);
|
||||
(target_directory, profile, projects)
|
||||
}
|
||||
|
||||
fn sniff_out_profile_name(b: &[OsString]) -> String {
|
||||
let parsed_cargo_arguments = cargo().to_options().run_inner(b);
|
||||
match parsed_cargo_arguments {
|
||||
Ok(Cargo {
|
||||
release: Some(true),
|
||||
..
|
||||
}) => "release".to_string(),
|
||||
Ok(Cargo {
|
||||
profile: Some(profile),
|
||||
..
|
||||
}) => profile,
|
||||
_ => "debug".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn zip(zip: Build) {
|
||||
let (target_dir, profile, projects) = compile(zip);
|
||||
os::zip(target_dir, profile, projects)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
mod os {
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
path::PathBuf,
|
||||
};
|
||||
use tar::Header;
|
||||
|
||||
pub fn make_symlinks(
|
||||
target_directory: &std::path::PathBuf,
|
||||
projects: &[super::Project],
|
||||
profile: &str,
|
||||
) {
|
||||
use std::os::unix::fs as unix_fs;
|
||||
for project in projects.iter() {
|
||||
let libname = project.file_name();
|
||||
for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) {
|
||||
let mut dir = full_path.clone();
|
||||
assert!(dir.pop());
|
||||
fs::create_dir_all(dir).unwrap();
|
||||
fs::remove_file(&full_path).ok();
|
||||
unix_fs::symlink(&target, full_path).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
|
||||
let tar_gz =
|
||||
File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap();
|
||||
let enc = GzEncoder::new(tar_gz, Compression::default());
|
||||
let mut tar = tar::Builder::new(enc);
|
||||
for project in projects.iter() {
|
||||
let file_name = project.file_name();
|
||||
let mut file =
|
||||
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
|
||||
tar.append_file(format!("zluda/{file_name}"), &mut file)
|
||||
.unwrap();
|
||||
for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) {
|
||||
let mut header = Header::new_gnu();
|
||||
let meta = fs::symlink_metadata(&full_path).unwrap();
|
||||
header.set_metadata(&meta);
|
||||
tar.append_link(&mut header, format!("zluda/{source}"), target)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
tar.finish().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
mod os {
|
||||
use std::{fs::File, io, path::PathBuf};
|
||||
use zip::{write::SimpleFileOptions, ZipWriter};
|
||||
|
||||
pub fn make_symlinks(
|
||||
_target_directory: &std::path::PathBuf,
|
||||
_projects: &[super::Project],
|
||||
_profile: &str,
|
||||
) {
|
||||
}
|
||||
|
||||
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
|
||||
let zip_file =
|
||||
File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap();
|
||||
let mut zip = ZipWriter::new(zip_file);
|
||||
zip.add_directory("zluda", SimpleFileOptions::default())
|
||||
.unwrap();
|
||||
for project in projects.iter() {
|
||||
let file_name = project.file_name();
|
||||
let mut file =
|
||||
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
|
||||
let file_options = file_options_from_time(&file).unwrap_or_default();
|
||||
zip.start_file(format!("zluda/{file_name}"), file_options)
|
||||
.unwrap();
|
||||
io::copy(&mut file, &mut zip).unwrap();
|
||||
}
|
||||
zip.finish().unwrap();
|
||||
}
|
||||
|
||||
fn file_options_from_time(from: &File) -> io::Result<SimpleFileOptions> {
|
||||
let metadata = from.metadata()?;
|
||||
let modified = metadata.modified()?;
|
||||
let modified = time::OffsetDateTime::from(modified);
|
||||
Ok(SimpleFileOptions::default().last_modified_time(
|
||||
zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?,
|
||||
))
|
||||
}
|
||||
}
|
||||
use bpaf::{Args, Bpaf, Parser};
|
||||
use cargo_metadata::{MetadataCommand, Package};
|
||||
use serde::Deserialize;
|
||||
use std::{env, ffi::OsString, path::PathBuf, process::Command};
|
||||
|
||||
#[derive(Debug, Clone, Bpaf)]
|
||||
#[bpaf(options)]
|
||||
enum Options {
|
||||
#[bpaf(command)]
|
||||
/// Compile ZLUDA (default command)
|
||||
Build(#[bpaf(external(build))] Build),
|
||||
#[bpaf(command)]
|
||||
/// Compile ZLUDA and build a package
|
||||
Zip(#[bpaf(external(build))] Build),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Bpaf)]
|
||||
struct Build {
|
||||
#[bpaf(any("CARGO", not_help), many)]
|
||||
/// Arguments to pass to cargo, e.g. `--release` for release build
|
||||
cargo_arguments: Vec<OsString>,
|
||||
}
|
||||
|
||||
fn not_help(s: OsString) -> Option<OsString> {
|
||||
if s == "-h" || s == "--help" {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
}
|
||||
}
|
||||
|
||||
// We need to sniff out some args passed to cargo to understand how to create
|
||||
// symlinks (should they go into `target/debug`, `target/release` or custom)
|
||||
#[derive(Debug, Clone, Bpaf)]
|
||||
struct Cargo {
|
||||
#[bpaf(switch, long, short)]
|
||||
release: Option<bool>,
|
||||
#[bpaf(long)]
|
||||
profile: Option<String>,
|
||||
#[bpaf(any("", Some), many)]
|
||||
_unused: Vec<OsString>,
|
||||
}
|
||||
|
||||
struct Project {
|
||||
name: String,
|
||||
target_name: String,
|
||||
target_kind: ProjectTarget,
|
||||
meta: ZludaMetadata,
|
||||
}
|
||||
|
||||
impl Project {
|
||||
fn try_new(p: Package) -> Option<Project> {
|
||||
let name = p.name;
|
||||
serde_json::from_value::<Option<Metadata>>(p.metadata)
|
||||
.unwrap()
|
||||
.map(|m| {
|
||||
let (target_name, target_kind) = p
|
||||
.targets
|
||||
.into_iter()
|
||||
.find_map(|target| {
|
||||
if target.is_cdylib() {
|
||||
Some((target.name, ProjectTarget::Cdylib))
|
||||
} else if target.is_bin() {
|
||||
Some((target.name, ProjectTarget::Bin))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
Self {
|
||||
name,
|
||||
target_name,
|
||||
target_kind,
|
||||
meta: m.zluda,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn prefix(&self) -> &'static str {
|
||||
match self.target_kind {
|
||||
ProjectTarget::Bin => "",
|
||||
ProjectTarget::Cdylib => "lib",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn prefix(&self) -> &'static str {
|
||||
""
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn suffix(&self) -> &'static str {
|
||||
match self.target_kind {
|
||||
ProjectTarget::Bin => "",
|
||||
ProjectTarget::Cdylib => ".so",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn suffix(&self) -> &'static str {
|
||||
match self.target_kind {
|
||||
ProjectTarget::Bin => ".exe",
|
||||
ProjectTarget::Cdylib => ".dll",
|
||||
}
|
||||
}
|
||||
|
||||
// Returns tuple:
|
||||
// * symlink file path (relative to the root of build dir)
|
||||
// * symlink absolute file path
|
||||
// * target actual file (relative to symlink file)
|
||||
#[cfg_attr(not(unix), allow(unused))]
|
||||
fn symlinks<'a>(
|
||||
&'a self,
|
||||
target_dir: &'a PathBuf,
|
||||
profile: &'a str,
|
||||
libname: &'a str,
|
||||
) -> impl Iterator<Item = (&'a str, PathBuf, PathBuf)> + 'a {
|
||||
self.meta.linux_symlinks.iter().map(move |source| {
|
||||
let mut link = target_dir.clone();
|
||||
link.extend([profile, source]);
|
||||
let relative_link = PathBuf::from(source);
|
||||
let ancestors = relative_link.as_path().ancestors().count();
|
||||
let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold(
|
||||
PathBuf::new(),
|
||||
|mut buff, segment| {
|
||||
buff.push(segment);
|
||||
buff
|
||||
},
|
||||
);
|
||||
target.push(libname);
|
||||
(&**source, link, target)
|
||||
})
|
||||
}
|
||||
|
||||
fn file_name(&self) -> String {
|
||||
let target_name = &self.target_name;
|
||||
let prefix = self.prefix();
|
||||
let suffix = self.suffix();
|
||||
format!("{prefix}{target_name}{suffix}")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ProjectTarget {
|
||||
Cdylib,
|
||||
Bin,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Metadata {
|
||||
zluda: ZludaMetadata,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
struct ZludaMetadata {
|
||||
#[serde(default)]
|
||||
windows_only: bool,
|
||||
#[serde(default)]
|
||||
debug_only: bool,
|
||||
#[cfg_attr(not(unix), allow(unused))]
|
||||
#[serde(default)]
|
||||
linux_symlinks: Vec<String>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let options = match options().run_inner(Args::current_args()) {
|
||||
Ok(b) => b,
|
||||
Err(err) => match build().to_options().run_inner(Args::current_args()) {
|
||||
Ok(b) => Options::Build(b),
|
||||
Err(_) => {
|
||||
err.print_message(100);
|
||||
std::process::exit(err.exit_code());
|
||||
}
|
||||
},
|
||||
};
|
||||
match options {
|
||||
Options::Build(b) => {
|
||||
compile(b);
|
||||
}
|
||||
Options::Zip(b) => zip(b),
|
||||
}
|
||||
}
|
||||
|
||||
fn compile(b: Build) -> (PathBuf, String, Vec<Project>) {
|
||||
let profile = sniff_out_profile_name(&b.cargo_arguments);
|
||||
let meta = MetadataCommand::new().no_deps().exec().unwrap();
|
||||
let target_directory = meta.target_directory.into_std_path_buf();
|
||||
let projects = meta
|
||||
.packages
|
||||
.into_iter()
|
||||
.filter_map(Project::try_new)
|
||||
.filter(|project| {
|
||||
if project.meta.windows_only && cfg!(not(windows)) {
|
||||
return false;
|
||||
}
|
||||
if project.meta.debug_only && profile != "debug" {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
|
||||
let mut command = Command::new(&cargo);
|
||||
command.arg("build");
|
||||
command.arg("--locked");
|
||||
for project in projects.iter() {
|
||||
command.arg("--package");
|
||||
command.arg(&project.name);
|
||||
}
|
||||
command.args(b.cargo_arguments);
|
||||
assert!(command.status().unwrap().success());
|
||||
os::make_symlinks(&target_directory, &*projects, &*profile);
|
||||
(target_directory, profile, projects)
|
||||
}
|
||||
|
||||
fn sniff_out_profile_name(b: &[OsString]) -> String {
|
||||
let parsed_cargo_arguments = cargo().to_options().run_inner(b);
|
||||
match parsed_cargo_arguments {
|
||||
Ok(Cargo {
|
||||
release: Some(true),
|
||||
..
|
||||
}) => "release".to_string(),
|
||||
Ok(Cargo {
|
||||
profile: Some(profile),
|
||||
..
|
||||
}) => profile,
|
||||
_ => "debug".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn zip(zip: Build) {
|
||||
let (target_dir, profile, projects) = compile(zip);
|
||||
os::zip(target_dir, profile, projects)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
mod os {
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
path::PathBuf,
|
||||
};
|
||||
use tar::Header;
|
||||
|
||||
pub fn make_symlinks(
|
||||
target_directory: &std::path::PathBuf,
|
||||
projects: &[super::Project],
|
||||
profile: &str,
|
||||
) {
|
||||
use std::os::unix::fs as unix_fs;
|
||||
for project in projects.iter() {
|
||||
let libname = project.file_name();
|
||||
for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) {
|
||||
let mut dir = full_path.clone();
|
||||
assert!(dir.pop());
|
||||
fs::create_dir_all(dir).unwrap();
|
||||
fs::remove_file(&full_path).ok();
|
||||
unix_fs::symlink(&target, full_path).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
|
||||
let tar_gz =
|
||||
File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap();
|
||||
let enc = GzEncoder::new(tar_gz, Compression::default());
|
||||
let mut tar = tar::Builder::new(enc);
|
||||
for project in projects.iter() {
|
||||
let file_name = project.file_name();
|
||||
let mut file =
|
||||
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
|
||||
tar.append_file(format!("zluda/{file_name}"), &mut file)
|
||||
.unwrap();
|
||||
for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) {
|
||||
let mut header = Header::new_gnu();
|
||||
let meta = fs::symlink_metadata(&full_path).unwrap();
|
||||
header.set_metadata(&meta);
|
||||
tar.append_link(&mut header, format!("zluda/{source}"), target)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
tar.finish().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
mod os {
|
||||
use std::{fs::File, io, path::PathBuf};
|
||||
use zip::{write::SimpleFileOptions, ZipWriter};
|
||||
|
||||
pub fn make_symlinks(
|
||||
_target_directory: &std::path::PathBuf,
|
||||
_projects: &[super::Project],
|
||||
_profile: &str,
|
||||
) {
|
||||
}
|
||||
|
||||
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
|
||||
let zip_file =
|
||||
File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap();
|
||||
let mut zip = ZipWriter::new(zip_file);
|
||||
zip.add_directory("zluda", SimpleFileOptions::default())
|
||||
.unwrap();
|
||||
for project in projects.iter() {
|
||||
let file_name = project.file_name();
|
||||
let mut file =
|
||||
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
|
||||
let file_options = file_options_from_time(&file).unwrap_or_default();
|
||||
zip.start_file(format!("zluda/{file_name}"), file_options)
|
||||
.unwrap();
|
||||
io::copy(&mut file, &mut zip).unwrap();
|
||||
}
|
||||
zip.finish().unwrap();
|
||||
}
|
||||
|
||||
fn file_options_from_time(from: &File) -> io::Result<SimpleFileOptions> {
|
||||
let metadata = from.metadata()?;
|
||||
let modified = metadata.modified()?;
|
||||
let modified = time::OffsetDateTime::from(modified);
|
||||
Ok(SimpleFileOptions::default().last_modified_time(
|
||||
zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
@ -1,426 +1,426 @@
|
||||
use super::{FromCuda, LiveCheck};
|
||||
use crate::r#impl::{context, device};
|
||||
use comgr::Comgr;
|
||||
use cuda_types::cuda::*;
|
||||
use hip_runtime_sys::*;
|
||||
use std::{
|
||||
ffi::{c_void, CStr, CString},
|
||||
mem, ptr, slice,
|
||||
sync::OnceLock,
|
||||
usize,
|
||||
};
|
||||
|
||||
#[cfg_attr(windows, path = "os_win.rs")]
|
||||
#[cfg_attr(not(windows), path = "os_unix.rs")]
|
||||
mod os;
|
||||
|
||||
pub(crate) struct GlobalState {
|
||||
pub devices: Vec<Device>,
|
||||
pub comgr: Comgr,
|
||||
}
|
||||
|
||||
pub(crate) struct Device {
|
||||
pub(crate) _comgr_isa: CString,
|
||||
primary_context: LiveCheck<context::Context>,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
|
||||
unsafe {
|
||||
(
|
||||
self.primary_context.data.assume_init_ref(),
|
||||
self.primary_context.as_handle(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
|
||||
global_state()?
|
||||
.devices
|
||||
.get(dev as usize)
|
||||
.ok_or(CUerror::INVALID_DEVICE)
|
||||
}
|
||||
|
||||
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
||||
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
|
||||
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
|
||||
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
|
||||
}
|
||||
GLOBAL_STATE
|
||||
.get_or_init(|| {
|
||||
let mut device_count = 0;
|
||||
unsafe { hipGetDeviceCount(&mut device_count) }?;
|
||||
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
|
||||
Ok(GlobalState {
|
||||
comgr,
|
||||
devices: (0..device_count)
|
||||
.map(|i| {
|
||||
let mut props = unsafe { mem::zeroed() };
|
||||
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
|
||||
Ok::<_, CUerror>(Device {
|
||||
_comgr_isa: CStr::from_bytes_until_nul(cast_slice(
|
||||
&props.gcnArchName[..],
|
||||
))
|
||||
.map_err(|_| CUerror::UNKNOWN)?
|
||||
.to_owned(),
|
||||
primary_context: LiveCheck::new(context::Context::new(i)),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
})
|
||||
})
|
||||
.as_ref()
|
||||
.map_err(|e| *e)
|
||||
}
|
||||
|
||||
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
|
||||
unsafe { hipInit(flags) }?;
|
||||
global_state()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct UnknownBuffer<const S: usize> {
|
||||
buffer: std::cell::UnsafeCell<[u32; S]>,
|
||||
}
|
||||
|
||||
impl<const S: usize> UnknownBuffer<S> {
|
||||
const fn new() -> Self {
|
||||
UnknownBuffer {
|
||||
buffer: std::cell::UnsafeCell::new([0; S]),
|
||||
}
|
||||
}
|
||||
const fn len(&self) -> usize {
|
||||
S
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<const S: usize> Sync for UnknownBuffer<S> {}
|
||||
|
||||
static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new();
|
||||
static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new();
|
||||
|
||||
struct DarkApi {}
|
||||
|
||||
impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||
unsafe extern "system" fn get_module_from_cubin(
|
||||
_module: *mut cuda_types::cuda::CUmodule,
|
||||
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn cudart_interface_fn2(
|
||||
pctx: *mut cuda_types::cuda::CUcontext,
|
||||
hip_dev: hipDevice_t,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
let pctx = match pctx.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return CUresult::ERROR_INVALID_VALUE,
|
||||
};
|
||||
|
||||
device::primary_context_retain(pctx, hip_dev)
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_module_from_cubin_ext1(
|
||||
_result: *mut cuda_types::cuda::CUmodule,
|
||||
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
||||
_arg3: *mut std::ffi::c_void,
|
||||
_arg4: *mut std::ffi::c_void,
|
||||
_arg5: u32,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_module_from_cubin_ext2(
|
||||
_fatbin_header: *const cuda_types::dark_api::FatbinHeader,
|
||||
_result: *mut cuda_types::cuda::CUmodule,
|
||||
_arg3: *mut std::ffi::c_void,
|
||||
_arg4: *mut std::ffi::c_void,
|
||||
_arg5: u32,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_unknown_buffer1(
|
||||
ptr: *mut *mut std::ffi::c_void,
|
||||
size: *mut usize,
|
||||
) -> () {
|
||||
*ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void;
|
||||
*size = UNKNOWN_BUFFER1.len();
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_unknown_buffer2(
|
||||
ptr: *mut *mut std::ffi::c_void,
|
||||
size: *mut usize,
|
||||
) -> () {
|
||||
*ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void;
|
||||
*size = UNKNOWN_BUFFER2.len();
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_local_storage_put(
|
||||
cu_ctx: CUcontext,
|
||||
key: *mut c_void,
|
||||
value: *mut c_void,
|
||||
dtor_cb: Option<extern "system" fn(CUcontext, *mut c_void, *mut c_void)>,
|
||||
) -> CUresult {
|
||||
let _ctx = if cu_ctx.0 != ptr::null_mut() {
|
||||
cu_ctx
|
||||
} else {
|
||||
let mut current_ctx: CUcontext = CUcontext(ptr::null_mut());
|
||||
context::get_current(&mut current_ctx)?;
|
||||
current_ctx
|
||||
};
|
||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
|
||||
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
|
||||
state.storage.insert(
|
||||
key as usize,
|
||||
context::StorageData {
|
||||
value: value as usize,
|
||||
reset_cb: dtor_cb,
|
||||
handle: _ctx,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_local_storage_delete(
|
||||
cu_ctx: CUcontext,
|
||||
key: *mut c_void,
|
||||
) -> CUresult {
|
||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?;
|
||||
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
|
||||
state.storage.remove(&(key as usize));
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_local_storage_get(
|
||||
value: *mut *mut c_void,
|
||||
cu_ctx: CUcontext,
|
||||
key: *mut c_void,
|
||||
) -> CUresult {
|
||||
let mut _ctx: CUcontext;
|
||||
if cu_ctx.0 == ptr::null_mut() {
|
||||
_ctx = context::get_current_context()?;
|
||||
} else {
|
||||
_ctx = cu_ctx
|
||||
};
|
||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
|
||||
ctx_obj.with_state(|state: &context::ContextState| {
|
||||
match state.storage.get(&(key as usize)) {
|
||||
Some(data) => *value = data.value as *mut c_void,
|
||||
None => return CUresult::ERROR_INVALID_HANDLE,
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ctx_create_v2_bypass(
|
||||
_pctx: *mut cuda_types::cuda::CUcontext,
|
||||
_flags: ::std::os::raw::c_uint,
|
||||
_dev: cuda_types::cuda::CUdevice,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn heap_alloc(
|
||||
_heap_alloc_record_ptr: *mut *const std::ffi::c_void,
|
||||
_arg2: usize,
|
||||
_arg3: usize,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn heap_free(
|
||||
_heap_alloc_record_ptr: *const std::ffi::c_void,
|
||||
_arg2: *mut usize,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn device_get_attribute_ext(
|
||||
_dev: cuda_types::cuda::CUdevice,
|
||||
_attribute: std::ffi::c_uint,
|
||||
_unknown: std::ffi::c_int,
|
||||
_result: *mut [usize; 2],
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn device_get_something(
|
||||
_result: *mut std::ffi::c_uchar,
|
||||
_dev: cuda_types::cuda::CUdevice,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn integrity_check(
|
||||
version: u32,
|
||||
unix_seconds: u64,
|
||||
result: *mut [u64; 2],
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
let current_process = std::process::id();
|
||||
let current_thread = os::current_thread();
|
||||
|
||||
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
|
||||
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
|
||||
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
|
||||
|
||||
let devices = get_device_hash_info()?;
|
||||
let device_count = devices.len() as u32;
|
||||
let get_device = |dev| devices[dev as usize];
|
||||
|
||||
let hash = ::dark_api::integrity_check(
|
||||
version,
|
||||
unix_seconds,
|
||||
cuda_types::cuda::CUDA_VERSION,
|
||||
current_process,
|
||||
current_thread,
|
||||
integrity_check_table,
|
||||
cudart_table,
|
||||
fn_address,
|
||||
device_count,
|
||||
get_device,
|
||||
);
|
||||
*result = hash;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_check(
|
||||
_ctx_in: cuda_types::cuda::CUcontext,
|
||||
result1: *mut u32,
|
||||
_result2: *mut *const std::ffi::c_void,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
*result1 = 0;
|
||||
CUresult::SUCCESS
|
||||
}
|
||||
|
||||
unsafe extern "system" fn check_fn3() -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
|
||||
let mut device_count = 0;
|
||||
device::get_count(&mut device_count)?;
|
||||
|
||||
(0..device_count)
|
||||
.map(|dev| {
|
||||
let mut guid = CUuuid_st { bytes: [0; 16] };
|
||||
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
|
||||
|
||||
let mut pci_domain = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_domain,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
let mut pci_bus = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_bus,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
let mut pci_device = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_device,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
Ok(::dark_api::DeviceHashinfo {
|
||||
guid,
|
||||
pci_domain,
|
||||
pci_bus,
|
||||
pci_device,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
|
||||
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();
|
||||
|
||||
pub(crate) fn get_export_table(
|
||||
pp_export_table: &mut *const ::core::ffi::c_void,
|
||||
p_export_table_id: &CUuuid,
|
||||
) -> CUresult {
|
||||
if let Some(table) = EXPORT_TABLE.get(p_export_table_id) {
|
||||
*pp_export_table = table.start();
|
||||
cuda_types::cuda::CUresult::SUCCESS
|
||||
} else {
|
||||
cuda_types::cuda::CUresult::ERROR_INVALID_VALUE
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
|
||||
*version = cuda_types::cuda::CUDA_VERSION as i32;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn get_proc_address(
|
||||
symbol: &CStr,
|
||||
pfn: &mut *mut ::core::ffi::c_void,
|
||||
cuda_version: ::core::ffi::c_int,
|
||||
flags: cuda_types::cuda::cuuint64_t,
|
||||
) -> CUresult {
|
||||
get_proc_address_v2(symbol, pfn, cuda_version, flags, None)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn get_proc_address_v2(
|
||||
symbol: &CStr,
|
||||
pfn: &mut *mut ::core::ffi::c_void,
|
||||
cuda_version: ::core::ffi::c_int,
|
||||
flags: cuda_types::cuda::cuuint64_t,
|
||||
symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>,
|
||||
) -> CUresult {
|
||||
// This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point.
|
||||
fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void {
|
||||
use crate::*;
|
||||
include!("../../../zluda_bindgen/src/process_table.rs")
|
||||
}
|
||||
let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version);
|
||||
match fn_ptr as usize {
|
||||
0 => {
|
||||
if let Some(symbol_status) = symbol_status {
|
||||
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND;
|
||||
}
|
||||
*pfn = ptr::null_mut();
|
||||
CUresult::ERROR_NOT_FOUND
|
||||
}
|
||||
usize::MAX => {
|
||||
if let Some(symbol_status) = symbol_status {
|
||||
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT;
|
||||
}
|
||||
*pfn = ptr::null_mut();
|
||||
CUresult::ERROR_NOT_FOUND
|
||||
}
|
||||
_ => {
|
||||
if let Some(symbol_status) = symbol_status {
|
||||
*symbol_status =
|
||||
cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS;
|
||||
}
|
||||
*pfn = fn_ptr;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn profiler_start() -> CUresult {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn profiler_stop() -> CUresult {
|
||||
Ok(())
|
||||
}
|
||||
use super::{FromCuda, LiveCheck};
|
||||
use crate::r#impl::{context, device};
|
||||
use comgr::Comgr;
|
||||
use cuda_types::cuda::*;
|
||||
use hip_runtime_sys::*;
|
||||
use std::{
|
||||
ffi::{c_void, CStr, CString},
|
||||
mem, ptr, slice,
|
||||
sync::OnceLock,
|
||||
usize,
|
||||
};
|
||||
|
||||
#[cfg_attr(windows, path = "os_win.rs")]
|
||||
#[cfg_attr(not(windows), path = "os_unix.rs")]
|
||||
mod os;
|
||||
|
||||
pub(crate) struct GlobalState {
|
||||
pub devices: Vec<Device>,
|
||||
pub comgr: Comgr,
|
||||
}
|
||||
|
||||
pub(crate) struct Device {
|
||||
pub(crate) _comgr_isa: CString,
|
||||
primary_context: LiveCheck<context::Context>,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
|
||||
unsafe {
|
||||
(
|
||||
self.primary_context.data.assume_init_ref(),
|
||||
self.primary_context.as_handle(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
|
||||
global_state()?
|
||||
.devices
|
||||
.get(dev as usize)
|
||||
.ok_or(CUerror::INVALID_DEVICE)
|
||||
}
|
||||
|
||||
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
||||
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
|
||||
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
|
||||
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
|
||||
}
|
||||
GLOBAL_STATE
|
||||
.get_or_init(|| {
|
||||
let mut device_count = 0;
|
||||
unsafe { hipGetDeviceCount(&mut device_count) }?;
|
||||
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
|
||||
Ok(GlobalState {
|
||||
comgr,
|
||||
devices: (0..device_count)
|
||||
.map(|i| {
|
||||
let mut props = unsafe { mem::zeroed() };
|
||||
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
|
||||
Ok::<_, CUerror>(Device {
|
||||
_comgr_isa: CStr::from_bytes_until_nul(cast_slice(
|
||||
&props.gcnArchName[..],
|
||||
))
|
||||
.map_err(|_| CUerror::UNKNOWN)?
|
||||
.to_owned(),
|
||||
primary_context: LiveCheck::new(context::Context::new(i)),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
})
|
||||
})
|
||||
.as_ref()
|
||||
.map_err(|e| *e)
|
||||
}
|
||||
|
||||
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
|
||||
unsafe { hipInit(flags) }?;
|
||||
global_state()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct UnknownBuffer<const S: usize> {
|
||||
buffer: std::cell::UnsafeCell<[u32; S]>,
|
||||
}
|
||||
|
||||
impl<const S: usize> UnknownBuffer<S> {
|
||||
const fn new() -> Self {
|
||||
UnknownBuffer {
|
||||
buffer: std::cell::UnsafeCell::new([0; S]),
|
||||
}
|
||||
}
|
||||
const fn len(&self) -> usize {
|
||||
S
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<const S: usize> Sync for UnknownBuffer<S> {}
|
||||
|
||||
static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new();
|
||||
static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new();
|
||||
|
||||
struct DarkApi {}
|
||||
|
||||
impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||
unsafe extern "system" fn get_module_from_cubin(
|
||||
_module: *mut cuda_types::cuda::CUmodule,
|
||||
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn cudart_interface_fn2(
|
||||
pctx: *mut cuda_types::cuda::CUcontext,
|
||||
hip_dev: hipDevice_t,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
let pctx = match pctx.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return CUresult::ERROR_INVALID_VALUE,
|
||||
};
|
||||
|
||||
device::primary_context_retain(pctx, hip_dev)
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_module_from_cubin_ext1(
|
||||
_result: *mut cuda_types::cuda::CUmodule,
|
||||
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
||||
_arg3: *mut std::ffi::c_void,
|
||||
_arg4: *mut std::ffi::c_void,
|
||||
_arg5: u32,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_module_from_cubin_ext2(
|
||||
_fatbin_header: *const cuda_types::dark_api::FatbinHeader,
|
||||
_result: *mut cuda_types::cuda::CUmodule,
|
||||
_arg3: *mut std::ffi::c_void,
|
||||
_arg4: *mut std::ffi::c_void,
|
||||
_arg5: u32,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_unknown_buffer1(
|
||||
ptr: *mut *mut std::ffi::c_void,
|
||||
size: *mut usize,
|
||||
) -> () {
|
||||
*ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void;
|
||||
*size = UNKNOWN_BUFFER1.len();
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_unknown_buffer2(
|
||||
ptr: *mut *mut std::ffi::c_void,
|
||||
size: *mut usize,
|
||||
) -> () {
|
||||
*ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void;
|
||||
*size = UNKNOWN_BUFFER2.len();
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_local_storage_put(
|
||||
cu_ctx: CUcontext,
|
||||
key: *mut c_void,
|
||||
value: *mut c_void,
|
||||
dtor_cb: Option<extern "system" fn(CUcontext, *mut c_void, *mut c_void)>,
|
||||
) -> CUresult {
|
||||
let _ctx = if cu_ctx.0 != ptr::null_mut() {
|
||||
cu_ctx
|
||||
} else {
|
||||
let mut current_ctx: CUcontext = CUcontext(ptr::null_mut());
|
||||
context::get_current(&mut current_ctx)?;
|
||||
current_ctx
|
||||
};
|
||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
|
||||
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
|
||||
state.storage.insert(
|
||||
key as usize,
|
||||
context::StorageData {
|
||||
value: value as usize,
|
||||
reset_cb: dtor_cb,
|
||||
handle: _ctx,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_local_storage_delete(
|
||||
cu_ctx: CUcontext,
|
||||
key: *mut c_void,
|
||||
) -> CUresult {
|
||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?;
|
||||
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
|
||||
state.storage.remove(&(key as usize));
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_local_storage_get(
|
||||
value: *mut *mut c_void,
|
||||
cu_ctx: CUcontext,
|
||||
key: *mut c_void,
|
||||
) -> CUresult {
|
||||
let mut _ctx: CUcontext;
|
||||
if cu_ctx.0 == ptr::null_mut() {
|
||||
_ctx = context::get_current_context()?;
|
||||
} else {
|
||||
_ctx = cu_ctx
|
||||
};
|
||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
|
||||
ctx_obj.with_state(|state: &context::ContextState| {
|
||||
match state.storage.get(&(key as usize)) {
|
||||
Some(data) => *value = data.value as *mut c_void,
|
||||
None => return CUresult::ERROR_INVALID_HANDLE,
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn ctx_create_v2_bypass(
|
||||
_pctx: *mut cuda_types::cuda::CUcontext,
|
||||
_flags: ::std::os::raw::c_uint,
|
||||
_dev: cuda_types::cuda::CUdevice,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn heap_alloc(
|
||||
_heap_alloc_record_ptr: *mut *const std::ffi::c_void,
|
||||
_arg2: usize,
|
||||
_arg3: usize,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn heap_free(
|
||||
_heap_alloc_record_ptr: *const std::ffi::c_void,
|
||||
_arg2: *mut usize,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn device_get_attribute_ext(
|
||||
_dev: cuda_types::cuda::CUdevice,
|
||||
_attribute: std::ffi::c_uint,
|
||||
_unknown: std::ffi::c_int,
|
||||
_result: *mut [usize; 2],
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn device_get_something(
|
||||
_result: *mut std::ffi::c_uchar,
|
||||
_dev: cuda_types::cuda::CUdevice,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
unsafe extern "system" fn integrity_check(
|
||||
version: u32,
|
||||
unix_seconds: u64,
|
||||
result: *mut [u64; 2],
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
let current_process = std::process::id();
|
||||
let current_thread = os::current_thread();
|
||||
|
||||
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
|
||||
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
|
||||
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
|
||||
|
||||
let devices = get_device_hash_info()?;
|
||||
let device_count = devices.len() as u32;
|
||||
let get_device = |dev| devices[dev as usize];
|
||||
|
||||
let hash = ::dark_api::integrity_check(
|
||||
version,
|
||||
unix_seconds,
|
||||
cuda_types::cuda::CUDA_VERSION,
|
||||
current_process,
|
||||
current_thread,
|
||||
integrity_check_table,
|
||||
cudart_table,
|
||||
fn_address,
|
||||
device_count,
|
||||
get_device,
|
||||
);
|
||||
*result = hash;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_check(
|
||||
_ctx_in: cuda_types::cuda::CUcontext,
|
||||
result1: *mut u32,
|
||||
_result2: *mut *const std::ffi::c_void,
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
*result1 = 0;
|
||||
CUresult::SUCCESS
|
||||
}
|
||||
|
||||
unsafe extern "system" fn check_fn3() -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
|
||||
let mut device_count = 0;
|
||||
device::get_count(&mut device_count)?;
|
||||
|
||||
(0..device_count)
|
||||
.map(|dev| {
|
||||
let mut guid = CUuuid_st { bytes: [0; 16] };
|
||||
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
|
||||
|
||||
let mut pci_domain = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_domain,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
let mut pci_bus = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_bus,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
let mut pci_device = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_device,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
Ok(::dark_api::DeviceHashinfo {
|
||||
guid,
|
||||
pci_domain,
|
||||
pci_bus,
|
||||
pci_device,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
|
||||
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();
|
||||
|
||||
pub(crate) fn get_export_table(
|
||||
pp_export_table: &mut *const ::core::ffi::c_void,
|
||||
p_export_table_id: &CUuuid,
|
||||
) -> CUresult {
|
||||
if let Some(table) = EXPORT_TABLE.get(p_export_table_id) {
|
||||
*pp_export_table = table.start();
|
||||
cuda_types::cuda::CUresult::SUCCESS
|
||||
} else {
|
||||
cuda_types::cuda::CUresult::ERROR_INVALID_VALUE
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
|
||||
*version = cuda_types::cuda::CUDA_VERSION as i32;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn get_proc_address(
|
||||
symbol: &CStr,
|
||||
pfn: &mut *mut ::core::ffi::c_void,
|
||||
cuda_version: ::core::ffi::c_int,
|
||||
flags: cuda_types::cuda::cuuint64_t,
|
||||
) -> CUresult {
|
||||
get_proc_address_v2(symbol, pfn, cuda_version, flags, None)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn get_proc_address_v2(
|
||||
symbol: &CStr,
|
||||
pfn: &mut *mut ::core::ffi::c_void,
|
||||
cuda_version: ::core::ffi::c_int,
|
||||
flags: cuda_types::cuda::cuuint64_t,
|
||||
symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>,
|
||||
) -> CUresult {
|
||||
// This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point.
|
||||
fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void {
|
||||
use crate::*;
|
||||
include!("../../../zluda_bindgen/src/process_table.rs")
|
||||
}
|
||||
let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version);
|
||||
match fn_ptr as usize {
|
||||
0 => {
|
||||
if let Some(symbol_status) = symbol_status {
|
||||
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND;
|
||||
}
|
||||
*pfn = ptr::null_mut();
|
||||
CUresult::ERROR_NOT_FOUND
|
||||
}
|
||||
usize::MAX => {
|
||||
if let Some(symbol_status) = symbol_status {
|
||||
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT;
|
||||
}
|
||||
*pfn = ptr::null_mut();
|
||||
CUresult::ERROR_NOT_FOUND
|
||||
}
|
||||
_ => {
|
||||
if let Some(symbol_status) = symbol_status {
|
||||
*symbol_status =
|
||||
cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS;
|
||||
}
|
||||
*pfn = fn_ptr;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn profiler_start() -> CUresult {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn profiler_stop() -> CUresult {
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
// TODO: remove duplication with zluda_dump
|
||||
#[link(name = "pthread")]
|
||||
unsafe extern "C" {
|
||||
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
(unsafe { pthread_self() }) as u32
|
||||
}
|
||||
// TODO: remove duplication with zluda_dump
|
||||
#[link(name = "pthread")]
|
||||
unsafe extern "C" {
|
||||
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
(unsafe { pthread_self() }) as u32
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
// TODO: remove duplication with zluda_dump
|
||||
#[link(name = "kernel32")]
|
||||
unsafe extern "system" {
|
||||
fn GetCurrentThreadId() -> u32;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
unsafe { GetCurrentThreadId() }
|
||||
}
|
||||
// TODO: remove duplication with zluda_dump
|
||||
#[link(name = "kernel32")]
|
||||
unsafe extern "system" {
|
||||
fn GetCurrentThreadId() -> u32;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
unsafe { GetCurrentThreadId() }
|
||||
}
|
||||
|
@ -1,124 +1,124 @@
|
||||
use crate::os;
|
||||
use crate::{CudaFunctionName, ErrorEntry};
|
||||
use cuda_types::cuda::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::cell::RefMut;
|
||||
use std::hash::Hash;
|
||||
use std::{collections::hash_map, ffi::c_void, mem};
|
||||
|
||||
pub(crate) struct DarkApiState2 {
|
||||
// Key is Box<CUuuid, because thunk reporting unknown export table needs a
|
||||
// stable memory location for the guid
|
||||
pub(crate) overrides: FxHashMap<Box<CUuuidWrapper>, (*const *const c_void, Vec<*const c_void>)>,
|
||||
}
|
||||
|
||||
unsafe impl Send for DarkApiState2 {}
|
||||
unsafe impl Sync for DarkApiState2 {}
|
||||
|
||||
impl DarkApiState2 {
|
||||
pub(crate) fn new() -> Self {
|
||||
DarkApiState2 {
|
||||
overrides: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn override_export_table(
|
||||
&mut self,
|
||||
known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable,
|
||||
original_export_table: *const *const c_void,
|
||||
guid: &CUuuid_st,
|
||||
) -> (*const *const c_void, Option<ErrorEntry>) {
|
||||
let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) {
|
||||
hash_map::Entry::Occupied(entry) => {
|
||||
let (_, override_table) = entry.get();
|
||||
return (override_table.as_ptr(), None);
|
||||
}
|
||||
hash_map::Entry::Vacant(entry) => entry,
|
||||
};
|
||||
let mut error = None;
|
||||
let byte_size: usize = unsafe { *(original_export_table.cast::<usize>()) };
|
||||
// Some export tables don't start with a byte count, but directly with a
|
||||
// pointer, and are instead terminated by 0 or MAX
|
||||
let export_functions_start_idx;
|
||||
let export_functions_size;
|
||||
if byte_size > 0x10000 {
|
||||
export_functions_start_idx = 0;
|
||||
let mut i = 0;
|
||||
loop {
|
||||
let current_ptr = unsafe { original_export_table.add(i) };
|
||||
let current_ptr_numeric = unsafe { *current_ptr } as usize;
|
||||
if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX {
|
||||
export_functions_size = i;
|
||||
break;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
export_functions_start_idx = 1;
|
||||
export_functions_size = byte_size / mem::size_of::<usize>();
|
||||
}
|
||||
let our_functions = known_exports.get(guid);
|
||||
if let Some(ref our_functions) = our_functions {
|
||||
if our_functions.len() != export_functions_size {
|
||||
error = Some(ErrorEntry::UnexpectedExportTableSize {
|
||||
expected: our_functions.len(),
|
||||
computed: export_functions_size,
|
||||
});
|
||||
}
|
||||
}
|
||||
let mut override_table =
|
||||
unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) }
|
||||
.to_vec();
|
||||
for i in export_functions_start_idx..export_functions_size {
|
||||
let current_fn = (|| {
|
||||
if let Some(ref our_functions) = our_functions {
|
||||
if let Some(fn_) = our_functions.get_fn(i) {
|
||||
return fn_;
|
||||
}
|
||||
}
|
||||
os::get_thunk(
|
||||
override_table[i],
|
||||
Self::report_unknown_export_table_call,
|
||||
std::ptr::from_ref(entry.key().as_ref()).cast(),
|
||||
i,
|
||||
)
|
||||
})();
|
||||
override_table[i] = current_fn;
|
||||
}
|
||||
(
|
||||
entry
|
||||
.insert((original_export_table, override_table))
|
||||
.1
|
||||
.as_ptr(),
|
||||
error,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) {
|
||||
let global_state = crate::GLOBAL_STATE2.lock();
|
||||
let global_state_ref_cell = &*global_state;
|
||||
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
|
||||
let global_state = &mut *global_state_ref_mut;
|
||||
let log_guard = crate::OuterCallGuard {
|
||||
writer: &mut global_state.log_writer,
|
||||
log_root: &global_state.log_stack,
|
||||
};
|
||||
{
|
||||
let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| {
|
||||
log_stack.enter()
|
||||
});
|
||||
logger.name = CudaFunctionName::Dark { guid: *guid, index };
|
||||
};
|
||||
drop(log_guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq)]
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct CUuuidWrapper(pub CUuuid);
|
||||
|
||||
impl Hash for CUuuidWrapper {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.0.bytes.hash(state);
|
||||
}
|
||||
}
|
||||
use crate::os;
|
||||
use crate::{CudaFunctionName, ErrorEntry};
|
||||
use cuda_types::cuda::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::cell::RefMut;
|
||||
use std::hash::Hash;
|
||||
use std::{collections::hash_map, ffi::c_void, mem};
|
||||
|
||||
pub(crate) struct DarkApiState2 {
|
||||
// Key is Box<CUuuid, because thunk reporting unknown export table needs a
|
||||
// stable memory location for the guid
|
||||
pub(crate) overrides: FxHashMap<Box<CUuuidWrapper>, (*const *const c_void, Vec<*const c_void>)>,
|
||||
}
|
||||
|
||||
unsafe impl Send for DarkApiState2 {}
|
||||
unsafe impl Sync for DarkApiState2 {}
|
||||
|
||||
impl DarkApiState2 {
|
||||
pub(crate) fn new() -> Self {
|
||||
DarkApiState2 {
|
||||
overrides: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn override_export_table(
|
||||
&mut self,
|
||||
known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable,
|
||||
original_export_table: *const *const c_void,
|
||||
guid: &CUuuid_st,
|
||||
) -> (*const *const c_void, Option<ErrorEntry>) {
|
||||
let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) {
|
||||
hash_map::Entry::Occupied(entry) => {
|
||||
let (_, override_table) = entry.get();
|
||||
return (override_table.as_ptr(), None);
|
||||
}
|
||||
hash_map::Entry::Vacant(entry) => entry,
|
||||
};
|
||||
let mut error = None;
|
||||
let byte_size: usize = unsafe { *(original_export_table.cast::<usize>()) };
|
||||
// Some export tables don't start with a byte count, but directly with a
|
||||
// pointer, and are instead terminated by 0 or MAX
|
||||
let export_functions_start_idx;
|
||||
let export_functions_size;
|
||||
if byte_size > 0x10000 {
|
||||
export_functions_start_idx = 0;
|
||||
let mut i = 0;
|
||||
loop {
|
||||
let current_ptr = unsafe { original_export_table.add(i) };
|
||||
let current_ptr_numeric = unsafe { *current_ptr } as usize;
|
||||
if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX {
|
||||
export_functions_size = i;
|
||||
break;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
export_functions_start_idx = 1;
|
||||
export_functions_size = byte_size / mem::size_of::<usize>();
|
||||
}
|
||||
let our_functions = known_exports.get(guid);
|
||||
if let Some(ref our_functions) = our_functions {
|
||||
if our_functions.len() != export_functions_size {
|
||||
error = Some(ErrorEntry::UnexpectedExportTableSize {
|
||||
expected: our_functions.len(),
|
||||
computed: export_functions_size,
|
||||
});
|
||||
}
|
||||
}
|
||||
let mut override_table =
|
||||
unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) }
|
||||
.to_vec();
|
||||
for i in export_functions_start_idx..export_functions_size {
|
||||
let current_fn = (|| {
|
||||
if let Some(ref our_functions) = our_functions {
|
||||
if let Some(fn_) = our_functions.get_fn(i) {
|
||||
return fn_;
|
||||
}
|
||||
}
|
||||
os::get_thunk(
|
||||
override_table[i],
|
||||
Self::report_unknown_export_table_call,
|
||||
std::ptr::from_ref(entry.key().as_ref()).cast(),
|
||||
i,
|
||||
)
|
||||
})();
|
||||
override_table[i] = current_fn;
|
||||
}
|
||||
(
|
||||
entry
|
||||
.insert((original_export_table, override_table))
|
||||
.1
|
||||
.as_ptr(),
|
||||
error,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) {
|
||||
let global_state = crate::GLOBAL_STATE2.lock();
|
||||
let global_state_ref_cell = &*global_state;
|
||||
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
|
||||
let global_state = &mut *global_state_ref_mut;
|
||||
let log_guard = crate::OuterCallGuard {
|
||||
writer: &mut global_state.log_writer,
|
||||
log_root: &global_state.log_stack,
|
||||
};
|
||||
{
|
||||
let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| {
|
||||
log_stack.enter()
|
||||
});
|
||||
logger.name = CudaFunctionName::Dark { guid: *guid, index };
|
||||
};
|
||||
drop(log_guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq)]
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct CUuuidWrapper(pub CUuuid);
|
||||
|
||||
impl Hash for CUuuidWrapper {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.0.bytes.hash(state);
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,81 +1,81 @@
|
||||
use cuda_types::cuda::CUuuid;
|
||||
use std::ffi::{c_void, CStr, CString};
|
||||
use std::mem;
|
||||
|
||||
pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1";
|
||||
|
||||
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
|
||||
let libcuda_path = CString::new(libcuda_path).unwrap();
|
||||
libc::dlopen(
|
||||
libcuda_path.as_ptr() as *const _,
|
||||
libc::RTLD_LOCAL | libc::RTLD_NOW,
|
||||
)
|
||||
}
|
||||
|
||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||
libc::dlsym(handle, func.as_ptr() as *const _)
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! os_log {
|
||||
($format:tt) => {
|
||||
{
|
||||
eprintln!("[ZLUDA_DUMP] {}", format!($format));
|
||||
}
|
||||
};
|
||||
($format:tt, $($obj: expr),+) => {
|
||||
{
|
||||
eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
//RDI, RSI, RDX, RCX, R8, R9
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub fn get_thunk(
|
||||
original_fn: *const c_void,
|
||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||
guid: *const CUuuid,
|
||||
idx: usize,
|
||||
) -> *const c_void {
|
||||
use dynasmrt::{dynasm, DynasmApi};
|
||||
let mut ops = dynasmrt::x64::Assembler::new().unwrap();
|
||||
let start = ops.offset();
|
||||
dynasm!(ops
|
||||
// stack alignment
|
||||
; sub rsp, 8
|
||||
; push rdi
|
||||
; push rsi
|
||||
; push rdx
|
||||
; push rcx
|
||||
; push r8
|
||||
; push r9
|
||||
; mov rdi, QWORD guid as i64
|
||||
; mov rsi, QWORD idx as i64
|
||||
; mov rax, QWORD report_fn as i64
|
||||
; call rax
|
||||
; pop r9
|
||||
; pop r8
|
||||
; pop rcx
|
||||
; pop rdx
|
||||
; pop rsi
|
||||
; pop rdi
|
||||
; add rsp, 8
|
||||
; mov rax, QWORD original_fn as i64
|
||||
; jmp rax
|
||||
; int 3
|
||||
);
|
||||
let exe_buf = ops.finalize().unwrap();
|
||||
let result_fn = exe_buf.ptr(start);
|
||||
mem::forget(exe_buf);
|
||||
result_fn as *const _
|
||||
}
|
||||
|
||||
#[link(name = "pthread")]
|
||||
unsafe extern "C" {
|
||||
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
(unsafe { pthread_self() }) as u32
|
||||
}
|
||||
use cuda_types::cuda::CUuuid;
|
||||
use std::ffi::{c_void, CStr, CString};
|
||||
use std::mem;
|
||||
|
||||
pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1";
|
||||
|
||||
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
|
||||
let libcuda_path = CString::new(libcuda_path).unwrap();
|
||||
libc::dlopen(
|
||||
libcuda_path.as_ptr() as *const _,
|
||||
libc::RTLD_LOCAL | libc::RTLD_NOW,
|
||||
)
|
||||
}
|
||||
|
||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||
libc::dlsym(handle, func.as_ptr() as *const _)
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! os_log {
|
||||
($format:tt) => {
|
||||
{
|
||||
eprintln!("[ZLUDA_DUMP] {}", format!($format));
|
||||
}
|
||||
};
|
||||
($format:tt, $($obj: expr),+) => {
|
||||
{
|
||||
eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
//RDI, RSI, RDX, RCX, R8, R9
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub fn get_thunk(
|
||||
original_fn: *const c_void,
|
||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||
guid: *const CUuuid,
|
||||
idx: usize,
|
||||
) -> *const c_void {
|
||||
use dynasmrt::{dynasm, DynasmApi};
|
||||
let mut ops = dynasmrt::x64::Assembler::new().unwrap();
|
||||
let start = ops.offset();
|
||||
dynasm!(ops
|
||||
// stack alignment
|
||||
; sub rsp, 8
|
||||
; push rdi
|
||||
; push rsi
|
||||
; push rdx
|
||||
; push rcx
|
||||
; push r8
|
||||
; push r9
|
||||
; mov rdi, QWORD guid as i64
|
||||
; mov rsi, QWORD idx as i64
|
||||
; mov rax, QWORD report_fn as i64
|
||||
; call rax
|
||||
; pop r9
|
||||
; pop r8
|
||||
; pop rcx
|
||||
; pop rdx
|
||||
; pop rsi
|
||||
; pop rdi
|
||||
; add rsp, 8
|
||||
; mov rax, QWORD original_fn as i64
|
||||
; jmp rax
|
||||
; int 3
|
||||
);
|
||||
let exe_buf = ops.finalize().unwrap();
|
||||
let result_fn = exe_buf.ptr(start);
|
||||
mem::forget(exe_buf);
|
||||
result_fn as *const _
|
||||
}
|
||||
|
||||
#[link(name = "pthread")]
|
||||
unsafe extern "C" {
|
||||
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
(unsafe { pthread_self() }) as u32
|
||||
}
|
||||
|
@ -1,190 +1,190 @@
|
||||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
mem, ptr,
|
||||
sync::LazyLock,
|
||||
};
|
||||
|
||||
use std::os::windows::io::AsRawHandle;
|
||||
use winapi::{
|
||||
shared::minwindef::{FARPROC, HMODULE},
|
||||
um::debugapi::OutputDebugStringA,
|
||||
um::libloaderapi::{GetProcAddress, LoadLibraryW},
|
||||
};
|
||||
|
||||
use cuda_types::cuda::CUuuid;
|
||||
|
||||
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
|
||||
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
|
||||
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
|
||||
|
||||
static PLATFORM_LIBRARY: LazyLock<PlatformLibrary> =
|
||||
LazyLock::new(|| unsafe { PlatformLibrary::new() });
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
struct PlatformLibrary {
|
||||
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
|
||||
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
|
||||
}
|
||||
|
||||
impl PlatformLibrary {
|
||||
#[allow(non_snake_case)]
|
||||
unsafe fn new() -> Self {
|
||||
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
|
||||
None => (
|
||||
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
|
||||
mem::transmute(
|
||||
GetProcAddress
|
||||
as unsafe extern "system" fn(
|
||||
hModule: HMODULE,
|
||||
lpProcName: *const i8,
|
||||
) -> FARPROC,
|
||||
),
|
||||
),
|
||||
Some(zluda_with) => (
|
||||
mem::transmute(GetProcAddress(
|
||||
zluda_with,
|
||||
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
|
||||
)),
|
||||
mem::transmute(GetProcAddress(
|
||||
zluda_with,
|
||||
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
|
||||
)),
|
||||
),
|
||||
};
|
||||
PlatformLibrary {
|
||||
LoadLibraryW,
|
||||
GetProcAddress,
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn get_detourer_module() -> Option<HMODULE> {
|
||||
let mut module = ptr::null_mut();
|
||||
loop {
|
||||
module = detours_sys::DetourEnumerateModules(module);
|
||||
if module == ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
|
||||
if payload != ptr::null_mut() {
|
||||
return Some(module as _);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
|
||||
let libcuda_path_uf16 = libcuda_path
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect::<Vec<_>>();
|
||||
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
|
||||
}
|
||||
|
||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! os_log {
|
||||
($format:tt) => {
|
||||
{
|
||||
use crate::os::__log_impl;
|
||||
__log_impl(format!($format));
|
||||
}
|
||||
};
|
||||
($format:tt, $($obj: expr),+) => {
|
||||
{
|
||||
use crate::os::__log_impl;
|
||||
__log_impl(format!($format, $($obj,)+));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn __log_impl(s: String) {
|
||||
let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut();
|
||||
if log_to_stderr {
|
||||
eprintln!("[ZLUDA_DUMP] {}", s);
|
||||
} else {
|
||||
let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2);
|
||||
win_str.push_str("[ZLUDA_DUMP] ");
|
||||
win_str.push_str(&s);
|
||||
win_str.push_str("\n\0");
|
||||
unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) };
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86")]
|
||||
pub fn get_thunk(
|
||||
original_fn: *const c_void,
|
||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||
guid: *const CUuuid,
|
||||
idx: usize,
|
||||
) -> *const c_void {
|
||||
use dynasmrt::{dynasm, DynasmApi};
|
||||
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
|
||||
let start = ops.offset();
|
||||
dynasm!(ops
|
||||
; .arch x86
|
||||
; push idx as i32
|
||||
; push guid as i32
|
||||
; mov eax, report_fn as i32
|
||||
; call eax
|
||||
; mov eax, original_fn as i32
|
||||
; jmp eax
|
||||
; int 3
|
||||
);
|
||||
let exe_buf = ops.finalize().unwrap();
|
||||
let result_fn = exe_buf.ptr(start);
|
||||
mem::forget(exe_buf);
|
||||
result_fn as *const _
|
||||
}
|
||||
|
||||
//RCX, RDX, R8, R9
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub fn get_thunk(
|
||||
original_fn: *const c_void,
|
||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||
guid: *const CUuuid,
|
||||
idx: usize,
|
||||
) -> *const c_void {
|
||||
use dynasmrt::{dynasm, DynasmApi};
|
||||
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
|
||||
let start = ops.offset();
|
||||
// Let's hope there's never more than 4 arguments
|
||||
dynasm!(ops
|
||||
; .arch x64
|
||||
; push rbp
|
||||
; mov rbp, rsp
|
||||
; push rcx
|
||||
; push rdx
|
||||
; push r8
|
||||
; push r9
|
||||
; mov rcx, QWORD guid as i64
|
||||
; mov rdx, QWORD idx as i64
|
||||
; mov rax, QWORD report_fn as i64
|
||||
; call rax
|
||||
; pop r9
|
||||
; pop r8
|
||||
; pop rdx
|
||||
; pop rcx
|
||||
; mov rax, QWORD original_fn as i64
|
||||
; call rax
|
||||
; pop rbp
|
||||
; ret
|
||||
; int 3
|
||||
);
|
||||
let exe_buf = ops.finalize().unwrap();
|
||||
let result_fn = exe_buf.ptr(start);
|
||||
mem::forget(exe_buf);
|
||||
result_fn as *const _
|
||||
}
|
||||
|
||||
#[link(name = "kernel32")]
|
||||
unsafe extern "system" {
|
||||
fn GetCurrentThreadId() -> u32;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
unsafe { GetCurrentThreadId() }
|
||||
}
|
||||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
mem, ptr,
|
||||
sync::LazyLock,
|
||||
};
|
||||
|
||||
use std::os::windows::io::AsRawHandle;
|
||||
use winapi::{
|
||||
shared::minwindef::{FARPROC, HMODULE},
|
||||
um::debugapi::OutputDebugStringA,
|
||||
um::libloaderapi::{GetProcAddress, LoadLibraryW},
|
||||
};
|
||||
|
||||
use cuda_types::cuda::CUuuid;
|
||||
|
||||
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
|
||||
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
|
||||
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
|
||||
|
||||
static PLATFORM_LIBRARY: LazyLock<PlatformLibrary> =
|
||||
LazyLock::new(|| unsafe { PlatformLibrary::new() });
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
struct PlatformLibrary {
|
||||
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
|
||||
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
|
||||
}
|
||||
|
||||
impl PlatformLibrary {
|
||||
#[allow(non_snake_case)]
|
||||
unsafe fn new() -> Self {
|
||||
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
|
||||
None => (
|
||||
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
|
||||
mem::transmute(
|
||||
GetProcAddress
|
||||
as unsafe extern "system" fn(
|
||||
hModule: HMODULE,
|
||||
lpProcName: *const i8,
|
||||
) -> FARPROC,
|
||||
),
|
||||
),
|
||||
Some(zluda_with) => (
|
||||
mem::transmute(GetProcAddress(
|
||||
zluda_with,
|
||||
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
|
||||
)),
|
||||
mem::transmute(GetProcAddress(
|
||||
zluda_with,
|
||||
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
|
||||
)),
|
||||
),
|
||||
};
|
||||
PlatformLibrary {
|
||||
LoadLibraryW,
|
||||
GetProcAddress,
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn get_detourer_module() -> Option<HMODULE> {
|
||||
let mut module = ptr::null_mut();
|
||||
loop {
|
||||
module = detours_sys::DetourEnumerateModules(module);
|
||||
if module == ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
|
||||
if payload != ptr::null_mut() {
|
||||
return Some(module as _);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
|
||||
let libcuda_path_uf16 = libcuda_path
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect::<Vec<_>>();
|
||||
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
|
||||
}
|
||||
|
||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! os_log {
|
||||
($format:tt) => {
|
||||
{
|
||||
use crate::os::__log_impl;
|
||||
__log_impl(format!($format));
|
||||
}
|
||||
};
|
||||
($format:tt, $($obj: expr),+) => {
|
||||
{
|
||||
use crate::os::__log_impl;
|
||||
__log_impl(format!($format, $($obj,)+));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn __log_impl(s: String) {
|
||||
let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut();
|
||||
if log_to_stderr {
|
||||
eprintln!("[ZLUDA_DUMP] {}", s);
|
||||
} else {
|
||||
let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2);
|
||||
win_str.push_str("[ZLUDA_DUMP] ");
|
||||
win_str.push_str(&s);
|
||||
win_str.push_str("\n\0");
|
||||
unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) };
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86")]
|
||||
pub fn get_thunk(
|
||||
original_fn: *const c_void,
|
||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||
guid: *const CUuuid,
|
||||
idx: usize,
|
||||
) -> *const c_void {
|
||||
use dynasmrt::{dynasm, DynasmApi};
|
||||
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
|
||||
let start = ops.offset();
|
||||
dynasm!(ops
|
||||
; .arch x86
|
||||
; push idx as i32
|
||||
; push guid as i32
|
||||
; mov eax, report_fn as i32
|
||||
; call eax
|
||||
; mov eax, original_fn as i32
|
||||
; jmp eax
|
||||
; int 3
|
||||
);
|
||||
let exe_buf = ops.finalize().unwrap();
|
||||
let result_fn = exe_buf.ptr(start);
|
||||
mem::forget(exe_buf);
|
||||
result_fn as *const _
|
||||
}
|
||||
|
||||
//RCX, RDX, R8, R9
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub fn get_thunk(
|
||||
original_fn: *const c_void,
|
||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||
guid: *const CUuuid,
|
||||
idx: usize,
|
||||
) -> *const c_void {
|
||||
use dynasmrt::{dynasm, DynasmApi};
|
||||
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
|
||||
let start = ops.offset();
|
||||
// Let's hope there's never more than 4 arguments
|
||||
dynasm!(ops
|
||||
; .arch x64
|
||||
; push rbp
|
||||
; mov rbp, rsp
|
||||
; push rcx
|
||||
; push rdx
|
||||
; push r8
|
||||
; push r9
|
||||
; mov rcx, QWORD guid as i64
|
||||
; mov rdx, QWORD idx as i64
|
||||
; mov rax, QWORD report_fn as i64
|
||||
; call rax
|
||||
; pop r9
|
||||
; pop r8
|
||||
; pop rdx
|
||||
; pop rcx
|
||||
; mov rax, QWORD original_fn as i64
|
||||
; call rax
|
||||
; pop rbp
|
||||
; ret
|
||||
; int 3
|
||||
);
|
||||
let exe_buf = ops.finalize().unwrap();
|
||||
let result_fn = exe_buf.ptr(start);
|
||||
mem::forget(exe_buf);
|
||||
result_fn as *const _
|
||||
}
|
||||
|
||||
#[link(name = "kernel32")]
|
||||
unsafe extern "system" {
|
||||
fn GetCurrentThreadId() -> u32;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
unsafe { GetCurrentThreadId() }
|
||||
}
|
||||
|
@ -1,334 +1,334 @@
|
||||
use crate::{
|
||||
log::{self, UInt},
|
||||
trace, ErrorEntry, FnCallLog, Settings,
|
||||
};
|
||||
use cuda_types::{
|
||||
cuda::*,
|
||||
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
|
||||
};
|
||||
use dark_api::fatbin::{
|
||||
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
|
||||
};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
ffi::{c_void, CStr, CString},
|
||||
fs::{self, File},
|
||||
io::{self, Read, Write},
|
||||
path::PathBuf,
|
||||
};
|
||||
use unwrap_or::unwrap_some_or;
|
||||
|
||||
// This struct is the heart of CUDA state tracking, it:
|
||||
// * receives calls from the probes about changes to CUDA state
|
||||
// * records updates to the state change
|
||||
// * writes out relevant state change and details to disk and log
|
||||
pub(crate) struct StateTracker {
|
||||
writer: DumpWriter,
|
||||
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
|
||||
saved_modules: FxHashSet<CUmodule>,
|
||||
module_counter: usize,
|
||||
submodule_counter: usize,
|
||||
pub(crate) override_cc: Option<(u32, u32)>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) struct CodePointer(pub *const c_void);
|
||||
|
||||
unsafe impl Send for CodePointer {}
|
||||
unsafe impl Sync for CodePointer {}
|
||||
|
||||
impl StateTracker {
|
||||
pub(crate) fn new(settings: &Settings) -> Self {
|
||||
StateTracker {
|
||||
writer: DumpWriter::new(settings.dump_dir.clone()),
|
||||
libraries: FxHashMap::default(),
|
||||
saved_modules: FxHashSet::default(),
|
||||
module_counter: 0,
|
||||
submodule_counter: 0,
|
||||
override_cc: settings.override_cc,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn record_new_module_file(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
file_name: *const i8,
|
||||
fn_logger: &mut FnCallLog,
|
||||
) {
|
||||
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name);
|
||||
fn_logger.log_io_error(maybe_io_error)
|
||||
}
|
||||
|
||||
fn try_record_new_module_file(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
fn_logger: &mut FnCallLog,
|
||||
file_name: &str,
|
||||
) -> io::Result<()> {
|
||||
let mut module_file = fs::File::open(file_name)?;
|
||||
let mut read_buff = Vec::new();
|
||||
module_file.read_to_end(&mut read_buff)?;
|
||||
self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn record_new_submodule(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
submodule: &[u8],
|
||||
fn_logger: &mut FnCallLog,
|
||||
type_: &'static str,
|
||||
) {
|
||||
if self.saved_modules.insert(module) {
|
||||
self.module_counter += 1;
|
||||
self.submodule_counter = 0;
|
||||
}
|
||||
self.submodule_counter += 1;
|
||||
fn_logger.log_io_error(self.writer.save_module(
|
||||
self.module_counter,
|
||||
Some(self.submodule_counter),
|
||||
submodule,
|
||||
type_,
|
||||
));
|
||||
if type_ == "ptx" {
|
||||
match CString::new(submodule) {
|
||||
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
|
||||
Ok(submodule_cstring) => match submodule_cstring.to_str() {
|
||||
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
|
||||
Ok(submodule_text) => self.try_parse_and_record_kernels(
|
||||
fn_logger,
|
||||
self.module_counter,
|
||||
Some(self.submodule_counter),
|
||||
submodule_text,
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn record_new_module(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
raw_image: *const c_void,
|
||||
fn_logger: &mut FnCallLog,
|
||||
) {
|
||||
self.module_counter += 1;
|
||||
if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG {
|
||||
self.saved_modules.insert(module);
|
||||
// TODO: Parse ELF and write it to disk
|
||||
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||
module,
|
||||
raw_image,
|
||||
kind: "ELF",
|
||||
})
|
||||
} else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC {
|
||||
self.saved_modules.insert(module);
|
||||
// TODO: Figure out how to get size of archive module and write it to disk
|
||||
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||
module,
|
||||
raw_image,
|
||||
kind: "archive",
|
||||
})
|
||||
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
|
||||
unsafe {
|
||||
fn_logger.try_(|fn_logger| {
|
||||
trace::record_submodules_from_wrapped_fatbin(
|
||||
module,
|
||||
raw_image as *const FatbincWrapper,
|
||||
fn_logger,
|
||||
self,
|
||||
)
|
||||
});
|
||||
}
|
||||
} else {
|
||||
self.record_module_ptx(module, raw_image, fn_logger)
|
||||
}
|
||||
}
|
||||
|
||||
fn record_module_ptx(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
raw_image: *const c_void,
|
||||
fn_logger: &mut FnCallLog,
|
||||
) {
|
||||
self.saved_modules.insert(module);
|
||||
let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str();
|
||||
let module_text = match module_text {
|
||||
Ok(m) => m,
|
||||
Err(utf8_err) => {
|
||||
fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
fn_logger.log_io_error(self.writer.save_module(
|
||||
self.module_counter,
|
||||
None,
|
||||
module_text.as_bytes(),
|
||||
"ptx",
|
||||
));
|
||||
self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text);
|
||||
}
|
||||
|
||||
fn try_parse_and_record_kernels(
|
||||
&mut self,
|
||||
fn_logger: &mut FnCallLog,
|
||||
module_index: usize,
|
||||
submodule_index: Option<usize>,
|
||||
module_text: &str,
|
||||
) {
|
||||
let errors = ptx_parser::parse_for_errors(module_text);
|
||||
if !errors.is_empty() {
|
||||
fn_logger.log(log::ErrorEntry::ModuleParsingError(
|
||||
DumpWriter::get_file_name(module_index, submodule_index, "log"),
|
||||
));
|
||||
fn_logger.log_io_error(self.writer.save_module_error_log(
|
||||
module_index,
|
||||
submodule_index,
|
||||
&*errors,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This structs writes out information about CUDA execution to the dump dir
|
||||
struct DumpWriter {
|
||||
dump_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl DumpWriter {
|
||||
fn new(dump_dir: Option<PathBuf>) -> Self {
|
||||
Self { dump_dir }
|
||||
}
|
||||
|
||||
fn save_module(
|
||||
&self,
|
||||
module_index: usize,
|
||||
submodule_index: Option<usize>,
|
||||
buffer: &[u8],
|
||||
kind: &'static str,
|
||||
) -> io::Result<()> {
|
||||
let mut dump_file = match &self.dump_dir {
|
||||
None => return Ok(()),
|
||||
Some(d) => d.clone(),
|
||||
};
|
||||
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
|
||||
let mut file = File::create(dump_file)?;
|
||||
file.write_all(buffer)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_module_error_log<'input>(
|
||||
&self,
|
||||
module_index: usize,
|
||||
submodule_index: Option<usize>,
|
||||
errors: &[ptx_parser::PtxError<'input>],
|
||||
) -> io::Result<()> {
|
||||
let mut log_file = match &self.dump_dir {
|
||||
None => return Ok(()),
|
||||
Some(d) => d.clone(),
|
||||
};
|
||||
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
|
||||
let mut file = File::create(log_file)?;
|
||||
for error in errors {
|
||||
writeln!(file, "{}", error)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String {
|
||||
match submodule_index {
|
||||
None => {
|
||||
format!("module_{:04}.{:02}", module_index, kind)
|
||||
}
|
||||
Some(submodule_index) => {
|
||||
format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
|
||||
module: CUmodule,
|
||||
fatbinc_wrapper: *const FatbincWrapper,
|
||||
fn_logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
|
||||
let mut submodules = fatbin.get_submodules()?;
|
||||
while let Some(current) = submodules.next()? {
|
||||
record_submodules_from_fatbin(module, current, fn_logger, state)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn record_submodules_from_fatbin(
|
||||
module: CUmodule,
|
||||
submodule: FatbinSubmodule,
|
||||
logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
record_submodules(module, logger, state, submodule.get_files())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn record_submodules(
|
||||
module: CUmodule,
|
||||
fn_logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
mut files: FatbinFileIterator,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
while let Some(file) = files.next()? {
|
||||
let mut payload = if file
|
||||
.header
|
||||
.flags
|
||||
.contains(FatbinFileHeaderFlags::CompressedLz4)
|
||||
{
|
||||
Cow::Owned(unwrap_some_or!(
|
||||
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
|
||||
continue
|
||||
))
|
||||
} else if file
|
||||
.header
|
||||
.flags
|
||||
.contains(FatbinFileHeaderFlags::CompressedZstd)
|
||||
{
|
||||
Cow::Owned(unwrap_some_or!(
|
||||
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
|
||||
continue
|
||||
))
|
||||
} else {
|
||||
Cow::Borrowed(file.get_payload())
|
||||
};
|
||||
match file.header.kind {
|
||||
FatbinFileHeader::HEADER_KIND_PTX => {
|
||||
while payload.last() == Some(&0) {
|
||||
// remove trailing zeros
|
||||
payload.to_mut().pop();
|
||||
}
|
||||
state.record_new_submodule(module, &*payload, fn_logger, "ptx")
|
||||
}
|
||||
FatbinFileHeader::HEADER_KIND_ELF => {
|
||||
state.record_new_submodule(module, &*payload, fn_logger, "elf")
|
||||
}
|
||||
_ => {
|
||||
fn_logger.log(log::ErrorEntry::UnexpectedBinaryField {
|
||||
field_name: "FATBIN_FILE_HEADER_KIND",
|
||||
expected: vec![
|
||||
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
|
||||
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
|
||||
],
|
||||
observed: UInt::U16(file.header.kind),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
use crate::{
|
||||
log::{self, UInt},
|
||||
trace, ErrorEntry, FnCallLog, Settings,
|
||||
};
|
||||
use cuda_types::{
|
||||
cuda::*,
|
||||
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
|
||||
};
|
||||
use dark_api::fatbin::{
|
||||
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
|
||||
};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
ffi::{c_void, CStr, CString},
|
||||
fs::{self, File},
|
||||
io::{self, Read, Write},
|
||||
path::PathBuf,
|
||||
};
|
||||
use unwrap_or::unwrap_some_or;
|
||||
|
||||
// This struct is the heart of CUDA state tracking, it:
|
||||
// * receives calls from the probes about changes to CUDA state
|
||||
// * records updates to the state change
|
||||
// * writes out relevant state change and details to disk and log
|
||||
pub(crate) struct StateTracker {
|
||||
writer: DumpWriter,
|
||||
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
|
||||
saved_modules: FxHashSet<CUmodule>,
|
||||
module_counter: usize,
|
||||
submodule_counter: usize,
|
||||
pub(crate) override_cc: Option<(u32, u32)>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) struct CodePointer(pub *const c_void);
|
||||
|
||||
unsafe impl Send for CodePointer {}
|
||||
unsafe impl Sync for CodePointer {}
|
||||
|
||||
impl StateTracker {
|
||||
pub(crate) fn new(settings: &Settings) -> Self {
|
||||
StateTracker {
|
||||
writer: DumpWriter::new(settings.dump_dir.clone()),
|
||||
libraries: FxHashMap::default(),
|
||||
saved_modules: FxHashSet::default(),
|
||||
module_counter: 0,
|
||||
submodule_counter: 0,
|
||||
override_cc: settings.override_cc,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn record_new_module_file(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
file_name: *const i8,
|
||||
fn_logger: &mut FnCallLog,
|
||||
) {
|
||||
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name);
|
||||
fn_logger.log_io_error(maybe_io_error)
|
||||
}
|
||||
|
||||
fn try_record_new_module_file(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
fn_logger: &mut FnCallLog,
|
||||
file_name: &str,
|
||||
) -> io::Result<()> {
|
||||
let mut module_file = fs::File::open(file_name)?;
|
||||
let mut read_buff = Vec::new();
|
||||
module_file.read_to_end(&mut read_buff)?;
|
||||
self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn record_new_submodule(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
submodule: &[u8],
|
||||
fn_logger: &mut FnCallLog,
|
||||
type_: &'static str,
|
||||
) {
|
||||
if self.saved_modules.insert(module) {
|
||||
self.module_counter += 1;
|
||||
self.submodule_counter = 0;
|
||||
}
|
||||
self.submodule_counter += 1;
|
||||
fn_logger.log_io_error(self.writer.save_module(
|
||||
self.module_counter,
|
||||
Some(self.submodule_counter),
|
||||
submodule,
|
||||
type_,
|
||||
));
|
||||
if type_ == "ptx" {
|
||||
match CString::new(submodule) {
|
||||
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
|
||||
Ok(submodule_cstring) => match submodule_cstring.to_str() {
|
||||
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
|
||||
Ok(submodule_text) => self.try_parse_and_record_kernels(
|
||||
fn_logger,
|
||||
self.module_counter,
|
||||
Some(self.submodule_counter),
|
||||
submodule_text,
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn record_new_module(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
raw_image: *const c_void,
|
||||
fn_logger: &mut FnCallLog,
|
||||
) {
|
||||
self.module_counter += 1;
|
||||
if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG {
|
||||
self.saved_modules.insert(module);
|
||||
// TODO: Parse ELF and write it to disk
|
||||
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||
module,
|
||||
raw_image,
|
||||
kind: "ELF",
|
||||
})
|
||||
} else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC {
|
||||
self.saved_modules.insert(module);
|
||||
// TODO: Figure out how to get size of archive module and write it to disk
|
||||
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||
module,
|
||||
raw_image,
|
||||
kind: "archive",
|
||||
})
|
||||
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
|
||||
unsafe {
|
||||
fn_logger.try_(|fn_logger| {
|
||||
trace::record_submodules_from_wrapped_fatbin(
|
||||
module,
|
||||
raw_image as *const FatbincWrapper,
|
||||
fn_logger,
|
||||
self,
|
||||
)
|
||||
});
|
||||
}
|
||||
} else {
|
||||
self.record_module_ptx(module, raw_image, fn_logger)
|
||||
}
|
||||
}
|
||||
|
||||
fn record_module_ptx(
|
||||
&mut self,
|
||||
module: CUmodule,
|
||||
raw_image: *const c_void,
|
||||
fn_logger: &mut FnCallLog,
|
||||
) {
|
||||
self.saved_modules.insert(module);
|
||||
let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str();
|
||||
let module_text = match module_text {
|
||||
Ok(m) => m,
|
||||
Err(utf8_err) => {
|
||||
fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
fn_logger.log_io_error(self.writer.save_module(
|
||||
self.module_counter,
|
||||
None,
|
||||
module_text.as_bytes(),
|
||||
"ptx",
|
||||
));
|
||||
self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text);
|
||||
}
|
||||
|
||||
fn try_parse_and_record_kernels(
|
||||
&mut self,
|
||||
fn_logger: &mut FnCallLog,
|
||||
module_index: usize,
|
||||
submodule_index: Option<usize>,
|
||||
module_text: &str,
|
||||
) {
|
||||
let errors = ptx_parser::parse_for_errors(module_text);
|
||||
if !errors.is_empty() {
|
||||
fn_logger.log(log::ErrorEntry::ModuleParsingError(
|
||||
DumpWriter::get_file_name(module_index, submodule_index, "log"),
|
||||
));
|
||||
fn_logger.log_io_error(self.writer.save_module_error_log(
|
||||
module_index,
|
||||
submodule_index,
|
||||
&*errors,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This structs writes out information about CUDA execution to the dump dir
|
||||
struct DumpWriter {
|
||||
dump_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl DumpWriter {
|
||||
fn new(dump_dir: Option<PathBuf>) -> Self {
|
||||
Self { dump_dir }
|
||||
}
|
||||
|
||||
fn save_module(
|
||||
&self,
|
||||
module_index: usize,
|
||||
submodule_index: Option<usize>,
|
||||
buffer: &[u8],
|
||||
kind: &'static str,
|
||||
) -> io::Result<()> {
|
||||
let mut dump_file = match &self.dump_dir {
|
||||
None => return Ok(()),
|
||||
Some(d) => d.clone(),
|
||||
};
|
||||
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
|
||||
let mut file = File::create(dump_file)?;
|
||||
file.write_all(buffer)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_module_error_log<'input>(
|
||||
&self,
|
||||
module_index: usize,
|
||||
submodule_index: Option<usize>,
|
||||
errors: &[ptx_parser::PtxError<'input>],
|
||||
) -> io::Result<()> {
|
||||
let mut log_file = match &self.dump_dir {
|
||||
None => return Ok(()),
|
||||
Some(d) => d.clone(),
|
||||
};
|
||||
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
|
||||
let mut file = File::create(log_file)?;
|
||||
for error in errors {
|
||||
writeln!(file, "{}", error)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String {
|
||||
match submodule_index {
|
||||
None => {
|
||||
format!("module_{:04}.{:02}", module_index, kind)
|
||||
}
|
||||
Some(submodule_index) => {
|
||||
format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
|
||||
module: CUmodule,
|
||||
fatbinc_wrapper: *const FatbincWrapper,
|
||||
fn_logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
|
||||
let mut submodules = fatbin.get_submodules()?;
|
||||
while let Some(current) = submodules.next()? {
|
||||
record_submodules_from_fatbin(module, current, fn_logger, state)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn record_submodules_from_fatbin(
|
||||
module: CUmodule,
|
||||
submodule: FatbinSubmodule,
|
||||
logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
record_submodules(module, logger, state, submodule.get_files())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn record_submodules(
|
||||
module: CUmodule,
|
||||
fn_logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
mut files: FatbinFileIterator,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
while let Some(file) = files.next()? {
|
||||
let mut payload = if file
|
||||
.header
|
||||
.flags
|
||||
.contains(FatbinFileHeaderFlags::CompressedLz4)
|
||||
{
|
||||
Cow::Owned(unwrap_some_or!(
|
||||
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
|
||||
continue
|
||||
))
|
||||
} else if file
|
||||
.header
|
||||
.flags
|
||||
.contains(FatbinFileHeaderFlags::CompressedZstd)
|
||||
{
|
||||
Cow::Owned(unwrap_some_or!(
|
||||
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
|
||||
continue
|
||||
))
|
||||
} else {
|
||||
Cow::Borrowed(file.get_payload())
|
||||
};
|
||||
match file.header.kind {
|
||||
FatbinFileHeader::HEADER_KIND_PTX => {
|
||||
while payload.last() == Some(&0) {
|
||||
// remove trailing zeros
|
||||
payload.to_mut().pop();
|
||||
}
|
||||
state.record_new_submodule(module, &*payload, fn_logger, "ptx")
|
||||
}
|
||||
FatbinFileHeader::HEADER_KIND_ELF => {
|
||||
state.record_new_submodule(module, &*payload, fn_logger, "elf")
|
||||
}
|
||||
_ => {
|
||||
fn_logger.log(log::ErrorEntry::UnexpectedBinaryField {
|
||||
field_name: "FATBIN_FILE_HEADER_KIND",
|
||||
expected: vec![
|
||||
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
|
||||
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
|
||||
],
|
||||
observed: UInt::U16(file.header.kind),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,81 +1,81 @@
|
||||
use std::{
|
||||
env::{self, VarError},
|
||||
fs::{self, DirEntry},
|
||||
io,
|
||||
path::{self, PathBuf},
|
||||
process::Command,
|
||||
};
|
||||
|
||||
fn main() -> Result<(), VarError> {
|
||||
if std::env::var_os("CARGO_CFG_WINDOWS").is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
if env::var("PROFILE")? != "debug" {
|
||||
return Ok(());
|
||||
}
|
||||
let rustc_exe = env::var("RUSTC")?;
|
||||
let out_dir = env::var("OUT_DIR")?;
|
||||
let target = env::var("TARGET")?;
|
||||
let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc";
|
||||
let opt_level = env::var("OPT_LEVEL")?;
|
||||
let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap();
|
||||
let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
|
||||
helpers_dir.push("tests");
|
||||
helpers_dir.push("helpers");
|
||||
let helpers_dir_as_string = helpers_dir.to_string_lossy();
|
||||
println!("cargo:rerun-if-changed={}", helpers_dir_as_string);
|
||||
for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) {
|
||||
let full_file_path = format!(
|
||||
"{}{}{}",
|
||||
helpers_dir_as_string,
|
||||
path::MAIN_SEPARATOR,
|
||||
rust_file
|
||||
);
|
||||
let mut rustc_cmd = Command::new(&*rustc_exe);
|
||||
if debug {
|
||||
rustc_cmd.arg("-g");
|
||||
}
|
||||
rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string));
|
||||
if !is_msvc {
|
||||
// HACK ALERT
|
||||
// I have no idea why the extra library below have to be linked
|
||||
rustc_cmd.arg(r"-lucrt");
|
||||
}
|
||||
rustc_cmd
|
||||
.arg("-C")
|
||||
.arg(format!("opt-level={}", opt_level))
|
||||
.arg("-L")
|
||||
.arg(format!("{}", out_dir))
|
||||
.arg("--out-dir")
|
||||
.arg(format!("{}", out_dir))
|
||||
.arg("--target")
|
||||
.arg(format!("{}", target))
|
||||
.arg(full_file_path);
|
||||
assert!(rustc_cmd.status().unwrap().success());
|
||||
}
|
||||
std::fs::copy(
|
||||
format!(
|
||||
"{}{}do_cuinit_late_clr.exe",
|
||||
helpers_dir_as_string,
|
||||
path::MAIN_SEPARATOR
|
||||
),
|
||||
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
|
||||
)
|
||||
.unwrap();
|
||||
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rust_file(entry: io::Result<DirEntry>) -> Option<String> {
|
||||
entry.ok().and_then(|e| {
|
||||
let os_file_name = e.file_name();
|
||||
let file_name = os_file_name.to_string_lossy();
|
||||
let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false);
|
||||
if is_file && file_name.ends_with(".rs") {
|
||||
Some(file_name.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
use std::{
|
||||
env::{self, VarError},
|
||||
fs::{self, DirEntry},
|
||||
io,
|
||||
path::{self, PathBuf},
|
||||
process::Command,
|
||||
};
|
||||
|
||||
fn main() -> Result<(), VarError> {
|
||||
if std::env::var_os("CARGO_CFG_WINDOWS").is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
if env::var("PROFILE")? != "debug" {
|
||||
return Ok(());
|
||||
}
|
||||
let rustc_exe = env::var("RUSTC")?;
|
||||
let out_dir = env::var("OUT_DIR")?;
|
||||
let target = env::var("TARGET")?;
|
||||
let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc";
|
||||
let opt_level = env::var("OPT_LEVEL")?;
|
||||
let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap();
|
||||
let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
|
||||
helpers_dir.push("tests");
|
||||
helpers_dir.push("helpers");
|
||||
let helpers_dir_as_string = helpers_dir.to_string_lossy();
|
||||
println!("cargo:rerun-if-changed={}", helpers_dir_as_string);
|
||||
for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) {
|
||||
let full_file_path = format!(
|
||||
"{}{}{}",
|
||||
helpers_dir_as_string,
|
||||
path::MAIN_SEPARATOR,
|
||||
rust_file
|
||||
);
|
||||
let mut rustc_cmd = Command::new(&*rustc_exe);
|
||||
if debug {
|
||||
rustc_cmd.arg("-g");
|
||||
}
|
||||
rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string));
|
||||
if !is_msvc {
|
||||
// HACK ALERT
|
||||
// I have no idea why the extra library below have to be linked
|
||||
rustc_cmd.arg(r"-lucrt");
|
||||
}
|
||||
rustc_cmd
|
||||
.arg("-C")
|
||||
.arg(format!("opt-level={}", opt_level))
|
||||
.arg("-L")
|
||||
.arg(format!("{}", out_dir))
|
||||
.arg("--out-dir")
|
||||
.arg(format!("{}", out_dir))
|
||||
.arg("--target")
|
||||
.arg(format!("{}", target))
|
||||
.arg(full_file_path);
|
||||
assert!(rustc_cmd.status().unwrap().success());
|
||||
}
|
||||
std::fs::copy(
|
||||
format!(
|
||||
"{}{}do_cuinit_late_clr.exe",
|
||||
helpers_dir_as_string,
|
||||
path::MAIN_SEPARATOR
|
||||
),
|
||||
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
|
||||
)
|
||||
.unwrap();
|
||||
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rust_file(entry: io::Result<DirEntry>) -> Option<String> {
|
||||
entry.ok().and_then(|e| {
|
||||
let os_file_name = e.file_name();
|
||||
let file_name = os_file_name.to_string_lossy();
|
||||
let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false);
|
||||
if is_file && file_name.ends_with(".rs") {
|
||||
Some(file_name.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -1,311 +1,311 @@
|
||||
use std::env;
|
||||
use std::os::windows;
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
use std::{error::Error, process};
|
||||
use std::{fs, io, ptr};
|
||||
use std::{mem, path::PathBuf};
|
||||
|
||||
use argh::FromArgs;
|
||||
use mem::size_of_val;
|
||||
use tempfile::TempDir;
|
||||
use winapi::um::processenv::SearchPathW;
|
||||
use winapi::um::{
|
||||
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
|
||||
processthreadsapi::{GetExitCodeProcess, ResumeThread},
|
||||
synchapi::WaitForSingleObject,
|
||||
winbase::CreateJobObjectA,
|
||||
winnt::{
|
||||
JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
|
||||
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
|
||||
},
|
||||
};
|
||||
|
||||
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
||||
|
||||
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
||||
static NVCUDA_DLL: &'static str = "nvcuda.dll";
|
||||
static NVML_DLL: &'static str = "nvml.dll";
|
||||
|
||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||
|
||||
#[derive(FromArgs)]
|
||||
/// Launch application with custom CUDA libraries
|
||||
struct ProgramArguments {
|
||||
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
|
||||
#[argh(option)]
|
||||
nvcuda: Option<PathBuf>,
|
||||
|
||||
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
|
||||
#[argh(option)]
|
||||
nvml: Option<PathBuf>,
|
||||
|
||||
/// executable to be injected with custom CUDA libraries
|
||||
#[argh(positional)]
|
||||
exe: String,
|
||||
|
||||
/// arguments to the executable
|
||||
#[argh(positional)]
|
||||
args: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
||||
let raw_args = argh::from_env::<ProgramArguments>();
|
||||
let normalized_args = NormalizedArguments::new(raw_args)?;
|
||||
let mut environment = Environment::setup(normalized_args)?;
|
||||
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
||||
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
||||
let mut dlls_to_inject = [
|
||||
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
|
||||
environment.nvcuda_path_zero_terminated.as_ptr() as _,
|
||||
environment.redirect_path_zero_terminated.as_ptr() as _,
|
||||
];
|
||||
os_call!(
|
||||
detours_sys::DetourCreateProcessWithDllsW(
|
||||
ptr::null(),
|
||||
environment.winapi_command_line_zero_terminated.as_mut_ptr(),
|
||||
ptr::null_mut(),
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
0,
|
||||
ptr::null_mut(),
|
||||
ptr::null(),
|
||||
&mut startup_info as *mut _,
|
||||
&mut proc_info as *mut _,
|
||||
dlls_to_inject.len() as u32,
|
||||
dlls_to_inject.as_mut_ptr(),
|
||||
Option::None
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
kill_child_on_process_exit(proc_info.hProcess)?;
|
||||
os_call!(
|
||||
detours_sys::DetourCopyPayloadToProcess(
|
||||
proc_info.hProcess,
|
||||
&PAYLOAD_NVCUDA_GUID,
|
||||
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
|
||||
environment.nvcuda_path_zero_terminated.len() as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
os_call!(
|
||||
detours_sys::DetourCopyPayloadToProcess(
|
||||
proc_info.hProcess,
|
||||
&PAYLOAD_NVML_GUID,
|
||||
environment.nvml_path_zero_terminated.as_ptr() as *mut _,
|
||||
environment.nvml_path_zero_terminated.len() as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
|
||||
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
|
||||
!= WAIT_FAILED);
|
||||
let mut child_exit_code: u32 = 0;
|
||||
os_call!(
|
||||
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|
||||
|x| x != 0
|
||||
);
|
||||
process::exit(child_exit_code as i32)
|
||||
}
|
||||
|
||||
struct NormalizedArguments {
|
||||
nvml_path: PathBuf,
|
||||
nvcuda_path: PathBuf,
|
||||
redirect_path: PathBuf,
|
||||
winapi_command_line_zero_terminated: Vec<u16>,
|
||||
}
|
||||
|
||||
impl NormalizedArguments {
|
||||
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
|
||||
let current_exe = env::current_exe()?;
|
||||
let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?;
|
||||
let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?;
|
||||
let winapi_command_line_zero_terminated =
|
||||
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
|
||||
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
|
||||
redirect_path.push(REDIRECT_DLL);
|
||||
Ok(Self {
|
||||
nvml_path,
|
||||
nvcuda_path,
|
||||
redirect_path,
|
||||
winapi_command_line_zero_terminated,
|
||||
})
|
||||
}
|
||||
|
||||
const WIN_MAX_PATH: usize = 260;
|
||||
|
||||
fn get_absolute_path(
|
||||
current_exe: &PathBuf,
|
||||
dll: Option<PathBuf>,
|
||||
default: &str,
|
||||
) -> Result<PathBuf, Box<dyn Error>> {
|
||||
Ok(if let Some(dll) = dll {
|
||||
if dll.is_absolute() {
|
||||
dll
|
||||
} else {
|
||||
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
|
||||
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
|
||||
dll_utf16.push(0);
|
||||
loop {
|
||||
let copied_len = os_call!(
|
||||
SearchPathW(
|
||||
ptr::null_mut(),
|
||||
dll_utf16.as_ptr(),
|
||||
ptr::null(),
|
||||
full_dll_path.len() as u32,
|
||||
full_dll_path.as_mut_ptr(),
|
||||
ptr::null_mut()
|
||||
),
|
||||
|x| x != 0
|
||||
) as usize;
|
||||
if copied_len > full_dll_path.len() {
|
||||
full_dll_path.resize(copied_len + 1, 0);
|
||||
} else {
|
||||
full_dll_path.truncate(copied_len);
|
||||
break;
|
||||
}
|
||||
}
|
||||
PathBuf::from(String::from_utf16_lossy(&full_dll_path))
|
||||
}
|
||||
} else {
|
||||
let mut dll_path = current_exe.parent().unwrap().to_path_buf();
|
||||
dll_path.push(default);
|
||||
dll_path
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct Environment {
|
||||
nvml_path_zero_terminated: String,
|
||||
nvcuda_path_zero_terminated: String,
|
||||
redirect_path_zero_terminated: String,
|
||||
winapi_command_line_zero_terminated: Vec<u16>,
|
||||
_temp_dir: TempDir,
|
||||
}
|
||||
|
||||
// This structs represents "enviroment". By environment we mean all paths
|
||||
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
|
||||
// directory which contains nvcuda.dll
|
||||
impl Environment {
|
||||
fn setup(args: NormalizedArguments) -> io::Result<Self> {
|
||||
let _temp_dir = TempDir::new()?;
|
||||
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.nvml_path,
|
||||
&_temp_dir,
|
||||
NVML_DLL,
|
||||
)?);
|
||||
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.nvcuda_path,
|
||||
&_temp_dir,
|
||||
NVCUDA_DLL,
|
||||
)?);
|
||||
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
|
||||
Ok(Self {
|
||||
nvml_path_zero_terminated,
|
||||
nvcuda_path_zero_terminated,
|
||||
redirect_path_zero_terminated,
|
||||
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
|
||||
_temp_dir,
|
||||
})
|
||||
}
|
||||
|
||||
fn copy_to_correct_name(
|
||||
path_buf: PathBuf,
|
||||
temp_dir: &TempDir,
|
||||
correct_name: &str,
|
||||
) -> io::Result<PathBuf> {
|
||||
let file_name = path_buf.file_name().unwrap();
|
||||
if file_name == correct_name {
|
||||
Ok(path_buf)
|
||||
} else {
|
||||
let mut temp_file_path = temp_dir.path().to_path_buf();
|
||||
temp_file_path.push(correct_name);
|
||||
match windows::fs::symlink_file(&path_buf, &temp_file_path) {
|
||||
Ok(()) => {}
|
||||
Err(_) => {
|
||||
fs::copy(&path_buf, &temp_file_path)?;
|
||||
}
|
||||
}
|
||||
Ok(temp_file_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_terminate(p: PathBuf) -> String {
|
||||
let mut s = p.to_string_lossy().to_string();
|
||||
s.push('\0');
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
||||
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
|
||||
!= ptr::null_mut());
|
||||
let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() };
|
||||
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
|
||||
os_call!(
|
||||
SetInformationJobObject(
|
||||
job_handle,
|
||||
JobObjectExtendedLimitInformation,
|
||||
&mut info as *mut _ as *mut _,
|
||||
size_of_val(&info) as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
|
||||
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
|
||||
let mut cmd_line = Vec::new();
|
||||
let args_len = args.size_hint().0;
|
||||
for (idx, arg) in args.enumerate() {
|
||||
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
|
||||
cmd_line.extend(arg.encode_utf16());
|
||||
} else {
|
||||
cmd_line.push('"' as u16); // "
|
||||
let mut char_iter = arg.chars().peekable();
|
||||
loop {
|
||||
let mut current = char_iter.next();
|
||||
let mut backslashes = 0;
|
||||
match current {
|
||||
Some('\\') => {
|
||||
backslashes = 1;
|
||||
while let Some('\\') = char_iter.peek() {
|
||||
backslashes += 1;
|
||||
char_iter.next();
|
||||
}
|
||||
current = char_iter.next();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
match current {
|
||||
None => {
|
||||
for _ in 0..(backslashes * 2) {
|
||||
cmd_line.push('\\' as u16);
|
||||
}
|
||||
break;
|
||||
}
|
||||
Some('"') => {
|
||||
for _ in 0..(backslashes * 2 + 1) {
|
||||
cmd_line.push('\\' as u16);
|
||||
}
|
||||
cmd_line.push('"' as u16);
|
||||
}
|
||||
Some(c) => {
|
||||
for _ in 0..backslashes {
|
||||
cmd_line.push('\\' as u16);
|
||||
}
|
||||
let mut temp = [0u16; 2];
|
||||
cmd_line.extend(&*c.encode_utf16(&mut temp));
|
||||
}
|
||||
}
|
||||
}
|
||||
cmd_line.push('"' as u16);
|
||||
}
|
||||
if idx < args_len - 1 {
|
||||
cmd_line.push(' ' as u16);
|
||||
}
|
||||
}
|
||||
cmd_line.push(0);
|
||||
cmd_line
|
||||
}
|
||||
use std::env;
|
||||
use std::os::windows;
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
use std::{error::Error, process};
|
||||
use std::{fs, io, ptr};
|
||||
use std::{mem, path::PathBuf};
|
||||
|
||||
use argh::FromArgs;
|
||||
use mem::size_of_val;
|
||||
use tempfile::TempDir;
|
||||
use winapi::um::processenv::SearchPathW;
|
||||
use winapi::um::{
|
||||
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
|
||||
processthreadsapi::{GetExitCodeProcess, ResumeThread},
|
||||
synchapi::WaitForSingleObject,
|
||||
winbase::CreateJobObjectA,
|
||||
winnt::{
|
||||
JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
|
||||
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
|
||||
},
|
||||
};
|
||||
|
||||
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
||||
|
||||
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
||||
static NVCUDA_DLL: &'static str = "nvcuda.dll";
|
||||
static NVML_DLL: &'static str = "nvml.dll";
|
||||
|
||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||
|
||||
#[derive(FromArgs)]
|
||||
/// Launch application with custom CUDA libraries
|
||||
struct ProgramArguments {
|
||||
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
|
||||
#[argh(option)]
|
||||
nvcuda: Option<PathBuf>,
|
||||
|
||||
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
|
||||
#[argh(option)]
|
||||
nvml: Option<PathBuf>,
|
||||
|
||||
/// executable to be injected with custom CUDA libraries
|
||||
#[argh(positional)]
|
||||
exe: String,
|
||||
|
||||
/// arguments to the executable
|
||||
#[argh(positional)]
|
||||
args: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
||||
let raw_args = argh::from_env::<ProgramArguments>();
|
||||
let normalized_args = NormalizedArguments::new(raw_args)?;
|
||||
let mut environment = Environment::setup(normalized_args)?;
|
||||
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
||||
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
||||
let mut dlls_to_inject = [
|
||||
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
|
||||
environment.nvcuda_path_zero_terminated.as_ptr() as _,
|
||||
environment.redirect_path_zero_terminated.as_ptr() as _,
|
||||
];
|
||||
os_call!(
|
||||
detours_sys::DetourCreateProcessWithDllsW(
|
||||
ptr::null(),
|
||||
environment.winapi_command_line_zero_terminated.as_mut_ptr(),
|
||||
ptr::null_mut(),
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
0,
|
||||
ptr::null_mut(),
|
||||
ptr::null(),
|
||||
&mut startup_info as *mut _,
|
||||
&mut proc_info as *mut _,
|
||||
dlls_to_inject.len() as u32,
|
||||
dlls_to_inject.as_mut_ptr(),
|
||||
Option::None
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
kill_child_on_process_exit(proc_info.hProcess)?;
|
||||
os_call!(
|
||||
detours_sys::DetourCopyPayloadToProcess(
|
||||
proc_info.hProcess,
|
||||
&PAYLOAD_NVCUDA_GUID,
|
||||
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
|
||||
environment.nvcuda_path_zero_terminated.len() as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
os_call!(
|
||||
detours_sys::DetourCopyPayloadToProcess(
|
||||
proc_info.hProcess,
|
||||
&PAYLOAD_NVML_GUID,
|
||||
environment.nvml_path_zero_terminated.as_ptr() as *mut _,
|
||||
environment.nvml_path_zero_terminated.len() as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
|
||||
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
|
||||
!= WAIT_FAILED);
|
||||
let mut child_exit_code: u32 = 0;
|
||||
os_call!(
|
||||
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|
||||
|x| x != 0
|
||||
);
|
||||
process::exit(child_exit_code as i32)
|
||||
}
|
||||
|
||||
struct NormalizedArguments {
|
||||
nvml_path: PathBuf,
|
||||
nvcuda_path: PathBuf,
|
||||
redirect_path: PathBuf,
|
||||
winapi_command_line_zero_terminated: Vec<u16>,
|
||||
}
|
||||
|
||||
impl NormalizedArguments {
|
||||
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
|
||||
let current_exe = env::current_exe()?;
|
||||
let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?;
|
||||
let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?;
|
||||
let winapi_command_line_zero_terminated =
|
||||
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
|
||||
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
|
||||
redirect_path.push(REDIRECT_DLL);
|
||||
Ok(Self {
|
||||
nvml_path,
|
||||
nvcuda_path,
|
||||
redirect_path,
|
||||
winapi_command_line_zero_terminated,
|
||||
})
|
||||
}
|
||||
|
||||
const WIN_MAX_PATH: usize = 260;
|
||||
|
||||
fn get_absolute_path(
|
||||
current_exe: &PathBuf,
|
||||
dll: Option<PathBuf>,
|
||||
default: &str,
|
||||
) -> Result<PathBuf, Box<dyn Error>> {
|
||||
Ok(if let Some(dll) = dll {
|
||||
if dll.is_absolute() {
|
||||
dll
|
||||
} else {
|
||||
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
|
||||
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
|
||||
dll_utf16.push(0);
|
||||
loop {
|
||||
let copied_len = os_call!(
|
||||
SearchPathW(
|
||||
ptr::null_mut(),
|
||||
dll_utf16.as_ptr(),
|
||||
ptr::null(),
|
||||
full_dll_path.len() as u32,
|
||||
full_dll_path.as_mut_ptr(),
|
||||
ptr::null_mut()
|
||||
),
|
||||
|x| x != 0
|
||||
) as usize;
|
||||
if copied_len > full_dll_path.len() {
|
||||
full_dll_path.resize(copied_len + 1, 0);
|
||||
} else {
|
||||
full_dll_path.truncate(copied_len);
|
||||
break;
|
||||
}
|
||||
}
|
||||
PathBuf::from(String::from_utf16_lossy(&full_dll_path))
|
||||
}
|
||||
} else {
|
||||
let mut dll_path = current_exe.parent().unwrap().to_path_buf();
|
||||
dll_path.push(default);
|
||||
dll_path
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct Environment {
|
||||
nvml_path_zero_terminated: String,
|
||||
nvcuda_path_zero_terminated: String,
|
||||
redirect_path_zero_terminated: String,
|
||||
winapi_command_line_zero_terminated: Vec<u16>,
|
||||
_temp_dir: TempDir,
|
||||
}
|
||||
|
||||
// This structs represents "enviroment". By environment we mean all paths
|
||||
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
|
||||
// directory which contains nvcuda.dll
|
||||
impl Environment {
|
||||
fn setup(args: NormalizedArguments) -> io::Result<Self> {
|
||||
let _temp_dir = TempDir::new()?;
|
||||
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.nvml_path,
|
||||
&_temp_dir,
|
||||
NVML_DLL,
|
||||
)?);
|
||||
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.nvcuda_path,
|
||||
&_temp_dir,
|
||||
NVCUDA_DLL,
|
||||
)?);
|
||||
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
|
||||
Ok(Self {
|
||||
nvml_path_zero_terminated,
|
||||
nvcuda_path_zero_terminated,
|
||||
redirect_path_zero_terminated,
|
||||
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
|
||||
_temp_dir,
|
||||
})
|
||||
}
|
||||
|
||||
fn copy_to_correct_name(
|
||||
path_buf: PathBuf,
|
||||
temp_dir: &TempDir,
|
||||
correct_name: &str,
|
||||
) -> io::Result<PathBuf> {
|
||||
let file_name = path_buf.file_name().unwrap();
|
||||
if file_name == correct_name {
|
||||
Ok(path_buf)
|
||||
} else {
|
||||
let mut temp_file_path = temp_dir.path().to_path_buf();
|
||||
temp_file_path.push(correct_name);
|
||||
match windows::fs::symlink_file(&path_buf, &temp_file_path) {
|
||||
Ok(()) => {}
|
||||
Err(_) => {
|
||||
fs::copy(&path_buf, &temp_file_path)?;
|
||||
}
|
||||
}
|
||||
Ok(temp_file_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_terminate(p: PathBuf) -> String {
|
||||
let mut s = p.to_string_lossy().to_string();
|
||||
s.push('\0');
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
||||
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
|
||||
!= ptr::null_mut());
|
||||
let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() };
|
||||
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
|
||||
os_call!(
|
||||
SetInformationJobObject(
|
||||
job_handle,
|
||||
JobObjectExtendedLimitInformation,
|
||||
&mut info as *mut _ as *mut _,
|
||||
size_of_val(&info) as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
|
||||
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
|
||||
let mut cmd_line = Vec::new();
|
||||
let args_len = args.size_hint().0;
|
||||
for (idx, arg) in args.enumerate() {
|
||||
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
|
||||
cmd_line.extend(arg.encode_utf16());
|
||||
} else {
|
||||
cmd_line.push('"' as u16); // "
|
||||
let mut char_iter = arg.chars().peekable();
|
||||
loop {
|
||||
let mut current = char_iter.next();
|
||||
let mut backslashes = 0;
|
||||
match current {
|
||||
Some('\\') => {
|
||||
backslashes = 1;
|
||||
while let Some('\\') = char_iter.peek() {
|
||||
backslashes += 1;
|
||||
char_iter.next();
|
||||
}
|
||||
current = char_iter.next();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
match current {
|
||||
None => {
|
||||
for _ in 0..(backslashes * 2) {
|
||||
cmd_line.push('\\' as u16);
|
||||
}
|
||||
break;
|
||||
}
|
||||
Some('"') => {
|
||||
for _ in 0..(backslashes * 2 + 1) {
|
||||
cmd_line.push('\\' as u16);
|
||||
}
|
||||
cmd_line.push('"' as u16);
|
||||
}
|
||||
Some(c) => {
|
||||
for _ in 0..backslashes {
|
||||
cmd_line.push('\\' as u16);
|
||||
}
|
||||
let mut temp = [0u16; 2];
|
||||
cmd_line.extend(&*c.encode_utf16(&mut temp));
|
||||
}
|
||||
}
|
||||
}
|
||||
cmd_line.push('"' as u16);
|
||||
}
|
||||
if idx < args_len - 1 {
|
||||
cmd_line.push(' ' as u16);
|
||||
}
|
||||
}
|
||||
cmd_line.push(0);
|
||||
cmd_line
|
||||
}
|
||||
|
@ -1,13 +1,13 @@
|
||||
#[macro_use]
|
||||
#[cfg(target_os = "windows")]
|
||||
mod win;
|
||||
#[cfg(target_os = "windows")]
|
||||
mod bin;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
bin::main_impl()
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn main() {}
|
||||
#[macro_use]
|
||||
#[cfg(target_os = "windows")]
|
||||
mod win;
|
||||
#[cfg(target_os = "windows")]
|
||||
mod bin;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
bin::main_impl()
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn main() {}
|
||||
|
@ -1,151 +1,151 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::error;
|
||||
use std::fmt;
|
||||
use std::ptr;
|
||||
|
||||
mod c {
|
||||
use std::ffi::c_void;
|
||||
use std::os::raw::c_ulong;
|
||||
|
||||
pub type DWORD = c_ulong;
|
||||
pub type HANDLE = LPVOID;
|
||||
pub type LPVOID = *mut c_void;
|
||||
pub type HINSTANCE = HANDLE;
|
||||
pub type HMODULE = HINSTANCE;
|
||||
pub type WCHAR = u16;
|
||||
pub type LPCWSTR = *const WCHAR;
|
||||
pub type LPWSTR = *mut WCHAR;
|
||||
|
||||
pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
|
||||
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
|
||||
pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
|
||||
pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;
|
||||
|
||||
extern "system" {
|
||||
pub fn GetLastError() -> DWORD;
|
||||
pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
|
||||
pub fn FormatMessageW(
|
||||
flags: DWORD,
|
||||
lpSrc: LPVOID,
|
||||
msgId: DWORD,
|
||||
langId: DWORD,
|
||||
buf: LPWSTR,
|
||||
nsize: DWORD,
|
||||
args: *const c_void,
|
||||
) -> DWORD;
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! last_ident {
|
||||
($i:ident) => {
|
||||
stringify!($i)
|
||||
};
|
||||
($start:ident, $($cont:ident),+) => {
|
||||
last_ident!($($cont),+)
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! os_call {
|
||||
($($path:ident)::+ ($($args:expr),*), $success:expr) => {
|
||||
{
|
||||
let result = unsafe{ $($path)::+ ($($args),*) };
|
||||
if !($success)(result) {
|
||||
let name = last_ident!($($path),+);
|
||||
let err_code = $crate::win::errno();
|
||||
Err($crate::win::OsError{
|
||||
function: name,
|
||||
error_code: err_code as u32,
|
||||
message: $crate::win::error_string(err_code)
|
||||
})?;
|
||||
}
|
||||
result
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OsError {
|
||||
pub function: &'static str,
|
||||
pub error_code: u32,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for OsError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for OsError {
|
||||
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn errno() -> i32 {
|
||||
unsafe { c::GetLastError() as i32 }
|
||||
}
|
||||
|
||||
/// Gets a detailed string description for the given error number.
|
||||
pub fn error_string(mut errnum: i32) -> String {
|
||||
// This value is calculated from the macro
|
||||
// MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
|
||||
let langId = 0x0800 as c::DWORD;
|
||||
|
||||
let mut buf = [0 as c::WCHAR; 2048];
|
||||
|
||||
unsafe {
|
||||
let mut module = ptr::null_mut();
|
||||
let mut flags = 0;
|
||||
|
||||
// NTSTATUS errors may be encoded as HRESULT, which may returned from
|
||||
// GetLastError. For more information about Windows error codes, see
|
||||
// `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
|
||||
if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
|
||||
// format according to https://support.microsoft.com/en-us/help/259693
|
||||
const NTDLL_DLL: &[u16] = &[
|
||||
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
|
||||
'L' as _, 0,
|
||||
];
|
||||
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());
|
||||
|
||||
if module != ptr::null_mut() {
|
||||
errnum ^= c::FACILITY_NT_BIT as i32;
|
||||
flags = c::FORMAT_MESSAGE_FROM_HMODULE;
|
||||
}
|
||||
}
|
||||
|
||||
let res = c::FormatMessageW(
|
||||
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
module,
|
||||
errnum as c::DWORD,
|
||||
langId,
|
||||
buf.as_mut_ptr(),
|
||||
buf.len() as c::DWORD,
|
||||
ptr::null(),
|
||||
) as usize;
|
||||
if res == 0 {
|
||||
// Sometimes FormatMessageW can fail e.g., system doesn't like langId,
|
||||
let fm_err = errno();
|
||||
return format!(
|
||||
"OS Error {} (FormatMessageW() returned error {})",
|
||||
errnum, fm_err
|
||||
);
|
||||
}
|
||||
|
||||
match String::from_utf16(&buf[..res]) {
|
||||
Ok(mut msg) => {
|
||||
// Trim trailing CRLF inserted by FormatMessageW
|
||||
let len = msg.trim_end().len();
|
||||
msg.truncate(len);
|
||||
msg
|
||||
}
|
||||
Err(..) => format!(
|
||||
"OS Error {} (FormatMessageW() returned \
|
||||
invalid UTF-16)",
|
||||
errnum
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::error;
|
||||
use std::fmt;
|
||||
use std::ptr;
|
||||
|
||||
mod c {
|
||||
use std::ffi::c_void;
|
||||
use std::os::raw::c_ulong;
|
||||
|
||||
pub type DWORD = c_ulong;
|
||||
pub type HANDLE = LPVOID;
|
||||
pub type LPVOID = *mut c_void;
|
||||
pub type HINSTANCE = HANDLE;
|
||||
pub type HMODULE = HINSTANCE;
|
||||
pub type WCHAR = u16;
|
||||
pub type LPCWSTR = *const WCHAR;
|
||||
pub type LPWSTR = *mut WCHAR;
|
||||
|
||||
pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
|
||||
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
|
||||
pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
|
||||
pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;
|
||||
|
||||
extern "system" {
|
||||
pub fn GetLastError() -> DWORD;
|
||||
pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
|
||||
pub fn FormatMessageW(
|
||||
flags: DWORD,
|
||||
lpSrc: LPVOID,
|
||||
msgId: DWORD,
|
||||
langId: DWORD,
|
||||
buf: LPWSTR,
|
||||
nsize: DWORD,
|
||||
args: *const c_void,
|
||||
) -> DWORD;
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! last_ident {
|
||||
($i:ident) => {
|
||||
stringify!($i)
|
||||
};
|
||||
($start:ident, $($cont:ident),+) => {
|
||||
last_ident!($($cont),+)
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! os_call {
|
||||
($($path:ident)::+ ($($args:expr),*), $success:expr) => {
|
||||
{
|
||||
let result = unsafe{ $($path)::+ ($($args),*) };
|
||||
if !($success)(result) {
|
||||
let name = last_ident!($($path),+);
|
||||
let err_code = $crate::win::errno();
|
||||
Err($crate::win::OsError{
|
||||
function: name,
|
||||
error_code: err_code as u32,
|
||||
message: $crate::win::error_string(err_code)
|
||||
})?;
|
||||
}
|
||||
result
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OsError {
|
||||
pub function: &'static str,
|
||||
pub error_code: u32,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for OsError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for OsError {
|
||||
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn errno() -> i32 {
|
||||
unsafe { c::GetLastError() as i32 }
|
||||
}
|
||||
|
||||
/// Gets a detailed string description for the given error number.
|
||||
pub fn error_string(mut errnum: i32) -> String {
|
||||
// This value is calculated from the macro
|
||||
// MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
|
||||
let langId = 0x0800 as c::DWORD;
|
||||
|
||||
let mut buf = [0 as c::WCHAR; 2048];
|
||||
|
||||
unsafe {
|
||||
let mut module = ptr::null_mut();
|
||||
let mut flags = 0;
|
||||
|
||||
// NTSTATUS errors may be encoded as HRESULT, which may returned from
|
||||
// GetLastError. For more information about Windows error codes, see
|
||||
// `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
|
||||
if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
|
||||
// format according to https://support.microsoft.com/en-us/help/259693
|
||||
const NTDLL_DLL: &[u16] = &[
|
||||
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
|
||||
'L' as _, 0,
|
||||
];
|
||||
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());
|
||||
|
||||
if module != ptr::null_mut() {
|
||||
errnum ^= c::FACILITY_NT_BIT as i32;
|
||||
flags = c::FORMAT_MESSAGE_FROM_HMODULE;
|
||||
}
|
||||
}
|
||||
|
||||
let res = c::FormatMessageW(
|
||||
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
module,
|
||||
errnum as c::DWORD,
|
||||
langId,
|
||||
buf.as_mut_ptr(),
|
||||
buf.len() as c::DWORD,
|
||||
ptr::null(),
|
||||
) as usize;
|
||||
if res == 0 {
|
||||
// Sometimes FormatMessageW can fail e.g., system doesn't like langId,
|
||||
let fm_err = errno();
|
||||
return format!(
|
||||
"OS Error {} (FormatMessageW() returned error {})",
|
||||
errnum, fm_err
|
||||
);
|
||||
}
|
||||
|
||||
match String::from_utf16(&buf[..res]) {
|
||||
Ok(mut msg) => {
|
||||
// Trim trailing CRLF inserted by FormatMessageW
|
||||
let len = msg.trim_end().len();
|
||||
msg.truncate(len);
|
||||
msg
|
||||
}
|
||||
Err(..) => format!(
|
||||
"OS Error {} (FormatMessageW() returned \
|
||||
invalid UTF-16)",
|
||||
errnum
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,51 +1,51 @@
|
||||
#![cfg(windows)]
|
||||
use std::{env, io, path::PathBuf, process::Command};
|
||||
|
||||
#[test]
|
||||
fn direct_cuinit() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("direct_cuinit")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_early() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_early")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_late() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_late")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_late_clr() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn indirect_cuinit() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("indirect_cuinit")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subprocess() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("subprocess")
|
||||
}
|
||||
|
||||
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
||||
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
|
||||
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
|
||||
zluda_dump_dll.push("zluda_dump.dll");
|
||||
let helpers_dir = env!("HELPERS_OUT_DIR");
|
||||
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
|
||||
let mut test_cmd = Command::new(&zluda_with_exe);
|
||||
let test_cmd = test_cmd
|
||||
.arg("--nvcuda")
|
||||
.arg(&zluda_dump_dll)
|
||||
.arg("--")
|
||||
.arg(&exe_under_test);
|
||||
let test_output = test_cmd.output()?;
|
||||
assert!(test_output.status.success());
|
||||
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
|
||||
assert!(stderr_text.contains("ZLUDA_DUMP"));
|
||||
Ok(())
|
||||
}
|
||||
#![cfg(windows)]
|
||||
use std::{env, io, path::PathBuf, process::Command};
|
||||
|
||||
#[test]
|
||||
fn direct_cuinit() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("direct_cuinit")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_early() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_early")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_late() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_late")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_late_clr() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn indirect_cuinit() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("indirect_cuinit")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subprocess() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("subprocess")
|
||||
}
|
||||
|
||||
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
||||
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
|
||||
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
|
||||
zluda_dump_dll.push("zluda_dump.dll");
|
||||
let helpers_dir = env!("HELPERS_OUT_DIR");
|
||||
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
|
||||
let mut test_cmd = Command::new(&zluda_with_exe);
|
||||
let test_cmd = test_cmd
|
||||
.arg("--nvcuda")
|
||||
.arg(&zluda_dump_dll)
|
||||
.arg("--")
|
||||
.arg(&exe_under_test);
|
||||
let test_output = test_cmd.output()?;
|
||||
assert!(test_output.status.success());
|
||||
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
|
||||
assert!(stderr_text.contains("ZLUDA_DUMP"));
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user