Parametrize Statement by Instruction and not by Arguments

This commit is contained in:
Andrzej Janik
2020-07-28 00:59:41 +02:00
parent 72f5ffe2f9
commit d514a5610a

View File

@ -182,9 +182,9 @@ fn to_ssa<'a>(
} }
fn normalize_labels( fn normalize_labels(
func: Vec<Statement<ExpandedArgs>>, func: Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
) -> Vec<Statement<ExpandedArgs>> { ) -> Vec<ExpandedStatement> {
let mut labels_in_use = HashSet::new(); let mut labels_in_use = HashSet::new();
for s in func.iter() { for s in func.iter() {
match s { match s {
@ -240,11 +240,11 @@ fn normalize_predicates(
result.push(Statement::Conditional(branch)); result.push(Statement::Conditional(branch));
if folded_bra.is_none() { if folded_bra.is_none() {
result.push(Statement::Label(if_true)); result.push(Statement::Label(if_true));
result.push(Statement::Instruction(Instruction::from_ast(inst))); result.push(Statement::Instruction(inst));
} }
result.push(Statement::Label(if_false)); result.push(Statement::Label(if_false));
} else { } else {
result.push(Statement::Instruction(Instruction::from_ast(inst))); result.push(Statement::Instruction(inst));
} }
} }
ast::Statement::Variable(var) => { ast::Statement::Variable(var) => {
@ -263,7 +263,7 @@ fn insert_mem_ssa_statements(
for s in func { for s in func {
match s { match s {
Statement::Instruction(inst) => match inst { Statement::Instruction(inst) => match inst {
Instruction::Ld( ast::Instruction::Ld(
ld ld
@ @
ast::LdData { ast::LdData {
@ -272,7 +272,7 @@ fn insert_mem_ssa_statements(
}, },
arg, arg,
) => { ) => {
result.push(Statement::Instruction(Instruction::Ld(ld, arg))); result.push(Statement::Instruction(ast::Instruction::Ld(ld, arg)));
} }
mut inst => { mut inst => {
let mut post_statements = Vec::new(); let mut post_statements = Vec::new();
@ -343,51 +343,51 @@ fn expand_arguments(
fn normalize_insert_instruction( fn normalize_insert_instruction(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
instr: Instruction<NormalizedArgs>, instr: ast::Instruction<spirv::Word>,
) -> Instruction<ExpandedArgs> { ) -> Instruction {
match instr { match instr {
Instruction::Ld(d, a) => { ast::Instruction::Ld(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a); let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a);
Instruction::Ld(d, arg) Instruction::Ld(d, arg)
} }
Instruction::Mov(d, a) => { ast::Instruction::Mov(d, a) => {
let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a); let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a);
Instruction::Mov(d, arg) Instruction::Mov(d, arg)
} }
Instruction::Mul(d, a) => { ast::Instruction::Mul(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
Instruction::Mul(d, arg) Instruction::Mul(d, arg)
} }
Instruction::Add(d, a) => { ast::Instruction::Add(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
Instruction::Add(d, arg) Instruction::Add(d, arg)
} }
Instruction::Setp(d, a) => { ast::Instruction::Setp(d, a) => {
let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a); let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a);
Instruction::Setp(d, arg) Instruction::Setp(d, arg)
} }
Instruction::SetpBool(d, a) => { ast::Instruction::SetpBool(d, a) => {
let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a); let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a);
Instruction::SetpBool(d, arg) Instruction::SetpBool(d, arg)
} }
Instruction::Not(d, a) => { ast::Instruction::Not(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
Instruction::Not(d, arg) Instruction::Not(d, arg)
} }
Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }), ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }),
Instruction::Cvt(d, a) => { ast::Instruction::Cvt(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
Instruction::Cvt(d, arg) Instruction::Cvt(d, arg)
} }
Instruction::Shl(d, a) => { ast::Instruction::Shl(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a); let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a);
Instruction::Shl(d, arg) Instruction::Shl(d, arg)
} }
Instruction::St(d, a) => { ast::Instruction::St(d, a) => {
let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a); let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a);
Instruction::St(d, arg) Instruction::St(d, arg)
} }
Instruction::Ret(d) => Instruction::Ret(d), ast::Instruction::Ret(d) => Instruction::Ret(d),
} }
} }
@ -967,19 +967,19 @@ impl NumericIdResolver {
} }
} }
enum Statement<A: Args> { enum Statement<I> {
Variable(spirv::Word, ast::Type, ast::StateSpace), Variable(spirv::Word, ast::Type, ast::StateSpace),
LoadVar(Arg2, ast::Type), LoadVar(Arg2, ast::Type),
StoreVar(Arg2St, ast::Type), StoreVar(Arg2St, ast::Type),
Label(u32), Label(u32),
Instruction(Instruction<A>), Instruction(I),
// SPIR-V compatible replacement for PTX predicates // SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition), Conditional(BrachCondition),
Conversion(ImplicitConversion), Conversion(ImplicitConversion),
Constant(ConstantDefinition), Constant(ConstantDefinition),
} }
impl Statement<ExpandedArgs> { impl Statement<Instruction> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) { fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
match self { match self {
Statement::Variable(id, _, _) => f(id), Statement::Variable(id, _, _) => f(id),
@ -994,95 +994,44 @@ impl Statement<ExpandedArgs> {
} }
} }
trait Args { type NormalizedStatement = Statement<ast::Instruction<spirv::Word>>;
type Arg1; type ExpandedStatement = Statement<Instruction>;
type Arg2;
type Arg2St;
type Arg2Mov;
type Arg3;
type Arg4;
type Arg5;
}
enum NormalizedArgs {} enum Instruction {
Ld(ast::LdData, Arg2),
impl Args for NormalizedArgs { Mov(ast::MovData, Arg2),
type Arg1 = ast::Arg1<spirv::Word>; Mul(ast::MulDetails, Arg3),
type Arg2 = ast::Arg2<spirv::Word>; Add(ast::AddDetails, Arg3),
type Arg2St = ast::Arg2St<spirv::Word>; Setp(ast::SetpData, Arg4),
type Arg2Mov = ast::Arg2Mov<spirv::Word>; SetpBool(ast::SetpBoolData, Arg5),
type Arg3 = ast::Arg3<spirv::Word>; Not(ast::NotData, Arg2),
type Arg4 = ast::Arg4<spirv::Word>; Bra(ast::BraData, Arg1),
type Arg5 = ast::Arg5<spirv::Word>; Cvt(ast::CvtData, Arg2),
} Shl(ast::ShlData, Arg3),
St(ast::StData, Arg2St),
enum ExpandedArgs {}
impl Args for ExpandedArgs {
type Arg1 = Arg1;
type Arg2 = Arg2;
type Arg2St = Arg2St;
type Arg2Mov = Arg2;
type Arg3 = Arg3;
type Arg4 = Arg4;
type Arg5 = Arg5;
}
type NormalizedStatement = Statement<NormalizedArgs>;
type ExpandedStatement = Statement<ExpandedArgs>;
enum Instruction<A: Args> {
Ld(ast::LdData, A::Arg2),
Mov(ast::MovData, A::Arg2Mov),
Mul(ast::MulDetails, A::Arg3),
Add(ast::AddDetails, A::Arg3),
Setp(ast::SetpData, A::Arg4),
SetpBool(ast::SetpBoolData, A::Arg5),
Not(ast::NotData, A::Arg2),
Bra(ast::BraData, A::Arg1),
Cvt(ast::CvtData, A::Arg2),
Shl(ast::ShlData, A::Arg3),
St(ast::StData, A::Arg2St),
Ret(ast::RetData), Ret(ast::RetData),
} }
impl Instruction<NormalizedArgs> { impl ast::Instruction<spirv::Word> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) { fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) {
match self { match self {
Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Not(_, _) => todo!(), ast::Instruction::Not(_, _) => todo!(),
Instruction::Cvt(_, _) => todo!(), ast::Instruction::Cvt(_, _) => todo!(),
Instruction::Shl(_, _) => todo!(), ast::Instruction::Shl(_, _) => todo!(),
Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Bra(_, a) => a.visit_id(f, None), ast::Instruction::Bra(_, a) => a.visit_id(f, None),
Instruction::Ret(_) => (), ast::Instruction::Ret(_) => (),
}
}
fn from_ast(s: ast::Instruction<spirv::Word>) -> Self {
match s {
ast::Instruction::Ld(d, a) => Instruction::Ld(d, a),
ast::Instruction::Mov(d, a) => Instruction::Mov(d, a),
ast::Instruction::Mul(d, a) => Instruction::Mul(d, a),
ast::Instruction::Add(d, a) => Instruction::Add(d, a),
ast::Instruction::Setp(d, a) => Instruction::Setp(d, a),
ast::Instruction::SetpBool(d, a) => Instruction::SetpBool(d, a),
ast::Instruction::Not(d, a) => Instruction::Not(d, a),
ast::Instruction::Cvt(d, a) => Instruction::Cvt(d, a),
ast::Instruction::Shl(d, a) => Instruction::Shl(d, a),
ast::Instruction::St(d, a) => Instruction::St(d, a),
ast::Instruction::Bra(d, a) => Instruction::Bra(d, a),
ast::Instruction::Ret(d) => Instruction::Ret(d),
} }
} }
} }
impl Instruction<ExpandedArgs> { impl Instruction {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) { fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
let f_visitor = &mut Self::typed_visitor(f); let f_visitor = &mut Self::typed_visitor(f);
match self { match self {
@ -1118,9 +1067,9 @@ impl Instruction<ExpandedArgs> {
Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Not(_, a) => todo!(), Instruction::Not(_, _) => todo!(),
Instruction::Cvt(_, a) => todo!(), Instruction::Cvt(_, _) => todo!(),
Instruction::Shl(_, a) => todo!(), Instruction::Shl(_, _) => todo!(),
Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Bra(_, a) => a.visit_id(f, None), Instruction::Bra(_, a) => a.visit_id(f, None),
Instruction::Ret(_) => (), Instruction::Ret(_) => (),
@ -1830,7 +1779,7 @@ fn insert_with_implicit_conversion_dst<
T, T,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>, ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word, Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> Instruction<ExpandedArgs>, ToInstruction: FnOnce(T) -> Instruction,
>( >(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
instr_type: ast::ScalarType, instr_type: ast::ScalarType,
@ -1958,7 +1907,7 @@ fn should_convert_relaxed_dst(
fn insert_implicit_bitcasts( fn insert_implicit_bitcasts(
func: &mut Vec<ExpandedStatement>, func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
mut instr: Instruction<ExpandedArgs>, mut instr: Instruction,
) { ) {
let mut dst_coercion = None; let mut dst_coercion = None;
instr.visit_id_extended(&mut |is_dst, id, id_type| { instr.visit_id_extended(&mut |is_dst, id, id_type| {