Implement linking

This commit is contained in:
Andrzej Janik
2021-09-16 23:26:02 +00:00
parent 04394dbb04
commit 62ce1fd3a9
3 changed files with 36 additions and 22 deletions

View File

@ -2565,6 +2565,7 @@ pub unsafe extern "system" fn cuLinkAddData_v2(
options, options,
optionValues, optionValues,
) )
.encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
@ -2580,12 +2581,12 @@ pub extern "system" fn cuLinkAddFile_v2(
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]
pub extern "system" fn cuLinkComplete( pub unsafe extern "system" fn cuLinkComplete(
state: CUlinkState, state: CUlinkState,
cubinOut: *mut *mut ::std::os::raw::c_void, cubinOut: *mut *mut ::std::os::raw::c_void,
sizeOut: *mut usize, sizeOut: *mut usize,
) -> CUresult { ) -> CUresult {
r#impl::link::complete(state, cubinOut, sizeOut) r#impl::link::complete(state, cubinOut, sizeOut).encuda()
} }
#[cfg_attr(not(test), no_mangle)] #[cfg_attr(not(test), no_mangle)]

View File

@ -3,10 +3,18 @@ use std::{
mem, ptr, slice, mem, ptr, slice,
}; };
use crate::cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult}; use hip_runtime_sys::{hipCtxGetDevice, hipError_t, hipGetDeviceProperties};
use crate::{
cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult},
hip_call,
};
use super::module::{self, SpirvModule};
struct LinkState { struct LinkState {
modules: Vec<String>, modules: Vec<SpirvModule>,
result: Option<Vec<u8>>,
} }
pub(crate) unsafe fn create( pub(crate) unsafe fn create(
@ -20,6 +28,7 @@ pub(crate) unsafe fn create(
} }
let state = Box::new(LinkState { let state = Box::new(LinkState {
modules: Vec::new(), modules: Vec::new(),
result: None,
}); });
*state_out = mem::transmute(state); *state_out = mem::transmute(state);
CUresult::CUDA_SUCCESS CUresult::CUDA_SUCCESS
@ -34,31 +43,36 @@ pub(crate) unsafe fn add_data(
num_options: u32, num_options: u32,
options: *mut CUjit_option, options: *mut CUjit_option,
option_values: *mut *mut c_void, option_values: *mut *mut c_void,
) -> CUresult { ) -> Result<(), hipError_t> {
if state == ptr::null_mut() { if state == ptr::null_mut() {
return CUresult::CUDA_ERROR_INVALID_VALUE; return Err(hipError_t::hipErrorInvalidValue);
} }
let state: *mut LinkState = mem::transmute(state); let state: *mut LinkState = mem::transmute(state);
let state = &mut *state; let state = &mut *state;
// V-RAY specific hack // V-RAY specific hack
if state.modules.len() == 2 { if state.modules.len() == 2 {
return CUresult::CUDA_SUCCESS; return Err(hipError_t::hipSuccess);
} }
let ptx = slice::from_raw_parts(data as *mut u8, size); let spirv_data = SpirvModule::new_raw(data as *const _)?;
state.modules.push( state.modules.push(spirv_data);
CStr::from_bytes_with_nul_unchecked(ptx) Ok(())
.to_string_lossy()
.to_string(),
);
CUresult::CUDA_SUCCESS
} }
pub(crate) fn complete( pub(crate) unsafe fn complete(
state: CUlinkState, state: CUlinkState,
cubin_out: *mut *mut c_void, cubin_out: *mut *mut c_void,
size_out: *mut usize, size_out: *mut usize,
) -> CUresult { ) -> Result<(), hipError_t> {
CUresult::CUDA_SUCCESS let mut dev = 0;
hip_call! { hipCtxGetDevice(&mut dev) };
let mut props = unsafe { mem::zeroed() };
hip_call! { hipGetDeviceProperties(&mut props, dev) };
let state: &LinkState = mem::transmute(state);
let spirv_bins = state.modules.iter().map(|m| &m.binaries[..]);
let should_link_ptx_impl = state.modules.iter().find_map(|m| m.should_link_ptx_impl);
let arch_binary = module::compile_amd(&props, spirv_bins, should_link_ptx_impl)
.map_err(|_| hipError_t::hipErrorUnknown)?;
Ok(())
} }
pub(crate) unsafe fn destroy(state: CUlinkState) -> CUresult { pub(crate) unsafe fn destroy(state: CUlinkState) -> CUresult {

View File

@ -7,7 +7,7 @@ use std::ops::Add;
use std::os::raw::c_char; use std::os::raw::c_char;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
use std::{env, fs, mem, ptr, slice}; use std::{env, fs, iter, mem, ptr, slice};
use hip_runtime_sys::{ use hip_runtime_sys::{
hipCtxGetCurrent, hipCtxGetDevice, hipDeviceGetAttribute, hipDeviceGetName, hipDeviceProp_t, hipCtxGetCurrent, hipCtxGetDevice, hipDeviceGetAttribute, hipDeviceGetName, hipDeviceProp_t,
@ -87,7 +87,7 @@ pub fn load_data_impl(pmod: *mut CUmodule, spirv_data: SpirvModule) -> Result<()
let err = unsafe { hipGetDeviceProperties(&mut props, dev) }; let err = unsafe { hipGetDeviceProperties(&mut props, dev) };
let arch_binary = compile_amd( let arch_binary = compile_amd(
&props, &props,
&[&spirv_data.binaries[..]], iter::once(&spirv_data.binaries[..]),
spirv_data.should_link_ptx_impl, spirv_data.should_link_ptx_impl,
) )
.map_err(|_| hipError_t::hipErrorUnknown)?; .map_err(|_| hipError_t::hipErrorUnknown)?;
@ -113,9 +113,9 @@ const AMDGPU_BITCODE: [&'static str; 8] = [
]; ];
const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_"; const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
fn compile_amd( pub(crate) fn compile_amd<'a>(
device_pros: &hipDeviceProp_t, device_pros: &hipDeviceProp_t,
spirv_il: &[&[u32]], spirv_il: impl Iterator<Item = &'a [u32]>,
ptx_lib: Option<(&'static [u8], &'static [u8])>, ptx_lib: Option<(&'static [u8], &'static [u8])>,
) -> io::Result<Vec<u8>> { ) -> io::Result<Vec<u8>> {
let null_terminator = device_pros let null_terminator = device_pros
@ -134,7 +134,6 @@ fn compile_amd(
}; };
let dir = tempfile::tempdir()?; let dir = tempfile::tempdir()?;
let spirv_files = spirv_il let spirv_files = spirv_il
.iter()
.map(|spirv| { .map(|spirv| {
let mut spirv_file = NamedTempFile::new_in(&dir)?; let mut spirv_file = NamedTempFile::new_in(&dir)?;
let spirv_u8 = unsafe { let spirv_u8 = unsafe {