diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 9089c01..82580aa 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -138,12 +138,11 @@ pub enum Instruction { Cvt(CvtData, Arg2), Shl(ShlData, Arg3), St(StData, Arg2), - At(AtData, Arg1), Ret(RetData), } pub struct Arg1 { - pub dst: ID, + pub src: ID, // it is a jump destination, but in terms of operands it is a source operand } pub struct Arg2 { @@ -210,6 +209,4 @@ pub struct ShlData {} pub struct StData {} -pub struct AtData {} - pub struct RetData {} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 2168ba7..f40846d 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -426,7 +426,7 @@ VectorOperand: (&'input str, &'input str) = { }; Arg1: ast::Arg1<&'input str> = { - => ast::Arg1{<>} + => ast::Arg1{<>} }; Arg2: ast::Arg2<&'input str> = { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 5584af5..02ff958 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -108,34 +108,32 @@ fn emit_function<'a>( let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type)); builder.function_parameter(arg_type)?; } - let normalized_ids = normalize_identifiers(f.body); + let (mut normalized_ids, max_id) = normalize_identifiers(f.body); let bbs = get_basic_blocks(&normalized_ids); let rpostorder = to_reverse_postorder(&bbs); - let dom_fronts = dominance_frontiers(&bbs, &rpostorder); - let (ops, phis) = ssa_legalize(normalized_ids, bbs, &dom_fronts); - emit_function_body_ops(builder, ops, phis); + let doms = immediate_dominators(&bbs, &rpostorder); + let dom_fronts = dominance_frontiers(&bbs, &rpostorder, &doms); + ssa_legalize(&mut normalized_ids, max_id, bbs, &doms, &dom_fronts); + emit_function_body_ops(builder); builder.ret()?; builder.end_function()?; Ok(func_id) } -fn emit_function_body_ops( - builder: &mut dr::Builder, - ops: Vec, - phis: Vec>, -) { +fn emit_function_body_ops(builder: &mut dr::Builder) { todo!() } // TODO: support scopes -fn normalize_identifiers<'a>(func: Vec>) -> Vec { +fn normalize_identifiers<'a>(func: Vec>) -> (Vec, spirv::Word) { let mut result = Vec::with_capacity(func.len()); let mut id: u32 = 0; let mut known_ids = HashMap::new(); let mut get_or_add = |key| { *known_ids.entry(key).or_insert_with(|| { + let to_insert = id; id += 1; - id + to_insert }) }; for s in func { @@ -143,50 +141,184 @@ fn normalize_identifiers<'a>(func: Vec>) -> Vec, + func: &mut [Statement], + max_id: spirv::Word, bbs: Vec, + doms: &Vec, dom_fronts: &Vec>, -) -> (Vec, Vec>) { - let mut phis = gather_phi_sets(&func, &bbs, dom_fronts); - trim_singleton_phi_sets(&mut phis); - todo!() +) { + let phis = gather_phi_sets(&func, &bbs, dom_fronts); + apply_ssa_renaming(func, &bbs, doms, max_id, &phis); +} + +// "Modern Compiler Implementation in Java" - Algorithm 19.7 +fn apply_ssa_renaming( + func: &mut [Statement], + bbs: &[BasicBlock], + doms: &[BBIndex], + max_id: spirv::Word, + old_phi: &[Vec], +) { + let mut dom_tree = vec![Vec::new(); bbs.len()]; + for (bb, idom) in doms.iter().enumerate() { + dom_tree[idom.0].push(BBIndex(bb)); + } + let mut old_dst_id = vec![Vec::new(); bbs.len()]; + for bb in 0..bbs.len() { + for s in get_bb_body(func, bbs, BBIndex(bb)) { + s.for_dst_id(&mut |id| old_dst_id[bb].push(id)); + } + } + let mut new_phi = old_phi + .iter() + .map(|ids| { + ids.iter() + .map(|id| (*id, Vec::new())) + .collect::>() + }) + .collect::>(); + let mut ssa_state = SSARewriteState::new(max_id); + // once again, we do explicit stack + let mut state = Vec::new(); + state.push((BBIndex(0), 0)); + loop { + if let Some((BBIndex(bb), dom_succ_idx)) = state.last_mut() { + let bb = *bb; + if *dom_succ_idx == 0 { + rename_phi_dst(max_id, &mut ssa_state, &mut new_phi[bb]); + rename_bb_body(&mut ssa_state, func, bbs, BBIndex(bb)); + for BBIndex(succ_idx) in bbs[bb].succ.iter() { + rename_succesor_phi_src(&ssa_state, &mut new_phi[*succ_idx]); + } + } + if let Some(s) = dom_tree[bb].get(*dom_succ_idx) { + *dom_succ_idx += 1; + state.push((*s, 0)); + } else { + state.pop(); + pop_stacks(&mut ssa_state, &old_phi[bb], &old_dst_id[bb]); + } + } else { + break; + } + } +} + +fn rename_phi_dst( + max_old_id: spirv::Word, + rewriter: &mut SSARewriteState, + phi: &mut HashMap>, +) { + let old_keys = phi + .keys() + .copied() + .filter(|id| *id <= max_old_id) + .collect::>(); + for k in old_keys.into_iter() { + let remapped_id = rewriter.redefine(k); + let values = phi.remove(&k).unwrap(); + phi.insert(remapped_id, values); + } +} + +fn rename_bb_body( + ssa_state: &mut SSARewriteState, + func: &mut [Statement], + all_bb: &[BasicBlock], + bb: BBIndex, +) { + for s in get_bb_body_mut(func, all_bb, bb) { + s.visit_id_mut(&mut |is_dst, id| { + if is_dst { + *id = ssa_state.redefine(*id); + } else { + *id = ssa_state.get(*id); + } + }); + } +} + +fn rename_succesor_phi_src( + ssa_state: &SSARewriteState, + phi: &mut HashMap>, +) { + for (id, v) in phi.iter_mut() { + v.push(ssa_state.get(*id)); + } +} + +fn pop_stacks(ssa_state: &mut SSARewriteState, old_phi: &[spirv::Word], old_ids: &[spirv::Word]) { + for id in old_phi.iter().chain(old_ids) { + ssa_state.pop(*id); + } +} + +fn get_bb_body_mut<'a>( + func: &'a mut [Statement], + all_bb: &[BasicBlock], + bb: BBIndex, +) -> &'a mut [Statement] { + let (start, end) = get_bb_body_idx(all_bb, bb); + &mut func[start..end] +} + +fn get_bb_body<'a>(func: &'a [Statement], all_bb: &[BasicBlock], bb: BBIndex) -> &'a [Statement] { + let (start, end) = get_bb_body_idx(all_bb, bb); + &func[start..end] +} + +fn get_bb_body_idx(all_bb: &[BasicBlock], bb: BBIndex) -> (usize, usize) { + let BBIndex(bb_idx) = bb; + let start = all_bb[bb_idx].start.0; + let end = if bb_idx == all_bb.len() - 1 { + all_bb.len() + } else { + all_bb[bb_idx + 1].start.0 + }; + (start, end) +} + +// We assume here that the variables are defined in the dense sequence 0..max +struct SSARewriteState { + next: spirv::Word, + stack: Vec>, +} + +impl SSARewriteState { + fn new(max: spirv::Word) -> Self { + let stack = vec![Vec::new(); max as usize]; + SSARewriteState { + next: max + 1, + stack, + } + } + + fn get(&self, x: spirv::Word) -> spirv::Word { + *self.stack[x as usize].last().unwrap() + } + + fn redefine(&mut self, x: spirv::Word) -> spirv::Word { + let result = self.next; + self.next += 1; + self.stack[x as usize].push(result); + return result; + } + + fn pop(&mut self, x: spirv::Word) { + self.stack[x as usize].pop(); + } } fn gather_phi_sets( - func: &Vec, - bbs: &Vec, - dom_fronts: &Vec>, -) -> Vec>> { - let mut phis = vec![HashMap::new(); bbs.len()]; - for (bb_idx, bb) in bbs.iter().enumerate() { - let StmtIndex(start) = bb.start; - let end = if bb_idx == bbs.len() - 1 { - bbs.len() - } else { - bbs[bb_idx + 1].start.0 - }; - for s in func[start..end].iter() { - s.for_dst_id(&mut |id| { - for BBIndex(phi_target) in dom_fronts[bb_idx].iter() { - phis[*phi_target] - .entry(id) - .or_insert_with(|| HashSet::new()) - .insert(BBIndex(bb_idx)); - } - }); - } - } - phis -} - -fn trim_singleton_phi_sets(phis: &mut Vec>>) { - for phi_map in phis.iter_mut() { - phi_map.retain(|_, set| set.len() > 1); - } + func: &[Statement], + bbs: &[BasicBlock], + dom_fronts: &[HashSet], +) -> Vec> { + todo!() } fn get_basic_blocks(fun: &Vec) -> Vec { @@ -258,8 +390,11 @@ fn get_basic_blocks(fun: &Vec) -> Vec { // "A Simple, Fast Dominance Algorithm" - Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy // https://www.cs.rice.edu/~keith/EMBED/dom.pdf -fn dominance_frontiers(bbs: &Vec, order: &Vec) -> Vec> { - let doms = immediate_dominators(bbs, order); +fn dominance_frontiers( + bbs: &Vec, + order: &Vec, + doms: &Vec, +) -> Vec> { let mut result = vec![HashSet::new(); bbs.len()]; for (bb_idx, b) in bbs.iter().enumerate() { if b.pred.len() < 2 { @@ -321,16 +456,16 @@ fn intersect(doms: &mut Vec, b1: BBIndex, b2: BBIndex) -> BBIndex { fn to_reverse_postorder(input: &Vec) -> Vec { let mut i = input.len(); let mut old = BitVec::from_elem(input.len(), false); - // I would do just vec![BasicBlock::empty(), input.len()], but Vec is not Copy - let mut result = Vec::with_capacity(input.len()); - unsafe { result.set_len(input.len()) }; + let mut result = vec![BBIndex(usize::max_value()); input.len()]; // original uses recursion and implicit stack, we do it explictly let mut state = Vec::new(); state.push((BBIndex(0), 0usize)); loop { if let Some((BBIndex(bb), succ_iter_idx)) = state.last_mut() { let bb = *bb; - old.set(bb, true); + if *succ_iter_idx == 0 { + old.set(bb, true); + } if let Some(BBIndex(succ)) = &input[bb].succ.get(*succ_iter_idx) { *succ_iter_idx += 1; if !old.get(*succ).unwrap() { @@ -348,11 +483,6 @@ fn to_reverse_postorder(input: &Vec) -> Vec { result } -struct PhiBasicBlock { - bb: BasicBlock, - phi: Vec<(spirv::Word, Vec<(spirv::Word, BBIndex)>)>, -} - #[derive(Eq, PartialEq, Debug, Clone)] struct BasicBlock { start: StmtIndex, @@ -396,6 +526,16 @@ impl Statement { } } } + + fn visit_id_mut(&mut self, f: &mut F) { + match self { + Statement::Label(id) => f(true, id), + Statement::Instruction(pred, inst) => { + pred.as_mut().map(|p| p.visit_id_mut(f)); + inst.visit_id_mut(f); + } + } + } } impl ast::PredAt { @@ -405,10 +545,14 @@ impl ast::PredAt { label: f(self.label), } } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.label) + } } impl ast::PredAt { - fn for_dst_id(&self, f: &mut F) {} + fn for_dst_id(&self, _: &mut F) {} } impl ast::Instruction { @@ -425,16 +569,32 @@ impl ast::Instruction { ast::Instruction::Cvt(d, a) => ast::Instruction::Cvt(d, a.map_id(f)), ast::Instruction::Shl(d, a) => ast::Instruction::Shl(d, a.map_id(f)), ast::Instruction::St(d, a) => ast::Instruction::St(d, a.map_id(f)), - ast::Instruction::At(d, a) => ast::Instruction::At(d, a.map_id(f)), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), } } + + fn visit_id_mut(&mut self, f: &mut F) { + match self { + ast::Instruction::Ld(_, a) => a.visit_id_mut(f), + ast::Instruction::Mov(_, a) => a.visit_id_mut(f), + ast::Instruction::Mul(_, a) => a.visit_id_mut(f), + ast::Instruction::Add(_, a) => a.visit_id_mut(f), + ast::Instruction::Setp(_, a) => a.visit_id_mut(f), + ast::Instruction::SetpBool(_, a) => a.visit_id_mut(f), + ast::Instruction::Not(_, a) => a.visit_id_mut(f), + ast::Instruction::Cvt(_, a) => a.visit_id_mut(f), + ast::Instruction::Shl(_, a) => a.visit_id_mut(f), + ast::Instruction::St(_, a) => a.visit_id_mut(f), + ast::Instruction::Bra(_, a) => a.visit_id_mut(f), + ast::Instruction::Ret(_) => (), + } + } } impl ast::Instruction { fn jump_target(&self) -> Option { match self { - ast::Instruction::Bra(_, a) => Some(a.dst), + ast::Instruction::Bra(_, a) => Some(a.src), ast::Instruction::Ld(_, _) | ast::Instruction::Mov(_, _) | ast::Instruction::Mul(_, _) @@ -445,14 +605,12 @@ impl ast::Instruction { | ast::Instruction::Cvt(_, _) | ast::Instruction::Shl(_, _) | ast::Instruction::St(_, _) - | ast::Instruction::At(_, _) | ast::Instruction::Ret(_) => None, } } fn for_dst_id(&self, f: &mut F) { match self { - ast::Instruction::Bra(_, a) => a.for_dst_id(f), ast::Instruction::Ld(_, a) => a.for_dst_id(f), ast::Instruction::Mov(_, a) => a.for_dst_id(f), ast::Instruction::Mul(_, a) => a.for_dst_id(f), @@ -463,7 +621,7 @@ impl ast::Instruction { ast::Instruction::Cvt(_, a) => a.for_dst_id(f), ast::Instruction::Shl(_, a) => a.for_dst_id(f), ast::Instruction::St(_, a) => a.for_dst_id(f), - ast::Instruction::At(_, a) => a.for_dst_id(f), + ast::Instruction::Bra(_, _) => (), ast::Instruction::Ret(_) => (), } } @@ -471,13 +629,11 @@ impl ast::Instruction { impl ast::Arg1 { fn map_id U>(self, f: &mut F) -> ast::Arg1 { - ast::Arg1 { dst: f(self.dst) } + ast::Arg1 { src: f(self.src) } } -} -impl ast::Arg1 { - fn for_dst_id(&self, f: &mut F) { - f(self.dst) + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.src); } } @@ -488,6 +644,11 @@ impl ast::Arg2 { src: self.src.map_id(f), } } + + fn visit_id_mut(&mut self, f: &mut F) { + f(true, &mut self.dst); + self.src.visit_id_mut(f); + } } impl ast::Arg2 { @@ -503,6 +664,11 @@ impl ast::Arg2Mov { src: self.src.map_id(f), } } + + fn visit_id_mut(&mut self, f: &mut F) { + f(true, &mut self.dst); + self.src.visit_id_mut(f); + } } impl ast::Arg2Mov { @@ -519,6 +685,12 @@ impl ast::Arg3 { src2: self.src2.map_id(f), } } + + fn visit_id_mut(&mut self, f: &mut F) { + f(true, &mut self.dst); + self.src1.visit_id_mut(f); + self.src2.visit_id_mut(f); + } } impl ast::Arg3 { @@ -536,6 +708,13 @@ impl ast::Arg4 { src2: self.src2.map_id(f), } } + + fn visit_id_mut(&mut self, f: &mut F) { + f(true, &mut self.dst1); + self.dst2.as_mut().map(|i| f(true, i)); + self.src1.visit_id_mut(f); + self.src2.visit_id_mut(f); + } } impl ast::Arg4 { @@ -555,6 +734,14 @@ impl ast::Arg5 { src3: self.src3.map_id(f), } } + + fn visit_id_mut(&mut self, f: &mut F) { + f(true, &mut self.dst1); + self.dst2.as_mut().map(|i| f(true, i)); + self.src1.visit_id_mut(f); + self.src2.visit_id_mut(f); + self.src3.visit_id_mut(f); + } } impl ast::Arg5 { @@ -572,11 +759,13 @@ impl ast::Operand { ast::Operand::Imm(v) => ast::Operand::Imm(v), } } -} -impl ast::Operand { - fn for_dst_id(&self, f: &mut F) { - unreachable!() + fn visit_id_mut(&mut self, f: &mut F) { + match self { + ast::Operand::Reg(i) => f(false, i), + ast::Operand::RegOffset(i, _) => f(false, i), + ast::Operand::Imm(_) => (), + } } } @@ -587,12 +776,10 @@ impl ast::MovOperand { ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2), } } -} -impl ast::MovOperand { - fn for_dst_id(&self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { match self { - ast::MovOperand::Op(o) => o.for_dst_id(f), + ast::MovOperand::Op(o) => o.visit_id_mut(f), ast::MovOperand::Vec(_, _) => (), } } @@ -727,7 +914,7 @@ mod tests { Statement::Label(12), Statement::Instruction( None, - ast::Instruction::Bra(ast::BraData {}, ast::Arg1 { dst: 12 }), + ast::Instruction::Bra(ast::BraData {}, ast::Arg1 { src: 12 }), ), ]; let bbs = get_basic_blocks(&func);