diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 76d72bf..ac78cbe 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -3,6 +3,7 @@ use bit_vec::BitVec; use rspirv::dr; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; +use std::fmt; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { @@ -266,20 +267,20 @@ fn get_bb_body_mut<'a>( all_bb: &[BasicBlock], bb: BBIndex, ) -> &'a mut [Statement] { - let (start, end) = get_bb_body_idx(all_bb, bb); + let (start, end) = get_bb_body_idx(func, all_bb, bb); &mut func[start..end] } fn get_bb_body<'a>(func: &'a [Statement], all_bb: &[BasicBlock], bb: BBIndex) -> &'a [Statement] { - let (start, end) = get_bb_body_idx(all_bb, bb); + let (start, end) = get_bb_body_idx(func, all_bb, bb); &func[start..end] } -fn get_bb_body_idx(all_bb: &[BasicBlock], bb: BBIndex) -> (usize, usize) { +fn get_bb_body_idx(func: &[Statement], all_bb: &[BasicBlock], bb: BBIndex) -> (usize, usize) { let BBIndex(bb_idx) = bb; let start = all_bb[bb_idx].start.0; let end = if bb_idx == all_bb.len() - 1 { - all_bb.len() + func.len() } else { all_bb[bb_idx + 1].start.0 }; @@ -466,16 +467,17 @@ fn dominance_frontiers(bbs: &[BasicBlock], doms: &[BBIndex]) -> Vec, order: &Vec) -> Vec { - let mut doms = vec![BBIndex(usize::max_value()); bbs.len()]; + let undefined = BBIndex(usize::max_value()); + let mut doms = vec![undefined; bbs.len()]; doms[0] = BBIndex(0); let mut changed = true; while changed { changed = false; for BBIndex(bb_idx) in order.iter().skip(1) { let bb = &bbs[*bb_idx]; - if let Some(first_pred) = bb.pred.get(0) { + if let Some(first_pred) = bb.pred.iter().find(|bb| doms[bb.0] != undefined) { let mut new_idom = *first_pred; - for BBIndex(p_idx) in bb.pred.iter().copied().skip(1) { + for BBIndex(p_idx) in bb.pred.iter().copied().filter(|bb| bb != first_pred) { if doms[p_idx] != BBIndex(usize::max_value()) { new_idom = intersect(&mut doms, BBIndex(p_idx), new_idom); } @@ -546,9 +548,22 @@ struct BasicBlock { #[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)] struct StmtIndex(pub usize); + +impl fmt::Display for StmtIndex { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + #[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)] struct BBIndex(pub usize); +impl fmt::Display for BBIndex { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + enum Statement { Label(u32), Instruction( @@ -1187,30 +1202,32 @@ mod tests { } // page 403 + const fig_19_4: &'static str = "{ + mov.u32 i, 1; + mov.u32 j, 1; + mov.u32 k, 0; + block_2: + setp.ge.u32 p, k, 100; + @p bra block_4; + block_3: + setp.ge.u32 q, j, 20; + @q bra block_6; + block_5: + mov.u32 j, i; + add.u32 k, k, 1; + bra block_7; + block_6: + mov.u32 j, k; + add.u32 k, k, 2; + block_7: + bra block_2; + block_4: + ret; + }"; + #[test] - fn gather_phi_sets_19_4() { - let func = "{ - mov.u32 i, 1; - mov.u32 j, 1; - mov.u32 k, 0; - block_2: - setp.ge.u32 p, k, 100; - @p bra block_4; - block_3: - setp.ge.u32 q, j, 20; - @q bra block_6; - block_5: - mov.u32 j, i; - add.u32 k, k, 1; - bra block_7; - block_6: - mov.u32 j, k; - add.u32 k, k, 2; - block_7: - bra block_2; - block_4: - ret; - }"; + fn gather_phi_sets_fig_19_4() { + let func = fig_19_4; let mut errors = Vec::new(); let ast = ptx::FunctionBodyParser::new() .parse(&mut errors, func) @@ -1301,6 +1318,36 @@ mod tests { ] } + // cfg from 19.4 with slighlty shuffled order of succ/pred + #[test] + fn reverse_postorder_fig_19_4() { + let mut cfg = cfg_fig_19_4(); + cfg[1].pred.swap(0, 1); + cfg[2].succ.swap(0, 1); + let rpostorder = vec![ + BBIndex(0), + BBIndex(1), + BBIndex(6), + BBIndex(2), + BBIndex(3), + BBIndex(4), + BBIndex(5), + ]; + let doms = immediate_dominators(&cfg, &rpostorder); + assert_eq!( + doms, + vec![ + BBIndex(0), + BBIndex(0), + BBIndex(1), + BBIndex(2), + BBIndex(2), + BBIndex(2), + BBIndex(1) + ] + ); + } + // page 403 #[test] fn dominance_frontiers_fig_19_4() {