Start doing SSA conversion

This commit is contained in:
Andrzej Janik
2020-04-22 00:55:49 +02:00
parent 0c71826bc7
commit 7b2bc69330
4 changed files with 155 additions and 38 deletions

View File

@ -1,7 +1,4 @@
use std::convert::From; use std::convert::From;
use std::convert::Into;
use std::error::Error;
use std::mem;
use std::num::ParseIntError; use std::num::ParseIntError;
quick_error! { quick_error! {

View File

@ -9,6 +9,7 @@ extern crate spirv_headers as spirv;
lalrpop_mod!(ptx); lalrpop_mod!(ptx);
#[cfg(test)]
mod test; mod test;
mod translate; mod translate;
pub mod ast; pub mod ast;

View File

@ -2,7 +2,7 @@ use super::ptx;
fn parse_and_assert(s: &str) { fn parse_and_assert(s: &str) {
let mut errors = Vec::new(); let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
assert!(errors.len() == 0); assert!(errors.len() == 0);
} }
@ -12,6 +12,7 @@ fn empty() {
} }
#[test] #[test]
#[allow(non_snake_case)]
fn vectorAdd_kernel64_ptx() { fn vectorAdd_kernel64_ptx() {
let vector_add = include_str!("vectorAdd_kernel64.ptx"); let vector_add = include_str!("vectorAdd_kernel64.ptx");
parse_and_assert(vector_add); parse_and_assert(vector_add);

View File

@ -1,8 +1,8 @@
use crate::ast; use crate::ast;
use bit_vec::BitVec; use bit_vec::BitVec;
use rspirv::dr; use rspirv::dr;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet}; use std::collections::{BTreeMap, HashMap, HashSet};
use std::{cell::RefCell, ptr};
#[derive(PartialEq, Eq, Hash, Clone, Copy)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType { enum SpirvType {
@ -57,23 +57,8 @@ impl TypeWordMap {
} }
} }
struct IdWordMap<'a>(HashMap<&'a str, spirv::Word>);
impl<'a> IdWordMap<'a> {
fn new() -> Self {
IdWordMap(HashMap::new())
}
}
impl<'a> IdWordMap<'a> {
fn get_or_add(&mut self, b: &mut dr::Builder, id: &'a str) -> spirv::Word {
*self.0.entry(id).or_insert_with(|| b.id())
}
}
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> { pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
let mut builder = dr::Builder::new(); let mut builder = dr::Builder::new();
let mut ids = IdWordMap::new();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 0); builder.set_version(1, 0);
emit_capabilities(&mut builder); emit_capabilities(&mut builder);
@ -82,7 +67,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
emit_memory_model(&mut builder); emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder); let mut map = TypeWordMap::new(&mut builder);
for f in ast.functions { for f in ast.functions {
emit_function(&mut builder, &mut map, &mut ids, f)?; emit_function(&mut builder, &mut map, f)?;
} }
Ok(vec![]) Ok(vec![])
} }
@ -111,9 +96,8 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn emit_function<'a>( fn emit_function<'a>(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
ids: &mut IdWordMap<'a>,
f: ast::Function<'a>, f: ast::Function<'a>,
) -> Result<(), rspirv::dr::Error> { ) -> Result<spirv::Word, rspirv::dr::Error> {
let func_id = builder.begin_function( let func_id = builder.begin_function(
map.void(), map.void(),
None, None,
@ -128,15 +112,19 @@ fn emit_function<'a>(
let bbs = get_basic_blocks(&normalized_ids); let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs); let rpostorder = to_reverse_postorder(&bbs);
let dom_fronts = dominance_frontiers(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &rpostorder);
let ssa = ssa_legalize(normalized_ids, dom_fronts); let (ops, phis) = ssa_legalize(normalized_ids, bbs, &dom_fronts);
emit_function_body_ops(ssa, builder); emit_function_body_ops(builder, ops, phis);
builder.ret()?; builder.ret()?;
builder.end_function()?; builder.end_function()?;
Ok(()) Ok(func_id)
} }
fn emit_function_body_ops(ssa: Vec<Statement>, builder: &mut dr::Builder) { fn emit_function_body_ops(
unimplemented!() builder: &mut dr::Builder,
ops: Vec<Statement>,
phis: Vec<RefCell<PhiBasicBlock>>,
) {
todo!()
} }
// TODO: support scopes // TODO: support scopes
@ -158,8 +146,47 @@ fn normalize_identifiers<'a>(func: Vec<ast::Statement<&'a str>>) -> Vec<Statemen
result result
} }
fn ssa_legalize(func: Vec<Statement>, dom_fronts: Vec<HashSet<BBIndex>>) -> Vec<Statement> { fn ssa_legalize(
unimplemented!() func: Vec<Statement>,
bbs: Vec<BasicBlock>,
dom_fronts: &Vec<HashSet<BBIndex>>,
) -> (Vec<Statement>, Vec<RefCell<PhiBasicBlock>>) {
let mut phis = gather_phi_sets(&func, &bbs, dom_fronts);
trim_singleton_phi_sets(&mut phis);
todo!()
}
fn gather_phi_sets(
func: &Vec<Statement>,
bbs: &Vec<BasicBlock>,
dom_fronts: &Vec<HashSet<BBIndex>>,
) -> Vec<HashMap<spirv::Word, HashSet<BBIndex>>> {
let mut phis = vec![HashMap::new(); bbs.len()];
for (bb_idx, bb) in bbs.iter().enumerate() {
let StmtIndex(start) = bb.start;
let end = if bb_idx == bbs.len() - 1 {
bbs.len()
} else {
bbs[bb_idx + 1].start.0
};
for s in func[start..end].iter() {
s.for_dst_id(&mut |id| {
for BBIndex(phi_target) in dom_fronts[bb_idx].iter() {
phis[*phi_target]
.entry(id)
.or_insert_with(|| HashSet::new())
.insert(BBIndex(bb_idx));
}
});
}
}
phis
}
fn trim_singleton_phi_sets(phis: &mut Vec<HashMap<spirv::Word, HashSet<BBIndex>>>) {
for phi_map in phis.iter_mut() {
phi_map.retain(|_, set| set.len() > 1);
}
} }
fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> { fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
@ -179,7 +206,6 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
Statement::Label(id) => { Statement::Label(id) => {
labels.insert(id, StmtIndex(idx)); labels.insert(id, StmtIndex(idx));
} }
Statement::Phi(_) => (),
}; };
} }
let mut bbs_map = BTreeMap::new(); let mut bbs_map = BTreeMap::new();
@ -322,10 +348,10 @@ fn to_reverse_postorder(input: &Vec<BasicBlock>) -> Vec<BBIndex> {
result result
} }
#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)] struct PhiBasicBlock {
struct StmtIndex(pub usize); bb: BasicBlock,
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)] phi: Vec<(spirv::Word, Vec<(spirv::Word, BBIndex)>)>,
struct BBIndex(pub usize); }
#[derive(Eq, PartialEq, Debug, Clone)] #[derive(Eq, PartialEq, Debug, Clone)]
struct BasicBlock { struct BasicBlock {
@ -334,10 +360,17 @@ struct BasicBlock {
succ: Vec<BBIndex>, succ: Vec<BBIndex>,
} }
#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
struct StmtIndex(pub usize);
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
struct BBIndex(pub usize);
enum Statement { enum Statement {
Label(u32), Label(u32),
Instruction(Option<ast::PredAt<u32>>, ast::Instruction<u32>), Instruction(
Phi(Vec<spirv::Word>), Option<ast::PredAt<spirv::Word>>,
ast::Instruction<spirv::Word>,
),
} }
impl Statement { impl Statement {
@ -353,6 +386,16 @@ impl Statement {
ast::Statement::Variable(_) => None, ast::Statement::Variable(_) => None,
} }
} }
fn for_dst_id<F: FnMut(spirv::Word)>(&self, f: &mut F) {
match self {
Statement::Label(id) => f(*id),
Statement::Instruction(pred, inst) => {
pred.as_ref().map(|p| p.for_dst_id(f));
inst.for_dst_id(f);
}
}
}
} }
impl<T> ast::PredAt<T> { impl<T> ast::PredAt<T> {
@ -364,6 +407,10 @@ impl<T> ast::PredAt<T> {
} }
} }
impl<T: Copy> ast::PredAt<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {}
}
impl<T> ast::Instruction<T> { impl<T> ast::Instruction<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> { fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
match self { match self {
@ -387,7 +434,7 @@ impl<T> ast::Instruction<T> {
impl<T: Copy> ast::Instruction<T> { impl<T: Copy> ast::Instruction<T> {
fn jump_target(&self) -> Option<T> { fn jump_target(&self) -> Option<T> {
match self { match self {
ast::Instruction::Bra(d, a) => Some(a.dst), ast::Instruction::Bra(_, a) => Some(a.dst),
ast::Instruction::Ld(_, _) ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _) | ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _) | ast::Instruction::Mul(_, _)
@ -402,6 +449,24 @@ impl<T: Copy> ast::Instruction<T> {
| ast::Instruction::Ret(_) => None, | ast::Instruction::Ret(_) => None,
} }
} }
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
match self {
ast::Instruction::Bra(_, a) => a.for_dst_id(f),
ast::Instruction::Ld(_, a) => a.for_dst_id(f),
ast::Instruction::Mov(_, a) => a.for_dst_id(f),
ast::Instruction::Mul(_, a) => a.for_dst_id(f),
ast::Instruction::Add(_, a) => a.for_dst_id(f),
ast::Instruction::Setp(_, a) => a.for_dst_id(f),
ast::Instruction::SetpBool(_, a) => a.for_dst_id(f),
ast::Instruction::Not(_, a) => a.for_dst_id(f),
ast::Instruction::Cvt(_, a) => a.for_dst_id(f),
ast::Instruction::Shl(_, a) => a.for_dst_id(f),
ast::Instruction::St(_, a) => a.for_dst_id(f),
ast::Instruction::At(_, a) => a.for_dst_id(f),
ast::Instruction::Ret(_) => (),
}
}
} }
impl<T> ast::Arg1<T> { impl<T> ast::Arg1<T> {
@ -410,6 +475,12 @@ impl<T> ast::Arg1<T> {
} }
} }
impl<T: Copy> ast::Arg1<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst)
}
}
impl<T> ast::Arg2<T> { impl<T> ast::Arg2<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2<U> { fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2<U> {
ast::Arg2 { ast::Arg2 {
@ -419,6 +490,12 @@ impl<T> ast::Arg2<T> {
} }
} }
impl<T: Copy> ast::Arg2<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst);
}
}
impl<T> ast::Arg2Mov<T> { impl<T> ast::Arg2Mov<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> { fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> {
ast::Arg2Mov { ast::Arg2Mov {
@ -428,6 +505,12 @@ impl<T> ast::Arg2Mov<T> {
} }
} }
impl<T: Copy> ast::Arg2Mov<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst);
}
}
impl<T> ast::Arg3<T> { impl<T> ast::Arg3<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg3<U> { fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg3<U> {
ast::Arg3 { ast::Arg3 {
@ -438,6 +521,12 @@ impl<T> ast::Arg3<T> {
} }
} }
impl<T: Copy> ast::Arg3<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst);
}
}
impl<T> ast::Arg4<T> { impl<T> ast::Arg4<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> { fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
ast::Arg4 { ast::Arg4 {
@ -449,6 +538,13 @@ impl<T> ast::Arg4<T> {
} }
} }
impl<T: Copy> ast::Arg4<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst1);
self.dst2.map(|t| f(t));
}
}
impl<T> ast::Arg5<T> { impl<T> ast::Arg5<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg5<U> { fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg5<U> {
ast::Arg5 { ast::Arg5 {
@ -461,6 +557,13 @@ impl<T> ast::Arg5<T> {
} }
} }
impl<T: Copy> ast::Arg5<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst1);
self.dst2.map(|t| f(t));
}
}
impl<T> ast::Operand<T> { impl<T> ast::Operand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Operand<U> { fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Operand<U> {
match self { match self {
@ -471,6 +574,12 @@ impl<T> ast::Operand<T> {
} }
} }
impl<T: Copy> ast::Operand<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
unreachable!()
}
}
impl<T> ast::MovOperand<T> { impl<T> ast::MovOperand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> { fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
match self { match self {
@ -480,6 +589,15 @@ impl<T> ast::MovOperand<T> {
} }
} }
impl<T: Copy> ast::MovOperand<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
match self {
ast::MovOperand::Op(o) => o.for_dst_id(f),
ast::MovOperand::Vec(_, _) => (),
}
}
}
// CFGs below taken from "Modern Compiler Implementation in Java" // CFGs below taken from "Modern Compiler Implementation in Java"
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {