Merge commit 'b8244243679bb02c3c38cf8ab00c2488a51a7741' into minmax

This commit is contained in:
Andrzej Janik
2025-07-01 22:42:04 +00:00
25 changed files with 884 additions and 229 deletions

48
Cargo.lock generated
View File

@ -327,8 +327,10 @@ dependencies = [
"cglue",
"cuda_types",
"format",
"lz4-sys",
"paste",
"uuid",
"zstd-safe",
]
[[package]]
@ -494,6 +496,18 @@ dependencies = [
"uuid",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi",
]
[[package]]
name = "glob"
version = "0.3.1"
@ -582,10 +596,11 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]]
name = "jobserver"
version = "0.1.32"
version = "0.1.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a"
dependencies = [
"getrandom",
"libc",
]
@ -759,9 +774,9 @@ dependencies = [
[[package]]
name = "microlp"
version = "0.2.10"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edaa5264bc1f7668bc12e10757f8f529a526656c796cc2106cf2be10c5b8d483"
checksum = "51d1790c73b93164ff65868f63164497cb32339458a9297e17e212d91df62258"
dependencies = [
"log",
"sprs",
@ -1118,6 +1133,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rawpointer"
version = "0.2.1"
@ -1516,6 +1537,15 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wchar"
version = "0.6.1"
@ -1660,6 +1690,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.9.1",
]
[[package]]
name = "xattr"
version = "1.5.0"
@ -1783,7 +1822,6 @@ dependencies = [
"format",
"goblin",
"libc",
"lz4-sys",
"parking_lot",
"paste",
"ptx",

View File

@ -199,7 +199,8 @@ impl VisitMut for FixFnSignatures {
}
const MODULES: &[&str] = &[
"context", "device", "driver", "function", "link", "memory", "module", "pointer", "stream",
"context", "device", "driver", "function", "library", "link", "memory", "module", "pointer",
"stream",
];
#[proc_macro]

View File

@ -78,122 +78,17 @@ bitflags! {
impl FatbincWrapper {
pub const MAGIC: c_uint = 0x466243B1;
const VERSION_V1: c_uint = 0x1;
pub const VERSION_V1: c_uint = 0x1;
pub const VERSION_V2: c_uint = 0x2;
pub fn new<'a, T: Sized>(ptr: &*const T) -> Result<&'a Self, ParseError> {
unsafe { ptr.cast::<Self>().as_ref() }
.ok_or(ParseError::NullPointer("FatbincWrapper"))
.and_then(|ptr| {
ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [Self::MAGIC])?;
ParseError::check_fields(
"FATBINC_VERSION",
ptr.version,
[Self::VERSION_V1, Self::VERSION_V2],
)?;
Ok(ptr)
})
}
}
impl FatbinHeader {
const MAGIC: c_uint = 0xBA55ED50;
const VERSION: c_ushort = 0x01;
pub fn new<'a, T: Sized>(ptr: &'a *const T) -> Result<&'a Self, ParseError> {
unsafe { ptr.cast::<Self>().as_ref() }
.ok_or(ParseError::NullPointer("FatbinHeader"))
.and_then(|ptr| {
ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [Self::MAGIC])?;
ParseError::check_fields("FATBIN_VERSION", ptr.version, [Self::VERSION])?;
Ok(ptr)
})
}
pub unsafe fn get_content<'a>(&'a self) -> &'a [u8] {
let start = std::ptr::from_ref(self)
.cast::<u8>()
.add(self.header_size as usize);
std::slice::from_raw_parts(start, self.files_size as usize)
}
pub const MAGIC: c_uint = 0xBA55ED50;
pub const VERSION: c_ushort = 0x01;
}
impl FatbinFileHeader {
pub const HEADER_KIND_PTX: c_ushort = 0x01;
pub const HEADER_KIND_ELF: c_ushort = 0x02;
const HEADER_VERSION_CURRENT: c_ushort = 0x101;
pub fn new_ptx<T: Sized>(ptr: *const T) -> Result<Option<&'static Self>, ParseError> {
unsafe { ptr.cast::<Self>().as_ref() }
.ok_or(ParseError::NullPointer("FatbinFileHeader"))
.and_then(|ptr| {
ParseError::check_fields(
"FATBIN_FILE_HEADER_VERSION_CURRENT",
ptr.version,
[Self::HEADER_VERSION_CURRENT],
)?;
match ptr.kind {
Self::HEADER_KIND_PTX => Ok(Some(ptr)),
Self::HEADER_KIND_ELF => Ok(None),
_ => Err(ParseError::UnexpectedBinaryField {
field_name: "FATBIN_FILE_HEADER_KIND",
observed: ptr.kind.into(),
expected: vec![Self::HEADER_KIND_PTX.into(), Self::HEADER_KIND_ELF.into()],
}),
}
})
}
pub unsafe fn next<'a>(slice: &'a mut &[u8]) -> Result<Option<&'a Self>, ParseError> {
if slice.len() < std::mem::size_of::<Self>() {
return Ok(None);
}
let this = &*slice.as_ptr().cast::<Self>();
let next_element = slice
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
.map(|(_, next)| next);
*slice = next_element.unwrap_or(&[]);
ParseError::check_fields(
"FATBIN_FILE_HEADER_VERSION_CURRENT",
this.version,
[Self::HEADER_VERSION_CURRENT],
)?;
Ok(Some(this))
}
pub unsafe fn get_payload<'a>(&'a self) -> &'a [u8] {
let start = std::ptr::from_ref(self)
.cast::<u8>()
.add(self.header_size as usize);
std::slice::from_raw_parts(start, self.payload_size as usize)
}
}
pub enum ParseError {
NullPointer(&'static str),
UnexpectedBinaryField {
field_name: &'static str,
observed: u32,
expected: Vec<u32>,
},
}
impl ParseError {
pub(crate) fn check_fields<const N: usize, T: Into<u32> + Eq + Copy>(
name: &'static str,
observed: T,
expected: [T; N],
) -> Result<(), Self> {
if expected.contains(&observed) {
Ok(())
} else {
let observed = observed.into();
let expected = expected.into_iter().map(Into::into).collect();
Err(ParseError::UnexpectedBinaryField {
field_name: name,
expected,
observed,
})
}
}
pub const HEADER_VERSION_CURRENT: c_ushort = 0x101;
}

View File

@ -10,3 +10,5 @@ uuid = "1.16"
paste = "1.0"
bit-vec = "0.8.0"
cglue = "0.3.5"
lz4-sys = "1.9"
zstd-safe = { version = "7.2.4", features = ["std"] }

259
dark_api/src/fatbin.rs Normal file
View File

@ -0,0 +1,259 @@
// This file contains a higher-level interface for parsing fatbins
use std::ptr;
use cuda_types::dark_api::*;
pub enum ParseError {
NullPointer(&'static str),
UnexpectedBinaryField {
field_name: &'static str,
observed: u32,
expected: Vec<u32>,
},
}
impl ParseError {
pub(crate) fn check_fields<const N: usize, T: Into<u32> + Eq + Copy>(
name: &'static str,
observed: T,
expected: [T; N],
) -> Result<(), Self> {
if expected.contains(&observed) {
Ok(())
} else {
let observed = observed.into();
let expected = expected.into_iter().map(Into::into).collect();
Err(ParseError::UnexpectedBinaryField {
field_name: name,
expected,
observed,
})
}
}
}
pub enum FatbinError {
ParseFailure(ParseError),
Lz4DecompressionFailure,
ZstdDecompressionFailure(usize),
}
pub fn parse_fatbinc_wrapper<T: Sized>(ptr: &*const T) -> Result<&FatbincWrapper, ParseError> {
unsafe { ptr.cast::<FatbincWrapper>().as_ref() }
.ok_or(ParseError::NullPointer("FatbincWrapper"))
.and_then(|ptr| {
ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [FatbincWrapper::MAGIC])?;
ParseError::check_fields(
"FATBINC_VERSION",
ptr.version,
[FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2],
)?;
Ok(ptr)
})
}
fn parse_fatbin_header<T: Sized>(ptr: &*const T) -> Result<&FatbinHeader, ParseError> {
unsafe { ptr.cast::<FatbinHeader>().as_ref() }
.ok_or(ParseError::NullPointer("FatbinHeader"))
.and_then(|ptr| {
ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [FatbinHeader::MAGIC])?;
ParseError::check_fields("FATBIN_VERSION", ptr.version, [FatbinHeader::VERSION])?;
Ok(ptr)
})
}
pub struct Fatbin<'a> {
pub wrapper: &'a FatbincWrapper,
}
impl<'a> Fatbin<'a> {
pub fn new<T>(ptr: &'a *const T) -> Result<Self, FatbinError> {
let wrapper: &FatbincWrapper =
parse_fatbinc_wrapper(ptr).map_err(|e| FatbinError::ParseFailure(e))?;
Ok(Fatbin { wrapper })
}
pub fn get_submodules(&self) -> Result<FatbinIter<'a>, FatbinError> {
match self.wrapper.version {
FatbincWrapper::VERSION_V2 =>
Ok(FatbinIter::V2(FatbinSubmoduleIterator {
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
_phantom: std::marker::PhantomData,
})),
FatbincWrapper::VERSION_V1 => {
let header = parse_fatbin_header(&self.wrapper.data)
.map_err(FatbinError::ParseFailure)?;
Ok(FatbinIter::V1(Some(FatbinSubmodule::new(header))))
}
version => Err(FatbinError::ParseFailure(ParseError::UnexpectedBinaryField{
field_name: "FATBINC_VERSION",
observed: version,
expected: [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2].into(),
})),
}
}
}
pub struct FatbinSubmodule<'a> {
pub header: &'a FatbinHeader, // TODO: maybe make private
}
impl<'a> FatbinSubmodule<'a> {
pub fn new(header: &'a FatbinHeader) -> Self {
FatbinSubmodule { header }
}
pub fn get_files(&self) -> FatbinFileIterator {
unsafe { FatbinFileIterator::new(self.header) }
}
}
pub enum FatbinIter<'a> {
V1(Option<FatbinSubmodule<'a>>),
V2(FatbinSubmoduleIterator<'a>),
}
impl<'a> FatbinIter<'a> {
pub fn next(&mut self) -> Result<Option<FatbinSubmodule<'a>>, ParseError> {
match self {
FatbinIter::V1(opt) => Ok(opt.take()),
FatbinIter::V2(iter) => unsafe { iter.next() },
}
}
}
pub struct FatbinSubmoduleIterator<'a> {
fatbins: *const *const std::ffi::c_void,
_phantom: std::marker::PhantomData<&'a FatbinHeader>,
}
impl<'a> FatbinSubmoduleIterator<'a> {
pub unsafe fn next(&mut self) -> Result<Option<FatbinSubmodule<'a>>, ParseError> {
if *self.fatbins != ptr::null() {
let header = *self.fatbins as *const FatbinHeader;
self.fatbins = self.fatbins.add(1);
Ok(Some(FatbinSubmodule::new(header.as_ref().ok_or(
ParseError::NullPointer("FatbinSubmoduleIterator"),
)?)))
} else {
Ok(None)
}
}
}
pub struct FatbinFile<'a> {
pub header: &'a FatbinFileHeader,
}
impl<'a> FatbinFile<'a> {
pub fn new(header: &'a FatbinFileHeader) -> Self {
Self { header }
}
pub unsafe fn get_payload(&'a self) -> &'a [u8] {
let start = std::ptr::from_ref(self.header)
.cast::<u8>()
.add(self.header.header_size as usize);
std::slice::from_raw_parts(start, self.header.payload_size as usize)
}
pub unsafe fn decompress(&'a self) -> Result<Vec<u8>, FatbinError> {
let mut payload = if self
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedLz4)
{
unsafe { decompress_lz4(self) }?
} else if self
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedZstd)
{
unsafe { decompress_zstd(self) }?
} else {
unsafe { self.get_payload().to_vec() }
};
while payload.last() == Some(&0) {
// remove trailing zeros
payload.pop();
}
Ok(payload)
}
}
pub struct FatbinFileIterator<'a> {
file_buffer: &'a [u8],
}
impl<'a> FatbinFileIterator<'a> {
pub unsafe fn new(header: &'a FatbinHeader) -> Self {
let start = std::ptr::from_ref(header)
.cast::<u8>()
.add(header.header_size as usize);
let file_buffer = std::slice::from_raw_parts(start, header.files_size as usize);
Self { file_buffer }
}
pub unsafe fn next(&mut self) -> Result<Option<FatbinFile>, ParseError> {
if self.file_buffer.len() < std::mem::size_of::<FatbinFileHeader>() {
return Ok(None);
}
let this = &*self.file_buffer.as_ptr().cast::<FatbinFileHeader>();
let next_element = self
.file_buffer
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
.map(|(_, next)| next);
self.file_buffer = next_element.unwrap_or(&[]);
ParseError::check_fields(
"FATBIN_FILE_HEADER_VERSION_CURRENT",
this.version,
[FatbinFileHeader::HEADER_VERSION_CURRENT],
)?;
Ok(Some(FatbinFile::new(this)))
}
}
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
let decompressed_size = usize::max(1024, file.header.uncompressed_payload as usize);
let mut decompressed_vec = vec![0u8; decompressed_size];
loop {
match lz4_sys::LZ4_decompress_safe(
file.get_payload().as_ptr() as *const _,
decompressed_vec.as_mut_ptr() as *mut _,
file.header.payload_size as _,
decompressed_vec.len() as _,
) {
error if error < 0 => {
let new_size = decompressed_vec.len() * 2;
if new_size > MAX_MODULE_DECOMPRESSION_BOUND {
return Err(FatbinError::Lz4DecompressionFailure);
}
decompressed_vec.resize(decompressed_vec.len() * 2, 0);
}
real_decompressed_size => {
decompressed_vec.truncate(real_decompressed_size as usize);
return Ok(decompressed_vec);
}
}
}
}
pub unsafe fn decompress_zstd(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
let payload = file.get_payload();
match zstd_safe::decompress(&mut result, payload) {
Ok(actual_size) => {
result.truncate(actual_size);
Ok(result)
}
Err(err) => Err(FatbinError::ZstdDecompressionFailure(err)),
}
}

