Refactor various functions for visiting/mapping statements and instructions into one

This commit is contained in:
Andrzej Janik
2020-07-30 03:01:37 +02:00
parent 52faaab547
commit 66fa0706a4

View File

@ -156,16 +156,17 @@ fn emit_function<'a>(
let (mut func_body, unique_ids) = to_ssa(&f.args, f.body); let (mut func_body, unique_ids) = to_ssa(&f.args, f.body);
let id_offset = builder.reserve_ids(unique_ids); let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args); emit_function_args(builder, id_offset, map, &f.args);
apply_id_offset(&mut func_body, id_offset); func_body = apply_id_offset(func_body, id_offset);
emit_function_body_ops(builder, map, opencl_id, &func_body)?; emit_function_body_ops(builder, map, opencl_id, &func_body)?;
builder.end_function()?; builder.end_function()?;
Ok(func_id) Ok(func_id)
} }
fn apply_id_offset(func_body: &mut Vec<ExpandedStatement>, id_offset: u32) { fn apply_id_offset(func_body: Vec<ExpandedStatement>, id_offset: u32) -> Vec<ExpandedStatement> {
for s in func_body { func_body
s.visit_id(&mut |id| *id += id_offset); .into_iter()
} .map(|s| s.visit_variable(&mut |id| id + id_offset))
.collect()
} }
fn to_ssa<'a, 'b>( fn to_ssa<'a, 'b>(
@ -274,32 +275,32 @@ fn insert_mem_ssa_statements(
) => { ) => {
result.push(Statement::Instruction(ast::Instruction::Ld(ld, arg))); result.push(Statement::Instruction(ast::Instruction::Ld(ld, arg)));
} }
mut inst => { inst => {
let mut post_statements = Vec::new(); let mut post_statements = Vec::new();
inst.visit_id(&mut |is_dst, id, id_type| { let inst = inst.visit_variable(&mut |id, is_dst, id_type| {
let id_type = match id_type { let id_type = match id_type {
Some(t) => t, Some(t) => t,
None => return, None => return id,
}; };
let generated_id = id_def.new_id(Some(id_type)); let generated_id = id_def.new_id(Some(id_type));
if !is_dst { if !is_dst {
result.push(Statement::LoadVar( result.push(Statement::LoadVar(
Arg2 { Arg2 {
dst: generated_id, dst: generated_id,
src: *id, src: id,
}, },
id_type, id_type,
)); ));
} else { } else {
post_statements.push(Statement::StoreVar( post_statements.push(Statement::StoreVar(
Arg2St { Arg2St {
src1: *id, src1: id,
src2: generated_id, src2: generated_id,
}, },
id_type, id_type,
)); ));
} }
*id = generated_id; generated_id
}); });
result.push(Statement::Instruction(inst)); result.push(Statement::Instruction(inst));
result.append(&mut post_statements); result.append(&mut post_statements);
@ -847,12 +848,12 @@ fn normalize_identifiers<'a, 'b>(
} }
let mut result = Vec::new(); let mut result = Vec::new();
for s in func { for s in func {
expand_map_ids(&mut id_defs, &mut result, s); expand_map_variables(&mut id_defs, &mut result, s);
} }
(result, id_defs.finish()) (result, id_defs.finish())
} }
fn expand_map_ids<'a>( fn expand_map_variables<'a>(
id_defs: &mut StringIdResolver<'a>, id_defs: &mut StringIdResolver<'a>,
result: &mut Vec<ast::Statement<NormalizedArgParams>>, result: &mut Vec<ast::Statement<NormalizedArgParams>>,
s: ast::Statement<ast::ParsedArgParams<'a>>, s: ast::Statement<ast::ParsedArgParams<'a>>,
@ -862,8 +863,8 @@ fn expand_map_ids<'a>(
result.push(ast::Statement::Label(id_defs.add_def(name, None))) result.push(ast::Statement::Label(id_defs.add_def(name, None)))
} }
ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction( ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction(
p.map(|p| p.map_id(&mut |id| id_defs.get_id(id))), p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
i.map_id(&mut |id| id_defs.get_id(id)), i.map_variable(&mut |id| id_defs.get_id(id)),
)), )),
ast::Statement::Variable(var) => match var.count { ast::Statement::Variable(var) => match var.count {
Some(count) => { Some(count) => {
@ -969,8 +970,8 @@ impl NumericIdResolver {
enum Statement<I> { enum Statement<I> {
Variable(spirv::Word, ast::Type, ast::StateSpace), Variable(spirv::Word, ast::Type, ast::StateSpace),
LoadVar(Arg2, ast::Type), LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(Arg2St, ast::Type), StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Label(u32), Label(u32),
Instruction(I), Instruction(I),
// SPIR-V compatible replacement for PTX predicates // SPIR-V compatible replacement for PTX predicates
@ -980,16 +981,20 @@ enum Statement<I> {
} }
impl Statement<ast::Instruction<ExpandedArgParams>> { impl Statement<ast::Instruction<ExpandedArgParams>> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) { fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
match self { match self {
Statement::Variable(id, _, _) => f(id), Statement::Variable(id, t, ss) => Statement::Variable(f(id), t, ss),
Statement::LoadVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None), Statement::LoadVar(a, t) => {
Statement::StoreVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None), Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t)
Statement::Label(id) => f(id), }
Statement::Instruction(inst) => inst.visit_id(f), Statement::StoreVar(a, t) => {
Statement::Conditional(bra) => bra.visit_id(&mut |_, id, _| f(id)), Statement::StoreVar(a.map(&mut reduced_visitor(f), Some(t)), t)
Statement::Conversion(conv) => conv.visit_id(f), }
Statement::Constant(cons) => cons.visit_id(f), Statement::Label(id) => Statement::Label(f(id)),
Statement::Instruction(inst) => Statement::Instruction(inst.visit_variable(f)),
Statement::Conditional(bra) => Statement::Conditional(bra.map(f)),
Statement::Conversion(conv) => Statement::Conversion(conv.map(f)),
Statement::Constant(cons) => Statement::Constant(cons.map(f)),
} }
} }
} }
@ -1012,69 +1017,211 @@ impl ast::ArgParams for ExpandedArgParams {
type MovOperand = spirv::Word; type MovOperand = spirv::Word;
} }
impl ast::Instruction<NormalizedArgParams> { trait ArgumentMapVisitor<T: ast::ArgParams, U: ast::ArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) { fn dst_variable(&mut self, v: T::ID, typ: Option<ast::Type>) -> U::ID;
match self { fn src_operand(&mut self, o: T::Operand, typ: Option<ast::Type>) -> U::Operand;
ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), fn src_mov_operand(&mut self, o: T::MovOperand, typ: Option<ast::Type>) -> U::MovOperand;
ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), }
ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), struct FlattenArguments<'a> {
ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), func: &'a mut Vec<ExpandedStatement>,
ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), id_def: &'a mut NumericIdResolver,
ast::Instruction::Not(_, _) => todo!(), }
ast::Instruction::Cvt(_, _) => todo!(),
ast::Instruction::Shl(_, _) => todo!(), impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenArguments<'a> {
ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), fn dst_variable(&mut self, x: spirv::Word, _: Option<ast::Type>) -> spirv::Word {
ast::Instruction::Bra(_, a) => a.visit_id(f, None), x
ast::Instruction::Ret(_) => (), }
fn src_operand(&mut self, op: ast::Operand<spirv::Word>, t: Option<ast::Type>) -> spirv::Word {
match op {
ast::Operand::Reg(r) => r,
ast::Operand::Imm(x) => {
if let Some(typ) = t {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
todo!()
};
let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
value: x,
}));
id
} else {
todo!()
}
}
_ => todo!(),
}
}
fn src_mov_operand(
&mut self,
op: ast::MovOperand<spirv::Word>,
t: Option<ast::Type>,
) -> spirv::Word {
match op {
ast::MovOperand::Op(opr) => self.src_operand(opr, t),
ast::MovOperand::Vec(_, _) => todo!(),
} }
} }
} }
impl ast::Instruction<ExpandedArgParams> { impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) { where
let f_visitor = &mut Self::typed_visitor(f); T: FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word,
match self { {
ast::Instruction::Ld(_, a) => a.visit_id(f_visitor, None), fn dst_variable(&mut self, x: spirv::Word, t: Option<ast::Type>) -> spirv::Word {
ast::Instruction::Mov(_, a) => a.visit_id(f_visitor, None), self(x, t.is_some(), t)
ast::Instruction::Mul(_, a) => a.visit_id(f_visitor, None), }
ast::Instruction::Add(_, a) => a.visit_id(f_visitor, None), fn src_operand(&mut self, x: spirv::Word, t: Option<ast::Type>) -> spirv::Word {
ast::Instruction::Setp(_, a) => a.visit_id(f_visitor, None), self(x, false, t)
ast::Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None), }
ast::Instruction::Not(_, a) => a.visit_id(f_visitor, None), fn src_mov_operand(&mut self, x: spirv::Word, t: Option<ast::Type>) -> spirv::Word {
ast::Instruction::Cvt(_, a) => a.visit_id(f_visitor, None), self(x, false, t)
ast::Instruction::Shl(_, a) => a.visit_id(f_visitor, None), }
ast::Instruction::St(_, a) => a.visit_id(f_visitor, None), }
ast::Instruction::Bra(_, a) => a.visit_id(f_visitor, None),
ast::Instruction::Ret(_) => (), impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> for T
where
T: FnMut(&str) -> spirv::Word,
{
fn dst_variable(&mut self, x: &str, _: Option<ast::Type>) -> spirv::Word {
self(x)
}
fn src_operand(
&mut self,
x: ast::Operand<&str>,
_: Option<ast::Type>,
) -> ast::Operand<spirv::Word> {
match x {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)),
ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id), imm),
} }
} }
fn typed_visitor<'a>( fn src_mov_operand(
f: &'a mut impl FnMut(&mut spirv::Word),
) -> impl FnMut(bool, &mut spirv::Word, Option<ast::Type>) + 'a {
move |_, id, _| f(id)
}
fn visit_id_extended<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self, &mut self,
f: &mut F, x: ast::MovOperand<&str>,
) { t: Option<ast::Type>,
) -> ast::MovOperand<spirv::Word> {
match x {
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.src_operand(op, t)),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
}
}
}
impl<T: ast::ArgParams> ast::Instruction<T> {
fn map_variable_new<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
) -> ast::Instruction<U> {
match self { match self {
ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::Ld(d, a) => {
ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), let inst_type = d.typ;
ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), ast::Instruction::Ld(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), }
ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::Mov(d, a) => {
ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), let inst_type = d.typ;
ast::Instruction::Mov(d, a.map(visitor, Some(inst_type)))
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
ast::Instruction::Mul(d, a.map(visitor, Some(inst_type)))
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
ast::Instruction::Add(d, a.map(visitor, Some(inst_type)))
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
ast::Instruction::Setp(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
}
ast::Instruction::SetpBool(d, a) => {
let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
}
ast::Instruction::Not(_, _) => todo!(), ast::Instruction::Not(_, _) => todo!(),
ast::Instruction::Cvt(_, _) => todo!(), ast::Instruction::Cvt(_, _) => todo!(),
ast::Instruction::Shl(_, _) => todo!(), ast::Instruction::Shl(_, _) => todo!(),
ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::St(d, a) => {
ast::Instruction::Bra(_, a) => a.visit_id(f, None), let inst_type = d.typ;
ast::Instruction::Ret(_) => (), ast::Instruction::St(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
} }
} }
}
impl ast::Instruction<NormalizedArgParams> {
fn visit_variable<F: FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Instruction<NormalizedArgParams> {
self.map_variable_new(f)
}
}
impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
where
T: FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word,
{
fn dst_variable(&mut self, x: spirv::Word, t: Option<ast::Type>) -> spirv::Word {
self(x, t.is_some(), t)
}
fn src_operand(
&mut self,
x: ast::Operand<spirv::Word>,
t: Option<ast::Type>,
) -> ast::Operand<spirv::Word> {
match x {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id, false, t)),
ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id, false, t), imm),
}
}
fn src_mov_operand(
&mut self,
x: ast::MovOperand<spirv::Word>,
t: Option<ast::Type>,
) -> ast::MovOperand<spirv::Word> {
match x {
ast::MovOperand::Op(op) => ast::MovOperand::Op(ArgumentMapVisitor::<
NormalizedArgParams,
NormalizedArgParams,
>::src_operand(self, op, t)),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
}
}
}
fn reduced_visitor<'a>(
f: &'a mut impl FnMut(spirv::Word) -> spirv::Word,
) -> impl FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word + 'a {
move |id, _, _| f(id)
}
impl ast::Instruction<ExpandedArgParams> {
fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
let mut visitor = reduced_visitor(f);
self.map_variable_new(&mut visitor)
}
fn visit_variable_extended<F: FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word>(
self,
f: &mut F,
) -> Self {
self.map_variable_new(f)
}
fn jump_target(&self) -> Option<spirv::Word> { fn jump_target(&self) -> Option<spirv::Word> {
match self { match self {
@ -1094,126 +1241,9 @@ impl ast::Instruction<ExpandedArgParams> {
} }
} }
type Arg1 = ast::Arg1<ExpandedArgParams>;
impl Arg1 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(false, &mut self.src, t);
}
}
type Arg2 = ast::Arg2<ExpandedArgParams>; type Arg2 = ast::Arg2<ExpandedArgParams>;
impl Arg2 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
f(false, &mut self.src, t);
}
}
type Arg2Mov = ast::Arg2Mov<ExpandedArgParams>;
impl Arg2Mov {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
f(false, &mut self.src, t);
}
}
type Arg2St = ast::Arg2St<ExpandedArgParams>; type Arg2St = ast::Arg2St<ExpandedArgParams>;
impl Arg2St {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
}
}
type Arg3 = ast::Arg3<ExpandedArgParams>;
impl Arg3 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
}
}
type Arg4 = ast::Arg4<ExpandedArgParams>;
impl Arg4 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(
true,
&mut self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
self.dst2.as_mut().map(|dst2| {
f(
true,
dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
}
}
type Arg5 = ast::Arg5<ExpandedArgParams>;
impl Arg5 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(
true,
&mut self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
self.dst2.as_mut().map(|dst2| {
f(
true,
dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
f(
false,
&mut self.src3,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
}
}
struct ConstantDefinition { struct ConstantDefinition {
pub dst: spirv::Word, pub dst: spirv::Word,
pub typ: ast::ScalarType, pub typ: ast::ScalarType,
@ -1221,8 +1251,12 @@ struct ConstantDefinition {
} }
impl ConstantDefinition { impl ConstantDefinition {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) { fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
f(&mut self.dst); Self {
dst: f(self.dst),
typ: self.typ,
value: self.value,
}
} }
} }
@ -1233,14 +1267,12 @@ struct BrachCondition {
} }
impl BrachCondition { impl BrachCondition {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) { fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
f( Self {
false, predicate: f(self.predicate),
&mut self.predicate, if_true: f(self.if_true),
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), if_false: f(self.if_false),
); }
f(false, &mut self.if_true, None);
f(false, &mut self.if_false, None);
} }
} }
@ -1261,14 +1293,19 @@ enum ConversionKind {
} }
impl ImplicitConversion { impl ImplicitConversion {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) { fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
f(&mut self.dst); Self {
f(&mut self.src); src: f(self.src),
dst: f(self.dst),
from: self.from,
to: self.to,
kind: self.kind,
}
} }
} }
impl<T> ast::PredAt<T> { impl<T> ast::PredAt<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> { fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> {
ast::PredAt { ast::PredAt {
not: self.not, not: self.not,
label: f(self.label), label: f(self.label),
@ -1276,247 +1313,127 @@ impl<T> ast::PredAt<T> {
} }
} }
// REMOVE
impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> { impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>( fn map_variable<F: FnMut(&str) -> spirv::Word>(
self, self,
f: &mut F, f: &mut F,
) -> ast::Instruction<NormalizedArgParams> { ) -> ast::Instruction<NormalizedArgParams> {
match self { self.map_variable_new(f)
ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)), }
ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)), }
ast::Instruction::Mul(d, a) => ast::Instruction::Mul(d, a.map_id(f)),
ast::Instruction::Add(d, a) => ast::Instruction::Add(d, a.map_id(f)), impl<T: ast::ArgParams> ast::Arg1<T> {
ast::Instruction::Setp(d, a) => ast::Instruction::Setp(d, a.map_id(f)), fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
ast::Instruction::SetpBool(d, a) => ast::Instruction::SetpBool(d, a.map_id(f)), self,
ast::Instruction::Not(d, a) => ast::Instruction::Not(d, a.map_id(f)), visitor: &mut V,
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map_id(f)), t: Option<ast::Type>,
ast::Instruction::Cvt(d, a) => ast::Instruction::Cvt(d, a.map_id(f)), ) -> ast::Arg1<U> {
ast::Instruction::Shl(d, a) => ast::Instruction::Shl(d, a.map_id(f)), ast::Arg1 {
ast::Instruction::St(d, a) => ast::Instruction::St(d, a.map_id(f)), src: visitor.dst_variable(self.src, t),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
} }
} }
} }
impl<'a> ast::Arg1<ast::ParsedArgParams<'a>> { impl<T: ast::ArgParams> ast::Arg2<T> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg1<NormalizedArgParams> { fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
ast::Arg1 { src: f(self.src) } self,
} visitor: &mut V,
}
impl ast::Arg1<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>, t: Option<ast::Type>,
) { ) -> ast::Arg2<U> {
f(false, &mut self.src, t);
}
}
impl<'a> ast::Arg2<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg2<NormalizedArgParams> {
ast::Arg2 { ast::Arg2 {
dst: f(self.dst), dst: visitor.dst_variable(self.dst, t),
src: self.src.map_id(f), src: visitor.src_operand(self.src, t),
} }
} }
} }
impl ast::Arg2<NormalizedArgParams> { impl<T: ast::ArgParams> ast::Arg2St<T> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
self.src.visit_id(f, t);
}
}
impl<'a> ast::Arg2St<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(
self, self,
f: &mut F, visitor: &mut V,
) -> ast::Arg2St<NormalizedArgParams> { t: Option<ast::Type>,
) -> ast::Arg2St<U> {
ast::Arg2St { ast::Arg2St {
src1: self.src1.map_id(f), src1: visitor.src_operand(self.src1, t),
src2: self.src2.map_id(f), src2: visitor.src_operand(self.src2, t),
} }
} }
} }
impl ast::Arg2St<NormalizedArgParams> { impl<T: ast::ArgParams> ast::Arg2Mov<T> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
}
}
impl<'a> ast::Arg2Mov<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(
self, self,
f: &mut F, visitor: &mut V,
) -> ast::Arg2Mov<NormalizedArgParams> { t: Option<ast::Type>,
) -> ast::Arg2Mov<U> {
ast::Arg2Mov { ast::Arg2Mov {
dst: f(self.dst), dst: visitor.dst_variable(self.dst, t),
src: self.src.map_id(f), src: visitor.src_mov_operand(self.src, t),
} }
} }
} }
impl ast::Arg2Mov<NormalizedArgParams> { impl<T: ast::ArgParams> ast::Arg3<T> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
&mut self, self,
f: &mut F, visitor: &mut V,
t: Option<ast::Type>, t: Option<ast::Type>,
) { ) -> ast::Arg3<U> {
f(true, &mut self.dst, t);
self.src.visit_id(f, t);
}
}
impl<'a> ast::Arg3<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg3<NormalizedArgParams> {
ast::Arg3 { ast::Arg3 {
dst: f(self.dst), dst: visitor.dst_variable(self.dst, t),
src1: self.src1.map_id(f), src1: visitor.src_operand(self.src1, t),
src2: self.src2.map_id(f), src2: visitor.src_operand(self.src2, t),
} }
} }
} }
impl ast::Arg3<NormalizedArgParams> { impl<T: ast::ArgParams> ast::Arg4<T> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
&mut self, self,
f: &mut F, visitor: &mut V,
t: Option<ast::Type>, t: Option<ast::Type>,
) { ) -> ast::Arg4<U> {
f(true, &mut self.dst, t);
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
}
}
impl<'a> ast::Arg4<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg4<NormalizedArgParams> {
ast::Arg4 { ast::Arg4 {
dst1: f(self.dst1), dst1: visitor.dst_variable(
dst2: self.dst2.map(|i| f(i)), self.dst1,
src1: self.src1.map_id(f), Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
src2: self.src2.map_id(f), ),
dst2: self.dst2.map(|dst2| {
visitor.dst_variable(
dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
}),
src1: visitor.src_operand(self.src1, t),
src2: visitor.src_operand(self.src2, t),
} }
} }
} }
impl ast::Arg4<NormalizedArgParams> { impl<T: ast::ArgParams> ast::Arg5<T> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>( fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
&mut self, self,
f: &mut F, visitor: &mut V,
t: Option<ast::Type>, t: Option<ast::Type>,
) { ) -> ast::Arg5<U> {
f(
true,
&mut self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
self.dst2.as_mut().map(|i| {
f(
true,
i,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
}
}
impl<'a> ast::Arg5<ast::ParsedArgParams<'a>> {
fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg5<NormalizedArgParams> {
ast::Arg5 { ast::Arg5 {
dst1: f(self.dst1), dst1: visitor.dst_variable(
dst2: self.dst2.map(|i| f(i)), self.dst1,
src1: self.src1.map_id(f),
src2: self.src2.map_id(f),
src3: self.src3.map_id(f),
}
}
}
impl ast::Arg5<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(
true,
&mut self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
self.dst2.as_mut().map(|i| {
f(
true,
i,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
) ),
}); dst2: self.dst2.map(|dst2| {
self.src1.visit_id(f, t); visitor.dst_variable(
self.src2.visit_id(f, t); dst2,
self.src3.visit_id( Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
f, )
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), }),
); src1: visitor.src_operand(self.src1, t),
} src2: visitor.src_operand(self.src2, t),
} src3: visitor.src_operand(
self.src3,
impl<T> ast::Operand<T> { Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Operand<U> { ),
match self {
ast::Operand::Reg(i) => ast::Operand::Reg(f(i)),
ast::Operand::RegOffset(i, o) => ast::Operand::RegOffset(f(i), o),
ast::Operand::Imm(v) => ast::Operand::Imm(v),
}
}
}
impl<T: Copy> ast::Operand<T> {
fn visit_id<F: FnMut(bool, &mut T, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
match self {
ast::Operand::Reg(i) => f(false, i, t),
ast::Operand::RegOffset(i, _) => f(false, i, t),
ast::Operand::Imm(_) => (),
}
}
}
impl<T> ast::MovOperand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
match self {
ast::MovOperand::Op(o) => ast::MovOperand::Op(o.map_id(f)),
ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2),
}
}
}
impl<T: Copy> ast::MovOperand<T> {
fn visit_id<F: FnMut(bool, &mut T, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
match self {
ast::MovOperand::Op(o) => o.visit_id(f, t),
ast::MovOperand::Vec(_, _) => todo!(),
} }
} }
} }
@ -1906,34 +1823,37 @@ fn should_convert_relaxed_dst(
fn insert_implicit_bitcasts( fn insert_implicit_bitcasts(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
mut instr: ast::Instruction<ExpandedArgParams>, instr: ast::Instruction<ExpandedArgParams>,
) { ) {
let mut dst_coercion = None; let mut dst_coercion = None;
instr.visit_id_extended(&mut |is_dst, id, id_type| { let instr = instr.visit_variable_extended(&mut |mut id, is_dst, id_type| {
let id_type_from_instr = match id_type { let id_type_from_instr = match id_type {
Some(t) => t, Some(t) => t,
None => return, None => return id,
}; };
let id_actual_type = id_def.get_type(*id); let id_actual_type = id_def.get_type(id);
if should_bitcast(id_type_from_instr, id_def.get_type(*id)) { if should_bitcast(id_type_from_instr, id_def.get_type(id)) {
if is_dst { if is_dst {
dst_coercion = Some(get_conversion_dst( dst_coercion = Some(get_conversion_dst(
id_def, id_def,
id, &mut id,
id_type_from_instr, id_type_from_instr,
id_actual_type, id_actual_type,
ConversionKind::Default, ConversionKind::Default,
)); ));
id
} else { } else {
*id = insert_conversion_src( insert_conversion_src(
func, func,
id_def, id_def,
*id, id,
id_actual_type, id_actual_type,
id_type_from_instr, id_type_from_instr,
ConversionKind::Default, ConversionKind::Default,
); )
} }
} else {
id
} }
}); });
func.push(Statement::Instruction(instr)); func.push(Statement::Instruction(instr));