Update type lookup map when emitting new instructions during translation

This commit is contained in:
Andrzej Janik
2020-07-20 20:15:23 +02:00
parent 872d69c714
commit 4e9a71ed38

View File

@ -180,11 +180,23 @@ fn to_ssa<'a>(
collect_arg_ids(&mut contant_ids, &mut type_check, &f_args); collect_arg_ids(&mut contant_ids, &mut type_check, &f_args);
collect_label_ids(&mut contant_ids, &f_body); collect_label_ids(&mut contant_ids, &f_body);
let registers = collect_var_definitions(&f_args, &f_body); let registers = collect_var_definitions(&f_args, &f_body);
let (normalized_ids, unique_ids) = let (normalized_ids, mut unique_ids) =
normalize_identifiers(f_body, &contant_ids, &mut type_check, registers); normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
let (normalized_stmts, unique_ids) = normalize_statements(normalized_ids, unique_ids); let type_check = RefCell::new(type_check);
let (mut func_body, unique_ids) = let new_id = &mut |typ: Option<ast::Type>| {
insert_implicit_conversions(normalized_stmts, unique_ids, &|x| type_check[&x]); let to_insert = unique_ids;
{
let mut type_check = type_check.borrow_mut();
typ.map(|t| (*type_check).insert(to_insert, t));
}
unique_ids += 1;
to_insert
};
let normalized_stmts = normalize_statements(normalized_ids, new_id);
let mut func_body = insert_implicit_conversions(normalized_stmts, new_id, &|x| {
let type_check = type_check.borrow();
type_check[&x]
});
let bbs = get_basic_blocks(&func_body); let bbs = get_basic_blocks(&func_body);
let rpostorder = to_reverse_postorder(&bbs); let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder); let doms = immediate_dominators(&bbs, &rpostorder);
@ -202,22 +214,16 @@ fn to_ssa<'a>(
fn normalize_statements( fn normalize_statements(
func: Vec<ast::Statement<spirv::Word>>, func: Vec<ast::Statement<spirv::Word>>,
unique_ids: spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
) -> (Vec<Statement>, spirv::Word) { ) -> Vec<Statement> {
let mut result = Vec::with_capacity(func.len()); let mut result = Vec::with_capacity(func.len());
let mut id = unique_ids;
let new_id = &mut || {
let to_insert = id;
id += 1;
to_insert
};
for s in func { for s in func {
match s { match s {
ast::Statement::Label(id) => result.push(Statement::Label(id)), ast::Statement::Label(id) => result.push(Statement::Label(id)),
ast::Statement::Instruction(pred, inst) => { ast::Statement::Instruction(pred, inst) => {
if let Some(pred) = pred { if let Some(pred) = pred {
let mut if_true = new_id(); let mut if_true = new_id(None);
let mut if_false = new_id(); let mut if_false = new_id(None);
if pred.not { if pred.not {
std::mem::swap(&mut if_true, &mut if_false); std::mem::swap(&mut if_true, &mut if_false);
} }
@ -245,13 +251,13 @@ fn normalize_statements(
ast::Statement::Variable(_) => unreachable!(), ast::Statement::Variable(_) => unreachable!(),
} }
} }
(result, id) result
} }
#[must_use] #[must_use]
fn normalize_insert_instruction( fn normalize_insert_instruction(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
instr: ast::Instruction<spirv::Word>, instr: ast::Instruction<spirv::Word>,
) -> Instruction { ) -> Instruction {
match instr { match instr {
@ -302,7 +308,7 @@ fn normalize_insert_instruction(
fn normalize_expand_arg2( fn normalize_expand_arg2(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2<spirv::Word>, a: ast::Arg2<spirv::Word>,
) -> Arg2 { ) -> Arg2 {
@ -314,7 +320,7 @@ fn normalize_expand_arg2(
fn normalize_expand_arg2mov( fn normalize_expand_arg2mov(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2Mov<spirv::Word>, a: ast::Arg2Mov<spirv::Word>,
) -> Arg2 { ) -> Arg2 {
@ -326,7 +332,7 @@ fn normalize_expand_arg2mov(
fn normalize_expand_arg2st( fn normalize_expand_arg2st(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2St<spirv::Word>, a: ast::Arg2St<spirv::Word>,
) -> Arg2St { ) -> Arg2St {
@ -338,7 +344,7 @@ fn normalize_expand_arg2st(
fn normalize_expand_arg3( fn normalize_expand_arg3(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg3<spirv::Word>, a: ast::Arg3<spirv::Word>,
) -> Arg3 { ) -> Arg3 {
@ -351,7 +357,7 @@ fn normalize_expand_arg3(
fn normalize_expand_arg4( fn normalize_expand_arg4(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg4<spirv::Word>, a: ast::Arg4<spirv::Word>,
) -> Arg4 { ) -> Arg4 {
@ -365,7 +371,7 @@ fn normalize_expand_arg4(
fn normalize_expand_arg5( fn normalize_expand_arg5(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg5<spirv::Word>, a: ast::Arg5<spirv::Word>,
) -> Arg5 { ) -> Arg5 {
@ -380,7 +386,7 @@ fn normalize_expand_arg5(
fn normalize_expand_operand( fn normalize_expand_operand(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
opr: ast::Operand<spirv::Word>, opr: ast::Operand<spirv::Word>,
) -> spirv::Word { ) -> spirv::Word {
@ -388,7 +394,7 @@ fn normalize_expand_operand(
ast::Operand::Reg(r) => r, ast::Operand::Reg(r) => r,
ast::Operand::Imm(x) => { ast::Operand::Imm(x) => {
if let Some(typ) = inst_type() { if let Some(typ) = inst_type() {
let id = new_id(); let id = new_id(Some(ast::Type::Scalar(typ)));
func.push(Statement::Constant(ConstantDefinition { func.push(Statement::Constant(ConstantDefinition {
dst: id, dst: id,
typ: typ, typ: typ,
@ -405,7 +411,7 @@ fn normalize_expand_operand(
fn normalize_expand_mov_operand( fn normalize_expand_mov_operand(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>, inst_type: &impl Fn() -> Option<ast::ScalarType>,
opr: ast::MovOperand<spirv::Word>, opr: ast::MovOperand<spirv::Word>,
) -> spirv::Word { ) -> spirv::Word {
@ -456,15 +462,9 @@ fn collect_var_definitions<'a>(
*/ */
fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>( fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
normalized_ids: Vec<Statement>, normalized_ids: Vec<Statement>,
unique_ids: spirv::Word, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
type_check: &TypeCheck, type_check: &TypeCheck,
) -> (Vec<Statement>, spirv::Word) { ) -> Vec<Statement> {
let mut id = unique_ids;
let new_id = &mut || {
let temp = id;
id += 1;
temp
};
let mut result = Vec::with_capacity(normalized_ids.len()); let mut result = Vec::with_capacity(normalized_ids.len());
for s in normalized_ids.into_iter() { for s in normalized_ids.into_iter() {
match s { match s {
@ -518,7 +518,7 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
Statement::Converison(_) => unreachable!(), Statement::Converison(_) => unreachable!(),
} }
} }
(result, id) result
} }
fn get_function_type( fn get_function_type(
@ -2007,14 +2007,11 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
} }
} }
fn insert_implicit_conversions_ld_src< fn insert_implicit_conversions_ld_src<TypeCheck: Fn(spirv::Word) -> ast::Type>(
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
instr_type: ast::Type, instr_type: ast::Type,
type_check: &TypeCheck, type_check: &TypeCheck,
new_id: &mut NewId, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
state_space: ast::LdStateSpace, state_space: ast::LdStateSpace,
src: spirv::Word, src: spirv::Word,
) -> spirv::Word { ) -> spirv::Word {
@ -2055,12 +2052,11 @@ fn insert_implicit_conversions_ld_src<
fn insert_implicit_conversions_ld_src_impl< fn insert_implicit_conversions_ld_src_impl<
TypeCheck: Fn(spirv::Word) -> ast::Type, TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>, ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>( >(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
type_check: &TypeCheck, type_check: &TypeCheck,
new_id: &mut NewId, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
instr_type: ast::Type, instr_type: ast::Type,
src: spirv::Word, src: spirv::Word,
should_convert: ShouldConvert, should_convert: ShouldConvert,
@ -2099,15 +2095,15 @@ fn should_convert_ld_generic_src_to_bitcast(
} }
#[must_use] #[must_use]
fn insert_conversion_src<NewId: FnMut() -> spirv::Word>( fn insert_conversion_src(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
new_id: &mut NewId, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
src: spirv::Word, src: spirv::Word,
src_type: ast::Type, src_type: ast::Type,
instr_type: ast::Type, instr_type: ast::Type,
conv: ConversionKind, conv: ConversionKind,
) -> spirv::Word { ) -> spirv::Word {
let temp_src = new_id(); let temp_src = new_id(Some(instr_type));
func.push(Statement::Converison(ImplicitConversion { func.push(Statement::Converison(ImplicitConversion {
src: src, src: src,
dst: temp_src, dst: temp_src,
@ -2121,7 +2117,6 @@ fn insert_conversion_src<NewId: FnMut() -> spirv::Word>(
fn insert_with_implicit_conversion_dst< fn insert_with_implicit_conversion_dst<
T, T,
TypeCheck: Fn(spirv::Word) -> ast::Type, TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>, ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word, Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> Instruction, ToInstruction: FnOnce(T) -> Instruction,
@ -2129,7 +2124,7 @@ fn insert_with_implicit_conversion_dst<
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
instr_type: ast::ScalarType, instr_type: ast::ScalarType,
type_check: &TypeCheck, type_check: &TypeCheck,
new_id: &mut NewId, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
should_convert: ShouldConvert, should_convert: ShouldConvert,
mut t: T, mut t: T,
setter: Setter, setter: Setter,
@ -2146,15 +2141,15 @@ fn insert_with_implicit_conversion_dst<
} }
#[must_use] #[must_use]
fn get_conversion_dst<NewId: FnMut() -> spirv::Word>( fn get_conversion_dst(
new_id: &mut NewId, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
dst: &mut spirv::Word, dst: &mut spirv::Word,
instr_type: ast::Type, instr_type: ast::Type,
dst_type: ast::Type, dst_type: ast::Type,
kind: ConversionKind, kind: ConversionKind,
) -> Statement { ) -> Statement {
let original_dst = *dst; let original_dst = *dst;
let temp_dst = new_id(); let temp_dst = new_id(Some(instr_type));
*dst = temp_dst; *dst = temp_dst;
Statement::Converison(ImplicitConversion { Statement::Converison(ImplicitConversion {
src: temp_dst, src: temp_dst,
@ -2250,13 +2245,10 @@ fn should_convert_relaxed_dst(
} }
} }
fn insert_implicit_bitcasts< fn insert_implicit_bitcasts<TypeCheck: Fn(spirv::Word) -> ast::Type>(
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
func: &mut Vec<Statement>, func: &mut Vec<Statement>,
type_check: &TypeCheck, type_check: &TypeCheck,
new_id: &mut NewId, new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
mut instr: Instruction, mut instr: Instruction,
) { ) {
let mut dst_coercion = None; let mut dst_coercion = None;
@ -2662,9 +2654,20 @@ mod tests {
let mut constant_ids = HashMap::new(); let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &ast); collect_label_ids(&mut constant_ids, &ast);
let registers = collect_var_definitions(&[], &ast); let registers = collect_var_definitions(&[], &ast);
let (normalized_ids, unique_ids) = let mut type_check = HashMap::new();
normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers); let (normalized_ids, mut unique_ids) =
let (normalized_stmts, _) = normalize_statements(normalized_ids, unique_ids); normalize_identifiers(ast, &constant_ids, &mut type_check, registers);
let type_check = RefCell::new(type_check);
let new_id = &mut |typ: Option<ast::Type>| {
let to_insert = unique_ids;
{
let mut type_check = type_check.borrow_mut();
typ.map(|t| (*type_check).insert(to_insert, t));
}
unique_ids += 1;
to_insert
};
let normalized_stmts = normalize_statements(normalized_ids, new_id);
let mut bbs = get_basic_blocks(&normalized_stmts); let mut bbs = get_basic_blocks(&normalized_stmts);
bbs.iter_mut().for_each(sort_pred_succ); bbs.iter_mut().for_each(sort_pred_succ);
assert_eq!( assert_eq!(
@ -2811,10 +2814,22 @@ mod tests {
let mut constant_ids = HashMap::new(); let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast); collect_label_ids(&mut constant_ids, &fn_ast);
assert_eq!(constant_ids.len(), 4); assert_eq!(constant_ids.len(), 4);
let mut type_check = HashMap::new();
let registers = collect_var_definitions(&[], &fn_ast); let registers = collect_var_definitions(&[], &fn_ast);
let (normalized_ids, unique_ids) = let (normalized_ids, mut unique_ids) =
normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers); normalize_identifiers(fn_ast, &constant_ids, &mut type_check, registers);
let (normalized_stmts, max_id) = normalize_statements(normalized_ids, unique_ids); let type_check = RefCell::new(type_check);
let new_id = &mut |typ: Option<ast::Type>| {
let to_insert = unique_ids;
{
let mut type_check = type_check.borrow_mut();
typ.map(|t| (*type_check).insert(to_insert, t));
}
unique_ids += 1;
to_insert
};
let normalized_stmts = normalize_statements(normalized_ids, new_id);
let bbs = get_basic_blocks(&normalized_stmts); let bbs = get_basic_blocks(&normalized_stmts);
let rpostorder = to_reverse_postorder(&bbs); let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder); let doms = immediate_dominators(&bbs, &rpostorder);
@ -2822,7 +2837,7 @@ mod tests {
let phi = gather_phi_sets( let phi = gather_phi_sets(
&normalized_stmts, &normalized_stmts,
constant_ids.len() as u32, constant_ids.len() as u32,
max_id, unique_ids,
&bbs, &bbs,
&dom_fronts, &dom_fronts,
); );