Translate instruction ld

This commit is contained in:
Andrzej Janik
2020-05-07 00:37:10 +02:00
parent 3b433456a1
commit fa075abc22
3 changed files with 152 additions and 59 deletions

View File

@ -187,7 +187,53 @@ pub enum MovOperand<ID> {
Vec(String, String), Vec(String, String),
} }
pub struct LdData {} pub enum VectorPrefix {
V2,
V4
}
pub struct LdData {
pub qualifier: LdQualifier,
pub state_space: LdStateSpace,
pub caching: LdCacheOperator,
pub vector: Option<VectorPrefix>,
pub typ: ScalarType
}
#[derive(PartialEq, Eq)]
pub enum LdQualifier {
Weak,
Volatile,
Relaxed(LdScope),
Acquire(LdScope),
}
#[derive(PartialEq, Eq)]
pub enum LdScope {
Cta,
Gpu,
Sys
}
#[derive(PartialEq, Eq)]
pub enum LdStateSpace {
Generic,
Const,
Global,
Local,
Param,
Shared,
}
#[derive(PartialEq, Eq)]
pub enum LdCacheOperator {
Cached,
L2Only,
Streaming,
LastUse,
Uncached
}
pub struct MovData {} pub struct MovData {}
@ -201,7 +247,9 @@ pub struct SetpBoolData {}
pub struct NotData {} pub struct NotData {}
pub struct BraData {} pub struct BraData {
pub uniform: bool
}
pub struct CvtData {} pub struct CvtData {}

View File

