Remove the need for custom Arg types in middle-end

This commit is contained in:
Andrzej Janik
2020-07-28 02:44:24 +02:00
parent d514a5610a
commit 52faaab547
3 changed files with 244 additions and 229 deletions

View File

@ -1,5 +1,5 @@
use std::convert::From;
use std::num::ParseIntError;
use std::{marker::PhantomData, num::ParseIntError};
quick_error! {
#[derive(Debug)]
@ -52,7 +52,7 @@ pub struct Function<'a> {
pub kernel: bool,
pub name: &'a str,
pub args: Vec<Argument<'a>>,
pub body: Vec<Statement<&'a str>>,
pub body: Vec<Statement<ParsedArgParams<'a>>>,
}
#[derive(Default)]
@ -141,16 +141,16 @@ impl Default for ScalarType {
}
}
pub enum Statement<ID> {
Label(ID),
Variable(Variable<ID>),
Instruction(Option<PredAt<ID>>, Instruction<ID>),
pub enum Statement<P: ArgParams> {
Label(P::ID),
Variable(Variable<P>),
Instruction(Option<PredAt<P::ID>>, Instruction<P>),
}
pub struct Variable<ID> {
pub struct Variable<P: ArgParams> {
pub space: StateSpace,
pub v_type: Type,
pub name: ID,
pub name: P::ID,
pub count: Option<u32>,
}
@ -169,59 +169,75 @@ pub struct PredAt<ID> {
pub label: ID,
}
pub enum Instruction<ID> {
Ld(LdData, Arg2<ID>),
Mov(MovData, Arg2Mov<ID>),
Mul(MulDetails, Arg3<ID>),
Add(AddDetails, Arg3<ID>),
Setp(SetpData, Arg4<ID>),
SetpBool(SetpBoolData, Arg5<ID>),
Not(NotData, Arg2<ID>),
Bra(BraData, Arg1<ID>),
Cvt(CvtData, Arg2<ID>),
Shl(ShlData, Arg3<ID>),
St(StData, Arg2St<ID>),
pub enum Instruction<P: ArgParams> {
Ld(LdData, Arg2<P>),
Mov(MovData, Arg2Mov<P>),
Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>),
Setp(SetpData, Arg4<P>),
SetpBool(SetpBoolData, Arg5<P>),
Not(NotData, Arg2<P>),
Bra(BraData, Arg1<P>),
Cvt(CvtData, Arg2<P>),
Shl(ShlData, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
}
pub struct Arg1<ID> {
pub src: ID, // it is a jump destination, but in terms of operands it is a source operand
pub trait ArgParams {
type ID;
type Operand;
type MovOperand;
}
pub struct Arg2<ID> {
pub dst: ID,
pub src: Operand<ID>,
pub struct ParsedArgParams<'a> {
_marker: PhantomData<&'a ()>,
}
pub struct Arg2St<ID> {
pub src1: Operand<ID>,
pub src2: Operand<ID>,
impl<'a> ArgParams for ParsedArgParams<'a> {
type ID = &'a str;
type Operand = Operand<&'a str>;
type MovOperand = MovOperand<&'a str>;
}
pub struct Arg2Mov<ID> {
pub dst: ID,
pub src: MovOperand<ID>,
pub struct Arg1<P: ArgParams> {
pub src: P::ID, // it is a jump destination, but in terms of operands it is a source operand
}
pub struct Arg3<ID> {
pub dst: ID,
pub src1: Operand<ID>,
pub src2: Operand<ID>,
pub struct Arg2<P: ArgParams> {
pub dst: P::ID,
pub src: P::Operand,
}
pub struct Arg4<ID> {
pub dst1: ID,
pub dst2: Option<ID>,
pub src1: Operand<ID>,
pub src2: Operand<ID>,
pub struct Arg2St<P: ArgParams> {
pub src1: P::Operand,
pub src2: P::Operand,
}
pub struct Arg5<ID> {
pub dst1: ID,
pub dst2: Option<ID>,
pub src1: Operand<ID>,
pub src2: Operand<ID>,
pub src3: Operand<ID>,
pub struct Arg2Mov<P: ArgParams> {
pub dst: P::ID,
pub src: P::MovOperand,
}
pub struct Arg3<P: ArgParams> {
pub dst: P::ID,
pub src1: P::Operand,
pub src2: P::Operand,
}
pub struct Arg4<P: ArgParams> {
pub dst1: P::ID,
pub dst2: Option<P::ID>,
pub src1: P::Operand,
pub src2: P::Operand,
}
pub struct Arg5<P: ArgParams> {
pub dst1: P::ID,
pub dst2: Option<P::ID>,
pub src1: P::Operand,
pub src2: P::Operand,
pub src3: P::Operand,
}
pub enum Operand<ID> {

View File

@ -223,7 +223,7 @@ FunctionInput: ast::Argument<'input> = {
}
};
pub(crate) FunctionBody: Vec<ast::Statement<&'input str>> = {
pub(crate) FunctionBody: Vec<ast::Statement<ast::ParsedArgParams<'input>>> = {
"{" <s:Statement*> "}" => { without_none(s) }
};
@ -269,7 +269,7 @@ MemoryType: ast::ScalarType = {
".f64" => ast::ScalarType::F64,
};
Statement: Option<ast::Statement<&'input str>> = {
Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
<l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None,
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
@ -289,7 +289,7 @@ Label: &'input str = {
<id:ExtendedID> ":" => id
};
Variable: ast::Variable<&'input str> = {
Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
<s:StateSpaceSpecifier> <t:Type> <v:VariableName> => {
let (name, count) = v;
ast::Variable { space: s, v_type: t, name: name, count: count }
@ -310,7 +310,7 @@ VariableName: (&'input str, Option<u32>) = {
}
};
Instruction: ast::Instruction<&'input str> = {
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd,
InstMov,
InstMul,
@ -325,7 +325,7 @@ Instruction: ast::Instruction<&'input str> = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<&'input str> = {
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
ast::Instruction::Ld(
ast::LdData {
@ -370,7 +370,7 @@ LdCacheOperator: ast::LdCacheOperator = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
InstMov: ast::Instruction<&'input str> = {
InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mov" <t:MovType> <a:Arg2Mov> => {
ast::Instruction::Mov(ast::MovData{ typ:t }, a)
}
@ -394,7 +394,7 @@ MovType: ast::Type = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
InstMul: ast::Instruction<&'input str> = {
InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a)
};
@ -455,7 +455,7 @@ IntType : ast::IntType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
InstAdd: ast::Instruction<&'input str> = {
InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a)
};
@ -492,7 +492,7 @@ InstAddMode: ast::AddDetails = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
// TODO: support f16 setp
InstSetp: ast::Instruction<&'input str> = {
InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = {
"setp" <d:SetpMode> <a:Arg4> => ast::Instruction::Setp(d, a),
"setp" <d:SetpBoolMode> <a:Arg5> => ast::Instruction::SetpBool(d, a),
};
@ -556,7 +556,7 @@ SetpType: ast::ScalarType = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
InstNot: ast::Instruction<&'input str> = {
InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
"not" NotType <a:Arg2> => ast::Instruction::Not(ast::NotData{}, a)
};
@ -571,12 +571,12 @@ PredAt: ast::PredAt<&'input str> = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
InstBra: ast::Instruction<&'input str> = {
InstBra: ast::Instruction<ast::ParsedArgParams<'input>> = {
"bra" <u:".uni"?> <a:Arg1> => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a)
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
InstCvt: ast::Instruction<&'input str> = {
InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtData{}, a)
}
@ -602,7 +602,7 @@ CvtType = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
InstShl: ast::Instruction<&'input str> = {
InstShl: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shl" ShlType <a:Arg3> => ast::Instruction::Shl(ast::ShlData{}, a)
};
@ -612,7 +612,7 @@ ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<&'input str> = {
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => {
ast::Instruction::St(
ast::StData {
@ -642,7 +642,7 @@ StCacheOperator: ast::StCacheOperator = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
InstRet: ast::Instruction<&'input str> = {
InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = {
"ret" <u:".uni"?> => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() })
};
@ -675,28 +675,28 @@ VectorOperand: (&'input str, &'input str) = {
<pref:ExtendedID> <suf:DotID> => (pref, &suf[1..]),
};
Arg1: ast::Arg1<&'input str> = {
Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = {
<src:ExtendedID> => ast::Arg1{<>}
};
Arg2: ast::Arg2<&'input str> = {
Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
};
Arg2Mov: ast::Arg2Mov<&'input str> = {
Arg2Mov: ast::Arg2Mov<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:MovOperand> => ast::Arg2Mov{<>}
};
Arg3: ast::Arg3<&'input str> = {
Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
};
Arg4: ast::Arg4<&'input str> = {
Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4{<>}
};
// TODO: pass src3 negation somewhere
Arg5: ast::Arg5<&'input str> = {
Arg5: ast::Arg5<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
};

