mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-18 17:56:22 +03:00
Fix bugs in basic block resolution
This commit is contained in:
@ -86,7 +86,7 @@ FunctionInput: ast::Argument<'input> = {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
FunctionBody: Vec<ast::Statement<&'input str>> = {
|
pub(crate) FunctionBody: Vec<ast::Statement<&'input str>> = {
|
||||||
"{" <s:Statement*> "}" => { without_none(s) }
|
"{" <s:Statement*> "}" => { without_none(s) }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -251,7 +251,11 @@ fn rename_succesor_phi_src(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pop_stacks(ssa_state: &mut SSARewriteState, old_phi: &HashSet<spirv::Word>, old_ids: &[spirv::Word]) {
|
fn pop_stacks(
|
||||||
|
ssa_state: &mut SSARewriteState,
|
||||||
|
old_phi: &HashSet<spirv::Word>,
|
||||||
|
old_ids: &[spirv::Word],
|
||||||
|
) {
|
||||||
for id in old_phi.iter().chain(old_ids) {
|
for id in old_phi.iter().chain(old_ids) {
|
||||||
ssa_state.pop(*id);
|
ssa_state.pop(*id);
|
||||||
}
|
}
|
||||||
@ -335,7 +339,7 @@ fn gather_phi_sets(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (id, to_work ) in def_sites.iter_mut().enumerate() {
|
for (id, to_work) in def_sites.iter_mut().enumerate() {
|
||||||
let id = id as spirv::Word;
|
let id = id as spirv::Word;
|
||||||
let (ref mut set, ref mut stack) = to_work;
|
let (ref mut set, ref mut stack) = to_work;
|
||||||
loop {
|
loop {
|
||||||
@ -358,18 +362,26 @@ fn gather_phi_sets(
|
|||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
|
fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
|
||||||
let mut direct_bb_start = Vec::new();
|
// edge signify pred/succ relationship between bbs
|
||||||
let mut indirect_bb_start = Vec::new();
|
let mut bb_edge = HashSet::new();
|
||||||
|
let mut unresolved_bb_edge = Vec::new();
|
||||||
|
// bb start means that a bb is starting at this statement, but there's no predecessor
|
||||||
|
let mut bb_start = Vec::new();
|
||||||
let mut labels = HashMap::new();
|
let mut labels = HashMap::new();
|
||||||
for (idx, s) in fun.iter().enumerate() {
|
for (idx, s) in fun.iter().enumerate() {
|
||||||
match s {
|
match s {
|
||||||
Statement::Instruction(_, i) => {
|
Statement::Instruction(pred, i) => {
|
||||||
if let Some(id) = i.jump_target() {
|
if let Some(id) = i.jump_target() {
|
||||||
indirect_bb_start.push((StmtIndex(idx), id));
|
unresolved_bb_edge.push((StmtIndex(idx), id));
|
||||||
if idx + 1 < fun.len() {
|
if idx + 1 < fun.len() {
|
||||||
direct_bb_start.push((StmtIndex(idx), StmtIndex(idx + 1)));
|
if pred.is_some() {
|
||||||
|
bb_edge.insert((StmtIndex(idx), StmtIndex(idx + 1)));
|
||||||
|
}
|
||||||
|
bb_start.push(StmtIndex(idx + 1));
|
||||||
}
|
}
|
||||||
|
} else if i.is_terminal() && idx + 1 < fun.len() {
|
||||||
|
bb_start.push(StmtIndex(idx + 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Statement::Label(id) => {
|
Statement::Label(id) => {
|
||||||
@ -377,6 +389,25 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
// Resolve every <jump into label> into <jump into statement index>
|
||||||
|
// TODO: handle jumps into nowhere
|
||||||
|
for (idx, id) in unresolved_bb_edge {
|
||||||
|
let target = labels[&id];
|
||||||
|
bb_edge.insert((idx, target));
|
||||||
|
bb_start.push(target);
|
||||||
|
// now check if the preceding statement forms an edge
|
||||||
|
if target != StmtIndex(0) {
|
||||||
|
match &fun[target.0 - 1] {
|
||||||
|
Statement::Instruction(pred, i) => {
|
||||||
|
if !((pred.is_none() && i.jump_target().is_some()) || i.is_terminal()) {
|
||||||
|
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Statement::Label(_) => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Create list of bbs without succ/pred
|
||||||
let mut bbs_map = BTreeMap::new();
|
let mut bbs_map = BTreeMap::new();
|
||||||
bbs_map.insert(
|
bbs_map.insert(
|
||||||
StmtIndex(0),
|
StmtIndex(0),
|
||||||
@ -386,32 +417,22 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
|
|||||||
succ: Vec::new(),
|
succ: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
// TODO: handle jumps into nowhere
|
for bb_first_stmt in bb_start {
|
||||||
let resolved_indirect_bb_start = indirect_bb_start
|
bbs_map.entry(bb_first_stmt).or_insert_with(|| BasicBlock {
|
||||||
.into_iter()
|
start: bb_first_stmt,
|
||||||
.map(|(idx, id)| (idx, labels[&id]))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
for (_, to) in direct_bb_start
|
|
||||||
.iter()
|
|
||||||
.chain(resolved_indirect_bb_start.iter())
|
|
||||||
{
|
|
||||||
bbs_map.entry(*to).or_insert_with(|| BasicBlock {
|
|
||||||
start: *to,
|
|
||||||
pred: Vec::new(),
|
pred: Vec::new(),
|
||||||
succ: Vec::new(),
|
succ: Vec::new(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
// Populate succ/pred
|
||||||
let indexed_bbs_map = bbs_map
|
let indexed_bbs_map = bbs_map
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val))))
|
.map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val))))
|
||||||
.collect::<BTreeMap<_, _>>();
|
.collect::<BTreeMap<_, _>>();
|
||||||
for (from, to) in direct_bb_start
|
for (from, to) in bb_edge {
|
||||||
.iter()
|
let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=from).next_back().unwrap();
|
||||||
.chain(resolved_indirect_bb_start.iter())
|
let (to_idx, to_ref) = indexed_bbs_map.get(&to).unwrap();
|
||||||
{
|
|
||||||
let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=*from).next_back().unwrap();
|
|
||||||
let (to_idx, to_ref) = indexed_bbs_map.get(to).unwrap();
|
|
||||||
{
|
{
|
||||||
from_ref.borrow_mut().succ.push(*to_idx);
|
from_ref.borrow_mut().succ.push(*to_idx);
|
||||||
}
|
}
|
||||||
@ -527,9 +548,9 @@ struct BasicBlock {
|
|||||||
succ: Vec<BBIndex>,
|
succ: Vec<BBIndex>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
|
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
|
||||||
struct StmtIndex(pub usize);
|
struct StmtIndex(pub usize);
|
||||||
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
|
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
|
||||||
struct BBIndex(pub usize);
|
struct BBIndex(pub usize);
|
||||||
|
|
||||||
enum Statement {
|
enum Statement {
|
||||||
@ -646,6 +667,23 @@ impl<T: Copy> ast::Instruction<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_terminal(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
ast::Instruction::Ret(_) => true,
|
||||||
|
ast::Instruction::Ld(_, _)
|
||||||
|
| ast::Instruction::Mov(_, _)
|
||||||
|
| ast::Instruction::Mul(_, _)
|
||||||
|
| ast::Instruction::Add(_, _)
|
||||||
|
| ast::Instruction::Setp(_, _)
|
||||||
|
| ast::Instruction::SetpBool(_, _)
|
||||||
|
| ast::Instruction::Not(_, _)
|
||||||
|
| ast::Instruction::Cvt(_, _)
|
||||||
|
| ast::Instruction::Shl(_, _)
|
||||||
|
| ast::Instruction::St(_, _)
|
||||||
|
| ast::Instruction::Bra(_, _) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
|
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
|
||||||
match self {
|
match self {
|
||||||
ast::Instruction::Ld(_, a) => a.for_dst_id(f),
|
ast::Instruction::Ld(_, a) => a.for_dst_id(f),
|
||||||
@ -826,6 +864,8 @@ impl<T> ast::MovOperand<T> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::ast;
|
||||||
|
use crate::ptx;
|
||||||
|
|
||||||
// page 411
|
// page 411
|
||||||
#[test]
|
#[test]
|
||||||
@ -1140,4 +1180,84 @@ mod tests {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sort_pred_succ(bb: &mut BasicBlock) {
|
||||||
|
bb.pred.sort();
|
||||||
|
bb.succ.sort();
|
||||||
|
}
|
||||||
|
|
||||||
|
// page 403
|
||||||
|
#[test]
|
||||||
|
fn gather_phi_sets_19_4() {
|
||||||
|
let func = "{
|
||||||
|
mov.u32 i, 1;
|
||||||
|
mov.u32 j, 1;
|
||||||
|
mov.u32 k, 0;
|
||||||
|
block_2:
|
||||||
|
setp.ge.u32 p, k, 100;
|
||||||
|
@p bra block_4;
|
||||||
|
block_3:
|
||||||
|
setp.ge.u32 q, j, 20;
|
||||||
|
@q bra block_6;
|
||||||
|
block_5:
|
||||||
|
mov.u32 j, i;
|
||||||
|
add.u32 k, k, 1;
|
||||||
|
bra block_7;
|
||||||
|
block_6:
|
||||||
|
mov.u32 j, k;
|
||||||
|
add.u32 k, k, 2;
|
||||||
|
block_7:
|
||||||
|
bra block_2;
|
||||||
|
block_4:
|
||||||
|
ret;
|
||||||
|
}";
|
||||||
|
let mut errors = Vec::new();
|
||||||
|
let ast = ptx::FunctionBodyParser::new()
|
||||||
|
.parse(&mut errors, func)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(errors.len(), 0);
|
||||||
|
let (normalized_ids, _) = normalize_identifiers(ast);
|
||||||
|
let mut bbs = get_basic_blocks(&normalized_ids);
|
||||||
|
bbs.iter_mut().for_each(sort_pred_succ);
|
||||||
|
assert_eq!(
|
||||||
|
bbs,
|
||||||
|
vec![
|
||||||
|
BasicBlock {
|
||||||
|
start: StmtIndex(0),
|
||||||
|
pred: vec![],
|
||||||
|
succ: vec![BBIndex(1)]
|
||||||
|
},
|
||||||
|
BasicBlock {
|
||||||
|
start: StmtIndex(3),
|
||||||
|
pred: vec![BBIndex(0), BBIndex(5)],
|
||||||
|
succ: vec![BBIndex(2), BBIndex(6)]
|
||||||
|
},
|
||||||
|
BasicBlock {
|
||||||
|
start: StmtIndex(6),
|
||||||
|
pred: vec![BBIndex(1)],
|
||||||
|
succ: vec![BBIndex(3), BBIndex(4)]
|
||||||
|
},
|
||||||
|
BasicBlock {
|
||||||
|
start: StmtIndex(9),
|
||||||
|
pred: vec![BBIndex(2)],
|
||||||
|
succ: vec![BBIndex(5)]
|
||||||
|
},
|
||||||
|
BasicBlock {
|
||||||
|
start: StmtIndex(13),
|
||||||
|
pred: vec![BBIndex(2)],
|
||||||
|
succ: vec![BBIndex(5)]
|
||||||
|
},
|
||||||
|
BasicBlock {
|
||||||
|
start: StmtIndex(16),
|
||||||
|
pred: vec![BBIndex(3), BBIndex(4)],
|
||||||
|
succ: vec![BBIndex(1)]
|
||||||
|
},
|
||||||
|
BasicBlock {
|
||||||
|
start: StmtIndex(18),
|
||||||
|
pred: vec![BBIndex(1)],
|
||||||
|
succ: vec![]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user