Remove the need for custom Arg types in middle-end

This commit is contained in:
Andrzej Janik
2020-07-28 02:44:24 +02:00
parent d514a5610a
commit 52faaab547
3 changed files with 244 additions and 229 deletions

View File

@ -1,5 +1,5 @@
use std::convert::From; use std::convert::From;
use std::num::ParseIntError; use std::{marker::PhantomData, num::ParseIntError};
quick_error! { quick_error! {
#[derive(Debug)] #[derive(Debug)]
@ -52,7 +52,7 @@ pub struct Function<'a> {
pub kernel: bool, pub kernel: bool,
pub name: &'a str, pub name: &'a str,
pub args: Vec<Argument<'a>>, pub args: Vec<Argument<'a>>,
pub body: Vec<Statement<&'a str>>, pub body: Vec<Statement<ParsedArgParams<'a>>>,
} }
#[derive(Default)] #[derive(Default)]
@ -141,16 +141,16 @@ impl Default for ScalarType {
} }
} }
pub enum Statement<ID> { pub enum Statement<P: ArgParams> {
Label(ID), Label(P::ID),
Variable(Variable<ID>), Variable(Variable<P>),
Instruction(Option<PredAt<ID>>, Instruction<ID>), Instruction(Option<PredAt<P::ID>>, Instruction<P>),
} }
pub struct Variable<ID> { pub struct Variable<P: ArgParams> {
pub space: StateSpace, pub space: StateSpace,
pub v_type: Type, pub v_type: Type,
pub name: ID, pub name: P::ID,
pub count: Option<u32>, pub count: Option<u32>,
} }
@ -169,59 +169,75 @@ pub struct PredAt<ID> {
pub label: ID, pub label: ID,
} }
pub enum Instruction<ID> { pub enum Instruction<P: ArgParams> {
Ld(LdData, Arg2<ID>), Ld(LdData, Arg2<P>),
Mov(MovData, Arg2Mov<ID>), Mov(MovData, Arg2Mov<P>),
Mul(MulDetails, Arg3<ID>), Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<ID>), Add(AddDetails, Arg3<P>),
Setp(SetpData, Arg4<ID>), Setp(SetpData, Arg4<P>),
SetpBool(SetpBoolData, Arg5<ID>), SetpBool(SetpBoolData, Arg5<P>),
Not(NotData, Arg2<ID>), Not(NotData, Arg2<P>),
Bra(BraData, Arg1<ID>), Bra(BraData, Arg1<P>),
Cvt(CvtData, Arg2<ID>), Cvt(CvtData, Arg2<P>),
Shl(ShlData, Arg3<ID>), Shl(ShlData, Arg3<P>),
St(StData, Arg2St<ID>), St(StData, Arg2St<P>),
Ret(RetData), Ret(RetData),
} }
pub struct Arg1<ID> { pub trait ArgParams {
pub src: ID, // it is a jump destination, but in terms of operands it is a source operand type ID;
type Operand;
type MovOperand;
} }
pub struct Arg2<ID> { pub struct ParsedArgParams<'a> {
pub dst: ID, _marker: PhantomData<&'a ()>,
pub src: Operand<ID>,
} }
pub struct Arg2St<ID> { impl<'a> ArgParams for ParsedArgParams<'a> {
pub src1: Operand<ID>, type ID = &'a str;
pub src2: Operand<ID>, type Operand = Operand<&'a str>;
type MovOperand = MovOperand<&'a str>;
} }
pub struct Arg2Mov<ID> { pub struct Arg1<P: ArgParams> {
pub dst: ID, pub src: P::ID, // it is a jump destination, but in terms of operands it is a source operand
pub src: MovOperand<ID>,
} }
pub struct Arg3<ID> { pub struct Arg2<P: ArgParams> {
pub dst: ID, pub dst: P::ID,
pub src1: Operand<ID>, pub src: P::Operand,
pub src2: Operand<ID>,
} }
pub struct Arg4<ID> { pub struct Arg2St<P: ArgParams> {
pub dst1: ID, pub src1: P::Operand,
pub dst2: Option<ID>, pub src2: P::Operand,
pub src1: Operand<ID>,
pub src2: Operand<ID>,
} }
pub struct Arg5<ID> { pub struct Arg2Mov<P: ArgParams> {
pub dst1: ID, pub dst: P::ID,
pub dst2: Option<ID>, pub src: P::MovOperand,
pub src1: Operand<ID>, }
pub src2: Operand<ID>,
pub src3: Operand<ID>, pub struct Arg3<P: ArgParams> {
pub dst: P::ID,
pub src1: P::Operand,
pub src2: P::Operand,
}
pub struct Arg4<P: ArgParams> {
pub dst1: P::ID,
pub dst2: Option<P::ID>,
pub src1: P::Operand,
pub src2: P::Operand,
}
pub struct Arg5<P: ArgParams> {
pub dst1: P::ID,
pub dst2: Option<P::ID>,
pub src1: P::Operand,
pub src2: P::Operand,
pub src3: P::Operand,
} }
pub enum Operand<ID> { pub enum Operand<ID> {

View File

@ -223,7 +223,7 @@ FunctionInput: ast::Argument<'input> = {
} }
}; };
pub(crate) FunctionBody: Vec<ast::Statement<&'input str>> = { pub(crate) FunctionBody: Vec<ast::Statement<ast::ParsedArgParams<'input>>> = {
"{" <s:Statement*> "}" => { without_none(s) } "{" <s:Statement*> "}" => { without_none(s) }
}; };
@ -269,7 +269,7 @@ MemoryType: ast::ScalarType = {
".f64" => ast::ScalarType::F64, ".f64" => ast::ScalarType::F64,
}; };
Statement: Option<ast::Statement<&'input str>> = { Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
<l:Label> => Some(ast::Statement::Label(l)), <l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None, DebugDirective => None,
<v:Variable> ";" => Some(ast::Statement::Variable(v)), <v:Variable> ";" => Some(ast::Statement::Variable(v)),
@ -289,7 +289,7 @@ Label: &'input str = {
<id:ExtendedID> ":" => id <id:ExtendedID> ":" => id
}; };
Variable: ast::Variable<&'input str> = { Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
<s:StateSpaceSpecifier> <t:Type> <v:VariableName> => { <s:StateSpaceSpecifier> <t:Type> <v:VariableName> => {
let (name, count) = v; let (name, count) = v;
ast::Variable { space: s, v_type: t, name: name, count: count } ast::Variable { space: s, v_type: t, name: name, count: count }
@ -310,7 +310,7 @@ VariableName: (&'input str, Option<u32>) = {
} }
}; };
Instruction: ast::Instruction<&'input str> = { Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd, InstLd,
InstMov, InstMov,
InstMul, InstMul,
@ -325,7 +325,7 @@ Instruction: ast::Instruction<&'input str> = {
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<&'input str> = { InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ExtendedID> "," "[" <src:Operand> "]" => { "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
ast::Instruction::Ld( ast::Instruction::Ld(
ast::LdData { ast::LdData {
@ -370,7 +370,7 @@ LdCacheOperator: ast::LdCacheOperator = {
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
InstMov: ast::Instruction<&'input str> = { InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mov" <t:MovType> <a:Arg2Mov> => { "mov" <t:MovType> <a:Arg2Mov> => {
ast::Instruction::Mov(ast::MovData{ typ:t }, a) ast::Instruction::Mov(ast::MovData{ typ:t }, a)
} }
@ -394,7 +394,7 @@ MovType: ast::Type = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
InstMul: ast::Instruction<&'input str> = { InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a) "mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a)
}; };
@ -455,7 +455,7 @@ IntType : ast::IntType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
InstAdd: ast::Instruction<&'input str> = { InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a) "add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a)
}; };
@ -492,7 +492,7 @@ InstAddMode: ast::AddDetails = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
// TODO: support f16 setp // TODO: support f16 setp
InstSetp: ast::Instruction<&'input str> = { InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = {
"setp" <d:SetpMode> <a:Arg4> => ast::Instruction::Setp(d, a), "setp" <d:SetpMode> <a:Arg4> => ast::Instruction::Setp(d, a),
"setp" <d:SetpBoolMode> <a:Arg5> => ast::Instruction::SetpBool(d, a), "setp" <d:SetpBoolMode> <a:Arg5> => ast::Instruction::SetpBool(d, a),
}; };
@ -556,7 +556,7 @@ SetpType: ast::ScalarType = {
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
InstNot: ast::Instruction<&'input str> = { InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
"not" NotType <a:Arg2> => ast::Instruction::Not(ast::NotData{}, a) "not" NotType <a:Arg2> => ast::Instruction::Not(ast::NotData{}, a)
}; };
@ -571,12 +571,12 @@ PredAt: ast::PredAt<&'input str> = {
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
InstBra: ast::Instruction<&'input str> = { InstBra: ast::Instruction<ast::ParsedArgParams<'input>> = {
"bra" <u:".uni"?> <a:Arg1> => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a) "bra" <u:".uni"?> <a:Arg1> => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a)
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
InstCvt: ast::Instruction<&'input str> = { InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType <a:Arg2> => { "cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtData{}, a) ast::Instruction::Cvt(ast::CvtData{}, a)
} }
@ -602,7 +602,7 @@ CvtType = {
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
InstShl: ast::Instruction<&'input str> = { InstShl: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shl" ShlType <a:Arg3> => ast::Instruction::Shl(ast::ShlData{}, a) "shl" ShlType <a:Arg3> => ast::Instruction::Shl(ast::ShlData{}, a)
}; };
@ -612,7 +612,7 @@ ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once // Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<&'input str> = { InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => { "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => {
ast::Instruction::St( ast::Instruction::St(
ast::StData { ast::StData {
@ -642,7 +642,7 @@ StCacheOperator: ast::StCacheOperator = {
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
InstRet: ast::Instruction<&'input str> = { InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = {
"ret" <u:".uni"?> => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() }) "ret" <u:".uni"?> => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() })
}; };
@ -675,28 +675,28 @@ VectorOperand: (&'input str, &'input str) = {
<pref:ExtendedID> <suf:DotID> => (pref, &suf[1..]), <pref:ExtendedID> <suf:DotID> => (pref, &suf[1..]),
}; };
Arg1: ast::Arg1<&'input str> = { Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = {
<src:ExtendedID> => ast::Arg1{<>} <src:ExtendedID> => ast::Arg1{<>}
}; };
Arg2: ast::Arg2<&'input str> = { Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>} <dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
}; };
Arg2Mov: ast::Arg2Mov<&'input str> = { Arg2Mov: ast::Arg2Mov<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:MovOperand> => ast::Arg2Mov{<>} <dst:ExtendedID> "," <src:MovOperand> => ast::Arg2Mov{<>}
}; };
Arg3: ast::Arg3<&'input str> = { Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>} <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
}; };
Arg4: ast::Arg4<&'input str> = { Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4{<>} <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4{<>}
}; };
// TODO: pass src3 negation somewhere // TODO: pass src3 negation somewhere
Arg5: ast::Arg5<&'input str> = { Arg5: ast::Arg5<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>} <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
}; };

