mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:57:43 +03:00
Merge commit 'b8244243679bb02c3c38cf8ab00c2488a51a7741' into minmax
This commit is contained in:
48
Cargo.lock
generated
48
Cargo.lock
generated
@ -327,8 +327,10 @@ dependencies = [
|
||||
"cglue",
|
||||
"cuda_types",
|
||||
"format",
|
||||
"lz4-sys",
|
||||
"paste",
|
||||
"uuid",
|
||||
"zstd-safe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -494,6 +496,18 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
@ -582,10 +596,11 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.32"
|
||||
version = "0.1.33"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
|
||||
checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"libc",
|
||||
]
|
||||
|
||||
@ -759,9 +774,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "microlp"
|
||||
version = "0.2.10"
|
||||
version = "0.2.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edaa5264bc1f7668bc12e10757f8f529a526656c796cc2106cf2be10c5b8d483"
|
||||
checksum = "51d1790c73b93164ff65868f63164497cb32339458a9297e17e212d91df62258"
|
||||
dependencies = [
|
||||
"log",
|
||||
"sprs",
|
||||
@ -1118,6 +1133,12 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "r-efi"
|
||||
version = "5.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
@ -1516,6 +1537,15 @@ version = "0.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.14.2+wasi-0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
|
||||
dependencies = [
|
||||
"wit-bindgen-rt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wchar"
|
||||
version = "0.6.1"
|
||||
@ -1660,6 +1690,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rt"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "1.5.0"
|
||||
@ -1783,7 +1822,6 @@ dependencies = [
|
||||
"format",
|
||||
"goblin",
|
||||
"libc",
|
||||
"lz4-sys",
|
||||
"parking_lot",
|
||||
"paste",
|
||||
"ptx",
|
||||
|
@ -199,7 +199,8 @@ impl VisitMut for FixFnSignatures {
|
||||
}
|
||||
|
||||
const MODULES: &[&str] = &[
|
||||
"context", "device", "driver", "function", "link", "memory", "module", "pointer", "stream",
|
||||
"context", "device", "driver", "function", "library", "link", "memory", "module", "pointer",
|
||||
"stream",
|
||||
];
|
||||
|
||||
#[proc_macro]
|
||||
|
@ -78,122 +78,17 @@ bitflags! {
|
||||
|
||||
impl FatbincWrapper {
|
||||
pub const MAGIC: c_uint = 0x466243B1;
|
||||
const VERSION_V1: c_uint = 0x1;
|
||||
pub const VERSION_V1: c_uint = 0x1;
|
||||
pub const VERSION_V2: c_uint = 0x2;
|
||||
|
||||
pub fn new<'a, T: Sized>(ptr: &*const T) -> Result<&'a Self, ParseError> {
|
||||
unsafe { ptr.cast::<Self>().as_ref() }
|
||||
.ok_or(ParseError::NullPointer("FatbincWrapper"))
|
||||
.and_then(|ptr| {
|
||||
ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [Self::MAGIC])?;
|
||||
ParseError::check_fields(
|
||||
"FATBINC_VERSION",
|
||||
ptr.version,
|
||||
[Self::VERSION_V1, Self::VERSION_V2],
|
||||
)?;
|
||||
Ok(ptr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FatbinHeader {
|
||||
const MAGIC: c_uint = 0xBA55ED50;
|
||||
const VERSION: c_ushort = 0x01;
|
||||
|
||||
pub fn new<'a, T: Sized>(ptr: &'a *const T) -> Result<&'a Self, ParseError> {
|
||||
unsafe { ptr.cast::<Self>().as_ref() }
|
||||
.ok_or(ParseError::NullPointer("FatbinHeader"))
|
||||
.and_then(|ptr| {
|
||||
ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [Self::MAGIC])?;
|
||||
ParseError::check_fields("FATBIN_VERSION", ptr.version, [Self::VERSION])?;
|
||||
Ok(ptr)
|
||||
})
|
||||
}
|
||||
|
||||
pub unsafe fn get_content<'a>(&'a self) -> &'a [u8] {
|
||||
let start = std::ptr::from_ref(self)
|
||||
.cast::<u8>()
|
||||
.add(self.header_size as usize);
|
||||
std::slice::from_raw_parts(start, self.files_size as usize)
|
||||
}
|
||||
pub const MAGIC: c_uint = 0xBA55ED50;
|
||||
pub const VERSION: c_ushort = 0x01;
|
||||
}
|
||||
|
||||
impl FatbinFileHeader {
|
||||
pub const HEADER_KIND_PTX: c_ushort = 0x01;
|
||||
pub const HEADER_KIND_ELF: c_ushort = 0x02;
|
||||
const HEADER_VERSION_CURRENT: c_ushort = 0x101;
|
||||
|
||||
pub fn new_ptx<T: Sized>(ptr: *const T) -> Result<Option<&'static Self>, ParseError> {
|
||||
unsafe { ptr.cast::<Self>().as_ref() }
|
||||
.ok_or(ParseError::NullPointer("FatbinFileHeader"))
|
||||
.and_then(|ptr| {
|
||||
ParseError::check_fields(
|
||||
"FATBIN_FILE_HEADER_VERSION_CURRENT",
|
||||
ptr.version,
|
||||
[Self::HEADER_VERSION_CURRENT],
|
||||
)?;
|
||||
match ptr.kind {
|
||||
Self::HEADER_KIND_PTX => Ok(Some(ptr)),
|
||||
Self::HEADER_KIND_ELF => Ok(None),
|
||||
_ => Err(ParseError::UnexpectedBinaryField {
|
||||
field_name: "FATBIN_FILE_HEADER_KIND",
|
||||
observed: ptr.kind.into(),
|
||||
expected: vec![Self::HEADER_KIND_PTX.into(), Self::HEADER_KIND_ELF.into()],
|
||||
}),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub unsafe fn next<'a>(slice: &'a mut &[u8]) -> Result<Option<&'a Self>, ParseError> {
|
||||
if slice.len() < std::mem::size_of::<Self>() {
|
||||
return Ok(None);
|
||||
}
|
||||
let this = &*slice.as_ptr().cast::<Self>();
|
||||
let next_element = slice
|
||||
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
|
||||
.map(|(_, next)| next);
|
||||
*slice = next_element.unwrap_or(&[]);
|
||||
ParseError::check_fields(
|
||||
"FATBIN_FILE_HEADER_VERSION_CURRENT",
|
||||
this.version,
|
||||
[Self::HEADER_VERSION_CURRENT],
|
||||
)?;
|
||||
Ok(Some(this))
|
||||
}
|
||||
|
||||
pub unsafe fn get_payload<'a>(&'a self) -> &'a [u8] {
|
||||
let start = std::ptr::from_ref(self)
|
||||
.cast::<u8>()
|
||||
.add(self.header_size as usize);
|
||||
std::slice::from_raw_parts(start, self.payload_size as usize)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ParseError {
|
||||
NullPointer(&'static str),
|
||||
UnexpectedBinaryField {
|
||||
field_name: &'static str,
|
||||
observed: u32,
|
||||
expected: Vec<u32>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ParseError {
|
||||
pub(crate) fn check_fields<const N: usize, T: Into<u32> + Eq + Copy>(
|
||||
name: &'static str,
|
||||
observed: T,
|
||||
expected: [T; N],
|
||||
) -> Result<(), Self> {
|
||||
if expected.contains(&observed) {
|
||||
Ok(())
|
||||
} else {
|
||||
let observed = observed.into();
|
||||
let expected = expected.into_iter().map(Into::into).collect();
|
||||
Err(ParseError::UnexpectedBinaryField {
|
||||
field_name: name,
|
||||
expected,
|
||||
observed,
|
||||
})
|
||||
}
|
||||
}
|
||||
pub const HEADER_VERSION_CURRENT: c_ushort = 0x101;
|
||||
}
|
||||
|
@ -10,3 +10,5 @@ uuid = "1.16"
|
||||
paste = "1.0"
|
||||
bit-vec = "0.8.0"
|
||||
cglue = "0.3.5"
|
||||
lz4-sys = "1.9"
|
||||
zstd-safe = { version = "7.2.4", features = ["std"] }
|
||||
|
259
dark_api/src/fatbin.rs
Normal file
259
dark_api/src/fatbin.rs
Normal file
@ -0,0 +1,259 @@
|
||||
// This file contains a higher-level interface for parsing fatbins
|
||||
|
||||
use std::ptr;
|
||||
|
||||
use cuda_types::dark_api::*;
|
||||
|
||||
pub enum ParseError {
|
||||
NullPointer(&'static str),
|
||||
UnexpectedBinaryField {
|
||||
field_name: &'static str,
|
||||
observed: u32,
|
||||
expected: Vec<u32>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ParseError {
|
||||
pub(crate) fn check_fields<const N: usize, T: Into<u32> + Eq + Copy>(
|
||||
name: &'static str,
|
||||
observed: T,
|
||||
expected: [T; N],
|
||||
) -> Result<(), Self> {
|
||||
if expected.contains(&observed) {
|
||||
Ok(())
|
||||
} else {
|
||||
let observed = observed.into();
|
||||
let expected = expected.into_iter().map(Into::into).collect();
|
||||
Err(ParseError::UnexpectedBinaryField {
|
||||
field_name: name,
|
||||
expected,
|
||||
observed,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum FatbinError {
|
||||
ParseFailure(ParseError),
|
||||
Lz4DecompressionFailure,
|
||||
ZstdDecompressionFailure(usize),
|
||||
}
|
||||
|
||||
pub fn parse_fatbinc_wrapper<T: Sized>(ptr: &*const T) -> Result<&FatbincWrapper, ParseError> {
|
||||
unsafe { ptr.cast::<FatbincWrapper>().as_ref() }
|
||||
.ok_or(ParseError::NullPointer("FatbincWrapper"))
|
||||
.and_then(|ptr| {
|
||||
ParseError::check_fields("FATBINC_MAGIC", ptr.magic, [FatbincWrapper::MAGIC])?;
|
||||
ParseError::check_fields(
|
||||
"FATBINC_VERSION",
|
||||
ptr.version,
|
||||
[FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2],
|
||||
)?;
|
||||
Ok(ptr)
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_fatbin_header<T: Sized>(ptr: &*const T) -> Result<&FatbinHeader, ParseError> {
|
||||
unsafe { ptr.cast::<FatbinHeader>().as_ref() }
|
||||
.ok_or(ParseError::NullPointer("FatbinHeader"))
|
||||
.and_then(|ptr| {
|
||||
ParseError::check_fields("FATBIN_MAGIC", ptr.magic, [FatbinHeader::MAGIC])?;
|
||||
ParseError::check_fields("FATBIN_VERSION", ptr.version, [FatbinHeader::VERSION])?;
|
||||
Ok(ptr)
|
||||
})
|
||||
}
|
||||
|
||||
pub struct Fatbin<'a> {
|
||||
pub wrapper: &'a FatbincWrapper,
|
||||
}
|
||||
|
||||
impl<'a> Fatbin<'a> {
|
||||
pub fn new<T>(ptr: &'a *const T) -> Result<Self, FatbinError> {
|
||||
let wrapper: &FatbincWrapper =
|
||||
parse_fatbinc_wrapper(ptr).map_err(|e| FatbinError::ParseFailure(e))?;
|
||||
|
||||
Ok(Fatbin { wrapper })
|
||||
}
|
||||
|
||||
pub fn get_submodules(&self) -> Result<FatbinIter<'a>, FatbinError> {
|
||||
match self.wrapper.version {
|
||||
FatbincWrapper::VERSION_V2 =>
|
||||
Ok(FatbinIter::V2(FatbinSubmoduleIterator {
|
||||
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
|
||||
_phantom: std::marker::PhantomData,
|
||||
})),
|
||||
FatbincWrapper::VERSION_V1 => {
|
||||
let header = parse_fatbin_header(&self.wrapper.data)
|
||||
.map_err(FatbinError::ParseFailure)?;
|
||||
Ok(FatbinIter::V1(Some(FatbinSubmodule::new(header))))
|
||||
}
|
||||
version => Err(FatbinError::ParseFailure(ParseError::UnexpectedBinaryField{
|
||||
field_name: "FATBINC_VERSION",
|
||||
observed: version,
|
||||
expected: [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2].into(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FatbinSubmodule<'a> {
|
||||
pub header: &'a FatbinHeader, // TODO: maybe make private
|
||||
}
|
||||
|
||||
impl<'a> FatbinSubmodule<'a> {
|
||||
pub fn new(header: &'a FatbinHeader) -> Self {
|
||||
FatbinSubmodule { header }
|
||||
}
|
||||
|
||||
pub fn get_files(&self) -> FatbinFileIterator {
|
||||
unsafe { FatbinFileIterator::new(self.header) }
|
||||
}
|
||||
}
|
||||
|
||||
pub enum FatbinIter<'a> {
|
||||
V1(Option<FatbinSubmodule<'a>>),
|
||||
V2(FatbinSubmoduleIterator<'a>),
|
||||
}
|
||||
|
||||
impl<'a> FatbinIter<'a> {
|
||||
pub fn next(&mut self) -> Result<Option<FatbinSubmodule<'a>>, ParseError> {
|
||||
match self {
|
||||
FatbinIter::V1(opt) => Ok(opt.take()),
|
||||
FatbinIter::V2(iter) => unsafe { iter.next() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FatbinSubmoduleIterator<'a> {
|
||||
fatbins: *const *const std::ffi::c_void,
|
||||
_phantom: std::marker::PhantomData<&'a FatbinHeader>,
|
||||
}
|
||||
|
||||
impl<'a> FatbinSubmoduleIterator<'a> {
|
||||
pub unsafe fn next(&mut self) -> Result<Option<FatbinSubmodule<'a>>, ParseError> {
|
||||
if *self.fatbins != ptr::null() {
|
||||
let header = *self.fatbins as *const FatbinHeader;
|
||||
self.fatbins = self.fatbins.add(1);
|
||||
Ok(Some(FatbinSubmodule::new(header.as_ref().ok_or(
|
||||
ParseError::NullPointer("FatbinSubmoduleIterator"),
|
||||
)?)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FatbinFile<'a> {
|
||||
pub header: &'a FatbinFileHeader,
|
||||
}
|
||||
|
||||
impl<'a> FatbinFile<'a> {
|
||||
pub fn new(header: &'a FatbinFileHeader) -> Self {
|
||||
Self { header }
|
||||
}
|
||||
|
||||
pub unsafe fn get_payload(&'a self) -> &'a [u8] {
|
||||
let start = std::ptr::from_ref(self.header)
|
||||
.cast::<u8>()
|
||||
.add(self.header.header_size as usize);
|
||||
std::slice::from_raw_parts(start, self.header.payload_size as usize)
|
||||
}
|
||||
|
||||
pub unsafe fn decompress(&'a self) -> Result<Vec<u8>, FatbinError> {
|
||||
let mut payload = if self
|
||||
.header
|
||||
.flags
|
||||
.contains(FatbinFileHeaderFlags::CompressedLz4)
|
||||
{
|
||||
unsafe { decompress_lz4(self) }?
|
||||
} else if self
|
||||
.header
|
||||
.flags
|
||||
.contains(FatbinFileHeaderFlags::CompressedZstd)
|
||||
{
|
||||
unsafe { decompress_zstd(self) }?
|
||||
} else {
|
||||
unsafe { self.get_payload().to_vec() }
|
||||
};
|
||||
|
||||
|
||||
while payload.last() == Some(&0) {
|
||||
// remove trailing zeros
|
||||
payload.pop();
|
||||
}
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FatbinFileIterator<'a> {
|
||||
file_buffer: &'a [u8],
|
||||
}
|
||||
|
||||
impl<'a> FatbinFileIterator<'a> {
|
||||
pub unsafe fn new(header: &'a FatbinHeader) -> Self {
|
||||
let start = std::ptr::from_ref(header)
|
||||
.cast::<u8>()
|
||||
.add(header.header_size as usize);
|
||||
let file_buffer = std::slice::from_raw_parts(start, header.files_size as usize);
|
||||
|
||||
Self { file_buffer }
|
||||
}
|
||||
|
||||
pub unsafe fn next(&mut self) -> Result<Option<FatbinFile>, ParseError> {
|
||||
if self.file_buffer.len() < std::mem::size_of::<FatbinFileHeader>() {
|
||||
return Ok(None);
|
||||
}
|
||||
let this = &*self.file_buffer.as_ptr().cast::<FatbinFileHeader>();
|
||||
let next_element = self
|
||||
.file_buffer
|
||||
.split_at_checked(this.header_size as usize + this.padded_payload_size as usize)
|
||||
.map(|(_, next)| next);
|
||||
self.file_buffer = next_element.unwrap_or(&[]);
|
||||
ParseError::check_fields(
|
||||
"FATBIN_FILE_HEADER_VERSION_CURRENT",
|
||||
this.version,
|
||||
[FatbinFileHeader::HEADER_VERSION_CURRENT],
|
||||
)?;
|
||||
Ok(Some(FatbinFile::new(this)))
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
|
||||
|
||||
pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
||||
let decompressed_size = usize::max(1024, file.header.uncompressed_payload as usize);
|
||||
let mut decompressed_vec = vec![0u8; decompressed_size];
|
||||
loop {
|
||||
match lz4_sys::LZ4_decompress_safe(
|
||||
file.get_payload().as_ptr() as *const _,
|
||||
decompressed_vec.as_mut_ptr() as *mut _,
|
||||
file.header.payload_size as _,
|
||||
decompressed_vec.len() as _,
|
||||
) {
|
||||
error if error < 0 => {
|
||||
let new_size = decompressed_vec.len() * 2;
|
||||
if new_size > MAX_MODULE_DECOMPRESSION_BOUND {
|
||||
return Err(FatbinError::Lz4DecompressionFailure);
|
||||
}
|
||||
decompressed_vec.resize(decompressed_vec.len() * 2, 0);
|
||||
}
|
||||
real_decompressed_size => {
|
||||
decompressed_vec.truncate(real_decompressed_size as usize);
|
||||
return Ok(decompressed_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn decompress_zstd(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
||||
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
|
||||
let payload = file.get_payload();
|
||||
match zstd_safe::decompress(&mut result, payload) {
|
||||
Ok(actual_size) => {
|
||||
result.truncate(actual_size);
|
||||
Ok(result)
|
||||
}
|
||||
Err(err) => Err(FatbinError::ZstdDecompressionFailure(err)),
|
||||
}
|
||||
}
|
@ -2,6 +2,8 @@ use std::ffi::c_void;
|
||||
|
||||
use cuda_types::cuda::CUuuid;
|
||||
|
||||
pub mod fatbin;
|
||||
|
||||
macro_rules! dark_api_init {
|
||||
(SIZE_OF, $table_len:literal, $type_:ty) => {
|
||||
(std::mem::size_of::<usize>() * $table_len) as *const std::ffi::c_void
|
||||
|
@ -18,7 +18,7 @@ rustc-hash = "2.0.0"
|
||||
strum = "0.26"
|
||||
strum_macros = "0.26"
|
||||
petgraph = "0.7.1"
|
||||
microlp = "0.2.10"
|
||||
microlp = "0.2.11"
|
||||
int-enum = "1.1"
|
||||
unwrap_or = "1.0.1"
|
||||
|
||||
@ -31,3 +31,6 @@ tempfile = "3"
|
||||
paste = "1.0"
|
||||
pretty_assertions = "1.4.1"
|
||||
libloading = "0.8"
|
||||
|
||||
[features]
|
||||
ci_build = []
|
||||
|
64
ptx/src/test/ll/assertfail.ll
Normal file
64
ptx/src/test/ll/assertfail.ll
Normal file
@ -0,0 +1,64 @@
|
||||
declare void @__zluda_ptx_impl___assertfail(i64, i64, i32, i64, i64) #0
|
||||
|
||||
define amdgpu_kernel void @assertfail(ptr addrspace(4) byref(i64) %"86", ptr addrspace(4) byref(i64) %"87") #1 {
|
||||
%"88" = alloca i64, align 8, addrspace(5)
|
||||
%"89" = alloca i64, align 8, addrspace(5)
|
||||
%"90" = alloca i64, align 8, addrspace(5)
|
||||
%"91" = alloca i64, align 8, addrspace(5)
|
||||
%"94" = alloca i32, align 4, addrspace(5)
|
||||
%"96" = alloca i64, align 8, addrspace(5)
|
||||
%"99" = alloca i64, align 8, addrspace(5)
|
||||
%"102" = alloca i32, align 4, addrspace(5)
|
||||
%"105" = alloca i64, align 8, addrspace(5)
|
||||
%"108" = alloca i64, align 8, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"84"
|
||||
|
||||
"84": ; preds = %1
|
||||
%"92" = load i64, ptr addrspace(4) %"86", align 4
|
||||
store i64 %"92", ptr addrspace(5) %"88", align 4
|
||||
%"93" = load i64, ptr addrspace(4) %"87", align 4
|
||||
store i64 %"93", ptr addrspace(5) %"89", align 4
|
||||
store i32 0, ptr addrspace(5) %"94", align 4
|
||||
%"97" = getelementptr inbounds i8, ptr addrspace(5) %"96", i64 0
|
||||
%"98" = load i64, ptr addrspace(5) %"88", align 4
|
||||
store i64 %"98", ptr addrspace(5) %"97", align 4
|
||||
%"100" = getelementptr inbounds i8, ptr addrspace(5) %"99", i64 0
|
||||
%"101" = load i64, ptr addrspace(5) %"88", align 4
|
||||
store i64 %"101", ptr addrspace(5) %"100", align 4
|
||||
%"103" = getelementptr inbounds i8, ptr addrspace(5) %"102", i64 0
|
||||
%"104" = load i32, ptr addrspace(5) %"94", align 4
|
||||
store i32 %"104", ptr addrspace(5) %"103", align 4
|
||||
%"106" = getelementptr inbounds i8, ptr addrspace(5) %"105", i64 0
|
||||
%"107" = load i64, ptr addrspace(5) %"88", align 4
|
||||
store i64 %"107", ptr addrspace(5) %"106", align 4
|
||||
%"109" = getelementptr inbounds i8, ptr addrspace(5) %"108", i64 0
|
||||
%"110" = load i64, ptr addrspace(5) %"88", align 4
|
||||
store i64 %"110", ptr addrspace(5) %"109", align 4
|
||||
%"74" = load i64, ptr addrspace(5) %"96", align 4
|
||||
%"75" = load i64, ptr addrspace(5) %"99", align 4
|
||||
%"76" = load i32, ptr addrspace(5) %"102", align 4
|
||||
%"77" = load i64, ptr addrspace(5) %"105", align 4
|
||||
%"78" = load i64, ptr addrspace(5) %"108", align 4
|
||||
call void @__zluda_ptx_impl___assertfail(i64 %"74", i64 %"75", i32 %"76", i64 %"77", i64 %"78")
|
||||
br label %"85"
|
||||
|
||||
"85": ; preds = %"84"
|
||||
%"112" = load i64, ptr addrspace(5) %"88", align 4
|
||||
%"122" = inttoptr i64 %"112" to ptr
|
||||
%"111" = load i64, ptr %"122", align 4
|
||||
store i64 %"111", ptr addrspace(5) %"90", align 4
|
||||
%"114" = load i64, ptr addrspace(5) %"90", align 4
|
||||
%"113" = add i64 %"114", 1
|
||||
store i64 %"113", ptr addrspace(5) %"91", align 4
|
||||
%"115" = load i64, ptr addrspace(5) %"89", align 4
|
||||
%"116" = load i64, ptr addrspace(5) %"91", align 4
|
||||
%"123" = inttoptr i64 %"115" to ptr
|
||||
store i64 %"116", ptr %"123", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
43
ptx/src/test/ll/extern_func.ll
Normal file
43
ptx/src/test/ll/extern_func.ll
Normal file
@ -0,0 +1,43 @@
|
||||
declare [16 x i8] @foobar(i64) #0
|
||||
|
||||
define amdgpu_kernel void @extern_func(ptr addrspace(4) byref(i64) %"44", ptr addrspace(4) byref(i64) %"45") #1 {
|
||||
%"46" = alloca i64, align 8, addrspace(5)
|
||||
%"47" = alloca i64, align 8, addrspace(5)
|
||||
%"48" = alloca i64, align 8, addrspace(5)
|
||||
%"49" = alloca i64, align 8, addrspace(5)
|
||||
%"54" = alloca i64, align 8, addrspace(5)
|
||||
%"57" = alloca [16 x i8], align 16, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"41"
|
||||
|
||||
"41": ; preds = %1
|
||||
%"50" = load i64, ptr addrspace(4) %"44", align 4
|
||||
store i64 %"50", ptr addrspace(5) %"46", align 4
|
||||
%"51" = load i64, ptr addrspace(4) %"45", align 4
|
||||
store i64 %"51", ptr addrspace(5) %"47", align 4
|
||||
%"53" = load i64, ptr addrspace(5) %"46", align 4
|
||||
%"61" = inttoptr i64 %"53" to ptr addrspace(1)
|
||||
%"52" = load i64, ptr addrspace(1) %"61", align 4
|
||||
store i64 %"52", ptr addrspace(5) %"48", align 4
|
||||
%"55" = getelementptr inbounds i8, ptr addrspace(5) %"54", i64 0
|
||||
%"56" = load i64, ptr addrspace(5) %"48", align 4
|
||||
store i64 %"56", ptr addrspace(5) %"55", align 4
|
||||
%"39" = load i64, ptr addrspace(5) %"54", align 4
|
||||
%"40" = call [16 x i8] @foobar(i64 %"39")
|
||||
br label %"42"
|
||||
|
||||
"42": ; preds = %"41"
|
||||
store [16 x i8] %"40", ptr addrspace(5) %"57", align 1
|
||||
%"58" = load i64, ptr addrspace(5) %"57", align 4
|
||||
store i64 %"58", ptr addrspace(5) %"49", align 4
|
||||
%"59" = load i64, ptr addrspace(5) %"47", align 4
|
||||
%"60" = load i64, ptr addrspace(5) %"49", align 4
|
||||
%"64" = inttoptr i64 %"59" to ptr
|
||||
store i64 %"60", ptr %"64", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
43
ptx/src/test/ll/lanemask_lt.ll
Normal file
43
ptx/src/test/ll/lanemask_lt.ll
Normal file
@ -0,0 +1,43 @@
|
||||
declare i32 @__zluda_ptx_impl_sreg_lanemask_lt() #0
|
||||
|
||||
define amdgpu_kernel void @lanemask_lt(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #1 {
|
||||
%"38" = alloca i64, align 8, addrspace(5)
|
||||
%"39" = alloca i64, align 8, addrspace(5)
|
||||
%"40" = alloca i32, align 4, addrspace(5)
|
||||
%"41" = alloca i32, align 4, addrspace(5)
|
||||
%"42" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"33"
|
||||
|
||||
"33": ; preds = %1
|
||||
%"43" = load i64, ptr addrspace(4) %"36", align 4
|
||||
store i64 %"43", ptr addrspace(5) %"38", align 4
|
||||
%"44" = load i64, ptr addrspace(4) %"37", align 4
|
||||
store i64 %"44", ptr addrspace(5) %"39", align 4
|
||||
%"46" = load i64, ptr addrspace(5) %"38", align 4
|
||||
%"56" = inttoptr i64 %"46" to ptr
|
||||
%"55" = load i32, ptr %"56", align 4
|
||||
store i32 %"55", ptr addrspace(5) %"40", align 4
|
||||
%"48" = load i32, ptr addrspace(5) %"40", align 4
|
||||
%"57" = add i32 %"48", 1
|
||||
store i32 %"57", ptr addrspace(5) %"41", align 4
|
||||
%"31" = call i32 @__zluda_ptx_impl_sreg_lanemask_lt()
|
||||
br label %"34"
|
||||
|
||||
"34": ; preds = %"33"
|
||||
store i32 %"31", ptr addrspace(5) %"42", align 4
|
||||
%"51" = load i32, ptr addrspace(5) %"41", align 4
|
||||
%"52" = load i32, ptr addrspace(5) %"42", align 4
|
||||
%"60" = add i32 %"51", %"52"
|
||||
store i32 %"60", ptr addrspace(5) %"41", align 4
|
||||
%"53" = load i64, ptr addrspace(5) %"39", align 4
|
||||
%"54" = load i32, ptr addrspace(5) %"41", align 4
|
||||
%"63" = inttoptr i64 %"53" to ptr
|
||||
store i32 %"54", ptr %"63", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
@ -10,32 +10,65 @@ use std::fmt::{self, Debug, Display, Formatter};
|
||||
use std::fs::{self, File};
|
||||
use std::io::Write;
|
||||
use std::mem;
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::ptr;
|
||||
use std::str;
|
||||
|
||||
#[cfg(not(feature = "ci_build"))]
|
||||
macro_rules! read_test_file {
|
||||
($file:expr) => {
|
||||
{
|
||||
// CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx).
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.pop();
|
||||
path.push(file!());
|
||||
path.pop();
|
||||
path.push($file);
|
||||
std::fs::read_to_string(path).unwrap()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(feature = "ci_build")]
|
||||
macro_rules! read_test_file {
|
||||
($file:expr) => {
|
||||
include_str!($file).to_string()
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! test_ptx {
|
||||
($fn_name:ident, $input:expr, $output:expr) => {
|
||||
paste::item! {
|
||||
#[test]
|
||||
fn [<$fn_name _hip>]() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
|
||||
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
|
||||
let input = $input;
|
||||
let mut output = $output;
|
||||
test_hip_assert(stringify!($fn_name), ptx, &input, &mut output)
|
||||
test_hip_assert(stringify!($fn_name), &ptx, &input, &mut output)
|
||||
}
|
||||
}
|
||||
|
||||
paste::item! {
|
||||
#[test]
|
||||
fn [<$fn_name _cuda>]() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
|
||||
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
|
||||
let input = $input;
|
||||
let mut output = $output;
|
||||
test_cuda_assert(stringify!($fn_name), ptx, &input, &mut output)
|
||||
test_cuda_assert(stringify!($fn_name), &ptx, &input, &mut output)
|
||||
}
|
||||
}
|
||||
|
||||
paste::item! {
|
||||
#[test]
|
||||
fn [<$fn_name _llvm>]() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
|
||||
let ll = read_test_file!(concat!("../ll/", stringify!($fn_name), ".ll"));
|
||||
test_llvm_assert(stringify!($fn_name), &ptx, ll.trim())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
($fn_name:ident) => {
|
||||
paste::item! {
|
||||
#[test]
|
||||
fn [<$fn_name _llvm>]() -> Result<(), Box<dyn std::error::Error>> {
|
||||
@ -45,8 +78,6 @@ macro_rules! test_ptx {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
($fn_name:ident) => {};
|
||||
}
|
||||
|
||||
test_ptx!(ld_st, [1u64], [1u64]);
|
||||
@ -242,7 +273,8 @@ test_ptx!(
|
||||
);
|
||||
|
||||
test_ptx!(assertfail);
|
||||
test_ptx!(func_ptr);
|
||||
// TODO: not yet supported
|
||||
//test_ptx!(func_ptr);
|
||||
test_ptx!(lanemask_lt);
|
||||
test_ptx!(extern_func);
|
||||
|
||||
@ -265,15 +297,14 @@ impl<T: Debug> Debug for DisplayError<T> {
|
||||
impl<T: Debug> error::Error for DisplayError<T> {}
|
||||
|
||||
fn test_hip_assert<
|
||||
'a,
|
||||
Input: From<u8> + Debug + Copy + PartialEq,
|
||||
Output: From<u8> + Debug + Copy + PartialEq + Default,
|
||||
>(
|
||||
name: &str,
|
||||
ptx_text: &'a str,
|
||||
ptx_text: &str,
|
||||
input: &[Input],
|
||||
output: &mut [Output],
|
||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||
) -> Result<(), Box<dyn error::Error>> {
|
||||
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
|
||||
let llvm_ir = pass::to_llvm_module(ast).unwrap();
|
||||
let name = CString::new(name)?;
|
||||
@ -283,11 +314,11 @@ fn test_hip_assert<
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_llvm_assert<'a>(
|
||||
fn test_llvm_assert(
|
||||
name: &str,
|
||||
ptx_text: &'a str,
|
||||
ptx_text: &str,
|
||||
expected_ll: &str,
|
||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||
) -> Result<(), Box<dyn error::Error>> {
|
||||
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
|
||||
let llvm_ir = pass::to_llvm_module(ast).unwrap();
|
||||
let actual_ll = llvm_ir.llvm_ir.print_module_to_string();
|
||||
@ -301,22 +332,21 @@ fn test_llvm_assert<'a>(
|
||||
let mut output_file = File::create(output_file).unwrap();
|
||||
output_file.write_all(actual_ll.as_bytes()).unwrap();
|
||||
}
|
||||
let comparison = pretty_assertions::StrComparison::new(expected_ll, actual_ll);
|
||||
let comparison = pretty_assertions::StrComparison::new(&expected_ll, &actual_ll);
|
||||
panic!("assertion failed: `(left == right)`\n\n{}", comparison);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_cuda_assert<
|
||||
'a,
|
||||
Input: From<u8> + Debug + Copy + PartialEq,
|
||||
Output: From<u8> + Debug + Copy + PartialEq + Default,
|
||||
>(
|
||||
name: &str,
|
||||
ptx_text: &'a str,
|
||||
ptx_text: &str,
|
||||
input: &[Input],
|
||||
output: &mut [Output],
|
||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||
) -> Result<(), Box<dyn error::Error>> {
|
||||
let name = CString::new(name)?;
|
||||
let result = run_cuda(name.as_c_str(), ptx_text, input, output);
|
||||
assert_eq!(result.as_slice(), output);
|
||||
|
@ -1492,6 +1492,46 @@ pub struct TokenError(std::ops::Range<usize>);
|
||||
|
||||
impl std::error::Error for TokenError {}
|
||||
|
||||
fn first_optional<
|
||||
'a,
|
||||
'input,
|
||||
Input: Stream,
|
||||
OptionalOutput,
|
||||
RequiredOutput,
|
||||
Error,
|
||||
ParseOptional,
|
||||
ParseRequired,
|
||||
>(
|
||||
mut optional: ParseOptional,
|
||||
mut required: ParseRequired,
|
||||
) -> impl Parser<Input, (Option<OptionalOutput>, RequiredOutput), Error>
|
||||
where
|
||||
ParseOptional: Parser<Input, OptionalOutput, Error>,
|
||||
ParseRequired: Parser<Input, RequiredOutput, Error>,
|
||||
Error: ParserError<Input>,
|
||||
{
|
||||
move |input: &mut Input| -> Result<(Option<OptionalOutput>, RequiredOutput), ErrMode<Error>> {
|
||||
let start = input.checkpoint();
|
||||
|
||||
let parsed_optional = match optional.parse_next(input) {
|
||||
Ok(v) => Some(v),
|
||||
Err(ErrMode::Backtrack(_)) => {
|
||||
input.reset(&start);
|
||||
None
|
||||
},
|
||||
Err(e) => return Err(e)
|
||||
};
|
||||
|
||||
match required.parse_next(input) {
|
||||
Ok(v) => return Ok((parsed_optional, v)),
|
||||
Err(ErrMode::Backtrack(_)) => input.reset(&start),
|
||||
Err(e) => return Err(e)
|
||||
};
|
||||
|
||||
Ok((None, required.parse_next(input)?))
|
||||
}
|
||||
}
|
||||
|
||||
// This macro is responsible for generating parser code for instruction parser.
|
||||
// Instruction parsing is by far the most complex part of parsing PTX code:
|
||||
// * There are tens of instruction kinds, each with slightly different parsing rules
|
||||
@ -3413,6 +3453,7 @@ derive_parser!(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::first_optional;
|
||||
use crate::parse_module_checked;
|
||||
use crate::PtxError;
|
||||
|
||||
@ -3423,6 +3464,55 @@ mod tests {
|
||||
use logos::Span;
|
||||
use winnow::prelude::*;
|
||||
|
||||
#[test]
|
||||
fn first_optional_present() {
|
||||
let text = "AB";
|
||||
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
|
||||
assert_eq!(result, Ok((Some('A'), 'B')));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_optional_absent() {
|
||||
let text = "B";
|
||||
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
|
||||
assert_eq!(result, Ok((None, 'B')));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_optional_repeated_absent() {
|
||||
let text = "A";
|
||||
let result = first_optional::<_, _, _, (), _, _>('A', 'A').parse(text);
|
||||
assert_eq!(result, Ok((None, 'A')));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_optional_repeated_present() {
|
||||
let text = "AA";
|
||||
let result = first_optional::<_, _, _, (), _, _>('A', 'A').parse(text);
|
||||
assert_eq!(result, Ok((Some('A'), 'A')));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_optional_sequence_absent() {
|
||||
let text = "AA";
|
||||
let result = ('A', first_optional::<_, _, _, (), _, _>('A', 'A')).parse(text);
|
||||
assert_eq!(result, Ok(('A', (None, 'A'))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_optional_sequence_present() {
|
||||
let text = "AAA";
|
||||
let result = ('A', first_optional::<_, _, _, (), _, _>('A', 'A')).parse(text);
|
||||
assert_eq!(result, Ok(('A', (Some('A'), 'A'))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_optional_no_match() {
|
||||
let text = "C";
|
||||
let result = first_optional::<_, _, _, (), _, _>('A', 'B').parse(text);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sm_11() {
|
||||
let text = ".target sm_11";
|
||||
|
@ -757,12 +757,13 @@ fn emit_definition_parser(
|
||||
DotModifierRef::Direct { optional: true, .. }
|
||||
| DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
|
||||
});
|
||||
let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| {
|
||||
let (arguments_pattern, arguments_parser) = definition.arguments.0.iter().enumerate().rfold((quote! { () }, quote! { empty }), |(emitted_pattern, emitted_parser), (idx, arg)| {
|
||||
let comma = if idx == 0 || arg.pre_pipe {
|
||||
quote! { empty }
|
||||
} else {
|
||||
quote! { any.verify(|(t, _)| *t == #token_type::Comma).void() }
|
||||
};
|
||||
|
||||
let pre_bracket = if arg.pre_bracket {
|
||||
quote! {
|
||||
any.verify(|(t, _)| *t == #token_type::LBracket).void()
|
||||
@ -833,16 +834,20 @@ fn emit_definition_parser(
|
||||
#pattern.map(|(_, _, _, _, name, _, _)| name)
|
||||
}
|
||||
};
|
||||
if arg.optional {
|
||||
quote! {
|
||||
let #arg_name = opt(#inner_parser).parse_next(stream)?;
|
||||
}
|
||||
|
||||
let parser = if arg.optional {
|
||||
quote! { first_optional(#inner_parser, #emitted_parser) }
|
||||
} else {
|
||||
quote! {
|
||||
let #arg_name = #inner_parser.parse_next(stream)?;
|
||||
}
|
||||
}
|
||||
quote! { (#inner_parser, #emitted_parser) }
|
||||
};
|
||||
|
||||
let pattern = quote! { ( #arg_name, #emitted_pattern ) };
|
||||
|
||||
(pattern, parser)
|
||||
});
|
||||
|
||||
let arguments_parse = quote! { let #arguments_pattern = ( #arguments_parser ).parse_next(stream)?; };
|
||||
|
||||
let fn_args = definition.function_arguments();
|
||||
let fn_name = format_ident!("{}_{}", opcode, fn_idx);
|
||||
let fn_call = quote! {
|
||||
@ -863,7 +868,7 @@ fn emit_definition_parser(
|
||||
}
|
||||
}
|
||||
#(#unordered_parse_validations)*
|
||||
#(#arguments_parse)*
|
||||
#arguments_parse
|
||||
#fn_call
|
||||
}
|
||||
}
|
||||
|
@ -22,14 +22,23 @@ impl Parse for ParseDefinitions {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let token_type = input.parse::<ItemEnum>()?;
|
||||
let mut additional_enums = FxHashMap::default();
|
||||
while input.peek(Token![#]) {
|
||||
let enum_ = input.parse::<ItemEnum>()?;
|
||||
additional_enums.insert(enum_.ident.clone(), enum_);
|
||||
}
|
||||
let mut definitions = Vec::new();
|
||||
while !input.is_empty() {
|
||||
definitions.push(input.parse::<OpcodeDefinition>()?);
|
||||
loop {
|
||||
if input.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let lookahead = input.lookahead1();
|
||||
if lookahead.peek(Token![#]) {
|
||||
let enum_ = input.parse::<ItemEnum>()?;
|
||||
additional_enums.insert(enum_.ident.clone(), enum_);
|
||||
} else if lookahead.peek(Ident) {
|
||||
definitions.push(input.parse::<OpcodeDefinition>()?);
|
||||
} else {
|
||||
return Err(lookahead.error());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
token_type,
|
||||
additional_enums,
|
||||
|
@ -10,6 +10,10 @@ use std::{
|
||||
usize,
|
||||
};
|
||||
|
||||
#[cfg_attr(windows, path = "os_win.rs")]
|
||||
#[cfg_attr(not(windows), path = "os_unix.rs")]
|
||||
mod os;
|
||||
|
||||
pub(crate) struct GlobalState {
|
||||
pub devices: Vec<Device>,
|
||||
pub comgr: Comgr,
|
||||
@ -232,7 +236,31 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||
unix_seconds: u64,
|
||||
result: *mut [u64; 2],
|
||||
) -> cuda_types::cuda::CUresult {
|
||||
todo!()
|
||||
let current_process = std::process::id();
|
||||
let current_thread = os::current_thread();
|
||||
|
||||
let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast();
|
||||
let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast();
|
||||
let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1];
|
||||
|
||||
let devices = get_device_hash_info()?;
|
||||
let device_count = devices.len() as u32;
|
||||
let get_device = |dev| devices[dev as usize];
|
||||
|
||||
let hash = ::dark_api::integrity_check(
|
||||
version,
|
||||
unix_seconds,
|
||||
cuda_types::cuda::CUDA_VERSION,
|
||||
current_process,
|
||||
current_thread,
|
||||
integrity_check_table,
|
||||
cudart_table,
|
||||
fn_address,
|
||||
device_count,
|
||||
get_device,
|
||||
);
|
||||
*result = hash;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn context_check(
|
||||
@ -244,10 +272,50 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||
}
|
||||
|
||||
unsafe extern "system" fn check_fn3() -> u32 {
|
||||
todo!()
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
|
||||
let mut device_count = 0;
|
||||
device::get_count(&mut device_count)?;
|
||||
|
||||
(0..device_count)
|
||||
.map(|dev| {
|
||||
let mut guid = CUuuid_st { bytes: [0; 16] };
|
||||
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
|
||||
|
||||
let mut pci_domain = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_domain,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
let mut pci_bus = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_bus,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
let mut pci_device = 0;
|
||||
device::get_attribute(
|
||||
&mut pci_device,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID,
|
||||
dev,
|
||||
)?;
|
||||
|
||||
Ok(::dark_api::DeviceHashinfo {
|
||||
guid,
|
||||
pci_domain,
|
||||
pci_bus,
|
||||
pci_device,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable =
|
||||
::dark_api::cuda::CudaDarkApiGlobalTable::new::<DarkApi>();
|
||||
|
||||
|
38
zluda/src/impl/library.rs
Normal file
38
zluda/src/impl/library.rs
Normal file
@ -0,0 +1,38 @@
|
||||
use super::module;
|
||||
|
||||
use super::ZludaObject;
|
||||
|
||||
use cuda_types::cuda::*;
|
||||
use hip_runtime_sys::*;
|
||||
|
||||
pub(crate) struct Library {
|
||||
base: hipModule_t,
|
||||
}
|
||||
|
||||
impl ZludaObject for Library {
|
||||
const COOKIE: usize = 0xb328a916cc234d7c;
|
||||
|
||||
type CudaHandle = CUlibrary;
|
||||
|
||||
fn drop_checked(&mut self) -> CUresult {
|
||||
// TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice.
|
||||
unsafe { hipModuleUnload(self.base) }?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// This implementation simply loads the code as a HIP module for now. The various JIT and library options are ignored.
|
||||
pub(crate) fn load_data(
|
||||
library: &mut CUlibrary,
|
||||
code: *const ::core::ffi::c_void,
|
||||
_jit_options: &mut CUjit_option,
|
||||
_jit_options_values: &mut *mut ::core::ffi::c_void,
|
||||
_num_jit_options: ::core::ffi::c_uint,
|
||||
_library_options: &mut CUlibraryOption,
|
||||
_library_option_values: &mut *mut ::core::ffi::c_void,
|
||||
_num_library_options: ::core::ffi::c_uint,
|
||||
) -> CUresult {
|
||||
let hip_module = module::load_hip_module(code)?;
|
||||
*library = Library { base: hip_module }.wrap();
|
||||
Ok(())
|
||||
}
|
@ -10,6 +10,7 @@ pub(super) mod context;
|
||||
pub(super) mod device;
|
||||
pub(super) mod driver;
|
||||
pub(super) mod function;
|
||||
pub(super) mod library;
|
||||
pub(super) mod memory;
|
||||
pub(super) mod module;
|
||||
pub(super) mod pointer;
|
||||
@ -135,6 +136,9 @@ from_cuda_nop!(
|
||||
cuda_types::cuda::CUdevprop,
|
||||
CUdevice_attribute,
|
||||
CUdriverProcAddressQueryResult,
|
||||
CUjit_option,
|
||||
CUlibrary,
|
||||
CUlibraryOption,
|
||||
CUmoduleLoadingMode,
|
||||
CUuuid
|
||||
);
|
||||
@ -169,6 +173,15 @@ impl<'a> FromCuda<'a, *const ::core::ffi::c_char> for &CStr {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FromCuda<'a, *const ::core::ffi::c_void> for &'a ::core::ffi::c_void {
|
||||
fn from_cuda(x: &'a *const ::core::ffi::c_void) -> Result<Self, CUerror> {
|
||||
match unsafe { x.as_ref() } {
|
||||
Some(x) => Ok(x),
|
||||
None => Err(CUerror::INVALID_VALUE),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait ZludaObject: Sized + Send + Sync {
|
||||
const COOKIE: usize;
|
||||
const LIVENESS_FAIL: CUerror = cuda_types::cuda::CUerror::INVALID_VALUE;
|
||||
|
@ -1,5 +1,9 @@
|
||||
use super::{driver, ZludaObject};
|
||||
use cuda_types::cuda::*;
|
||||
use cuda_types::{
|
||||
cuda::*,
|
||||
dark_api::{FatbinFileHeader, FatbincWrapper},
|
||||
};
|
||||
use dark_api::fatbin::Fatbin;
|
||||
use hip_runtime_sys::*;
|
||||
use std::{ffi::CStr, mem};
|
||||
|
||||
@ -18,12 +22,48 @@ impl ZludaObject for Module {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -> CUresult {
|
||||
fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result<Vec<u8>, CUerror> {
|
||||
let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?;
|
||||
let mut submodules = fatbin.get_submodules().map_err(|_| CUerror::UNKNOWN)?;
|
||||
|
||||
while let Some(current) = unsafe { submodules.next().map_err(|_| CUerror::UNKNOWN)? } {
|
||||
let mut files = current.get_files();
|
||||
while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } {
|
||||
if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX {
|
||||
let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?;
|
||||
return Ok(decompressed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(CUerror::NO_BINARY_FOR_GPU)
|
||||
}
|
||||
|
||||
/// get_ptx takes an `image` that can be either a fatbin or a NULL-terminated ptx, and returns a String containing a ptx extracted from `image`.
|
||||
fn get_ptx(image: *const ::core::ffi::c_void) -> Result<String, CUerror> {
|
||||
if image.is_null() {
|
||||
return Err(CUerror::INVALID_VALUE);
|
||||
}
|
||||
|
||||
let ptx = if unsafe { *(image as *const u32) } == FatbincWrapper::MAGIC {
|
||||
let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?;
|
||||
str::from_utf8(&ptx_bytes)
|
||||
.map_err(|_| CUerror::UNKNOWN)?
|
||||
.to_owned()
|
||||
} else {
|
||||
unsafe { CStr::from_ptr(image.cast()) }
|
||||
.to_str()
|
||||
.map_err(|_| CUerror::INVALID_VALUE)?
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
Ok(ptx)
|
||||
}
|
||||
|
||||
pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModule_t, CUerror> {
|
||||
let global_state = driver::global_state()?;
|
||||
let text = unsafe { CStr::from_ptr(image.cast()) }
|
||||
.to_str()
|
||||
.map_err(|_| CUerror::INVALID_VALUE)?;
|
||||
let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
|
||||
let text = get_ptx(image)?;
|
||||
let ast = ptx_parser::parse_module_checked(&text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
|
||||
let llvm_module = ptx::to_llvm_module(ast).map_err(|_| CUerror::UNKNOWN)?;
|
||||
let mut dev = 0;
|
||||
unsafe { hipCtxGetDevice(&mut dev) }?;
|
||||
@ -38,6 +78,11 @@ pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -
|
||||
.map_err(|_| CUerror::UNKNOWN)?;
|
||||
let mut hip_module = unsafe { mem::zeroed() };
|
||||
unsafe { hipModuleLoadData(&mut hip_module, elf_module.as_ptr().cast()) }?;
|
||||
Ok(hip_module)
|
||||
}
|
||||
|
||||
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
|
||||
let hip_module = load_hip_module(image)?;
|
||||
*module = Module { base: hip_module }.wrap();
|
||||
Ok(())
|
||||
}
|
||||
|
@ -11,3 +11,13 @@ pub unsafe fn heap_alloc(_heap: *mut c_void, _bytes: usize) -> *mut c_void {
|
||||
pub unsafe fn heap_free(_heap: *mut c_void, _alloc: *mut c_void) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// TODO: remove duplication with zluda_dump
|
||||
#[link(name = "pthread")]
|
||||
unsafe extern "C" {
|
||||
fn pthread_self() -> std::os::unix::thread::RawPthread;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
(unsafe { pthread_self() }) as u32
|
||||
}
|
||||
|
@ -14,3 +14,13 @@ pub unsafe fn heap_alloc(heap: *mut c_void, bytes: usize) -> *mut c_void {
|
||||
pub unsafe fn heap_free(heap: *mut c_void, alloc: *mut c_void) {
|
||||
HeapFree(heap, 0, alloc);
|
||||
}
|
||||
|
||||
// TODO: remove duplication with zluda_dump
|
||||
#[link(name = "kernel32")]
|
||||
unsafe extern "system" {
|
||||
fn GetCurrentThreadId() -> u32;
|
||||
}
|
||||
|
||||
pub(crate) fn current_thread() -> u32 {
|
||||
unsafe { GetCurrentThreadId() }
|
||||
}
|
||||
|
@ -66,6 +66,7 @@ cuda_base::cuda_function_declarations!(
|
||||
cuGetProcAddress,
|
||||
cuGetProcAddress_v2,
|
||||
cuInit,
|
||||
cuLibraryLoadData,
|
||||
cuMemAlloc_v2,
|
||||
cuMemFree_v2,
|
||||
cuMemGetAddressRange_v2,
|
||||
@ -84,4 +85,4 @@ cuda_base::cuda_function_declarations!(
|
||||
implemented_in_function <= [
|
||||
cuLaunchKernel,
|
||||
]
|
||||
);
|
||||
);
|
||||
|
@ -14,7 +14,6 @@ ptx_parser = { path = "../ptx_parser" }
|
||||
zluda_dump_common = { path = "../zluda_dump_common" }
|
||||
format = { path = "../format" }
|
||||
dark_api = { path = "../dark_api" }
|
||||
lz4-sys = "1.9"
|
||||
regex = "1.4"
|
||||
dynasm = "1.2"
|
||||
dynasmrt = "1.2"
|
||||
|
@ -1,3 +1,4 @@
|
||||
use ::dark_api::fatbin::FatbinFileIterator;
|
||||
use ::dark_api::FnFfi;
|
||||
use cuda_types::cuda::*;
|
||||
use dark_api::DarkApiState2;
|
||||
@ -360,7 +361,16 @@ impl DarkApiDump {
|
||||
});
|
||||
}
|
||||
fn_logger.try_(|fn_logger| unsafe {
|
||||
trace::record_submodules_from_fatbin(*module, fatbin_header, fn_logger, state)
|
||||
trace::record_submodules(
|
||||
*module,
|
||||
fn_logger,
|
||||
state,
|
||||
FatbinFileIterator::new(
|
||||
fatbin_header
|
||||
.as_ref()
|
||||
.ok_or(ErrorEntry::NullPointer("get_module_from_cubin_ext2_post"))?,
|
||||
),
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -308,11 +308,11 @@ pub(crate) enum ErrorEntry {
|
||||
unsafe impl Send for ErrorEntry {}
|
||||
unsafe impl Sync for ErrorEntry {}
|
||||
|
||||
impl From<cuda_types::dark_api::ParseError> for ErrorEntry {
|
||||
fn from(e: cuda_types::dark_api::ParseError) -> Self {
|
||||
impl From<dark_api::fatbin::ParseError> for ErrorEntry {
|
||||
fn from(e: dark_api::fatbin::ParseError) -> Self {
|
||||
match e {
|
||||
cuda_types::dark_api::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s),
|
||||
cuda_types::dark_api::ParseError::UnexpectedBinaryField {
|
||||
dark_api::fatbin::ParseError::NullPointer(s) => ErrorEntry::NullPointer(s),
|
||||
dark_api::fatbin::ParseError::UnexpectedBinaryField {
|
||||
field_name,
|
||||
observed,
|
||||
expected,
|
||||
@ -325,6 +325,20 @@ impl From<cuda_types::dark_api::ParseError> for ErrorEntry {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<dark_api::fatbin::FatbinError> for ErrorEntry {
|
||||
fn from(e: dark_api::fatbin::FatbinError) -> Self {
|
||||
match e {
|
||||
dark_api::fatbin::FatbinError::ParseFailure(parse_error) => parse_error.into(),
|
||||
dark_api::fatbin::FatbinError::Lz4DecompressionFailure => {
|
||||
ErrorEntry::Lz4DecompressionFailure
|
||||
}
|
||||
dark_api::fatbin::FatbinError::ZstdDecompressionFailure(c) => {
|
||||
ErrorEntry::ZstdDecompressionFailure(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ErrorEntry {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
@ -4,7 +4,10 @@ use crate::{
|
||||
};
|
||||
use cuda_types::{
|
||||
cuda::*,
|
||||
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbinHeader, FatbincWrapper},
|
||||
dark_api::{FatbinFileHeader, FatbinFileHeaderFlags, FatbincWrapper},
|
||||
};
|
||||
use dark_api::fatbin::{
|
||||
decompress_lz4, decompress_zstd, Fatbin, FatbinFileIterator, FatbinSubmodule,
|
||||
};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use std::{
|
||||
@ -13,7 +16,6 @@ use std::{
|
||||
fs::{self, File},
|
||||
io::{self, Read, Write},
|
||||
path::PathBuf,
|
||||
ptr,
|
||||
};
|
||||
use unwrap_or::unwrap_some_or;
|
||||
|
||||
@ -259,52 +261,53 @@ pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
|
||||
fn_logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
let fatbinc_wrapper = FatbincWrapper::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
|
||||
let is_version_2 = fatbinc_wrapper.version == FatbincWrapper::VERSION_V2;
|
||||
record_submodules_from_fatbin(module, (*fatbinc_wrapper).data, fn_logger, state)?;
|
||||
if is_version_2 {
|
||||
let mut current = (*fatbinc_wrapper).filename_or_fatbins as *const *const c_void;
|
||||
while *current != ptr::null() {
|
||||
record_submodules_from_fatbin(module, *current as *const _, fn_logger, state)?;
|
||||
current = current.add(1);
|
||||
}
|
||||
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
|
||||
let mut submodules = fatbin.get_submodules()?;
|
||||
while let Some(current) = submodules.next()? {
|
||||
record_submodules_from_fatbin(module, current, fn_logger, state)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn record_submodules_from_fatbin(
|
||||
module: CUmodule,
|
||||
fatbin_header: *const FatbinHeader,
|
||||
submodule: FatbinSubmodule,
|
||||
logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
let header = FatbinHeader::new(&fatbin_header).map_err(ErrorEntry::from)?;
|
||||
let file = header.get_content();
|
||||
record_submodules(module, logger, state, file)?;
|
||||
record_submodules(module, logger, state, submodule.get_files())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn record_submodules(
|
||||
pub(crate) unsafe fn record_submodules(
|
||||
module: CUmodule,
|
||||
fn_logger: &mut FnCallLog,
|
||||
state: &mut StateTracker,
|
||||
mut file_buffer: &[u8],
|
||||
mut files: FatbinFileIterator,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
while let Some(file) = FatbinFileHeader::next(&mut file_buffer)? {
|
||||
let mut payload = if file.flags.contains(FatbinFileHeaderFlags::CompressedLz4) {
|
||||
while let Some(file) = files.next()? {
|
||||
let mut payload = if file
|
||||
.header
|
||||
.flags
|
||||
.contains(FatbinFileHeaderFlags::CompressedLz4)
|
||||
{
|
||||
Cow::Owned(unwrap_some_or!(
|
||||
fn_logger.try_return(|| decompress_lz4(file)),
|
||||
fn_logger.try_return(|| decompress_lz4(&file).map_err(|e| e.into())),
|
||||
continue
|
||||
))
|
||||
} else if file.flags.contains(FatbinFileHeaderFlags::CompressedZstd) {
|
||||
} else if file
|
||||
.header
|
||||
.flags
|
||||
.contains(FatbinFileHeaderFlags::CompressedZstd)
|
||||
{
|
||||
Cow::Owned(unwrap_some_or!(
|
||||
fn_logger.try_return(|| decompress_zstd(file)),
|
||||
fn_logger.try_return(|| decompress_zstd(&file).map_err(|e| e.into())),
|
||||
continue
|
||||
))
|
||||
} else {
|
||||
Cow::Borrowed(file.get_payload())
|
||||
};
|
||||
match file.kind {
|
||||
match file.header.kind {
|
||||
FatbinFileHeader::HEADER_KIND_PTX => {
|
||||
while payload.last() == Some(&0) {
|
||||
// remove trailing zeros
|
||||
@ -322,50 +325,10 @@ unsafe fn record_submodules(
|
||||
UInt::U16(FatbinFileHeader::HEADER_KIND_PTX),
|
||||
UInt::U16(FatbinFileHeader::HEADER_KIND_ELF),
|
||||
],
|
||||
observed: UInt::U16(file.kind),
|
||||
observed: UInt::U16(file.header.kind),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
|
||||
|
||||
unsafe fn decompress_lz4(file: &FatbinFileHeader) -> Result<Vec<u8>, ErrorEntry> {
|
||||
let decompressed_size = usize::max(1024, (*file).uncompressed_payload as usize);
|
||||
let mut decompressed_vec = vec![0u8; decompressed_size];
|
||||
loop {
|
||||
match lz4_sys::LZ4_decompress_safe(
|
||||
file.get_payload().as_ptr() as *const _,
|
||||
decompressed_vec.as_mut_ptr() as *mut _,
|
||||
(*file).payload_size as _,
|
||||
decompressed_vec.len() as _,
|
||||
) {
|
||||
error if error < 0 => {
|
||||
let new_size = decompressed_vec.len() * 2;
|
||||
if new_size > MAX_MODULE_DECOMPRESSION_BOUND {
|
||||
return Err(ErrorEntry::Lz4DecompressionFailure);
|
||||
}
|
||||
decompressed_vec.resize(decompressed_vec.len() * 2, 0);
|
||||
}
|
||||
real_decompressed_size => {
|
||||
decompressed_vec.truncate(real_decompressed_size as usize);
|
||||
return Ok(decompressed_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn decompress_zstd(file: &FatbinFileHeader) -> Result<Vec<u8>, ErrorEntry> {
|
||||
let mut result = Vec::with_capacity(file.uncompressed_payload as usize);
|
||||
let payload = file.get_payload();
|
||||
dbg!((payload.len(), file.uncompressed_payload, file.payload_size));
|
||||
match zstd_safe::decompress(&mut result, payload) {
|
||||
Ok(actual_size) => {
|
||||
result.truncate(actual_size);
|
||||
Ok(result)
|
||||
}
|
||||
Err(err) => Err(ErrorEntry::ZstdDecompressionFailure(err)),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user