View File

@ -2,6 +2,8 @@ use std::ffi::c_void;
use cuda_types::cuda::CUuuid;
pub mod fatbin;
macro_rules! dark_api_init {
(SIZE_OF, $table_len:literal, $type_:ty) => {
(std::mem::size_of::<usize>() * $table_len) as *const std::ffi::c_void

View File

@ -18,7 +18,7 @@ rustc-hash = "2.0.0"
strum = "0.26"
strum_macros = "0.26"
petgraph = "0.7.1"
microlp = "0.2.10"
microlp = "0.2.11"
int-enum = "1.1"
unwrap_or = "1.0.1"
@ -31,3 +31,6 @@ tempfile = "3"
paste = "1.0"
pretty_assertions = "1.4.1"
libloading = "0.8"
[features]
ci_build = []

View File

@ -0,0 +1,64 @@
declare void @__zluda_ptx_impl___assertfail(i64, i64, i32, i64, i64) #0
define amdgpu_kernel void @assertfail(ptr addrspace(4) byref(i64) %"86", ptr addrspace(4) byref(i64) %"87") #1 {
%"88" = alloca i64, align 8, addrspace(5)
%"89" = alloca i64, align 8, addrspace(5)
%"90" = alloca i64, align 8, addrspace(5)
%"91" = alloca i64, align 8, addrspace(5)
%"94" = alloca i32, align 4, addrspace(5)
%"96" = alloca i64, align 8, addrspace(5)
%"99" = alloca i64, align 8, addrspace(5)
%"102" = alloca i32, align 4, addrspace(5)
%"105" = alloca i64, align 8, addrspace(5)
%"108" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"84"
"84": ; preds = %1
%"92" = load i64, ptr addrspace(4) %"86", align 4
store i64 %"92", ptr addrspace(5) %"88", align 4
%"93" = load i64, ptr addrspace(4) %"87", align 4
store i64 %"93", ptr addrspace(5) %"89", align 4
store i32 0, ptr addrspace(5) %"94", align 4
%"97" = getelementptr inbounds i8, ptr addrspace(5) %"96", i64 0
%"98" = load i64, ptr addrspace(5) %"88", align 4
store i64 %"98", ptr addrspace(5) %"97", align 4
%"100" = getelementptr inbounds i8, ptr addrspace(5) %"99", i64 0
%"101" = load i64, ptr addrspace(5) %"88", align 4
store i64 %"101", ptr addrspace(5) %"100", align 4
%"103" = getelementptr inbounds i8, ptr addrspace(5) %"102", i64 0
%"104" = load i32, ptr addrspace(5) %"94", align 4
store i32 %"104", ptr addrspace(5) %"103", align 4
%"106" = getelementptr inbounds i8, ptr addrspace(5) %"105", i64 0
%"107" = load i64, ptr addrspace(5) %"88", align 4
store i64 %"107", ptr addrspace(5) %"106", align 4
%"109" = getelementptr inbounds i8, ptr addrspace(5) %"108", i64 0
%"110" = load i64, ptr addrspace(5) %"88", align 4
store i64 %"110", ptr addrspace(5) %"109", align 4
%"74" = load i64, ptr addrspace(5) %"96", align 4
%"75" = load i64, ptr addrspace(5) %"99", align 4
%"76" = load i32, ptr addrspace(5) %"102", align 4
%"77" = load i64, ptr addrspace(5) %"105", align 4
%"78" = load i64, ptr addrspace(5) %"108", align 4
call void @__zluda_ptx_impl___assertfail(i64 %"74", i64 %"75", i32 %"76", i64 %"77", i64 %"78")
br label %"85"
"85": ; preds = %"84"
%"112" = load i64, ptr addrspace(5) %"88", align 4
%"122" = inttoptr i64 %"112" to ptr
%"111" = load i64, ptr %"122", align 4
store i64 %"111", ptr addrspace(5) %"90", align 4
%"114" = load i64, ptr addrspace(5) %"90", align 4
%"113" = add i64 %"114", 1
store i64 %"113", ptr addrspace(5) %"91", align 4
%"115" = load i64, ptr addrspace(5) %"89", align 4
%"116" = load i64, ptr addrspace(5) %"91", align 4
%"123" = inttoptr i64 %"115" to ptr
store i64 %"116", ptr %"123", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,43 @@
declare [16 x i8] @foobar(i64) #0
define amdgpu_kernel void @extern_func(ptr addrspace(4) byref(i64) %"44", ptr addrspace(4) byref(i64) %"45") #1 {
%"46" = alloca i64, align 8, addrspace(5)
%"47" = alloca i64, align 8, addrspace(5)
%"48" = alloca i64, align 8, addrspace(5)
%"49" = alloca i64, align 8, addrspace(5)
%"54" = alloca i64, align 8, addrspace(5)
%"57" = alloca [16 x i8], align 16, addrspace(5)
br label %1
1: ; preds = %0
br label %"41"
"41": ; preds = %1
%"50" = load i64, ptr addrspace(4) %"44", align 4
store i64 %"50", ptr addrspace(5) %"46", align 4
%"51" = load i64, ptr addrspace(4) %"45", align 4
store i64 %"51", ptr addrspace(5) %"47", align 4
%"53" = load i64, ptr addrspace(5) %"46", align 4
%"61" = inttoptr i64 %"53" to ptr addrspace(1)
%"52" = load i64, ptr addrspace(1) %"61", align 4
store i64 %"52", ptr addrspace(5) %"48", align 4
%"55" = getelementptr inbounds i8, ptr addrspace(5) %"54", i64 0
%"56" = load i64, ptr addrspace(5) %"48", align 4
store i64 %"56", ptr addrspace(5) %"55", align 4
%"39" = load i64, ptr addrspace(5) %"54", align 4
%"40" = call [16 x i8] @foobar(i64 %"39")
br label %"42"
"42": ; preds = %"41"
store [16 x i8] %"40", ptr addrspace(5) %"57", align 1
%"58" = load i64, ptr addrspace(5) %"57", align 4
store i64 %"58", ptr addrspace(5) %"49", align 4
%"59" = load i64, ptr addrspace(5) %"47", align 4
%"60" = load i64, ptr addrspace(5) %"49", align 4
%"64" = inttoptr i64 %"59" to ptr
store i64 %"60", ptr %"64", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,43 @@
declare i32 @__zluda_ptx_impl_sreg_lanemask_lt() #0
define amdgpu_kernel void @lanemask_lt(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #1 {
%"38" = alloca i64, align 8, addrspace(5)
%"39" = alloca i64, align 8, addrspace(5)
%"40" = alloca i32, align 4, addrspace(5)
%"41" = alloca i32, align 4, addrspace(5)
%"42" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"33"
"33": ; preds = %1
%"43" = load i64, ptr addrspace(4) %"36", align 4
store i64 %"43", ptr addrspace(5) %"38", align 4
%"44" = load i64, ptr addrspace(4) %"37", align 4
store i64 %"44", ptr addrspace(5) %"39", align 4
%"46" = load i64, ptr addrspace(5) %"38", align 4
%"56" = inttoptr i64 %"46" to ptr
%"55" = load i32, ptr %"56", align 4
store i32 %"55", ptr addrspace(5) %"40", align 4
%"48" = load i32, ptr addrspace(5) %"40", align 4
%"57" = add i32 %"48", 1
store i32 %"57", ptr addrspace(5) %"41", align 4
%"31" = call i32 @__zluda_ptx_impl_sreg_lanemask_lt()
br label %"34"
"34": ; preds = %"33"
store i32 %"31", ptr addrspace(5) %"42", align 4
%"51" = load i32, ptr addrspace(5) %"41", align 4
%"52" = load i32, ptr addrspace(5) %"42", align 4
%"60" = add i32 %"51", %"52"
store i32 %"60", ptr addrspace(5) %"41", align 4
%"53" = load i64, ptr addrspace(5) %"39", align 4
%"54" = load i32, ptr addrspace(5) %"41", align 4
%"63" = inttoptr i64 %"53" to ptr
store i32 %"54", ptr %"63", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -10,32 +10,65 @@ use std::fmt::{self, Debug, Display, Formatter};
use std::fs::{self, File};
use std::io::Write;
use std::mem;
use std::path::Path;
use std::path::{Path, PathBuf};
use std::ptr;
use std::str;
#[cfg(not(feature = "ci_build"))]
macro_rules! read_test_file {
($file:expr) => {
{
// CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx).
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.pop();
path.push(file!());
path.pop();
path.push($file);
std::fs::read_to_string(path).unwrap()
}
};
}
#[cfg(feature = "ci_build")]
macro_rules! read_test_file {
($file:expr) => {
include_str!($file).to_string()
};
}
macro_rules! test_ptx {
($fn_name:ident, $input:expr, $output:expr) => {
paste::item! {
#[test]
fn [<$fn_name _hip>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let input = $input;
let mut output = $output;
test_hip_assert(stringify!($fn_name), ptx, &input, &mut output)
test_hip_assert(stringify!($fn_name), &ptx, &input, &mut output)
}
}
paste::item! {
#[test]
fn [<$fn_name _cuda>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let input = $input;
let mut output = $output;
test_cuda_assert(stringify!($fn_name), ptx, &input, &mut output)
test_cuda_assert(stringify!($fn_name), &ptx, &input, &mut output)
}
}
paste::item! {
#[test]
fn [<$fn_name _llvm>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let ll = read_test_file!(concat!("../ll/", stringify!($fn_name), ".ll"));
test_llvm_assert(stringify!($fn_name), &ptx, ll.trim())
}
}
};
($fn_name:ident) => {
paste::item! {
#[test]
fn [<$fn_name _llvm>]() -> Result<(), Box<dyn std::error::Error>> {
@ -45,8 +78,6 @@ macro_rules! test_ptx {
}
}
};
($fn_name:ident) => {};
}
test_ptx!(ld_st, [1u64], [1u64]);
@ -242,7 +273,8 @@ test_ptx!(
);
test_ptx!(assertfail);
test_ptx!(func_ptr);
// TODO: not yet supported
//test_ptx!(func_ptr);
test_ptx!(lanemask_lt);
test_ptx!(extern_func);
@ -265,15 +297,14 @@ impl<T: Debug> Debug for DisplayError<T> {
impl<T: Debug> error::Error for DisplayError<T> {}
fn test_hip_assert<
'a,
Input: From<u8> + Debug + Copy + PartialEq,
Output: From<u8> + Debug + Copy + PartialEq + Default,
>(
name: &str,
ptx_text: &'a str,
ptx_text: &str,
input: &[Input],
output: &mut [Output],
) -> Result<(), Box<dyn error::Error + 'a>> {
) -> Result<(), Box<dyn error::Error>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap();
let name = CString::new(name)?;
@ -283,11 +314,11 @@ fn test_hip_assert<
Ok(())
}
fn test_llvm_assert<'a>(
fn test_llvm_assert(
name: &str,
ptx_text: &'a str,
ptx_text: &str,
expected_ll: &str,
) -> Result<(), Box<dyn error::Error + 'a>> {
) -> Result<(), Box<dyn error::Error>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap();
let actual_ll = llvm_ir.llvm_ir.print_module_to_string();
@ -301,22 +332,21 @@ fn test_llvm_assert<'a>(
let mut output_file = File::create(output_file).unwrap();
output_file.write_all(actual_ll.as_bytes()).unwrap();
}
let comparison = pretty_assertions::StrComparison::new(expected_ll, actual_ll);
let comparison = pretty_assertions::StrComparison::new(&expected_ll, &actual_ll);
panic!("assertion failed: `(left == right)`\n\n{}", comparison);
}
Ok(())
}
fn test_cuda_assert<
'a,
Input: From<u8> + Debug + Copy + PartialEq,
Output: From<u8> + Debug + Copy + PartialEq + Default,
>(
name: &str,
ptx_text: &'a str,
ptx_text: &str,
input: &[Input],
output: &mut [Output],
) -> Result<(), Box<dyn error::Error + 'a>> {
) -> Result<(), Box<dyn error::Error>> {
let name = CString::new(name)?;
let result = run_cuda(name.as_c_str(), ptx_text, input, output);
assert_eq!(result.as_slice(), output);

View File

@ -1492,6 +1492,46 @@ pub struct TokenError(std::ops::Range<usize>);
impl std::error::Error for TokenError {}
fn first_optional<
'a,
'input,
Input: Stream,
OptionalOutput,
RequiredOutput,
Error,
ParseOptional,
ParseRequired,
>(
mut optional: ParseOptional,
mut required: ParseRequired,
) -> impl Parser<Input, (Option<OptionalOutput>, RequiredOutput), Error>
where
ParseOptional: Parser<Input, OptionalOutput, Error>,
ParseRequired: Parser<Input, RequiredOutput, Error>,
Error: ParserError<Input>,
{
move |input: &mut Input| -> Result<(Option<OptionalOutput>, RequiredOutput), ErrMode<Error>> {
let start = input.checkpoint();
let parsed_optional = match optional.parse_next(input) {
Ok(v) => Some(v),
Err(ErrMode::Backtrack(_)) => {
input.reset(&start);
None
},
Err(e) => return Err(e)
};
match required.parse_next(input) {
Ok(v) => return Ok((parsed_optional, v)),
Err(ErrMode::Backtrack(_)) => input.reset(&start),
Err(e) => return Err(e)
};
Ok((None, required.parse_next(input)?))
}
}
// This macro is responsible for generating parser code for instruction parser.
// Instruction parsing is by far the most complex part of parsing PTX code:
// * There are tens of instruction kinds, each with slightly different parsing rules
@ -3413,6 +3453,7 @@ derive_parser!(
#[cfg(test)]
mod tests {
use crate::first_optional;
use crate::parse_module_checked;
use crate::PtxError;
@ -3423,6 +3464,55 @@ mod tests {
use logos::Span;
use winnow::prelude::*;
#[test]
fn first_optional_present() {
let text = "AB";
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
assert_eq!(result, Ok((Some('A'), 'B')));
}
#[test]
fn first_optional_absent() {
let text = "B";
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
assert_eq!(result, Ok((None, 'B')));
}
#[test]
fn first_optional_repeated_absent() {
let text = "A";
let result = first_optional::<_, _, _, (), _, _>('A', 'A').parse(text);
assert_eq!(result, Ok((None, 'A')));
}
#[test]
fn first_optional_repeated_present() {
let text = "AA";
let result = first_optional::<_, _, _, (), _, _>('A', 'A').parse(text);
assert_eq!(result, Ok((Some('A'), 'A')));
}
#[test]
fn first_optional_sequence_absent() {
let text = "AA";
let result = ('A', first_optional::<_, _, _, (), _, _>('A', 'A')).parse(text);
assert_eq!(result, Ok(('A', (None, 'A'))));
}
#[test]
fn first_optional_sequence_present() {
let text = "AAA";
let result = ('A', first_optional::<_, _, _, (), _, _>('A', 'A')).parse(text);
assert_eq!(result, Ok(('A', (Some('A'), 'A'))));
}
#[test]
fn first_optional_no_match() {
let text = "C";
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
assert!(result.is_err());
}
#[test]
fn sm_11() {
let text = ".target sm_11";

View File

@ -757,12 +757,13 @@ fn emit_definition_parser(
DotModifierRef::Direct { optional: true, .. }
| DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
});
let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| {
let (arguments_pattern, arguments_parser) = definition.arguments.0.iter().enumerate().rfold((quote! { () }, quote! { empty }), |(emitted_pattern, emitted_parser), (idx, arg)| {
let comma = if idx == 0 || arg.pre_pipe {
quote! { empty }
} else {
quote! { any.verify(|(t, _)| *t == #token_type::Comma).void() }
};
let pre_bracket = if arg.pre_bracket {
quote! {
any.verify(|(t, _)| *t == #token_type::LBracket).void()
@ -833,16 +834,20 @@ fn emit_definition_parser(
#pattern.map(|(_, _, _, _, name, _, _)| name)
}
};
if arg.optional {
quote! {
let #arg_name = opt(#inner_parser).parse_next(stream)?;
}
let parser = if arg.optional {
quote! { first_optional(#inner_parser, #emitted_parser) }
} else {
quote! {
let #arg_name = #inner_parser.parse_next(stream)?;
}
}
quote! { (#inner_parser, #emitted_parser) }
};
let pattern = quote! { ( #arg_name, #emitted_pattern ) };
(pattern, parser)
});
let arguments_parse = quote! { let #arguments_pattern = ( #arguments_parser ).parse_next(stream)?; };
let fn_args = definition.function_arguments();
let fn_name = format_ident!("{}_{}", opcode, fn_idx);
let fn_call = quote! {
@ -863,7 +868,7 @@ fn emit_definition_parser(
}
}
#(#unordered_parse_validations)*
#(#arguments_parse)*
#arguments_parse
#fn_call
}
}

View File

@ -22,14 +22,23 @@ impl Parse for ParseDefinitions {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let token_type = input.parse::<ItemEnum>()?;
let mut additional_enums = FxHashMap::default();
while input.peek(Token![#]) {
let enum_ = input.parse::<ItemEnum>()?;
additional_enums.insert(enum_.ident.clone(), enum_);
}
let mut definitions = Vec::new();
while !input.is_empty() {
definitions.push(input.parse::<OpcodeDefinition>()?);
loop {
if input.is_empty() {
break;
}
let lookahead = input.lookahead1();
if lookahead.peek(Token![#]) {
let enum_ = input.parse::<ItemEnum>()?;
additional_enums.insert(enum_.ident.clone(), enum_);
} else if lookahead.peek(Ident) {
definitions.push(input.parse::<OpcodeDefinition>()?);
} else {
return Err(lookahead.error());
}
}
Ok(Self {
token_type,
additional_enums,

View File

@ -10,6 +10,10 @@ use std::{
usize,
};
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
mod os;
pub(crate) struct GlobalState {
pub devices: Vec<Device>,
pub comgr: Comgr,
@ -232,7 +236,31 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
unix_seconds: u64,
result: *mut [u64; 2],
) -> cuda_types::cuda::CUresult {
todo!()
let current_process = std::process::id();
let current_thread = os::current_thread();
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
let devices = get_device_hash_info()?;
let device_count = devices.len() as u32;
let get_device = |dev| devices[dev as usize];
let hash = ::dark_api::integrity_check(
version,
unix_seconds,
cuda_types::cuda::CUDA_VERSION,
current_process,
current_thread,
integrity_check_table,
cudart_table,
fn_address,
device_count,
get_device,
);
*result = hash;
Ok(())
}
unsafe extern "system" fn context_check(
@ -244,10 +272,50 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
}
unsafe extern "system" fn check_fn3() -> u32 {
todo!()
0
}
}
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
let mut device_count = 0;
device::get_count(&mut device_count)?;
(0..device_count)
.map(|dev| {
let mut guid = CUuuid_st { bytes: [0; 16] };
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
let mut pci_domain = 0;
device::get_attribute(
&mut pci_domain,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
dev,
)?;
let mut pci_bus = 0;
device::get_attribute(
&mut pci_bus,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
dev,
)?;
let mut pci_device = 0;
device::get_attribute(
&mut pci_device,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
dev,
)?;
Ok(::dark_api::DeviceHashinfo {
guid,
pci_domain,
pci_bus,
pci_device,
})
})
.collect()
}
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();

38
zluda/src/impl/library.rs Normal file
View File

@ -0,0 +1,38 @@
use super::module;
use super::ZludaObject;
use cuda_types::cuda::*;
use hip_runtime_sys::*;
pub(crate) struct Library {
base: hipModule_t,
}
impl ZludaObject for Library {
const COOKIE: usize = 0xb328a916cc234d7c;
type CudaHandle = CUlibrary;
fn drop_checked(&mut self) -> CUresult {
// TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice.
unsafe { hipModuleUnload(self.base) }?;
Ok(())
}
}
/// This implementation simply loads the code as a HIP module for now. The various JIT and library options are ignored.
pub(crate) fn load_data(
library: &mut CUlibrary,
code: *const ::core::ffi::c_void,
_jit_options: &mut CUjit_option,
_jit_options_values: &mut *mut ::core::ffi::c_void,
_num_jit_options: ::core::ffi::c_uint,
_library_options: &mut CUlibraryOption,
_library_option_values: &mut *mut ::core::ffi::c_void,
_num_library_options: ::core::ffi::c_uint,
) -> CUresult {
let hip_module = module::load_hip_module(code)?;
*library = Library { base: hip_module }.wrap();
Ok(())
}

View File

@ -10,6 +10,7 @@ pub(super) mod context;
pub(super) mod device;
pub(super) mod driver;
pub(super) mod function;
pub(super) mod library;
pub(super) mod memory;
pub(super) mod module;
pub(super) mod pointer;
@ -135,6 +136,9 @@ from_cuda_nop!(
cuda_types::cuda::CUdevprop,
CUdevice_attribute,
CUdriverProcAddressQueryResult,
CUjit_option,
CUlibrary,
CUlibraryOption,
CUmoduleLoadingMode,
CUuuid
);
@ -169,6 +173,15 @@ impl<'a> FromCuda<'a, *const ::core::ffi::c_char> for &CStr {
}
}
impl<'a> FromCuda<'a, *const ::core::ffi::c_void> for &'a ::core::ffi::c_void {
fn from_cuda(x: &'a *const ::core::ffi::c_void) -> Result<Self, CUerror> {
match unsafe { x.as_ref() } {
Some(x) => Ok(x),
None => Err(CUerror::INVALID_VALUE),
}
}
}
pub(crate) trait ZludaObject: Sized + Send + Sync {
const COOKIE: usize;
const LIVENESS_FAIL: CUerror = cuda_types::cuda::CUerror::INVALID_VALUE;

View File

@ -1,5 +1,9 @@
use super::{driver, ZludaObject};
use cuda_types::cuda::*;
use cuda_types::{
cuda::*,
dark_api::{FatbinFileHeader, FatbincWrapper},
};
use dark_api::fatbin::Fatbin;
use hip_runtime_sys::*;
use std::{ffi::CStr, mem};
@ -18,12 +22,48 @@ impl ZludaObject for Module {
}
}
pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -> CUresult {
fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result<Vec<u8>, CUerror> {
let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?;
let mut submodules = fatbin.get_submodules().map_err(|_| CUerror::UNKNOWN)?;
while let Some(current) = unsafe { submodules.next().map_err(|_| CUerror::UNKNOWN)? } {
let mut files = current.get_files();
while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } {
if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX {
let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?;
return Ok(decompressed);
}
}
}
Err(CUerror::NO_BINARY_FOR_GPU)
}
/// get_ptx takes an `image` that can be either a fatbin or a NULL-terminated ptx, and returns a String containing a ptx extracted from `image`.
fn get_ptx(image: *const ::core::ffi::c_void) -> Result<String, CUerror> {
if image.is_null() {
return Err(CUerror::INVALID_VALUE);
}
let ptx = if unsafe { *(image as *const u32) } == FatbincWrapper::MAGIC {
let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?;
str::from_utf8(&ptx_bytes)
.map_err(|_| CUerror::UNKNOWN)?
.to_owned()
} else {
unsafe { CStr::from_ptr(image.cast()) }
.to_str()
.map_err(|_| CUerror::INVALID_VALUE)?
.to_owned()
};
Ok(ptx)
}
pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModule_t, CUerror> {
let global_state = driver::global_state()?;
let text = unsafe { CStr::from_ptr(image.cast()) }
.to_str()
.map_err(|_| CUerror::INVALID_VALUE)?;
let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
let text = get_ptx(image)?;
let ast = ptx_parser::parse_module_checked(&text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
let llvm_module = ptx::to_llvm_module(ast).map_err(|_| CUerror::UNKNOWN)?;
let mut dev = 0;
unsafe { hipCtxGetDevice(&mut dev) }?;
@ -38,6 +78,11 @@ pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -
.map_err(|_| CUerror::UNKNOWN)?;
let mut hip_module = unsafe { mem::zeroed() };
unsafe { hipModuleLoadData(&mut hip_module, elf_module.as_ptr().cast()) }?;
Ok(hip_module)
}
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
let hip_module = load_hip_module(image)?;
*module = Module { base: hip_module }.wrap();
Ok(())
}

View File

@ -11,3 +11,13 @@ pub unsafe fn heap_alloc(_heap: *mut c_void, _bytes: usize) -> *mut c_void {
pub unsafe fn heap_free(_heap: *mut c_void, _alloc: *mut c_void) {
todo!()
}
// TODO: remove duplication with zluda_dump
#[link(name = "pthread")]
unsafe extern "C" {
fn pthread_self() -> std::os::unix::thread::RawPthread;
}
pub(crate) fn current_thread() -> u32 {
(unsafe { pthread_self() }) as u32
}

View File

@ -14,3 +14,13 @@ pub unsafe fn heap_alloc(heap: *mut c_void, bytes: usize) -> *mut c_void {
pub unsafe fn heap_free(heap: *mut c_void, alloc: *mut c_void) {
HeapFree(heap, 0, alloc);
}
// TODO: remove duplication with zluda_dump
#[link(name = "kernel32")]
unsafe extern "system" {
fn GetCurrentThreadId() -> u32;
}
pub(crate) fn current_thread() -> u32 {
unsafe { GetCurrentThreadId() }
}

View File

@ -66,6 +66,7 @@ cuda_base::cuda_function_declarations!(
cuGetProcAddress,
cuGetProcAddress_v2,
cuInit,
cuLibraryLoadData,
cuMemAlloc_v2,
cuMemFree_v2,
cuMemGetAddressRange_v2,
@ -84,4 +85,4 @@ cuda_base::cuda_function_declarations!(
implemented_in_function <= [
cuLaunchKernel,
]
);
);

View File

@ -14,7 +14,6 @@ ptx_parser = { path = "../ptx_parser" }
zluda_dump_common = { path = "../zluda_dump_common" }
format = { path = "../format" }
dark_api = { path = "../dark_api" }
lz4-sys = "1.9"
regex = "1.4"
dynasm = "1.2"
dynasmrt = "1.2"

View File

@ -1,3 +1,4 @@
use ::dark_api::fatbin::FatbinFileIterator;
use ::dark_api::FnFfi;
use cuda_types::cuda::*;
use dark_api::DarkApiState2;
@ -360,7 +361,16 @@ impl DarkApiDump {
});
}
fn_logger.try_(|fn_logger| unsafe {
trace::record_submodules_from_fatbin(*module, fatbin_header, fn_logger, state)
trace::record_submodules(
*module,
fn_logger,
state,
FatbinFileIterator::new(
fatbin_header
.as_ref()
.ok_or(ErrorEntry::NullPointer("get_module_from_cubin_ext2_post"))?,
),
)
});
}
}

View File

@ -308,11 +308,11 @@ pub(crate) enum ErrorEntry {
unsafe impl Send for ErrorEntry {}
unsafe impl Sync for ErrorEntry {}
impl From<cuda_types::dark_api::ParseError> for ErrorEntry {
fn from(e: cuda_types::dark_api::ParseError) -> Self {
impl From<dark_api::fatbin::ParseError> for ErrorEntry {
fn from(e: dark_api::fatbin::ParseError) -> Self {
match e {
cuda_types::dark_api::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s),
cuda_types::dark_api::ParseError::UnexpectedBinaryField {
dark_api::fatbin::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s),
dark_api::fatbin::ParseError::UnexpectedBinaryField {
field_name,
observed,
expected,
@ -325,6 +325,20 @@ impl From<cuda_types::dark_api::ParseError> for ErrorEntry {
}
}
impl From<dark_api::fatbin::FatbinError> for ErrorEntry {
fn from(e: dark_api::fatbin::FatbinError) -> Self {
match e {
dark_api::fatbin::FatbinError::ParseFailure(parse_error) => parse_error.into(),
dark_api::fatbin::FatbinError::Lz4DecompressionFailure => {
ErrorEntry::Lz4DecompressionFailure
}
dark_api::fatbin::FatbinError::ZstdDecompressionFailure(c) => {
ErrorEntry::ZstdDecompressionFailure(c)
}
}
}
}
impl Display for ErrorEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {

View File

@ -4,7 +4,10 @@ use crate::{
};
use cuda_types::{
cuda::*,
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbinHeader, FatbincWrapper},
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
};
use dark_api::fatbin::{
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{
@ -13,7 +16,6 @@ use std::{
fs::{self, File},
io::{self, Read, Write},
path::PathBuf,
ptr,
};
use unwrap_or::unwrap_some_or;
@ -259,52 +261,53 @@ pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
let fatbinc_wrapper = FatbincWrapper::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
let is_version_2 = fatbinc_wrapper.version == FatbincWrapper::VERSION_V2;
record_submodules_from_fatbin(module, (*fatbinc_wrapper).data, fn_logger, state)?;
if is_version_2 {
let mut current = (*fatbinc_wrapper).filename_or_fatbins as *const *const c_void;
while *current != ptr::null() {
record_submodules_from_fatbin(module, *current as *const _, fn_logger, state)?;
current = current.add(1);
}
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
let mut submodules = fatbin.get_submodules()?;
while let Some(current) = submodules.next()? {
record_submodules_from_fatbin(module, current, fn_logger, state)?;
}
Ok(())
}
pub(crate) unsafe fn record_submodules_from_fatbin(
module: CUmodule,
fatbin_header: *const FatbinHeader,
submodule: FatbinSubmodule,
logger: &mut FnCallLog,
state: &mut StateTracker,
) -> Result<(), ErrorEntry> {
let header = FatbinHeader::new(&fatbin_header).map_err(ErrorEntry::from)?;
let file = header.get_content();
record_submodules(module, logger, state, file)?;
record_submodules(module, logger, state, submodule.get_files())?;
Ok(())
}
unsafe fn record_submodules(
pub(crate) unsafe fn record_submodules(
module: CUmodule,
fn_logger: &mut FnCallLog,
state: &mut StateTracker,
mut file_buffer: &[u8],
mut files: FatbinFileIterator,
) -> Result<(), ErrorEntry> {
while let Some(file) = FatbinFileHeader::next(&mut file_buffer)? {
let mut payload = if file.flags.contains(FatbinFileHeaderFlags::CompressedLz4) {
while let Some(file) = files.next()? {
let mut payload = if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedLz4)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_lz4(file)),
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
continue
))
} else if file.flags.contains(FatbinFileHeaderFlags::CompressedZstd) {
} else if file
.header
.flags
.contains(FatbinFileHeaderFlags::CompressedZstd)
{
Cow::Owned(unwrap_some_or!(
fn_logger.try_return(|| decompress_zstd(file)),
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
continue
))
} else {
Cow::Borrowed(file.get_payload())
};
match file.kind {
match file.header.kind {
FatbinFileHeader::HEADER_KIND_PTX => {
while payload.last() == Some(&0) {
// remove trailing zeros
@ -322,50 +325,10 @@ unsafe fn record_submodules(
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
],
observed: UInt::U16(file.kind),
observed: UInt::U16(file.header.kind),
});
}
}
}
Ok(())
}
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
unsafe fn decompress_lz4(file: &FatbinFileHeader) -> Result<Vec<u8>, ErrorEntry> {
let decompressed_size = usize::max(1024, (*file).uncompressed_payload as usize);
let mut decompressed_vec = vec![0u8; decompressed_size];
loop {
match lz4_sys::LZ4_decompress_safe(
file.get_payload().as_ptr() as *const _,
decompressed_vec.as_mut_ptr() as *mut _,
(*file).payload_size as _,
decompressed_vec.len() as _,
) {
error if error < 0 => {
let new_size = decompressed_vec.len() * 2;
if new_size > MAX_MODULE_DECOMPRESSION_BOUND {
return Err(ErrorEntry::Lz4DecompressionFailure);
}
decompressed_vec.resize(decompressed_vec.len() * 2, 0);
}
real_decompressed_size => {
decompressed_vec.truncate(real_decompressed_size as usize);
return Ok(decompressed_vec);
}
}
}
}
unsafe fn decompress_zstd(file: &FatbinFileHeader) -> Result<Vec<u8>, ErrorEntry> {
let mut result = Vec::with_capacity(file.uncompressed_payload as usize);
let payload = file.get_payload();
dbg!((payload.len(), file.uncompressed_payload, file.payload_size));
match zstd_safe::decompress(&mut result, payload) {
Ok(actual_size) => {
result.truncate(actual_size);
Ok(result)
}
Err(err) => Err(ErrorEntry::ZstdDecompressionFailure(err)),
}
}