mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 14:57:43 +03:00
Unified fatbin versions behind a single iterator. (#398)
This commit is contained in:
@ -75,21 +75,24 @@ impl<'a> Fatbin<'a> {
|
||||
Ok(Fatbin { wrapper })
|
||||
}
|
||||
|
||||
pub fn get_first(&self) -> Result<FatbinSubmodule, FatbinError> {
|
||||
let header: &FatbinHeader =
|
||||
parse_fatbin_header(&self.wrapper.data).map_err(|e| FatbinError::ParseFailure(e))?;
|
||||
Ok(FatbinSubmodule::new(header))
|
||||
}
|
||||
|
||||
pub fn get_submodules(&self) -> Option<FatbinSubmoduleIterator> {
|
||||
let is_version_2 = self.wrapper.version == FatbincWrapper::VERSION_V2;
|
||||
if !is_version_2 {
|
||||
return None;
|
||||
pub fn get_submodules(&self) -> Result<FatbinIter<'a>, FatbinError> {
|
||||
match self.wrapper.version {
|
||||
FatbincWrapper::VERSION_V2 =>
|
||||
Ok(FatbinIter::V2(FatbinSubmoduleIterator {
|
||||
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
|
||||
_phantom: std::marker::PhantomData,
|
||||
})),
|
||||
FatbincWrapper::VERSION_V1 => {
|
||||
let header = parse_fatbin_header(&self.wrapper.data)
|
||||
.map_err(FatbinError::ParseFailure)?;
|
||||
Ok(FatbinIter::V1(Some(FatbinSubmodule::new(header))))
|
||||
}
|
||||
version => Err(FatbinError::ParseFailure(ParseError::UnexpectedBinaryField{
|
||||
field_name: "FATBINC_VERSION",
|
||||
observed: version,
|
||||
expected: [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2].into(),
|
||||
})),
|
||||
}
|
||||
|
||||
Some(FatbinSubmoduleIterator {
|
||||
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,12 +110,27 @@ impl<'a> FatbinSubmodule<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FatbinSubmoduleIterator {
|
||||
fatbins: *const *const std::ffi::c_void,
|
||||
pub enum FatbinIter<'a> {
|
||||
V1(Option<FatbinSubmodule<'a>>),
|
||||
V2(FatbinSubmoduleIterator<'a>),
|
||||
}
|
||||
|
||||
impl FatbinSubmoduleIterator {
|
||||
pub unsafe fn next(&mut self) -> Result<Option<FatbinSubmodule>, ParseError> {
|
||||
impl<'a> FatbinIter<'a> {
|
||||
pub fn next(&mut self) -> Result<Option<FatbinSubmodule<'a>>, ParseError> {
|
||||
match self {
|
||||
FatbinIter::V1(opt) => Ok(opt.take()),
|
||||
FatbinIter::V2(iter) => unsafe { iter.next() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FatbinSubmoduleIterator<'a> {
|
||||
fatbins: *const *const std::ffi::c_void,
|
||||
_phantom: std::marker::PhantomData<&'a FatbinHeader>,
|
||||
}
|
||||
|
||||
impl<'a> FatbinSubmoduleIterator<'a> {
|
||||
pub unsafe fn next(&mut self) -> Result<Option<FatbinSubmodule<'a>>, ParseError> {
|
||||
if *self.fatbins != ptr::null() {
|
||||
let header = *self.fatbins as *const FatbinHeader;
|
||||
self.fatbins = self.fatbins.add(1);
|
||||
|
@ -24,14 +24,15 @@ impl ZludaObject for Module {
|
||||
|
||||
fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result<Vec<u8>, CUerror> {
|
||||
let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?;
|
||||
let first = fatbin.get_first().map_err(|_| CUerror::UNKNOWN)?;
|
||||
let mut files = first.get_files();
|
||||
let mut submodules = fatbin.get_submodules().map_err(|_| CUerror::UNKNOWN)?;
|
||||
|
||||
while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } {
|
||||
// Eventually we will want to get the PTX for the highest hardware version that we can get to compile. But for now we just get the first PTX we can find.
|
||||
if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX {
|
||||
let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?;
|
||||
return Ok(decompressed);
|
||||
while let Some(current) = unsafe { submodules.next().map_err(|_| CUerror::UNKNOWN)? } {
|
||||
let mut files = current.get_files();
|
||||
while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } {
|
||||
if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX {
|
||||
let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?;
|
||||
return Ok(decompressed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -262,12 +262,9 @@ pub(crate) unsafe fn record_submodules_from_wrapped_fatbin(
|
||||
state: &mut StateTracker,
|
||||
) -> Result<(), ErrorEntry> {
|
||||
let fatbin = Fatbin::new(&fatbinc_wrapper).map_err(ErrorEntry::from)?;
|
||||
let first = fatbin.get_first().map_err(ErrorEntry::from)?;
|
||||
record_submodules_from_fatbin(module, first, fn_logger, state)?;
|
||||
if let Some(mut submodules) = fatbin.get_submodules() {
|
||||
while let Some(current) = submodules.next()? {
|
||||
record_submodules_from_fatbin(module, current, fn_logger, state)?;
|
||||
}
|
||||
let mut submodules = fatbin.get_submodules()?;
|
||||
while let Some(current) = submodules.next()? {
|
||||
record_submodules_from_fatbin(module, current, fn_logger, state)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user