mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:57:43 +03:00
Always use Unix line endings (#453)
This commit is contained in:
1
.rustfmt.toml
Normal file
1
.rustfmt.toml
Normal file
@ -0,0 +1 @@
|
|||||||
|
newline_style = "Unix"
|
@ -1,147 +1,147 @@
|
|||||||
use cmake::Config;
|
use cmake::Config;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
|
|
||||||
const COMPONENTS: &[&'static str] = &[
|
const COMPONENTS: &[&'static str] = &[
|
||||||
"LLVMCore",
|
"LLVMCore",
|
||||||
"LLVMBitWriter",
|
"LLVMBitWriter",
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
"LLVMAnalysis", // for module verify
|
"LLVMAnalysis", // for module verify
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
"LLVMBitReader",
|
"LLVMBitReader",
|
||||||
];
|
];
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let mut cmake = Config::new(r"../ext/llvm-project/llvm");
|
let mut cmake = Config::new(r"../ext/llvm-project/llvm");
|
||||||
try_use_sccache(&mut cmake);
|
try_use_sccache(&mut cmake);
|
||||||
try_use_ninja(&mut cmake);
|
try_use_ninja(&mut cmake);
|
||||||
cmake
|
cmake
|
||||||
// It's not like we can do anything about the warnings
|
// It's not like we can do anything about the warnings
|
||||||
.define("LLVM_ENABLE_WARNINGS", "OFF")
|
.define("LLVM_ENABLE_WARNINGS", "OFF")
|
||||||
// For some reason Rust always links to release CRT
|
// For some reason Rust always links to release CRT
|
||||||
.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded")
|
.define("CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded")
|
||||||
.define("LLVM_ENABLE_TERMINFO", "OFF")
|
.define("LLVM_ENABLE_TERMINFO", "OFF")
|
||||||
.define("LLVM_ENABLE_LIBXML2", "OFF")
|
.define("LLVM_ENABLE_LIBXML2", "OFF")
|
||||||
.define("LLVM_ENABLE_LIBEDIT", "OFF")
|
.define("LLVM_ENABLE_LIBEDIT", "OFF")
|
||||||
.define("LLVM_ENABLE_LIBPFM", "OFF")
|
.define("LLVM_ENABLE_LIBPFM", "OFF")
|
||||||
.define("LLVM_ENABLE_ZLIB", "OFF")
|
.define("LLVM_ENABLE_ZLIB", "OFF")
|
||||||
.define("LLVM_ENABLE_ZSTD", "OFF")
|
.define("LLVM_ENABLE_ZSTD", "OFF")
|
||||||
.define("LLVM_INCLUDE_BENCHMARKS", "OFF")
|
.define("LLVM_INCLUDE_BENCHMARKS", "OFF")
|
||||||
.define("LLVM_INCLUDE_EXAMPLES", "OFF")
|
.define("LLVM_INCLUDE_EXAMPLES", "OFF")
|
||||||
.define("LLVM_INCLUDE_TESTS", "OFF")
|
.define("LLVM_INCLUDE_TESTS", "OFF")
|
||||||
.define("LLVM_BUILD_TOOLS", "OFF")
|
.define("LLVM_BUILD_TOOLS", "OFF")
|
||||||
.define("LLVM_TARGETS_TO_BUILD", "")
|
.define("LLVM_TARGETS_TO_BUILD", "")
|
||||||
.define("LLVM_ENABLE_PROJECTS", "");
|
.define("LLVM_ENABLE_PROJECTS", "");
|
||||||
cmake.build_target("llvm-config");
|
cmake.build_target("llvm-config");
|
||||||
let llvm_dir = cmake.build();
|
let llvm_dir = cmake.build();
|
||||||
for c in COMPONENTS {
|
for c in COMPONENTS {
|
||||||
cmake.build_target(c);
|
cmake.build_target(c);
|
||||||
cmake.build();
|
cmake.build();
|
||||||
}
|
}
|
||||||
let cmake_profile = cmake.get_profile();
|
let cmake_profile = cmake.get_profile();
|
||||||
let (cxxflags, ldflags, libdir, lib_names, system_libs) =
|
let (cxxflags, ldflags, libdir, lib_names, system_libs) =
|
||||||
llvm_config(&llvm_dir, &["build", "bin", "llvm-config"])
|
llvm_config(&llvm_dir, &["build", "bin", "llvm-config"])
|
||||||
.or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"]))
|
.or_else(|_| llvm_config(&llvm_dir, &["build", cmake_profile, "bin", "llvm-config"]))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
println!("cargo:rustc-link-arg={ldflags}");
|
println!("cargo:rustc-link-arg={ldflags}");
|
||||||
println!("cargo:rustc-link-search=native={libdir}");
|
println!("cargo:rustc-link-search=native={libdir}");
|
||||||
for lib in system_libs.split_ascii_whitespace() {
|
for lib in system_libs.split_ascii_whitespace() {
|
||||||
println!("cargo:rustc-link-arg={lib}");
|
println!("cargo:rustc-link-arg={lib}");
|
||||||
}
|
}
|
||||||
link_llvm_components(lib_names);
|
link_llvm_components(lib_names);
|
||||||
compile_cxx_lib(cxxflags);
|
compile_cxx_lib(cxxflags);
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/mozilla/sccache/blob/main/README.md#usage
|
// https://github.com/mozilla/sccache/blob/main/README.md#usage
|
||||||
fn try_use_sccache(cmake: &mut Config) {
|
fn try_use_sccache(cmake: &mut Config) {
|
||||||
if let Ok(sccache) = std::env::var("SCCACHE_PATH") {
|
if let Ok(sccache) = std::env::var("SCCACHE_PATH") {
|
||||||
cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache);
|
cmake.define("CMAKE_CXX_COMPILER_LAUNCHER", &*sccache);
|
||||||
cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache);
|
cmake.define("CMAKE_C_COMPILER_LAUNCHER", &*sccache);
|
||||||
match std::env::var_os("CARGO_CFG_TARGET_OS") {
|
match std::env::var_os("CARGO_CFG_TARGET_OS") {
|
||||||
Some(os) if os == "windows" => {
|
Some(os) if os == "windows" => {
|
||||||
cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded");
|
cmake.define("CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "Embedded");
|
||||||
cmake.define("CMAKE_POLICY_CMP0141", "NEW");
|
cmake.define("CMAKE_POLICY_CMP0141", "NEW");
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_use_ninja(cmake: &mut Config) {
|
fn try_use_ninja(cmake: &mut Config) {
|
||||||
let mut cmd = Command::new("ninja");
|
let mut cmd = Command::new("ninja");
|
||||||
cmd.arg("--version");
|
cmd.arg("--version");
|
||||||
if let Ok(status) = cmd.status() {
|
if let Ok(status) = cmd.status() {
|
||||||
if status.success() {
|
if status.success() {
|
||||||
cmake.generator("Ninja");
|
cmake.generator("Ninja");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn llvm_config(
|
fn llvm_config(
|
||||||
llvm_build_dir: &PathBuf,
|
llvm_build_dir: &PathBuf,
|
||||||
path_to_llvm_config: &[&str],
|
path_to_llvm_config: &[&str],
|
||||||
) -> io::Result<(String, String, String, String, String)> {
|
) -> io::Result<(String, String, String, String, String)> {
|
||||||
let mut llvm_build_path = llvm_build_dir.clone();
|
let mut llvm_build_path = llvm_build_dir.clone();
|
||||||
llvm_build_path.extend(path_to_llvm_config);
|
llvm_build_path.extend(path_to_llvm_config);
|
||||||
let mut cmd = Command::new(llvm_build_path);
|
let mut cmd = Command::new(llvm_build_path);
|
||||||
cmd.args([
|
cmd.args([
|
||||||
"--link-static",
|
"--link-static",
|
||||||
"--cxxflags",
|
"--cxxflags",
|
||||||
"--ldflags",
|
"--ldflags",
|
||||||
"--libdir",
|
"--libdir",
|
||||||
"--libnames",
|
"--libnames",
|
||||||
"--system-libs",
|
"--system-libs",
|
||||||
]);
|
]);
|
||||||
for c in COMPONENTS {
|
for c in COMPONENTS {
|
||||||
cmd.arg(c[4..].to_lowercase());
|
cmd.arg(c[4..].to_lowercase());
|
||||||
}
|
}
|
||||||
let output = cmd.output()?;
|
let output = cmd.output()?;
|
||||||
if !output.status.success() {
|
if !output.status.success() {
|
||||||
return Err(io::Error::from(io::ErrorKind::Other));
|
return Err(io::Error::from(io::ErrorKind::Other));
|
||||||
}
|
}
|
||||||
let output = unsafe { String::from_utf8_unchecked(output.stdout) };
|
let output = unsafe { String::from_utf8_unchecked(output.stdout) };
|
||||||
let mut lines = output.lines();
|
let mut lines = output.lines();
|
||||||
let cxxflags = lines.next().unwrap();
|
let cxxflags = lines.next().unwrap();
|
||||||
let ldflags = lines.next().unwrap();
|
let ldflags = lines.next().unwrap();
|
||||||
let libdir = lines.next().unwrap();
|
let libdir = lines.next().unwrap();
|
||||||
let lib_names = lines.next().unwrap();
|
let lib_names = lines.next().unwrap();
|
||||||
let system_libs = lines.next().unwrap();
|
let system_libs = lines.next().unwrap();
|
||||||
Ok((
|
Ok((
|
||||||
cxxflags.to_string(),
|
cxxflags.to_string(),
|
||||||
ldflags.to_string(),
|
ldflags.to_string(),
|
||||||
libdir.to_string(),
|
libdir.to_string(),
|
||||||
lib_names.to_string(),
|
lib_names.to_string(),
|
||||||
system_libs.to_string(),
|
system_libs.to_string(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compile_cxx_lib(cxxflags: String) {
|
fn compile_cxx_lib(cxxflags: String) {
|
||||||
let mut cc = cc::Build::new();
|
let mut cc = cc::Build::new();
|
||||||
for flag in cxxflags.split_whitespace() {
|
for flag in cxxflags.split_whitespace() {
|
||||||
cc.flag(flag);
|
cc.flag(flag);
|
||||||
}
|
}
|
||||||
cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp");
|
cc.cpp(true).file("src/lib.cpp").compile("llvm_zluda_cpp");
|
||||||
println!("cargo:rerun-if-changed=src/lib.cpp");
|
println!("cargo:rerun-if-changed=src/lib.cpp");
|
||||||
println!("cargo:rerun-if-changed=src/lib.rs");
|
println!("cargo:rerun-if-changed=src/lib.rs");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn link_llvm_components(components: String) {
|
fn link_llvm_components(components: String) {
|
||||||
for component in components.split_whitespace() {
|
for component in components.split_whitespace() {
|
||||||
let component = if let Some(component) = component
|
let component = if let Some(component) = component
|
||||||
.strip_prefix("lib")
|
.strip_prefix("lib")
|
||||||
.and_then(|component| component.strip_suffix(".a"))
|
.and_then(|component| component.strip_suffix(".a"))
|
||||||
{
|
{
|
||||||
// Unix (Linux/Mac)
|
// Unix (Linux/Mac)
|
||||||
// libLLVMfoo.a
|
// libLLVMfoo.a
|
||||||
component
|
component
|
||||||
} else if let Some(component) = component.strip_suffix(".lib") {
|
} else if let Some(component) = component.strip_suffix(".lib") {
|
||||||
// Windows
|
// Windows
|
||||||
// LLVMfoo.lib
|
// LLVMfoo.lib
|
||||||
component
|
component
|
||||||
} else {
|
} else {
|
||||||
panic!("'{}' does not look like a static library name", component)
|
panic!("'{}' does not look like a static library name", component)
|
||||||
};
|
};
|
||||||
println!("cargo:rustc-link-lib={component}");
|
println!("cargo:rustc-link-lib={component}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,81 +1,81 @@
|
|||||||
#![allow(non_upper_case_globals)]
|
#![allow(non_upper_case_globals)]
|
||||||
use llvm_sys::prelude::*;
|
use llvm_sys::prelude::*;
|
||||||
pub use llvm_sys::*;
|
pub use llvm_sys::*;
|
||||||
|
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
pub enum LLVMZludaAtomicRMWBinOp {
|
pub enum LLVMZludaAtomicRMWBinOp {
|
||||||
LLVMZludaAtomicRMWBinOpXchg = 0,
|
LLVMZludaAtomicRMWBinOpXchg = 0,
|
||||||
LLVMZludaAtomicRMWBinOpAdd = 1,
|
LLVMZludaAtomicRMWBinOpAdd = 1,
|
||||||
LLVMZludaAtomicRMWBinOpSub = 2,
|
LLVMZludaAtomicRMWBinOpSub = 2,
|
||||||
LLVMZludaAtomicRMWBinOpAnd = 3,
|
LLVMZludaAtomicRMWBinOpAnd = 3,
|
||||||
LLVMZludaAtomicRMWBinOpNand = 4,
|
LLVMZludaAtomicRMWBinOpNand = 4,
|
||||||
LLVMZludaAtomicRMWBinOpOr = 5,
|
LLVMZludaAtomicRMWBinOpOr = 5,
|
||||||
LLVMZludaAtomicRMWBinOpXor = 6,
|
LLVMZludaAtomicRMWBinOpXor = 6,
|
||||||
LLVMZludaAtomicRMWBinOpMax = 7,
|
LLVMZludaAtomicRMWBinOpMax = 7,
|
||||||
LLVMZludaAtomicRMWBinOpMin = 8,
|
LLVMZludaAtomicRMWBinOpMin = 8,
|
||||||
LLVMZludaAtomicRMWBinOpUMax = 9,
|
LLVMZludaAtomicRMWBinOpUMax = 9,
|
||||||
LLVMZludaAtomicRMWBinOpUMin = 10,
|
LLVMZludaAtomicRMWBinOpUMin = 10,
|
||||||
LLVMZludaAtomicRMWBinOpFAdd = 11,
|
LLVMZludaAtomicRMWBinOpFAdd = 11,
|
||||||
LLVMZludaAtomicRMWBinOpFSub = 12,
|
LLVMZludaAtomicRMWBinOpFSub = 12,
|
||||||
LLVMZludaAtomicRMWBinOpFMax = 13,
|
LLVMZludaAtomicRMWBinOpFMax = 13,
|
||||||
LLVMZludaAtomicRMWBinOpFMin = 14,
|
LLVMZludaAtomicRMWBinOpFMin = 14,
|
||||||
LLVMZludaAtomicRMWBinOpUIncWrap = 15,
|
LLVMZludaAtomicRMWBinOpUIncWrap = 15,
|
||||||
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
|
LLVMZludaAtomicRMWBinOpUDecWrap = 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backport from LLVM 19
|
// Backport from LLVM 19
|
||||||
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
|
pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0;
|
||||||
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
|
pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1;
|
||||||
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
|
pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2;
|
||||||
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
|
pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3;
|
||||||
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
|
pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4;
|
||||||
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
|
pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5;
|
||||||
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
|
pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6;
|
||||||
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
|
pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0;
|
||||||
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
|
pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc
|
||||||
| LLVMZludaFastMathNoNaNs
|
| LLVMZludaFastMathNoNaNs
|
||||||
| LLVMZludaFastMathNoInfs
|
| LLVMZludaFastMathNoInfs
|
||||||
| LLVMZludaFastMathNoSignedZeros
|
| LLVMZludaFastMathNoSignedZeros
|
||||||
| LLVMZludaFastMathAllowReciprocal
|
| LLVMZludaFastMathAllowReciprocal
|
||||||
| LLVMZludaFastMathAllowContract
|
| LLVMZludaFastMathAllowContract
|
||||||
| LLVMZludaFastMathApproxFunc;
|
| LLVMZludaFastMathApproxFunc;
|
||||||
|
|
||||||
pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
|
pub type LLVMZludaFastMathFlags = std::ffi::c_uint;
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
pub fn LLVMZludaBuildAlloca(
|
pub fn LLVMZludaBuildAlloca(
|
||||||
B: LLVMBuilderRef,
|
B: LLVMBuilderRef,
|
||||||
Ty: LLVMTypeRef,
|
Ty: LLVMTypeRef,
|
||||||
AddrSpace: u32,
|
AddrSpace: u32,
|
||||||
Name: *const i8,
|
Name: *const i8,
|
||||||
) -> LLVMValueRef;
|
) -> LLVMValueRef;
|
||||||
|
|
||||||
pub fn LLVMZludaBuildAtomicRMW(
|
pub fn LLVMZludaBuildAtomicRMW(
|
||||||
B: LLVMBuilderRef,
|
B: LLVMBuilderRef,
|
||||||
op: LLVMZludaAtomicRMWBinOp,
|
op: LLVMZludaAtomicRMWBinOp,
|
||||||
PTR: LLVMValueRef,
|
PTR: LLVMValueRef,
|
||||||
Val: LLVMValueRef,
|
Val: LLVMValueRef,
|
||||||
scope: *const i8,
|
scope: *const i8,
|
||||||
ordering: LLVMAtomicOrdering,
|
ordering: LLVMAtomicOrdering,
|
||||||
) -> LLVMValueRef;
|
) -> LLVMValueRef;
|
||||||
|
|
||||||
pub fn LLVMZludaBuildAtomicCmpXchg(
|
pub fn LLVMZludaBuildAtomicCmpXchg(
|
||||||
B: LLVMBuilderRef,
|
B: LLVMBuilderRef,
|
||||||
Ptr: LLVMValueRef,
|
Ptr: LLVMValueRef,
|
||||||
Cmp: LLVMValueRef,
|
Cmp: LLVMValueRef,
|
||||||
New: LLVMValueRef,
|
New: LLVMValueRef,
|
||||||
scope: *const i8,
|
scope: *const i8,
|
||||||
SuccessOrdering: LLVMAtomicOrdering,
|
SuccessOrdering: LLVMAtomicOrdering,
|
||||||
FailureOrdering: LLVMAtomicOrdering,
|
FailureOrdering: LLVMAtomicOrdering,
|
||||||
) -> LLVMValueRef;
|
) -> LLVMValueRef;
|
||||||
|
|
||||||
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
|
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
|
||||||
|
|
||||||
pub fn LLVMZludaBuildFence(
|
pub fn LLVMZludaBuildFence(
|
||||||
B: LLVMBuilderRef,
|
B: LLVMBuilderRef,
|
||||||
ordering: LLVMAtomicOrdering,
|
ordering: LLVMAtomicOrdering,
|
||||||
scope: *const i8,
|
scope: *const i8,
|
||||||
Name: *const i8,
|
Name: *const i8,
|
||||||
) -> LLVMValueRef;
|
) -> LLVMValueRef;
|
||||||
}
|
}
|
||||||
|
@ -1,191 +1,191 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'a, 'input>(
|
pub(super) fn run<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2,
|
resolver: &mut GlobalStringIdentResolver2,
|
||||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_method<'input>(
|
fn run_method<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2,
|
resolver: &mut GlobalStringIdentResolver2,
|
||||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
let is_declaration = method.body.is_none();
|
let is_declaration = method.body.is_none();
|
||||||
let mut body = Vec::new();
|
let mut body = Vec::new();
|
||||||
let mut remap_returns = Vec::new();
|
let mut remap_returns = Vec::new();
|
||||||
if !method.is_kernel {
|
if !method.is_kernel {
|
||||||
for arg in method.return_arguments.iter_mut() {
|
for arg in method.return_arguments.iter_mut() {
|
||||||
match arg.state_space {
|
match arg.state_space {
|
||||||
ptx_parser::StateSpace::Param => {
|
ptx_parser::StateSpace::Param => {
|
||||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||||
let old_name = arg.name;
|
let old_name = arg.name;
|
||||||
arg.name =
|
arg.name =
|
||||||
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||||
if is_declaration {
|
if is_declaration {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
||||||
body.push(Statement::Variable(ast::Variable {
|
body.push(Statement::Variable(ast::Variable {
|
||||||
align: None,
|
align: None,
|
||||||
name: old_name,
|
name: old_name,
|
||||||
v_type: arg.v_type.clone(),
|
v_type: arg.v_type.clone(),
|
||||||
state_space: ptx_parser::StateSpace::Param,
|
state_space: ptx_parser::StateSpace::Param,
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
ptx_parser::StateSpace::Reg => {}
|
ptx_parser::StateSpace::Reg => {}
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for arg in method.input_arguments.iter_mut() {
|
for arg in method.input_arguments.iter_mut() {
|
||||||
match arg.state_space {
|
match arg.state_space {
|
||||||
ptx_parser::StateSpace::Param => {
|
ptx_parser::StateSpace::Param => {
|
||||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||||
let old_name = arg.name;
|
let old_name = arg.name;
|
||||||
arg.name =
|
arg.name =
|
||||||
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||||
if is_declaration {
|
if is_declaration {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
body.push(Statement::Variable(ast::Variable {
|
body.push(Statement::Variable(ast::Variable {
|
||||||
align: None,
|
align: None,
|
||||||
name: old_name,
|
name: old_name,
|
||||||
v_type: arg.v_type.clone(),
|
v_type: arg.v_type.clone(),
|
||||||
state_space: ptx_parser::StateSpace::Param,
|
state_space: ptx_parser::StateSpace::Param,
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
}));
|
}));
|
||||||
body.push(Statement::Instruction(ast::Instruction::St {
|
body.push(Statement::Instruction(ast::Instruction::St {
|
||||||
data: ast::StData {
|
data: ast::StData {
|
||||||
qualifier: ast::LdStQualifier::Weak,
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
state_space: ast::StateSpace::Param,
|
state_space: ast::StateSpace::Param,
|
||||||
caching: ast::StCacheOperator::Writethrough,
|
caching: ast::StCacheOperator::Writethrough,
|
||||||
typ: arg.v_type.clone(),
|
typ: arg.v_type.clone(),
|
||||||
},
|
},
|
||||||
arguments: ast::StArgs {
|
arguments: ast::StArgs {
|
||||||
src1: old_name,
|
src1: old_name,
|
||||||
src2: arg.name,
|
src2: arg.name,
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
ptx_parser::StateSpace::Reg => {}
|
ptx_parser::StateSpace::Reg => {}
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
for statement in statements {
|
for statement in statements {
|
||||||
run_statement(resolver, &remap_returns, &mut body, statement)?;
|
run_statement(resolver, &remap_returns, &mut body, statement)?;
|
||||||
}
|
}
|
||||||
Ok::<_, TranslateError>(body)
|
Ok::<_, TranslateError>(body)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 { body, ..method })
|
Ok(Function2 { body, ..method })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'input>(
|
fn run_statement<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
|
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
|
||||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
match statement {
|
match statement {
|
||||||
Statement::Instruction(ast::Instruction::Call {
|
Statement::Instruction(ast::Instruction::Call {
|
||||||
mut data,
|
mut data,
|
||||||
mut arguments,
|
mut arguments,
|
||||||
}) => {
|
}) => {
|
||||||
let mut post_st = Vec::new();
|
let mut post_st = Vec::new();
|
||||||
for ((type_, space), ident) in data
|
for ((type_, space), ident) in data
|
||||||
.input_arguments
|
.input_arguments
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.zip(arguments.input_arguments.iter_mut())
|
.zip(arguments.input_arguments.iter_mut())
|
||||||
{
|
{
|
||||||
if *space == ptx_parser::StateSpace::Param {
|
if *space == ptx_parser::StateSpace::Param {
|
||||||
*space = ptx_parser::StateSpace::Reg;
|
*space = ptx_parser::StateSpace::Reg;
|
||||||
let old_name = *ident;
|
let old_name = *ident;
|
||||||
*ident = resolver
|
*ident = resolver
|
||||||
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
|
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
|
||||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||||
data: ast::LdDetails {
|
data: ast::LdDetails {
|
||||||
qualifier: ast::LdStQualifier::Weak,
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
state_space: ast::StateSpace::Param,
|
state_space: ast::StateSpace::Param,
|
||||||
caching: ast::LdCacheOperator::Cached,
|
caching: ast::LdCacheOperator::Cached,
|
||||||
typ: type_.clone(),
|
typ: type_.clone(),
|
||||||
non_coherent: false,
|
non_coherent: false,
|
||||||
},
|
},
|
||||||
arguments: ast::LdArgs {
|
arguments: ast::LdArgs {
|
||||||
dst: *ident,
|
dst: *ident,
|
||||||
src: old_name,
|
src: old_name,
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for ((type_, space), ident) in data
|
for ((type_, space), ident) in data
|
||||||
.return_arguments
|
.return_arguments
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.zip(arguments.return_arguments.iter_mut())
|
.zip(arguments.return_arguments.iter_mut())
|
||||||
{
|
{
|
||||||
if *space == ptx_parser::StateSpace::Param {
|
if *space == ptx_parser::StateSpace::Param {
|
||||||
*space = ptx_parser::StateSpace::Reg;
|
*space = ptx_parser::StateSpace::Reg;
|
||||||
let old_name = *ident;
|
let old_name = *ident;
|
||||||
*ident = resolver
|
*ident = resolver
|
||||||
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
|
.register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
|
||||||
post_st.push(Statement::Instruction(ast::Instruction::St {
|
post_st.push(Statement::Instruction(ast::Instruction::St {
|
||||||
data: ast::StData {
|
data: ast::StData {
|
||||||
qualifier: ast::LdStQualifier::Weak,
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
state_space: ast::StateSpace::Param,
|
state_space: ast::StateSpace::Param,
|
||||||
caching: ast::StCacheOperator::Writethrough,
|
caching: ast::StCacheOperator::Writethrough,
|
||||||
typ: type_.clone(),
|
typ: type_.clone(),
|
||||||
},
|
},
|
||||||
arguments: ast::StArgs {
|
arguments: ast::StArgs {
|
||||||
src1: old_name,
|
src1: old_name,
|
||||||
src2: *ident,
|
src2: *ident,
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result.push(Statement::Instruction(ast::Instruction::Call {
|
result.push(Statement::Instruction(ast::Instruction::Call {
|
||||||
data,
|
data,
|
||||||
arguments,
|
arguments,
|
||||||
}));
|
}));
|
||||||
result.extend(post_st.into_iter());
|
result.extend(post_st.into_iter());
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
||||||
for (old_name, new_name, type_) in remap_returns.iter() {
|
for (old_name, new_name, type_) in remap_returns.iter() {
|
||||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||||
data: ast::LdDetails {
|
data: ast::LdDetails {
|
||||||
qualifier: ast::LdStQualifier::Weak,
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
state_space: ast::StateSpace::Param,
|
state_space: ast::StateSpace::Param,
|
||||||
caching: ast::LdCacheOperator::Cached,
|
caching: ast::LdCacheOperator::Cached,
|
||||||
typ: type_.clone(),
|
typ: type_.clone(),
|
||||||
non_coherent: false,
|
non_coherent: false,
|
||||||
},
|
},
|
||||||
arguments: ast::LdArgs {
|
arguments: ast::LdArgs {
|
||||||
dst: *new_name,
|
dst: *new_name,
|
||||||
src: *old_name,
|
src: *old_name,
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
result.push(Statement::Instruction(ast::Instruction::Ret { data }));
|
result.push(Statement::Instruction(ast::Instruction::Ret { data }));
|
||||||
}
|
}
|
||||||
statement => {
|
statement => {
|
||||||
result.push(statement);
|
result.push(statement);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,301 +1,301 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'a, 'input>(
|
pub(super) fn run<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<UnconditionalDirective>,
|
directives: Vec<UnconditionalDirective>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directive: Directive2<
|
directive: Directive2<
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>,
|
>,
|
||||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_method<'input>(
|
fn run_method<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
method: Function2<
|
method: Function2<
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>,
|
>,
|
||||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
let mut result = Vec::with_capacity(statements.len());
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
for statement in statements {
|
for statement in statements {
|
||||||
run_statement(resolver, &mut result, statement)?;
|
run_statement(resolver, &mut result, statement)?;
|
||||||
}
|
}
|
||||||
Ok::<_, TranslateError>(result)
|
Ok::<_, TranslateError>(result)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
body,
|
body,
|
||||||
return_arguments: method.return_arguments,
|
return_arguments: method.return_arguments,
|
||||||
name: method.name,
|
name: method.name,
|
||||||
input_arguments: method.input_arguments,
|
input_arguments: method.input_arguments,
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
is_kernel: method.is_kernel,
|
is_kernel: method.is_kernel,
|
||||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||||
rounding_mode_f32: method.rounding_mode_f32,
|
rounding_mode_f32: method.rounding_mode_f32,
|
||||||
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'input>(
|
fn run_statement<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
statement: UnconditionalStatement,
|
statement: UnconditionalStatement,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
let mut visitor = FlattenArguments::new(resolver, result);
|
let mut visitor = FlattenArguments::new(resolver, result);
|
||||||
let new_statement = statement.visit_map(&mut visitor)?;
|
let new_statement = statement.visit_map(&mut visitor)?;
|
||||||
visitor.result.push(new_statement);
|
visitor.result.push(new_statement);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FlattenArguments<'a, 'input> {
|
struct FlattenArguments<'a, 'input> {
|
||||||
result: &'a mut Vec<ExpandedStatement>,
|
result: &'a mut Vec<ExpandedStatement>,
|
||||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
post_stmts: Vec<ExpandedStatement>,
|
post_stmts: Vec<ExpandedStatement>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'input> FlattenArguments<'a, 'input> {
|
impl<'a, 'input> FlattenArguments<'a, 'input> {
|
||||||
fn new(
|
fn new(
|
||||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
result: &'a mut Vec<ExpandedStatement>,
|
result: &'a mut Vec<ExpandedStatement>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
FlattenArguments {
|
FlattenArguments {
|
||||||
result,
|
result,
|
||||||
resolver,
|
resolver,
|
||||||
post_stmts: Vec::new(),
|
post_stmts: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||||
Ok(name)
|
Ok(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reg_offset(
|
fn reg_offset(
|
||||||
&mut self,
|
&mut self,
|
||||||
reg: SpirvWord,
|
reg: SpirvWord,
|
||||||
offset: i32,
|
offset: i32,
|
||||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
_is_dst: bool,
|
_is_dst: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||||
(type_, state_space)
|
(type_, state_space)
|
||||||
} else {
|
} else {
|
||||||
return Err(TranslateError::UntypedSymbol);
|
return Err(TranslateError::UntypedSymbol);
|
||||||
};
|
};
|
||||||
if state_space == ast::StateSpace::Reg {
|
if state_space == ast::StateSpace::Reg {
|
||||||
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
|
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
|
||||||
if *reg_space != ast::StateSpace::Reg {
|
if *reg_space != ast::StateSpace::Reg {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
let reg_scalar_type = match reg_type {
|
let reg_scalar_type = match reg_type {
|
||||||
ast::Type::Scalar(underlying_type) => *underlying_type,
|
ast::Type::Scalar(underlying_type) => *underlying_type,
|
||||||
_ => return Err(error_mismatched_type()),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let reg_type = reg_type.clone();
|
let reg_type = reg_type.clone();
|
||||||
let id_constant_stmt = self
|
let id_constant_stmt = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
|
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
|
||||||
self.result.push(Statement::Constant(ConstantDefinition {
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
dst: id_constant_stmt,
|
dst: id_constant_stmt,
|
||||||
typ: reg_scalar_type,
|
typ: reg_scalar_type,
|
||||||
value: ast::ImmediateValue::S64(offset as i64),
|
value: ast::ImmediateValue::S64(offset as i64),
|
||||||
}));
|
}));
|
||||||
let arith_details = match reg_scalar_type.kind() {
|
let arith_details = match reg_scalar_type.kind() {
|
||||||
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
type_: reg_scalar_type,
|
type_: reg_scalar_type,
|
||||||
saturate: false,
|
saturate: false,
|
||||||
}),
|
}),
|
||||||
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
type_: reg_scalar_type,
|
type_: reg_scalar_type,
|
||||||
saturate: false,
|
saturate: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
};
|
};
|
||||||
let id_add_result = self
|
let id_add_result = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((reg_type, state_space)));
|
.register_unnamed(Some((reg_type, state_space)));
|
||||||
self.result
|
self.result
|
||||||
.push(Statement::Instruction(ast::Instruction::Add {
|
.push(Statement::Instruction(ast::Instruction::Add {
|
||||||
data: arith_details,
|
data: arith_details,
|
||||||
arguments: ast::AddArgs {
|
arguments: ast::AddArgs {
|
||||||
dst: id_add_result,
|
dst: id_add_result,
|
||||||
src1: reg,
|
src1: reg,
|
||||||
src2: id_constant_stmt,
|
src2: id_constant_stmt,
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
Ok(id_add_result)
|
Ok(id_add_result)
|
||||||
} else {
|
} else {
|
||||||
let id_constant_stmt = self.resolver.register_unnamed(Some((
|
let id_constant_stmt = self.resolver.register_unnamed(Some((
|
||||||
ast::Type::Scalar(ast::ScalarType::S64),
|
ast::Type::Scalar(ast::ScalarType::S64),
|
||||||
ast::StateSpace::Reg,
|
ast::StateSpace::Reg,
|
||||||
)));
|
)));
|
||||||
self.result.push(Statement::Constant(ConstantDefinition {
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
dst: id_constant_stmt,
|
dst: id_constant_stmt,
|
||||||
typ: ast::ScalarType::S64,
|
typ: ast::ScalarType::S64,
|
||||||
value: ast::ImmediateValue::S64(offset as i64),
|
value: ast::ImmediateValue::S64(offset as i64),
|
||||||
}));
|
}));
|
||||||
let dst = self
|
let dst = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((type_.clone(), state_space)));
|
.register_unnamed(Some((type_.clone(), state_space)));
|
||||||
self.result.push(Statement::PtrAccess(PtrAccess {
|
self.result.push(Statement::PtrAccess(PtrAccess {
|
||||||
underlying_type: type_.clone(),
|
underlying_type: type_.clone(),
|
||||||
state_space: state_space,
|
state_space: state_space,
|
||||||
dst,
|
dst,
|
||||||
ptr_src: reg,
|
ptr_src: reg,
|
||||||
offset_src: id_constant_stmt,
|
offset_src: id_constant_stmt,
|
||||||
}));
|
}));
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn immediate(
|
fn immediate(
|
||||||
&mut self,
|
&mut self,
|
||||||
value: ast::ImmediateValue,
|
value: ast::ImmediateValue,
|
||||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
let (scalar_t, state_space) =
|
let (scalar_t, state_space) =
|
||||||
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||||
(*scalar, state_space)
|
(*scalar, state_space)
|
||||||
} else {
|
} else {
|
||||||
return Err(TranslateError::UntypedSymbol);
|
return Err(TranslateError::UntypedSymbol);
|
||||||
};
|
};
|
||||||
let id = self
|
let id = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
|
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
|
||||||
self.result.push(Statement::Constant(ConstantDefinition {
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
dst: id,
|
dst: id,
|
||||||
typ: scalar_t,
|
typ: scalar_t,
|
||||||
value,
|
value,
|
||||||
}));
|
}));
|
||||||
Ok(id)
|
Ok(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_member(
|
fn vec_member(
|
||||||
&mut self,
|
&mut self,
|
||||||
vector_ident: SpirvWord,
|
vector_ident: SpirvWord,
|
||||||
member: u8,
|
member: u8,
|
||||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
|
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
|
||||||
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
||||||
(*vector_width, *scalar_t, *space)
|
(*vector_width, *scalar_t, *space)
|
||||||
}
|
}
|
||||||
_ => return Err(error_mismatched_type()),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let temporary = self
|
let temporary = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((scalar_type.into(), space)));
|
.register_unnamed(Some((scalar_type.into(), space)));
|
||||||
if is_dst {
|
if is_dst {
|
||||||
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
|
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
|
||||||
scalar_type,
|
scalar_type,
|
||||||
vector_width,
|
vector_width,
|
||||||
vector_dst: vector_ident,
|
vector_dst: vector_ident,
|
||||||
vector_src: vector_ident,
|
vector_src: vector_ident,
|
||||||
scalar_src: temporary,
|
scalar_src: temporary,
|
||||||
member,
|
member,
|
||||||
}));
|
}));
|
||||||
} else {
|
} else {
|
||||||
self.result.push(Statement::VectorRead(VectorRead {
|
self.result.push(Statement::VectorRead(VectorRead {
|
||||||
scalar_type,
|
scalar_type,
|
||||||
vector_width,
|
vector_width,
|
||||||
scalar_dst: temporary,
|
scalar_dst: temporary,
|
||||||
vector_src: vector_ident,
|
vector_src: vector_ident,
|
||||||
member,
|
member,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
Ok(temporary)
|
Ok(temporary)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_pack(
|
fn vec_pack(
|
||||||
&mut self,
|
&mut self,
|
||||||
vector_elements: Vec<SpirvWord>,
|
vector_elements: Vec<SpirvWord>,
|
||||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
relaxed_type_check: bool,
|
relaxed_type_check: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
let (width, scalar_t, state_space) = match type_space {
|
let (width, scalar_t, state_space) = match type_space {
|
||||||
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
|
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
|
||||||
_ => return Err(error_mismatched_type()),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let temporary_vector = self
|
let temporary_vector = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
|
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
|
||||||
let statement = Statement::RepackVector(RepackVectorDetails {
|
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||||
is_extract: is_dst,
|
is_extract: is_dst,
|
||||||
typ: scalar_t,
|
typ: scalar_t,
|
||||||
packed: temporary_vector,
|
packed: temporary_vector,
|
||||||
unpacked: vector_elements,
|
unpacked: vector_elements,
|
||||||
relaxed_type_check,
|
relaxed_type_check,
|
||||||
});
|
});
|
||||||
if is_dst {
|
if is_dst {
|
||||||
self.post_stmts.push(statement);
|
self.post_stmts.push(statement);
|
||||||
} else {
|
} else {
|
||||||
self.result.push(statement);
|
self.result.push(statement);
|
||||||
}
|
}
|
||||||
Ok(temporary_vector)
|
Ok(temporary_vector)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
|
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
|
||||||
for FlattenArguments<'a, 'b>
|
for FlattenArguments<'a, 'b>
|
||||||
{
|
{
|
||||||
fn visit(
|
fn visit(
|
||||||
&mut self,
|
&mut self,
|
||||||
args: ast::ParsedOperand<SpirvWord>,
|
args: ast::ParsedOperand<SpirvWord>,
|
||||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
relaxed_type_check: bool,
|
relaxed_type_check: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
match args {
|
match args {
|
||||||
ast::ParsedOperand::Reg(r) => self.reg(r),
|
ast::ParsedOperand::Reg(r) => self.reg(r),
|
||||||
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
|
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
|
||||||
ast::ParsedOperand::RegOffset(reg, offset) => {
|
ast::ParsedOperand::RegOffset(reg, offset) => {
|
||||||
self.reg_offset(reg, offset, type_space, is_dst)
|
self.reg_offset(reg, offset, type_space, is_dst)
|
||||||
}
|
}
|
||||||
ast::ParsedOperand::VecMember(vec, member) => {
|
ast::ParsedOperand::VecMember(vec, member) => {
|
||||||
self.vec_member(vec, member, type_space, is_dst)
|
self.vec_member(vec, member, type_space, is_dst)
|
||||||
}
|
}
|
||||||
ast::ParsedOperand::VecPack(vecs) => {
|
ast::ParsedOperand::VecPack(vecs) => {
|
||||||
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
|
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_ident(
|
fn visit_ident(
|
||||||
&mut self,
|
&mut self,
|
||||||
name: SpirvWord,
|
name: SpirvWord,
|
||||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
_is_dst: bool,
|
_is_dst: bool,
|
||||||
_relaxed_type_check: bool,
|
_relaxed_type_check: bool,
|
||||||
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
|
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
|
||||||
self.reg(name)
|
self.reg(name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for FlattenArguments<'_, '_> {
|
impl Drop for FlattenArguments<'_, '_> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.result.extend(self.post_stmts.drain(..));
|
self.result.extend(self.post_stmts.drain(..));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,208 +1,208 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'a, 'input>(
|
pub(super) fn run<'a, 'input>(
|
||||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
special_registers: &'a SpecialRegistersMap2,
|
special_registers: &'a SpecialRegistersMap2,
|
||||||
directives: Vec<UnconditionalDirective>,
|
directives: Vec<UnconditionalDirective>,
|
||||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||||
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
|
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
|
||||||
let mut sreg_to_function =
|
let mut sreg_to_function =
|
||||||
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
|
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
|
||||||
SpecialRegistersMap2::foreach_declaration(
|
SpecialRegistersMap2::foreach_declaration(
|
||||||
resolver,
|
resolver,
|
||||||
|sreg, (return_arguments, name, input_arguments)| {
|
|sreg, (return_arguments, name, input_arguments)| {
|
||||||
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||||
return_arguments,
|
return_arguments,
|
||||||
name,
|
name,
|
||||||
input_arguments,
|
input_arguments,
|
||||||
body: None,
|
body: None,
|
||||||
import_as: None,
|
import_as: None,
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
linkage: ast::LinkingDirective::EXTERN,
|
linkage: ast::LinkingDirective::EXTERN,
|
||||||
is_kernel: false,
|
is_kernel: false,
|
||||||
flush_to_zero_f32: false,
|
flush_to_zero_f32: false,
|
||||||
flush_to_zero_f16f64: false,
|
flush_to_zero_f16f64: false,
|
||||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||||
}));
|
}));
|
||||||
sreg_to_function.insert(sreg, name);
|
sreg_to_function.insert(sreg, name);
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
let mut visitor = SpecialRegisterResolver {
|
let mut visitor = SpecialRegisterResolver {
|
||||||
resolver,
|
resolver,
|
||||||
special_registers,
|
special_registers,
|
||||||
sreg_to_function,
|
sreg_to_function,
|
||||||
result: Vec::new(),
|
result: Vec::new(),
|
||||||
};
|
};
|
||||||
for directive in directives.into_iter() {
|
for directive in directives.into_iter() {
|
||||||
result.push(run_directive(&mut visitor, directive)?);
|
result.push(run_directive(&mut visitor, directive)?);
|
||||||
}
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'a, 'input>(
|
fn run_directive<'a, 'input>(
|
||||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||||
directive: UnconditionalDirective,
|
directive: UnconditionalDirective,
|
||||||
) -> Result<UnconditionalDirective, TranslateError> {
|
) -> Result<UnconditionalDirective, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_method<'a, 'input>(
|
fn run_method<'a, 'input>(
|
||||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||||
method: UnconditionalFunction,
|
method: UnconditionalFunction,
|
||||||
) -> Result<UnconditionalFunction, TranslateError> {
|
) -> Result<UnconditionalFunction, TranslateError> {
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
let mut result = Vec::with_capacity(statements.len());
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
for statement in statements {
|
for statement in statements {
|
||||||
run_statement(visitor, &mut result, statement)?;
|
run_statement(visitor, &mut result, statement)?;
|
||||||
}
|
}
|
||||||
Ok::<_, TranslateError>(result)
|
Ok::<_, TranslateError>(result)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 { body, ..method })
|
Ok(Function2 { body, ..method })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'a, 'input>(
|
fn run_statement<'a, 'input>(
|
||||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||||
result: &mut Vec<UnconditionalStatement>,
|
result: &mut Vec<UnconditionalStatement>,
|
||||||
statement: UnconditionalStatement,
|
statement: UnconditionalStatement,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
let converted_statement = statement.visit_map(visitor)?;
|
let converted_statement = statement.visit_map(visitor)?;
|
||||||
result.extend(visitor.result.drain(..));
|
result.extend(visitor.result.drain(..));
|
||||||
result.push(converted_statement);
|
result.push(converted_statement);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SpecialRegisterResolver<'a, 'input> {
|
struct SpecialRegisterResolver<'a, 'input> {
|
||||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
special_registers: &'a SpecialRegistersMap2,
|
special_registers: &'a SpecialRegistersMap2,
|
||||||
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||||
result: Vec<UnconditionalStatement>,
|
result: Vec<UnconditionalStatement>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b, 'input>
|
impl<'a, 'b, 'input>
|
||||||
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
|
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
|
||||||
for SpecialRegisterResolver<'a, 'input>
|
for SpecialRegisterResolver<'a, 'input>
|
||||||
{
|
{
|
||||||
fn visit(
|
fn visit(
|
||||||
&mut self,
|
&mut self,
|
||||||
operand: ast::ParsedOperand<SpirvWord>,
|
operand: ast::ParsedOperand<SpirvWord>,
|
||||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
_relaxed_type_check: bool,
|
_relaxed_type_check: bool,
|
||||||
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
|
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
|
||||||
map_operand(operand, &mut |ident, vector_index| {
|
map_operand(operand, &mut |ident, vector_index| {
|
||||||
self.replace_sreg(ident, vector_index, is_dst)
|
self.replace_sreg(ident, vector_index, is_dst)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_ident(
|
fn visit_ident(
|
||||||
&mut self,
|
&mut self,
|
||||||
args: SpirvWord,
|
args: SpirvWord,
|
||||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
_relaxed_type_check: bool,
|
_relaxed_type_check: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
|
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
||||||
fn replace_sreg(
|
fn replace_sreg(
|
||||||
&mut self,
|
&mut self,
|
||||||
name: SpirvWord,
|
name: SpirvWord,
|
||||||
vector_index: Option<u8>,
|
vector_index: Option<u8>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
) -> Result<Option<SpirvWord>, TranslateError> {
|
) -> Result<Option<SpirvWord>, TranslateError> {
|
||||||
if let Some(sreg) = self.special_registers.get(name) {
|
if let Some(sreg) = self.special_registers.get(name) {
|
||||||
if is_dst {
|
if is_dst {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||||
(Some(idx), Some(inp_type)) => {
|
(Some(idx), Some(inp_type)) => {
|
||||||
if inp_type != ast::ScalarType::U8 {
|
if inp_type != ast::ScalarType::U8 {
|
||||||
return Err(TranslateError::Unreachable);
|
return Err(TranslateError::Unreachable);
|
||||||
}
|
}
|
||||||
let constant = self.resolver.register_unnamed(Some((
|
let constant = self.resolver.register_unnamed(Some((
|
||||||
ast::Type::Scalar(inp_type),
|
ast::Type::Scalar(inp_type),
|
||||||
ast::StateSpace::Reg,
|
ast::StateSpace::Reg,
|
||||||
)));
|
)));
|
||||||
self.result.push(Statement::Constant(ConstantDefinition {
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
dst: constant,
|
dst: constant,
|
||||||
typ: inp_type,
|
typ: inp_type,
|
||||||
value: ast::ImmediateValue::U64(idx as u64),
|
value: ast::ImmediateValue::U64(idx as u64),
|
||||||
}));
|
}));
|
||||||
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
|
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
|
||||||
}
|
}
|
||||||
(None, None) => Vec::new(),
|
(None, None) => Vec::new(),
|
||||||
_ => return Err(error_mismatched_type()),
|
_ => return Err(error_mismatched_type()),
|
||||||
};
|
};
|
||||||
let return_type = sreg.get_function_return_type();
|
let return_type = sreg.get_function_return_type();
|
||||||
let fn_result = self
|
let fn_result = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
|
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
|
||||||
let return_arguments = vec![(
|
let return_arguments = vec![(
|
||||||
fn_result,
|
fn_result,
|
||||||
ast::Type::Scalar(return_type),
|
ast::Type::Scalar(return_type),
|
||||||
ast::StateSpace::Reg,
|
ast::StateSpace::Reg,
|
||||||
)];
|
)];
|
||||||
let data = ast::CallDetails {
|
let data = ast::CallDetails {
|
||||||
uniform: false,
|
uniform: false,
|
||||||
return_arguments: return_arguments
|
return_arguments: return_arguments
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||||
.collect(),
|
.collect(),
|
||||||
input_arguments: input_arguments
|
input_arguments: input_arguments
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||||
.collect(),
|
.collect(),
|
||||||
};
|
};
|
||||||
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
|
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
|
||||||
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||||
func: self.sreg_to_function[&sreg],
|
func: self.sreg_to_function[&sreg],
|
||||||
input_arguments: input_arguments
|
input_arguments: input_arguments
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
|
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
|
||||||
.collect(),
|
.collect(),
|
||||||
};
|
};
|
||||||
self.result
|
self.result
|
||||||
.push(Statement::Instruction(ast::Instruction::Call {
|
.push(Statement::Instruction(ast::Instruction::Call {
|
||||||
data,
|
data,
|
||||||
arguments,
|
arguments,
|
||||||
}));
|
}));
|
||||||
Ok(Some(fn_result))
|
Ok(Some(fn_result))
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn map_operand<T: Copy, Err>(
|
pub fn map_operand<T: Copy, Err>(
|
||||||
this: ast::ParsedOperand<T>,
|
this: ast::ParsedOperand<T>,
|
||||||
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
|
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
|
||||||
) -> Result<ast::ParsedOperand<T>, Err> {
|
) -> Result<ast::ParsedOperand<T>, Err> {
|
||||||
Ok(match this {
|
Ok(match this {
|
||||||
ast::ParsedOperand::Reg(ident) => {
|
ast::ParsedOperand::Reg(ident) => {
|
||||||
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
|
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
|
||||||
}
|
}
|
||||||
ast::ParsedOperand::RegOffset(ident, offset) => {
|
ast::ParsedOperand::RegOffset(ident, offset) => {
|
||||||
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
|
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
|
||||||
}
|
}
|
||||||
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
||||||
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
|
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
|
||||||
Some(ident) => ast::ParsedOperand::Reg(ident),
|
Some(ident) => ast::ParsedOperand::Reg(ident),
|
||||||
None => ast::ParsedOperand::VecMember(ident, member),
|
None => ast::ParsedOperand::VecMember(ident, member),
|
||||||
},
|
},
|
||||||
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||||
idents
|
idents
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
|
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
|
||||||
.collect::<Result<Vec<_>, _>>()?,
|
.collect::<Result<Vec<_>, _>>()?,
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,45 +1,45 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'input>(
|
pub(super) fn run<'input>(
|
||||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
let mut result = Vec::with_capacity(directives.len());
|
let mut result = Vec::with_capacity(directives.len());
|
||||||
for mut directive in directives.into_iter() {
|
for mut directive in directives.into_iter() {
|
||||||
run_directive(&mut result, &mut directive)?;
|
run_directive(&mut result, &mut directive)?;
|
||||||
result.push(directive);
|
result.push(directive);
|
||||||
}
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
match directive {
|
match directive {
|
||||||
Directive2::Variable(..) => {}
|
Directive2::Variable(..) => {}
|
||||||
Directive2::Method(function2) => run_function(result, function2),
|
Directive2::Method(function2) => run_function(result, function2),
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_function<'input>(
|
fn run_function<'input>(
|
||||||
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) {
|
) {
|
||||||
function.body = function.body.take().map(|statements| {
|
function.body = function.body.take().map(|statements| {
|
||||||
statements
|
statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|statement| match statement {
|
.filter_map(|statement| match statement {
|
||||||
Statement::Variable(var @ ast::Variable {
|
Statement::Variable(var @ ast::Variable {
|
||||||
state_space:
|
state_space:
|
||||||
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
|
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
|
||||||
..
|
..
|
||||||
}) => {
|
}) => {
|
||||||
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
|
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
s => Some(s),
|
s => Some(s),
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1,404 +1,404 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
// This pass:
|
// This pass:
|
||||||
// * Turns all .local, .param and .reg in-body variables into .local variables
|
// * Turns all .local, .param and .reg in-body variables into .local variables
|
||||||
// (if _not_ an input method argument)
|
// (if _not_ an input method argument)
|
||||||
// * Inserts explicit `ld`/`st` for newly converted .reg variables
|
// * Inserts explicit `ld`/`st` for newly converted .reg variables
|
||||||
// * Fixup state space of all existing `ld`/`st` instructions into newly
|
// * Fixup state space of all existing `ld`/`st` instructions into newly
|
||||||
// converted variables
|
// converted variables
|
||||||
// * Turns `.entry` input arguments into param::entry and all related `.param`
|
// * Turns `.entry` input arguments into param::entry and all related `.param`
|
||||||
// loads into `param::entry` loads
|
// loads into `param::entry` loads
|
||||||
// * All `.func` input arguments are turned into `.reg` arguments by another
|
// * All `.func` input arguments are turned into `.reg` arguments by another
|
||||||
// pass, so we do nothing there
|
// pass, so we do nothing there
|
||||||
pub(super) fn run<'a, 'input>(
|
pub(super) fn run<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'a, 'input>(
|
fn run_directive<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(method) => {
|
Directive2::Method(method) => {
|
||||||
let visitor = InsertMemSSAVisitor::new(resolver);
|
let visitor = InsertMemSSAVisitor::new(resolver);
|
||||||
Directive2::Method(run_method(visitor, method)?)
|
Directive2::Method(run_method(visitor, method)?)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_method<'a, 'input>(
|
fn run_method<'a, 'input>(
|
||||||
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
||||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
let is_kernel = method.is_kernel;
|
let is_kernel = method.is_kernel;
|
||||||
if is_kernel {
|
if is_kernel {
|
||||||
for arg in method.input_arguments.iter_mut() {
|
for arg in method.input_arguments.iter_mut() {
|
||||||
let old_name = arg.name;
|
let old_name = arg.name;
|
||||||
let old_space = arg.state_space;
|
let old_space = arg.state_space;
|
||||||
let new_space = ast::StateSpace::ParamEntry;
|
let new_space = ast::StateSpace::ParamEntry;
|
||||||
let new_name = visitor
|
let new_name = visitor
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
||||||
visitor.input_argument(old_name, new_name, old_space)?;
|
visitor.input_argument(old_name, new_name, old_space)?;
|
||||||
arg.name = new_name;
|
arg.name = new_name;
|
||||||
arg.state_space = new_space;
|
arg.state_space = new_space;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
for arg in method.return_arguments.iter_mut() {
|
for arg in method.return_arguments.iter_mut() {
|
||||||
visitor.visit_variable(arg)?;
|
visitor.visit_variable(arg)?;
|
||||||
}
|
}
|
||||||
let return_arguments = &method.return_arguments[..];
|
let return_arguments = &method.return_arguments[..];
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(move |statements| {
|
.map(move |statements| {
|
||||||
let mut result = Vec::with_capacity(statements.len());
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
for statement in statements {
|
for statement in statements {
|
||||||
run_statement(&mut visitor, return_arguments, &mut result, statement)?;
|
run_statement(&mut visitor, return_arguments, &mut result, statement)?;
|
||||||
}
|
}
|
||||||
Ok::<_, TranslateError>(result)
|
Ok::<_, TranslateError>(result)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 { body, ..method })
|
Ok(Function2 { body, ..method })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'a, 'input>(
|
fn run_statement<'a, 'input>(
|
||||||
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
|
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
|
||||||
return_arguments: &[ast::Variable<SpirvWord>],
|
return_arguments: &[ast::Variable<SpirvWord>],
|
||||||
result: &mut Vec<ExpandedStatement>,
|
result: &mut Vec<ExpandedStatement>,
|
||||||
statement: ExpandedStatement,
|
statement: ExpandedStatement,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
match statement {
|
match statement {
|
||||||
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
||||||
let statement = if return_arguments.is_empty() {
|
let statement = if return_arguments.is_empty() {
|
||||||
Statement::Instruction(ast::Instruction::Ret { data })
|
Statement::Instruction(ast::Instruction::Ret { data })
|
||||||
} else {
|
} else {
|
||||||
Statement::RetValue(
|
Statement::RetValue(
|
||||||
data,
|
data,
|
||||||
return_arguments
|
return_arguments
|
||||||
.iter()
|
.iter()
|
||||||
.map(|arg| {
|
.map(|arg| {
|
||||||
if arg.state_space != ast::StateSpace::Local {
|
if arg.state_space != ast::StateSpace::Local {
|
||||||
return Err(error_unreachable());
|
return Err(error_unreachable());
|
||||||
}
|
}
|
||||||
Ok((arg.name, arg.v_type.clone()))
|
Ok((arg.name, arg.v_type.clone()))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?,
|
.collect::<Result<Vec<_>, _>>()?,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
let new_statement = statement.visit_map(visitor)?;
|
let new_statement = statement.visit_map(visitor)?;
|
||||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||||
result.push(new_statement);
|
result.push(new_statement);
|
||||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||||
}
|
}
|
||||||
Statement::Variable(mut var) => {
|
Statement::Variable(mut var) => {
|
||||||
visitor.visit_variable(&mut var)?;
|
visitor.visit_variable(&mut var)?;
|
||||||
result.push(Statement::Variable(var));
|
result.push(Statement::Variable(var));
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
|
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
|
||||||
let instruction = visitor.visit_ld(data, arguments)?;
|
let instruction = visitor.visit_ld(data, arguments)?;
|
||||||
let instruction = ast::visit_map(instruction, visitor)?;
|
let instruction = ast::visit_map(instruction, visitor)?;
|
||||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||||
result.push(Statement::Instruction(instruction));
|
result.push(Statement::Instruction(instruction));
|
||||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
|
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
|
||||||
let instruction = visitor.visit_st(data, arguments)?;
|
let instruction = visitor.visit_st(data, arguments)?;
|
||||||
let instruction = ast::visit_map(instruction, visitor)?;
|
let instruction = ast::visit_map(instruction, visitor)?;
|
||||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||||
result.push(Statement::Instruction(instruction));
|
result.push(Statement::Instruction(instruction));
|
||||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||||
}
|
}
|
||||||
Statement::PtrAccess(ptr_access) => {
|
Statement::PtrAccess(ptr_access) => {
|
||||||
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
|
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
|
||||||
let statement = statement.visit_map(visitor)?;
|
let statement = statement.visit_map(visitor)?;
|
||||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||||
result.push(statement);
|
result.push(statement);
|
||||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||||
}
|
}
|
||||||
s => {
|
s => {
|
||||||
let new_statement = s.visit_map(visitor)?;
|
let new_statement = s.visit_map(visitor)?;
|
||||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||||
result.push(new_statement);
|
result.push(new_statement);
|
||||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct InsertMemSSAVisitor<'a, 'input> {
|
struct InsertMemSSAVisitor<'a, 'input> {
|
||||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
variables: FxHashMap<SpirvWord, RemapAction>,
|
variables: FxHashMap<SpirvWord, RemapAction>,
|
||||||
pre: Vec<ast::Instruction<SpirvWord>>,
|
pre: Vec<ast::Instruction<SpirvWord>>,
|
||||||
post: Vec<ast::Instruction<SpirvWord>>,
|
post: Vec<ast::Instruction<SpirvWord>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
|
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
resolver,
|
resolver,
|
||||||
variables: FxHashMap::default(),
|
variables: FxHashMap::default(),
|
||||||
pre: Vec::new(),
|
pre: Vec::new(),
|
||||||
post: Vec::new(),
|
post: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn input_argument(
|
fn input_argument(
|
||||||
&mut self,
|
&mut self,
|
||||||
old_name: SpirvWord,
|
old_name: SpirvWord,
|
||||||
new_name: SpirvWord,
|
new_name: SpirvWord,
|
||||||
old_space: ast::StateSpace,
|
old_space: ast::StateSpace,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
if old_space != ast::StateSpace::Param {
|
if old_space != ast::StateSpace::Param {
|
||||||
return Err(error_unreachable());
|
return Err(error_unreachable());
|
||||||
}
|
}
|
||||||
self.variables.insert(
|
self.variables.insert(
|
||||||
old_name,
|
old_name,
|
||||||
RemapAction::LDStSpaceChange {
|
RemapAction::LDStSpaceChange {
|
||||||
name: new_name,
|
name: new_name,
|
||||||
old_space,
|
old_space,
|
||||||
new_space: ast::StateSpace::ParamEntry,
|
new_space: ast::StateSpace::ParamEntry,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn variable(
|
fn variable(
|
||||||
&mut self,
|
&mut self,
|
||||||
type_: &ast::Type,
|
type_: &ast::Type,
|
||||||
old_name: SpirvWord,
|
old_name: SpirvWord,
|
||||||
new_name: SpirvWord,
|
new_name: SpirvWord,
|
||||||
old_space: ast::StateSpace,
|
old_space: ast::StateSpace,
|
||||||
) -> Result<bool, TranslateError> {
|
) -> Result<bool, TranslateError> {
|
||||||
Ok(match old_space {
|
Ok(match old_space {
|
||||||
ast::StateSpace::Reg => {
|
ast::StateSpace::Reg => {
|
||||||
self.variables.insert(
|
self.variables.insert(
|
||||||
old_name,
|
old_name,
|
||||||
RemapAction::PreLdPostSt {
|
RemapAction::PreLdPostSt {
|
||||||
name: new_name,
|
name: new_name,
|
||||||
type_: type_.clone(),
|
type_: type_.clone(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
ast::StateSpace::Param => {
|
ast::StateSpace::Param => {
|
||||||
self.variables.insert(
|
self.variables.insert(
|
||||||
old_name,
|
old_name,
|
||||||
RemapAction::LDStSpaceChange {
|
RemapAction::LDStSpaceChange {
|
||||||
old_space,
|
old_space,
|
||||||
new_space: ast::StateSpace::Local,
|
new_space: ast::StateSpace::Local,
|
||||||
name: new_name,
|
name: new_name,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
// Good as-is
|
// Good as-is
|
||||||
ast::StateSpace::Local
|
ast::StateSpace::Local
|
||||||
| ast::StateSpace::Generic
|
| ast::StateSpace::Generic
|
||||||
| ast::StateSpace::SharedCluster
|
| ast::StateSpace::SharedCluster
|
||||||
| ast::StateSpace::Global
|
| ast::StateSpace::Global
|
||||||
| ast::StateSpace::Const
|
| ast::StateSpace::Const
|
||||||
| ast::StateSpace::SharedCta
|
| ast::StateSpace::SharedCta
|
||||||
| ast::StateSpace::Shared
|
| ast::StateSpace::Shared
|
||||||
| ast::StateSpace::ParamEntry
|
| ast::StateSpace::ParamEntry
|
||||||
| ast::StateSpace::ParamFunc => return Err(error_unreachable()),
|
| ast::StateSpace::ParamFunc => return Err(error_unreachable()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_st(
|
fn visit_st(
|
||||||
&self,
|
&self,
|
||||||
mut data: ast::StData,
|
mut data: ast::StData,
|
||||||
mut arguments: ast::StArgs<SpirvWord>,
|
mut arguments: ast::StArgs<SpirvWord>,
|
||||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||||
if let Some(remap) = self.variables.get(&arguments.src1) {
|
if let Some(remap) = self.variables.get(&arguments.src1) {
|
||||||
match remap {
|
match remap {
|
||||||
RemapAction::PreLdPostSt { .. } => {}
|
RemapAction::PreLdPostSt { .. } => {}
|
||||||
RemapAction::LDStSpaceChange {
|
RemapAction::LDStSpaceChange {
|
||||||
old_space,
|
old_space,
|
||||||
new_space,
|
new_space,
|
||||||
name,
|
name,
|
||||||
} => {
|
} => {
|
||||||
if data.state_space != *old_space {
|
if data.state_space != *old_space {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
data.state_space = *new_space;
|
data.state_space = *new_space;
|
||||||
arguments.src1 = *name;
|
arguments.src1 = *name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(ast::Instruction::St { data, arguments })
|
Ok(ast::Instruction::St { data, arguments })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_ld(
|
fn visit_ld(
|
||||||
&self,
|
&self,
|
||||||
mut data: ast::LdDetails,
|
mut data: ast::LdDetails,
|
||||||
mut arguments: ast::LdArgs<SpirvWord>,
|
mut arguments: ast::LdArgs<SpirvWord>,
|
||||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||||
if let Some(remap) = self.variables.get(&arguments.src) {
|
if let Some(remap) = self.variables.get(&arguments.src) {
|
||||||
match remap {
|
match remap {
|
||||||
RemapAction::PreLdPostSt { .. } => {}
|
RemapAction::PreLdPostSt { .. } => {}
|
||||||
RemapAction::LDStSpaceChange {
|
RemapAction::LDStSpaceChange {
|
||||||
old_space,
|
old_space,
|
||||||
new_space,
|
new_space,
|
||||||
name,
|
name,
|
||||||
} => {
|
} => {
|
||||||
if data.state_space != *old_space {
|
if data.state_space != *old_space {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
data.state_space = *new_space;
|
data.state_space = *new_space;
|
||||||
arguments.src = *name;
|
arguments.src = *name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(ast::Instruction::Ld { data, arguments })
|
Ok(ast::Instruction::Ld { data, arguments })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_ptr_access(
|
fn visit_ptr_access(
|
||||||
&mut self,
|
&mut self,
|
||||||
ptr_access: PtrAccess<SpirvWord>,
|
ptr_access: PtrAccess<SpirvWord>,
|
||||||
) -> Result<PtrAccess<SpirvWord>, TranslateError> {
|
) -> Result<PtrAccess<SpirvWord>, TranslateError> {
|
||||||
let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
|
let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
|
||||||
Some(RemapAction::LDStSpaceChange {
|
Some(RemapAction::LDStSpaceChange {
|
||||||
old_space,
|
old_space,
|
||||||
new_space,
|
new_space,
|
||||||
name,
|
name,
|
||||||
}) => (*old_space, *new_space, *name),
|
}) => (*old_space, *new_space, *name),
|
||||||
Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
|
Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
|
||||||
};
|
};
|
||||||
if ptr_access.state_space != old_space {
|
if ptr_access.state_space != old_space {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
// Propagate space changes in dst
|
// Propagate space changes in dst
|
||||||
let new_dst = self
|
let new_dst = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
|
.register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
|
||||||
self.variables.insert(
|
self.variables.insert(
|
||||||
ptr_access.dst,
|
ptr_access.dst,
|
||||||
RemapAction::LDStSpaceChange {
|
RemapAction::LDStSpaceChange {
|
||||||
old_space,
|
old_space,
|
||||||
new_space,
|
new_space,
|
||||||
name: new_dst,
|
name: new_dst,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
Ok(PtrAccess {
|
Ok(PtrAccess {
|
||||||
ptr_src: name,
|
ptr_src: name,
|
||||||
dst: new_dst,
|
dst: new_dst,
|
||||||
state_space: new_space,
|
state_space: new_space,
|
||||||
..ptr_access
|
..ptr_access
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
||||||
let old_space = match var.state_space {
|
let old_space = match var.state_space {
|
||||||
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
|
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
|
||||||
// Do nothing
|
// Do nothing
|
||||||
ptx_parser::StateSpace::Local => return Ok(()),
|
ptx_parser::StateSpace::Local => return Ok(()),
|
||||||
// Handled by another pass
|
// Handled by another pass
|
||||||
ptx_parser::StateSpace::Generic
|
ptx_parser::StateSpace::Generic
|
||||||
| ptx_parser::StateSpace::SharedCluster
|
| ptx_parser::StateSpace::SharedCluster
|
||||||
| ptx_parser::StateSpace::ParamEntry
|
| ptx_parser::StateSpace::ParamEntry
|
||||||
| ptx_parser::StateSpace::Global
|
| ptx_parser::StateSpace::Global
|
||||||
| ptx_parser::StateSpace::SharedCta
|
| ptx_parser::StateSpace::SharedCta
|
||||||
| ptx_parser::StateSpace::Const
|
| ptx_parser::StateSpace::Const
|
||||||
| ptx_parser::StateSpace::Shared
|
| ptx_parser::StateSpace::Shared
|
||||||
| ptx_parser::StateSpace::ParamFunc => return Ok(()),
|
| ptx_parser::StateSpace::ParamFunc => return Ok(()),
|
||||||
};
|
};
|
||||||
let old_name = var.name;
|
let old_name = var.name;
|
||||||
let new_space = ast::StateSpace::Local;
|
let new_space = ast::StateSpace::Local;
|
||||||
let new_name = self
|
let new_name = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
||||||
self.variable(&var.v_type, old_name, new_name, old_space)?;
|
self.variable(&var.v_type, old_name, new_name, old_space)?;
|
||||||
var.name = new_name;
|
var.name = new_name;
|
||||||
var.state_space = new_space;
|
var.state_space = new_space;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
|
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
|
||||||
for InsertMemSSAVisitor<'a, 'input>
|
for InsertMemSSAVisitor<'a, 'input>
|
||||||
{
|
{
|
||||||
fn visit(
|
fn visit(
|
||||||
&mut self,
|
&mut self,
|
||||||
ident: SpirvWord,
|
ident: SpirvWord,
|
||||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
_relaxed_type_check: bool,
|
_relaxed_type_check: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
if let Some(remap) = self.variables.get(&ident) {
|
if let Some(remap) = self.variables.get(&ident) {
|
||||||
match remap {
|
match remap {
|
||||||
RemapAction::PreLdPostSt { name, type_ } => {
|
RemapAction::PreLdPostSt { name, type_ } => {
|
||||||
if is_dst {
|
if is_dst {
|
||||||
let temp = self
|
let temp = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||||
self.post.push(ast::Instruction::St {
|
self.post.push(ast::Instruction::St {
|
||||||
data: ast::StData {
|
data: ast::StData {
|
||||||
state_space: ast::StateSpace::Local,
|
state_space: ast::StateSpace::Local,
|
||||||
qualifier: ast::LdStQualifier::Weak,
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
caching: ast::StCacheOperator::Writethrough,
|
caching: ast::StCacheOperator::Writethrough,
|
||||||
typ: type_.clone(),
|
typ: type_.clone(),
|
||||||
},
|
},
|
||||||
arguments: ast::StArgs {
|
arguments: ast::StArgs {
|
||||||
src1: *name,
|
src1: *name,
|
||||||
src2: temp,
|
src2: temp,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
Ok(temp)
|
Ok(temp)
|
||||||
} else {
|
} else {
|
||||||
let temp = self
|
let temp = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||||
self.pre.push(ast::Instruction::Ld {
|
self.pre.push(ast::Instruction::Ld {
|
||||||
data: ast::LdDetails {
|
data: ast::LdDetails {
|
||||||
state_space: ast::StateSpace::Local,
|
state_space: ast::StateSpace::Local,
|
||||||
qualifier: ast::LdStQualifier::Weak,
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
caching: ast::LdCacheOperator::Cached,
|
caching: ast::LdCacheOperator::Cached,
|
||||||
typ: type_.clone(),
|
typ: type_.clone(),
|
||||||
non_coherent: false,
|
non_coherent: false,
|
||||||
},
|
},
|
||||||
arguments: ast::LdArgs {
|
arguments: ast::LdArgs {
|
||||||
dst: temp,
|
dst: temp,
|
||||||
src: *name,
|
src: *name,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
Ok(temp)
|
Ok(temp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RemapAction::LDStSpaceChange { .. } => {
|
RemapAction::LDStSpaceChange { .. } => {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Ok(ident)
|
Ok(ident)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_ident(
|
fn visit_ident(
|
||||||
&mut self,
|
&mut self,
|
||||||
args: SpirvWord,
|
args: SpirvWord,
|
||||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
relaxed_type_check: bool,
|
relaxed_type_check: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
self.visit(args, type_space, is_dst, relaxed_type_check)
|
self.visit(args, type_space, is_dst, relaxed_type_check)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
enum RemapAction {
|
enum RemapAction {
|
||||||
PreLdPostSt {
|
PreLdPostSt {
|
||||||
name: SpirvWord,
|
name: SpirvWord,
|
||||||
type_: ast::Type,
|
type_: ast::Type,
|
||||||
},
|
},
|
||||||
LDStSpaceChange {
|
LDStSpaceChange {
|
||||||
old_space: ast::StateSpace,
|
old_space: ast::StateSpace,
|
||||||
new_space: ast::StateSpace,
|
new_space: ast::StateSpace,
|
||||||
name: SpirvWord,
|
name: SpirvWord,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1,401 +1,401 @@
|
|||||||
use std::mem;
|
use std::mem;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
There are several kinds of implicit conversions in PTX:
|
There are several kinds of implicit conversions in PTX:
|
||||||
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||||
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||||
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||||
semantics are to first zext/chop/bitcast `y` as needed and then do
|
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||||
documented special ld/st/cvt conversion rules for destination operands
|
documented special ld/st/cvt conversion rules for destination operands
|
||||||
- st.param [x] y (used as function return arguments) same rule as above applies
|
- st.param [x] y (used as function return arguments) same rule as above applies
|
||||||
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
||||||
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
||||||
documented special ld/st/cvt conversion rules are applied to dst
|
documented special ld/st/cvt conversion rules are applied to dst
|
||||||
- generic/global st: for instruction `st [x], y`, x must be of type
|
- generic/global st: for instruction `st [x], y`, x must be of type
|
||||||
b64/u64/s64, which is bitcast to a pointer
|
b64/u64/s64, which is bitcast to a pointer
|
||||||
*/
|
*/
|
||||||
pub(super) fn run<'input>(
|
pub(super) fn run<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'a, 'input>(
|
fn run_directive<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(mut method) => {
|
Directive2::Method(mut method) => {
|
||||||
method.body = method
|
method.body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| run_statements(resolver, statements))
|
.map(|statements| run_statements(resolver, statements))
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Directive2::Method(method)
|
Directive2::Method(method)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statements<'input>(
|
fn run_statements<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
func: Vec<ExpandedStatement>,
|
func: Vec<ExpandedStatement>,
|
||||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||||
let mut result = Vec::with_capacity(func.len());
|
let mut result = Vec::with_capacity(func.len());
|
||||||
for s in func.into_iter() {
|
for s in func.into_iter() {
|
||||||
insert_implicit_conversions_impl(resolver, &mut result, s)?;
|
insert_implicit_conversions_impl(resolver, &mut result, s)?;
|
||||||
}
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert_implicit_conversions_impl<'input>(
|
fn insert_implicit_conversions_impl<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
func: &mut Vec<ExpandedStatement>,
|
func: &mut Vec<ExpandedStatement>,
|
||||||
stmt: ExpandedStatement,
|
stmt: ExpandedStatement,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
let mut post_conv = Vec::new();
|
let mut post_conv = Vec::new();
|
||||||
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
||||||
&mut |operand,
|
&mut |operand,
|
||||||
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
is_dst,
|
is_dst,
|
||||||
relaxed_type_check| {
|
relaxed_type_check| {
|
||||||
let (instr_type, instruction_space) = match type_state {
|
let (instr_type, instruction_space) = match type_state {
|
||||||
None => return Ok(operand),
|
None => return Ok(operand),
|
||||||
Some(t) => t,
|
Some(t) => t,
|
||||||
};
|
};
|
||||||
let (operand_type, operand_space) = resolver.get_typed(operand)?;
|
let (operand_type, operand_space) = resolver.get_typed(operand)?;
|
||||||
let conversion_fn = if relaxed_type_check {
|
let conversion_fn = if relaxed_type_check {
|
||||||
if is_dst {
|
if is_dst {
|
||||||
should_convert_relaxed_dst_wrapper
|
should_convert_relaxed_dst_wrapper
|
||||||
} else {
|
} else {
|
||||||
should_convert_relaxed_src_wrapper
|
should_convert_relaxed_src_wrapper
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
default_implicit_conversion
|
default_implicit_conversion
|
||||||
};
|
};
|
||||||
match conversion_fn(
|
match conversion_fn(
|
||||||
(*operand_space, &operand_type),
|
(*operand_space, &operand_type),
|
||||||
(instruction_space, instr_type),
|
(instruction_space, instr_type),
|
||||||
)? {
|
)? {
|
||||||
Some(conv_kind) => {
|
Some(conv_kind) => {
|
||||||
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
||||||
let mut from_type = instr_type.clone();
|
let mut from_type = instr_type.clone();
|
||||||
let mut from_space = instruction_space;
|
let mut from_space = instruction_space;
|
||||||
let mut to_type = operand_type.clone();
|
let mut to_type = operand_type.clone();
|
||||||
let mut to_space = *operand_space;
|
let mut to_space = *operand_space;
|
||||||
let mut src =
|
let mut src =
|
||||||
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
|
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
|
||||||
let mut dst = operand;
|
let mut dst = operand;
|
||||||
let result = Ok::<_, TranslateError>(src);
|
let result = Ok::<_, TranslateError>(src);
|
||||||
if !is_dst {
|
if !is_dst {
|
||||||
mem::swap(&mut src, &mut dst);
|
mem::swap(&mut src, &mut dst);
|
||||||
mem::swap(&mut from_type, &mut to_type);
|
mem::swap(&mut from_type, &mut to_type);
|
||||||
mem::swap(&mut from_space, &mut to_space);
|
mem::swap(&mut from_space, &mut to_space);
|
||||||
}
|
}
|
||||||
conv_output.push(Statement::Conversion(ImplicitConversion {
|
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||||
src,
|
src,
|
||||||
dst,
|
dst,
|
||||||
from_type,
|
from_type,
|
||||||
from_space,
|
from_space,
|
||||||
to_type,
|
to_type,
|
||||||
to_space,
|
to_space,
|
||||||
kind: conv_kind,
|
kind: conv_kind,
|
||||||
}));
|
}));
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
None => Ok(operand),
|
None => Ok(operand),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
func.push(statement);
|
func.push(statement);
|
||||||
func.append(&mut post_conv);
|
func.append(&mut post_conv);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn default_implicit_conversion(
|
pub(crate) fn default_implicit_conversion(
|
||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if instruction_space == ast::StateSpace::Reg {
|
if instruction_space == ast::StateSpace::Reg {
|
||||||
if operand_space == ast::StateSpace::Reg {
|
if operand_space == ast::StateSpace::Reg {
|
||||||
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||||
(operand_type, instruction_type)
|
(operand_type, instruction_type)
|
||||||
{
|
{
|
||||||
if scalar.kind() == ast::ScalarKind::Bit
|
if scalar.kind() == ast::ScalarKind::Bit
|
||||||
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
||||||
{
|
{
|
||||||
return Ok(Some(ConversionKind::Default));
|
return Ok(Some(ConversionKind::Default));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if is_addressable(operand_space) {
|
} else if is_addressable(operand_space) {
|
||||||
return Ok(Some(ConversionKind::AddressOf));
|
return Ok(Some(ConversionKind::AddressOf));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if instruction_space != operand_space {
|
if instruction_space != operand_space {
|
||||||
default_implicit_conversion_space((operand_space, operand_type), instruction_space)
|
default_implicit_conversion_space((operand_space, operand_type), instruction_space)
|
||||||
} else if instruction_type != operand_type {
|
} else if instruction_type != operand_type {
|
||||||
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_addressable(this: ast::StateSpace) -> bool {
|
fn is_addressable(this: ast::StateSpace) -> bool {
|
||||||
match this {
|
match this {
|
||||||
ast::StateSpace::Const
|
ast::StateSpace::Const
|
||||||
| ast::StateSpace::Generic
|
| ast::StateSpace::Generic
|
||||||
| ast::StateSpace::Global
|
| ast::StateSpace::Global
|
||||||
| ast::StateSpace::Local
|
| ast::StateSpace::Local
|
||||||
| ast::StateSpace::Shared => true,
|
| ast::StateSpace::Shared => true,
|
||||||
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
||||||
ast::StateSpace::SharedCluster
|
ast::StateSpace::SharedCluster
|
||||||
| ast::StateSpace::SharedCta
|
| ast::StateSpace::SharedCta
|
||||||
| ast::StateSpace::ParamEntry
|
| ast::StateSpace::ParamEntry
|
||||||
| ast::StateSpace::ParamFunc => todo!(),
|
| ast::StateSpace::ParamFunc => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Space is different
|
// Space is different
|
||||||
fn default_implicit_conversion_space(
|
fn default_implicit_conversion_space(
|
||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
instruction_space: ast::StateSpace,
|
instruction_space: ast::StateSpace,
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
||||||
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||||
{
|
{
|
||||||
Ok(Some(ConversionKind::PtrToPtr))
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
} else if operand_space == ast::StateSpace::Reg {
|
} else if operand_space == ast::StateSpace::Reg {
|
||||||
match operand_type {
|
match operand_type {
|
||||||
// TODO: 32 bit
|
// TODO: 32 bit
|
||||||
ast::Type::Scalar(ast::ScalarType::B64)
|
ast::Type::Scalar(ast::ScalarType::B64)
|
||||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||||
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
||||||
ast::StateSpace::Global
|
ast::StateSpace::Global
|
||||||
| ast::StateSpace::Generic
|
| ast::StateSpace::Generic
|
||||||
| ast::StateSpace::Const
|
| ast::StateSpace::Const
|
||||||
| ast::StateSpace::Local
|
| ast::StateSpace::Local
|
||||||
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||||
_ => Err(error_mismatched_type()),
|
_ => Err(error_mismatched_type()),
|
||||||
},
|
},
|
||||||
ast::Type::Scalar(ast::ScalarType::B32)
|
ast::Type::Scalar(ast::ScalarType::B32)
|
||||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||||
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||||
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||||
Ok(Some(ConversionKind::BitToPtr))
|
Ok(Some(ConversionKind::BitToPtr))
|
||||||
}
|
}
|
||||||
_ => Err(error_mismatched_type()),
|
_ => Err(error_mismatched_type()),
|
||||||
},
|
},
|
||||||
_ => Err(error_mismatched_type()),
|
_ => Err(error_mismatched_type()),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Err(error_mismatched_type())
|
Err(error_mismatched_type())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Space is same, but type is different
|
// Space is same, but type is different
|
||||||
fn default_implicit_conversion_type(
|
fn default_implicit_conversion_type(
|
||||||
space: ast::StateSpace,
|
space: ast::StateSpace,
|
||||||
operand_type: &ast::Type,
|
operand_type: &ast::Type,
|
||||||
instruction_type: &ast::Type,
|
instruction_type: &ast::Type,
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if space == ast::StateSpace::Reg {
|
if space == ast::StateSpace::Reg {
|
||||||
if should_bitcast(instruction_type, operand_type) {
|
if should_bitcast(instruction_type, operand_type) {
|
||||||
Ok(Some(ConversionKind::Default))
|
Ok(Some(ConversionKind::Default))
|
||||||
} else {
|
} else {
|
||||||
Err(TranslateError::MismatchedType)
|
Err(TranslateError::MismatchedType)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Ok(Some(ConversionKind::PtrToPtr))
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||||
match this {
|
match this {
|
||||||
ast::StateSpace::Global
|
ast::StateSpace::Global
|
||||||
| ast::StateSpace::Const
|
| ast::StateSpace::Const
|
||||||
| ast::StateSpace::Local
|
| ast::StateSpace::Local
|
||||||
| ptx_parser::StateSpace::SharedCta
|
| ptx_parser::StateSpace::SharedCta
|
||||||
| ast::StateSpace::SharedCluster
|
| ast::StateSpace::SharedCluster
|
||||||
| ast::StateSpace::Shared => true,
|
| ast::StateSpace::Shared => true,
|
||||||
ast::StateSpace::Reg
|
ast::StateSpace::Reg
|
||||||
| ast::StateSpace::Param
|
| ast::StateSpace::Param
|
||||||
| ast::StateSpace::ParamEntry
|
| ast::StateSpace::ParamEntry
|
||||||
| ast::StateSpace::ParamFunc
|
| ast::StateSpace::ParamFunc
|
||||||
| ast::StateSpace::Generic => false,
|
| ast::StateSpace::Generic => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||||
match (instr, operand) {
|
match (instr, operand) {
|
||||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||||
if inst.size_of() != operand.size_of() {
|
if inst.size_of() != operand.size_of() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
match inst.kind() {
|
match inst.kind() {
|
||||||
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||||
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||||
ast::ScalarKind::Signed => {
|
ast::ScalarKind::Signed => {
|
||||||
operand.kind() == ast::ScalarKind::Bit
|
operand.kind() == ast::ScalarKind::Bit
|
||||||
|| operand.kind() == ast::ScalarKind::Unsigned
|
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Unsigned => {
|
ast::ScalarKind::Unsigned => {
|
||||||
operand.kind() == ast::ScalarKind::Bit
|
operand.kind() == ast::ScalarKind::Bit
|
||||||
|| operand.kind() == ast::ScalarKind::Signed
|
|| operand.kind() == ast::ScalarKind::Signed
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Pred => false,
|
ast::ScalarKind::Pred => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
||||||
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
||||||
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||||
}
|
}
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if operand_space != instruction_space {
|
if operand_space != instruction_space {
|
||||||
return Err(TranslateError::MismatchedType);
|
return Err(TranslateError::MismatchedType);
|
||||||
}
|
}
|
||||||
if operand_type == instruction_type {
|
if operand_type == instruction_type {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||||
conv @ Some(_) => Ok(conv),
|
conv @ Some(_) => Ok(conv),
|
||||||
None => Err(TranslateError::MismatchedType),
|
None => Err(TranslateError::MismatchedType),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||||
fn should_convert_relaxed_dst(
|
fn should_convert_relaxed_dst(
|
||||||
dst_type: &ast::Type,
|
dst_type: &ast::Type,
|
||||||
instr_type: &ast::Type,
|
instr_type: &ast::Type,
|
||||||
) -> Option<ConversionKind> {
|
) -> Option<ConversionKind> {
|
||||||
if dst_type == instr_type {
|
if dst_type == instr_type {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
match (dst_type, instr_type) {
|
match (dst_type, instr_type) {
|
||||||
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||||
ast::ScalarKind::Bit => {
|
ast::ScalarKind::Bit => {
|
||||||
if instr_type.size_of() <= dst_type.size_of() {
|
if instr_type.size_of() <= dst_type.size_of() {
|
||||||
Some(ConversionKind::Default)
|
Some(ConversionKind::Default)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Signed => {
|
ast::ScalarKind::Signed => {
|
||||||
if dst_type.kind() != ast::ScalarKind::Float {
|
if dst_type.kind() != ast::ScalarKind::Float {
|
||||||
if instr_type.size_of() == dst_type.size_of() {
|
if instr_type.size_of() == dst_type.size_of() {
|
||||||
Some(ConversionKind::Default)
|
Some(ConversionKind::Default)
|
||||||
} else if instr_type.size_of() < dst_type.size_of() {
|
} else if instr_type.size_of() < dst_type.size_of() {
|
||||||
Some(ConversionKind::SignExtend)
|
Some(ConversionKind::SignExtend)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Unsigned => {
|
ast::ScalarKind::Unsigned => {
|
||||||
if instr_type.size_of() <= dst_type.size_of()
|
if instr_type.size_of() <= dst_type.size_of()
|
||||||
&& dst_type.kind() != ast::ScalarKind::Float
|
&& dst_type.kind() != ast::ScalarKind::Float
|
||||||
{
|
{
|
||||||
Some(ConversionKind::Default)
|
Some(ConversionKind::Default)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Float => {
|
ast::ScalarKind::Float => {
|
||||||
if instr_type.size_of() <= dst_type.size_of()
|
if instr_type.size_of() <= dst_type.size_of()
|
||||||
&& dst_type.kind() == ast::ScalarKind::Bit
|
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||||
{
|
{
|
||||||
Some(ConversionKind::Default)
|
Some(ConversionKind::Default)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Pred => None,
|
ast::ScalarKind::Pred => None,
|
||||||
},
|
},
|
||||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||||
should_convert_relaxed_dst(
|
should_convert_relaxed_dst(
|
||||||
&ast::Type::Scalar(*dst_type),
|
&ast::Type::Scalar(*dst_type),
|
||||||
&ast::Type::Scalar(*instr_type),
|
&ast::Type::Scalar(*instr_type),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn should_convert_relaxed_src_wrapper(
|
pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if operand_space != instruction_space {
|
if operand_space != instruction_space {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
if operand_type == instruction_type {
|
if operand_type == instruction_type {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
match should_convert_relaxed_src(operand_type, instruction_type) {
|
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||||
conv @ Some(_) => Ok(conv),
|
conv @ Some(_) => Ok(conv),
|
||||||
None => Err(error_mismatched_type()),
|
None => Err(error_mismatched_type()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||||
fn should_convert_relaxed_src(
|
fn should_convert_relaxed_src(
|
||||||
src_type: &ast::Type,
|
src_type: &ast::Type,
|
||||||
instr_type: &ast::Type,
|
instr_type: &ast::Type,
|
||||||
) -> Option<ConversionKind> {
|
) -> Option<ConversionKind> {
|
||||||
if src_type == instr_type {
|
if src_type == instr_type {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
match (src_type, instr_type) {
|
match (src_type, instr_type) {
|
||||||
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||||
ast::ScalarKind::Bit => {
|
ast::ScalarKind::Bit => {
|
||||||
if instr_type.size_of() <= src_type.size_of() {
|
if instr_type.size_of() <= src_type.size_of() {
|
||||||
Some(ConversionKind::Default)
|
Some(ConversionKind::Default)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||||
if instr_type.size_of() <= src_type.size_of()
|
if instr_type.size_of() <= src_type.size_of()
|
||||||
&& src_type.kind() != ast::ScalarKind::Float
|
&& src_type.kind() != ast::ScalarKind::Float
|
||||||
{
|
{
|
||||||
Some(ConversionKind::Default)
|
Some(ConversionKind::Default)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Float => {
|
ast::ScalarKind::Float => {
|
||||||
if instr_type.size_of() <= src_type.size_of()
|
if instr_type.size_of() <= src_type.size_of()
|
||||||
&& src_type.kind() == ast::ScalarKind::Bit
|
&& src_type.kind() == ast::ScalarKind::Bit
|
||||||
{
|
{
|
||||||
Some(ConversionKind::Default)
|
Some(ConversionKind::Default)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::ScalarKind::Pred => None,
|
ast::ScalarKind::Pred => None,
|
||||||
},
|
},
|
||||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||||
should_convert_relaxed_src(
|
should_convert_relaxed_src(
|
||||||
&ast::Type::Scalar(*dst_type),
|
&ast::Type::Scalar(*dst_type),
|
||||||
&ast::Type::Scalar(*instr_type),
|
&ast::Type::Scalar(*instr_type),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,194 +1,194 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
|
|
||||||
pub(crate) fn run<'input, 'b>(
|
pub(crate) fn run<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||||
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
|
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
|
||||||
resolver.start_scope();
|
resolver.start_scope();
|
||||||
let result = directives
|
let result = directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
resolver.end_scope();
|
resolver.end_scope();
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'input, 'b>(
|
fn run_directive<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||||
) -> Result<NormalizedDirective2, TranslateError> {
|
) -> Result<NormalizedDirective2, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
ast::Directive::Variable(linking, var) => {
|
ast::Directive::Variable(linking, var) => {
|
||||||
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||||
}
|
}
|
||||||
ast::Directive::Method(linking, directive) => {
|
ast::Directive::Method(linking, directive) => {
|
||||||
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
|
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_method<'input, 'b>(
|
fn run_method<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
linkage: ast::LinkingDirective,
|
linkage: ast::LinkingDirective,
|
||||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
) -> Result<NormalizedFunction2, TranslateError> {
|
) -> Result<NormalizedFunction2, TranslateError> {
|
||||||
let is_kernel = method.func_directive.name.is_kernel();
|
let is_kernel = method.func_directive.name.is_kernel();
|
||||||
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
|
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
|
||||||
resolver.start_scope();
|
resolver.start_scope();
|
||||||
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
|
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
let mut result = Vec::with_capacity(statements.len());
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
run_statements(resolver, &mut result, statements)?;
|
run_statements(resolver, &mut result, statements)?;
|
||||||
Ok::<_, TranslateError>(result)
|
Ok::<_, TranslateError>(result)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
resolver.end_scope();
|
resolver.end_scope();
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
return_arguments,
|
return_arguments,
|
||||||
name,
|
name,
|
||||||
input_arguments,
|
input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: None,
|
import_as: None,
|
||||||
linkage,
|
linkage,
|
||||||
is_kernel,
|
is_kernel,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
flush_to_zero_f32: false,
|
flush_to_zero_f32: false,
|
||||||
flush_to_zero_f16f64: false,
|
flush_to_zero_f16f64: false,
|
||||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_function_decl<'input, 'b>(
|
fn run_function_decl<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||||
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
||||||
assert!(func_directive.shared_mem.is_none());
|
assert!(func_directive.shared_mem.is_none());
|
||||||
let return_arguments = func_directive
|
let return_arguments = func_directive
|
||||||
.return_arguments
|
.return_arguments
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|var| run_variable(resolver, var))
|
.map(|var| run_variable(resolver, var))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let input_arguments = func_directive
|
let input_arguments = func_directive
|
||||||
.input_arguments
|
.input_arguments
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|var| run_variable(resolver, var))
|
.map(|var| run_variable(resolver, var))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
Ok((return_arguments, input_arguments))
|
Ok((return_arguments, input_arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_variable<'input, 'b>(
|
fn run_variable<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
variable: ast::Variable<&'input str>,
|
variable: ast::Variable<&'input str>,
|
||||||
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
||||||
Ok(ast::Variable {
|
Ok(ast::Variable {
|
||||||
name: resolver.add(
|
name: resolver.add(
|
||||||
Cow::Borrowed(variable.name),
|
Cow::Borrowed(variable.name),
|
||||||
Some((variable.v_type.clone(), variable.state_space)),
|
Some((variable.v_type.clone(), variable.state_space)),
|
||||||
)?,
|
)?,
|
||||||
align: variable.align,
|
align: variable.align,
|
||||||
v_type: variable.v_type,
|
v_type: variable.v_type,
|
||||||
state_space: variable.state_space,
|
state_space: variable.state_space,
|
||||||
array_init: variable.array_init,
|
array_init: variable.array_init,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statements<'input, 'b>(
|
fn run_statements<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
result: &mut Vec<NormalizedStatement>,
|
result: &mut Vec<NormalizedStatement>,
|
||||||
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
for statement in statements.iter() {
|
for statement in statements.iter() {
|
||||||
match statement {
|
match statement {
|
||||||
ast::Statement::Label(label) => {
|
ast::Statement::Label(label) => {
|
||||||
resolver.add(Cow::Borrowed(*label), None)?;
|
resolver.add(Cow::Borrowed(*label), None)?;
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for statement in statements {
|
for statement in statements {
|
||||||
match statement {
|
match statement {
|
||||||
ast::Statement::Label(label) => {
|
ast::Statement::Label(label) => {
|
||||||
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
||||||
}
|
}
|
||||||
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
|
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
|
||||||
ast::Statement::Instruction(predicate, instruction) => {
|
ast::Statement::Instruction(predicate, instruction) => {
|
||||||
result.push(Statement::Instruction((
|
result.push(Statement::Instruction((
|
||||||
predicate
|
predicate
|
||||||
.map(|pred| {
|
.map(|pred| {
|
||||||
Ok::<_, TranslateError>(ast::PredAt {
|
Ok::<_, TranslateError>(ast::PredAt {
|
||||||
not: pred.not,
|
not: pred.not,
|
||||||
label: resolver.get(pred.label)?,
|
label: resolver.get(pred.label)?,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.transpose()?,
|
.transpose()?,
|
||||||
run_instruction(resolver, instruction)?,
|
run_instruction(resolver, instruction)?,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
ast::Statement::Block(block) => {
|
ast::Statement::Block(block) => {
|
||||||
resolver.start_scope();
|
resolver.start_scope();
|
||||||
run_statements(resolver, result, block)?;
|
run_statements(resolver, result, block)?;
|
||||||
resolver.end_scope();
|
resolver.end_scope();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_instruction<'input, 'b>(
|
fn run_instruction<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
||||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||||
ast::visit_map(instruction, &mut |name: &'input str,
|
ast::visit_map(instruction, &mut |name: &'input str,
|
||||||
_: Option<(
|
_: Option<(
|
||||||
&ast::Type,
|
&ast::Type,
|
||||||
ast::StateSpace,
|
ast::StateSpace,
|
||||||
)>,
|
)>,
|
||||||
_,
|
_,
|
||||||
_| {
|
_| {
|
||||||
resolver.get(&name)
|
resolver.get(&name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_multivariable<'input, 'b>(
|
fn run_multivariable<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
result: &mut Vec<NormalizedStatement>,
|
result: &mut Vec<NormalizedStatement>,
|
||||||
variable: ast::MultiVariable<&'input str>,
|
variable: ast::MultiVariable<&'input str>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
match variable.count {
|
match variable.count {
|
||||||
Some(count) => {
|
Some(count) => {
|
||||||
for i in 0..count {
|
for i in 0..count {
|
||||||
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
|
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
|
||||||
let ident = resolver.add(
|
let ident = resolver.add(
|
||||||
name,
|
name,
|
||||||
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||||
)?;
|
)?;
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: variable.var.align,
|
align: variable.var.align,
|
||||||
v_type: variable.var.v_type.clone(),
|
v_type: variable.var.v_type.clone(),
|
||||||
state_space: variable.var.state_space,
|
state_space: variable.var.state_space,
|
||||||
name: ident,
|
name: ident,
|
||||||
array_init: variable.var.array_init.clone(),
|
array_init: variable.var.array_init.clone(),
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let name = Cow::Borrowed(variable.var.name);
|
let name = Cow::Borrowed(variable.var.name);
|
||||||
let ident = resolver.add(
|
let ident = resolver.add(
|
||||||
name,
|
name,
|
||||||
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||||
)?;
|
)?;
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: variable.var.align,
|
align: variable.var.align,
|
||||||
v_type: variable.var.v_type.clone(),
|
v_type: variable.var.v_type.clone(),
|
||||||
state_space: variable.var.state_space,
|
state_space: variable.var.state_space,
|
||||||
name: ident,
|
name: ident,
|
||||||
array_init: variable.var.array_init.clone(),
|
array_init: variable.var.array_init.clone(),
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,90 +1,90 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<NormalizedDirective2>,
|
directives: Vec<NormalizedDirective2>,
|
||||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directive: NormalizedDirective2,
|
directive: NormalizedDirective2,
|
||||||
) -> Result<UnconditionalDirective, TranslateError> {
|
) -> Result<UnconditionalDirective, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_method<'input>(
|
fn run_method<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
method: NormalizedFunction2,
|
method: NormalizedFunction2,
|
||||||
) -> Result<UnconditionalFunction, TranslateError> {
|
) -> Result<UnconditionalFunction, TranslateError> {
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
let mut result = Vec::with_capacity(statements.len());
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
for statement in statements {
|
for statement in statements {
|
||||||
run_statement(resolver, &mut result, statement)?;
|
run_statement(resolver, &mut result, statement)?;
|
||||||
}
|
}
|
||||||
Ok::<_, TranslateError>(result)
|
Ok::<_, TranslateError>(result)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
body,
|
body,
|
||||||
return_arguments: method.return_arguments,
|
return_arguments: method.return_arguments,
|
||||||
name: method.name,
|
name: method.name,
|
||||||
input_arguments: method.input_arguments,
|
input_arguments: method.input_arguments,
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
is_kernel: method.is_kernel,
|
is_kernel: method.is_kernel,
|
||||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||||
rounding_mode_f32: method.rounding_mode_f32,
|
rounding_mode_f32: method.rounding_mode_f32,
|
||||||
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'input>(
|
fn run_statement<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
result: &mut Vec<UnconditionalStatement>,
|
result: &mut Vec<UnconditionalStatement>,
|
||||||
statement: NormalizedStatement,
|
statement: NormalizedStatement,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
Ok(match statement {
|
Ok(match statement {
|
||||||
Statement::Label(label) => result.push(Statement::Label(label)),
|
Statement::Label(label) => result.push(Statement::Label(label)),
|
||||||
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||||
Statement::Instruction((predicate, instruction)) => {
|
Statement::Instruction((predicate, instruction)) => {
|
||||||
if let Some(pred) = predicate {
|
if let Some(pred) = predicate {
|
||||||
let if_true = resolver.register_unnamed(None);
|
let if_true = resolver.register_unnamed(None);
|
||||||
let if_false = resolver.register_unnamed(None);
|
let if_false = resolver.register_unnamed(None);
|
||||||
let folded_bra = match &instruction {
|
let folded_bra = match &instruction {
|
||||||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
let mut branch = BrachCondition {
|
let mut branch = BrachCondition {
|
||||||
predicate: pred.label,
|
predicate: pred.label,
|
||||||
if_true: folded_bra.unwrap_or(if_true),
|
if_true: folded_bra.unwrap_or(if_true),
|
||||||
if_false,
|
if_false,
|
||||||
};
|
};
|
||||||
if pred.not {
|
if pred.not {
|
||||||
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||||
}
|
}
|
||||||
result.push(Statement::Conditional(branch));
|
result.push(Statement::Conditional(branch));
|
||||||
if folded_bra.is_none() {
|
if folded_bra.is_none() {
|
||||||
result.push(Statement::Label(if_true));
|
result.push(Statement::Label(if_true));
|
||||||
result.push(Statement::Instruction(instruction));
|
result.push(Statement::Instruction(instruction));
|
||||||
}
|
}
|
||||||
result.push(Statement::Label(if_false));
|
result.push(Statement::Label(if_false));
|
||||||
} else {
|
} else {
|
||||||
result.push(Statement::Instruction(instruction));
|
result.push(Statement::Instruction(instruction));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,268 +1,268 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'input>(
|
pub(super) fn run<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
let mut fn_declarations = FxHashMap::default();
|
let mut fn_declarations = FxHashMap::default();
|
||||||
let remapped_directives = directives
|
let remapped_directives = directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, &mut fn_declarations, directive))
|
.map(|directive| run_directive(resolver, &mut fn_declarations, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let mut result = fn_declarations
|
let mut result = fn_declarations
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(_, (return_arguments, name, input_arguments))| {
|
.map(|(_, (return_arguments, name, input_arguments))| {
|
||||||
Directive2::Method(Function2 {
|
Directive2::Method(Function2 {
|
||||||
return_arguments,
|
return_arguments,
|
||||||
name: name,
|
name: name,
|
||||||
input_arguments,
|
input_arguments,
|
||||||
body: None,
|
body: None,
|
||||||
import_as: None,
|
import_as: None,
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
linkage: ast::LinkingDirective::EXTERN,
|
linkage: ast::LinkingDirective::EXTERN,
|
||||||
is_kernel: false,
|
is_kernel: false,
|
||||||
flush_to_zero_f32: false,
|
flush_to_zero_f32: false,
|
||||||
flush_to_zero_f16f64: false,
|
flush_to_zero_f16f64: false,
|
||||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
result.extend(remapped_directives);
|
result.extend(remapped_directives);
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
fn_declarations: &mut FxHashMap<
|
fn_declarations: &mut FxHashMap<
|
||||||
Cow<'input, str>,
|
Cow<'input, str>,
|
||||||
(
|
(
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
SpirvWord,
|
SpirvWord,
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
),
|
),
|
||||||
>,
|
>,
|
||||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(mut method) => {
|
Directive2::Method(mut method) => {
|
||||||
method.body = method
|
method.body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| run_statements(resolver, fn_declarations, statements))
|
.map(|statements| run_statements(resolver, fn_declarations, statements))
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Directive2::Method(method)
|
Directive2::Method(method)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statements<'input>(
|
fn run_statements<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
fn_declarations: &mut FxHashMap<
|
fn_declarations: &mut FxHashMap<
|
||||||
Cow<'input, str>,
|
Cow<'input, str>,
|
||||||
(
|
(
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
SpirvWord,
|
SpirvWord,
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
),
|
),
|
||||||
>,
|
>,
|
||||||
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
statements
|
statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|statement| {
|
.map(|statement| {
|
||||||
Ok(match statement {
|
Ok(match statement {
|
||||||
Statement::Instruction(instruction) => {
|
Statement::Instruction(instruction) => {
|
||||||
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
|
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
|
||||||
}
|
}
|
||||||
s => s,
|
s => s,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_instruction<'input>(
|
fn run_instruction<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
fn_declarations: &mut FxHashMap<
|
fn_declarations: &mut FxHashMap<
|
||||||
Cow<'input, str>,
|
Cow<'input, str>,
|
||||||
(
|
(
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
SpirvWord,
|
SpirvWord,
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
),
|
),
|
||||||
>,
|
>,
|
||||||
instruction: ptx_parser::Instruction<SpirvWord>,
|
instruction: ptx_parser::Instruction<SpirvWord>,
|
||||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||||
Ok(match instruction {
|
Ok(match instruction {
|
||||||
i @ ptx_parser::Instruction::Sqrt {
|
i @ ptx_parser::Instruction::Sqrt {
|
||||||
data:
|
data:
|
||||||
ast::RcpData {
|
ast::RcpData {
|
||||||
kind: ast::RcpKind::Approx,
|
kind: ast::RcpKind::Approx,
|
||||||
type_: ast::ScalarType::F32,
|
type_: ast::ScalarType::F32,
|
||||||
flush_to_zero: None | Some(false),
|
flush_to_zero: None | Some(false),
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
|
} => to_call(resolver, fn_declarations, "sqrt_approx_f32".into(), i)?,
|
||||||
i @ ptx_parser::Instruction::Rsqrt {
|
i @ ptx_parser::Instruction::Rsqrt {
|
||||||
data:
|
data:
|
||||||
ast::TypeFtz {
|
ast::TypeFtz {
|
||||||
type_: ast::ScalarType::F32,
|
type_: ast::ScalarType::F32,
|
||||||
flush_to_zero: None | Some(false),
|
flush_to_zero: None | Some(false),
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
|
} => to_call(resolver, fn_declarations, "rsqrt_approx_f32".into(), i)?,
|
||||||
i @ ptx_parser::Instruction::Rcp {
|
i @ ptx_parser::Instruction::Rcp {
|
||||||
data:
|
data:
|
||||||
ast::RcpData {
|
ast::RcpData {
|
||||||
kind: ast::RcpKind::Approx,
|
kind: ast::RcpKind::Approx,
|
||||||
type_: ast::ScalarType::F32,
|
type_: ast::ScalarType::F32,
|
||||||
flush_to_zero: None | Some(false),
|
flush_to_zero: None | Some(false),
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
|
} => to_call(resolver, fn_declarations, "rcp_approx_f32".into(), i)?,
|
||||||
i @ ptx_parser::Instruction::Ex2 {
|
i @ ptx_parser::Instruction::Ex2 {
|
||||||
data:
|
data:
|
||||||
ast::TypeFtz {
|
ast::TypeFtz {
|
||||||
type_: ast::ScalarType::F32,
|
type_: ast::ScalarType::F32,
|
||||||
flush_to_zero: None | Some(false),
|
flush_to_zero: None | Some(false),
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
|
} => to_call(resolver, fn_declarations, "ex2_approx_f32".into(), i)?,
|
||||||
i @ ptx_parser::Instruction::Lg2 {
|
i @ ptx_parser::Instruction::Lg2 {
|
||||||
data: ast::FlushToZero {
|
data: ast::FlushToZero {
|
||||||
flush_to_zero: false,
|
flush_to_zero: false,
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
|
} => to_call(resolver, fn_declarations, "lg2_approx_f32".into(), i)?,
|
||||||
i @ ptx_parser::Instruction::Activemask { .. } => {
|
i @ ptx_parser::Instruction::Activemask { .. } => {
|
||||||
to_call(resolver, fn_declarations, "activemask".into(), i)?
|
to_call(resolver, fn_declarations, "activemask".into(), i)?
|
||||||
}
|
}
|
||||||
i @ ptx_parser::Instruction::Bfe { data, .. } => {
|
i @ ptx_parser::Instruction::Bfe { data, .. } => {
|
||||||
let name = ["bfe_", scalar_to_ptx_name(data)].concat();
|
let name = ["bfe_", scalar_to_ptx_name(data)].concat();
|
||||||
to_call(resolver, fn_declarations, name.into(), i)?
|
to_call(resolver, fn_declarations, name.into(), i)?
|
||||||
}
|
}
|
||||||
i @ ptx_parser::Instruction::Bfi { data, .. } => {
|
i @ ptx_parser::Instruction::Bfi { data, .. } => {
|
||||||
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
|
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
|
||||||
to_call(resolver, fn_declarations, name.into(), i)?
|
to_call(resolver, fn_declarations, name.into(), i)?
|
||||||
}
|
}
|
||||||
i @ ptx_parser::Instruction::Bar { .. } => {
|
i @ ptx_parser::Instruction::Bar { .. } => {
|
||||||
to_call(resolver, fn_declarations, "bar_sync".into(), i)?
|
to_call(resolver, fn_declarations, "bar_sync".into(), i)?
|
||||||
}
|
}
|
||||||
ptx_parser::Instruction::BarRed { data, arguments } => {
|
ptx_parser::Instruction::BarRed { data, arguments } => {
|
||||||
if arguments.src_threadcount.is_some() {
|
if arguments.src_threadcount.is_some() {
|
||||||
return Err(error_todo());
|
return Err(error_todo());
|
||||||
}
|
}
|
||||||
let name = match data.pred_reduction {
|
let name = match data.pred_reduction {
|
||||||
ptx_parser::Reduction::And => "bar_red_and_pred",
|
ptx_parser::Reduction::And => "bar_red_and_pred",
|
||||||
ptx_parser::Reduction::Or => "bar_red_or_pred",
|
ptx_parser::Reduction::Or => "bar_red_or_pred",
|
||||||
};
|
};
|
||||||
to_call(
|
to_call(
|
||||||
resolver,
|
resolver,
|
||||||
fn_declarations,
|
fn_declarations,
|
||||||
name.into(),
|
name.into(),
|
||||||
ptx_parser::Instruction::BarRed { data, arguments },
|
ptx_parser::Instruction::BarRed { data, arguments },
|
||||||
)?
|
)?
|
||||||
}
|
}
|
||||||
ptx_parser::Instruction::ShflSync { data, arguments } => {
|
ptx_parser::Instruction::ShflSync { data, arguments } => {
|
||||||
let mode = match data.mode {
|
let mode = match data.mode {
|
||||||
ptx_parser::ShuffleMode::Up => "up",
|
ptx_parser::ShuffleMode::Up => "up",
|
||||||
ptx_parser::ShuffleMode::Down => "down",
|
ptx_parser::ShuffleMode::Down => "down",
|
||||||
ptx_parser::ShuffleMode::BFly => "bfly",
|
ptx_parser::ShuffleMode::BFly => "bfly",
|
||||||
ptx_parser::ShuffleMode::Idx => "idx",
|
ptx_parser::ShuffleMode::Idx => "idx",
|
||||||
};
|
};
|
||||||
let pred = if arguments.dst_pred.is_some() {
|
let pred = if arguments.dst_pred.is_some() {
|
||||||
"_pred"
|
"_pred"
|
||||||
} else {
|
} else {
|
||||||
""
|
""
|
||||||
};
|
};
|
||||||
to_call(
|
to_call(
|
||||||
resolver,
|
resolver,
|
||||||
fn_declarations,
|
fn_declarations,
|
||||||
format!("shfl_sync_{}_b32{}", mode, pred).into(),
|
format!("shfl_sync_{}_b32{}", mode, pred).into(),
|
||||||
ptx_parser::Instruction::ShflSync { data, arguments },
|
ptx_parser::Instruction::ShflSync { data, arguments },
|
||||||
)?
|
)?
|
||||||
}
|
}
|
||||||
i @ ptx_parser::Instruction::Nanosleep { .. } => {
|
i @ ptx_parser::Instruction::Nanosleep { .. } => {
|
||||||
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
|
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
|
||||||
}
|
}
|
||||||
i => i,
|
i => i,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_call<'input>(
|
fn to_call<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
fn_declarations: &mut FxHashMap<
|
fn_declarations: &mut FxHashMap<
|
||||||
Cow<'input, str>,
|
Cow<'input, str>,
|
||||||
(
|
(
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
SpirvWord,
|
SpirvWord,
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
),
|
),
|
||||||
>,
|
>,
|
||||||
name: Cow<'input, str>,
|
name: Cow<'input, str>,
|
||||||
i: ast::Instruction<SpirvWord>,
|
i: ast::Instruction<SpirvWord>,
|
||||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||||
let mut data_return = Vec::new();
|
let mut data_return = Vec::new();
|
||||||
let mut data_input = Vec::new();
|
let mut data_input = Vec::new();
|
||||||
let mut arguments_return = Vec::new();
|
let mut arguments_return = Vec::new();
|
||||||
let mut arguments_input = Vec::new();
|
let mut arguments_input = Vec::new();
|
||||||
ast::visit(&i, &mut |name: &SpirvWord,
|
ast::visit(&i, &mut |name: &SpirvWord,
|
||||||
type_space: Option<(
|
type_space: Option<(
|
||||||
&ptx_parser::Type,
|
&ptx_parser::Type,
|
||||||
ptx_parser::StateSpace,
|
ptx_parser::StateSpace,
|
||||||
)>,
|
)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
_: bool| {
|
_: bool| {
|
||||||
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
|
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
|
||||||
if is_dst {
|
if is_dst {
|
||||||
data_return.push((type_.clone(), space));
|
data_return.push((type_.clone(), space));
|
||||||
arguments_return.push(*name);
|
arguments_return.push(*name);
|
||||||
} else {
|
} else {
|
||||||
data_input.push((type_.clone(), space));
|
data_input.push((type_.clone(), space));
|
||||||
arguments_input.push(*name);
|
arguments_input.push(*name);
|
||||||
};
|
};
|
||||||
Ok::<_, TranslateError>(())
|
Ok::<_, TranslateError>(())
|
||||||
})?;
|
})?;
|
||||||
let fn_name = match fn_declarations.entry(name) {
|
let fn_name = match fn_declarations.entry(name) {
|
||||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||||
hash_map::Entry::Vacant(vacant_entry) => {
|
hash_map::Entry::Vacant(vacant_entry) => {
|
||||||
let name = vacant_entry.key().clone();
|
let name = vacant_entry.key().clone();
|
||||||
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
||||||
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
||||||
vacant_entry.insert((
|
vacant_entry.insert((
|
||||||
to_variables(resolver, &data_return),
|
to_variables(resolver, &data_return),
|
||||||
name,
|
name,
|
||||||
to_variables(resolver, &data_input),
|
to_variables(resolver, &data_input),
|
||||||
));
|
));
|
||||||
name
|
name
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(ast::Instruction::Call {
|
Ok(ast::Instruction::Call {
|
||||||
data: ptx_parser::CallDetails {
|
data: ptx_parser::CallDetails {
|
||||||
uniform: false,
|
uniform: false,
|
||||||
return_arguments: data_return,
|
return_arguments: data_return,
|
||||||
input_arguments: data_input,
|
input_arguments: data_input,
|
||||||
},
|
},
|
||||||
arguments: ptx_parser::CallArgs {
|
arguments: ptx_parser::CallArgs {
|
||||||
return_arguments: arguments_return,
|
return_arguments: arguments_return,
|
||||||
func: fn_name,
|
func: fn_name,
|
||||||
input_arguments: arguments_input,
|
input_arguments: arguments_input,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_variables<'input>(
|
fn to_variables<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
) -> Vec<ptx_parser::Variable<SpirvWord>> {
|
) -> Vec<ptx_parser::Variable<SpirvWord>> {
|
||||||
arguments
|
arguments
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(type_, space)| ast::Variable {
|
.map(|(type_, space)| ast::Variable {
|
||||||
align: None,
|
align: None,
|
||||||
v_type: type_.clone(),
|
v_type: type_.clone(),
|
||||||
state_space: *space,
|
state_space: *space,
|
||||||
name: resolver.register_unnamed(Some((type_.clone(), *space))),
|
name: resolver.register_unnamed(Some((type_.clone(), *space))),
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
@ -1,33 +1,33 @@
|
|||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
|
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
mut directives: Vec<NormalizedDirective2>,
|
mut directives: Vec<NormalizedDirective2>,
|
||||||
) -> Vec<NormalizedDirective2> {
|
) -> Vec<NormalizedDirective2> {
|
||||||
for directive in directives.iter_mut() {
|
for directive in directives.iter_mut() {
|
||||||
match directive {
|
match directive {
|
||||||
NormalizedDirective2::Method(func) => {
|
NormalizedDirective2::Method(func) => {
|
||||||
replace_with_ptx_impl(resolver, func.name);
|
replace_with_ptx_impl(resolver, func.name);
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
directives
|
directives
|
||||||
}
|
}
|
||||||
|
|
||||||
fn replace_with_ptx_impl<'input>(
|
fn replace_with_ptx_impl<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
fn_name: SpirvWord,
|
fn_name: SpirvWord,
|
||||||
) {
|
) {
|
||||||
let known_names = ["__assertfail"];
|
let known_names = ["__assertfail"];
|
||||||
if let Some(super::IdentEntry {
|
if let Some(super::IdentEntry {
|
||||||
name: Some(name), ..
|
name: Some(name), ..
|
||||||
}) = resolver.ident_map.get_mut(&fn_name)
|
}) = resolver.ident_map.get_mut(&fn_name)
|
||||||
{
|
{
|
||||||
if known_names.contains(&&**name) {
|
if known_names.contains(&&**name) {
|
||||||
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
|
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,69 +1,69 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
use rustc_hash::FxHashSet;
|
use rustc_hash::FxHashSet;
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
directives: Vec<UnconditionalDirective>,
|
directives: Vec<UnconditionalDirective>,
|
||||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||||
let mut functions = FxHashSet::default();
|
let mut functions = FxHashSet::default();
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(&mut functions, directive))
|
.map(|directive| run_directive(&mut functions, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
functions: &mut FxHashSet<SpirvWord>,
|
functions: &mut FxHashSet<SpirvWord>,
|
||||||
directive: UnconditionalDirective,
|
directive: UnconditionalDirective,
|
||||||
) -> Result<UnconditionalDirective, TranslateError> {
|
) -> Result<UnconditionalDirective, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(method) => {
|
Directive2::Method(method) => {
|
||||||
if !method.is_kernel {
|
if !method.is_kernel {
|
||||||
functions.insert(method.name);
|
functions.insert(method.name);
|
||||||
}
|
}
|
||||||
Directive2::Method(run_method(functions, method)?)
|
Directive2::Method(run_method(functions, method)?)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_method<'input>(
|
fn run_method<'input>(
|
||||||
functions: &mut FxHashSet<SpirvWord>,
|
functions: &mut FxHashSet<SpirvWord>,
|
||||||
method: UnconditionalFunction,
|
method: UnconditionalFunction,
|
||||||
) -> Result<UnconditionalFunction, TranslateError> {
|
) -> Result<UnconditionalFunction, TranslateError> {
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
statements
|
statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|statement| run_statement(functions, statement))
|
.map(|statement| run_statement(functions, statement))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 { body, ..method })
|
Ok(Function2 { body, ..method })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'input>(
|
fn run_statement<'input>(
|
||||||
functions: &mut FxHashSet<SpirvWord>,
|
functions: &mut FxHashSet<SpirvWord>,
|
||||||
statement: UnconditionalStatement,
|
statement: UnconditionalStatement,
|
||||||
) -> Result<UnconditionalStatement, TranslateError> {
|
) -> Result<UnconditionalStatement, TranslateError> {
|
||||||
Ok(match statement {
|
Ok(match statement {
|
||||||
Statement::Instruction(ast::Instruction::Mov {
|
Statement::Instruction(ast::Instruction::Mov {
|
||||||
data,
|
data,
|
||||||
arguments:
|
arguments:
|
||||||
ast::MovArgs {
|
ast::MovArgs {
|
||||||
dst: ast::ParsedOperand::Reg(dst_reg),
|
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||||
src: ast::ParsedOperand::Reg(src_reg),
|
src: ast::ParsedOperand::Reg(src_reg),
|
||||||
},
|
},
|
||||||
}) if functions.contains(&src_reg) => {
|
}) if functions.contains(&src_reg) => {
|
||||||
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
|
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
|
||||||
dst: dst_reg,
|
dst: dst_reg,
|
||||||
src: src_reg,
|
src: src_reg,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
s => s,
|
s => s,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,327 +1,327 @@
|
|||||||
use bpaf::{Args, Bpaf, Parser};
|
use bpaf::{Args, Bpaf, Parser};
|
||||||
use cargo_metadata::{MetadataCommand, Package};
|
use cargo_metadata::{MetadataCommand, Package};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::{env, ffi::OsString, path::PathBuf, process::Command};
|
use std::{env, ffi::OsString, path::PathBuf, process::Command};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Bpaf)]
|
#[derive(Debug, Clone, Bpaf)]
|
||||||
#[bpaf(options)]
|
#[bpaf(options)]
|
||||||
enum Options {
|
enum Options {
|
||||||
#[bpaf(command)]
|
#[bpaf(command)]
|
||||||
/// Compile ZLUDA (default command)
|
/// Compile ZLUDA (default command)
|
||||||
Build(#[bpaf(external(build))] Build),
|
Build(#[bpaf(external(build))] Build),
|
||||||
#[bpaf(command)]
|
#[bpaf(command)]
|
||||||
/// Compile ZLUDA and build a package
|
/// Compile ZLUDA and build a package
|
||||||
Zip(#[bpaf(external(build))] Build),
|
Zip(#[bpaf(external(build))] Build),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Bpaf)]
|
#[derive(Debug, Clone, Bpaf)]
|
||||||
struct Build {
|
struct Build {
|
||||||
#[bpaf(any("CARGO", not_help), many)]
|
#[bpaf(any("CARGO", not_help), many)]
|
||||||
/// Arguments to pass to cargo, e.g. `--release` for release build
|
/// Arguments to pass to cargo, e.g. `--release` for release build
|
||||||
cargo_arguments: Vec<OsString>,
|
cargo_arguments: Vec<OsString>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn not_help(s: OsString) -> Option<OsString> {
|
fn not_help(s: OsString) -> Option<OsString> {
|
||||||
if s == "-h" || s == "--help" {
|
if s == "-h" || s == "--help" {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(s)
|
Some(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to sniff out some args passed to cargo to understand how to create
|
// We need to sniff out some args passed to cargo to understand how to create
|
||||||
// symlinks (should they go into `target/debug`, `target/release` or custom)
|
// symlinks (should they go into `target/debug`, `target/release` or custom)
|
||||||
#[derive(Debug, Clone, Bpaf)]
|
#[derive(Debug, Clone, Bpaf)]
|
||||||
struct Cargo {
|
struct Cargo {
|
||||||
#[bpaf(switch, long, short)]
|
#[bpaf(switch, long, short)]
|
||||||
release: Option<bool>,
|
release: Option<bool>,
|
||||||
#[bpaf(long)]
|
#[bpaf(long)]
|
||||||
profile: Option<String>,
|
profile: Option<String>,
|
||||||
#[bpaf(any("", Some), many)]
|
#[bpaf(any("", Some), many)]
|
||||||
_unused: Vec<OsString>,
|
_unused: Vec<OsString>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Project {
|
struct Project {
|
||||||
name: String,
|
name: String,
|
||||||
target_name: String,
|
target_name: String,
|
||||||
target_kind: ProjectTarget,
|
target_kind: ProjectTarget,
|
||||||
meta: ZludaMetadata,
|
meta: ZludaMetadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Project {
|
impl Project {
|
||||||
fn try_new(p: Package) -> Option<Project> {
|
fn try_new(p: Package) -> Option<Project> {
|
||||||
let name = p.name;
|
let name = p.name;
|
||||||
serde_json::from_value::<Option<Metadata>>(p.metadata)
|
serde_json::from_value::<Option<Metadata>>(p.metadata)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.map(|m| {
|
.map(|m| {
|
||||||
let (target_name, target_kind) = p
|
let (target_name, target_kind) = p
|
||||||
.targets
|
.targets
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find_map(|target| {
|
.find_map(|target| {
|
||||||
if target.is_cdylib() {
|
if target.is_cdylib() {
|
||||||
Some((target.name, ProjectTarget::Cdylib))
|
Some((target.name, ProjectTarget::Cdylib))
|
||||||
} else if target.is_bin() {
|
} else if target.is_bin() {
|
||||||
Some((target.name, ProjectTarget::Bin))
|
Some((target.name, ProjectTarget::Bin))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
Self {
|
Self {
|
||||||
name,
|
name,
|
||||||
target_name,
|
target_name,
|
||||||
target_kind,
|
target_kind,
|
||||||
meta: m.zluda,
|
meta: m.zluda,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn prefix(&self) -> &'static str {
|
fn prefix(&self) -> &'static str {
|
||||||
match self.target_kind {
|
match self.target_kind {
|
||||||
ProjectTarget::Bin => "",
|
ProjectTarget::Bin => "",
|
||||||
ProjectTarget::Cdylib => "lib",
|
ProjectTarget::Cdylib => "lib",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
fn prefix(&self) -> &'static str {
|
fn prefix(&self) -> &'static str {
|
||||||
""
|
""
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn suffix(&self) -> &'static str {
|
fn suffix(&self) -> &'static str {
|
||||||
match self.target_kind {
|
match self.target_kind {
|
||||||
ProjectTarget::Bin => "",
|
ProjectTarget::Bin => "",
|
||||||
ProjectTarget::Cdylib => ".so",
|
ProjectTarget::Cdylib => ".so",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
fn suffix(&self) -> &'static str {
|
fn suffix(&self) -> &'static str {
|
||||||
match self.target_kind {
|
match self.target_kind {
|
||||||
ProjectTarget::Bin => ".exe",
|
ProjectTarget::Bin => ".exe",
|
||||||
ProjectTarget::Cdylib => ".dll",
|
ProjectTarget::Cdylib => ".dll",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns tuple:
|
// Returns tuple:
|
||||||
// * symlink file path (relative to the root of build dir)
|
// * symlink file path (relative to the root of build dir)
|
||||||
// * symlink absolute file path
|
// * symlink absolute file path
|
||||||
// * target actual file (relative to symlink file)
|
// * target actual file (relative to symlink file)
|
||||||
#[cfg_attr(not(unix), allow(unused))]
|
#[cfg_attr(not(unix), allow(unused))]
|
||||||
fn symlinks<'a>(
|
fn symlinks<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
target_dir: &'a PathBuf,
|
target_dir: &'a PathBuf,
|
||||||
profile: &'a str,
|
profile: &'a str,
|
||||||
libname: &'a str,
|
libname: &'a str,
|
||||||
) -> impl Iterator<Item = (&'a str, PathBuf, PathBuf)> + 'a {
|
) -> impl Iterator<Item = (&'a str, PathBuf, PathBuf)> + 'a {
|
||||||
self.meta.linux_symlinks.iter().map(move |source| {
|
self.meta.linux_symlinks.iter().map(move |source| {
|
||||||
let mut link = target_dir.clone();
|
let mut link = target_dir.clone();
|
||||||
link.extend([profile, source]);
|
link.extend([profile, source]);
|
||||||
let relative_link = PathBuf::from(source);
|
let relative_link = PathBuf::from(source);
|
||||||
let ancestors = relative_link.as_path().ancestors().count();
|
let ancestors = relative_link.as_path().ancestors().count();
|
||||||
let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold(
|
let mut target = std::iter::repeat_with(|| "../").take(ancestors - 2).fold(
|
||||||
PathBuf::new(),
|
PathBuf::new(),
|
||||||
|mut buff, segment| {
|
|mut buff, segment| {
|
||||||
buff.push(segment);
|
buff.push(segment);
|
||||||
buff
|
buff
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
target.push(libname);
|
target.push(libname);
|
||||||
(&**source, link, target)
|
(&**source, link, target)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn file_name(&self) -> String {
|
fn file_name(&self) -> String {
|
||||||
let target_name = &self.target_name;
|
let target_name = &self.target_name;
|
||||||
let prefix = self.prefix();
|
let prefix = self.prefix();
|
||||||
let suffix = self.suffix();
|
let suffix = self.suffix();
|
||||||
format!("{prefix}{target_name}{suffix}")
|
format!("{prefix}{target_name}{suffix}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
enum ProjectTarget {
|
enum ProjectTarget {
|
||||||
Cdylib,
|
Cdylib,
|
||||||
Bin,
|
Bin,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Metadata {
|
struct Metadata {
|
||||||
zluda: ZludaMetadata,
|
zluda: ZludaMetadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
#[serde(deny_unknown_fields)]
|
#[serde(deny_unknown_fields)]
|
||||||
struct ZludaMetadata {
|
struct ZludaMetadata {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
windows_only: bool,
|
windows_only: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
debug_only: bool,
|
debug_only: bool,
|
||||||
#[cfg_attr(not(unix), allow(unused))]
|
#[cfg_attr(not(unix), allow(unused))]
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
linux_symlinks: Vec<String>,
|
linux_symlinks: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let options = match options().run_inner(Args::current_args()) {
|
let options = match options().run_inner(Args::current_args()) {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
Err(err) => match build().to_options().run_inner(Args::current_args()) {
|
Err(err) => match build().to_options().run_inner(Args::current_args()) {
|
||||||
Ok(b) => Options::Build(b),
|
Ok(b) => Options::Build(b),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
err.print_message(100);
|
err.print_message(100);
|
||||||
std::process::exit(err.exit_code());
|
std::process::exit(err.exit_code());
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
match options {
|
match options {
|
||||||
Options::Build(b) => {
|
Options::Build(b) => {
|
||||||
compile(b);
|
compile(b);
|
||||||
}
|
}
|
||||||
Options::Zip(b) => zip(b),
|
Options::Zip(b) => zip(b),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compile(b: Build) -> (PathBuf, String, Vec<Project>) {
|
fn compile(b: Build) -> (PathBuf, String, Vec<Project>) {
|
||||||
let profile = sniff_out_profile_name(&b.cargo_arguments);
|
let profile = sniff_out_profile_name(&b.cargo_arguments);
|
||||||
let meta = MetadataCommand::new().no_deps().exec().unwrap();
|
let meta = MetadataCommand::new().no_deps().exec().unwrap();
|
||||||
let target_directory = meta.target_directory.into_std_path_buf();
|
let target_directory = meta.target_directory.into_std_path_buf();
|
||||||
let projects = meta
|
let projects = meta
|
||||||
.packages
|
.packages
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(Project::try_new)
|
.filter_map(Project::try_new)
|
||||||
.filter(|project| {
|
.filter(|project| {
|
||||||
if project.meta.windows_only && cfg!(not(windows)) {
|
if project.meta.windows_only && cfg!(not(windows)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if project.meta.debug_only && profile != "debug" {
|
if project.meta.debug_only && profile != "debug" {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
|
let cargo = env::var("CARGO").unwrap_or_else(|_| "cargo".to_string());
|
||||||
let mut command = Command::new(&cargo);
|
let mut command = Command::new(&cargo);
|
||||||
command.arg("build");
|
command.arg("build");
|
||||||
command.arg("--locked");
|
command.arg("--locked");
|
||||||
for project in projects.iter() {
|
for project in projects.iter() {
|
||||||
command.arg("--package");
|
command.arg("--package");
|
||||||
command.arg(&project.name);
|
command.arg(&project.name);
|
||||||
}
|
}
|
||||||
command.args(b.cargo_arguments);
|
command.args(b.cargo_arguments);
|
||||||
assert!(command.status().unwrap().success());
|
assert!(command.status().unwrap().success());
|
||||||
os::make_symlinks(&target_directory, &*projects, &*profile);
|
os::make_symlinks(&target_directory, &*projects, &*profile);
|
||||||
(target_directory, profile, projects)
|
(target_directory, profile, projects)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sniff_out_profile_name(b: &[OsString]) -> String {
|
fn sniff_out_profile_name(b: &[OsString]) -> String {
|
||||||
let parsed_cargo_arguments = cargo().to_options().run_inner(b);
|
let parsed_cargo_arguments = cargo().to_options().run_inner(b);
|
||||||
match parsed_cargo_arguments {
|
match parsed_cargo_arguments {
|
||||||
Ok(Cargo {
|
Ok(Cargo {
|
||||||
release: Some(true),
|
release: Some(true),
|
||||||
..
|
..
|
||||||
}) => "release".to_string(),
|
}) => "release".to_string(),
|
||||||
Ok(Cargo {
|
Ok(Cargo {
|
||||||
profile: Some(profile),
|
profile: Some(profile),
|
||||||
..
|
..
|
||||||
}) => profile,
|
}) => profile,
|
||||||
_ => "debug".to_string(),
|
_ => "debug".to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn zip(zip: Build) {
|
fn zip(zip: Build) {
|
||||||
let (target_dir, profile, projects) = compile(zip);
|
let (target_dir, profile, projects) = compile(zip);
|
||||||
os::zip(target_dir, profile, projects)
|
os::zip(target_dir, profile, projects)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
mod os {
|
mod os {
|
||||||
use flate2::write::GzEncoder;
|
use flate2::write::GzEncoder;
|
||||||
use flate2::Compression;
|
use flate2::Compression;
|
||||||
use std::{
|
use std::{
|
||||||
fs::{self, File},
|
fs::{self, File},
|
||||||
path::PathBuf,
|
path::PathBuf,
|
||||||
};
|
};
|
||||||
use tar::Header;
|
use tar::Header;
|
||||||
|
|
||||||
pub fn make_symlinks(
|
pub fn make_symlinks(
|
||||||
target_directory: &std::path::PathBuf,
|
target_directory: &std::path::PathBuf,
|
||||||
projects: &[super::Project],
|
projects: &[super::Project],
|
||||||
profile: &str,
|
profile: &str,
|
||||||
) {
|
) {
|
||||||
use std::os::unix::fs as unix_fs;
|
use std::os::unix::fs as unix_fs;
|
||||||
for project in projects.iter() {
|
for project in projects.iter() {
|
||||||
let libname = project.file_name();
|
let libname = project.file_name();
|
||||||
for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) {
|
for (_, full_path, target) in project.symlinks(target_directory, profile, &libname) {
|
||||||
let mut dir = full_path.clone();
|
let mut dir = full_path.clone();
|
||||||
assert!(dir.pop());
|
assert!(dir.pop());
|
||||||
fs::create_dir_all(dir).unwrap();
|
fs::create_dir_all(dir).unwrap();
|
||||||
fs::remove_file(&full_path).ok();
|
fs::remove_file(&full_path).ok();
|
||||||
unix_fs::symlink(&target, full_path).unwrap();
|
unix_fs::symlink(&target, full_path).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
|
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
|
||||||
let tar_gz =
|
let tar_gz =
|
||||||
File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap();
|
File::create(format!("{}/{profile}/zluda.tar.gz", target_dir.display())).unwrap();
|
||||||
let enc = GzEncoder::new(tar_gz, Compression::default());
|
let enc = GzEncoder::new(tar_gz, Compression::default());
|
||||||
let mut tar = tar::Builder::new(enc);
|
let mut tar = tar::Builder::new(enc);
|
||||||
for project in projects.iter() {
|
for project in projects.iter() {
|
||||||
let file_name = project.file_name();
|
let file_name = project.file_name();
|
||||||
let mut file =
|
let mut file =
|
||||||
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
|
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
|
||||||
tar.append_file(format!("zluda/{file_name}"), &mut file)
|
tar.append_file(format!("zluda/{file_name}"), &mut file)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) {
|
for (source, full_path, target) in project.symlinks(&target_dir, &profile, &file_name) {
|
||||||
let mut header = Header::new_gnu();
|
let mut header = Header::new_gnu();
|
||||||
let meta = fs::symlink_metadata(&full_path).unwrap();
|
let meta = fs::symlink_metadata(&full_path).unwrap();
|
||||||
header.set_metadata(&meta);
|
header.set_metadata(&meta);
|
||||||
tar.append_link(&mut header, format!("zluda/{source}"), target)
|
tar.append_link(&mut header, format!("zluda/{source}"), target)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tar.finish().unwrap();
|
tar.finish().unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
mod os {
|
mod os {
|
||||||
use std::{fs::File, io, path::PathBuf};
|
use std::{fs::File, io, path::PathBuf};
|
||||||
use zip::{write::SimpleFileOptions, ZipWriter};
|
use zip::{write::SimpleFileOptions, ZipWriter};
|
||||||
|
|
||||||
pub fn make_symlinks(
|
pub fn make_symlinks(
|
||||||
_target_directory: &std::path::PathBuf,
|
_target_directory: &std::path::PathBuf,
|
||||||
_projects: &[super::Project],
|
_projects: &[super::Project],
|
||||||
_profile: &str,
|
_profile: &str,
|
||||||
) {
|
) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
|
pub(crate) fn zip(target_dir: PathBuf, profile: String, projects: Vec<crate::Project>) {
|
||||||
let zip_file =
|
let zip_file =
|
||||||
File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap();
|
File::create(format!("{}/{profile}/zluda.zip", target_dir.display())).unwrap();
|
||||||
let mut zip = ZipWriter::new(zip_file);
|
let mut zip = ZipWriter::new(zip_file);
|
||||||
zip.add_directory("zluda", SimpleFileOptions::default())
|
zip.add_directory("zluda", SimpleFileOptions::default())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
for project in projects.iter() {
|
for project in projects.iter() {
|
||||||
let file_name = project.file_name();
|
let file_name = project.file_name();
|
||||||
let mut file =
|
let mut file =
|
||||||
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
|
File::open(format!("{}/{profile}/{file_name}", target_dir.display())).unwrap();
|
||||||
let file_options = file_options_from_time(&file).unwrap_or_default();
|
let file_options = file_options_from_time(&file).unwrap_or_default();
|
||||||
zip.start_file(format!("zluda/{file_name}"), file_options)
|
zip.start_file(format!("zluda/{file_name}"), file_options)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
io::copy(&mut file, &mut zip).unwrap();
|
io::copy(&mut file, &mut zip).unwrap();
|
||||||
}
|
}
|
||||||
zip.finish().unwrap();
|
zip.finish().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn file_options_from_time(from: &File) -> io::Result<SimpleFileOptions> {
|
fn file_options_from_time(from: &File) -> io::Result<SimpleFileOptions> {
|
||||||
let metadata = from.metadata()?;
|
let metadata = from.metadata()?;
|
||||||
let modified = metadata.modified()?;
|
let modified = metadata.modified()?;
|
||||||
let modified = time::OffsetDateTime::from(modified);
|
let modified = time::OffsetDateTime::from(modified);
|
||||||
Ok(SimpleFileOptions::default().last_modified_time(
|
Ok(SimpleFileOptions::default().last_modified_time(
|
||||||
zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?,
|
zip::DateTime::try_from(modified).map_err(|err| io::Error::other(err))?,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,426 +1,426 @@
|
|||||||
use super::{FromCuda, LiveCheck};
|
use super::{FromCuda, LiveCheck};
|
||||||
use crate::r#impl::{context, device};
|
use crate::r#impl::{context, device};
|
||||||
use comgr::Comgr;
|
use comgr::Comgr;
|
||||||
use cuda_types::cuda::*;
|
use cuda_types::cuda::*;
|
||||||
use hip_runtime_sys::*;
|
use hip_runtime_sys::*;
|
||||||
use std::{
|
use std::{
|
||||||
ffi::{c_void, CStr, CString},
|
ffi::{c_void, CStr, CString},
|
||||||
mem, ptr, slice,
|
mem, ptr, slice,
|
||||||
sync::OnceLock,
|
sync::OnceLock,
|
||||||
usize,
|
usize,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg_attr(windows, path = "os_win.rs")]
|
#[cfg_attr(windows, path = "os_win.rs")]
|
||||||
#[cfg_attr(not(windows), path = "os_unix.rs")]
|
#[cfg_attr(not(windows), path = "os_unix.rs")]
|
||||||
mod os;
|
mod os;
|
||||||
|
|
||||||
pub(crate) struct GlobalState {
|
pub(crate) struct GlobalState {
|
||||||
pub devices: Vec<Device>,
|
pub devices: Vec<Device>,
|
||||||
pub comgr: Comgr,
|
pub comgr: Comgr,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct Device {
|
pub(crate) struct Device {
|
||||||
pub(crate) _comgr_isa: CString,
|
pub(crate) _comgr_isa: CString,
|
||||||
primary_context: LiveCheck<context::Context>,
|
primary_context: LiveCheck<context::Context>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Device {
|
impl Device {
|
||||||
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
|
pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
|
||||||
unsafe {
|
unsafe {
|
||||||
(
|
(
|
||||||
self.primary_context.data.assume_init_ref(),
|
self.primary_context.data.assume_init_ref(),
|
||||||
self.primary_context.as_handle(),
|
self.primary_context.as_handle(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
|
pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
|
||||||
global_state()?
|
global_state()?
|
||||||
.devices
|
.devices
|
||||||
.get(dev as usize)
|
.get(dev as usize)
|
||||||
.ok_or(CUerror::INVALID_DEVICE)
|
.ok_or(CUerror::INVALID_DEVICE)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
||||||
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
|
static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
|
||||||
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
|
fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
|
||||||
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
|
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
|
||||||
}
|
}
|
||||||
GLOBAL_STATE
|
GLOBAL_STATE
|
||||||
.get_or_init(|| {
|
.get_or_init(|| {
|
||||||
let mut device_count = 0;
|
let mut device_count = 0;
|
||||||
unsafe { hipGetDeviceCount(&mut device_count) }?;
|
unsafe { hipGetDeviceCount(&mut device_count) }?;
|
||||||
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
|
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
|
||||||
Ok(GlobalState {
|
Ok(GlobalState {
|
||||||
comgr,
|
comgr,
|
||||||
devices: (0..device_count)
|
devices: (0..device_count)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let mut props = unsafe { mem::zeroed() };
|
let mut props = unsafe { mem::zeroed() };
|
||||||
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
|
unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
|
||||||
Ok::<_, CUerror>(Device {
|
Ok::<_, CUerror>(Device {
|
||||||
_comgr_isa: CStr::from_bytes_until_nul(cast_slice(
|
_comgr_isa: CStr::from_bytes_until_nul(cast_slice(
|
||||||
&props.gcnArchName[..],
|
&props.gcnArchName[..],
|
||||||
))
|
))
|
||||||
.map_err(|_| CUerror::UNKNOWN)?
|
.map_err(|_| CUerror::UNKNOWN)?
|
||||||
.to_owned(),
|
.to_owned(),
|
||||||
primary_context: LiveCheck::new(context::Context::new(i)),
|
primary_context: LiveCheck::new(context::Context::new(i)),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?,
|
.collect::<Result<Vec<_>, _>>()?,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_err(|e| *e)
|
.map_err(|e| *e)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
|
pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
|
||||||
unsafe { hipInit(flags) }?;
|
unsafe { hipInit(flags) }?;
|
||||||
global_state()?;
|
global_state()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct UnknownBuffer<const S: usize> {
|
struct UnknownBuffer<const S: usize> {
|
||||||
buffer: std::cell::UnsafeCell<[u32; S]>,
|
buffer: std::cell::UnsafeCell<[u32; S]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const S: usize> UnknownBuffer<S> {
|
impl<const S: usize> UnknownBuffer<S> {
|
||||||
const fn new() -> Self {
|
const fn new() -> Self {
|
||||||
UnknownBuffer {
|
UnknownBuffer {
|
||||||
buffer: std::cell::UnsafeCell::new([0; S]),
|
buffer: std::cell::UnsafeCell::new([0; S]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const fn len(&self) -> usize {
|
const fn len(&self) -> usize {
|
||||||
S
|
S
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<const S: usize> Sync for UnknownBuffer<S> {}
|
unsafe impl<const S: usize> Sync for UnknownBuffer<S> {}
|
||||||
|
|
||||||
static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new();
|
static UNKNOWN_BUFFER1: UnknownBuffer<1024> = UnknownBuffer::new();
|
||||||
static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new();
|
static UNKNOWN_BUFFER2: UnknownBuffer<14> = UnknownBuffer::new();
|
||||||
|
|
||||||
struct DarkApi {}
|
struct DarkApi {}
|
||||||
|
|
||||||
impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||||
unsafe extern "system" fn get_module_from_cubin(
|
unsafe extern "system" fn get_module_from_cubin(
|
||||||
_module: *mut cuda_types::cuda::CUmodule,
|
_module: *mut cuda_types::cuda::CUmodule,
|
||||||
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn cudart_interface_fn2(
|
unsafe extern "system" fn cudart_interface_fn2(
|
||||||
pctx: *mut cuda_types::cuda::CUcontext,
|
pctx: *mut cuda_types::cuda::CUcontext,
|
||||||
hip_dev: hipDevice_t,
|
hip_dev: hipDevice_t,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
let pctx = match pctx.as_mut() {
|
let pctx = match pctx.as_mut() {
|
||||||
Some(p) => p,
|
Some(p) => p,
|
||||||
None => return CUresult::ERROR_INVALID_VALUE,
|
None => return CUresult::ERROR_INVALID_VALUE,
|
||||||
};
|
};
|
||||||
|
|
||||||
device::primary_context_retain(pctx, hip_dev)
|
device::primary_context_retain(pctx, hip_dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn get_module_from_cubin_ext1(
|
unsafe extern "system" fn get_module_from_cubin_ext1(
|
||||||
_result: *mut cuda_types::cuda::CUmodule,
|
_result: *mut cuda_types::cuda::CUmodule,
|
||||||
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
||||||
_arg3: *mut std::ffi::c_void,
|
_arg3: *mut std::ffi::c_void,
|
||||||
_arg4: *mut std::ffi::c_void,
|
_arg4: *mut std::ffi::c_void,
|
||||||
_arg5: u32,
|
_arg5: u32,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
|
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn get_module_from_cubin_ext2(
|
unsafe extern "system" fn get_module_from_cubin_ext2(
|
||||||
_fatbin_header: *const cuda_types::dark_api::FatbinHeader,
|
_fatbin_header: *const cuda_types::dark_api::FatbinHeader,
|
||||||
_result: *mut cuda_types::cuda::CUmodule,
|
_result: *mut cuda_types::cuda::CUmodule,
|
||||||
_arg3: *mut std::ffi::c_void,
|
_arg3: *mut std::ffi::c_void,
|
||||||
_arg4: *mut std::ffi::c_void,
|
_arg4: *mut std::ffi::c_void,
|
||||||
_arg5: u32,
|
_arg5: u32,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn get_unknown_buffer1(
|
unsafe extern "system" fn get_unknown_buffer1(
|
||||||
ptr: *mut *mut std::ffi::c_void,
|
ptr: *mut *mut std::ffi::c_void,
|
||||||
size: *mut usize,
|
size: *mut usize,
|
||||||
) -> () {
|
) -> () {
|
||||||
*ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void;
|
*ptr = UNKNOWN_BUFFER1.buffer.get() as *mut std::ffi::c_void;
|
||||||
*size = UNKNOWN_BUFFER1.len();
|
*size = UNKNOWN_BUFFER1.len();
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn get_unknown_buffer2(
|
unsafe extern "system" fn get_unknown_buffer2(
|
||||||
ptr: *mut *mut std::ffi::c_void,
|
ptr: *mut *mut std::ffi::c_void,
|
||||||
size: *mut usize,
|
size: *mut usize,
|
||||||
) -> () {
|
) -> () {
|
||||||
*ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void;
|
*ptr = UNKNOWN_BUFFER2.buffer.get() as *mut std::ffi::c_void;
|
||||||
*size = UNKNOWN_BUFFER2.len();
|
*size = UNKNOWN_BUFFER2.len();
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn context_local_storage_put(
|
unsafe extern "system" fn context_local_storage_put(
|
||||||
cu_ctx: CUcontext,
|
cu_ctx: CUcontext,
|
||||||
key: *mut c_void,
|
key: *mut c_void,
|
||||||
value: *mut c_void,
|
value: *mut c_void,
|
||||||
dtor_cb: Option<extern "system" fn(CUcontext, *mut c_void, *mut c_void)>,
|
dtor_cb: Option<extern "system" fn(CUcontext, *mut c_void, *mut c_void)>,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
let _ctx = if cu_ctx.0 != ptr::null_mut() {
|
let _ctx = if cu_ctx.0 != ptr::null_mut() {
|
||||||
cu_ctx
|
cu_ctx
|
||||||
} else {
|
} else {
|
||||||
let mut current_ctx: CUcontext = CUcontext(ptr::null_mut());
|
let mut current_ctx: CUcontext = CUcontext(ptr::null_mut());
|
||||||
context::get_current(&mut current_ctx)?;
|
context::get_current(&mut current_ctx)?;
|
||||||
current_ctx
|
current_ctx
|
||||||
};
|
};
|
||||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
|
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
|
||||||
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
|
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
|
||||||
state.storage.insert(
|
state.storage.insert(
|
||||||
key as usize,
|
key as usize,
|
||||||
context::StorageData {
|
context::StorageData {
|
||||||
value: value as usize,
|
value: value as usize,
|
||||||
reset_cb: dtor_cb,
|
reset_cb: dtor_cb,
|
||||||
handle: _ctx,
|
handle: _ctx,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn context_local_storage_delete(
|
unsafe extern "system" fn context_local_storage_delete(
|
||||||
cu_ctx: CUcontext,
|
cu_ctx: CUcontext,
|
||||||
key: *mut c_void,
|
key: *mut c_void,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?;
|
let ctx_obj: &context::Context = FromCuda::from_cuda(&cu_ctx)?;
|
||||||
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
|
ctx_obj.with_state_mut(|state: &mut context::ContextState| {
|
||||||
state.storage.remove(&(key as usize));
|
state.storage.remove(&(key as usize));
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn context_local_storage_get(
|
unsafe extern "system" fn context_local_storage_get(
|
||||||
value: *mut *mut c_void,
|
value: *mut *mut c_void,
|
||||||
cu_ctx: CUcontext,
|
cu_ctx: CUcontext,
|
||||||
key: *mut c_void,
|
key: *mut c_void,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
let mut _ctx: CUcontext;
|
let mut _ctx: CUcontext;
|
||||||
if cu_ctx.0 == ptr::null_mut() {
|
if cu_ctx.0 == ptr::null_mut() {
|
||||||
_ctx = context::get_current_context()?;
|
_ctx = context::get_current_context()?;
|
||||||
} else {
|
} else {
|
||||||
_ctx = cu_ctx
|
_ctx = cu_ctx
|
||||||
};
|
};
|
||||||
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
|
let ctx_obj: &context::Context = FromCuda::from_cuda(&_ctx)?;
|
||||||
ctx_obj.with_state(|state: &context::ContextState| {
|
ctx_obj.with_state(|state: &context::ContextState| {
|
||||||
match state.storage.get(&(key as usize)) {
|
match state.storage.get(&(key as usize)) {
|
||||||
Some(data) => *value = data.value as *mut c_void,
|
Some(data) => *value = data.value as *mut c_void,
|
||||||
None => return CUresult::ERROR_INVALID_HANDLE,
|
None => return CUresult::ERROR_INVALID_HANDLE,
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn ctx_create_v2_bypass(
|
unsafe extern "system" fn ctx_create_v2_bypass(
|
||||||
_pctx: *mut cuda_types::cuda::CUcontext,
|
_pctx: *mut cuda_types::cuda::CUcontext,
|
||||||
_flags: ::std::os::raw::c_uint,
|
_flags: ::std::os::raw::c_uint,
|
||||||
_dev: cuda_types::cuda::CUdevice,
|
_dev: cuda_types::cuda::CUdevice,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn heap_alloc(
|
unsafe extern "system" fn heap_alloc(
|
||||||
_heap_alloc_record_ptr: *mut *const std::ffi::c_void,
|
_heap_alloc_record_ptr: *mut *const std::ffi::c_void,
|
||||||
_arg2: usize,
|
_arg2: usize,
|
||||||
_arg3: usize,
|
_arg3: usize,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn heap_free(
|
unsafe extern "system" fn heap_free(
|
||||||
_heap_alloc_record_ptr: *const std::ffi::c_void,
|
_heap_alloc_record_ptr: *const std::ffi::c_void,
|
||||||
_arg2: *mut usize,
|
_arg2: *mut usize,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn device_get_attribute_ext(
|
unsafe extern "system" fn device_get_attribute_ext(
|
||||||
_dev: cuda_types::cuda::CUdevice,
|
_dev: cuda_types::cuda::CUdevice,
|
||||||
_attribute: std::ffi::c_uint,
|
_attribute: std::ffi::c_uint,
|
||||||
_unknown: std::ffi::c_int,
|
_unknown: std::ffi::c_int,
|
||||||
_result: *mut [usize; 2],
|
_result: *mut [usize; 2],
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn device_get_something(
|
unsafe extern "system" fn device_get_something(
|
||||||
_result: *mut std::ffi::c_uchar,
|
_result: *mut std::ffi::c_uchar,
|
||||||
_dev: cuda_types::cuda::CUdevice,
|
_dev: cuda_types::cuda::CUdevice,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn integrity_check(
|
unsafe extern "system" fn integrity_check(
|
||||||
version: u32,
|
version: u32,
|
||||||
unix_seconds: u64,
|
unix_seconds: u64,
|
||||||
result: *mut [u64; 2],
|
result: *mut [u64; 2],
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
let current_process = std::process::id();
|
let current_process = std::process::id();
|
||||||
let current_thread = os::current_thread();
|
let current_thread = os::current_thread();
|
||||||
|
|
||||||
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
|
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
|
||||||
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
|
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
|
||||||
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
|
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
|
||||||
|
|
||||||
let devices = get_device_hash_info()?;
|
let devices = get_device_hash_info()?;
|
||||||
let device_count = devices.len() as u32;
|
let device_count = devices.len() as u32;
|
||||||
let get_device = |dev| devices[dev as usize];
|
let get_device = |dev| devices[dev as usize];
|
||||||
|
|
||||||
let hash = ::dark_api::integrity_check(
|
let hash = ::dark_api::integrity_check(
|
||||||
version,
|
version,
|
||||||
unix_seconds,
|
unix_seconds,
|
||||||
cuda_types::cuda::CUDA_VERSION,
|
cuda_types::cuda::CUDA_VERSION,
|
||||||
current_process,
|
current_process,
|
||||||
current_thread,
|
current_thread,
|
||||||
integrity_check_table,
|
integrity_check_table,
|
||||||
cudart_table,
|
cudart_table,
|
||||||
fn_address,
|
fn_address,
|
||||||
device_count,
|
device_count,
|
||||||
get_device,
|
get_device,
|
||||||
);
|
);
|
||||||
*result = hash;
|
*result = hash;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn context_check(
|
unsafe extern "system" fn context_check(
|
||||||
_ctx_in: cuda_types::cuda::CUcontext,
|
_ctx_in: cuda_types::cuda::CUcontext,
|
||||||
result1: *mut u32,
|
result1: *mut u32,
|
||||||
_result2: *mut *const std::ffi::c_void,
|
_result2: *mut *const std::ffi::c_void,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
*result1 = 0;
|
*result1 = 0;
|
||||||
CUresult::SUCCESS
|
CUresult::SUCCESS
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn check_fn3() -> u32 {
|
unsafe extern "system" fn check_fn3() -> u32 {
|
||||||
0
|
0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
|
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
|
||||||
let mut device_count = 0;
|
let mut device_count = 0;
|
||||||
device::get_count(&mut device_count)?;
|
device::get_count(&mut device_count)?;
|
||||||
|
|
||||||
(0..device_count)
|
(0..device_count)
|
||||||
.map(|dev| {
|
.map(|dev| {
|
||||||
let mut guid = CUuuid_st { bytes: [0; 16] };
|
let mut guid = CUuuid_st { bytes: [0; 16] };
|
||||||
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
|
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
|
||||||
|
|
||||||
let mut pci_domain = 0;
|
let mut pci_domain = 0;
|
||||||
device::get_attribute(
|
device::get_attribute(
|
||||||
&mut pci_domain,
|
&mut pci_domain,
|
||||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
|
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
|
||||||
dev,
|
dev,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut pci_bus = 0;
|
let mut pci_bus = 0;
|
||||||
device::get_attribute(
|
device::get_attribute(
|
||||||
&mut pci_bus,
|
&mut pci_bus,
|
||||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
|
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
|
||||||
dev,
|
dev,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut pci_device = 0;
|
let mut pci_device = 0;
|
||||||
device::get_attribute(
|
device::get_attribute(
|
||||||
&mut pci_device,
|
&mut pci_device,
|
||||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
|
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
|
||||||
dev,
|
dev,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(::dark_api::DeviceHashinfo {
|
Ok(::dark_api::DeviceHashinfo {
|
||||||
guid,
|
guid,
|
||||||
pci_domain,
|
pci_domain,
|
||||||
pci_bus,
|
pci_bus,
|
||||||
pci_device,
|
pci_device,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
|
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
|
||||||
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();
|
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();
|
||||||
|
|
||||||
pub(crate) fn get_export_table(
|
pub(crate) fn get_export_table(
|
||||||
pp_export_table: &mut *const ::core::ffi::c_void,
|
pp_export_table: &mut *const ::core::ffi::c_void,
|
||||||
p_export_table_id: &CUuuid,
|
p_export_table_id: &CUuuid,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
if let Some(table) = EXPORT_TABLE.get(p_export_table_id) {
|
if let Some(table) = EXPORT_TABLE.get(p_export_table_id) {
|
||||||
*pp_export_table = table.start();
|
*pp_export_table = table.start();
|
||||||
cuda_types::cuda::CUresult::SUCCESS
|
cuda_types::cuda::CUresult::SUCCESS
|
||||||
} else {
|
} else {
|
||||||
cuda_types::cuda::CUresult::ERROR_INVALID_VALUE
|
cuda_types::cuda::CUresult::ERROR_INVALID_VALUE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
|
pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
|
||||||
*version = cuda_types::cuda::CUDA_VERSION as i32;
|
*version = cuda_types::cuda::CUDA_VERSION as i32;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn get_proc_address(
|
pub(crate) unsafe fn get_proc_address(
|
||||||
symbol: &CStr,
|
symbol: &CStr,
|
||||||
pfn: &mut *mut ::core::ffi::c_void,
|
pfn: &mut *mut ::core::ffi::c_void,
|
||||||
cuda_version: ::core::ffi::c_int,
|
cuda_version: ::core::ffi::c_int,
|
||||||
flags: cuda_types::cuda::cuuint64_t,
|
flags: cuda_types::cuda::cuuint64_t,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
get_proc_address_v2(symbol, pfn, cuda_version, flags, None)
|
get_proc_address_v2(symbol, pfn, cuda_version, flags, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn get_proc_address_v2(
|
pub(crate) unsafe fn get_proc_address_v2(
|
||||||
symbol: &CStr,
|
symbol: &CStr,
|
||||||
pfn: &mut *mut ::core::ffi::c_void,
|
pfn: &mut *mut ::core::ffi::c_void,
|
||||||
cuda_version: ::core::ffi::c_int,
|
cuda_version: ::core::ffi::c_int,
|
||||||
flags: cuda_types::cuda::cuuint64_t,
|
flags: cuda_types::cuda::cuuint64_t,
|
||||||
symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>,
|
symbol_status: Option<&mut cuda_types::cuda::CUdriverProcAddressQueryResult>,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
// This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point.
|
// This implementation is mostly the same as cuGetProcAddress_v2 in zluda_dump. We may want to factor out the duplication at some point.
|
||||||
fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void {
|
fn raw_match(name: &[u8], flag: u64, version: i32) -> *mut ::core::ffi::c_void {
|
||||||
use crate::*;
|
use crate::*;
|
||||||
include!("../../../zluda_bindgen/src/process_table.rs")
|
include!("../../../zluda_bindgen/src/process_table.rs")
|
||||||
}
|
}
|
||||||
let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version);
|
let fn_ptr = raw_match(symbol.to_bytes(), flags, cuda_version);
|
||||||
match fn_ptr as usize {
|
match fn_ptr as usize {
|
||||||
0 => {
|
0 => {
|
||||||
if let Some(symbol_status) = symbol_status {
|
if let Some(symbol_status) = symbol_status {
|
||||||
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND;
|
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND;
|
||||||
}
|
}
|
||||||
*pfn = ptr::null_mut();
|
*pfn = ptr::null_mut();
|
||||||
CUresult::ERROR_NOT_FOUND
|
CUresult::ERROR_NOT_FOUND
|
||||||
}
|
}
|
||||||
usize::MAX => {
|
usize::MAX => {
|
||||||
if let Some(symbol_status) = symbol_status {
|
if let Some(symbol_status) = symbol_status {
|
||||||
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT;
|
*symbol_status = cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT;
|
||||||
}
|
}
|
||||||
*pfn = ptr::null_mut();
|
*pfn = ptr::null_mut();
|
||||||
CUresult::ERROR_NOT_FOUND
|
CUresult::ERROR_NOT_FOUND
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if let Some(symbol_status) = symbol_status {
|
if let Some(symbol_status) = symbol_status {
|
||||||
*symbol_status =
|
*symbol_status =
|
||||||
cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS;
|
cuda_types::cuda::CUdriverProcAddressQueryResult::CU_GET_PROC_ADDRESS_SUCCESS;
|
||||||
}
|
}
|
||||||
*pfn = fn_ptr;
|
*pfn = fn_ptr;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn profiler_start() -> CUresult {
|
pub(crate) fn profiler_start() -> CUresult {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn profiler_stop() -> CUresult {
|
pub(crate) fn profiler_stop() -> CUresult {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
// TODO: remove duplication with zluda_dump
|
// TODO: remove duplication with zluda_dump
|
||||||
#[link(name = "pthread")]
|
#[link(name = "pthread")]
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn current_thread() -> u32 {
|
pub(crate) fn current_thread() -> u32 {
|
||||||
(unsafe { pthread_self() }) as u32
|
(unsafe { pthread_self() }) as u32
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
// TODO: remove duplication with zluda_dump
|
// TODO: remove duplication with zluda_dump
|
||||||
#[link(name = "kernel32")]
|
#[link(name = "kernel32")]
|
||||||
unsafe extern "system" {
|
unsafe extern "system" {
|
||||||
fn GetCurrentThreadId() -> u32;
|
fn GetCurrentThreadId() -> u32;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn current_thread() -> u32 {
|
pub(crate) fn current_thread() -> u32 {
|
||||||
unsafe { GetCurrentThreadId() }
|
unsafe { GetCurrentThreadId() }
|
||||||
}
|
}
|
||||||
|
@ -1,124 +1,124 @@
|
|||||||
use crate::os;
|
use crate::os;
|
||||||
use crate::{CudaFunctionName, ErrorEntry};
|
use crate::{CudaFunctionName, ErrorEntry};
|
||||||
use cuda_types::cuda::*;
|
use cuda_types::cuda::*;
|
||||||
use rustc_hash::FxHashMap;
|
use rustc_hash::FxHashMap;
|
||||||
use std::cell::RefMut;
|
use std::cell::RefMut;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::{collections::hash_map, ffi::c_void, mem};
|
use std::{collections::hash_map, ffi::c_void, mem};
|
||||||
|
|
||||||
pub(crate) struct DarkApiState2 {
|
pub(crate) struct DarkApiState2 {
|
||||||
// Key is Box<CUuuid, because thunk reporting unknown export table needs a
|
// Key is Box<CUuuid, because thunk reporting unknown export table needs a
|
||||||
// stable memory location for the guid
|
// stable memory location for the guid
|
||||||
pub(crate) overrides: FxHashMap<Box<CUuuidWrapper>, (*const *const c_void, Vec<*const c_void>)>,
|
pub(crate) overrides: FxHashMap<Box<CUuuidWrapper>, (*const *const c_void, Vec<*const c_void>)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for DarkApiState2 {}
|
unsafe impl Send for DarkApiState2 {}
|
||||||
unsafe impl Sync for DarkApiState2 {}
|
unsafe impl Sync for DarkApiState2 {}
|
||||||
|
|
||||||
impl DarkApiState2 {
|
impl DarkApiState2 {
|
||||||
pub(crate) fn new() -> Self {
|
pub(crate) fn new() -> Self {
|
||||||
DarkApiState2 {
|
DarkApiState2 {
|
||||||
overrides: FxHashMap::default(),
|
overrides: FxHashMap::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn override_export_table(
|
pub(crate) fn override_export_table(
|
||||||
&mut self,
|
&mut self,
|
||||||
known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable,
|
known_exports: &::dark_api::cuda::CudaDarkApiGlobalTable,
|
||||||
original_export_table: *const *const c_void,
|
original_export_table: *const *const c_void,
|
||||||
guid: &CUuuid_st,
|
guid: &CUuuid_st,
|
||||||
) -> (*const *const c_void, Option<ErrorEntry>) {
|
) -> (*const *const c_void, Option<ErrorEntry>) {
|
||||||
let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) {
|
let entry = match self.overrides.entry(Box::new(CUuuidWrapper(*guid))) {
|
||||||
hash_map::Entry::Occupied(entry) => {
|
hash_map::Entry::Occupied(entry) => {
|
||||||
let (_, override_table) = entry.get();
|
let (_, override_table) = entry.get();
|
||||||
return (override_table.as_ptr(), None);
|
return (override_table.as_ptr(), None);
|
||||||
}
|
}
|
||||||
hash_map::Entry::Vacant(entry) => entry,
|
hash_map::Entry::Vacant(entry) => entry,
|
||||||
};
|
};
|
||||||
let mut error = None;
|
let mut error = None;
|
||||||
let byte_size: usize = unsafe { *(original_export_table.cast::<usize>()) };
|
let byte_size: usize = unsafe { *(original_export_table.cast::<usize>()) };
|
||||||
// Some export tables don't start with a byte count, but directly with a
|
// Some export tables don't start with a byte count, but directly with a
|
||||||
// pointer, and are instead terminated by 0 or MAX
|
// pointer, and are instead terminated by 0 or MAX
|
||||||
let export_functions_start_idx;
|
let export_functions_start_idx;
|
||||||
let export_functions_size;
|
let export_functions_size;
|
||||||
if byte_size > 0x10000 {
|
if byte_size > 0x10000 {
|
||||||
export_functions_start_idx = 0;
|
export_functions_start_idx = 0;
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
loop {
|
loop {
|
||||||
let current_ptr = unsafe { original_export_table.add(i) };
|
let current_ptr = unsafe { original_export_table.add(i) };
|
||||||
let current_ptr_numeric = unsafe { *current_ptr } as usize;
|
let current_ptr_numeric = unsafe { *current_ptr } as usize;
|
||||||
if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX {
|
if current_ptr_numeric == 0usize || current_ptr_numeric == usize::MAX {
|
||||||
export_functions_size = i;
|
export_functions_size = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
export_functions_start_idx = 1;
|
export_functions_start_idx = 1;
|
||||||
export_functions_size = byte_size / mem::size_of::<usize>();
|
export_functions_size = byte_size / mem::size_of::<usize>();
|
||||||
}
|
}
|
||||||
let our_functions = known_exports.get(guid);
|
let our_functions = known_exports.get(guid);
|
||||||
if let Some(ref our_functions) = our_functions {
|
if let Some(ref our_functions) = our_functions {
|
||||||
if our_functions.len() != export_functions_size {
|
if our_functions.len() != export_functions_size {
|
||||||
error = Some(ErrorEntry::UnexpectedExportTableSize {
|
error = Some(ErrorEntry::UnexpectedExportTableSize {
|
||||||
expected: our_functions.len(),
|
expected: our_functions.len(),
|
||||||
computed: export_functions_size,
|
computed: export_functions_size,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let mut override_table =
|
let mut override_table =
|
||||||
unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) }
|
unsafe { std::slice::from_raw_parts(original_export_table, export_functions_size) }
|
||||||
.to_vec();
|
.to_vec();
|
||||||
for i in export_functions_start_idx..export_functions_size {
|
for i in export_functions_start_idx..export_functions_size {
|
||||||
let current_fn = (|| {
|
let current_fn = (|| {
|
||||||
if let Some(ref our_functions) = our_functions {
|
if let Some(ref our_functions) = our_functions {
|
||||||
if let Some(fn_) = our_functions.get_fn(i) {
|
if let Some(fn_) = our_functions.get_fn(i) {
|
||||||
return fn_;
|
return fn_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
os::get_thunk(
|
os::get_thunk(
|
||||||
override_table[i],
|
override_table[i],
|
||||||
Self::report_unknown_export_table_call,
|
Self::report_unknown_export_table_call,
|
||||||
std::ptr::from_ref(entry.key().as_ref()).cast(),
|
std::ptr::from_ref(entry.key().as_ref()).cast(),
|
||||||
i,
|
i,
|
||||||
)
|
)
|
||||||
})();
|
})();
|
||||||
override_table[i] = current_fn;
|
override_table[i] = current_fn;
|
||||||
}
|
}
|
||||||
(
|
(
|
||||||
entry
|
entry
|
||||||
.insert((original_export_table, override_table))
|
.insert((original_export_table, override_table))
|
||||||
.1
|
.1
|
||||||
.as_ptr(),
|
.as_ptr(),
|
||||||
error,
|
error,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) {
|
unsafe extern "system" fn report_unknown_export_table_call(guid: &CUuuid, index: usize) {
|
||||||
let global_state = crate::GLOBAL_STATE2.lock();
|
let global_state = crate::GLOBAL_STATE2.lock();
|
||||||
let global_state_ref_cell = &*global_state;
|
let global_state_ref_cell = &*global_state;
|
||||||
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
|
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
|
||||||
let global_state = &mut *global_state_ref_mut;
|
let global_state = &mut *global_state_ref_mut;
|
||||||
let log_guard = crate::OuterCallGuard {
|
let log_guard = crate::OuterCallGuard {
|
||||||
writer: &mut global_state.log_writer,
|
writer: &mut global_state.log_writer,
|
||||||
log_root: &global_state.log_stack,
|
log_root: &global_state.log_stack,
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| {
|
let mut logger = RefMut::map(global_state.log_stack.borrow_mut(), |log_stack| {
|
||||||
log_stack.enter()
|
log_stack.enter()
|
||||||
});
|
});
|
||||||
logger.name = CudaFunctionName::Dark { guid: *guid, index };
|
logger.name = CudaFunctionName::Dark { guid: *guid, index };
|
||||||
};
|
};
|
||||||
drop(log_guard);
|
drop(log_guard);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Eq, PartialEq)]
|
#[derive(Eq, PartialEq)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
pub(crate) struct CUuuidWrapper(pub CUuuid);
|
pub(crate) struct CUuuidWrapper(pub CUuuid);
|
||||||
|
|
||||||
impl Hash for CUuuidWrapper {
|
impl Hash for CUuuidWrapper {
|
||||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
self.0.bytes.hash(state);
|
self.0.bytes.hash(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,81 +1,81 @@
|
|||||||
use cuda_types::cuda::CUuuid;
|
use cuda_types::cuda::CUuuid;
|
||||||
use std::ffi::{c_void, CStr, CString};
|
use std::ffi::{c_void, CStr, CString};
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
|
||||||
pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1";
|
pub(crate) const LIBCUDA_DEFAULT_PATH: &str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1";
|
||||||
|
|
||||||
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
|
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
|
||||||
let libcuda_path = CString::new(libcuda_path).unwrap();
|
let libcuda_path = CString::new(libcuda_path).unwrap();
|
||||||
libc::dlopen(
|
libc::dlopen(
|
||||||
libcuda_path.as_ptr() as *const _,
|
libcuda_path.as_ptr() as *const _,
|
||||||
libc::RTLD_LOCAL | libc::RTLD_NOW,
|
libc::RTLD_LOCAL | libc::RTLD_NOW,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||||
libc::dlsym(handle, func.as_ptr() as *const _)
|
libc::dlsym(handle, func.as_ptr() as *const _)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! os_log {
|
macro_rules! os_log {
|
||||||
($format:tt) => {
|
($format:tt) => {
|
||||||
{
|
{
|
||||||
eprintln!("[ZLUDA_DUMP] {}", format!($format));
|
eprintln!("[ZLUDA_DUMP] {}", format!($format));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
($format:tt, $($obj: expr),+) => {
|
($format:tt, $($obj: expr),+) => {
|
||||||
{
|
{
|
||||||
eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+));
|
eprintln!("[ZLUDA_DUMP] {}", format!($format, $($obj,)+));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
//RDI, RSI, RDX, RCX, R8, R9
|
//RDI, RSI, RDX, RCX, R8, R9
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
pub fn get_thunk(
|
pub fn get_thunk(
|
||||||
original_fn: *const c_void,
|
original_fn: *const c_void,
|
||||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||||
guid: *const CUuuid,
|
guid: *const CUuuid,
|
||||||
idx: usize,
|
idx: usize,
|
||||||
) -> *const c_void {
|
) -> *const c_void {
|
||||||
use dynasmrt::{dynasm, DynasmApi};
|
use dynasmrt::{dynasm, DynasmApi};
|
||||||
let mut ops = dynasmrt::x64::Assembler::new().unwrap();
|
let mut ops = dynasmrt::x64::Assembler::new().unwrap();
|
||||||
let start = ops.offset();
|
let start = ops.offset();
|
||||||
dynasm!(ops
|
dynasm!(ops
|
||||||
// stack alignment
|
// stack alignment
|
||||||
; sub rsp, 8
|
; sub rsp, 8
|
||||||
; push rdi
|
; push rdi
|
||||||
; push rsi
|
; push rsi
|
||||||
; push rdx
|
; push rdx
|
||||||
; push rcx
|
; push rcx
|
||||||
; push r8
|
; push r8
|
||||||
; push r9
|
; push r9
|
||||||
; mov rdi, QWORD guid as i64
|
; mov rdi, QWORD guid as i64
|
||||||
; mov rsi, QWORD idx as i64
|
; mov rsi, QWORD idx as i64
|
||||||
; mov rax, QWORD report_fn as i64
|
; mov rax, QWORD report_fn as i64
|
||||||
; call rax
|
; call rax
|
||||||
; pop r9
|
; pop r9
|
||||||
; pop r8
|
; pop r8
|
||||||
; pop rcx
|
; pop rcx
|
||||||
; pop rdx
|
; pop rdx
|
||||||
; pop rsi
|
; pop rsi
|
||||||
; pop rdi
|
; pop rdi
|
||||||
; add rsp, 8
|
; add rsp, 8
|
||||||
; mov rax, QWORD original_fn as i64
|
; mov rax, QWORD original_fn as i64
|
||||||
; jmp rax
|
; jmp rax
|
||||||
; int 3
|
; int 3
|
||||||
);
|
);
|
||||||
let exe_buf = ops.finalize().unwrap();
|
let exe_buf = ops.finalize().unwrap();
|
||||||
let result_fn = exe_buf.ptr(start);
|
let result_fn = exe_buf.ptr(start);
|
||||||
mem::forget(exe_buf);
|
mem::forget(exe_buf);
|
||||||
result_fn as *const _
|
result_fn as *const _
|
||||||
}
|
}
|
||||||
|
|
||||||
#[link(name = "pthread")]
|
#[link(name = "pthread")]
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn current_thread() -> u32 {
|
pub(crate) fn current_thread() -> u32 {
|
||||||
(unsafe { pthread_self() }) as u32
|
(unsafe { pthread_self() }) as u32
|
||||||
}
|
}
|
||||||
|
@ -1,190 +1,190 @@
|
|||||||
use std::{
|
use std::{
|
||||||
ffi::{c_void, CStr},
|
ffi::{c_void, CStr},
|
||||||
mem, ptr,
|
mem, ptr,
|
||||||
sync::LazyLock,
|
sync::LazyLock,
|
||||||
};
|
};
|
||||||
|
|
||||||
use std::os::windows::io::AsRawHandle;
|
use std::os::windows::io::AsRawHandle;
|
||||||
use winapi::{
|
use winapi::{
|
||||||
shared::minwindef::{FARPROC, HMODULE},
|
shared::minwindef::{FARPROC, HMODULE},
|
||||||
um::debugapi::OutputDebugStringA,
|
um::debugapi::OutputDebugStringA,
|
||||||
um::libloaderapi::{GetProcAddress, LoadLibraryW},
|
um::libloaderapi::{GetProcAddress, LoadLibraryW},
|
||||||
};
|
};
|
||||||
|
|
||||||
use cuda_types::cuda::CUuuid;
|
use cuda_types::cuda::CUuuid;
|
||||||
|
|
||||||
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
|
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
|
||||||
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
|
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
|
||||||
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
|
const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
|
||||||
|
|
||||||
static PLATFORM_LIBRARY: LazyLock<PlatformLibrary> =
|
static PLATFORM_LIBRARY: LazyLock<PlatformLibrary> =
|
||||||
LazyLock::new(|| unsafe { PlatformLibrary::new() });
|
LazyLock::new(|| unsafe { PlatformLibrary::new() });
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
struct PlatformLibrary {
|
struct PlatformLibrary {
|
||||||
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
|
LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
|
||||||
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
|
GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PlatformLibrary {
|
impl PlatformLibrary {
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
unsafe fn new() -> Self {
|
unsafe fn new() -> Self {
|
||||||
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
|
let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
|
||||||
None => (
|
None => (
|
||||||
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
|
LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
|
||||||
mem::transmute(
|
mem::transmute(
|
||||||
GetProcAddress
|
GetProcAddress
|
||||||
as unsafe extern "system" fn(
|
as unsafe extern "system" fn(
|
||||||
hModule: HMODULE,
|
hModule: HMODULE,
|
||||||
lpProcName: *const i8,
|
lpProcName: *const i8,
|
||||||
) -> FARPROC,
|
) -> FARPROC,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
Some(zluda_with) => (
|
Some(zluda_with) => (
|
||||||
mem::transmute(GetProcAddress(
|
mem::transmute(GetProcAddress(
|
||||||
zluda_with,
|
zluda_with,
|
||||||
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
|
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
|
||||||
)),
|
)),
|
||||||
mem::transmute(GetProcAddress(
|
mem::transmute(GetProcAddress(
|
||||||
zluda_with,
|
zluda_with,
|
||||||
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
|
GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
|
||||||
)),
|
)),
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
PlatformLibrary {
|
PlatformLibrary {
|
||||||
LoadLibraryW,
|
LoadLibraryW,
|
||||||
GetProcAddress,
|
GetProcAddress,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn get_detourer_module() -> Option<HMODULE> {
|
unsafe fn get_detourer_module() -> Option<HMODULE> {
|
||||||
let mut module = ptr::null_mut();
|
let mut module = ptr::null_mut();
|
||||||
loop {
|
loop {
|
||||||
module = detours_sys::DetourEnumerateModules(module);
|
module = detours_sys::DetourEnumerateModules(module);
|
||||||
if module == ptr::null_mut() {
|
if module == ptr::null_mut() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
|
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
|
||||||
if payload != ptr::null_mut() {
|
if payload != ptr::null_mut() {
|
||||||
return Some(module as _);
|
return Some(module as _);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
|
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
|
||||||
let libcuda_path_uf16 = libcuda_path
|
let libcuda_path_uf16 = libcuda_path
|
||||||
.encode_utf16()
|
.encode_utf16()
|
||||||
.chain(std::iter::once(0))
|
.chain(std::iter::once(0))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
|
(PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||||
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
|
(PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! os_log {
|
macro_rules! os_log {
|
||||||
($format:tt) => {
|
($format:tt) => {
|
||||||
{
|
{
|
||||||
use crate::os::__log_impl;
|
use crate::os::__log_impl;
|
||||||
__log_impl(format!($format));
|
__log_impl(format!($format));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
($format:tt, $($obj: expr),+) => {
|
($format:tt, $($obj: expr),+) => {
|
||||||
{
|
{
|
||||||
use crate::os::__log_impl;
|
use crate::os::__log_impl;
|
||||||
__log_impl(format!($format, $($obj,)+));
|
__log_impl(format!($format, $($obj,)+));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn __log_impl(s: String) {
|
pub fn __log_impl(s: String) {
|
||||||
let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut();
|
let log_to_stderr = std::io::stderr().as_raw_handle() != ptr::null_mut();
|
||||||
if log_to_stderr {
|
if log_to_stderr {
|
||||||
eprintln!("[ZLUDA_DUMP] {}", s);
|
eprintln!("[ZLUDA_DUMP] {}", s);
|
||||||
} else {
|
} else {
|
||||||
let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2);
|
let mut win_str = String::with_capacity("[ZLUDA_DUMP] ".len() + s.len() + 2);
|
||||||
win_str.push_str("[ZLUDA_DUMP] ");
|
win_str.push_str("[ZLUDA_DUMP] ");
|
||||||
win_str.push_str(&s);
|
win_str.push_str(&s);
|
||||||
win_str.push_str("\n\0");
|
win_str.push_str("\n\0");
|
||||||
unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) };
|
unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_arch = "x86")]
|
#[cfg(target_arch = "x86")]
|
||||||
pub fn get_thunk(
|
pub fn get_thunk(
|
||||||
original_fn: *const c_void,
|
original_fn: *const c_void,
|
||||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||||
guid: *const CUuuid,
|
guid: *const CUuuid,
|
||||||
idx: usize,
|
idx: usize,
|
||||||
) -> *const c_void {
|
) -> *const c_void {
|
||||||
use dynasmrt::{dynasm, DynasmApi};
|
use dynasmrt::{dynasm, DynasmApi};
|
||||||
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
|
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
|
||||||
let start = ops.offset();
|
let start = ops.offset();
|
||||||
dynasm!(ops
|
dynasm!(ops
|
||||||
; .arch x86
|
; .arch x86
|
||||||
; push idx as i32
|
; push idx as i32
|
||||||
; push guid as i32
|
; push guid as i32
|
||||||
; mov eax, report_fn as i32
|
; mov eax, report_fn as i32
|
||||||
; call eax
|
; call eax
|
||||||
; mov eax, original_fn as i32
|
; mov eax, original_fn as i32
|
||||||
; jmp eax
|
; jmp eax
|
||||||
; int 3
|
; int 3
|
||||||
);
|
);
|
||||||
let exe_buf = ops.finalize().unwrap();
|
let exe_buf = ops.finalize().unwrap();
|
||||||
let result_fn = exe_buf.ptr(start);
|
let result_fn = exe_buf.ptr(start);
|
||||||
mem::forget(exe_buf);
|
mem::forget(exe_buf);
|
||||||
result_fn as *const _
|
result_fn as *const _
|
||||||
}
|
}
|
||||||
|
|
||||||
//RCX, RDX, R8, R9
|
//RCX, RDX, R8, R9
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
pub fn get_thunk(
|
pub fn get_thunk(
|
||||||
original_fn: *const c_void,
|
original_fn: *const c_void,
|
||||||
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
report_fn: unsafe extern "system" fn(&CUuuid, usize),
|
||||||
guid: *const CUuuid,
|
guid: *const CUuuid,
|
||||||
idx: usize,
|
idx: usize,
|
||||||
) -> *const c_void {
|
) -> *const c_void {
|
||||||
use dynasmrt::{dynasm, DynasmApi};
|
use dynasmrt::{dynasm, DynasmApi};
|
||||||
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
|
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
|
||||||
let start = ops.offset();
|
let start = ops.offset();
|
||||||
// Let's hope there's never more than 4 arguments
|
// Let's hope there's never more than 4 arguments
|
||||||
dynasm!(ops
|
dynasm!(ops
|
||||||
; .arch x64
|
; .arch x64
|
||||||
; push rbp
|
; push rbp
|
||||||
; mov rbp, rsp
|
; mov rbp, rsp
|
||||||
; push rcx
|
; push rcx
|
||||||
; push rdx
|
; push rdx
|
||||||
; push r8
|
; push r8
|
||||||
; push r9
|
; push r9
|
||||||
; mov rcx, QWORD guid as i64
|
; mov rcx, QWORD guid as i64
|
||||||
; mov rdx, QWORD idx as i64
|
; mov rdx, QWORD idx as i64
|
||||||
; mov rax, QWORD report_fn as i64
|
; mov rax, QWORD report_fn as i64
|
||||||
; call rax
|
; call rax
|
||||||
; pop r9
|
; pop r9
|
||||||
; pop r8
|
; pop r8
|
||||||
; pop rdx
|
; pop rdx
|
||||||
; pop rcx
|
; pop rcx
|
||||||
; mov rax, QWORD original_fn as i64
|
; mov rax, QWORD original_fn as i64
|
||||||
; call rax
|
; call rax
|
||||||
; pop rbp
|
; pop rbp
|
||||||
; ret
|
; ret
|
||||||
; int 3
|
; int 3
|
||||||
);
|
);
|
||||||
let exe_buf = ops.finalize().unwrap();
|
let exe_buf = ops.finalize().unwrap();
|
||||||
let result_fn = exe_buf.ptr(start);
|
let result_fn = exe_buf.ptr(start);
|
||||||
mem::forget(exe_buf);
|
mem::forget(exe_buf);
|
||||||
result_fn as *const _
|
result_fn as *const _
|
||||||
}
|
}
|
||||||
|
|
||||||
#[link(name = "kernel32")]
|
#[link(name = "kernel32")]
|
||||||
unsafe extern "system" {
|
unsafe extern "system" {
|
||||||
fn GetCurrentThreadId() -> u32;
|
fn GetCurrentThreadId() -> u32;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn current_thread() -> u32 {
|
pub(crate) fn current_thread() -> u32 {
|
||||||
unsafe { GetCurrentThreadId() }
|
unsafe { GetCurrentThreadId() }
|
||||||
}
|
}
|
||||||
|
@ -1,334 +1,334 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
log::{self, UInt},
|
log::{self, UInt},
|
||||||
trace, ErrorEntry, FnCallLog, Settings,
|
trace, ErrorEntry, FnCallLog, Settings,
|
||||||
};
|
};
|
||||||
use cuda_types::{
|
use cuda_types::{
|
||||||
cuda::*,
|
cuda::*,
|
||||||
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
|
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
|
||||||
};
|
};
|
||||||
use dark_api::fatbin::{
|
use dark_api::fatbin::{
|
||||||
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
|
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
|
||||||
};
|
};
|
||||||
use rustc_hash::{FxHashMap, FxHashSet};
|
use rustc_hash::{FxHashMap, FxHashSet};
|
||||||
use std::{
|
use std::{
|
||||||
borrow::Cow,
|
borrow::Cow,
|
||||||
ffi::{c_void, CStr, CString},
|
ffi::{c_void, CStr, CString},
|
||||||
fs::{self, File},
|
fs::{self, File},
|
||||||
io::{self, Read, Write},
|
io::{self, Read, Write},
|
||||||
path::PathBuf,
|
path::PathBuf,
|
||||||
};
|
};
|
||||||
use unwrap_or::unwrap_some_or;
|
use unwrap_or::unwrap_some_or;
|
||||||
|
|
||||||
// This struct is the heart of CUDA state tracking, it:
|
// This struct is the heart of CUDA state tracking, it:
|
||||||
// * receives calls from the probes about changes to CUDA state
|
// * receives calls from the probes about changes to CUDA state
|
||||||
// * records updates to the state change
|
// * records updates to the state change
|
||||||
// * writes out relevant state change and details to disk and log
|
// * writes out relevant state change and details to disk and log
|
||||||
pub(crate) struct StateTracker {
|
pub(crate) struct StateTracker {
|
||||||
writer: DumpWriter,
|
writer: DumpWriter,
|
||||||
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
|
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
|
||||||
saved_modules: FxHashSet<CUmodule>,
|
saved_modules: FxHashSet<CUmodule>,
|
||||||
module_counter: usize,
|
module_counter: usize,
|
||||||
submodule_counter: usize,
|
submodule_counter: usize,
|
||||||
pub(crate) override_cc: Option<(u32, u32)>,
|
pub(crate) override_cc: Option<(u32, u32)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
pub(crate) struct CodePointer(pub *const c_void);
|
pub(crate) struct CodePointer(pub *const c_void);
|
||||||
|
|
||||||
unsafe impl Send for CodePointer {}
|
unsafe impl Send for CodePointer {}
|
||||||
unsafe impl Sync for CodePointer {}
|
unsafe impl Sync for CodePointer {}
|
||||||
|
|
||||||
impl StateTracker {
|
impl StateTracker {
|
||||||
pub(crate) fn new(settings: &Settings) -> Self {
|
pub(crate) fn new(settings: &Settings) -> Self {
|
||||||
StateTracker {
|
StateTracker {
|
||||||
writer: DumpWriter::new(settings.dump_dir.clone()),
|
writer: DumpWriter::new(settings.dump_dir.clone()),
|
||||||
libraries: FxHashMap::default(),
|
libraries: FxHashMap::default(),
|
||||||
saved_modules: FxHashSet::default(),
|
saved_modules: FxHashSet::default(),
|
||||||
module_counter: 0,
|
module_counter: 0,
|
||||||
submodule_counter: 0,
|
submodule_counter: 0,
|
||||||
override_cc: settings.override_cc,
|
override_cc: settings.override_cc,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn record_new_module_file(
|
pub(crate) fn record_new_module_file(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
file_name: *const i8,
|
file_name: *const i8,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
) {
|
) {
|
||||||
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
|
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
|
||||||
Ok(f) => f,
|
Ok(f) => f,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
|
fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name);
|
let maybe_io_error = self.try_record_new_module_file(module, fn_logger, file_name);
|
||||||
fn_logger.log_io_error(maybe_io_error)
|
fn_logger.log_io_error(maybe_io_error)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_record_new_module_file(
|
fn try_record_new_module_file(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
file_name: &str,
|
file_name: &str,
|
||||||
) -> io::Result<()> {
|
) -> io::Result<()> {
|
||||||
let mut module_file = fs::File::open(file_name)?;
|
let mut module_file = fs::File::open(file_name)?;
|
||||||
let mut read_buff = Vec::new();
|
let mut read_buff = Vec::new();
|
||||||
module_file.read_to_end(&mut read_buff)?;
|
module_file.read_to_end(&mut read_buff)?;
|
||||||
self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger);
|
self.record_new_module(module, read_buff.as_ptr() as *const _, fn_logger);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn record_new_submodule(
|
pub(crate) fn record_new_submodule(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
submodule: &[u8],
|
submodule: &[u8],
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
type_: &'static str,
|
type_: &'static str,
|
||||||
) {
|
) {
|
||||||
if self.saved_modules.insert(module) {
|
if self.saved_modules.insert(module) {
|
||||||
self.module_counter += 1;
|
self.module_counter += 1;
|
||||||
self.submodule_counter = 0;
|
self.submodule_counter = 0;
|
||||||
}
|
}
|
||||||
self.submodule_counter += 1;
|
self.submodule_counter += 1;
|
||||||
fn_logger.log_io_error(self.writer.save_module(
|
fn_logger.log_io_error(self.writer.save_module(
|
||||||
self.module_counter,
|
self.module_counter,
|
||||||
Some(self.submodule_counter),
|
Some(self.submodule_counter),
|
||||||
submodule,
|
submodule,
|
||||||
type_,
|
type_,
|
||||||
));
|
));
|
||||||
if type_ == "ptx" {
|
if type_ == "ptx" {
|
||||||
match CString::new(submodule) {
|
match CString::new(submodule) {
|
||||||
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
|
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
|
||||||
Ok(submodule_cstring) => match submodule_cstring.to_str() {
|
Ok(submodule_cstring) => match submodule_cstring.to_str() {
|
||||||
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
|
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
|
||||||
Ok(submodule_text) => self.try_parse_and_record_kernels(
|
Ok(submodule_text) => self.try_parse_and_record_kernels(
|
||||||
fn_logger,
|
fn_logger,
|
||||||
self.module_counter,
|
self.module_counter,
|
||||||
Some(self.submodule_counter),
|
Some(self.submodule_counter),
|
||||||
submodule_text,
|
submodule_text,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn record_new_module(
|
pub(crate) fn record_new_module(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
raw_image: *const c_void,
|
raw_image: *const c_void,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
) {
|
) {
|
||||||
self.module_counter += 1;
|
self.module_counter += 1;
|
||||||
if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG {
|
if unsafe { *(raw_image as *const [u8; 4]) } == *goblin::elf64::header::ELFMAG {
|
||||||
self.saved_modules.insert(module);
|
self.saved_modules.insert(module);
|
||||||
// TODO: Parse ELF and write it to disk
|
// TODO: Parse ELF and write it to disk
|
||||||
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||||
module,
|
module,
|
||||||
raw_image,
|
raw_image,
|
||||||
kind: "ELF",
|
kind: "ELF",
|
||||||
})
|
})
|
||||||
} else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC {
|
} else if unsafe { *(raw_image as *const [u8; 8]) } == *goblin::archive::MAGIC {
|
||||||
self.saved_modules.insert(module);
|
self.saved_modules.insert(module);
|
||||||
// TODO: Figure out how to get size of archive module and write it to disk
|
// TODO: Figure out how to get size of archive module and write it to disk
|
||||||
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||||
module,
|
module,
|
||||||
raw_image,
|
raw_image,
|
||||||
kind: "archive",
|
kind: "archive",
|
||||||
})
|
})
|
||||||
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
|
} else if unsafe { *(raw_image as *const u32) } == FatbincWrapper::MAGIC {
|
||||||
unsafe {
|
unsafe {
|
||||||
fn_logger.try_(|fn_logger| {
|
fn_logger.try_(|fn_logger| {
|
||||||
trace::record_submodules_from_wrapped_fatbin(
|
trace::record_submodules_from_wrapped_fatbin(
|
||||||
module,
|
module,
|
||||||
raw_image as *const FatbincWrapper,
|
raw_image as *const FatbincWrapper,
|
||||||
fn_logger,
|
fn_logger,
|
||||||
self,
|
self,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.record_module_ptx(module, raw_image, fn_logger)
|
self.record_module_ptx(module, raw_image, fn_logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn record_module_ptx(
|
fn record_module_ptx(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
raw_image: *const c_void,
|
raw_image: *const c_void,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
) {
|
) {
|
||||||
self.saved_modules.insert(module);
|
self.saved_modules.insert(module);
|
||||||
let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str();
|
let module_text = unsafe { CStr::from_ptr(raw_image as *const _) }.to_str();
|
||||||
let module_text = match module_text {
|
let module_text = match module_text {
|
||||||
Ok(m) => m,
|
Ok(m) => m,
|
||||||
Err(utf8_err) => {
|
Err(utf8_err) => {
|
||||||
fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err));
|
fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(utf8_err));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
fn_logger.log_io_error(self.writer.save_module(
|
fn_logger.log_io_error(self.writer.save_module(
|
||||||
self.module_counter,
|
self.module_counter,
|
||||||
None,
|
None,
|
||||||
module_text.as_bytes(),
|
module_text.as_bytes(),
|
||||||
"ptx",
|
"ptx",
|
||||||
));
|
));
|
||||||
self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text);
|
self.try_parse_and_record_kernels(fn_logger, self.module_counter, None, module_text);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_parse_and_record_kernels(
|
fn try_parse_and_record_kernels(
|
||||||
&mut self,
|
&mut self,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
module_index: usize,
|
module_index: usize,
|
||||||
submodule_index: Option<usize>,
|
submodule_index: Option<usize>,
|
||||||
module_text: &str,
|
module_text: &str,
|
||||||
) {
|
) {
|
||||||
let errors = ptx_parser::parse_for_errors(module_text);
|
let errors = ptx_parser::parse_for_errors(module_text);
|
||||||
if !errors.is_empty() {
|
if !errors.is_empty() {
|
||||||
fn_logger.log(log::ErrorEntry::ModuleParsingError(
|
fn_logger.log(log::ErrorEntry::ModuleParsingError(
|
||||||
DumpWriter::get_file_name(module_index, submodule_index, "log"),
|
DumpWriter::get_file_name(module_index, submodule_index, "log"),
|
||||||
));
|
));
|
||||||
fn_logger.log_io_error(self.writer.save_module_error_log(
|
fn_logger.log_io_error(self.writer.save_module_error_log(
|
||||||
module_index,
|
module_index,
|
||||||
submodule_index,
|
submodule_index,
|
||||||
&*errors,
|
&*errors,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This structs writes out information about CUDA execution to the dump dir
|
// This structs writes out information about CUDA execution to the dump dir
|
||||||
struct DumpWriter {
|
struct DumpWriter {
|
||||||
dump_dir: Option<PathBuf>,
|
dump_dir: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DumpWriter {
|
impl DumpWriter {
|
||||||
fn new(dump_dir: Option<PathBuf>) -> Self {
|
fn new(dump_dir: Option<PathBuf>) -> Self {
|
||||||
Self { dump_dir }
|
Self { dump_dir }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_module(
|
fn save_module(
|
||||||
&self,
|
&self,
|
||||||
module_index: usize,
|
module_index: usize,
|
||||||
submodule_index: Option<usize>,
|
submodule_index: Option<usize>,
|
||||||
buffer: &[u8],
|
buffer: &[u8],
|
||||||
kind: &'static str,
|
kind: &'static str,
|
||||||
) -> io::Result<()> {
|
) -> io::Result<()> {
|
||||||
let mut dump_file = match &self.dump_dir {
|
let mut dump_file = match &self.dump_dir {
|
||||||
None => return Ok(()),
|
None => return Ok(()),
|
||||||
Some(d) => d.clone(),
|
Some(d) => d.clone(),
|
||||||
};
|
};
|
||||||
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
|
dump_file.push(Self::get_file_name(module_index, submodule_index, kind));
|
||||||
let mut file = File::create(dump_file)?;
|
let mut file = File::create(dump_file)?;
|
||||||
file.write_all(buffer)?;
|
file.write_all(buffer)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_module_error_log<'input>(
|
fn save_module_error_log<'input>(
|
||||||
&self,
|
&self,
|
||||||
module_index: usize,
|
module_index: usize,
|
||||||
submodule_index: Option<usize>,
|
submodule_index: Option<usize>,
|
||||||
errors: &[ptx_parser::PtxError<'input>],
|
errors: &[ptx_parser::PtxError<'input>],
|
||||||
) -> io::Result<()> {
|
) -> io::Result<()> {
|
||||||
let mut log_file = match &self.dump_dir {
|
let mut log_file = match &self.dump_dir {
|
||||||
None => return Ok(()),
|
None => return Ok(()),
|
||||||
Some(d) => d.clone(),
|
Some(d) => d.clone(),
|
||||||
};
|
};
|
||||||
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
|
log_file.push(Self::get_file_name(module_index, submodule_index, "log"));
|
||||||
let mut file = File::create(log_file)?;
|
let mut file = File::create(log_file)?;
|
||||||
for error in errors {
|
for error in errors {
|
||||||
writeln!(file, "{}", error)?;
|
writeln!(file, "{}", error)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String {
|
fn get_file_name(module_index: usize, submodule_index: Option<usize>, kind: &str) -> String {
|
||||||
match submodule_index {
|
match submodule_index {
|
||||||
None => {
|
None => {
|
||||||
format!("module_{:04}.{:02}", module_index, kind)
|
format!("module_{:04}.{:02}", module_index, kind)
|
||||||
}
|
}
|
||||||
Some(submodule_index) => {
|
Some(submodule_index) => {
|
||||||
format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind)
|
format!("module_{:04}_{:02}.{}", module_index, submodule_index, kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
|
pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
fatbinc_wrapper: *const FatbincWrapper,
|
fatbinc_wrapper: *const FatbincWrapper,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
state: &mut StateTracker,
|
state: &mut StateTracker,
|
||||||
) -> Result<(), ErrorEntry> {
|
) -> Result<(), ErrorEntry> {
|
||||||
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
|
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
|
||||||
let mut submodules = fatbin.get_submodules()?;
|
let mut submodules = fatbin.get_submodules()?;
|
||||||
while let Some(current) = submodules.next()? {
|
while let Some(current) = submodules.next()? {
|
||||||
record_submodules_from_fatbin(module, current, fn_logger, state)?;
|
record_submodules_from_fatbin(module, current, fn_logger, state)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn record_submodules_from_fatbin(
|
pub(crate) unsafe fn record_submodules_from_fatbin(
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
submodule: FatbinSubmodule,
|
submodule: FatbinSubmodule,
|
||||||
logger: &mut FnCallLog,
|
logger: &mut FnCallLog,
|
||||||
state: &mut StateTracker,
|
state: &mut StateTracker,
|
||||||
) -> Result<(), ErrorEntry> {
|
) -> Result<(), ErrorEntry> {
|
||||||
record_submodules(module, logger, state, submodule.get_files())?;
|
record_submodules(module, logger, state, submodule.get_files())?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn record_submodules(
|
pub(crate) unsafe fn record_submodules(
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
state: &mut StateTracker,
|
state: &mut StateTracker,
|
||||||
mut files: FatbinFileIterator,
|
mut files: FatbinFileIterator,
|
||||||
) -> Result<(), ErrorEntry> {
|
) -> Result<(), ErrorEntry> {
|
||||||
while let Some(file) = files.next()? {
|
while let Some(file) = files.next()? {
|
||||||
let mut payload = if file
|
let mut payload = if file
|
||||||
.header
|
.header
|
||||||
.flags
|
.flags
|
||||||
.contains(FatbinFileHeaderFlags::CompressedLz4)
|
.contains(FatbinFileHeaderFlags::CompressedLz4)
|
||||||
{
|
{
|
||||||
Cow::Owned(unwrap_some_or!(
|
Cow::Owned(unwrap_some_or!(
|
||||||
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
|
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
|
||||||
continue
|
continue
|
||||||
))
|
))
|
||||||
} else if file
|
} else if file
|
||||||
.header
|
.header
|
||||||
.flags
|
.flags
|
||||||
.contains(FatbinFileHeaderFlags::CompressedZstd)
|
.contains(FatbinFileHeaderFlags::CompressedZstd)
|
||||||
{
|
{
|
||||||
Cow::Owned(unwrap_some_or!(
|
Cow::Owned(unwrap_some_or!(
|
||||||
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
|
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
|
||||||
continue
|
continue
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
Cow::Borrowed(file.get_payload())
|
Cow::Borrowed(file.get_payload())
|
||||||
};
|
};
|
||||||
match file.header.kind {
|
match file.header.kind {
|
||||||
FatbinFileHeader::HEADER_KIND_PTX => {
|
FatbinFileHeader::HEADER_KIND_PTX => {
|
||||||
while payload.last() == Some(&0) {
|
while payload.last() == Some(&0) {
|
||||||
// remove trailing zeros
|
// remove trailing zeros
|
||||||
payload.to_mut().pop();
|
payload.to_mut().pop();
|
||||||
}
|
}
|
||||||
state.record_new_submodule(module, &*payload, fn_logger, "ptx")
|
state.record_new_submodule(module, &*payload, fn_logger, "ptx")
|
||||||
}
|
}
|
||||||
FatbinFileHeader::HEADER_KIND_ELF => {
|
FatbinFileHeader::HEADER_KIND_ELF => {
|
||||||
state.record_new_submodule(module, &*payload, fn_logger, "elf")
|
state.record_new_submodule(module, &*payload, fn_logger, "elf")
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
fn_logger.log(log::ErrorEntry::UnexpectedBinaryField {
|
fn_logger.log(log::ErrorEntry::UnexpectedBinaryField {
|
||||||
field_name: "FATBIN_FILE_HEADER_KIND",
|
field_name: "FATBIN_FILE_HEADER_KIND",
|
||||||
expected: vec![
|
expected: vec![
|
||||||
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
|
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
|
||||||
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
|
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
|
||||||
],
|
],
|
||||||
observed: UInt::U16(file.header.kind),
|
observed: UInt::U16(file.header.kind),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,81 +1,81 @@
|
|||||||
use std::{
|
use std::{
|
||||||
env::{self, VarError},
|
env::{self, VarError},
|
||||||
fs::{self, DirEntry},
|
fs::{self, DirEntry},
|
||||||
io,
|
io,
|
||||||
path::{self, PathBuf},
|
path::{self, PathBuf},
|
||||||
process::Command,
|
process::Command,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn main() -> Result<(), VarError> {
|
fn main() -> Result<(), VarError> {
|
||||||
if std::env::var_os("CARGO_CFG_WINDOWS").is_none() {
|
if std::env::var_os("CARGO_CFG_WINDOWS").is_none() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
if env::var("PROFILE")? != "debug" {
|
if env::var("PROFILE")? != "debug" {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
let rustc_exe = env::var("RUSTC")?;
|
let rustc_exe = env::var("RUSTC")?;
|
||||||
let out_dir = env::var("OUT_DIR")?;
|
let out_dir = env::var("OUT_DIR")?;
|
||||||
let target = env::var("TARGET")?;
|
let target = env::var("TARGET")?;
|
||||||
let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc";
|
let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc";
|
||||||
let opt_level = env::var("OPT_LEVEL")?;
|
let opt_level = env::var("OPT_LEVEL")?;
|
||||||
let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap();
|
let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap();
|
||||||
let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
|
let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
|
||||||
helpers_dir.push("tests");
|
helpers_dir.push("tests");
|
||||||
helpers_dir.push("helpers");
|
helpers_dir.push("helpers");
|
||||||
let helpers_dir_as_string = helpers_dir.to_string_lossy();
|
let helpers_dir_as_string = helpers_dir.to_string_lossy();
|
||||||
println!("cargo:rerun-if-changed={}", helpers_dir_as_string);
|
println!("cargo:rerun-if-changed={}", helpers_dir_as_string);
|
||||||
for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) {
|
for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) {
|
||||||
let full_file_path = format!(
|
let full_file_path = format!(
|
||||||
"{}{}{}",
|
"{}{}{}",
|
||||||
helpers_dir_as_string,
|
helpers_dir_as_string,
|
||||||
path::MAIN_SEPARATOR,
|
path::MAIN_SEPARATOR,
|
||||||
rust_file
|
rust_file
|
||||||
);
|
);
|
||||||
let mut rustc_cmd = Command::new(&*rustc_exe);
|
let mut rustc_cmd = Command::new(&*rustc_exe);
|
||||||
if debug {
|
if debug {
|
||||||
rustc_cmd.arg("-g");
|
rustc_cmd.arg("-g");
|
||||||
}
|
}
|
||||||
rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string));
|
rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string));
|
||||||
if !is_msvc {
|
if !is_msvc {
|
||||||
// HACK ALERT
|
// HACK ALERT
|
||||||
// I have no idea why the extra library below have to be linked
|
// I have no idea why the extra library below have to be linked
|
||||||
rustc_cmd.arg(r"-lucrt");
|
rustc_cmd.arg(r"-lucrt");
|
||||||
}
|
}
|
||||||
rustc_cmd
|
rustc_cmd
|
||||||
.arg("-C")
|
.arg("-C")
|
||||||
.arg(format!("opt-level={}", opt_level))
|
.arg(format!("opt-level={}", opt_level))
|
||||||
.arg("-L")
|
.arg("-L")
|
||||||
.arg(format!("{}", out_dir))
|
.arg(format!("{}", out_dir))
|
||||||
.arg("--out-dir")
|
.arg("--out-dir")
|
||||||
.arg(format!("{}", out_dir))
|
.arg(format!("{}", out_dir))
|
||||||
.arg("--target")
|
.arg("--target")
|
||||||
.arg(format!("{}", target))
|
.arg(format!("{}", target))
|
||||||
.arg(full_file_path);
|
.arg(full_file_path);
|
||||||
assert!(rustc_cmd.status().unwrap().success());
|
assert!(rustc_cmd.status().unwrap().success());
|
||||||
}
|
}
|
||||||
std::fs::copy(
|
std::fs::copy(
|
||||||
format!(
|
format!(
|
||||||
"{}{}do_cuinit_late_clr.exe",
|
"{}{}do_cuinit_late_clr.exe",
|
||||||
helpers_dir_as_string,
|
helpers_dir_as_string,
|
||||||
path::MAIN_SEPARATOR
|
path::MAIN_SEPARATOR
|
||||||
),
|
),
|
||||||
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
|
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
|
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rust_file(entry: io::Result<DirEntry>) -> Option<String> {
|
fn rust_file(entry: io::Result<DirEntry>) -> Option<String> {
|
||||||
entry.ok().and_then(|e| {
|
entry.ok().and_then(|e| {
|
||||||
let os_file_name = e.file_name();
|
let os_file_name = e.file_name();
|
||||||
let file_name = os_file_name.to_string_lossy();
|
let file_name = os_file_name.to_string_lossy();
|
||||||
let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false);
|
let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false);
|
||||||
if is_file && file_name.ends_with(".rs") {
|
if is_file && file_name.ends_with(".rs") {
|
||||||
Some(file_name.to_string())
|
Some(file_name.to_string())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,311 +1,311 @@
|
|||||||
use std::env;
|
use std::env;
|
||||||
use std::os::windows;
|
use std::os::windows;
|
||||||
use std::os::windows::ffi::OsStrExt;
|
use std::os::windows::ffi::OsStrExt;
|
||||||
use std::{error::Error, process};
|
use std::{error::Error, process};
|
||||||
use std::{fs, io, ptr};
|
use std::{fs, io, ptr};
|
||||||
use std::{mem, path::PathBuf};
|
use std::{mem, path::PathBuf};
|
||||||
|
|
||||||
use argh::FromArgs;
|
use argh::FromArgs;
|
||||||
use mem::size_of_val;
|
use mem::size_of_val;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
use winapi::um::processenv::SearchPathW;
|
use winapi::um::processenv::SearchPathW;
|
||||||
use winapi::um::{
|
use winapi::um::{
|
||||||
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
|
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
|
||||||
processthreadsapi::{GetExitCodeProcess, ResumeThread},
|
processthreadsapi::{GetExitCodeProcess, ResumeThread},
|
||||||
synchapi::WaitForSingleObject,
|
synchapi::WaitForSingleObject,
|
||||||
winbase::CreateJobObjectA,
|
winbase::CreateJobObjectA,
|
||||||
winnt::{
|
winnt::{
|
||||||
JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
|
JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
|
||||||
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
|
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
||||||
|
|
||||||
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
||||||
static NVCUDA_DLL: &'static str = "nvcuda.dll";
|
static NVCUDA_DLL: &'static str = "nvcuda.dll";
|
||||||
static NVML_DLL: &'static str = "nvml.dll";
|
static NVML_DLL: &'static str = "nvml.dll";
|
||||||
|
|
||||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||||
|
|
||||||
#[derive(FromArgs)]
|
#[derive(FromArgs)]
|
||||||
/// Launch application with custom CUDA libraries
|
/// Launch application with custom CUDA libraries
|
||||||
struct ProgramArguments {
|
struct ProgramArguments {
|
||||||
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
|
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
|
||||||
#[argh(option)]
|
#[argh(option)]
|
||||||
nvcuda: Option<PathBuf>,
|
nvcuda: Option<PathBuf>,
|
||||||
|
|
||||||
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
|
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
|
||||||
#[argh(option)]
|
#[argh(option)]
|
||||||
nvml: Option<PathBuf>,
|
nvml: Option<PathBuf>,
|
||||||
|
|
||||||
/// executable to be injected with custom CUDA libraries
|
/// executable to be injected with custom CUDA libraries
|
||||||
#[argh(positional)]
|
#[argh(positional)]
|
||||||
exe: String,
|
exe: String,
|
||||||
|
|
||||||
/// arguments to the executable
|
/// arguments to the executable
|
||||||
#[argh(positional)]
|
#[argh(positional)]
|
||||||
args: Vec<String>,
|
args: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
||||||
let raw_args = argh::from_env::<ProgramArguments>();
|
let raw_args = argh::from_env::<ProgramArguments>();
|
||||||
let normalized_args = NormalizedArguments::new(raw_args)?;
|
let normalized_args = NormalizedArguments::new(raw_args)?;
|
||||||
let mut environment = Environment::setup(normalized_args)?;
|
let mut environment = Environment::setup(normalized_args)?;
|
||||||
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
||||||
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
||||||
let mut dlls_to_inject = [
|
let mut dlls_to_inject = [
|
||||||
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
|
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
|
||||||
environment.nvcuda_path_zero_terminated.as_ptr() as _,
|
environment.nvcuda_path_zero_terminated.as_ptr() as _,
|
||||||
environment.redirect_path_zero_terminated.as_ptr() as _,
|
environment.redirect_path_zero_terminated.as_ptr() as _,
|
||||||
];
|
];
|
||||||
os_call!(
|
os_call!(
|
||||||
detours_sys::DetourCreateProcessWithDllsW(
|
detours_sys::DetourCreateProcessWithDllsW(
|
||||||
ptr::null(),
|
ptr::null(),
|
||||||
environment.winapi_command_line_zero_terminated.as_mut_ptr(),
|
environment.winapi_command_line_zero_terminated.as_mut_ptr(),
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
ptr::null(),
|
ptr::null(),
|
||||||
&mut startup_info as *mut _,
|
&mut startup_info as *mut _,
|
||||||
&mut proc_info as *mut _,
|
&mut proc_info as *mut _,
|
||||||
dlls_to_inject.len() as u32,
|
dlls_to_inject.len() as u32,
|
||||||
dlls_to_inject.as_mut_ptr(),
|
dlls_to_inject.as_mut_ptr(),
|
||||||
Option::None
|
Option::None
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
);
|
);
|
||||||
kill_child_on_process_exit(proc_info.hProcess)?;
|
kill_child_on_process_exit(proc_info.hProcess)?;
|
||||||
os_call!(
|
os_call!(
|
||||||
detours_sys::DetourCopyPayloadToProcess(
|
detours_sys::DetourCopyPayloadToProcess(
|
||||||
proc_info.hProcess,
|
proc_info.hProcess,
|
||||||
&PAYLOAD_NVCUDA_GUID,
|
&PAYLOAD_NVCUDA_GUID,
|
||||||
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
|
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
|
||||||
environment.nvcuda_path_zero_terminated.len() as u32
|
environment.nvcuda_path_zero_terminated.len() as u32
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
);
|
);
|
||||||
os_call!(
|
os_call!(
|
||||||
detours_sys::DetourCopyPayloadToProcess(
|
detours_sys::DetourCopyPayloadToProcess(
|
||||||
proc_info.hProcess,
|
proc_info.hProcess,
|
||||||
&PAYLOAD_NVML_GUID,
|
&PAYLOAD_NVML_GUID,
|
||||||
environment.nvml_path_zero_terminated.as_ptr() as *mut _,
|
environment.nvml_path_zero_terminated.as_ptr() as *mut _,
|
||||||
environment.nvml_path_zero_terminated.len() as u32
|
environment.nvml_path_zero_terminated.len() as u32
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
);
|
);
|
||||||
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
|
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
|
||||||
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
|
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
|
||||||
!= WAIT_FAILED);
|
!= WAIT_FAILED);
|
||||||
let mut child_exit_code: u32 = 0;
|
let mut child_exit_code: u32 = 0;
|
||||||
os_call!(
|
os_call!(
|
||||||
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
);
|
);
|
||||||
process::exit(child_exit_code as i32)
|
process::exit(child_exit_code as i32)
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NormalizedArguments {
|
struct NormalizedArguments {
|
||||||
nvml_path: PathBuf,
|
nvml_path: PathBuf,
|
||||||
nvcuda_path: PathBuf,
|
nvcuda_path: PathBuf,
|
||||||
redirect_path: PathBuf,
|
redirect_path: PathBuf,
|
||||||
winapi_command_line_zero_terminated: Vec<u16>,
|
winapi_command_line_zero_terminated: Vec<u16>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NormalizedArguments {
|
impl NormalizedArguments {
|
||||||
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
|
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
|
||||||
let current_exe = env::current_exe()?;
|
let current_exe = env::current_exe()?;
|
||||||
let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?;
|
let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?;
|
||||||
let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?;
|
let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?;
|
||||||
let winapi_command_line_zero_terminated =
|
let winapi_command_line_zero_terminated =
|
||||||
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
|
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
|
||||||
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
|
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
|
||||||
redirect_path.push(REDIRECT_DLL);
|
redirect_path.push(REDIRECT_DLL);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
nvml_path,
|
nvml_path,
|
||||||
nvcuda_path,
|
nvcuda_path,
|
||||||
redirect_path,
|
redirect_path,
|
||||||
winapi_command_line_zero_terminated,
|
winapi_command_line_zero_terminated,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const WIN_MAX_PATH: usize = 260;
|
const WIN_MAX_PATH: usize = 260;
|
||||||
|
|
||||||
fn get_absolute_path(
|
fn get_absolute_path(
|
||||||
current_exe: &PathBuf,
|
current_exe: &PathBuf,
|
||||||
dll: Option<PathBuf>,
|
dll: Option<PathBuf>,
|
||||||
default: &str,
|
default: &str,
|
||||||
) -> Result<PathBuf, Box<dyn Error>> {
|
) -> Result<PathBuf, Box<dyn Error>> {
|
||||||
Ok(if let Some(dll) = dll {
|
Ok(if let Some(dll) = dll {
|
||||||
if dll.is_absolute() {
|
if dll.is_absolute() {
|
||||||
dll
|
dll
|
||||||
} else {
|
} else {
|
||||||
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
|
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
|
||||||
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
|
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
|
||||||
dll_utf16.push(0);
|
dll_utf16.push(0);
|
||||||
loop {
|
loop {
|
||||||
let copied_len = os_call!(
|
let copied_len = os_call!(
|
||||||
SearchPathW(
|
SearchPathW(
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
dll_utf16.as_ptr(),
|
dll_utf16.as_ptr(),
|
||||||
ptr::null(),
|
ptr::null(),
|
||||||
full_dll_path.len() as u32,
|
full_dll_path.len() as u32,
|
||||||
full_dll_path.as_mut_ptr(),
|
full_dll_path.as_mut_ptr(),
|
||||||
ptr::null_mut()
|
ptr::null_mut()
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
) as usize;
|
) as usize;
|
||||||
if copied_len > full_dll_path.len() {
|
if copied_len > full_dll_path.len() {
|
||||||
full_dll_path.resize(copied_len + 1, 0);
|
full_dll_path.resize(copied_len + 1, 0);
|
||||||
} else {
|
} else {
|
||||||
full_dll_path.truncate(copied_len);
|
full_dll_path.truncate(copied_len);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PathBuf::from(String::from_utf16_lossy(&full_dll_path))
|
PathBuf::from(String::from_utf16_lossy(&full_dll_path))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let mut dll_path = current_exe.parent().unwrap().to_path_buf();
|
let mut dll_path = current_exe.parent().unwrap().to_path_buf();
|
||||||
dll_path.push(default);
|
dll_path.push(default);
|
||||||
dll_path
|
dll_path
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Environment {
|
struct Environment {
|
||||||
nvml_path_zero_terminated: String,
|
nvml_path_zero_terminated: String,
|
||||||
nvcuda_path_zero_terminated: String,
|
nvcuda_path_zero_terminated: String,
|
||||||
redirect_path_zero_terminated: String,
|
redirect_path_zero_terminated: String,
|
||||||
winapi_command_line_zero_terminated: Vec<u16>,
|
winapi_command_line_zero_terminated: Vec<u16>,
|
||||||
_temp_dir: TempDir,
|
_temp_dir: TempDir,
|
||||||
}
|
}
|
||||||
|
|
||||||
// This structs represents "enviroment". By environment we mean all paths
|
// This structs represents "enviroment". By environment we mean all paths
|
||||||
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
|
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
|
||||||
// directory which contains nvcuda.dll
|
// directory which contains nvcuda.dll
|
||||||
impl Environment {
|
impl Environment {
|
||||||
fn setup(args: NormalizedArguments) -> io::Result<Self> {
|
fn setup(args: NormalizedArguments) -> io::Result<Self> {
|
||||||
let _temp_dir = TempDir::new()?;
|
let _temp_dir = TempDir::new()?;
|
||||||
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||||
args.nvml_path,
|
args.nvml_path,
|
||||||
&_temp_dir,
|
&_temp_dir,
|
||||||
NVML_DLL,
|
NVML_DLL,
|
||||||
)?);
|
)?);
|
||||||
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||||
args.nvcuda_path,
|
args.nvcuda_path,
|
||||||
&_temp_dir,
|
&_temp_dir,
|
||||||
NVCUDA_DLL,
|
NVCUDA_DLL,
|
||||||
)?);
|
)?);
|
||||||
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
|
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
nvml_path_zero_terminated,
|
nvml_path_zero_terminated,
|
||||||
nvcuda_path_zero_terminated,
|
nvcuda_path_zero_terminated,
|
||||||
redirect_path_zero_terminated,
|
redirect_path_zero_terminated,
|
||||||
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
|
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
|
||||||
_temp_dir,
|
_temp_dir,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_to_correct_name(
|
fn copy_to_correct_name(
|
||||||
path_buf: PathBuf,
|
path_buf: PathBuf,
|
||||||
temp_dir: &TempDir,
|
temp_dir: &TempDir,
|
||||||
correct_name: &str,
|
correct_name: &str,
|
||||||
) -> io::Result<PathBuf> {
|
) -> io::Result<PathBuf> {
|
||||||
let file_name = path_buf.file_name().unwrap();
|
let file_name = path_buf.file_name().unwrap();
|
||||||
if file_name == correct_name {
|
if file_name == correct_name {
|
||||||
Ok(path_buf)
|
Ok(path_buf)
|
||||||
} else {
|
} else {
|
||||||
let mut temp_file_path = temp_dir.path().to_path_buf();
|
let mut temp_file_path = temp_dir.path().to_path_buf();
|
||||||
temp_file_path.push(correct_name);
|
temp_file_path.push(correct_name);
|
||||||
match windows::fs::symlink_file(&path_buf, &temp_file_path) {
|
match windows::fs::symlink_file(&path_buf, &temp_file_path) {
|
||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
fs::copy(&path_buf, &temp_file_path)?;
|
fs::copy(&path_buf, &temp_file_path)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(temp_file_path)
|
Ok(temp_file_path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn zero_terminate(p: PathBuf) -> String {
|
fn zero_terminate(p: PathBuf) -> String {
|
||||||
let mut s = p.to_string_lossy().to_string();
|
let mut s = p.to_string_lossy().to_string();
|
||||||
s.push('\0');
|
s.push('\0');
|
||||||
s
|
s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
||||||
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
|
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
|
||||||
!= ptr::null_mut());
|
!= ptr::null_mut());
|
||||||
let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() };
|
let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() };
|
||||||
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
|
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
|
||||||
os_call!(
|
os_call!(
|
||||||
SetInformationJobObject(
|
SetInformationJobObject(
|
||||||
job_handle,
|
job_handle,
|
||||||
JobObjectExtendedLimitInformation,
|
JobObjectExtendedLimitInformation,
|
||||||
&mut info as *mut _ as *mut _,
|
&mut info as *mut _ as *mut _,
|
||||||
size_of_val(&info) as u32
|
size_of_val(&info) as u32
|
||||||
),
|
),
|
||||||
|x| x != 0
|
|x| x != 0
|
||||||
);
|
);
|
||||||
os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0);
|
os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
|
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
|
||||||
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
|
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
|
||||||
let mut cmd_line = Vec::new();
|
let mut cmd_line = Vec::new();
|
||||||
let args_len = args.size_hint().0;
|
let args_len = args.size_hint().0;
|
||||||
for (idx, arg) in args.enumerate() {
|
for (idx, arg) in args.enumerate() {
|
||||||
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
|
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
|
||||||
cmd_line.extend(arg.encode_utf16());
|
cmd_line.extend(arg.encode_utf16());
|
||||||
} else {
|
} else {
|
||||||
cmd_line.push('"' as u16); // "
|
cmd_line.push('"' as u16); // "
|
||||||
let mut char_iter = arg.chars().peekable();
|
let mut char_iter = arg.chars().peekable();
|
||||||
loop {
|
loop {
|
||||||
let mut current = char_iter.next();
|
let mut current = char_iter.next();
|
||||||
let mut backslashes = 0;
|
let mut backslashes = 0;
|
||||||
match current {
|
match current {
|
||||||
Some('\\') => {
|
Some('\\') => {
|
||||||
backslashes = 1;
|
backslashes = 1;
|
||||||
while let Some('\\') = char_iter.peek() {
|
while let Some('\\') = char_iter.peek() {
|
||||||
backslashes += 1;
|
backslashes += 1;
|
||||||
char_iter.next();
|
char_iter.next();
|
||||||
}
|
}
|
||||||
current = char_iter.next();
|
current = char_iter.next();
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
match current {
|
match current {
|
||||||
None => {
|
None => {
|
||||||
for _ in 0..(backslashes * 2) {
|
for _ in 0..(backslashes * 2) {
|
||||||
cmd_line.push('\\' as u16);
|
cmd_line.push('\\' as u16);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Some('"') => {
|
Some('"') => {
|
||||||
for _ in 0..(backslashes * 2 + 1) {
|
for _ in 0..(backslashes * 2 + 1) {
|
||||||
cmd_line.push('\\' as u16);
|
cmd_line.push('\\' as u16);
|
||||||
}
|
}
|
||||||
cmd_line.push('"' as u16);
|
cmd_line.push('"' as u16);
|
||||||
}
|
}
|
||||||
Some(c) => {
|
Some(c) => {
|
||||||
for _ in 0..backslashes {
|
for _ in 0..backslashes {
|
||||||
cmd_line.push('\\' as u16);
|
cmd_line.push('\\' as u16);
|
||||||
}
|
}
|
||||||
let mut temp = [0u16; 2];
|
let mut temp = [0u16; 2];
|
||||||
cmd_line.extend(&*c.encode_utf16(&mut temp));
|
cmd_line.extend(&*c.encode_utf16(&mut temp));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cmd_line.push('"' as u16);
|
cmd_line.push('"' as u16);
|
||||||
}
|
}
|
||||||
if idx < args_len - 1 {
|
if idx < args_len - 1 {
|
||||||
cmd_line.push(' ' as u16);
|
cmd_line.push(' ' as u16);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cmd_line.push(0);
|
cmd_line.push(0);
|
||||||
cmd_line
|
cmd_line
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
#[macro_use]
|
#[macro_use]
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
mod win;
|
mod win;
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
mod bin;
|
mod bin;
|
||||||
|
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
bin::main_impl()
|
bin::main_impl()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(target_os = "windows"))]
|
#[cfg(not(target_os = "windows"))]
|
||||||
fn main() {}
|
fn main() {}
|
||||||
|
@ -1,151 +1,151 @@
|
|||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
use std::error;
|
use std::error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
|
|
||||||
mod c {
|
mod c {
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::os::raw::c_ulong;
|
use std::os::raw::c_ulong;
|
||||||
|
|
||||||
pub type DWORD = c_ulong;
|
pub type DWORD = c_ulong;
|
||||||
pub type HANDLE = LPVOID;
|
pub type HANDLE = LPVOID;
|
||||||
pub type LPVOID = *mut c_void;
|
pub type LPVOID = *mut c_void;
|
||||||
pub type HINSTANCE = HANDLE;
|
pub type HINSTANCE = HANDLE;
|
||||||
pub type HMODULE = HINSTANCE;
|
pub type HMODULE = HINSTANCE;
|
||||||
pub type WCHAR = u16;
|
pub type WCHAR = u16;
|
||||||
pub type LPCWSTR = *const WCHAR;
|
pub type LPCWSTR = *const WCHAR;
|
||||||
pub type LPWSTR = *mut WCHAR;
|
pub type LPWSTR = *mut WCHAR;
|
||||||
|
|
||||||
pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
|
pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
|
||||||
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
|
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
|
||||||
pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
|
pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
|
||||||
pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;
|
pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;
|
||||||
|
|
||||||
extern "system" {
|
extern "system" {
|
||||||
pub fn GetLastError() -> DWORD;
|
pub fn GetLastError() -> DWORD;
|
||||||
pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
|
pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
|
||||||
pub fn FormatMessageW(
|
pub fn FormatMessageW(
|
||||||
flags: DWORD,
|
flags: DWORD,
|
||||||
lpSrc: LPVOID,
|
lpSrc: LPVOID,
|
||||||
msgId: DWORD,
|
msgId: DWORD,
|
||||||
langId: DWORD,
|
langId: DWORD,
|
||||||
buf: LPWSTR,
|
buf: LPWSTR,
|
||||||
nsize: DWORD,
|
nsize: DWORD,
|
||||||
args: *const c_void,
|
args: *const c_void,
|
||||||
) -> DWORD;
|
) -> DWORD;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! last_ident {
|
macro_rules! last_ident {
|
||||||
($i:ident) => {
|
($i:ident) => {
|
||||||
stringify!($i)
|
stringify!($i)
|
||||||
};
|
};
|
||||||
($start:ident, $($cont:ident),+) => {
|
($start:ident, $($cont:ident),+) => {
|
||||||
last_ident!($($cont),+)
|
last_ident!($($cont),+)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! os_call {
|
macro_rules! os_call {
|
||||||
($($path:ident)::+ ($($args:expr),*), $success:expr) => {
|
($($path:ident)::+ ($($args:expr),*), $success:expr) => {
|
||||||
{
|
{
|
||||||
let result = unsafe{ $($path)::+ ($($args),*) };
|
let result = unsafe{ $($path)::+ ($($args),*) };
|
||||||
if !($success)(result) {
|
if !($success)(result) {
|
||||||
let name = last_ident!($($path),+);
|
let name = last_ident!($($path),+);
|
||||||
let err_code = $crate::win::errno();
|
let err_code = $crate::win::errno();
|
||||||
Err($crate::win::OsError{
|
Err($crate::win::OsError{
|
||||||
function: name,
|
function: name,
|
||||||
error_code: err_code as u32,
|
error_code: err_code as u32,
|
||||||
message: $crate::win::error_string(err_code)
|
message: $crate::win::error_string(err_code)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct OsError {
|
pub struct OsError {
|
||||||
pub function: &'static str,
|
pub function: &'static str,
|
||||||
pub error_code: u32,
|
pub error_code: u32,
|
||||||
pub message: String,
|
pub message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for OsError {
|
impl fmt::Display for OsError {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
write!(f, "{:?}", self)
|
write!(f, "{:?}", self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl error::Error for OsError {
|
impl error::Error for OsError {
|
||||||
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn errno() -> i32 {
|
pub fn errno() -> i32 {
|
||||||
unsafe { c::GetLastError() as i32 }
|
unsafe { c::GetLastError() as i32 }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets a detailed string description for the given error number.
|
/// Gets a detailed string description for the given error number.
|
||||||
pub fn error_string(mut errnum: i32) -> String {
|
pub fn error_string(mut errnum: i32) -> String {
|
||||||
// This value is calculated from the macro
|
// This value is calculated from the macro
|
||||||
// MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
|
// MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
|
||||||
let langId = 0x0800 as c::DWORD;
|
let langId = 0x0800 as c::DWORD;
|
||||||
|
|
||||||
let mut buf = [0 as c::WCHAR; 2048];
|
let mut buf = [0 as c::WCHAR; 2048];
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut module = ptr::null_mut();
|
let mut module = ptr::null_mut();
|
||||||
let mut flags = 0;
|
let mut flags = 0;
|
||||||
|
|
||||||
// NTSTATUS errors may be encoded as HRESULT, which may returned from
|
// NTSTATUS errors may be encoded as HRESULT, which may returned from
|
||||||
// GetLastError. For more information about Windows error codes, see
|
// GetLastError. For more information about Windows error codes, see
|
||||||
// `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
|
// `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
|
||||||
if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
|
if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
|
||||||
// format according to https://support.microsoft.com/en-us/help/259693
|
// format according to https://support.microsoft.com/en-us/help/259693
|
||||||
const NTDLL_DLL: &[u16] = &[
|
const NTDLL_DLL: &[u16] = &[
|
||||||
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
|
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
|
||||||
'L' as _, 0,
|
'L' as _, 0,
|
||||||
];
|
];
|
||||||
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());
|
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());
|
||||||
|
|
||||||
if module != ptr::null_mut() {
|
if module != ptr::null_mut() {
|
||||||
errnum ^= c::FACILITY_NT_BIT as i32;
|
errnum ^= c::FACILITY_NT_BIT as i32;
|
||||||
flags = c::FORMAT_MESSAGE_FROM_HMODULE;
|
flags = c::FORMAT_MESSAGE_FROM_HMODULE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let res = c::FormatMessageW(
|
let res = c::FormatMessageW(
|
||||||
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
|
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||||
module,
|
module,
|
||||||
errnum as c::DWORD,
|
errnum as c::DWORD,
|
||||||
langId,
|
langId,
|
||||||
buf.as_mut_ptr(),
|
buf.as_mut_ptr(),
|
||||||
buf.len() as c::DWORD,
|
buf.len() as c::DWORD,
|
||||||
ptr::null(),
|
ptr::null(),
|
||||||
) as usize;
|
) as usize;
|
||||||
if res == 0 {
|
if res == 0 {
|
||||||
// Sometimes FormatMessageW can fail e.g., system doesn't like langId,
|
// Sometimes FormatMessageW can fail e.g., system doesn't like langId,
|
||||||
let fm_err = errno();
|
let fm_err = errno();
|
||||||
return format!(
|
return format!(
|
||||||
"OS Error {} (FormatMessageW() returned error {})",
|
"OS Error {} (FormatMessageW() returned error {})",
|
||||||
errnum, fm_err
|
errnum, fm_err
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
match String::from_utf16(&buf[..res]) {
|
match String::from_utf16(&buf[..res]) {
|
||||||
Ok(mut msg) => {
|
Ok(mut msg) => {
|
||||||
// Trim trailing CRLF inserted by FormatMessageW
|
// Trim trailing CRLF inserted by FormatMessageW
|
||||||
let len = msg.trim_end().len();
|
let len = msg.trim_end().len();
|
||||||
msg.truncate(len);
|
msg.truncate(len);
|
||||||
msg
|
msg
|
||||||
}
|
}
|
||||||
Err(..) => format!(
|
Err(..) => format!(
|
||||||
"OS Error {} (FormatMessageW() returned \
|
"OS Error {} (FormatMessageW() returned \
|
||||||
invalid UTF-16)",
|
invalid UTF-16)",
|
||||||
errnum
|
errnum
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,51 +1,51 @@
|
|||||||
#![cfg(windows)]
|
#![cfg(windows)]
|
||||||
use std::{env, io, path::PathBuf, process::Command};
|
use std::{env, io, path::PathBuf, process::Command};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn direct_cuinit() -> io::Result<()> {
|
fn direct_cuinit() -> io::Result<()> {
|
||||||
run_process_and_check_for_zluda_dump("direct_cuinit")
|
run_process_and_check_for_zluda_dump("direct_cuinit")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn do_cuinit_early() -> io::Result<()> {
|
fn do_cuinit_early() -> io::Result<()> {
|
||||||
run_process_and_check_for_zluda_dump("do_cuinit_early")
|
run_process_and_check_for_zluda_dump("do_cuinit_early")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn do_cuinit_late() -> io::Result<()> {
|
fn do_cuinit_late() -> io::Result<()> {
|
||||||
run_process_and_check_for_zluda_dump("do_cuinit_late")
|
run_process_and_check_for_zluda_dump("do_cuinit_late")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn do_cuinit_late_clr() -> io::Result<()> {
|
fn do_cuinit_late_clr() -> io::Result<()> {
|
||||||
run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
|
run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn indirect_cuinit() -> io::Result<()> {
|
fn indirect_cuinit() -> io::Result<()> {
|
||||||
run_process_and_check_for_zluda_dump("indirect_cuinit")
|
run_process_and_check_for_zluda_dump("indirect_cuinit")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn subprocess() -> io::Result<()> {
|
fn subprocess() -> io::Result<()> {
|
||||||
run_process_and_check_for_zluda_dump("subprocess")
|
run_process_and_check_for_zluda_dump("subprocess")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
||||||
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
|
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
|
||||||
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
|
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
|
||||||
zluda_dump_dll.push("zluda_dump.dll");
|
zluda_dump_dll.push("zluda_dump.dll");
|
||||||
let helpers_dir = env!("HELPERS_OUT_DIR");
|
let helpers_dir = env!("HELPERS_OUT_DIR");
|
||||||
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
|
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
|
||||||
let mut test_cmd = Command::new(&zluda_with_exe);
|
let mut test_cmd = Command::new(&zluda_with_exe);
|
||||||
let test_cmd = test_cmd
|
let test_cmd = test_cmd
|
||||||
.arg("--nvcuda")
|
.arg("--nvcuda")
|
||||||
.arg(&zluda_dump_dll)
|
.arg(&zluda_dump_dll)
|
||||||
.arg("--")
|
.arg("--")
|
||||||
.arg(&exe_under_test);
|
.arg(&exe_under_test);
|
||||||
let test_output = test_cmd.output()?;
|
let test_output = test_cmd.output()?;
|
||||||
assert!(test_output.status.success());
|
assert!(test_output.status.success());
|
||||||
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
|
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
|
||||||
assert!(stderr_text.contains("ZLUDA_DUMP"));
|
assert!(stderr_text.contains("ZLUDA_DUMP"));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user