diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 6bb099a..9fab216 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,5 +1,5 @@ use std::convert::From; -use std::num::ParseIntError; +use std::{marker::PhantomData, num::ParseIntError}; quick_error! { #[derive(Debug)] @@ -52,7 +52,7 @@ pub struct Function<'a> { pub kernel: bool, pub name: &'a str, pub args: Vec>, - pub body: Vec>, + pub body: Vec>>, } #[derive(Default)] @@ -141,16 +141,16 @@ impl Default for ScalarType { } } -pub enum Statement { - Label(ID), - Variable(Variable), - Instruction(Option>, Instruction), +pub enum Statement { + Label(P::ID), + Variable(Variable

), + Instruction(Option>, Instruction

), } -pub struct Variable { +pub struct Variable { pub space: StateSpace, pub v_type: Type, - pub name: ID, + pub name: P::ID, pub count: Option, } @@ -169,59 +169,75 @@ pub struct PredAt { pub label: ID, } -pub enum Instruction { - Ld(LdData, Arg2), - Mov(MovData, Arg2Mov), - Mul(MulDetails, Arg3), - Add(AddDetails, Arg3), - Setp(SetpData, Arg4), - SetpBool(SetpBoolData, Arg5), - Not(NotData, Arg2), - Bra(BraData, Arg1), - Cvt(CvtData, Arg2), - Shl(ShlData, Arg3), - St(StData, Arg2St), +pub enum Instruction { + Ld(LdData, Arg2

), + Mov(MovData, Arg2Mov

), + Mul(MulDetails, Arg3

), + Add(AddDetails, Arg3

), + Setp(SetpData, Arg4

), + SetpBool(SetpBoolData, Arg5

), + Not(NotData, Arg2

), + Bra(BraData, Arg1

), + Cvt(CvtData, Arg2

), + Shl(ShlData, Arg3

), + St(StData, Arg2St

), Ret(RetData), } -pub struct Arg1 { - pub src: ID, // it is a jump destination, but in terms of operands it is a source operand +pub trait ArgParams { + type ID; + type Operand; + type MovOperand; } -pub struct Arg2 { - pub dst: ID, - pub src: Operand, +pub struct ParsedArgParams<'a> { + _marker: PhantomData<&'a ()>, } -pub struct Arg2St { - pub src1: Operand, - pub src2: Operand, +impl<'a> ArgParams for ParsedArgParams<'a> { + type ID = &'a str; + type Operand = Operand<&'a str>; + type MovOperand = MovOperand<&'a str>; } -pub struct Arg2Mov { - pub dst: ID, - pub src: MovOperand, +pub struct Arg1 { + pub src: P::ID, // it is a jump destination, but in terms of operands it is a source operand } -pub struct Arg3 { - pub dst: ID, - pub src1: Operand, - pub src2: Operand, +pub struct Arg2 { + pub dst: P::ID, + pub src: P::Operand, } -pub struct Arg4 { - pub dst1: ID, - pub dst2: Option, - pub src1: Operand, - pub src2: Operand, +pub struct Arg2St { + pub src1: P::Operand, + pub src2: P::Operand, } -pub struct Arg5 { - pub dst1: ID, - pub dst2: Option, - pub src1: Operand, - pub src2: Operand, - pub src3: Operand, +pub struct Arg2Mov { + pub dst: P::ID, + pub src: P::MovOperand, +} + +pub struct Arg3 { + pub dst: P::ID, + pub src1: P::Operand, + pub src2: P::Operand, +} + +pub struct Arg4 { + pub dst1: P::ID, + pub dst2: Option, + pub src1: P::Operand, + pub src2: P::Operand, +} + +pub struct Arg5 { + pub dst1: P::ID, + pub dst2: Option, + pub src1: P::Operand, + pub src2: P::Operand, + pub src3: P::Operand, } pub enum Operand { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index cc58cf2..af26765 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -223,7 +223,7 @@ FunctionInput: ast::Argument<'input> = { } }; -pub(crate) FunctionBody: Vec> = { +pub(crate) FunctionBody: Vec>> = { "{" "}" => { without_none(s) } }; @@ -269,7 +269,7 @@ MemoryType: ast::ScalarType = { ".f64" => ast::ScalarType::F64, }; -Statement: Option> = { +Statement: Option>> = { => Some(ast::Statement::Label(l)), DebugDirective => None, ";" => Some(ast::Statement::Variable(v)), @@ -289,7 +289,7 @@ Label: &'input str = { ":" => id }; -Variable: ast::Variable<&'input str> = { +Variable: ast::Variable> = { => { let (name, count) = v; ast::Variable { space: s, v_type: t, name: name, count: count } @@ -310,7 +310,7 @@ VariableName: (&'input str, Option) = { } }; -Instruction: ast::Instruction<&'input str> = { +Instruction: ast::Instruction> = { InstLd, InstMov, InstMul, @@ -325,7 +325,7 @@ Instruction: ast::Instruction<&'input str> = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld -InstLd: ast::Instruction<&'input str> = { +InstLd: ast::Instruction> = { "ld" "," "[" "]" => { ast::Instruction::Ld( ast::LdData { @@ -370,7 +370,7 @@ LdCacheOperator: ast::LdCacheOperator = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov -InstMov: ast::Instruction<&'input str> = { +InstMov: ast::Instruction> = { "mov" => { ast::Instruction::Mov(ast::MovData{ typ:t }, a) } @@ -394,7 +394,7 @@ MovType: ast::Type = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul -InstMul: ast::Instruction<&'input str> = { +InstMul: ast::Instruction> = { "mul" => ast::Instruction::Mul(d, a) }; @@ -455,7 +455,7 @@ IntType : ast::IntType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add -InstAdd: ast::Instruction<&'input str> = { +InstAdd: ast::Instruction> = { "add" => ast::Instruction::Add(d, a) }; @@ -492,7 +492,7 @@ InstAddMode: ast::AddDetails = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp // TODO: support f16 setp -InstSetp: ast::Instruction<&'input str> = { +InstSetp: ast::Instruction> = { "setp" => ast::Instruction::Setp(d, a), "setp" => ast::Instruction::SetpBool(d, a), }; @@ -556,7 +556,7 @@ SetpType: ast::ScalarType = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not -InstNot: ast::Instruction<&'input str> = { +InstNot: ast::Instruction> = { "not" NotType => ast::Instruction::Not(ast::NotData{}, a) }; @@ -571,12 +571,12 @@ PredAt: ast::PredAt<&'input str> = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra -InstBra: ast::Instruction<&'input str> = { +InstBra: ast::Instruction> = { "bra" => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a) }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt -InstCvt: ast::Instruction<&'input str> = { +InstCvt: ast::Instruction> = { "cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType => { ast::Instruction::Cvt(ast::CvtData{}, a) } @@ -602,7 +602,7 @@ CvtType = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl -InstShl: ast::Instruction<&'input str> = { +InstShl: ast::Instruction> = { "shl" ShlType => ast::Instruction::Shl(ast::ShlData{}, a) }; @@ -612,7 +612,7 @@ ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once -InstSt: ast::Instruction<&'input str> = { +InstSt: ast::Instruction> = { "st" "[" "]" "," => { ast::Instruction::St( ast::StData { @@ -642,7 +642,7 @@ StCacheOperator: ast::StCacheOperator = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret -InstRet: ast::Instruction<&'input str> = { +InstRet: ast::Instruction> = { "ret" => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() }) }; @@ -675,28 +675,28 @@ VectorOperand: (&'input str, &'input str) = { => (pref, &suf[1..]), }; -Arg1: ast::Arg1<&'input str> = { +Arg1: ast::Arg1> = { => ast::Arg1{<>} }; -Arg2: ast::Arg2<&'input str> = { +Arg2: ast::Arg2> = { "," => ast::Arg2{<>} }; -Arg2Mov: ast::Arg2Mov<&'input str> = { +Arg2Mov: ast::Arg2Mov> = { "," => ast::Arg2Mov{<>} }; -Arg3: ast::Arg3<&'input str> = { +Arg3: ast::Arg3> = { "," "," => ast::Arg3{<>} }; -Arg4: ast::Arg4<&'input str> = { +Arg4: ast::Arg4> = { "," "," => ast::Arg4{<>} }; // TODO: pass src3 negation somewhere -Arg5: ast::Arg5<&'input str> = { +Arg5: ast::Arg5> = { "," "," "," "!"? => ast::Arg5{<>} }; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 3486edd..ebcb090 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -168,9 +168,9 @@ fn apply_id_offset(func_body: &mut Vec, id_offset: u32) { } } -fn to_ssa<'a>( - f_args: &[ast::Argument], - f_body: Vec>, +fn to_ssa<'a, 'b>( + f_args: &'b [ast::Argument<'a>], + f_body: Vec>>, ) -> (Vec, spirv::Word) { let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body); let normalized_statements = normalize_predicates(normalized_ids, &mut id_def); @@ -214,7 +214,7 @@ fn normalize_labels( } fn normalize_predicates( - func: Vec>, + func: Vec>, id_def: &mut NumericIdResolver, ) -> Vec { let mut result = Vec::with_capacity(func.len()); @@ -343,51 +343,51 @@ fn expand_arguments( fn normalize_insert_instruction( func: &mut Vec, id_def: &mut NumericIdResolver, - instr: ast::Instruction, -) -> Instruction { + instr: ast::Instruction, +) -> ast::Instruction { match instr { ast::Instruction::Ld(d, a) => { let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a); - Instruction::Ld(d, arg) + ast::Instruction::Ld(d, arg) } ast::Instruction::Mov(d, a) => { let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a); - Instruction::Mov(d, arg) + ast::Instruction::Mov(d, arg) } ast::Instruction::Mul(d, a) => { let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); - Instruction::Mul(d, arg) + ast::Instruction::Mul(d, arg) } ast::Instruction::Add(d, a) => { let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); - Instruction::Add(d, arg) + ast::Instruction::Add(d, arg) } ast::Instruction::Setp(d, a) => { let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a); - Instruction::Setp(d, arg) + ast::Instruction::Setp(d, arg) } ast::Instruction::SetpBool(d, a) => { let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a); - Instruction::SetpBool(d, arg) + ast::Instruction::SetpBool(d, arg) } ast::Instruction::Not(d, a) => { let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); - Instruction::Not(d, arg) + ast::Instruction::Not(d, arg) } - ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }), + ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, ast::Arg1 { src: a.src }), ast::Instruction::Cvt(d, a) => { let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); - Instruction::Cvt(d, arg) + ast::Instruction::Cvt(d, arg) } ast::Instruction::Shl(d, a) => { let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a); - Instruction::Shl(d, arg) + ast::Instruction::Shl(d, arg) } ast::Instruction::St(d, a) => { let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a); - Instruction::St(d, arg) + ast::Instruction::St(d, arg) } - ast::Instruction::Ret(d) => Instruction::Ret(d), + ast::Instruction::Ret(d) => ast::Instruction::Ret(d), } } @@ -395,9 +395,9 @@ fn normalize_expand_arg2( func: &mut Vec, id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, - a: ast::Arg2, -) -> Arg2 { - Arg2 { + a: ast::Arg2, +) -> ast::Arg2 { + ast::Arg2 { dst: a.dst, src: normalize_expand_operand(func, id_def, inst_type, a.src), } @@ -407,9 +407,9 @@ fn normalize_expand_arg2mov( func: &mut Vec, id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, - a: ast::Arg2Mov, -) -> Arg2 { - Arg2 { + a: ast::Arg2Mov, +) -> ast::Arg2Mov { + ast::Arg2Mov { dst: a.dst, src: normalize_expand_mov_operand(func, id_def, inst_type, a.src), } @@ -419,9 +419,9 @@ fn normalize_expand_arg2st( func: &mut Vec, id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, - a: ast::Arg2St, -) -> Arg2St { - Arg2St { + a: ast::Arg2St, +) -> ast::Arg2St { + ast::Arg2St { src1: normalize_expand_operand(func, id_def, inst_type, a.src1), src2: normalize_expand_operand(func, id_def, inst_type, a.src2), } @@ -431,9 +431,9 @@ fn normalize_expand_arg3( func: &mut Vec, id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, - a: ast::Arg3, -) -> Arg3 { - Arg3 { + a: ast::Arg3, +) -> ast::Arg3 { + ast::Arg3 { dst: a.dst, src1: normalize_expand_operand(func, id_def, inst_type, a.src1), src2: normalize_expand_operand(func, id_def, inst_type, a.src2), @@ -444,9 +444,9 @@ fn normalize_expand_arg4( func: &mut Vec, id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, - a: ast::Arg4, -) -> Arg4 { - Arg4 { + a: ast::Arg4, +) -> ast::Arg4 { + ast::Arg4 { dst1: a.dst1, dst2: a.dst2, src1: normalize_expand_operand(func, id_def, inst_type, a.src1), @@ -458,9 +458,9 @@ fn normalize_expand_arg5( func: &mut Vec, id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, - a: ast::Arg5, -) -> Arg5 { - Arg5 { + a: ast::Arg5, +) -> ast::Arg5 { + ast::Arg5 { dst1: a.dst1, dst2: a.dst2, src1: normalize_expand_operand(func, id_def, inst_type, a.src1), @@ -527,7 +527,7 @@ fn insert_implicit_conversions( for s in func.into_iter() { match s { Statement::Instruction(inst) => match inst { - Instruction::Ld(ld, mut arg) => { + ast::Instruction::Ld(ld, mut arg) => { arg.src = insert_implicit_conversions_ld_src( &mut result, ast::Type::Scalar(ld.typ), @@ -542,10 +542,10 @@ fn insert_implicit_conversions( should_convert_relaxed_dst, arg, |arg| &mut arg.dst, - |arg| Instruction::Ld(ld, arg), + |arg| ast::Instruction::Ld(ld, arg), ); } - Instruction::St(st, mut arg) => { + ast::Instruction::St(st, mut arg) => { let arg_src2_type = id_def.get_type(arg.src2); if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) { arg.src2 = insert_conversion_src( @@ -564,7 +564,7 @@ fn insert_implicit_conversions( st.state_space.to_ld_ss(), arg.src1, ); - result.push(Statement::Instruction(Instruction::St(st, arg))); + result.push(Statement::Instruction(ast::Instruction::St(st, arg))); } inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst), }, @@ -668,10 +668,10 @@ fn emit_function_body_ops( } Statement::Instruction(inst) => match inst { // SPIR-V does not support marking jumps as guaranteed-converged - Instruction::Bra(_, arg) => { + ast::Instruction::Bra(_, arg) => { builder.branch(arg.src)?; } - Instruction::Ld(data, arg) => { + ast::Instruction::Ld(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() { todo!() } @@ -686,7 +686,7 @@ fn emit_function_body_ops( _ => todo!(), } } - Instruction::St(data, arg) => { + ast::Instruction::St(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() || data.state_space != ast::StStateSpace::Generic @@ -696,18 +696,18 @@ fn emit_function_body_ops( builder.store(arg.src1, arg.src2, None, &[])?; } // SPIR-V does not support ret as guaranteed-converged - Instruction::Ret(_) => builder.ret()?, - Instruction::Mov(mov, arg) => { + ast::Instruction::Ret(_) => builder.ret()?, + ast::Instruction::Mov(mov, arg) => { let result_type = map.get_or_add(builder, SpirvType::from(mov.typ)); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } - Instruction::Mul(mul, arg) => match mul { + ast::Instruction::Mul(mul, arg) => match mul { ast::MulDetails::Int(ref ctr) => { emit_mul_int(builder, map, opencl, ctr, arg)?; } ast::MulDetails::Float(_) => todo!(), }, - Instruction::Add(add, arg) => match add { + ast::Instruction::Add(add, arg) => match add { ast::AddDetails::Int(ref desc) => { emit_add_int(builder, map, desc, arg)?; } @@ -732,7 +732,7 @@ fn emit_mul_int( map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MulIntDesc, - arg: &Arg3, + arg: &ast::Arg3, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into())); match desc.control { @@ -762,7 +762,7 @@ fn emit_add_int( builder: &mut dr::Builder, map: &mut TypeWordMap, ctr: &ast::AddIntDesc, - arg: &Arg3, + arg: &ast::Arg3, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into())); builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; @@ -837,10 +837,10 @@ fn emit_implicit_conversion( } // TODO: support scopes -fn normalize_identifiers<'a>( - args: &'a [ast::Argument<'a>], - func: Vec>, -) -> (Vec>, NumericIdResolver) { +fn normalize_identifiers<'a, 'b>( + args: &'b [ast::Argument<'a>], + func: Vec>>, +) -> (Vec>, NumericIdResolver) { let mut id_defs = StringIdResolver::new(); for arg in args { id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type))); @@ -854,8 +854,8 @@ fn normalize_identifiers<'a>( fn expand_map_ids<'a>( id_defs: &mut StringIdResolver<'a>, - result: &mut Vec>, - s: ast::Statement<&'a str>, + result: &mut Vec>, + s: ast::Statement>, ) { match s { ast::Statement::Label(name) => { @@ -979,7 +979,7 @@ enum Statement { Constant(ConstantDefinition), } -impl Statement { +impl Statement> { fn visit_id(&mut self, f: &mut F) { match self { Statement::Variable(id, _, _) => f(id), @@ -994,25 +994,25 @@ impl Statement { } } -type NormalizedStatement = Statement>; -type ExpandedStatement = Statement; +enum NormalizedArgParams {} +type NormalizedStatement = Statement>; -enum Instruction { - Ld(ast::LdData, Arg2), - Mov(ast::MovData, Arg2), - Mul(ast::MulDetails, Arg3), - Add(ast::AddDetails, Arg3), - Setp(ast::SetpData, Arg4), - SetpBool(ast::SetpBoolData, Arg5), - Not(ast::NotData, Arg2), - Bra(ast::BraData, Arg1), - Cvt(ast::CvtData, Arg2), - Shl(ast::ShlData, Arg3), - St(ast::StData, Arg2St), - Ret(ast::RetData), +impl ast::ArgParams for NormalizedArgParams { + type ID = spirv::Word; + type Operand = ast::Operand; + type MovOperand = ast::MovOperand; } -impl ast::Instruction { +enum ExpandedArgParams {} +type ExpandedStatement = Statement>; + +impl ast::ArgParams for ExpandedArgParams { + type ID = spirv::Word; + type Operand = spirv::Word; + type MovOperand = spirv::Word; +} + +impl ast::Instruction { fn visit_id)>(&mut self, f: &mut F) { match self { ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), @@ -1031,22 +1031,22 @@ impl ast::Instruction { } } -impl Instruction { +impl ast::Instruction { fn visit_id(&mut self, f: &mut F) { let f_visitor = &mut Self::typed_visitor(f); match self { - Instruction::Ld(_, a) => a.visit_id(f_visitor, None), - Instruction::Mov(_, a) => a.visit_id(f_visitor, None), - Instruction::Mul(_, a) => a.visit_id(f_visitor, None), - Instruction::Add(_, a) => a.visit_id(f_visitor, None), - Instruction::Setp(_, a) => a.visit_id(f_visitor, None), - Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None), - Instruction::Not(_, a) => a.visit_id(f_visitor, None), - Instruction::Cvt(_, a) => a.visit_id(f_visitor, None), - Instruction::Shl(_, a) => a.visit_id(f_visitor, None), - Instruction::St(_, a) => a.visit_id(f_visitor, None), - Instruction::Bra(_, a) => a.visit_id(f_visitor, None), - Instruction::Ret(_) => (), + ast::Instruction::Ld(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Mov(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Mul(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Add(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Setp(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Not(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Cvt(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Shl(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::St(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Bra(_, a) => a.visit_id(f_visitor, None), + ast::Instruction::Ret(_) => (), } } @@ -1061,42 +1061,40 @@ impl Instruction { f: &mut F, ) { match self { - Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), - Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), - Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), - Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::Not(_, _) => todo!(), - Instruction::Cvt(_, _) => todo!(), - Instruction::Shl(_, _) => todo!(), - Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::Bra(_, a) => a.visit_id(f, None), - Instruction::Ret(_) => (), + ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), + ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), + ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), + ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::Not(_, _) => todo!(), + ast::Instruction::Cvt(_, _) => todo!(), + ast::Instruction::Shl(_, _) => todo!(), + ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::Bra(_, a) => a.visit_id(f, None), + ast::Instruction::Ret(_) => (), } } fn jump_target(&self) -> Option { match self { - Instruction::Bra(_, a) => Some(a.src), - Instruction::Ld(_, _) - | Instruction::Mov(_, _) - | Instruction::Mul(_, _) - | Instruction::Add(_, _) - | Instruction::Setp(_, _) - | Instruction::SetpBool(_, _) - | Instruction::Not(_, _) - | Instruction::Cvt(_, _) - | Instruction::Shl(_, _) - | Instruction::St(_, _) - | Instruction::Ret(_) => None, + ast::Instruction::Bra(_, a) => Some(a.src), + ast::Instruction::Ld(_, _) + | ast::Instruction::Mov(_, _) + | ast::Instruction::Mul(_, _) + | ast::Instruction::Add(_, _) + | ast::Instruction::Setp(_, _) + | ast::Instruction::SetpBool(_, _) + | ast::Instruction::Not(_, _) + | ast::Instruction::Cvt(_, _) + | ast::Instruction::Shl(_, _) + | ast::Instruction::St(_, _) + | ast::Instruction::Ret(_) => None, } } } -struct Arg1 { - pub src: spirv::Word, -} +type Arg1 = ast::Arg1; impl Arg1 { fn visit_id)>( @@ -1108,10 +1106,7 @@ impl Arg1 { } } -struct Arg2 { - pub dst: spirv::Word, - pub src: spirv::Word, -} +type Arg2 = ast::Arg2; impl Arg2 { fn visit_id)>( @@ -1124,11 +1119,21 @@ impl Arg2 { } } -pub struct Arg2St { - pub src1: spirv::Word, - pub src2: spirv::Word, +type Arg2Mov = ast::Arg2Mov; + +impl Arg2Mov { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(true, &mut self.dst, t); + f(false, &mut self.src, t); + } } +type Arg2St = ast::Arg2St; + impl Arg2St { fn visit_id)>( &mut self, @@ -1140,11 +1145,7 @@ impl Arg2St { } } -struct Arg3 { - pub dst: spirv::Word, - pub src1: spirv::Word, - pub src2: spirv::Word, -} +type Arg3 = ast::Arg3; impl Arg3 { fn visit_id)>( @@ -1158,12 +1159,7 @@ impl Arg3 { } } -struct Arg4 { - pub dst1: spirv::Word, - pub dst2: Option, - pub src1: spirv::Word, - pub src2: spirv::Word, -} +type Arg4 = ast::Arg4; impl Arg4 { fn visit_id)>( @@ -1188,13 +1184,7 @@ impl Arg4 { } } -struct Arg5 { - pub dst1: spirv::Word, - pub dst2: Option, - pub src1: spirv::Word, - pub src2: spirv::Word, - pub src3: spirv::Word, -} +type Arg5 = ast::Arg5; impl Arg5 { fn visit_id)>( @@ -1286,8 +1276,11 @@ impl ast::PredAt { } } -impl ast::Instruction { - fn map_id U>(self, f: &mut F) -> ast::Instruction { +impl<'a> ast::Instruction> { + fn map_id spirv::Word>( + self, + f: &mut F, + ) -> ast::Instruction { match self { ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)), ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)), @@ -1305,13 +1298,13 @@ impl ast::Instruction { } } -impl ast::Arg1 { - fn map_id U>(self, f: &mut F) -> ast::Arg1 { +impl<'a> ast::Arg1> { + fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg1 { ast::Arg1 { src: f(self.src) } } } -impl ast::Arg1 { +impl ast::Arg1 { fn visit_id)>( &mut self, f: &mut F, @@ -1321,8 +1314,8 @@ impl ast::Arg1 { } } -impl ast::Arg2 { - fn map_id U>(self, f: &mut F) -> ast::Arg2 { +impl<'a> ast::Arg2> { + fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg2 { ast::Arg2 { dst: f(self.dst), src: self.src.map_id(f), @@ -1330,7 +1323,7 @@ impl ast::Arg2 { } } -impl ast::Arg2 { +impl ast::Arg2 { fn visit_id)>( &mut self, f: &mut F, @@ -1341,8 +1334,11 @@ impl ast::Arg2 { } } -impl ast::Arg2St { - fn map_id U>(self, f: &mut F) -> ast::Arg2St { +impl<'a> ast::Arg2St> { + fn map_id spirv::Word>( + self, + f: &mut F, + ) -> ast::Arg2St { ast::Arg2St { src1: self.src1.map_id(f), src2: self.src2.map_id(f), @@ -1350,7 +1346,7 @@ impl ast::Arg2St { } } -impl ast::Arg2St { +impl ast::Arg2St { fn visit_id)>( &mut self, f: &mut F, @@ -1361,8 +1357,11 @@ impl ast::Arg2St { } } -impl ast::Arg2Mov { - fn map_id U>(self, f: &mut F) -> ast::Arg2Mov { +impl<'a> ast::Arg2Mov> { + fn map_id spirv::Word>( + self, + f: &mut F, + ) -> ast::Arg2Mov { ast::Arg2Mov { dst: f(self.dst), src: self.src.map_id(f), @@ -1370,7 +1369,7 @@ impl ast::Arg2Mov { } } -impl ast::Arg2Mov { +impl ast::Arg2Mov { fn visit_id)>( &mut self, f: &mut F, @@ -1381,8 +1380,8 @@ impl ast::Arg2Mov { } } -impl ast::Arg3 { - fn map_id U>(self, f: &mut F) -> ast::Arg3 { +impl<'a> ast::Arg3> { + fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg3 { ast::Arg3 { dst: f(self.dst), src1: self.src1.map_id(f), @@ -1391,7 +1390,7 @@ impl ast::Arg3 { } } -impl ast::Arg3 { +impl ast::Arg3 { fn visit_id)>( &mut self, f: &mut F, @@ -1403,8 +1402,8 @@ impl ast::Arg3 { } } -impl ast::Arg4 { - fn map_id U>(self, f: &mut F) -> ast::Arg4 { +impl<'a> ast::Arg4> { + fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg4 { ast::Arg4 { dst1: f(self.dst1), dst2: self.dst2.map(|i| f(i)), @@ -1414,7 +1413,7 @@ impl ast::Arg4 { } } -impl ast::Arg4 { +impl ast::Arg4 { fn visit_id)>( &mut self, f: &mut F, @@ -1437,8 +1436,8 @@ impl ast::Arg4 { } } -impl ast::Arg5 { - fn map_id U>(self, f: &mut F) -> ast::Arg5 { +impl<'a> ast::Arg5> { + fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg5 { ast::Arg5 { dst1: f(self.dst1), dst2: self.dst2.map(|i| f(i)), @@ -1449,7 +1448,7 @@ impl ast::Arg5 { } } -impl ast::Arg5 { +impl ast::Arg5 { fn visit_id)>( &mut self, f: &mut F, @@ -1779,7 +1778,7 @@ fn insert_with_implicit_conversion_dst< T, ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option, Setter: Fn(&mut T) -> &mut spirv::Word, - ToInstruction: FnOnce(T) -> Instruction, + ToInstruction: FnOnce(T) -> ast::Instruction, >( func: &mut Vec, instr_type: ast::ScalarType, @@ -1907,7 +1906,7 @@ fn should_convert_relaxed_dst( fn insert_implicit_bitcasts( func: &mut Vec, id_def: &mut NumericIdResolver, - mut instr: Instruction, + mut instr: ast::Instruction, ) { let mut dst_coercion = None; instr.visit_id_extended(&mut |is_dst, id, id_type| {