From 21ef5f60a3a5efa17855a30f6b5c7d1968cd46ba Mon Sep 17 00:00:00 2001 From: Violet Date: Wed, 30 Jul 2025 14:55:09 -0700 Subject: [PATCH] Check Rust formatting on pull requests (#451) * Check Rust formatting on pull requests This should help us maintain consistent style, without having unrelated style changes in pull requests from running `rustfmt`. * cargo fmt non-generated files * Ignore generated files --- .github/workflows/pr_master.yml | 9 + comgr/src/lib.rs | 815 ++-- cuda_macros/.rustfmt.toml | 1 + cuda_types/.rustfmt.toml | 1 + dark_api/src/fatbin.rs | 26 +- dark_api/src/lib.rs | 6 +- ext/amd_comgr-sys/src/lib.rs | 2 +- ext/hip_runtime-sys/.rustfmt.toml | 1 + ext/rocblas-sys/.rustfmt.toml | 1 + format/.rustfmt.toml | 1 + format/src/dark_api.rs | 2 +- format/src/lib.rs | 2635 +++++------ ptx/src/lib.rs | 1 - ptx/src/pass/llvm/attributes.rs | 18 +- ptx/src/pass/llvm/emit.rs | 6043 ++++++++++++------------ ptx/src/pass/llvm/mod.rs | 6 +- ptx/src/pass/mod.rs | 1929 ++++---- ptx/src/pass/normalize_basic_blocks.rs | 4 +- ptx/src/test/mod.rs | 4 +- ptx/src/test/spirv_run/mod.rs | 1335 +++--- ptx_parser/src/ast.rs | 4045 ++++++++-------- ptx_parser/src/lib.rs | 2 +- ptx_parser_macros/src/lib.rs | 2064 ++++---- ptx_parser_macros_impl/src/lib.rs | 1762 +++---- ptx_parser_macros_impl/src/parser.rs | 1827 ++++--- zluda/src/impl/device.rs | 21 +- zluda/src/impl/library.rs | 7 +- zluda/src/impl/module.rs | 5 +- zluda/src/lib.rs | 116 +- zluda/src/os_unix.rs | 1 + zluda_blas/src/impl.rs | 8 +- zluda_blas/src/lib.rs | 13 +- zluda_blaslt/src/impl.rs | 4 +- zluda_blaslt/src/lib.rs | 15 +- zluda_dnn/src/lib.rs | 15 +- zluda_dump/src/lib.rs | 10 +- zluda_fft/src/lib.rs | 4 +- zluda_ml/src/impl.rs | 80 +- zluda_ml/src/lib.rs | 6 +- zluda_sparse/src/lib.rs | 17 +- 40 files changed, 11463 insertions(+), 11399 deletions(-) create mode 100644 cuda_macros/.rustfmt.toml create mode 100644 cuda_types/.rustfmt.toml create mode 100644 ext/hip_runtime-sys/.rustfmt.toml create mode 100644 ext/rocblas-sys/.rustfmt.toml create mode 100644 format/.rustfmt.toml diff --git a/.github/workflows/pr_master.yml b/.github/workflows/pr_master.yml index 9b50a84..8787c81 100644 --- a/.github/workflows/pr_master.yml +++ b/.github/workflows/pr_master.yml @@ -11,6 +11,15 @@ env: ROCM_VERSION: "6.3.1" jobs: + formatting: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rustfmt + - name: Check Rust formatting + uses: actions-rust-lang/rustfmt@v1 build_linux: name: Build (Linux) runs-on: ubuntu-22.04 diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 776f76c..366476b 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -1,407 +1,408 @@ -use amd_comgr_sys::*; -use std::{ffi::CStr, mem, ptr}; - -macro_rules! call_dispatch_arg { - (2, $arg:ident) => { - $arg.comgr2() - }; - (2, $arg:tt) => { - #[allow(unused_braces)] - $arg - }; - (3, $arg:ident) => { - $arg.comgr3() - }; - (3, $arg:tt) => { - #[allow(unused_braces)] - $arg - }; -} - -macro_rules! call_dispatch { - ($src:expr => $fn_:ident( $($arg:tt),+ )) => { - match $src { - Comgr::V2(this) => unsafe { this. $fn_( - $( - call_dispatch_arg!(2, $arg), - )+ - ) }?, - Comgr::V3(this) => unsafe { this. $fn_( - $( - call_dispatch_arg!(3, $arg), - )+ - ) }?, - } - }; -} - -macro_rules! comgr_owned { - ($name:ident, $comgr_type:ident, $ctor:ident, $dtor:ident) => { - struct $name<'a> { - handle: u64, - comgr: &'a Comgr, - } - - impl<'a> $name<'a> { - fn new(comgr: &'a Comgr) -> Result { - let handle = match comgr { - Comgr::V2(comgr) => { - let mut result = unsafe { mem::zeroed() }; - unsafe { comgr.$ctor(&mut result)? }; - result.handle - } - Comgr::V3(comgr) => { - let mut result = unsafe { mem::zeroed() }; - unsafe { comgr.$ctor(&mut result)? }; - result.handle - } - }; - Ok(Self { handle, comgr }) - } - - fn comgr2(&self) -> amd_comgr_sys::comgr2::$comgr_type { - amd_comgr_sys::comgr2::$comgr_type { - handle: self.handle, - } - } - - fn comgr3(&self) -> amd_comgr_sys::comgr3::$comgr_type { - amd_comgr_sys::comgr3::$comgr_type { - handle: self.handle, - } - } - } - - impl<'a> Drop for $name<'a> { - fn drop(&mut self) { - match self.comgr { - Comgr::V2(comgr) => { - unsafe { - comgr.$dtor(amd_comgr_sys::comgr2::$comgr_type { - handle: self.handle, - }) - } - .ok(); - } - Comgr::V3(comgr) => { - unsafe { - comgr.$dtor(amd_comgr_sys::comgr3::$comgr_type { - handle: self.handle, - }) - } - .ok(); - } - } - } - } - }; -} - -comgr_owned!( - ActionInfo, - amd_comgr_action_info_t, - amd_comgr_create_action_info, - amd_comgr_destroy_action_info -); - -impl<'a> ActionInfo<'a> { - fn set_isa_name(&self, isa: &CStr) -> Result<(), Error> { - let mut full_isa = "amdgcn-amd-amdhsa--".to_string().into_bytes(); - full_isa.extend(isa.to_bytes_with_nul()); - call_dispatch!(self.comgr => amd_comgr_action_info_set_isa_name(self, { full_isa.as_ptr().cast() })); - Ok(()) - } - - fn set_language(&self, language: Language) -> Result<(), Error> { - call_dispatch!(self.comgr => amd_comgr_action_info_set_language(self, language)); - Ok(()) - } - - fn set_options<'b>(&self, options: impl Iterator) -> Result<(), Error> { - let options = options.map(|x| x.as_ptr()).collect::>(); - call_dispatch!(self.comgr => amd_comgr_action_info_set_option_list(self, { options.as_ptr().cast_mut() }, { options.len() })); - Ok(()) - } -} - -comgr_owned!( - DataSet, - amd_comgr_data_set_t, - amd_comgr_create_data_set, - amd_comgr_destroy_data_set -); - -impl<'a> DataSet<'a> { - fn add(&self, data: &Data) -> Result<(), Error> { - call_dispatch!(self.comgr => amd_comgr_data_set_add(self, data)); - Ok(()) - } - - fn get_data(&self, kind: DataKind, index: usize) -> Result { - let mut handle = 0u64; - call_dispatch!(self.comgr => amd_comgr_action_data_get_data(self, kind, { index }, { std::ptr::from_mut(&mut handle).cast() })); - Ok(Data(handle)) - } -} - -struct Data(u64); - -impl Data { - fn new(comgr: &Comgr, kind: DataKind, name: &CStr, content: &[u8]) -> Result { - let mut handle = 0u64; - call_dispatch!(comgr => amd_comgr_create_data(kind, { std::ptr::from_mut(&mut handle).cast() })); - let data = Data(handle); - call_dispatch!(comgr => amd_comgr_set_data_name(data, { name.as_ptr() })); - call_dispatch!(comgr => amd_comgr_set_data(data, { content.len() }, { content.as_ptr().cast() })); - Ok(data) - } - - fn comgr2(&self) -> comgr2::amd_comgr_data_t { - comgr2::amd_comgr_data_s { handle: self.0 } - } - - fn comgr3(&self) -> comgr3::amd_comgr_data_t { - comgr3::amd_comgr_data_s { handle: self.0 } - } - - fn copy_content(&self, comgr: &Comgr) -> Result, Error> { - let mut size = unsafe { mem::zeroed() }; - call_dispatch!(comgr => amd_comgr_get_data(self, { &mut size }, { ptr::null_mut() })); - let mut result: Vec = Vec::with_capacity(size); - unsafe { result.set_len(size) }; - call_dispatch!(comgr => amd_comgr_get_data(self, { &mut size }, { result.as_mut_ptr().cast() })); - Ok(result) - } -} - -pub fn compile_bitcode( - comgr: &Comgr, - gcn_arch: &CStr, - main_buffer: &[u8], - attributes_buffer: &[u8], - ptx_impl: &[u8], -) -> Result, Error> { - let bitcode_data_set = DataSet::new(comgr)?; - let main_bitcode_data = Data::new(comgr, DataKind::Bc, c"zluda.bc", main_buffer)?; - bitcode_data_set.add(&main_bitcode_data)?; - let attributes_bitcode_data = Data::new(comgr, DataKind::Bc, c"attributes.bc", attributes_buffer)?; - bitcode_data_set.add(&attributes_bitcode_data)?; - let stdlib_bitcode_data = Data::new(comgr, DataKind::Bc, c"ptx_impl.bc", ptx_impl)?; - bitcode_data_set.add(&stdlib_bitcode_data)?; - let linking_info = ActionInfo::new(comgr)?; - let linked_data_set = - comgr.do_action(ActionKind::LinkBcToBc, &linking_info, &bitcode_data_set)?; - let compile_to_exec = ActionInfo::new(comgr)?; - compile_to_exec.set_isa_name(gcn_arch)?; - compile_to_exec.set_language(Language::LlvmIr)?; - let common_options = [ - // This makes no sense, but it makes ockl linking work - c"-Xclang", - c"-mno-link-builtin-bitcode-postopt", - // Otherwise LLVM omits dynamic fp mode for ockl functions during linking - // and then fails to inline them - c"-Xclang", - c"-fdenormal-fp-math=dynamic", - c"-O3", - c"-mno-wavefrontsize64", - c"-mcumode", - // Useful for inlining reports, combined with AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_EMIT_VERBOSE_LOGS=1 AMD_COMGR_REDIRECT_LOGS=stderr - // c"-fsave-optimization-record=yaml", - ] - .into_iter(); - let opt_options = if cfg!(debug_assertions) { - //[c"-g", c"-mllvm", c"-print-before-all", c"", c""] - [c"-g", c"", c"", c"", c""] - } else { - [ - c"-g0", - // default inlining threshold times 10 - c"-mllvm", - c"-inline-threshold=2250", - c"-mllvm", - c"-inlinehint-threshold=3250", - ] - }; - compile_to_exec.set_options(common_options.chain(opt_options))?; - let exec_data_set = comgr.do_action( - ActionKind::CompileSourceToExecutable, - &compile_to_exec, - &linked_data_set, - )?; - let executable = exec_data_set.get_data(DataKind::Executable, 0)?; - executable.copy_content(comgr) -} - -pub enum Comgr { - V2(amd_comgr_sys::comgr2::Comgr2), - V3(amd_comgr_sys::comgr3::Comgr3), -} - -impl Comgr { - pub fn new() -> Result { - unsafe { libloading::Library::new(os::COMGR3) } - .and_then(|lib| { - Ok(Comgr::V3(unsafe { - amd_comgr_sys::comgr3::Comgr3::from_library(lib)? - })) - }) - .or_else(|_| { - unsafe { libloading::Library::new(os::COMGR2) }.and_then(|lib| { - Ok(if Self::is_broken_v2(&lib) { - Comgr::V3(unsafe { amd_comgr_sys::comgr3::Comgr3::from_library(lib)? }) - } else { - Comgr::V2(unsafe { amd_comgr_sys::comgr2::Comgr2::from_library(lib)? }) - }) - }) - }) - .map_err(Into::into) - } - - // For reasons unknown, on AMD Adrenalin 25.5.1, AMD ships amd_comgr_2.dll that shows up as - // version 2.9.0, but actually uses the 3.X ABI. This is our best effort to detect it. - // Version 25.3.1 returns 2.8.0, which seem to be the last version that actually uses the 2 ABI - fn is_broken_v2(lib: &libloading::Library) -> bool { - if cfg!(not(windows)) { - return false; - } - let amd_comgr_get_version = match unsafe { - lib.get::( - b"amd_comgr_get_version\0", - ) - } { - Ok(symbol) => symbol, - Err(_) => return false, - }; - let mut major = 0; - let mut minor = 0; - unsafe { (amd_comgr_get_version)(&mut major, &mut minor) }; - (major, minor) >= (2, 9) - } - - fn do_action( - &self, - kind: ActionKind, - action: &ActionInfo, - data_set: &DataSet, - ) -> Result { - let result = DataSet::new(self)?; - call_dispatch!(self => amd_comgr_do_action(kind, action, data_set, result)); - Ok(result) - } -} - -#[derive(Debug)] -pub struct Error(pub ::std::num::NonZeroU32); - -impl Error { - #[doc = " A generic error has occurred."] - pub const UNKNOWN: Error = Error(unsafe { ::std::num::NonZeroU32::new_unchecked(1) }); - #[doc = " One of the actual arguments does not meet a precondition stated\n in the documentation of the corresponding formal argument. This\n includes both invalid Action types, and invalid arguments to\n valid Action types."] - pub const INVALID_ARGUMENT: Error = Error(unsafe { ::std::num::NonZeroU32::new_unchecked(2) }); - #[doc = " Failed to allocate the necessary resources."] - pub const OUT_OF_RESOURCES: Error = Error(unsafe { ::std::num::NonZeroU32::new_unchecked(3) }); -} - -impl From for Error { - fn from(_: libloading::Error) -> Self { - Self::UNKNOWN - } -} - -impl From for Error { - fn from(status: comgr2::amd_comgr_status_s) -> Self { - Error(status.0) - } -} - -impl From for Error { - fn from(status: comgr3::amd_comgr_status_s) -> Self { - Error(status.0) - } -} - -macro_rules! impl_into { - ($self_type:ident, $to_type:ident, [$($from:ident => $to:ident),+]) => { - #[derive(Copy, Clone)] - #[allow(unused)] - enum $self_type { - $( - $from, - )+ - } - - impl $self_type { - fn comgr2(self) -> comgr2::$to_type { - match self { - $( - Self:: $from => comgr2 :: $to_type :: $to, - )+ - } - } - - fn comgr3(self) -> comgr3::$to_type { - match self { - $( - Self:: $from => comgr3 :: $to_type :: $to, - )+ - } - } - } - }; -} - -impl_into!( - ActionKind, - amd_comgr_action_kind_t, - [ - LinkBcToBc => AMD_COMGR_ACTION_LINK_BC_TO_BC, - CompileSourceToExecutable => AMD_COMGR_ACTION_COMPILE_SOURCE_TO_EXECUTABLE - ] -); - -impl_into!( - DataKind, - amd_comgr_data_kind_t, - [ - Undef => AMD_COMGR_DATA_KIND_UNDEF, - Source => AMD_COMGR_DATA_KIND_SOURCE, - Include => AMD_COMGR_DATA_KIND_INCLUDE, - PrecompiledHeader => AMD_COMGR_DATA_KIND_PRECOMPILED_HEADER, - Diagnostic => AMD_COMGR_DATA_KIND_DIAGNOSTIC, - Log => AMD_COMGR_DATA_KIND_LOG, - Bc => AMD_COMGR_DATA_KIND_BC, - Relocatable => AMD_COMGR_DATA_KIND_RELOCATABLE, - Executable => AMD_COMGR_DATA_KIND_EXECUTABLE, - Bytes => AMD_COMGR_DATA_KIND_BYTES, - Fatbin => AMD_COMGR_DATA_KIND_FATBIN, - Ar => AMD_COMGR_DATA_KIND_AR, - BcBundle => AMD_COMGR_DATA_KIND_BC_BUNDLE, - ArBundle => AMD_COMGR_DATA_KIND_AR_BUNDLE, - ObjBundle => AMD_COMGR_DATA_KIND_OBJ_BUNDLE - - ] -); - -impl_into!( - Language, - amd_comgr_language_t, - [ - None => AMD_COMGR_LANGUAGE_NONE, - OpenCl12 => AMD_COMGR_LANGUAGE_OPENCL_1_2, - OpenCl20 => AMD_COMGR_LANGUAGE_OPENCL_2_0, - Hip => AMD_COMGR_LANGUAGE_HIP, - LlvmIr => AMD_COMGR_LANGUAGE_LLVM_IR - ] -); - -#[cfg(unix)] -mod os { - pub static COMGR3: &'static str = "libamd_comgr.so.3"; - pub static COMGR2: &'static str = "libamd_comgr.so.2"; -} - -#[cfg(windows)] -mod os { - pub static COMGR3: &'static str = "amd_comgr_3.dll"; - pub static COMGR2: &'static str = "amd_comgr_2.dll"; -} +use amd_comgr_sys::*; +use std::{ffi::CStr, mem, ptr}; + +macro_rules! call_dispatch_arg { + (2, $arg:ident) => { + $arg.comgr2() + }; + (2, $arg:tt) => { + #[allow(unused_braces)] + $arg + }; + (3, $arg:ident) => { + $arg.comgr3() + }; + (3, $arg:tt) => { + #[allow(unused_braces)] + $arg + }; +} + +macro_rules! call_dispatch { + ($src:expr => $fn_:ident( $($arg:tt),+ )) => { + match $src { + Comgr::V2(this) => unsafe { this. $fn_( + $( + call_dispatch_arg!(2, $arg), + )+ + ) }?, + Comgr::V3(this) => unsafe { this. $fn_( + $( + call_dispatch_arg!(3, $arg), + )+ + ) }?, + } + }; +} + +macro_rules! comgr_owned { + ($name:ident, $comgr_type:ident, $ctor:ident, $dtor:ident) => { + struct $name<'a> { + handle: u64, + comgr: &'a Comgr, + } + + impl<'a> $name<'a> { + fn new(comgr: &'a Comgr) -> Result { + let handle = match comgr { + Comgr::V2(comgr) => { + let mut result = unsafe { mem::zeroed() }; + unsafe { comgr.$ctor(&mut result)? }; + result.handle + } + Comgr::V3(comgr) => { + let mut result = unsafe { mem::zeroed() }; + unsafe { comgr.$ctor(&mut result)? }; + result.handle + } + }; + Ok(Self { handle, comgr }) + } + + fn comgr2(&self) -> amd_comgr_sys::comgr2::$comgr_type { + amd_comgr_sys::comgr2::$comgr_type { + handle: self.handle, + } + } + + fn comgr3(&self) -> amd_comgr_sys::comgr3::$comgr_type { + amd_comgr_sys::comgr3::$comgr_type { + handle: self.handle, + } + } + } + + impl<'a> Drop for $name<'a> { + fn drop(&mut self) { + match self.comgr { + Comgr::V2(comgr) => { + unsafe { + comgr.$dtor(amd_comgr_sys::comgr2::$comgr_type { + handle: self.handle, + }) + } + .ok(); + } + Comgr::V3(comgr) => { + unsafe { + comgr.$dtor(amd_comgr_sys::comgr3::$comgr_type { + handle: self.handle, + }) + } + .ok(); + } + } + } + } + }; +} + +comgr_owned!( + ActionInfo, + amd_comgr_action_info_t, + amd_comgr_create_action_info, + amd_comgr_destroy_action_info +); + +impl<'a> ActionInfo<'a> { + fn set_isa_name(&self, isa: &CStr) -> Result<(), Error> { + let mut full_isa = "amdgcn-amd-amdhsa--".to_string().into_bytes(); + full_isa.extend(isa.to_bytes_with_nul()); + call_dispatch!(self.comgr => amd_comgr_action_info_set_isa_name(self, { full_isa.as_ptr().cast() })); + Ok(()) + } + + fn set_language(&self, language: Language) -> Result<(), Error> { + call_dispatch!(self.comgr => amd_comgr_action_info_set_language(self, language)); + Ok(()) + } + + fn set_options<'b>(&self, options: impl Iterator) -> Result<(), Error> { + let options = options.map(|x| x.as_ptr()).collect::>(); + call_dispatch!(self.comgr => amd_comgr_action_info_set_option_list(self, { options.as_ptr().cast_mut() }, { options.len() })); + Ok(()) + } +} + +comgr_owned!( + DataSet, + amd_comgr_data_set_t, + amd_comgr_create_data_set, + amd_comgr_destroy_data_set +); + +impl<'a> DataSet<'a> { + fn add(&self, data: &Data) -> Result<(), Error> { + call_dispatch!(self.comgr => amd_comgr_data_set_add(self, data)); + Ok(()) + } + + fn get_data(&self, kind: DataKind, index: usize) -> Result { + let mut handle = 0u64; + call_dispatch!(self.comgr => amd_comgr_action_data_get_data(self, kind, { index }, { std::ptr::from_mut(&mut handle).cast() })); + Ok(Data(handle)) + } +} + +struct Data(u64); + +impl Data { + fn new(comgr: &Comgr, kind: DataKind, name: &CStr, content: &[u8]) -> Result { + let mut handle = 0u64; + call_dispatch!(comgr => amd_comgr_create_data(kind, { std::ptr::from_mut(&mut handle).cast() })); + let data = Data(handle); + call_dispatch!(comgr => amd_comgr_set_data_name(data, { name.as_ptr() })); + call_dispatch!(comgr => amd_comgr_set_data(data, { content.len() }, { content.as_ptr().cast() })); + Ok(data) + } + + fn comgr2(&self) -> comgr2::amd_comgr_data_t { + comgr2::amd_comgr_data_s { handle: self.0 } + } + + fn comgr3(&self) -> comgr3::amd_comgr_data_t { + comgr3::amd_comgr_data_s { handle: self.0 } + } + + fn copy_content(&self, comgr: &Comgr) -> Result, Error> { + let mut size = unsafe { mem::zeroed() }; + call_dispatch!(comgr => amd_comgr_get_data(self, { &mut size }, { ptr::null_mut() })); + let mut result: Vec = Vec::with_capacity(size); + unsafe { result.set_len(size) }; + call_dispatch!(comgr => amd_comgr_get_data(self, { &mut size }, { result.as_mut_ptr().cast() })); + Ok(result) + } +} + +pub fn compile_bitcode( + comgr: &Comgr, + gcn_arch: &CStr, + main_buffer: &[u8], + attributes_buffer: &[u8], + ptx_impl: &[u8], +) -> Result, Error> { + let bitcode_data_set = DataSet::new(comgr)?; + let main_bitcode_data = Data::new(comgr, DataKind::Bc, c"zluda.bc", main_buffer)?; + bitcode_data_set.add(&main_bitcode_data)?; + let attributes_bitcode_data = + Data::new(comgr, DataKind::Bc, c"attributes.bc", attributes_buffer)?; + bitcode_data_set.add(&attributes_bitcode_data)?; + let stdlib_bitcode_data = Data::new(comgr, DataKind::Bc, c"ptx_impl.bc", ptx_impl)?; + bitcode_data_set.add(&stdlib_bitcode_data)?; + let linking_info = ActionInfo::new(comgr)?; + let linked_data_set = + comgr.do_action(ActionKind::LinkBcToBc, &linking_info, &bitcode_data_set)?; + let compile_to_exec = ActionInfo::new(comgr)?; + compile_to_exec.set_isa_name(gcn_arch)?; + compile_to_exec.set_language(Language::LlvmIr)?; + let common_options = [ + // This makes no sense, but it makes ockl linking work + c"-Xclang", + c"-mno-link-builtin-bitcode-postopt", + // Otherwise LLVM omits dynamic fp mode for ockl functions during linking + // and then fails to inline them + c"-Xclang", + c"-fdenormal-fp-math=dynamic", + c"-O3", + c"-mno-wavefrontsize64", + c"-mcumode", + // Useful for inlining reports, combined with AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_EMIT_VERBOSE_LOGS=1 AMD_COMGR_REDIRECT_LOGS=stderr + // c"-fsave-optimization-record=yaml", + ] + .into_iter(); + let opt_options = if cfg!(debug_assertions) { + //[c"-g", c"-mllvm", c"-print-before-all", c"", c""] + [c"-g", c"", c"", c"", c""] + } else { + [ + c"-g0", + // default inlining threshold times 10 + c"-mllvm", + c"-inline-threshold=2250", + c"-mllvm", + c"-inlinehint-threshold=3250", + ] + }; + compile_to_exec.set_options(common_options.chain(opt_options))?; + let exec_data_set = comgr.do_action( + ActionKind::CompileSourceToExecutable, + &compile_to_exec, + &linked_data_set, + )?; + let executable = exec_data_set.get_data(DataKind::Executable, 0)?; + executable.copy_content(comgr) +} + +pub enum Comgr { + V2(amd_comgr_sys::comgr2::Comgr2), + V3(amd_comgr_sys::comgr3::Comgr3), +} + +impl Comgr { + pub fn new() -> Result { + unsafe { libloading::Library::new(os::COMGR3) } + .and_then(|lib| { + Ok(Comgr::V3(unsafe { + amd_comgr_sys::comgr3::Comgr3::from_library(lib)? + })) + }) + .or_else(|_| { + unsafe { libloading::Library::new(os::COMGR2) }.and_then(|lib| { + Ok(if Self::is_broken_v2(&lib) { + Comgr::V3(unsafe { amd_comgr_sys::comgr3::Comgr3::from_library(lib)? }) + } else { + Comgr::V2(unsafe { amd_comgr_sys::comgr2::Comgr2::from_library(lib)? }) + }) + }) + }) + .map_err(Into::into) + } + + // For reasons unknown, on AMD Adrenalin 25.5.1, AMD ships amd_comgr_2.dll that shows up as + // version 2.9.0, but actually uses the 3.X ABI. This is our best effort to detect it. + // Version 25.3.1 returns 2.8.0, which seem to be the last version that actually uses the 2 ABI + fn is_broken_v2(lib: &libloading::Library) -> bool { + if cfg!(not(windows)) { + return false; + } + let amd_comgr_get_version = match unsafe { + lib.get::( + b"amd_comgr_get_version\0", + ) + } { + Ok(symbol) => symbol, + Err(_) => return false, + }; + let mut major = 0; + let mut minor = 0; + unsafe { (amd_comgr_get_version)(&mut major, &mut minor) }; + (major, minor) >= (2, 9) + } + + fn do_action( + &self, + kind: ActionKind, + action: &ActionInfo, + data_set: &DataSet, + ) -> Result { + let result = DataSet::new(self)?; + call_dispatch!(self => amd_comgr_do_action(kind, action, data_set, result)); + Ok(result) + } +} + +#[derive(Debug)] +pub struct Error(pub ::std::num::NonZeroU32); + +impl Error { + #[doc = " A generic error has occurred."] + pub const UNKNOWN: Error = Error(unsafe { ::std::num::NonZeroU32::new_unchecked(1) }); + #[doc = " One of the actual arguments does not meet a precondition stated\n in the documentation of the corresponding formal argument. This\n includes both invalid Action types, and invalid arguments to\n valid Action types."] + pub const INVALID_ARGUMENT: Error = Error(unsafe { ::std::num::NonZeroU32::new_unchecked(2) }); + #[doc = " Failed to allocate the necessary resources."] + pub const OUT_OF_RESOURCES: Error = Error(unsafe { ::std::num::NonZeroU32::new_unchecked(3) }); +} + +impl From for Error { + fn from(_: libloading::Error) -> Self { + Self::UNKNOWN + } +} + +impl From for Error { + fn from(status: comgr2::amd_comgr_status_s) -> Self { + Error(status.0) + } +} + +impl From for Error { + fn from(status: comgr3::amd_comgr_status_s) -> Self { + Error(status.0) + } +} + +macro_rules! impl_into { + ($self_type:ident, $to_type:ident, [$($from:ident => $to:ident),+]) => { + #[derive(Copy, Clone)] + #[allow(unused)] + enum $self_type { + $( + $from, + )+ + } + + impl $self_type { + fn comgr2(self) -> comgr2::$to_type { + match self { + $( + Self:: $from => comgr2 :: $to_type :: $to, + )+ + } + } + + fn comgr3(self) -> comgr3::$to_type { + match self { + $( + Self:: $from => comgr3 :: $to_type :: $to, + )+ + } + } + } + }; +} + +impl_into!( + ActionKind, + amd_comgr_action_kind_t, + [ + LinkBcToBc => AMD_COMGR_ACTION_LINK_BC_TO_BC, + CompileSourceToExecutable => AMD_COMGR_ACTION_COMPILE_SOURCE_TO_EXECUTABLE + ] +); + +impl_into!( + DataKind, + amd_comgr_data_kind_t, + [ + Undef => AMD_COMGR_DATA_KIND_UNDEF, + Source => AMD_COMGR_DATA_KIND_SOURCE, + Include => AMD_COMGR_DATA_KIND_INCLUDE, + PrecompiledHeader => AMD_COMGR_DATA_KIND_PRECOMPILED_HEADER, + Diagnostic => AMD_COMGR_DATA_KIND_DIAGNOSTIC, + Log => AMD_COMGR_DATA_KIND_LOG, + Bc => AMD_COMGR_DATA_KIND_BC, + Relocatable => AMD_COMGR_DATA_KIND_RELOCATABLE, + Executable => AMD_COMGR_DATA_KIND_EXECUTABLE, + Bytes => AMD_COMGR_DATA_KIND_BYTES, + Fatbin => AMD_COMGR_DATA_KIND_FATBIN, + Ar => AMD_COMGR_DATA_KIND_AR, + BcBundle => AMD_COMGR_DATA_KIND_BC_BUNDLE, + ArBundle => AMD_COMGR_DATA_KIND_AR_BUNDLE, + ObjBundle => AMD_COMGR_DATA_KIND_OBJ_BUNDLE + + ] +); + +impl_into!( + Language, + amd_comgr_language_t, + [ + None => AMD_COMGR_LANGUAGE_NONE, + OpenCl12 => AMD_COMGR_LANGUAGE_OPENCL_1_2, + OpenCl20 => AMD_COMGR_LANGUAGE_OPENCL_2_0, + Hip => AMD_COMGR_LANGUAGE_HIP, + LlvmIr => AMD_COMGR_LANGUAGE_LLVM_IR + ] +); + +#[cfg(unix)] +mod os { + pub static COMGR3: &'static str = "libamd_comgr.so.3"; + pub static COMGR2: &'static str = "libamd_comgr.so.2"; +} + +#[cfg(windows)] +mod os { + pub static COMGR3: &'static str = "amd_comgr_3.dll"; + pub static COMGR2: &'static str = "amd_comgr_2.dll"; +} diff --git a/cuda_macros/.rustfmt.toml b/cuda_macros/.rustfmt.toml new file mode 100644 index 0000000..c7ad93b --- /dev/null +++ b/cuda_macros/.rustfmt.toml @@ -0,0 +1 @@ +disable_all_formatting = true diff --git a/cuda_types/.rustfmt.toml b/cuda_types/.rustfmt.toml new file mode 100644 index 0000000..c7ad93b --- /dev/null +++ b/cuda_types/.rustfmt.toml @@ -0,0 +1 @@ +disable_all_formatting = true diff --git a/dark_api/src/fatbin.rs b/dark_api/src/fatbin.rs index c9ff08f..7fb4493 100644 --- a/dark_api/src/fatbin.rs +++ b/dark_api/src/fatbin.rs @@ -77,21 +77,22 @@ impl<'a> Fatbin<'a> { pub fn get_submodules(&self) -> Result, FatbinError> { match self.wrapper.version { - FatbincWrapper::VERSION_V2 => - Ok(FatbinIter::V2(FatbinSubmoduleIterator { - fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void, - _phantom: std::marker::PhantomData, - })), + FatbincWrapper::VERSION_V2 => Ok(FatbinIter::V2(FatbinSubmoduleIterator { + fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void, + _phantom: std::marker::PhantomData, + })), FatbincWrapper::VERSION_V1 => { - let header = parse_fatbin_header(&self.wrapper.data) - .map_err(FatbinError::ParseFailure)?; + let header = + parse_fatbin_header(&self.wrapper.data).map_err(FatbinError::ParseFailure)?; Ok(FatbinIter::V1(Some(FatbinSubmodule::new(header)))) } - version => Err(FatbinError::ParseFailure(ParseError::UnexpectedBinaryField{ - field_name: "FATBINC_VERSION", - observed: version, - expected: [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2].into(), - })), + version => Err(FatbinError::ParseFailure( + ParseError::UnexpectedBinaryField { + field_name: "FATBINC_VERSION", + observed: version, + expected: [FatbincWrapper::VERSION_V1, FatbincWrapper::VERSION_V2].into(), + }, + )), } } } @@ -176,7 +177,6 @@ impl<'a> FatbinFile<'a> { unsafe { self.get_payload().to_vec() } }; - while payload.last() == Some(&0) { // remove trailing zeros payload.pop(); diff --git a/dark_api/src/lib.rs b/dark_api/src/lib.rs index b90cb53..756d228 100644 --- a/dark_api/src/lib.rs +++ b/dark_api/src/lib.rs @@ -259,12 +259,12 @@ dark_api! { "{C693336E-1121-DF11-A8C3-68F355D89593}" => CONTEXT_LOCAL_STORAGE_INTERFACE_V0301[4] { [0] = context_local_storage_put( context: cuda_types::cuda::CUcontext, - key: *mut std::ffi::c_void, - value: *mut std::ffi::c_void, + key: *mut std::ffi::c_void, + value: *mut std::ffi::c_void, // clsContextDestroyCallback, have to be called on cuDevicePrimaryCtxReset dtor_cb: Option ) -> cuda_types::cuda::CUresult, diff --git a/ext/amd_comgr-sys/src/lib.rs b/ext/amd_comgr-sys/src/lib.rs index c37a884..dff39ca 100644 --- a/ext/amd_comgr-sys/src/lib.rs +++ b/ext/amd_comgr-sys/src/lib.rs @@ -1,4 +1,4 @@ #[allow(warnings)] pub mod comgr2; #[allow(warnings)] -pub mod comgr3; \ No newline at end of file +pub mod comgr3; diff --git a/ext/hip_runtime-sys/.rustfmt.toml b/ext/hip_runtime-sys/.rustfmt.toml new file mode 100644 index 0000000..c7ad93b --- /dev/null +++ b/ext/hip_runtime-sys/.rustfmt.toml @@ -0,0 +1 @@ +disable_all_formatting = true diff --git a/ext/rocblas-sys/.rustfmt.toml b/ext/rocblas-sys/.rustfmt.toml new file mode 100644 index 0000000..c7ad93b --- /dev/null +++ b/ext/rocblas-sys/.rustfmt.toml @@ -0,0 +1 @@ +disable_all_formatting = true diff --git a/format/.rustfmt.toml b/format/.rustfmt.toml new file mode 100644 index 0000000..c7ad93b --- /dev/null +++ b/format/.rustfmt.toml @@ -0,0 +1 @@ +disable_all_formatting = true diff --git a/format/src/dark_api.rs b/format/src/dark_api.rs index 35ad7f2..973e4df 100644 --- a/format/src/dark_api.rs +++ b/format/src/dark_api.rs @@ -37,4 +37,4 @@ impl CudaDisplay for FatbinHeader { CudaDisplay::write(&self.files_size, "", 0, writer)?; writer.write_all(b" }") } -} \ No newline at end of file +} diff --git a/format/src/lib.rs b/format/src/lib.rs index 004e3af..196e845 100644 --- a/format/src/lib.rs +++ b/format/src/lib.rs @@ -1,1317 +1,1318 @@ -use cuda_types::cuda::*; -use std::{ - any::TypeId, - ffi::{c_void, CStr}, - fmt::LowerHex, - mem, ptr, slice, -}; - -pub trait CudaDisplay { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()>; -} - -impl CudaDisplay for () { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "()") - } -} - -impl CudaDisplay for CUuuid { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - let guid = self.bytes; - let uuid = uuid::Uuid::from_bytes(guid); - let braced = uuid.as_braced(); - write!(writer, "{braced:#X}") - } -} - -impl CudaDisplay for CUdeviceptr_v1 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{:p}", self.0 as usize as *const ()) - } -} - -impl CudaDisplay for bool { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for u8 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for u16 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for i32 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for u32 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for i64 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for u64 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for usize { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for f32 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -impl CudaDisplay for f64 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{}", *self) - } -} - -// user by Dark API -impl CudaDisplay - for Option< - extern "system" fn( - cuda_types::cuda::CUcontext, - *mut std::ffi::c_void, - *mut std::ffi::c_void, - ), - > -{ - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - if let Some(fn_ptr) = self { - write!(writer, "{:p}", *fn_ptr) - } else { - writer.write_all(b"NULL") - } - } -} - -impl CudaDisplay for Option { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - if let Some(fn_ptr) = self { - write!(writer, "{:p}", *fn_ptr) - } else { - writer.write_all(b"NULL") - } - } -} - -impl CudaDisplay for Option { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - if let Some(fn_ptr) = self { - write!(writer, "{:p}", *fn_ptr) - } else { - writer.write_all(b"NULL") - } - } -} - -pub fn write_handle( - this: &[T; 64], - writer: &mut (impl std::io::Write + ?Sized), -) -> std::io::Result<()> { - writer.write_all(b"0x")?; - for i in (0..64).rev() { - write!(writer, "{:02x}", this[i])?; - } - Ok(()) -} - -impl CudaDisplay for CUipcMemHandle { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write_handle(&self.reserved, writer) - } -} - -impl CudaDisplay for CUipcEventHandle { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write_handle(&self.reserved, writer) - } -} - -impl CudaDisplay for CUmemPoolPtrExportData_v1 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write_handle(&self.reserved, writer) - } -} - -impl CudaDisplay for *mut c_void { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{:p}", *self) - } -} - -impl CudaDisplay for *const c_void { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write!(writer, "{:p}", *self) - } -} - -impl CudaDisplay for *const i8 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - if self.is_null() { - writer.write_all(b"NULL") - } else { - write!( - writer, - "\"{}\"", - unsafe { CStr::from_ptr(*self as _) }.to_string_lossy() - ) - } - } -} - -impl CudaDisplay for *mut cuda_types::FILE { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - if self.is_null() { - writer.write_all(b"NULL") - } else { - write!(writer, "{:p}", *self) - } - } -} - -#[repr(C)] -#[derive(Copy, Clone)] -struct Luid { - low_part: u32, - high_part: u32, -} - -impl CudaDisplay for *mut i8 { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - if fn_name == "cuDeviceGetLuid" && index == 0 { - let luid_ptr = *self as *mut Luid; - let luid = unsafe { *luid_ptr }; - write!(writer, "{{{:08X}-{:08X}}}", luid.low_part, luid.high_part) - } else { - write!( - writer, - "\"{}\"", - unsafe { CStr::from_ptr(*self as _) }.to_string_lossy() - ) - } - } -} - -impl CudaDisplay for CUstreamBatchMemOpParams { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - unsafe { - match self.operation { - // The below is not a typo, `WAIT_VALUE` and `WRITE_VALUE` are - // distinct operations with nominally distinct union variants, but - // in reality they are structurally different, so we take a little - // shortcut here - CUstreamBatchMemOpType::CU_STREAM_MEM_OP_WAIT_VALUE_32 - | CUstreamBatchMemOpType::CU_STREAM_MEM_OP_WRITE_VALUE_32 => { - write_wait_value(&self.waitValue, writer, false) - } - CUstreamBatchMemOpType::CU_STREAM_MEM_OP_WAIT_VALUE_64 - | CUstreamBatchMemOpType::CU_STREAM_MEM_OP_WRITE_VALUE_64 => { - write_wait_value(&self.waitValue, writer, true) - } - CUstreamBatchMemOpType::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES => { - CudaDisplay::write(&self.flushRemoteWrites, "", 0, writer) - } - _ => { - writer.write_all(b"{ operation: ")?; - CudaDisplay::write(&self.operation, "", 0, writer)?; - writer.write_all(b", ... }") - } - } - } - } -} - -impl CudaDisplay for CUcheckpointRestoreArgs_st { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - CudaDisplay::write(&self.reserved, fn_name, index, writer) - } -} - -impl CudaDisplay for CUcheckpointUnlockArgs_st { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - CudaDisplay::write(&self.reserved, fn_name, index, writer) - } -} - -impl CudaDisplay for CUcheckpointCheckpointArgs_st { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - CudaDisplay::write(&self.reserved, fn_name, index, writer) - } -} - -impl CudaDisplay for CUmemcpy3DOperand_st { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(b"{ type_: ")?; - CudaDisplay::write(&self.type_, "", 0, writer)?; - writer.write_all(b", op: ")?; - match self.type_ { - CUmemcpy3DOperandType::CU_MEMCPY_OPERAND_TYPE_ARRAY => { - CudaDisplay::write(unsafe { &self.op.array }, fn_name, index, writer)?; - } - CUmemcpy3DOperandType::CU_MEMCPY_OPERAND_TYPE_POINTER => { - CudaDisplay::write(unsafe { &self.op.ptr }, fn_name, index, writer)?; - } - _ => { - const CU_MEMCPY_3D_OP_SIZE: usize = mem::size_of::(); - CudaDisplay::write( - &unsafe { mem::transmute::<_, [u8; CU_MEMCPY_3D_OP_SIZE]>(self.op) }, - fn_name, - index, - writer, - )?; - } - } - writer.write_all(b" }") - } -} - -pub fn write_wait_value( - this: &CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st, - writer: &mut (impl std::io::Write + ?Sized), - is_64_bit: bool, -) -> std::io::Result<()> { - writer.write_all(b"{ operation: ")?; - CudaDisplay::write(&this.operation, "", 0, writer)?; - writer.write_all(b", address: ")?; - CudaDisplay::write(&this.address, "", 0, writer)?; - write_wait_value_32_or_64(&this.__bindgen_anon_1, writer, is_64_bit)?; - writer.write_all(b", flags: ")?; - CudaDisplay::write(&this.flags, "", 0, writer)?; - writer.write_all(b", alias: ")?; - CudaDisplay::write(&this.alias, "", 0, writer)?; - writer.write_all(b" }") -} - -pub fn write_wait_value_32_or_64( - this: &CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st__bindgen_ty_1, - writer: &mut (impl std::io::Write + ?Sized), - is_64_bit: bool, -) -> std::io::Result<()> { - if is_64_bit { - writer.write_all(b", value64: ")?; - CudaDisplay::write(unsafe { &this.value64 }, "", 0, writer) - } else { - writer.write_all(b", value: ")?; - CudaDisplay::write(unsafe { &this.value }, "", 0, writer) - } -} - -impl CudaDisplay for CUDA_RESOURCE_DESC_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(b"{ resType: ")?; - CudaDisplay::write(&self.resType, "", 0, writer)?; - match self.resType { - CUresourcetype::CU_RESOURCE_TYPE_ARRAY => { - writer.write_all(b", res: ")?; - CudaDisplay::write(unsafe { &self.res.array }, "", 0, writer)?; - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0, writer)?; - writer.write_all(b" }") - } - CUresourcetype::CU_RESOURCE_TYPE_MIPMAPPED_ARRAY => { - writer.write_all(b", res: ")?; - CudaDisplay::write(unsafe { &self.res.mipmap }, "", 0, writer)?; - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0, writer)?; - writer.write_all(b" }") - } - CUresourcetype::CU_RESOURCE_TYPE_LINEAR => { - writer.write_all(b", res: ")?; - CudaDisplay::write(unsafe { &self.res.linear }, "", 0, writer)?; - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0, writer)?; - writer.write_all(b" }") - } - CUresourcetype::CU_RESOURCE_TYPE_PITCH2D => { - writer.write_all(b", res: ")?; - CudaDisplay::write(unsafe { &self.res.pitch2D }, "", 0, writer)?; - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0, writer)?; - writer.write_all(b" }") - } - _ => { - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0, writer)?; - writer.write_all(b", ... }") - } - } - } -} - -impl crate::CudaDisplay for cuda_types::cuda::CUlaunchConfig_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(concat!("{ ", stringify!(gridDimX), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.gridDimX, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(gridDimY), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.gridDimY, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(gridDimZ), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.gridDimZ, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(blockDimX), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.blockDimX, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(blockDimY), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.blockDimY, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(blockDimZ), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.blockDimZ, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(sharedMemBytes), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.sharedMemBytes, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(hStream), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.hStream, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(numAttrs), ": ").as_bytes())?; - crate::CudaDisplay::write(&self.numAttrs, "", 0, writer)?; - writer.write_all(concat!(", ", stringify!(attrs), ": ").as_bytes())?; - writer.write_all(b"[")?; - for i in 0..self.numAttrs { - if i != 0 { - writer.write_all(b", ")?; - } - crate::CudaDisplay::write(&unsafe { *self.attrs.add(i as usize) }, "", 0, writer)?; - } - writer.write_all(b"]")?; - writer.write_all(b" }") - } -} - -impl CudaDisplay for CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(b"{ type: ")?; - CudaDisplay::write(&self.type_, "", 0, writer)?; - match self.type_ { - CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD => { - writer.write_all(b", handle: ")?; - CudaDisplay::write(unsafe { &self.handle.fd }, "", 0, writer)?; - } - CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 - | CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP - | CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE - | CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE => { - write_win32_handle(unsafe { self.handle.win32 }, writer)?; - } - CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT - | CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT => { - writer.write_all(b", handle: ")?; - CudaDisplay::write(unsafe { &self.handle.win32.handle }, "", 0, writer)?; - } - CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF => { - writer.write_all(b", handle: ")?; - CudaDisplay::write(unsafe { &self.handle.nvSciBufObject }, "", 0, writer)?; - } - _ => { - writer.write_all(b", size: ")?; - CudaDisplay::write(&self.size, "", 0, writer)?; - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0, writer)?; - return writer.write_all(b", ... }"); - } - } - writer.write_all(b", size: ")?; - CudaDisplay::write(&self.size, "", 0, writer)?; - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0, writer)?; - writer.write_all(b" }") - } -} - -pub fn write_win32_handle( - win32: CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st__bindgen_ty_1__bindgen_ty_1, - writer: &mut (impl std::io::Write + ?Sized), -) -> std::io::Result<()> { - if win32.handle != ptr::null_mut() { - writer.write_all(b", handle: ")?; - CudaDisplay::write(&win32.handle, "", 0, writer)?; - } - if win32.name != ptr::null_mut() { - let name_ptr = win32.name as *const u16; - let mut strlen = 0usize; - while unsafe { *name_ptr.add(strlen) } != 0 { - strlen += 1; - } - let text = String::from_utf16_lossy(unsafe { slice::from_raw_parts(name_ptr, strlen) }); - write!(writer, ", name: \"{}\"", text)?; - } - Ok(()) -} - -impl CudaDisplay for CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(b"{ type: ")?; - CudaDisplay::write(&self.type_, "", 0, writer)?; - match self.type_ { - CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD => { - writer.write_all(b", handle: ")?; - CudaDisplay::write(unsafe { &self.handle.fd }, "", 0,writer)?; - } - CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 - | CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE - | CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE - | CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX - | CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT => { - write_win32_handle(unsafe { mem::transmute(self.handle.win32) }, writer)?; - } - CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT => { - writer.write_all(b", handle: ")?; - CudaDisplay::write(unsafe { &self.handle.win32.handle }, "", 0,writer)?; - } - CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC => { - writer.write_all(b", handle: ")?; - CudaDisplay::write(unsafe { &self.handle.nvSciSyncObj }, "", 0,writer)?; - } - _ => { - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0,writer)?; - return writer.write_all(b", ... }") - } - } - writer.write_all(b", flags: ")?; - CudaDisplay::write(&self.flags, "", 0, writer)?; - writer.write_all(b" }") - } -} - -impl CudaDisplay for CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_st__bindgen_ty_1__bindgen_ty_2 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(b"{ fence: ")?; - CudaDisplay::write(&unsafe { self.fence }, "", 0, writer)?; - writer.write_all(b" }") - } -} - -impl CudaDisplay for CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st__bindgen_ty_1__bindgen_ty_2 { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(b"{ fence: ")?; - CudaDisplay::write(&unsafe { self.fence }, "", 0, writer)?; - writer.write_all(b" }") - } -} - -impl CudaDisplay for CUgraphNodeParams_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - _writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - todo!() - } -} - -impl CudaDisplay for CUeglFrame_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - _writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - todo!() - } -} - -impl CudaDisplay for CUdevResource_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - _writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - todo!() - } -} - -impl CudaDisplay for CUlaunchAttribute_st { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - write_launch_attribute(writer, fn_name, index, self.id, self.value) - } -} - -impl CudaDisplay for *mut T { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - CudaDisplay::write(&self.cast_const(), fn_name, index, writer) - } -} - -impl CudaDisplay for *const T { - fn write( - &self, - fn_name: &'static str, - index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - if *self == ptr::null() { - writer.write_all(b"NULL") - } else { - if fn_name.len() > 2 - && fn_name.starts_with("cu") - && fn_name.as_bytes()[2].is_ascii_lowercase() - && (TypeId::of::() == TypeId::of::() - || TypeId::of::() == TypeId::of::()) - { - CudaDisplay::write(&self.cast::(), fn_name, index, writer) - } else { - let this: &T = unsafe { &**self }; - this.write(fn_name, index, writer) - } - } - } -} - -impl CudaDisplay for [T; N] { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(b"[")?; - for i in 0..N { - CudaDisplay::write(&self[i], "", 0, writer)?; - if i != N - 1 { - writer.write_all(b", ")?; - } - } - writer.write_all(b"]") - } -} - -impl CudaDisplay for [T] { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - writer.write_all(b"[")?; - for i in 0..self.len() { - CudaDisplay::write(&self[i], "", 0, writer)?; - if i != self.len() - 1 { - writer.write_all(b", ")?; - } - } - writer.write_all(b"]") - } -} - -impl CudaDisplay for CUarrayMapInfo_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - _writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - todo!() - } -} - -impl CudaDisplay for CUexecAffinityParam_st { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - _writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - todo!() - } -} - -impl CudaDisplay for *mut cuda_types::cudnn9::cudnnRuntimeTag_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - _writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - todo!() - } -} - -#[allow(non_snake_case)] -pub fn write_cuGraphKernelNodeGetAttribute( - writer: &mut (impl std::io::Write + ?Sized), - hNode: CUgraphNode, - attr: CUkernelNodeAttrID, - value_out: *mut CUkernelNodeAttrValue, -) -> std::io::Result<()> { - writer.write_all(b"(hNode: ")?; - CudaDisplay::write(&hNode, "cuGraphKernelNodeGetAttribute", 0, writer)?; - writer.write_all(b", attr: ")?; - CudaDisplay::write(&attr, "cuGraphKernelNodeGetAttribute", 1, writer)?; - writer.write_all(b", value_out: ")?; - write_launch_attribute(writer, "cuGraphKernelNodeGetAttribute", 2, attr, unsafe { - *value_out - })?; - writer.write_all(b") ") -} - -#[allow(non_snake_case)] -pub fn write_cuGraphKernelNodeSetAttribute( - writer: &mut (impl std::io::Write + ?Sized), - hNode: CUgraphNode, - attr: CUkernelNodeAttrID, - value_out: *const CUkernelNodeAttrValue, -) -> std::io::Result<()> { - write_cuGraphKernelNodeGetAttribute(writer, hNode, attr, value_out as *mut _) -} - -#[allow(non_snake_case)] -pub fn write_cuStreamGetAttribute( - writer: &mut (impl std::io::Write + ?Sized), - hStream: CUstream, - attr: CUstreamAttrID, - value_out: *mut CUstreamAttrValue, -) -> std::io::Result<()> { - writer.write_all(b"(hStream: ")?; - CudaDisplay::write(&hStream, "cuStreamGetAttribute", 0, writer)?; - writer.write_all(b", attr: ")?; - CudaDisplay::write(&attr, "cuStreamGetAttribute", 1, writer)?; - writer.write_all(b", value_out: ")?; - write_launch_attribute(writer, "cuStreamGetAttribute", 2, attr, unsafe { - *value_out - })?; - writer.write_all(b") ") -} - -fn write_launch_attribute( - writer: &mut (impl std::io::Write + ?Sized), - fn_name: &'static str, - index: usize, - attribute: CUlaunchAttributeID, - value: CUlaunchAttributeValue, -) -> std::io::Result<()> { - match attribute { - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW = ")?; - CudaDisplay::write(unsafe { &value.accessPolicyWindow }, fn_name, index, writer) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_COOPERATIVE => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_COOPERATIVE = ")?; - CudaDisplay::write(unsafe { &value.cooperative }, fn_name, index, writer) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY = ")?; - CudaDisplay::write(unsafe { &value.syncPolicy }, fn_name, index, writer) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION = ")?; - CudaDisplay::write(unsafe { &value.clusterDim }, fn_name, index, writer) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE = ")?; - CudaDisplay::write( - unsafe { &value.clusterSchedulingPolicyPreference }, - fn_name, - index, - writer, - ) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION = ")?; - CudaDisplay::write( - unsafe { &value.programmaticStreamSerializationAllowed }, - fn_name, - index, - writer, - ) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT = ")?; - CudaDisplay::write(unsafe { &value.programmaticEvent }, fn_name, index, writer) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_PRIORITY => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_PRIORITY = ")?; - CudaDisplay::write(unsafe { &value.priority }, fn_name, index, writer) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP = ")?; - CudaDisplay::write(unsafe { &value.memSyncDomainMap }, fn_name, index, writer) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN = ")?; - CudaDisplay::write(unsafe { &value.memSyncDomain }, fn_name, index, writer) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT = ")?; - CudaDisplay::write( - unsafe { &value.launchCompletionEvent }, - fn_name, - index, - writer, - ) - } - CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE => { - writer.write_all(b"CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE = ")?; - CudaDisplay::write( - unsafe { &value.deviceUpdatableKernelNode }, - fn_name, - index, - writer, - ) - } - _ => writer.write_all(b""), - } -} - -#[allow(non_snake_case)] -pub fn write_cuStreamGetAttribute_ptsz( - writer: &mut (impl std::io::Write + ?Sized), - hStream: CUstream, - attr: CUstreamAttrID, - value_out: *mut CUstreamAttrValue, -) -> std::io::Result<()> { - write_cuStreamGetAttribute(writer, hStream, attr, value_out) -} - -#[allow(non_snake_case)] -pub fn write_cuStreamSetAttribute( - writer: &mut (impl std::io::Write + ?Sized), - hStream: CUstream, - attr: CUstreamAttrID, - value_out: *const CUstreamAttrValue, -) -> std::io::Result<()> { - write_cuStreamGetAttribute(writer, hStream, attr, value_out as *mut _) -} - -#[allow(non_snake_case)] -pub fn write_cuStreamSetAttribute_ptsz( - writer: &mut (impl std::io::Write + ?Sized), - hStream: CUstream, - attr: CUstreamAttrID, - value_out: *const CUstreamAttrValue, -) -> std::io::Result<()> { - write_cuStreamSetAttribute(writer, hStream, attr, value_out) -} - -#[allow(non_snake_case)] -pub fn write_cuGLGetDevices( - _writer: &mut (impl std::io::Write + ?Sized), - _pCudaDeviceCount: *mut ::std::os::raw::c_uint, - _pCudaDevices: *mut CUdevice, - _cudaDeviceCount: ::std::os::raw::c_uint, - _deviceList: CUGLDeviceList, -) -> std::io::Result<()> { - todo!() -} - -#[allow(non_snake_case)] -pub fn write_cuGLGetDevices_v2( - _writer: &mut (impl std::io::Write + ?Sized), - _pCudaDeviceCount: *mut ::std::os::raw::c_uint, - _pCudaDevices: *mut CUdevice, - _cudaDeviceCount: ::std::os::raw::c_uint, - _deviceList: CUGLDeviceList, -) -> std::io::Result<()> { - todo!() -} - -#[allow(non_snake_case)] -pub fn write_cudnnBackendGetAttribute( - writer: &mut (impl std::io::Write + ?Sized), - descriptor: cuda_types::cudnn9::cudnnBackendDescriptor_t, - attributeName: cuda_types::cudnn9::cudnnBackendAttributeName_t, - attributeType: cuda_types::cudnn9::cudnnBackendAttributeType_t, - requestedElementCount: i64, - elementCount: *mut i64, - arrayOfElements: *mut ::core::ffi::c_void, -) -> std::io::Result<()> { - let mut arg_idx = 0usize; - writer.write_all(b"(")?; - writer.write_all(concat!(stringify!(descriptor), ": ").as_bytes())?; - crate::CudaDisplay::write(&descriptor, "cudnnBackendGetAttribute", arg_idx, writer)?; - arg_idx += 1; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(attributeName), ": ").as_bytes())?; - crate::CudaDisplay::write(&attributeName, "cudnnBackendGetAttribute", arg_idx, writer)?; - arg_idx += 1; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(attributeType), ": ").as_bytes())?; - crate::CudaDisplay::write(&attributeType, "cudnnBackendGetAttribute", arg_idx, writer)?; - arg_idx += 1; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(requestedElementCount), ": ").as_bytes())?; - crate::CudaDisplay::write( - &requestedElementCount, - "cudnnBackendGetAttribute", - arg_idx, - writer, - )?; - arg_idx += 1; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(elementCount), ": ").as_bytes())?; - crate::CudaDisplay::write(&elementCount, "cudnnBackendGetAttribute", arg_idx, writer)?; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(arrayOfElements), ": ").as_bytes())?; - cudnn9_print_elements( - writer, - attributeType, - unsafe { elementCount.as_ref() } - .copied() - .unwrap_or(requestedElementCount), - arrayOfElements, - )?; - writer.write_all(b")") -} - -#[allow(non_snake_case)] -pub fn write_cudnnBackendSetAttribute( - writer: &mut (impl std::io::Write + ?Sized), - descriptor: cuda_types::cudnn9::cudnnBackendDescriptor_t, - attributeName: cuda_types::cudnn9::cudnnBackendAttributeName_t, - attributeType: cuda_types::cudnn9::cudnnBackendAttributeType_t, - elementCount: i64, - arrayOfElements: *const ::core::ffi::c_void, -) -> std::io::Result<()> { - let mut arg_idx = 0usize; - writer.write_all(b"(")?; - writer.write_all(concat!(stringify!(descriptor), ": ").as_bytes())?; - crate::CudaDisplay::write(&descriptor, "cudnnBackendSetAttribute", arg_idx, writer)?; - arg_idx += 1; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(attributeName), ": ").as_bytes())?; - crate::CudaDisplay::write(&attributeName, "cudnnBackendSetAttribute", arg_idx, writer)?; - arg_idx += 1; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(attributeType), ": ").as_bytes())?; - crate::CudaDisplay::write(&attributeType, "cudnnBackendSetAttribute", arg_idx, writer)?; - arg_idx += 1; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(elementCount), ": ").as_bytes())?; - crate::CudaDisplay::write(&elementCount, "cudnnBackendSetAttribute", arg_idx, writer)?; - writer.write_all(b", ")?; - writer.write_all(concat!(stringify!(arrayOfElements), ": ").as_bytes())?; - cudnn9_print_elements(writer, attributeType, elementCount, arrayOfElements)?; - writer.write_all(b")") -} - -fn cudnn9_print_elements( - writer: &mut (impl std::io::Write + ?Sized), - type_: cuda_types::cudnn9::cudnnBackendAttributeType_t, - element_count: i64, - array_of_elements: *const ::core::ffi::c_void, -) -> std::io::Result<()> { - fn print_typed( - writer: &mut (impl std::io::Write + ?Sized), - element_count: i64, - array_of_elements: *const ::core::ffi::c_void, - ) -> std::io::Result<()> { - if array_of_elements.is_null() { - return writer.write_all(b"NULL"); - } - let elements = - unsafe { slice::from_raw_parts(array_of_elements as *const T, element_count as usize) }; - CudaDisplay::write(elements, "", 0, writer) - } - match type_ { - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_HANDLE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_DATA_TYPE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_BOOLEAN => { - print_typed::(writer, element_count, array_of_elements) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_INT64 => { - print_typed::(writer, element_count, array_of_elements) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT => { - print_typed::(writer, element_count, array_of_elements) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE => { - print_typed::(writer, element_count, array_of_elements) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_VOID_PTR => { - print_typed::<*const c_void>(writer, element_count, array_of_elements) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_CONVOLUTION_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_HEUR_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_KNOB_TYPE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_NAN_PROPOGATION => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_NUMERICAL_NOTE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_LAYOUT_TYPE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_ATTRIB_NAME => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_POINTWISE_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_GENSTATS_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_BN_FINALIZE_STATS_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_REDUCTION_OPERATOR_TYPE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_BEHAVIOR_NOTE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_TENSOR_REORDERING_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_RESAMPLE_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_PADDING_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_INT32 => { - print_typed::(writer, element_count, array_of_elements) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_CHAR => { - CudaDisplay::write(&array_of_elements.cast::(), "", 0, writer) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_SIGNAL_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_FRACTION => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_NORM_MODE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_NORM_FWD_PHASE => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_RNG_DISTRIBUTION => { - print_typed::( - writer, - element_count, - array_of_elements, - ) - } - _ => unimplemented!(), - } -} - -mod dark_api; -mod format_generated; -pub use format_generated::*; -mod format_generated_blas; -pub use format_generated_blas::*; -mod format_generated_blaslt; -pub use format_generated_blaslt::*; -mod format_generated_blaslt_internal; -pub use format_generated_blaslt_internal::*; -mod format_generated_dnn9; -pub use format_generated_dnn9::*; -mod format_generated_fft; -pub use format_generated_fft::*; -mod format_generated_sparse; -pub use format_generated_sparse::*; +use cuda_types::cuda::*; +use std::{ + any::TypeId, + ffi::{c_void, CStr}, + fmt::LowerHex, + mem, ptr, slice, +}; + +pub trait CudaDisplay { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()>; +} + +impl CudaDisplay for () { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "()") + } +} + +impl CudaDisplay for CUuuid { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + let guid = self.bytes; + let uuid = uuid::Uuid::from_bytes(guid); + let braced = uuid.as_braced(); + write!(writer, "{braced:#X}") + } +} + +impl CudaDisplay for CUdeviceptr_v1 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{:p}", self.0 as usize as *const ()) + } +} + +impl CudaDisplay for bool { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for u8 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for u16 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for i32 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for u32 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for i64 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for u64 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for usize { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for f32 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +impl CudaDisplay for f64 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{}", *self) + } +} + +// user by Dark API +impl CudaDisplay + for Option< + extern "system" fn( + cuda_types::cuda::CUcontext, + *mut std::ffi::c_void, + *mut std::ffi::c_void, + ), + > +{ + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + if let Some(fn_ptr) = self { + write!(writer, "{:p}", *fn_ptr) + } else { + writer.write_all(b"NULL") + } + } +} + +impl CudaDisplay for Option { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + if let Some(fn_ptr) = self { + write!(writer, "{:p}", *fn_ptr) + } else { + writer.write_all(b"NULL") + } + } +} + +impl CudaDisplay for Option { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + if let Some(fn_ptr) = self { + write!(writer, "{:p}", *fn_ptr) + } else { + writer.write_all(b"NULL") + } + } +} + +pub fn write_handle( + this: &[T; 64], + writer: &mut (impl std::io::Write + ?Sized), +) -> std::io::Result<()> { + writer.write_all(b"0x")?; + for i in (0..64).rev() { + write!(writer, "{:02x}", this[i])?; + } + Ok(()) +} + +impl CudaDisplay for CUipcMemHandle { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write_handle(&self.reserved, writer) + } +} + +impl CudaDisplay for CUipcEventHandle { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write_handle(&self.reserved, writer) + } +} + +impl CudaDisplay for CUmemPoolPtrExportData_v1 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write_handle(&self.reserved, writer) + } +} + +impl CudaDisplay for *mut c_void { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{:p}", *self) + } +} + +impl CudaDisplay for *const c_void { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write!(writer, "{:p}", *self) + } +} + +impl CudaDisplay for *const i8 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + if self.is_null() { + writer.write_all(b"NULL") + } else { + write!( + writer, + "\"{}\"", + unsafe { CStr::from_ptr(*self as _) }.to_string_lossy() + ) + } + } +} + +impl CudaDisplay for *mut cuda_types::FILE { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + if self.is_null() { + writer.write_all(b"NULL") + } else { + write!(writer, "{:p}", *self) + } + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +struct Luid { + low_part: u32, + high_part: u32, +} + +impl CudaDisplay for *mut i8 { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + if fn_name == "cuDeviceGetLuid" && index == 0 { + let luid_ptr = *self as *mut Luid; + let luid = unsafe { *luid_ptr }; + write!(writer, "{{{:08X}-{:08X}}}", luid.low_part, luid.high_part) + } else { + write!( + writer, + "\"{}\"", + unsafe { CStr::from_ptr(*self as _) }.to_string_lossy() + ) + } + } +} + +impl CudaDisplay for CUstreamBatchMemOpParams { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + unsafe { + match self.operation { + // The below is not a typo, `WAIT_VALUE` and `WRITE_VALUE` are + // distinct operations with nominally distinct union variants, but + // in reality they are structurally different, so we take a little + // shortcut here + CUstreamBatchMemOpType::CU_STREAM_MEM_OP_WAIT_VALUE_32 + | CUstreamBatchMemOpType::CU_STREAM_MEM_OP_WRITE_VALUE_32 => { + write_wait_value(&self.waitValue, writer, false) + } + CUstreamBatchMemOpType::CU_STREAM_MEM_OP_WAIT_VALUE_64 + | CUstreamBatchMemOpType::CU_STREAM_MEM_OP_WRITE_VALUE_64 => { + write_wait_value(&self.waitValue, writer, true) + } + CUstreamBatchMemOpType::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES => { + CudaDisplay::write(&self.flushRemoteWrites, "", 0, writer) + } + _ => { + writer.write_all(b"{ operation: ")?; + CudaDisplay::write(&self.operation, "", 0, writer)?; + writer.write_all(b", ... }") + } + } + } + } +} + +impl CudaDisplay for CUcheckpointRestoreArgs_st { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + CudaDisplay::write(&self.reserved, fn_name, index, writer) + } +} + +impl CudaDisplay for CUcheckpointUnlockArgs_st { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + CudaDisplay::write(&self.reserved, fn_name, index, writer) + } +} + +impl CudaDisplay for CUcheckpointCheckpointArgs_st { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + CudaDisplay::write(&self.reserved, fn_name, index, writer) + } +} + +impl CudaDisplay for CUmemcpy3DOperand_st { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(b"{ type_: ")?; + CudaDisplay::write(&self.type_, "", 0, writer)?; + writer.write_all(b", op: ")?; + match self.type_ { + CUmemcpy3DOperandType::CU_MEMCPY_OPERAND_TYPE_ARRAY => { + CudaDisplay::write(unsafe { &self.op.array }, fn_name, index, writer)?; + } + CUmemcpy3DOperandType::CU_MEMCPY_OPERAND_TYPE_POINTER => { + CudaDisplay::write(unsafe { &self.op.ptr }, fn_name, index, writer)?; + } + _ => { + const CU_MEMCPY_3D_OP_SIZE: usize = + mem::size_of::(); + CudaDisplay::write( + &unsafe { mem::transmute::<_, [u8; CU_MEMCPY_3D_OP_SIZE]>(self.op) }, + fn_name, + index, + writer, + )?; + } + } + writer.write_all(b" }") + } +} + +pub fn write_wait_value( + this: &CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st, + writer: &mut (impl std::io::Write + ?Sized), + is_64_bit: bool, +) -> std::io::Result<()> { + writer.write_all(b"{ operation: ")?; + CudaDisplay::write(&this.operation, "", 0, writer)?; + writer.write_all(b", address: ")?; + CudaDisplay::write(&this.address, "", 0, writer)?; + write_wait_value_32_or_64(&this.__bindgen_anon_1, writer, is_64_bit)?; + writer.write_all(b", flags: ")?; + CudaDisplay::write(&this.flags, "", 0, writer)?; + writer.write_all(b", alias: ")?; + CudaDisplay::write(&this.alias, "", 0, writer)?; + writer.write_all(b" }") +} + +pub fn write_wait_value_32_or_64( + this: &CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st__bindgen_ty_1, + writer: &mut (impl std::io::Write + ?Sized), + is_64_bit: bool, +) -> std::io::Result<()> { + if is_64_bit { + writer.write_all(b", value64: ")?; + CudaDisplay::write(unsafe { &this.value64 }, "", 0, writer) + } else { + writer.write_all(b", value: ")?; + CudaDisplay::write(unsafe { &this.value }, "", 0, writer) + } +} + +impl CudaDisplay for CUDA_RESOURCE_DESC_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(b"{ resType: ")?; + CudaDisplay::write(&self.resType, "", 0, writer)?; + match self.resType { + CUresourcetype::CU_RESOURCE_TYPE_ARRAY => { + writer.write_all(b", res: ")?; + CudaDisplay::write(unsafe { &self.res.array }, "", 0, writer)?; + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0, writer)?; + writer.write_all(b" }") + } + CUresourcetype::CU_RESOURCE_TYPE_MIPMAPPED_ARRAY => { + writer.write_all(b", res: ")?; + CudaDisplay::write(unsafe { &self.res.mipmap }, "", 0, writer)?; + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0, writer)?; + writer.write_all(b" }") + } + CUresourcetype::CU_RESOURCE_TYPE_LINEAR => { + writer.write_all(b", res: ")?; + CudaDisplay::write(unsafe { &self.res.linear }, "", 0, writer)?; + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0, writer)?; + writer.write_all(b" }") + } + CUresourcetype::CU_RESOURCE_TYPE_PITCH2D => { + writer.write_all(b", res: ")?; + CudaDisplay::write(unsafe { &self.res.pitch2D }, "", 0, writer)?; + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0, writer)?; + writer.write_all(b" }") + } + _ => { + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0, writer)?; + writer.write_all(b", ... }") + } + } + } +} + +impl crate::CudaDisplay for cuda_types::cuda::CUlaunchConfig_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(concat!("{ ", stringify!(gridDimX), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.gridDimX, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(gridDimY), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.gridDimY, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(gridDimZ), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.gridDimZ, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(blockDimX), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.blockDimX, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(blockDimY), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.blockDimY, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(blockDimZ), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.blockDimZ, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(sharedMemBytes), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.sharedMemBytes, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(hStream), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.hStream, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(numAttrs), ": ").as_bytes())?; + crate::CudaDisplay::write(&self.numAttrs, "", 0, writer)?; + writer.write_all(concat!(", ", stringify!(attrs), ": ").as_bytes())?; + writer.write_all(b"[")?; + for i in 0..self.numAttrs { + if i != 0 { + writer.write_all(b", ")?; + } + crate::CudaDisplay::write(&unsafe { *self.attrs.add(i as usize) }, "", 0, writer)?; + } + writer.write_all(b"]")?; + writer.write_all(b" }") + } +} + +impl CudaDisplay for CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(b"{ type: ")?; + CudaDisplay::write(&self.type_, "", 0, writer)?; + match self.type_ { + CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD => { + writer.write_all(b", handle: ")?; + CudaDisplay::write(unsafe { &self.handle.fd }, "", 0, writer)?; + } + CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 + | CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP + | CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE + | CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE => { + write_win32_handle(unsafe { self.handle.win32 }, writer)?; + } + CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT + | CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT => { + writer.write_all(b", handle: ")?; + CudaDisplay::write(unsafe { &self.handle.win32.handle }, "", 0, writer)?; + } + CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF => { + writer.write_all(b", handle: ")?; + CudaDisplay::write(unsafe { &self.handle.nvSciBufObject }, "", 0, writer)?; + } + _ => { + writer.write_all(b", size: ")?; + CudaDisplay::write(&self.size, "", 0, writer)?; + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0, writer)?; + return writer.write_all(b", ... }"); + } + } + writer.write_all(b", size: ")?; + CudaDisplay::write(&self.size, "", 0, writer)?; + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0, writer)?; + writer.write_all(b" }") + } +} + +pub fn write_win32_handle( + win32: CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st__bindgen_ty_1__bindgen_ty_1, + writer: &mut (impl std::io::Write + ?Sized), +) -> std::io::Result<()> { + if win32.handle != ptr::null_mut() { + writer.write_all(b", handle: ")?; + CudaDisplay::write(&win32.handle, "", 0, writer)?; + } + if win32.name != ptr::null_mut() { + let name_ptr = win32.name as *const u16; + let mut strlen = 0usize; + while unsafe { *name_ptr.add(strlen) } != 0 { + strlen += 1; + } + let text = String::from_utf16_lossy(unsafe { slice::from_raw_parts(name_ptr, strlen) }); + write!(writer, ", name: \"{}\"", text)?; + } + Ok(()) +} + +impl CudaDisplay for CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(b"{ type: ")?; + CudaDisplay::write(&self.type_, "", 0, writer)?; + match self.type_ { + CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD => { + writer.write_all(b", handle: ")?; + CudaDisplay::write(unsafe { &self.handle.fd }, "", 0,writer)?; + } + CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 + | CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE + | CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE + | CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX + | CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT => { + write_win32_handle(unsafe { mem::transmute(self.handle.win32) }, writer)?; + } + CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT => { + writer.write_all(b", handle: ")?; + CudaDisplay::write(unsafe { &self.handle.win32.handle }, "", 0,writer)?; + } + CUexternalSemaphoreHandleType::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC => { + writer.write_all(b", handle: ")?; + CudaDisplay::write(unsafe { &self.handle.nvSciSyncObj }, "", 0,writer)?; + } + _ => { + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0,writer)?; + return writer.write_all(b", ... }") + } + } + writer.write_all(b", flags: ")?; + CudaDisplay::write(&self.flags, "", 0, writer)?; + writer.write_all(b" }") + } +} + +impl CudaDisplay for CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_st__bindgen_ty_1__bindgen_ty_2 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(b"{ fence: ")?; + CudaDisplay::write(&unsafe { self.fence }, "", 0, writer)?; + writer.write_all(b" }") + } +} + +impl CudaDisplay for CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st__bindgen_ty_1__bindgen_ty_2 { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(b"{ fence: ")?; + CudaDisplay::write(&unsafe { self.fence }, "", 0, writer)?; + writer.write_all(b" }") + } +} + +impl CudaDisplay for CUgraphNodeParams_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + _writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + todo!() + } +} + +impl CudaDisplay for CUeglFrame_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + _writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + todo!() + } +} + +impl CudaDisplay for CUdevResource_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + _writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + todo!() + } +} + +impl CudaDisplay for CUlaunchAttribute_st { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + write_launch_attribute(writer, fn_name, index, self.id, self.value) + } +} + +impl CudaDisplay for *mut T { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + CudaDisplay::write(&self.cast_const(), fn_name, index, writer) + } +} + +impl CudaDisplay for *const T { + fn write( + &self, + fn_name: &'static str, + index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + if *self == ptr::null() { + writer.write_all(b"NULL") + } else { + if fn_name.len() > 2 + && fn_name.starts_with("cu") + && fn_name.as_bytes()[2].is_ascii_lowercase() + && (TypeId::of::() == TypeId::of::() + || TypeId::of::() == TypeId::of::()) + { + CudaDisplay::write(&self.cast::(), fn_name, index, writer) + } else { + let this: &T = unsafe { &**self }; + this.write(fn_name, index, writer) + } + } + } +} + +impl CudaDisplay for [T; N] { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(b"[")?; + for i in 0..N { + CudaDisplay::write(&self[i], "", 0, writer)?; + if i != N - 1 { + writer.write_all(b", ")?; + } + } + writer.write_all(b"]") + } +} + +impl CudaDisplay for [T] { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + writer.write_all(b"[")?; + for i in 0..self.len() { + CudaDisplay::write(&self[i], "", 0, writer)?; + if i != self.len() - 1 { + writer.write_all(b", ")?; + } + } + writer.write_all(b"]") + } +} + +impl CudaDisplay for CUarrayMapInfo_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + _writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + todo!() + } +} + +impl CudaDisplay for CUexecAffinityParam_st { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + _writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + todo!() + } +} + +impl CudaDisplay for *mut cuda_types::cudnn9::cudnnRuntimeTag_t { + fn write( + &self, + _fn_name: &'static str, + _index: usize, + _writer: &mut (impl std::io::Write + ?Sized), + ) -> std::io::Result<()> { + todo!() + } +} + +#[allow(non_snake_case)] +pub fn write_cuGraphKernelNodeGetAttribute( + writer: &mut (impl std::io::Write + ?Sized), + hNode: CUgraphNode, + attr: CUkernelNodeAttrID, + value_out: *mut CUkernelNodeAttrValue, +) -> std::io::Result<()> { + writer.write_all(b"(hNode: ")?; + CudaDisplay::write(&hNode, "cuGraphKernelNodeGetAttribute", 0, writer)?; + writer.write_all(b", attr: ")?; + CudaDisplay::write(&attr, "cuGraphKernelNodeGetAttribute", 1, writer)?; + writer.write_all(b", value_out: ")?; + write_launch_attribute(writer, "cuGraphKernelNodeGetAttribute", 2, attr, unsafe { + *value_out + })?; + writer.write_all(b") ") +} + +#[allow(non_snake_case)] +pub fn write_cuGraphKernelNodeSetAttribute( + writer: &mut (impl std::io::Write + ?Sized), + hNode: CUgraphNode, + attr: CUkernelNodeAttrID, + value_out: *const CUkernelNodeAttrValue, +) -> std::io::Result<()> { + write_cuGraphKernelNodeGetAttribute(writer, hNode, attr, value_out as *mut _) +} + +#[allow(non_snake_case)] +pub fn write_cuStreamGetAttribute( + writer: &mut (impl std::io::Write + ?Sized), + hStream: CUstream, + attr: CUstreamAttrID, + value_out: *mut CUstreamAttrValue, +) -> std::io::Result<()> { + writer.write_all(b"(hStream: ")?; + CudaDisplay::write(&hStream, "cuStreamGetAttribute", 0, writer)?; + writer.write_all(b", attr: ")?; + CudaDisplay::write(&attr, "cuStreamGetAttribute", 1, writer)?; + writer.write_all(b", value_out: ")?; + write_launch_attribute(writer, "cuStreamGetAttribute", 2, attr, unsafe { + *value_out + })?; + writer.write_all(b") ") +} + +fn write_launch_attribute( + writer: &mut (impl std::io::Write + ?Sized), + fn_name: &'static str, + index: usize, + attribute: CUlaunchAttributeID, + value: CUlaunchAttributeValue, +) -> std::io::Result<()> { + match attribute { + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW = ")?; + CudaDisplay::write(unsafe { &value.accessPolicyWindow }, fn_name, index, writer) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_COOPERATIVE => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_COOPERATIVE = ")?; + CudaDisplay::write(unsafe { &value.cooperative }, fn_name, index, writer) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY = ")?; + CudaDisplay::write(unsafe { &value.syncPolicy }, fn_name, index, writer) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION = ")?; + CudaDisplay::write(unsafe { &value.clusterDim }, fn_name, index, writer) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE = ")?; + CudaDisplay::write( + unsafe { &value.clusterSchedulingPolicyPreference }, + fn_name, + index, + writer, + ) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION = ")?; + CudaDisplay::write( + unsafe { &value.programmaticStreamSerializationAllowed }, + fn_name, + index, + writer, + ) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT = ")?; + CudaDisplay::write(unsafe { &value.programmaticEvent }, fn_name, index, writer) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_PRIORITY => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_PRIORITY = ")?; + CudaDisplay::write(unsafe { &value.priority }, fn_name, index, writer) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP = ")?; + CudaDisplay::write(unsafe { &value.memSyncDomainMap }, fn_name, index, writer) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN = ")?; + CudaDisplay::write(unsafe { &value.memSyncDomain }, fn_name, index, writer) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_LAUNCH_COMPLETION_EVENT = ")?; + CudaDisplay::write( + unsafe { &value.launchCompletionEvent }, + fn_name, + index, + writer, + ) + } + CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE => { + writer.write_all(b"CU_LAUNCH_ATTRIBUTE_DEVICE_UPDATABLE_KERNEL_NODE = ")?; + CudaDisplay::write( + unsafe { &value.deviceUpdatableKernelNode }, + fn_name, + index, + writer, + ) + } + _ => writer.write_all(b""), + } +} + +#[allow(non_snake_case)] +pub fn write_cuStreamGetAttribute_ptsz( + writer: &mut (impl std::io::Write + ?Sized), + hStream: CUstream, + attr: CUstreamAttrID, + value_out: *mut CUstreamAttrValue, +) -> std::io::Result<()> { + write_cuStreamGetAttribute(writer, hStream, attr, value_out) +} + +#[allow(non_snake_case)] +pub fn write_cuStreamSetAttribute( + writer: &mut (impl std::io::Write + ?Sized), + hStream: CUstream, + attr: CUstreamAttrID, + value_out: *const CUstreamAttrValue, +) -> std::io::Result<()> { + write_cuStreamGetAttribute(writer, hStream, attr, value_out as *mut _) +} + +#[allow(non_snake_case)] +pub fn write_cuStreamSetAttribute_ptsz( + writer: &mut (impl std::io::Write + ?Sized), + hStream: CUstream, + attr: CUstreamAttrID, + value_out: *const CUstreamAttrValue, +) -> std::io::Result<()> { + write_cuStreamSetAttribute(writer, hStream, attr, value_out) +} + +#[allow(non_snake_case)] +pub fn write_cuGLGetDevices( + _writer: &mut (impl std::io::Write + ?Sized), + _pCudaDeviceCount: *mut ::std::os::raw::c_uint, + _pCudaDevices: *mut CUdevice, + _cudaDeviceCount: ::std::os::raw::c_uint, + _deviceList: CUGLDeviceList, +) -> std::io::Result<()> { + todo!() +} + +#[allow(non_snake_case)] +pub fn write_cuGLGetDevices_v2( + _writer: &mut (impl std::io::Write + ?Sized), + _pCudaDeviceCount: *mut ::std::os::raw::c_uint, + _pCudaDevices: *mut CUdevice, + _cudaDeviceCount: ::std::os::raw::c_uint, + _deviceList: CUGLDeviceList, +) -> std::io::Result<()> { + todo!() +} + +#[allow(non_snake_case)] +pub fn write_cudnnBackendGetAttribute( + writer: &mut (impl std::io::Write + ?Sized), + descriptor: cuda_types::cudnn9::cudnnBackendDescriptor_t, + attributeName: cuda_types::cudnn9::cudnnBackendAttributeName_t, + attributeType: cuda_types::cudnn9::cudnnBackendAttributeType_t, + requestedElementCount: i64, + elementCount: *mut i64, + arrayOfElements: *mut ::core::ffi::c_void, +) -> std::io::Result<()> { + let mut arg_idx = 0usize; + writer.write_all(b"(")?; + writer.write_all(concat!(stringify!(descriptor), ": ").as_bytes())?; + crate::CudaDisplay::write(&descriptor, "cudnnBackendGetAttribute", arg_idx, writer)?; + arg_idx += 1; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(attributeName), ": ").as_bytes())?; + crate::CudaDisplay::write(&attributeName, "cudnnBackendGetAttribute", arg_idx, writer)?; + arg_idx += 1; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(attributeType), ": ").as_bytes())?; + crate::CudaDisplay::write(&attributeType, "cudnnBackendGetAttribute", arg_idx, writer)?; + arg_idx += 1; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(requestedElementCount), ": ").as_bytes())?; + crate::CudaDisplay::write( + &requestedElementCount, + "cudnnBackendGetAttribute", + arg_idx, + writer, + )?; + arg_idx += 1; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(elementCount), ": ").as_bytes())?; + crate::CudaDisplay::write(&elementCount, "cudnnBackendGetAttribute", arg_idx, writer)?; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(arrayOfElements), ": ").as_bytes())?; + cudnn9_print_elements( + writer, + attributeType, + unsafe { elementCount.as_ref() } + .copied() + .unwrap_or(requestedElementCount), + arrayOfElements, + )?; + writer.write_all(b")") +} + +#[allow(non_snake_case)] +pub fn write_cudnnBackendSetAttribute( + writer: &mut (impl std::io::Write + ?Sized), + descriptor: cuda_types::cudnn9::cudnnBackendDescriptor_t, + attributeName: cuda_types::cudnn9::cudnnBackendAttributeName_t, + attributeType: cuda_types::cudnn9::cudnnBackendAttributeType_t, + elementCount: i64, + arrayOfElements: *const ::core::ffi::c_void, +) -> std::io::Result<()> { + let mut arg_idx = 0usize; + writer.write_all(b"(")?; + writer.write_all(concat!(stringify!(descriptor), ": ").as_bytes())?; + crate::CudaDisplay::write(&descriptor, "cudnnBackendSetAttribute", arg_idx, writer)?; + arg_idx += 1; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(attributeName), ": ").as_bytes())?; + crate::CudaDisplay::write(&attributeName, "cudnnBackendSetAttribute", arg_idx, writer)?; + arg_idx += 1; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(attributeType), ": ").as_bytes())?; + crate::CudaDisplay::write(&attributeType, "cudnnBackendSetAttribute", arg_idx, writer)?; + arg_idx += 1; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(elementCount), ": ").as_bytes())?; + crate::CudaDisplay::write(&elementCount, "cudnnBackendSetAttribute", arg_idx, writer)?; + writer.write_all(b", ")?; + writer.write_all(concat!(stringify!(arrayOfElements), ": ").as_bytes())?; + cudnn9_print_elements(writer, attributeType, elementCount, arrayOfElements)?; + writer.write_all(b")") +} + +fn cudnn9_print_elements( + writer: &mut (impl std::io::Write + ?Sized), + type_: cuda_types::cudnn9::cudnnBackendAttributeType_t, + element_count: i64, + array_of_elements: *const ::core::ffi::c_void, +) -> std::io::Result<()> { + fn print_typed( + writer: &mut (impl std::io::Write + ?Sized), + element_count: i64, + array_of_elements: *const ::core::ffi::c_void, + ) -> std::io::Result<()> { + if array_of_elements.is_null() { + return writer.write_all(b"NULL"); + } + let elements = + unsafe { slice::from_raw_parts(array_of_elements as *const T, element_count as usize) }; + CudaDisplay::write(elements, "", 0, writer) + } + match type_ { + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_HANDLE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_DATA_TYPE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_BOOLEAN => { + print_typed::(writer, element_count, array_of_elements) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_INT64 => { + print_typed::(writer, element_count, array_of_elements) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT => { + print_typed::(writer, element_count, array_of_elements) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE => { + print_typed::(writer, element_count, array_of_elements) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_VOID_PTR => { + print_typed::<*const c_void>(writer, element_count, array_of_elements) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_CONVOLUTION_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_HEUR_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_KNOB_TYPE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_NAN_PROPOGATION => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_NUMERICAL_NOTE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_LAYOUT_TYPE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_ATTRIB_NAME => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_POINTWISE_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_GENSTATS_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_BN_FINALIZE_STATS_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_REDUCTION_OPERATOR_TYPE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_BEHAVIOR_NOTE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_TENSOR_REORDERING_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_RESAMPLE_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_PADDING_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_INT32 => { + print_typed::(writer, element_count, array_of_elements) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_CHAR => { + CudaDisplay::write(&array_of_elements.cast::(), "", 0, writer) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_SIGNAL_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_FRACTION => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_NORM_MODE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_NORM_FWD_PHASE => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + cuda_types::cudnn9::cudnnBackendAttributeType_t::CUDNN_TYPE_RNG_DISTRIBUTION => { + print_typed::( + writer, + element_count, + array_of_elements, + ) + } + _ => unimplemented!(), + } +} + +mod dark_api; +mod format_generated; +pub use format_generated::*; +mod format_generated_blas; +pub use format_generated_blas::*; +mod format_generated_blaslt; +pub use format_generated_blaslt::*; +mod format_generated_blaslt_internal; +pub use format_generated_blaslt_internal::*; +mod format_generated_dnn9; +pub use format_generated_dnn9::*; +mod format_generated_fft; +pub use format_generated_fft::*; +mod format_generated_sparse; +pub use format_generated_sparse::*; diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 7aa9ee1..9edc7aa 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -4,4 +4,3 @@ mod test; pub use pass::to_llvm_module; pub use pass::Attributes; - diff --git a/ptx/src/pass/llvm/attributes.rs b/ptx/src/pass/llvm/attributes.rs index 4479ece..92ee20f 100644 --- a/ptx/src/pass/llvm/attributes.rs +++ b/ptx/src/pass/llvm/attributes.rs @@ -1,10 +1,13 @@ use std::ffi::CStr; -use super::*; use super::super::*; -use llvm_zluda::{core::*}; +use super::*; +use llvm_zluda::core::*; -pub(crate) fn run(context: &Context, attributes: Attributes) -> Result { +pub(crate) fn run( + context: &Context, + attributes: Attributes, +) -> Result { let module = llvm::Module::new(context, LLVM_UNNAMED); emit_attribute(context, &module, "clock_rate", attributes.clock_rate)?; @@ -16,7 +19,12 @@ pub(crate) fn run(context: &Context, attributes: Attributes) -> Result Result<(), TranslateError> { +fn emit_attribute( + context: &Context, + module: &llvm::Module, + name: &str, + attribute: u32, +) -> Result<(), TranslateError> { let name = format!("{}attribute_{}\0", ZLUDA_PTX_PREFIX, name).to_ascii_uppercase(); let name = unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) }; let attribute_type = get_scalar_type(context.get(), ast::ScalarType::U32); @@ -31,4 +39,4 @@ fn emit_attribute(context: &Context, module: &llvm::Module, name: &str, attribut unsafe { LLVMSetInitializer(global, LLVMConstInt(attribute_type, attribute as u64, 0)) }; unsafe { LLVMSetGlobalConstant(global, 1) }; Ok(()) -} \ No newline at end of file +} diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index d73a881..f234507 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -1,3019 +1,3024 @@ -// We use Raw LLVM-C bindings here because using inkwell is just not worth it. -// Specifically the issue is with builder functions. We maintain the mapping -// between ZLUDA identifiers and LLVM values. When using inkwell, LLVM values -// are kept as instances `AnyValueEnum`. Now look at the signature of -// `Builder::build_int_add(...)`: -// pub fn build_int_add>(&self, lhs: T, rhs: T, name: &str, ) -> Result -// At this point both lhs and rhs are `AnyValueEnum`. To call -// `build_int_add(...)` we would have to do something like this: -// if let (Ok(lhs), Ok(rhs)) = (lhs.as_int(), rhs.as_int()) { -// builder.build_int_add(lhs, rhs, dst)?; -// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_pointer(), rhs.as_pointer()) { -// builder.build_int_add(lhs, rhs, dst)?; -// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_vector(), rhs.as_vector()) { -// builder.build_int_add(lhs, rhs, dst)?; -// } else { -// return Err(error_unrachable()); -// } -// while with plain LLVM-C it's just: -// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) }; - -// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete. -// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with -// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all" -// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel", -// but it will too fail similarly, but with "unable to legalize instruction" - -use std::array::TryFromSliceError; -use std::convert::TryInto; -use std::ffi::{CStr, NulError}; -use std::{i8, ptr, u64}; - -use super::*; -use crate::pass::*; -use llvm_zluda::{core::*, *}; -use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; -use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; -use ptx_parser::{CpAsyncArgs, CpAsyncDetails, Mul24Control}; - -struct Builder(LLVMBuilderRef); - -impl Builder { - fn new(ctx: &Context) -> Self { - Self::new_raw(ctx.get()) - } - - fn new_raw(ctx: LLVMContextRef) -> Self { - Self(unsafe { LLVMCreateBuilderInContext(ctx) }) - } - - fn get(&self) -> LLVMBuilderRef { - self.0 - } -} - -impl Drop for Builder { - fn drop(&mut self) { - unsafe { - LLVMDisposeBuilder(self.0); - } - } -} - -pub(crate) fn run<'input>( - context: &Context, - id_defs: GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result { - let module = llvm::Module::new(context, LLVM_UNNAMED); - let mut emit_ctx = ModuleEmitContext::new(context, &module, &id_defs); - for directive in directives { - match directive { - Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?, - Directive2::Method(method) => emit_ctx.emit_method(method)?, - } - } - if let Err(err) = module.verify() { - panic!("{:?}", err); - } - Ok(module) -} - -struct ModuleEmitContext<'a, 'input> { - context: LLVMContextRef, - module: LLVMModuleRef, - builder: Builder, - id_defs: &'a GlobalStringIdentResolver2<'input>, - resolver: ResolveIdent, -} - -impl<'a, 'input> ModuleEmitContext<'a, 'input> { - fn new(context: &Context, module: &llvm::Module, id_defs: &'a GlobalStringIdentResolver2<'input>) -> Self { - ModuleEmitContext { - context: context.get(), - module: module.get(), - builder: Builder::new(context), - id_defs, - resolver: ResolveIdent::new(&id_defs), - } - } - - fn kernel_call_convention() -> u32 { - LLVMCallConv::LLVMAMDGPUKERNELCallConv as u32 - } - - fn func_call_convention() -> u32 { - LLVMCallConv::LLVMCCallConv as u32 - } - - fn emit_method( - &mut self, - method: Function2, SpirvWord>, - ) -> Result<(), TranslateError> { - let name = method - .import_as - .as_deref() - .or_else(|| self.id_defs.ident_map[&method.name].name.as_deref()) - .ok_or_else(|| error_unreachable())?; - let name = CString::new(name).map_err(|_| error_unreachable())?; - let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; - if fn_ == ptr::null_mut() { - let fn_type = get_function_type( - self.context, - method.return_arguments.iter().map(|v| &v.v_type), - method - .input_arguments - .iter() - .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), - )?; - fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; - self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true"); - self.emit_fn_attribute(fn_, "uniform-work-group-size", "true"); - self.emit_fn_attribute(fn_, "no-trapping-math", "true"); - } - if !method.is_kernel { - self.resolver.register(method.name, fn_); - self.emit_fn_attribute(fn_, "denormal-fp-math-f32", "dynamic"); - self.emit_fn_attribute(fn_, "denormal-fp-math", "dynamic"); - } else { - self.emit_fn_attribute( - fn_, - "denormal-fp-math-f32", - llvm_ftz(method.flush_to_zero_f32), - ); - self.emit_fn_attribute( - fn_, - "denormal-fp-math", - llvm_ftz(method.flush_to_zero_f16f64), - ); - } - for (i, param) in method.input_arguments.iter().enumerate() { - let value = unsafe { LLVMGetParam(fn_, i as u32) }; - let name = self.resolver.get_or_add(param.name); - unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; - self.resolver.register(param.name, value); - if method.is_kernel { - let attr_kind = unsafe { - LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len()) - }; - let attr = unsafe { - LLVMCreateTypeAttribute( - self.context, - attr_kind, - get_type(self.context, ¶m.v_type)?, - ) - }; - unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; - } - } - let call_conv = if method.is_kernel { - Self::kernel_call_convention() - } else { - Self::func_call_convention() - }; - unsafe { LLVMSetFunctionCallConv(fn_, call_conv) }; - if let Some(statements) = method.body { - let variables_bb = - unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; - let variables_builder = Builder::new_raw(self.context); - unsafe { LLVMPositionBuilderAtEnd(variables_builder.get(), variables_bb) }; - let real_bb = - unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; - unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) }; - let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); - for var in method.return_arguments { - method_emitter.emit_variable(var)?; - } - for statement in statements.iter() { - if let Statement::Label(label) = statement { - method_emitter.emit_label_initial(*label); - } - } - let mut statements = statements.into_iter(); - if let Some(Statement::Label(label)) = statements.next() { - method_emitter.emit_label_delayed(label)?; - } else { - return Err(error_unreachable()); - } - method_emitter.emit_kernel_rounding_prelude( - method.is_kernel, - method.rounding_mode_f32, - method.rounding_mode_f16f64, - )?; - for statement in statements { - method_emitter.emit_statement(statement)?; - } - unsafe { LLVMBuildBr(method_emitter.variables_builder.get(), real_bb) }; - } - Ok(()) - } - - fn emit_global( - &mut self, - _linking: ast::LinkingDirective, - var: ast::Variable, - ) -> Result<(), TranslateError> { - let name = self - .id_defs - .ident_map - .get(&var.name) - .map(|entry| { - entry - .name - .as_ref() - .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?))) - }) - .flatten() - .transpose() - .map_err(|_| error_unreachable())? - .unwrap_or(Cow::Borrowed(LLVM_UNNAMED)); - let global = unsafe { - LLVMAddGlobalInAddressSpace( - self.module, - get_type(self.context, &var.v_type)?, - name.as_ptr(), - get_state_space(var.state_space)?, - ) - }; - self.resolver.register(var.name, global); - if let Some(align) = var.align { - unsafe { LLVMSetAlignment(global, align) }; - } - if !var.array_init.is_empty() { - self.emit_array_init(&var.v_type, &*var.array_init, global)?; - } - Ok(()) - } - - // TODO: instead of Vec we should emit a typed initializer - fn emit_array_init( - &mut self, - type_: &ast::Type, - array_init: &[u8], - global: *mut llvm_zluda::LLVMValue, - ) -> Result<(), TranslateError> { - match type_ { - ast::Type::Array(None, scalar, dimensions) => { - if dimensions.len() != 1 { - todo!() - } - if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() { - return Err(error_unreachable()); - } - let type_ = get_scalar_type(self.context, *scalar); - let mut elements = array_init - .chunks(scalar.size_of() as usize) - .map(|chunk| self.constant_from_bytes(*scalar, chunk, type_)) - .collect::, _>>() - .map_err(|_| error_unreachable())?; - let initializer = - unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) }; - unsafe { LLVMSetInitializer(global, initializer) }; - } - _ => todo!(), - } - Ok(()) - } - - fn constant_from_bytes( - &self, - scalar: ast::ScalarType, - bytes: &[u8], - llvm_type: LLVMTypeRef, - ) -> Result { - Ok(match scalar { - ptx_parser::ScalarType::Pred - | ptx_parser::ScalarType::S8 - | ptx_parser::ScalarType::B8 - | ptx_parser::ScalarType::U8 => unsafe { - LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0) - }, - ptx_parser::ScalarType::S16 - | ptx_parser::ScalarType::B16 - | ptx_parser::ScalarType::U16 => unsafe { - LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) - }, - ptx_parser::ScalarType::S32 - | ptx_parser::ScalarType::B32 - | ptx_parser::ScalarType::U32 => unsafe { - LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0) - }, - ptx_parser::ScalarType::F16 => todo!(), - ptx_parser::ScalarType::BF16 => todo!(), - ptx_parser::ScalarType::U64 => todo!(), - ptx_parser::ScalarType::S64 => todo!(), - ptx_parser::ScalarType::S16x2 => todo!(), - ptx_parser::ScalarType::F32 => todo!(), - ptx_parser::ScalarType::B64 => todo!(), - ptx_parser::ScalarType::F64 => todo!(), - ptx_parser::ScalarType::B128 => todo!(), - ptx_parser::ScalarType::U16x2 => todo!(), - ptx_parser::ScalarType::F16x2 => todo!(), - ptx_parser::ScalarType::BF16x2 => todo!(), - }) - } - - fn emit_fn_attribute(&self, llvm_object: LLVMValueRef, key: &str, value: &str) { - let attribute = unsafe { - LLVMCreateStringAttribute( - self.context, - key.as_ptr() as _, - key.len() as u32, - value.as_ptr() as _, - value.len() as u32, - ) - }; - unsafe { LLVMAddAttributeAtIndex(llvm_object, LLVMAttributeFunctionIndex, attribute) }; - } -} - -fn llvm_ftz(ftz: bool) -> &'static str { - if ftz { - "preserve-sign" - } else { - "ieee" - } -} - -fn get_input_argument_type( - context: LLVMContextRef, - v_type: &ast::Type, - state_space: ast::StateSpace, -) -> Result { - match state_space { - ast::StateSpace::ParamEntry => { - Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) }) - } - ast::StateSpace::Reg => get_type(context, v_type), - _ => return Err(error_unreachable()), - } -} - -struct MethodEmitContext<'a> { - context: LLVMContextRef, - module: LLVMModuleRef, - method: LLVMValueRef, - builder: LLVMBuilderRef, - variables_builder: Builder, - resolver: &'a mut ResolveIdent, -} - -impl<'a> MethodEmitContext<'a> { - fn new( - parent: &'a mut ModuleEmitContext, - method: LLVMValueRef, - variables_builder: Builder, - ) -> MethodEmitContext<'a> { - MethodEmitContext { - context: parent.context, - module: parent.module, - builder: parent.builder.get(), - variables_builder, - resolver: &mut parent.resolver, - method, - } - } - - fn emit_statement( - &mut self, - statement: Statement, SpirvWord>, - ) -> Result<(), TranslateError> { - Ok(match statement { - Statement::Variable(var) => self.emit_variable(var)?, - Statement::Label(label) => self.emit_label_delayed(label)?, - Statement::Instruction(inst) => self.emit_instruction(inst)?, - Statement::Conditional(cond) => self.emit_conditional(cond)?, - Statement::Conversion(conversion) => self.emit_conversion(conversion)?, - Statement::Constant(constant) => self.emit_constant(constant)?, - Statement::RetValue(_, values) => self.emit_ret_value(values)?, - Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?, - Statement::RepackVector(repack) => self.emit_vector_repack(repack)?, - Statement::FunctionPointer(_) => todo!(), - Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, - Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, - Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?, - Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?, - }) - } - - // This should be a kernel attribute, but sadly AMDGPU LLVM target does - // not support attribute for it. So we have to set it as the first - // instruction in the body of a kernel - fn emit_kernel_rounding_prelude( - &mut self, - is_kernel: bool, - rounding_mode_f32: ast::RoundingMode, - rounding_mode_f16f64: ast::RoundingMode, - ) -> Result<(), TranslateError> { - if is_kernel { - if rounding_mode_f32 != ast::RoundingMode::NearestEven - || rounding_mode_f16f64 != ast::RoundingMode::NearestEven - { - self.emit_set_mode(ModeRegister::Rounding { - f32: rounding_mode_f32, - f16f64: rounding_mode_f16f64, - })?; - } - } - Ok(()) - } - - fn emit_variable(&mut self, var: ast::Variable) -> Result<(), TranslateError> { - let alloca = unsafe { - LLVMZludaBuildAlloca( - self.variables_builder.get(), - get_type(self.context, &var.v_type)?, - get_state_space(var.state_space)?, - self.resolver.get_or_add_raw(var.name), - ) - }; - self.resolver.register(var.name, alloca); - if let Some(align) = var.align { - unsafe { LLVMSetAlignment(alloca, align) }; - } - if !var.array_init.is_empty() { - todo!() - } - Ok(()) - } - - fn emit_label_initial(&mut self, label: SpirvWord) { - let block = unsafe { - LLVMAppendBasicBlockInContext( - self.context, - self.method, - self.resolver.get_or_add_raw(label), - ) - }; - self.resolver - .register(label, unsafe { LLVMBasicBlockAsValue(block) }); - } - - fn emit_label_delayed(&mut self, label: SpirvWord) -> Result<(), TranslateError> { - let block = self.resolver.value(label)?; - let block = unsafe { LLVMValueAsBasicBlock(block) }; - let last_block = unsafe { LLVMGetInsertBlock(self.builder) }; - if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() { - unsafe { LLVMBuildBr(self.builder, block) }; - } - unsafe { LLVMPositionBuilderAtEnd(self.builder, block) }; - Ok(()) - } - - fn emit_instruction( - &mut self, - inst: ast::Instruction, - ) -> Result<(), TranslateError> { - match inst { - ast::Instruction::Mov { data: _, arguments } => self.emit_mov(arguments), - ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), - ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), - ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), - ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments), - ast::Instruction::Mul24 { data, arguments } => self.emit_mul24(data, arguments), - ast::Instruction::Set { data, arguments } => self.emit_set(data, arguments), - ast::Instruction::SetBool { data, arguments } => self.emit_set_bool(data, arguments), - ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments), - ast::Instruction::SetpBool { data, arguments } => self.emit_setp_bool(data, arguments), - ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments), - ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments), - ast::Instruction::And { arguments, .. } => self.emit_and(arguments), - ast::Instruction::Bra { arguments } => self.emit_bra(arguments), - ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), - ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments), - ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments), - ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments), - ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), - ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), - ast::Instruction::Abs { data, arguments } => self.emit_abs(data, arguments), - ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments), - ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments), - ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), - ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments), - ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments), - ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments), - ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments), - ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments), - ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments), - ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), - ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments), - ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments), - ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments), - ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments), - ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments), - ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments), - ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments), - ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments), - ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments), - ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments), - ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments), - ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments), - ast::Instruction::PrmtSlow { .. } => { - Err(error_todo_msg("PrmtSlow is not implemented yet")) - } - ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments), - ast::Instruction::Membar { data } => self.emit_membar(data), - ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")), - ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments), - ast::Instruction::CpAsync { data, arguments } => self.emit_cp_async(data, arguments), - ast::Instruction::CpAsyncCommitGroup { } => Ok(()), // nop - ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop - ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop - // replaced by a function call - ast::Instruction::Bfe { .. } - | ast::Instruction::Bar { .. } - | ast::Instruction::BarRed { .. } - | ast::Instruction::Bfi { .. } - | ast::Instruction::Activemask { .. } - | ast::Instruction::ShflSync { .. } - | ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()), - } - } - - fn emit_ld( - &mut self, - data: ast::LdDetails, - arguments: ast::LdArgs, - ) -> Result<(), TranslateError> { - if data.qualifier != ast::LdStQualifier::Weak { - todo!() - } - let builder = self.builder; - let type_ = get_type(self.context, &data.typ)?; - let ptr = self.resolver.value(arguments.src)?; - self.resolver.with_result(arguments.dst, |dst| { - let load = unsafe { LLVMBuildLoad2(builder, type_, ptr, dst) }; - unsafe { LLVMSetAlignment(load, data.typ.layout().align() as u32) }; - load - }); - Ok(()) - } - - fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> { - let builder = self.builder; - match conversion.kind { - ConversionKind::Default => self.emit_conversion_default( - self.resolver.value(conversion.src)?, - conversion.dst, - &conversion.from_type, - conversion.from_space, - &conversion.to_type, - conversion.to_space, - ), - ConversionKind::SignExtend => { - let src = self.resolver.value(conversion.src)?; - let type_ = get_type(self.context, &conversion.to_type)?; - self.resolver.with_result(conversion.dst, |dst| unsafe { - LLVMBuildSExt(builder, src, type_, dst) - }); - Ok(()) - } - ConversionKind::BitToPtr => { - let src = self.resolver.value(conversion.src)?; - let type_ = get_pointer_type(self.context, conversion.to_space)?; - self.resolver.with_result(conversion.dst, |dst| unsafe { - LLVMBuildIntToPtr(builder, src, type_, dst) - }); - Ok(()) - } - ConversionKind::PtrToPtr => { - let src = self.resolver.value(conversion.src)?; - let dst_type = get_pointer_type(self.context, conversion.to_space)?; - self.resolver.with_result(conversion.dst, |dst| unsafe { - LLVMBuildAddrSpaceCast(builder, src, dst_type, dst) - }); - Ok(()) - } - ConversionKind::AddressOf => { - let src = self.resolver.value(conversion.src)?; - let dst_type = get_type(self.context, &conversion.to_type)?; - self.resolver.with_result(conversion.dst, |dst| unsafe { - LLVMBuildPtrToInt(self.builder, src, dst_type, dst) - }); - Ok(()) - } - } - } - - fn emit_conversion_default( - &mut self, - src: LLVMValueRef, - dst: SpirvWord, - from_type: &ast::Type, - from_space: ast::StateSpace, - to_type: &ast::Type, - to_space: ast::StateSpace, - ) -> Result<(), TranslateError> { - match (from_type, to_type) { - (ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => { - let from_layout = from_type.layout(); - let to_layout = to_type.layout(); - if from_layout.size() == to_layout.size() { - let dst_type = get_type(self.context, &to_type)?; - if from_type.kind() != ast::ScalarKind::Float - && to_type_scalar.kind() != ast::ScalarKind::Float - { - // It is noop, but another instruction expects result of this conversion - self.resolver.register(dst, src); - } else { - self.resolver.with_result(dst, |dst| unsafe { - LLVMBuildBitCast(self.builder, src, dst_type, dst) - }); - } - Ok(()) - } else { - // This block is safe because it's illegal to implictly convert between floating point values - let same_width_bit_type = unsafe { - LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32) - }; - let same_width_bit_value = unsafe { - LLVMBuildBitCast( - self.builder, - src, - same_width_bit_type, - LLVM_UNNAMED.as_ptr(), - ) - }; - let wide_bit_type = match to_type_scalar.layout().size() { - 1 => ast::ScalarType::B8, - 2 => ast::ScalarType::B16, - 4 => ast::ScalarType::B32, - 8 => ast::ScalarType::B64, - _ => return Err(error_unreachable()), - }; - let wide_bit_type_llvm = unsafe { - LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32) - }; - if to_type_scalar.kind() == ast::ScalarKind::Unsigned - || to_type_scalar.kind() == ast::ScalarKind::Bit - { - let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() { - LLVMBuildZExtOrBitCast - } else { - LLVMBuildTrunc - }; - self.resolver.with_result(dst, |dst| unsafe { - llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst) - }); - Ok(()) - } else { - let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed - && to_type_scalar.kind() == ast::ScalarKind::Signed - { - if to_type_scalar.size_of() >= from_type.size_of() { - LLVMBuildSExtOrBitCast - } else { - LLVMBuildTrunc - } - } else { - if to_type_scalar.size_of() >= from_type.size_of() { - LLVMBuildZExtOrBitCast - } else { - LLVMBuildTrunc - } - }; - let wide_bit_value = unsafe { - conversion_fn( - self.builder, - same_width_bit_value, - wide_bit_type_llvm, - LLVM_UNNAMED.as_ptr(), - ) - }; - self.emit_conversion_default( - wide_bit_value, - dst, - &wide_bit_type.into(), - from_space, - to_type, - to_space, - ) - } - } - } - (ast::Type::Vector(..), ast::Type::Scalar(..)) - | (ast::Type::Scalar(..), ast::Type::Array(..)) - | (ast::Type::Array(..), ast::Type::Scalar(..)) => { - let dst_type = get_type(self.context, to_type)?; - self.resolver.with_result(dst, |dst| unsafe { - LLVMBuildBitCast(self.builder, src, dst_type, dst) - }); - Ok(()) - } - _ => todo!(), - } - } - - fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> { - let type_ = get_scalar_type(self.context, constant.typ); - let value = match constant.value { - ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) }, - ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) }, - ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) }, - ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) }, - }; - self.resolver.register(constant.dst, value); - Ok(()) - } - - fn emit_add( - &mut self, - data: ast::ArithDetails, - arguments: ast::AddArgs, - ) -> Result<(), TranslateError> { - let builder = self.builder; - let fn_ = match data { - ast::ArithDetails::Integer(ast::ArithInteger { - saturate: true, - type_, - }) => { - let op = if type_.kind() == ast::ScalarKind::Signed { - "sadd" - } else { - "uadd" - }; - return self.emit_intrinsic_saturate( - op, - type_, - arguments.dst, - arguments.src1, - arguments.src2, - ); - } - ast::ArithDetails::Integer(ast::ArithInteger { - saturate: false, .. - }) => LLVMBuildAdd, - ast::ArithDetails::Float(..) => LLVMBuildFAdd, - }; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - fn_(builder, src1, src2, dst) - }); - Ok(()) - } - - fn emit_st( - &self, - data: ast::StData, - arguments: ast::StArgs, - ) -> Result<(), TranslateError> { - let ptr = self.resolver.value(arguments.src1)?; - let value = self.resolver.value(arguments.src2)?; - if data.qualifier != ast::LdStQualifier::Weak { - todo!() - } - let store = unsafe { LLVMBuildStore(self.builder, value, ptr) }; - unsafe { LLVMSetAlignment(store, data.typ.layout().align() as u32); } - Ok(()) - } - - fn emit_ret(&self, _data: ast::RetData) { - unsafe { LLVMBuildRetVoid(self.builder) }; - } - - fn emit_call( - &mut self, - data: ast::CallDetails, - arguments: ast::CallArgs, - ) -> Result<(), TranslateError> { - if cfg!(debug_assertions) { - for (_, space) in data.return_arguments.iter() { - if *space != ast::StateSpace::Reg { - panic!() - } - } - for (_, space) in data.input_arguments.iter() { - if *space != ast::StateSpace::Reg { - panic!() - } - } - } - let name = match &*arguments.return_arguments { - [dst] => self.resolver.get_or_add_raw(*dst), - _ => LLVM_UNNAMED.as_ptr(), - }; - let type_ = get_function_type( - self.context, - data.return_arguments.iter().map(|(type_, ..)| type_), - data.input_arguments - .iter() - .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)), - )?; - let mut input_arguments = arguments - .input_arguments - .iter() - .map(|arg| self.resolver.value(*arg)) - .collect::, _>>()?; - let llvm_call = unsafe { - LLVMBuildCall2( - self.builder, - type_, - self.resolver.value(arguments.func)?, - input_arguments.as_mut_ptr(), - input_arguments.len() as u32, - name, - ) - }; - match &*arguments.return_arguments { - [] => {} - [name] => self.resolver.register(*name, llvm_call), - [b32, pred] => { - self.resolver.with_result(*b32, |name| unsafe { - LLVMBuildExtractValue(self.builder, llvm_call, 0, name) - }); - self.resolver.with_result(*pred, |name| unsafe { - let extracted = - LLVMBuildExtractValue(self.builder, llvm_call, 1, LLVM_UNNAMED.as_ptr()); - LLVMBuildTrunc( - self.builder, - extracted, - get_scalar_type(self.context, ast::ScalarType::Pred), - name, - ) - }); - } - _ => { - return Err(error_todo_msg( - "Only two return arguments (.b32, .pred) currently supported", - )) - } - } - Ok(()) - } - - fn emit_mov(&mut self, arguments: ast::MovArgs) -> Result<(), TranslateError> { - self.resolver - .register(arguments.dst, self.resolver.value(arguments.src)?); - Ok(()) - } - - fn emit_ptr_access(&mut self, ptr_access: PtrAccess) -> Result<(), TranslateError> { - let ptr_src = self.resolver.value(ptr_access.ptr_src)?; - let mut offset_src = self.resolver.value(ptr_access.offset_src)?; - let pointee_type = get_scalar_type(self.context, ast::ScalarType::B8); - self.resolver.with_result(ptr_access.dst, |dst| unsafe { - LLVMBuildInBoundsGEP2(self.builder, pointee_type, ptr_src, &mut offset_src, 1, dst) - }); - Ok(()) - } - - fn emit_and(&mut self, arguments: ast::AndArgs) -> Result<(), TranslateError> { - let builder = self.builder; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildAnd(builder, src1, src2, dst) - }); - Ok(()) - } - - fn emit_atom( - &mut self, - data: ast::AtomDetails, - arguments: ast::AtomArgs, - ) -> Result<(), TranslateError> { - let builder = self.builder; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - let op = match data.op { - ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd, - ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr, - ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor, - ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg, - ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd, - ast::AtomicOp::IncrementWrap => { - LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap - } - ast::AtomicOp::DecrementWrap => { - LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap - } - ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin, - ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin, - ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax, - ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax, - ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd, - ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin, - ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax, - }; - self.resolver.register(arguments.dst, unsafe { - LLVMZludaBuildAtomicRMW( - builder, - op, - src1, - src2, - get_scope(data.scope)?, - get_ordering(data.semantics), - ) - }); - Ok(()) - } - - fn emit_atom_cas( - &mut self, - data: ast::AtomCasDetails, - arguments: ast::AtomCasArgs, - ) -> Result<(), TranslateError> { - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - let src3 = self.resolver.value(arguments.src3)?; - let success_ordering = get_ordering(data.semantics); - let failure_ordering = get_ordering_failure(data.semantics); - let temp = unsafe { - LLVMZludaBuildAtomicCmpXchg( - self.builder, - src1, - src2, - src3, - get_scope(data.scope)?, - success_ordering, - failure_ordering, - ) - }; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildExtractValue(self.builder, temp, 0, dst) - }); - Ok(()) - } - - fn emit_bra(&self, arguments: ast::BraArgs) -> Result<(), TranslateError> { - let src = self.resolver.value(arguments.src)?; - let src = unsafe { LLVMValueAsBasicBlock(src) }; - unsafe { LLVMBuildBr(self.builder, src) }; - Ok(()) - } - - fn emit_brev( - &mut self, - data: ast::ScalarType, - arguments: ast::BrevArgs, - ) -> Result<(), TranslateError> { - let llvm_fn = match data.size_of() { - 4 => c"llvm.bitreverse.i32", - 8 => c"llvm.bitreverse.i64", - _ => return Err(error_unreachable()), - }; - let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; - let type_ = get_scalar_type(self.context, data); - let fn_type = get_function_type( - self.context, - iter::once(&data.into()), - iter::once(Ok(type_)), - )?; - if fn_ == ptr::null_mut() { - fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; - } - let mut src = self.resolver.value(arguments.src)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst) - }); - Ok(()) - } - - fn emit_ret_value( - &mut self, - values: Vec<(SpirvWord, ptx_parser::Type)>, - ) -> Result<(), TranslateError> { - let loads = values - .iter() - .map(|(value, type_)| { - let value = self.resolver.value(*value)?; - let lowered_type = get_type(self.context, type_)?; - let load = unsafe { - LLVMBuildLoad2(self.builder, lowered_type, value, LLVM_UNNAMED.as_ptr()) - }; - unsafe { - LLVMSetAlignment(load, type_.layout().align() as u32); - } - Ok(load) - }) - .collect::, _>>()?; - - match &*loads { - [] => unsafe { LLVMBuildRetVoid(self.builder) }, - [value] => unsafe { LLVMBuildRet(self.builder, *value) }, - _ => { - check_multiple_return_types(values.iter().map(|(_, type_)| type_))?; - let array_ty = - get_array_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?; - let insert_b32 = unsafe { - LLVMBuildInsertValue( - self.builder, - LLVMGetPoison(array_ty), - loads[0], - 0, - LLVM_UNNAMED.as_ptr(), - ) - }; - let zext_pred = unsafe { - LLVMBuildZExt( - self.builder, - loads[1], - get_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32))?, - LLVM_UNNAMED.as_ptr(), - ) - }; - let insert_pred = unsafe { - LLVMBuildInsertValue( - self.builder, - insert_b32, - zext_pred, - 1, - LLVM_UNNAMED.as_ptr(), - ) - }; - unsafe { LLVMBuildRet(self.builder, insert_pred) } - } - }; - Ok(()) - } - - fn emit_clz( - &mut self, - data: ptx_parser::ScalarType, - arguments: ptx_parser::ClzArgs, - ) -> Result<(), TranslateError> { - let llvm_fn = match data.size_of() { - 4 => c"llvm.ctlz.i32", - 8 => c"llvm.ctlz.i64", - _ => return Err(error_unreachable()), - }; - let type_ = get_scalar_type(self.context, data.into()); - let pred = get_scalar_type(self.context, ast::ScalarType::Pred); - let fn_type = get_function_type( - self.context, - iter::once(&ast::ScalarType::U32.into()), - [Ok(type_), Ok(pred)].into_iter(), - )?; - let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; - if fn_ == ptr::null_mut() { - fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; - } - let src = self.resolver.value(arguments.src)?; - let false_ = unsafe { LLVMConstInt(pred, 0, 0) }; - let mut args = [src, false_]; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildCall2( - self.builder, - fn_type, - fn_, - args.as_mut_ptr(), - args.len() as u32, - dst, - ) - }); - Ok(()) - } - - fn emit_mul( - &mut self, - data: ast::MulDetails, - arguments: ast::MulArgs, - ) -> Result<(), TranslateError> { - self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?; - Ok(()) - } - - fn emit_mul_impl( - &mut self, - data: ast::MulDetails, - dst: Option, - src1: SpirvWord, - src2: SpirvWord, - ) -> Result { - let mul_fn = match data { - ast::MulDetails::Integer { control, type_ } => match control { - ast::MulIntControl::Low => LLVMBuildMul, - ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2), - ast::MulIntControl::Wide => { - return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1) - } - }, - ast::MulDetails::Float(..) => LLVMBuildFMul, - }; - let src1 = self.resolver.value(src1)?; - let src2 = self.resolver.value(src2)?; - Ok(self - .resolver - .with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) })) - } - - fn emit_mul_high( - &mut self, - type_: ptx_parser::ScalarType, - dst: Option, - src1: SpirvWord, - src2: SpirvWord, - ) -> Result { - let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?; - let shift_constant = - unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) }; - let shifted = unsafe { - LLVMBuildLShr( - self.builder, - wide_value, - shift_constant, - LLVM_UNNAMED.as_ptr(), - ) - }; - let narrow_type = get_scalar_type(self.context, type_); - Ok(self.resolver.with_result_option(dst, |dst| unsafe { - LLVMBuildTrunc(self.builder, shifted, narrow_type, dst) - })) - } - - fn emit_mul_wide_impl( - &mut self, - type_: ptx_parser::ScalarType, - dst: Option, - src1: SpirvWord, - src2: SpirvWord, - ) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> { - let src1 = self.resolver.value(src1)?; - let src2 = self.resolver.value(src2)?; - let wide_type = - unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) }; - let llvm_cast = match type_.kind() { - ptx_parser::ScalarKind::Signed => LLVMBuildSExt, - ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt, - _ => return Err(error_unreachable()), - }; - let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) }; - let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) }; - Ok(( - wide_type, - self.resolver.with_result_option(dst, |dst| unsafe { - LLVMBuildMul(self.builder, src1, src2, dst) - }), - )) - } - - fn emit_cos( - &mut self, - _data: ast::FlushToZero, - arguments: ast::CosArgs, - ) -> Result<(), TranslateError> { - let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); - let cos = self.emit_intrinsic( - c"llvm.cos.f32", - Some(arguments.dst), - Some(&ast::ScalarType::F32.into()), - vec![(self.resolver.value(arguments.src)?, llvm_f32)], - )?; - unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) } - Ok(()) - } - - fn emit_or( - &mut self, - _data: ptx_parser::ScalarType, - arguments: ptx_parser::OrArgs, - ) -> Result<(), TranslateError> { - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildOr(self.builder, src1, src2, dst) - }); - Ok(()) - } - - fn emit_xor( - &mut self, - _data: ptx_parser::ScalarType, - arguments: ptx_parser::XorArgs, - ) -> Result<(), TranslateError> { - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildXor(self.builder, src1, src2, dst) - }); - Ok(()) - } - - fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> { - let src = self.resolver.value(vec_acccess.vector_src)?; - let index = unsafe { - LLVMConstInt( - get_scalar_type(self.context, ast::ScalarType::B8), - vec_acccess.member as _, - 0, - ) - }; - self.resolver - .with_result(vec_acccess.scalar_dst, |dst| unsafe { - LLVMBuildExtractElement(self.builder, src, index, dst) - }); - Ok(()) - } - - fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> { - let vector_src = self.resolver.value(vector_write.vector_src)?; - let scalar_src = self.resolver.value(vector_write.scalar_src)?; - let index = unsafe { - LLVMConstInt( - get_scalar_type(self.context, ast::ScalarType::B8), - vector_write.member as _, - 0, - ) - }; - self.resolver - .with_result(vector_write.vector_dst, |dst| unsafe { - LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst) - }); - Ok(()) - } - - fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> { - let i8_type = get_scalar_type(self.context, ast::ScalarType::B8); - if repack.is_extract { - let src = self.resolver.value(repack.packed)?; - for (index, dst) in repack.unpacked.iter().enumerate() { - let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) }; - self.resolver.with_result(*dst, |dst| unsafe { - LLVMBuildExtractElement(self.builder, src, index, dst) - }); - } - } else { - let vector_type = get_type( - self.context, - &ast::Type::Vector(repack.unpacked.len() as u8, repack.typ), - )?; - let mut temp_vec = unsafe { LLVMGetUndef(vector_type) }; - for (index, src_id) in repack.unpacked.iter().enumerate() { - let dst = if index == repack.unpacked.len() - 1 { - Some(repack.packed) - } else { - None - }; - let scalar_src = self.resolver.value(*src_id)?; - let index = unsafe { LLVMConstInt(i8_type, index as _, 0) }; - temp_vec = self.resolver.with_result_option(dst, |dst| unsafe { - LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst) - }); - } - } - Ok(()) - } - - fn emit_div( - &mut self, - data: ptx_parser::DivDetails, - arguments: ptx_parser::DivArgs, - ) -> Result<(), TranslateError> { - let integer_div = match data { - ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv, - ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv, - ptx_parser::DivDetails::Float(float_div) => { - return self.emit_div_float(float_div, arguments) - } - }; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - integer_div(self.builder, src1, src2, dst) - }); - Ok(()) - } - - fn emit_div_float( - &mut self, - float_div: ptx_parser::DivFloatDetails, - arguments: ptx_parser::DivArgs, - ) -> Result<(), TranslateError> { - let builder = self.builder; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - let _rnd = match float_div.kind { - ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven, - ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven, - ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode, - }; - let approx = match float_div.kind { - ptx_parser::DivFloatKind::Approx => { - LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc - } - ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone, - ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone, - }; - let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildFDiv(builder, src1, src2, dst) - }); - unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) }; - if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind { - // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div: - // div.full.f32 implements a relatively fast, full-range approximation that scales - // operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not - // support rounding modifiers. The maximum ulp error is 2 across the full range of - // inputs. - // https://llvm.org/docs/LangRef.html#fpmath-metadata - let fpmath_value = - unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) }; - let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) }; - let mut md_node_content = [fpmath_value]; - let md_node = unsafe { - LLVMMDNodeInContext2( - self.context, - md_node_content.as_mut_ptr(), - md_node_content.len(), - ) - }; - let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) }; - let kind = unsafe { - LLVMGetMDKindIDInContext( - self.context, - "fpmath".as_ptr().cast(), - "fpmath".len() as u32, - ) - }; - unsafe { LLVMSetMetadata(fdiv, kind, md_node) }; - } - Ok(()) - } - - fn emit_cvta( - &mut self, - data: ptx_parser::CvtaDetails, - arguments: ptx_parser::CvtaArgs, - ) -> Result<(), TranslateError> { - let (from_space, to_space) = match data.direction { - ptx_parser::CvtaDirection::GenericToExplicit => { - (ast::StateSpace::Generic, data.state_space) - } - ptx_parser::CvtaDirection::ExplicitToGeneric => { - (data.state_space, ast::StateSpace::Generic) - } - }; - let from_type = get_pointer_type(self.context, from_space)?; - let dest_type = get_pointer_type(self.context, to_space)?; - let src = self.resolver.value(arguments.src)?; - let temp_ptr = - unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) }; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst) - }); - Ok(()) - } - - fn emit_sub( - &mut self, - data: ptx_parser::ArithDetails, - arguments: ptx_parser::SubArgs, - ) -> Result<(), TranslateError> { - match data { - ptx_parser::ArithDetails::Integer(arith_integer) => { - self.emit_sub_integer(arith_integer, arguments) - } - ptx_parser::ArithDetails::Float(arith_float) => { - self.emit_sub_float(arith_float, arguments) - } - } - } - - fn emit_sub_integer( - &mut self, - arith_integer: ptx_parser::ArithInteger, - arguments: ptx_parser::SubArgs, - ) -> Result<(), TranslateError> { - if arith_integer.saturate { - let op = if arith_integer.type_.kind() == ast::ScalarKind::Signed { - "ssub" - } else { - "usub" - }; - return self.emit_intrinsic_saturate( - op, - arith_integer.type_, - arguments.dst, - arguments.src1, - arguments.src2, - ); - } - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildSub(self.builder, src1, src2, dst) - }); - Ok(()) - } - - fn emit_sub_float( - &mut self, - _arith_float: ptx_parser::ArithFloat, - arguments: ptx_parser::SubArgs, - ) -> Result<(), TranslateError> { - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildFSub(self.builder, src1, src2, dst) - }); - Ok(()) - } - - fn emit_sin( - &mut self, - _data: ptx_parser::FlushToZero, - arguments: ptx_parser::SinArgs, - ) -> Result<(), TranslateError> { - let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); - let sin = self.emit_intrinsic( - c"llvm.sin.f32", - Some(arguments.dst), - Some(&ast::ScalarType::F32.into()), - vec![(self.resolver.value(arguments.src)?, llvm_f32)], - )?; - unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) } - Ok(()) - } - - fn emit_intrinsic( - &mut self, - name: &CStr, - dst: Option, - return_type: Option<&ast::Type>, - arguments: Vec<(LLVMValueRef, LLVMTypeRef)>, - ) -> Result { - let fn_type = get_function_type( - self.context, - return_type.into_iter(), - arguments.iter().map(|(_, type_)| Ok(*type_)), - )?; - let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; - if fn_ == ptr::null_mut() { - fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; - } - let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::>(); - Ok(self.resolver.with_result_option(dst, |dst| unsafe { - LLVMBuildCall2( - self.builder, - fn_type, - fn_, - arguments.as_mut_ptr(), - arguments.len() as u32, - dst, - ) - })) - } - - fn emit_neg( - &mut self, - data: ptx_parser::TypeFtz, - arguments: ptx_parser::NegArgs, - ) -> Result<(), TranslateError> { - let src = self.resolver.value(arguments.src)?; - let is_floating_point = data.type_.kind() == ptx_parser::ScalarKind::Float; - let llvm_fn = if is_floating_point { - LLVMBuildFNeg - } else { - LLVMBuildNeg - }; - if is_floating_point && data.flush_to_zero == Some(true) { - let negated = unsafe { llvm_fn(self.builder, src, LLVM_UNNAMED.as_ptr()) }; - let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_)); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(arguments.dst), - Some(&data.type_.into()), - vec![(negated, get_scalar_type(self.context, data.type_))], - )?; - } else { - self.resolver.with_result(arguments.dst, |dst| unsafe { - llvm_fn(self.builder, src, dst) - }); - } - Ok(()) - } - - fn emit_not( - &mut self, - type_: ptx_parser::ScalarType, - arguments: ptx_parser::NotArgs, - ) -> Result<(), TranslateError> { - let src = self.resolver.value(arguments.src)?; - let type_ = get_scalar_type(self.context, type_); - let constant = unsafe { LLVMConstInt(type_, u64::MAX, 0) }; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildXor(self.builder, src, constant, dst) - }); - Ok(()) - } - - fn emit_setp( - &mut self, - data: ptx_parser::SetpData, - arguments: ptx_parser::SetpArgs, - ) -> Result<(), TranslateError> { - let dst = self.emit_setp_impl(data, arguments.dst2, arguments.src1, arguments.src2)?; - self.resolver.register(arguments.dst1, dst); - Ok(()) - } - - fn emit_setp_impl( - &mut self, - data: ptx_parser::SetpData, - dst2: Option, - src1: SpirvWord, - src2: SpirvWord, - ) -> Result { - if dst2.is_some() { - return Err(error_todo_msg( - "setp with two dst arguments not yet supported", - )); - } - match data.cmp_op { - ptx_parser::SetpCompareOp::Integer(setp_compare_int) => { - self.emit_setp_int(setp_compare_int, src1, src2) - } - ptx_parser::SetpCompareOp::Float(setp_compare_float) => { - self.emit_setp_float(setp_compare_float, src1, src2) - } - } - } - - fn emit_setp_int( - &mut self, - setp: ptx_parser::SetpCompareInt, - src1: SpirvWord, - src2: SpirvWord, - ) -> Result { - let op = match setp { - ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ, - ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE, - ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT, - ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE, - ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT, - ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE, - ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT, - ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE, - ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT, - ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE, - }; - let src1 = self.resolver.value(src1)?; - let src2 = self.resolver.value(src2)?; - Ok(unsafe { LLVMBuildICmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) }) - } - - fn emit_setp_float( - &mut self, - setp: ptx_parser::SetpCompareFloat, - src1: SpirvWord, - src2: SpirvWord, - ) -> Result { - let op = match setp { - ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ, - ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE, - ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT, - ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE, - ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT, - ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE, - ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ, - ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE, - ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT, - ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE, - ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT, - ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE, - ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD, - ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO, - }; - let src1 = self.resolver.value(src1)?; - let src2 = self.resolver.value(src2)?; - Ok(unsafe { LLVMBuildFCmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) }) - } - - fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> { - let predicate = self.resolver.value(cond.predicate)?; - let if_true = self.resolver.value(cond.if_true)?; - let if_false = self.resolver.value(cond.if_false)?; - unsafe { - LLVMBuildCondBr( - self.builder, - predicate, - LLVMValueAsBasicBlock(if_true), - LLVMValueAsBasicBlock(if_false), - ) - }; - Ok(()) - } - - fn emit_cvt( - &mut self, - data: ptx_parser::CvtDetails, - arguments: ptx_parser::CvtArgs, - ) -> Result<(), TranslateError> { - let dst_type = get_scalar_type(self.context, data.to); - let llvm_fn = match data.mode { - ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt, - ptx_parser::CvtMode::SignExtend => LLVMBuildSExt, - ptx_parser::CvtMode::Truncate => LLVMBuildTrunc, - ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, - ptx_parser::CvtMode::IntSaturateToSigned => { - return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments) - } - ptx_parser::CvtMode::IntSaturateToUnsigned => { - return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments) - } - ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt, - ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc, - ptx_parser::CvtMode::FPRound { - integer_rounding: None, - flush_to_zero: None | Some(false), - .. - } => { - return self.emit_mov(ast::MovArgs { - dst: arguments.dst, - src: arguments.src, - }) - } - ptx_parser::CvtMode::FPRound { - integer_rounding: None, - flush_to_zero: Some(true), - .. - } => return self.flush_denormals(data.to, arguments.src, arguments.dst), - ptx_parser::CvtMode::FPRound { - integer_rounding: Some(rounding), - .. - } => return self.emit_cvt_float_to_int(data.from, data.to, rounding, arguments, None), - ptx_parser::CvtMode::SignedFromFP { rounding, .. } => { - return self.emit_cvt_float_to_int( - data.from, - data.to, - rounding, - arguments, - Some(true), - ) - } - ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => { - return self.emit_cvt_float_to_int( - data.from, - data.to, - rounding, - arguments, - Some(false), - ) - } - ptx_parser::CvtMode::FPFromSigned { .. } => { - return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildSIToFP) - } - ptx_parser::CvtMode::FPFromUnsigned { .. } => { - return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildUIToFP) - } - }; - let src = self.resolver.value(arguments.src)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - llvm_fn(self.builder, src, dst_type, dst) - }); - Ok(()) - } - - fn emit_cvt_unsigned_to_signed_sat( - &mut self, - from: ptx_parser::ScalarType, - to: ptx_parser::ScalarType, - arguments: ptx_parser::CvtArgs, - ) -> Result<(), TranslateError> { - let clamped = self.emit_saturate_integer(from, to, &arguments)?; - let resize_fn = if to.layout().size() >= from.layout().size() { - LLVMBuildSExtOrBitCast - } else { - LLVMBuildTrunc - }; - let to_llvm = get_scalar_type(self.context, to); - self.resolver.with_result(arguments.dst, |dst| unsafe { - resize_fn(self.builder, clamped, to_llvm, dst) - }); - Ok(()) - } - - fn emit_saturate_integer( - &mut self, - from: ptx_parser::ScalarType, - to: ptx_parser::ScalarType, - arguments: &ptx_parser::CvtArgs, - ) -> Result { - let from_llvm = get_scalar_type(self.context, from); - match from.kind() { - ptx_parser::ScalarKind::Unsigned => { - let max_value = match to { - ptx_parser::ScalarType::U8 => u8::MAX as u64, - ptx_parser::ScalarType::S8 => i8::MAX as u64, - ptx_parser::ScalarType::U16 => u16::MAX as u64, - ptx_parser::ScalarType::S16 => i16::MAX as u64, - ptx_parser::ScalarType::U32 => u32::MAX as u64, - ptx_parser::ScalarType::S32 => i32::MAX as u64, - ptx_parser::ScalarType::U64 => u64::MAX as u64, - ptx_parser::ScalarType::S64 => i64::MAX as u64, - _ => return Err(error_unreachable()), - }; - let intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from)); - let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; - let clamped = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - None, - Some(&from.into()), - vec![ - (self.resolver.value(arguments.src)?, from_llvm), - (max, from_llvm), - ], - )?; - Ok(clamped) - } - ptx_parser::ScalarKind::Signed => { - let (min_value_from, max_value_from) = match from { - ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128), - ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128), - ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128), - ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128), - _ => return Err(error_unreachable()), - }; - let (min_value_to, max_value_to) = match to { - ptx_parser::ScalarType::U8 => (u8::MIN as i128, u8::MAX as i128), - ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128), - ptx_parser::ScalarType::U16 => (u16::MIN as i128, u16::MAX as i128), - ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128), - ptx_parser::ScalarType::U32 => (u32::MIN as i128, u32::MAX as i128), - ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128), - ptx_parser::ScalarType::U64 => (u64::MIN as i128, u64::MAX as i128), - ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128), - _ => return Err(error_unreachable()), - }; - let min_value = min_value_from.max(min_value_to); - let max_value = max_value_from.min(max_value_to); - let max_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from)); - let min = unsafe { LLVMConstInt(from_llvm, min_value as u64, 1) }; - let min_intrinsic = format!("llvm.smin.{}\0", LLVMTypeDisplay(from)); - let max = unsafe { LLVMConstInt(from_llvm, max_value as u64, 1) }; - let clamped = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(max_intrinsic.as_bytes()) }, - None, - Some(&from.into()), - vec![ - (self.resolver.value(arguments.src)?, from_llvm), - (min, from_llvm), - ], - )?; - let clamped = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(min_intrinsic.as_bytes()) }, - None, - Some(&from.into()), - vec![(clamped, from_llvm), (max, from_llvm)], - )?; - Ok(clamped) - } - _ => return Err(error_unreachable()), - } - } - - fn emit_cvt_signed_to_unsigned_sat( - &mut self, - from: ptx_parser::ScalarType, - to: ptx_parser::ScalarType, - arguments: ptx_parser::CvtArgs, - ) -> Result<(), TranslateError> { - let clamped = self.emit_saturate_integer(from, to, &arguments)?; - let resize_fn = if to.layout().size() >= from.layout().size() { - LLVMBuildZExtOrBitCast - } else { - LLVMBuildTrunc - }; - let to_llvm = get_scalar_type(self.context, to); - self.resolver.with_result(arguments.dst, |dst| unsafe { - resize_fn(self.builder, clamped, to_llvm, dst) - }); - Ok(()) - } - - fn emit_cvt_float_to_int( - &mut self, - from: ast::ScalarType, - to: ast::ScalarType, - rounding: ast::RoundingMode, - arguments: ptx_parser::CvtArgs, - signed_cast: Option, - ) -> Result<(), TranslateError> { - let dst_int_rounded = - self.emit_fp_int_rounding(from, rounding, &arguments, signed_cast.is_some())?; - // In PTX all the int-from-float casts are saturating casts. On the other hand, in LLVM, - // out-of-range fptoui and fptosi have undefined behavior. - // We could handle this all with llvm.fptosi.sat and llvm.fptoui.sat intrinsics, but - // the problem is that, when using *.sat variants AMDGPU target _always_ emits saturation - // checks. Often they are unnecessary because v_cvt_* instructions saturates anyway. - // For that reason, all from-to combinations that we know have a direct corresponding - // v_cvt_* instruction get special treatment - let is_saturating_cast = match (to, from) { - (ast::ScalarType::S16, ast::ScalarType::F16) - | (ast::ScalarType::S32, ast::ScalarType::F32) - | (ast::ScalarType::S32, ast::ScalarType::F64) - | (ast::ScalarType::U16, ast::ScalarType::F16) - | (ast::ScalarType::U32, ast::ScalarType::F32) - | (ast::ScalarType::U32, ast::ScalarType::F64) => true, - _ => false, - }; - let signed_cast = match signed_cast { - Some(s) => s, - None => { - self.resolver.register( - arguments.dst, - dst_int_rounded.ok_or_else(error_unreachable)?, - ); - return Ok(()); - } - }; - if is_saturating_cast { - let to = get_scalar_type(self.context, to); - let src = - dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap()); - let llvm_cast = if signed_cast { - LLVMBuildFPToSI - } else { - LLVMBuildFPToUI - }; - let poisoned_dst = unsafe { llvm_cast(self.builder, src, to, LLVM_UNNAMED.as_ptr()) }; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildFreeze(self.builder, poisoned_dst, dst) - }); - } else { - let cvt_op = if to.kind() == ptx_parser::ScalarKind::Unsigned { - "fptoui" - } else { - "fptosi" - }; - let cast_intrinsic = format!( - "llvm.{cvt_op}.sat.{}.{}\0", - LLVMTypeDisplay(to), - LLVMTypeDisplay(from) - ); - let src = - dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap()); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) }, - Some(arguments.dst), - Some(&to.into()), - vec![(src, get_scalar_type(self.context, from))], - )?; - } - Ok(()) - } - - fn emit_fp_int_rounding( - &mut self, - from: ptx_parser::ScalarType, - rounding: ptx_parser::RoundingMode, - arguments: &ptx_parser::CvtArgs, - will_saturate_with_cvt: bool, - ) -> Result, TranslateError> { - let prefix = match rounding { - ptx_parser::RoundingMode::NearestEven => "llvm.roundeven", - ptx_parser::RoundingMode::Zero => { - // cvt has round-to-zero semantics - if will_saturate_with_cvt { - return Ok(None); - } else { - "llvm.trunc" - } - } - ptx_parser::RoundingMode::NegativeInf => "llvm.floor", - ptx_parser::RoundingMode::PositiveInf => "llvm.ceil", - }; - let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from)); - let rounded_float = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - None, - Some(&from.into()), - vec![( - self.resolver.value(arguments.src)?, - get_scalar_type(self.context, from), - )], - )?; - Ok(Some(rounded_float)) - } - - fn emit_cvt_int_to_float( - &mut self, - to: ptx_parser::ScalarType, - arguments: ptx_parser::CvtArgs, - llvm_func: unsafe extern "C" fn( - arg1: LLVMBuilderRef, - Val: LLVMValueRef, - DestTy: LLVMTypeRef, - Name: *const i8, - ) -> LLVMValueRef, - ) -> Result<(), TranslateError> { - let type_ = get_scalar_type(self.context, to); - let src = self.resolver.value(arguments.src)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - llvm_func(self.builder, src, type_, dst) - }); - Ok(()) - } - - fn emit_rsqrt( - &mut self, - data: ptx_parser::TypeFtz, - arguments: ptx_parser::RsqrtArgs, - ) -> Result<(), TranslateError> { - let type_ = get_scalar_type(self.context, data.type_); - let intrinsic = match data.type_ { - ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32", - ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64", - _ => return Err(error_unreachable()), - }; - self.emit_intrinsic( - intrinsic, - Some(arguments.dst), - Some(&data.type_.into()), - vec![(self.resolver.value(arguments.src)?, type_)], - )?; - Ok(()) - } - - fn emit_sqrt( - &mut self, - data: ptx_parser::RcpData, - arguments: ptx_parser::SqrtArgs, - ) -> Result<(), TranslateError> { - let type_ = get_scalar_type(self.context, data.type_); - let intrinsic = match (data.type_, data.kind) { - (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32", - (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32", - (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64", - _ => return Err(error_unreachable()), - }; - self.emit_intrinsic( - intrinsic, - Some(arguments.dst), - Some(&data.type_.into()), - vec![(self.resolver.value(arguments.src)?, type_)], - )?; - Ok(()) - } - - fn emit_rcp( - &mut self, - data: ptx_parser::RcpData, - arguments: ptx_parser::RcpArgs, - ) -> Result<(), TranslateError> { - let type_ = get_scalar_type(self.context, data.type_); - let intrinsic = match (data.type_, data.kind) { - (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32", - (_, ast::RcpKind::Compliant(rnd)) => { - return self.emit_rcp_compliant(data, arguments, rnd) - } - _ => return Err(error_unreachable()), - }; - self.emit_intrinsic( - intrinsic, - Some(arguments.dst), - Some(&data.type_.into()), - vec![(self.resolver.value(arguments.src)?, type_)], - )?; - Ok(()) - } - - fn emit_rcp_compliant( - &mut self, - data: ptx_parser::RcpData, - arguments: ptx_parser::RcpArgs, - _rnd: ast::RoundingMode, - ) -> Result<(), TranslateError> { - let type_ = get_scalar_type(self.context, data.type_); - let one = unsafe { LLVMConstReal(type_, 1.0) }; - let src = self.resolver.value(arguments.src)?; - let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildFDiv(self.builder, one, src, dst) - }); - unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) }; - Ok(()) - } - - fn emit_shr( - &mut self, - data: ptx_parser::ShrData, - arguments: ptx_parser::ShrArgs, - ) -> Result<(), TranslateError> { - let llvm_type = get_scalar_type(self.context, data.type_); - let (out_of_range, shift_fn): ( - *mut LLVMValue, - unsafe extern "C" fn( - LLVMBuilderRef, - LLVMValueRef, - LLVMValueRef, - *const i8, - ) -> LLVMValueRef, - ) = match data.kind { - ptx_parser::RightShiftKind::Logical => { - (unsafe { LLVMConstNull(llvm_type) }, LLVMBuildLShr) - } - ptx_parser::RightShiftKind::Arithmetic => { - let src1 = self.resolver.value(arguments.src1)?; - let shift_size = - unsafe { LLVMConstInt(llvm_type, (data.type_.size_of() as u64 * 8) - 1, 0) }; - let out_of_range = - unsafe { LLVMBuildAShr(self.builder, src1, shift_size, LLVM_UNNAMED.as_ptr()) }; - (out_of_range, LLVMBuildAShr) - } - }; - self.emit_shift( - data.type_, - arguments.dst, - arguments.src1, - arguments.src2, - out_of_range, - shift_fn, - ) - } - - fn emit_shl( - &mut self, - type_: ptx_parser::ScalarType, - arguments: ptx_parser::ShlArgs, - ) -> Result<(), TranslateError> { - let llvm_type = get_scalar_type(self.context, type_); - self.emit_shift( - type_, - arguments.dst, - arguments.src1, - arguments.src2, - unsafe { LLVMConstNull(llvm_type) }, - LLVMBuildShl, - ) - } - - fn emit_shift( - &mut self, - type_: ast::ScalarType, - dst: SpirvWord, - src1: SpirvWord, - src2: SpirvWord, - out_of_range_value: LLVMValueRef, - llvm_fn: unsafe extern "C" fn( - LLVMBuilderRef, - LLVMValueRef, - LLVMValueRef, - *const i8, - ) -> LLVMValueRef, - ) -> Result<(), TranslateError> { - let src1 = self.resolver.value(src1)?; - let shift_size = self.resolver.value(src2)?; - let integer_bits = type_.layout().size() * 8; - let integer_bits_constant = unsafe { - LLVMConstInt( - get_scalar_type(self.context, ast::ScalarType::U32), - integer_bits as u64, - 0, - ) - }; - let should_clamp = unsafe { - LLVMBuildICmp( - self.builder, - LLVMIntPredicate::LLVMIntUGE, - shift_size, - integer_bits_constant, - LLVM_UNNAMED.as_ptr(), - ) - }; - let llvm_type = get_scalar_type(self.context, type_); - let normalized_shift_size = if type_.layout().size() >= 4 { - unsafe { - LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) - } - } else { - unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } - }; - let shifted = unsafe { - llvm_fn( - self.builder, - src1, - normalized_shift_size, - LLVM_UNNAMED.as_ptr(), - ) - }; - self.resolver.with_result(dst, |dst| unsafe { - LLVMBuildSelect(self.builder, should_clamp, out_of_range_value, shifted, dst) - }); - Ok(()) - } - - fn emit_ex2( - &mut self, - data: ptx_parser::TypeFtz, - arguments: ptx_parser::Ex2Args, - ) -> Result<(), TranslateError> { - let intrinsic = match data.type_ { - ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16", - ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32", - _ => return Err(error_unreachable()), - }; - self.emit_intrinsic( - intrinsic, - Some(arguments.dst), - Some(&data.type_.into()), - vec![( - self.resolver.value(arguments.src)?, - get_scalar_type(self.context, data.type_), - )], - )?; - Ok(()) - } - - fn emit_lg2( - &mut self, - _data: ptx_parser::FlushToZero, - arguments: ptx_parser::Lg2Args, - ) -> Result<(), TranslateError> { - self.emit_intrinsic( - c"llvm.amdgcn.log.f32", - Some(arguments.dst), - Some(&ast::ScalarType::F32.into()), - vec![( - self.resolver.value(arguments.src)?, - get_scalar_type(self.context, ast::ScalarType::F32.into()), - )], - )?; - Ok(()) - } - - fn emit_selp( - &mut self, - _data: ptx_parser::ScalarType, - arguments: ptx_parser::SelpArgs, - ) -> Result<(), TranslateError> { - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - let src3 = self.resolver.value(arguments.src3)?; - self.resolver.with_result(arguments.dst, |dst_name| unsafe { - LLVMBuildSelect(self.builder, src3, src1, src2, dst_name) - }); - Ok(()) - } - - fn emit_rem( - &mut self, - data: ptx_parser::ScalarType, - arguments: ptx_parser::RemArgs, - ) -> Result<(), TranslateError> { - let llvm_fn = match data.kind() { - ptx_parser::ScalarKind::Unsigned => LLVMBuildURem, - ptx_parser::ScalarKind::Signed => LLVMBuildSRem, - _ => return Err(error_unreachable()), - }; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - self.resolver.with_result(arguments.dst, |dst_name| unsafe { - llvm_fn(self.builder, src1, src2, dst_name) - }); - Ok(()) - } - - fn emit_popc( - &mut self, - type_: ptx_parser::ScalarType, - arguments: ptx_parser::PopcArgs, - ) -> Result<(), TranslateError> { - let intrinsic = match type_ { - ast::ScalarType::B32 => c"llvm.ctpop.i32", - ast::ScalarType::B64 => c"llvm.ctpop.i64", - _ => return Err(error_unreachable()), - }; - let llvm_type = get_scalar_type(self.context, type_); - self.emit_intrinsic( - intrinsic, - Some(arguments.dst), - Some(&type_.into()), - vec![(self.resolver.value(arguments.src)?, llvm_type)], - )?; - Ok(()) - } - - fn emit_min( - &mut self, - data: ptx_parser::MinMaxDetails, - arguments: ptx_parser::MinArgs, - ) -> Result<(), TranslateError> { - let llvm_prefix = match data { - ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin", - ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin", - ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { - "llvm.minimum" - } - ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", - }; - let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); - let llvm_type = get_scalar_type(self.context, data.type_()); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(arguments.dst), - Some(&data.type_().into()), - vec![ - (self.resolver.value(arguments.src1)?, llvm_type), - (self.resolver.value(arguments.src2)?, llvm_type), - ], - )?; - Ok(()) - } - - fn emit_max( - &mut self, - data: ptx_parser::MinMaxDetails, - arguments: ptx_parser::MaxArgs, - ) -> Result<(), TranslateError> { - let llvm_prefix = match data { - ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax", - ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax", - ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { - "llvm.maximum" - } - ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", - }; - let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); - let llvm_type = get_scalar_type(self.context, data.type_()); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(arguments.dst), - Some(&data.type_().into()), - vec![ - (self.resolver.value(arguments.src1)?, llvm_type), - (self.resolver.value(arguments.src2)?, llvm_type), - ], - )?; - Ok(()) - } - - fn emit_fma( - &mut self, - data: ptx_parser::ArithFloat, - arguments: ptx_parser::FmaArgs, - ) -> Result<(), TranslateError> { - let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_)); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(arguments.dst), - Some(&data.type_.into()), - vec![ - ( - self.resolver.value(arguments.src1)?, - get_scalar_type(self.context, data.type_), - ), - ( - self.resolver.value(arguments.src2)?, - get_scalar_type(self.context, data.type_), - ), - ( - self.resolver.value(arguments.src3)?, - get_scalar_type(self.context, data.type_), - ), - ], - )?; - Ok(()) - } - - fn emit_mad( - &mut self, - data: ptx_parser::MadDetails, - arguments: ptx_parser::MadArgs, - ) -> Result<(), TranslateError> { - let mul_control = match data { - ptx_parser::MadDetails::Float(mad_float) => { - return self.emit_fma( - mad_float, - ast::FmaArgs { - dst: arguments.dst, - src1: arguments.src1, - src2: arguments.src2, - src3: arguments.src3, - }, - ) - } - ptx_parser::MadDetails::Integer { - saturate: true, - control: ast::MulIntControl::High, - type_: ast::ScalarType::S32, - } => { - return self.emit_mad_hi_sat_s32( - arguments.dst, - (arguments.src1, arguments.src2, arguments.src3), - ); - } - ptx_parser::MadDetails::Integer { saturate: true, .. } => { - return Err(error_unreachable()) - } - ptx_parser::MadDetails::Integer { type_, control, .. } => { - ast::MulDetails::Integer { control, type_ } - } - }; - let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?; - let src3 = self.resolver.value(arguments.src3)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildAdd(self.builder, temp, src3, dst) - }); - Ok(()) - } - - fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> { - unsafe { - LLVMZludaBuildFence( - self.builder, - LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent, - get_scope_membar(data)?, - LLVM_UNNAMED.as_ptr(), - ) - }; - Ok(()) - } - - fn emit_prmt( - &mut self, - control: u16, - arguments: ptx_parser::PrmtArgs, - ) -> Result<(), TranslateError> { - let components = [ - (control >> 0) & 0b1111, - (control >> 4) & 0b1111, - (control >> 8) & 0b1111, - (control >> 12) & 0b1111, - ]; - if components.iter().any(|&c| c > 7) { - return Err(TranslateError::Todo("".to_string())); - } - let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); - let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; - let mut components = [ - unsafe { LLVMConstInt(u32_type, components[0] as _, 0) }, - unsafe { LLVMConstInt(u32_type, components[1] as _, 0) }, - unsafe { LLVMConstInt(u32_type, components[2] as _, 0) }, - unsafe { LLVMConstInt(u32_type, components[3] as _, 0) }, - ]; - let components_indices = - unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) }; - let src1 = self.resolver.value(arguments.src1)?; - let src1_vector = - unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) }; - let src2 = self.resolver.value(arguments.src2)?; - let src2_vector = - unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) }; - self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildShuffleVector( - self.builder, - src1_vector, - src2_vector, - components_indices, - dst, - ) - }); - Ok(()) - } - - fn emit_abs( - &mut self, - data: ast::TypeFtz, - arguments: ptx_parser::AbsArgs, - ) -> Result<(), TranslateError> { - let llvm_type = get_scalar_type(self.context, data.type_); - let src = self.resolver.value(arguments.src)?; - let is_floating_point = data.type_.kind() == ast::ScalarKind::Float; - let (prefix, intrinsic_arguments) = if is_floating_point { - ("llvm.fabs", vec![(src, llvm_type)]) - } else { - let pred = get_scalar_type(self.context, ast::ScalarType::Pred); - let zero = unsafe { LLVMConstInt(pred, 0, 0) }; - ("llvm.abs", vec![(src, llvm_type), (zero, pred)]) - }; - let llvm_intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(data.type_)); - let abs_result = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) }, - None, - Some(&data.type_.into()), - intrinsic_arguments, - )?; - if is_floating_point && data.flush_to_zero == Some(true) { - let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_)); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(arguments.dst), - Some(&data.type_.into()), - vec![(abs_result, llvm_type)], - )?; - } else { - self.resolver.register(arguments.dst, abs_result); - } - Ok(()) - } - - fn emit_mul24( - &mut self, - data: ast::Mul24Details, - arguments: ast::Mul24Args, - ) -> Result<(), TranslateError> { - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; - let name_lo = match data.type_ { - ast::ScalarType::U32 => c"llvm.amdgcn.mul.u24", - ast::ScalarType::S32 => c"llvm.amdgcn.mul.i24", - _ => return Err(error_unreachable()), - }; - let res_lo = self.emit_intrinsic( - name_lo, - if data.control == Mul24Control::Lo { - Some(arguments.dst) - } else { - None - }, - Some(&ast::Type::Scalar(data.type_)), - vec![ - (src1, get_scalar_type(self.context, data.type_)), - (src2, get_scalar_type(self.context, data.type_)), - ], - )?; - if data.control == Mul24Control::Hi { - // There is an important difference between NVIDIA's mul24.hi and AMD's mulhi.[ui]24. - // NVIDIA: Returns bits 47..16 of the 64-bit result - // AMD: Returns bits 63..32 of the 64-bit result - // Hence we need to compute both hi and lo, shift the results and add them together to replicate NVIDIA's mul24 - let name_hi = match data.type_ { - ast::ScalarType::U32 => c"llvm.amdgcn.mulhi.u24", - ast::ScalarType::S32 => c"llvm.amdgcn.mulhi.i24", - _ => return Err(error_unreachable()), - }; - let res_hi = self.emit_intrinsic( - name_hi, - None, - Some(&ast::Type::Scalar(data.type_)), - vec![ - (src1, get_scalar_type(self.context, data.type_)), - (src2, get_scalar_type(self.context, data.type_)), - ], - )?; - let shift_number = unsafe { LLVMConstInt(LLVMInt32TypeInContext(self.context), 16, 0) }; - let res_lo_shr = - unsafe { LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr()) }; - let res_hi_shl = - unsafe { LLVMBuildShl(self.builder, res_hi, shift_number, LLVM_UNNAMED.as_ptr()) }; - - self.resolver - .with_result(arguments.dst, |dst: *const i8| unsafe { - LLVMBuildOr(self.builder, res_lo_shr, res_hi_shl, dst) - }); - } - Ok(()) - } - - fn emit_set_mode(&mut self, mode_reg: ModeRegister) -> Result<(), TranslateError> { - fn hwreg(reg: u32, offset: u32, size: u32) -> u32 { - reg | (offset << 6) | ((size - 1) << 11) - } - fn denormal_to_value(ftz: bool) -> u32 { - if ftz { - 0 - } else { - 3 - } - } - fn rounding_to_value(ftz: ast::RoundingMode) -> u32 { - match ftz { - ptx_parser::RoundingMode::NearestEven => 0, - ptx_parser::RoundingMode::Zero => 3, - ptx_parser::RoundingMode::NegativeInf => 2, - ptx_parser::RoundingMode::PositiveInf => 1, - } - } - fn merge_regs(f32: u32, f16f64: u32) -> u32 { - f32 | f16f64 << 2 - } - let intrinsic = c"llvm.amdgcn.s.setreg"; - let (hwreg, value) = match mode_reg { - ModeRegister::Denormal { f32, f16f64 } => { - let hwreg = hwreg(1, 4, 4); - let f32 = denormal_to_value(f32); - let f16f64 = denormal_to_value(f16f64); - let value = merge_regs(f32, f16f64); - (hwreg, value) - } - ModeRegister::Rounding { f32, f16f64 } => { - let hwreg = hwreg(1, 0, 4); - let f32 = rounding_to_value(f32); - let f16f64 = rounding_to_value(f16f64); - let value = merge_regs(f32, f16f64); - (hwreg, value) - } - }; - let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32); - let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) }; - let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) }; - self.emit_intrinsic( - intrinsic, - None, - None, - vec![(hwreg_llvm, llvm_i32), (value_llvm, llvm_i32)], - )?; - Ok(()) - } - - fn emit_fp_saturate( - &mut self, - type_: ast::ScalarType, - dst: SpirvWord, - src: SpirvWord, - ) -> Result<(), TranslateError> { - let llvm_type = get_scalar_type(self.context, type_); - let zero = unsafe { LLVMConstReal(llvm_type, 0.0) }; - let one = unsafe { LLVMConstReal(llvm_type, 1.0) }; - let maxnum_intrinsic = format!("llvm.maxnum.{}\0", LLVMTypeDisplay(type_)); - let minnum_intrinsic = format!("llvm.minnum.{}\0", LLVMTypeDisplay(type_)); - let src = self.resolver.value(src)?; - let maxnum = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(maxnum_intrinsic.as_bytes()) }, - None, - Some(&type_.into()), - vec![(src, llvm_type), (zero, llvm_type)], - )?; - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(minnum_intrinsic.as_bytes()) }, - Some(dst), - Some(&type_.into()), - vec![(maxnum, llvm_type), (one, llvm_type)], - )?; - Ok(()) - } - - fn emit_intrinsic_saturate( - &mut self, - op: &str, - type_: ast::ScalarType, - dst: SpirvWord, - src1: SpirvWord, - src2: SpirvWord, - ) -> Result<(), TranslateError> { - let llvm_type = get_scalar_type(self.context, type_); - let src1 = self.resolver.value(src1)?; - let src2 = self.resolver.value(src2)?; - let intrinsic = format!("llvm.{}.sat.{}\0", op, LLVMTypeDisplay(type_)); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(dst), - Some(&type_.into()), - vec![(src1, llvm_type), (src2, llvm_type)], - )?; - Ok(()) - } - - fn emit_cp_async( - &mut self, - data: CpAsyncDetails, - arguments: CpAsyncArgs, - ) -> Result<(), TranslateError> { - // Asynchronous copies are not supported by all AMD hardware, so we just do a synchronous copy for now - let to = self.resolver.value(arguments.src_to)?; - let from = self.resolver.value(arguments.src_from)?; - let cp_size = data.cp_size; - let src_size = data.src_size.unwrap_or(cp_size.as_u64()); - - let from_type = unsafe { LLVMIntTypeInContext(self.context, (src_size as u32) * 8) }; - - let to_type = match cp_size { - ptx_parser::CpAsyncCpSize::Bytes4 => unsafe { LLVMInt32TypeInContext(self.context) }, - ptx_parser::CpAsyncCpSize::Bytes8 => unsafe { LLVMInt64TypeInContext(self.context) }, - ptx_parser::CpAsyncCpSize::Bytes16 => unsafe { LLVMInt128TypeInContext(self.context) }, - }; - - let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) }; - unsafe { - LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8); - } - - let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) }; - - unsafe { LLVMBuildStore(self.builder, extended, to) }; - unsafe { - LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8); - } - Ok(()) - } - - - fn flush_denormals( - &mut self, - type_: ptx_parser::ScalarType, - src: SpirvWord, - dst: SpirvWord, - ) -> Result<(), TranslateError> { - let llvm_type = get_scalar_type(self.context, type_); - let src = self.resolver.value(src)?; - let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(type_)); - self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(dst), - Some(&type_.into()), - vec![(src, llvm_type)], - )?; - Ok(()) - } - - fn emit_mad_hi_sat_s32( - &mut self, - dst: SpirvWord, - (src1, src2, src3): (SpirvWord, SpirvWord, SpirvWord), - ) -> Result<(), TranslateError> { - let src1 = self.resolver.value(src1)?; - let src2 = self.resolver.value(src2)?; - let src3 = self.resolver.value(src3)?; - let llvm_type_s32 = get_scalar_type(self.context, ast::ScalarType::S32); - let llvm_type_s64 = get_scalar_type(self.context, ast::ScalarType::S64); - let src1_wide = - unsafe { LLVMBuildSExt(self.builder, src1, llvm_type_s64, LLVM_UNNAMED.as_ptr()) }; - let src2_wide = - unsafe { LLVMBuildSExt(self.builder, src2, llvm_type_s64, LLVM_UNNAMED.as_ptr()) }; - let mul_wide = - unsafe { LLVMBuildMul(self.builder, src1_wide, src2_wide, LLVM_UNNAMED.as_ptr()) }; - let const_32 = unsafe { LLVMConstInt(llvm_type_s64, 32, 0) }; - let mul_wide = - unsafe { LLVMBuildLShr(self.builder, mul_wide, const_32, LLVM_UNNAMED.as_ptr()) }; - let mul_narrow = - unsafe { LLVMBuildTrunc(self.builder, mul_wide, llvm_type_s32, LLVM_UNNAMED.as_ptr()) }; - self.emit_intrinsic( - c"llvm.sadd.sat.i32", - Some(dst), - Some(&ast::ScalarType::S32.into()), - vec![(mul_narrow, llvm_type_s32), (src3, llvm_type_s32)], - )?; - Ok(()) - } - - fn emit_set( - &mut self, - data: ptx_parser::SetData, - arguments: ptx_parser::SetArgs, - ) -> Result<(), TranslateError> { - let setp_result = self.emit_setp_impl(data.base, None, arguments.src1, arguments.src2)?; - self.setp_to_set(arguments.dst, data.dtype, setp_result)?; - Ok(()) - } - - fn emit_set_bool( - &mut self, - data: ptx_parser::SetBoolData, - arguments: ptx_parser::SetBoolArgs, - ) -> Result<(), TranslateError> { - let result = - self.emit_setp_bool_impl(data.base, arguments.src1, arguments.src2, arguments.src3)?; - self.setp_to_set(arguments.dst, data.dtype, result)?; - Ok(()) - } - - fn emit_setp_bool( - &mut self, - data: ast::SetpBoolData, - args: ast::SetpBoolArgs, - ) -> Result<(), TranslateError> { - let dst = self.emit_setp_bool_impl(data, args.src1, args.src2, args.src3)?; - self.resolver.register(args.dst1, dst); - Ok(()) - } - - fn emit_setp_bool_impl( - &mut self, - data: ptx_parser::SetpBoolData, - src1: SpirvWord, - src2: SpirvWord, - src3: SpirvWord, - ) -> Result { - let bool_result = self.emit_setp_impl(data.base, None, src1, src2)?; - let bool_result = if data.negate_src3 { - let constant = - unsafe { LLVMConstInt(LLVMIntTypeInContext(self.context, 1), u64::MAX, 0) }; - unsafe { LLVMBuildXor(self.builder, bool_result, constant, LLVM_UNNAMED.as_ptr()) } - } else { - bool_result - }; - let post_op = match data.bool_op { - ptx_parser::SetpBoolPostOp::Xor => LLVMBuildXor, - ptx_parser::SetpBoolPostOp::And => LLVMBuildAnd, - ptx_parser::SetpBoolPostOp::Or => LLVMBuildOr, - }; - let src3 = self.resolver.value(src3)?; - Ok(unsafe { post_op(self.builder, bool_result, src3, LLVM_UNNAMED.as_ptr()) }) - } - - fn setp_to_set( - &mut self, - dst: SpirvWord, - dtype: ast::ScalarType, - setp_result: LLVMValueRef, - ) -> Result<(), TranslateError> { - let llvm_dtype = get_scalar_type(self.context, dtype); - let zero = unsafe { LLVMConstNull(llvm_dtype) }; - let one = if dtype.kind() == ast::ScalarKind::Float { - unsafe { LLVMConstReal(llvm_dtype, 1.0) } - } else { - unsafe { LLVMConstInt(llvm_dtype, u64::MAX, 0) } - }; - self.resolver.with_result(dst, |dst| unsafe { - LLVMBuildSelect(self.builder, setp_result, one, zero, dst) - }); - Ok(()) - } - - // TODO: revisit this on gfx1250 which has native tanh support - fn emit_tanh( - &mut self, - data: ast::ScalarType, - arguments: ast::TanhArgs, - ) -> Result<(), TranslateError> { - let src = self.resolver.value(arguments.src)?; - let llvm_type = get_scalar_type(self.context, data); - let name = format!("__ocml_tanh_{}\0", LLVMTypeDisplay(data)); - let tanh = self.emit_intrinsic( - unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) }, - Some(arguments.dst), - Some(&data.into()), - vec![(src, llvm_type)], - )?; - // Not sure if it ultimately does anything - unsafe { LLVMZludaSetFastMathFlags(tanh, LLVMZludaFastMathApproxFunc) } - Ok(()) - } - - /* - // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` - // Should be available in LLVM 19 - fn with_rounding(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T { - let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32); - let void_type = unsafe { LLVMVoidTypeInContext(self.context) }; - let get_rounding = c"llvm.get.rounding"; - let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) }; - let mut get_rounding_fn = - unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) }; - if get_rounding_fn == ptr::null_mut() { - get_rounding_fn = unsafe { - LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type) - }; - } - let set_rounding = c"llvm.set.rounding"; - let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) }; - let mut set_rounding_fn = - unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) }; - if set_rounding_fn == ptr::null_mut() { - set_rounding_fn = unsafe { - LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type) - }; - } - let mut preserved_rounding_mode = unsafe { - LLVMBuildCall2( - self.builder, - get_rounding_fn_type, - get_rounding_fn, - ptr::null_mut(), - 0, - LLVM_UNNAMED.as_ptr(), - ) - }; - let mut requested_rounding = unsafe { - LLVMConstInt( - get_scalar_type(self.context, ast::ScalarType::B32), - rounding_to_llvm(rnd) as u64, - 0, - ) - }; - unsafe { - LLVMBuildCall2( - self.builder, - set_rounding_fn_type, - set_rounding_fn, - &mut requested_rounding, - 1, - LLVM_UNNAMED.as_ptr(), - ) - }; - let result = fn_(self); - unsafe { - LLVMBuildCall2( - self.builder, - set_rounding_fn_type, - set_rounding_fn, - &mut preserved_rounding_mode, - 1, - LLVM_UNNAMED.as_ptr(), - ) - }; - result - } - */ -} - -fn get_pointer_type<'ctx>( - context: LLVMContextRef, - to_space: ast::StateSpace, -) -> Result { - Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) }) -} - -// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes -fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> { - Ok(match scope { - ast::MemScope::Cta => c"workgroup-one-as", - ast::MemScope::Gpu => c"agent-one-as", - ast::MemScope::Sys => c"one-as", - ast::MemScope::Cluster => todo!(), - } - .as_ptr()) -} - -fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> { - Ok(match scope { - ast::MemScope::Cta => c"workgroup", - ast::MemScope::Gpu => c"agent", - ast::MemScope::Sys => c"", - ast::MemScope::Cluster => todo!(), - } - .as_ptr()) -} - -fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { - match semantics { - ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, - ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, - ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease, - ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease, - } -} - -fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { - match semantics { - ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, - ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, - ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, - ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, - } -} - -fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result { - Ok(match type_ { - ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar), - ast::Type::Vector(size, scalar) => { - let base_type = get_scalar_type(context, *scalar); - unsafe { LLVMVectorType(base_type, *size as u32) } - } - ast::Type::Array(vec, scalar, dimensions) => { - let mut underlying_type = get_scalar_type(context, *scalar); - if let Some(size) = vec { - underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) }; - } - if dimensions.is_empty() { - return Ok(unsafe { LLVMArrayType2(underlying_type, 0) }); - } - dimensions - .iter() - .rfold(underlying_type, |result, dimension| unsafe { - LLVMArrayType2(result, *dimension as u64) - }) - } - }) -} - -fn get_array_type<'a>( - context: LLVMContextRef, - elem_type: &'a ast::Type, - count: u64, -) -> Result { - let elem_type = get_type(context, elem_type)?; - Ok(unsafe { LLVMArrayType2(elem_type, count) }) -} - -fn check_multiple_return_types<'a>( - mut return_args: impl ExactSizeIterator, -) -> Result<(), TranslateError> { - let err_msg = "Only (.b32, .pred) multiple return types are supported"; - - let first = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?; - let second = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?; - match (first, second) { - (ast::Type::Scalar(first), ast::Type::Scalar(second)) => { - if first.size_of() != 4 || second.size_of() != 1 { - return Err(error_todo_msg(err_msg)); - } - } - _ => return Err(error_todo_msg(err_msg)), - } - Ok(()) -} - -fn get_function_type<'a>( - context: LLVMContextRef, - mut return_args: impl ExactSizeIterator, - input_args: impl ExactSizeIterator>, -) -> Result { - let mut input_args = input_args.collect::, _>>()?; - let return_type = match return_args.len() { - 0 => unsafe { LLVMVoidTypeInContext(context) }, - 1 => get_type(context, &return_args.next().unwrap())?, - _ => { - check_multiple_return_types(return_args)?; - get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)? - } - }; - - Ok(unsafe { - LLVMFunctionType( - return_type, - input_args.as_mut_ptr(), - input_args.len() as u32, - 0, - ) - }) -} - -struct ResolveIdent { - words: HashMap, - values: HashMap, -} - -impl ResolveIdent { - fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self { - ResolveIdent { - words: HashMap::new(), - values: HashMap::new(), - } - } - - fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T { - let str = match self.words.entry(word) { - hash_map::Entry::Occupied(entry) => entry.into_mut(), - hash_map::Entry::Vacant(entry) => { - let mut text = word.0.to_string(); - text.push('\0'); - entry.insert(text) - } - }; - fn_(&str[..str.len() - 1]) - } - - fn get_or_add(&mut self, word: SpirvWord) -> &str { - self.get_or_ad_impl(word, |x| x) - } - - fn get_or_add_raw(&mut self, word: SpirvWord) -> *const i8 { - self.get_or_add(word).as_ptr().cast() - } - - fn register(&mut self, word: SpirvWord, v: LLVMValueRef) { - self.values.insert(word, v); - } - - fn value(&self, word: SpirvWord) -> Result { - self.values - .get(&word) - .copied() - .ok_or_else(|| error_unreachable()) - } - - fn with_result( - &mut self, - word: SpirvWord, - fn_: impl FnOnce(*const i8) -> LLVMValueRef, - ) -> LLVMValueRef { - let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast())); - self.register(word, t); - t - } - - fn with_result_option( - &mut self, - word: Option, - fn_: impl FnOnce(*const i8) -> LLVMValueRef, - ) -> LLVMValueRef { - match word { - Some(word) => self.with_result(word, fn_), - None => fn_(LLVM_UNNAMED.as_ptr()), - } - } -} - -struct LLVMTypeDisplay(ast::ScalarType); - -impl std::fmt::Display for LLVMTypeDisplay { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.0 { - ast::ScalarType::Pred => write!(f, "i1"), - ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"), - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"), - ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"), - ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"), - ptx_parser::ScalarType::B128 => write!(f, "i128"), - ast::ScalarType::F16 => write!(f, "f16"), - ptx_parser::ScalarType::BF16 => write!(f, "bfloat"), - ast::ScalarType::F32 => write!(f, "f32"), - ast::ScalarType::F64 => write!(f, "f64"), - ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"), - ast::ScalarType::F16x2 => write!(f, "v2f16"), - ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"), - } - } -} - -/* -fn rounding_to_llvm(this: ast::RoundingMode) -> u32 { - match this { - ptx_parser::RoundingMode::Zero => 0, - ptx_parser::RoundingMode::NearestEven => 1, - ptx_parser::RoundingMode::PositiveInf => 2, - ptx_parser::RoundingMode::NegativeInf => 3, - } -} -*/ +// We use Raw LLVM-C bindings here because using inkwell is just not worth it. +// Specifically the issue is with builder functions. We maintain the mapping +// between ZLUDA identifiers and LLVM values. When using inkwell, LLVM values +// are kept as instances `AnyValueEnum`. Now look at the signature of +// `Builder::build_int_add(...)`: +// pub fn build_int_add>(&self, lhs: T, rhs: T, name: &str, ) -> Result +// At this point both lhs and rhs are `AnyValueEnum`. To call +// `build_int_add(...)` we would have to do something like this: +// if let (Ok(lhs), Ok(rhs)) = (lhs.as_int(), rhs.as_int()) { +// builder.build_int_add(lhs, rhs, dst)?; +// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_pointer(), rhs.as_pointer()) { +// builder.build_int_add(lhs, rhs, dst)?; +// } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_vector(), rhs.as_vector()) { +// builder.build_int_add(lhs, rhs, dst)?; +// } else { +// return Err(error_unrachable()); +// } +// while with plain LLVM-C it's just: +// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) }; + +// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete. +// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with +// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all" +// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel", +// but it will too fail similarly, but with "unable to legalize instruction" + +use std::array::TryFromSliceError; +use std::convert::TryInto; +use std::ffi::{CStr, NulError}; +use std::{i8, ptr, u64}; + +use super::*; +use crate::pass::*; +use llvm_zluda::{core::*, *}; +use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; +use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; +use ptx_parser::{CpAsyncArgs, CpAsyncDetails, Mul24Control}; + +struct Builder(LLVMBuilderRef); + +impl Builder { + fn new(ctx: &Context) -> Self { + Self::new_raw(ctx.get()) + } + + fn new_raw(ctx: LLVMContextRef) -> Self { + Self(unsafe { LLVMCreateBuilderInContext(ctx) }) + } + + fn get(&self) -> LLVMBuilderRef { + self.0 + } +} + +impl Drop for Builder { + fn drop(&mut self) { + unsafe { + LLVMDisposeBuilder(self.0); + } + } +} + +pub(crate) fn run<'input>( + context: &Context, + id_defs: GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result { + let module = llvm::Module::new(context, LLVM_UNNAMED); + let mut emit_ctx = ModuleEmitContext::new(context, &module, &id_defs); + for directive in directives { + match directive { + Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?, + Directive2::Method(method) => emit_ctx.emit_method(method)?, + } + } + if let Err(err) = module.verify() { + panic!("{:?}", err); + } + Ok(module) +} + +struct ModuleEmitContext<'a, 'input> { + context: LLVMContextRef, + module: LLVMModuleRef, + builder: Builder, + id_defs: &'a GlobalStringIdentResolver2<'input>, + resolver: ResolveIdent, +} + +impl<'a, 'input> ModuleEmitContext<'a, 'input> { + fn new( + context: &Context, + module: &llvm::Module, + id_defs: &'a GlobalStringIdentResolver2<'input>, + ) -> Self { + ModuleEmitContext { + context: context.get(), + module: module.get(), + builder: Builder::new(context), + id_defs, + resolver: ResolveIdent::new(&id_defs), + } + } + + fn kernel_call_convention() -> u32 { + LLVMCallConv::LLVMAMDGPUKERNELCallConv as u32 + } + + fn func_call_convention() -> u32 { + LLVMCallConv::LLVMCCallConv as u32 + } + + fn emit_method( + &mut self, + method: Function2, SpirvWord>, + ) -> Result<(), TranslateError> { + let name = method + .import_as + .as_deref() + .or_else(|| self.id_defs.ident_map[&method.name].name.as_deref()) + .ok_or_else(|| error_unreachable())?; + let name = CString::new(name).map_err(|_| error_unreachable())?; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; + if fn_ == ptr::null_mut() { + let fn_type = get_function_type( + self.context, + method.return_arguments.iter().map(|v| &v.v_type), + method + .input_arguments + .iter() + .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), + )?; + fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true"); + self.emit_fn_attribute(fn_, "uniform-work-group-size", "true"); + self.emit_fn_attribute(fn_, "no-trapping-math", "true"); + } + if !method.is_kernel { + self.resolver.register(method.name, fn_); + self.emit_fn_attribute(fn_, "denormal-fp-math-f32", "dynamic"); + self.emit_fn_attribute(fn_, "denormal-fp-math", "dynamic"); + } else { + self.emit_fn_attribute( + fn_, + "denormal-fp-math-f32", + llvm_ftz(method.flush_to_zero_f32), + ); + self.emit_fn_attribute( + fn_, + "denormal-fp-math", + llvm_ftz(method.flush_to_zero_f16f64), + ); + } + for (i, param) in method.input_arguments.iter().enumerate() { + let value = unsafe { LLVMGetParam(fn_, i as u32) }; + let name = self.resolver.get_or_add(param.name); + unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; + self.resolver.register(param.name, value); + if method.is_kernel { + let attr_kind = unsafe { + LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len()) + }; + let attr = unsafe { + LLVMCreateTypeAttribute( + self.context, + attr_kind, + get_type(self.context, ¶m.v_type)?, + ) + }; + unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; + } + } + let call_conv = if method.is_kernel { + Self::kernel_call_convention() + } else { + Self::func_call_convention() + }; + unsafe { LLVMSetFunctionCallConv(fn_, call_conv) }; + if let Some(statements) = method.body { + let variables_bb = + unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; + let variables_builder = Builder::new_raw(self.context); + unsafe { LLVMPositionBuilderAtEnd(variables_builder.get(), variables_bb) }; + let real_bb = + unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; + unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) }; + let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); + for var in method.return_arguments { + method_emitter.emit_variable(var)?; + } + for statement in statements.iter() { + if let Statement::Label(label) = statement { + method_emitter.emit_label_initial(*label); + } + } + let mut statements = statements.into_iter(); + if let Some(Statement::Label(label)) = statements.next() { + method_emitter.emit_label_delayed(label)?; + } else { + return Err(error_unreachable()); + } + method_emitter.emit_kernel_rounding_prelude( + method.is_kernel, + method.rounding_mode_f32, + method.rounding_mode_f16f64, + )?; + for statement in statements { + method_emitter.emit_statement(statement)?; + } + unsafe { LLVMBuildBr(method_emitter.variables_builder.get(), real_bb) }; + } + Ok(()) + } + + fn emit_global( + &mut self, + _linking: ast::LinkingDirective, + var: ast::Variable, + ) -> Result<(), TranslateError> { + let name = self + .id_defs + .ident_map + .get(&var.name) + .map(|entry| { + entry + .name + .as_ref() + .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?))) + }) + .flatten() + .transpose() + .map_err(|_| error_unreachable())? + .unwrap_or(Cow::Borrowed(LLVM_UNNAMED)); + let global = unsafe { + LLVMAddGlobalInAddressSpace( + self.module, + get_type(self.context, &var.v_type)?, + name.as_ptr(), + get_state_space(var.state_space)?, + ) + }; + self.resolver.register(var.name, global); + if let Some(align) = var.align { + unsafe { LLVMSetAlignment(global, align) }; + } + if !var.array_init.is_empty() { + self.emit_array_init(&var.v_type, &*var.array_init, global)?; + } + Ok(()) + } + + // TODO: instead of Vec we should emit a typed initializer + fn emit_array_init( + &mut self, + type_: &ast::Type, + array_init: &[u8], + global: *mut llvm_zluda::LLVMValue, + ) -> Result<(), TranslateError> { + match type_ { + ast::Type::Array(None, scalar, dimensions) => { + if dimensions.len() != 1 { + todo!() + } + if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() { + return Err(error_unreachable()); + } + let type_ = get_scalar_type(self.context, *scalar); + let mut elements = array_init + .chunks(scalar.size_of() as usize) + .map(|chunk| self.constant_from_bytes(*scalar, chunk, type_)) + .collect::, _>>() + .map_err(|_| error_unreachable())?; + let initializer = + unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) }; + unsafe { LLVMSetInitializer(global, initializer) }; + } + _ => todo!(), + } + Ok(()) + } + + fn constant_from_bytes( + &self, + scalar: ast::ScalarType, + bytes: &[u8], + llvm_type: LLVMTypeRef, + ) -> Result { + Ok(match scalar { + ptx_parser::ScalarType::Pred + | ptx_parser::ScalarType::S8 + | ptx_parser::ScalarType::B8 + | ptx_parser::ScalarType::U8 => unsafe { + LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0) + }, + ptx_parser::ScalarType::S16 + | ptx_parser::ScalarType::B16 + | ptx_parser::ScalarType::U16 => unsafe { + LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) + }, + ptx_parser::ScalarType::S32 + | ptx_parser::ScalarType::B32 + | ptx_parser::ScalarType::U32 => unsafe { + LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0) + }, + ptx_parser::ScalarType::F16 => todo!(), + ptx_parser::ScalarType::BF16 => todo!(), + ptx_parser::ScalarType::U64 => todo!(), + ptx_parser::ScalarType::S64 => todo!(), + ptx_parser::ScalarType::S16x2 => todo!(), + ptx_parser::ScalarType::F32 => todo!(), + ptx_parser::ScalarType::B64 => todo!(), + ptx_parser::ScalarType::F64 => todo!(), + ptx_parser::ScalarType::B128 => todo!(), + ptx_parser::ScalarType::U16x2 => todo!(), + ptx_parser::ScalarType::F16x2 => todo!(), + ptx_parser::ScalarType::BF16x2 => todo!(), + }) + } + + fn emit_fn_attribute(&self, llvm_object: LLVMValueRef, key: &str, value: &str) { + let attribute = unsafe { + LLVMCreateStringAttribute( + self.context, + key.as_ptr() as _, + key.len() as u32, + value.as_ptr() as _, + value.len() as u32, + ) + }; + unsafe { LLVMAddAttributeAtIndex(llvm_object, LLVMAttributeFunctionIndex, attribute) }; + } +} + +fn llvm_ftz(ftz: bool) -> &'static str { + if ftz { + "preserve-sign" + } else { + "ieee" + } +} + +fn get_input_argument_type( + context: LLVMContextRef, + v_type: &ast::Type, + state_space: ast::StateSpace, +) -> Result { + match state_space { + ast::StateSpace::ParamEntry => { + Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) }) + } + ast::StateSpace::Reg => get_type(context, v_type), + _ => return Err(error_unreachable()), + } +} + +struct MethodEmitContext<'a> { + context: LLVMContextRef, + module: LLVMModuleRef, + method: LLVMValueRef, + builder: LLVMBuilderRef, + variables_builder: Builder, + resolver: &'a mut ResolveIdent, +} + +impl<'a> MethodEmitContext<'a> { + fn new( + parent: &'a mut ModuleEmitContext, + method: LLVMValueRef, + variables_builder: Builder, + ) -> MethodEmitContext<'a> { + MethodEmitContext { + context: parent.context, + module: parent.module, + builder: parent.builder.get(), + variables_builder, + resolver: &mut parent.resolver, + method, + } + } + + fn emit_statement( + &mut self, + statement: Statement, SpirvWord>, + ) -> Result<(), TranslateError> { + Ok(match statement { + Statement::Variable(var) => self.emit_variable(var)?, + Statement::Label(label) => self.emit_label_delayed(label)?, + Statement::Instruction(inst) => self.emit_instruction(inst)?, + Statement::Conditional(cond) => self.emit_conditional(cond)?, + Statement::Conversion(conversion) => self.emit_conversion(conversion)?, + Statement::Constant(constant) => self.emit_constant(constant)?, + Statement::RetValue(_, values) => self.emit_ret_value(values)?, + Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?, + Statement::RepackVector(repack) => self.emit_vector_repack(repack)?, + Statement::FunctionPointer(_) => todo!(), + Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, + Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, + Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?, + Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?, + }) + } + + // This should be a kernel attribute, but sadly AMDGPU LLVM target does + // not support attribute for it. So we have to set it as the first + // instruction in the body of a kernel + fn emit_kernel_rounding_prelude( + &mut self, + is_kernel: bool, + rounding_mode_f32: ast::RoundingMode, + rounding_mode_f16f64: ast::RoundingMode, + ) -> Result<(), TranslateError> { + if is_kernel { + if rounding_mode_f32 != ast::RoundingMode::NearestEven + || rounding_mode_f16f64 != ast::RoundingMode::NearestEven + { + self.emit_set_mode(ModeRegister::Rounding { + f32: rounding_mode_f32, + f16f64: rounding_mode_f16f64, + })?; + } + } + Ok(()) + } + + fn emit_variable(&mut self, var: ast::Variable) -> Result<(), TranslateError> { + let alloca = unsafe { + LLVMZludaBuildAlloca( + self.variables_builder.get(), + get_type(self.context, &var.v_type)?, + get_state_space(var.state_space)?, + self.resolver.get_or_add_raw(var.name), + ) + }; + self.resolver.register(var.name, alloca); + if let Some(align) = var.align { + unsafe { LLVMSetAlignment(alloca, align) }; + } + if !var.array_init.is_empty() { + todo!() + } + Ok(()) + } + + fn emit_label_initial(&mut self, label: SpirvWord) { + let block = unsafe { + LLVMAppendBasicBlockInContext( + self.context, + self.method, + self.resolver.get_or_add_raw(label), + ) + }; + self.resolver + .register(label, unsafe { LLVMBasicBlockAsValue(block) }); + } + + fn emit_label_delayed(&mut self, label: SpirvWord) -> Result<(), TranslateError> { + let block = self.resolver.value(label)?; + let block = unsafe { LLVMValueAsBasicBlock(block) }; + let last_block = unsafe { LLVMGetInsertBlock(self.builder) }; + if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() { + unsafe { LLVMBuildBr(self.builder, block) }; + } + unsafe { LLVMPositionBuilderAtEnd(self.builder, block) }; + Ok(()) + } + + fn emit_instruction( + &mut self, + inst: ast::Instruction, + ) -> Result<(), TranslateError> { + match inst { + ast::Instruction::Mov { data: _, arguments } => self.emit_mov(arguments), + ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), + ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), + ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), + ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments), + ast::Instruction::Mul24 { data, arguments } => self.emit_mul24(data, arguments), + ast::Instruction::Set { data, arguments } => self.emit_set(data, arguments), + ast::Instruction::SetBool { data, arguments } => self.emit_set_bool(data, arguments), + ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments), + ast::Instruction::SetpBool { data, arguments } => self.emit_setp_bool(data, arguments), + ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments), + ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments), + ast::Instruction::And { arguments, .. } => self.emit_and(arguments), + ast::Instruction::Bra { arguments } => self.emit_bra(arguments), + ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), + ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments), + ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments), + ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments), + ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), + ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), + ast::Instruction::Abs { data, arguments } => self.emit_abs(data, arguments), + ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments), + ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments), + ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), + ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments), + ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments), + ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments), + ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments), + ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments), + ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments), + ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), + ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments), + ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments), + ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments), + ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments), + ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments), + ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments), + ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments), + ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments), + ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments), + ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments), + ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments), + ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments), + ast::Instruction::PrmtSlow { .. } => { + Err(error_todo_msg("PrmtSlow is not implemented yet")) + } + ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments), + ast::Instruction::Membar { data } => self.emit_membar(data), + ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")), + ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments), + ast::Instruction::CpAsync { data, arguments } => self.emit_cp_async(data, arguments), + ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop + ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop + ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop + // replaced by a function call + ast::Instruction::Bfe { .. } + | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Activemask { .. } + | ast::Instruction::ShflSync { .. } + | ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()), + } + } + + fn emit_ld( + &mut self, + data: ast::LdDetails, + arguments: ast::LdArgs, + ) -> Result<(), TranslateError> { + if data.qualifier != ast::LdStQualifier::Weak { + todo!() + } + let builder = self.builder; + let type_ = get_type(self.context, &data.typ)?; + let ptr = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| { + let load = unsafe { LLVMBuildLoad2(builder, type_, ptr, dst) }; + unsafe { LLVMSetAlignment(load, data.typ.layout().align() as u32) }; + load + }); + Ok(()) + } + + fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> { + let builder = self.builder; + match conversion.kind { + ConversionKind::Default => self.emit_conversion_default( + self.resolver.value(conversion.src)?, + conversion.dst, + &conversion.from_type, + conversion.from_space, + &conversion.to_type, + conversion.to_space, + ), + ConversionKind::SignExtend => { + let src = self.resolver.value(conversion.src)?; + let type_ = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildSExt(builder, src, type_, dst) + }); + Ok(()) + } + ConversionKind::BitToPtr => { + let src = self.resolver.value(conversion.src)?; + let type_ = get_pointer_type(self.context, conversion.to_space)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildIntToPtr(builder, src, type_, dst) + }); + Ok(()) + } + ConversionKind::PtrToPtr => { + let src = self.resolver.value(conversion.src)?; + let dst_type = get_pointer_type(self.context, conversion.to_space)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildAddrSpaceCast(builder, src, dst_type, dst) + }); + Ok(()) + } + ConversionKind::AddressOf => { + let src = self.resolver.value(conversion.src)?; + let dst_type = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildPtrToInt(self.builder, src, dst_type, dst) + }); + Ok(()) + } + } + } + + fn emit_conversion_default( + &mut self, + src: LLVMValueRef, + dst: SpirvWord, + from_type: &ast::Type, + from_space: ast::StateSpace, + to_type: &ast::Type, + to_space: ast::StateSpace, + ) -> Result<(), TranslateError> { + match (from_type, to_type) { + (ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => { + let from_layout = from_type.layout(); + let to_layout = to_type.layout(); + if from_layout.size() == to_layout.size() { + let dst_type = get_type(self.context, &to_type)?; + if from_type.kind() != ast::ScalarKind::Float + && to_type_scalar.kind() != ast::ScalarKind::Float + { + // It is noop, but another instruction expects result of this conversion + self.resolver.register(dst, src); + } else { + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildBitCast(self.builder, src, dst_type, dst) + }); + } + Ok(()) + } else { + // This block is safe because it's illegal to implictly convert between floating point values + let same_width_bit_type = unsafe { + LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32) + }; + let same_width_bit_value = unsafe { + LLVMBuildBitCast( + self.builder, + src, + same_width_bit_type, + LLVM_UNNAMED.as_ptr(), + ) + }; + let wide_bit_type = match to_type_scalar.layout().size() { + 1 => ast::ScalarType::B8, + 2 => ast::ScalarType::B16, + 4 => ast::ScalarType::B32, + 8 => ast::ScalarType::B64, + _ => return Err(error_unreachable()), + }; + let wide_bit_type_llvm = unsafe { + LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32) + }; + if to_type_scalar.kind() == ast::ScalarKind::Unsigned + || to_type_scalar.kind() == ast::ScalarKind::Bit + { + let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + }; + self.resolver.with_result(dst, |dst| unsafe { + llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst) + }); + Ok(()) + } else { + let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed + && to_type_scalar.kind() == ast::ScalarKind::Signed + { + if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildSExtOrBitCast + } else { + LLVMBuildTrunc + } + } else { + if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + } + }; + let wide_bit_value = unsafe { + conversion_fn( + self.builder, + same_width_bit_value, + wide_bit_type_llvm, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.emit_conversion_default( + wide_bit_value, + dst, + &wide_bit_type.into(), + from_space, + to_type, + to_space, + ) + } + } + } + (ast::Type::Vector(..), ast::Type::Scalar(..)) + | (ast::Type::Scalar(..), ast::Type::Array(..)) + | (ast::Type::Array(..), ast::Type::Scalar(..)) => { + let dst_type = get_type(self.context, to_type)?; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildBitCast(self.builder, src, dst_type, dst) + }); + Ok(()) + } + _ => todo!(), + } + } + + fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, constant.typ); + let value = match constant.value { + ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) }, + ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) }, + ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) }, + ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) }, + }; + self.resolver.register(constant.dst, value); + Ok(()) + } + + fn emit_add( + &mut self, + data: ast::ArithDetails, + arguments: ast::AddArgs, + ) -> Result<(), TranslateError> { + let builder = self.builder; + let fn_ = match data { + ast::ArithDetails::Integer(ast::ArithInteger { + saturate: true, + type_, + }) => { + let op = if type_.kind() == ast::ScalarKind::Signed { + "sadd" + } else { + "uadd" + }; + return self.emit_intrinsic_saturate( + op, + type_, + arguments.dst, + arguments.src1, + arguments.src2, + ); + } + ast::ArithDetails::Integer(ast::ArithInteger { + saturate: false, .. + }) => LLVMBuildAdd, + ast::ArithDetails::Float(..) => LLVMBuildFAdd, + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + fn_(builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_st( + &self, + data: ast::StData, + arguments: ast::StArgs, + ) -> Result<(), TranslateError> { + let ptr = self.resolver.value(arguments.src1)?; + let value = self.resolver.value(arguments.src2)?; + if data.qualifier != ast::LdStQualifier::Weak { + todo!() + } + let store = unsafe { LLVMBuildStore(self.builder, value, ptr) }; + unsafe { + LLVMSetAlignment(store, data.typ.layout().align() as u32); + } + Ok(()) + } + + fn emit_ret(&self, _data: ast::RetData) { + unsafe { LLVMBuildRetVoid(self.builder) }; + } + + fn emit_call( + &mut self, + data: ast::CallDetails, + arguments: ast::CallArgs, + ) -> Result<(), TranslateError> { + if cfg!(debug_assertions) { + for (_, space) in data.return_arguments.iter() { + if *space != ast::StateSpace::Reg { + panic!() + } + } + for (_, space) in data.input_arguments.iter() { + if *space != ast::StateSpace::Reg { + panic!() + } + } + } + let name = match &*arguments.return_arguments { + [dst] => self.resolver.get_or_add_raw(*dst), + _ => LLVM_UNNAMED.as_ptr(), + }; + let type_ = get_function_type( + self.context, + data.return_arguments.iter().map(|(type_, ..)| type_), + data.input_arguments + .iter() + .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)), + )?; + let mut input_arguments = arguments + .input_arguments + .iter() + .map(|arg| self.resolver.value(*arg)) + .collect::, _>>()?; + let llvm_call = unsafe { + LLVMBuildCall2( + self.builder, + type_, + self.resolver.value(arguments.func)?, + input_arguments.as_mut_ptr(), + input_arguments.len() as u32, + name, + ) + }; + match &*arguments.return_arguments { + [] => {} + [name] => self.resolver.register(*name, llvm_call), + [b32, pred] => { + self.resolver.with_result(*b32, |name| unsafe { + LLVMBuildExtractValue(self.builder, llvm_call, 0, name) + }); + self.resolver.with_result(*pred, |name| unsafe { + let extracted = + LLVMBuildExtractValue(self.builder, llvm_call, 1, LLVM_UNNAMED.as_ptr()); + LLVMBuildTrunc( + self.builder, + extracted, + get_scalar_type(self.context, ast::ScalarType::Pred), + name, + ) + }); + } + _ => { + return Err(error_todo_msg( + "Only two return arguments (.b32, .pred) currently supported", + )) + } + } + Ok(()) + } + + fn emit_mov(&mut self, arguments: ast::MovArgs) -> Result<(), TranslateError> { + self.resolver + .register(arguments.dst, self.resolver.value(arguments.src)?); + Ok(()) + } + + fn emit_ptr_access(&mut self, ptr_access: PtrAccess) -> Result<(), TranslateError> { + let ptr_src = self.resolver.value(ptr_access.ptr_src)?; + let mut offset_src = self.resolver.value(ptr_access.offset_src)?; + let pointee_type = get_scalar_type(self.context, ast::ScalarType::B8); + self.resolver.with_result(ptr_access.dst, |dst| unsafe { + LLVMBuildInBoundsGEP2(self.builder, pointee_type, ptr_src, &mut offset_src, 1, dst) + }); + Ok(()) + } + + fn emit_and(&mut self, arguments: ast::AndArgs) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildAnd(builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_atom( + &mut self, + data: ast::AtomDetails, + arguments: ast::AtomArgs, + ) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let op = match data.op { + ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd, + ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr, + ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor, + ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg, + ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd, + ast::AtomicOp::IncrementWrap => { + LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap + } + ast::AtomicOp::DecrementWrap => { + LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap + } + ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin, + ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin, + ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax, + ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax, + ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd, + ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin, + ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax, + }; + self.resolver.register(arguments.dst, unsafe { + LLVMZludaBuildAtomicRMW( + builder, + op, + src1, + src2, + get_scope(data.scope)?, + get_ordering(data.semantics), + ) + }); + Ok(()) + } + + fn emit_atom_cas( + &mut self, + data: ast::AtomCasDetails, + arguments: ast::AtomCasArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + let success_ordering = get_ordering(data.semantics); + let failure_ordering = get_ordering_failure(data.semantics); + let temp = unsafe { + LLVMZludaBuildAtomicCmpXchg( + self.builder, + src1, + src2, + src3, + get_scope(data.scope)?, + success_ordering, + failure_ordering, + ) + }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildExtractValue(self.builder, temp, 0, dst) + }); + Ok(()) + } + + fn emit_bra(&self, arguments: ast::BraArgs) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let src = unsafe { LLVMValueAsBasicBlock(src) }; + unsafe { LLVMBuildBr(self.builder, src) }; + Ok(()) + } + + fn emit_brev( + &mut self, + data: ast::ScalarType, + arguments: ast::BrevArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.size_of() { + 4 => c"llvm.bitreverse.i32", + 8 => c"llvm.bitreverse.i64", + _ => return Err(error_unreachable()), + }; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; + let type_ = get_scalar_type(self.context, data); + let fn_type = get_function_type( + self.context, + iter::once(&data.into()), + iter::once(Ok(type_)), + )?; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; + } + let mut src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst) + }); + Ok(()) + } + + fn emit_ret_value( + &mut self, + values: Vec<(SpirvWord, ptx_parser::Type)>, + ) -> Result<(), TranslateError> { + let loads = values + .iter() + .map(|(value, type_)| { + let value = self.resolver.value(*value)?; + let lowered_type = get_type(self.context, type_)?; + let load = unsafe { + LLVMBuildLoad2(self.builder, lowered_type, value, LLVM_UNNAMED.as_ptr()) + }; + unsafe { + LLVMSetAlignment(load, type_.layout().align() as u32); + } + Ok(load) + }) + .collect::, _>>()?; + + match &*loads { + [] => unsafe { LLVMBuildRetVoid(self.builder) }, + [value] => unsafe { LLVMBuildRet(self.builder, *value) }, + _ => { + check_multiple_return_types(values.iter().map(|(_, type_)| type_))?; + let array_ty = + get_array_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?; + let insert_b32 = unsafe { + LLVMBuildInsertValue( + self.builder, + LLVMGetPoison(array_ty), + loads[0], + 0, + LLVM_UNNAMED.as_ptr(), + ) + }; + let zext_pred = unsafe { + LLVMBuildZExt( + self.builder, + loads[1], + get_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32))?, + LLVM_UNNAMED.as_ptr(), + ) + }; + let insert_pred = unsafe { + LLVMBuildInsertValue( + self.builder, + insert_b32, + zext_pred, + 1, + LLVM_UNNAMED.as_ptr(), + ) + }; + unsafe { LLVMBuildRet(self.builder, insert_pred) } + } + }; + Ok(()) + } + + fn emit_clz( + &mut self, + data: ptx_parser::ScalarType, + arguments: ptx_parser::ClzArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.size_of() { + 4 => c"llvm.ctlz.i32", + 8 => c"llvm.ctlz.i64", + _ => return Err(error_unreachable()), + }; + let type_ = get_scalar_type(self.context, data.into()); + let pred = get_scalar_type(self.context, ast::ScalarType::Pred); + let fn_type = get_function_type( + self.context, + iter::once(&ast::ScalarType::U32.into()), + [Ok(type_), Ok(pred)].into_iter(), + )?; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; + } + let src = self.resolver.value(arguments.src)?; + let false_ = unsafe { LLVMConstInt(pred, 0, 0) }; + let mut args = [src, false_]; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildCall2( + self.builder, + fn_type, + fn_, + args.as_mut_ptr(), + args.len() as u32, + dst, + ) + }); + Ok(()) + } + + fn emit_mul( + &mut self, + data: ast::MulDetails, + arguments: ast::MulArgs, + ) -> Result<(), TranslateError> { + self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?; + Ok(()) + } + + fn emit_mul_impl( + &mut self, + data: ast::MulDetails, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + let mul_fn = match data { + ast::MulDetails::Integer { control, type_ } => match control { + ast::MulIntControl::Low => LLVMBuildMul, + ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2), + ast::MulIntControl::Wide => { + return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1) + } + }, + ast::MulDetails::Float(..) => LLVMBuildFMul, + }; + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + Ok(self + .resolver + .with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) })) + } + + fn emit_mul_high( + &mut self, + type_: ptx_parser::ScalarType, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?; + let shift_constant = + unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) }; + let shifted = unsafe { + LLVMBuildLShr( + self.builder, + wide_value, + shift_constant, + LLVM_UNNAMED.as_ptr(), + ) + }; + let narrow_type = get_scalar_type(self.context, type_); + Ok(self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildTrunc(self.builder, shifted, narrow_type, dst) + })) + } + + fn emit_mul_wide_impl( + &mut self, + type_: ptx_parser::ScalarType, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> { + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + let wide_type = + unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) }; + let llvm_cast = match type_.kind() { + ptx_parser::ScalarKind::Signed => LLVMBuildSExt, + ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt, + _ => return Err(error_unreachable()), + }; + let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) }; + let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) }; + Ok(( + wide_type, + self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildMul(self.builder, src1, src2, dst) + }), + )) + } + + fn emit_cos( + &mut self, + _data: ast::FlushToZero, + arguments: ast::CosArgs, + ) -> Result<(), TranslateError> { + let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); + let cos = self.emit_intrinsic( + c"llvm.cos.f32", + Some(arguments.dst), + Some(&ast::ScalarType::F32.into()), + vec![(self.resolver.value(arguments.src)?, llvm_f32)], + )?; + unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) } + Ok(()) + } + + fn emit_or( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::OrArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildOr(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_xor( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::XorArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildXor(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> { + let src = self.resolver.value(vec_acccess.vector_src)?; + let index = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B8), + vec_acccess.member as _, + 0, + ) + }; + self.resolver + .with_result(vec_acccess.scalar_dst, |dst| unsafe { + LLVMBuildExtractElement(self.builder, src, index, dst) + }); + Ok(()) + } + + fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> { + let vector_src = self.resolver.value(vector_write.vector_src)?; + let scalar_src = self.resolver.value(vector_write.scalar_src)?; + let index = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B8), + vector_write.member as _, + 0, + ) + }; + self.resolver + .with_result(vector_write.vector_dst, |dst| unsafe { + LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst) + }); + Ok(()) + } + + fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> { + let i8_type = get_scalar_type(self.context, ast::ScalarType::B8); + if repack.is_extract { + let src = self.resolver.value(repack.packed)?; + for (index, dst) in repack.unpacked.iter().enumerate() { + let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) }; + self.resolver.with_result(*dst, |dst| unsafe { + LLVMBuildExtractElement(self.builder, src, index, dst) + }); + } + } else { + let vector_type = get_type( + self.context, + &ast::Type::Vector(repack.unpacked.len() as u8, repack.typ), + )?; + let mut temp_vec = unsafe { LLVMGetUndef(vector_type) }; + for (index, src_id) in repack.unpacked.iter().enumerate() { + let dst = if index == repack.unpacked.len() - 1 { + Some(repack.packed) + } else { + None + }; + let scalar_src = self.resolver.value(*src_id)?; + let index = unsafe { LLVMConstInt(i8_type, index as _, 0) }; + temp_vec = self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst) + }); + } + } + Ok(()) + } + + fn emit_div( + &mut self, + data: ptx_parser::DivDetails, + arguments: ptx_parser::DivArgs, + ) -> Result<(), TranslateError> { + let integer_div = match data { + ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv, + ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv, + ptx_parser::DivDetails::Float(float_div) => { + return self.emit_div_float(float_div, arguments) + } + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + integer_div(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_div_float( + &mut self, + float_div: ptx_parser::DivFloatDetails, + arguments: ptx_parser::DivArgs, + ) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let _rnd = match float_div.kind { + ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven, + ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven, + ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode, + }; + let approx = match float_div.kind { + ptx_parser::DivFloatKind::Approx => { + LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc + } + ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone, + ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone, + }; + let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFDiv(builder, src1, src2, dst) + }); + unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) }; + if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind { + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div: + // div.full.f32 implements a relatively fast, full-range approximation that scales + // operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not + // support rounding modifiers. The maximum ulp error is 2 across the full range of + // inputs. + // https://llvm.org/docs/LangRef.html#fpmath-metadata + let fpmath_value = + unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) }; + let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) }; + let mut md_node_content = [fpmath_value]; + let md_node = unsafe { + LLVMMDNodeInContext2( + self.context, + md_node_content.as_mut_ptr(), + md_node_content.len(), + ) + }; + let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) }; + let kind = unsafe { + LLVMGetMDKindIDInContext( + self.context, + "fpmath".as_ptr().cast(), + "fpmath".len() as u32, + ) + }; + unsafe { LLVMSetMetadata(fdiv, kind, md_node) }; + } + Ok(()) + } + + fn emit_cvta( + &mut self, + data: ptx_parser::CvtaDetails, + arguments: ptx_parser::CvtaArgs, + ) -> Result<(), TranslateError> { + let (from_space, to_space) = match data.direction { + ptx_parser::CvtaDirection::GenericToExplicit => { + (ast::StateSpace::Generic, data.state_space) + } + ptx_parser::CvtaDirection::ExplicitToGeneric => { + (data.state_space, ast::StateSpace::Generic) + } + }; + let from_type = get_pointer_type(self.context, from_space)?; + let dest_type = get_pointer_type(self.context, to_space)?; + let src = self.resolver.value(arguments.src)?; + let temp_ptr = + unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst) + }); + Ok(()) + } + + fn emit_sub( + &mut self, + data: ptx_parser::ArithDetails, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + match data { + ptx_parser::ArithDetails::Integer(arith_integer) => { + self.emit_sub_integer(arith_integer, arguments) + } + ptx_parser::ArithDetails::Float(arith_float) => { + self.emit_sub_float(arith_float, arguments) + } + } + } + + fn emit_sub_integer( + &mut self, + arith_integer: ptx_parser::ArithInteger, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + if arith_integer.saturate { + let op = if arith_integer.type_.kind() == ast::ScalarKind::Signed { + "ssub" + } else { + "usub" + }; + return self.emit_intrinsic_saturate( + op, + arith_integer.type_, + arguments.dst, + arguments.src1, + arguments.src2, + ); + } + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildSub(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_sub_float( + &mut self, + _arith_float: ptx_parser::ArithFloat, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFSub(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_sin( + &mut self, + _data: ptx_parser::FlushToZero, + arguments: ptx_parser::SinArgs, + ) -> Result<(), TranslateError> { + let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); + let sin = self.emit_intrinsic( + c"llvm.sin.f32", + Some(arguments.dst), + Some(&ast::ScalarType::F32.into()), + vec![(self.resolver.value(arguments.src)?, llvm_f32)], + )?; + unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) } + Ok(()) + } + + fn emit_intrinsic( + &mut self, + name: &CStr, + dst: Option, + return_type: Option<&ast::Type>, + arguments: Vec<(LLVMValueRef, LLVMTypeRef)>, + ) -> Result { + let fn_type = get_function_type( + self.context, + return_type.into_iter(), + arguments.iter().map(|(_, type_)| Ok(*type_)), + )?; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + } + let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::>(); + Ok(self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildCall2( + self.builder, + fn_type, + fn_, + arguments.as_mut_ptr(), + arguments.len() as u32, + dst, + ) + })) + } + + fn emit_neg( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::NegArgs, + ) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let is_floating_point = data.type_.kind() == ptx_parser::ScalarKind::Float; + let llvm_fn = if is_floating_point { + LLVMBuildFNeg + } else { + LLVMBuildNeg + }; + if is_floating_point && data.flush_to_zero == Some(true) { + let negated = unsafe { llvm_fn(self.builder, src, LLVM_UNNAMED.as_ptr()) }; + let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&data.type_.into()), + vec![(negated, get_scalar_type(self.context, data.type_))], + )?; + } else { + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst) + }); + } + Ok(()) + } + + fn emit_not( + &mut self, + type_: ptx_parser::ScalarType, + arguments: ptx_parser::NotArgs, + ) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let type_ = get_scalar_type(self.context, type_); + let constant = unsafe { LLVMConstInt(type_, u64::MAX, 0) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildXor(self.builder, src, constant, dst) + }); + Ok(()) + } + + fn emit_setp( + &mut self, + data: ptx_parser::SetpData, + arguments: ptx_parser::SetpArgs, + ) -> Result<(), TranslateError> { + let dst = self.emit_setp_impl(data, arguments.dst2, arguments.src1, arguments.src2)?; + self.resolver.register(arguments.dst1, dst); + Ok(()) + } + + fn emit_setp_impl( + &mut self, + data: ptx_parser::SetpData, + dst2: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + if dst2.is_some() { + return Err(error_todo_msg( + "setp with two dst arguments not yet supported", + )); + } + match data.cmp_op { + ptx_parser::SetpCompareOp::Integer(setp_compare_int) => { + self.emit_setp_int(setp_compare_int, src1, src2) + } + ptx_parser::SetpCompareOp::Float(setp_compare_float) => { + self.emit_setp_float(setp_compare_float, src1, src2) + } + } + } + + fn emit_setp_int( + &mut self, + setp: ptx_parser::SetpCompareInt, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + let op = match setp { + ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ, + ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE, + ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT, + ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE, + ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT, + ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE, + ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT, + ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE, + ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT, + ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE, + }; + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + Ok(unsafe { LLVMBuildICmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) }) + } + + fn emit_setp_float( + &mut self, + setp: ptx_parser::SetpCompareFloat, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + let op = match setp { + ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ, + ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE, + ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT, + ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE, + ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT, + ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE, + ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ, + ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE, + ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT, + ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE, + ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT, + ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE, + ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD, + ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO, + }; + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + Ok(unsafe { LLVMBuildFCmp(self.builder, op, src1, src2, LLVM_UNNAMED.as_ptr()) }) + } + + fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> { + let predicate = self.resolver.value(cond.predicate)?; + let if_true = self.resolver.value(cond.if_true)?; + let if_false = self.resolver.value(cond.if_false)?; + unsafe { + LLVMBuildCondBr( + self.builder, + predicate, + LLVMValueAsBasicBlock(if_true), + LLVMValueAsBasicBlock(if_false), + ) + }; + Ok(()) + } + + fn emit_cvt( + &mut self, + data: ptx_parser::CvtDetails, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + let dst_type = get_scalar_type(self.context, data.to); + let llvm_fn = match data.mode { + ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt, + ptx_parser::CvtMode::SignExtend => LLVMBuildSExt, + ptx_parser::CvtMode::Truncate => LLVMBuildTrunc, + ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, + ptx_parser::CvtMode::IntSaturateToSigned => { + return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments) + } + ptx_parser::CvtMode::IntSaturateToUnsigned => { + return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments) + } + ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt, + ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc, + ptx_parser::CvtMode::FPRound { + integer_rounding: None, + flush_to_zero: None | Some(false), + .. + } => { + return self.emit_mov(ast::MovArgs { + dst: arguments.dst, + src: arguments.src, + }) + } + ptx_parser::CvtMode::FPRound { + integer_rounding: None, + flush_to_zero: Some(true), + .. + } => return self.flush_denormals(data.to, arguments.src, arguments.dst), + ptx_parser::CvtMode::FPRound { + integer_rounding: Some(rounding), + .. + } => return self.emit_cvt_float_to_int(data.from, data.to, rounding, arguments, None), + ptx_parser::CvtMode::SignedFromFP { rounding, .. } => { + return self.emit_cvt_float_to_int( + data.from, + data.to, + rounding, + arguments, + Some(true), + ) + } + ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => { + return self.emit_cvt_float_to_int( + data.from, + data.to, + rounding, + arguments, + Some(false), + ) + } + ptx_parser::CvtMode::FPFromSigned { .. } => { + return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildSIToFP) + } + ptx_parser::CvtMode::FPFromUnsigned { .. } => { + return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildUIToFP) + } + }; + let src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst_type, dst) + }); + Ok(()) + } + + fn emit_cvt_unsigned_to_signed_sat( + &mut self, + from: ptx_parser::ScalarType, + to: ptx_parser::ScalarType, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + let clamped = self.emit_saturate_integer(from, to, &arguments)?; + let resize_fn = if to.layout().size() >= from.layout().size() { + LLVMBuildSExtOrBitCast + } else { + LLVMBuildTrunc + }; + let to_llvm = get_scalar_type(self.context, to); + self.resolver.with_result(arguments.dst, |dst| unsafe { + resize_fn(self.builder, clamped, to_llvm, dst) + }); + Ok(()) + } + + fn emit_saturate_integer( + &mut self, + from: ptx_parser::ScalarType, + to: ptx_parser::ScalarType, + arguments: &ptx_parser::CvtArgs, + ) -> Result { + let from_llvm = get_scalar_type(self.context, from); + match from.kind() { + ptx_parser::ScalarKind::Unsigned => { + let max_value = match to { + ptx_parser::ScalarType::U8 => u8::MAX as u64, + ptx_parser::ScalarType::S8 => i8::MAX as u64, + ptx_parser::ScalarType::U16 => u16::MAX as u64, + ptx_parser::ScalarType::S16 => i16::MAX as u64, + ptx_parser::ScalarType::U32 => u32::MAX as u64, + ptx_parser::ScalarType::S32 => i32::MAX as u64, + ptx_parser::ScalarType::U64 => u64::MAX as u64, + ptx_parser::ScalarType::S64 => i64::MAX as u64, + _ => return Err(error_unreachable()), + }; + let intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from)); + let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; + let clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + None, + Some(&from.into()), + vec![ + (self.resolver.value(arguments.src)?, from_llvm), + (max, from_llvm), + ], + )?; + Ok(clamped) + } + ptx_parser::ScalarKind::Signed => { + let (min_value_from, max_value_from) = match from { + ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128), + ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128), + ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128), + ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128), + _ => return Err(error_unreachable()), + }; + let (min_value_to, max_value_to) = match to { + ptx_parser::ScalarType::U8 => (u8::MIN as i128, u8::MAX as i128), + ptx_parser::ScalarType::S8 => (i8::MIN as i128, i8::MAX as i128), + ptx_parser::ScalarType::U16 => (u16::MIN as i128, u16::MAX as i128), + ptx_parser::ScalarType::S16 => (i16::MIN as i128, i16::MAX as i128), + ptx_parser::ScalarType::U32 => (u32::MIN as i128, u32::MAX as i128), + ptx_parser::ScalarType::S32 => (i32::MIN as i128, i32::MAX as i128), + ptx_parser::ScalarType::U64 => (u64::MIN as i128, u64::MAX as i128), + ptx_parser::ScalarType::S64 => (i64::MIN as i128, i64::MAX as i128), + _ => return Err(error_unreachable()), + }; + let min_value = min_value_from.max(min_value_to); + let max_value = max_value_from.min(max_value_to); + let max_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from)); + let min = unsafe { LLVMConstInt(from_llvm, min_value as u64, 1) }; + let min_intrinsic = format!("llvm.smin.{}\0", LLVMTypeDisplay(from)); + let max = unsafe { LLVMConstInt(from_llvm, max_value as u64, 1) }; + let clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(max_intrinsic.as_bytes()) }, + None, + Some(&from.into()), + vec![ + (self.resolver.value(arguments.src)?, from_llvm), + (min, from_llvm), + ], + )?; + let clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(min_intrinsic.as_bytes()) }, + None, + Some(&from.into()), + vec![(clamped, from_llvm), (max, from_llvm)], + )?; + Ok(clamped) + } + _ => return Err(error_unreachable()), + } + } + + fn emit_cvt_signed_to_unsigned_sat( + &mut self, + from: ptx_parser::ScalarType, + to: ptx_parser::ScalarType, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + let clamped = self.emit_saturate_integer(from, to, &arguments)?; + let resize_fn = if to.layout().size() >= from.layout().size() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + }; + let to_llvm = get_scalar_type(self.context, to); + self.resolver.with_result(arguments.dst, |dst| unsafe { + resize_fn(self.builder, clamped, to_llvm, dst) + }); + Ok(()) + } + + fn emit_cvt_float_to_int( + &mut self, + from: ast::ScalarType, + to: ast::ScalarType, + rounding: ast::RoundingMode, + arguments: ptx_parser::CvtArgs, + signed_cast: Option, + ) -> Result<(), TranslateError> { + let dst_int_rounded = + self.emit_fp_int_rounding(from, rounding, &arguments, signed_cast.is_some())?; + // In PTX all the int-from-float casts are saturating casts. On the other hand, in LLVM, + // out-of-range fptoui and fptosi have undefined behavior. + // We could handle this all with llvm.fptosi.sat and llvm.fptoui.sat intrinsics, but + // the problem is that, when using *.sat variants AMDGPU target _always_ emits saturation + // checks. Often they are unnecessary because v_cvt_* instructions saturates anyway. + // For that reason, all from-to combinations that we know have a direct corresponding + // v_cvt_* instruction get special treatment + let is_saturating_cast = match (to, from) { + (ast::ScalarType::S16, ast::ScalarType::F16) + | (ast::ScalarType::S32, ast::ScalarType::F32) + | (ast::ScalarType::S32, ast::ScalarType::F64) + | (ast::ScalarType::U16, ast::ScalarType::F16) + | (ast::ScalarType::U32, ast::ScalarType::F32) + | (ast::ScalarType::U32, ast::ScalarType::F64) => true, + _ => false, + }; + let signed_cast = match signed_cast { + Some(s) => s, + None => { + self.resolver.register( + arguments.dst, + dst_int_rounded.ok_or_else(error_unreachable)?, + ); + return Ok(()); + } + }; + if is_saturating_cast { + let to = get_scalar_type(self.context, to); + let src = + dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap()); + let llvm_cast = if signed_cast { + LLVMBuildFPToSI + } else { + LLVMBuildFPToUI + }; + let poisoned_dst = unsafe { llvm_cast(self.builder, src, to, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFreeze(self.builder, poisoned_dst, dst) + }); + } else { + let cvt_op = if to.kind() == ptx_parser::ScalarKind::Unsigned { + "fptoui" + } else { + "fptosi" + }; + let cast_intrinsic = format!( + "llvm.{cvt_op}.sat.{}.{}\0", + LLVMTypeDisplay(to), + LLVMTypeDisplay(from) + ); + let src = + dst_int_rounded.unwrap_or_else(|| self.resolver.value(arguments.src).unwrap()); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&to.into()), + vec![(src, get_scalar_type(self.context, from))], + )?; + } + Ok(()) + } + + fn emit_fp_int_rounding( + &mut self, + from: ptx_parser::ScalarType, + rounding: ptx_parser::RoundingMode, + arguments: &ptx_parser::CvtArgs, + will_saturate_with_cvt: bool, + ) -> Result, TranslateError> { + let prefix = match rounding { + ptx_parser::RoundingMode::NearestEven => "llvm.roundeven", + ptx_parser::RoundingMode::Zero => { + // cvt has round-to-zero semantics + if will_saturate_with_cvt { + return Ok(None); + } else { + "llvm.trunc" + } + } + ptx_parser::RoundingMode::NegativeInf => "llvm.floor", + ptx_parser::RoundingMode::PositiveInf => "llvm.ceil", + }; + let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from)); + let rounded_float = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + None, + Some(&from.into()), + vec![( + self.resolver.value(arguments.src)?, + get_scalar_type(self.context, from), + )], + )?; + Ok(Some(rounded_float)) + } + + fn emit_cvt_int_to_float( + &mut self, + to: ptx_parser::ScalarType, + arguments: ptx_parser::CvtArgs, + llvm_func: unsafe extern "C" fn( + arg1: LLVMBuilderRef, + Val: LLVMValueRef, + DestTy: LLVMTypeRef, + Name: *const i8, + ) -> LLVMValueRef, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, to); + let src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_func(self.builder, src, type_, dst) + }); + Ok(()) + } + + fn emit_rsqrt( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::RsqrtArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match data.type_ { + ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32", + ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + Some(&data.type_.into()), + vec![(self.resolver.value(arguments.src)?, type_)], + )?; + Ok(()) + } + + fn emit_sqrt( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::SqrtArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match (data.type_, data.kind) { + (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32", + (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32", + (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + Some(&data.type_.into()), + vec![(self.resolver.value(arguments.src)?, type_)], + )?; + Ok(()) + } + + fn emit_rcp( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::RcpArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match (data.type_, data.kind) { + (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32", + (_, ast::RcpKind::Compliant(rnd)) => { + return self.emit_rcp_compliant(data, arguments, rnd) + } + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + Some(&data.type_.into()), + vec![(self.resolver.value(arguments.src)?, type_)], + )?; + Ok(()) + } + + fn emit_rcp_compliant( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::RcpArgs, + _rnd: ast::RoundingMode, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let one = unsafe { LLVMConstReal(type_, 1.0) }; + let src = self.resolver.value(arguments.src)?; + let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFDiv(self.builder, one, src, dst) + }); + unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) }; + Ok(()) + } + + fn emit_shr( + &mut self, + data: ptx_parser::ShrData, + arguments: ptx_parser::ShrArgs, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, data.type_); + let (out_of_range, shift_fn): ( + *mut LLVMValue, + unsafe extern "C" fn( + LLVMBuilderRef, + LLVMValueRef, + LLVMValueRef, + *const i8, + ) -> LLVMValueRef, + ) = match data.kind { + ptx_parser::RightShiftKind::Logical => { + (unsafe { LLVMConstNull(llvm_type) }, LLVMBuildLShr) + } + ptx_parser::RightShiftKind::Arithmetic => { + let src1 = self.resolver.value(arguments.src1)?; + let shift_size = + unsafe { LLVMConstInt(llvm_type, (data.type_.size_of() as u64 * 8) - 1, 0) }; + let out_of_range = + unsafe { LLVMBuildAShr(self.builder, src1, shift_size, LLVM_UNNAMED.as_ptr()) }; + (out_of_range, LLVMBuildAShr) + } + }; + self.emit_shift( + data.type_, + arguments.dst, + arguments.src1, + arguments.src2, + out_of_range, + shift_fn, + ) + } + + fn emit_shl( + &mut self, + type_: ptx_parser::ScalarType, + arguments: ptx_parser::ShlArgs, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); + self.emit_shift( + type_, + arguments.dst, + arguments.src1, + arguments.src2, + unsafe { LLVMConstNull(llvm_type) }, + LLVMBuildShl, + ) + } + + fn emit_shift( + &mut self, + type_: ast::ScalarType, + dst: SpirvWord, + src1: SpirvWord, + src2: SpirvWord, + out_of_range_value: LLVMValueRef, + llvm_fn: unsafe extern "C" fn( + LLVMBuilderRef, + LLVMValueRef, + LLVMValueRef, + *const i8, + ) -> LLVMValueRef, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(src1)?; + let shift_size = self.resolver.value(src2)?; + let integer_bits = type_.layout().size() * 8; + let integer_bits_constant = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::U32), + integer_bits as u64, + 0, + ) + }; + let should_clamp = unsafe { + LLVMBuildICmp( + self.builder, + LLVMIntPredicate::LLVMIntUGE, + shift_size, + integer_bits_constant, + LLVM_UNNAMED.as_ptr(), + ) + }; + let llvm_type = get_scalar_type(self.context, type_); + let normalized_shift_size = if type_.layout().size() >= 4 { + unsafe { + LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) + } + } else { + unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } + }; + let shifted = unsafe { + llvm_fn( + self.builder, + src1, + normalized_shift_size, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildSelect(self.builder, should_clamp, out_of_range_value, shifted, dst) + }); + Ok(()) + } + + fn emit_ex2( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::Ex2Args, + ) -> Result<(), TranslateError> { + let intrinsic = match data.type_ { + ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16", + ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + Some(&data.type_.into()), + vec![( + self.resolver.value(arguments.src)?, + get_scalar_type(self.context, data.type_), + )], + )?; + Ok(()) + } + + fn emit_lg2( + &mut self, + _data: ptx_parser::FlushToZero, + arguments: ptx_parser::Lg2Args, + ) -> Result<(), TranslateError> { + self.emit_intrinsic( + c"llvm.amdgcn.log.f32", + Some(arguments.dst), + Some(&ast::ScalarType::F32.into()), + vec![( + self.resolver.value(arguments.src)?, + get_scalar_type(self.context, ast::ScalarType::F32.into()), + )], + )?; + Ok(()) + } + + fn emit_selp( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::SelpArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + self.resolver.with_result(arguments.dst, |dst_name| unsafe { + LLVMBuildSelect(self.builder, src3, src1, src2, dst_name) + }); + Ok(()) + } + + fn emit_rem( + &mut self, + data: ptx_parser::ScalarType, + arguments: ptx_parser::RemArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.kind() { + ptx_parser::ScalarKind::Unsigned => LLVMBuildURem, + ptx_parser::ScalarKind::Signed => LLVMBuildSRem, + _ => return Err(error_unreachable()), + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst_name| unsafe { + llvm_fn(self.builder, src1, src2, dst_name) + }); + Ok(()) + } + + fn emit_popc( + &mut self, + type_: ptx_parser::ScalarType, + arguments: ptx_parser::PopcArgs, + ) -> Result<(), TranslateError> { + let intrinsic = match type_ { + ast::ScalarType::B32 => c"llvm.ctpop.i32", + ast::ScalarType::B64 => c"llvm.ctpop.i64", + _ => return Err(error_unreachable()), + }; + let llvm_type = get_scalar_type(self.context, type_); + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + Some(&type_.into()), + vec![(self.resolver.value(arguments.src)?, llvm_type)], + )?; + Ok(()) + } + + fn emit_min( + &mut self, + data: ptx_parser::MinMaxDetails, + arguments: ptx_parser::MinArgs, + ) -> Result<(), TranslateError> { + let llvm_prefix = match data { + ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin", + ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin", + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { + "llvm.minimum" + } + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", + }; + let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); + let llvm_type = get_scalar_type(self.context, data.type_()); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&data.type_().into()), + vec![ + (self.resolver.value(arguments.src1)?, llvm_type), + (self.resolver.value(arguments.src2)?, llvm_type), + ], + )?; + Ok(()) + } + + fn emit_max( + &mut self, + data: ptx_parser::MinMaxDetails, + arguments: ptx_parser::MaxArgs, + ) -> Result<(), TranslateError> { + let llvm_prefix = match data { + ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax", + ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax", + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { + "llvm.maximum" + } + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", + }; + let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); + let llvm_type = get_scalar_type(self.context, data.type_()); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&data.type_().into()), + vec![ + (self.resolver.value(arguments.src1)?, llvm_type), + (self.resolver.value(arguments.src2)?, llvm_type), + ], + )?; + Ok(()) + } + + fn emit_fma( + &mut self, + data: ptx_parser::ArithFloat, + arguments: ptx_parser::FmaArgs, + ) -> Result<(), TranslateError> { + let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&data.type_.into()), + vec![ + ( + self.resolver.value(arguments.src1)?, + get_scalar_type(self.context, data.type_), + ), + ( + self.resolver.value(arguments.src2)?, + get_scalar_type(self.context, data.type_), + ), + ( + self.resolver.value(arguments.src3)?, + get_scalar_type(self.context, data.type_), + ), + ], + )?; + Ok(()) + } + + fn emit_mad( + &mut self, + data: ptx_parser::MadDetails, + arguments: ptx_parser::MadArgs, + ) -> Result<(), TranslateError> { + let mul_control = match data { + ptx_parser::MadDetails::Float(mad_float) => { + return self.emit_fma( + mad_float, + ast::FmaArgs { + dst: arguments.dst, + src1: arguments.src1, + src2: arguments.src2, + src3: arguments.src3, + }, + ) + } + ptx_parser::MadDetails::Integer { + saturate: true, + control: ast::MulIntControl::High, + type_: ast::ScalarType::S32, + } => { + return self.emit_mad_hi_sat_s32( + arguments.dst, + (arguments.src1, arguments.src2, arguments.src3), + ); + } + ptx_parser::MadDetails::Integer { saturate: true, .. } => { + return Err(error_unreachable()) + } + ptx_parser::MadDetails::Integer { type_, control, .. } => { + ast::MulDetails::Integer { control, type_ } + } + }; + let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildAdd(self.builder, temp, src3, dst) + }); + Ok(()) + } + + fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> { + unsafe { + LLVMZludaBuildFence( + self.builder, + LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent, + get_scope_membar(data)?, + LLVM_UNNAMED.as_ptr(), + ) + }; + Ok(()) + } + + fn emit_prmt( + &mut self, + control: u16, + arguments: ptx_parser::PrmtArgs, + ) -> Result<(), TranslateError> { + let components = [ + (control >> 0) & 0b1111, + (control >> 4) & 0b1111, + (control >> 8) & 0b1111, + (control >> 12) & 0b1111, + ]; + if components.iter().any(|&c| c > 7) { + return Err(TranslateError::Todo("".to_string())); + } + let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); + let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; + let mut components = [ + unsafe { LLVMConstInt(u32_type, components[0] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[1] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[2] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[3] as _, 0) }, + ]; + let components_indices = + unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) }; + let src1 = self.resolver.value(arguments.src1)?; + let src1_vector = + unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) }; + let src2 = self.resolver.value(arguments.src2)?; + let src2_vector = + unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildShuffleVector( + self.builder, + src1_vector, + src2_vector, + components_indices, + dst, + ) + }); + Ok(()) + } + + fn emit_abs( + &mut self, + data: ast::TypeFtz, + arguments: ptx_parser::AbsArgs, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, data.type_); + let src = self.resolver.value(arguments.src)?; + let is_floating_point = data.type_.kind() == ast::ScalarKind::Float; + let (prefix, intrinsic_arguments) = if is_floating_point { + ("llvm.fabs", vec![(src, llvm_type)]) + } else { + let pred = get_scalar_type(self.context, ast::ScalarType::Pred); + let zero = unsafe { LLVMConstInt(pred, 0, 0) }; + ("llvm.abs", vec![(src, llvm_type), (zero, pred)]) + }; + let llvm_intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(data.type_)); + let abs_result = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) }, + None, + Some(&data.type_.into()), + intrinsic_arguments, + )?; + if is_floating_point && data.flush_to_zero == Some(true) { + let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(data.type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + Some(&data.type_.into()), + vec![(abs_result, llvm_type)], + )?; + } else { + self.resolver.register(arguments.dst, abs_result); + } + Ok(()) + } + + fn emit_mul24( + &mut self, + data: ast::Mul24Details, + arguments: ast::Mul24Args, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let name_lo = match data.type_ { + ast::ScalarType::U32 => c"llvm.amdgcn.mul.u24", + ast::ScalarType::S32 => c"llvm.amdgcn.mul.i24", + _ => return Err(error_unreachable()), + }; + let res_lo = self.emit_intrinsic( + name_lo, + if data.control == Mul24Control::Lo { + Some(arguments.dst) + } else { + None + }, + Some(&ast::Type::Scalar(data.type_)), + vec![ + (src1, get_scalar_type(self.context, data.type_)), + (src2, get_scalar_type(self.context, data.type_)), + ], + )?; + if data.control == Mul24Control::Hi { + // There is an important difference between NVIDIA's mul24.hi and AMD's mulhi.[ui]24. + // NVIDIA: Returns bits 47..16 of the 64-bit result + // AMD: Returns bits 63..32 of the 64-bit result + // Hence we need to compute both hi and lo, shift the results and add them together to replicate NVIDIA's mul24 + let name_hi = match data.type_ { + ast::ScalarType::U32 => c"llvm.amdgcn.mulhi.u24", + ast::ScalarType::S32 => c"llvm.amdgcn.mulhi.i24", + _ => return Err(error_unreachable()), + }; + let res_hi = self.emit_intrinsic( + name_hi, + None, + Some(&ast::Type::Scalar(data.type_)), + vec![ + (src1, get_scalar_type(self.context, data.type_)), + (src2, get_scalar_type(self.context, data.type_)), + ], + )?; + let shift_number = unsafe { LLVMConstInt(LLVMInt32TypeInContext(self.context), 16, 0) }; + let res_lo_shr = + unsafe { LLVMBuildLShr(self.builder, res_lo, shift_number, LLVM_UNNAMED.as_ptr()) }; + let res_hi_shl = + unsafe { LLVMBuildShl(self.builder, res_hi, shift_number, LLVM_UNNAMED.as_ptr()) }; + + self.resolver + .with_result(arguments.dst, |dst: *const i8| unsafe { + LLVMBuildOr(self.builder, res_lo_shr, res_hi_shl, dst) + }); + } + Ok(()) + } + + fn emit_set_mode(&mut self, mode_reg: ModeRegister) -> Result<(), TranslateError> { + fn hwreg(reg: u32, offset: u32, size: u32) -> u32 { + reg | (offset << 6) | ((size - 1) << 11) + } + fn denormal_to_value(ftz: bool) -> u32 { + if ftz { + 0 + } else { + 3 + } + } + fn rounding_to_value(ftz: ast::RoundingMode) -> u32 { + match ftz { + ptx_parser::RoundingMode::NearestEven => 0, + ptx_parser::RoundingMode::Zero => 3, + ptx_parser::RoundingMode::NegativeInf => 2, + ptx_parser::RoundingMode::PositiveInf => 1, + } + } + fn merge_regs(f32: u32, f16f64: u32) -> u32 { + f32 | f16f64 << 2 + } + let intrinsic = c"llvm.amdgcn.s.setreg"; + let (hwreg, value) = match mode_reg { + ModeRegister::Denormal { f32, f16f64 } => { + let hwreg = hwreg(1, 4, 4); + let f32 = denormal_to_value(f32); + let f16f64 = denormal_to_value(f16f64); + let value = merge_regs(f32, f16f64); + (hwreg, value) + } + ModeRegister::Rounding { f32, f16f64 } => { + let hwreg = hwreg(1, 0, 4); + let f32 = rounding_to_value(f32); + let f16f64 = rounding_to_value(f16f64); + let value = merge_regs(f32, f16f64); + (hwreg, value) + } + }; + let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32); + let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) }; + let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) }; + self.emit_intrinsic( + intrinsic, + None, + None, + vec![(hwreg_llvm, llvm_i32), (value_llvm, llvm_i32)], + )?; + Ok(()) + } + + fn emit_fp_saturate( + &mut self, + type_: ast::ScalarType, + dst: SpirvWord, + src: SpirvWord, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); + let zero = unsafe { LLVMConstReal(llvm_type, 0.0) }; + let one = unsafe { LLVMConstReal(llvm_type, 1.0) }; + let maxnum_intrinsic = format!("llvm.maxnum.{}\0", LLVMTypeDisplay(type_)); + let minnum_intrinsic = format!("llvm.minnum.{}\0", LLVMTypeDisplay(type_)); + let src = self.resolver.value(src)?; + let maxnum = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(maxnum_intrinsic.as_bytes()) }, + None, + Some(&type_.into()), + vec![(src, llvm_type), (zero, llvm_type)], + )?; + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(minnum_intrinsic.as_bytes()) }, + Some(dst), + Some(&type_.into()), + vec![(maxnum, llvm_type), (one, llvm_type)], + )?; + Ok(()) + } + + fn emit_intrinsic_saturate( + &mut self, + op: &str, + type_: ast::ScalarType, + dst: SpirvWord, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + let intrinsic = format!("llvm.{}.sat.{}\0", op, LLVMTypeDisplay(type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(dst), + Some(&type_.into()), + vec![(src1, llvm_type), (src2, llvm_type)], + )?; + Ok(()) + } + + fn emit_cp_async( + &mut self, + data: CpAsyncDetails, + arguments: CpAsyncArgs, + ) -> Result<(), TranslateError> { + // Asynchronous copies are not supported by all AMD hardware, so we just do a synchronous copy for now + let to = self.resolver.value(arguments.src_to)?; + let from = self.resolver.value(arguments.src_from)?; + let cp_size = data.cp_size; + let src_size = data.src_size.unwrap_or(cp_size.as_u64()); + + let from_type = unsafe { LLVMIntTypeInContext(self.context, (src_size as u32) * 8) }; + + let to_type = match cp_size { + ptx_parser::CpAsyncCpSize::Bytes4 => unsafe { LLVMInt32TypeInContext(self.context) }, + ptx_parser::CpAsyncCpSize::Bytes8 => unsafe { LLVMInt64TypeInContext(self.context) }, + ptx_parser::CpAsyncCpSize::Bytes16 => unsafe { LLVMInt128TypeInContext(self.context) }, + }; + + let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) }; + unsafe { + LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8); + } + + let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) }; + + unsafe { LLVMBuildStore(self.builder, extended, to) }; + unsafe { + LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8); + } + Ok(()) + } + + fn flush_denormals( + &mut self, + type_: ptx_parser::ScalarType, + src: SpirvWord, + dst: SpirvWord, + ) -> Result<(), TranslateError> { + let llvm_type = get_scalar_type(self.context, type_); + let src = self.resolver.value(src)?; + let intrinsic = format!("llvm.canonicalize.{}\0", LLVMTypeDisplay(type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(dst), + Some(&type_.into()), + vec![(src, llvm_type)], + )?; + Ok(()) + } + + fn emit_mad_hi_sat_s32( + &mut self, + dst: SpirvWord, + (src1, src2, src3): (SpirvWord, SpirvWord, SpirvWord), + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + let src3 = self.resolver.value(src3)?; + let llvm_type_s32 = get_scalar_type(self.context, ast::ScalarType::S32); + let llvm_type_s64 = get_scalar_type(self.context, ast::ScalarType::S64); + let src1_wide = + unsafe { LLVMBuildSExt(self.builder, src1, llvm_type_s64, LLVM_UNNAMED.as_ptr()) }; + let src2_wide = + unsafe { LLVMBuildSExt(self.builder, src2, llvm_type_s64, LLVM_UNNAMED.as_ptr()) }; + let mul_wide = + unsafe { LLVMBuildMul(self.builder, src1_wide, src2_wide, LLVM_UNNAMED.as_ptr()) }; + let const_32 = unsafe { LLVMConstInt(llvm_type_s64, 32, 0) }; + let mul_wide = + unsafe { LLVMBuildLShr(self.builder, mul_wide, const_32, LLVM_UNNAMED.as_ptr()) }; + let mul_narrow = + unsafe { LLVMBuildTrunc(self.builder, mul_wide, llvm_type_s32, LLVM_UNNAMED.as_ptr()) }; + self.emit_intrinsic( + c"llvm.sadd.sat.i32", + Some(dst), + Some(&ast::ScalarType::S32.into()), + vec![(mul_narrow, llvm_type_s32), (src3, llvm_type_s32)], + )?; + Ok(()) + } + + fn emit_set( + &mut self, + data: ptx_parser::SetData, + arguments: ptx_parser::SetArgs, + ) -> Result<(), TranslateError> { + let setp_result = self.emit_setp_impl(data.base, None, arguments.src1, arguments.src2)?; + self.setp_to_set(arguments.dst, data.dtype, setp_result)?; + Ok(()) + } + + fn emit_set_bool( + &mut self, + data: ptx_parser::SetBoolData, + arguments: ptx_parser::SetBoolArgs, + ) -> Result<(), TranslateError> { + let result = + self.emit_setp_bool_impl(data.base, arguments.src1, arguments.src2, arguments.src3)?; + self.setp_to_set(arguments.dst, data.dtype, result)?; + Ok(()) + } + + fn emit_setp_bool( + &mut self, + data: ast::SetpBoolData, + args: ast::SetpBoolArgs, + ) -> Result<(), TranslateError> { + let dst = self.emit_setp_bool_impl(data, args.src1, args.src2, args.src3)?; + self.resolver.register(args.dst1, dst); + Ok(()) + } + + fn emit_setp_bool_impl( + &mut self, + data: ptx_parser::SetpBoolData, + src1: SpirvWord, + src2: SpirvWord, + src3: SpirvWord, + ) -> Result { + let bool_result = self.emit_setp_impl(data.base, None, src1, src2)?; + let bool_result = if data.negate_src3 { + let constant = + unsafe { LLVMConstInt(LLVMIntTypeInContext(self.context, 1), u64::MAX, 0) }; + unsafe { LLVMBuildXor(self.builder, bool_result, constant, LLVM_UNNAMED.as_ptr()) } + } else { + bool_result + }; + let post_op = match data.bool_op { + ptx_parser::SetpBoolPostOp::Xor => LLVMBuildXor, + ptx_parser::SetpBoolPostOp::And => LLVMBuildAnd, + ptx_parser::SetpBoolPostOp::Or => LLVMBuildOr, + }; + let src3 = self.resolver.value(src3)?; + Ok(unsafe { post_op(self.builder, bool_result, src3, LLVM_UNNAMED.as_ptr()) }) + } + + fn setp_to_set( + &mut self, + dst: SpirvWord, + dtype: ast::ScalarType, + setp_result: LLVMValueRef, + ) -> Result<(), TranslateError> { + let llvm_dtype = get_scalar_type(self.context, dtype); + let zero = unsafe { LLVMConstNull(llvm_dtype) }; + let one = if dtype.kind() == ast::ScalarKind::Float { + unsafe { LLVMConstReal(llvm_dtype, 1.0) } + } else { + unsafe { LLVMConstInt(llvm_dtype, u64::MAX, 0) } + }; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildSelect(self.builder, setp_result, one, zero, dst) + }); + Ok(()) + } + + // TODO: revisit this on gfx1250 which has native tanh support + fn emit_tanh( + &mut self, + data: ast::ScalarType, + arguments: ast::TanhArgs, + ) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let llvm_type = get_scalar_type(self.context, data); + let name = format!("__ocml_tanh_{}\0", LLVMTypeDisplay(data)); + let tanh = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) }, + Some(arguments.dst), + Some(&data.into()), + vec![(src, llvm_type)], + )?; + // Not sure if it ultimately does anything + unsafe { LLVMZludaSetFastMathFlags(tanh, LLVMZludaFastMathApproxFunc) } + Ok(()) + } + + /* + // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` + // Should be available in LLVM 19 + fn with_rounding(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T { + let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32); + let void_type = unsafe { LLVMVoidTypeInContext(self.context) }; + let get_rounding = c"llvm.get.rounding"; + let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) }; + let mut get_rounding_fn = + unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) }; + if get_rounding_fn == ptr::null_mut() { + get_rounding_fn = unsafe { + LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type) + }; + } + let set_rounding = c"llvm.set.rounding"; + let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) }; + let mut set_rounding_fn = + unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) }; + if set_rounding_fn == ptr::null_mut() { + set_rounding_fn = unsafe { + LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type) + }; + } + let mut preserved_rounding_mode = unsafe { + LLVMBuildCall2( + self.builder, + get_rounding_fn_type, + get_rounding_fn, + ptr::null_mut(), + 0, + LLVM_UNNAMED.as_ptr(), + ) + }; + let mut requested_rounding = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B32), + rounding_to_llvm(rnd) as u64, + 0, + ) + }; + unsafe { + LLVMBuildCall2( + self.builder, + set_rounding_fn_type, + set_rounding_fn, + &mut requested_rounding, + 1, + LLVM_UNNAMED.as_ptr(), + ) + }; + let result = fn_(self); + unsafe { + LLVMBuildCall2( + self.builder, + set_rounding_fn_type, + set_rounding_fn, + &mut preserved_rounding_mode, + 1, + LLVM_UNNAMED.as_ptr(), + ) + }; + result + } + */ +} + +fn get_pointer_type<'ctx>( + context: LLVMContextRef, + to_space: ast::StateSpace, +) -> Result { + Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) }) +} + +// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes +fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> { + Ok(match scope { + ast::MemScope::Cta => c"workgroup-one-as", + ast::MemScope::Gpu => c"agent-one-as", + ast::MemScope::Sys => c"one-as", + ast::MemScope::Cluster => todo!(), + } + .as_ptr()) +} + +fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> { + Ok(match scope { + ast::MemScope::Cta => c"workgroup", + ast::MemScope::Gpu => c"agent", + ast::MemScope::Sys => c"", + ast::MemScope::Cluster => todo!(), + } + .as_ptr()) +} + +fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { + match semantics { + ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, + ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease, + ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease, + } +} + +fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { + match semantics { + ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, + ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + } +} + +fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result { + Ok(match type_ { + ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar), + ast::Type::Vector(size, scalar) => { + let base_type = get_scalar_type(context, *scalar); + unsafe { LLVMVectorType(base_type, *size as u32) } + } + ast::Type::Array(vec, scalar, dimensions) => { + let mut underlying_type = get_scalar_type(context, *scalar); + if let Some(size) = vec { + underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) }; + } + if dimensions.is_empty() { + return Ok(unsafe { LLVMArrayType2(underlying_type, 0) }); + } + dimensions + .iter() + .rfold(underlying_type, |result, dimension| unsafe { + LLVMArrayType2(result, *dimension as u64) + }) + } + }) +} + +fn get_array_type<'a>( + context: LLVMContextRef, + elem_type: &'a ast::Type, + count: u64, +) -> Result { + let elem_type = get_type(context, elem_type)?; + Ok(unsafe { LLVMArrayType2(elem_type, count) }) +} + +fn check_multiple_return_types<'a>( + mut return_args: impl ExactSizeIterator, +) -> Result<(), TranslateError> { + let err_msg = "Only (.b32, .pred) multiple return types are supported"; + + let first = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?; + let second = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?; + match (first, second) { + (ast::Type::Scalar(first), ast::Type::Scalar(second)) => { + if first.size_of() != 4 || second.size_of() != 1 { + return Err(error_todo_msg(err_msg)); + } + } + _ => return Err(error_todo_msg(err_msg)), + } + Ok(()) +} + +fn get_function_type<'a>( + context: LLVMContextRef, + mut return_args: impl ExactSizeIterator, + input_args: impl ExactSizeIterator>, +) -> Result { + let mut input_args = input_args.collect::, _>>()?; + let return_type = match return_args.len() { + 0 => unsafe { LLVMVoidTypeInContext(context) }, + 1 => get_type(context, &return_args.next().unwrap())?, + _ => { + check_multiple_return_types(return_args)?; + get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)? + } + }; + + Ok(unsafe { + LLVMFunctionType( + return_type, + input_args.as_mut_ptr(), + input_args.len() as u32, + 0, + ) + }) +} + +struct ResolveIdent { + words: HashMap, + values: HashMap, +} + +impl ResolveIdent { + fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self { + ResolveIdent { + words: HashMap::new(), + values: HashMap::new(), + } + } + + fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T { + let str = match self.words.entry(word) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let mut text = word.0.to_string(); + text.push('\0'); + entry.insert(text) + } + }; + fn_(&str[..str.len() - 1]) + } + + fn get_or_add(&mut self, word: SpirvWord) -> &str { + self.get_or_ad_impl(word, |x| x) + } + + fn get_or_add_raw(&mut self, word: SpirvWord) -> *const i8 { + self.get_or_add(word).as_ptr().cast() + } + + fn register(&mut self, word: SpirvWord, v: LLVMValueRef) { + self.values.insert(word, v); + } + + fn value(&self, word: SpirvWord) -> Result { + self.values + .get(&word) + .copied() + .ok_or_else(|| error_unreachable()) + } + + fn with_result( + &mut self, + word: SpirvWord, + fn_: impl FnOnce(*const i8) -> LLVMValueRef, + ) -> LLVMValueRef { + let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast())); + self.register(word, t); + t + } + + fn with_result_option( + &mut self, + word: Option, + fn_: impl FnOnce(*const i8) -> LLVMValueRef, + ) -> LLVMValueRef { + match word { + Some(word) => self.with_result(word, fn_), + None => fn_(LLVM_UNNAMED.as_ptr()), + } + } +} + +struct LLVMTypeDisplay(ast::ScalarType); + +impl std::fmt::Display for LLVMTypeDisplay { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + ast::ScalarType::Pred => write!(f, "i1"), + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"), + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"), + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"), + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"), + ptx_parser::ScalarType::B128 => write!(f, "i128"), + ast::ScalarType::F16 => write!(f, "f16"), + ptx_parser::ScalarType::BF16 => write!(f, "bfloat"), + ast::ScalarType::F32 => write!(f, "f32"), + ast::ScalarType::F64 => write!(f, "f64"), + ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"), + ast::ScalarType::F16x2 => write!(f, "v2f16"), + ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"), + } + } +} + +/* +fn rounding_to_llvm(this: ast::RoundingMode) -> u32 { + match this { + ptx_parser::RoundingMode::Zero => 0, + ptx_parser::RoundingMode::NearestEven => 1, + ptx_parser::RoundingMode::PositiveInf => 2, + ptx_parser::RoundingMode::NegativeInf => 3, + } +} +*/ diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index daaa91f..a40e38a 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -1,5 +1,5 @@ -pub(super) mod emit; pub(super) mod attributes; +pub(super) mod emit; use std::ffi::CStr; use std::ops::Deref; @@ -44,9 +44,7 @@ pub struct Module(LLVMModuleRef); impl Module { fn new(ctx: &Context, name: &CStr) -> Self { - Self( - unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) }, - ) + Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) }) } fn get(&self) -> LLVMModuleRef { diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index ace910e..c10eb56 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,963 +1,966 @@ -use ptx_parser as ast; -use quick_error::quick_error; -use rustc_hash::FxHashMap; -use std::hash::Hash; -use std::{ - borrow::Cow, - collections::{hash_map, HashMap}, - ffi::CString, - iter, -}; -use strum::IntoEnumIterator; -use strum_macros::EnumIter; - -mod deparamize_functions; -mod expand_operands; -mod fix_special_registers2; -mod hoist_globals; -mod insert_explicit_load_store; -mod insert_implicit_conversions2; -mod insert_post_saturation; -mod instruction_mode_to_global_mode; -mod llvm; -mod normalize_basic_blocks; -mod normalize_identifiers2; -mod normalize_predicates2; -mod remove_unreachable_basic_blocks; -mod replace_instructions_with_function_calls; -mod replace_known_functions; -mod resolve_function_pointers; - -static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); -const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; - -quick_error! { - #[derive(Debug)] - pub enum TranslateError { - UnknownSymbol(symbol: String) { - display("Unknown symbol: \"{}\"", symbol) - } - UntypedSymbol {} - MismatchedType {} - Unreachable {} - Todo(msg: String) { - display("TODO: {}", msg) - } - } -} - -/// GPU attributes needed at compile time. -pub struct Attributes { - /// Clock frequency in kHz. - pub clock_rate: u32, -} - -pub fn to_llvm_module<'input>(ast: ast::Module<'input>, attributes: Attributes) -> Result { - let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); - let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); - let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?; - let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; - let directives = replace_known_functions::run(&mut flat_resolver, directives); - let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; - let directives = resolve_function_pointers::run(directives)?; - let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; - let directives = expand_operands::run(&mut flat_resolver, directives)?; - let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; - let directives = deparamize_functions::run(&mut flat_resolver, directives)?; - let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; - let directives = remove_unreachable_basic_blocks::run(directives)?; - let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?; - let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; - let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; - let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?; - let directives = hoist_globals::run(directives)?; - - let context = llvm::Context::new(); - let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?; - let attributes_ir = llvm::attributes::run(&context, attributes)?; - Ok(Module { - llvm_ir, - attributes_ir, - kernel_info: HashMap::new(), - _context: context, - }) -} - -pub struct Module { - pub llvm_ir: llvm::Module, - pub attributes_ir: llvm::Module, - pub kernel_info: HashMap, - _context: llvm::Context, -} - -impl Module { - pub fn linked_bitcode(&self) -> &[u8] { - ZLUDA_PTX_IMPL - } -} - -pub struct KernelInfo { - pub arguments_sizes: Vec<(usize, bool)>, - pub uses_shared_mem: bool, -} - -#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)] -enum PtxSpecialRegister { - Tid, - Ntid, - Ctaid, - Nctaid, - Clock, - LanemaskLt, -} - -impl PtxSpecialRegister { - fn as_str(self) -> &'static str { - match self { - Self::Tid => "%tid", - Self::Ntid => "%ntid", - Self::Ctaid => "%ctaid", - Self::Nctaid => "%nctaid", - Self::Clock => "%clock", - Self::LanemaskLt => "%lanemask_lt", - } - } - - fn get_type(self) -> ast::Type { - match self { - PtxSpecialRegister::Tid - | PtxSpecialRegister::Ntid - | PtxSpecialRegister::Ctaid - | PtxSpecialRegister::Nctaid => ast::Type::Vector(4, self.get_function_return_type()), - _ => ast::Type::Scalar(self.get_function_return_type()), - } - } - - fn get_function_return_type(self) -> ast::ScalarType { - match self { - PtxSpecialRegister::Tid => ast::ScalarType::U32, - PtxSpecialRegister::Ntid => ast::ScalarType::U32, - PtxSpecialRegister::Ctaid => ast::ScalarType::U32, - PtxSpecialRegister::Nctaid => ast::ScalarType::U32, - PtxSpecialRegister::Clock => ast::ScalarType::U32, - PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, - } - } - - fn get_function_input_type(self) -> Option { - match self { - PtxSpecialRegister::Tid - | PtxSpecialRegister::Ntid - | PtxSpecialRegister::Ctaid - | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), - PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None, - } - } - - fn get_unprefixed_function_name(self) -> &'static str { - match self { - PtxSpecialRegister::Tid => "sreg_tid", - PtxSpecialRegister::Ntid => "sreg_ntid", - PtxSpecialRegister::Ctaid => "sreg_ctaid", - PtxSpecialRegister::Nctaid => "sreg_nctaid", - PtxSpecialRegister::Clock => "sreg_clock", - PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", - } - } -} - -#[cfg(debug_assertions)] -fn error_unreachable() -> TranslateError { - unreachable!() -} - -#[cfg(not(debug_assertions))] -fn error_unreachable() -> TranslateError { - TranslateError::Unreachable -} - -#[cfg(debug_assertions)] -fn error_todo_msg>(msg: T) -> TranslateError { - unreachable!("{}", msg.into()) -} - -#[cfg(not(debug_assertions))] -fn error_todo_msg>(msg: T) -> TranslateError { - TranslateError::Todo(msg.into()) -} - -#[cfg(debug_assertions)] -fn error_todo() -> TranslateError { - unreachable!() -} - -#[cfg(not(debug_assertions))] -fn error_todo() -> TranslateError { - TranslateError::Todo("".to_string()) -} - -#[cfg(debug_assertions)] -fn error_unknown_symbol>(symbol: T) -> TranslateError { - panic!("Unknown symbol: \"{}\"", symbol.into()) -} - -#[cfg(not(debug_assertions))] -fn error_unknown_symbol>(symbol: T) -> TranslateError { - TranslateError::UnknownSymbol(symbol.into()) -} - -#[cfg(debug_assertions)] -fn error_mismatched_type() -> TranslateError { - panic!() -} - -#[cfg(not(debug_assertions))] -fn error_mismatched_type() -> TranslateError { - TranslateError::MismatchedType -} - -enum Statement { - Label(SpirvWord), - Variable(ast::Variable), - Instruction(I), - // SPIR-V compatible replacement for PTX predicates - Conditional(BrachCondition), - Conversion(ImplicitConversion), - Constant(ConstantDefinition), - RetValue(ast::RetData, Vec<(SpirvWord, ast::Type)>), - PtrAccess(PtrAccess

), - RepackVector(RepackVectorDetails), - FunctionPointer(FunctionPointerDetails), - VectorRead(VectorRead), - VectorWrite(VectorWrite), - SetMode(ModeRegister), - FpSaturate { - dst: SpirvWord, - src: SpirvWord, - type_: ast::ScalarType, - }, -} - -#[derive(Eq, PartialEq, Clone, Copy)] -#[cfg_attr(test, derive(Debug))] -enum ModeRegister { - Denormal { - f32: bool, - f16f64: bool, - }, - Rounding { - f32: ast::RoundingMode, - f16f64: ast::RoundingMode, - }, -} - -impl> Statement, T> { - fn visit_map, Err>( - self, - visitor: &mut impl ast::VisitorMap, - ) -> std::result::Result, To>, Err> { - Ok(match self { - Statement::Instruction(i) => { - return ast::visit_map(i, visitor).map(Statement::Instruction) - } - Statement::Label(label) => { - Statement::Label(visitor.visit_ident(label, None, false, false)?) - } - Statement::Variable(var) => { - let name = visitor.visit_ident( - var.name, - Some((&var.v_type, var.state_space)), - true, - false, - )?; - Statement::Variable(ast::Variable { - align: var.align, - v_type: var.v_type, - state_space: var.state_space, - name, - array_init: var.array_init, - }) - } - Statement::Conditional(conditional) => { - let predicate = visitor.visit_ident( - conditional.predicate, - Some((&ast::ScalarType::Pred.into(), ast::StateSpace::Reg)), - false, - false, - )?; - let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?; - let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?; - Statement::Conditional(BrachCondition { - predicate, - if_true, - if_false, - }) - } - Statement::Conversion(ImplicitConversion { - src, - dst, - from_type, - to_type, - from_space, - to_space, - kind, - }) => { - let dst = visitor.visit_ident( - dst, - Some((&to_type, ast::StateSpace::Reg)), - true, - false, - )?; - let src = visitor.visit_ident( - src, - Some((&from_type, ast::StateSpace::Reg)), - false, - false, - )?; - Statement::Conversion(ImplicitConversion { - src, - dst, - from_type, - to_type, - from_space, - to_space, - kind, - }) - } - Statement::Constant(ConstantDefinition { dst, typ, value }) => { - let dst = visitor.visit_ident( - dst, - Some((&typ.into(), ast::StateSpace::Reg)), - true, - false, - )?; - Statement::Constant(ConstantDefinition { dst, typ, value }) - } - Statement::RetValue(data, value) => { - let value = value - .into_iter() - .map(|(ident, type_)| { - Ok(( - visitor.visit_ident( - ident, - Some((&type_, ast::StateSpace::Local)), - false, - false, - )?, - type_, - )) - }) - .collect::, _>>()?; - Statement::RetValue(data, value) - } - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src, - }) => { - let dst = - visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?; - let ptr_src = visitor.visit_ident( - ptr_src, - Some((&underlying_type, state_space)), - false, - false, - )?; - let offset_src = visitor.visit( - offset_src, - Some(( - &ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - )), - false, - false, - )?; - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src, - }) - } - Statement::VectorRead(VectorRead { - scalar_type, - vector_width, - scalar_dst: dst, - vector_src, - member, - }) => { - let scalar_t = scalar_type.into(); - let vector_t = ast::Type::Vector(vector_width, scalar_type); - let dst: SpirvWord = visitor.visit_ident( - dst, - Some((&scalar_t, ast::StateSpace::Reg)), - true, - false, - )?; - let src = visitor.visit_ident( - vector_src, - Some((&vector_t, ast::StateSpace::Reg)), - false, - false, - )?; - Statement::VectorRead(VectorRead { - scalar_type, - vector_width, - scalar_dst: dst, - vector_src: src, - member, - }) - } - Statement::VectorWrite(VectorWrite { - scalar_type, - vector_width, - vector_dst, - vector_src, - scalar_src, - member, - }) => { - let scalar_t = scalar_type.into(); - let vector_t = ast::Type::Vector(vector_width, scalar_type); - let vector_dst = visitor.visit_ident( - vector_dst, - Some((&vector_t, ast::StateSpace::Reg)), - true, - false, - )?; - let vector_src = visitor.visit_ident( - vector_src, - Some((&vector_t, ast::StateSpace::Reg)), - false, - false, - )?; - let scalar_src = visitor.visit_ident( - scalar_src, - Some((&scalar_t, ast::StateSpace::Reg)), - false, - false, - )?; - Statement::VectorWrite(VectorWrite { - vector_dst, - vector_src, - scalar_src, - scalar_type, - vector_width, - member, - }) - } - Statement::RepackVector(RepackVectorDetails { - is_extract, - typ, - packed, - unpacked, - relaxed_type_check, - }) => { - let (packed, unpacked) = if is_extract { - let unpacked = unpacked - .into_iter() - .map(|ident| { - visitor.visit_ident( - ident, - Some((&typ.into(), ast::StateSpace::Reg)), - true, - relaxed_type_check, - ) - }) - .collect::, _>>()?; - let packed = visitor.visit_ident( - packed, - Some(( - &ast::Type::Vector(unpacked.len() as u8, typ), - ast::StateSpace::Reg, - )), - false, - false, - )?; - (packed, unpacked) - } else { - let packed = visitor.visit_ident( - packed, - Some(( - &ast::Type::Vector(unpacked.len() as u8, typ), - ast::StateSpace::Reg, - )), - true, - false, - )?; - let unpacked = unpacked - .into_iter() - .map(|ident| { - visitor.visit_ident( - ident, - Some((&typ.into(), ast::StateSpace::Reg)), - false, - relaxed_type_check, - ) - }) - .collect::, _>>()?; - (packed, unpacked) - }; - Statement::RepackVector(RepackVectorDetails { - is_extract, - typ, - packed, - unpacked, - relaxed_type_check, - }) - } - Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { - let dst = visitor.visit_ident( - dst, - Some(( - &ast::Type::Scalar(ast::ScalarType::U64), - ast::StateSpace::Reg, - )), - true, - false, - )?; - let src = visitor.visit_ident(src, None, false, false)?; - Statement::FunctionPointer(FunctionPointerDetails { dst, src }) - } - Statement::SetMode(mode_register) => Statement::SetMode(mode_register), - Statement::FpSaturate { dst, src, type_ } => { - let dst = visitor.visit_ident( - dst, - Some((&type_.into(), ast::StateSpace::Reg)), - true, - false, - )?; - let src = visitor.visit_ident( - src, - Some((&type_.into(), ast::StateSpace::Reg)), - false, - false, - )?; - Statement::FpSaturate { dst, src, type_ } - } - }) - } -} - -struct BrachCondition { - predicate: SpirvWord, - if_true: SpirvWord, - if_false: SpirvWord, -} - -#[derive(Clone)] -struct ImplicitConversion { - src: SpirvWord, - dst: SpirvWord, - from_type: ast::Type, - to_type: ast::Type, - from_space: ast::StateSpace, - to_space: ast::StateSpace, - kind: ConversionKind, -} - -#[derive(PartialEq, Clone)] -enum ConversionKind { - Default, - // zero-extend/chop/bitcast depending on types - SignExtend, - BitToPtr, - PtrToPtr, - AddressOf, -} - -struct ConstantDefinition { - pub dst: SpirvWord, - pub typ: ast::ScalarType, - pub value: ast::ImmediateValue, -} - -pub struct PtrAccess { - underlying_type: ast::Type, - state_space: ast::StateSpace, - dst: SpirvWord, - ptr_src: SpirvWord, - offset_src: T, -} - -struct RepackVectorDetails { - is_extract: bool, - typ: ast::ScalarType, - packed: SpirvWord, - unpacked: Vec, - relaxed_type_check: bool, -} - -struct FunctionPointerDetails { - dst: SpirvWord, - src: SpirvWord, -} - -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)] -pub struct SpirvWord(u32); - -impl From for SpirvWord { - fn from(value: u32) -> Self { - Self(value) - } -} -impl From for u32 { - fn from(value: SpirvWord) -> Self { - value.0 - } -} - -impl ast::Operand for SpirvWord { - type Ident = Self; - - fn from_ident(ident: Self::Ident) -> Self { - ident - } -} - -type ExpandedStatement = Statement, SpirvWord>; - -type NormalizedStatement = Statement< - ( - Option>, - ast::Instruction>, - ), - ast::ParsedOperand, ->; - -enum Directive2 { - Variable(ast::LinkingDirective, ast::Variable), - Method(Function2), -} - -struct Function2 { - pub return_arguments: Vec>, - pub name: Operand::Ident, - pub input_arguments: Vec>, - pub body: Option>>, - is_kernel: bool, - import_as: Option, - tuning: Vec, - linkage: ast::LinkingDirective, - flush_to_zero_f32: bool, - flush_to_zero_f16f64: bool, - rounding_mode_f32: ast::RoundingMode, - rounding_mode_f16f64: ast::RoundingMode, -} - -type NormalizedDirective2 = Directive2< - ( - Option>, - ast::Instruction>, - ), - ast::ParsedOperand, ->; - -type NormalizedFunction2 = Function2< - ( - Option>, - ast::Instruction>, - ), - ast::ParsedOperand, ->; - -type UnconditionalDirective = - Directive2>, ast::ParsedOperand>; - -type UnconditionalFunction = - Function2>, ast::ParsedOperand>; - -struct GlobalStringIdentResolver2<'input> { - pub(crate) current_id: SpirvWord, - pub(crate) ident_map: FxHashMap>, -} - -impl<'input> GlobalStringIdentResolver2<'input> { - fn new(spirv_word: SpirvWord) -> Self { - Self { - current_id: spirv_word, - ident_map: FxHashMap::default(), - } - } - - fn register_named( - &mut self, - name: Cow<'input, str>, - type_space: Option<(ast::Type, ast::StateSpace)>, - ) -> SpirvWord { - let new_id = self.current_id; - self.ident_map.insert( - new_id, - IdentEntry { - name: Some(name), - type_space, - }, - ); - self.current_id.0 += 1; - new_id - } - - fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { - let new_id = self.current_id; - self.ident_map.insert( - new_id, - IdentEntry { - name: None, - type_space, - }, - ); - self.current_id.0 += 1; - new_id - } - - fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> { - match self.ident_map.get(&id) { - Some(IdentEntry { - type_space: Some(type_space), - .. - }) => Ok(type_space), - _ => Err(error_unknown_symbol(format!("{:?}", id))), - } - } -} - -struct IdentEntry<'input> { - name: Option>, - type_space: Option<(ast::Type, ast::StateSpace)>, -} - -struct ScopedResolver<'input, 'b> { - flat_resolver: &'b mut GlobalStringIdentResolver2<'input>, - scopes: Vec>, -} - -impl<'input, 'b> ScopedResolver<'input, 'b> { - fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self { - Self { - flat_resolver, - scopes: vec![ScopeMarker::new()], - } - } - - fn start_scope(&mut self) { - self.scopes.push(ScopeMarker::new()); - } - - fn end_scope(&mut self) { - let scope = self.scopes.pop().unwrap(); - scope.flush(self.flat_resolver); - } - - fn add_or_get_in_current_scope_untyped( - &mut self, - name: &'input str, - ) -> Result { - let current_scope = self.scopes.last_mut().unwrap(); - Ok( - match current_scope.name_to_ident.entry(Cow::Borrowed(name)) { - hash_map::Entry::Occupied(occupied_entry) => { - let ident = *occupied_entry.get(); - let entry = current_scope - .ident_map - .get(&ident) - .ok_or_else(|| error_unreachable())?; - if entry.type_space.is_some() { - return Err(error_unknown_symbol(name)); - } - ident - } - hash_map::Entry::Vacant(vacant_entry) => { - let new_id = self.flat_resolver.current_id; - self.flat_resolver.current_id.0 += 1; - vacant_entry.insert(new_id); - current_scope.ident_map.insert( - new_id, - IdentEntry { - name: Some(Cow::Borrowed(name)), - type_space: None, - }, - ); - new_id - } - }, - ) - } - - fn add( - &mut self, - name: Cow<'input, str>, - type_space: Option<(ast::Type, ast::StateSpace)>, - ) -> Result { - let result = self.flat_resolver.current_id; - self.flat_resolver.current_id.0 += 1; - let current_scope = self.scopes.last_mut().unwrap(); - if current_scope - .name_to_ident - .insert(name.clone(), result) - .is_some() - { - return Err(error_unknown_symbol(name)); - } - current_scope.ident_map.insert( - result, - IdentEntry { - name: Some(name), - type_space, - }, - ); - Ok(result) - } - - fn get(&mut self, name: &str) -> Result { - self.scopes - .iter() - .rev() - .find_map(|resolver| resolver.name_to_ident.get(name).copied()) - .ok_or_else(|| error_unknown_symbol(name)) - } - - fn get_in_current_scope(&self, label: &'input str) -> Result { - let current_scope = self.scopes.last().unwrap(); - current_scope - .name_to_ident - .get(label) - .copied() - .ok_or_else(|| error_unreachable()) - } -} - -struct ScopeMarker<'input> { - ident_map: FxHashMap>, - name_to_ident: FxHashMap, SpirvWord>, -} - -impl<'input> ScopeMarker<'input> { - fn new() -> Self { - Self { - ident_map: FxHashMap::default(), - name_to_ident: FxHashMap::default(), - } - } - - fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) { - resolver.ident_map.extend(self.ident_map); - } -} - -struct SpecialRegistersMap2 { - reg_to_id: FxHashMap, - id_to_reg: FxHashMap, -} - -impl SpecialRegistersMap2 { - fn new(resolver: &mut ScopedResolver) -> Result { - let mut result = SpecialRegistersMap2 { - reg_to_id: FxHashMap::default(), - id_to_reg: FxHashMap::default(), - }; - for sreg in PtxSpecialRegister::iter() { - let text = sreg.as_str(); - let id = resolver.add( - Cow::Borrowed(text), - Some((sreg.get_type(), ast::StateSpace::Reg)), - )?; - result.reg_to_id.insert(sreg, id); - result.id_to_reg.insert(id, sreg); - } - Ok(result) - } - - fn get(&self, id: SpirvWord) -> Option { - self.id_to_reg.get(&id).copied() - } - - fn len() -> usize { - PtxSpecialRegister::iter().len() - } - - fn foreach_declaration<'a, 'input>( - resolver: &'a mut GlobalStringIdentResolver2<'input>, - mut fn_: impl FnMut( - PtxSpecialRegister, - ( - Vec>, - SpirvWord, - Vec>, - ), - ), - ) { - for sreg in PtxSpecialRegister::iter() { - let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); - let name = resolver.register_named(Cow::Owned(external_fn_name), None); - let return_type = sreg.get_function_return_type(); - let input_type = sreg.get_function_input_type(); - let return_arguments = vec![ast::Variable { - align: None, - v_type: return_type.into(), - state_space: ast::StateSpace::Reg, - name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), - }]; - let input_arguments = input_type - .into_iter() - .map(|type_| ast::Variable { - align: None, - v_type: type_.into(), - state_space: ast::StateSpace::Reg, - name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), - }) - .collect::>(); - fn_(sreg, (return_arguments, name, input_arguments)); - } - } -} - -pub struct VectorRead { - scalar_type: ast::ScalarType, - vector_width: u8, - scalar_dst: SpirvWord, - vector_src: SpirvWord, - member: u8, -} - -pub struct VectorWrite { - scalar_type: ast::ScalarType, - vector_width: u8, - vector_dst: SpirvWord, - vector_src: SpirvWord, - scalar_src: SpirvWord, - member: u8, -} - -fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { - match this { - ast::ScalarType::B8 => "b8", - ast::ScalarType::B16 => "b16", - ast::ScalarType::B32 => "b32", - ast::ScalarType::B64 => "b64", - ast::ScalarType::B128 => "b128", - ast::ScalarType::U8 => "u8", - ast::ScalarType::U16 => "u16", - ast::ScalarType::U16x2 => "u16x2", - ast::ScalarType::U32 => "u32", - ast::ScalarType::U64 => "u64", - ast::ScalarType::S8 => "s8", - ast::ScalarType::S16 => "s16", - ast::ScalarType::S16x2 => "s16x2", - ast::ScalarType::S32 => "s32", - ast::ScalarType::S64 => "s64", - ast::ScalarType::F16 => "f16", - ast::ScalarType::F16x2 => "f16x2", - ast::ScalarType::F32 => "f32", - ast::ScalarType::F64 => "f64", - ast::ScalarType::BF16 => "bf16", - ast::ScalarType::BF16x2 => "bf16x2", - ast::ScalarType::Pred => "pred", - } -} - -type UnconditionalStatement = - Statement>, ast::ParsedOperand>; +use ptx_parser as ast; +use quick_error::quick_error; +use rustc_hash::FxHashMap; +use std::hash::Hash; +use std::{ + borrow::Cow, + collections::{hash_map, HashMap}, + ffi::CString, + iter, +}; +use strum::IntoEnumIterator; +use strum_macros::EnumIter; + +mod deparamize_functions; +mod expand_operands; +mod fix_special_registers2; +mod hoist_globals; +mod insert_explicit_load_store; +mod insert_implicit_conversions2; +mod insert_post_saturation; +mod instruction_mode_to_global_mode; +mod llvm; +mod normalize_basic_blocks; +mod normalize_identifiers2; +mod normalize_predicates2; +mod remove_unreachable_basic_blocks; +mod replace_instructions_with_function_calls; +mod replace_known_functions; +mod resolve_function_pointers; + +static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; + +quick_error! { + #[derive(Debug)] + pub enum TranslateError { + UnknownSymbol(symbol: String) { + display("Unknown symbol: \"{}\"", symbol) + } + UntypedSymbol {} + MismatchedType {} + Unreachable {} + Todo(msg: String) { + display("TODO: {}", msg) + } + } +} + +/// GPU attributes needed at compile time. +pub struct Attributes { + /// Clock frequency in kHz. + pub clock_rate: u32, +} + +pub fn to_llvm_module<'input>( + ast: ast::Module<'input>, + attributes: Attributes, +) -> Result { + let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); + let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); + let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?; + let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; + let directives = replace_known_functions::run(&mut flat_resolver, directives); + let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; + let directives = resolve_function_pointers::run(directives)?; + let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; + let directives = expand_operands::run(&mut flat_resolver, directives)?; + let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; + let directives = deparamize_functions::run(&mut flat_resolver, directives)?; + let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; + let directives = remove_unreachable_basic_blocks::run(directives)?; + let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?; + let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; + let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; + let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?; + let directives = hoist_globals::run(directives)?; + + let context = llvm::Context::new(); + let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?; + let attributes_ir = llvm::attributes::run(&context, attributes)?; + Ok(Module { + llvm_ir, + attributes_ir, + kernel_info: HashMap::new(), + _context: context, + }) +} + +pub struct Module { + pub llvm_ir: llvm::Module, + pub attributes_ir: llvm::Module, + pub kernel_info: HashMap, + _context: llvm::Context, +} + +impl Module { + pub fn linked_bitcode(&self) -> &[u8] { + ZLUDA_PTX_IMPL + } +} + +pub struct KernelInfo { + pub arguments_sizes: Vec<(usize, bool)>, + pub uses_shared_mem: bool, +} + +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)] +enum PtxSpecialRegister { + Tid, + Ntid, + Ctaid, + Nctaid, + Clock, + LanemaskLt, +} + +impl PtxSpecialRegister { + fn as_str(self) -> &'static str { + match self { + Self::Tid => "%tid", + Self::Ntid => "%ntid", + Self::Ctaid => "%ctaid", + Self::Nctaid => "%nctaid", + Self::Clock => "%clock", + Self::LanemaskLt => "%lanemask_lt", + } + } + + fn get_type(self) -> ast::Type { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => ast::Type::Vector(4, self.get_function_return_type()), + _ => ast::Type::Scalar(self.get_function_return_type()), + } + } + + fn get_function_return_type(self) -> ast::ScalarType { + match self { + PtxSpecialRegister::Tid => ast::ScalarType::U32, + PtxSpecialRegister::Ntid => ast::ScalarType::U32, + PtxSpecialRegister::Ctaid => ast::ScalarType::U32, + PtxSpecialRegister::Nctaid => ast::ScalarType::U32, + PtxSpecialRegister::Clock => ast::ScalarType::U32, + PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, + } + } + + fn get_function_input_type(self) -> Option { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), + PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None, + } + } + + fn get_unprefixed_function_name(self) -> &'static str { + match self { + PtxSpecialRegister::Tid => "sreg_tid", + PtxSpecialRegister::Ntid => "sreg_ntid", + PtxSpecialRegister::Ctaid => "sreg_ctaid", + PtxSpecialRegister::Nctaid => "sreg_nctaid", + PtxSpecialRegister::Clock => "sreg_clock", + PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", + } + } +} + +#[cfg(debug_assertions)] +fn error_unreachable() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_unreachable() -> TranslateError { + TranslateError::Unreachable +} + +#[cfg(debug_assertions)] +fn error_todo_msg>(msg: T) -> TranslateError { + unreachable!("{}", msg.into()) +} + +#[cfg(not(debug_assertions))] +fn error_todo_msg>(msg: T) -> TranslateError { + TranslateError::Todo(msg.into()) +} + +#[cfg(debug_assertions)] +fn error_todo() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_todo() -> TranslateError { + TranslateError::Todo("".to_string()) +} + +#[cfg(debug_assertions)] +fn error_unknown_symbol>(symbol: T) -> TranslateError { + panic!("Unknown symbol: \"{}\"", symbol.into()) +} + +#[cfg(not(debug_assertions))] +fn error_unknown_symbol>(symbol: T) -> TranslateError { + TranslateError::UnknownSymbol(symbol.into()) +} + +#[cfg(debug_assertions)] +fn error_mismatched_type() -> TranslateError { + panic!() +} + +#[cfg(not(debug_assertions))] +fn error_mismatched_type() -> TranslateError { + TranslateError::MismatchedType +} + +enum Statement { + Label(SpirvWord), + Variable(ast::Variable), + Instruction(I), + // SPIR-V compatible replacement for PTX predicates + Conditional(BrachCondition), + Conversion(ImplicitConversion), + Constant(ConstantDefinition), + RetValue(ast::RetData, Vec<(SpirvWord, ast::Type)>), + PtrAccess(PtrAccess

), + RepackVector(RepackVectorDetails), + FunctionPointer(FunctionPointerDetails), + VectorRead(VectorRead), + VectorWrite(VectorWrite), + SetMode(ModeRegister), + FpSaturate { + dst: SpirvWord, + src: SpirvWord, + type_: ast::ScalarType, + }, +} + +#[derive(Eq, PartialEq, Clone, Copy)] +#[cfg_attr(test, derive(Debug))] +enum ModeRegister { + Denormal { + f32: bool, + f16f64: bool, + }, + Rounding { + f32: ast::RoundingMode, + f16f64: ast::RoundingMode, + }, +} + +impl> Statement, T> { + fn visit_map, Err>( + self, + visitor: &mut impl ast::VisitorMap, + ) -> std::result::Result, To>, Err> { + Ok(match self { + Statement::Instruction(i) => { + return ast::visit_map(i, visitor).map(Statement::Instruction) + } + Statement::Label(label) => { + Statement::Label(visitor.visit_ident(label, None, false, false)?) + } + Statement::Variable(var) => { + let name = visitor.visit_ident( + var.name, + Some((&var.v_type, var.state_space)), + true, + false, + )?; + Statement::Variable(ast::Variable { + align: var.align, + v_type: var.v_type, + state_space: var.state_space, + name, + array_init: var.array_init, + }) + } + Statement::Conditional(conditional) => { + let predicate = visitor.visit_ident( + conditional.predicate, + Some((&ast::ScalarType::Pred.into(), ast::StateSpace::Reg)), + false, + false, + )?; + let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?; + let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?; + Statement::Conditional(BrachCondition { + predicate, + if_true, + if_false, + }) + } + Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + to_type, + from_space, + to_space, + kind, + }) => { + let dst = visitor.visit_ident( + dst, + Some((&to_type, ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + src, + Some((&from_type, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + to_type, + from_space, + to_space, + kind, + }) + } + Statement::Constant(ConstantDefinition { dst, typ, value }) => { + let dst = visitor.visit_ident( + dst, + Some((&typ.into(), ast::StateSpace::Reg)), + true, + false, + )?; + Statement::Constant(ConstantDefinition { dst, typ, value }) + } + Statement::RetValue(data, value) => { + let value = value + .into_iter() + .map(|(ident, type_)| { + Ok(( + visitor.visit_ident( + ident, + Some((&type_, ast::StateSpace::Local)), + false, + false, + )?, + type_, + )) + }) + .collect::, _>>()?; + Statement::RetValue(data, value) + } + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) => { + let dst = + visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?; + let ptr_src = visitor.visit_ident( + ptr_src, + Some((&underlying_type, state_space)), + false, + false, + )?; + let offset_src = visitor.visit( + offset_src, + Some(( + &ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + )), + false, + false, + )?; + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) + } + Statement::VectorRead(VectorRead { + scalar_type, + vector_width, + scalar_dst: dst, + vector_src, + member, + }) => { + let scalar_t = scalar_type.into(); + let vector_t = ast::Type::Vector(vector_width, scalar_type); + let dst: SpirvWord = visitor.visit_ident( + dst, + Some((&scalar_t, ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + vector_src, + Some((&vector_t, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::VectorRead(VectorRead { + scalar_type, + vector_width, + scalar_dst: dst, + vector_src: src, + member, + }) + } + Statement::VectorWrite(VectorWrite { + scalar_type, + vector_width, + vector_dst, + vector_src, + scalar_src, + member, + }) => { + let scalar_t = scalar_type.into(); + let vector_t = ast::Type::Vector(vector_width, scalar_type); + let vector_dst = visitor.visit_ident( + vector_dst, + Some((&vector_t, ast::StateSpace::Reg)), + true, + false, + )?; + let vector_src = visitor.visit_ident( + vector_src, + Some((&vector_t, ast::StateSpace::Reg)), + false, + false, + )?; + let scalar_src = visitor.visit_ident( + scalar_src, + Some((&scalar_t, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::VectorWrite(VectorWrite { + vector_dst, + vector_src, + scalar_src, + scalar_type, + vector_width, + member, + }) + } + Statement::RepackVector(RepackVectorDetails { + is_extract, + typ, + packed, + unpacked, + relaxed_type_check, + }) => { + let (packed, unpacked) = if is_extract { + let unpacked = unpacked + .into_iter() + .map(|ident| { + visitor.visit_ident( + ident, + Some((&typ.into(), ast::StateSpace::Reg)), + true, + relaxed_type_check, + ) + }) + .collect::, _>>()?; + let packed = visitor.visit_ident( + packed, + Some(( + &ast::Type::Vector(unpacked.len() as u8, typ), + ast::StateSpace::Reg, + )), + false, + false, + )?; + (packed, unpacked) + } else { + let packed = visitor.visit_ident( + packed, + Some(( + &ast::Type::Vector(unpacked.len() as u8, typ), + ast::StateSpace::Reg, + )), + true, + false, + )?; + let unpacked = unpacked + .into_iter() + .map(|ident| { + visitor.visit_ident( + ident, + Some((&typ.into(), ast::StateSpace::Reg)), + false, + relaxed_type_check, + ) + }) + .collect::, _>>()?; + (packed, unpacked) + }; + Statement::RepackVector(RepackVectorDetails { + is_extract, + typ, + packed, + unpacked, + relaxed_type_check, + }) + } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + let dst = visitor.visit_ident( + dst, + Some(( + &ast::Type::Scalar(ast::ScalarType::U64), + ast::StateSpace::Reg, + )), + true, + false, + )?; + let src = visitor.visit_ident(src, None, false, false)?; + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) + } + Statement::SetMode(mode_register) => Statement::SetMode(mode_register), + Statement::FpSaturate { dst, src, type_ } => { + let dst = visitor.visit_ident( + dst, + Some((&type_.into(), ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + src, + Some((&type_.into(), ast::StateSpace::Reg)), + false, + false, + )?; + Statement::FpSaturate { dst, src, type_ } + } + }) + } +} + +struct BrachCondition { + predicate: SpirvWord, + if_true: SpirvWord, + if_false: SpirvWord, +} + +#[derive(Clone)] +struct ImplicitConversion { + src: SpirvWord, + dst: SpirvWord, + from_type: ast::Type, + to_type: ast::Type, + from_space: ast::StateSpace, + to_space: ast::StateSpace, + kind: ConversionKind, +} + +#[derive(PartialEq, Clone)] +enum ConversionKind { + Default, + // zero-extend/chop/bitcast depending on types + SignExtend, + BitToPtr, + PtrToPtr, + AddressOf, +} + +struct ConstantDefinition { + pub dst: SpirvWord, + pub typ: ast::ScalarType, + pub value: ast::ImmediateValue, +} + +pub struct PtrAccess { + underlying_type: ast::Type, + state_space: ast::StateSpace, + dst: SpirvWord, + ptr_src: SpirvWord, + offset_src: T, +} + +struct RepackVectorDetails { + is_extract: bool, + typ: ast::ScalarType, + packed: SpirvWord, + unpacked: Vec, + relaxed_type_check: bool, +} + +struct FunctionPointerDetails { + dst: SpirvWord, + src: SpirvWord, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)] +pub struct SpirvWord(u32); + +impl From for SpirvWord { + fn from(value: u32) -> Self { + Self(value) + } +} +impl From for u32 { + fn from(value: SpirvWord) -> Self { + value.0 + } +} + +impl ast::Operand for SpirvWord { + type Ident = Self; + + fn from_ident(ident: Self::Ident) -> Self { + ident + } +} + +type ExpandedStatement = Statement, SpirvWord>; + +type NormalizedStatement = Statement< + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +enum Directive2 { + Variable(ast::LinkingDirective, ast::Variable), + Method(Function2), +} + +struct Function2 { + pub return_arguments: Vec>, + pub name: Operand::Ident, + pub input_arguments: Vec>, + pub body: Option>>, + is_kernel: bool, + import_as: Option, + tuning: Vec, + linkage: ast::LinkingDirective, + flush_to_zero_f32: bool, + flush_to_zero_f16f64: bool, + rounding_mode_f32: ast::RoundingMode, + rounding_mode_f16f64: ast::RoundingMode, +} + +type NormalizedDirective2 = Directive2< + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +type NormalizedFunction2 = Function2< + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +type UnconditionalDirective = + Directive2>, ast::ParsedOperand>; + +type UnconditionalFunction = + Function2>, ast::ParsedOperand>; + +struct GlobalStringIdentResolver2<'input> { + pub(crate) current_id: SpirvWord, + pub(crate) ident_map: FxHashMap>, +} + +impl<'input> GlobalStringIdentResolver2<'input> { + fn new(spirv_word: SpirvWord) -> Self { + Self { + current_id: spirv_word, + ident_map: FxHashMap::default(), + } + } + + fn register_named( + &mut self, + name: Cow<'input, str>, + type_space: Option<(ast::Type, ast::StateSpace)>, + ) -> SpirvWord { + let new_id = self.current_id; + self.ident_map.insert( + new_id, + IdentEntry { + name: Some(name), + type_space, + }, + ); + self.current_id.0 += 1; + new_id + } + + fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { + let new_id = self.current_id; + self.ident_map.insert( + new_id, + IdentEntry { + name: None, + type_space, + }, + ); + self.current_id.0 += 1; + new_id + } + + fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> { + match self.ident_map.get(&id) { + Some(IdentEntry { + type_space: Some(type_space), + .. + }) => Ok(type_space), + _ => Err(error_unknown_symbol(format!("{:?}", id))), + } + } +} + +struct IdentEntry<'input> { + name: Option>, + type_space: Option<(ast::Type, ast::StateSpace)>, +} + +struct ScopedResolver<'input, 'b> { + flat_resolver: &'b mut GlobalStringIdentResolver2<'input>, + scopes: Vec>, +} + +impl<'input, 'b> ScopedResolver<'input, 'b> { + fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self { + Self { + flat_resolver, + scopes: vec![ScopeMarker::new()], + } + } + + fn start_scope(&mut self) { + self.scopes.push(ScopeMarker::new()); + } + + fn end_scope(&mut self) { + let scope = self.scopes.pop().unwrap(); + scope.flush(self.flat_resolver); + } + + fn add_or_get_in_current_scope_untyped( + &mut self, + name: &'input str, + ) -> Result { + let current_scope = self.scopes.last_mut().unwrap(); + Ok( + match current_scope.name_to_ident.entry(Cow::Borrowed(name)) { + hash_map::Entry::Occupied(occupied_entry) => { + let ident = *occupied_entry.get(); + let entry = current_scope + .ident_map + .get(&ident) + .ok_or_else(|| error_unreachable())?; + if entry.type_space.is_some() { + return Err(error_unknown_symbol(name)); + } + ident + } + hash_map::Entry::Vacant(vacant_entry) => { + let new_id = self.flat_resolver.current_id; + self.flat_resolver.current_id.0 += 1; + vacant_entry.insert(new_id); + current_scope.ident_map.insert( + new_id, + IdentEntry { + name: Some(Cow::Borrowed(name)), + type_space: None, + }, + ); + new_id + } + }, + ) + } + + fn add( + &mut self, + name: Cow<'input, str>, + type_space: Option<(ast::Type, ast::StateSpace)>, + ) -> Result { + let result = self.flat_resolver.current_id; + self.flat_resolver.current_id.0 += 1; + let current_scope = self.scopes.last_mut().unwrap(); + if current_scope + .name_to_ident + .insert(name.clone(), result) + .is_some() + { + return Err(error_unknown_symbol(name)); + } + current_scope.ident_map.insert( + result, + IdentEntry { + name: Some(name), + type_space, + }, + ); + Ok(result) + } + + fn get(&mut self, name: &str) -> Result { + self.scopes + .iter() + .rev() + .find_map(|resolver| resolver.name_to_ident.get(name).copied()) + .ok_or_else(|| error_unknown_symbol(name)) + } + + fn get_in_current_scope(&self, label: &'input str) -> Result { + let current_scope = self.scopes.last().unwrap(); + current_scope + .name_to_ident + .get(label) + .copied() + .ok_or_else(|| error_unreachable()) + } +} + +struct ScopeMarker<'input> { + ident_map: FxHashMap>, + name_to_ident: FxHashMap, SpirvWord>, +} + +impl<'input> ScopeMarker<'input> { + fn new() -> Self { + Self { + ident_map: FxHashMap::default(), + name_to_ident: FxHashMap::default(), + } + } + + fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) { + resolver.ident_map.extend(self.ident_map); + } +} + +struct SpecialRegistersMap2 { + reg_to_id: FxHashMap, + id_to_reg: FxHashMap, +} + +impl SpecialRegistersMap2 { + fn new(resolver: &mut ScopedResolver) -> Result { + let mut result = SpecialRegistersMap2 { + reg_to_id: FxHashMap::default(), + id_to_reg: FxHashMap::default(), + }; + for sreg in PtxSpecialRegister::iter() { + let text = sreg.as_str(); + let id = resolver.add( + Cow::Borrowed(text), + Some((sreg.get_type(), ast::StateSpace::Reg)), + )?; + result.reg_to_id.insert(sreg, id); + result.id_to_reg.insert(id, sreg); + } + Ok(result) + } + + fn get(&self, id: SpirvWord) -> Option { + self.id_to_reg.get(&id).copied() + } + + fn len() -> usize { + PtxSpecialRegister::iter().len() + } + + fn foreach_declaration<'a, 'input>( + resolver: &'a mut GlobalStringIdentResolver2<'input>, + mut fn_: impl FnMut( + PtxSpecialRegister, + ( + Vec>, + SpirvWord, + Vec>, + ), + ), + ) { + for sreg in PtxSpecialRegister::iter() { + let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); + let name = resolver.register_named(Cow::Owned(external_fn_name), None); + let return_type = sreg.get_function_return_type(); + let input_type = sreg.get_function_input_type(); + let return_arguments = vec![ast::Variable { + align: None, + v_type: return_type.into(), + state_space: ast::StateSpace::Reg, + name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }]; + let input_arguments = input_type + .into_iter() + .map(|type_| ast::Variable { + align: None, + v_type: type_.into(), + state_space: ast::StateSpace::Reg, + name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }) + .collect::>(); + fn_(sreg, (return_arguments, name, input_arguments)); + } + } +} + +pub struct VectorRead { + scalar_type: ast::ScalarType, + vector_width: u8, + scalar_dst: SpirvWord, + vector_src: SpirvWord, + member: u8, +} + +pub struct VectorWrite { + scalar_type: ast::ScalarType, + vector_width: u8, + vector_dst: SpirvWord, + vector_src: SpirvWord, + scalar_src: SpirvWord, + member: u8, +} + +fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { + match this { + ast::ScalarType::B8 => "b8", + ast::ScalarType::B16 => "b16", + ast::ScalarType::B32 => "b32", + ast::ScalarType::B64 => "b64", + ast::ScalarType::B128 => "b128", + ast::ScalarType::U8 => "u8", + ast::ScalarType::U16 => "u16", + ast::ScalarType::U16x2 => "u16x2", + ast::ScalarType::U32 => "u32", + ast::ScalarType::U64 => "u64", + ast::ScalarType::S8 => "s8", + ast::ScalarType::S16 => "s16", + ast::ScalarType::S16x2 => "s16x2", + ast::ScalarType::S32 => "s32", + ast::ScalarType::S64 => "s64", + ast::ScalarType::F16 => "f16", + ast::ScalarType::F16x2 => "f16x2", + ast::ScalarType::F32 => "f32", + ast::ScalarType::F64 => "f64", + ast::ScalarType::BF16 => "bf16", + ast::ScalarType::BF16x2 => "bf16x2", + ast::ScalarType::Pred => "pred", + } +} + +type UnconditionalStatement = + Statement>, ast::ParsedOperand>; diff --git a/ptx/src/pass/normalize_basic_blocks.rs b/ptx/src/pass/normalize_basic_blocks.rs index 920cf21..5566b54 100644 --- a/ptx/src/pass/normalize_basic_blocks.rs +++ b/ptx/src/pass/normalize_basic_blocks.rs @@ -21,7 +21,9 @@ pub(crate) fn run( for directive in directives.iter_mut() { let (body_ref, is_kernel) = match directive { Directive2::Method(Function2 { - body: Some(body), is_kernel, .. + body: Some(body), + is_kernel, + .. }) => (body, *is_kernel), _ => continue, }; diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index ba45ec0..d81fdf8 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -9,7 +9,9 @@ fn parse_and_assert(ptx_text: &str) { fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> { let ast = ast::parse_module_checked(ptx_text).unwrap(); - let attributes = pass::Attributes { clock_rate: 2124000 }; + let attributes = pass::Attributes { + clock_rate: 2124000, + }; crate::to_llvm_module(ast, attributes)?; Ok(()) } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index efa16a4..7f7a424 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,647 +1,688 @@ -use crate::pass; -use comgr::Comgr; -use cuda_types::cuda::CUstream; -use hip_runtime_sys::hipError_t; -use pretty_assertions; -use std::env; -use std::error; -use std::ffi::{CStr, CString}; -use std::fmt::{self, Debug, Display, Formatter}; -use std::fs::{self, File}; -use std::io::Write; -use std::mem; -use std::path::{Path, PathBuf}; -use std::ptr; -use std::str; - -#[cfg(not(feature = "ci_build"))] -macro_rules! read_test_file { - ($file:expr) => { - { - // CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx). - let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - path.pop(); - path.push(file!()); - path.pop(); - path.push($file); - std::fs::read_to_string(path).unwrap() - } - }; -} - -#[cfg(feature = "ci_build")] -macro_rules! read_test_file { - ($file:expr) => { - include_str!($file).to_string() - }; -} - -macro_rules! test_ptx_llvm { - ($fn_name:ident) => { - paste::item! { - #[test] - fn [<$fn_name _llvm>]() -> Result<(), Box> { - let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); - let ll = read_test_file!(concat!("../ll/", stringify!($fn_name), ".ll")); - test_llvm_assert(stringify!($fn_name), &ptx, ll.trim()) - } - } - } -} - -macro_rules! test_ptx { - ($fn_name:ident, $input:expr, $output:expr) => { - paste::item! { - #[test] - fn [<$fn_name _amdgpu>]() -> Result<(), Box> { - let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); - let input = $input; - let output = $output; - test_hip_assert(stringify!($fn_name), &ptx, Some(&input), &output, 1) - } - } - - paste::item! { - #[test] - fn [<$fn_name _cuda>]() -> Result<(), Box> { - let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); - let input = $input; - let output = $output; - test_cuda_assert(stringify!($fn_name), &ptx, Some(&input), &output, 1) - } - } - - test_ptx_llvm!($fn_name); - }; - - ($fn_name:ident) => { - test_ptx_llvm!($fn_name); - }; -} - -macro_rules! test_ptx_warp { - ($fn_name:ident, $output:expr) => { - paste::item! { - #[test] - fn [<$fn_name _amdgpu>]() -> Result<(), Box> { - let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); - let mut output = $output; - test_hip_assert(stringify!($fn_name), &ptx, None::<&[u8]>, &mut output, 64) - } - } - - paste::item! { - #[test] - fn [<$fn_name _cuda>]() -> Result<(), Box> { - let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); - let mut output = $output; - test_cuda_assert(stringify!($fn_name), &ptx, None::<&[u8]>, &mut output, 64) - } - } - - test_ptx_llvm!($fn_name); - }; -} - -test_ptx!(ld_st, [1u64], [1u64]); -test_ptx!(ld_st_implicit, [0.5f32, 0.25f32], [0.5f32]); -test_ptx!(mov, [1u64], [1u64]); -test_ptx!(mul_lo, [1u64], [2u64]); -test_ptx!(mul_hi, [u64::max_value()], [1u64]); -test_ptx!(add, [1u64], [2u64]); -test_ptx!( - mul24_lo_u32, - [0b01110101_01010101_01010101u32], - [0b00011100_00100011_10001110_00111001u32] -); -test_ptx!( - mul24_hi_u32, - [0b01110101_01010101_01010101u32], - [0b00110101_11000111_00011100_00100011u32] -); -test_ptx!( - mul24_lo_s32, - [0b01110101_01010101_01010101i32], - [-0b0011100_00100011_10001110_00111001i32] -); -test_ptx!( - mul24_hi_s32, - [0b01110101_01010101_01010101i32], - [-0b0110101_11000111_00011100_00100100i32] -); -test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); -test_ptx!(setp_gt, [f32::NAN, 1f32], [1f32]); -test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]); -test_ptx!(bra, [10u64], [11u64]); -test_ptx!(not, [0u64], [u64::max_value()]); -test_ptx!(shl, [11u64], [44u64]); -test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); -test_ptx!(cvta, [3.0f32], [3.0f32]); -test_ptx!(block, [1u64], [2u64]); -test_ptx!(local_align, [1u64], [1u64]); -test_ptx!(call, [1u64], [2u64]); -test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); -test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]); -test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); -test_ptx!(ntid, [3u32], [4u32]); -test_ptx!(reg_local, [12u64], [13u64]); -test_ptx!(mov_address, [0xDEADu64], [0u64]); -test_ptx!(b64tof64, [111u64], [111u64]); -// This segfaults NV compiler -// test_ptx!(implicit_param, [34u32], [34u32]); -test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]); -test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32]); -test_ptx!(mad_wide, [-1i32, 3, 4, 5], [21474836481i64]); -test_ptx!( - mul_wide, - [0x01_00_00_00__01_00_00_00i64], - [0x1_00_00_00_00_00_00i64] -); -test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); -test_ptx!(shr, [-2i32], [-1i32]); -test_ptx!(shr_oob, [-32768i16], [-1i16]); -test_ptx!(or, [1u64, 2u64], [3u64]); -test_ptx!(sub, [2u64], [1u64]); -test_ptx!(min, [555i32, 444i32], [444i32]); -test_ptx!(max, [555i32, 444i32], [555i32]); -test_ptx!(global_array, [0xDEADu32], [1u32]); -test_ptx!(extern_shared, [127u64], [127u64]); -test_ptx!(extern_shared_call, [121u64], [123u64]); -test_ptx!(rcp, [2f32], [0.5f32]); -// 0b1_00000000_10000000000000000000000u32 is a large denormal -// 0x3f000000 is 0.5 -test_ptx!( - mul_ftz, - [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], - [0b1_00000000_00000000000000000000000u32] -); -test_ptx!( - mul_non_ftz, - [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], - [0b1_00000000_01000000000000000000000u32] -); -test_ptx!(constant_f32, [10f32], [5f32]); -test_ptx!(abs, [i32::MIN], [i32::MIN]); -test_ptx!(constant_negative, [-101i32], [101i32]); -test_ptx!(and, [6u32, 3u32], [2u32]); -test_ptx!(selp, [100u16, 200u16], [200u16]); -test_ptx!(selp_true, [100u16, 200u16], [100u16]); -test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]); -test_ptx!(shared_variable, [513u64], [513u64]); -test_ptx!(shared_ptr_32, [513u64], [513u64]); -test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]); -test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]); -test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]); -test_ptx!(div_approx, [1f32, 2f32], [0.5f32]); -test_ptx!(sqrt, [0.25f32], [0.5f32]); -test_ptx!(rsqrt, [0.25f64], [2f64]); -test_ptx!(neg, [181i32], [-181i32]); -test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]); -test_ptx!(cos, [std::f32::consts::PI], [-1f32]); -test_ptx!(lg2, [512f32], [9f32]); -test_ptx!(ex2, [10f32], [1024f32]); -test_ptx!(fmax, [0u16, half::f16::NAN.to_bits()], [0u16]); -test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]); -test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]); -test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]); -test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]); -test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); -test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); -test_ptx!( - brev, - [0b11000111_01011100_10101110_11111011u32], - [0b11011111_01110101_00111010_11100011u32] -); -test_ptx!( - xor, - [ - 0b01010010_00011010_01000000_00001101u32, - 0b11100110_10011011_00001100_00100011u32 - ], - [0b10110100100000010100110000101110u32] -); -test_ptx!(rem, [21692i32, 13i32], [8i32]); -test_ptx!( - bfe, - [0b11111000_11000001_00100010_10100000u32, 16u32, 8u32], - [0b11000001u32] -); -test_ptx!(bfi, [0b10u32, 0b101u32, 0u32, 2u32], [0b110u32]); -test_ptx!(stateful_ld_st_simple, [121u64], [121u64]); -test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]); -test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]); -test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]); -test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]); -test_ptx!(cvt_s64_s32, [-1i32], [-1i64]); -test_ptx!(add_tuning, [2u64], [3u64]); -test_ptx!(add_non_coherent, [3u64], [4u64]); -test_ptx!(sign_extend, [-1i16], [-1i32]); -test_ptx!(atom_add_float, [1.25f32, 0.5f32], [1.25f32, 1.75f32]); -test_ptx!( - setp_nan, - [ - 0.5f32, - f32::NAN, - f32::NAN, - 0.5f32, - f32::NAN, - f32::NAN, - 0.5f32, - 0.5f32 - ], - [1u32, 1u32, 1u32, 0u32] -); -test_ptx!( - setp_num, - [ - 0.5f32, - f32::NAN, - f32::NAN, - 0.5f32, - f32::NAN, - f32::NAN, - 0.5f32, - 0.5f32 - ], - [0u32, 0u32, 0u32, 2u32] -); -test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]); -test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]); -test_ptx!(const, [0u16], [10u16, 20, 30, 40]); -test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]); -test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]); -test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); -test_ptx!(activemask, [0u32], [1u32]); -test_ptx!(membar, [152731u32], [152731u32]); -test_ptx!(shared_unify_extern, [7681u64, 7682u64], [15363u64]); -test_ptx!(shared_unify_local, [16752u64, 714u64], [17466u64]); -// FIXME: This test currently fails for reasons outside of ZLUDA's control. -// One of the LLVM passes does not understand that setreg instruction changes -// global floating point state and assumes that both floating point -// additions are the exact same expressions and optimizes second addition away. -// test_ptx!( -// add_ftz, -// [f32::from_bits(0x800000), f32::from_bits(0x007FFFFF)], -// [0x800000u32, 0xFFFFFF] -// ); -test_ptx!(add_s32_sat, [i32::MIN, -1], [i32::MIN, i32::MAX]); -test_ptx!(malformed_label, [2u64], [3u64]); -test_ptx!( - call_rnd, - [ - 1.0f32, - f32::from_bits(0x33800000), - 1.0f32, - f32::from_bits(0x33800000) - ], - [1.0000001, 1.0f32] -); -test_ptx!(multiple_return, [5u32], [6u32, 123u32]); -test_ptx!(warp_sz, [0u8], [32u8]); -test_ptx!(tanh, [f32::INFINITY], [1.0f32]); -test_ptx!(cp_async, [0u32], [1u32, 2u32, 3u32, 0u32]); - -test_ptx!(nanosleep, [0u64], [0u64]); - -test_ptx!(assertfail); -// TODO: not yet supported -//test_ptx!(func_ptr); -test_ptx!(lanemask_lt); -test_ptx!(extern_func); - -test_ptx_warp!(tid, [ - 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, 15u8, - 16u8, 17u8, 18u8, 19u8, 20u8, 21u8, 22u8, 23u8, 24u8, 25u8, 26u8, 27u8, 28u8, 29u8, 30u8, 31u8, - 32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8, - 48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8, -]); -test_ptx_warp!(bar_red_and_pred, [ - 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, - 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, - 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, - 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, -]); -test_ptx_warp!(shfl_sync_up_b32_pred, [ - 1000u32, 1001u32, 1002u32, 0u32, 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, - 13u32, 14u32, 15u32, 16u32, 17u32, 18u32, 19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, - 1032u32, 1033u32, 1034u32, 32u32, 33u32, 34u32, 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32, - 45u32, 46u32, 47u32, 48u32, 49u32, 50u32, 51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32, -]); -test_ptx_warp!(shfl_sync_down_b32_pred, [ - 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, 13u32, 14u32, 15u32, 16u32, 17u32, 18u32, - 19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, 29u32, 30u32, 31u32, 1029u32, 1030u32, 1031u32, - 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32, 45u32, 46u32, 47u32, 48u32, 49u32, 50u32, - 51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32, 61u32, 62u32, 63u32, 1061u32, 1062u32, 1063u32, -]); -test_ptx_warp!(shfl_sync_bfly_b32_pred, [ - 3u32, 2u32, 1u32, 0u32, 7u32, 6u32, 5u32, 4u32, 11u32, 10u32, 9u32, 8u32, 15u32, 14u32, 13u32, 12u32, - 19u32, 18u32, 17u32, 16u32, 23u32, 22u32, 21u32, 20u32, 27u32, 26u32, 25u32, 24u32, 31u32, 30u32, 29u32, 28u32, - 35u32, 34u32, 33u32, 32u32, 39u32, 38u32, 37u32, 36u32, 43u32, 42u32, 41u32, 40u32, 47u32, 46u32, 45u32, 44u32, - 51u32, 50u32, 49u32, 48u32, 55u32, 54u32, 53u32, 52u32, 59u32, 58u32, 57u32, 56u32, 63u32, 62u32, 61u32, 60u32, -]); -test_ptx_warp!(shfl_sync_idx_b32_pred, [ - 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, - 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, - 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, - 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, -]); -test_ptx_warp!(shfl_sync_mode_b32, [ - 9u32, 7u32, 8u32, 9u32, 21u32, 19u32, 20u32, 21u32, 33u32, 31u32, 32u32, 33u32, 45u32, 43u32, 44u32, 45u32, - 73u32, 71u32, 72u32, 73u32, 85u32, 83u32, 84u32, 85u32, 97u32, 95u32, 96u32, 97u32, 109u32, 107u32, 108u32, 109u32, - 137u32, 135u32, 136u32, 137u32, 149u32, 147u32, 148u32, 149u32, 161u32, 159u32, 160u32, 161u32, 173u32, 171u32, 172u32, 173u32, - 201u32, 199u32, 200u32, 201u32, 213u32, 211u32, 212u32, 213u32, 225u32, 223u32, 224u32, 225u32, 237u32, 235u32, 236u32, 237u32, -]); - -struct DisplayError { - err: T, -} - -impl Display for DisplayError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.err, f) - } -} - -impl Debug for DisplayError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.err, f) - } -} - -impl error::Error for DisplayError {} - -fn test_hip_assert< - Input: From + Debug + Copy + PartialEq, - Output: From + Debug + Copy + PartialEq + Default, ->( - name: &str, - ptx_text: &str, - input: Option<&[Input]>, - output: &[Output], - block_dim_x: u32, -) -> Result<(), Box> { - let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); - let llvm_ir = pass::to_llvm_module(ast, pass::Attributes { clock_rate: 2124000 }).unwrap(); - let name = CString::new(name)?; - let result = - run_hip(name.as_c_str(), llvm_ir, input, output, block_dim_x).map_err(|err| DisplayError { err })?; - assert_eq!(result.as_slice(), output); - Ok(()) -} - -fn test_llvm_assert( - name: &str, - ptx_text: &str, - expected_ll: &str, -) -> Result<(), Box> { - let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); - let llvm_ir = pass::to_llvm_module(ast, pass::Attributes { clock_rate: 2124000 }).unwrap(); - let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); - let actual_ll = actual_ll.to_str(); - compare_llvm(name, actual_ll, expected_ll); - - let expected_attributes_ll = read_test_file!(concat!("../ll/_attributes.ll")); - let actual_attributes_ll = llvm_ir.attributes_ir.print_module_to_string(); - let actual_attributes_ll = actual_attributes_ll.to_str(); - compare_llvm("_attributes", actual_attributes_ll, &expected_attributes_ll); - Ok(()) -} - -fn compare_llvm(name: &str, actual_ll: &str, expected_ll: &str) { - if actual_ll != expected_ll { - let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR"); - if let Ok(output_dir) = output_dir { - let output_dir = Path::new(&output_dir); - fs::create_dir_all(&output_dir).unwrap(); - let output_file = output_dir.join(format!("{}.ll", name)); - let mut output_file = File::create(output_file).unwrap(); - output_file.write_all(actual_ll.as_bytes()).unwrap(); - } - let comparison = pretty_assertions::StrComparison::new(&expected_ll, &actual_ll); - panic!("assertion failed: `(left == right)`\n\n{}", comparison); - } -} - -fn test_cuda_assert< - Input: From + Debug + Copy + PartialEq, - Output: From + Debug + Copy + PartialEq + Default, ->( - name: &str, - ptx_text: &str, - input: Option<&[Input]>, - output: &[Output], - block_dim_x: u32, -) -> Result<(), Box> { - let name = CString::new(name)?; - let result = run_cuda(name.as_c_str(), ptx_text, input, output, block_dim_x); - assert_eq!(result.as_slice(), output); - Ok(()) -} - -fn run_cuda + Copy + Debug, Output: From + Copy + Debug + Default>( - name: &CStr, - ptx_module: &str, - input: Option<&[Input]>, - output: &[Output], - block_dim_x: u32, -) -> Vec { - unsafe { CUDA.cuInit(0) }.unwrap().unwrap(); - let ptx_module = CString::new(ptx_module).unwrap(); - let mut result = vec![0u8.into(); output.len()]; - { - let mut ctx = unsafe { mem::zeroed() }; - unsafe { CUDA.cuCtxCreate_v2(&mut ctx, 0, 0) } - .unwrap() - .unwrap(); - let mut module = unsafe { mem::zeroed() }; - unsafe { CUDA.cuModuleLoadData(&mut module, ptx_module.as_ptr() as _) } - .unwrap() - .unwrap(); - let mut kernel = unsafe { mem::zeroed() }; - unsafe { CUDA.cuModuleGetFunction(&mut kernel, module, name.as_ptr()) } - .unwrap() - .unwrap(); - let mut out_b = unsafe { mem::zeroed() }; - unsafe { CUDA.cuMemAlloc_v2(&mut out_b, output.len() * mem::size_of::()) } - .unwrap() - .unwrap(); - let mut inp_b = unsafe { mem::zeroed() }; - if let Some(input) = input { - unsafe { CUDA.cuMemAlloc_v2(&mut inp_b, input.len() * mem::size_of::()) } - .unwrap() - .unwrap(); - unsafe { - CUDA.cuMemcpyHtoD_v2( - inp_b, - input.as_ptr() as _, - input.len() * mem::size_of::(), - ) - } - .unwrap() - .unwrap(); - } - unsafe { CUDA.cuMemsetD8_v2(out_b, 0, output.len() * mem::size_of::()) } - .unwrap() - .unwrap(); - let mut args = if input.is_some() { - [&inp_b, &out_b] - } else { - [&out_b, &out_b] - }; - unsafe { - CUDA.cuLaunchKernel( - kernel, - 1, - 1, - 1, - block_dim_x, - 1, - 1, - 1024, - CUstream(ptr::null_mut()), - args.as_mut_ptr() as _, - ptr::null_mut(), - ) - } - .unwrap() - .unwrap(); - unsafe { - CUDA.cuMemcpyDtoH_v2( - result.as_mut_ptr() as _, - out_b, - output.len() * mem::size_of::(), - ) - } - .unwrap() - .unwrap(); - unsafe { CUDA.cuStreamSynchronize(CUstream(ptr::null_mut())) } - .unwrap() - .unwrap(); - unsafe { CUDA.cuMemFree_v2(inp_b) }.unwrap().unwrap(); - unsafe { CUDA.cuMemFree_v2(out_b) }.unwrap().unwrap(); - unsafe { CUDA.cuModuleUnload(module) }.unwrap().unwrap(); - unsafe { CUDA.cuCtxDestroy_v2(ctx) }.unwrap().unwrap(); - } - result -} - -struct DynamicCuda { - lib: libloading::Library, -} - -impl DynamicCuda { - #[cfg(not(windows))] - const CUDA_PATH: &'static str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1"; - #[cfg(windows)] - const CUDA_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll"; - - pub fn new() -> Result { - let lib = unsafe { libloading::Library::new(Self::CUDA_PATH) }?; - Ok(Self { lib }) - } -} - -macro_rules! dynamic_fns { - ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => { - impl DynamicCuda { - $( - #[allow(dead_code)] - unsafe fn $fn_name(&self, $($arg_id : $arg_type),*) -> Result<$ret_type, libloading::Error> { - let func = unsafe { self.lib.get:: $ret_type>(concat!(stringify!($fn_name), "\0").as_bytes()) }; - func.map(|f| f($($arg_id),*) ) - } - )* - } - }; -} - -cuda_macros::cuda_function_declarations!(dynamic_fns); - -static COMGR: std::sync::LazyLock = std::sync::LazyLock::new(|| Comgr::new().unwrap()); -static CUDA: std::sync::LazyLock = - std::sync::LazyLock::new(|| DynamicCuda::new().unwrap()); - -fn run_hip + Copy + Debug, Output: From + Copy + Debug + Default>( - name: &CStr, - module: pass::Module, - input: Option<&[Input]>, - output: &[Output], - block_dim_x: u32, -) -> Result, hipError_t> { - use hip_runtime_sys::*; - unsafe { hipInit(0) }.unwrap(); - let comgr = &*COMGR; - let mut result = vec![0u8.into(); output.len()]; - { - let dev = 0; - let mut stream = unsafe { mem::zeroed() }; - unsafe { hipStreamCreate(&mut stream) }.unwrap(); - let mut dev_props = unsafe { mem::zeroed() }; - unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap(); - let elf_module = comgr::compile_bitcode( - &comgr, - unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, - &*module.llvm_ir.write_bitcode_to_memory(), - &*module.attributes_ir.write_bitcode_to_memory(), - module.linked_bitcode(), - ) - .unwrap(); - let mut module = unsafe { mem::zeroed() }; - unsafe { hipModuleLoadData(&mut module, elf_module.as_ptr() as _) }.unwrap(); - let mut kernel = unsafe { mem::zeroed() }; - unsafe { hipModuleGetFunction(&mut kernel, module, name.as_ptr()) }.unwrap(); - let mut out_b = ptr::null_mut(); - unsafe { hipMalloc(&mut out_b, output.len() * mem::size_of::()) }.unwrap(); - let mut inp_b = ptr::null_mut(); - if let Some(input) = input { - unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::()) }.unwrap(); - unsafe { - hipMemcpyWithStream( - inp_b, - input.as_ptr() as _, - input.len() * mem::size_of::(), - hipMemcpyKind::hipMemcpyHostToDevice, - stream, - ) - } - .unwrap(); - } - unsafe { hipMemset(out_b, 0, output.len() * mem::size_of::()) }.unwrap(); - let mut args = if input.is_some() { - [&inp_b, &out_b] - } else { - [&out_b, &out_b] - }; - unsafe { - hipModuleLaunchKernel( - kernel, - 1, - 1, - 1, - block_dim_x, - 1, - 1, - 1024, - stream, - args.as_mut_ptr() as _, - ptr::null_mut(), - ) - } - .unwrap(); - unsafe { - hipMemcpyAsync( - result.as_mut_ptr() as _, - out_b, - output.len() * mem::size_of::(), - hipMemcpyKind::hipMemcpyDeviceToHost, - stream, - ) - } - .unwrap(); - unsafe { hipStreamSynchronize(stream) }.unwrap(); - unsafe { hipFree(inp_b) }.unwrap(); - unsafe { hipFree(out_b) }.unwrap(); - unsafe { hipModuleUnload(module) }.unwrap(); - } - Ok(result) -} +use crate::pass; +use comgr::Comgr; +use cuda_types::cuda::CUstream; +use hip_runtime_sys::hipError_t; +use pretty_assertions; +use std::env; +use std::error; +use std::ffi::{CStr, CString}; +use std::fmt::{self, Debug, Display, Formatter}; +use std::fs::{self, File}; +use std::io::Write; +use std::mem; +use std::path::{Path, PathBuf}; +use std::ptr; +use std::str; + +#[cfg(not(feature = "ci_build"))] +macro_rules! read_test_file { + ($file:expr) => { + { + // CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx). + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.pop(); + path.push(file!()); + path.pop(); + path.push($file); + std::fs::read_to_string(path).unwrap() + } + }; +} + +#[cfg(feature = "ci_build")] +macro_rules! read_test_file { + ($file:expr) => { + include_str!($file).to_string() + }; +} + +macro_rules! test_ptx_llvm { + ($fn_name:ident) => { + paste::item! { + #[test] + fn [<$fn_name _llvm>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let ll = read_test_file!(concat!("../ll/", stringify!($fn_name), ".ll")); + test_llvm_assert(stringify!($fn_name), &ptx, ll.trim()) + } + } + }; +} + +macro_rules! test_ptx { + ($fn_name:ident, $input:expr, $output:expr) => { + paste::item! { + #[test] + fn [<$fn_name _amdgpu>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let input = $input; + let output = $output; + test_hip_assert(stringify!($fn_name), &ptx, Some(&input), &output, 1) + } + } + + paste::item! { + #[test] + fn [<$fn_name _cuda>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let input = $input; + let output = $output; + test_cuda_assert(stringify!($fn_name), &ptx, Some(&input), &output, 1) + } + } + + test_ptx_llvm!($fn_name); + }; + + ($fn_name:ident) => { + test_ptx_llvm!($fn_name); + }; +} + +macro_rules! test_ptx_warp { + ($fn_name:ident, $output:expr) => { + paste::item! { + #[test] + fn [<$fn_name _amdgpu>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let mut output = $output; + test_hip_assert(stringify!($fn_name), &ptx, None::<&[u8]>, &mut output, 64) + } + } + + paste::item! { + #[test] + fn [<$fn_name _cuda>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let mut output = $output; + test_cuda_assert(stringify!($fn_name), &ptx, None::<&[u8]>, &mut output, 64) + } + } + + test_ptx_llvm!($fn_name); + }; +} + +test_ptx!(ld_st, [1u64], [1u64]); +test_ptx!(ld_st_implicit, [0.5f32, 0.25f32], [0.5f32]); +test_ptx!(mov, [1u64], [1u64]); +test_ptx!(mul_lo, [1u64], [2u64]); +test_ptx!(mul_hi, [u64::max_value()], [1u64]); +test_ptx!(add, [1u64], [2u64]); +test_ptx!( + mul24_lo_u32, + [0b01110101_01010101_01010101u32], + [0b00011100_00100011_10001110_00111001u32] +); +test_ptx!( + mul24_hi_u32, + [0b01110101_01010101_01010101u32], + [0b00110101_11000111_00011100_00100011u32] +); +test_ptx!( + mul24_lo_s32, + [0b01110101_01010101_01010101i32], + [-0b0011100_00100011_10001110_00111001i32] +); +test_ptx!( + mul24_hi_s32, + [0b01110101_01010101_01010101i32], + [-0b0110101_11000111_00011100_00100100i32] +); +test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); +test_ptx!(setp_gt, [f32::NAN, 1f32], [1f32]); +test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]); +test_ptx!(bra, [10u64], [11u64]); +test_ptx!(not, [0u64], [u64::max_value()]); +test_ptx!(shl, [11u64], [44u64]); +test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); +test_ptx!(cvta, [3.0f32], [3.0f32]); +test_ptx!(block, [1u64], [2u64]); +test_ptx!(local_align, [1u64], [1u64]); +test_ptx!(call, [1u64], [2u64]); +test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); +test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]); +test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); +test_ptx!(ntid, [3u32], [4u32]); +test_ptx!(reg_local, [12u64], [13u64]); +test_ptx!(mov_address, [0xDEADu64], [0u64]); +test_ptx!(b64tof64, [111u64], [111u64]); +// This segfaults NV compiler +// test_ptx!(implicit_param, [34u32], [34u32]); +test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]); +test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32]); +test_ptx!(mad_wide, [-1i32, 3, 4, 5], [21474836481i64]); +test_ptx!( + mul_wide, + [0x01_00_00_00__01_00_00_00i64], + [0x1_00_00_00_00_00_00i64] +); +test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); +test_ptx!(shr, [-2i32], [-1i32]); +test_ptx!(shr_oob, [-32768i16], [-1i16]); +test_ptx!(or, [1u64, 2u64], [3u64]); +test_ptx!(sub, [2u64], [1u64]); +test_ptx!(min, [555i32, 444i32], [444i32]); +test_ptx!(max, [555i32, 444i32], [555i32]); +test_ptx!(global_array, [0xDEADu32], [1u32]); +test_ptx!(extern_shared, [127u64], [127u64]); +test_ptx!(extern_shared_call, [121u64], [123u64]); +test_ptx!(rcp, [2f32], [0.5f32]); +// 0b1_00000000_10000000000000000000000u32 is a large denormal +// 0x3f000000 is 0.5 +test_ptx!( + mul_ftz, + [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], + [0b1_00000000_00000000000000000000000u32] +); +test_ptx!( + mul_non_ftz, + [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], + [0b1_00000000_01000000000000000000000u32] +); +test_ptx!(constant_f32, [10f32], [5f32]); +test_ptx!(abs, [i32::MIN], [i32::MIN]); +test_ptx!(constant_negative, [-101i32], [101i32]); +test_ptx!(and, [6u32, 3u32], [2u32]); +test_ptx!(selp, [100u16, 200u16], [200u16]); +test_ptx!(selp_true, [100u16, 200u16], [100u16]); +test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]); +test_ptx!(shared_variable, [513u64], [513u64]); +test_ptx!(shared_ptr_32, [513u64], [513u64]); +test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]); +test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]); +test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]); +test_ptx!(div_approx, [1f32, 2f32], [0.5f32]); +test_ptx!(sqrt, [0.25f32], [0.5f32]); +test_ptx!(rsqrt, [0.25f64], [2f64]); +test_ptx!(neg, [181i32], [-181i32]); +test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]); +test_ptx!(cos, [std::f32::consts::PI], [-1f32]); +test_ptx!(lg2, [512f32], [9f32]); +test_ptx!(ex2, [10f32], [1024f32]); +test_ptx!(fmax, [0u16, half::f16::NAN.to_bits()], [0u16]); +test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]); +test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]); +test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]); +test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]); +test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); +test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); +test_ptx!( + brev, + [0b11000111_01011100_10101110_11111011u32], + [0b11011111_01110101_00111010_11100011u32] +); +test_ptx!( + xor, + [ + 0b01010010_00011010_01000000_00001101u32, + 0b11100110_10011011_00001100_00100011u32 + ], + [0b10110100100000010100110000101110u32] +); +test_ptx!(rem, [21692i32, 13i32], [8i32]); +test_ptx!( + bfe, + [0b11111000_11000001_00100010_10100000u32, 16u32, 8u32], + [0b11000001u32] +); +test_ptx!(bfi, [0b10u32, 0b101u32, 0u32, 2u32], [0b110u32]); +test_ptx!(stateful_ld_st_simple, [121u64], [121u64]); +test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]); +test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]); +test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]); +test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]); +test_ptx!(cvt_s64_s32, [-1i32], [-1i64]); +test_ptx!(add_tuning, [2u64], [3u64]); +test_ptx!(add_non_coherent, [3u64], [4u64]); +test_ptx!(sign_extend, [-1i16], [-1i32]); +test_ptx!(atom_add_float, [1.25f32, 0.5f32], [1.25f32, 1.75f32]); +test_ptx!( + setp_nan, + [ + 0.5f32, + f32::NAN, + f32::NAN, + 0.5f32, + f32::NAN, + f32::NAN, + 0.5f32, + 0.5f32 + ], + [1u32, 1u32, 1u32, 0u32] +); +test_ptx!( + setp_num, + [ + 0.5f32, + f32::NAN, + f32::NAN, + 0.5f32, + f32::NAN, + f32::NAN, + 0.5f32, + 0.5f32 + ], + [0u32, 0u32, 0u32, 2u32] +); +test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]); +test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]); +test_ptx!(const, [0u16], [10u16, 20, 30, 40]); +test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]); +test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]); +test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); +test_ptx!(activemask, [0u32], [1u32]); +test_ptx!(membar, [152731u32], [152731u32]); +test_ptx!(shared_unify_extern, [7681u64, 7682u64], [15363u64]); +test_ptx!(shared_unify_local, [16752u64, 714u64], [17466u64]); +// FIXME: This test currently fails for reasons outside of ZLUDA's control. +// One of the LLVM passes does not understand that setreg instruction changes +// global floating point state and assumes that both floating point +// additions are the exact same expressions and optimizes second addition away. +// test_ptx!( +// add_ftz, +// [f32::from_bits(0x800000), f32::from_bits(0x007FFFFF)], +// [0x800000u32, 0xFFFFFF] +// ); +test_ptx!(add_s32_sat, [i32::MIN, -1], [i32::MIN, i32::MAX]); +test_ptx!(malformed_label, [2u64], [3u64]); +test_ptx!( + call_rnd, + [ + 1.0f32, + f32::from_bits(0x33800000), + 1.0f32, + f32::from_bits(0x33800000) + ], + [1.0000001, 1.0f32] +); +test_ptx!(multiple_return, [5u32], [6u32, 123u32]); +test_ptx!(warp_sz, [0u8], [32u8]); +test_ptx!(tanh, [f32::INFINITY], [1.0f32]); +test_ptx!(cp_async, [0u32], [1u32, 2u32, 3u32, 0u32]); + +test_ptx!(nanosleep, [0u64], [0u64]); + +test_ptx!(assertfail); +// TODO: not yet supported +//test_ptx!(func_ptr); +test_ptx!(lanemask_lt); +test_ptx!(extern_func); + +test_ptx_warp!( + tid, + [ + 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, 15u8, 16u8, + 17u8, 18u8, 19u8, 20u8, 21u8, 22u8, 23u8, 24u8, 25u8, 26u8, 27u8, 28u8, 29u8, 30u8, 31u8, + 32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, + 47u8, 48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, + 62u8, 63u8, + ] +); +test_ptx_warp!( + bar_red_and_pred, + [ + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, + ] +); +test_ptx_warp!( + shfl_sync_up_b32_pred, + [ + 1000u32, 1001u32, 1002u32, 0u32, 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, + 10u32, 11u32, 12u32, 13u32, 14u32, 15u32, 16u32, 17u32, 18u32, 19u32, 20u32, 21u32, 22u32, + 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, 1032u32, 1033u32, 1034u32, 32u32, 33u32, 34u32, + 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32, 45u32, 46u32, 47u32, + 48u32, 49u32, 50u32, 51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32, + ] +); +test_ptx_warp!( + shfl_sync_down_b32_pred, + [ + 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, 13u32, 14u32, 15u32, 16u32, + 17u32, 18u32, 19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, 29u32, + 30u32, 31u32, 1029u32, 1030u32, 1031u32, 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, + 42u32, 43u32, 44u32, 45u32, 46u32, 47u32, 48u32, 49u32, 50u32, 51u32, 52u32, 53u32, 54u32, + 55u32, 56u32, 57u32, 58u32, 59u32, 60u32, 61u32, 62u32, 63u32, 1061u32, 1062u32, 1063u32, + ] +); +test_ptx_warp!( + shfl_sync_bfly_b32_pred, + [ + 3u32, 2u32, 1u32, 0u32, 7u32, 6u32, 5u32, 4u32, 11u32, 10u32, 9u32, 8u32, 15u32, 14u32, + 13u32, 12u32, 19u32, 18u32, 17u32, 16u32, 23u32, 22u32, 21u32, 20u32, 27u32, 26u32, 25u32, + 24u32, 31u32, 30u32, 29u32, 28u32, 35u32, 34u32, 33u32, 32u32, 39u32, 38u32, 37u32, 36u32, + 43u32, 42u32, 41u32, 40u32, 47u32, 46u32, 45u32, 44u32, 51u32, 50u32, 49u32, 48u32, 55u32, + 54u32, 53u32, 52u32, 59u32, 58u32, 57u32, 56u32, 63u32, 62u32, 61u32, 60u32, + ] +); +test_ptx_warp!( + shfl_sync_idx_b32_pred, + [ + 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, + 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, + 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, + 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, + 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, + ] +); +test_ptx_warp!( + shfl_sync_mode_b32, + [ + 9u32, 7u32, 8u32, 9u32, 21u32, 19u32, 20u32, 21u32, 33u32, 31u32, 32u32, 33u32, 45u32, + 43u32, 44u32, 45u32, 73u32, 71u32, 72u32, 73u32, 85u32, 83u32, 84u32, 85u32, 97u32, 95u32, + 96u32, 97u32, 109u32, 107u32, 108u32, 109u32, 137u32, 135u32, 136u32, 137u32, 149u32, + 147u32, 148u32, 149u32, 161u32, 159u32, 160u32, 161u32, 173u32, 171u32, 172u32, 173u32, + 201u32, 199u32, 200u32, 201u32, 213u32, 211u32, 212u32, 213u32, 225u32, 223u32, 224u32, + 225u32, 237u32, 235u32, 236u32, 237u32, + ] +); + +struct DisplayError { + err: T, +} + +impl Display for DisplayError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.err, f) + } +} + +impl Debug for DisplayError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.err, f) + } +} + +impl error::Error for DisplayError {} + +fn test_hip_assert< + Input: From + Debug + Copy + PartialEq, + Output: From + Debug + Copy + PartialEq + Default, +>( + name: &str, + ptx_text: &str, + input: Option<&[Input]>, + output: &[Output], + block_dim_x: u32, +) -> Result<(), Box> { + let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); + let llvm_ir = pass::to_llvm_module( + ast, + pass::Attributes { + clock_rate: 2124000, + }, + ) + .unwrap(); + let name = CString::new(name)?; + let result = run_hip(name.as_c_str(), llvm_ir, input, output, block_dim_x) + .map_err(|err| DisplayError { err })?; + assert_eq!(result.as_slice(), output); + Ok(()) +} + +fn test_llvm_assert( + name: &str, + ptx_text: &str, + expected_ll: &str, +) -> Result<(), Box> { + let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); + let llvm_ir = pass::to_llvm_module( + ast, + pass::Attributes { + clock_rate: 2124000, + }, + ) + .unwrap(); + let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); + let actual_ll = actual_ll.to_str(); + compare_llvm(name, actual_ll, expected_ll); + + let expected_attributes_ll = read_test_file!(concat!("../ll/_attributes.ll")); + let actual_attributes_ll = llvm_ir.attributes_ir.print_module_to_string(); + let actual_attributes_ll = actual_attributes_ll.to_str(); + compare_llvm("_attributes", actual_attributes_ll, &expected_attributes_ll); + Ok(()) +} + +fn compare_llvm(name: &str, actual_ll: &str, expected_ll: &str) { + if actual_ll != expected_ll { + let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR"); + if let Ok(output_dir) = output_dir { + let output_dir = Path::new(&output_dir); + fs::create_dir_all(&output_dir).unwrap(); + let output_file = output_dir.join(format!("{}.ll", name)); + let mut output_file = File::create(output_file).unwrap(); + output_file.write_all(actual_ll.as_bytes()).unwrap(); + } + let comparison = pretty_assertions::StrComparison::new(&expected_ll, &actual_ll); + panic!("assertion failed: `(left == right)`\n\n{}", comparison); + } +} + +fn test_cuda_assert< + Input: From + Debug + Copy + PartialEq, + Output: From + Debug + Copy + PartialEq + Default, +>( + name: &str, + ptx_text: &str, + input: Option<&[Input]>, + output: &[Output], + block_dim_x: u32, +) -> Result<(), Box> { + let name = CString::new(name)?; + let result = run_cuda(name.as_c_str(), ptx_text, input, output, block_dim_x); + assert_eq!(result.as_slice(), output); + Ok(()) +} + +fn run_cuda + Copy + Debug, Output: From + Copy + Debug + Default>( + name: &CStr, + ptx_module: &str, + input: Option<&[Input]>, + output: &[Output], + block_dim_x: u32, +) -> Vec { + unsafe { CUDA.cuInit(0) }.unwrap().unwrap(); + let ptx_module = CString::new(ptx_module).unwrap(); + let mut result = vec![0u8.into(); output.len()]; + { + let mut ctx = unsafe { mem::zeroed() }; + unsafe { CUDA.cuCtxCreate_v2(&mut ctx, 0, 0) } + .unwrap() + .unwrap(); + let mut module = unsafe { mem::zeroed() }; + unsafe { CUDA.cuModuleLoadData(&mut module, ptx_module.as_ptr() as _) } + .unwrap() + .unwrap(); + let mut kernel = unsafe { mem::zeroed() }; + unsafe { CUDA.cuModuleGetFunction(&mut kernel, module, name.as_ptr()) } + .unwrap() + .unwrap(); + let mut out_b = unsafe { mem::zeroed() }; + unsafe { CUDA.cuMemAlloc_v2(&mut out_b, output.len() * mem::size_of::()) } + .unwrap() + .unwrap(); + let mut inp_b = unsafe { mem::zeroed() }; + if let Some(input) = input { + unsafe { CUDA.cuMemAlloc_v2(&mut inp_b, input.len() * mem::size_of::()) } + .unwrap() + .unwrap(); + unsafe { + CUDA.cuMemcpyHtoD_v2( + inp_b, + input.as_ptr() as _, + input.len() * mem::size_of::(), + ) + } + .unwrap() + .unwrap(); + } + unsafe { CUDA.cuMemsetD8_v2(out_b, 0, output.len() * mem::size_of::()) } + .unwrap() + .unwrap(); + let mut args = if input.is_some() { + [&inp_b, &out_b] + } else { + [&out_b, &out_b] + }; + unsafe { + CUDA.cuLaunchKernel( + kernel, + 1, + 1, + 1, + block_dim_x, + 1, + 1, + 1024, + CUstream(ptr::null_mut()), + args.as_mut_ptr() as _, + ptr::null_mut(), + ) + } + .unwrap() + .unwrap(); + unsafe { + CUDA.cuMemcpyDtoH_v2( + result.as_mut_ptr() as _, + out_b, + output.len() * mem::size_of::(), + ) + } + .unwrap() + .unwrap(); + unsafe { CUDA.cuStreamSynchronize(CUstream(ptr::null_mut())) } + .unwrap() + .unwrap(); + unsafe { CUDA.cuMemFree_v2(inp_b) }.unwrap().unwrap(); + unsafe { CUDA.cuMemFree_v2(out_b) }.unwrap().unwrap(); + unsafe { CUDA.cuModuleUnload(module) }.unwrap().unwrap(); + unsafe { CUDA.cuCtxDestroy_v2(ctx) }.unwrap().unwrap(); + } + result +} + +struct DynamicCuda { + lib: libloading::Library, +} + +impl DynamicCuda { + #[cfg(not(windows))] + const CUDA_PATH: &'static str = "/usr/lib/x86_64-linux-gnu/libcuda.so.1"; + #[cfg(windows)] + const CUDA_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll"; + + pub fn new() -> Result { + let lib = unsafe { libloading::Library::new(Self::CUDA_PATH) }?; + Ok(Self { lib }) + } +} + +macro_rules! dynamic_fns { + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => { + impl DynamicCuda { + $( + #[allow(dead_code)] + unsafe fn $fn_name(&self, $($arg_id : $arg_type),*) -> Result<$ret_type, libloading::Error> { + let func = unsafe { self.lib.get:: $ret_type>(concat!(stringify!($fn_name), "\0").as_bytes()) }; + func.map(|f| f($($arg_id),*) ) + } + )* + } + }; +} + +cuda_macros::cuda_function_declarations!(dynamic_fns); + +static COMGR: std::sync::LazyLock = std::sync::LazyLock::new(|| Comgr::new().unwrap()); +static CUDA: std::sync::LazyLock = + std::sync::LazyLock::new(|| DynamicCuda::new().unwrap()); + +fn run_hip + Copy + Debug, Output: From + Copy + Debug + Default>( + name: &CStr, + module: pass::Module, + input: Option<&[Input]>, + output: &[Output], + block_dim_x: u32, +) -> Result, hipError_t> { + use hip_runtime_sys::*; + unsafe { hipInit(0) }.unwrap(); + let comgr = &*COMGR; + let mut result = vec![0u8.into(); output.len()]; + { + let dev = 0; + let mut stream = unsafe { mem::zeroed() }; + unsafe { hipStreamCreate(&mut stream) }.unwrap(); + let mut dev_props = unsafe { mem::zeroed() }; + unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap(); + let elf_module = comgr::compile_bitcode( + &comgr, + unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, + &*module.llvm_ir.write_bitcode_to_memory(), + &*module.attributes_ir.write_bitcode_to_memory(), + module.linked_bitcode(), + ) + .unwrap(); + let mut module = unsafe { mem::zeroed() }; + unsafe { hipModuleLoadData(&mut module, elf_module.as_ptr() as _) }.unwrap(); + let mut kernel = unsafe { mem::zeroed() }; + unsafe { hipModuleGetFunction(&mut kernel, module, name.as_ptr()) }.unwrap(); + let mut out_b = ptr::null_mut(); + unsafe { hipMalloc(&mut out_b, output.len() * mem::size_of::()) }.unwrap(); + let mut inp_b = ptr::null_mut(); + if let Some(input) = input { + unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::()) }.unwrap(); + unsafe { + hipMemcpyWithStream( + inp_b, + input.as_ptr() as _, + input.len() * mem::size_of::(), + hipMemcpyKind::hipMemcpyHostToDevice, + stream, + ) + } + .unwrap(); + } + unsafe { hipMemset(out_b, 0, output.len() * mem::size_of::()) }.unwrap(); + let mut args = if input.is_some() { + [&inp_b, &out_b] + } else { + [&out_b, &out_b] + }; + unsafe { + hipModuleLaunchKernel( + kernel, + 1, + 1, + 1, + block_dim_x, + 1, + 1, + 1024, + stream, + args.as_mut_ptr() as _, + ptr::null_mut(), + ) + } + .unwrap(); + unsafe { + hipMemcpyAsync( + result.as_mut_ptr() as _, + out_b, + output.len() * mem::size_of::(), + hipMemcpyKind::hipMemcpyDeviceToHost, + stream, + ) + } + .unwrap(); + unsafe { hipStreamSynchronize(stream) }.unwrap(); + unsafe { hipFree(inp_b) }.unwrap(); + unsafe { hipFree(out_b) }.unwrap(); + unsafe { hipModuleUnload(module) }.unwrap(); + } + Ok(result) +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 77721e3..d140adb 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,2023 +1,2022 @@ -use super::{ - AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, - StateSpace, VectorPrefix, -}; -use crate::{Mul24Control, Reduction, PtxError, PtxParserState, ShuffleMode}; -use bitflags::bitflags; -use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; - -pub enum Statement { - Label(P::Ident), - Variable(MultiVariable), - Instruction(Option>, Instruction

), - Block(Vec>), -} - -// We define the instruction enum through the macro instead of normally, because we have some of how -// we use this type in the compilee. Each instruction can be logically split into two parts: -// properties that define instruction semantics (e.g. is memory load volatile?) that don't change -// during compilation and arguments (e.g. memory load source and destination) that evolve during -// compilation. To support compilation passes we need to be able to visit (and change) every -// argument in a generic way. This macro has visibility over all the fields. Consequently, we use it -// to generate visitor functions. There re three functions to support three different semantics: -// visit-by-ref, visit-by-mutable-ref, visit-and-map. In a previous version of the compiler it was -// done by hand and was very limiting (we supported only visit-and-map). -// The visitor must implement appropriate visitor trait defined below this macro. For convenience, -// we implemented visitors for some corresponding FnMut(...) types. -// Properties in this macro are used to encode information about the instruction arguments (what -// Rust type is used for it post-parsing, what PTX type does it expect, what PTX address space does -// it expect, etc.). -// This information is then available to a visitor. -ptx_parser_macros::generate_instruction_type!( - pub enum Instruction { - Abs { - data: TypeFtz, - type: { Type::Scalar(data.type_) }, - arguments: { - dst: T, - src: T, - } - }, - Activemask { - type: Type::Scalar(ScalarType::B32), - arguments: { - dst: T - } - }, - Add { - type: { Type::from(data.type_()) }, - data: ArithDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - And { - data: ScalarType, - type: { Type::Scalar(data.clone()) }, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Atom { - type: &data.type_, - data: AtomDetails, - arguments: { - dst: T, - src1: { - repr: T, - space: { data.space }, - }, - src2: T, - } - }, - AtomCas { - type: Type::Scalar(data.type_), - data: AtomCasDetails, - arguments: { - dst: T, - src1: { - repr: T, - space: { data.space }, - }, - src2: T, - src3: T, - } - }, - Bar { - type: Type::Scalar(ScalarType::U32), - data: BarData, - arguments: { - src1: T, - src2: Option, - } - }, - BarRed { - type: Type::Scalar(ScalarType::U32), - data: BarRedData, - arguments: { - dst1: { - repr: T, - type: Type::from(ScalarType::Pred) - }, - src_barrier: T, - src_threadcount: Option, - src_predicate: { - repr: T, - type: Type::from(ScalarType::Pred) - }, - src_negate_predicate: { - repr: T, - type: Type::from(ScalarType::Pred) - }, - } - }, - Bfe { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - src3: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - } - }, - Bfi { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: T, - src3: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - src4: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - } - }, - Bra { - type: !, - arguments: { - src: T - } - }, - Brev { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src: T - } - }, - Call { - data: CallDetails, - arguments: CallArgs, - visit: arguments.visit(data, visitor)?, - visit_mut: arguments.visit_mut(data, visitor)?, - map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data } - }, - Clz { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - src: T - } - }, - Cos { - type: Type::Scalar(ScalarType::F32), - data: FlushToZero, - arguments: { - dst: T, - src: T - } - }, - CpAsync { - type: Type::Scalar(ScalarType::U32), - data: CpAsyncDetails, - arguments: { - src_to: { - repr: T, - space: StateSpace::Shared - }, - src_from: { - repr: T, - space: StateSpace::Global - } - } - }, - CpAsyncCommitGroup { }, - CpAsyncWaitGroup { - type: Type::Scalar(ScalarType::U64), - arguments: { - src_group: T - } - }, - CpAsyncWaitAll { }, - Cvt { - data: CvtDetails, - arguments: { - dst: { - repr: T, - type: { Type::Scalar(data.to) }, - // TODO: double check - relaxed_type_check: true, - }, - src: { - repr: T, - type: { Type::Scalar(data.from) }, - relaxed_type_check: true, - }, - } - }, - Cvta { - data: CvtaDetails, - type: { Type::Scalar(ScalarType::B64) }, - arguments: { - dst: T, - src: T, - } - }, - Div { - type: Type::Scalar(data.type_()), - data: DivDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Ex2 { - type: Type::Scalar(ScalarType::F32), - data: TypeFtz, - arguments: { - dst: T, - src: T - } - }, - Fma { - type: { Type::from(data.type_) }, - data: ArithFloat, - arguments: { - dst: T, - src1: T, - src2: T, - src3: T, - } - }, - Ld { - type: { &data.typ }, - data: LdDetails, - arguments: { - dst: { - repr: T, - relaxed_type_check: true, - }, - src: { - repr: T, - space: { data.state_space }, - } - } - }, - Lg2 { - type: Type::Scalar(ScalarType::F32), - data: FlushToZero, - arguments: { - dst: T, - src: T - } - }, - Mad { - type: { Type::from(data.type_()) }, - data: MadDetails, - arguments: { - dst: { - repr: T, - type: { Type::from(data.dst_type()) }, - }, - src1: T, - src2: T, - src3: { - repr: T, - type: { Type::from(data.dst_type()) }, - } - } - }, - Max { - type: { Type::from(data.type_()) }, - data: MinMaxDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Membar { - data: MemScope - }, - Min { - type: { Type::from(data.type_()) }, - data: MinMaxDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Mov { - type: { &data.typ }, - data: MovDetails, - arguments: { - dst: T, - src: T - } - }, - Mul { - type: { Type::from(data.type_()) }, - data: MulDetails, - arguments: { - dst: { - repr: T, - type: { Type::from(data.dst_type()) }, - }, - src1: T, - src2: T, - } - }, - Mul24 { - type: { Type::from(data.type_) }, - data: Mul24Details, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Nanosleep { - type: Type::Scalar(ScalarType::U32), - arguments: { - src: T - } - }, - Neg { - type: Type::Scalar(data.type_), - data: TypeFtz, - arguments: { - dst: T, - src: T - } - }, - Not { - data: ScalarType, - type: { Type::Scalar(data.clone()) }, - arguments: { - dst: T, - src: T, - } - }, - Or { - data: ScalarType, - type: { Type::Scalar(data.clone()) }, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Popc { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: { - repr: T, - type: Type::Scalar(ScalarType::U32) - }, - src: T - } - }, - Prmt { - type: Type::Scalar(ScalarType::B32), - data: u16, - arguments: { - dst: T, - src1: T, - src2: T - } - }, - PrmtSlow { - type: Type::Scalar(ScalarType::U32), - arguments: { - dst: T, - src1: T, - src2: T, - src3: T - } - }, - Rcp { - type: { Type::from(data.type_) }, - data: RcpData, - arguments: { - dst: T, - src: T, - } - }, - Rem { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: T - } - }, - Ret { - data: RetData - }, - Rsqrt { - type: { Type::from(data.type_) }, - data: TypeFtz, - arguments: { - dst: T, - src: T, - } - }, - Selp { - type: { Type::Scalar(data.clone()) }, - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: T, - src3: { - repr: T, - type: Type::Scalar(ScalarType::Pred) - }, - } - }, - Set { - data: SetData, - arguments: { - dst: { - repr: T, - type: Type::from(data.dtype) - }, - src1: { - repr: T, - type: Type::from(data.base.type_), - }, - src2: { - repr: T, - type: Type::from(data.base.type_), - } - } - }, - SetBool { - data: SetBoolData, - arguments: { - dst: { - repr: T, - type: Type::from(data.dtype) - }, - src1: { - repr: T, - type: Type::from(data.base.base.type_), - }, - src2: { - repr: T, - type: Type::from(data.base.base.type_), - }, - src3: { - repr: T, - type: Type::from(ScalarType::Pred) - } - } - }, - Setp { - data: SetpData, - arguments: { - dst1: { - repr: T, - type: Type::from(ScalarType::Pred) - }, - dst2: { - repr: Option, - type: Type::from(ScalarType::Pred) - }, - src1: { - repr: T, - type: Type::from(data.type_), - }, - src2: { - repr: T, - type: Type::from(data.type_), - } - } - }, - SetpBool { - data: SetpBoolData, - arguments: { - dst1: { - repr: T, - type: Type::from(ScalarType::Pred) - }, - dst2: { - repr: Option, - type: Type::from(ScalarType::Pred) - }, - src1: { - repr: T, - type: Type::from(data.base.type_), - }, - src2: { - repr: T, - type: Type::from(data.base.type_), - }, - src3: { - repr: T, - type: Type::from(ScalarType::Pred) - } - } - }, - ShflSync { - data: ShflSyncDetails, - type: Type::Scalar(ScalarType::B32), - arguments: { - dst: T, - dst_pred: { - repr: Option, - type: Type::from(ScalarType::Pred) - }, - src: T, - src_lane: T, - src_opts: T, - src_membermask: T - } - }, - Shl { - data: ScalarType, - type: { Type::Scalar(data.clone()) }, - arguments: { - dst: T, - src1: T, - src2: { - repr: T, - type: { Type::Scalar(ScalarType::U32) }, - }, - } - }, - Shr { - data: ShrData, - type: { Type::Scalar(data.type_.clone()) }, - arguments: { - dst: T, - src1: T, - src2: { - repr: T, - type: { Type::Scalar(ScalarType::U32) }, - }, - } - }, - Sin { - type: Type::Scalar(ScalarType::F32), - data: FlushToZero, - arguments: { - dst: T, - src: T - } - }, - Sqrt { - type: { Type::from(data.type_) }, - data: RcpData, - arguments: { - dst: T, - src: T, - } - }, - St { - type: { &data.typ }, - data: StData, - arguments: { - src1: { - repr: T, - space: { data.state_space }, - }, - src2: { - repr: T, - relaxed_type_check: true, - } - } - }, - Sub { - type: { Type::from(data.type_()) }, - data: ArithDetails, - arguments: { - dst: T, - src1: T, - src2: T, - } - }, - Trap { }, - Xor { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src1: T, - src2: T - } - }, - Tanh { - type: Type::Scalar(data.clone()), - data: ScalarType, - arguments: { - dst: T, - src: T - } - }, - } -); - -pub trait Visitor { - fn visit( - &mut self, - args: &T, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result<(), Err>; - fn visit_ident( - &mut self, - args: &T::Ident, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result<(), Err>; -} - -impl< - T: Operand, - Err, - Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>, - > Visitor for Fn -{ - fn visit( - &mut self, - args: &T, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result<(), Err> { - (self)(args, type_space, is_dst, relaxed_type_check) - } - - fn visit_ident( - &mut self, - args: &T::Ident, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result<(), Err> { - (self)( - &T::from_ident(*args), - type_space, - is_dst, - relaxed_type_check, - ) - } -} - -pub trait VisitorMut { - fn visit( - &mut self, - args: &mut T, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result<(), Err>; - fn visit_ident( - &mut self, - args: &mut T::Ident, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result<(), Err>; -} - -pub trait VisitorMap { - fn visit( - &mut self, - args: From, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result; - fn visit_ident( - &mut self, - args: From::Ident, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result; -} - -impl VisitorMap, ParsedOperand, Err> for Fn -where - Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, -{ - fn visit( - &mut self, - args: ParsedOperand, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result, Err> { - Ok(match args { - ParsedOperand::Reg(ident) => { - ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?) - } - ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset( - (self)(ident, type_space, is_dst, relaxed_type_check)?, - imm, - ), - ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm), - ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember( - (self)(ident, type_space, is_dst, relaxed_type_check)?, - index, - ), - ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( - vec.into_iter() - .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check)) - .collect::, _>>()?, - ), - }) - } - - fn visit_ident( - &mut self, - args: T, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - (self)(args, type_space, is_dst, relaxed_type_check) - } -} - -impl, U: Operand, Err, Fn> VisitorMap for Fn -where - Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, -{ - fn visit( - &mut self, - args: T, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - (self)(args, type_space, is_dst, relaxed_type_check) - } - - fn visit_ident( - &mut self, - args: T, - type_space: Option<(&Type, StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - (self)(args, type_space, is_dst, relaxed_type_check) - } -} - -trait VisitOperand { - type Operand: Operand; - #[allow(unused)] // Used by generated code - fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>; - #[allow(unused)] // Used by generated code - fn visit_mut( - &mut self, - fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, - ) -> Result<(), Err>; -} - -impl VisitOperand for T { - type Operand = Self; - fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { - fn_(self) - } - fn visit_mut( - &mut self, - mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, - ) -> Result<(), Err> { - fn_(self) - } -} - -impl VisitOperand for Option { - type Operand = T; - fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { - if let Some(x) = self { - fn_(x)?; - } - Ok(()) - } - fn visit_mut( - &mut self, - mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, - ) -> Result<(), Err> { - if let Some(x) = self { - fn_(x)?; - } - Ok(()) - } -} - -impl VisitOperand for Vec { - type Operand = T; - fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { - for o in self { - fn_(o)?; - } - Ok(()) - } - fn visit_mut( - &mut self, - mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, - ) -> Result<(), Err> { - for o in self { - fn_(o)?; - } - Ok(()) - } -} - -trait MapOperand: Sized { - type Input; - type Output; - #[allow(unused)] // Used by generated code - fn map( - self, - fn_: impl FnOnce(Self::Input) -> Result, - ) -> Result, Err>; -} - -impl MapOperand for T { - type Input = Self; - type Output = U; - fn map(self, fn_: impl FnOnce(T) -> Result) -> Result { - fn_(self) - } -} - -impl MapOperand for Option { - type Input = T; - type Output = Option; - fn map(self, fn_: impl FnOnce(T) -> Result) -> Result, Err> { - self.map(|x| fn_(x)).transpose() - } -} - -pub struct MultiVariable { - pub var: Variable, - pub count: Option, -} - -#[derive(Clone)] -pub struct Variable { - pub align: Option, - pub v_type: Type, - pub state_space: StateSpace, - pub name: ID, - pub array_init: Vec, -} - -pub struct PredAt { - pub not: bool, - pub label: ID, -} - -#[derive(PartialEq, Eq, Clone, Hash)] -pub enum Type { - // .param.b32 foo; - Scalar(ScalarType), - // .param.v2.b32 foo; - Vector(u8, ScalarType), - // .param.b32 foo[4]; - Array(Option, ScalarType, Vec), -} - -impl Type { - pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { - match vector { - Some(prefix) => Type::Vector(prefix.len().get(), scalar), - None => Type::Scalar(scalar), - } - } - - pub(crate) fn maybe_vector_parsed(prefix: Option, scalar: ScalarType) -> Self { - match prefix { - Some(prefix) => Type::Vector(prefix.get(), scalar), - None => Type::Scalar(scalar), - } - } - - pub(crate) fn maybe_array( - prefix: Option, - scalar: ScalarType, - array: Option>, - ) -> Self { - match array { - Some(dimensions) => Type::Array(prefix, scalar, dimensions), - None => Self::maybe_vector_parsed(prefix, scalar), - } - } - - pub fn layout(&self) -> Layout { - match self { - Type::Scalar(type_) => type_.layout(), - Type::Vector(elements, scalar_type) => { - let scalar_layout = scalar_type.layout(); - unsafe { - Layout::from_size_align_unchecked( - scalar_layout.size() * *elements as usize, - scalar_layout.align() * *elements as usize, - ) - } - } - Type::Array(non_zero, scalar, vec) => { - let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout(); - let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0); - unsafe { - Layout::from_size_align_unchecked( - element_layout.size() * (len as usize), - element_layout.align(), - ) - } - } - } - } -} - -impl ScalarType { - pub fn size_of(self) -> u8 { - match self { - ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1, - ScalarType::U16 - | ScalarType::S16 - | ScalarType::B16 - | ScalarType::F16 - | ScalarType::BF16 => 2, - ScalarType::U32 - | ScalarType::S32 - | ScalarType::B32 - | ScalarType::F32 - | ScalarType::U16x2 - | ScalarType::S16x2 - | ScalarType::F16x2 - | ScalarType::BF16x2 => 4, - ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8, - ScalarType::B128 => 16, - ScalarType::Pred => 1, - } - } - - pub fn layout(self) -> Layout { - match self { - ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::(), - ScalarType::U16 - | ScalarType::S16 - | ScalarType::B16 - | ScalarType::F16 - | ScalarType::BF16 => Layout::new::(), - ScalarType::U32 - | ScalarType::S32 - | ScalarType::B32 - | ScalarType::F32 - | ScalarType::U16x2 - | ScalarType::S16x2 - | ScalarType::F16x2 - | ScalarType::BF16x2 => Layout::new::(), - ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => { - Layout::new::() - } - ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) }, - // Close enough - ScalarType::Pred => Layout::new::(), - } - } - - pub fn kind(self) -> ScalarKind { - match self { - ScalarType::U8 => ScalarKind::Unsigned, - ScalarType::U16 => ScalarKind::Unsigned, - ScalarType::U16x2 => ScalarKind::Unsigned, - ScalarType::U32 => ScalarKind::Unsigned, - ScalarType::U64 => ScalarKind::Unsigned, - ScalarType::S8 => ScalarKind::Signed, - ScalarType::S16 => ScalarKind::Signed, - ScalarType::S16x2 => ScalarKind::Signed, - ScalarType::S32 => ScalarKind::Signed, - ScalarType::S64 => ScalarKind::Signed, - ScalarType::B8 => ScalarKind::Bit, - ScalarType::B16 => ScalarKind::Bit, - ScalarType::B32 => ScalarKind::Bit, - ScalarType::B64 => ScalarKind::Bit, - ScalarType::B128 => ScalarKind::Bit, - ScalarType::F16 => ScalarKind::Float, - ScalarType::F16x2 => ScalarKind::Float, - ScalarType::F32 => ScalarKind::Float, - ScalarType::F64 => ScalarKind::Float, - ScalarType::BF16 => ScalarKind::Float, - ScalarType::BF16x2 => ScalarKind::Float, - ScalarType::Pred => ScalarKind::Pred, - } - } -} - -#[derive(Clone, Copy, PartialEq, Eq)] -pub enum ScalarKind { - Bit, - Unsigned, - Signed, - Float, - Pred, -} -impl From for Type { - fn from(value: ScalarType) -> Self { - Type::Scalar(value) - } -} - -#[derive(Clone)] -pub struct MovDetails { - pub typ: super::Type, - pub src_is_address: bool, - // two fields below are in use by member moves - pub dst_width: u8, - pub src_width: u8, - // This is in use by auto-generated movs - pub relaxed_src2_conv: bool, -} - -impl MovDetails { - pub(crate) fn new(vector: Option, scalar: ScalarType) -> Self { - MovDetails { - typ: Type::maybe_vector(vector, scalar), - src_is_address: false, - dst_width: 0, - src_width: 0, - relaxed_src2_conv: false, - } - } -} - -#[derive(Copy, Clone)] -pub struct ShflSyncDetails { - pub mode: ShuffleMode, -} - -pub enum CpAsyncCpSize { - Bytes4, - Bytes8, - Bytes16, -} - -impl CpAsyncCpSize { - pub fn from_u64(n: u64) -> Option { - match n { - 4 => Some(Self::Bytes4), - 8 => Some(Self::Bytes8), - 16 => Some(Self::Bytes16), - _ => None, - } - } - - pub fn as_u64(&self) -> u64 { - match self { - CpAsyncCpSize::Bytes4 => 4, - CpAsyncCpSize::Bytes8 => 8, - CpAsyncCpSize::Bytes16 => 16, - } - } -} - -pub struct CpAsyncDetails { - pub caching: CpAsyncCacheOperator, - pub space: StateSpace, - pub cp_size: CpAsyncCpSize, - pub src_size: Option, -} - -#[derive(Clone)] -pub enum ParsedOperand { - Reg(Ident), - RegOffset(Ident, i32), - Imm(ImmediateValue), - VecMember(Ident, u8), - VecPack(Vec), -} - -impl ParsedOperand { - pub fn as_immediate(&self) -> Option { - match self { - ParsedOperand::Imm(imm) => Some(*imm), - _ => None, - } - } -} - -impl Operand for ParsedOperand { - type Ident = Ident; - - fn from_ident(ident: Self::Ident) -> Self { - ParsedOperand::Reg(ident) - } -} - -pub trait Operand: Sized { - type Ident: Copy; - - fn from_ident(ident: Self::Ident) -> Self; -} - -#[derive(Copy, Clone)] -pub enum ImmediateValue { - U64(u64), - S64(i64), - F32(f32), - F64(f64), -} - -impl ImmediateValue { - /// If the value is a U64 or S64, returns the value as a u64, ignoring the sign. - pub fn as_u64(&self) -> Option { - match *self { - ImmediateValue::U64(n) => Some(n), - ImmediateValue::S64(n) => Some(n as u64), - ImmediateValue::F32(_) | ImmediateValue::F64(_) => None, - } - } -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum StCacheOperator { - Writeback, - L2Only, - Streaming, - Writethrough, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdCacheOperator { - Cached, - L2Only, - Streaming, - LastUse, - Uncached, -} - -pub enum CpAsyncCacheOperator { - Cached, - L2Only, -} - -#[derive(Copy, Clone)] -pub enum ArithDetails { - Integer(ArithInteger), - Float(ArithFloat), -} - -impl ArithDetails { - pub fn type_(&self) -> ScalarType { - match self { - ArithDetails::Integer(t) => t.type_, - ArithDetails::Float(arith) => arith.type_, - } - } -} - -#[derive(Copy, Clone)] -pub struct ArithInteger { - pub type_: ScalarType, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub struct ArithFloat { - pub type_: ScalarType, - pub rounding: RoundingMode, - pub flush_to_zero: Option, - pub saturate: bool, - // From PTX documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions-add - // Note that an add instruction with an explicit rounding modifier is treated conservatively by - // the code optimizer. An add instruction with no rounding modifier defaults to - // round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular, - // mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add - // instructions on the target device. - pub is_fusable: bool, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdStQualifier { - Weak, - Volatile, - Relaxed(MemScope), - Acquire(MemScope), - Release(MemScope), -} - -#[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub enum RoundingMode { - NearestEven, - Zero, - NegativeInf, - PositiveInf, -} - -pub struct LdDetails { - pub qualifier: LdStQualifier, - pub state_space: StateSpace, - pub caching: LdCacheOperator, - pub typ: Type, - pub non_coherent: bool, -} - -pub struct StData { - pub qualifier: LdStQualifier, - pub state_space: StateSpace, - pub caching: StCacheOperator, - pub typ: Type, -} - -#[derive(Copy, Clone)] -pub struct RetData { - pub uniform: bool, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum TuningDirective { - MaxNReg(u32), - MaxNtid(u32, u32, u32), - ReqNtid(u32, u32, u32), - MinNCtaPerSm(u32), -} - -pub struct MethodDeclaration<'input, ID> { - pub return_arguments: Vec>, - pub name: MethodName<'input, ID>, - pub input_arguments: Vec>, - pub shared_mem: Option, -} - -impl<'input> MethodDeclaration<'input, &'input str> { - pub fn name(&self) -> &'input str { - match self.name { - MethodName::Kernel(n) => n, - MethodName::Func(n) => n, - } - } -} - -#[derive(Hash, PartialEq, Eq, Copy, Clone)] -pub enum MethodName<'input, ID> { - Kernel(&'input str), - Func(ID), -} - -impl<'input, ID> MethodName<'input, ID> { - pub fn is_kernel(&self) -> bool { - match self { - MethodName::Kernel(_) => true, - MethodName::Func(_) => false, - } - } -} - -impl<'input> MethodName<'input, &'input str> { - pub fn text(&self) -> &'input str { - match self { - MethodName::Kernel(name) => *name, - MethodName::Func(name) => *name, - } - } -} - -bitflags! { - pub struct LinkingDirective: u8 { - const NONE = 0b000; - const EXTERN = 0b001; - const VISIBLE = 0b10; - const WEAK = 0b100; - } -} - -pub struct Function<'a, ID, S> { - pub func_directive: MethodDeclaration<'a, ID>, - pub tuning: Vec, - pub body: Option>, -} - -pub enum Directive<'input, O: Operand> { - Variable(LinkingDirective, Variable), - Method( - LinkingDirective, - Function<'input, &'input str, Statement>, - ), -} - -pub struct Module<'input> { - pub version: (u8, u8), - pub directives: Vec>>, -} - -#[derive(Copy, Clone)] -pub enum MulDetails { - Integer { - type_: ScalarType, - control: MulIntControl, - }, - Float(ArithFloat), -} - -impl MulDetails { - pub fn type_(&self) -> ScalarType { - match self { - MulDetails::Integer { type_, .. } => *type_, - MulDetails::Float(arith) => arith.type_, - } - } - - pub fn dst_type(&self) -> ScalarType { - match self { - MulDetails::Integer { - type_, - control: MulIntControl::Wide, - } => match type_ { - ScalarType::U16 => ScalarType::U32, - ScalarType::S16 => ScalarType::S32, - ScalarType::U32 => ScalarType::U64, - ScalarType::S32 => ScalarType::S64, - _ => unreachable!(), - }, - _ => self.type_(), - } - } -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MulIntControl { - Low, - High, - Wide, -} - -#[derive(Copy, Clone)] -pub struct Mul24Details { - pub type_: ScalarType, - pub control: Mul24Control, -} - -pub struct SetData { - pub dtype: ScalarType, - pub base: SetpData, -} - -pub struct SetpData { - pub type_: ScalarType, - pub flush_to_zero: Option, - pub cmp_op: SetpCompareOp, -} - -impl SetpData { - pub(crate) fn try_parse( - state: &mut PtxParserState, - cmp_op: super::RawSetpCompareOp, - ftz: bool, - type_: ScalarType, - ) -> Self { - let flush_to_zero = match (ftz, type_) { - (_, ScalarType::F32) => Some(ftz), - (true, _) => { - state.errors.push(PtxError::NonF32Ftz); - None - } - _ => None, - }; - let type_kind = type_.kind(); - let cmp_op = if type_kind == ScalarKind::Float { - SetpCompareOp::Float(SetpCompareFloat::from(cmp_op)) - } else { - match SetpCompareInt::try_from((cmp_op, type_kind)) { - Ok(op) => SetpCompareOp::Integer(op), - Err(err) => { - state.errors.push(err); - SetpCompareOp::Integer(SetpCompareInt::Eq) - } - } - }; - Self { - type_, - flush_to_zero, - cmp_op, - } - } -} - - -pub struct SetBoolData { - pub dtype: ScalarType, - pub base: SetpBoolData, -} - -pub struct SetpBoolData { - pub base: SetpData, - pub bool_op: SetpBoolPostOp, - pub negate_src3: bool, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum SetpCompareOp { - Integer(SetpCompareInt), - Float(SetpCompareFloat), -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum SetpCompareInt { - Eq, - NotEq, - UnsignedLess, - UnsignedLessOrEq, - UnsignedGreater, - UnsignedGreaterOrEq, - SignedLess, - SignedLessOrEq, - SignedGreater, - SignedGreaterOrEq, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum SetpCompareFloat { - Eq, - NotEq, - Less, - LessOrEq, - Greater, - GreaterOrEq, - NanEq, - NanNotEq, - NanLess, - NanLessOrEq, - NanGreater, - NanGreaterOrEq, - IsNotNan, - IsAnyNan, -} - -impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt { - type Error = PtxError<'static>; - - fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result> { - match (value, kind) { - (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq), - (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq), - (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => { - Ok(SetpCompareInt::SignedLess) - } - (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess), - (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => { - Ok(SetpCompareInt::SignedLessOrEq) - } - (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => { - Ok(SetpCompareInt::UnsignedLessOrEq) - } - (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => { - Ok(SetpCompareInt::SignedGreater) - } - (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater), - (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => { - Ok(SetpCompareInt::SignedGreaterOrEq) - } - (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => { - Ok(SetpCompareInt::UnsignedGreaterOrEq) - } - (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType), - (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType), - (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType), - (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType), - (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType), - (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType), - (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType), - (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType), - } - } -} - -impl From for SetpCompareFloat { - fn from(value: RawSetpCompareOp) -> Self { - match value { - RawSetpCompareOp::Eq => SetpCompareFloat::Eq, - RawSetpCompareOp::Ne => SetpCompareFloat::NotEq, - RawSetpCompareOp::Lt => SetpCompareFloat::Less, - RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq, - RawSetpCompareOp::Gt => SetpCompareFloat::Greater, - RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq, - RawSetpCompareOp::Lo => SetpCompareFloat::Less, - RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq, - RawSetpCompareOp::Hi => SetpCompareFloat::Greater, - RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq, - RawSetpCompareOp::Equ => SetpCompareFloat::NanEq, - RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq, - RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess, - RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq, - RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater, - RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq, - RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan, - RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan, - } - } -} - -pub struct CallDetails { - pub uniform: bool, - pub return_arguments: Vec<(Type, StateSpace)>, - pub input_arguments: Vec<(Type, StateSpace)>, -} - -pub struct CallArgs { - pub return_arguments: Vec, - pub func: T::Ident, - pub input_arguments: Vec, -} - -impl CallArgs { - #[allow(dead_code)] // Used by generated code - fn visit( - &self, - details: &CallDetails, - visitor: &mut impl Visitor, - ) -> Result<(), Err> { - for (param, (type_, space)) in self - .return_arguments - .iter() - .zip(details.return_arguments.iter()) - { - visitor.visit_ident( - param, - Some((type_, *space)), - *space == StateSpace::Reg, - false, - )?; - } - visitor.visit_ident(&self.func, None, false, false)?; - for (param, (type_, space)) in self - .input_arguments - .iter() - .zip(details.input_arguments.iter()) - { - visitor.visit(param, Some((type_, *space)), false, false)?; - } - Ok(()) - } - - #[allow(dead_code)] // Used by generated code - fn visit_mut( - &mut self, - details: &CallDetails, - visitor: &mut impl VisitorMut, - ) -> Result<(), Err> { - for (param, (type_, space)) in self - .return_arguments - .iter_mut() - .zip(details.return_arguments.iter()) - { - visitor.visit_ident( - param, - Some((type_, *space)), - *space == StateSpace::Reg, - false, - )?; - } - visitor.visit_ident(&mut self.func, None, false, false)?; - for (param, (type_, space)) in self - .input_arguments - .iter_mut() - .zip(details.input_arguments.iter()) - { - visitor.visit(param, Some((type_, *space)), false, false)?; - } - Ok(()) - } - - #[allow(dead_code)] // Used by generated code - fn map( - self, - details: &CallDetails, - visitor: &mut impl VisitorMap, - ) -> Result, Err> { - let return_arguments = self - .return_arguments - .into_iter() - .zip(details.return_arguments.iter()) - .map(|(param, (type_, space))| { - visitor.visit_ident( - param, - Some((type_, *space)), - *space == StateSpace::Reg, - false, - ) - }) - .collect::, _>>()?; - let func = visitor.visit_ident(self.func, None, false, false)?; - let input_arguments = self - .input_arguments - .into_iter() - .zip(details.input_arguments.iter()) - .map(|(param, (type_, space))| { - visitor.visit(param, Some((type_, *space)), false, false) - }) - .collect::, _>>()?; - Ok(CallArgs { - return_arguments, - func, - input_arguments, - }) - } -} - -pub struct CvtDetails { - pub from: ScalarType, - pub to: ScalarType, - pub mode: CvtMode, -} - -#[derive(Clone, Copy)] -pub enum CvtMode { - // int from int - ZeroExtend, - SignExtend, - Truncate, - Bitcast, - IntSaturateToSigned, - IntSaturateToUnsigned, - // float from float - FPExtend { - flush_to_zero: Option, - saturate: bool, - }, - FPTruncate { - // float rounding - rounding: RoundingMode, - is_integer_rounding: bool, - flush_to_zero: Option, - saturate: bool, - }, - FPRound { - integer_rounding: Option, - flush_to_zero: Option, - saturate: bool, - }, - // int from float - SignedFromFP { - rounding: RoundingMode, - flush_to_zero: Option, - }, // integer rounding - UnsignedFromFP { - rounding: RoundingMode, - flush_to_zero: Option, - }, // integer rounding - // float from int, ftz is allowed in the grammar, but clearly nonsensical - FPFromSigned { - rounding: RoundingMode, - saturate: bool, - }, // float rounding - FPFromUnsigned { - rounding: RoundingMode, - saturate: bool, - }, // float rounding -} - -impl CvtDetails { - pub(crate) fn new( - errors: &mut Vec, - rnd: Option, - ftz: bool, - saturate: bool, - dst: ScalarType, - src: ScalarType, - ) -> Self { - // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. - let flush_to_zero = match (dst, src) { - (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz), - _ => { - if ftz { - errors.push(PtxError::NonF32Ftz); - } - None - } - }; - let rounding = rnd.map(RawRoundingMode::normalize); - let mut unwrap_rounding = || match rounding { - Some((rnd, is_integer)) => (rnd, is_integer), - None => { - errors.push(PtxError::SyntaxError); - (RoundingMode::NearestEven, false) - } - }; - let mode = match (dst.kind(), src.kind()) { - (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { - Ordering::Less => { - let (rounding, is_integer_rounding) = unwrap_rounding(); - CvtMode::FPTruncate { - rounding, - is_integer_rounding, - flush_to_zero, - saturate, - } - } - Ordering::Equal => CvtMode::FPRound { - integer_rounding: rounding.map(|(rnd, _)| rnd), - flush_to_zero, - saturate, - }, - Ordering::Greater => { - if rounding.is_some() { - errors.push(PtxError::SyntaxError); - } - CvtMode::FPExtend { - flush_to_zero, - saturate, - } - } - }, - (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP { - rounding: unwrap_rounding().0, - flush_to_zero, - }, - (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP { - rounding: unwrap_rounding().0, - flush_to_zero, - }, - (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned { - rounding: unwrap_rounding().0, - saturate, - }, - (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned { - rounding: unwrap_rounding().0, - saturate, - }, - (ScalarKind::Signed, ScalarKind::Unsigned) - | (ScalarKind::Signed, ScalarKind::Signed) - if saturate => - { - CvtMode::IntSaturateToSigned - } - (ScalarKind::Unsigned, ScalarKind::Signed) - | (ScalarKind::Unsigned, ScalarKind::Unsigned) - if saturate => - { - CvtMode::IntSaturateToUnsigned - } - (ScalarKind::Unsigned, ScalarKind::Unsigned) - | (ScalarKind::Signed, ScalarKind::Signed) - | (ScalarKind::Unsigned, ScalarKind::Signed) - | (ScalarKind::Signed, ScalarKind::Unsigned) - if dst.size_of() == src.size_of() => - { - CvtMode::Bitcast - } - (ScalarKind::Unsigned, ScalarKind::Unsigned) - | (ScalarKind::Signed, ScalarKind::Signed) - | (ScalarKind::Unsigned, ScalarKind::Signed) - | (ScalarKind::Signed, ScalarKind::Unsigned) => match dst.size_of().cmp(&src.size_of()) - { - Ordering::Less => CvtMode::Truncate, - Ordering::Equal => CvtMode::Bitcast, - Ordering::Greater => { - if src.kind() == ScalarKind::Signed { - CvtMode::SignExtend - } else { - CvtMode::ZeroExtend - } - } - }, - (_, _) => { - errors.push(PtxError::SyntaxError); - CvtMode::Bitcast - } - }; - CvtDetails { - mode, - to: dst, - from: src, - } - } -} - -pub struct CvtIntToIntDesc { - pub dst: ScalarType, - pub src: ScalarType, - pub saturate: bool, -} - -pub struct CvtDesc { - pub rounding: Option, - pub flush_to_zero: Option, - pub saturate: bool, - pub dst: ScalarType, - pub src: ScalarType, -} - -pub struct ShrData { - pub type_: ScalarType, - pub kind: RightShiftKind, -} - -pub enum RightShiftKind { - Arithmetic, - Logical, -} - -pub struct CvtaDetails { - pub state_space: StateSpace, - pub direction: CvtaDirection, -} - -pub enum CvtaDirection { - GenericToExplicit, - ExplicitToGeneric, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub struct TypeFtz { - pub flush_to_zero: Option, - pub type_: ScalarType, -} - -#[derive(Copy, Clone)] -pub enum MadDetails { - Integer { - control: MulIntControl, - saturate: bool, - type_: ScalarType, - }, - Float(ArithFloat), -} - -impl MadDetails { - pub fn dst_type(&self) -> ScalarType { - match self { - MadDetails::Integer { - type_, - control: MulIntControl::Wide, - .. - } => match type_ { - ScalarType::U16 => ScalarType::U32, - ScalarType::S16 => ScalarType::S32, - ScalarType::U32 => ScalarType::U64, - ScalarType::S32 => ScalarType::S64, - _ => unreachable!(), - }, - _ => self.type_(), - } - } - - fn type_(&self) -> ScalarType { - match self { - MadDetails::Integer { type_, .. } => *type_, - MadDetails::Float(arith) => arith.type_, - } - } -} - -#[derive(Copy, Clone)] -pub enum MinMaxDetails { - Signed(ScalarType), - Unsigned(ScalarType), - Float(MinMaxFloat), -} - -impl MinMaxDetails { - pub fn type_(&self) -> ScalarType { - match self { - MinMaxDetails::Signed(t) => *t, - MinMaxDetails::Unsigned(t) => *t, - MinMaxDetails::Float(float) => float.type_, - } - } -} - -#[derive(Copy, Clone)] -pub struct MinMaxFloat { - pub flush_to_zero: Option, - pub nan: bool, - pub type_: ScalarType, -} - -#[derive(Copy, Clone)] -pub struct RcpData { - pub kind: RcpKind, - pub flush_to_zero: Option, - pub type_: ScalarType, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum RcpKind { - Approx, - Compliant(RoundingMode), -} - -pub struct BarData { - pub aligned: bool, -} - -#[derive(Copy, Clone)] -pub struct BarRedData { - pub aligned: bool, - pub pred_reduction: Reduction, -} - -pub struct AtomDetails { - pub type_: Type, - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: StateSpace, - pub op: AtomicOp, -} - -#[derive(Copy, Clone)] -pub enum AtomicOp { - And, - Or, - Xor, - Exchange, - Add, - IncrementWrap, - DecrementWrap, - SignedMin, - UnsignedMin, - SignedMax, - UnsignedMax, - FloatAdd, - FloatMin, - FloatMax, -} - -impl AtomicOp { - pub(crate) fn new(op: super::RawAtomicOp, kind: ScalarKind) -> Self { - use super::RawAtomicOp; - match (op, kind) { - (RawAtomicOp::And, _) => Self::And, - (RawAtomicOp::Or, _) => Self::Or, - (RawAtomicOp::Xor, _) => Self::Xor, - (RawAtomicOp::Exch, _) => Self::Exchange, - (RawAtomicOp::Add, ScalarKind::Float) => Self::FloatAdd, - (RawAtomicOp::Add, _) => Self::Add, - (RawAtomicOp::Inc, _) => Self::IncrementWrap, - (RawAtomicOp::Dec, _) => Self::DecrementWrap, - (RawAtomicOp::Min, ScalarKind::Signed) => Self::SignedMin, - (RawAtomicOp::Min, ScalarKind::Float) => Self::FloatMin, - (RawAtomicOp::Min, _) => Self::UnsignedMin, - (RawAtomicOp::Max, ScalarKind::Signed) => Self::SignedMax, - (RawAtomicOp::Max, ScalarKind::Float) => Self::FloatMax, - (RawAtomicOp::Max, _) => Self::UnsignedMax, - } - } -} - -pub struct AtomCasDetails { - pub type_: ScalarType, - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: StateSpace, -} - -#[derive(Copy, Clone)] -pub enum DivDetails { - Unsigned(ScalarType), - Signed(ScalarType), - Float(DivFloatDetails), -} - -impl DivDetails { - pub fn type_(&self) -> ScalarType { - match self { - DivDetails::Unsigned(t) => *t, - DivDetails::Signed(t) => *t, - DivDetails::Float(float) => float.type_, - } - } -} - -#[derive(Copy, Clone)] -pub struct DivFloatDetails { - pub type_: ScalarType, - pub flush_to_zero: Option, - pub kind: DivFloatKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum DivFloatKind { - Approx, - ApproxFull, - Rounding(RoundingMode), -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct FlushToZero { - pub flush_to_zero: bool, -} +use super::{ + AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, + StateSpace, VectorPrefix, +}; +use crate::{Mul24Control, PtxError, PtxParserState, Reduction, ShuffleMode}; +use bitflags::bitflags; +use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; + +pub enum Statement { + Label(P::Ident), + Variable(MultiVariable), + Instruction(Option>, Instruction

), + Block(Vec>), +} + +// We define the instruction enum through the macro instead of normally, because we have some of how +// we use this type in the compilee. Each instruction can be logically split into two parts: +// properties that define instruction semantics (e.g. is memory load volatile?) that don't change +// during compilation and arguments (e.g. memory load source and destination) that evolve during +// compilation. To support compilation passes we need to be able to visit (and change) every +// argument in a generic way. This macro has visibility over all the fields. Consequently, we use it +// to generate visitor functions. There re three functions to support three different semantics: +// visit-by-ref, visit-by-mutable-ref, visit-and-map. In a previous version of the compiler it was +// done by hand and was very limiting (we supported only visit-and-map). +// The visitor must implement appropriate visitor trait defined below this macro. For convenience, +// we implemented visitors for some corresponding FnMut(...) types. +// Properties in this macro are used to encode information about the instruction arguments (what +// Rust type is used for it post-parsing, what PTX type does it expect, what PTX address space does +// it expect, etc.). +// This information is then available to a visitor. +ptx_parser_macros::generate_instruction_type!( + pub enum Instruction { + Abs { + data: TypeFtz, + type: { Type::Scalar(data.type_) }, + arguments: { + dst: T, + src: T, + } + }, + Activemask { + type: Type::Scalar(ScalarType::B32), + arguments: { + dst: T + } + }, + Add { + type: { Type::from(data.type_()) }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + And { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Atom { + type: &data.type_, + data: AtomDetails, + arguments: { + dst: T, + src1: { + repr: T, + space: { data.space }, + }, + src2: T, + } + }, + AtomCas { + type: Type::Scalar(data.type_), + data: AtomCasDetails, + arguments: { + dst: T, + src1: { + repr: T, + space: { data.space }, + }, + src2: T, + src3: T, + } + }, + Bar { + type: Type::Scalar(ScalarType::U32), + data: BarData, + arguments: { + src1: T, + src2: Option, + } + }, + BarRed { + type: Type::Scalar(ScalarType::U32), + data: BarRedData, + arguments: { + dst1: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + src_barrier: T, + src_threadcount: Option, + src_predicate: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + src_negate_predicate: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + } + }, + Bfe { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src3: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + } + }, + Bfi { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T, + src3: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src4: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + } + }, + Bra { + type: !, + arguments: { + src: T + } + }, + Brev { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src: T + } + }, + Call { + data: CallDetails, + arguments: CallArgs, + visit: arguments.visit(data, visitor)?, + visit_mut: arguments.visit_mut(data, visitor)?, + map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data } + }, + Clz { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src: T + } + }, + Cos { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + CpAsync { + type: Type::Scalar(ScalarType::U32), + data: CpAsyncDetails, + arguments: { + src_to: { + repr: T, + space: StateSpace::Shared + }, + src_from: { + repr: T, + space: StateSpace::Global + } + } + }, + CpAsyncCommitGroup { }, + CpAsyncWaitGroup { + type: Type::Scalar(ScalarType::U64), + arguments: { + src_group: T + } + }, + CpAsyncWaitAll { }, + Cvt { + data: CvtDetails, + arguments: { + dst: { + repr: T, + type: { Type::Scalar(data.to) }, + // TODO: double check + relaxed_type_check: true, + }, + src: { + repr: T, + type: { Type::Scalar(data.from) }, + relaxed_type_check: true, + }, + } + }, + Cvta { + data: CvtaDetails, + type: { Type::Scalar(ScalarType::B64) }, + arguments: { + dst: T, + src: T, + } + }, + Div { + type: Type::Scalar(data.type_()), + data: DivDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Ex2 { + type: Type::Scalar(ScalarType::F32), + data: TypeFtz, + arguments: { + dst: T, + src: T + } + }, + Fma { + type: { Type::from(data.type_) }, + data: ArithFloat, + arguments: { + dst: T, + src1: T, + src2: T, + src3: T, + } + }, + Ld { + type: { &data.typ }, + data: LdDetails, + arguments: { + dst: { + repr: T, + relaxed_type_check: true, + }, + src: { + repr: T, + space: { data.state_space }, + } + } + }, + Lg2 { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Mad { + type: { Type::from(data.type_()) }, + data: MadDetails, + arguments: { + dst: { + repr: T, + type: { Type::from(data.dst_type()) }, + }, + src1: T, + src2: T, + src3: { + repr: T, + type: { Type::from(data.dst_type()) }, + } + } + }, + Max { + type: { Type::from(data.type_()) }, + data: MinMaxDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Membar { + data: MemScope + }, + Min { + type: { Type::from(data.type_()) }, + data: MinMaxDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Mov { + type: { &data.typ }, + data: MovDetails, + arguments: { + dst: T, + src: T + } + }, + Mul { + type: { Type::from(data.type_()) }, + data: MulDetails, + arguments: { + dst: { + repr: T, + type: { Type::from(data.dst_type()) }, + }, + src1: T, + src2: T, + } + }, + Mul24 { + type: { Type::from(data.type_) }, + data: Mul24Details, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Nanosleep { + type: Type::Scalar(ScalarType::U32), + arguments: { + src: T + } + }, + Neg { + type: Type::Scalar(data.type_), + data: TypeFtz, + arguments: { + dst: T, + src: T + } + }, + Not { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src: T, + } + }, + Or { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Popc { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src: T + } + }, + Prmt { + type: Type::Scalar(ScalarType::B32), + data: u16, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + PrmtSlow { + type: Type::Scalar(ScalarType::U32), + arguments: { + dst: T, + src1: T, + src2: T, + src3: T + } + }, + Rcp { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + Rem { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Ret { + data: RetData + }, + Rsqrt { + type: { Type::from(data.type_) }, + data: TypeFtz, + arguments: { + dst: T, + src: T, + } + }, + Selp { + type: { Type::Scalar(data.clone()) }, + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T, + src3: { + repr: T, + type: Type::Scalar(ScalarType::Pred) + }, + } + }, + Set { + data: SetData, + arguments: { + dst: { + repr: T, + type: Type::from(data.dtype) + }, + src1: { + repr: T, + type: Type::from(data.base.type_), + }, + src2: { + repr: T, + type: Type::from(data.base.type_), + } + } + }, + SetBool { + data: SetBoolData, + arguments: { + dst: { + repr: T, + type: Type::from(data.dtype) + }, + src1: { + repr: T, + type: Type::from(data.base.base.type_), + }, + src2: { + repr: T, + type: Type::from(data.base.base.type_), + }, + src3: { + repr: T, + type: Type::from(ScalarType::Pred) + } + } + }, + Setp { + data: SetpData, + arguments: { + dst1: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + dst2: { + repr: Option, + type: Type::from(ScalarType::Pred) + }, + src1: { + repr: T, + type: Type::from(data.type_), + }, + src2: { + repr: T, + type: Type::from(data.type_), + } + } + }, + SetpBool { + data: SetpBoolData, + arguments: { + dst1: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + dst2: { + repr: Option, + type: Type::from(ScalarType::Pred) + }, + src1: { + repr: T, + type: Type::from(data.base.type_), + }, + src2: { + repr: T, + type: Type::from(data.base.type_), + }, + src3: { + repr: T, + type: Type::from(ScalarType::Pred) + } + } + }, + ShflSync { + data: ShflSyncDetails, + type: Type::Scalar(ScalarType::B32), + arguments: { + dst: T, + dst_pred: { + repr: Option, + type: Type::from(ScalarType::Pred) + }, + src: T, + src_lane: T, + src_opts: T, + src_membermask: T + } + }, + Shl { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + }, + } + }, + Shr { + data: ShrData, + type: { Type::Scalar(data.type_.clone()) }, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + }, + } + }, + Sin { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Sqrt { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + St { + type: { &data.typ }, + data: StData, + arguments: { + src1: { + repr: T, + space: { data.state_space }, + }, + src2: { + repr: T, + relaxed_type_check: true, + } + } + }, + Sub { + type: { Type::from(data.type_()) }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Trap { }, + Xor { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Tanh { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src: T + } + }, + } +); + +pub trait Visitor { + fn visit( + &mut self, + args: &T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err>; + fn visit_ident( + &mut self, + args: &T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err>; +} + +impl< + T: Operand, + Err, + Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>, + > Visitor for Fn +{ + fn visit( + &mut self, + args: &T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err> { + (self)(args, type_space, is_dst, relaxed_type_check) + } + + fn visit_ident( + &mut self, + args: &T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err> { + (self)( + &T::from_ident(*args), + type_space, + is_dst, + relaxed_type_check, + ) + } +} + +pub trait VisitorMut { + fn visit( + &mut self, + args: &mut T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err>; + fn visit_ident( + &mut self, + args: &mut T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err>; +} + +pub trait VisitorMap { + fn visit( + &mut self, + args: From, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result; + fn visit_ident( + &mut self, + args: From::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result; +} + +impl VisitorMap, ParsedOperand, Err> for Fn +where + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, +{ + fn visit( + &mut self, + args: ParsedOperand, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result, Err> { + Ok(match args { + ParsedOperand::Reg(ident) => { + ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?) + } + ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset( + (self)(ident, type_space, is_dst, relaxed_type_check)?, + imm, + ), + ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm), + ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember( + (self)(ident, type_space, is_dst, relaxed_type_check)?, + index, + ), + ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( + vec.into_iter() + .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check)) + .collect::, _>>()?, + ), + }) + } + + fn visit_ident( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self)(args, type_space, is_dst, relaxed_type_check) + } +} + +impl, U: Operand, Err, Fn> VisitorMap for Fn +where + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, +{ + fn visit( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self)(args, type_space, is_dst, relaxed_type_check) + } + + fn visit_ident( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self)(args, type_space, is_dst, relaxed_type_check) + } +} + +trait VisitOperand { + type Operand: Operand; + #[allow(unused)] // Used by generated code + fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>; + #[allow(unused)] // Used by generated code + fn visit_mut( + &mut self, + fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err>; +} + +impl VisitOperand for T { + type Operand = Self; + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { + fn_(self) + } + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { + fn_(self) + } +} + +impl VisitOperand for Option { + type Operand = T; + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { + if let Some(x) = self { + fn_(x)?; + } + Ok(()) + } + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { + if let Some(x) = self { + fn_(x)?; + } + Ok(()) + } +} + +impl VisitOperand for Vec { + type Operand = T; + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { + for o in self { + fn_(o)?; + } + Ok(()) + } + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { + for o in self { + fn_(o)?; + } + Ok(()) + } +} + +trait MapOperand: Sized { + type Input; + type Output; + #[allow(unused)] // Used by generated code + fn map( + self, + fn_: impl FnOnce(Self::Input) -> Result, + ) -> Result, Err>; +} + +impl MapOperand for T { + type Input = Self; + type Output = U; + fn map(self, fn_: impl FnOnce(T) -> Result) -> Result { + fn_(self) + } +} + +impl MapOperand for Option { + type Input = T; + type Output = Option; + fn map(self, fn_: impl FnOnce(T) -> Result) -> Result, Err> { + self.map(|x| fn_(x)).transpose() + } +} + +pub struct MultiVariable { + pub var: Variable, + pub count: Option, +} + +#[derive(Clone)] +pub struct Variable { + pub align: Option, + pub v_type: Type, + pub state_space: StateSpace, + pub name: ID, + pub array_init: Vec, +} + +pub struct PredAt { + pub not: bool, + pub label: ID, +} + +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum Type { + // .param.b32 foo; + Scalar(ScalarType), + // .param.v2.b32 foo; + Vector(u8, ScalarType), + // .param.b32 foo[4]; + Array(Option, ScalarType, Vec), +} + +impl Type { + pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { + match vector { + Some(prefix) => Type::Vector(prefix.len().get(), scalar), + None => Type::Scalar(scalar), + } + } + + pub(crate) fn maybe_vector_parsed(prefix: Option, scalar: ScalarType) -> Self { + match prefix { + Some(prefix) => Type::Vector(prefix.get(), scalar), + None => Type::Scalar(scalar), + } + } + + pub(crate) fn maybe_array( + prefix: Option, + scalar: ScalarType, + array: Option>, + ) -> Self { + match array { + Some(dimensions) => Type::Array(prefix, scalar, dimensions), + None => Self::maybe_vector_parsed(prefix, scalar), + } + } + + pub fn layout(&self) -> Layout { + match self { + Type::Scalar(type_) => type_.layout(), + Type::Vector(elements, scalar_type) => { + let scalar_layout = scalar_type.layout(); + unsafe { + Layout::from_size_align_unchecked( + scalar_layout.size() * *elements as usize, + scalar_layout.align() * *elements as usize, + ) + } + } + Type::Array(non_zero, scalar, vec) => { + let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout(); + let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0); + unsafe { + Layout::from_size_align_unchecked( + element_layout.size() * (len as usize), + element_layout.align(), + ) + } + } + } + } +} + +impl ScalarType { + pub fn size_of(self) -> u8 { + match self { + ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1, + ScalarType::U16 + | ScalarType::S16 + | ScalarType::B16 + | ScalarType::F16 + | ScalarType::BF16 => 2, + ScalarType::U32 + | ScalarType::S32 + | ScalarType::B32 + | ScalarType::F32 + | ScalarType::U16x2 + | ScalarType::S16x2 + | ScalarType::F16x2 + | ScalarType::BF16x2 => 4, + ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8, + ScalarType::B128 => 16, + ScalarType::Pred => 1, + } + } + + pub fn layout(self) -> Layout { + match self { + ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::(), + ScalarType::U16 + | ScalarType::S16 + | ScalarType::B16 + | ScalarType::F16 + | ScalarType::BF16 => Layout::new::(), + ScalarType::U32 + | ScalarType::S32 + | ScalarType::B32 + | ScalarType::F32 + | ScalarType::U16x2 + | ScalarType::S16x2 + | ScalarType::F16x2 + | ScalarType::BF16x2 => Layout::new::(), + ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => { + Layout::new::() + } + ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) }, + // Close enough + ScalarType::Pred => Layout::new::(), + } + } + + pub fn kind(self) -> ScalarKind { + match self { + ScalarType::U8 => ScalarKind::Unsigned, + ScalarType::U16 => ScalarKind::Unsigned, + ScalarType::U16x2 => ScalarKind::Unsigned, + ScalarType::U32 => ScalarKind::Unsigned, + ScalarType::U64 => ScalarKind::Unsigned, + ScalarType::S8 => ScalarKind::Signed, + ScalarType::S16 => ScalarKind::Signed, + ScalarType::S16x2 => ScalarKind::Signed, + ScalarType::S32 => ScalarKind::Signed, + ScalarType::S64 => ScalarKind::Signed, + ScalarType::B8 => ScalarKind::Bit, + ScalarType::B16 => ScalarKind::Bit, + ScalarType::B32 => ScalarKind::Bit, + ScalarType::B64 => ScalarKind::Bit, + ScalarType::B128 => ScalarKind::Bit, + ScalarType::F16 => ScalarKind::Float, + ScalarType::F16x2 => ScalarKind::Float, + ScalarType::F32 => ScalarKind::Float, + ScalarType::F64 => ScalarKind::Float, + ScalarType::BF16 => ScalarKind::Float, + ScalarType::BF16x2 => ScalarKind::Float, + ScalarType::Pred => ScalarKind::Pred, + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ScalarKind { + Bit, + Unsigned, + Signed, + Float, + Pred, +} +impl From for Type { + fn from(value: ScalarType) -> Self { + Type::Scalar(value) + } +} + +#[derive(Clone)] +pub struct MovDetails { + pub typ: super::Type, + pub src_is_address: bool, + // two fields below are in use by member moves + pub dst_width: u8, + pub src_width: u8, + // This is in use by auto-generated movs + pub relaxed_src2_conv: bool, +} + +impl MovDetails { + pub(crate) fn new(vector: Option, scalar: ScalarType) -> Self { + MovDetails { + typ: Type::maybe_vector(vector, scalar), + src_is_address: false, + dst_width: 0, + src_width: 0, + relaxed_src2_conv: false, + } + } +} + +#[derive(Copy, Clone)] +pub struct ShflSyncDetails { + pub mode: ShuffleMode, +} + +pub enum CpAsyncCpSize { + Bytes4, + Bytes8, + Bytes16, +} + +impl CpAsyncCpSize { + pub fn from_u64(n: u64) -> Option { + match n { + 4 => Some(Self::Bytes4), + 8 => Some(Self::Bytes8), + 16 => Some(Self::Bytes16), + _ => None, + } + } + + pub fn as_u64(&self) -> u64 { + match self { + CpAsyncCpSize::Bytes4 => 4, + CpAsyncCpSize::Bytes8 => 8, + CpAsyncCpSize::Bytes16 => 16, + } + } +} + +pub struct CpAsyncDetails { + pub caching: CpAsyncCacheOperator, + pub space: StateSpace, + pub cp_size: CpAsyncCpSize, + pub src_size: Option, +} + +#[derive(Clone)] +pub enum ParsedOperand { + Reg(Ident), + RegOffset(Ident, i32), + Imm(ImmediateValue), + VecMember(Ident, u8), + VecPack(Vec), +} + +impl ParsedOperand { + pub fn as_immediate(&self) -> Option { + match self { + ParsedOperand::Imm(imm) => Some(*imm), + _ => None, + } + } +} + +impl Operand for ParsedOperand { + type Ident = Ident; + + fn from_ident(ident: Self::Ident) -> Self { + ParsedOperand::Reg(ident) + } +} + +pub trait Operand: Sized { + type Ident: Copy; + + fn from_ident(ident: Self::Ident) -> Self; +} + +#[derive(Copy, Clone)] +pub enum ImmediateValue { + U64(u64), + S64(i64), + F32(f32), + F64(f64), +} + +impl ImmediateValue { + /// If the value is a U64 or S64, returns the value as a u64, ignoring the sign. + pub fn as_u64(&self) -> Option { + match *self { + ImmediateValue::U64(n) => Some(n), + ImmediateValue::S64(n) => Some(n as u64), + ImmediateValue::F32(_) | ImmediateValue::F64(_) => None, + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum StCacheOperator { + Writeback, + L2Only, + Streaming, + Writethrough, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdCacheOperator { + Cached, + L2Only, + Streaming, + LastUse, + Uncached, +} + +pub enum CpAsyncCacheOperator { + Cached, + L2Only, +} + +#[derive(Copy, Clone)] +pub enum ArithDetails { + Integer(ArithInteger), + Float(ArithFloat), +} + +impl ArithDetails { + pub fn type_(&self) -> ScalarType { + match self { + ArithDetails::Integer(t) => t.type_, + ArithDetails::Float(arith) => arith.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct ArithInteger { + pub type_: ScalarType, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub struct ArithFloat { + pub type_: ScalarType, + pub rounding: RoundingMode, + pub flush_to_zero: Option, + pub saturate: bool, + // From PTX documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions-add + // Note that an add instruction with an explicit rounding modifier is treated conservatively by + // the code optimizer. An add instruction with no rounding modifier defaults to + // round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular, + // mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add + // instructions on the target device. + pub is_fusable: bool, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdStQualifier { + Weak, + Volatile, + Relaxed(MemScope), + Acquire(MemScope), + Release(MemScope), +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub enum RoundingMode { + NearestEven, + Zero, + NegativeInf, + PositiveInf, +} + +pub struct LdDetails { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: LdCacheOperator, + pub typ: Type, + pub non_coherent: bool, +} + +pub struct StData { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: StCacheOperator, + pub typ: Type, +} + +#[derive(Copy, Clone)] +pub struct RetData { + pub uniform: bool, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum TuningDirective { + MaxNReg(u32), + MaxNtid(u32, u32, u32), + ReqNtid(u32, u32, u32), + MinNCtaPerSm(u32), +} + +pub struct MethodDeclaration<'input, ID> { + pub return_arguments: Vec>, + pub name: MethodName<'input, ID>, + pub input_arguments: Vec>, + pub shared_mem: Option, +} + +impl<'input> MethodDeclaration<'input, &'input str> { + pub fn name(&self) -> &'input str { + match self.name { + MethodName::Kernel(n) => n, + MethodName::Func(n) => n, + } + } +} + +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub enum MethodName<'input, ID> { + Kernel(&'input str), + Func(ID), +} + +impl<'input, ID> MethodName<'input, ID> { + pub fn is_kernel(&self) -> bool { + match self { + MethodName::Kernel(_) => true, + MethodName::Func(_) => false, + } + } +} + +impl<'input> MethodName<'input, &'input str> { + pub fn text(&self) -> &'input str { + match self { + MethodName::Kernel(name) => *name, + MethodName::Func(name) => *name, + } + } +} + +bitflags! { + pub struct LinkingDirective: u8 { + const NONE = 0b000; + const EXTERN = 0b001; + const VISIBLE = 0b10; + const WEAK = 0b100; + } +} + +pub struct Function<'a, ID, S> { + pub func_directive: MethodDeclaration<'a, ID>, + pub tuning: Vec, + pub body: Option>, +} + +pub enum Directive<'input, O: Operand> { + Variable(LinkingDirective, Variable), + Method( + LinkingDirective, + Function<'input, &'input str, Statement>, + ), +} + +pub struct Module<'input> { + pub version: (u8, u8), + pub directives: Vec>>, +} + +#[derive(Copy, Clone)] +pub enum MulDetails { + Integer { + type_: ScalarType, + control: MulIntControl, + }, + Float(ArithFloat), +} + +impl MulDetails { + pub fn type_(&self) -> ScalarType { + match self { + MulDetails::Integer { type_, .. } => *type_, + MulDetails::Float(arith) => arith.type_, + } + } + + pub fn dst_type(&self) -> ScalarType { + match self { + MulDetails::Integer { + type_, + control: MulIntControl::Wide, + } => match type_ { + ScalarType::U16 => ScalarType::U32, + ScalarType::S16 => ScalarType::S32, + ScalarType::U32 => ScalarType::U64, + ScalarType::S32 => ScalarType::S64, + _ => unreachable!(), + }, + _ => self.type_(), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum MulIntControl { + Low, + High, + Wide, +} + +#[derive(Copy, Clone)] +pub struct Mul24Details { + pub type_: ScalarType, + pub control: Mul24Control, +} + +pub struct SetData { + pub dtype: ScalarType, + pub base: SetpData, +} + +pub struct SetpData { + pub type_: ScalarType, + pub flush_to_zero: Option, + pub cmp_op: SetpCompareOp, +} + +impl SetpData { + pub(crate) fn try_parse( + state: &mut PtxParserState, + cmp_op: super::RawSetpCompareOp, + ftz: bool, + type_: ScalarType, + ) -> Self { + let flush_to_zero = match (ftz, type_) { + (_, ScalarType::F32) => Some(ftz), + (true, _) => { + state.errors.push(PtxError::NonF32Ftz); + None + } + _ => None, + }; + let type_kind = type_.kind(); + let cmp_op = if type_kind == ScalarKind::Float { + SetpCompareOp::Float(SetpCompareFloat::from(cmp_op)) + } else { + match SetpCompareInt::try_from((cmp_op, type_kind)) { + Ok(op) => SetpCompareOp::Integer(op), + Err(err) => { + state.errors.push(err); + SetpCompareOp::Integer(SetpCompareInt::Eq) + } + } + }; + Self { + type_, + flush_to_zero, + cmp_op, + } + } +} + +pub struct SetBoolData { + pub dtype: ScalarType, + pub base: SetpBoolData, +} + +pub struct SetpBoolData { + pub base: SetpData, + pub bool_op: SetpBoolPostOp, + pub negate_src3: bool, +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareOp { + Integer(SetpCompareInt), + Float(SetpCompareFloat), +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareInt { + Eq, + NotEq, + UnsignedLess, + UnsignedLessOrEq, + UnsignedGreater, + UnsignedGreaterOrEq, + SignedLess, + SignedLessOrEq, + SignedGreater, + SignedGreaterOrEq, +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareFloat { + Eq, + NotEq, + Less, + LessOrEq, + Greater, + GreaterOrEq, + NanEq, + NanNotEq, + NanLess, + NanLessOrEq, + NanGreater, + NanGreaterOrEq, + IsNotNan, + IsAnyNan, +} + +impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt { + type Error = PtxError<'static>; + + fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result> { + match (value, kind) { + (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq), + (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq), + (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedLess) + } + (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess), + (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedLessOrEq) + } + (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => { + Ok(SetpCompareInt::UnsignedLessOrEq) + } + (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedGreater) + } + (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater), + (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedGreaterOrEq) + } + (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => { + Ok(SetpCompareInt::UnsignedGreaterOrEq) + } + (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType), + } + } +} + +impl From for SetpCompareFloat { + fn from(value: RawSetpCompareOp) -> Self { + match value { + RawSetpCompareOp::Eq => SetpCompareFloat::Eq, + RawSetpCompareOp::Ne => SetpCompareFloat::NotEq, + RawSetpCompareOp::Lt => SetpCompareFloat::Less, + RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq, + RawSetpCompareOp::Gt => SetpCompareFloat::Greater, + RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq, + RawSetpCompareOp::Lo => SetpCompareFloat::Less, + RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq, + RawSetpCompareOp::Hi => SetpCompareFloat::Greater, + RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq, + RawSetpCompareOp::Equ => SetpCompareFloat::NanEq, + RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq, + RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess, + RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq, + RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater, + RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq, + RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan, + RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan, + } + } +} + +pub struct CallDetails { + pub uniform: bool, + pub return_arguments: Vec<(Type, StateSpace)>, + pub input_arguments: Vec<(Type, StateSpace)>, +} + +pub struct CallArgs { + pub return_arguments: Vec, + pub func: T::Ident, + pub input_arguments: Vec, +} + +impl CallArgs { + #[allow(dead_code)] // Used by generated code + fn visit( + &self, + details: &CallDetails, + visitor: &mut impl Visitor, + ) -> Result<(), Err> { + for (param, (type_, space)) in self + .return_arguments + .iter() + .zip(details.return_arguments.iter()) + { + visitor.visit_ident( + param, + Some((type_, *space)), + *space == StateSpace::Reg, + false, + )?; + } + visitor.visit_ident(&self.func, None, false, false)?; + for (param, (type_, space)) in self + .input_arguments + .iter() + .zip(details.input_arguments.iter()) + { + visitor.visit(param, Some((type_, *space)), false, false)?; + } + Ok(()) + } + + #[allow(dead_code)] // Used by generated code + fn visit_mut( + &mut self, + details: &CallDetails, + visitor: &mut impl VisitorMut, + ) -> Result<(), Err> { + for (param, (type_, space)) in self + .return_arguments + .iter_mut() + .zip(details.return_arguments.iter()) + { + visitor.visit_ident( + param, + Some((type_, *space)), + *space == StateSpace::Reg, + false, + )?; + } + visitor.visit_ident(&mut self.func, None, false, false)?; + for (param, (type_, space)) in self + .input_arguments + .iter_mut() + .zip(details.input_arguments.iter()) + { + visitor.visit(param, Some((type_, *space)), false, false)?; + } + Ok(()) + } + + #[allow(dead_code)] // Used by generated code + fn map( + self, + details: &CallDetails, + visitor: &mut impl VisitorMap, + ) -> Result, Err> { + let return_arguments = self + .return_arguments + .into_iter() + .zip(details.return_arguments.iter()) + .map(|(param, (type_, space))| { + visitor.visit_ident( + param, + Some((type_, *space)), + *space == StateSpace::Reg, + false, + ) + }) + .collect::, _>>()?; + let func = visitor.visit_ident(self.func, None, false, false)?; + let input_arguments = self + .input_arguments + .into_iter() + .zip(details.input_arguments.iter()) + .map(|(param, (type_, space))| { + visitor.visit(param, Some((type_, *space)), false, false) + }) + .collect::, _>>()?; + Ok(CallArgs { + return_arguments, + func, + input_arguments, + }) + } +} + +pub struct CvtDetails { + pub from: ScalarType, + pub to: ScalarType, + pub mode: CvtMode, +} + +#[derive(Clone, Copy)] +pub enum CvtMode { + // int from int + ZeroExtend, + SignExtend, + Truncate, + Bitcast, + IntSaturateToSigned, + IntSaturateToUnsigned, + // float from float + FPExtend { + flush_to_zero: Option, + saturate: bool, + }, + FPTruncate { + // float rounding + rounding: RoundingMode, + is_integer_rounding: bool, + flush_to_zero: Option, + saturate: bool, + }, + FPRound { + integer_rounding: Option, + flush_to_zero: Option, + saturate: bool, + }, + // int from float + SignedFromFP { + rounding: RoundingMode, + flush_to_zero: Option, + }, // integer rounding + UnsignedFromFP { + rounding: RoundingMode, + flush_to_zero: Option, + }, // integer rounding + // float from int, ftz is allowed in the grammar, but clearly nonsensical + FPFromSigned { + rounding: RoundingMode, + saturate: bool, + }, // float rounding + FPFromUnsigned { + rounding: RoundingMode, + saturate: bool, + }, // float rounding +} + +impl CvtDetails { + pub(crate) fn new( + errors: &mut Vec, + rnd: Option, + ftz: bool, + saturate: bool, + dst: ScalarType, + src: ScalarType, + ) -> Self { + // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. + let flush_to_zero = match (dst, src) { + (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz), + _ => { + if ftz { + errors.push(PtxError::NonF32Ftz); + } + None + } + }; + let rounding = rnd.map(RawRoundingMode::normalize); + let mut unwrap_rounding = || match rounding { + Some((rnd, is_integer)) => (rnd, is_integer), + None => { + errors.push(PtxError::SyntaxError); + (RoundingMode::NearestEven, false) + } + }; + let mode = match (dst.kind(), src.kind()) { + (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => { + let (rounding, is_integer_rounding) = unwrap_rounding(); + CvtMode::FPTruncate { + rounding, + is_integer_rounding, + flush_to_zero, + saturate, + } + } + Ordering::Equal => CvtMode::FPRound { + integer_rounding: rounding.map(|(rnd, _)| rnd), + flush_to_zero, + saturate, + }, + Ordering::Greater => { + if rounding.is_some() { + errors.push(PtxError::SyntaxError); + } + CvtMode::FPExtend { + flush_to_zero, + saturate, + } + } + }, + (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP { + rounding: unwrap_rounding().0, + flush_to_zero, + }, + (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP { + rounding: unwrap_rounding().0, + flush_to_zero, + }, + (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned { + rounding: unwrap_rounding().0, + saturate, + }, + (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned { + rounding: unwrap_rounding().0, + saturate, + }, + (ScalarKind::Signed, ScalarKind::Unsigned) + | (ScalarKind::Signed, ScalarKind::Signed) + if saturate => + { + CvtMode::IntSaturateToSigned + } + (ScalarKind::Unsigned, ScalarKind::Signed) + | (ScalarKind::Unsigned, ScalarKind::Unsigned) + if saturate => + { + CvtMode::IntSaturateToUnsigned + } + (ScalarKind::Unsigned, ScalarKind::Unsigned) + | (ScalarKind::Signed, ScalarKind::Signed) + | (ScalarKind::Unsigned, ScalarKind::Signed) + | (ScalarKind::Signed, ScalarKind::Unsigned) + if dst.size_of() == src.size_of() => + { + CvtMode::Bitcast + } + (ScalarKind::Unsigned, ScalarKind::Unsigned) + | (ScalarKind::Signed, ScalarKind::Signed) + | (ScalarKind::Unsigned, ScalarKind::Signed) + | (ScalarKind::Signed, ScalarKind::Unsigned) => match dst.size_of().cmp(&src.size_of()) + { + Ordering::Less => CvtMode::Truncate, + Ordering::Equal => CvtMode::Bitcast, + Ordering::Greater => { + if src.kind() == ScalarKind::Signed { + CvtMode::SignExtend + } else { + CvtMode::ZeroExtend + } + } + }, + (_, _) => { + errors.push(PtxError::SyntaxError); + CvtMode::Bitcast + } + }; + CvtDetails { + mode, + to: dst, + from: src, + } + } +} + +pub struct CvtIntToIntDesc { + pub dst: ScalarType, + pub src: ScalarType, + pub saturate: bool, +} + +pub struct CvtDesc { + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, + pub dst: ScalarType, + pub src: ScalarType, +} + +pub struct ShrData { + pub type_: ScalarType, + pub kind: RightShiftKind, +} + +pub enum RightShiftKind { + Arithmetic, + Logical, +} + +pub struct CvtaDetails { + pub state_space: StateSpace, + pub direction: CvtaDirection, +} + +pub enum CvtaDirection { + GenericToExplicit, + ExplicitToGeneric, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct TypeFtz { + pub flush_to_zero: Option, + pub type_: ScalarType, +} + +#[derive(Copy, Clone)] +pub enum MadDetails { + Integer { + control: MulIntControl, + saturate: bool, + type_: ScalarType, + }, + Float(ArithFloat), +} + +impl MadDetails { + pub fn dst_type(&self) -> ScalarType { + match self { + MadDetails::Integer { + type_, + control: MulIntControl::Wide, + .. + } => match type_ { + ScalarType::U16 => ScalarType::U32, + ScalarType::S16 => ScalarType::S32, + ScalarType::U32 => ScalarType::U64, + ScalarType::S32 => ScalarType::S64, + _ => unreachable!(), + }, + _ => self.type_(), + } + } + + fn type_(&self) -> ScalarType { + match self { + MadDetails::Integer { type_, .. } => *type_, + MadDetails::Float(arith) => arith.type_, + } + } +} + +#[derive(Copy, Clone)] +pub enum MinMaxDetails { + Signed(ScalarType), + Unsigned(ScalarType), + Float(MinMaxFloat), +} + +impl MinMaxDetails { + pub fn type_(&self) -> ScalarType { + match self { + MinMaxDetails::Signed(t) => *t, + MinMaxDetails::Unsigned(t) => *t, + MinMaxDetails::Float(float) => float.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct MinMaxFloat { + pub flush_to_zero: Option, + pub nan: bool, + pub type_: ScalarType, +} + +#[derive(Copy, Clone)] +pub struct RcpData { + pub kind: RcpKind, + pub flush_to_zero: Option, + pub type_: ScalarType, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum RcpKind { + Approx, + Compliant(RoundingMode), +} + +pub struct BarData { + pub aligned: bool, +} + +#[derive(Copy, Clone)] +pub struct BarRedData { + pub aligned: bool, + pub pred_reduction: Reduction, +} + +pub struct AtomDetails { + pub type_: Type, + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: StateSpace, + pub op: AtomicOp, +} + +#[derive(Copy, Clone)] +pub enum AtomicOp { + And, + Or, + Xor, + Exchange, + Add, + IncrementWrap, + DecrementWrap, + SignedMin, + UnsignedMin, + SignedMax, + UnsignedMax, + FloatAdd, + FloatMin, + FloatMax, +} + +impl AtomicOp { + pub(crate) fn new(op: super::RawAtomicOp, kind: ScalarKind) -> Self { + use super::RawAtomicOp; + match (op, kind) { + (RawAtomicOp::And, _) => Self::And, + (RawAtomicOp::Or, _) => Self::Or, + (RawAtomicOp::Xor, _) => Self::Xor, + (RawAtomicOp::Exch, _) => Self::Exchange, + (RawAtomicOp::Add, ScalarKind::Float) => Self::FloatAdd, + (RawAtomicOp::Add, _) => Self::Add, + (RawAtomicOp::Inc, _) => Self::IncrementWrap, + (RawAtomicOp::Dec, _) => Self::DecrementWrap, + (RawAtomicOp::Min, ScalarKind::Signed) => Self::SignedMin, + (RawAtomicOp::Min, ScalarKind::Float) => Self::FloatMin, + (RawAtomicOp::Min, _) => Self::UnsignedMin, + (RawAtomicOp::Max, ScalarKind::Signed) => Self::SignedMax, + (RawAtomicOp::Max, ScalarKind::Float) => Self::FloatMax, + (RawAtomicOp::Max, _) => Self::UnsignedMax, + } + } +} + +pub struct AtomCasDetails { + pub type_: ScalarType, + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: StateSpace, +} + +#[derive(Copy, Clone)] +pub enum DivDetails { + Unsigned(ScalarType), + Signed(ScalarType), + Float(DivFloatDetails), +} + +impl DivDetails { + pub fn type_(&self) -> ScalarType { + match self { + DivDetails::Unsigned(t) => *t, + DivDetails::Signed(t) => *t, + DivDetails::Float(float) => float.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct DivFloatDetails { + pub type_: ScalarType, + pub flush_to_zero: Option, + pub kind: DivFloatKind, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum DivFloatKind { + Approx, + ApproxFull, + Rounding(RoundingMode), +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct FlushToZero { + pub flush_to_zero: bool, +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 76887e5..4b79d47 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3581,7 +3581,7 @@ derive_parser!( state.errors.push(PtxError::SyntaxError); CpAsyncCpSize::Bytes4 }); - + let src_size = src_size .and_then(|op| op.as_immediate()) .and_then(|imm| imm.as_u64()); diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index faebd21..51e17ad 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -1,1028 +1,1036 @@ -use either::Either; -use ptx_parser_macros_impl::parser; -use proc_macro2::{Span, TokenStream}; -use quote::{format_ident, quote, ToTokens}; -use rustc_hash::{FxHashMap, FxHashSet}; -use std::{collections::hash_map, hash::Hash, iter, rc::Rc}; -use syn::{ - parse_macro_input, parse_quote, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, - Variant, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-floating-point-data-types -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-integer-data-types -#[rustfmt::skip] -static POSTFIX_MODIFIERS: &[&str] = &[ - ".v2", ".v4", ".v8", - ".s8", ".s16", ".s16x2", ".s32", ".s64", - ".u8", ".u16", ".u16x2", ".u32", ".u64", - ".f16", ".f16x2", ".f32", ".f64", - ".b8", ".b16", ".b32", ".b64", ".b128", - ".pred", - ".bf16", ".bf16x2", ".e4m3", ".e5m2", ".tf32", -]; - -static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; - -struct OpcodeDefinitions { - definitions: Vec, - block_selection: Vec>, usize)>>, -} - -impl OpcodeDefinitions { - fn new(opcode: &Ident, definitions: Vec) -> Self { - let mut selections = vec![None; definitions.len()]; - let mut generation = 0usize; - loop { - let mut selected_something = false; - let unselected = selections - .iter() - .enumerate() - .filter_map(|(idx, s)| if s.is_none() { Some(idx) } else { None }) - .collect::>(); - match &*unselected { - [] => break, - [remaining] => { - selections[*remaining] = Some((None, generation)); - break; - } - _ => {} - } - 'check_definitions: for i in unselected.iter().copied() { - let mut candidates = definitions[i] - .unordered_modifiers - .iter() - .chain(definitions[i].ordered_modifiers.iter()) - .filter(|modifier| match modifier { - DotModifierRef::Direct { - optional: false, .. - } - | DotModifierRef::Indirect { - optional: false, .. - } => true, - _ => false, - }) - .collect::>(); - candidates.sort_by_key(|modifier| match modifier { - DotModifierRef::Direct { .. } => 1, - DotModifierRef::Indirect { value, .. } => value.alternatives.len(), - }); - // Attempt every modifier - 'check_candidates: for candidate_modifier in candidates { - // check all other unselected patterns - for j in unselected.iter().copied() { - if i == j { - continue; - } - let candidate_set = match candidate_modifier { - DotModifierRef::Direct { value, .. } => Either::Left(iter::once(value)), - DotModifierRef::Indirect { value, .. } => { - Either::Right(value.alternatives.iter()) - } - }; - for candidate_value in candidate_set { - if definitions[j].possible_modifiers.contains(candidate_value) { - continue 'check_candidates; - } - } - } - // it's unique - let candidate_vec = match candidate_modifier { - DotModifierRef::Direct { value, .. } => vec![value.clone()], - DotModifierRef::Indirect { value, .. } => { - value.alternatives.iter().cloned().collect::>() - } - }; - selections[i] = Some((Some(candidate_vec), generation)); - selected_something = true; - continue 'check_definitions; - } - } - if !selected_something { - panic!( - "Failed to generate pattern selection for `{}`. State: {:?}", - opcode, - selections.into_iter().rev().collect::>() - ); - } - generation += 1; - } - let mut block_selection = Vec::new(); - for current_generation in 0usize.. { - let mut current_generation_definitions = Vec::new(); - for (idx, selection) in selections.iter_mut().enumerate() { - match selection { - Some((modifier_set, generation)) => { - if *generation == current_generation { - current_generation_definitions.push((modifier_set.clone(), idx)); - *selection = None; - } - } - None => {} - } - } - if current_generation_definitions.is_empty() { - break; - } - block_selection.push(current_generation_definitions); - } - #[cfg(debug_assertions)] - { - let selected = block_selection - .iter() - .map(|x| x.len()) - .reduce(|x, y| x + y) - .unwrap(); - if selected != definitions.len() { - panic!( - "Internal error when generating pattern selection for `{}`: {:?}", - opcode, &block_selection - ); - } - } - Self { - definitions, - block_selection, - } - } - - fn get_enum_types( - parse_definitions: &[parser::OpcodeDefinition], - ) -> FxHashMap> { - let mut result = FxHashMap::default(); - for parser::OpcodeDefinition(_, rules) in parse_definitions.iter() { - for rule in rules { - let type_ = match rule.type_ { - Some(ref type_) => type_.clone(), - None => continue, - }; - let insert_values = |set: &mut FxHashSet<_>| { - for value in rule.alternatives.iter().cloned() { - set.insert(value); - } - }; - match result.entry(type_) { - hash_map::Entry::Occupied(mut entry) => insert_values(entry.get_mut()), - hash_map::Entry::Vacant(entry) => { - insert_values(entry.insert(FxHashSet::default())) - } - }; - } - } - result - } -} - -struct SingleOpcodeDefinition { - possible_modifiers: FxHashSet, - unordered_modifiers: Vec, - ordered_modifiers: Vec, - arguments: parser::Arguments, - code_block: parser::CodeBlock, -} - -impl SingleOpcodeDefinition { - fn function_arguments_declarations(&self) -> impl Iterator + '_ { - self.unordered_modifiers - .iter() - .chain(self.ordered_modifiers.iter()) - .filter_map(|modf| { - let type_ = modf.type_of(); - type_.map(|t| { - let name = modf.ident(); - quote! { #name : #t } - }) - }) - .chain(self.arguments.0.iter().map(|arg| { - let name = &arg.ident.ident(); - let arg_type = if arg.unified { - quote! { (ParsedOperandStr<'input>, bool) } - } else if arg.can_be_negated { - quote! { (bool, ParsedOperandStr<'input>) } - } else { - quote! { ParsedOperandStr<'input> } - }; - if arg.optional { - quote! { #name : Option<#arg_type> } - } else { - quote! { #name : #arg_type } - } - })) - } - - fn function_arguments(&self) -> impl Iterator + '_ { - self.unordered_modifiers - .iter() - .chain(self.ordered_modifiers.iter()) - .filter_map(|modf| { - let type_ = modf.type_of(); - type_.map(|_| { - let name = modf.ident(); - quote! { #name } - }) - }) - .chain(self.arguments.0.iter().map(|arg| { - let name = &arg.ident.ident(); - quote! { #name } - })) - } - - fn extract_and_insert( - definitions: &mut FxHashMap>, - special_definitions: &mut FxHashMap, - parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition, - ) { - let (mut named_rules, mut unnamed_rules) = gather_rules(rules); - let mut last_opcode = pattern_seq.0.last().unwrap().0 .0.name.clone(); - for (opcode_decl, code_block) in pattern_seq.0.into_iter().rev() { - let current_opcode = opcode_decl.0.name.clone(); - if last_opcode != current_opcode { - named_rules = FxHashMap::default(); - unnamed_rules = FxHashMap::default(); - } - let parser::OpcodeDecl(instruction, arguments) = opcode_decl; - if code_block.special { - if !instruction.modifiers.is_empty() || !arguments.0.is_empty() { - panic!( - "`{}`: no modifiers or arguments are allowed in parser definition.", - instruction.name - ); - } - special_definitions.insert(instruction.name, code_block.code); - continue; - } - let mut possible_modifiers = FxHashSet::default(); - let mut unordered_modifiers = instruction - .modifiers - .into_iter() - .map(|parser::MaybeDotModifier { optional, modifier }| { - match named_rules.get(&modifier) { - Some(alts) => { - possible_modifiers.extend(alts.alternatives.iter().cloned()); - if alts.alternatives.len() == 1 && alts.type_.is_none() { - DotModifierRef::Direct { - optional, - value: alts.alternatives[0].clone(), - name: modifier, - type_: alts.type_.clone(), - } - } else { - DotModifierRef::Indirect { - optional, - value: alts.clone(), - name: modifier, - } - } - } - None => { - let type_ = unnamed_rules.get(&modifier).cloned(); - possible_modifiers.insert(modifier.clone()); - DotModifierRef::Direct { - optional, - value: modifier.clone(), - name: modifier, - type_, - } - } - } - }) - .collect::>(); - let ordered_modifiers = Self::extract_ordered_modifiers(&mut unordered_modifiers); - let entry = Self { - possible_modifiers, - unordered_modifiers, - ordered_modifiers, - arguments, - code_block, - }; - multihash_extend(definitions, current_opcode.clone(), entry); - last_opcode = current_opcode; - } - } - - fn extract_ordered_modifiers( - unordered_modifiers: &mut Vec, - ) -> Vec { - let mut result = Vec::new(); - loop { - let is_ordered = match unordered_modifiers.last() { - Some(DotModifierRef::Direct { value, .. }) => { - let name = value.to_string(); - POSTFIX_MODIFIERS.contains(&&*name) - } - Some(DotModifierRef::Indirect { value, .. }) => { - let type_ = value.type_.to_token_stream().to_string(); - //panic!("{} {}", type_, POSTFIX_TYPES.contains(&&*type_)); - POSTFIX_TYPES.contains(&&*type_) - } - None => break, - }; - if is_ordered { - result.push(unordered_modifiers.pop().unwrap()); - } else { - break; - } - } - if unordered_modifiers.len() == 1 { - result.push(unordered_modifiers.pop().unwrap()); - } - result.reverse(); - result - } -} - -fn gather_rules( - rules: Vec, -) -> ( - FxHashMap>, - FxHashMap, -) { - let mut named = FxHashMap::default(); - let mut unnamed = FxHashMap::default(); - for rule in rules { - match rule.modifier { - Some(ref modifier) => { - named.insert(modifier.clone(), Rc::new(rule)); - } - None => unnamed.extend( - rule.alternatives - .into_iter() - .map(|alt| (alt, rule.type_.as_ref().unwrap().clone())), - ), - } - } - (named, unnamed) -} - -#[proc_macro] -pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { - let parse_definitions = parse_macro_input!(tokens as ptx_parser_macros_impl::parser::ParseDefinitions); - let mut definitions = FxHashMap::default(); - let mut special_definitions = FxHashMap::default(); - let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions); - let enum_types_tokens = emit_enum_types(types, parse_definitions.additional_enums); - for definition in parse_definitions.definitions.into_iter() { - SingleOpcodeDefinition::extract_and_insert( - &mut definitions, - &mut special_definitions, - definition, - ); - } - let definitions = definitions - .into_iter() - .map(|(k, v)| { - let v = OpcodeDefinitions::new(&k, v); - (k, v) - }) - .collect::>(); - let mut token_enum = parse_definitions.token_type; - let (all_opcode, all_modifier) = write_definitions_into_tokens( - &definitions, - special_definitions.keys(), - &mut token_enum.variants, - ); - let token_impl = emit_parse_function(&token_enum.ident, &definitions, &special_definitions, all_opcode, all_modifier); - let tokens = quote! { - #enum_types_tokens - - #token_enum - - #token_impl - }; - tokens.into() -} - -fn emit_enum_types( - types: FxHashMap>, - mut existing_enums: FxHashMap, -) -> TokenStream { - let token_types = types.into_iter().filter_map(|(type_, variants)| { - match type_ { - syn::Type::Path(TypePath { - qself: None, - ref path, - }) => { - if let Some(ident) = path.get_ident() { - if let Some(enum_) = existing_enums.get_mut(ident) { - enum_.variants.extend(variants.into_iter().map(|modifier| { - let ident = modifier.variant_capitalized(); - let variant: syn::Variant = syn::parse_quote! { - #ident - }; - variant - })); - return None; - } - } - } - _ => {} - } - let variants = variants.iter().map(|v| v.variant_capitalized()); - Some(quote! { - #[derive(Copy, Clone, PartialEq, Eq, Hash)] - enum #type_ { - #(#variants),* - } - }) - }); - let mut result = TokenStream::new(); - for tokens in token_types { - tokens.to_tokens(&mut result); - } - for (_, enum_) in existing_enums { - quote! { #enum_ }.to_tokens(&mut result); - } - result -} - -fn emit_parse_function( - type_name: &Ident, - defs: &FxHashMap, - special_defs: &FxHashMap, - all_opcode: Vec<&Ident>, - all_modifier: FxHashSet<&parser::DotModifier>, -) -> TokenStream { - use std::fmt::Write; - let fns_ = defs - .iter() - .map(|(opcode, defs)| { - defs.definitions.iter().enumerate().map(|(idx, def)| { - let mut fn_name = opcode.to_string(); - write!(&mut fn_name, "_{}", idx).ok(); - let fn_name = Ident::new(&fn_name, Span::call_site()); - let code_block = &def.code_block.code; - let args = def.function_arguments_declarations(); - quote! { - fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block - } - }) - }) - .flatten(); - let selectors = defs.iter().map(|(opcode, def)| { - let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); - let mut result = TokenStream::new(); - let mut selectors = TokenStream::new(); - quote! { - if false { - unsafe { std::hint::unreachable_unchecked() } - } - } - .to_tokens(&mut selectors); - let mut has_default_selector = false; - for selection_layer in def.block_selection.iter() { - for (selection_key, selected_definition) in selection_layer { - let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]); - match selection_key { - Some(selection_keys) => { - let selection_keys = selection_keys.iter().map(|k| k.dot_capitalized()); - quote! { - else if false #(|| modifiers.iter().any(|(t, _)| *t == #type_name :: #selection_keys))* { - #def_parser - } - } - .to_tokens(&mut selectors); - } - None => { - has_default_selector = true; - quote! { - else { - #def_parser - } - } - .to_tokens(&mut selectors); - } - } - } - } - if !has_default_selector { - quote! { - else { - return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) - } - } - .to_tokens(&mut selectors); - } - quote! { - #opcode_variant => { - let modifers_start = stream.checkpoint(); - let modifiers = take_while(0.., |(t,_)| Token::modifier(t)).parse_next(stream)?; - #selectors - } - } - .to_tokens(&mut result); - result - }).chain(special_defs.iter().map(|(opcode, code)| { - let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); - quote! { - #opcode_variant => { #code? } - } - })); - let opcodes = all_opcode.into_iter().map(|op_ident| { - let op = op_ident.to_string(); - let variant = Ident::new(&capitalize(&op), op_ident.span()); - let value = op; - quote! { - #type_name :: #variant => Some(#value), - } - }); - let modifier_names = iter::once(Ident::new("DotUnified", Span::call_site())) - .chain(all_modifier.iter().map(|m| m.dot_capitalized())); - quote! { - impl<'input> #type_name<'input> { - fn opcode_text(self) -> Option<&'static str> { - match self { - #(#opcodes)* - _ => None - } - } - - fn modifier(self) -> bool { - match self { - #( - #type_name :: #modifier_names => true, - )* - _ => false - } - } - } - - #(#fns_)* - - fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> - { - use winnow::Parser; - use winnow::token::*; - use winnow::combinator::*; - let opcode = any.parse_next(stream)?.0; - let modifiers_start = stream.checkpoint(); - Ok(match opcode { - #( - #type_name :: #selectors - )* - _ => return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) - }) - } - } -} - -fn emit_definition_parser( - token_type: &Ident, - (opcode, fn_idx): (&Ident, usize), - definition: &SingleOpcodeDefinition, -) -> TokenStream { - let return_error_ref = quote! { - return Err(winnow::error::ErrMode::from_error_kind(&stream, winnow::error::ErrorKind::Token)) - }; - let return_error = quote! { - return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) - }; - let ordered_parse_declarations = definition.ordered_modifiers.iter().map(|modifier| { - modifier.type_of().map(|type_| { - let name = modifier.ident(); - quote! { - let #name : #type_; - } - }) - }); - let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| { - let arg_name = modifier.ident(); - match modifier { - DotModifierRef::Direct { optional, value, type_: None, .. } => { - let variant = value.dot_capitalized(); - if *optional { - quote! { - #arg_name = opt(any.verify(|(t, _)| *t == #token_type :: #variant)).parse_next(&mut stream)?.is_some(); - } - } else { - quote! { - any.verify(|(t, _)| *t == #token_type :: #variant).parse_next(&mut stream)?; - } - } - } - DotModifierRef::Direct { optional: false, type_: Some(type_), name, value } => { - let variable = name.ident(); - let variant = value.dot_capitalized(); - let parsed_variant = value.variant_capitalized(); - quote! { - any.verify(|(t, _)| *t == #token_type :: #variant).parse_next(&mut stream)?; - #variable = #type_ :: #parsed_variant; - } - } - DotModifierRef::Direct { optional: true, type_: Some(_), .. } => { todo!() } - DotModifierRef::Indirect { optional, value, .. } => { - let variants = value.alternatives.iter().map(|alt| { - let type_ = value.type_.as_ref().unwrap(); - let token_variant = alt.dot_capitalized(); - let parsed_variant = alt.variant_capitalized(); - quote! { - #token_type :: #token_variant => #type_ :: #parsed_variant, - } - }); - if *optional { - quote! { - #arg_name = opt(any.verify_map(|(tok, _)| { - Some(match tok { - #(#variants)* - _ => return None - }) - })).parse_next(&mut stream)?; - } - } else { - quote! { - #arg_name = any.verify_map(|(tok, _)| { - Some(match tok { - #(#variants)* - _ => return None - }) - }).parse_next(&mut stream)?; - } - } - } - } - }); - let unordered_parse_declarations = definition.unordered_modifiers.iter().map(|modifier| { - let name = modifier.ident(); - let type_ = modifier.type_of_check(); - quote! { - let mut #name : #type_ = std::default::Default::default(); - } - }); - let unordered_parse = definition - .unordered_modifiers - .iter() - .map(|modifier| match modifier { - DotModifierRef::Direct { - name, - value, - type_: None, - .. - } => { - let name = name.ident(); - let token_variant = value.dot_capitalized(); - quote! { - #token_type :: #token_variant => { - if #name { - #return_error_ref; - } - #name = true; - } - } - } - DotModifierRef::Direct { - name, - value, - type_: Some(type_), - .. - } => { - let variable = name.ident(); - let token_variant = value.dot_capitalized(); - let enum_variant = value.variant_capitalized(); - quote! { - #token_type :: #token_variant => { - if #variable.is_some() { - #return_error_ref; - } - #variable = Some(#type_ :: #enum_variant); - } - } - } - DotModifierRef::Indirect { value, name, .. } => { - let variable = name.ident(); - let type_ = value.type_.as_ref().unwrap(); - let alternatives = value.alternatives.iter().map(|alt| { - let token_variant = alt.dot_capitalized(); - let enum_variant = alt.variant_capitalized(); - quote! { - #token_type :: #token_variant => { - if #variable.is_some() { - #return_error_ref; - } - #variable = Some(#type_ :: #enum_variant); - } - } - }); - quote! { - #(#alternatives)* - } - } - }); - let unordered_parse_validations = - definition - .unordered_modifiers - .iter() - .map(|modifier| match modifier { - DotModifierRef::Direct { - optional: false, - name, - type_: None, - .. - } => { - let variable = name.ident(); - quote! { - if !#variable { - #return_error; - } - } - } - DotModifierRef::Direct { - optional: false, - name, - type_: Some(_), - .. - } => { - let variable = name.ident(); - quote! { - let #variable = match #variable { - Some(x) => x, - None => #return_error - }; - } - } - DotModifierRef::Indirect { - optional: false, - name, - .. - } => { - let variable = name.ident(); - quote! { - let #variable = match #variable { - Some(x) => x, - None => #return_error - }; - } - } - DotModifierRef::Direct { optional: true, .. } - | DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), - }); - let (arguments_pattern, arguments_parser) = definition.arguments.0.iter().enumerate().rfold((quote! { () }, quote! { empty }), |(emitted_pattern, emitted_parser), (idx, arg)| { - let comma = if idx == 0 || arg.pre_pipe { - quote! { empty } - } else { - quote! { any.verify(|(t, _)| *t == #token_type::Comma).void() } - }; - - let pre_bracket = if arg.pre_bracket { - quote! { - any.verify(|(t, _)| *t == #token_type::LBracket).void() - } - } else { - quote! { - empty - } - }; - let pre_pipe = if arg.pre_pipe { - quote! { - any.verify(|(t, _)| *t == #token_type::Pipe).void() - } - } else { - quote! { - empty - } - }; - let can_be_negated = if arg.can_be_negated { - quote! { - opt(any.verify(|(t, _)| *t == #token_type::Exclamation)).map(|o| o.is_some()) - } - } else { - quote! { - empty - } - }; - let operand = { - quote! { - ParsedOperandStr::parse - } - }; - let post_bracket = if arg.post_bracket { - quote! { - any.verify(|(t, _)| *t == #token_type::RBracket).void() - } - } else { - quote! { - empty - } - }; - let unified = if arg.unified { - quote! { - opt(any.verify(|(t, _)| *t == #token_type::DotUnified).void()).map(|u| u.is_some()) - } - } else { - quote! { - empty - } - }; - let pattern = quote! { - (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified) - }; - let arg_name = &arg.ident.ident(); - if arg.unified && arg.can_be_negated { - panic!("TODO: argument can't be both prefixed by `!` and suffixed by `.unified`") - } - let inner_parser = if arg.unified { - quote! { - #pattern.map(|(_, _, _, _, name, _, unified)| (name, unified)) - } - } else if arg.can_be_negated { - quote! { - #pattern.map(|(_, _, _, negated, name, _, _)| (negated, name)) - } - } else { - quote! { - #pattern.map(|(_, _, _, _, name, _, _)| name) - } - }; - - let parser = if arg.optional { - quote! { first_optional(#inner_parser, #emitted_parser) } - } else { - quote! { (#inner_parser, #emitted_parser) } - }; - - let pattern = quote! { ( #arg_name, #emitted_pattern ) }; - - (pattern, parser) - }); - - let arguments_parse = quote! { let #arguments_pattern = ( #arguments_parser ).parse_next(stream)?; }; - - let fn_args = definition.function_arguments(); - let fn_name = format_ident!("{}_{}", opcode, fn_idx); - let fn_call = quote! { - #fn_name(&mut stream.state, #(#fn_args),* ) - }; - quote! { - #(#unordered_parse_declarations)* - #(#ordered_parse_declarations)* - { - let mut stream = ReverseStream(modifiers); - #(#ordered_parse)* - let mut stream: &[_] = stream.0; - for (token, _) in stream.iter().cloned() { - match token { - #(#unordered_parse)* - _ => #return_error_ref - } - } - } - #(#unordered_parse_validations)* - #arguments_parse - #fn_call - } -} - -fn write_definitions_into_tokens<'a>( - defs: &'a FxHashMap, - special_definitions: impl Iterator, - variants: &mut Punctuated, -) -> (Vec<&'a Ident>, FxHashSet<&'a parser::DotModifier>) { - let mut all_opcodes = Vec::new(); - let mut all_modifiers = FxHashSet::default(); - for (opcode, definitions) in defs.iter() { - all_opcodes.push(opcode); - let opcode_as_string = opcode.to_string(); - let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); - let arg: Variant = syn::parse_quote! { - #[token(#opcode_as_string)] - #variant_name - }; - variants.push(arg); - for definition in definitions.definitions.iter() { - for modifier in definition.possible_modifiers.iter() { - all_modifiers.insert(modifier); - } - } - } - for opcode in special_definitions { - all_opcodes.push(opcode); - let opcode_as_string = opcode.to_string(); - let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); - let arg: Variant = syn::parse_quote! { - #[token(#opcode_as_string)] - #variant_name - }; - variants.push(arg); - } - for modifier in all_modifiers.iter() { - let modifier_as_string = modifier.to_string(); - let variant_name = modifier.dot_capitalized(); - let arg: Variant = syn::parse_quote! { - #[token(#modifier_as_string)] - #variant_name - }; - variants.push(arg); - } - variants.push(parse_quote! { - #[token(".unified")] - DotUnified - }); - (all_opcodes, all_modifiers) -} - -fn capitalize(s: &str) -> String { - let mut c = s.chars(); - match c.next() { - None => String::new(), - Some(f) => f.to_uppercase().collect::() + c.as_str(), - } -} - -fn multihash_extend(multimap: &mut FxHashMap>, k: K, v: V) { - match multimap.entry(k) { - hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(v), - hash_map::Entry::Vacant(entry) => { - entry.insert(vec![v]); - } - } -} - -enum DotModifierRef { - Direct { - optional: bool, - value: parser::DotModifier, - name: parser::DotModifier, - type_: Option, - }, - Indirect { - optional: bool, - name: parser::DotModifier, - value: Rc, - }, -} - -impl DotModifierRef { - fn ident(&self) -> Ident { - match self { - DotModifierRef::Direct { name, .. } => name.ident(), - DotModifierRef::Indirect { name, .. } => name.ident(), - } - } - - fn type_of(&self) -> Option { - Some(match self { - DotModifierRef::Direct { - optional: true, - type_: None, - .. - } => syn::parse_quote! { bool }, - DotModifierRef::Direct { - optional: false, - type_: None, - .. - } => return None, - DotModifierRef::Direct { - optional: true, - type_: Some(type_), - .. - } => syn::parse_quote! { Option<#type_> }, - DotModifierRef::Direct { - optional: false, - type_: Some(type_), - .. - } => type_.clone(), - DotModifierRef::Indirect { - optional, value, .. - } => { - let type_ = value - .type_ - .as_ref() - .expect("Indirect modifer must have a type"); - if *optional { - syn::parse_quote! { Option<#type_> } - } else { - type_.clone() - } - } - }) - } - - fn type_of_check(&self) -> syn::Type { - match self { - DotModifierRef::Direct { type_: None, .. } => syn::parse_quote! { bool }, - DotModifierRef::Direct { - type_: Some(type_), .. - } => syn::parse_quote! { Option<#type_> }, - DotModifierRef::Indirect { value, .. } => { - let type_ = value - .type_ - .as_ref() - .expect("Indirect modifer must have a type"); - syn::parse_quote! { Option<#type_> } - } - } - } -} - -#[proc_macro] -pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(tokens as ptx_parser_macros_impl::GenerateInstructionType); - let mut result = proc_macro2::TokenStream::new(); - input.emit_arg_types(&mut result); - input.emit_instruction_type(&mut result); - input.emit_visit(&mut result); - input.emit_visit_mut(&mut result); - input.emit_visit_map(&mut result); - result.into() -} +use either::Either; +use proc_macro2::{Span, TokenStream}; +use ptx_parser_macros_impl::parser; +use quote::{format_ident, quote, ToTokens}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{collections::hash_map, hash::Hash, iter, rc::Rc}; +use syn::{ + parse_macro_input, parse_quote, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, + Variant, +}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-floating-point-data-types +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-integer-data-types +#[rustfmt::skip] +static POSTFIX_MODIFIERS: &[&str] = &[ + ".v2", ".v4", ".v8", + ".s8", ".s16", ".s16x2", ".s32", ".s64", + ".u8", ".u16", ".u16x2", ".u32", ".u64", + ".f16", ".f16x2", ".f32", ".f64", + ".b8", ".b16", ".b32", ".b64", ".b128", + ".pred", + ".bf16", ".bf16x2", ".e4m3", ".e5m2", ".tf32", +]; + +static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; + +struct OpcodeDefinitions { + definitions: Vec, + block_selection: Vec>, usize)>>, +} + +impl OpcodeDefinitions { + fn new(opcode: &Ident, definitions: Vec) -> Self { + let mut selections = vec![None; definitions.len()]; + let mut generation = 0usize; + loop { + let mut selected_something = false; + let unselected = selections + .iter() + .enumerate() + .filter_map(|(idx, s)| if s.is_none() { Some(idx) } else { None }) + .collect::>(); + match &*unselected { + [] => break, + [remaining] => { + selections[*remaining] = Some((None, generation)); + break; + } + _ => {} + } + 'check_definitions: for i in unselected.iter().copied() { + let mut candidates = definitions[i] + .unordered_modifiers + .iter() + .chain(definitions[i].ordered_modifiers.iter()) + .filter(|modifier| match modifier { + DotModifierRef::Direct { + optional: false, .. + } + | DotModifierRef::Indirect { + optional: false, .. + } => true, + _ => false, + }) + .collect::>(); + candidates.sort_by_key(|modifier| match modifier { + DotModifierRef::Direct { .. } => 1, + DotModifierRef::Indirect { value, .. } => value.alternatives.len(), + }); + // Attempt every modifier + 'check_candidates: for candidate_modifier in candidates { + // check all other unselected patterns + for j in unselected.iter().copied() { + if i == j { + continue; + } + let candidate_set = match candidate_modifier { + DotModifierRef::Direct { value, .. } => Either::Left(iter::once(value)), + DotModifierRef::Indirect { value, .. } => { + Either::Right(value.alternatives.iter()) + } + }; + for candidate_value in candidate_set { + if definitions[j].possible_modifiers.contains(candidate_value) { + continue 'check_candidates; + } + } + } + // it's unique + let candidate_vec = match candidate_modifier { + DotModifierRef::Direct { value, .. } => vec![value.clone()], + DotModifierRef::Indirect { value, .. } => { + value.alternatives.iter().cloned().collect::>() + } + }; + selections[i] = Some((Some(candidate_vec), generation)); + selected_something = true; + continue 'check_definitions; + } + } + if !selected_something { + panic!( + "Failed to generate pattern selection for `{}`. State: {:?}", + opcode, + selections.into_iter().rev().collect::>() + ); + } + generation += 1; + } + let mut block_selection = Vec::new(); + for current_generation in 0usize.. { + let mut current_generation_definitions = Vec::new(); + for (idx, selection) in selections.iter_mut().enumerate() { + match selection { + Some((modifier_set, generation)) => { + if *generation == current_generation { + current_generation_definitions.push((modifier_set.clone(), idx)); + *selection = None; + } + } + None => {} + } + } + if current_generation_definitions.is_empty() { + break; + } + block_selection.push(current_generation_definitions); + } + #[cfg(debug_assertions)] + { + let selected = block_selection + .iter() + .map(|x| x.len()) + .reduce(|x, y| x + y) + .unwrap(); + if selected != definitions.len() { + panic!( + "Internal error when generating pattern selection for `{}`: {:?}", + opcode, &block_selection + ); + } + } + Self { + definitions, + block_selection, + } + } + + fn get_enum_types( + parse_definitions: &[parser::OpcodeDefinition], + ) -> FxHashMap> { + let mut result = FxHashMap::default(); + for parser::OpcodeDefinition(_, rules) in parse_definitions.iter() { + for rule in rules { + let type_ = match rule.type_ { + Some(ref type_) => type_.clone(), + None => continue, + }; + let insert_values = |set: &mut FxHashSet<_>| { + for value in rule.alternatives.iter().cloned() { + set.insert(value); + } + }; + match result.entry(type_) { + hash_map::Entry::Occupied(mut entry) => insert_values(entry.get_mut()), + hash_map::Entry::Vacant(entry) => { + insert_values(entry.insert(FxHashSet::default())) + } + }; + } + } + result + } +} + +struct SingleOpcodeDefinition { + possible_modifiers: FxHashSet, + unordered_modifiers: Vec, + ordered_modifiers: Vec, + arguments: parser::Arguments, + code_block: parser::CodeBlock, +} + +impl SingleOpcodeDefinition { + fn function_arguments_declarations(&self) -> impl Iterator + '_ { + self.unordered_modifiers + .iter() + .chain(self.ordered_modifiers.iter()) + .filter_map(|modf| { + let type_ = modf.type_of(); + type_.map(|t| { + let name = modf.ident(); + quote! { #name : #t } + }) + }) + .chain(self.arguments.0.iter().map(|arg| { + let name = &arg.ident.ident(); + let arg_type = if arg.unified { + quote! { (ParsedOperandStr<'input>, bool) } + } else if arg.can_be_negated { + quote! { (bool, ParsedOperandStr<'input>) } + } else { + quote! { ParsedOperandStr<'input> } + }; + if arg.optional { + quote! { #name : Option<#arg_type> } + } else { + quote! { #name : #arg_type } + } + })) + } + + fn function_arguments(&self) -> impl Iterator + '_ { + self.unordered_modifiers + .iter() + .chain(self.ordered_modifiers.iter()) + .filter_map(|modf| { + let type_ = modf.type_of(); + type_.map(|_| { + let name = modf.ident(); + quote! { #name } + }) + }) + .chain(self.arguments.0.iter().map(|arg| { + let name = &arg.ident.ident(); + quote! { #name } + })) + } + + fn extract_and_insert( + definitions: &mut FxHashMap>, + special_definitions: &mut FxHashMap, + parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition, + ) { + let (mut named_rules, mut unnamed_rules) = gather_rules(rules); + let mut last_opcode = pattern_seq.0.last().unwrap().0 .0.name.clone(); + for (opcode_decl, code_block) in pattern_seq.0.into_iter().rev() { + let current_opcode = opcode_decl.0.name.clone(); + if last_opcode != current_opcode { + named_rules = FxHashMap::default(); + unnamed_rules = FxHashMap::default(); + } + let parser::OpcodeDecl(instruction, arguments) = opcode_decl; + if code_block.special { + if !instruction.modifiers.is_empty() || !arguments.0.is_empty() { + panic!( + "`{}`: no modifiers or arguments are allowed in parser definition.", + instruction.name + ); + } + special_definitions.insert(instruction.name, code_block.code); + continue; + } + let mut possible_modifiers = FxHashSet::default(); + let mut unordered_modifiers = instruction + .modifiers + .into_iter() + .map(|parser::MaybeDotModifier { optional, modifier }| { + match named_rules.get(&modifier) { + Some(alts) => { + possible_modifiers.extend(alts.alternatives.iter().cloned()); + if alts.alternatives.len() == 1 && alts.type_.is_none() { + DotModifierRef::Direct { + optional, + value: alts.alternatives[0].clone(), + name: modifier, + type_: alts.type_.clone(), + } + } else { + DotModifierRef::Indirect { + optional, + value: alts.clone(), + name: modifier, + } + } + } + None => { + let type_ = unnamed_rules.get(&modifier).cloned(); + possible_modifiers.insert(modifier.clone()); + DotModifierRef::Direct { + optional, + value: modifier.clone(), + name: modifier, + type_, + } + } + } + }) + .collect::>(); + let ordered_modifiers = Self::extract_ordered_modifiers(&mut unordered_modifiers); + let entry = Self { + possible_modifiers, + unordered_modifiers, + ordered_modifiers, + arguments, + code_block, + }; + multihash_extend(definitions, current_opcode.clone(), entry); + last_opcode = current_opcode; + } + } + + fn extract_ordered_modifiers( + unordered_modifiers: &mut Vec, + ) -> Vec { + let mut result = Vec::new(); + loop { + let is_ordered = match unordered_modifiers.last() { + Some(DotModifierRef::Direct { value, .. }) => { + let name = value.to_string(); + POSTFIX_MODIFIERS.contains(&&*name) + } + Some(DotModifierRef::Indirect { value, .. }) => { + let type_ = value.type_.to_token_stream().to_string(); + //panic!("{} {}", type_, POSTFIX_TYPES.contains(&&*type_)); + POSTFIX_TYPES.contains(&&*type_) + } + None => break, + }; + if is_ordered { + result.push(unordered_modifiers.pop().unwrap()); + } else { + break; + } + } + if unordered_modifiers.len() == 1 { + result.push(unordered_modifiers.pop().unwrap()); + } + result.reverse(); + result + } +} + +fn gather_rules( + rules: Vec, +) -> ( + FxHashMap>, + FxHashMap, +) { + let mut named = FxHashMap::default(); + let mut unnamed = FxHashMap::default(); + for rule in rules { + match rule.modifier { + Some(ref modifier) => { + named.insert(modifier.clone(), Rc::new(rule)); + } + None => unnamed.extend( + rule.alternatives + .into_iter() + .map(|alt| (alt, rule.type_.as_ref().unwrap().clone())), + ), + } + } + (named, unnamed) +} + +#[proc_macro] +pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let parse_definitions = + parse_macro_input!(tokens as ptx_parser_macros_impl::parser::ParseDefinitions); + let mut definitions = FxHashMap::default(); + let mut special_definitions = FxHashMap::default(); + let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions); + let enum_types_tokens = emit_enum_types(types, parse_definitions.additional_enums); + for definition in parse_definitions.definitions.into_iter() { + SingleOpcodeDefinition::extract_and_insert( + &mut definitions, + &mut special_definitions, + definition, + ); + } + let definitions = definitions + .into_iter() + .map(|(k, v)| { + let v = OpcodeDefinitions::new(&k, v); + (k, v) + }) + .collect::>(); + let mut token_enum = parse_definitions.token_type; + let (all_opcode, all_modifier) = write_definitions_into_tokens( + &definitions, + special_definitions.keys(), + &mut token_enum.variants, + ); + let token_impl = emit_parse_function( + &token_enum.ident, + &definitions, + &special_definitions, + all_opcode, + all_modifier, + ); + let tokens = quote! { + #enum_types_tokens + + #token_enum + + #token_impl + }; + tokens.into() +} + +fn emit_enum_types( + types: FxHashMap>, + mut existing_enums: FxHashMap, +) -> TokenStream { + let token_types = types.into_iter().filter_map(|(type_, variants)| { + match type_ { + syn::Type::Path(TypePath { + qself: None, + ref path, + }) => { + if let Some(ident) = path.get_ident() { + if let Some(enum_) = existing_enums.get_mut(ident) { + enum_.variants.extend(variants.into_iter().map(|modifier| { + let ident = modifier.variant_capitalized(); + let variant: syn::Variant = syn::parse_quote! { + #ident + }; + variant + })); + return None; + } + } + } + _ => {} + } + let variants = variants.iter().map(|v| v.variant_capitalized()); + Some(quote! { + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + enum #type_ { + #(#variants),* + } + }) + }); + let mut result = TokenStream::new(); + for tokens in token_types { + tokens.to_tokens(&mut result); + } + for (_, enum_) in existing_enums { + quote! { #enum_ }.to_tokens(&mut result); + } + result +} + +fn emit_parse_function( + type_name: &Ident, + defs: &FxHashMap, + special_defs: &FxHashMap, + all_opcode: Vec<&Ident>, + all_modifier: FxHashSet<&parser::DotModifier>, +) -> TokenStream { + use std::fmt::Write; + let fns_ = defs + .iter() + .map(|(opcode, defs)| { + defs.definitions.iter().enumerate().map(|(idx, def)| { + let mut fn_name = opcode.to_string(); + write!(&mut fn_name, "_{}", idx).ok(); + let fn_name = Ident::new(&fn_name, Span::call_site()); + let code_block = &def.code_block.code; + let args = def.function_arguments_declarations(); + quote! { + fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block + } + }) + }) + .flatten(); + let selectors = defs.iter().map(|(opcode, def)| { + let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); + let mut result = TokenStream::new(); + let mut selectors = TokenStream::new(); + quote! { + if false { + unsafe { std::hint::unreachable_unchecked() } + } + } + .to_tokens(&mut selectors); + let mut has_default_selector = false; + for selection_layer in def.block_selection.iter() { + for (selection_key, selected_definition) in selection_layer { + let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]); + match selection_key { + Some(selection_keys) => { + let selection_keys = selection_keys.iter().map(|k| k.dot_capitalized()); + quote! { + else if false #(|| modifiers.iter().any(|(t, _)| *t == #type_name :: #selection_keys))* { + #def_parser + } + } + .to_tokens(&mut selectors); + } + None => { + has_default_selector = true; + quote! { + else { + #def_parser + } + } + .to_tokens(&mut selectors); + } + } + } + } + if !has_default_selector { + quote! { + else { + return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + } + } + .to_tokens(&mut selectors); + } + quote! { + #opcode_variant => { + let modifers_start = stream.checkpoint(); + let modifiers = take_while(0.., |(t,_)| Token::modifier(t)).parse_next(stream)?; + #selectors + } + } + .to_tokens(&mut result); + result + }).chain(special_defs.iter().map(|(opcode, code)| { + let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); + quote! { + #opcode_variant => { #code? } + } + })); + let opcodes = all_opcode.into_iter().map(|op_ident| { + let op = op_ident.to_string(); + let variant = Ident::new(&capitalize(&op), op_ident.span()); + let value = op; + quote! { + #type_name :: #variant => Some(#value), + } + }); + let modifier_names = iter::once(Ident::new("DotUnified", Span::call_site())) + .chain(all_modifier.iter().map(|m| m.dot_capitalized())); + quote! { + impl<'input> #type_name<'input> { + fn opcode_text(self) -> Option<&'static str> { + match self { + #(#opcodes)* + _ => None + } + } + + fn modifier(self) -> bool { + match self { + #( + #type_name :: #modifier_names => true, + )* + _ => false + } + } + } + + #(#fns_)* + + fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> + { + use winnow::Parser; + use winnow::token::*; + use winnow::combinator::*; + let opcode = any.parse_next(stream)?.0; + let modifiers_start = stream.checkpoint(); + Ok(match opcode { + #( + #type_name :: #selectors + )* + _ => return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + }) + } + } +} + +fn emit_definition_parser( + token_type: &Ident, + (opcode, fn_idx): (&Ident, usize), + definition: &SingleOpcodeDefinition, +) -> TokenStream { + let return_error_ref = quote! { + return Err(winnow::error::ErrMode::from_error_kind(&stream, winnow::error::ErrorKind::Token)) + }; + let return_error = quote! { + return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + }; + let ordered_parse_declarations = definition.ordered_modifiers.iter().map(|modifier| { + modifier.type_of().map(|type_| { + let name = modifier.ident(); + quote! { + let #name : #type_; + } + }) + }); + let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| { + let arg_name = modifier.ident(); + match modifier { + DotModifierRef::Direct { optional, value, type_: None, .. } => { + let variant = value.dot_capitalized(); + if *optional { + quote! { + #arg_name = opt(any.verify(|(t, _)| *t == #token_type :: #variant)).parse_next(&mut stream)?.is_some(); + } + } else { + quote! { + any.verify(|(t, _)| *t == #token_type :: #variant).parse_next(&mut stream)?; + } + } + } + DotModifierRef::Direct { optional: false, type_: Some(type_), name, value } => { + let variable = name.ident(); + let variant = value.dot_capitalized(); + let parsed_variant = value.variant_capitalized(); + quote! { + any.verify(|(t, _)| *t == #token_type :: #variant).parse_next(&mut stream)?; + #variable = #type_ :: #parsed_variant; + } + } + DotModifierRef::Direct { optional: true, type_: Some(_), .. } => { todo!() } + DotModifierRef::Indirect { optional, value, .. } => { + let variants = value.alternatives.iter().map(|alt| { + let type_ = value.type_.as_ref().unwrap(); + let token_variant = alt.dot_capitalized(); + let parsed_variant = alt.variant_capitalized(); + quote! { + #token_type :: #token_variant => #type_ :: #parsed_variant, + } + }); + if *optional { + quote! { + #arg_name = opt(any.verify_map(|(tok, _)| { + Some(match tok { + #(#variants)* + _ => return None + }) + })).parse_next(&mut stream)?; + } + } else { + quote! { + #arg_name = any.verify_map(|(tok, _)| { + Some(match tok { + #(#variants)* + _ => return None + }) + }).parse_next(&mut stream)?; + } + } + } + } + }); + let unordered_parse_declarations = definition.unordered_modifiers.iter().map(|modifier| { + let name = modifier.ident(); + let type_ = modifier.type_of_check(); + quote! { + let mut #name : #type_ = std::default::Default::default(); + } + }); + let unordered_parse = definition + .unordered_modifiers + .iter() + .map(|modifier| match modifier { + DotModifierRef::Direct { + name, + value, + type_: None, + .. + } => { + let name = name.ident(); + let token_variant = value.dot_capitalized(); + quote! { + #token_type :: #token_variant => { + if #name { + #return_error_ref; + } + #name = true; + } + } + } + DotModifierRef::Direct { + name, + value, + type_: Some(type_), + .. + } => { + let variable = name.ident(); + let token_variant = value.dot_capitalized(); + let enum_variant = value.variant_capitalized(); + quote! { + #token_type :: #token_variant => { + if #variable.is_some() { + #return_error_ref; + } + #variable = Some(#type_ :: #enum_variant); + } + } + } + DotModifierRef::Indirect { value, name, .. } => { + let variable = name.ident(); + let type_ = value.type_.as_ref().unwrap(); + let alternatives = value.alternatives.iter().map(|alt| { + let token_variant = alt.dot_capitalized(); + let enum_variant = alt.variant_capitalized(); + quote! { + #token_type :: #token_variant => { + if #variable.is_some() { + #return_error_ref; + } + #variable = Some(#type_ :: #enum_variant); + } + } + }); + quote! { + #(#alternatives)* + } + } + }); + let unordered_parse_validations = + definition + .unordered_modifiers + .iter() + .map(|modifier| match modifier { + DotModifierRef::Direct { + optional: false, + name, + type_: None, + .. + } => { + let variable = name.ident(); + quote! { + if !#variable { + #return_error; + } + } + } + DotModifierRef::Direct { + optional: false, + name, + type_: Some(_), + .. + } => { + let variable = name.ident(); + quote! { + let #variable = match #variable { + Some(x) => x, + None => #return_error + }; + } + } + DotModifierRef::Indirect { + optional: false, + name, + .. + } => { + let variable = name.ident(); + quote! { + let #variable = match #variable { + Some(x) => x, + None => #return_error + }; + } + } + DotModifierRef::Direct { optional: true, .. } + | DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), + }); + let (arguments_pattern, arguments_parser) = definition.arguments.0.iter().enumerate().rfold((quote! { () }, quote! { empty }), |(emitted_pattern, emitted_parser), (idx, arg)| { + let comma = if idx == 0 || arg.pre_pipe { + quote! { empty } + } else { + quote! { any.verify(|(t, _)| *t == #token_type::Comma).void() } + }; + + let pre_bracket = if arg.pre_bracket { + quote! { + any.verify(|(t, _)| *t == #token_type::LBracket).void() + } + } else { + quote! { + empty + } + }; + let pre_pipe = if arg.pre_pipe { + quote! { + any.verify(|(t, _)| *t == #token_type::Pipe).void() + } + } else { + quote! { + empty + } + }; + let can_be_negated = if arg.can_be_negated { + quote! { + opt(any.verify(|(t, _)| *t == #token_type::Exclamation)).map(|o| o.is_some()) + } + } else { + quote! { + empty + } + }; + let operand = { + quote! { + ParsedOperandStr::parse + } + }; + let post_bracket = if arg.post_bracket { + quote! { + any.verify(|(t, _)| *t == #token_type::RBracket).void() + } + } else { + quote! { + empty + } + }; + let unified = if arg.unified { + quote! { + opt(any.verify(|(t, _)| *t == #token_type::DotUnified).void()).map(|u| u.is_some()) + } + } else { + quote! { + empty + } + }; + let pattern = quote! { + (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified) + }; + let arg_name = &arg.ident.ident(); + if arg.unified && arg.can_be_negated { + panic!("TODO: argument can't be both prefixed by `!` and suffixed by `.unified`") + } + let inner_parser = if arg.unified { + quote! { + #pattern.map(|(_, _, _, _, name, _, unified)| (name, unified)) + } + } else if arg.can_be_negated { + quote! { + #pattern.map(|(_, _, _, negated, name, _, _)| (negated, name)) + } + } else { + quote! { + #pattern.map(|(_, _, _, _, name, _, _)| name) + } + }; + + let parser = if arg.optional { + quote! { first_optional(#inner_parser, #emitted_parser) } + } else { + quote! { (#inner_parser, #emitted_parser) } + }; + + let pattern = quote! { ( #arg_name, #emitted_pattern ) }; + + (pattern, parser) + }); + + let arguments_parse = + quote! { let #arguments_pattern = ( #arguments_parser ).parse_next(stream)?; }; + + let fn_args = definition.function_arguments(); + let fn_name = format_ident!("{}_{}", opcode, fn_idx); + let fn_call = quote! { + #fn_name(&mut stream.state, #(#fn_args),* ) + }; + quote! { + #(#unordered_parse_declarations)* + #(#ordered_parse_declarations)* + { + let mut stream = ReverseStream(modifiers); + #(#ordered_parse)* + let mut stream: &[_] = stream.0; + for (token, _) in stream.iter().cloned() { + match token { + #(#unordered_parse)* + _ => #return_error_ref + } + } + } + #(#unordered_parse_validations)* + #arguments_parse + #fn_call + } +} + +fn write_definitions_into_tokens<'a>( + defs: &'a FxHashMap, + special_definitions: impl Iterator, + variants: &mut Punctuated, +) -> (Vec<&'a Ident>, FxHashSet<&'a parser::DotModifier>) { + let mut all_opcodes = Vec::new(); + let mut all_modifiers = FxHashSet::default(); + for (opcode, definitions) in defs.iter() { + all_opcodes.push(opcode); + let opcode_as_string = opcode.to_string(); + let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); + let arg: Variant = syn::parse_quote! { + #[token(#opcode_as_string)] + #variant_name + }; + variants.push(arg); + for definition in definitions.definitions.iter() { + for modifier in definition.possible_modifiers.iter() { + all_modifiers.insert(modifier); + } + } + } + for opcode in special_definitions { + all_opcodes.push(opcode); + let opcode_as_string = opcode.to_string(); + let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); + let arg: Variant = syn::parse_quote! { + #[token(#opcode_as_string)] + #variant_name + }; + variants.push(arg); + } + for modifier in all_modifiers.iter() { + let modifier_as_string = modifier.to_string(); + let variant_name = modifier.dot_capitalized(); + let arg: Variant = syn::parse_quote! { + #[token(#modifier_as_string)] + #variant_name + }; + variants.push(arg); + } + variants.push(parse_quote! { + #[token(".unified")] + DotUnified + }); + (all_opcodes, all_modifiers) +} + +fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +fn multihash_extend(multimap: &mut FxHashMap>, k: K, v: V) { + match multimap.entry(k) { + hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(v), + hash_map::Entry::Vacant(entry) => { + entry.insert(vec![v]); + } + } +} + +enum DotModifierRef { + Direct { + optional: bool, + value: parser::DotModifier, + name: parser::DotModifier, + type_: Option, + }, + Indirect { + optional: bool, + name: parser::DotModifier, + value: Rc, + }, +} + +impl DotModifierRef { + fn ident(&self) -> Ident { + match self { + DotModifierRef::Direct { name, .. } => name.ident(), + DotModifierRef::Indirect { name, .. } => name.ident(), + } + } + + fn type_of(&self) -> Option { + Some(match self { + DotModifierRef::Direct { + optional: true, + type_: None, + .. + } => syn::parse_quote! { bool }, + DotModifierRef::Direct { + optional: false, + type_: None, + .. + } => return None, + DotModifierRef::Direct { + optional: true, + type_: Some(type_), + .. + } => syn::parse_quote! { Option<#type_> }, + DotModifierRef::Direct { + optional: false, + type_: Some(type_), + .. + } => type_.clone(), + DotModifierRef::Indirect { + optional, value, .. + } => { + let type_ = value + .type_ + .as_ref() + .expect("Indirect modifer must have a type"); + if *optional { + syn::parse_quote! { Option<#type_> } + } else { + type_.clone() + } + } + }) + } + + fn type_of_check(&self) -> syn::Type { + match self { + DotModifierRef::Direct { type_: None, .. } => syn::parse_quote! { bool }, + DotModifierRef::Direct { + type_: Some(type_), .. + } => syn::parse_quote! { Option<#type_> }, + DotModifierRef::Indirect { value, .. } => { + let type_ = value + .type_ + .as_ref() + .expect("Indirect modifer must have a type"); + syn::parse_quote! { Option<#type_> } + } + } + } +} + +#[proc_macro] +pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(tokens as ptx_parser_macros_impl::GenerateInstructionType); + let mut result = proc_macro2::TokenStream::new(); + input.emit_arg_types(&mut result); + input.emit_instruction_type(&mut result); + input.emit_visit(&mut result); + input.emit_visit_mut(&mut result); + input.emit_visit_map(&mut result); + result.into() +} diff --git a/ptx_parser_macros_impl/src/lib.rs b/ptx_parser_macros_impl/src/lib.rs index 2f2c87a..34d97da 100644 --- a/ptx_parser_macros_impl/src/lib.rs +++ b/ptx_parser_macros_impl/src/lib.rs @@ -1,881 +1,881 @@ -use proc_macro2::TokenStream; -use quote::{format_ident, quote, ToTokens}; -use syn::{ - braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, LitBool, PathSegment, Token, - Type, TypeParam, Visibility, -}; - -pub mod parser; - -pub struct GenerateInstructionType { - pub visibility: Option, - pub name: Ident, - pub type_parameters: Punctuated, - pub short_parameters: Punctuated, - pub variants: Punctuated, -} - -impl GenerateInstructionType { - pub fn emit_arg_types(&self, tokens: &mut TokenStream) { - for v in self.variants.iter() { - v.emit_type(&self.visibility, tokens); - } - } - - pub fn emit_instruction_type(&self, tokens: &mut TokenStream) { - let vis = &self.visibility; - let type_name = &self.name; - let type_parameters = &self.type_parameters; - let variants = self.variants.iter().map(|v| v.emit_variant()); - quote! { - #vis enum #type_name<#type_parameters> { - #(#variants),* - } - } - .to_tokens(tokens); - } - - pub fn emit_visit(&self, tokens: &mut TokenStream) { - self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit) - } - - pub fn emit_visit_mut(&self, tokens: &mut TokenStream) { - self.emit_visit_impl( - VisitKind::RefMut, - tokens, - InstructionVariant::emit_visit_mut, - ) - } - - pub fn emit_visit_map(&self, tokens: &mut TokenStream) { - self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map) - } - - fn emit_visit_impl( - &self, - kind: VisitKind, - tokens: &mut TokenStream, - mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream), - ) { - let type_name = &self.name; - let type_parameters = &self.type_parameters; - let short_parameters = &self.short_parameters; - let mut inner_tokens = TokenStream::new(); - for v in self.variants.iter() { - fn_(v, type_name, &mut inner_tokens); - } - let visit_ref = kind.reference(); - let visitor_type = format_ident!("Visitor{}", kind.type_suffix()); - let visit_fn = format_ident!("visit{}", kind.fn_suffix()); - let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { - ( - quote! { <#type_parameters, To: Operand, Err> }, - quote! { <#short_parameters, To, Err> }, - quote! { std::result::Result<#type_name, Err> }, - ) - } else { - ( - quote! { <#type_parameters, Err> }, - quote! { <#short_parameters, Err> }, - quote! { std::result::Result<(), Err> }, - ) - }; - quote! { - pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { - Ok(match i { - #inner_tokens - }) - } - }.to_tokens(tokens); - if kind == VisitKind::Map { - return; - } - } -} - -#[derive(Clone, Copy, PartialEq, Eq)] -enum VisitKind { - Ref, - RefMut, - Map, -} - -impl VisitKind { - fn fn_suffix(self) -> &'static str { - match self { - VisitKind::Ref => "", - VisitKind::RefMut => "_mut", - VisitKind::Map => "_map", - } - } - - fn type_suffix(self) -> &'static str { - match self { - VisitKind::Ref => "", - VisitKind::RefMut => "Mut", - VisitKind::Map => "Map", - } - } - - fn reference(self) -> Option { - match self { - VisitKind::Ref => Some(quote! { & }), - VisitKind::RefMut => Some(quote! { &mut }), - VisitKind::Map => None, - } - } -} - -impl Parse for GenerateInstructionType { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let visibility = if !input.peek(Token![enum]) { - Some(input.parse::()?) - } else { - None - }; - input.parse::()?; - let name = input.parse::()?; - input.parse::()?; - let type_parameters = Punctuated::parse_separated_nonempty(input)?; - let short_parameters = type_parameters - .iter() - .map(|p: &TypeParam| p.ident.clone()) - .collect(); - input.parse::]>()?; - let variants_buffer; - braced!(variants_buffer in input); - let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?; - Ok(Self { - visibility, - name, - type_parameters, - short_parameters, - variants, - }) - } -} - -pub struct InstructionVariant { - pub name: Ident, - pub type_: Option>, - pub space: Option, - pub data: Option, - pub arguments: Option, - pub visit: Option, - pub visit_mut: Option, - pub map: Option, -} - -impl InstructionVariant { - fn args_name(&self) -> Ident { - format_ident!("{}Args", self.name) - } - - fn emit_variant(&self) -> TokenStream { - let name = &self.name; - let data = match &self.data { - None => { - quote! {} - } - Some(data_type) => { - quote! { - data: #data_type, - } - } - }; - let arguments = match &self.arguments { - None => { - quote! {} - } - Some(args) => { - let args_name = self.args_name(); - match &args { - Arguments::Def(InstructionArguments { generic: None, .. }) => { - quote! { - arguments: #args_name, - } - } - Arguments::Def(InstructionArguments { - generic: Some(generics), - .. - }) => { - quote! { - arguments: #args_name <#generics>, - } - } - Arguments::Decl(type_) => quote! { - arguments: #type_, - }, - } - } - }; - quote! { - #name { #data #arguments } - } - } - - fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) { - self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit) - } - - fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) { - self.emit_visit_impl( - &self.visit_mut, - enum_, - tokens, - InstructionArguments::emit_visit_mut, - ) - } - - fn emit_visit_impl( - &self, - visit_fn: &Option, - enum_: &Ident, - tokens: &mut TokenStream, - mut fn_: impl FnMut(&InstructionArguments, &Option>, &Option) -> TokenStream, - ) { - let name = &self.name; - let arguments = match &self.arguments { - None => { - quote! { - #enum_ :: #name { .. } => { } - } - .to_tokens(tokens); - return; - } - Some(Arguments::Decl(_)) => { - quote! { - #enum_ :: #name { data, arguments } => { #visit_fn } - } - .to_tokens(tokens); - return; - } - Some(Arguments::Def(args)) => args, - }; - let data = &self.data.as_ref().map(|_| quote! { data,}); - let arg_calls = fn_(arguments, &self.type_, &self.space); - quote! { - #enum_ :: #name { #data arguments } => { - #arg_calls - } - } - .to_tokens(tokens); - } - - fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) { - let name = &self.name; - let data = &self.data.as_ref().map(|_| quote! { data,}); - let arguments = match self.arguments { - None => None, - Some(Arguments::Decl(_)) => { - let map = self.map.as_ref().unwrap(); - quote! { - #enum_ :: #name { #data arguments } => { - #map - } - } - .to_tokens(tokens); - return; - } - Some(Arguments::Def(ref def)) => Some(def), - }; - let arguments_ident = &self.arguments.as_ref().map(|_| quote! { arguments,}); - let mut arg_calls = None; - let arguments_init = arguments.as_ref().map(|arguments| { - let arg_type = self.args_name(); - arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space)); - let arg_names = arguments.fields.iter().map(|arg| &arg.name); - quote! { - arguments: #arg_type { #(#arg_names),* } - } - }); - quote! { - #enum_ :: #name { #data #arguments_ident } => { - #arg_calls - #enum_ :: #name { #data #arguments_init } - } - } - .to_tokens(tokens); - } - - fn emit_type(&self, vis: &Option, tokens: &mut TokenStream) { - let arguments = match self.arguments { - Some(Arguments::Def(ref a)) => a, - Some(Arguments::Decl(_)) => return, - None => return, - }; - let name = self.args_name(); - let type_parameters = if arguments.generic.is_some() { - Some(quote! { }) - } else { - None - }; - let fields = arguments.fields.iter().map(|f| f.emit_field(vis)); - quote! { - #vis struct #name #type_parameters { - #(#fields),* - } - } - .to_tokens(tokens); - } -} - -impl Parse for InstructionVariant { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let name = input.parse::()?; - let properties_buffer; - braced!(properties_buffer in input); - let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?; - let mut type_ = None; - let mut space = None; - let mut data = None; - let mut arguments = None; - let mut visit = None; - let mut visit_mut = None; - let mut map = None; - for property in properties { - match property { - VariantProperty::Type(t) => type_ = Some(t), - VariantProperty::Space(s) => space = Some(s), - VariantProperty::Data(d) => data = Some(d), - VariantProperty::Arguments(a) => arguments = Some(a), - VariantProperty::Visit(e) => visit = Some(e), - VariantProperty::VisitMut(e) => visit_mut = Some(e), - VariantProperty::Map(e) => map = Some(e), - } - } - Ok(Self { - name, - type_, - space, - data, - arguments, - visit, - visit_mut, - map, - }) - } -} - -enum VariantProperty { - Type(Option), - Space(Expr), - Data(Type), - Arguments(Arguments), - Visit(Expr), - VisitMut(Expr), - Map(Expr), -} - -impl VariantProperty { - pub fn parse(input: syn::parse::ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - Ok(if lookahead.peek(Token![type]) { - input.parse::()?; - input.parse::()?; - VariantProperty::Type(if input.peek(Token![!]) { - input.parse::()?; - None - } else { - Some(input.parse::()?) - }) - } else if lookahead.peek(Ident) { - let key = input.parse::()?; - match &*key.to_string() { - "data" => { - input.parse::()?; - VariantProperty::Data(input.parse::()?) - } - "space" => { - input.parse::()?; - VariantProperty::Space(input.parse::()?) - } - "arguments" => { - let generics = if input.peek(Token![<]) { - input.parse::()?; - let gen_params = - Punctuated::::parse_separated_nonempty(input)?; - input.parse::]>()?; - Some(gen_params) - } else { - None - }; - input.parse::()?; - if input.peek(token::Brace) { - let fields; - braced!(fields in input); - VariantProperty::Arguments(Arguments::Def(InstructionArguments::parse( - generics, &fields, - )?)) - } else { - VariantProperty::Arguments(Arguments::Decl(input.parse::()?)) - } - } - "visit" => { - input.parse::()?; - VariantProperty::Visit(input.parse::()?) - } - "visit_mut" => { - input.parse::()?; - VariantProperty::VisitMut(input.parse::()?) - } - "map" => { - input.parse::()?; - VariantProperty::Map(input.parse::()?) - } - x => { - return Err(syn::Error::new( - key.span(), - format!( - "Unexpected key `{}`. Expected `type`, `data`, `arguments`, `visit, `visit_mut` or `map`.", - x - ), - )) - } - } - } else { - return Err(lookahead.error()); - }) - } -} - -pub enum Arguments { - Decl(Type), - Def(InstructionArguments), -} - -pub struct InstructionArguments { - pub generic: Option>, - pub fields: Punctuated, -} - -impl InstructionArguments { - pub fn parse( - generic: Option>, - input: syn::parse::ParseStream, - ) -> syn::Result { - let fields = Punctuated::::parse_terminated_with( - input, - ArgumentField::parse, - )?; - Ok(Self { generic, fields }) - } - - fn emit_visit( - &self, - parent_type: &Option>, - parent_space: &Option, - ) -> TokenStream { - self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit) - } - - fn emit_visit_mut( - &self, - parent_type: &Option>, - parent_space: &Option, - ) -> TokenStream { - self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut) - } - - fn emit_visit_map( - &self, - parent_type: &Option>, - parent_space: &Option, - ) -> TokenStream { - self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map) - } - - fn emit_visit_impl( - &self, - parent_type: &Option>, - parent_space: &Option, - mut fn_: impl FnMut(&ArgumentField, &Option>, &Option, bool) -> TokenStream, - ) -> TokenStream { - let is_ident = if let Some(ref generic) = self.generic { - generic.len() > 1 - } else { - false - }; - let field_calls = self - .fields - .iter() - .map(|f| fn_(f, parent_type, parent_space, is_ident)); - quote! { - #(#field_calls)* - } - } -} - -pub struct ArgumentField { - pub name: Ident, - pub is_dst: bool, - pub repr: Type, - pub space: Option, - pub type_: Option, - pub relaxed_type_check: bool, -} - -impl ArgumentField { - fn parse_block( - input: syn::parse::ParseStream, - ) -> syn::Result<(Type, Option, Option, Option, bool)> { - let content; - braced!(content in input); - let all_fields = - Punctuated::::parse_terminated_with(&content, |content| { - let lookahead = content.lookahead1(); - Ok(if lookahead.peek(Token![type]) { - content.parse::()?; - content.parse::()?; - ExprOrPath::Type(content.parse::()?) - } else if lookahead.peek(Ident) { - let name_ident = content.parse::()?; - content.parse::()?; - match &*name_ident.to_string() { - "relaxed_type_check" => { - ExprOrPath::RelaxedTypeCheck(content.parse::()?.value) - } - "repr" => ExprOrPath::Repr(content.parse::()?), - "space" => ExprOrPath::Space(content.parse::()?), - "dst" => { - let ident = content.parse::()?; - ExprOrPath::Dst(ident.value) - } - name => { - return Err(syn::Error::new( - name_ident.span(), - format!("Unexpected key `{}`, expected `repr` or `space", name), - )) - } - } - } else { - return Err(lookahead.error()); - }) - })?; - let mut repr = None; - let mut type_ = None; - let mut space = None; - let mut is_dst = None; - let mut relaxed_type_check = false; - for exp_or_path in all_fields { - match exp_or_path { - ExprOrPath::Repr(r) => repr = Some(r), - ExprOrPath::Type(t) => type_ = Some(t), - ExprOrPath::Space(s) => space = Some(s), - ExprOrPath::Dst(x) => is_dst = Some(x), - ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed, - } - } - Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check)) - } - - fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result { - input.parse::() - } - - fn emit_visit( - &self, - parent_type: &Option>, - parent_space: &Option, - is_ident: bool, - ) -> TokenStream { - self.emit_visit_impl(parent_type, parent_space, is_ident, false) - } - - fn emit_visit_mut( - &self, - parent_type: &Option>, - parent_space: &Option, - is_ident: bool, - ) -> TokenStream { - self.emit_visit_impl(parent_type, parent_space, is_ident, true) - } - - fn emit_visit_impl( - &self, - parent_type: &Option>, - parent_space: &Option, - is_ident: bool, - is_mut: bool, - ) -> TokenStream { - let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) { - (Some(type_), _) => (false, Some(type_)), - (None, None) => panic!("No type set"), - (None, Some(None)) => (true, None), - (None, Some(Some(type_))) => (false, Some(type_)), - }; - let space = self - .space - .as_ref() - .or(parent_space.as_ref()) - .map(|space| quote! { #space }) - .unwrap_or_else(|| quote! { StateSpace::Reg }); - let is_dst = self.is_dst; - let relaxed_type_check = self.relaxed_type_check; - let name = &self.name; - let type_space = if is_typeless { - quote! { - let type_space = None; - } - } else { - quote! { - let type_ = #type_; - let space = #space; - let type_space = Some((std::borrow::Borrow::::borrow(&type_), space)); - } - }; - if is_ident { - if is_mut { - quote! { - { - #type_space - visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?; - } - } - } else { - quote! { - { - #type_space - visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?; - } - } - } - } else { - let (operand_fn, arguments_name) = if is_mut { - ( - quote! { - VisitOperand::visit_mut - }, - quote! { - &mut arguments.#name - }, - ) - } else { - ( - quote! { - VisitOperand::visit - }, - quote! { - & arguments.#name - }, - ) - }; - quote! {{ - #type_space - #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?; - }} - } - } - - fn emit_visit_map( - &self, - parent_type: &Option>, - parent_space: &Option, - is_ident: bool, - ) -> TokenStream { - let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) { - (Some(type_), _) => (false, Some(type_)), - (None, None) => panic!("No type set"), - (None, Some(None)) => (true, None), - (None, Some(Some(type_))) => (false, Some(type_)), - }; - let space = self - .space - .as_ref() - .or(parent_space.as_ref()) - .map(|space| quote! { #space }) - .unwrap_or_else(|| quote! { StateSpace::Reg }); - let is_dst = self.is_dst; - let relaxed_type_check = self.relaxed_type_check; - let name = &self.name; - let type_space = if is_typeless { - quote! { - let type_space = None; - } - } else { - quote! { - let type_ = #type_; - let space = #space; - let type_space = Some((std::borrow::Borrow::::borrow(&type_), space)); - } - }; - let map_call = if is_ident { - quote! { - visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)? - } - } else { - quote! { - MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))? - } - }; - quote! { - let #name = { - #type_space - #map_call - }; - } - } - - fn is_dst(name: &Ident) -> syn::Result { - if name.to_string().starts_with("dst") { - Ok(true) - } else if name.to_string().starts_with("src") { - Ok(false) - } else { - return Err(syn::Error::new( - name.span(), - format!( - "Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`", - name - ), - )); - } - } - - fn emit_field(&self, vis: &Option) -> TokenStream { - let name = &self.name; - let type_ = &self.repr; - quote! { - #vis #name: #type_ - } - } -} - -impl Parse for ArgumentField { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let name = input.parse::()?; - - input.parse::()?; - let lookahead = input.lookahead1(); - let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) { - Self::parse_block(input)? - } else if lookahead.peek(syn::Ident) { - (Self::parse_basic(input)?, None, None, None, false) - } else { - return Err(lookahead.error()); - }; - let is_dst = match is_dst { - Some(x) => x, - None => Self::is_dst(&name)?, - }; - Ok(Self { - name, - is_dst, - repr, - type_, - space, - relaxed_type_check - }) - } -} - -enum ExprOrPath { - Repr(Type), - Type(Expr), - Space(Expr), - Dst(bool), - RelaxedTypeCheck(bool), -} - -#[cfg(test)] -mod tests { - use super::*; - use proc_macro2::Span; - use quote::{quote, ToTokens}; - - fn to_string(x: impl ToTokens) -> String { - quote! { #x }.to_string() - } - - #[test] - fn parse_argument_field_basic() { - let input = quote! { - dst: P::Operand - }; - let arg = syn::parse2::(input).unwrap(); - assert_eq!("dst", arg.name.to_string()); - assert_eq!("P :: Operand", to_string(arg.repr)); - assert!(matches!(arg.type_, None)); - } - - #[test] - fn parse_argument_field_block() { - let input = quote! { - dst: { - type: ScalarType::U32, - space: StateSpace::Global, - repr: P::Operand, - } - }; - let arg = syn::parse2::(input).unwrap(); - assert_eq!("dst", arg.name.to_string()); - assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap())); - assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap())); - assert_eq!("P :: Operand", to_string(arg.repr)); - } - - #[test] - fn parse_argument_field_block_untyped() { - let input = quote! { - dst: { - repr: P::Operand, - } - }; - let arg = syn::parse2::(input).unwrap(); - assert_eq!("dst", arg.name.to_string()); - assert_eq!("P :: Operand", to_string(arg.repr)); - assert!(matches!(arg.type_, None)); - } - - #[test] - fn parse_variant_complex() { - let input = quote! { - Ld { - type: ScalarType::U32, - space: StateSpace::Global, - data: LdDetails, - arguments

: { - dst: { - repr: P::Operand, - type: ScalarType::U32, - space: StateSpace::Shared, - }, - src: P::Operand, - }, - } - }; - let variant = syn::parse2::(input).unwrap(); - assert_eq!("Ld", variant.name.to_string()); - assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap())); - assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap())); - assert_eq!("LdDetails", to_string(variant.data.unwrap())); - let arguments = if let Some(Arguments::Def(a)) = variant.arguments { - a - } else { - panic!() - }; - assert_eq!("P", to_string(arguments.generic)); - let mut fields = arguments.fields.into_iter(); - let dst = fields.next().unwrap(); - assert_eq!("P :: Operand", to_string(dst.repr)); - assert_eq!("ScalarType :: U32", to_string(dst.type_)); - assert_eq!("StateSpace :: Shared", to_string(dst.space)); - let src = fields.next().unwrap(); - assert_eq!("P :: Operand", to_string(src.repr)); - assert!(matches!(src.type_, None)); - assert!(matches!(src.space, None)); - } - - #[test] - fn visit_variant_empty() { - let input = quote! { - Ret { - data: RetData - } - }; - let variant = syn::parse2::(input).unwrap(); - let mut output = TokenStream::new(); - variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output); - assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }"); - } -} +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, LitBool, PathSegment, Token, + Type, TypeParam, Visibility, +}; + +pub mod parser; + +pub struct GenerateInstructionType { + pub visibility: Option, + pub name: Ident, + pub type_parameters: Punctuated, + pub short_parameters: Punctuated, + pub variants: Punctuated, +} + +impl GenerateInstructionType { + pub fn emit_arg_types(&self, tokens: &mut TokenStream) { + for v in self.variants.iter() { + v.emit_type(&self.visibility, tokens); + } + } + + pub fn emit_instruction_type(&self, tokens: &mut TokenStream) { + let vis = &self.visibility; + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let variants = self.variants.iter().map(|v| v.emit_variant()); + quote! { + #vis enum #type_name<#type_parameters> { + #(#variants),* + } + } + .to_tokens(tokens); + } + + pub fn emit_visit(&self, tokens: &mut TokenStream) { + self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit) + } + + pub fn emit_visit_mut(&self, tokens: &mut TokenStream) { + self.emit_visit_impl( + VisitKind::RefMut, + tokens, + InstructionVariant::emit_visit_mut, + ) + } + + pub fn emit_visit_map(&self, tokens: &mut TokenStream) { + self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map) + } + + fn emit_visit_impl( + &self, + kind: VisitKind, + tokens: &mut TokenStream, + mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream), + ) { + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let short_parameters = &self.short_parameters; + let mut inner_tokens = TokenStream::new(); + for v in self.variants.iter() { + fn_(v, type_name, &mut inner_tokens); + } + let visit_ref = kind.reference(); + let visitor_type = format_ident!("Visitor{}", kind.type_suffix()); + let visit_fn = format_ident!("visit{}", kind.fn_suffix()); + let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { + ( + quote! { <#type_parameters, To: Operand, Err> }, + quote! { <#short_parameters, To, Err> }, + quote! { std::result::Result<#type_name, Err> }, + ) + } else { + ( + quote! { <#type_parameters, Err> }, + quote! { <#short_parameters, Err> }, + quote! { std::result::Result<(), Err> }, + ) + }; + quote! { + pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { + Ok(match i { + #inner_tokens + }) + } + }.to_tokens(tokens); + if kind == VisitKind::Map { + return; + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum VisitKind { + Ref, + RefMut, + Map, +} + +impl VisitKind { + fn fn_suffix(self) -> &'static str { + match self { + VisitKind::Ref => "", + VisitKind::RefMut => "_mut", + VisitKind::Map => "_map", + } + } + + fn type_suffix(self) -> &'static str { + match self { + VisitKind::Ref => "", + VisitKind::RefMut => "Mut", + VisitKind::Map => "Map", + } + } + + fn reference(self) -> Option { + match self { + VisitKind::Ref => Some(quote! { & }), + VisitKind::RefMut => Some(quote! { &mut }), + VisitKind::Map => None, + } + } +} + +impl Parse for GenerateInstructionType { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let visibility = if !input.peek(Token![enum]) { + Some(input.parse::()?) + } else { + None + }; + input.parse::()?; + let name = input.parse::()?; + input.parse::()?; + let type_parameters = Punctuated::parse_separated_nonempty(input)?; + let short_parameters = type_parameters + .iter() + .map(|p: &TypeParam| p.ident.clone()) + .collect(); + input.parse::]>()?; + let variants_buffer; + braced!(variants_buffer in input); + let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?; + Ok(Self { + visibility, + name, + type_parameters, + short_parameters, + variants, + }) + } +} + +pub struct InstructionVariant { + pub name: Ident, + pub type_: Option>, + pub space: Option, + pub data: Option, + pub arguments: Option, + pub visit: Option, + pub visit_mut: Option, + pub map: Option, +} + +impl InstructionVariant { + fn args_name(&self) -> Ident { + format_ident!("{}Args", self.name) + } + + fn emit_variant(&self) -> TokenStream { + let name = &self.name; + let data = match &self.data { + None => { + quote! {} + } + Some(data_type) => { + quote! { + data: #data_type, + } + } + }; + let arguments = match &self.arguments { + None => { + quote! {} + } + Some(args) => { + let args_name = self.args_name(); + match &args { + Arguments::Def(InstructionArguments { generic: None, .. }) => { + quote! { + arguments: #args_name, + } + } + Arguments::Def(InstructionArguments { + generic: Some(generics), + .. + }) => { + quote! { + arguments: #args_name <#generics>, + } + } + Arguments::Decl(type_) => quote! { + arguments: #type_, + }, + } + } + }; + quote! { + #name { #data #arguments } + } + } + + fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) { + self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit) + } + + fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) { + self.emit_visit_impl( + &self.visit_mut, + enum_, + tokens, + InstructionArguments::emit_visit_mut, + ) + } + + fn emit_visit_impl( + &self, + visit_fn: &Option, + enum_: &Ident, + tokens: &mut TokenStream, + mut fn_: impl FnMut(&InstructionArguments, &Option>, &Option) -> TokenStream, + ) { + let name = &self.name; + let arguments = match &self.arguments { + None => { + quote! { + #enum_ :: #name { .. } => { } + } + .to_tokens(tokens); + return; + } + Some(Arguments::Decl(_)) => { + quote! { + #enum_ :: #name { data, arguments } => { #visit_fn } + } + .to_tokens(tokens); + return; + } + Some(Arguments::Def(args)) => args, + }; + let data = &self.data.as_ref().map(|_| quote! { data,}); + let arg_calls = fn_(arguments, &self.type_, &self.space); + quote! { + #enum_ :: #name { #data arguments } => { + #arg_calls + } + } + .to_tokens(tokens); + } + + fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) { + let name = &self.name; + let data = &self.data.as_ref().map(|_| quote! { data,}); + let arguments = match self.arguments { + None => None, + Some(Arguments::Decl(_)) => { + let map = self.map.as_ref().unwrap(); + quote! { + #enum_ :: #name { #data arguments } => { + #map + } + } + .to_tokens(tokens); + return; + } + Some(Arguments::Def(ref def)) => Some(def), + }; + let arguments_ident = &self.arguments.as_ref().map(|_| quote! { arguments,}); + let mut arg_calls = None; + let arguments_init = arguments.as_ref().map(|arguments| { + let arg_type = self.args_name(); + arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space)); + let arg_names = arguments.fields.iter().map(|arg| &arg.name); + quote! { + arguments: #arg_type { #(#arg_names),* } + } + }); + quote! { + #enum_ :: #name { #data #arguments_ident } => { + #arg_calls + #enum_ :: #name { #data #arguments_init } + } + } + .to_tokens(tokens); + } + + fn emit_type(&self, vis: &Option, tokens: &mut TokenStream) { + let arguments = match self.arguments { + Some(Arguments::Def(ref a)) => a, + Some(Arguments::Decl(_)) => return, + None => return, + }; + let name = self.args_name(); + let type_parameters = if arguments.generic.is_some() { + Some(quote! { }) + } else { + None + }; + let fields = arguments.fields.iter().map(|f| f.emit_field(vis)); + quote! { + #vis struct #name #type_parameters { + #(#fields),* + } + } + .to_tokens(tokens); + } +} + +impl Parse for InstructionVariant { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name = input.parse::()?; + let properties_buffer; + braced!(properties_buffer in input); + let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?; + let mut type_ = None; + let mut space = None; + let mut data = None; + let mut arguments = None; + let mut visit = None; + let mut visit_mut = None; + let mut map = None; + for property in properties { + match property { + VariantProperty::Type(t) => type_ = Some(t), + VariantProperty::Space(s) => space = Some(s), + VariantProperty::Data(d) => data = Some(d), + VariantProperty::Arguments(a) => arguments = Some(a), + VariantProperty::Visit(e) => visit = Some(e), + VariantProperty::VisitMut(e) => visit_mut = Some(e), + VariantProperty::Map(e) => map = Some(e), + } + } + Ok(Self { + name, + type_, + space, + data, + arguments, + visit, + visit_mut, + map, + }) + } +} + +enum VariantProperty { + Type(Option), + Space(Expr), + Data(Type), + Arguments(Arguments), + Visit(Expr), + VisitMut(Expr), + Map(Expr), +} + +impl VariantProperty { + pub fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + Ok(if lookahead.peek(Token![type]) { + input.parse::()?; + input.parse::()?; + VariantProperty::Type(if input.peek(Token![!]) { + input.parse::()?; + None + } else { + Some(input.parse::()?) + }) + } else if lookahead.peek(Ident) { + let key = input.parse::()?; + match &*key.to_string() { + "data" => { + input.parse::()?; + VariantProperty::Data(input.parse::()?) + } + "space" => { + input.parse::()?; + VariantProperty::Space(input.parse::()?) + } + "arguments" => { + let generics = if input.peek(Token![<]) { + input.parse::()?; + let gen_params = + Punctuated::::parse_separated_nonempty(input)?; + input.parse::]>()?; + Some(gen_params) + } else { + None + }; + input.parse::()?; + if input.peek(token::Brace) { + let fields; + braced!(fields in input); + VariantProperty::Arguments(Arguments::Def(InstructionArguments::parse( + generics, &fields, + )?)) + } else { + VariantProperty::Arguments(Arguments::Decl(input.parse::()?)) + } + } + "visit" => { + input.parse::()?; + VariantProperty::Visit(input.parse::()?) + } + "visit_mut" => { + input.parse::()?; + VariantProperty::VisitMut(input.parse::()?) + } + "map" => { + input.parse::()?; + VariantProperty::Map(input.parse::()?) + } + x => { + return Err(syn::Error::new( + key.span(), + format!( + "Unexpected key `{}`. Expected `type`, `data`, `arguments`, `visit, `visit_mut` or `map`.", + x + ), + )) + } + } + } else { + return Err(lookahead.error()); + }) + } +} + +pub enum Arguments { + Decl(Type), + Def(InstructionArguments), +} + +pub struct InstructionArguments { + pub generic: Option>, + pub fields: Punctuated, +} + +impl InstructionArguments { + pub fn parse( + generic: Option>, + input: syn::parse::ParseStream, + ) -> syn::Result { + let fields = Punctuated::::parse_terminated_with( + input, + ArgumentField::parse, + )?; + Ok(Self { generic, fields }) + } + + fn emit_visit( + &self, + parent_type: &Option>, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit) + } + + fn emit_visit_mut( + &self, + parent_type: &Option>, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut) + } + + fn emit_visit_map( + &self, + parent_type: &Option>, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map) + } + + fn emit_visit_impl( + &self, + parent_type: &Option>, + parent_space: &Option, + mut fn_: impl FnMut(&ArgumentField, &Option>, &Option, bool) -> TokenStream, + ) -> TokenStream { + let is_ident = if let Some(ref generic) = self.generic { + generic.len() > 1 + } else { + false + }; + let field_calls = self + .fields + .iter() + .map(|f| fn_(f, parent_type, parent_space, is_ident)); + quote! { + #(#field_calls)* + } + } +} + +pub struct ArgumentField { + pub name: Ident, + pub is_dst: bool, + pub repr: Type, + pub space: Option, + pub type_: Option, + pub relaxed_type_check: bool, +} + +impl ArgumentField { + fn parse_block( + input: syn::parse::ParseStream, + ) -> syn::Result<(Type, Option, Option, Option, bool)> { + let content; + braced!(content in input); + let all_fields = + Punctuated::::parse_terminated_with(&content, |content| { + let lookahead = content.lookahead1(); + Ok(if lookahead.peek(Token![type]) { + content.parse::()?; + content.parse::()?; + ExprOrPath::Type(content.parse::()?) + } else if lookahead.peek(Ident) { + let name_ident = content.parse::()?; + content.parse::()?; + match &*name_ident.to_string() { + "relaxed_type_check" => { + ExprOrPath::RelaxedTypeCheck(content.parse::()?.value) + } + "repr" => ExprOrPath::Repr(content.parse::()?), + "space" => ExprOrPath::Space(content.parse::()?), + "dst" => { + let ident = content.parse::()?; + ExprOrPath::Dst(ident.value) + } + name => { + return Err(syn::Error::new( + name_ident.span(), + format!("Unexpected key `{}`, expected `repr` or `space", name), + )) + } + } + } else { + return Err(lookahead.error()); + }) + })?; + let mut repr = None; + let mut type_ = None; + let mut space = None; + let mut is_dst = None; + let mut relaxed_type_check = false; + for exp_or_path in all_fields { + match exp_or_path { + ExprOrPath::Repr(r) => repr = Some(r), + ExprOrPath::Type(t) => type_ = Some(t), + ExprOrPath::Space(s) => space = Some(s), + ExprOrPath::Dst(x) => is_dst = Some(x), + ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed, + } + } + Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check)) + } + + fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result { + input.parse::() + } + + fn emit_visit( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, is_ident, false) + } + + fn emit_visit_mut( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, is_ident, true) + } + + fn emit_visit_impl( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + is_mut: bool, + ) -> TokenStream { + let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) { + (Some(type_), _) => (false, Some(type_)), + (None, None) => panic!("No type set"), + (None, Some(None)) => (true, None), + (None, Some(Some(type_))) => (false, Some(type_)), + }; + let space = self + .space + .as_ref() + .or(parent_space.as_ref()) + .map(|space| quote! { #space }) + .unwrap_or_else(|| quote! { StateSpace::Reg }); + let is_dst = self.is_dst; + let relaxed_type_check = self.relaxed_type_check; + let name = &self.name; + let type_space = if is_typeless { + quote! { + let type_space = None; + } + } else { + quote! { + let type_ = #type_; + let space = #space; + let type_space = Some((std::borrow::Borrow::::borrow(&type_), space)); + } + }; + if is_ident { + if is_mut { + quote! { + { + #type_space + visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?; + } + } + } else { + quote! { + { + #type_space + visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?; + } + } + } + } else { + let (operand_fn, arguments_name) = if is_mut { + ( + quote! { + VisitOperand::visit_mut + }, + quote! { + &mut arguments.#name + }, + ) + } else { + ( + quote! { + VisitOperand::visit + }, + quote! { + & arguments.#name + }, + ) + }; + quote! {{ + #type_space + #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?; + }} + } + } + + fn emit_visit_map( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + ) -> TokenStream { + let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) { + (Some(type_), _) => (false, Some(type_)), + (None, None) => panic!("No type set"), + (None, Some(None)) => (true, None), + (None, Some(Some(type_))) => (false, Some(type_)), + }; + let space = self + .space + .as_ref() + .or(parent_space.as_ref()) + .map(|space| quote! { #space }) + .unwrap_or_else(|| quote! { StateSpace::Reg }); + let is_dst = self.is_dst; + let relaxed_type_check = self.relaxed_type_check; + let name = &self.name; + let type_space = if is_typeless { + quote! { + let type_space = None; + } + } else { + quote! { + let type_ = #type_; + let space = #space; + let type_space = Some((std::borrow::Borrow::::borrow(&type_), space)); + } + }; + let map_call = if is_ident { + quote! { + visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)? + } + } else { + quote! { + MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))? + } + }; + quote! { + let #name = { + #type_space + #map_call + }; + } + } + + fn is_dst(name: &Ident) -> syn::Result { + if name.to_string().starts_with("dst") { + Ok(true) + } else if name.to_string().starts_with("src") { + Ok(false) + } else { + return Err(syn::Error::new( + name.span(), + format!( + "Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`", + name + ), + )); + } + } + + fn emit_field(&self, vis: &Option) -> TokenStream { + let name = &self.name; + let type_ = &self.repr; + quote! { + #vis #name: #type_ + } + } +} + +impl Parse for ArgumentField { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name = input.parse::()?; + + input.parse::()?; + let lookahead = input.lookahead1(); + let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) { + Self::parse_block(input)? + } else if lookahead.peek(syn::Ident) { + (Self::parse_basic(input)?, None, None, None, false) + } else { + return Err(lookahead.error()); + }; + let is_dst = match is_dst { + Some(x) => x, + None => Self::is_dst(&name)?, + }; + Ok(Self { + name, + is_dst, + repr, + type_, + space, + relaxed_type_check, + }) + } +} + +enum ExprOrPath { + Repr(Type), + Type(Expr), + Space(Expr), + Dst(bool), + RelaxedTypeCheck(bool), +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use quote::{quote, ToTokens}; + + fn to_string(x: impl ToTokens) -> String { + quote! { #x }.to_string() + } + + #[test] + fn parse_argument_field_basic() { + let input = quote! { + dst: P::Operand + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("P :: Operand", to_string(arg.repr)); + assert!(matches!(arg.type_, None)); + } + + #[test] + fn parse_argument_field_block() { + let input = quote! { + dst: { + type: ScalarType::U32, + space: StateSpace::Global, + repr: P::Operand, + } + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap())); + assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap())); + assert_eq!("P :: Operand", to_string(arg.repr)); + } + + #[test] + fn parse_argument_field_block_untyped() { + let input = quote! { + dst: { + repr: P::Operand, + } + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("P :: Operand", to_string(arg.repr)); + assert!(matches!(arg.type_, None)); + } + + #[test] + fn parse_variant_complex() { + let input = quote! { + Ld { + type: ScalarType::U32, + space: StateSpace::Global, + data: LdDetails, + arguments

: { + dst: { + repr: P::Operand, + type: ScalarType::U32, + space: StateSpace::Shared, + }, + src: P::Operand, + }, + } + }; + let variant = syn::parse2::(input).unwrap(); + assert_eq!("Ld", variant.name.to_string()); + assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap())); + assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap())); + assert_eq!("LdDetails", to_string(variant.data.unwrap())); + let arguments = if let Some(Arguments::Def(a)) = variant.arguments { + a + } else { + panic!() + }; + assert_eq!("P", to_string(arguments.generic)); + let mut fields = arguments.fields.into_iter(); + let dst = fields.next().unwrap(); + assert_eq!("P :: Operand", to_string(dst.repr)); + assert_eq!("ScalarType :: U32", to_string(dst.type_)); + assert_eq!("StateSpace :: Shared", to_string(dst.space)); + let src = fields.next().unwrap(); + assert_eq!("P :: Operand", to_string(src.repr)); + assert!(matches!(src.type_, None)); + assert!(matches!(src.space, None)); + } + + #[test] + fn visit_variant_empty() { + let input = quote! { + Ret { + data: RetData + } + }; + let variant = syn::parse2::(input).unwrap(); + let mut output = TokenStream::new(); + variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output); + assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }"); + } +} diff --git a/ptx_parser_macros_impl/src/parser.rs b/ptx_parser_macros_impl/src/parser.rs index 20a7dab..b9a5d2f 100644 --- a/ptx_parser_macros_impl/src/parser.rs +++ b/ptx_parser_macros_impl/src/parser.rs @@ -1,914 +1,913 @@ -use proc_macro2::Span; -use proc_macro2::TokenStream; -use quote::quote; -use quote::ToTokens; -use rustc_hash::FxHashMap; -use std::fmt::Write; -use syn::bracketed; -use syn::parse::Peek; -use syn::punctuated::Punctuated; -use syn::spanned::Spanned; -use syn::LitInt; -use syn::Type; -use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token}; - -pub struct ParseDefinitions { - pub token_type: ItemEnum, - pub additional_enums: FxHashMap, - pub definitions: Vec, -} - -impl Parse for ParseDefinitions { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let token_type = input.parse::()?; - let mut additional_enums = FxHashMap::default(); - let mut definitions = Vec::new(); - loop { - if input.is_empty() { - break; - } - - let lookahead = input.lookahead1(); - if lookahead.peek(Token![#]) { - let enum_ = input.parse::()?; - additional_enums.insert(enum_.ident.clone(), enum_); - } else if lookahead.peek(Ident) { - definitions.push(input.parse::()?); - } else { - return Err(lookahead.error()); - } - } - - Ok(Self { - token_type, - additional_enums, - definitions, - }) - } -} - -pub struct OpcodeDefinition(pub Patterns, pub Vec); - -impl Parse for OpcodeDefinition { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let patterns = input.parse::()?; - let mut rules = Vec::new(); - while Rule::peek(input) { - rules.push(input.parse::()?); - input.parse::()?; - } - Ok(Self(patterns, rules)) - } -} - -pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>); - -impl Parse for Patterns { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut result = Vec::new(); - loop { - if !OpcodeDecl::peek(input) { - break; - } - let decl = input.parse::()?; - let code_block = input.parse::()?; - result.push((decl, code_block)) - } - Ok(Self(result)) - } -} - -pub struct OpcodeDecl(pub Instruction, pub Arguments); - -impl OpcodeDecl { - fn peek(input: syn::parse::ParseStream) -> bool { - Instruction::peek(input) && !input.peek2(Token![=]) - } -} - -impl Parse for OpcodeDecl { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - Ok(Self( - input.parse::()?, - input.parse::()?, - )) - } -} - -pub struct CodeBlock { - pub special: bool, - pub code: proc_macro2::Group, -} - -impl Parse for CodeBlock { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - let (special, code) = if lookahead.peek(Token![<]) { - input.parse::()?; - input.parse::()?; - //input.parse::]>()?; - (true, input.parse::()?) - } else if lookahead.peek(Token![=]) { - input.parse::()?; - input.parse::]>()?; - (false, input.parse::()?) - } else { - return Err(lookahead.error()); - }; - Ok(Self { special, code }) - } -} - -pub struct Rule { - pub modifier: Option, - pub type_: Option, - pub alternatives: Vec, -} - -impl Rule { - fn peek(input: syn::parse::ParseStream) -> bool { - DotModifier::peek(input) - || (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>])) - } - - fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result> { - let mut result = Vec::new(); - Self::parse_with_alternative(input, &mut result)?; - loop { - if !input.peek(Token![,]) { - break; - } - input.parse::()?; - Self::parse_with_alternative(input, &mut result)?; - } - Ok(result) - } - - fn parse_with_alternative( - input: &syn::parse::ParseBuffer, - result: &mut Vec, - ) -> Result<(), syn::Error> { - input.parse::()?; - let part1 = input.parse::()?; - if input.peek(token::Brace) { - result.push(DotModifier { - part1: part1.clone(), - part2: None, - }); - let suffix_content; - braced!(suffix_content in input); - let suffixes = Punctuated::::parse_separated_nonempty( - &suffix_content, - )?; - for part2 in suffixes { - result.push(DotModifier { - part1: part1.clone(), - part2: Some(part2), - }); - } - } else if IdentOrTypeSuffix::peek(input) { - let part2 = Some(IdentOrTypeSuffix::parse(input)?); - result.push(DotModifier { part1, part2 }); - } else { - result.push(DotModifier { part1, part2: None }); - } - Ok(()) - } -} - -#[derive(PartialEq, Eq, Hash, Clone)] -struct IdentOrTypeSuffix(IdentLike); - -impl IdentOrTypeSuffix { - fn span(&self) -> Span { - self.0.span() - } - - fn peek(input: syn::parse::ParseStream) -> bool { - input.peek(Token![::]) - } -} - -impl ToTokens for IdentOrTypeSuffix { - fn to_tokens(&self, tokens: &mut TokenStream) { - let ident = &self.0; - quote! { :: #ident }.to_tokens(tokens) - } -} - -impl Parse for IdentOrTypeSuffix { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - input.parse::()?; - Ok(Self(input.parse::()?)) - } -} - -impl Parse for Rule { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let (modifier, type_) = if DotModifier::peek(input) { - let modifier = Some(input.parse::()?); - if input.peek(Token![:]) { - input.parse::()?; - (modifier, Some(input.parse::()?)) - } else { - (modifier, None) - } - } else { - (None, Some(input.parse::()?)) - }; - input.parse::()?; - let content; - braced!(content in input); - let alternatives = Self::parse_alternatives(&content)?; - Ok(Self { - modifier, - type_, - alternatives, - }) - } -} - -pub struct Instruction { - pub name: Ident, - pub modifiers: Vec, -} -impl Instruction { - fn peek(input: syn::parse::ParseStream) -> bool { - input.peek(Ident) - } -} - -impl Parse for Instruction { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let instruction = input.parse::()?; - let mut modifiers = Vec::new(); - loop { - if !MaybeDotModifier::peek(input) { - break; - } - modifiers.push(MaybeDotModifier::parse(input)?); - } - Ok(Self { - name: instruction, - modifiers, - }) - } -} - -pub struct MaybeDotModifier { - pub optional: bool, - pub modifier: DotModifier, -} - -impl MaybeDotModifier { - fn peek(input: syn::parse::ParseStream) -> bool { - input.peek(token::Brace) || DotModifier::peek(input) - } -} - -impl Parse for MaybeDotModifier { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - Ok(if input.peek(token::Brace) { - let content; - braced!(content in input); - let modifier = DotModifier::parse(&content)?; - Self { - modifier, - optional: true, - } - } else { - let modifier = DotModifier::parse(input)?; - Self { - modifier, - optional: false, - } - }) - } -} - -#[derive(PartialEq, Eq, Hash, Clone)] -pub struct DotModifier { - part1: IdentLike, - part2: Option, -} - -impl std::fmt::Display for DotModifier { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, ".")?; - self.part1.fmt(f)?; - if let Some(ref part2) = self.part2 { - write!(f, "::")?; - part2.0.fmt(f)?; - } - Ok(()) - } -} - -impl std::fmt::Debug for DotModifier { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Display::fmt(&self, f) - } -} - -impl DotModifier { - pub fn span(&self) -> Span { - let part1 = self.part1.span(); - if let Some(ref part2) = self.part2 { - part1.join(part2.span()).unwrap_or(part1) - } else { - part1 - } - } - - pub fn ident(&self) -> Ident { - let mut result = String::new(); - write!(&mut result, "{}", self.part1).unwrap(); - if let Some(ref part2) = self.part2 { - write!(&mut result, "_{}", part2.0).unwrap(); - } else { - match self.part1 { - IdentLike::Type(_) | IdentLike::Const(_) | IdentLike::Async(_) => result.push('_'), - IdentLike::Ident(_) | IdentLike::Integer(_) => {} - } - } - Ident::new(&result.to_ascii_lowercase(), self.span()) - } - - pub fn variant_capitalized(&self) -> Ident { - self.capitalized_impl(String::new()) - } - - pub fn dot_capitalized(&self) -> Ident { - self.capitalized_impl("Dot".to_string()) - } - - fn capitalized_impl(&self, prefix: String) -> Ident { - let mut temp = String::new(); - write!(&mut temp, "{}", &self.part1).unwrap(); - if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 { - write!(&mut temp, "_{}", part2).unwrap(); - } - let mut result = prefix; - let mut capitalize = true; - for c in temp.chars() { - if c == '_' { - capitalize = true; - continue; - } - // Special hack to emit `BF16`` instead of `Bf16`` - let c = if capitalize || c == 'f' && result.ends_with('B') { - capitalize = false; - c.to_ascii_uppercase() - } else { - c - }; - result.push(c); - } - Ident::new(&result, self.span()) - } - - pub fn tokens(&self) -> TokenStream { - let part1 = &self.part1; - let part2 = &self.part2; - match self.part2 { - None => quote! { . #part1 }, - Some(_) => quote! { . #part1 #part2 }, - } - } - - fn peek(input: syn::parse::ParseStream) -> bool { - input.peek(Token![.]) - } -} - -impl Parse for DotModifier { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - input.parse::()?; - let part1 = input.parse::()?; - if IdentOrTypeSuffix::peek(input) { - let part2 = Some(IdentOrTypeSuffix::parse(input)?); - Ok(Self { part1, part2 }) - } else { - Ok(Self { part1, part2: None }) - } - } -} - -#[derive(PartialEq, Eq)] -pub struct HyphenatedIdent { - idents: Punctuated, -} - -impl HyphenatedIdent { - fn span(&self) -> Span { - self.idents.span() - } - - pub fn ident(&self) -> Ident { - Ident::new(&self.to_string().to_string(), self.span()) - } -} - -impl std::fmt::Display for HyphenatedIdent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut idents = self.idents.iter(); - - if let Some(id) = idents.next() { - write!(f, "{}", id)?; - } - - for id in idents { - write!(f, "_{}", id)?; - } - - Ok(()) - } - -} - -impl Parse for HyphenatedIdent { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let idents = Punctuated::parse_separated_nonempty(input)?; - Ok(Self { idents }) - } -} - -#[derive(PartialEq, Eq, Hash, Clone)] -enum IdentLike { - Type(Token![type]), - Const(Token![const]), - Async(Token![async]), - Ident(Ident), - Integer(LitInt), -} - -impl IdentLike { - fn span(&self) -> Span { - match self { - IdentLike::Type(c) => c.span(), - IdentLike::Const(t) => t.span(), - IdentLike::Async(a) => a.span(), - IdentLike::Ident(i) => i.span(), - IdentLike::Integer(l) => l.span(), - } - } -} - -impl std::fmt::Display for IdentLike { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - IdentLike::Type(_) => f.write_str("type"), - IdentLike::Const(_) => f.write_str("const"), - IdentLike::Async(_) => f.write_str("async"), - IdentLike::Ident(ident) => write!(f, "{}", ident), - IdentLike::Integer(integer) => write!(f, "{}", integer), - } - } -} - -impl ToTokens for IdentLike { - fn to_tokens(&self, tokens: &mut TokenStream) { - match self { - IdentLike::Type(_) => quote! { type }.to_tokens(tokens), - IdentLike::Const(_) => quote! { const }.to_tokens(tokens), - IdentLike::Async(_) => quote! { async }.to_tokens(tokens), - IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens), - IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens), - } - } -} - -impl Parse for IdentLike { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - Ok(if lookahead.peek(Token![const]) { - IdentLike::Const(input.parse::()?) - } else if lookahead.peek(Token![type]) { - IdentLike::Type(input.parse::()?) - } else if lookahead.peek(Token![async]) { - IdentLike::Async(input.parse::()?) - } else if lookahead.peek(Ident) { - IdentLike::Ident(input.parse::()?) - } else if lookahead.peek(LitInt) { - IdentLike::Integer(input.parse::()?) - } else { - return Err(lookahead.error()); - }) - } -} - -// Arguments declaration can loook like this: -// a{, b} -// That's why we don't parse Arguments as Punctuated -#[derive(PartialEq, Eq)] -pub struct Arguments(pub Vec); - -impl Parse for Arguments { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut result = Vec::new(); - loop { - if input.peek(Token![,]) { - input.parse::()?; - } - let mut optional = false; - let mut can_be_negated = false; - let mut pre_pipe = false; - let ident; - let lookahead = input.lookahead1(); - if lookahead.peek(token::Brace) { - let content; - braced!(content in input); - let lookahead = content.lookahead1(); - if lookahead.peek(Token![!]) { - content.parse::()?; - can_be_negated = true; - ident = input.parse::()?; - } else if lookahead.peek(Token![,]) { - optional = true; - content.parse::()?; - ident = content.parse::()?; - } else { - return Err(lookahead.error()); - } - } else if lookahead.peek(token::Bracket) { - let bracketed; - bracketed!(bracketed in input); - if bracketed.peek(Token![|]) { - optional = true; - bracketed.parse::()?; - pre_pipe = true; - ident = bracketed.parse::()?; - } else { - let mut sub_args = Self::parse(&bracketed)?; - sub_args.0.first_mut().unwrap().pre_bracket = true; - sub_args.0.last_mut().unwrap().post_bracket = true; - if peek_brace_token(input, Token![.]) { - let optional_suffix; - braced!(optional_suffix in input); - optional_suffix.parse::()?; - let unified_ident = optional_suffix.parse::()?; - if unified_ident.to_string() != "unified" { - return Err(syn::Error::new( - unified_ident.span(), - format!("Expected `unified`, got `{}`", unified_ident), - )); - } - for a in sub_args.0.iter_mut() { - a.unified = true; - } - } - result.extend(sub_args.0); - continue; - } - } else if lookahead.peek(Ident) { - ident = input.parse::()?; - } else if lookahead.peek(Token![|]) { - input.parse::()?; - pre_pipe = true; - ident = input.parse::()?; - } else { - break; - } - result.push(Argument { - optional, - pre_pipe, - can_be_negated, - pre_bracket: false, - ident, - post_bracket: false, - unified: false, - }); - } - Ok(Self(result)) - } -} - -// This is effectively input.peek(token::Brace) && input.peek2(Token![.]) -// input.peek2 is supposed to skip over next token, but it skips over whole -// braced token group. Not sure if it's a bug -fn peek_brace_token(input: syn::parse::ParseStream, _t: T) -> bool { - use syn::token::Token; - let cursor = input.cursor(); - cursor - .group(proc_macro2::Delimiter::Brace) - .map_or(false, |(content, ..)| T::Token::peek(content)) -} - -#[derive(PartialEq, Eq)] -pub struct Argument { - pub optional: bool, - pub pre_bracket: bool, - pub pre_pipe: bool, - pub can_be_negated: bool, - pub ident: HyphenatedIdent, - pub post_bracket: bool, - pub unified: bool, -} - -#[cfg(test)] -mod tests { - use super::{Arguments, DotModifier, MaybeDotModifier}; - use quote::{quote, ToTokens}; - - #[test] - fn parse_modifier_complex() { - let input = quote! { - .level::eviction_priority - }; - let modifier = syn::parse2::(input).unwrap(); - assert_eq!( - ". level :: eviction_priority", - modifier.tokens().to_string() - ); - } - - #[test] - fn parse_modifier_optional() { - let input = quote! { - { .level::eviction_priority } - }; - let maybe_modifider = syn::parse2::(input).unwrap(); - assert_eq!( - ". level :: eviction_priority", - maybe_modifider.modifier.tokens().to_string() - ); - assert!(maybe_modifider.optional); - } - - #[test] - fn parse_type_token() { - let input = quote! { - . type - }; - let maybe_modifier = syn::parse2::(input).unwrap(); - assert_eq!(". type", maybe_modifier.modifier.tokens().to_string()); - assert!(!maybe_modifier.optional); - } - - #[test] - fn arguments_memory() { - let input = quote! { - [a], b - }; - let arguments = syn::parse2::(input).unwrap(); - let a = &arguments.0[0]; - assert!(!a.optional); - assert_eq!("a", a.ident.to_string()); - assert!(a.pre_bracket); - assert!(!a.pre_pipe); - assert!(a.post_bracket); - assert!(!a.can_be_negated); - let b = &arguments.0[1]; - assert!(!b.optional); - assert_eq!("b", b.ident.to_string()); - assert!(!b.pre_bracket); - assert!(!b.pre_pipe); - assert!(!b.post_bracket); - assert!(!b.can_be_negated); - } - - #[test] - fn arguments_optional() { - let input = quote! { - b{, cache_policy} - }; - let arguments = syn::parse2::(input).unwrap(); - let b = &arguments.0[0]; - assert!(!b.optional); - assert_eq!("b", b.ident.to_string()); - assert!(!b.pre_bracket); - assert!(!b.pre_pipe); - assert!(!b.post_bracket); - assert!(!b.can_be_negated); - let cache_policy = &arguments.0[1]; - assert!(cache_policy.optional); - assert_eq!("cache_policy", cache_policy.ident.to_string()); - assert!(!cache_policy.pre_bracket); - assert!(!cache_policy.pre_pipe); - assert!(!cache_policy.post_bracket); - assert!(!cache_policy.can_be_negated); - } - - #[test] - fn arguments_optional_pred() { - let input = quote! { - p[|q], a - }; - let arguments = syn::parse2::(input).unwrap(); - assert_eq!(arguments.0.len(), 3); - let p = &arguments.0[0]; - assert!(!p.optional); - assert_eq!("p", p.ident.to_string()); - assert!(!p.pre_bracket); - assert!(!p.pre_pipe); - assert!(!p.post_bracket); - assert!(!p.can_be_negated); - let q = &arguments.0[1]; - assert!(q.optional); - assert_eq!("q", q.ident.to_string()); - assert!(!q.pre_bracket); - assert!(q.pre_pipe); - assert!(!q.post_bracket); - assert!(!q.can_be_negated); - let a = &arguments.0[2]; - assert!(!a.optional); - assert_eq!("a", a.ident.to_string()); - assert!(!a.pre_bracket); - assert!(!a.pre_pipe); - assert!(!a.post_bracket); - assert!(!a.can_be_negated); - } - - #[test] - fn arguments_optional_with_negate() { - let input = quote! { - b, {!}c - }; - let arguments = syn::parse2::(input).unwrap(); - assert_eq!(arguments.0.len(), 2); - let b = &arguments.0[0]; - assert!(!b.optional); - assert_eq!("b", b.ident.to_string()); - assert!(!b.pre_bracket); - assert!(!b.pre_pipe); - assert!(!b.post_bracket); - assert!(!b.can_be_negated); - let c = &arguments.0[1]; - assert!(!c.optional); - assert_eq!("c", c.ident.to_string()); - assert!(!c.pre_bracket); - assert!(!c.pre_pipe); - assert!(!c.post_bracket); - assert!(c.can_be_negated); - } - - #[test] - fn arguments_tex() { - let input = quote! { - d[|p], [a{, b}, c], dpdx, dpdy {, e} - }; - let arguments = syn::parse2::(input).unwrap(); - assert_eq!(arguments.0.len(), 8); - { - let d = &arguments.0[0]; - assert!(!d.optional); - assert_eq!("d", d.ident.to_string()); - assert!(!d.pre_bracket); - assert!(!d.pre_pipe); - assert!(!d.post_bracket); - assert!(!d.can_be_negated); - } - { - let p = &arguments.0[1]; - assert!(p.optional); - assert_eq!("p", p.ident.to_string()); - assert!(!p.pre_bracket); - assert!(p.pre_pipe); - assert!(!p.post_bracket); - assert!(!p.can_be_negated); - } - { - let a = &arguments.0[2]; - assert!(!a.optional); - assert_eq!("a", a.ident.to_string()); - assert!(a.pre_bracket); - assert!(!a.pre_pipe); - assert!(!a.post_bracket); - assert!(!a.can_be_negated); - } - { - let b = &arguments.0[3]; - assert!(b.optional); - assert_eq!("b", b.ident.to_string()); - assert!(!b.pre_bracket); - assert!(!b.pre_pipe); - assert!(!b.post_bracket); - assert!(!b.can_be_negated); - } - { - let c = &arguments.0[4]; - assert!(!c.optional); - assert_eq!("c", c.ident.to_string()); - assert!(!c.pre_bracket); - assert!(!c.pre_pipe); - assert!(c.post_bracket); - assert!(!c.can_be_negated); - } - { - let dpdx = &arguments.0[5]; - assert!(!dpdx.optional); - assert_eq!("dpdx", dpdx.ident.to_string()); - assert!(!dpdx.pre_bracket); - assert!(!dpdx.pre_pipe); - assert!(!dpdx.post_bracket); - assert!(!dpdx.can_be_negated); - } - { - let dpdy = &arguments.0[6]; - assert!(!dpdy.optional); - assert_eq!("dpdy", dpdy.ident.to_string()); - assert!(!dpdy.pre_bracket); - assert!(!dpdy.pre_pipe); - assert!(!dpdy.post_bracket); - assert!(!dpdy.can_be_negated); - } - { - let e = &arguments.0[7]; - assert!(e.optional); - assert_eq!("e", e.ident.to_string()); - assert!(!e.pre_bracket); - assert!(!e.pre_pipe); - assert!(!e.post_bracket); - assert!(!e.can_be_negated); - } - } - - #[test] - fn rule_multi() { - let input = quote! { - .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} } - }; - let rule = syn::parse2::(input).unwrap(); - assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string()); - assert_eq!( - "StateSpace", - rule.type_.unwrap().to_token_stream().to_string() - ); - let alts = rule - .alternatives - .iter() - .map(|m| m.tokens().to_string()) - .collect::>(); - assert_eq!( - vec![ - ". global", - ". local", - ". param", - ". param :: func", - ". shared", - ". shared :: cta", - ". shared :: cluster" - ], - alts - ); - } - - #[test] - fn rule_multi2() { - let input = quote! { - .cop: StCacheOperator = { .wb, .cg, .cs, .wt } - }; - let rule = syn::parse2::(input).unwrap(); - assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string()); - assert_eq!( - "StCacheOperator", - rule.type_.unwrap().to_token_stream().to_string() - ); - let alts = rule - .alternatives - .iter() - .map(|m| m.tokens().to_string()) - .collect::>(); - assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts); - } - - #[test] - fn args_unified() { - let input = quote! { - d, [a]{.unified}{, cache_policy} - }; - let args = syn::parse2::(input).unwrap(); - let a = &args.0[1]; - assert!(!a.optional); - assert_eq!("a", a.ident.to_string()); - assert!(a.pre_bracket); - assert!(!a.pre_pipe); - assert!(a.post_bracket); - assert!(!a.can_be_negated); - assert!(a.unified); - } - - #[test] - fn args_hyphenated() { - let input = quote! { - d, cp-size, b - }; - let args = syn::parse2::(input).unwrap(); - let cp_size = &args.0[1]; - assert!(!cp_size.optional); - assert_eq!("cp_size", cp_size.ident.to_string()); - assert!(!cp_size.pre_bracket); - assert!(!cp_size.pre_pipe); - assert!(!cp_size.post_bracket); - assert!(!cp_size.can_be_negated); - assert!(!cp_size.unified); - } - - #[test] - fn special_block() { - let input = quote! { - bra <= { bra(stream) } - }; - syn::parse2::(input).unwrap(); - } -} +use proc_macro2::Span; +use proc_macro2::TokenStream; +use quote::quote; +use quote::ToTokens; +use rustc_hash::FxHashMap; +use std::fmt::Write; +use syn::bracketed; +use syn::parse::Peek; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::LitInt; +use syn::Type; +use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token}; + +pub struct ParseDefinitions { + pub token_type: ItemEnum, + pub additional_enums: FxHashMap, + pub definitions: Vec, +} + +impl Parse for ParseDefinitions { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let token_type = input.parse::()?; + let mut additional_enums = FxHashMap::default(); + let mut definitions = Vec::new(); + loop { + if input.is_empty() { + break; + } + + let lookahead = input.lookahead1(); + if lookahead.peek(Token![#]) { + let enum_ = input.parse::()?; + additional_enums.insert(enum_.ident.clone(), enum_); + } else if lookahead.peek(Ident) { + definitions.push(input.parse::()?); + } else { + return Err(lookahead.error()); + } + } + + Ok(Self { + token_type, + additional_enums, + definitions, + }) + } +} + +pub struct OpcodeDefinition(pub Patterns, pub Vec); + +impl Parse for OpcodeDefinition { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let patterns = input.parse::()?; + let mut rules = Vec::new(); + while Rule::peek(input) { + rules.push(input.parse::()?); + input.parse::()?; + } + Ok(Self(patterns, rules)) + } +} + +pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>); + +impl Parse for Patterns { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut result = Vec::new(); + loop { + if !OpcodeDecl::peek(input) { + break; + } + let decl = input.parse::()?; + let code_block = input.parse::()?; + result.push((decl, code_block)) + } + Ok(Self(result)) + } +} + +pub struct OpcodeDecl(pub Instruction, pub Arguments); + +impl OpcodeDecl { + fn peek(input: syn::parse::ParseStream) -> bool { + Instruction::peek(input) && !input.peek2(Token![=]) + } +} + +impl Parse for OpcodeDecl { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self( + input.parse::()?, + input.parse::()?, + )) + } +} + +pub struct CodeBlock { + pub special: bool, + pub code: proc_macro2::Group, +} + +impl Parse for CodeBlock { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + let (special, code) = if lookahead.peek(Token![<]) { + input.parse::()?; + input.parse::()?; + //input.parse::]>()?; + (true, input.parse::()?) + } else if lookahead.peek(Token![=]) { + input.parse::()?; + input.parse::]>()?; + (false, input.parse::()?) + } else { + return Err(lookahead.error()); + }; + Ok(Self { special, code }) + } +} + +pub struct Rule { + pub modifier: Option, + pub type_: Option, + pub alternatives: Vec, +} + +impl Rule { + fn peek(input: syn::parse::ParseStream) -> bool { + DotModifier::peek(input) + || (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>])) + } + + fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result> { + let mut result = Vec::new(); + Self::parse_with_alternative(input, &mut result)?; + loop { + if !input.peek(Token![,]) { + break; + } + input.parse::()?; + Self::parse_with_alternative(input, &mut result)?; + } + Ok(result) + } + + fn parse_with_alternative( + input: &syn::parse::ParseBuffer, + result: &mut Vec, + ) -> Result<(), syn::Error> { + input.parse::()?; + let part1 = input.parse::()?; + if input.peek(token::Brace) { + result.push(DotModifier { + part1: part1.clone(), + part2: None, + }); + let suffix_content; + braced!(suffix_content in input); + let suffixes = Punctuated::::parse_separated_nonempty( + &suffix_content, + )?; + for part2 in suffixes { + result.push(DotModifier { + part1: part1.clone(), + part2: Some(part2), + }); + } + } else if IdentOrTypeSuffix::peek(input) { + let part2 = Some(IdentOrTypeSuffix::parse(input)?); + result.push(DotModifier { part1, part2 }); + } else { + result.push(DotModifier { part1, part2: None }); + } + Ok(()) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +struct IdentOrTypeSuffix(IdentLike); + +impl IdentOrTypeSuffix { + fn span(&self) -> Span { + self.0.span() + } + + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Token![::]) + } +} + +impl ToTokens for IdentOrTypeSuffix { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.0; + quote! { :: #ident }.to_tokens(tokens) + } +} + +impl Parse for IdentOrTypeSuffix { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + Ok(Self(input.parse::()?)) + } +} + +impl Parse for Rule { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let (modifier, type_) = if DotModifier::peek(input) { + let modifier = Some(input.parse::()?); + if input.peek(Token![:]) { + input.parse::()?; + (modifier, Some(input.parse::()?)) + } else { + (modifier, None) + } + } else { + (None, Some(input.parse::()?)) + }; + input.parse::()?; + let content; + braced!(content in input); + let alternatives = Self::parse_alternatives(&content)?; + Ok(Self { + modifier, + type_, + alternatives, + }) + } +} + +pub struct Instruction { + pub name: Ident, + pub modifiers: Vec, +} +impl Instruction { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Ident) + } +} + +impl Parse for Instruction { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let instruction = input.parse::()?; + let mut modifiers = Vec::new(); + loop { + if !MaybeDotModifier::peek(input) { + break; + } + modifiers.push(MaybeDotModifier::parse(input)?); + } + Ok(Self { + name: instruction, + modifiers, + }) + } +} + +pub struct MaybeDotModifier { + pub optional: bool, + pub modifier: DotModifier, +} + +impl MaybeDotModifier { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(token::Brace) || DotModifier::peek(input) + } +} + +impl Parse for MaybeDotModifier { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(if input.peek(token::Brace) { + let content; + braced!(content in input); + let modifier = DotModifier::parse(&content)?; + Self { + modifier, + optional: true, + } + } else { + let modifier = DotModifier::parse(input)?; + Self { + modifier, + optional: false, + } + }) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +pub struct DotModifier { + part1: IdentLike, + part2: Option, +} + +impl std::fmt::Display for DotModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, ".")?; + self.part1.fmt(f)?; + if let Some(ref part2) = self.part2 { + write!(f, "::")?; + part2.0.fmt(f)?; + } + Ok(()) + } +} + +impl std::fmt::Debug for DotModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } +} + +impl DotModifier { + pub fn span(&self) -> Span { + let part1 = self.part1.span(); + if let Some(ref part2) = self.part2 { + part1.join(part2.span()).unwrap_or(part1) + } else { + part1 + } + } + + pub fn ident(&self) -> Ident { + let mut result = String::new(); + write!(&mut result, "{}", self.part1).unwrap(); + if let Some(ref part2) = self.part2 { + write!(&mut result, "_{}", part2.0).unwrap(); + } else { + match self.part1 { + IdentLike::Type(_) | IdentLike::Const(_) | IdentLike::Async(_) => result.push('_'), + IdentLike::Ident(_) | IdentLike::Integer(_) => {} + } + } + Ident::new(&result.to_ascii_lowercase(), self.span()) + } + + pub fn variant_capitalized(&self) -> Ident { + self.capitalized_impl(String::new()) + } + + pub fn dot_capitalized(&self) -> Ident { + self.capitalized_impl("Dot".to_string()) + } + + fn capitalized_impl(&self, prefix: String) -> Ident { + let mut temp = String::new(); + write!(&mut temp, "{}", &self.part1).unwrap(); + if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 { + write!(&mut temp, "_{}", part2).unwrap(); + } + let mut result = prefix; + let mut capitalize = true; + for c in temp.chars() { + if c == '_' { + capitalize = true; + continue; + } + // Special hack to emit `BF16`` instead of `Bf16`` + let c = if capitalize || c == 'f' && result.ends_with('B') { + capitalize = false; + c.to_ascii_uppercase() + } else { + c + }; + result.push(c); + } + Ident::new(&result, self.span()) + } + + pub fn tokens(&self) -> TokenStream { + let part1 = &self.part1; + let part2 = &self.part2; + match self.part2 { + None => quote! { . #part1 }, + Some(_) => quote! { . #part1 #part2 }, + } + } + + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Token![.]) + } +} + +impl Parse for DotModifier { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + let part1 = input.parse::()?; + if IdentOrTypeSuffix::peek(input) { + let part2 = Some(IdentOrTypeSuffix::parse(input)?); + Ok(Self { part1, part2 }) + } else { + Ok(Self { part1, part2: None }) + } + } +} + +#[derive(PartialEq, Eq)] +pub struct HyphenatedIdent { + idents: Punctuated, +} + +impl HyphenatedIdent { + fn span(&self) -> Span { + self.idents.span() + } + + pub fn ident(&self) -> Ident { + Ident::new(&self.to_string().to_string(), self.span()) + } +} + +impl std::fmt::Display for HyphenatedIdent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut idents = self.idents.iter(); + + if let Some(id) = idents.next() { + write!(f, "{}", id)?; + } + + for id in idents { + write!(f, "_{}", id)?; + } + + Ok(()) + } +} + +impl Parse for HyphenatedIdent { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let idents = Punctuated::parse_separated_nonempty(input)?; + Ok(Self { idents }) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +enum IdentLike { + Type(Token![type]), + Const(Token![const]), + Async(Token![async]), + Ident(Ident), + Integer(LitInt), +} + +impl IdentLike { + fn span(&self) -> Span { + match self { + IdentLike::Type(c) => c.span(), + IdentLike::Const(t) => t.span(), + IdentLike::Async(a) => a.span(), + IdentLike::Ident(i) => i.span(), + IdentLike::Integer(l) => l.span(), + } + } +} + +impl std::fmt::Display for IdentLike { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IdentLike::Type(_) => f.write_str("type"), + IdentLike::Const(_) => f.write_str("const"), + IdentLike::Async(_) => f.write_str("async"), + IdentLike::Ident(ident) => write!(f, "{}", ident), + IdentLike::Integer(integer) => write!(f, "{}", integer), + } + } +} + +impl ToTokens for IdentLike { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + IdentLike::Type(_) => quote! { type }.to_tokens(tokens), + IdentLike::Const(_) => quote! { const }.to_tokens(tokens), + IdentLike::Async(_) => quote! { async }.to_tokens(tokens), + IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens), + IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens), + } + } +} + +impl Parse for IdentLike { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + Ok(if lookahead.peek(Token![const]) { + IdentLike::Const(input.parse::()?) + } else if lookahead.peek(Token![type]) { + IdentLike::Type(input.parse::()?) + } else if lookahead.peek(Token![async]) { + IdentLike::Async(input.parse::()?) + } else if lookahead.peek(Ident) { + IdentLike::Ident(input.parse::()?) + } else if lookahead.peek(LitInt) { + IdentLike::Integer(input.parse::()?) + } else { + return Err(lookahead.error()); + }) + } +} + +// Arguments declaration can loook like this: +// a{, b} +// That's why we don't parse Arguments as Punctuated +#[derive(PartialEq, Eq)] +pub struct Arguments(pub Vec); + +impl Parse for Arguments { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut result = Vec::new(); + loop { + if input.peek(Token![,]) { + input.parse::()?; + } + let mut optional = false; + let mut can_be_negated = false; + let mut pre_pipe = false; + let ident; + let lookahead = input.lookahead1(); + if lookahead.peek(token::Brace) { + let content; + braced!(content in input); + let lookahead = content.lookahead1(); + if lookahead.peek(Token![!]) { + content.parse::()?; + can_be_negated = true; + ident = input.parse::()?; + } else if lookahead.peek(Token![,]) { + optional = true; + content.parse::()?; + ident = content.parse::()?; + } else { + return Err(lookahead.error()); + } + } else if lookahead.peek(token::Bracket) { + let bracketed; + bracketed!(bracketed in input); + if bracketed.peek(Token![|]) { + optional = true; + bracketed.parse::()?; + pre_pipe = true; + ident = bracketed.parse::()?; + } else { + let mut sub_args = Self::parse(&bracketed)?; + sub_args.0.first_mut().unwrap().pre_bracket = true; + sub_args.0.last_mut().unwrap().post_bracket = true; + if peek_brace_token(input, Token![.]) { + let optional_suffix; + braced!(optional_suffix in input); + optional_suffix.parse::()?; + let unified_ident = optional_suffix.parse::()?; + if unified_ident.to_string() != "unified" { + return Err(syn::Error::new( + unified_ident.span(), + format!("Expected `unified`, got `{}`", unified_ident), + )); + } + for a in sub_args.0.iter_mut() { + a.unified = true; + } + } + result.extend(sub_args.0); + continue; + } + } else if lookahead.peek(Ident) { + ident = input.parse::()?; + } else if lookahead.peek(Token![|]) { + input.parse::()?; + pre_pipe = true; + ident = input.parse::()?; + } else { + break; + } + result.push(Argument { + optional, + pre_pipe, + can_be_negated, + pre_bracket: false, + ident, + post_bracket: false, + unified: false, + }); + } + Ok(Self(result)) + } +} + +// This is effectively input.peek(token::Brace) && input.peek2(Token![.]) +// input.peek2 is supposed to skip over next token, but it skips over whole +// braced token group. Not sure if it's a bug +fn peek_brace_token(input: syn::parse::ParseStream, _t: T) -> bool { + use syn::token::Token; + let cursor = input.cursor(); + cursor + .group(proc_macro2::Delimiter::Brace) + .map_or(false, |(content, ..)| T::Token::peek(content)) +} + +#[derive(PartialEq, Eq)] +pub struct Argument { + pub optional: bool, + pub pre_bracket: bool, + pub pre_pipe: bool, + pub can_be_negated: bool, + pub ident: HyphenatedIdent, + pub post_bracket: bool, + pub unified: bool, +} + +#[cfg(test)] +mod tests { + use super::{Arguments, DotModifier, MaybeDotModifier}; + use quote::{quote, ToTokens}; + + #[test] + fn parse_modifier_complex() { + let input = quote! { + .level::eviction_priority + }; + let modifier = syn::parse2::(input).unwrap(); + assert_eq!( + ". level :: eviction_priority", + modifier.tokens().to_string() + ); + } + + #[test] + fn parse_modifier_optional() { + let input = quote! { + { .level::eviction_priority } + }; + let maybe_modifider = syn::parse2::(input).unwrap(); + assert_eq!( + ". level :: eviction_priority", + maybe_modifider.modifier.tokens().to_string() + ); + assert!(maybe_modifider.optional); + } + + #[test] + fn parse_type_token() { + let input = quote! { + . type + }; + let maybe_modifier = syn::parse2::(input).unwrap(); + assert_eq!(". type", maybe_modifier.modifier.tokens().to_string()); + assert!(!maybe_modifier.optional); + } + + #[test] + fn arguments_memory() { + let input = quote! { + [a], b + }; + let arguments = syn::parse2::(input).unwrap(); + let a = &arguments.0[0]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(a.post_bracket); + assert!(!a.can_be_negated); + let b = &arguments.0[1]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + } + + #[test] + fn arguments_optional() { + let input = quote! { + b{, cache_policy} + }; + let arguments = syn::parse2::(input).unwrap(); + let b = &arguments.0[0]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + let cache_policy = &arguments.0[1]; + assert!(cache_policy.optional); + assert_eq!("cache_policy", cache_policy.ident.to_string()); + assert!(!cache_policy.pre_bracket); + assert!(!cache_policy.pre_pipe); + assert!(!cache_policy.post_bracket); + assert!(!cache_policy.can_be_negated); + } + + #[test] + fn arguments_optional_pred() { + let input = quote! { + p[|q], a + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 3); + let p = &arguments.0[0]; + assert!(!p.optional); + assert_eq!("p", p.ident.to_string()); + assert!(!p.pre_bracket); + assert!(!p.pre_pipe); + assert!(!p.post_bracket); + assert!(!p.can_be_negated); + let q = &arguments.0[1]; + assert!(q.optional); + assert_eq!("q", q.ident.to_string()); + assert!(!q.pre_bracket); + assert!(q.pre_pipe); + assert!(!q.post_bracket); + assert!(!q.can_be_negated); + let a = &arguments.0[2]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(!a.pre_bracket); + assert!(!a.pre_pipe); + assert!(!a.post_bracket); + assert!(!a.can_be_negated); + } + + #[test] + fn arguments_optional_with_negate() { + let input = quote! { + b, {!}c + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 2); + let b = &arguments.0[0]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + let c = &arguments.0[1]; + assert!(!c.optional); + assert_eq!("c", c.ident.to_string()); + assert!(!c.pre_bracket); + assert!(!c.pre_pipe); + assert!(!c.post_bracket); + assert!(c.can_be_negated); + } + + #[test] + fn arguments_tex() { + let input = quote! { + d[|p], [a{, b}, c], dpdx, dpdy {, e} + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 8); + { + let d = &arguments.0[0]; + assert!(!d.optional); + assert_eq!("d", d.ident.to_string()); + assert!(!d.pre_bracket); + assert!(!d.pre_pipe); + assert!(!d.post_bracket); + assert!(!d.can_be_negated); + } + { + let p = &arguments.0[1]; + assert!(p.optional); + assert_eq!("p", p.ident.to_string()); + assert!(!p.pre_bracket); + assert!(p.pre_pipe); + assert!(!p.post_bracket); + assert!(!p.can_be_negated); + } + { + let a = &arguments.0[2]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(!a.post_bracket); + assert!(!a.can_be_negated); + } + { + let b = &arguments.0[3]; + assert!(b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + } + { + let c = &arguments.0[4]; + assert!(!c.optional); + assert_eq!("c", c.ident.to_string()); + assert!(!c.pre_bracket); + assert!(!c.pre_pipe); + assert!(c.post_bracket); + assert!(!c.can_be_negated); + } + { + let dpdx = &arguments.0[5]; + assert!(!dpdx.optional); + assert_eq!("dpdx", dpdx.ident.to_string()); + assert!(!dpdx.pre_bracket); + assert!(!dpdx.pre_pipe); + assert!(!dpdx.post_bracket); + assert!(!dpdx.can_be_negated); + } + { + let dpdy = &arguments.0[6]; + assert!(!dpdy.optional); + assert_eq!("dpdy", dpdy.ident.to_string()); + assert!(!dpdy.pre_bracket); + assert!(!dpdy.pre_pipe); + assert!(!dpdy.post_bracket); + assert!(!dpdy.can_be_negated); + } + { + let e = &arguments.0[7]; + assert!(e.optional); + assert_eq!("e", e.ident.to_string()); + assert!(!e.pre_bracket); + assert!(!e.pre_pipe); + assert!(!e.post_bracket); + assert!(!e.can_be_negated); + } + } + + #[test] + fn rule_multi() { + let input = quote! { + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} } + }; + let rule = syn::parse2::(input).unwrap(); + assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string()); + assert_eq!( + "StateSpace", + rule.type_.unwrap().to_token_stream().to_string() + ); + let alts = rule + .alternatives + .iter() + .map(|m| m.tokens().to_string()) + .collect::>(); + assert_eq!( + vec![ + ". global", + ". local", + ". param", + ". param :: func", + ". shared", + ". shared :: cta", + ". shared :: cluster" + ], + alts + ); + } + + #[test] + fn rule_multi2() { + let input = quote! { + .cop: StCacheOperator = { .wb, .cg, .cs, .wt } + }; + let rule = syn::parse2::(input).unwrap(); + assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string()); + assert_eq!( + "StCacheOperator", + rule.type_.unwrap().to_token_stream().to_string() + ); + let alts = rule + .alternatives + .iter() + .map(|m| m.tokens().to_string()) + .collect::>(); + assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts); + } + + #[test] + fn args_unified() { + let input = quote! { + d, [a]{.unified}{, cache_policy} + }; + let args = syn::parse2::(input).unwrap(); + let a = &args.0[1]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(a.post_bracket); + assert!(!a.can_be_negated); + assert!(a.unified); + } + + #[test] + fn args_hyphenated() { + let input = quote! { + d, cp-size, b + }; + let args = syn::parse2::(input).unwrap(); + let cp_size = &args.0[1]; + assert!(!cp_size.optional); + assert_eq!("cp_size", cp_size.ident.to_string()); + assert!(!cp_size.pre_bracket); + assert!(!cp_size.pre_pipe); + assert!(!cp_size.post_bracket); + assert!(!cp_size.can_be_negated); + assert!(!cp_size.unified); + } + + #[test] + fn special_block() { + let input = quote! { + bra <= { bra(stream) } + }; + syn::parse2::(input).unwrap(); + } +} diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 429770a..c7f178d 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -1,7 +1,7 @@ +use super::{context, driver}; use cuda_types::cuda::*; use hip_runtime_sys::*; use std::{mem, ptr}; -use super::{driver, context}; const PROJECT_SUFFIX: &[u8] = b" [ZLUDA]\0"; pub const COMPUTE_CAPABILITY_MAJOR: i32 = 8; @@ -462,22 +462,21 @@ fn clamp_usize(x: usize) -> i32 { usize::min(x, i32::MAX as usize) as i32 } -pub(crate) fn get_primary_context(hip_dev: hipDevice_t) -> Result<(&'static context::Context, CUcontext), CUerror> { +pub(crate) fn get_primary_context( + hip_dev: hipDevice_t, +) -> Result<(&'static context::Context, CUcontext), CUerror> { let dev: &'static driver::Device = driver::device(hip_dev)?; Ok(dev.primary_context()) } -pub(crate) fn primary_context_retain( - pctx: &mut CUcontext, - hip_dev: hipDevice_t, -) -> CUresult { +pub(crate) fn primary_context_retain(pctx: &mut CUcontext, hip_dev: hipDevice_t) -> CUresult { let (ctx, cu_ctx) = get_primary_context(hip_dev)?; - + ctx.with_state_mut(|state: &mut context::ContextState| { state.ref_count += 1; Ok(()) })?; - + *pctx = cu_ctx; Ok(()) } @@ -497,8 +496,6 @@ pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> CUresult { pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult { let (ctx, _) = get_primary_context(hip_dev)?; - ctx.with_state_mut(|state| { - state.reset() - })?; + ctx.with_state_mut(|state| state.reset())?; Ok(()) -} \ No newline at end of file +} diff --git a/zluda/src/impl/library.rs b/zluda/src/impl/library.rs index 2a5ad08..c5f60b1 100644 --- a/zluda/src/impl/library.rs +++ b/zluda/src/impl/library.rs @@ -38,10 +38,7 @@ pub(crate) unsafe fn unload(library: CUlibrary) -> CUresult { super::drop_checked::(library) } -pub(crate) unsafe fn get_module( - out: &mut CUmodule, - library: &Library, -) -> CUresult { - *out = module::Module{base: library.base}.wrap(); +pub(crate) unsafe fn get_module(out: &mut CUmodule, library: &Library) -> CUresult { + *out = module::Module { base: library.base }.wrap(); Ok(()) } diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index da37a22..797bec0 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -68,7 +68,9 @@ pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result CUre pub(crate) fn unload(hmod: CUmodule) -> CUresult { super::drop_checked::(hmod) - } pub(crate) fn get_function( diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index b6ced82..0de04b1 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -1,11 +1,10 @@ use cuda_types::cuda::CUerror; use std::sync::atomic::{AtomicBool, Ordering}; - +pub(crate) mod r#impl; #[cfg_attr(windows, path = "os_win.rs")] #[cfg_attr(not(windows), path = "os_unix.rs")] mod os; -pub(crate) mod r#impl; static INITIALIZED: AtomicBool = AtomicBool::new(true); pub(crate) fn initialized() -> bool { @@ -66,61 +65,60 @@ macro_rules! implemented_in_function { cuda_macros::cuda_function_declarations!( unimplemented, - implemented <= [ - cuCtxCreate_v2, - cuCtxDestroy_v2, - cuCtxGetLimit, - cuCtxSetCurrent, - cuCtxGetCurrent, - cuCtxGetDevice, - cuCtxSetLimit, - cuCtxSynchronize, - cuCtxPushCurrent, - cuCtxPushCurrent_v2, - cuCtxPopCurrent, - cuCtxPopCurrent_v2, - cuDeviceComputeCapability, - cuDeviceGet, - cuDeviceGetAttribute, - cuDeviceGetCount, - cuDeviceGetLuid, - cuDeviceGetName, - cuDeviceGetProperties, - cuDeviceGetUuid, - cuDeviceGetUuid_v2, - cuDevicePrimaryCtxRelease, - cuDevicePrimaryCtxRetain, - cuDevicePrimaryCtxReset, - cuDeviceTotalMem_v2, - cuDriverGetVersion, - cuFuncGetAttribute, - cuGetExportTable, - cuGetProcAddress, - cuGetProcAddress_v2, - cuInit, - cuLibraryLoadData, - cuLibraryGetModule, - cuLibraryUnload, - cuMemAlloc_v2, - cuMemFree_v2, - cuMemHostAlloc, - cuMemFreeHost, - cuMemGetAddressRange_v2, - cuMemGetInfo_v2, - cuMemcpyDtoH_v2, - cuMemcpyHtoD_v2, - cuMemsetD32_v2, - cuMemsetD8_v2, - cuModuleGetFunction, - cuModuleGetLoadingMode, - cuModuleLoadData, - cuModuleUnload, - cuPointerGetAttribute, - cuStreamSynchronize, - cuProfilerStart, - cuProfilerStop, - ], - implemented_in_function <= [ - cuLaunchKernel, - ] + implemented + <= [ + cuCtxCreate_v2, + cuCtxDestroy_v2, + cuCtxGetLimit, + cuCtxSetCurrent, + cuCtxGetCurrent, + cuCtxGetDevice, + cuCtxSetLimit, + cuCtxSynchronize, + cuCtxPushCurrent, + cuCtxPushCurrent_v2, + cuCtxPopCurrent, + cuCtxPopCurrent_v2, + cuDeviceComputeCapability, + cuDeviceGet, + cuDeviceGetAttribute, + cuDeviceGetCount, + cuDeviceGetLuid, + cuDeviceGetName, + cuDeviceGetProperties, + cuDeviceGetUuid, + cuDeviceGetUuid_v2, + cuDevicePrimaryCtxRelease, + cuDevicePrimaryCtxRetain, + cuDevicePrimaryCtxReset, + cuDeviceTotalMem_v2, + cuDriverGetVersion, + cuFuncGetAttribute, + cuGetExportTable, + cuGetProcAddress, + cuGetProcAddress_v2, + cuInit, + cuLibraryLoadData, + cuLibraryGetModule, + cuLibraryUnload, + cuMemAlloc_v2, + cuMemFree_v2, + cuMemHostAlloc, + cuMemFreeHost, + cuMemGetAddressRange_v2, + cuMemGetInfo_v2, + cuMemcpyDtoH_v2, + cuMemcpyHtoD_v2, + cuMemsetD32_v2, + cuMemsetD8_v2, + cuModuleGetFunction, + cuModuleGetLoadingMode, + cuModuleLoadData, + cuModuleUnload, + cuPointerGetAttribute, + cuStreamSynchronize, + cuProfilerStart, + cuProfilerStop, + ], + implemented_in_function <= [cuLaunchKernel,] ); diff --git a/zluda/src/os_unix.rs b/zluda/src/os_unix.rs index e69de29..8b13789 100644 --- a/zluda/src/os_unix.rs +++ b/zluda/src/os_unix.rs @@ -0,0 +1 @@ + diff --git a/zluda_blas/src/impl.rs b/zluda_blas/src/impl.rs index feb95e2..55b1edc 100644 --- a/zluda_blas/src/impl.rs +++ b/zluda_blas/src/impl.rs @@ -10,15 +10,11 @@ pub(crate) fn unimplemented() -> cublasStatus_t { cublasStatus_t::ERROR_NOT_SUPPORTED } -pub(crate) fn get_status_name( - _status: cublasStatus_t, -) -> *const ::core::ffi::c_char { +pub(crate) fn get_status_name(_status: cublasStatus_t) -> *const ::core::ffi::c_char { todo!() } -pub(crate) fn get_status_string( - _status: cublasStatus_t, -) -> *const ::core::ffi::c_char { +pub(crate) fn get_status_string(_status: cublasStatus_t) -> *const ::core::ffi::c_char { todo!() } diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index ed86c01..09bb28f 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -28,10 +28,11 @@ macro_rules! implemented { cuda_macros::cublas_function_declarations!( unimplemented, - implemented <= [ - cublasGetStatusName, - cublasGetStatusString, - cublasXerbla, - cublasGetCudartVersion, - ] + implemented + <= [ + cublasGetStatusName, + cublasGetStatusString, + cublasXerbla, + cublasGetCudartVersion + ] ); diff --git a/zluda_blaslt/src/impl.rs b/zluda_blaslt/src/impl.rs index 8b67915..f66328d 100644 --- a/zluda_blaslt/src/impl.rs +++ b/zluda_blaslt/src/impl.rs @@ -31,8 +31,6 @@ pub(crate) fn get_cudart_version() -> usize { } #[allow(non_snake_case)] -pub(crate) fn disable_cpu_instructions_set_mask( - _mask: ::core::ffi::c_uint, -) -> ::core::ffi::c_uint { +pub(crate) fn disable_cpu_instructions_set_mask(_mask: ::core::ffi::c_uint) -> ::core::ffi::c_uint { todo!() } diff --git a/zluda_blaslt/src/lib.rs b/zluda_blaslt/src/lib.rs index 326ac0a..603e191 100644 --- a/zluda_blaslt/src/lib.rs +++ b/zluda_blaslt/src/lib.rs @@ -28,11 +28,12 @@ macro_rules! implemented { cuda_macros::cublaslt_function_declarations!( unimplemented, - implemented <= [ - cublasLtGetStatusName, - cublasLtGetStatusString, - cublasLtDisableCpuInstructionsSetMask, - cublasLtGetVersion, - cublasLtGetCudartVersion - ] + implemented + <= [ + cublasLtGetStatusName, + cublasLtGetStatusString, + cublasLtDisableCpuInstructionsSetMask, + cublasLtGetVersion, + cublasLtGetCudartVersion + ] ); diff --git a/zluda_dnn/src/lib.rs b/zluda_dnn/src/lib.rs index a744a59..897e106 100644 --- a/zluda_dnn/src/lib.rs +++ b/zluda_dnn/src/lib.rs @@ -28,11 +28,12 @@ macro_rules! implemented { cuda_macros::cudnn9_function_declarations!( unimplemented, - implemented <= [ - cudnnGetVersion, - cudnnGetMaxDeviceVersion, - cudnnGetCudartVersion, - cudnnGetErrorString, - cudnnGetLastErrorString - ] + implemented + <= [ + cudnnGetVersion, + cudnnGetMaxDeviceVersion, + cudnnGetCudartVersion, + cudnnGetErrorString, + cudnnGetLastErrorString + ] ); diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index f405f72..c2acbcf 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -420,13 +420,13 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApiDump { CONTEXT_LOCAL_STORAGE_INTERFACE_V0301 { [0] = context_local_storage_put( context: cuda_types::cuda::CUcontext, - key: *mut std::ffi::c_void, + key: *mut std::ffi::c_void, value: *mut std::ffi::c_void, // clsContextDestroyCallback, have to be called on cuDevicePrimaryCtxReset dtor_cb: Option ) -> cuda_types::cuda::CUresult, [1] = context_local_storage_delete( @@ -434,9 +434,9 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApiDump { key: *mut std::ffi::c_void ) -> cuda_types::cuda::CUresult, [2] = context_local_storage_get( - value: *mut *mut std::ffi::c_void, + value: *mut *mut std::ffi::c_void, cu_ctx: cuda_types::cuda::CUcontext, - key: *mut std::ffi::c_void + key: *mut std::ffi::c_void ) -> cuda_types::cuda::CUresult } } diff --git a/zluda_fft/src/lib.rs b/zluda_fft/src/lib.rs index 8961714..7de894e 100644 --- a/zluda_fft/src/lib.rs +++ b/zluda_fft/src/lib.rs @@ -13,6 +13,4 @@ macro_rules! unimplemented { }; } -cuda_macros::cufft_function_declarations!( - unimplemented -); +cuda_macros::cufft_function_declarations!(unimplemented); diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs index accb048..369b868 100644 --- a/zluda_ml/src/impl.rs +++ b/zluda_ml/src/impl.rs @@ -1,41 +1,39 @@ -use cuda_types::nvml::*; -use std::{ffi::CStr, ptr}; - -#[cfg(debug_assertions)] -pub(crate) fn unimplemented() -> nvmlReturn_t { - unimplemented!() -} - -#[cfg(not(debug_assertions))] -pub(crate) fn unimplemented() -> nvmlReturn_t { - nvmlReturn_t::ERROR_NOT_SUPPORTED -} - -pub(crate) fn error_string( - _result: cuda_types::nvml::nvmlReturn_t, -) -> *const ::core::ffi::c_char { - c"".as_ptr() -} - -pub(crate) fn init_v2() -> cuda_types::nvml::nvmlReturn_t { - nvmlReturn_t::SUCCESS -} - -const VERSION: &'static CStr = c"550.77"; - -pub(crate) fn system_get_driver_version( - result: *mut ::core::ffi::c_char, - length: ::core::ffi::c_uint, -) -> cuda_types::nvml::nvmlReturn_t { - if result == ptr::null_mut() { - return nvmlReturn_t::ERROR_INVALID_ARGUMENT; - } - let version = VERSION.to_bytes_with_nul(); - let copy_length = usize::min(length as usize, version.len()); - let slice = unsafe { std::slice::from_raw_parts_mut(result.cast(), copy_length) }; - slice.copy_from_slice(&version[..copy_length]); - if let Some(null) = slice.last_mut() { - *null = 0; - } - nvmlReturn_t::SUCCESS -} +use cuda_types::nvml::*; +use std::{ffi::CStr, ptr}; + +#[cfg(debug_assertions)] +pub(crate) fn unimplemented() -> nvmlReturn_t { + unimplemented!() +} + +#[cfg(not(debug_assertions))] +pub(crate) fn unimplemented() -> nvmlReturn_t { + nvmlReturn_t::ERROR_NOT_SUPPORTED +} + +pub(crate) fn error_string(_result: cuda_types::nvml::nvmlReturn_t) -> *const ::core::ffi::c_char { + c"".as_ptr() +} + +pub(crate) fn init_v2() -> cuda_types::nvml::nvmlReturn_t { + nvmlReturn_t::SUCCESS +} + +const VERSION: &'static CStr = c"550.77"; + +pub(crate) fn system_get_driver_version( + result: *mut ::core::ffi::c_char, + length: ::core::ffi::c_uint, +) -> cuda_types::nvml::nvmlReturn_t { + if result == ptr::null_mut() { + return nvmlReturn_t::ERROR_INVALID_ARGUMENT; + } + let version = VERSION.to_bytes_with_nul(); + let copy_length = usize::min(length as usize, version.len()); + let slice = unsafe { std::slice::from_raw_parts_mut(result.cast(), copy_length) }; + slice.copy_from_slice(&version[..copy_length]); + if let Some(null) = slice.last_mut() { + *null = 0; + } + nvmlReturn_t::SUCCESS +} diff --git a/zluda_ml/src/lib.rs b/zluda_ml/src/lib.rs index d65fae8..81f2f2a 100644 --- a/zluda_ml/src/lib.rs +++ b/zluda_ml/src/lib.rs @@ -26,9 +26,5 @@ macro_rules! implemented_fn { cuda_macros::nvml_function_declarations!( unimplemented_fn, - implemented_fn <= [ - nvmlErrorString, - nvmlInit_v2, - nvmlSystemGetDriverVersion - ] + implemented_fn <= [nvmlErrorString, nvmlInit_v2, nvmlSystemGetDriverVersion] ); diff --git a/zluda_sparse/src/lib.rs b/zluda_sparse/src/lib.rs index 795f680..6c30741 100644 --- a/zluda_sparse/src/lib.rs +++ b/zluda_sparse/src/lib.rs @@ -28,12 +28,13 @@ macro_rules! implemented { cuda_macros::cusparse_function_declarations!( unimplemented, - implemented <= [ - cusparseGetErrorName, - cusparseGetErrorString, - cusparseGetMatIndexBase, - cusparseGetMatDiagType, - cusparseGetMatFillMode, - cusparseGetMatType - ] + implemented + <= [ + cusparseGetErrorName, + cusparseGetErrorString, + cusparseGetMatIndexBase, + cusparseGetMatDiagType, + cusparseGetMatFillMode, + cusparseGetMatType + ] );