View File

@ -168,9 +168,9 @@ fn apply_id_offset(func_body: &mut Vec<ExpandedStatement>, id_offset: u32) {
} }
} }
fn to_ssa<'a>( fn to_ssa<'a, 'b>(
f_args: &[ast::Argument], f_args: &'b [ast::Argument<'a>],
f_body: Vec<ast::Statement<&'a str>>, f_body: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> (Vec<ExpandedStatement>, spirv::Word) { ) -> (Vec<ExpandedStatement>, spirv::Word) {
let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body); let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body);
let normalized_statements = normalize_predicates(normalized_ids, &mut id_def); let normalized_statements = normalize_predicates(normalized_ids, &mut id_def);
@ -214,7 +214,7 @@ fn normalize_labels(
} }
fn normalize_predicates( fn normalize_predicates(
func: Vec<ast::Statement<spirv::Word>>, func: Vec<ast::Statement<NormalizedArgParams>>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
) -> Vec<NormalizedStatement> { ) -> Vec<NormalizedStatement> {
let mut result = Vec::with_capacity(func.len()); let mut result = Vec::with_capacity(func.len());
@ -343,51 +343,51 @@ fn expand_arguments(
fn normalize_insert_instruction( fn normalize_insert_instruction(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
instr: ast::Instruction<spirv::Word>, instr: ast::Instruction<NormalizedArgParams>,
) -> Instruction { ) -> ast::Instruction<ExpandedArgParams> {
match instr { match instr {
ast::Instruction::Ld(d, a) => { ast::Instruction::Ld(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a); let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a);
Instruction::Ld(d, arg) ast::Instruction::Ld(d, arg)
} }
ast::Instruction::Mov(d, a) => { ast::Instruction::Mov(d, a) => {
let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a); let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a);
Instruction::Mov(d, arg) ast::Instruction::Mov(d, arg)
} }
ast::Instruction::Mul(d, a) => { ast::Instruction::Mul(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
Instruction::Mul(d, arg) ast::Instruction::Mul(d, arg)
} }
ast::Instruction::Add(d, a) => { ast::Instruction::Add(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
Instruction::Add(d, arg) ast::Instruction::Add(d, arg)
} }
ast::Instruction::Setp(d, a) => { ast::Instruction::Setp(d, a) => {
let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a); let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a);
Instruction::Setp(d, arg) ast::Instruction::Setp(d, arg)
} }
ast::Instruction::SetpBool(d, a) => { ast::Instruction::SetpBool(d, a) => {
let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a); let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a);
Instruction::SetpBool(d, arg) ast::Instruction::SetpBool(d, arg)
} }
ast::Instruction::Not(d, a) => { ast::Instruction::Not(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
Instruction::Not(d, arg) ast::Instruction::Not(d, arg)
} }
ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }), ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, ast::Arg1 { src: a.src }),
ast::Instruction::Cvt(d, a) => { ast::Instruction::Cvt(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
Instruction::Cvt(d, arg) ast::Instruction::Cvt(d, arg)
} }
ast::Instruction::Shl(d, a) => { ast::Instruction::Shl(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a); let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a);
Instruction::Shl(d, arg) ast::Instruction::Shl(d, arg)
} }
ast::Instruction::St(d, a) => { ast::Instruction::St(d, a) => {
let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a); let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a);
Instruction::St(d, arg) ast::Instruction::St(d, arg)
} }
ast::Instruction::Ret(d) => Instruction::Ret(d), ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
} }
} }
@ -395,9 +395,9 @@ fn normalize_expand_arg2(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2<spirv::Word>, a: ast::Arg2<NormalizedArgParams>,
) -> Arg2 { ) -> ast::Arg2<ExpandedArgParams> {
Arg2 { ast::Arg2 {
dst: a.dst, dst: a.dst,
src: normalize_expand_operand(func, id_def, inst_type, a.src), src: normalize_expand_operand(func, id_def, inst_type, a.src),
} }
@ -407,9 +407,9 @@ fn normalize_expand_arg2mov(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2Mov<spirv::Word>, a: ast::Arg2Mov<NormalizedArgParams>,
) -> Arg2 { ) -> ast::Arg2Mov<ExpandedArgParams> {
Arg2 { ast::Arg2Mov {
dst: a.dst, dst: a.dst,
src: normalize_expand_mov_operand(func, id_def, inst_type, a.src), src: normalize_expand_mov_operand(func, id_def, inst_type, a.src),
} }
@ -419,9 +419,9 @@ fn normalize_expand_arg2st(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2St<spirv::Word>, a: ast::Arg2St<NormalizedArgParams>,
) -> Arg2St { ) -> ast::Arg2St<ExpandedArgParams> {
Arg2St { ast::Arg2St {
src1: normalize_expand_operand(func, id_def, inst_type, a.src1), src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
src2: normalize_expand_operand(func, id_def, inst_type, a.src2), src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
} }
@ -431,9 +431,9 @@ fn normalize_expand_arg3(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg3<spirv::Word>, a: ast::Arg3<NormalizedArgParams>,
) -> Arg3 { ) -> ast::Arg3<ExpandedArgParams> {
Arg3 { ast::Arg3 {
dst: a.dst, dst: a.dst,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1), src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
src2: normalize_expand_operand(func, id_def, inst_type, a.src2), src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
@ -444,9 +444,9 @@ fn normalize_expand_arg4(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg4<spirv::Word>, a: ast::Arg4<NormalizedArgParams>,
) -> Arg4 { ) -> ast::Arg4<ExpandedArgParams> {
Arg4 { ast::Arg4 {
dst1: a.dst1, dst1: a.dst1,
dst2: a.dst2, dst2: a.dst2,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1), src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
@ -458,9 +458,9 @@ fn normalize_expand_arg5(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg5<spirv::Word>, a: ast::Arg5<NormalizedArgParams>,
) -> Arg5 { ) -> ast::Arg5<ExpandedArgParams> {
Arg5 { ast::Arg5 {
dst1: a.dst1, dst1: a.dst1,
dst2: a.dst2, dst2: a.dst2,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1), src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
@ -527,7 +527,7 @@ fn insert_implicit_conversions(
for s in func.into_iter() { for s in func.into_iter() {
match s { match s {
Statement::Instruction(inst) => match inst { Statement::Instruction(inst) => match inst {
Instruction::Ld(ld, mut arg) => { ast::Instruction::Ld(ld, mut arg) => {
arg.src = insert_implicit_conversions_ld_src( arg.src = insert_implicit_conversions_ld_src(
&mut result, &mut result,
ast::Type::Scalar(ld.typ), ast::Type::Scalar(ld.typ),
@ -542,10 +542,10 @@ fn insert_implicit_conversions(
should_convert_relaxed_dst, should_convert_relaxed_dst,
arg, arg,
|arg| &mut arg.dst, |arg| &mut arg.dst,
|arg| Instruction::Ld(ld, arg), |arg| ast::Instruction::Ld(ld, arg),
); );
} }
Instruction::St(st, mut arg) => { ast::Instruction::St(st, mut arg) => {
let arg_src2_type = id_def.get_type(arg.src2); let arg_src2_type = id_def.get_type(arg.src2);
if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) { if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
arg.src2 = insert_conversion_src( arg.src2 = insert_conversion_src(
@ -564,7 +564,7 @@ fn insert_implicit_conversions(
st.state_space.to_ld_ss(), st.state_space.to_ld_ss(),
arg.src1, arg.src1,
); );
result.push(Statement::Instruction(Instruction::St(st, arg))); result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
} }
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst), inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
}, },
@ -668,10 +668,10 @@ fn emit_function_body_ops(
} }
Statement::Instruction(inst) => match inst { Statement::Instruction(inst) => match inst {
// SPIR-V does not support marking jumps as guaranteed-converged // SPIR-V does not support marking jumps as guaranteed-converged
Instruction::Bra(_, arg) => { ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?; builder.branch(arg.src)?;
} }
Instruction::Ld(data, arg) => { ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() { if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
todo!() todo!()
} }
@ -686,7 +686,7 @@ fn emit_function_body_ops(
_ => todo!(), _ => todo!(),
} }
} }
Instruction::St(data, arg) => { ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak if data.qualifier != ast::LdStQualifier::Weak
|| data.vector.is_some() || data.vector.is_some()
|| data.state_space != ast::StStateSpace::Generic || data.state_space != ast::StStateSpace::Generic
@ -696,18 +696,18 @@ fn emit_function_body_ops(
builder.store(arg.src1, arg.src2, None, &[])?; builder.store(arg.src1, arg.src2, None, &[])?;
} }
// SPIR-V does not support ret as guaranteed-converged // SPIR-V does not support ret as guaranteed-converged
Instruction::Ret(_) => builder.ret()?, ast::Instruction::Ret(_) => builder.ret()?,
Instruction::Mov(mov, arg) => { ast::Instruction::Mov(mov, arg) => {
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ)); let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
builder.copy_object(result_type, Some(arg.dst), arg.src)?; builder.copy_object(result_type, Some(arg.dst), arg.src)?;
} }
Instruction::Mul(mul, arg) => match mul { ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Int(ref ctr) => { ast::MulDetails::Int(ref ctr) => {
emit_mul_int(builder, map, opencl, ctr, arg)?; emit_mul_int(builder, map, opencl, ctr, arg)?;
} }
ast::MulDetails::Float(_) => todo!(), ast::MulDetails::Float(_) => todo!(),
}, },
Instruction::Add(add, arg) => match add { ast::Instruction::Add(add, arg) => match add {
ast::AddDetails::Int(ref desc) => { ast::AddDetails::Int(ref desc) => {
emit_add_int(builder, map, desc, arg)?; emit_add_int(builder, map, desc, arg)?;
} }
@ -732,7 +732,7 @@ fn emit_mul_int(
map: &mut TypeWordMap, map: &mut TypeWordMap,
opencl: spirv::Word, opencl: spirv::Word,
desc: &ast::MulIntDesc, desc: &ast::MulIntDesc,
arg: &Arg3, arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into())); let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into()));
match desc.control { match desc.control {
@ -762,7 +762,7 @@ fn emit_add_int(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
ctr: &ast::AddIntDesc, ctr: &ast::AddIntDesc,
arg: &Arg3, arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into())); let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into()));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
@ -837,10 +837,10 @@ fn emit_implicit_conversion(
} }
// TODO: support scopes // TODO: support scopes
fn normalize_identifiers<'a>( fn normalize_identifiers<'a, 'b>(
args: &'a [ast::Argument<'a>], args: &'b [ast::Argument<'a>],
func: Vec<ast::Statement<&'a str>>, func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> (Vec<ast::Statement<spirv::Word>>, NumericIdResolver) { ) -> (Vec<ast::Statement<NormalizedArgParams>>, NumericIdResolver) {
let mut id_defs = StringIdResolver::new(); let mut id_defs = StringIdResolver::new();
for arg in args { for arg in args {
id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type))); id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type)));
@ -854,8 +854,8 @@ fn normalize_identifiers<'a>(
fn expand_map_ids<'a>( fn expand_map_ids<'a>(
id_defs: &mut StringIdResolver<'a>, id_defs: &mut StringIdResolver<'a>,
result: &mut Vec<ast::Statement<spirv::Word>>, result: &mut Vec<ast::Statement<NormalizedArgParams>>,
s: ast::Statement<&'a str>, s: ast::Statement<ast::ParsedArgParams<'a>>,
) { ) {
match s { match s {
ast::Statement::Label(name) => { ast::Statement::Label(name) => {
@ -979,7 +979,7 @@ enum Statement<I> {
Constant(ConstantDefinition), Constant(ConstantDefinition),
} }
impl Statement<Instruction> { impl Statement<ast::Instruction<ExpandedArgParams>> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) { fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
match self { match self {
Statement::Variable(id, _, _) => f(id), Statement::Variable(id, _, _) => f(id),
@ -994,25 +994,25 @@ impl Statement<Instruction> {
} }
} }
type NormalizedStatement = Statement<ast::Instruction<spirv::Word>>; enum NormalizedArgParams {}
type ExpandedStatement = Statement<Instruction>; type NormalizedStatement = Statement<ast::Instruction<NormalizedArgParams>>;
enum Instruction { impl ast::ArgParams for NormalizedArgParams {
Ld(ast::LdData, Arg2), type ID = spirv::Word;
Mov(ast::MovData, Arg2), type Operand = ast::Operand<spirv::Word>;
Mul(ast::MulDetails, Arg3), type MovOperand = ast::MovOperand<spirv::Word>;
Add(ast::AddDetails, Arg3),
Setp(ast::SetpData, Arg4),
SetpBool(ast::SetpBoolData, Arg5),
Not(ast::NotData, Arg2),
Bra(ast::BraData, Arg1),
Cvt(ast::CvtData, Arg2),
Shl(ast::ShlData, Arg3),
St(ast::StData, Arg2St),
Ret(ast::RetData),
} }
impl ast::Instruction<spirv::Word> { enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>>;
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
type MovOperand = spirv::Word;
}
impl ast::Instruction<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) {
match self { match self {
ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
@ -1031,22 +1031,22 @@ impl ast::Instruction<spirv::Word> {
} }
} }
impl Instruction { impl ast::Instruction<ExpandedArgParams> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) { fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
let f_visitor = &mut Self::typed_visitor(f); let f_visitor = &mut Self::typed_visitor(f);
match self { match self {
Instruction::Ld(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Ld(_, a) => a.visit_id(f_visitor, None),
Instruction::Mov(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Mov(_, a) => a.visit_id(f_visitor, None),
Instruction::Mul(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Mul(_, a) => a.visit_id(f_visitor, None),
Instruction::Add(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Add(_, a) => a.visit_id(f_visitor, None),
Instruction::Setp(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Setp(_, a) => a.visit_id(f_visitor, None),
Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None), ast::Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None),
Instruction::Not(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Not(_, a) => a.visit_id(f_visitor, None),
Instruction::Cvt(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Cvt(_, a) => a.visit_id(f_visitor, None),
Instruction::Shl(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Shl(_, a) => a.visit_id(f_visitor, None),
Instruction::St(_, a) => a.visit_id(f_visitor, None), ast::Instruction::St(_, a) => a.visit_id(f_visitor, None),
Instruction::Bra(_, a) => a.visit_id(f_visitor, None), ast::Instruction::Bra(_, a) => a.visit_id(f_visitor, None),
Instruction::Ret(_) => (), ast::Instruction::Ret(_) => (),
} }
} }
@ -1061,42 +1061,40 @@ impl Instruction {
f: &mut F, f: &mut F,
) { ) {
match self { match self {
Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Not(_, _) => todo!(), ast::Instruction::Not(_, _) => todo!(),
Instruction::Cvt(_, _) => todo!(), ast::Instruction::Cvt(_, _) => todo!(),
Instruction::Shl(_, _) => todo!(), ast::Instruction::Shl(_, _) => todo!(),
Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Bra(_, a) => a.visit_id(f, None), ast::Instruction::Bra(_, a) => a.visit_id(f, None),
Instruction::Ret(_) => (), ast::Instruction::Ret(_) => (),
} }
} }
fn jump_target(&self) -> Option<spirv::Word> { fn jump_target(&self) -> Option<spirv::Word> {
match self { match self {
Instruction::Bra(_, a) => Some(a.src), ast::Instruction::Bra(_, a) => Some(a.src),
Instruction::Ld(_, _) ast::Instruction::Ld(_, _)
| Instruction::Mov(_, _) | ast::Instruction::Mov(_, _)
| Instruction::Mul(_, _) | ast::Instruction::Mul(_, _)
| Instruction::Add(_, _) | ast::Instruction::Add(_, _)
| Instruction::Setp(_, _) | ast::Instruction::Setp(_, _)
| Instruction::SetpBool(_, _) | ast::Instruction::SetpBool(_, _)
| Instruction::Not(_, _) | ast::Instruction::Not(_, _)
| Instruction::Cvt(_, _) | ast::Instruction::Cvt(_, _)
| Instruction::Shl(_, _) | ast::Instruction::Shl(_, _)
| Instruction::St(_, _) | ast::Instruction::St(_, _)
| Instruction::Ret(_) => None, | ast::Instruction::Ret(_) => None,
} }
} }
} }
struct Arg1 { type Arg1 = ast::Arg1<ExpandedArgParams>;
pub src: spirv::Word,
}
impl Arg1 { impl Arg1 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1108,10 +1106,7 @@ impl Arg1 {
} }
} }
struct Arg2 { type Arg2 = ast::Arg2<ExpandedArgParams>;
pub dst: spirv::Word,
pub src: spirv::Word,
}
impl Arg2 { impl Arg2 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1124,11 +1119,21 @@ impl Arg2 {
} }
} }
pub struct Arg2St { type Arg2Mov = ast::Arg2Mov<ExpandedArgParams>;
pub src1: spirv::Word,
pub src2: spirv::Word, impl Arg2Mov {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
f(false, &mut self.src, t);
}
} }
type Arg2St = ast::Arg2St<ExpandedArgParams>;
impl Arg2St { impl Arg2St {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
@ -1140,11 +1145,7 @@ impl Arg2St {
} }
} }
struct Arg3 { type Arg3 = ast::Arg3<ExpandedArgParams>;
pub dst: spirv::Word,
pub src1: spirv::Word,
pub src2: spirv::Word,
}
impl Arg3 { impl Arg3 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1158,12 +1159,7 @@ impl Arg3 {
} }
} }
struct Arg4 { type Arg4 = ast::Arg4<ExpandedArgParams>;
pub dst1: spirv::Word,
pub dst2: Option<spirv::Word>,
pub src1: spirv::Word,
pub src2: spirv::Word,
}
impl Arg4 { impl Arg4 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1188,13 +1184,7 @@ impl Arg4 {
} }
} }
struct Arg5 { type Arg5 = ast::Arg5<ExpandedArgParams>;
pub dst1: spirv::Word,
pub dst2: Option<spirv::Word>,
pub src1: spirv::Word,
pub src2: spirv::Word,
pub src3: spirv::Word,
}
impl Arg5 { impl Arg5 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@ -1286,8 +1276,11 @@ impl<T> ast::PredAt<T> {
} }
} }
impl<T> ast::Instruction<T> { impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> { fn map_id<F: FnMut(&'a str) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Instruction<NormalizedArgParams> {
match self { match self {
ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)), ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)),
ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)), ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)),
@ -1305,13 +1298,13 @@ impl<T> ast::Instruction<T> {
} }
} }
impl<T> ast::Arg1<T> { impl<'a> ast::Arg1<ast::ParsedArgParams<'a>> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg1<U> { fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg1<NormalizedArgParams> {
ast::Arg1 { src: f(self.src) } ast::Arg1 { src: f(self.src) }
} }
} }
impl ast::Arg1<spirv::Word> { impl ast::Arg1<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
f: &mut F, f: &mut F,
@ -1321,8 +1314,8 @@ impl ast::Arg1<spirv::Word> {
} }
} }
impl<T> ast::Arg2<T> { impl<'a> ast::Arg2<ast::ParsedArgParams<'a>> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2<U> { fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg2<NormalizedArgParams> {
ast::Arg2 { ast::Arg2 {
dst: f(self.dst), dst: f(self.dst),
src: self.src.map_id(f), src: self.src.map_id(f),
@ -1330,7 +1323,7 @@ impl<T> ast::Arg2<T> {
} }
} }
impl ast::Arg2<spirv::Word> { impl ast::Arg2<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
f: &mut F, f: &mut F,
@ -1341,8 +1334,11 @@ impl ast::Arg2<spirv::Word> {
} }
} }
impl<T> ast::Arg2St<T> { impl<'a> ast::Arg2St<ast::ParsedArgParams<'a>> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2St<U> { fn map_id<F: FnMut(&'a str) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Arg2St<NormalizedArgParams> {
ast::Arg2St { ast::Arg2St {
src1: self.src1.map_id(f), src1: self.src1.map_id(f),
src2: self.src2.map_id(f), src2: self.src2.map_id(f),
@ -1350,7 +1346,7 @@ impl<T> ast::Arg2St<T> {
} }
} }
impl ast::Arg2St<spirv::Word> { impl ast::Arg2St<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
f: &mut F, f: &mut F,
@ -1361,8 +1357,11 @@ impl ast::Arg2St<spirv::Word> {
} }
} }
impl<T> ast::Arg2Mov<T> { impl<'a> ast::Arg2Mov<ast::ParsedArgParams<'a>> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> { fn map_id<F: FnMut(&'a str) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Arg2Mov<NormalizedArgParams> {
ast::Arg2Mov { ast::Arg2Mov {
dst: f(self.dst), dst: f(self.dst),
src: self.src.map_id(f), src: self.src.map_id(f),
@ -1370,7 +1369,7 @@ impl<T> ast::Arg2Mov<T> {
} }
} }
impl ast::Arg2Mov<spirv::Word> { impl ast::Arg2Mov<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
f: &mut F, f: &mut F,
@ -1381,8 +1380,8 @@ impl ast::Arg2Mov<spirv::Word> {
} }
} }
impl<T> ast::Arg3<T> { impl<'a> ast::Arg3<ast::ParsedArgParams<'a>> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg3<U> { fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg3<NormalizedArgParams> {
ast::Arg3 { ast::Arg3 {
dst: f(self.dst), dst: f(self.dst),
src1: self.src1.map_id(f), src1: self.src1.map_id(f),
@ -1391,7 +1390,7 @@ impl<T> ast::Arg3<T> {
} }
} }
impl ast::Arg3<spirv::Word> { impl ast::Arg3<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
f: &mut F, f: &mut F,
@ -1403,8 +1402,8 @@ impl ast::Arg3<spirv::Word> {
} }
} }
impl<T> ast::Arg4<T> { impl<'a> ast::Arg4<ast::ParsedArgParams<'a>> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> { fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg4<NormalizedArgParams> {
ast::Arg4 { ast::Arg4 {
dst1: f(self.dst1), dst1: f(self.dst1),
dst2: self.dst2.map(|i| f(i)), dst2: self.dst2.map(|i| f(i)),
@ -1414,7 +1413,7 @@ impl<T> ast::Arg4<T> {
} }
} }
impl ast::Arg4<spirv::Word> { impl ast::Arg4<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
f: &mut F, f: &mut F,
@ -1437,8 +1436,8 @@ impl ast::Arg4<spirv::Word> {
} }
} }
impl<T> ast::Arg5<T> { impl<'a> ast::Arg5<ast::ParsedArgParams<'a>> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg5<U> { fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg5<NormalizedArgParams> {
ast::Arg5 { ast::Arg5 {
dst1: f(self.dst1), dst1: f(self.dst1),
dst2: self.dst2.map(|i| f(i)), dst2: self.dst2.map(|i| f(i)),
@ -1449,7 +1448,7 @@ impl<T> ast::Arg5<T> {
} }
} }
impl ast::Arg5<spirv::Word> { impl ast::Arg5<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
f: &mut F, f: &mut F,
@ -1779,7 +1778,7 @@ fn insert_with_implicit_conversion_dst<
T, T,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>, ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word, Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> Instruction, ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>,
>( >(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
instr_type: ast::ScalarType, instr_type: ast::ScalarType,
@ -1907,7 +1906,7 @@ fn should_convert_relaxed_dst(
fn insert_implicit_bitcasts( fn insert_implicit_bitcasts(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
mut instr: Instruction, mut instr: ast::Instruction<ExpandedArgParams>,
) { ) {
let mut dst_coercion = None; let mut dst_coercion = None;
instr.visit_id_extended(&mut |is_dst, id, id_type| { instr.visit_id_extended(&mut |is_dst, id, id_type| {