View File

@ -168,9 +168,9 @@ fn apply_id_offset(func_body: &mut Vec<ExpandedStatement>, id_offset: u32) {
}
}
fn to_ssa<'a>(
f_args: &[ast::Argument],
f_body: Vec<ast::Statement<&'a str>>,
fn to_ssa<'a, 'b>(
f_args: &'b [ast::Argument<'a>],
f_body: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> (Vec<ExpandedStatement>, spirv::Word) {
let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body);
let normalized_statements = normalize_predicates(normalized_ids, &mut id_def);
@ -214,7 +214,7 @@ fn normalize_labels(
}
fn normalize_predicates(
func: Vec<ast::Statement<spirv::Word>>,
func: Vec<ast::Statement<NormalizedArgParams>>,
id_def: &mut NumericIdResolver,
) -> Vec<NormalizedStatement> {
let mut result = Vec::with_capacity(func.len());
@ -343,51 +343,51 @@ fn expand_arguments(
fn normalize_insert_instruction(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
instr: ast::Instruction<spirv::Word>,
) -> Instruction {
instr: ast::Instruction<NormalizedArgParams>,
) -> ast::Instruction<ExpandedArgParams> {
match instr {
ast::Instruction::Ld(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a);
Instruction::Ld(d, arg)
ast::Instruction::Ld(d, arg)
}
ast::Instruction::Mov(d, a) => {
let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a);
Instruction::Mov(d, arg)
ast::Instruction::Mov(d, arg)
}
ast::Instruction::Mul(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
Instruction::Mul(d, arg)
ast::Instruction::Mul(d, arg)
}
ast::Instruction::Add(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
Instruction::Add(d, arg)
ast::Instruction::Add(d, arg)
}
ast::Instruction::Setp(d, a) => {
let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a);
Instruction::Setp(d, arg)
ast::Instruction::Setp(d, arg)
}
ast::Instruction::SetpBool(d, a) => {
let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a);
Instruction::SetpBool(d, arg)
ast::Instruction::SetpBool(d, arg)
}
ast::Instruction::Not(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
Instruction::Not(d, arg)
ast::Instruction::Not(d, arg)
}
ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }),
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, ast::Arg1 { src: a.src }),
ast::Instruction::Cvt(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
Instruction::Cvt(d, arg)
ast::Instruction::Cvt(d, arg)
}
ast::Instruction::Shl(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a);
Instruction::Shl(d, arg)
ast::Instruction::Shl(d, arg)
}
ast::Instruction::St(d, a) => {
let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a);
Instruction::St(d, arg)
ast::Instruction::St(d, arg)
}
ast::Instruction::Ret(d) => Instruction::Ret(d),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
}
}
@ -395,9 +395,9 @@ fn normalize_expand_arg2(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2<spirv::Word>,
) -> Arg2 {
Arg2 {
a: ast::Arg2<NormalizedArgParams>,
) -> ast::Arg2<ExpandedArgParams> {
ast::Arg2 {
dst: a.dst,
src: normalize_expand_operand(func, id_def, inst_type, a.src),
}
@ -407,9 +407,9 @@ fn normalize_expand_arg2mov(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2Mov<spirv::Word>,
) -> Arg2 {
Arg2 {
a: ast::Arg2Mov<NormalizedArgParams>,
) -> ast::Arg2Mov<ExpandedArgParams> {
ast::Arg2Mov {
dst: a.dst,
src: normalize_expand_mov_operand(func, id_def, inst_type, a.src),
}
@ -419,9 +419,9 @@ fn normalize_expand_arg2st(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2St<spirv::Word>,
) -> Arg2St {
Arg2St {
a: ast::Arg2St<NormalizedArgParams>,
) -> ast::Arg2St<ExpandedArgParams> {
ast::Arg2St {
src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
}
@ -431,9 +431,9 @@ fn normalize_expand_arg3(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg3<spirv::Word>,
) -> Arg3 {
Arg3 {
a: ast::Arg3<NormalizedArgParams>,
) -> ast::Arg3<ExpandedArgParams> {
ast::Arg3 {
dst: a.dst,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
@ -444,9 +444,9 @@ fn normalize_expand_arg4(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg4<spirv::Word>,
) -> Arg4 {
Arg4 {
a: ast::Arg4<NormalizedArgParams>,
) -> ast::Arg4<ExpandedArgParams> {
ast::Arg4 {
dst1: a.dst1,
dst2: a.dst2,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
@ -458,9 +458,9 @@ fn normalize_expand_arg5(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg5<spirv::Word>,
) -> Arg5 {
Arg5 {
a: ast::Arg5<NormalizedArgParams>,
) -> ast::Arg5<ExpandedArgParams> {
ast::Arg5 {
dst1: a.dst1,
dst2: a.dst2,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
@ -527,7 +527,7 @@ fn insert_implicit_conversions(
for s in func.into_iter() {
match s {
Statement::Instruction(inst) => match inst {
Instruction::Ld(ld, mut arg) => {
ast::Instruction::Ld(ld, mut arg) => {
arg.src = insert_implicit_conversions_ld_src(
&mut result,
ast::Type::Scalar(ld.typ),
@ -542,10 +542,10 @@ fn insert_implicit_conversions(
should_convert_relaxed_dst,
arg,
|arg| &mut arg.dst,
|arg| Instruction::Ld(ld, arg),
|arg| ast::Instruction::Ld(ld, arg),
);
}
Instruction::St(st, mut arg) => {
ast::Instruction::St(st, mut arg) => {
let arg_src2_type = id_def.get_type(arg.src2);
if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
arg.src2 = insert_conversion_src(
@ -564,7 +564,7 @@ fn insert_implicit_conversions(
st.state_space.to_ld_ss(),
arg.src1,
);
result.push(Statement::Instruction(Instruction::St(st, arg)));
result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
},
@ -668,10 +668,10 @@ fn emit_function_body_ops(
}
Statement::Instruction(inst) => match inst {
// SPIR-V does not support marking jumps as guaranteed-converged
Instruction::Bra(_, arg) => {
ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?;
}
Instruction::Ld(data, arg) => {
ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
todo!()
}
@ -686,7 +686,7 @@ fn emit_function_body_ops(
_ => todo!(),
}
}
Instruction::St(data, arg) => {
ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
|| data.vector.is_some()
|| data.state_space != ast::StStateSpace::Generic
@ -696,18 +696,18 @@ fn emit_function_body_ops(
builder.store(arg.src1, arg.src2, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
Instruction::Ret(_) => builder.ret()?,
Instruction::Mov(mov, arg) => {
ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(mov, arg) => {
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
Instruction::Mul(mul, arg) => match mul {
ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Int(ref ctr) => {
emit_mul_int(builder, map, opencl, ctr, arg)?;
}
ast::MulDetails::Float(_) => todo!(),
},
Instruction::Add(add, arg) => match add {
ast::Instruction::Add(add, arg) => match add {
ast::AddDetails::Int(ref desc) => {
emit_add_int(builder, map, desc, arg)?;
}
@ -732,7 +732,7 @@ fn emit_mul_int(
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulIntDesc,
arg: &Arg3,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into()));
match desc.control {
@ -762,7 +762,7 @@ fn emit_add_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
ctr: &ast::AddIntDesc,
arg: &Arg3,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into()));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
@ -837,10 +837,10 @@ fn emit_implicit_conversion(
}
// TODO: support scopes
fn normalize_identifiers<'a>(
args: &'a [ast::Argument<'a>],
func: Vec<ast::Statement<&'a str>>,
) -> (Vec<ast::Statement<spirv::Word>>, NumericIdResolver) {
fn normalize_identifiers<'a, 'b>(
args: &'b [ast::Argument<'a>],
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> (Vec<ast::Statement<NormalizedArgParams>>, NumericIdResolver) {
let mut id_defs = StringIdResolver::new();
for arg in args {
id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type)));
@ -854,8 +854,8 @@ fn normalize_identifiers<'a>(
fn expand_map_ids<'a>(
id_defs: &mut StringIdResolver<'a>,
result: &mut Vec<ast::Statement<spirv::Word>>,
s: ast::Statement<&'a str>,
result: &mut Vec<ast::Statement<NormalizedArgParams>>,
s: ast::Statement<ast::ParsedArgParams<'a>>,
) {
match s {
ast::Statement::Label(name) => {
@ -979,7 +979,7 @@ enum Statement<I> {
Constant(ConstantDefinition),
}
impl Statement<Instruction> {
impl Statement<ast::Instruction<ExpandedArgParams>> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
match self {
Statement::Variable(id, _, _) => f(id),
@ -994,25 +994,25 @@ impl Statement<Instruction> {
}
}
type NormalizedStatement = Statement<ast::Instruction<spirv::Word>>;
type ExpandedStatement = Statement<Instruction>;
enum NormalizedArgParams {}
type NormalizedStatement = Statement<ast::Instruction<NormalizedArgParams>>;
enum Instruction {
Ld(ast::LdData, Arg2),
Mov(ast::MovData, Arg2),
Mul(ast::MulDetails, Arg3),
Add(ast::AddDetails, Arg3),
Setp(ast::SetpData, Arg4),
SetpBool(ast::SetpBoolData, Arg5),
Not(ast::NotData, Arg2),
Bra(ast::BraData, Arg1),
Cvt(ast::CvtData, Arg2),
Shl(ast::ShlData, Arg3),
St(ast::StData, Arg2St),
Ret(ast::RetData),
impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type MovOperand = ast::MovOperand<spirv::Word>;
}
impl ast::Instruction<spirv::Word> {
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>>;
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
type MovOperand = spirv::Word;
}
impl ast::Instruction<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) {
match self {
ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
@ -1031,22 +1031,22 @@ impl ast::Instruction<spirv::Word> {
}
}
impl Instruction {
impl ast::Instruction<ExpandedArgParams> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
let f_visitor = &mut Self::typed_visitor(f);
match self {
Instruction::Ld(_, a) => a.visit_id(f_visitor, None),
Instruction::Mov(_, a) => a.visit_id(f_visitor, None),
Instruction::Mul(_, a) => a.visit_id(f_visitor, None),
Instruction::Add(_, a) => a.visit_id(f_visitor, None),
Instruction::Setp(_, a) => a.visit_id(f_visitor, None),
Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None),
Instruction::Not(_, a) => a.visit_id(f_visitor, None),
Instruction::Cvt(_, a) => a.visit_id(f_visitor, None),
Instruction::Shl(_, a) => a.visit_id(f_visitor, None),
Instruction::St(_, a) => a.visit_id(f_visitor, None),
Instruction::Bra(_, a) => a.visit_id(f_visitor, None),
Instruction::Ret(_) => (),
ast::Instruction::Ld(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Mov(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Mul(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Add(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Setp(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Not(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Cvt(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Shl(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::St(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Bra(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Ret(_) => (),
}
}
@ -1061,42 +1061,40 @@ impl Instruction {
f: &mut F,
) {
match self {
Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Not(_, _) => todo!(),
Instruction::Cvt(_, _) => todo!(),
Instruction::Shl(_, _) => todo!(),
Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Bra(_, a) => a.visit_id(f, None),
Instruction::Ret(_) => (),
ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
ast::Instruction::Not(_, _) => todo!(),
ast::Instruction::Cvt(_, _) => todo!(),
ast::Instruction::Shl(_, _) => todo!(),
ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
ast::Instruction::Bra(_, a) => a.visit_id(f, None),
ast::Instruction::Ret(_) => (),
}
}
fn jump_target(&self) -> Option<spirv::Word> {
match self {
Instruction::Bra(_, a) => Some(a.src),
Instruction::Ld(_, _)
| Instruction::Mov(_, _)
| Instruction::Mul(_, _)
| Instruction::Add(_, _)
| Instruction::Setp(_, _)
| Instruction::SetpBool(_, _)
| Instruction::Not(_, _)
| Instruction::Cvt(_, _)
| Instruction::Shl(_, _)
| Instruction::St(_, _)
| Instruction::Ret(_) => None,
ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
| ast::Instruction::SetpBool(_, _)
| ast::Instruction::Not(_, _)
| ast::Instruction::Cvt(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Ret(_) => None,
}
}
}
struct Arg1 {
pub src: spirv::Word,
}
type Arg1 = ast::Arg1<ExpandedArgParams>;
impl Arg1 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1108,10 +1106,7 @@ impl Arg1 {
}
}
struct Arg2 {
pub dst: spirv::Word,
pub src: spirv::Word,
}
type Arg2 = ast::Arg2<ExpandedArgParams>;
impl Arg2 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1124,11 +1119,21 @@ impl Arg2 {
}
}
pub struct Arg2St {
pub src1: spirv::Word,
pub src2: spirv::Word,
type Arg2Mov = ast::Arg2Mov<ExpandedArgParams>;
impl Arg2Mov {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
f(false, &mut self.src, t);
}
}
type Arg2St = ast::Arg2St<ExpandedArgParams>;
impl Arg2St {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
@ -1140,11 +1145,7 @@ impl Arg2St {
}
}
struct Arg3 {
pub dst: spirv::Word,
pub src1: spirv::Word,
pub src2: spirv::Word,
}
type Arg3 = ast::Arg3<ExpandedArgParams>;
impl Arg3 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1158,12 +1159,7 @@ impl Arg3 {
}
}
struct Arg4 {
pub dst1: spirv::Word,
pub dst2: Option<spirv::Word>,
pub src1: spirv::Word,
pub src2: spirv::Word,
}
type Arg4 = ast::Arg4<ExpandedArgParams>;
impl Arg4 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1188,13 +1184,7 @@ impl Arg4 {
}
}
struct Arg5 {
pub dst1: spirv::Word,
pub dst2: Option<spirv::Word>,
pub src1: spirv::Word,
pub src2: spirv::Word,
pub src3: spirv::Word,
}
type Arg5 = ast::Arg5<ExpandedArgParams>;
impl Arg5 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1286,8 +1276,11 @@ impl<T> ast::PredAt<T> {
}
}
impl<T> ast::Instruction<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Instruction<NormalizedArgParams> {
match self {
ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)),
ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)),
@ -1305,13 +1298,13 @@ impl<T> ast::Instruction<T> {
}
}
impl<T> ast::Arg1<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg1<U> {
impl<'a> ast::Arg1<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg1<NormalizedArgParams> {
ast::Arg1 { src: f(self.src) }
}
}
impl ast::Arg1<spirv::Word> {
impl ast::Arg1<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@ -1321,8 +1314,8 @@ impl ast::Arg1<spirv::Word> {
}
}
impl<T> ast::Arg2<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2<U> {
impl<'a> ast::Arg2<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg2<NormalizedArgParams> {
ast::Arg2 {
dst: f(self.dst),
src: self.src.map_id(f),
@ -1330,7 +1323,7 @@ impl<T> ast::Arg2<T> {
}
}
impl ast::Arg2<spirv::Word> {
impl ast::Arg2<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@ -1341,8 +1334,11 @@ impl ast::Arg2<spirv::Word> {
}
}
impl<T> ast::Arg2St<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2St<U> {
impl<'a> ast::Arg2St<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Arg2St<NormalizedArgParams> {
ast::Arg2St {
src1: self.src1.map_id(f),
src2: self.src2.map_id(f),
@ -1350,7 +1346,7 @@ impl<T> ast::Arg2St<T> {
}
}
impl ast::Arg2St<spirv::Word> {
impl ast::Arg2St<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@ -1361,8 +1357,11 @@ impl ast::Arg2St<spirv::Word> {
}
}
impl<T> ast::Arg2Mov<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> {
impl<'a> ast::Arg2Mov<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Arg2Mov<NormalizedArgParams> {
ast::Arg2Mov {
dst: f(self.dst),
src: self.src.map_id(f),
@ -1370,7 +1369,7 @@ impl<T> ast::Arg2Mov<T> {
}
}
impl ast::Arg2Mov<spirv::Word> {
impl ast::Arg2Mov<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@ -1381,8 +1380,8 @@ impl ast::Arg2Mov<spirv::Word> {
}
}
impl<T> ast::Arg3<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg3<U> {
impl<'a> ast::Arg3<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg3<NormalizedArgParams> {
ast::Arg3 {
dst: f(self.dst),
src1: self.src1.map_id(f),
@ -1391,7 +1390,7 @@ impl<T> ast::Arg3<T> {
}
}
impl ast::Arg3<spirv::Word> {
impl ast::Arg3<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@ -1403,8 +1402,8 @@ impl ast::Arg3<spirv::Word> {
}
}
impl<T> ast::Arg4<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
impl<'a> ast::Arg4<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg4<NormalizedArgParams> {
ast::Arg4 {
dst1: f(self.dst1),
dst2: self.dst2.map(|i| f(i)),
@ -1414,7 +1413,7 @@ impl<T> ast::Arg4<T> {
}
}
impl ast::Arg4<spirv::Word> {
impl ast::Arg4<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@ -1437,8 +1436,8 @@ impl ast::Arg4<spirv::Word> {
}
}
impl<T> ast::Arg5<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg5<U> {
impl<'a> ast::Arg5<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg5<NormalizedArgParams> {
ast::Arg5 {
dst1: f(self.dst1),
dst2: self.dst2.map(|i| f(i)),
@ -1449,7 +1448,7 @@ impl<T> ast::Arg5<T> {
}
}
impl ast::Arg5<spirv::Word> {
impl ast::Arg5<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@ -1779,7 +1778,7 @@ fn insert_with_implicit_conversion_dst<
T,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> Instruction,
ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>,
>(
func: &mut Vec<ExpandedStatement>,
instr_type: ast::ScalarType,
@ -1907,7 +1906,7 @@ fn should_convert_relaxed_dst(
fn insert_implicit_bitcasts(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
mut instr: Instruction,
mut instr: ast::Instruction<ExpandedArgParams>,
) {
let mut dst_coercion = None;
instr.visit_id_extended(&mut |is_dst, id, id_type| {