@ -106,6 +106,16 @@ Type: ast::Type = {
}; };
ScalarType: ast::ScalarType = { ScalarType: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
MemoryType
};
ExtendedScalarType: ast::ExtendedScalarType = {
".f16x2" => ast::ExtendedScalarType::F16x2,
".pred" => ast::ExtendedScalarType::Pred,
};
MemoryType: ast::ScalarType = {
".b8" => ast::ScalarType::B8, ".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16, ".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32, ".b32" => ast::ScalarType::B32,
@ -118,23 +128,10 @@ ScalarType: ast::ScalarType = {
".s16" => ast::ScalarType::S16, ".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32, ".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64, ".s64" => ast::ScalarType::S64,
".f16" => ast::ScalarType::F16,
".f32" => ast::ScalarType::F32, ".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64, ".f64" => ast::ScalarType::F64,
}; };
ExtendedScalarType: ast::ExtendedScalarType = {
".f16x2" => ast::ExtendedScalarType::F16x2,
".pred" => ast::ExtendedScalarType::Pred,
};
BaseType = {
".b8", ".b16", ".b32", ".b64",
".u8", ".u16", ".u32", ".u64",
".s8", ".s16", ".s32", ".s64",
".f32", ".f64"
};
Statement: Option<ast::Statement<&'input str>> = { Statement: Option<ast::Statement<&'input str>> = {
<l:Label> => Some(ast::Statement::Label(l)), <l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None, DebugDirective => None,
@ -191,36 +188,47 @@ Instruction: ast::Instruction<&'input str> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<&'input str> = { InstLd: ast::Instruction<&'input str> = {
"ld" LdQualifier? LdStateSpace? LdCacheOperator? Vector? BaseType <dst:ID> "," "[" <src:Operand> "]" => { "ld" <q:LdQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => {
ast::Instruction::Ld(ast::LdData{}, ast::Arg2{dst:dst, src:src}) ast::Instruction::Ld(
ast::LdData {
qualifier: q.unwrap_or(ast::LdQualifier::Weak),
state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
vector: v,
typ: t
},
ast::Arg2 { dst:dst, src:src }
)
} }
}; };
LdQualifier: () = { LdQualifier: ast::LdQualifier = {
".weak", ".weak" => ast::LdQualifier::Weak,
".volatile", ".volatile" => ast::LdQualifier::Volatile,
".relaxed" LdScope, ".relaxed" <s:LdScope> => ast::LdQualifier::Relaxed(s),
".acquire" LdScope, ".acquire" <s:LdScope> => ast::LdQualifier::Acquire(s),
}; };
LdScope = { LdScope: ast::LdScope = {
".cta", ".gpu", ".sys" ".cta" => ast::LdScope::Cta,
".gpu" => ast::LdScope::Gpu,
".sys" => ast::LdScope::Sys
}; };
LdStateSpace = { LdStateSpace: ast::LdStateSpace = {
".const", ".const" => ast::LdStateSpace::Const,
".global", ".global" => ast::LdStateSpace::Global,
".local", ".local" => ast::LdStateSpace::Local,
".param", ".param" => ast::LdStateSpace::Param,
".shared", ".shared" => ast::LdStateSpace::Shared,
}; };
LdCacheOperator = { LdCacheOperator: ast::LdCacheOperator = {
".ca", ".ca" => ast::LdCacheOperator::Cached,
".cg", ".cg" => ast::LdCacheOperator::L2Only,
".cs", ".cs" => ast::LdCacheOperator::Streaming,
".lu", ".lu" => ast::LdCacheOperator::LastUse,
".cv", ".cv" => ast::LdCacheOperator::Uncached,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
@ -332,7 +340,7 @@ PredAt: ast::PredAt<&'input str> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
InstBra: ast::Instruction<&'input str> = { InstBra: ast::Instruction<&'input str> = {
"bra" ".uni"? <a:Arg1> => ast::Instruction::Bra(ast::BraData{}, a) "bra" <u:".uni"?> <a:Arg1> => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a)
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
@ -372,7 +380,7 @@ ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
InstSt: ast::Instruction<&'input str> = { InstSt: ast::Instruction<&'input str> = {
"st" LdQualifier? StStateSpace? StCacheOperator? Vector? BaseType "[" <dst:ID> "]" "," <src:Operand> => { "st" LdQualifier? StStateSpace? StCacheOperator? VectorPrefix? MemoryType "[" <dst:ID> "]" "," <src:Operand> => {
ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src}) ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src})
} }
}; };
@ -454,9 +462,9 @@ OptionalDst: &'input str = {
"|" <dst2:ID> => dst2 "|" <dst2:ID> => dst2
} }
Vector = { VectorPrefix: ast::VectorPrefix = {
".v2", ".v2" => ast::VectorPrefix::V2,
".v4" ".v4" => ast::VectorPrefix::V4
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file

View File

@ -8,6 +8,7 @@ use std::fmt;
#[derive(PartialEq, Eq, Hash, Clone, Copy)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType { enum SpirvType {
Base(ast::ScalarType), Base(ast::ScalarType),
Pointer(ast::ScalarType, spirv::StorageClass),
} }
struct TypeWordMap { struct TypeWordMap {
@ -33,29 +34,41 @@ impl TypeWordMap {
self.fn_void self.fn_void
} }
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
*self.complex.entry(t).or_insert_with(|| match t { *self.complex.entry(SpirvType::Base(t)).or_insert_with(|| match t {
SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => { ast::ScalarType::B8 | ast::ScalarType::U8 => {
b.type_int(8, 0) b.type_int(8, 0)
} }
SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => { ast::ScalarType::B16 | ast::ScalarType::U16 => {
b.type_int(16, 0) b.type_int(16, 0)
} }
SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => { ast::ScalarType::B32 | ast::ScalarType::U32 => {
b.type_int(32, 0) b.type_int(32, 0)
} }
SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => { ast::ScalarType::B64 | ast::ScalarType::U64 => {
b.type_int(64, 0) b.type_int(64, 0)
} }
SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1), ast::ScalarType::S8 => b.type_int(8, 1),
SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1), ast::ScalarType::S16 => b.type_int(16, 1),
SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1), ast::ScalarType::S32 => b.type_int(32, 1),
SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1), ast::ScalarType::S64 => b.type_int(64, 1),
SpirvType::Base(ast::ScalarType::F16) => b.type_float(16), ast::ScalarType::F16 => b.type_float(16),
SpirvType::Base(ast::ScalarType::F32) => b.type_float(32), ast::ScalarType::F32 => b.type_float(32),
SpirvType::Base(ast::ScalarType::F64) => b.type_float(64), ast::ScalarType::F64 => b.type_float(64),
}) })
} }
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
match t {
SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
SpirvType::Pointer(scalar, storage) => {
let base = self.get_or_add_scalar(b, scalar);
*self.complex.entry(t).or_insert_with(|| {
b.type_pointer(None, storage, base)
})
}
}
}
} }
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> { pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
@ -123,7 +136,7 @@ fn emit_function<'a>(
); );
let id_offset = builder.reserve_ids(unique_ids); let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args); emit_function_args(builder, id_offset, map, &f.args);
emit_function_body_ops(builder, id_offset, &normalized_ids, &bbs)?; emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?;
builder.end_function()?; builder.end_function()?;
builder.ret()?; builder.ret()?;
builder.end_function()?; builder.end_function()?;
@ -178,6 +191,7 @@ fn collect_label_ids<'a>(
fn emit_function_body_ops( fn emit_function_body_ops(
builder: &mut dr::Builder, builder: &mut dr::Builder,
id_offset: spirv::Word, id_offset: spirv::Word,
map: &mut TypeWordMap,
func: &[Statement], func: &[Statement],
cfg: &[BasicBlock], cfg: &[BasicBlock],
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
@ -193,12 +207,35 @@ fn emit_function_body_ops(
}; };
builder.begin_block(header_id)?; builder.begin_block(header_id)?;
for s in body { for s in body {
/*
match s { match s {
Statement::Instruction(pred, inst) => (), // If block startd with a label it has already been emitted,
// all other labels in the block are unused
Statement::Label(_) => (), Statement::Label(_) => (),
Statement::Conditional(bra) => {
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
Statement::Instruction(inst) => match inst {
// Sadly, SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?;
}
ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdQualifier::Weak || data.vector.is_some() {
todo!()
}
let storage_class = match data.state_space {
ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
ast::LdStateSpace::Param => spirv::StorageClass::CrossWorkgroup,
_ => todo!(),
};
let result_type = map.get_or_add(builder, SpirvType::Base(data.typ));
let pointer_type =
map.get_or_add(builder, SpirvType::Pointer(data.typ, storage_class));
builder.load(result_type, None, pointer_type, None, [])?;
}
_ => todo!(),
},
} }
*/
} }
} }
Ok(()) Ok(())
@ -1273,7 +1310,7 @@ mod tests {
let func = vec![ let func = vec![
Statement::Label(12), Statement::Label(12),
Statement::Instruction(ast::Instruction::Bra( Statement::Instruction(ast::Instruction::Bra(
ast::BraData {}, ast::BraData { uniform: false },
ast::Arg1 { src: 12 }, ast::Arg1 { src: 12 },
)), )),
]; ];