Refactor code to support per-variable type definition in callbacks

This commit is contained in:
Andrzej Janik
2020-07-28 00:37:16 +02:00
parent 04820fba2f
commit 72f5ffe2f9

View File

@ -164,7 +164,7 @@ fn emit_function<'a>(
fn apply_id_offset(func_body: &mut Vec<ExpandedStatement>, id_offset: u32) { fn apply_id_offset(func_body: &mut Vec<ExpandedStatement>, id_offset: u32) {
for s in func_body { for s in func_body {
s.visit_id_mut(&mut |_, id| *id += id_offset); s.visit_id(&mut |id| *id += id_offset);
} }
} }
@ -200,7 +200,7 @@ fn normalize_labels(
Statement::Variable(_, _, _) Statement::Variable(_, _, _)
| Statement::LoadVar(_, _) | Statement::LoadVar(_, _)
| Statement::StoreVar(_, _) | Statement::StoreVar(_, _)
| Statement::Converison(_) | Statement::Conversion(_)
| Statement::Constant(_) | Statement::Constant(_)
| Statement::Label(_) => (), | Statement::Label(_) => (),
} }
@ -275,18 +275,20 @@ fn insert_mem_ssa_statements(
result.push(Statement::Instruction(Instruction::Ld(ld, arg))); result.push(Statement::Instruction(Instruction::Ld(ld, arg)));
} }
mut inst => { mut inst => {
let inst_type = inst.get_type();
let mut post_statements = Vec::new(); let mut post_statements = Vec::new();
inst.visit_id_mut(&mut |is_dst, id| { inst.visit_id(&mut |is_dst, id, id_type| {
let inst_type = inst_type.unwrap(); let id_type = match id_type {
let generated_id = id_def.new_id(Some(inst_type)); Some(t) => t,
None => return,
};
let generated_id = id_def.new_id(Some(id_type));
if !is_dst { if !is_dst {
result.push(Statement::LoadVar( result.push(Statement::LoadVar(
Arg2 { Arg2 {
dst: generated_id, dst: generated_id,
src: *id, src: *id,
}, },
inst_type, id_type,
)); ));
} else { } else {
post_statements.push(Statement::StoreVar( post_statements.push(Statement::StoreVar(
@ -294,7 +296,7 @@ fn insert_mem_ssa_statements(
src1: *id, src1: *id,
src2: generated_id, src2: generated_id,
}, },
inst_type, id_type,
)); ));
} }
*id = generated_id; *id = generated_id;
@ -308,7 +310,7 @@ fn insert_mem_ssa_statements(
| s @ Statement::Conditional(_) => result.push(s), | s @ Statement::Conditional(_) => result.push(s),
Statement::LoadVar(_, _) Statement::LoadVar(_, _)
| Statement::StoreVar(_, _) | Statement::StoreVar(_, _)
| Statement::Converison(_) | Statement::Conversion(_)
| Statement::Constant(_) => unreachable!(), | Statement::Constant(_) => unreachable!(),
} }
} }
@ -331,7 +333,7 @@ fn expand_arguments(
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::Converison(_) | Statement::Constant(_) => unreachable!(), Statement::Conversion(_) | Statement::Constant(_) => unreachable!(),
} }
} }
result result
@ -572,7 +574,7 @@ fn insert_implicit_conversions(
| s @ Statement::Variable(_, _, _) | s @ Statement::Variable(_, _, _)
| s @ Statement::LoadVar(_, _) | s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _) => result.push(s), | s @ Statement::StoreVar(_, _) => result.push(s),
Statement::Converison(_) => unreachable!(), Statement::Conversion(_) => unreachable!(),
} }
} }
result result
@ -660,7 +662,7 @@ fn emit_function_body_ops(
_ => unreachable!(), _ => unreachable!(),
} }
} }
Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?, Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conditional(bra) => { Statement::Conditional(bra) => {
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
} }
@ -973,38 +975,33 @@ enum Statement<A: Args> {
Instruction(Instruction<A>), Instruction(Instruction<A>),
// SPIR-V compatible replacement for PTX predicates // SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition), Conditional(BrachCondition),
Converison(ImplicitConversion), Conversion(ImplicitConversion),
Constant(ConstantDefinition), Constant(ConstantDefinition),
} }
impl<A: Args> Statement<A> { impl Statement<ExpandedArgs> {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
match self { match self {
Statement::Variable(id, _, _) => f(true, id), Statement::Variable(id, _, _) => f(id),
Statement::LoadVar(a, _) => a.visit_id_mut(f), Statement::LoadVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None),
Statement::StoreVar(a, _) => a.visit_id_mut(f), Statement::StoreVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None),
Statement::Label(id) => f(false, id), Statement::Label(id) => f(id),
Statement::Instruction(inst) => inst.visit_id_mut(f), Statement::Instruction(inst) => inst.visit_id(f),
Statement::Conditional(bra) => bra.visit_id_mut(f), Statement::Conditional(bra) => bra.visit_id(&mut |_, id, _| f(id)),
Statement::Converison(conv) => conv.visit_id_mut(f), Statement::Conversion(conv) => conv.visit_id(f),
Statement::Constant(cons) => cons.visit_id_mut(f), Statement::Constant(cons) => cons.visit_id(f),
} }
} }
} }
trait Args { trait Args {
type Arg1: Arg; type Arg1;
type Arg2: Arg; type Arg2;
type Arg2St: Arg; type Arg2St;
type Arg2Mov: Arg; type Arg2Mov;
type Arg3: Arg; type Arg3;
type Arg4: Arg; type Arg4;
type Arg5: Arg; type Arg5;
}
trait Arg {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F);
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F);
} }
enum NormalizedArgs {} enum NormalizedArgs {}
@ -1049,48 +1046,24 @@ enum Instruction<A: Args> {
Ret(ast::RetData), Ret(ast::RetData),
} }
impl<A: Args> Instruction<A> { impl Instruction<NormalizedArgs> {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) {
match self { match self {
Instruction::Ld(_, a) => a.visit_id_mut(f), Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Mov(_, a) => a.visit_id_mut(f), Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
Instruction::Mul(_, a) => a.visit_id_mut(f), Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Add(_, a) => a.visit_id_mut(f), Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Setp(_, a) => a.visit_id_mut(f), Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::SetpBool(_, a) => a.visit_id_mut(f), Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Not(_, a) => a.visit_id_mut(f), Instruction::Not(_, _) => todo!(),
Instruction::Cvt(_, a) => a.visit_id_mut(f), Instruction::Cvt(_, _) => todo!(),
Instruction::Shl(_, a) => a.visit_id_mut(f), Instruction::Shl(_, _) => todo!(),
Instruction::St(_, a) => a.visit_id_mut(f), Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Bra(_, a) => a.visit_id_mut(f), Instruction::Bra(_, a) => a.visit_id(f, None),
Instruction::Ret(_) => (), Instruction::Ret(_) => (),
} }
} }
fn get_type(&self) -> Option<ast::Type> {
match self {
Instruction::Add(add, _) => match add {
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => {
Some(ast::Type::Scalar((*typ).into()))
}
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => Some((*typ).into()),
},
Instruction::Ret(_) => None,
Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
Instruction::Mov(mov, _) => Some(mov.typ),
Instruction::Mul(mul, _) => match mul {
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => {
Some(ast::Type::Scalar((*typ).into()))
}
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => Some((*typ).into()),
},
_ => todo!(),
}
}
}
impl Instruction<NormalizedArgs> {
fn from_ast(s: ast::Instruction<spirv::Word>) -> Self { fn from_ast(s: ast::Instruction<spirv::Word>) -> Self {
match s { match s {
ast::Instruction::Ld(d, a) => Instruction::Ld(d, a), ast::Instruction::Ld(d, a) => Instruction::Ld(d, a),
@ -1110,6 +1083,50 @@ impl Instruction<NormalizedArgs> {
} }
impl Instruction<ExpandedArgs> { impl Instruction<ExpandedArgs> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
let f_visitor = &mut Self::typed_visitor(f);
match self {
Instruction::Ld(_, a) => a.visit_id(f_visitor, None),
Instruction::Mov(_, a) => a.visit_id(f_visitor, None),
Instruction::Mul(_, a) => a.visit_id(f_visitor, None),
Instruction::Add(_, a) => a.visit_id(f_visitor, None),
Instruction::Setp(_, a) => a.visit_id(f_visitor, None),
Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None),
Instruction::Not(_, a) => a.visit_id(f_visitor, None),
Instruction::Cvt(_, a) => a.visit_id(f_visitor, None),
Instruction::Shl(_, a) => a.visit_id(f_visitor, None),
Instruction::St(_, a) => a.visit_id(f_visitor, None),
Instruction::Bra(_, a) => a.visit_id(f_visitor, None),
Instruction::Ret(_) => (),
}
}
fn typed_visitor<'a>(
f: &'a mut impl FnMut(&mut spirv::Word),
) -> impl FnMut(bool, &mut spirv::Word, Option<ast::Type>) + 'a {
move |_, id, _| f(id)
}
fn visit_id_extended<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
) {
match self {
Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Not(_, a) => todo!(),
Instruction::Cvt(_, a) => todo!(),
Instruction::Shl(_, a) => todo!(),
Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Bra(_, a) => a.visit_id(f, None),
Instruction::Ret(_) => (),
}
}
fn jump_target(&self) -> Option<spirv::Word> { fn jump_target(&self) -> Option<spirv::Word> {
match self { match self {
Instruction::Bra(_, a) => Some(a.src), Instruction::Bra(_, a) => Some(a.src),
@ -1132,13 +1149,13 @@ struct Arg1 {
pub src: spirv::Word, pub src: spirv::Word,
} }
impl Arg for Arg1 { impl Arg1 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(false, self.src); &mut self,
} f: &mut F,
t: Option<ast::Type>,
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { ) {
f(false, &mut self.src); f(false, &mut self.src, t);
} }
} }
@ -1147,15 +1164,14 @@ struct Arg2 {
pub src: spirv::Word, pub src: spirv::Word,
} }
impl Arg for Arg2 { impl Arg2 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst); &mut self,
f(false, self.src); f: &mut F,
} t: Option<ast::Type>,
) {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { f(true, &mut self.dst, t);
f(false, &mut self.src); f(false, &mut self.src, t);
f(true, &mut self.dst);
} }
} }
@ -1164,15 +1180,14 @@ pub struct Arg2St {
pub src2: spirv::Word, pub src2: spirv::Word,
} }
impl Arg for Arg2St { impl Arg2St {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(false, self.src1); &mut self,
f(false, self.src2); f: &mut F,
} t: Option<ast::Type>,
) {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { f(false, &mut self.src1, t);
f(false, &mut self.src1); f(false, &mut self.src2, t);
f(false, &mut self.src2);
} }
} }
@ -1182,17 +1197,15 @@ struct Arg3 {
pub src2: spirv::Word, pub src2: spirv::Word,
} }
impl Arg for Arg3 { impl Arg3 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst); &mut self,
f(false, self.src1); f: &mut F,
f(false, self.src2); t: Option<ast::Type>,
} ) {
f(true, &mut self.dst, t);
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { f(false, &mut self.src1, t);
f(false, &mut self.src1); f(false, &mut self.src2, t);
f(false, &mut self.src2);
f(true, &mut self.dst);
} }
} }
@ -1203,19 +1216,26 @@ struct Arg4 {
pub src2: spirv::Word, pub src2: spirv::Word,
} }
impl Arg for Arg4 { impl Arg4 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst1); &mut self,
self.dst2.map(|dst2| f(true, dst2)); f: &mut F,
f(false, self.src1); t: Option<ast::Type>,
f(false, self.src2); ) {
} f(
true,
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { &mut self.dst1,
f(false, &mut self.src1); Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
f(false, &mut self.src2); );
f(true, &mut self.dst1); self.dst2.as_mut().map(|dst2| {
self.dst2.as_mut().map(|dst2| f(true, dst2)); f(
true,
dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
} }
} }
@ -1227,21 +1247,31 @@ struct Arg5 {
pub src3: spirv::Word, pub src3: spirv::Word,
} }
impl Arg for Arg5 { impl Arg5 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst1); &mut self,
self.dst2.map(|dst2| f(true, dst2)); f: &mut F,
f(false, self.src1); t: Option<ast::Type>,
f(false, self.src2); ) {
f(false, self.src3); f(
} true,
&mut self.dst1,
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
f(false, &mut self.src1); );
f(false, &mut self.src2); self.dst2.as_mut().map(|dst2| {
f(false, &mut self.src3); f(
f(true, &mut self.dst1); true,
self.dst2.as_mut().map(|dst2| f(true, dst2)); dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
f(
false,
&mut self.src3,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
} }
} }
@ -1252,12 +1282,8 @@ struct ConstantDefinition {
} }
impl ConstantDefinition { impl ConstantDefinition {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
f(true, self.dst); f(&mut self.dst);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(true, &mut self.dst);
} }
} }
@ -1268,16 +1294,14 @@ struct BrachCondition {
} }
impl BrachCondition { impl BrachCondition {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) {
f(false, self.predicate); f(
f(false, self.if_true); false,
f(false, self.if_false); &mut self.predicate,
} Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { f(false, &mut self.if_true, None);
f(false, &mut self.predicate); f(false, &mut self.if_false, None);
f(false, &mut self.if_true);
f(false, &mut self.if_false);
} }
} }
@ -1298,14 +1322,9 @@ enum ConversionKind {
} }
impl ImplicitConversion { impl ImplicitConversion {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
f(false, self.src); f(&mut self.dst);
f(true, self.dst); f(&mut self.src);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src);
f(true, &mut self.dst);
} }
} }
@ -1343,13 +1362,13 @@ impl<T> ast::Arg1<T> {
} }
} }
impl Arg for ast::Arg1<spirv::Word> { impl ast::Arg1<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(false, self.src); &mut self,
} f: &mut F,
t: Option<ast::Type>,
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { ) {
f(false, &mut self.src); f(false, &mut self.src, t);
} }
} }
@ -1362,15 +1381,14 @@ impl<T> ast::Arg2<T> {
} }
} }
impl Arg for ast::Arg2<spirv::Word> { impl ast::Arg2<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst); &mut self,
self.src.visit_id(f); f: &mut F,
} t: Option<ast::Type>,
) {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { f(true, &mut self.dst, t);
self.src.visit_id_mut(f); self.src.visit_id(f, t);
f(true, &mut self.dst);
} }
} }
@ -1383,15 +1401,14 @@ impl<T> ast::Arg2St<T> {
} }
} }
impl Arg for ast::Arg2St<spirv::Word> { impl ast::Arg2St<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
self.src1.visit_id(f); &mut self,
self.src2.visit_id(f); f: &mut F,
} t: Option<ast::Type>,
) {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { self.src1.visit_id(f, t);
self.src1.visit_id_mut(f); self.src2.visit_id(f, t);
self.src2.visit_id_mut(f);
} }
} }
@ -1404,15 +1421,14 @@ impl<T> ast::Arg2Mov<T> {
} }
} }
impl Arg for ast::Arg2Mov<spirv::Word> { impl ast::Arg2Mov<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst); &mut self,
self.src.visit_id(f); f: &mut F,
} t: Option<ast::Type>,
) {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { f(true, &mut self.dst, t);
self.src.visit_id_mut(f); self.src.visit_id(f, t);
f(true, &mut self.dst);
} }
} }
@ -1426,17 +1442,15 @@ impl<T> ast::Arg3<T> {
} }
} }
impl Arg for ast::Arg3<spirv::Word> { impl ast::Arg3<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst); &mut self,
self.src1.visit_id(f); f: &mut F,
self.src2.visit_id(f); t: Option<ast::Type>,
} ) {
f(true, &mut self.dst, t);
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { self.src1.visit_id(f, t);
self.src1.visit_id_mut(f); self.src2.visit_id(f, t);
self.src2.visit_id_mut(f);
f(true, &mut self.dst);
} }
} }
@ -1451,19 +1465,26 @@ impl<T> ast::Arg4<T> {
} }
} }
impl Arg for ast::Arg4<spirv::Word> { impl ast::Arg4<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst1); &mut self,
self.dst2.map(|i| f(true, i)); f: &mut F,
self.src1.visit_id(f); t: Option<ast::Type>,
self.src2.visit_id(f); ) {
} f(
true,
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { &mut self.dst1,
self.src1.visit_id_mut(f); Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
self.src2.visit_id_mut(f); );
f(true, &mut self.dst1); self.dst2.as_mut().map(|i| {
self.dst2.as_mut().map(|i| f(true, i)); f(
true,
i,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
} }
} }
@ -1479,21 +1500,30 @@ impl<T> ast::Arg5<T> {
} }
} }
impl Arg for ast::Arg5<spirv::Word> { impl ast::Arg5<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
f(true, self.dst1); &mut self,
self.dst2.map(|i| f(true, i)); f: &mut F,
self.src1.visit_id(f); t: Option<ast::Type>,
self.src2.visit_id(f); ) {
self.src3.visit_id(f); f(
} true,
&mut self.dst1,
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) { Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
self.src1.visit_id_mut(f); );
self.src2.visit_id_mut(f); self.dst2.as_mut().map(|i| {
self.src3.visit_id_mut(f); f(
f(true, &mut self.dst1); true,
self.dst2.as_mut().map(|i| f(true, i)); i,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
self.src3.visit_id(
f,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
} }
} }
@ -1508,18 +1538,14 @@ impl<T> ast::Operand<T> {
} }
impl<T: Copy> ast::Operand<T> { impl<T: Copy> ast::Operand<T> {
fn visit_id<F: FnMut(bool, T)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut T, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
match self { match self {
ast::Operand::Reg(i) => f(false, *i), ast::Operand::Reg(i) => f(false, i, t),
ast::Operand::RegOffset(i, _) => f(false, *i), ast::Operand::RegOffset(i, _) => f(false, i, t),
ast::Operand::Imm(_) => (),
}
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
match self {
ast::Operand::Reg(i) => f(false, i),
ast::Operand::RegOffset(i, _) => f(false, i),
ast::Operand::Imm(_) => (), ast::Operand::Imm(_) => (),
} }
} }
@ -1535,16 +1561,13 @@ impl<T> ast::MovOperand<T> {
} }
impl<T: Copy> ast::MovOperand<T> { impl<T: Copy> ast::MovOperand<T> {
fn visit_id<F: FnMut(bool, T)>(&self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut T, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
match self { match self {
ast::MovOperand::Op(o) => o.visit_id(f), ast::MovOperand::Op(o) => o.visit_id(f, t),
ast::MovOperand::Vec(_, _) => todo!(),
}
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
match self {
ast::MovOperand::Op(o) => o.visit_id_mut(f),
ast::MovOperand::Vec(_, _) => todo!(), ast::MovOperand::Vec(_, _) => todo!(),
} }
} }
@ -1793,7 +1816,7 @@ fn insert_conversion_src(
conv: ConversionKind, conv: ConversionKind,
) -> spirv::Word { ) -> spirv::Word {
let temp_src = id_def.new_id(Some(instr_type)); let temp_src = id_def.new_id(Some(instr_type));
func.push(Statement::Converison(ImplicitConversion { func.push(Statement::Conversion(ImplicitConversion {
src: src, src: src,
dst: temp_src, dst: temp_src,
from: src_type, from: src_type,
@ -1838,7 +1861,7 @@ fn get_conversion_dst(
let original_dst = *dst; let original_dst = *dst;
let temp_dst = id_def.new_id(Some(instr_type)); let temp_dst = id_def.new_id(Some(instr_type));
*dst = temp_dst; *dst = temp_dst;
Statement::Converison(ImplicitConversion { Statement::Conversion(ImplicitConversion {
src: temp_dst, src: temp_dst,
dst: original_dst, dst: original_dst,
from: instr_type, from: instr_type,
@ -1938,31 +1961,33 @@ fn insert_implicit_bitcasts(
mut instr: Instruction<ExpandedArgs>, mut instr: Instruction<ExpandedArgs>,
) { ) {
let mut dst_coercion = None; let mut dst_coercion = None;
if let Some(instr_type) = instr.get_type() { instr.visit_id_extended(&mut |is_dst, id, id_type| {
instr.visit_id_mut(&mut |is_dst, id| { let id_type_from_instr = match id_type {
let id_type = id_def.get_type(*id); Some(t) => t,
if should_bitcast(instr_type, id_def.get_type(*id)) { None => return,
if is_dst { };
dst_coercion = Some(get_conversion_dst( let id_actual_type = id_def.get_type(*id);
id_def, if should_bitcast(id_type_from_instr, id_def.get_type(*id)) {
id, if is_dst {
instr_type, dst_coercion = Some(get_conversion_dst(
id_type, id_def,
ConversionKind::Default, id,
)); id_type_from_instr,
} else { id_actual_type,
*id = insert_conversion_src( ConversionKind::Default,
func, ));
id_def, } else {
*id, *id = insert_conversion_src(
id_type, func,
instr_type, id_def,
ConversionKind::Default, *id,
); id_actual_type,
} id_type_from_instr,
ConversionKind::Default,
);
} }
}); }
} });
func.push(Statement::Instruction(instr)); func.push(Statement::Instruction(instr));
if let Some(cond) = dst_coercion { if let Some(cond) = dst_coercion {
func.push(cond); func.push(cond);