ZLUDA/ptx/src/translate.rs
2020-11-06 00:56:45 +01:00

7346 lines
273 KiB
Rust

use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
use std::{borrow::Cow, convert::TryFrom, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
};
use rspirv::binary::Assemble;
static NOTCUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/notcuda_ptx_impl.spv");
quick_error! {
#[derive(Debug)]
pub enum TranslateError {
UnknownSymbol {}
UntypedSymbol {}
MismatchedType {}
Spirv(err: rspirv::dr::Error) {
from()
display("{}", err)
cause(err)
}
Unreachable {}
Todo {}
}
}
#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
Vector(SpirvScalarKey, u8),
Array(SpirvScalarKey, Vec<u32>),
Pointer(Box<SpirvType>, spirv::StorageClass),
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
Struct(Vec<SpirvScalarKey>),
}
impl SpirvType {
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
let key = t.into();
SpirvType::Pointer(Box::new(key), sc)
}
}
impl From<ast::Type> for SpirvType {
fn from(t: ast::Type) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer(
Box::new(SpirvType::from(ast::Type::from(pointer_t))),
state_space.to_spirv(),
),
}
}
}
impl From<ast::PointerType> for ast::Type {
fn from(t: ast::PointerType) -> Self {
match t {
ast::PointerType::Scalar(t) => ast::Type::Scalar(t),
ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len),
ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims),
ast::PointerType::Pointer(t, space) => {
ast::Type::Pointer(ast::PointerType::Scalar(t), space)
}
}
}
}
impl ast::Type {
fn pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
Ok(match self {
ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
ast::Type::Vector(t, len) => {
ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
}
ast::Type::Array(t, dims) => {
ast::Type::Pointer(ast::PointerType::Array(t, dims), space)
}
ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
}
ast::Type::Pointer(_, _) => return Err(TranslateError::Unreachable),
})
}
}
impl Into<spirv::StorageClass> for ast::PointerStateSpace {
fn into(self) -> spirv::StorageClass {
match self {
ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
ast::PointerStateSpace::Param => spirv::StorageClass::Function,
ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
}
}
}
impl From<ast::ScalarType> for SpirvType {
fn from(t: ast::ScalarType) -> Self {
SpirvType::Base(t.into())
}
}
struct TypeWordMap {
void: spirv::Word,
complex: HashMap<SpirvType, spirv::Word>,
constants: HashMap<(SpirvType, u64), spirv::Word>,
}
// SPIR-V integer type definitions are signless, more below:
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvScalarKey {
B8,
B16,
B32,
B64,
F16,
F32,
F64,
Pred,
F16x2,
}
impl From<ast::ScalarType> for SpirvScalarKey {
fn from(t: ast::ScalarType) -> Self {
match t {
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8,
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
SpirvScalarKey::B16
}
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => {
SpirvScalarKey::B32
}
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => {
SpirvScalarKey::B64
}
ast::ScalarType::F16 => SpirvScalarKey::F16,
ast::ScalarType::F32 => SpirvScalarKey::F32,
ast::ScalarType::F64 => SpirvScalarKey::F64,
ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
ast::ScalarType::Pred => SpirvScalarKey::Pred,
}
}
}
impl TypeWordMap {
fn new(b: &mut dr::Builder) -> TypeWordMap {
let void = b.type_void();
TypeWordMap {
void: void,
complex: HashMap::<SpirvType, spirv::Word>::new(),
constants: HashMap::new(),
}
}
fn void(&self) -> spirv::Word {
self.void
}
fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
let key: SpirvScalarKey = t.into();
self.get_or_add_spirv_scalar(b, key)
}
fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> spirv::Word {
*self
.complex
.entry(SpirvType::Base(key))
.or_insert_with(|| match key {
SpirvScalarKey::B8 => b.type_int(8, 0),
SpirvScalarKey::B16 => b.type_int(16, 0),
SpirvScalarKey::B32 => b.type_int(32, 0),
SpirvScalarKey::B64 => b.type_int(64, 0),
SpirvScalarKey::F16 => b.type_float(16),
SpirvScalarKey::F32 => b.type_float(32),
SpirvScalarKey::F64 => b.type_float(64),
SpirvScalarKey::Pred => b.type_bool(),
SpirvScalarKey::F16x2 => todo!(),
})
}
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
match t {
SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
SpirvType::Pointer(ref typ, storage) => {
let base = self.get_or_add(b, *typ.clone());
*self
.complex
.entry(t)
.or_insert_with(|| b.type_pointer(None, storage, base))
}
SpirvType::Vector(typ, len) => {
let base = self.get_or_add_spirv_scalar(b, typ);
*self
.complex
.entry(t)
.or_insert_with(|| b.type_vector(base, len as u32))
}
SpirvType::Array(typ, array_dimensions) => {
let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let (base_type, length) = match &*array_dimensions {
&[len] => {
let base = self.get_or_add_spirv_scalar(b, typ);
let len_const = b.constant_u32(u32_type, None, len);
(base, len_const)
}
array_dimensions => {
let base = self
.get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
(base, len_const)
}
};
*self
.complex
.entry(SpirvType::Array(typ, array_dimensions))
.or_insert_with(|| b.type_array(base_type, length))
}
SpirvType::Func(ref out_params, ref in_params) => {
let out_t = match out_params {
Some(p) => self.get_or_add(b, *p.clone()),
None => self.void(),
};
let in_t = in_params
.iter()
.map(|t| self.get_or_add(b, t.clone()))
.collect::<Vec<_>>();
*self
.complex
.entry(t)
.or_insert_with(|| b.type_function(out_t, in_t))
}
SpirvType::Struct(ref underlying) => {
let underlying_ids = underlying
.iter()
.map(|t| self.get_or_add_spirv_scalar(b, *t))
.collect::<Vec<_>>();
*self
.complex
.entry(t)
.or_insert_with(|| b.type_struct(underlying_ids))
}
}
}
fn get_or_add_fn(
&mut self,
b: &mut dr::Builder,
in_params: impl ExactSizeIterator<Item = SpirvType>,
mut out_params: impl ExactSizeIterator<Item = SpirvType>,
) -> (spirv::Word, spirv::Word) {
let (out_args, out_spirv_type) = if out_params.len() == 0 {
(None, self.void())
} else if out_params.len() == 1 {
let arg_as_key = out_params.next().unwrap();
(
Some(Box::new(arg_as_key.clone())),
self.get_or_add(b, arg_as_key),
)
} else {
todo!()
};
(
out_spirv_type,
self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::<Vec<_>>())),
)
}
fn get_or_add_constant(
&mut self,
b: &mut dr::Builder,
typ: &ast::Type,
init: &[u8],
) -> Result<spirv::Word, TranslateError> {
Ok(match typ {
ast::Type::Scalar(t) => match t {
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self
.get_or_add_constant_single::<u8, _, _>(
b,
*t,
init,
|v| v as u64,
|b, result_type, v| b.constant_u32(result_type, None, v as u32),
),
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self
.get_or_add_constant_single::<u16, _, _>(
b,
*t,
init,
|v| v as u64,
|b, result_type, v| b.constant_u32(result_type, None, v as u32),
),
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self
.get_or_add_constant_single::<u32, _, _>(
b,
*t,
init,
|v| v as u64,
|b, result_type, v| b.constant_u32(result_type, None, v),
),
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self
.get_or_add_constant_single::<u64, _, _>(
b,
*t,
init,
|v| v,
|b, result_type, v| b.constant_u64(result_type, None, v),
),
ast::ScalarType::F16 => self.get_or_add_constant_single::<f16, _, _>(
b,
*t,
init,
|v| unsafe { mem::transmute::<_, u16>(v) } as u64,
|b, result_type, v| b.constant_f32(result_type, None, v.to_f32()),
),
ast::ScalarType::F32 => self.get_or_add_constant_single::<f32, _, _>(
b,
*t,
init,
|v| unsafe { mem::transmute::<_, u32>(v) } as u64,
|b, result_type, v| b.constant_f32(result_type, None, v),
),
ast::ScalarType::F64 => self.get_or_add_constant_single::<f64, _, _>(
b,
*t,
init,
|v| unsafe { mem::transmute::<_, u64>(v) },
|b, result_type, v| b.constant_f64(result_type, None, v),
),
ast::ScalarType::F16x2 => return Err(TranslateError::Todo),
ast::ScalarType::Pred => self.get_or_add_constant_single::<u8, _, _>(
b,
*t,
init,
|v| v as u64,
|b, result_type, v| {
if v == 0 {
b.constant_false(result_type, None)
} else {
b.constant_true(result_type, None)
}
},
),
},
ast::Type::Vector(typ, len) => {
let result_type =
self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
let size_of_t = typ.size_of();
let components = (0..*len)
.map(|x| {
self.get_or_add_constant(
b,
&ast::Type::Scalar(*typ),
&init[((size_of_t as usize) * (x as usize))..],
)
})
.collect::<Result<Vec<_>, _>>()?;
b.constant_composite(result_type, None, &components)
}
ast::Type::Array(typ, dims) => match dims.as_slice() {
[] => return Err(TranslateError::Unreachable),
[dim] => {
let result_type = self
.get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
let size_of_t = typ.size_of();
let components = (0..*dim)
.map(|x| {
self.get_or_add_constant(
b,
&ast::Type::Scalar(*typ),
&init[((size_of_t as usize) * (x as usize))..],
)
})
.collect::<Result<Vec<_>, _>>()?;
b.constant_composite(result_type, None, &components)
}
[first_dim, rest @ ..] => {
let result_type = self.get_or_add(
b,
SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()),
);
let size_of_t = rest
.iter()
.fold(typ.size_of() as u32, |x, y| (x as u32) * (*y));
let components = (0..*first_dim)
.map(|x| {
self.get_or_add_constant(
b,
&ast::Type::Array(*typ, rest.to_vec()),
&init[((size_of_t as usize) * (x as usize))..],
)
})
.collect::<Result<Vec<_>, _>>()?;
b.constant_composite(result_type, None, &components)
}
},
ast::Type::Pointer(typ, state_space) => {
let base_t = typ.clone().into();
let base = self.get_or_add_constant(b, &base_t, &[])?;
let result_type = self.get_or_add(
b,
SpirvType::Pointer(
Box::new(SpirvType::from(base_t)),
(*state_space).to_spirv(),
),
);
b.variable(result_type, None, (*state_space).to_spirv(), Some(base))
}
})
}
fn get_or_add_constant_single<
T: Copy,
CastAsU64: FnOnce(T) -> u64,
InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word,
>(
&mut self,
b: &mut dr::Builder,
key: ast::ScalarType,
init: &[u8],
cast: CastAsU64,
f: InsertConstant,
) -> spirv::Word {
let value = unsafe { *(init.as_ptr() as *const T) };
let value_64 = cast(value);
let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64);
match self.constants.get(&ht_key) {
Some(value) => *value,
None => {
let spirv_type = self.get_or_add_scalar(b, key);
let result = f(b, spirv_type, value);
self.constants.insert(ht_key, result);
result
}
}
}
}
pub struct Module {
pub spirv: dr::Module,
pub kernel_info: HashMap<String, KernelInfo>,
pub should_link_ptx_impl: Option<&'static [u8]>,
}
pub struct KernelInfo {
pub arguments_sizes: Vec<usize>,
pub uses_shared_mem: bool,
}
pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateError> {
let mut id_defs = GlobalStringIdResolver::new(1);
let mut ptx_impl_imports = HashMap::new();
let directives = ast
.directives
.into_iter()
.map(|directive| translate_directive(&mut id_defs, &mut ptx_impl_imports, directive))
.collect::<Result<Vec<_>, _>>()?;
let must_link_ptx_impl = ptx_impl_imports.len() > 0;
let directives = ptx_impl_imports
.into_iter()
.map(|(_, v)| v)
.chain(directives.into_iter())
.collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
let call_map = get_call_map(&directives);
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 3);
emit_capabilities(&mut builder);
emit_extensions(&mut builder);
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
emit_directives(
&mut builder,
&mut map,
&id_defs,
opencl_id,
&denorm_information,
&call_map,
directives,
&mut kernel_info,
)?;
let spirv = builder.module();
Ok(Module {
spirv,
kernel_info,
should_link_ptx_impl: if must_link_ptx_impl {
Some(NOTCUDA_PTX_IMPL)
} else {
None
},
})
}
fn emit_directives<'input>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver<'input>,
opencl_id: spirv::Word,
denorm_information: &HashMap<MethodName<'input>, HashMap<u8, spirv::FPDenormMode>>,
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
directives: Vec<Directive>,
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
let empty_body = Vec::new();
for d in directives.iter() {
match d {
Directive::Variable(var) => {
emit_variable(builder, map, &var)?;
}
Directive::Method(f) => {
let f_body = match &f.body {
Some(f) => f,
None => {
if f.import_as.is_some() {
&empty_body
} else {
continue;
}
}
};
for var in f.globals.iter() {
emit_variable(builder, map, var)?;
}
emit_function_header(
builder,
map,
&id_defs,
&f.globals,
&f.spirv_decl,
&denorm_information,
call_map,
&directives,
kernel_info,
)?;
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
(&f.func_decl, &f.import_as)
{
builder.decorate(
*fn_id,
spirv::Decoration::LinkageAttributes,
&[
dr::Operand::LiteralString(name.clone()),
dr::Operand::LinkageType(spirv::LinkageType::Import),
],
);
}
}
}
}
Ok(())
}
fn get_call_map<'input>(
module: &[Directive<'input>],
) -> HashMap<&'input str, HashSet<spirv::Word>> {
let mut directly_called_by = HashMap::new();
for directive in module {
match directive {
Directive::Method(Function {
func_decl,
body: Some(statements),
..
}) => {
let call_key = MethodName::new(&func_decl);
for statement in statements {
match statement {
Statement::Call(call) => {
multi_hash_map_append(&mut directly_called_by, call_key, call.func);
}
_ => {}
}
}
}
_ => {}
}
}
let mut result = HashMap::new();
for (method_key, children) in directly_called_by.iter() {
match method_key {
MethodName::Kernel(name) => {
let mut visited = HashSet::new();
for child in children {
add_call_map_single(&directly_called_by, &mut visited, *child);
}
result.insert(*name, visited);
}
MethodName::Func(_) => {}
}
}
result
}
fn add_call_map_single<'input>(
directly_called_by: &MultiHashMap<MethodName<'input>, spirv::Word>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
if !visited.insert(current) {
return;
}
if let Some(children) = directly_called_by.get(&MethodName::Func(current)) {
for child in children {
add_call_map_single(directly_called_by, visited, *child);
}
}
}
type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) {
match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().push(value);
}
hash_map::Entry::Vacant(entry) => {
entry.insert(vec![value]);
}
}
}
// PTX represents dynamically allocated shared local memory as
// .extern .shared .align 4 .b8 shared_mem[];
// In SPIRV/OpenCL world this is expressed as an additional argument
// This pass looks for all uses of .extern .shared and converts them to
// an additional method argument
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
) -> Vec<Directive<'input>> {
let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
Directive::Variable(var) => {
if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
var.v_type
{
extern_shared_decls.insert(var.name, p_type);
}
}
_ => {}
}
}
if extern_shared_decls.len() == 0 {
return module;
}
let mut methods_using_extern_shared = HashSet::new();
let mut directly_called_by = MultiHashMap::new();
let module = module
.into_iter()
.map(|directive| match directive {
Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
spirv_decl,
}) => {
let call_key = MethodName::new(&func_decl);
let statements = statements
.into_iter()
.map(|statement| match statement {
Statement::Call(call) => {
multi_hash_map_append(&mut directly_called_by, call.func, call_key);
Statement::Call(call)
}
statement => statement.map_id(&mut |id| {
if extern_shared_decls.contains_key(&id) {
methods_using_extern_shared.insert(call_key);
}
id
}),
})
.collect();
Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
spirv_decl,
})
}
directive => directive,
})
.collect::<Vec<_>>();
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
// make sure it gets propagated to `fn1` and `kernel`
get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by);
// now visit every method declaration and inject those additional arguments
module
.into_iter()
.map(|directive| match directive {
Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
mut spirv_decl,
}) => {
if !methods_using_extern_shared.contains(&spirv_decl.name) {
return Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
spirv_decl,
});
}
let shared_id_param = new_id();
spirv_decl.input.push({
ast::Variable {
align: None,
v_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
ast::LdStateSpace::Shared,
),
array_init: Vec::new(),
name: shared_id_param,
}
});
spirv_decl.uses_shared_mem = true;
let shared_var_id = new_id();
let shared_var = ExpandedStatement::Variable(ast::Variable {
align: None,
name: shared_var_id,
array_init: Vec::new(),
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
ast::SizedScalarType::B8,
ast::PointerStateSpace::Shared,
)),
});
let shared_var_st = ExpandedStatement::StoreVar(
ast::Arg2St {
src1: shared_var_id,
src2: shared_id_param,
},
ast::Type::Scalar(ast::ScalarType::B8),
);
let mut new_statements = vec![shared_var, shared_var_st];
replace_uses_of_shared_memory(
&mut new_statements,
new_id,
&extern_shared_decls,
&mut methods_using_extern_shared,
shared_id_param,
shared_var_id,
statements,
);
Directive::Method(Function {
func_decl,
globals,
body: Some(new_statements),
import_as,
spirv_decl,
})
}
directive => directive,
})
.collect::<Vec<_>>()
}
fn replace_uses_of_shared_memory<'a>(
result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
shared_id_param: spirv::Word,
shared_var_id: spirv::Word,
statements: Vec<ExpandedStatement>,
) {
for statement in statements {
match statement {
Statement::Call(mut call) => {
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
call.param_list
.push((shared_id_param, ast::FnArgumentType::Shared));
}
result.push(Statement::Call(call))
}
statement => {
let new_statement = statement.map_id(&mut |id| {
if let Some(typ) = extern_shared_decls.get(&id) {
let replacement_id = new_id();
if *typ != ast::SizedScalarType::B8 {
result.push(Statement::Conversion(ImplicitConversion {
src: shared_var_id,
dst: replacement_id,
from: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
),
to: ast::Type::Pointer(
ast::PointerType::Scalar((*typ).into()),
ast::LdStateSpace::Shared,
),
kind: ConversionKind::PtrToPtr { spirv_ptr: true },
}));
}
replacement_id
} else {
id
}
});
result.push(new_statement);
}
}
}
}
fn get_callers_of_extern_shared<'a>(
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
) {
let direct_uses_of_extern_shared = methods_using_extern_shared
.iter()
.filter_map(|method| {
if let MethodName::Func(f_id) = method {
Some(*f_id)
} else {
None
}
})
.collect::<Vec<_>>();
for fn_id in direct_uses_of_extern_shared {
get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id);
}
}
fn get_callers_of_extern_shared_single<'a>(
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
fn_id: spirv::Word,
) {
if let Some(callers) = directly_called_by.get(&fn_id) {
for caller in callers {
if methods_using_extern_shared.insert(*caller) {
if let MethodName::Func(caller_fn) = caller {
get_callers_of_extern_shared_single(
methods_using_extern_shared,
directly_called_by,
*caller_fn,
);
}
}
}
}
}
type DenormCountMap<T> = HashMap<T, isize>;
fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
let num_value = if value { 1 } else { -1 };
denorm_count_map_update_impl(map, key, num_value);
}
fn denorm_count_map_update_impl<T: Eq + Hash>(
map: &mut DenormCountMap<T>,
key: T,
num_value: isize,
) {
match map.entry(key) {
hash_map::Entry::Occupied(mut counter) => {
*(counter.get_mut()) += num_value;
}
hash_map::Entry::Vacant(entry) => {
entry.insert(num_value);
}
}
}
// HACK ALERT!
// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
// in the kernel as flushing denorms to zero or preserving them
// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
// such capability, so instead we guesstimate which use is more common in the kernel
// and emit suitable execution mode
fn compute_denorm_information<'input>(
module: &[Directive<'input>],
) -> HashMap<MethodName<'input>, HashMap<u8, spirv::FPDenormMode>> {
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
func_decl,
body: Some(statements),
..
}) => {
let mut flush_counter = DenormCountMap::new();
let method_key = MethodName::new(func_decl);
for statement in statements {
match statement {
Statement::Instruction(inst) => {
if let Some((flush, width)) = inst.flush_to_zero() {
denorm_count_map_update(&mut flush_counter, width, flush);
}
}
Statement::LoadVar(_, _) => {}
Statement::StoreVar(_, _) => {}
Statement::Call(_) => {}
Statement::Composite(_) => {}
Statement::Conditional(_) => {}
Statement::Conversion(_) => {}
Statement::Constant(_) => {}
Statement::RetValue(_, _) => {}
Statement::Undef(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAdd { .. } => {}
}
}
denorm_methods.insert(method_key, flush_counter);
}
}
}
denorm_methods
.into_iter()
.map(|(name, v)| {
let width_to_denorm = v
.into_iter()
.map(|(k, ftz_over_preserve)| {
let mode = if ftz_over_preserve > 0 {
spirv::FPDenormMode::FlushToZero
} else {
spirv::FPDenormMode::Preserve
};
(k, mode)
})
.collect();
(name, width_to_denorm)
})
.collect()
}
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
enum MethodName<'input> {
Kernel(&'input str),
Func(spirv::Word),
}
impl<'input> MethodName<'input> {
fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
match decl {
ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name),
ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id),
}
}
}
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver,
) {
for (reg, id) in id_defs.special_registers.iter() {
let result_type = map.get_or_add(
builder,
SpirvType::Pointer(
Box::new(SpirvType::from(reg.get_type())),
spirv::StorageClass::UniformConstant,
),
);
builder.variable(
result_type,
Some(*id),
spirv::StorageClass::UniformConstant,
None,
);
builder.decorate(
*id,
spirv::Decoration::BuiltIn,
&[dr::Operand::BuiltIn(reg.get_builtin())],
);
}
}
fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
func_decl: &SpirvMethodDecl<'a>,
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, spirv::FPDenormMode>>,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
if let MethodName::Kernel(name) = func_decl.name {
let args_lens = func_decl
.input
.iter()
.map(|param| param.v_type.size_of())
.collect();
kernel_info.insert(
name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
uses_shared_mem: func_decl.uses_shared_mem,
},
);
}
let (ret_type, func_type) =
get_function_type(builder, map, &func_decl.input, &func_decl.output);
let fn_id = match func_decl.name {
MethodName::Kernel(name) => {
let fn_id = defined_globals.get_id(name)?;
let mut global_variables = defined_globals
.variables_type_check
.iter()
.filter_map(|(k, t)| t.as_ref().map(|_| *k))
.collect::<Vec<_>>();
let mut interface = defined_globals
.special_registers
.iter()
.map(|(_, id)| *id)
.collect::<Vec<_>>();
for ast::Variable { name, .. } in synthetic_globals {
interface.push(*name);
}
let empty_hash_set = HashSet::new();
let child_fns = call_map.get(name).unwrap_or(&empty_hash_set);
for directive in direcitves {
match directive {
Directive::Method(Function {
func_decl: ast::MethodDecl::Func(_, name, _),
globals,
..
}) => {
if child_fns.contains(name) {
for var in globals {
interface.push(var.name);
}
}
}
_ => {}
}
}
global_variables.append(&mut interface);
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
fn_id
}
MethodName::Func(name) => name,
};
builder.begin_function(
ret_type,
Some(fn_id),
spirv::FunctionControl::NONE,
func_type,
)?;
// TODO: re-enable when Intel float control extension works
/*
if let Some(denorm_modes) = denorm_information.get(&func_decl.name) {
for (size_of, denorm_mode) in denorm_modes {
builder.decorate(
fn_id,
spirv::Decoration::FunctionDenormModeINTEL,
[
dr::Operand::LiteralInt32((*size_of as u32) * 8),
dr::Operand::FPDenormMode(*denorm_mode),
],
)
}
}
*/
for input in &func_decl.input {
let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
let inst = dr::Instruction::new(
spirv::Op::FunctionParameter,
Some(result_type),
Some(input.name),
Vec::new(),
);
builder.function.as_mut().unwrap().parameters.push(inst);
}
Ok(())
}
pub fn to_spirv<'a>(
ast: ast::Module<'a>,
) -> Result<(Option<&'static [u8]>, Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
let module = to_spirv_module(ast)?;
Ok((
module.should_link_ptx_impl,
module.spirv.assemble(),
module
.kernel_info
.into_iter()
.map(|(k, v)| (k, v.arguments_sizes))
.collect(),
))
}
fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::GenericPointer);
builder.capability(spirv::Capability::Linkage);
builder.capability(spirv::Capability::Addresses);
builder.capability(spirv::Capability::Kernel);
builder.capability(spirv::Capability::Int8);
builder.capability(spirv::Capability::Int16);
builder.capability(spirv::Capability::Int64);
builder.capability(spirv::Capability::Float16);
builder.capability(spirv::Capability::Float64);
// TODO: re-enable when Intel float control extension works
//builder.capability(spirv::Capability::FunctionFloatControlINTEL);
}
// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html
fn emit_extensions(_builder: &mut dr::Builder) {
// TODO: re-enable when Intel float control extension works
//builder.extension("SPV_INTEL_float_controls2");
}
fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {
builder.ext_inst_import("OpenCL.std")
}
fn emit_memory_model(builder: &mut dr::Builder) {
builder.memory_model(
spirv::AddressingModel::Physical64,
spirv::MemoryModel::OpenCL,
);
}
fn translate_directive<'input>(
id_defs: &mut GlobalStringIdResolver<'input>,
ptx_impl_imports: &mut HashMap<String, Directive>,
d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Directive<'input>, TranslateError> {
Ok(match d {
ast::Directive::Variable(v) => Directive::Variable(translate_variable(id_defs, v)?),
ast::Directive::Method(f) => {
Directive::Method(translate_function(id_defs, ptx_impl_imports, f)?)
}
})
}
fn translate_variable<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
var: ast::Variable<ast::VariableType, &'a str>,
) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
let (state_space, typ) = var.v_type.to_type();
Ok(ast::Variable {
align: var.align,
v_type: var.v_type,
name: id_defs.get_or_add_def_typed(var.name, (state_space.into(), typ)),
array_init: var.array_init,
})
}
fn translate_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
ptx_impl_imports: &mut HashMap<String, Directive>,
f: ast::ParsedFunction<'a>,
) -> Result<Function<'a>, TranslateError> {
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)
}
fn expand_kernel_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
) -> Result<Vec<ast::KernelArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
Ok(ast::KernelArgument {
name: fn_resolver.add_def(
a.name,
Some((
StateSpace::Param,
ast::Type::from(a.v_type.clone()).pointer_to(ast::LdStateSpace::Param)?,
)),
),
v_type: a.v_type.clone(),
align: a.align,
array_init: Vec::new(),
})
})
.collect::<Result<_, _>>()
}
fn expand_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
let var_type = a.v_type.to_func_type();
let ss = match a.v_type {
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
ast::FnArgumentType::Param(_) => StateSpace::Param,
ast::FnArgumentType::Shared => StateSpace::Shared,
};
Ok(ast::FnArgument {
name: fn_resolver.add_def(a.name, Some((ss, var_type))),
v_type: a.v_type.clone(),
align: a.align,
array_init: Vec::new(),
})
})
.collect()
}
fn to_ssa<'input, 'b>(
ptx_impl_imports: &mut HashMap<String, Directive>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
f_args: ast::MethodDecl<'input, spirv::Word>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
) -> Result<Function<'input>, TranslateError> {
let mut spirv_decl = SpirvMethodDecl::new(&f_args);
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(Function {
func_decl: f_args,
body: None,
globals: Vec::new(),
import_as: None,
spirv_decl,
})
}
};
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let ssa_statements =
insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?;
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
Ok(Function {
func_decl: f_args,
globals: globals,
body: Some(f_body),
import_as: None,
spirv_decl,
})
}
fn extract_globals<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
) -> (
Vec<ExpandedStatement>,
Vec<ast::Variable<ast::VariableType, spirv::Word>>,
) {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
match statement {
Statement::Variable(
var
@
ast::Variable {
v_type: ast::VariableType::Shared(_),
..
},
)
| Statement::Variable(
var
@
ast::Variable {
v_type: ast::VariableType::Global(_),
..
},
) => global.push(var),
Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => {
local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg));
}
Statement::Instruction(ast::Instruction::Atom(
d
@
ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
..
},
..
},
a,
)) => {
local.push(to_ptx_impl_atomic_call(
id_def,
ptx_impl_imports,
d,
a,
"inc",
));
}
Statement::Instruction(ast::Instruction::Atom(
d
@
ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
..
},
..
},
a,
)) => {
local.push(to_ptx_impl_atomic_call(
id_def,
ptx_impl_imports,
d,
a,
"dec",
));
}
s => local.push(s),
}
}
(local, global)
}
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
for directive in directives {
match directive {
Directive::Method(Function {
body: Some(func), ..
}) => {
func[1..].sort_by_key(|s| match s {
Statement::Variable(_) => 0,
_ => 1,
});
}
_ => (),
}
}
}
fn convert_to_typed_statements(
func: Vec<UnconditionalStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
for s in func {
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Call(call) => {
// TODO: error out if lengths don't match
let fn_def = fn_defs.get_fn_decl(call.func)?;
let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
.into_iter()
.partition(|(_, arg_type)| arg_type.is_param());
let normalized_input_args = out_params
.into_iter()
.map(|(id, typ)| (ast::CallOperand::Reg(id), typ))
.chain(in_args.into_iter())
.collect();
let resolved_call = ResolvedCall {
uniform: call.uniform,
ret_params: out_non_params,
func: call.func,
param_list: normalized_input_args,
};
result.push(Statement::Call(resolved_call));
}
// Supported ld/st:
// global: only compatible with reg b64/u64/s64 source/dest
// generic: compatible with global/local sources
// param: compiled as mov
// local compiled as mov
// We would like to convert ld/st local/param to movs here,
// but they have different semantics for implicit conversions
// For now, we convert generic ld from local params to ld.local.
// This way, we can rely on further stages of the compilation on
// ld.generic & ld.global having bytes address source
// One complication: immediate address is only allowed in local,
// It is not supported in generic ld
// ld.local foo, [1];
ast::Instruction::Ld(mut d, arg) => {
match arg.src.underlying() {
None => {}
Some(u) => {
let (ss, _) = id_defs.get_typed(*u)?;
match (d.state_space, ss) {
(ast::LdStateSpace::Generic, StateSpace::Local) => {
d.state_space = ast::LdStateSpace::Local;
}
_ => {}
};
}
};
result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast())));
}
ast::Instruction::St(mut d, arg) => {
match arg.src1.underlying() {
None => {}
Some(u) => {
let (ss, _) = id_defs.get_typed(*u)?;
match (d.state_space, ss) {
(ast::StStateSpace::Generic, StateSpace::Local) => {
d.state_space = ast::StStateSpace::Local;
}
_ => (),
};
}
};
result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast())));
}
ast::Instruction::Mov(mut d, args) => match args {
ast::Arg2Mov::Normal(arg) => {
if let Some(src_id) = arg.src.single_underlying() {
let (scope, _) = id_defs.get_typed(*src_id)?;
d.src_is_address = match scope {
StateSpace::Reg => false,
StateSpace::Const
| StateSpace::Global
| StateSpace::Local
| StateSpace::Shared
| StateSpace::Param
| StateSpace::ParamReg => true,
};
}
result.push(Statement::Instruction(ast::Instruction::Mov(
d,
ast::Arg2Mov::Normal(arg.cast()),
)));
}
ast::Arg2Mov::Member(args) => {
if let Some(dst_typ) = args.vector_dst() {
match id_defs.get_typed(*dst_typ)? {
(_, ast::Type::Vector(_, len)) => {
d.dst_width = len;
}
_ => return Err(TranslateError::MismatchedType),
}
};
if let Some((src_typ, _)) = args.vector_src() {
match id_defs.get_typed(*src_typ)? {
(_, ast::Type::Vector(_, len)) => {
d.src_width = len;
}
_ => return Err(TranslateError::MismatchedType),
}
};
result.push(Statement::Instruction(ast::Instruction::Mov(
d,
ast::Arg2Mov::Member(args.cast()),
)));
}
},
ast::Instruction::Mul(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast())))
}
ast::Instruction::Add(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast())))
}
ast::Instruction::Setp(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast())))
}
ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction(
ast::Instruction::SetpBool(d, a.cast()),
)),
ast::Instruction::Not(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast())))
}
ast::Instruction::Bra(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast())))
}
ast::Instruction::Cvt(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast())))
}
ast::Instruction::Cvta(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast())))
}
ast::Instruction::Shl(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast())))
}
ast::Instruction::Ret(d) => {
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
}
ast::Instruction::Abs(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast())))
}
ast::Instruction::Mad(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast())))
}
ast::Instruction::Shr(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast())))
}
ast::Instruction::Or(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
}
ast::Instruction::Sub(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast())))
}
ast::Instruction::Min(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast())))
}
ast::Instruction::Max(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
}
ast::Instruction::Rcp(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast())))
}
ast::Instruction::And(d, a) => {
result.push(Statement::Instruction(ast::Instruction::And(d, a.cast())))
}
ast::Instruction::Selp(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast())))
}
ast::Instruction::Bar(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast())))
}
ast::Instruction::Atom(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast())))
}
ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
ast::Instruction::AtomCas(d, a.cast()),
)),
ast::Instruction::Div(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast())))
}
ast::Instruction::Sqrt(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast())))
}
ast::Instruction::Rsqrt(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
}
ast::Instruction::Neg(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast())))
}
ast::Instruction::Sin { flush_to_zero, arg } => {
result.push(Statement::Instruction(ast::Instruction::Sin {
flush_to_zero,
arg: arg.cast(),
}))
}
ast::Instruction::Cos { flush_to_zero, arg } => {
result.push(Statement::Instruction(ast::Instruction::Cos {
flush_to_zero,
arg: arg.cast(),
}))
}
ast::Instruction::Lg2 { flush_to_zero, arg } => {
result.push(Statement::Instruction(ast::Instruction::Lg2 {
flush_to_zero,
arg: arg.cast(),
}))
}
ast::Instruction::Ex2 { flush_to_zero, arg } => {
result.push(Statement::Instruction(ast::Instruction::Ex2 {
flush_to_zero,
arg: arg.cast(),
}))
}
ast::Instruction::Clz { typ, arg } => {
result.push(Statement::Instruction(ast::Instruction::Clz {
typ,
arg: arg.cast(),
}))
}
ast::Instruction::Brev { typ, arg } => {
result.push(Statement::Instruction(ast::Instruction::Brev {
typ,
arg: arg.cast(),
}))
}
ast::Instruction::Popc { typ, arg } => {
result.push(Statement::Instruction(ast::Instruction::Popc {
typ,
arg: arg.cast(),
}))
}
ast::Instruction::Xor { typ, arg } => {
result.push(Statement::Instruction(ast::Instruction::Xor {
typ,
arg: arg.cast(),
}))
}
ast::Instruction::Bfe { typ, arg } => {
result.push(Statement::Instruction(ast::Instruction::Bfe {
typ,
arg: arg.cast(),
}))
}
ast::Instruction::Rem { typ, arg } => {
result.push(Statement::Instruction(ast::Instruction::Rem {
typ,
arg: arg.cast(),
}))
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
Statement::LoadVar(a, t) => result.push(Statement::LoadVar(a, t)),
Statement::StoreVar(a, t) => result.push(Statement::StoreVar(a, t)),
Statement::Call(c) => result.push(Statement::Call(c.cast())),
Statement::Composite(c) => result.push(Statement::Composite(c)),
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
Statement::Conversion(c) => result.push(Statement::Conversion(c)),
Statement::Constant(c) => result.push(Statement::Constant(c)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Undef(_, _) | Statement::PtrAdd { .. } => {
return Err(TranslateError::Unreachable)
}
}
}
Ok(result)
}
//TODO: share common code between this and to_ptx_impl_bfe_call
fn to_ptx_impl_atomic_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
details: ast::AtomDetails,
arg: ast::Arg3<ExpandedArgParams>,
op: &'static str,
) -> ExpandedStatement {
let semantics = ptx_semantics_name(details.semantics);
let scope = ptx_scope_name(details.scope);
let space = ptx_space_name(details.space);
let fn_name = format!(
"__notcuda_ptx_impl__atom_{}_{}_{}_{}",
semantics, scope, space, op
);
// TODO: extract to a function
let ptr_space = match details.space {
ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
ast::AtomSpace::Global => ast::PointerStateSpace::Global,
ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
};
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.new_id(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
name: id_defs.new_id(None),
array_init: Vec::new(),
}],
fn_id,
vec![
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
ast::SizedScalarType::U32,
ptr_space,
)),
name: id_defs.new_id(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
name: id_defs.new_id(None),
array_init: Vec::new(),
},
],
);
let spirv_decl = SpirvMethodDecl::new(&func_decl);
let func = Function {
func_decl,
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
spirv_decl,
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function {
func_decl: ast::MethodDecl::Func(_, name, _),
..
}) => *name,
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
ret_params: vec![(
arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
)],
param_list: vec![
(
arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
ast::SizedScalarType::U32,
ptr_space,
)),
),
(
arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
),
],
})
}
fn to_ptx_impl_bfe_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
typ: ast::IntType,
arg: ast::Arg4<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__notcuda_ptx_impl__";
let suffix = match typ {
ast::IntType::U32 => "bfe_u32",
ast::IntType::U64 => "bfe_u64",
ast::IntType::S32 => "bfe_s32",
ast::IntType::S64 => "bfe_s64",
_ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.new_id(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
name: id_defs.new_id(None),
array_init: Vec::new(),
}],
fn_id,
vec![
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
name: id_defs.new_id(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
name: id_defs.new_id(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
name: id_defs.new_id(None),
array_init: Vec::new(),
},
],
);
let spirv_decl = SpirvMethodDecl::new(&func_decl);
let func = Function {
func_decl,
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
spirv_decl,
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function {
func_decl: ast::MethodDecl::Func(_, name, _),
..
}) => *name,
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
ret_params: vec![(
arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
)],
param_list: vec![
(
arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
),
(
arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
),
(
arg.src3,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
),
],
})
}
fn to_resolved_fn_args<T>(
params: Vec<T>,
params_decl: &[ast::FnArgumentType],
) -> Vec<(T, ast::FnArgumentType)> {
params
.into_iter()
.zip(params_decl.iter())
.map(|(id, typ)| (id, typ.clone()))
.collect::<Vec<_>>()
}
fn normalize_labels(
func: Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
) -> Vec<ExpandedStatement> {
let mut labels_in_use = HashSet::new();
for s in func.iter() {
match s {
Statement::Instruction(i) => {
if let Some(target) = i.jump_target() {
labels_in_use.insert(target);
}
}
Statement::Conditional(cond) => {
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
Statement::Composite(_)
| Statement::Call(_)
| Statement::Variable(_)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::RetValue(_, _)
| Statement::Conversion(_)
| Statement::Constant(_)
| Statement::Label(_)
| Statement::Undef(_, _)
| Statement::PtrAdd { .. } => {}
}
}
iter::once(Statement::Label(id_def.new_id(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
}))
.collect::<Vec<_>>()
}
fn normalize_predicates(
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
) -> Vec<UnconditionalStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
let if_true = id_def.new_id(None);
let if_false = id_def.new_id(None);
let folded_bra = match &inst {
ast::Instruction::Bra(_, arg) => Some(arg.src),
_ => None,
};
let mut branch = BrachCondition {
predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
}
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
result.push(Statement::Instruction(inst));
}
result.push(Statement::Label(if_false));
} else {
result.push(Statement::Instruction(inst));
}
}
Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
_ => unreachable!(),
}
}
result
}
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut MutableNumericIdResolver,
fn_decl: &mut SpirvMethodDecl,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
match type_to_variable_type(&arg.v_type)? {
Some(var_type) => {
result.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: var_type,
name: arg.name,
array_init: arg.array_init.clone(),
}));
}
None => return Err(TranslateError::Unreachable),
}
}
for arg in fn_decl.input.iter_mut() {
match type_to_variable_type(&arg.v_type)? {
Some(var_type) => {
let typ = arg.v_type.clone();
let new_id = id_def.new_id(typ.clone());
result.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: var_type,
name: arg.name,
array_init: arg.array_init.clone(),
}));
result.push(Statement::StoreVar(
ast::Arg2St {
src1: arg.name,
src2: new_id,
},
typ,
));
arg.name = new_id;
}
None => {}
}
}
for s in func {
match s {
Statement::Call(call) => {
insert_mem_ssa_statement_default(id_def, &mut result, call.cast())?
}
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
if let &[out_param] = &fn_decl.output.as_slice() {
let typ = id_def.get_typed(out_param.name)?;
let new_id = id_def.new_id(typ.clone());
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
src: out_param.name,
},
typ.clone(),
));
result.push(Statement::RetValue(d, new_id));
} else {
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
}
}
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
let generated_id = id_def.new_id(ast::Type::Scalar(ast::ScalarType::Pred));
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: bra.predicate,
},
ast::Type::Scalar(ast::ScalarType::Pred),
));
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Conversion(_)
| Statement::RetValue(_, _)
| Statement::Constant(_)
| Statement::Undef(_, _)
| Statement::PtrAdd { .. } => {}
Statement::Composite(_) => todo!(),
}
}
Ok(result)
}
fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, TranslateError> {
Ok(match t {
ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
(*typ)
.try_into()
.map_err(|_| TranslateError::MismatchedType)?,
*len,
))),
ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
(*typ)
.try_into()
.map_err(|_| TranslateError::MismatchedType)?,
len.clone(),
))),
ast::Type::Pointer(_, _) => None,
})
}
trait VisitVariable: Sized {
fn visit_variable<
'a,
F: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<TypedStatement, TranslateError>;
}
trait VisitVariableExpanded {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError>;
}
struct VisitArgumentDescriptor<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> {
desc: ArgumentDescriptor<spirv::Word>,
typ: &'a ast::Type,
stmt_ctor: Ctor,
}
impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded
for VisitArgumentDescriptor<'a, Ctor>
{
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError> {
f(self.desc, Some(self.typ)).map(self.stmt_ctor)
}
}
fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
id_def: &mut MutableNumericIdResolver,
result: &mut Vec<TypedStatement>,
stmt: F,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
let new_statement =
stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, instr_type| {
if instr_type.is_none() || desc.sema == ArgumentSemantics::RegisterPointer {
return Ok(desc.op);
}
let id_type = match (id_def.get_typed(desc.op)?, desc.sema) {
(_, ArgumentSemantics::Address) => return Ok(desc.op),
(t, ArgumentSemantics::RegisterPointer)
| (t, ArgumentSemantics::Default)
| (t, ArgumentSemantics::DefaultRelaxed)
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
if let ast::Type::Array(_, _) = id_type {
return Ok(desc.op);
}
let generated_id = id_def.new_id(id_type.clone());
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: desc.op,
},
id_type,
));
} else {
post_statements.push(Statement::StoreVar(
Arg2St {
src1: desc.op,
src2: generated_id,
},
id_type,
));
}
Ok(generated_id)
})?;
result.push(new_statement);
result.append(&mut post_statements);
Ok(())
}
fn expand_arguments<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &'b mut MutableNumericIdResolver<'a>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
Statement::Call(call) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
let (new_call, post_stmts) = (call.map(&mut visitor)?, visitor.post_stmts);
result.push(Statement::Call(new_call));
result.extend(post_stmts);
}
Statement::Instruction(inst) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
let (new_inst, post_stmts) = (inst.map(&mut visitor)?, visitor.post_stmts);
result.push(Statement::Instruction(new_inst));
result.extend(post_stmts);
}
Statement::Variable(ast::Variable {
align,
v_type,
name,
array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
name,
array_init,
})),
Statement::PtrAdd {
underlying_type,
state_space,
dst,
ptr_src,
constant_src,
} => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
let sema = match state_space {
ast::LdStateSpace::Const
| ast::LdStateSpace::Global
| ast::LdStateSpace::Shared
| ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
ArgumentSemantics::RegisterPointer
}
};
let ptr_type = ast::Type::Pointer(underlying_type.clone(), state_space);
let new_dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema,
},
Some(&ptr_type),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
op: ptr_src,
is_dst: false,
sema,
},
Some(&ptr_type),
)?;
let new_constant_src = visitor.id(
ArgumentDescriptor {
op: constant_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Scalar(ast::ScalarType::S64)),
)?;
result.push(Statement::PtrAdd {
underlying_type,
state_space,
dst: new_dst,
ptr_src: new_ptr_src,
constant_src: new_constant_src,
})
}
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Composite(_)
| Statement::Conversion(_)
| Statement::Constant(_)
| Statement::Undef(_, _) => unreachable!(),
}
}
Ok(result)
}
struct FlattenArguments<'a, 'b> {
func: &'b mut Vec<ExpandedStatement>,
id_def: &'b mut MutableNumericIdResolver<'a>,
post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'b> FlattenArguments<'a, 'b> {
fn new(
func: &'b mut Vec<ExpandedStatement>,
id_def: &'b mut MutableNumericIdResolver<'a>,
) -> Self {
FlattenArguments {
func,
id_def,
post_stmts: Vec::new(),
}
}
fn insert_composite_read(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver<'a>,
typ: (ast::ScalarType, u8),
scalar_dst: Option<spirv::Word>,
scalar_sema_override: Option<ArgumentSemantics>,
composite_src: (spirv::Word, u8),
) -> spirv::Word {
let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0)));
func.push(Statement::Composite(CompositeRead {
typ: typ.0,
dst: new_id,
dst_semantics_override: scalar_sema_override,
src_composite: composite_src.0,
src_index: composite_src.1 as u32,
src_len: typ.1 as u32,
}));
new_id
}
fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
_: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
fn reg_offset(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
let add_type;
match typ {
ast::Type::Pointer(underlying_type, state_space) => {
let reg_typ = self.id_def.get_typed(reg)?;
if let ast::Type::Pointer(_, _) = reg_typ {
let id_constant_stmt = self.id_def.new_id(typ.clone());
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
let dst = self.id_def.new_id(typ.clone());
self.func.push(Statement::PtrAdd {
underlying_type: underlying_type.clone(),
state_space: *state_space,
dst,
ptr_src: reg,
constant_src: id_constant_stmt,
});
return Ok(dst);
} else {
add_type = self.id_def.get_typed(reg)?;
}
}
_ => {
add_type = typ.clone();
}
};
let (width, kind) = match add_type {
ast::Type::Scalar(scalar_t) => {
let kind = match scalar_t.kind() {
kind @ ScalarKind::Bit
| kind @ ScalarKind::Unsigned
| kind @ ScalarKind::Signed => kind,
ScalarKind::Float => return Err(TranslateError::MismatchedType),
ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
ScalarKind::Pred => return Err(TranslateError::MismatchedType),
};
(scalar_t.size_of(), kind)
}
_ => return Err(TranslateError::MismatchedType),
};
let arith_detail = if kind == ScalarKind::Signed {
ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::from_size(width),
saturate: false,
})
} else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
let id_constant_stmt = self.id_def.new_id(add_type.clone());
let result_id = self.id_def.new_id(add_type);
// TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ScalarKind::Signed {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
value: ast::ImmediateValue::U64(-(offset as i64) as u64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Sub(
arith_detail,
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
} else {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
value: ast::ImmediateValue::S64(offset as i64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
arith_detail,
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
}
Ok(result_id)
}
fn immediate(
&mut self,
desc: ArgumentDescriptor<ast::ImmediateValue>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
*scalar
} else {
todo!()
};
let id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
value: desc.op,
}));
Ok(id)
}
fn member_src(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
typ: (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
if desc.is_dst {
return Err(TranslateError::Unreachable);
}
let new_id = Self::insert_composite_read(
self.func,
self.id_def,
typ,
None,
Some(desc.sema),
desc.op,
);
Ok(new_id)
}
fn vector(
&mut self,
desc: ArgumentDescriptor<&Vec<spirv::Word>>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let (scalar_type, vec_len) = typ.get_vector()?;
if !desc.is_dst {
let mut new_id = self.id_def.new_id(typ.clone());
self.func.push(Statement::Undef(typ.clone(), new_id));
for (idx, id) in desc.op.iter().enumerate() {
let newer_id = self.id_def.new_id(typ.clone());
self.func.push(Statement::Instruction(ast::Instruction::Mov(
ast::MovDetails {
typ: ast::Type::Scalar(scalar_type),
src_is_address: false,
dst_width: vec_len,
src_width: 0,
relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed,
},
ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
(newer_id, idx as u8),
new_id,
*id,
)),
)));
new_id = newer_id;
}
Ok(new_id)
} else {
let new_id = self.id_def.new_id(typ.clone());
for (idx, id) in desc.op.iter().enumerate() {
Self::insert_composite_read(
&mut self.post_stmts,
self.id_def,
(scalar_type, vec_len),
Some(*id),
Some(desc.sema),
(new_id, idx as u8),
);
}
Ok(new_id)
}
}
}
impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenArguments<'a, 'b> {
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self.reg(desc, t)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ),
ast::Operand::RegOffset(reg, offset) => {
self.reg_offset(desc.new_op((reg, offset)), typ)
}
}
}
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)),
ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
}
}
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
typ: (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
self.member_src(desc, typ)
}
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
}
}
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ),
ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ),
ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
}
}
}
/*
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- st.param [x] y (used as function return arguments) same rule as above applies
- generic/global ld: for instruction `ld x, [y]`, y must be of type
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
documented special ld/st/cvt conversion rules are applied to dst
- generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer
*/
fn insert_implicit_conversions(
func: Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
Statement::Call(call) => insert_implicit_conversions_impl(
&mut result,
id_def,
call,
should_bitcast_wrapper,
None,
)?,
Statement::Instruction(inst) => {
let mut default_conversion_fn =
should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _;
let mut state_space = None;
if let ast::Instruction::Ld(d, _) = &inst {
state_space = Some(d.state_space);
}
if let ast::Instruction::St(d, _) = &inst {
state_space = Some(d.state_space.to_ld_ss());
}
if let ast::Instruction::Atom(d, _) = &inst {
state_space = Some(d.space.to_ld_ss());
}
if let ast::Instruction::AtomCas(d, _) = &inst {
state_space = Some(d.space.to_ld_ss());
}
if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
default_conversion_fn = should_bitcast_packed;
}
insert_implicit_conversions_impl(
&mut result,
id_def,
inst,
default_conversion_fn,
state_space,
)?;
}
Statement::Composite(composite) => insert_implicit_conversions_impl(
&mut result,
id_def,
composite,
should_bitcast_wrapper,
None,
)?,
Statement::PtrAdd {
underlying_type,
state_space,
dst,
ptr_src,
constant_src,
} => {
let visit_desc = VisitArgumentDescriptor {
desc: ArgumentDescriptor {
op: ptr_src,
is_dst: false,
sema: ArgumentSemantics::PhysicalPointer,
},
typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
stmt_ctor: |new_ptr_src| Statement::PtrAdd {
underlying_type,
state_space,
dst,
ptr_src: new_ptr_src,
constant_src,
},
};
insert_implicit_conversions_impl(
&mut result,
id_def,
visit_desc,
bitcast_physical_pointer,
Some(state_space),
)?;
}
s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _)
| s @ Statement::Undef(_, _)
| s @ Statement::RetValue(_, _) => result.push(s),
Statement::Conversion(_) => unreachable!(),
}
}
Ok(result)
}
fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: impl VisitVariableExpanded,
default_conversion_fn: for<'a> fn(
&'a ast::Type,
&'a ast::Type,
Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError>,
state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
let statement = stmt.visit_variable_extended(&mut |desc, typ| {
let instr_type = match typ {
None => return Ok(desc.op),
Some(t) => t,
};
let operand_type = id_def.get_typed(desc.op)?;
let mut conversion_fn = default_conversion_fn;
match desc.sema {
ArgumentSemantics::Default => {}
ArgumentSemantics::DefaultRelaxed => {
if desc.is_dst {
conversion_fn = should_convert_relaxed_dst_wrapper;
} else {
conversion_fn = should_convert_relaxed_src_wrapper;
}
}
ArgumentSemantics::PhysicalPointer => {
conversion_fn = bitcast_physical_pointer;
}
ArgumentSemantics::RegisterPointer => {
conversion_fn = bitcast_register_pointer;
}
ArgumentSemantics::Address => {
conversion_fn = force_bitcast_ptr_to_bit;
}
};
match conversion_fn(&operand_type, instr_type, state_space)? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
let mut from = instr_type.clone();
let mut to = operand_type;
let mut src = id_def.new_id(instr_type.clone());
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
mem::swap(&mut src, &mut dst);
mem::swap(&mut from, &mut to);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
from,
to,
kind: conv_kind,
}));
result
}
None => Ok(desc.op),
}
})?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
spirv_input
.iter()
.map(|var| SpirvType::from(var.v_type.clone())),
spirv_output
.iter()
.map(|var| SpirvType::from(var.v_type.clone())),
)
}
fn emit_function_body_ops(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
func: &[ExpandedStatement],
) -> Result<(), TranslateError> {
for s in func {
match s {
Statement::Label(id) => {
if builder.block.is_some() {
builder.branch(*id)?;
}
builder.begin_block(Some(*id))?;
}
_ => {
if builder.block.is_none() && builder.function.is_some() {
builder.begin_block(None)?;
}
}
}
match s {
Statement::Label(_) => (),
Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
[(id, typ)] => (
map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
Some(*id),
),
[] => (map.void(), None),
_ => todo!(),
};
let arg_list = call
.param_list
.iter()
.map(|(id, _)| *id)
.collect::<Vec<_>>();
builder.function_call(result_type, result_id, call.func, arg_list)?;
}
Statement::Variable(var) => {
emit_variable(builder, map, var)?;
}
Statement::Constant(cnst) => {
let typ_id = map.get_or_add_scalar(builder, cnst.typ);
match (cnst.typ, cnst.value) {
(ast::ScalarType::B8, ast::ImmediateValue::U64(value))
| (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32);
}
(ast::ScalarType::B16, ast::ImmediateValue::U64(value))
| (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32);
}
(ast::ScalarType::B32, ast::ImmediateValue::U64(value))
| (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as u32);
}
(ast::ScalarType::B64, ast::ImmediateValue::U64(value))
| (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => {
builder.constant_u64(typ_id, Some(cnst.dst), value);
}
(ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32);
}
(ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32);
}
(ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32);
}
(ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => {
builder.constant_u64(typ_id, Some(cnst.dst), value as i64 as u64);
}
(ast::ScalarType::B8, ast::ImmediateValue::S64(value))
| (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32);
}
(ast::ScalarType::B16, ast::ImmediateValue::S64(value))
| (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32);
}
(ast::ScalarType::B32, ast::ImmediateValue::S64(value))
| (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as u32);
}
(ast::ScalarType::B64, ast::ImmediateValue::S64(value))
| (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => {
builder.constant_u64(typ_id, Some(cnst.dst), value as u64);
}
(ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32);
}
(ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32);
}
(ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => {
builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32);
}
(ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => {
builder.constant_u64(typ_id, Some(cnst.dst), value as u64);
}
(ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => {
builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f32(value).to_f32());
}
(ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => {
builder.constant_f32(typ_id, Some(cnst.dst), value);
}
(ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => {
builder.constant_f64(typ_id, Some(cnst.dst), value as f64);
}
(ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => {
builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f64(value).to_f32());
}
(ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => {
builder.constant_f32(typ_id, Some(cnst.dst), value as f32);
}
(ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => {
builder.constant_f64(typ_id, Some(cnst.dst), value);
}
(ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => {
let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
if value == 0 {
builder.constant_false(bool_type, Some(cnst.dst));
} else {
builder.constant_true(bool_type, Some(cnst.dst));
}
}
(ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => {
let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
if value == 0 {
builder.constant_false(bool_type, Some(cnst.dst));
} else {
builder.constant_true(bool_type, Some(cnst.dst));
}
}
_ => return Err(TranslateError::MismatchedType),
}
}
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conditional(bra) => {
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
Statement::Instruction(inst) => match inst {
ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?,
ast::Instruction::Call(_) => unreachable!(),
// SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?;
}
ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
let result_type =
map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone())));
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
builder.store(arg.src1, arg.src2, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(d, arg) => match arg {
ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src })
| ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => {
let result_type = map
.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
builder.copy_object(result_type, Some(*dst), *src)?;
}
ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
dst,
composite_src,
scalar_src,
))
| ast::Arg2Mov::Member(ast::Arg2MovMember::Both(
dst,
composite_src,
scalar_src,
)) => {
let scalar_type = d.typ.get_scalar()?;
let result_type = map.get_or_add(
builder,
SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)),
);
let result_id = Some(dst.0);
builder.composite_insert(
result_type,
result_id,
*scalar_src,
*composite_src,
[dst.1 as u32],
)?;
}
},
ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Signed(ref ctr) => {
emit_mul_sint(builder, map, opencl, ctr, arg)?
}
ast::MulDetails::Unsigned(ref ctr) => {
emit_mul_uint(builder, map, opencl, ctr, arg)?
}
ast::MulDetails::Float(ref ctr) => emit_mul_float(builder, map, ctr, arg)?,
},
ast::Instruction::Add(add, arg) => match add {
ast::ArithDetails::Signed(ref desc) => {
emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)?
}
ast::ArithDetails::Unsigned(ref desc) => {
emit_add_int(builder, map, (*desc).into(), false, arg)?
}
ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
},
ast::Instruction::Setp(setp, arg) => {
if arg.dst2.is_some() {
todo!()
}
emit_setp(builder, map, setp, arg)?;
}
ast::Instruction::Not(t, a) => {
let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
let result_id = Some(a.dst);
let operand = a.src;
match t {
ast::BooleanType::Pred => {
// HACK ALERT
// Temporary workaround until IGC gets its shit together
// Currently IGC carries two copies of SPIRV-LLVM translator
// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/.
// Obviously, old and buggy one is used for compiling L0 SPIRV
// https://github.com/intel/intel-graphics-compiler/issues/148
let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
let const_true = builder.constant_true(type_pred, None);
let const_false = builder.constant_false(type_pred, None);
builder.select(result_type, result_id, operand, const_false, const_true)
}
_ => builder.not(result_type, result_id, operand),
}?;
}
ast::Instruction::Shl(t, a) => {
let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?;
}
ast::Instruction::Shr(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if t.signed() {
builder.shift_right_arithmetic(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.shift_right_logical(result_type, Some(a.dst), a.src1, a.src2)?;
}
}
ast::Instruction::Cvt(dets, arg) => {
emit_cvt(builder, map, opencl, dets, arg)?;
}
ast::Instruction::Cvta(_, arg) => {
// This would be only meaningful if const/slm/global pointers
// had a different format than generic pointers, but they don't pretty much by ptx definition
// Honestly, I have no idea why this instruction exists and is emitted by the compiler
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::Mad(mad, arg) => match mad {
ast::MulDetails::Signed(ref desc) => {
emit_mad_sint(builder, map, opencl, desc, arg)?
}
ast::MulDetails::Unsigned(ref desc) => {
emit_mad_uint(builder, map, opencl, desc, arg)?
}
ast::MulDetails::Float(desc) => {
emit_mad_float(builder, map, opencl, desc, arg)?
}
},
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if *t == ast::BooleanType::Pred {
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
}
}
ast::Instruction::Sub(d, arg) => match d {
ast::ArithDetails::Signed(desc) => {
emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?;
}
ast::ArithDetails::Unsigned(desc) => {
emit_sub_int(builder, map, (*desc).into(), false, arg)?;
}
ast::ArithDetails::Float(desc) => {
emit_sub_float(builder, map, desc, arg)?;
}
},
ast::Instruction::Min(d, a) => {
emit_min(builder, map, opencl, d, a)?;
}
ast::Instruction::Max(d, a) => {
emit_max(builder, map, opencl, d, a)?;
}
ast::Instruction::Rcp(d, a) => {
emit_rcp(builder, map, d, a)?;
}
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if *t == ast::BooleanType::Pred {
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
}
}
ast::Instruction::Selp(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
builder.select(result_type, Some(a.dst), a.src3, a.src2, a.src2)?;
}
// TODO: implement named barriers
ast::Instruction::Bar(d, _) => {
let workgroup_scope = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(spirv::Scope::Workgroup as u32),
)?;
let barrier_semantics = match d {
ast::BarDetails::SyncAligned => map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(
spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
| spirv::MemorySemantics::WORKGROUP_MEMORY
| spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
),
)?,
};
builder.control_barrier(workgroup_scope, workgroup_scope, barrier_semantics)?;
}
ast::Instruction::Atom(details, arg) => {
emit_atom(builder, map, details, arg)?;
}
ast::Instruction::AtomCas(details, arg) => {
let result_type = map.get_or_add_scalar(builder, details.typ.into());
let memory_const = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(details.scope.to_spirv() as u32),
)?;
let semantics_const = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(details.semantics.to_spirv().bits()),
)?;
builder.atomic_compare_exchange(
result_type,
Some(arg.dst),
arg.src1,
memory_const,
semantics_const,
semantics_const,
arg.src3,
arg.src2,
)?;
}
ast::Instruction::Div(details, arg) => match details {
ast::DivDetails::Unsigned(t) => {
let result_type = map.get_or_add_scalar(builder, (*t).into());
builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::DivDetails::Signed(t) => {
let result_type = map.get_or_add_scalar(builder, (*t).into());
builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::DivDetails::Float(t) => {
let result_type = map.get_or_add_scalar(builder, t.typ.into());
builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_float_div_decoration(builder, arg.dst, t.kind);
}
},
ast::Instruction::Sqrt(details, a) => {
emit_sqrt(builder, map, opencl, details, a)?;
}
ast::Instruction::Rsqrt(details, a) => {
let result_type = map.get_or_add_scalar(builder, details.typ.into());
builder.ext_inst(
result_type,
Some(a.dst),
opencl,
spirv::CLOp::native_rsqrt as spirv::Word,
&[a.src],
)?;
}
ast::Instruction::Neg(details, arg) => {
let result_type = map.get_or_add_scalar(builder, details.typ);
let negate_func = if details.typ.kind() == ScalarKind::Float {
dr::Builder::f_negate
} else {
dr::Builder::s_negate
};
negate_func(builder, result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Sin { arg, .. } => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::sin as u32,
[arg.src],
)?;
}
ast::Instruction::Cos { arg, .. } => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::cos as u32,
[arg.src],
)?;
}
ast::Instruction::Lg2 { arg, .. } => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::log2 as u32,
[arg.src],
)?;
}
ast::Instruction::Ex2 { arg, .. } => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::exp2 as u32,
[arg.src],
)?;
}
ast::Instruction::Clz { typ, arg } => {
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::clz as u32,
[arg.src],
)?;
}
ast::Instruction::Brev { typ, arg } => {
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder.bit_reverse(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Popc { typ, arg } => {
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder.bit_count(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Xor { typ, arg } => {
let builder_fn = match typ {
ast::BooleanType::Pred => emit_logical_xor_spirv,
_ => dr::Builder::bitwise_xor,
};
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::Instruction::Bfe { typ, arg } => {
let builder_fn = if typ.is_signed() {
dr::Builder::bit_field_s_extract
} else {
dr::Builder::bit_field_u_extract
};
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder_fn(
builder,
result_type,
Some(arg.dst),
arg.src1,
arg.src2,
arg.src3,
)?;
}
ast::Instruction::Rem { typ, arg } => {
let builder_fn = if typ.is_signed() {
dr::Builder::s_mod
} else {
dr::Builder::u_mod
};
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
}
Statement::StoreVar(arg, _) => {
builder.store(arg.src1, arg.src2, None, [])?;
}
Statement::RetValue(_, id) => {
builder.ret_value(*id)?;
}
Statement::Composite(c) => {
let result_type = map.get_or_add_scalar(builder, c.typ.into());
let result_id = Some(c.dst);
builder.composite_extract(
result_type,
result_id,
c.src_composite,
[c.src_index],
)?;
}
Statement::Undef(t, id) => {
let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
Statement::PtrAdd {
underlying_type,
state_space,
dst,
ptr_src,
constant_src,
} => {
let s64_type = map.get_or_add_scalar(builder, ast::ScalarType::S64);
let ptr_as_s64 = builder.bitcast(s64_type, None, *ptr_src)?;
let added_ptr = builder.i_add(s64_type, None, ptr_as_s64, *constant_src)?;
let result_type = map.get_or_add(
builder,
SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
);
builder.bitcast(result_type, Some(*dst), added_ptr)?;
}
}
}
Ok(())
}
// TODO: check what kind of assembly do we emit
fn emit_logical_xor_spirv(
builder: &mut dr::Builder,
result_type: spirv::Word,
result_id: Option<spirv::Word>,
op1: spirv::Word,
op2: spirv::Word,
) -> Result<spirv::Word, dr::Error> {
let temp_or = builder.logical_or(result_type, None, op1, op2)?;
let temp_and = builder.logical_and(result_type, None, op1, op2)?;
let temp_neg = builder.logical_not(result_type, None, temp_and)?;
builder.logical_and(result_type, result_id, temp_or, temp_neg)
}
fn emit_sqrt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
details: &ast::SqrtDetails,
a: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), TranslateError> {
let result_type = map.get_or_add_scalar(builder, details.typ.into());
let (ocl_op, rounding) = match details.kind {
ast::SqrtKind::Approx => (spirv::CLOp::native_sqrt, None),
ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
};
builder.ext_inst(
result_type,
Some(a.dst),
opencl,
ocl_op as spirv::Word,
&[a.src],
)?;
emit_rounding_decoration(builder, a.dst, rounding);
Ok(())
}
fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) {
match kind {
ast::DivFloatKind::Approx => {
builder.decorate(
dst,
spirv::Decoration::FPFastMathMode,
&[dr::Operand::FPFastMathMode(
spirv::FPFastMathMode::ALLOW_RECIP,
)],
);
}
ast::DivFloatKind::Rounding(rnd) => {
emit_rounding_decoration(builder, dst, Some(rnd));
}
ast::DivFloatKind::Full => {}
}
}
fn emit_atom(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
details: &ast::AtomDetails,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), TranslateError> {
let (spirv_op, typ) = match details.inner {
ast::AtomInnerDetails::Bit { op, typ } => {
let spirv_op = match op {
ast::AtomBitOp::And => dr::Builder::atomic_and,
ast::AtomBitOp::Or => dr::Builder::atomic_or,
ast::AtomBitOp::Xor => dr::Builder::atomic_xor,
ast::AtomBitOp::Exchange => dr::Builder::atomic_exchange,
};
(spirv_op, ast::ScalarType::from(typ))
}
ast::AtomInnerDetails::Unsigned { op, typ } => {
let spirv_op = match op {
ast::AtomUIntOp::Add => dr::Builder::atomic_i_add,
ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => {
return Err(TranslateError::Unreachable);
}
ast::AtomUIntOp::Min => dr::Builder::atomic_u_min,
ast::AtomUIntOp::Max => dr::Builder::atomic_u_max,
};
(spirv_op, typ.into())
}
ast::AtomInnerDetails::Signed { op, typ } => {
let spirv_op = match op {
ast::AtomSIntOp::Add => dr::Builder::atomic_i_add,
ast::AtomSIntOp::Min => dr::Builder::atomic_s_min,
ast::AtomSIntOp::Max => dr::Builder::atomic_s_max,
};
(spirv_op, typ.into())
}
// TODO: Hardware is capable of this, implement it through builtin
ast::AtomInnerDetails::Float { .. } => todo!(),
};
let result_type = map.get_or_add_scalar(builder, typ);
let memory_const = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(details.scope.to_spirv() as u32),
)?;
let semantics_const = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(details.semantics.to_spirv().bits()),
)?;
spirv_op(
builder,
result_type,
Some(arg.dst),
arg.src1,
memory_const,
semantics_const,
arg.src2,
)?;
Ok(())
}
#[derive(Clone)]
struct PtxImplImport {
out_arg: ast::Type,
fn_id: u32,
in_args: Vec<ast::Type>,
}
fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str {
match sema {
ast::AtomSemantics::Relaxed => "relaxed",
ast::AtomSemantics::Acquire => "acquire",
ast::AtomSemantics::Release => "release",
ast::AtomSemantics::AcquireRelease => "acq_rel",
}
}
fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
match scope {
ast::MemScope::Cta => "cta",
ast::MemScope::Gpu => "gpu",
ast::MemScope::Sys => "sys",
}
}
fn ptx_space_name(space: ast::AtomSpace) -> &'static str {
match space {
ast::AtomSpace::Generic => "generic",
ast::AtomSpace::Global => "global",
ast::AtomSpace::Shared => "shared",
}
}
fn emit_mul_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
ctr: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if ctr.saturate {
todo!()
}
let result_type = map.get_or_add_scalar(builder, ctr.typ.into());
builder.f_mul(result_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_rounding_decoration(builder, arg.dst, ctr.rounding);
Ok(())
}
fn emit_rcp(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::RcpDetails,
a: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), TranslateError> {
let (instr_type, constant) = if desc.is_f64 {
(ast::ScalarType::F64, vec_repr(1.0f64))
} else {
(ast::ScalarType::F32, vec_repr(1.0f32))
};
let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
let result_type = map.get_or_add_scalar(builder, instr_type);
builder.f_div(result_type, Some(a.dst), one, a.src)?;
emit_rounding_decoration(builder, a.dst, desc.rounding);
builder.decorate(
a.dst,
spirv::Decoration::FPFastMathMode,
&[dr::Operand::FPFastMathMode(
spirv::FPFastMathMode::ALLOW_RECIP,
)],
);
Ok(())
}
fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
let mut result = vec![0; mem::size_of::<T>()];
unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) };
result
}
fn emit_variable(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
var: &ast::Variable<ast::VariableType, spirv::Word>,
) -> Result<(), TranslateError> {
let (must_init, st_class) = match var.v_type {
ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
(false, spirv::StorageClass::Function)
}
ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
};
let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
builder,
&ast::Type::from(var.v_type.clone()),
&*var.array_init,
)?)
} else if must_init {
let type_id = map.get_or_add(
builder,
SpirvType::from(ast::Type::from(var.v_type.clone())),
);
Some(builder.constant_null(type_id, None))
} else {
None
};
let ptr_type_id = map.get_or_add(
builder,
SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
);
builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
if let Some(align) = var.align {
builder.decorate(
var.name,
spirv::Decoration::Alignment,
&[dr::Operand::LiteralInt32(align)],
);
}
Ok(())
}
fn emit_mad_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulUInt,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
}
ast::MulIntControl::High => {
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::u_mad_hi as spirv::Word,
[arg.src1, arg.src2, arg.src3],
)?;
}
ast::MulIntControl::Wide => todo!(),
};
Ok(())
}
fn emit_mad_sint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulSInt,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
}
ast::MulIntControl::High => {
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::s_mad_hi as spirv::Word,
[arg.src1, arg.src2, arg.src3],
)?;
}
ast::MulIntControl::Wide => todo!(),
};
Ok(())
}
fn emit_mad_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::ArithFloat,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::mad as spirv::Word,
[arg.src1, arg.src2, arg.src3],
)?;
Ok(())
}
fn emit_add_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
Ok(())
}
fn emit_sub_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
Ok(())
}
fn emit_min(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MinMaxDetails,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let cl_op = match desc {
ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
};
let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
cl_op as spirv::Word,
[arg.src1, arg.src2],
)?;
Ok(())
}
fn emit_max(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MinMaxDetails,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let cl_op = match desc {
ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
};
let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
cl_op as spirv::Word,
[arg.src1, arg.src2],
)?;
Ok(())
}
fn emit_cvt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
dets: &ast::CvtDetails,
arg: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), TranslateError> {
match dets {
ast::CvtDetails::FloatFromFloat(desc) => {
if desc.saturate {
todo!()
}
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.dst == desc.src {
match desc.rounding {
Some(ast::RoundingMode::NearestEven) => {
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::rint as u32,
[arg.src],
)?;
}
Some(ast::RoundingMode::Zero) => {
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::trunc as u32,
[arg.src],
)?;
}
Some(ast::RoundingMode::NegativeInf) => {
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::floor as u32,
[arg.src],
)?;
}
Some(ast::RoundingMode::PositiveInf) => {
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
spirv::CLOp::ceil as u32,
[arg.src],
)?;
}
None => {
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
}
} else {
builder.f_convert(result_type, Some(arg.dst), arg.src)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
}
}
ast::CvtDetails::FloatFromInt(desc) => {
if desc.saturate {
todo!()
}
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.src.is_signed() {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
}
emit_rounding_decoration(builder, arg.dst, desc.rounding);
}
ast::CvtDetails::IntFromFloat(desc) => {
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.dst.is_signed() {
builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
}
emit_rounding_decoration(builder, arg.dst, desc.rounding);
emit_saturating_decoration(builder, arg.dst, desc.saturate);
}
ast::CvtDetails::IntFromInt(desc) => {
let dest_t: ast::ScalarType = desc.dst.into();
let src_t: ast::ScalarType = desc.src.into();
// first do shortening/widening
let src = if desc.dst.width() != desc.src.width() {
let new_dst = if dest_t.kind() == src_t.kind() {
arg.dst
} else {
builder.id()
};
let cv = ImplicitConversion {
src: arg.src,
dst: new_dst,
from: ast::Type::Scalar(src_t),
to: ast::Type::Scalar(ast::ScalarType::from_parts(
dest_t.size_of(),
src_t.kind(),
)),
kind: ConversionKind::Default,
};
emit_implicit_conversion(builder, map, &cv)?;
new_dst
} else {
arg.src
};
if dest_t.kind() == src_t.kind() {
return Ok(());
}
// now do actual conversion
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.saturate {
if desc.dst.is_signed() {
builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
} else {
builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
}
} else {
builder.bitcast(result_type, Some(arg.dst), src)?;
}
}
}
Ok(())
}
fn emit_saturating_decoration(builder: &mut dr::Builder, dst: u32, saturate: bool) {
if saturate {
builder.decorate(dst, spirv::Decoration::SaturatedConversion, []);
}
}
fn emit_rounding_decoration(
builder: &mut dr::Builder,
dst: spirv::Word,
rounding: Option<ast::RoundingMode>,
) {
if let Some(rounding) = rounding {
builder.decorate(
dst,
spirv::Decoration::FPRoundingMode,
[rounding.to_spirv()],
);
}
}
impl ast::RoundingMode {
fn to_spirv(self) -> rspirv::dr::Operand {
let mode = match self {
ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE,
ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ,
ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP,
ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN,
};
rspirv::dr::Operand::FPRoundingMode(mode)
}
}
fn emit_setp(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
setp: &ast::SetpData,
arg: &ast::Arg4Setp<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred));
let result_id = Some(arg.dst1);
let operand_1 = arg.src1;
let operand_2 = arg.src2;
match (setp.cmp_op, setp.typ.kind()) {
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
builder.s_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Float) => {
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Float) => {
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanEq, _) => {
builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanNotEq, _) => {
builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanLess, _) => {
builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanLessOrEq, _) => {
builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanGreater, _) => {
builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanGreaterOrEq, _) => {
builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
_ => todo!(),
}?;
Ok(())
}
fn emit_mul_sint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulSInt,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ);
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::MulIntControl::High => {
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::s_mul_hi as spirv::Word,
[arg.src1, arg.src2],
)?;
}
ast::MulIntControl::Wide => {
let mul_ext_type = SpirvType::Struct(vec![
SpirvScalarKey::from(instruction_type),
SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.size_of();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
struct2_bitcast_to_wide(
builder,
map,
SpirvScalarKey::from(instruction_type),
inst_type,
arg.dst,
dst_type_id,
mul,
)?;
}
}
Ok(())
}
fn emit_mul_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulUInt,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ);
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::MulIntControl::High => {
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::u_mul_hi as spirv::Word,
[arg.src1, arg.src2],
)?;
}
ast::MulIntControl::Wide => {
let mul_ext_type = SpirvType::Struct(vec![
SpirvScalarKey::from(instruction_type),
SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.size_of();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
struct2_bitcast_to_wide(
builder,
map,
SpirvScalarKey::from(instruction_type),
inst_type,
arg.dst,
dst_type_id,
mul,
)?;
}
}
Ok(())
}
// Surprisingly, structs can't be bitcast, so we route everything through a vector
fn struct2_bitcast_to_wide(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
base_type_key: SpirvScalarKey,
instruction_type: spirv::Word,
dst: spirv::Word,
dst_type_id: spirv::Word,
src: spirv::Word,
) -> Result<(), dr::Error> {
let low_bits = builder.composite_extract(instruction_type, None, src, [0])?;
let high_bits = builder.composite_extract(instruction_type, None, src, [1])?;
let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2));
let vector = builder.composite_construct(vector_type, None, [low_bits, high_bits])?;
builder.bitcast(dst_type_id, Some(dst), vector)?;
Ok(())
}
fn emit_abs(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
d: &ast::AbsDetails,
arg: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let scalar_t = ast::ScalarType::from(d.typ);
let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
let cl_abs = if scalar_t.kind() == ScalarKind::Signed {
spirv::CLOp::s_abs
} else {
spirv::CLOp::fabs
};
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
cl_abs as spirv::Word,
[arg.src],
)?;
Ok(())
}
fn emit_add_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
typ: ast::ScalarType,
saturate: bool,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if saturate {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
fn emit_sub_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
typ: ast::ScalarType,
saturate: bool,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if saturate {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
fn emit_implicit_conversion(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), TranslateError> {
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
(_, _, ConversionKind::PtrToBit(typ)) => {
let dst_type = map.get_or_add_scalar(builder, typ.into());
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
}
(_, _, ConversionKind::BitToPtr(_)) => {
let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width {
let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
if from_parts.scalar_kind != ScalarKind::Float
&& to_parts.scalar_kind != ScalarKind::Float
{
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
} else {
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
} else {
// This block is safe because it's illegal to implictly convert between floating point instructions
let same_width_bit_type = map.get_or_add(
builder,
SpirvType::from(ast::Type::from_parts(TypeParts {
scalar_kind: ScalarKind::Bit,
..from_parts
})),
);
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts {
scalar_kind: ScalarKind::Bit,
..to_parts
});
let wide_bit_type_spirv =
map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
if to_parts.scalar_kind == ScalarKind::Unsigned
|| to_parts.scalar_kind == ScalarKind::Bit
{
builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else {
let wide_bit_value =
builder.u_convert(wide_bit_type_spirv, None, same_width_bit_value)?;
emit_implicit_conversion(
builder,
map,
&ImplicitConversion {
src: wide_bit_value,
dst: cv.dst,
from: wide_bit_type,
to: cv.to.clone(),
kind: ConversionKind::Default,
},
)?;
}
}
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
(_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
let result_type = if spirv_ptr {
map.get_or_add(
builder,
SpirvType::Pointer(
Box::new(SpirvType::from(cv.to.clone())),
spirv::StorageClass::Function,
),
)
} else {
map.get_or_add(builder, SpirvType::from(cv.to.clone()))
};
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
}
_ => unreachable!(),
}
Ok(())
}
fn normalize_identifiers<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
ast::Statement::Label(id) => {
id_defs.add_def(*id, None);
}
_ => (),
}
}
let mut result = Vec::new();
for s in func {
expand_map_variables(id_defs, fn_defs, &mut result, s)?;
}
Ok(result)
}
fn expand_map_variables<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
result: &mut Vec<NormalizedStatement>,
s: ast::Statement<ast::ParsedArgParams<'a>>,
) -> Result<(), TranslateError> {
match s {
ast::Statement::Block(block) => {
id_defs.start_block();
for s in block {
expand_map_variables(id_defs, fn_defs, result, s)?;
}
id_defs.end_block();
}
ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id)))
.transpose()?,
i.map_variable(&mut |id| id_defs.get_id(id))?,
))),
ast::Statement::Variable(var) => {
let ss = match var.var.v_type {
ast::VariableType::Reg(_) => StateSpace::Reg,
ast::VariableType::Global(_) => StateSpace::Global,
ast::VariableType::Shared(_) => StateSpace::Shared,
ast::VariableType::Param(_) => StateSpace::ParamReg,
ast::VariableType::Local(_) => StateSpace::Local,
};
let mut var_type = ast::Type::from(var.var.v_type.clone());
var_type = match var.var.v_type {
ast::VariableType::Reg(_) | ast::VariableType::Shared(_) => var_type,
ast::VariableType::Global(_) => var_type.pointer_to(ast::LdStateSpace::Global)?,
ast::VariableType::Param(_) => var_type.pointer_to(ast::LdStateSpace::Param)?,
ast::VariableType::Local(_) => var_type.pointer_to(ast::LdStateSpace::Local)?,
};
match var.count {
Some(count) => {
for new_id in id_defs.add_defs(var.var.name, count, ss, var_type) {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
name: new_id,
array_init: var.var.array_init.clone(),
}))
}
}
None => {
let new_id = id_defs.add_def(var.var.name, Some((ss, var_type)));
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
name: new_id,
array_init: var.var.array_init,
}));
}
}
}
};
Ok(())
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {
Tid,
Ntid,
Ctaid,
Nctaid,
}
impl PtxSpecialRegister {
fn try_parse(s: &str) -> Option<Self> {
match s {
"%tid" => Some(Self::Tid),
"%ntid" => Some(Self::Ntid),
"%ctaid" => Some(Self::Ctaid),
"%nctaid" => Some(Self::Nctaid),
_ => None,
}
}
fn get_type(self) -> ast::Type {
match self {
PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4),
PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4),
PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
}
}
fn get_builtin(self) -> spirv::BuiltIn {
match self {
PtxSpecialRegister::Tid => spirv::BuiltIn::LocalInvocationId,
PtxSpecialRegister::Ntid => spirv::BuiltIn::WorkgroupSize,
PtxSpecialRegister::Ctaid => spirv::BuiltIn::WorkgroupId,
PtxSpecialRegister::Nctaid => spirv::BuiltIn::NumWorkgroups,
}
}
}
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
variables_type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
fns: HashMap<spirv::Word, FnDecl>,
}
pub struct FnDecl {
ret_vals: Vec<ast::FnArgumentType>,
params: Vec<ast::FnArgumentType>,
}
impl<'a> GlobalStringIdResolver<'a> {
fn new(start_id: spirv::Word) -> Self {
Self {
current_id: start_id,
variables: HashMap::new(),
variables_type_check: HashMap::new(),
special_registers: HashMap::new(),
fns: HashMap::new(),
}
}
fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
self.get_or_add_impl(id, None)
}
fn get_or_add_def_typed(&mut self, id: &'a str, typ: (StateSpace, ast::Type)) -> spirv::Word {
self.get_or_add_impl(id, Some(typ))
}
fn get_or_add_impl(
&mut self,
id: &'a str,
typ: Option<(StateSpace, ast::Type)>,
) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
let numeric_id = self.current_id;
e.insert(numeric_id);
self.current_id += 1;
numeric_id
}
};
self.variables_type_check.insert(id, typ);
id
}
fn get_id(&self, id: &str) -> Result<spirv::Word, TranslateError> {
self.variables
.get(id)
.copied()
.ok_or(TranslateError::UnknownSymbol)
}
fn current_id(&self) -> spirv::Word {
self.current_id
}
fn start_fn<'b>(
&'b mut self,
header: &'b ast::MethodDecl<'a, &'a str>,
) -> Result<
(
FnStringIdResolver<'a, 'b>,
GlobalFnDeclResolver<'a, 'b>,
ast::MethodDecl<'a, spirv::Word>,
),
TranslateError,
> {
// In case a function decl was inserted earlier we want to use its id
let name_id = self.get_or_add_def(header.name());
let mut fn_resolver = FnStringIdResolver {
current_id: &mut self.current_id,
global_variables: &self.variables,
global_type_check: &self.variables_type_check,
special_registers: &mut self.special_registers,
variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
};
let new_fn_decl = match header {
ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel {
name,
in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?,
},
ast::MethodDecl::Func(ret_params, _, params) => {
let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?;
let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?;
self.fns.insert(
name_id,
FnDecl {
ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
},
);
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
}
};
Ok((
fn_resolver,
GlobalFnDeclResolver {
variables: &self.variables,
fns: &self.fns,
},
new_fn_decl,
))
}
}
pub struct GlobalFnDeclResolver<'input, 'a> {
variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
fns: &'a HashMap<spirv::Word, FnDecl>,
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
}
fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> {
match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
Some(Some(fn_d)) => Ok(fn_d),
_ => Err(TranslateError::UnknownSymbol),
}
}
}
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
fn finish(self) -> NumericIdResolver<'b> {
NumericIdResolver {
current_id: self.current_id,
global_type_check: self.global_type_check,
type_check: self.type_check,
special_registers: self
.special_registers
.iter()
.map(|(reg, id)| (*id, *reg))
.collect(),
}
}
fn start_block(&mut self) {
self.variables.push(HashMap::new())
}
fn end_block(&mut self) {
self.variables.pop();
}
fn get_id(&mut self, id: &str) -> Result<spirv::Word, TranslateError> {
for scope in self.variables.iter().rev() {
match scope.get(id) {
Some(id) => return Ok(*id),
None => continue,
}
}
match self.global_variables.get(id) {
Some(id) => Ok(*id),
None => {
let sreg =
PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?;
match self.special_registers.entry(sreg) {
hash_map::Entry::Occupied(e) => Ok(*e.get()),
hash_map::Entry::Vacant(e) => {
let numeric_id = *self.current_id;
*self.current_id += 1;
e.insert(numeric_id);
Ok(numeric_id)
}
}
}
}
}
fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
self.type_check.insert(numeric_id, typ);
*self.current_id += 1;
numeric_id
}
#[must_use]
fn add_defs(
&mut self,
base_id: &'a str,
count: u32,
ss: StateSpace,
typ: ast::Type,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
for i in 0..count {
self.variables
.last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
self.type_check
.insert(numeric_id + i, Some((ss, typ.clone())));
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
}
}
struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: HashMap<spirv::Word, PtxSpecialRegister>,
}
impl<'b> NumericIdResolver<'b> {
fn finish(self) -> MutableNumericIdResolver<'b> {
MutableNumericIdResolver { base: self }
}
fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(&id) {
Some(x) => Ok((StateSpace::Reg, x.get_type())),
None => match self.global_type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
},
},
}
}
fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
let new_id = *self.current_id;
self.type_check.insert(new_id, typ);
*self.current_id += 1;
new_id
}
}
struct MutableNumericIdResolver<'b> {
base: NumericIdResolver<'b>,
}
impl<'b> MutableNumericIdResolver<'b> {
fn unmut(self) -> NumericIdResolver<'b> {
self.base
}
fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
self.base.get_typed(id).map(|(_, t)| t)
}
fn new_id(&mut self, typ: ast::Type) -> spirv::Word {
self.base.new_id(Some((StateSpace::Reg, typ)))
}
}
enum Statement<I, P: ast::ArgParams> {
Label(u32),
Variable(ast::Variable<ast::VariableType, P::Id>),
Instruction(I),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Call(ResolvedCall<P>),
Composite(CompositeRead),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
Undef(ast::Type, spirv::Word),
PtrAdd {
underlying_type: ast::PointerType,
state_space: ast::LdStateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,
constant_src: spirv::Word,
},
}
impl ExpandedStatement {
fn map_id(self, f: &mut impl FnMut(spirv::Word) -> spirv::Word) -> ExpandedStatement {
match self {
Statement::Label(id) => Statement::Label(f(id)),
Statement::Variable(mut var) => {
var.name = f(var.name);
Statement::Variable(var)
}
Statement::Instruction(inst) => inst
.visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| Ok(f(arg.op)))
.unwrap(),
Statement::LoadVar(mut arg, typ) => {
arg.dst = f(arg.dst);
arg.src = f(arg.src);
Statement::LoadVar(arg, typ)
}
Statement::StoreVar(mut arg, typ) => {
arg.src1 = f(arg.src1);
arg.src2 = f(arg.src2);
Statement::StoreVar(arg, typ)
}
Statement::Call(mut call) => {
for (id, _) in call.ret_params.iter_mut() {
*id = f(*id);
}
call.func = f(call.func);
for (id, _) in call.param_list.iter_mut() {
*id = f(*id);
}
Statement::Call(call)
}
Statement::Composite(mut composite) => {
composite.dst = f(composite.dst);
composite.src_composite = f(composite.src_composite);
Statement::Composite(composite)
}
Statement::Conditional(mut conditional) => {
conditional.predicate = f(conditional.predicate);
conditional.if_true = f(conditional.if_true);
conditional.if_false = f(conditional.if_false);
Statement::Conditional(conditional)
}
Statement::Conversion(mut conv) => {
conv.dst = f(conv.dst);
conv.src = f(conv.src);
Statement::Conversion(conv)
}
Statement::Constant(mut constant) => {
constant.dst = f(constant.dst);
Statement::Constant(constant)
}
Statement::RetValue(data, id) => {
let id = f(id);
Statement::RetValue(data, id)
}
Statement::Undef(typ, id) => {
let id = f(id);
Statement::Undef(typ, id)
}
Statement::PtrAdd {
underlying_type,
state_space,
dst,
ptr_src,
constant_src,
} => {
let dst = f(dst);
let ptr_src = f(ptr_src);
let constant_src = f(constant_src);
Statement::PtrAdd {
underlying_type,
state_space,
dst,
ptr_src,
constant_src,
}
}
}
}
}
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
pub func: spirv::Word,
pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
fn cast<U: ast::ArgParams<CallOperand = T::CallOperand>>(self) -> ResolvedCall<U> {
ResolvedCall {
uniform: self.uniform,
ret_params: self.ret_params,
func: self.func,
param_list: self.param_list,
}
}
}
impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<From, To>>(
self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
let ret_params = self
.ret_params
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
let new_id = visitor.id(
ArgumentDescriptor {
op: id,
is_dst: !typ.is_param(),
sema: typ.semantics(),
},
Some(&typ.to_func_type()),
)?;
Ok((new_id, typ))
})
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.id(
ArgumentDescriptor {
op: self.func,
is_dst: false,
sema: ArgumentSemantics::Default,
},
None,
)?;
let param_list = self
.param_list
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
let new_id = visitor.src_call_operand(
ArgumentDescriptor {
op: id,
is_dst: false,
sema: typ.semantics(),
},
&typ.to_func_type(),
)?;
Ok((new_id, typ))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
uniform: self.uniform,
ret_params,
func,
param_list,
})
}
}
impl VisitVariable for ResolvedCall<TypedArgParams> {
fn visit_variable<
'a,
F: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<TypedStatement, TranslateError> {
Ok(Statement::Call(self.map(f)?))
}
}
impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError> {
Ok(Statement::Call(self.map(f)?))
}
}
pub trait ArgParamsEx: ast::ArgParams + Sized {
fn get_fn_decl<'x, 'b>(
id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'x, 'b>,
) -> Result<&'b FnDecl, TranslateError>;
}
impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
fn get_fn_decl<'x, 'b>(
id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'x, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl_str(id)
}
}
enum NormalizedArgParams {}
impl ast::ArgParams for NormalizedArgParams {
type Id = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
type IdOrVector = ast::IdOrVector<spirv::Word>;
type OperandOrVector = ast::OperandOrVector<spirv::Word>;
type SrcMemberOperand = (spirv::Word, u8);
}
impl ArgParamsEx for NormalizedArgParams {
fn get_fn_decl<'a, 'b>(
id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'a, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
type NormalizedStatement = Statement<
(
Option<ast::PredAt<spirv::Word>>,
ast::Instruction<NormalizedArgParams>,
),
NormalizedArgParams,
>;
type UnconditionalStatement = Statement<ast::Instruction<NormalizedArgParams>, NormalizedArgParams>;
enum TypedArgParams {}
impl ast::ArgParams for TypedArgParams {
type Id = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
type IdOrVector = ast::IdOrVector<spirv::Word>;
type OperandOrVector = ast::OperandOrVector<spirv::Word>;
type SrcMemberOperand = (spirv::Word, u8);
}
impl ArgParamsEx for TypedArgParams {
fn get_fn_decl<'a, 'b>(
id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'a, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
impl ast::ArgParams for ExpandedArgParams {
type Id = spirv::Word;
type Operand = spirv::Word;
type CallOperand = spirv::Word;
type IdOrVector = spirv::Word;
type OperandOrVector = spirv::Word;
type SrcMemberOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
fn get_fn_decl<'a, 'b>(
id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'a, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
#[derive(Copy, Clone)]
pub enum StateSpace {
Reg,
Const,
Global,
Local,
Shared,
Param,
ParamReg,
}
impl From<ast::StateSpace> for StateSpace {
fn from(ss: ast::StateSpace) -> Self {
match ss {
ast::StateSpace::Reg => StateSpace::Reg,
ast::StateSpace::Const => StateSpace::Const,
ast::StateSpace::Global => StateSpace::Global,
ast::StateSpace::Local => StateSpace::Local,
ast::StateSpace::Shared => StateSpace::Shared,
ast::StateSpace::Param => StateSpace::Param,
}
}
}
enum Directive<'input> {
Variable(ast::Variable<ast::VariableType, spirv::Word>),
Method(Function<'input>),
}
struct Function<'input> {
pub func_decl: ast::MethodDecl<'input, spirv::Word>,
pub spirv_decl: SpirvMethodDecl<'input>,
pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
}
pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn id(
&mut self,
desc: ArgumentDescriptor<T::Id>,
typ: Option<&ast::Type>,
) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
typ: &ast::Type,
) -> Result<U::Operand, TranslateError>;
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<T::IdOrVector>,
typ: &ast::Type,
) -> Result<U::IdOrVector, TranslateError>;
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<T::OperandOrVector>,
typ: &ast::Type,
) -> Result<U::OperandOrVector, TranslateError>;
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<T::CallOperand>,
typ: &ast::Type,
) -> Result<U::CallOperand, TranslateError>;
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<T::SrcMemberOperand>,
typ: (ast::ScalarType, u8),
) -> Result<U::SrcMemberOperand, TranslateError>;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(typ))
}
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(typ))
}
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
(scalar_type, _): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type)))
}
}
impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> for T
where
T: FnMut(&str) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
_: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc.op)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
_: &ast::Type,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)),
ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)),
ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
}
}
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<&'a str>>,
_: &ast::Type,
) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)),
ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec(
ids.into_iter().map(self).collect::<Result<_, _>>()?,
)),
}
}
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<&'a str>>,
_: &ast::Type,
) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)),
ast::OperandOrVector::RegOffset(id, imm) => {
Ok(ast::OperandOrVector::RegOffset(self(id)?, imm))
}
ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec(
ids.into_iter().map(self).collect::<Result<_, _>>()?,
)),
}
}
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<&str>>,
_: &ast::Type,
) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
match desc.op {
ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)),
ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)),
}
}
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(&str, u8)>,
_: (ast::ScalarType, u8),
) -> Result<(spirv::Word, u8), TranslateError> {
Ok((self(desc.op.0)?, desc.op.1))
}
}
pub struct ArgumentDescriptor<Op> {
op: Op,
is_dst: bool,
sema: ArgumentSemantics,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArgumentSemantics {
// normal register access
Default,
// normal register access with relaxed conversion rules (ld/st)
DefaultRelaxed,
// st/ld global
PhysicalPointer,
// st/ld .param, .local
RegisterPointer,
// mov of .local/.global variables
Address,
}
impl<T> ArgumentDescriptor<T> {
fn new_op<U>(&self, u: U) -> ArgumentDescriptor<U> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
sema: self.sema,
}
}
}
impl<T: ArgParamsEx> ast::Instruction<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
) -> Result<ast::Instruction<U>, TranslateError> {
Ok(match self {
ast::Instruction::Abs(d, arg) => {
ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
ast::Instruction::Ld(d, a) => {
let new_args = a.map(visitor, &d)?;
ast::Instruction::Ld(d, new_args)
}
ast::Instruction::Mov(d, a) => {
let mapped = a.map(visitor, &d)?;
ast::Instruction::Mov(d, mapped)
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
let is_wide = d.is_wide();
ast::Instruction::Mul(d, a.map_non_shift(visitor, &inst_type, is_wide)?)
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
ast::Instruction::Add(d, a.map_non_shift(visitor, &inst_type, false)?)
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
ast::Instruction::Setp(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
ast::Instruction::SetpBool(d, a) => {
let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::FloatFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromFloat(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
};
ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
}
ast::Instruction::Shr(t, a) => {
ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
}
ast::Instruction::St(d, a) => {
let new_args = a.map(visitor, &d)?;
ast::Instruction::St(d, new_args)
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
ast::Instruction::Cvta(d, a.map(visitor, &inst_type)?)
}
ast::Instruction::Mad(d, a) => {
let inst_type = d.get_type();
let is_wide = d.is_wide();
ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?)
}
ast::Instruction::Or(t, a) => ast::Instruction::Or(
t,
a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Sub(d, a) => {
let typ = d.get_type();
ast::Instruction::Sub(d, a.map_non_shift(visitor, &typ, false)?)
}
ast::Instruction::Min(d, a) => {
let typ = d.get_type();
ast::Instruction::Min(d, a.map_non_shift(visitor, &typ, false)?)
}
ast::Instruction::Max(d, a) => {
let typ = d.get_type();
ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?)
}
ast::Instruction::Rcp(d, a) => {
let typ = ast::Type::Scalar(if d.is_f64 {
ast::ScalarType::F64
} else {
ast::ScalarType::F32
});
ast::Instruction::Rcp(d, a.map(visitor, &typ)?)
}
ast::Instruction::And(t, a) => ast::Instruction::And(
t,
a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?),
ast::Instruction::Bar(d, a) => ast::Instruction::Bar(d, a.map(visitor)?),
ast::Instruction::Atom(d, a) => {
ast::Instruction::Atom(d, a.map_atom(visitor, d.inner.get_type(), d.space)?)
}
ast::Instruction::AtomCas(d, a) => {
ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
}
ast::Instruction::Div(d, a) => {
ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?)
}
ast::Instruction::Sqrt(d, a) => {
ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
}
ast::Instruction::Rsqrt(d, a) => {
ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
}
ast::Instruction::Neg(d, a) => {
ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?)
}
ast::Instruction::Sin { flush_to_zero, arg } => {
let typ = ast::Type::Scalar(ast::ScalarType::F32);
ast::Instruction::Sin {
flush_to_zero,
arg: arg.map(visitor, &typ)?,
}
}
ast::Instruction::Cos { flush_to_zero, arg } => {
let typ = ast::Type::Scalar(ast::ScalarType::F32);
ast::Instruction::Cos {
flush_to_zero,
arg: arg.map(visitor, &typ)?,
}
}
ast::Instruction::Lg2 { flush_to_zero, arg } => {
let typ = ast::Type::Scalar(ast::ScalarType::F32);
ast::Instruction::Lg2 {
flush_to_zero,
arg: arg.map(visitor, &typ)?,
}
}
ast::Instruction::Ex2 { flush_to_zero, arg } => {
let typ = ast::Type::Scalar(ast::ScalarType::F32);
ast::Instruction::Ex2 {
flush_to_zero,
arg: arg.map(visitor, &typ)?,
}
}
ast::Instruction::Clz { typ, arg } => {
let dst_type = ast::Type::Scalar(ast::ScalarType::B32);
let src_type = ast::Type::Scalar(typ.into());
ast::Instruction::Clz {
typ,
arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
}
}
ast::Instruction::Brev { typ, arg } => {
let full_type = ast::Type::Scalar(typ.into());
ast::Instruction::Brev {
typ,
arg: arg.map(visitor, &full_type)?,
}
}
ast::Instruction::Popc { typ, arg } => {
let dst_type = ast::Type::Scalar(ast::ScalarType::B32);
let src_type = ast::Type::Scalar(typ.into());
ast::Instruction::Popc {
typ,
arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
}
}
ast::Instruction::Xor { typ, arg } => {
let full_type = ast::Type::Scalar(typ.into());
ast::Instruction::Xor {
typ,
arg: arg.map_non_shift(visitor, &full_type, false)?,
}
}
ast::Instruction::Bfe { typ, arg } => {
let full_type = ast::Type::Scalar(typ.into());
ast::Instruction::Bfe {
typ,
arg: arg.map_bfe(visitor, &full_type)?,
}
}
ast::Instruction::Rem { typ, arg } => {
let full_type = ast::Type::Scalar(typ.into());
ast::Instruction::Rem {
typ,
arg: arg.map_non_shift(visitor, &full_type, false)?,
}
}
})
}
}
impl VisitVariable for ast::Instruction<TypedArgParams> {
fn visit_variable<
'a,
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<TypedStatement, TranslateError> {
Ok(Statement::Instruction(self.map(f)?))
}
}
impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
t: &ast::Type,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)),
ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(
self(desc.new_op(id), Some(t))?,
imm,
)),
}
}
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
t: &ast::Type,
) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
match desc.op {
ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)),
ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)),
}
}
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)),
ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec(
ids.iter()
.map(|id| self(desc.new_op(*id), Some(typ)))
.collect::<Result<_, _>>()?,
)),
}
}
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(id) => {
Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?))
}
ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset(
self(desc.new_op(id), Some(typ))?,
imm,
)),
ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec(
ids.iter()
.map(|id| self(desc.new_op(*id), Some(typ)))
.collect::<Result<_, _>>()?,
)),
}
}
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
(scalar_type, vector_len): (ast::ScalarType, u8),
) -> Result<(spirv::Word, u8), TranslateError> {
Ok((
self(
desc.new_op(desc.op.0),
Some(&ast::Type::Vector(scalar_type.into(), vector_len)),
)?,
desc.op.1,
))
}
}
impl ast::Type {
fn widen(self) -> Result<Self, TranslateError> {
match self {
ast::Type::Scalar(scalar) => {
let kind = scalar.kind();
let width = scalar.size_of();
if (kind != ScalarKind::Signed
&& kind != ScalarKind::Unsigned
&& kind != ScalarKind::Bit)
|| (width == 8)
{
return Err(TranslateError::MismatchedType);
}
Ok(ast::Type::Scalar(ast::ScalarType::from_parts(
width * 2,
kind,
)))
}
_ => Err(TranslateError::Unreachable),
}
}
fn to_parts(&self) -> TypeParts {
match self {
ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
state_space: ast::LdStateSpace::Global,
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*components as u32],
state_space: ast::LdStateSpace::Global,
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: components.clone(),
state_space: ast::LdStateSpace::Global,
},
ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts {
kind: TypeKind::PointerScalar,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
state_space: *state_space,
},
ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts {
kind: TypeKind::PointerVector,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*len as u32],
state_space: *state_space,
},
ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => {
TypeParts {
kind: TypeKind::PointerArray,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: components.clone(),
state_space: *state_space,
}
}
ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => {
TypeParts {
kind: TypeKind::PointerPointer,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*inner_space as u32],
state_space: *state_space,
}
}
}
}
fn from_parts(t: TypeParts) -> Self {
match t.kind {
TypeKind::Scalar => {
ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind))
}
TypeKind::Vector => ast::Type::Vector(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components[0] as u8,
),
TypeKind::Array => ast::Type::Array(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
TypeKind::PointerScalar => ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)),
t.state_space,
),
TypeKind::PointerVector => ast::Type::Pointer(
ast::PointerType::Vector(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components[0] as u8,
),
t.state_space,
),
TypeKind::PointerArray => ast::Type::Pointer(
ast::PointerType::Array(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
t.state_space,
),
TypeKind::PointerPointer => ast::Type::Pointer(
ast::PointerType::Pointer(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) },
),
t.state_space,
),
}
}
fn size_of(&self) -> usize {
match self {
ast::Type::Scalar(typ) => typ.size_of() as usize,
ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize),
ast::Type::Array(typ, len) => len
.iter()
.fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
ast::Type::Pointer(_, _) => mem::size_of::<usize>(),
}
}
}
#[derive(Eq, PartialEq, Clone)]
struct TypeParts {
kind: TypeKind,
scalar_kind: ScalarKind,
width: u8,
components: Vec<u32>,
state_space: ast::LdStateSpace,
}
#[derive(Eq, PartialEq, Copy, Clone)]
enum TypeKind {
Scalar,
Vector,
Array,
PointerScalar,
PointerVector,
PointerArray,
PointerPointer,
}
impl ast::Instruction<ExpandedArgParams> {
fn jump_target(&self) -> Option<spirv::Word> {
match self {
ast::Instruction::Bra(_, a) => Some(a.src),
_ => None,
}
}
// .wide instructions don't support ftz, so it's enough to just look at the
// type declared by the instruction
fn flush_to_zero(&self) -> Option<(bool, u8)> {
match self {
ast::Instruction::Ld(_, _) => None,
ast::Instruction::St(_, _) => None,
ast::Instruction::Mov(_, _) => None,
ast::Instruction::Not(_, _) => None,
ast::Instruction::Bra(_, _) => None,
ast::Instruction::Shl(_, _) => None,
ast::Instruction::Shr(_, _) => None,
ast::Instruction::Ret(_) => None,
ast::Instruction::Call(_) => None,
ast::Instruction::Or(_, _) => None,
ast::Instruction::And(_, _) => None,
ast::Instruction::Cvta(_, _) => None,
ast::Instruction::Selp(_, _) => None,
ast::Instruction::Bar(_, _) => None,
ast::Instruction::Atom(_, _) => None,
ast::Instruction::AtomCas(_, _) => None,
ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None,
ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None,
ast::Instruction::Add(ast::ArithDetails::Unsigned(_), _) => None,
ast::Instruction::Mul(ast::MulDetails::Unsigned(_), _) => None,
ast::Instruction::Mul(ast::MulDetails::Signed(_), _) => None,
ast::Instruction::Mad(ast::MulDetails::Unsigned(_), _) => None,
ast::Instruction::Mad(ast::MulDetails::Signed(_), _) => None,
ast::Instruction::Min(ast::MinMaxDetails::Signed(_), _) => None,
ast::Instruction::Min(ast::MinMaxDetails::Unsigned(_), _) => None,
ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None,
ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None,
ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None,
ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
ast::Instruction::Clz { .. } => None,
ast::Instruction::Brev { .. } => None,
ast::Instruction::Popc { .. } => None,
ast::Instruction::Xor { .. } => None,
ast::Instruction::Bfe { .. } => None,
ast::Instruction::Rem { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
| ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
ast::Instruction::Setp(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, details.typ.size_of())),
ast::Instruction::SetpBool(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, details.typ.size_of())),
ast::Instruction::Abs(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, details.typ.size_of())),
ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _)
| ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
ast::Instruction::Rcp(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })),
// Modifier .ftz can only be specified when either .dtype or .atype
// is .f32 and applies only to single precision (.f32) inputs and results.
ast::Instruction::Cvt(
ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }),
_,
)
| ast::Instruction::Cvt(
ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }),
_,
) => flush_to_zero.map(|ftz| (ftz, 4)),
ast::Instruction::Div(ast::DivDetails::Float(details), _) => details
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
ast::Instruction::Sqrt(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
ast::Instruction::Rsqrt(details, _) => Some((
details.flush_to_zero,
ast::ScalarType::from(details.typ).size_of(),
)),
ast::Instruction::Neg(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, details.typ.size_of())),
ast::Instruction::Sin { flush_to_zero, .. }
| ast::Instruction::Cos { flush_to_zero, .. }
| ast::Instruction::Lg2 { flush_to_zero, .. }
| ast::Instruction::Ex2 { flush_to_zero, .. } => {
Some((*flush_to_zero, mem::size_of::<f32>() as u8))
}
}
}
}
impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError> {
Ok(Statement::Instruction(self.map(f)?))
}
}
type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
struct CompositeRead {
pub typ: ast::ScalarType,
pub dst: spirv::Word,
pub dst_semantics_override: Option<ArgumentSemantics>,
pub src_composite: spirv::Word,
pub src_index: u32,
pub src_len: u32,
}
impl VisitVariableExpanded for CompositeRead {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError> {
let dst_sema = self
.dst_semantics_override
.unwrap_or(ArgumentSemantics::Default);
Ok(Statement::Composite(CompositeRead {
dst: f(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: dst_sema,
},
Some(&ast::Type::Scalar(self.typ)),
)?,
src_composite: f(
ArgumentDescriptor {
op: self.src_composite,
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Vector(self.typ, self.src_len as u8)),
)?,
..self
}))
}
}
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
pub value: ast::ImmediateValue,
}
struct BrachCondition {
predicate: spirv::Word,
if_true: spirv::Word,
if_false: spirv::Word,
}
#[derive(Clone)]
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
from: ast::Type,
to: ast::Type,
kind: ConversionKind,
}
#[derive(PartialEq, Copy, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
BitToPtr(ast::LdStateSpace),
PtrToBit(ast::UIntType),
PtrToPtr { spirv_ptr: bool },
}
impl<T> ast::PredAt<T> {
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
self,
f: &mut F,
) -> Result<ast::PredAt<U>, TranslateError> {
let new_label = f(self.label)?;
Ok(ast::PredAt {
not: self.not,
label: new_label,
})
}
}
impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
fn map_variable<F: FnMut(&str) -> Result<spirv::Word, TranslateError>>(
self,
f: &mut F,
) -> Result<ast::Instruction<NormalizedArgParams>, TranslateError> {
match self {
ast::Instruction::Call(call) => {
let call_inst = ast::CallInst {
uniform: call.uniform,
ret_params: call
.ret_params
.into_iter()
.map(|p| f(p))
.collect::<Result<_, _>>()?,
func: f(call.func)?,
param_list: call
.param_list
.into_iter()
.map(|p| p.map_variable(f))
.collect::<Result<_, _>>()?,
};
Ok(ast::Instruction::Call(call_inst))
}
i => i.map(f),
}
}
}
impl From<ast::KernelArgumentType> for ast::Type {
fn from(this: ast::KernelArgumentType) -> Self {
match this {
ast::KernelArgumentType::Normal(typ) => typ.into(),
ast::KernelArgumentType::Shared => ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
),
}
}
}
impl<T: ArgParamsEx> ast::Arg1<T> {
fn cast<U: ArgParamsEx<Id = T::Id>>(self) -> ast::Arg1<U> {
ast::Arg1 { src: self.src }
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<&ast::Type>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
Ok(ast::Arg1 { src: new_src })
}
}
impl<T: ArgParamsEx> ast::Arg1Bar<T> {
fn cast<U: ArgParamsEx<Operand = T::Operand>>(self) -> ast::Arg1Bar<U> {
ast::Arg1Bar { src: self.src }
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
) -> Result<ast::Arg1Bar<U>, TranslateError> {
let new_src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::U32),
)?;
Ok(ast::Arg1Bar { src: new_src })
}
}
impl<T: ArgParamsEx> ast::Arg2<T> {
fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg2<U> {
ast::Arg2 {
src: self.src,
dst: self.dst,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(t),
)?;
let new_src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
Ok(ast::Arg2 {
dst: new_dst,
src: new_src,
})
}
fn map_different_types<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
dst_t: &ast::Type,
src_t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(dst_t),
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
src_t,
)?;
Ok(ast::Arg2 { dst, src })
}
}
impl<T: ArgParamsEx> ast::Arg2Ld<T> {
fn cast<U: ArgParamsEx<Operand = T::Operand, IdOrVector = T::IdOrVector>>(
self,
) -> ast::Arg2Ld<U> {
ast::Arg2Ld {
dst: self.dst,
src: self.src,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
details: &ast::LdDetails,
) -> Result<ast::Arg2Ld<U>, TranslateError> {
let dst = visitor.id_or_vector(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::DefaultRelaxed,
},
&ast::Type::from(details.typ.clone()),
)?;
let is_logical_ptr = details.state_space == ast::LdStateSpace::Param
|| details.state_space == ast::LdStateSpace::Local;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
sema: if is_logical_ptr {
ArgumentSemantics::RegisterPointer
} else {
ArgumentSemantics::PhysicalPointer
},
},
&ast::Type::Pointer(
ast::PointerType::from(details.typ.clone()),
details.state_space,
),
)?;
Ok(ast::Arg2Ld { dst, src })
}
}
impl<T: ArgParamsEx> ast::Arg2St<T> {
fn cast<U: ArgParamsEx<Operand = T::Operand, OperandOrVector = T::OperandOrVector>>(
self,
) -> ast::Arg2St<U> {
ast::Arg2St {
src1: self.src1,
src2: self.src2,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
details: &ast::StData,
) -> Result<ast::Arg2St<U>, TranslateError> {
let is_logical_ptr = details.state_space == ast::StStateSpace::Param
|| details.state_space == ast::StStateSpace::Local;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: if is_logical_ptr {
ArgumentSemantics::RegisterPointer
} else {
ArgumentSemantics::PhysicalPointer
},
},
&ast::Type::Pointer(
ast::PointerType::from(details.typ.clone()),
details.state_space.to_ld_ss(),
),
)?;
let src2 = visitor.operand_or_vector(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::DefaultRelaxed,
},
&details.typ.clone().into(),
)?;
Ok(ast::Arg2St { src1, src2 })
}
}
impl<T: ArgParamsEx> ast::Arg2Mov<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
details: &ast::MovDetails,
) -> Result<ast::Arg2Mov<U>, TranslateError> {
Ok(match self {
ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?),
ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?),
})
}
}
impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
fn cast<U: ArgParamsEx<IdOrVector = P::IdOrVector, OperandOrVector = P::OperandOrVector>>(
self,
) -> ast::Arg2MovNormal<U> {
ast::Arg2MovNormal {
dst: self.dst,
src: self.src,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
self,
visitor: &mut V,
details: &ast::MovDetails,
) -> Result<ast::Arg2MovNormal<U>, TranslateError> {
let dst = visitor.id_or_vector(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
&details.typ.clone().into(),
)?;
let src = visitor.operand_or_vector(
ArgumentDescriptor {
op: self.src,
is_dst: false,
sema: if details.src_is_address {
ArgumentSemantics::Address
} else {
ArgumentSemantics::Default
},
},
&details.typ.clone().into(),
)?;
Ok(ast::Arg2MovNormal { dst, src })
}
}
impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
fn cast<U: ArgParamsEx<Id = T::Id, SrcMemberOperand = T::SrcMemberOperand>>(
self,
) -> ast::Arg2MovMember<U> {
match self {
ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2),
ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src),
ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2),
}
}
fn vector_dst(&self) -> Option<&T::Id> {
match self {
ast::Arg2MovMember::Src(_, _) => None,
ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => {
Some(d)
}
}
}
fn vector_src(&self) -> Option<&T::SrcMemberOperand> {
match self {
ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d),
ast::Arg2MovMember::Dst(_, _, _) => None,
}
}
}
impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
details: &ast::MovDetails,
) -> Result<ast::Arg2MovMember<U>, TranslateError> {
match self {
ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
let scalar_type = details.typ.get_scalar()?;
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src1 = visitor.id(
ArgumentDescriptor {
op: composite_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src2 = visitor.id(
ArgumentDescriptor {
op: scalar_src,
is_dst: false,
sema: if details.src_is_address {
ArgumentSemantics::Address
} else if details.relaxed_src2_conv {
ArgumentSemantics::DefaultRelaxed
} else {
ArgumentSemantics::Default
},
},
Some(&details.typ.clone().into()),
)?;
Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2))
}
ast::Arg2MovMember::Src(dst, src) => {
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&details.typ.clone().into()),
)?;
let scalar_typ = details.typ.get_scalar()?;
let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
(scalar_typ.into(), details.src_width),
)?;
Ok(ast::Arg2MovMember::Src(dst, src))
}
ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
let scalar_type = details.typ.get_scalar()?;
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let composite_src = visitor.id(
ArgumentDescriptor {
op: composite_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: if details.relaxed_src2_conv {
ArgumentSemantics::DefaultRelaxed
} else {
ArgumentSemantics::Default
},
},
(scalar_type.into(), details.src_width),
)?;
Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
}
}
}
}
impl<T: ArgParamsEx> ast::Arg3<T> {
fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg3<U> {
ast::Arg3 {
dst: self.dst,
src1: self.src1,
src2: self.src2,
}
}
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
typ: &ast::Type,
is_wide: bool,
) -> Result<ast::Arg3<U>, TranslateError> {
let wide_type = if is_wide {
Some(typ.clone().widen()?)
} else {
None
};
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(wide_type.as_ref().unwrap_or(typ)),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::Default,
},
typ,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
typ,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
fn map_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(t),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::U32),
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::ScalarType,
state_space: ast::AtomSpace,
) -> Result<ast::Arg3<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Scalar(scalar_type)),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::PhysicalPointer,
},
&ast::Type::Pointer(
ast::PointerType::Scalar(scalar_type),
state_space.to_ld_ss(),
),
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(scalar_type),
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
}
impl<T: ArgParamsEx> ast::Arg4<T> {
fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4<U> {
ast::Arg4 {
dst: self.dst,
src1: self.src1,
src2: self.src2,
src3: self.src3,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
is_wide: bool,
) -> Result<ast::Arg4<U>, TranslateError> {
let wide_type = if is_wide {
Some(t.clone().widen()?)
} else {
None
};
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(wide_type.as_ref().unwrap_or(t)),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
Ok(ast::Arg4 {
dst,
src1,
src2,
src3,
})
}
fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::SelpType,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Scalar(t.into())),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(t.into()),
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(t.into()),
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
)?;
Ok(ast::Arg4 {
dst,
src1,
src2,
src3,
})
}
fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::BitType,
state_space: ast::AtomSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Scalar(scalar_type)),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::PhysicalPointer,
},
&ast::Type::Pointer(
ast::PointerType::Scalar(scalar_type),
state_space.to_ld_ss(),
),
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(scalar_type),
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(scalar_type),
)?;
Ok(ast::Arg4 {
dst,
src1,
src2,
src3,
})
}
fn map_bfe<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
typ: &ast::Type,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(typ),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::Default,
},
typ,
)?;
let u32_type = ast::Type::Scalar(ast::ScalarType::U32);
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&u32_type,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&u32_type,
)?;
Ok(ast::Arg4 {
dst,
src1,
src2,
src3,
})
}
}
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4Setp<U> {
ast::Arg4Setp {
dst1: self.dst1,
dst2: self.dst2,
src1: self.src1,
src2: self.src2,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
) -> Result<ast::Arg4Setp<U>, TranslateError> {
let dst1 = visitor.id(
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)?;
let dst2 = self
.dst2
.map(|dst2| {
visitor.id(
ArgumentDescriptor {
op: dst2,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)
})
.transpose()?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
Ok(ast::Arg4Setp {
dst1,
dst2,
src1,
src2,
})
}
}
impl<T: ArgParamsEx> ast::Arg5<T> {
fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg5<U> {
ast::Arg5 {
dst1: self.dst1,
dst2: self.dst2,
src1: self.src1,
src2: self.src2,
src3: self.src3,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
) -> Result<ast::Arg5<U>, TranslateError> {
let dst1 = visitor.id(
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)?;
let dst2 = self
.dst2
.map(|dst2| {
visitor.id(
ArgumentDescriptor {
op: dst2,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)
})
.transpose()?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
)?;
Ok(ast::Arg5 {
dst1,
dst2,
src1,
src2,
src3,
})
}
}
impl ast::Type {
fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> {
match self {
ast::Type::Vector(t, len) => Ok((*t, *len)),
_ => Err(TranslateError::MismatchedType),
}
}
fn get_scalar(&self) -> Result<ast::ScalarType, TranslateError> {
match self {
ast::Type::Scalar(t) => Ok(*t),
_ => Err(TranslateError::MismatchedType),
}
}
}
impl<T> ast::CallOperand<T> {
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
self,
f: &mut F,
) -> Result<ast::CallOperand<U>, TranslateError> {
match self {
ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(f(id)?)),
ast::CallOperand::Imm(x) => Ok(ast::CallOperand::Imm(x)),
}
}
}
impl ast::StStateSpace {
fn to_ld_ss(self) -> ast::LdStateSpace {
match self {
ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
ast::StStateSpace::Global => ast::LdStateSpace::Global,
ast::StStateSpace::Local => ast::LdStateSpace::Local,
ast::StStateSpace::Param => ast::LdStateSpace::Param,
ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum ScalarKind {
Bit,
Unsigned,
Signed,
Float,
Float2,
Pred,
}
impl ast::ScalarType {
fn kind(self) -> ScalarKind {
match self {
ast::ScalarType::U8 => ScalarKind::Unsigned,
ast::ScalarType::U16 => ScalarKind::Unsigned,
ast::ScalarType::U32 => ScalarKind::Unsigned,
ast::ScalarType::U64 => ScalarKind::Unsigned,
ast::ScalarType::S8 => ScalarKind::Signed,
ast::ScalarType::S16 => ScalarKind::Signed,
ast::ScalarType::S32 => ScalarKind::Signed,
ast::ScalarType::S64 => ScalarKind::Signed,
ast::ScalarType::B8 => ScalarKind::Bit,
ast::ScalarType::B16 => ScalarKind::Bit,
ast::ScalarType::B32 => ScalarKind::Bit,
ast::ScalarType::B64 => ScalarKind::Bit,
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
ast::ScalarType::F16x2 => ScalarKind::Float,
ast::ScalarType::Pred => ScalarKind::Pred,
}
}
fn from_parts(width: u8, kind: ScalarKind) -> Self {
match kind {
ScalarKind::Float => match width {
2 => ast::ScalarType::F16,
4 => ast::ScalarType::F32,
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => unreachable!(),
},
ScalarKind::Signed => match width {
1 => ast::ScalarType::S8,
2 => ast::ScalarType::S16,
4 => ast::ScalarType::S32,
8 => ast::ScalarType::S64,
_ => unreachable!(),
},
ScalarKind::Unsigned => match width {
1 => ast::ScalarType::U8,
2 => ast::ScalarType::U16,
4 => ast::ScalarType::U32,
8 => ast::ScalarType::U64,
_ => unreachable!(),
},
ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2,
_ => unreachable!(),
},
ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
impl ast::BooleanType {
fn to_type(self) -> ast::Type {
match self {
ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
}
}
}
impl ast::ShlType {
fn to_type(self) -> ast::Type {
match self {
ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
}
}
}
impl ast::ShrType {
fn signed(&self) -> bool {
match self {
ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
_ => false,
}
}
}
impl ast::ArithDetails {
fn get_type(&self) -> ast::Type {
ast::Type::Scalar(match self {
ast::ArithDetails::Unsigned(t) => (*t).into(),
ast::ArithDetails::Signed(d) => d.typ.into(),
ast::ArithDetails::Float(d) => d.typ.into(),
})
}
}
impl ast::MulDetails {
fn get_type(&self) -> ast::Type {
ast::Type::Scalar(match self {
ast::MulDetails::Unsigned(d) => d.typ.into(),
ast::MulDetails::Signed(d) => d.typ.into(),
ast::MulDetails::Float(d) => d.typ.into(),
})
}
}
impl ast::MinMaxDetails {
fn get_type(&self) -> ast::Type {
ast::Type::Scalar(match self {
ast::MinMaxDetails::Signed(t) => (*t).into(),
ast::MinMaxDetails::Unsigned(t) => (*t).into(),
ast::MinMaxDetails::Float(d) => d.typ.into(),
})
}
}
impl ast::DivDetails {
fn get_type(&self) -> ast::Type {
ast::Type::Scalar(match self {
ast::DivDetails::Unsigned(t) => (*t).into(),
ast::DivDetails::Signed(t) => (*t).into(),
ast::DivDetails::Float(d) => d.typ.into(),
})
}
}
impl ast::AtomInnerDetails {
fn get_type(&self) -> ast::ScalarType {
match self {
ast::AtomInnerDetails::Bit { typ, .. } => (*typ).into(),
ast::AtomInnerDetails::Unsigned { typ, .. } => (*typ).into(),
ast::AtomInnerDetails::Signed { typ, .. } => (*typ).into(),
ast::AtomInnerDetails::Float { typ, .. } => (*typ).into(),
}
}
}
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::SIntType::S8,
2 => ast::SIntType::S16,
4 => ast::SIntType::S32,
8 => ast::SIntType::S64,
_ => unreachable!(),
}
}
}
impl ast::UIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::UIntType::U8,
2 => ast::UIntType::U16,
4 => ast::UIntType::U32,
8 => ast::UIntType::U64,
_ => unreachable!(),
}
}
}
impl ast::LdStateSpace {
fn to_spirv(self) -> spirv::StorageClass {
match self {
ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant,
ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
ast::LdStateSpace::Local => spirv::StorageClass::Function,
ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup,
ast::LdStateSpace::Param => spirv::StorageClass::Function,
}
}
}
impl From<ast::FnArgumentType> for ast::VariableType {
fn from(t: ast::FnArgumentType) -> Self {
match t {
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
ast::FnArgumentType::Shared => todo!(),
}
}
}
impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> {
match self {
ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r),
ast::Operand::Imm(_) => None,
}
}
}
impl<T> ast::OperandOrVector<T> {
fn single_underlying(&self) -> Option<&T> {
match self {
ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r),
ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None,
}
}
}
impl ast::MulDetails {
fn is_wide(&self) -> bool {
match self {
ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide,
ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide,
ast::MulDetails::Float(_) => false,
}
}
}
impl ast::AtomSpace {
fn to_ld_ss(self) -> ast::LdStateSpace {
match self {
ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
ast::AtomSpace::Global => ast::LdStateSpace::Global,
ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
}
}
}
impl ast::MemScope {
fn to_spirv(self) -> spirv::Scope {
match self {
ast::MemScope::Cta => spirv::Scope::Workgroup,
ast::MemScope::Gpu => spirv::Scope::Device,
ast::MemScope::Sys => spirv::Scope::CrossDevice,
}
}
}
impl ast::AtomSemantics {
fn to_spirv(self) -> spirv::MemorySemantics {
match self {
ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED,
ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE,
ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE,
ast::AtomSemantics::AcquireRelease => spirv::MemorySemantics::ACQUIRE_RELEASE,
}
}
}
impl ast::FnArgumentType {
fn semantics(&self) -> ArgumentSemantics {
match self {
ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
}
}
}
fn bitcast_register_pointer(
operand_type: &ast::Type,
instr_type: &ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
bitcast_physical_pointer(operand_type, instr_type, ss)
}
fn bitcast_physical_pointer(
operand_type: &ast::Type,
instr_type: &ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
match operand_type {
// array decays to a pointer
ast::Type::Array(op_scalar_t, _) => {
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
if ss == Some(*instr_space) {
if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) {
Ok(None)
} else {
Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
}
} else {
if ss == Some(ast::LdStateSpace::Generic)
|| *instr_space == ast::LdStateSpace::Generic
{
Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
} else {
Err(TranslateError::MismatchedType)
}
}
} else {
Err(TranslateError::MismatchedType)
}
}
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => {
if let Some(space) = ss {
Ok(Some(ConversionKind::BitToPtr(space)))
} else {
Err(TranslateError::Unreachable)
}
}
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match ss {
Some(ast::LdStateSpace::Shared)
| Some(ast::LdStateSpace::Generic)
| Some(ast::LdStateSpace::Param)
| Some(ast::LdStateSpace::Local) => {
Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
}
_ => Err(TranslateError::MismatchedType),
},
ast::Type::Pointer(op_scalar_t, op_space) => {
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
if op_space == instr_space {
if op_scalar_t == instr_scalar_t {
Ok(None)
} else {
Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
}
} else {
if *op_space == ast::LdStateSpace::Generic
|| *instr_space == ast::LdStateSpace::Generic
{
Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
} else {
Err(TranslateError::MismatchedType)
}
}
} else {
Err(TranslateError::MismatchedType)
}
}
_ => Err(TranslateError::MismatchedType),
}
}
fn force_bitcast_ptr_to_bit(
_: &ast::Type,
instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
// TODO: verify this on f32, u16 and the like
if let ast::Type::Scalar(scalar_t) = instr_type {
if let Ok(int_type) = (*scalar_t).try_into() {
return Ok(Some(ConversionKind::PtrToBit(int_type)));
}
}
Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
ScalarKind::Float => operand.kind() == ScalarKind::Bit,
ScalarKind::Signed => {
operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
}
ScalarKind::Unsigned => {
operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
}
ScalarKind::Float2 => false,
ScalarKind::Pred => false,
}
}
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
| (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
fn should_bitcast_packed(
operand: &ast::Type,
instr: &ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
(operand, instr)
{
if scalar.kind() == ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
}
should_bitcast_wrapper(operand, instr, ss)
}
fn should_bitcast_wrapper(
operand: &ast::Type,
instr: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if instr == operand {
return Ok(None);
}
if should_bitcast(instr, operand) {
Ok(Some(ConversionKind::Default))
} else {
Err(TranslateError::MismatchedType)
}
}
fn should_convert_relaxed_src_wrapper(
src_type: &ast::Type,
instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if src_type == instr_type {
return Ok(None);
}
match should_convert_relaxed_src(src_type, instr_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Signed | ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
should_convert_relaxed_src(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
fn should_convert_relaxed_dst_wrapper(
dst_type: &ast::Type,
instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if dst_type == instr_type {
return Ok(None);
}
match should_convert_relaxed_dst(dst_type, instr_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Signed => {
if dst_type.kind() != ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
}
} else {
None
}
}
ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
should_convert_relaxed_dst(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
impl<'a> ast::MethodDecl<'a, &'a str> {
fn name(&self) -> &'a str {
match self {
ast::MethodDecl::Kernel { name, .. } => name,
ast::MethodDecl::Func(_, name, _) => name,
}
}
}
struct SpirvMethodDecl<'input> {
input: Vec<ast::Variable<ast::Type, spirv::Word>>,
output: Vec<ast::Variable<ast::Type, spirv::Word>>,
name: MethodName<'input>,
uses_shared_mem: bool,
}
impl<'input> SpirvMethodDecl<'input> {
fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
let (input, output) = match ast_decl {
ast::MethodDecl::Kernel { in_args, .. } => {
let spirv_input = in_args
.iter()
.map(|var| {
let v_type = match &var.v_type {
ast::KernelArgumentType::Normal(t) => {
ast::FnArgumentType::Param(t.clone())
}
ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
};
ast::Variable {
name: var.name,
align: var.align,
v_type: v_type.to_kernel_type(),
array_init: var.array_init.clone(),
}
})
.collect();
(spirv_input, Vec::new())
}
ast::MethodDecl::Func(out_args, _, in_args) => {
let (param_output, non_param_output): (Vec<_>, Vec<_>) =
out_args.iter().partition(|var| var.v_type.is_param());
let spirv_output = non_param_output
.into_iter()
.cloned()
.map(|var| ast::Variable {
name: var.name,
align: var.align,
v_type: var.v_type.to_func_type(),
array_init: var.array_init.clone(),
})
.collect();
let spirv_input = param_output
.into_iter()
.cloned()
.chain(in_args.iter().cloned())
.map(|var| ast::Variable {
name: var.name,
align: var.align,
v_type: var.v_type.to_func_type(),
array_init: var.array_init.clone(),
})
.collect();
(spirv_input, spirv_output)
}
};
SpirvMethodDecl {
input,
output,
name: MethodName::new(ast_decl),
uses_shared_mem: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast;
static SCALAR_TYPES: [ast::ScalarType; 15] = [
ast::ScalarType::B8,
ast::ScalarType::B16,
ast::ScalarType::B32,
ast::ScalarType::B64,
ast::ScalarType::S8,
ast::ScalarType::S16,
ast::ScalarType::S32,
ast::ScalarType::S64,
ast::ScalarType::U8,
ast::ScalarType::U16,
ast::ScalarType::U32,
ast::ScalarType::U64,
ast::ScalarType::F16,
ast::ScalarType::F32,
ast::ScalarType::F64,
];
static RELAXED_SRC_CONVERSION_TABLE: &'static str =
"b8 - chop chop chop - chop chop chop - chop chop chop chop chop chop
b16 inv - chop chop inv - chop chop inv - chop chop - chop chop
b32 inv inv - chop inv inv - chop inv inv - chop inv - chop
b64 inv inv inv - inv inv inv - inv inv inv - inv inv -
s8 - chop chop chop - chop chop chop - chop chop chop inv inv inv
s16 inv - chop chop inv - chop chop inv - chop chop inv inv inv
s32 inv inv - chop inv inv - chop inv inv - chop inv inv inv
s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
u8 - chop chop chop - chop chop chop - chop chop chop inv inv inv
u16 inv - chop chop inv - chop chop inv - chop chop inv inv inv
u32 inv inv - chop inv inv - chop inv inv - chop inv inv inv
u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
f16 inv - chop chop inv inv inv inv inv inv inv inv - inv inv
f32 inv inv - chop inv inv inv inv inv inv inv inv inv - inv
f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -";
static RELAXED_DST_CONVERSION_TABLE: &'static str =
"b8 - zext zext zext - zext zext zext - zext zext zext zext zext zext
b16 inv - zext zext inv - zext zext inv - zext zext - zext zext
b32 inv inv - zext inv inv - zext inv inv - zext inv - zext
b64 inv inv inv - inv inv inv - inv inv inv - inv inv -
s8 - sext sext sext - sext sext sext - sext sext sext inv inv inv
s16 inv - sext sext inv - sext sext inv - sext sext inv inv inv
s32 inv inv - sext inv inv - sext inv inv - sext inv inv inv
s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
u8 - zext zext zext - zext zext zext - zext zext zext inv inv inv
u16 inv - zext zext inv - zext zext inv - zext zext inv inv inv
u32 inv inv - zext inv inv - zext inv inv - zext inv inv inv
u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
f16 inv - zext zext inv inv inv inv inv inv inv inv - inv inv
f32 inv inv - zext inv inv inv inv inv inv inv inv inv - inv
f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -";
fn table_entry_to_conversion(entry: &'static str) -> Option<ConversionKind> {
match entry {
"-" => Some(ConversionKind::Default),
"inv" => None,
"zext" => Some(ConversionKind::Default),
"chop" => Some(ConversionKind::Default),
"sext" => Some(ConversionKind::SignExtend),
_ => unreachable!(),
}
}
fn parse_conversion_table(table: &'static str) -> Vec<Vec<Option<ConversionKind>>> {
table
.lines()
.map(|line| {
line.split_ascii_whitespace()
.skip(1)
.map(table_entry_to_conversion)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
fn assert_conversion_table<F: Fn(&ast::Type, &ast::Type) -> Option<ConversionKind>>(
table: &'static str,
f: F,
) {
let conv_table = parse_conversion_table(table);
for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
let conversion = f(
&ast::Type::Scalar(*op_type),
&ast::Type::Scalar(*instr_type),
);
if instr_idx == op_idx {
assert!(conversion == None);
} else {
assert!(conversion == conv_table[instr_idx][op_idx]);
}
}
}
}
#[test]
fn should_convert_relaxed_src_all_combinations() {
assert_conversion_table(RELAXED_SRC_CONVERSION_TABLE, should_convert_relaxed_src);
}
#[test]
fn should_convert_relaxed_dst_all_combinations() {
assert_conversion_table(RELAXED_DST_CONVERSION_TABLE, should_convert_relaxed_dst);
}
}