mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-19 02:06:32 +03:00
Start introducing support for bitcast coercions in instructions
This commit is contained in:
@ -238,7 +238,9 @@ pub struct MovData {}
|
|||||||
|
|
||||||
pub struct MulData {}
|
pub struct MulData {}
|
||||||
|
|
||||||
pub struct AddData {}
|
pub struct AddData {
|
||||||
|
pub typ: ScalarType,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct SetpData {}
|
pub struct SetpData {}
|
||||||
|
|
||||||
|
@ -161,6 +161,7 @@ Variable: ast::Variable<&'input str> = {
|
|||||||
|
|
||||||
VariableName: (&'input str, Option<u32>) = {
|
VariableName: (&'input str, Option<u32>) = {
|
||||||
<id:ID> => (id, None),
|
<id:ID> => (id, None),
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
|
||||||
<id:ParametrizedID> => {
|
<id:ParametrizedID> => {
|
||||||
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
|
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
|
||||||
let count = id[left_angle+1..id.len()-1].parse::<u32>();
|
let count = id[left_angle+1..id.len()-1].parse::<u32>();
|
||||||
@ -270,9 +271,13 @@ RoundingMode = {
|
|||||||
".rn", ".rz", ".rm", ".rp"
|
".rn", ".rz", ".rm", ".rp"
|
||||||
};
|
};
|
||||||
|
|
||||||
IntType = {
|
IntType : ast::ScalarType = {
|
||||||
".u16", ".u32", ".u64",
|
".u16" => ast::ScalarType::U16,
|
||||||
".s16", ".s32", ".s64",
|
".u32" => ast::ScalarType::U32,
|
||||||
|
".u64" => ast::ScalarType::U64,
|
||||||
|
".s16" => ast::ScalarType::S16,
|
||||||
|
".s32" => ast::ScalarType::S32,
|
||||||
|
".s64" => ast::ScalarType::S64,
|
||||||
};
|
};
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
|
||||||
@ -283,12 +288,12 @@ InstAdd: ast::Instruction<&'input str> = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
InstAddMode: ast::AddData = {
|
InstAddMode: ast::AddData = {
|
||||||
IntType => ast::AddData{},
|
<t:IntType> => ast::AddData{ typ: t },
|
||||||
".sat" ".s32" => ast::AddData{},
|
".sat" ".s32" => ast::AddData{ typ: ast::ScalarType::S32 },
|
||||||
RoundingMode? ".ftz"? ".sat"? ".f32" => ast::AddData{},
|
RoundingMode? ".ftz"? ".sat"? ".f32" => ast::AddData{ typ: ast::ScalarType::F32 },
|
||||||
RoundingMode? ".f64" => ast::AddData{},
|
RoundingMode? ".f64" => ast::AddData{ typ: ast::ScalarType::F64 },
|
||||||
".rn"? ".ftz"? ".sat"? ".f16" => ast::AddData{},
|
".rn"? ".ftz"? ".sat"? ".f16" => ast::AddData{ typ: ast::ScalarType::F16 },
|
||||||
".rn"? ".ftz"? ".sat"? ".f16x2" => ast::AddData{}
|
".rn"? ".ftz"? ".sat"? ".f16x2" => todo!()
|
||||||
};
|
};
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
|
||||||
|
@ -3,7 +3,7 @@ use bit_vec::BitVec;
|
|||||||
use rspirv::dr;
|
use rspirv::dr;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||||
use std::fmt;
|
use std::{borrow::Cow, fmt};
|
||||||
|
|
||||||
use rspirv::binary::Assemble;
|
use rspirv::binary::Assemble;
|
||||||
|
|
||||||
@ -125,13 +125,17 @@ fn emit_function<'a>(
|
|||||||
let mut contant_ids = HashMap::new();
|
let mut contant_ids = HashMap::new();
|
||||||
collect_arg_ids(&mut contant_ids, &f.args);
|
collect_arg_ids(&mut contant_ids, &f.args);
|
||||||
collect_label_ids(&mut contant_ids, &f.body);
|
collect_label_ids(&mut contant_ids, &f.body);
|
||||||
let (mut normalized_ids, unique_ids) = normalize_identifiers(f.body, &contant_ids);
|
let registers = collect_registers(&f.body);
|
||||||
let bbs = get_basic_blocks(&normalized_ids);
|
let (normalized_ids, unique_ids, type_check) =
|
||||||
|
normalize_identifiers(f.body, &contant_ids, registers);
|
||||||
|
let (mut func_body, unique_ids) =
|
||||||
|
insert_implicit_conversion(normalized_ids, unique_ids, &|x| type_check[&x]);
|
||||||
|
let bbs = get_basic_blocks(&func_body);
|
||||||
let rpostorder = to_reverse_postorder(&bbs);
|
let rpostorder = to_reverse_postorder(&bbs);
|
||||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||||
let dom_fronts = dominance_frontiers(&bbs, &doms);
|
let dom_fronts = dominance_frontiers(&bbs, &doms);
|
||||||
let (_, unique_ids) = ssa_legalize(
|
let (_, unique_ids) = ssa_legalize(
|
||||||
&mut normalized_ids,
|
&mut func_body,
|
||||||
contant_ids.len() as u32,
|
contant_ids.len() as u32,
|
||||||
unique_ids,
|
unique_ids,
|
||||||
&bbs,
|
&bbs,
|
||||||
@ -140,11 +144,70 @@ fn emit_function<'a>(
|
|||||||
);
|
);
|
||||||
let id_offset = builder.reserve_ids(unique_ids);
|
let id_offset = builder.reserve_ids(unique_ids);
|
||||||
emit_function_args(builder, id_offset, map, &f.args);
|
emit_function_args(builder, id_offset, map, &f.args);
|
||||||
emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?;
|
emit_function_body_ops(builder, id_offset, map, &func_body, &bbs)?;
|
||||||
builder.end_function()?;
|
builder.end_function()?;
|
||||||
Ok(func_id)
|
Ok(func_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap<Cow<'a, str>, ast::Type> {
|
||||||
|
let mut result = HashMap::new();
|
||||||
|
for s in body {
|
||||||
|
match s {
|
||||||
|
ast::Statement::Variable(var) => match var.count {
|
||||||
|
Some(count) => {
|
||||||
|
for i in 0..count {
|
||||||
|
result.insert(Cow::Owned(format!("{}{}", var.name, i)), var.v_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
result.insert(Cow::Borrowed(var.name), var.v_type);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ast::Statement::Label(_) | ast::Statement::Instruction(_, _) => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
There are three kinds of implicit conversions in PTX:
|
||||||
|
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||||
|
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||||
|
* pointer dereference in st/ld: not documented, but for instruction `ld.<space>.<type> x, [y]` semantics are x = *(<type>*)y
|
||||||
|
*/
|
||||||
|
fn insert_implicit_conversion<TypeCheck: Fn(spirv::Word) -> ast::Type>(
|
||||||
|
normalized_ids: Vec<Statement>,
|
||||||
|
unique_ids: spirv::Word,
|
||||||
|
type_check: &TypeCheck,
|
||||||
|
) -> (Vec<Statement>, spirv::Word) {
|
||||||
|
let mut id = unique_ids;
|
||||||
|
let new_id = &mut || {
|
||||||
|
let temp = id;
|
||||||
|
id += 1;
|
||||||
|
temp
|
||||||
|
};
|
||||||
|
let mut result = Vec::with_capacity(normalized_ids.len());
|
||||||
|
for s in normalized_ids.into_iter() {
|
||||||
|
match s {
|
||||||
|
Statement::Instruction(inst) => match inst {
|
||||||
|
ast::Instruction::Add(add, arg) => {
|
||||||
|
arg.insert_implicit_conversions(
|
||||||
|
&mut result,
|
||||||
|
ast::Type::Scalar(add.typ),
|
||||||
|
type_check,
|
||||||
|
new_id,
|
||||||
|
|arg| Statement::Instruction(ast::Instruction::Add(add, arg)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
_ => todo!(),
|
||||||
|
},
|
||||||
|
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
|
||||||
|
Statement::Converison(_) => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(result, id)
|
||||||
|
}
|
||||||
|
|
||||||
fn get_function_type(
|
fn get_function_type(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
@ -223,6 +286,7 @@ fn emit_function_body_ops(
|
|||||||
// If block startd with a label it has already been emitted,
|
// If block startd with a label it has already been emitted,
|
||||||
// all other labels in the block are unused
|
// all other labels in the block are unused
|
||||||
Statement::Label(_) => (),
|
Statement::Label(_) => (),
|
||||||
|
Statement::Converison(_) => todo!(),
|
||||||
Statement::Conditional(bra) => {
|
Statement::Conditional(bra) => {
|
||||||
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
|
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
|
||||||
}
|
}
|
||||||
@ -300,7 +364,8 @@ fn emit_function_body_ops(
|
|||||||
fn normalize_identifiers<'a>(
|
fn normalize_identifiers<'a>(
|
||||||
func: Vec<ast::Statement<&'a str>>,
|
func: Vec<ast::Statement<&'a str>>,
|
||||||
constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
|
constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
|
||||||
) -> (Vec<Statement>, spirv::Word) {
|
types: HashMap<Cow<'a, str>, ast::Type>,
|
||||||
|
) -> (Vec<Statement>, spirv::Word, HashMap<spirv::Word, ast::Type>) {
|
||||||
let mut result = Vec::with_capacity(func.len());
|
let mut result = Vec::with_capacity(func.len());
|
||||||
let mut id: u32 = constant_identifiers.len() as u32;
|
let mut id: u32 = constant_identifiers.len() as u32;
|
||||||
let mut remapped_ids = HashMap::new();
|
let mut remapped_ids = HashMap::new();
|
||||||
@ -324,7 +389,11 @@ fn normalize_identifiers<'a>(
|
|||||||
for s in func {
|
for s in func {
|
||||||
Statement::from_ast(s, &mut result, &mut get_or_add);
|
Statement::from_ast(s, &mut result, &mut get_or_add);
|
||||||
}
|
}
|
||||||
(result, id)
|
let mut type_map = HashMap::with_capacity(types.len());
|
||||||
|
for (old_id, new_id) in remapped_ids {
|
||||||
|
type_map.insert(new_id, types[old_id]);
|
||||||
|
}
|
||||||
|
(result, id, type_map)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ssa_legalize(
|
fn ssa_legalize(
|
||||||
@ -580,6 +649,7 @@ fn gather_phi_sets(
|
|||||||
match s {
|
match s {
|
||||||
Statement::Instruction(inst) => inst.visit_id(&mut visitor),
|
Statement::Instruction(inst) => inst.visit_id(&mut visitor),
|
||||||
Statement::Conditional(brc) => visitor(false, &brc.predicate),
|
Statement::Conditional(brc) => visitor(false, &brc.predicate),
|
||||||
|
Statement::Converison(conv) => conv.visit_id(&mut visitor),
|
||||||
// label redefinition is a compile-time error
|
// label redefinition is a compile-time error
|
||||||
Statement::Label(_) => (),
|
Statement::Label(_) => (),
|
||||||
}
|
}
|
||||||
@ -630,6 +700,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
|
|||||||
unresolved_bb_edge.push((StmtIndex(idx), bra.if_false));
|
unresolved_bb_edge.push((StmtIndex(idx), bra.if_false));
|
||||||
unresolved_bb_edge.push((StmtIndex(idx), bra.if_true));
|
unresolved_bb_edge.push((StmtIndex(idx), bra.if_true));
|
||||||
}
|
}
|
||||||
|
Statement::Converison(_) => (),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
let mut bb_edge = HashSet::new();
|
let mut bb_edge = HashSet::new();
|
||||||
@ -647,7 +718,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
|
|||||||
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Statement::Label(_) => {
|
Statement::Converison(_) | Statement::Label(_) => {
|
||||||
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
||||||
}
|
}
|
||||||
// This is already in `unresolved_bb_edge`
|
// This is already in `unresolved_bb_edge`
|
||||||
@ -816,6 +887,7 @@ enum Statement {
|
|||||||
Instruction(ast::Instruction<spirv::Word>),
|
Instruction(ast::Instruction<spirv::Word>),
|
||||||
// SPIR-V compatible replacement for PTX predicates
|
// SPIR-V compatible replacement for PTX predicates
|
||||||
Conditional(BrachCondition),
|
Conditional(BrachCondition),
|
||||||
|
Converison(ImplicitConversion),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct BrachCondition {
|
struct BrachCondition {
|
||||||
@ -823,6 +895,7 @@ struct BrachCondition {
|
|||||||
if_true: spirv::Word,
|
if_true: spirv::Word,
|
||||||
if_false: spirv::Word,
|
if_false: spirv::Word,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BrachCondition {
|
impl BrachCondition {
|
||||||
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
|
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
|
||||||
f(false, &self.predicate);
|
f(false, &self.predicate);
|
||||||
@ -837,6 +910,25 @@ impl BrachCondition {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ImplicitConversion {
|
||||||
|
dst: spirv::Word,
|
||||||
|
src: spirv::Word,
|
||||||
|
from: ast::Type,
|
||||||
|
to: ast::Type,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ImplicitConversion {
|
||||||
|
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
|
||||||
|
f(false, &self.src);
|
||||||
|
f(true, &self.dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||||
|
f(false, &mut self.src);
|
||||||
|
f(true, &mut self.dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Statement {
|
impl Statement {
|
||||||
fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>(
|
fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>(
|
||||||
s: ast::Statement<&'a str>,
|
s: ast::Statement<&'a str>,
|
||||||
@ -885,6 +977,7 @@ impl Statement {
|
|||||||
Statement::Label(id) => f(false, id),
|
Statement::Label(id) => f(false, id),
|
||||||
Statement::Instruction(inst) => inst.visit_id(f),
|
Statement::Instruction(inst) => inst.visit_id(f),
|
||||||
Statement::Conditional(bra) => bra.visit_id(f),
|
Statement::Conditional(bra) => bra.visit_id(f),
|
||||||
|
Statement::Converison(conv) => conv.visit_id(f),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -895,6 +988,7 @@ impl Statement {
|
|||||||
Statement::Label(id) => f(false, id),
|
Statement::Label(id) => f(false, id),
|
||||||
Statement::Instruction(inst) => inst.visit_id_mut(f),
|
Statement::Instruction(inst) => inst.visit_id_mut(f),
|
||||||
Statement::Conditional(bra) => bra.visit_id_mut(f),
|
Statement::Conditional(bra) => bra.visit_id_mut(f),
|
||||||
|
Statement::Converison(conv) => conv.visit_id_mut(f),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1068,6 +1162,31 @@ impl<T> ast::Arg3<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ast::Arg3<spirv::Word> {
|
||||||
|
fn insert_implicit_conversions<
|
||||||
|
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||||
|
NewId: FnMut() -> spirv::Word,
|
||||||
|
NewStatement: FnOnce(Self) -> Statement,
|
||||||
|
>(
|
||||||
|
self,
|
||||||
|
func: &mut Vec<Statement>,
|
||||||
|
op_type: ast::Type,
|
||||||
|
type_check: &TypeCheck,
|
||||||
|
new_id: &mut NewId,
|
||||||
|
new_statement: NewStatement,
|
||||||
|
) {
|
||||||
|
let src1 = self
|
||||||
|
.src1
|
||||||
|
.insert_implicit_conversion(func, op_type, type_check, new_id);
|
||||||
|
let src2 = self
|
||||||
|
.src2
|
||||||
|
.insert_implicit_conversion(func, op_type, type_check, new_id);
|
||||||
|
insert_implicit_conversion_dst(func, op_type, type_check, new_id, self.dst, |dst| {
|
||||||
|
new_statement(Self { dst, src1, src2 })
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T> ast::Arg4<T> {
|
impl<T> ast::Arg4<T> {
|
||||||
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
|
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
|
||||||
ast::Arg4 {
|
ast::Arg4 {
|
||||||
@ -1147,6 +1266,37 @@ impl<T> ast::Operand<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ast::Operand<spirv::Word> {
|
||||||
|
fn insert_implicit_conversion<
|
||||||
|
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||||
|
NewId: FnMut() -> spirv::Word,
|
||||||
|
>(
|
||||||
|
self,
|
||||||
|
func: &mut Vec<Statement>,
|
||||||
|
op_type: ast::Type,
|
||||||
|
type_check: &TypeCheck,
|
||||||
|
new_id: &mut NewId,
|
||||||
|
) -> Self {
|
||||||
|
match self {
|
||||||
|
ast::Operand::Reg(src) => {
|
||||||
|
if type_check(src) == op_type {
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
let new_src = new_id();
|
||||||
|
func.push(Statement::Converison(ImplicitConversion {
|
||||||
|
src: src,
|
||||||
|
dst: new_src,
|
||||||
|
from: type_check(src),
|
||||||
|
to: op_type,
|
||||||
|
}));
|
||||||
|
ast::Operand::Reg(new_src)
|
||||||
|
}
|
||||||
|
o @ ast::Operand::Imm(_) => o,
|
||||||
|
ast::Operand::RegOffset(_, _) => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T> ast::MovOperand<T> {
|
impl<T> ast::MovOperand<T> {
|
||||||
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
|
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
|
||||||
match self {
|
match self {
|
||||||
@ -1170,6 +1320,32 @@ impl<T> ast::MovOperand<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn insert_implicit_conversion_dst<
|
||||||
|
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||||
|
NewId: FnMut() -> spirv::Word,
|
||||||
|
NewStatement: FnOnce(spirv::Word) -> Statement,
|
||||||
|
>(
|
||||||
|
func: &mut Vec<Statement>,
|
||||||
|
op_type: ast::Type,
|
||||||
|
type_check: &TypeCheck,
|
||||||
|
new_id: &mut NewId,
|
||||||
|
dst: spirv::Word,
|
||||||
|
new_statement: NewStatement,
|
||||||
|
) {
|
||||||
|
if type_check(dst) == op_type {
|
||||||
|
func.push(new_statement(dst));
|
||||||
|
} else {
|
||||||
|
let new_dst = new_id();
|
||||||
|
func.push(new_statement(new_dst));
|
||||||
|
func.push(Statement::Converison(ImplicitConversion {
|
||||||
|
src: new_dst,
|
||||||
|
dst: dst,
|
||||||
|
from: type_check(new_dst),
|
||||||
|
to: op_type,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CFGs below taken from "Modern Compiler Implementation in Java"
|
// CFGs below taken from "Modern Compiler Implementation in Java"
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
Reference in New Issue
Block a user