Parse vector movs (mov.type a.x b.y;)

This commit is contained in:
Andrzej Janik
2020-09-12 02:33:20 +02:00
parent 1238796dfd
commit 48dac43540
6 changed files with 178 additions and 133 deletions

View File

@ -8,7 +8,7 @@ pub struct Module {
pub enum ModuleCompileError<'a> { pub enum ModuleCompileError<'a> {
Parse( Parse(
Vec<ptx::ast::PtxError>, Vec<ptx::ast::PtxError>,
Option<ptx::ParseError<usize, ptx::Token<'a>, &'a str>>, Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
), ),
Compile(ptx::SpirvError), Compile(ptx::SpirvError),
} }

View File

@ -316,7 +316,8 @@ pub struct PredAt<ID> {
pub enum Instruction<P: ArgParams> { pub enum Instruction<P: ArgParams> {
Ld(LdData, Arg2<P>), Ld(LdData, Arg2<P>),
Mov(MovData, Arg2Mov<P>), Mov(MovType, Arg2<P>),
MovVector(MovVectorType, Arg2Vec<P>),
Mul(MulDetails, Arg3<P>), Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>), Add(AddDetails, Arg3<P>),
Setp(SetpData, Arg4<P>), Setp(SetpData, Arg4<P>),
@ -348,7 +349,6 @@ pub trait ArgParams {
type ID; type ID;
type Operand; type Operand;
type CallOperand; type CallOperand;
type MovOperand;
} }
pub struct ParsedArgParams<'a> { pub struct ParsedArgParams<'a> {
@ -359,7 +359,6 @@ impl<'a> ArgParams for ParsedArgParams<'a> {
type ID = &'a str; type ID = &'a str;
type Operand = Operand<&'a str>; type Operand = Operand<&'a str>;
type CallOperand = CallOperand<&'a str>; type CallOperand = CallOperand<&'a str>;
type MovOperand = MovOperand<&'a str>;
} }
pub struct Arg1<P: ArgParams> { pub struct Arg1<P: ArgParams> {
@ -376,9 +375,10 @@ pub struct Arg2St<P: ArgParams> {
pub src2: P::Operand, pub src2: P::Operand,
} }
pub struct Arg2Mov<P: ArgParams> { pub enum Arg2Vec<P: ArgParams> {
pub dst: P::ID, Dst((P::ID, u8), P::ID),
pub src: P::MovOperand, Src(P::ID, (P::ID, u8)),
Both((P::ID, u8), (P::ID, u8)),
} }
pub struct Arg3<P: ArgParams> { pub struct Arg3<P: ArgParams> {
@ -415,11 +415,6 @@ pub enum CallOperand<ID> {
Imm(i128), Imm(i128),
} }
pub enum MovOperand<ID> {
Op(Operand<ID>),
Vec(ID, u8),
}
pub enum VectorPrefix { pub enum VectorPrefix {
V2, V2,
V4, V4,
@ -467,10 +462,6 @@ pub enum LdCacheOperator {
Uncached, Uncached,
} }
pub struct MovData {
pub typ: Type,
}
sub_scalar_type!(MovScalarType { sub_scalar_type!(MovScalarType {
B16, B16,
B32, B32,
@ -486,19 +477,25 @@ sub_scalar_type!(MovScalarType {
Pred, Pred,
}); });
enum MovType { // pred vectors are illegal
Scalar(MovScalarType), sub_scalar_type!(MovVectorType {
Vector(MovScalarType, u8), B16,
Array(MovScalarType, u32), B32,
} B64,
U16,
U32,
U64,
S16,
S32,
S64,
F32,
F64,
});
impl From<MovType> for Type { sub_type! {
fn from(t: MovType) -> Self { MovType {
match t { Scalar(MovScalarType),
MovType::Scalar(t) => Type::Scalar(t.into()), Vector(MovVectorType, u8),
MovType::Vector(t, len) => Type::Vector(t.into(), len),
MovType::Array(t, len) => Type::Array(t.into(), len),
}
} }
} }

View File

@ -6,13 +6,13 @@ extern crate lalrpop_util;
extern crate quick_error; extern crate quick_error;
extern crate bit_vec; extern crate bit_vec;
extern crate half;
#[cfg(test)] #[cfg(test)]
extern crate level_zero as ze; extern crate level_zero as ze;
#[cfg(test)] #[cfg(test)]
extern crate level_zero_sys as l0; extern crate level_zero_sys as l0;
extern crate rspirv; extern crate rspirv;
extern crate spirv_headers as spirv; extern crate spirv_headers as spirv;
extern crate half;
#[cfg(test)] #[cfg(test)]
extern crate spirv_tools_sys as spirv_tools; extern crate spirv_tools_sys as spirv_tools;
@ -27,12 +27,26 @@ pub mod ast;
mod test; mod test;
mod translate; mod translate;
pub use lalrpop_util::ParseError as ParseError; pub use crate::ptx::ModuleParser;
pub use lalrpop_util::lexer::Token as Token; pub use lalrpop_util::lexer::Token;
pub use crate::ptx::ModuleParser as ModuleParser; pub use lalrpop_util::ParseError;
pub use translate::to_spirv as to_spirv;
pub use rspirv::dr::Error as SpirvError; pub use rspirv::dr::Error as SpirvError;
pub use translate::to_spirv;
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> { pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
x.into_iter().filter_map(|x| x).collect() x.into_iter().filter_map(|x| x).collect()
} }
pub(crate) fn vector_index<'input>(
inp: &'input str,
) -> Result<u8, ParseError<usize, lalrpop_util::lexer::Token<'input>, ast::PtxError>> {
match inp {
"x" | "r" => Ok(0),
"y" | "g" => Ok(1),
"z" | "b" => Ok(2),
"w" | "a" => Ok(3),
_ => Err(ParseError::User {
error: ast::PtxError::WrongVectorElement,
}),
}
}

View File

@ -1,9 +1,13 @@
use crate::ast; use crate::ast;
use crate::ast::UnwrapWithVec; use crate::ast::UnwrapWithVec;
use crate::without_none; use crate::{without_none, vector_index};
grammar<'a>(errors: &mut Vec<ast::PtxError>); grammar<'a>(errors: &mut Vec<ast::PtxError>);
extern {
type Error = ast::PtxError;
}
match { match {
r"\s+" => { }, r"\s+" => { },
r"//[^\n\r]*[\n\r]*" => { }, r"//[^\n\r]*[\n\r]*" => { },
@ -487,24 +491,49 @@ LdCacheOperator: ast::LdCacheOperator = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = { InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mov" <t:MovType> <a:Arg2Mov> => { "mov" <t:MovType> <a:Arg2> => {
ast::Instruction::Mov(ast::MovData{ typ:t }, a) ast::Instruction::Mov(t, a)
},
"mov" <t:MovVectorType> <a:Arg2Vec> => {
ast::Instruction::MovVector(t, a)
} }
}; };
MovType: ast::Type = { #[inline]
".b16" => ast::Type::Scalar(ast::ScalarType::B16), MovType: ast::MovType = {
".b32" => ast::Type::Scalar(ast::ScalarType::B32), <t:MovScalarType> => ast::MovType::Scalar(t),
".b64" => ast::Type::Scalar(ast::ScalarType::B64), <pref:VectorPrefix> <t:MovVectorType> => ast::MovType::Vector(t, pref)
".u16" => ast::Type::Scalar(ast::ScalarType::U16), }
".u32" => ast::Type::Scalar(ast::ScalarType::U32),
".u64" => ast::Type::Scalar(ast::ScalarType::U64), #[inline]
".s16" => ast::Type::Scalar(ast::ScalarType::S16), MovScalarType: ast::MovScalarType = {
".s32" => ast::Type::Scalar(ast::ScalarType::S32), ".b16" => ast::MovScalarType::B16,
".s64" => ast::Type::Scalar(ast::ScalarType::S64), ".b32" => ast::MovScalarType::B32,
".f32" => ast::Type::Scalar(ast::ScalarType::F32), ".b64" => ast::MovScalarType::B64,
".f64" => ast::Type::Scalar(ast::ScalarType::F64), ".u16" => ast::MovScalarType::U16,
".pred" => ast::Type::Scalar(ast::ScalarType::Pred) ".u32" => ast::MovScalarType::U32,
".u64" => ast::MovScalarType::U64,
".s16" => ast::MovScalarType::S16,
".s32" => ast::MovScalarType::S32,
".s64" => ast::MovScalarType::S64,
".f32" => ast::MovScalarType::F32,
".f64" => ast::MovScalarType::F64,
".pred" => ast::MovScalarType::Pred
};
#[inline]
MovVectorType: ast::MovVectorType = {
".b16" => ast::MovVectorType::B16,
".b32" => ast::MovVectorType::B32,
".b64" => ast::MovVectorType::B64,
".u16" => ast::MovVectorType::U16,
".u32" => ast::MovVectorType::U32,
".u64" => ast::MovVectorType::U64,
".s16" => ast::MovVectorType::S16,
".s32" => ast::MovVectorType::S32,
".s64" => ast::MovVectorType::S64,
".f32" => ast::MovVectorType::F32,
".f64" => ast::MovVectorType::F64,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
@ -989,29 +1018,6 @@ CallOperand: ast::CallOperand<&'input str> = {
} }
}; };
MovOperand: ast::MovOperand<&'input str> = {
<o:Operand> => ast::MovOperand::Op(o),
<o:VectorOperand> => {
let (pref, suf) = o;
let suf_idx = match suf {
"x" | "r" => 0,
"y" | "g" => 1,
"z" | "b" => 2,
"w" | "a" => 3,
_ => {
errors.push(ast::PtxError::WrongVectorElement);
0
}
};
ast::MovOperand::Vec(pref, suf_idx)
}
};
VectorOperand: (&'input str, &'input str) = {
<pref:ExtendedID> "." <suf:ExtendedID> => (pref, suf),
<pref:ExtendedID> <suf:DotID> => (pref, &suf[1..]),
};
Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = { Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = {
<src:ExtendedID> => ast::Arg1{<>} <src:ExtendedID> => ast::Arg1{<>}
}; };
@ -1020,8 +1026,21 @@ Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>} <dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
}; };
Arg2Mov: ast::Arg2Mov<ast::ParsedArgParams<'input>> = { Arg2Vec: ast::Arg2Vec<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:MovOperand> => ast::Arg2Mov{<>} <dst:VectorOperand> "," <src:ExtendedID> => ast::Arg2Vec::Dst(dst, src),
<dst:ExtendedID> "," <src:VectorOperand> => ast::Arg2Vec::Src(dst, src),
<dst:VectorOperand> "," <src:VectorOperand> => ast::Arg2Vec::Both(dst, src),
};
VectorOperand: (&'input str, u8) = {
<pref:ExtendedID> "." <suf:ExtendedID> =>? {
let suf_idx = vector_index(suf)?;
Ok((pref, suf_idx))
},
<pref:ExtendedID> <suf:DotID> =>? {
let suf_idx = vector_index(&suf[1..])?;
Ok((pref, suf_idx))
}
}; };
Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = { Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {

View File

@ -1,7 +1,7 @@
// Excersise as many features of vector types as possible // Excersise as many features of vector types as possible
.version 6.5 .version 6.5
.target sm_53 .target sm_60
.address_size 64 .address_size 64
.func (.reg .v2 .u32 output) impl( .func (.reg .v2 .u32 output) impl(
@ -17,6 +17,7 @@
add.u32 temp2, temp1, temp2; add.u32 temp2, temp1, temp2;
mov.u32 temp_v.x, temp2; mov.u32 temp_v.x, temp2;
mov.u32 temp_v.y, temp2; mov.u32 temp_v.y, temp2;
mov.u32 temp_v.x, temp_v.y;
mov.v2.u32 output, temp_v; mov.v2.u32 output, temp_v;
ret; ret;
} }

View File

@ -737,14 +737,11 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
} }
} }
fn src_mov_operand( fn src_vec_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>, desc: ArgumentDescriptor<(spirv::Word, u8)>,
) -> spirv::Word { ) -> (spirv::Word, u8) {
match &desc.op { (self.variable(desc.new_op(desc.op.0)), desc.op.1)
ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)),
ast::MovOperand::Vec(opr, _) => self.variable(desc.new_op(*opr)),
}
} }
} }
@ -986,8 +983,9 @@ fn emit_function_body_ops(
} }
// SPIR-V does not support ret as guaranteed-converged // SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?, ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(mov, arg) => { ast::Instruction::Mov(mov_type, arg) => {
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ)); let result_type =
map.get_or_add(builder, SpirvType::from(ast::Type::from(*mov_type)));
builder.copy_object(result_type, Some(arg.dst), arg.src)?; builder.copy_object(result_type, Some(arg.dst), arg.src)?;
} }
ast::Instruction::Mul(mul, arg) => match mul { ast::Instruction::Mul(mul, arg) => match mul {
@ -1032,6 +1030,7 @@ fn emit_function_body_ops(
builder.copy_object(result_type, Some(arg.dst), arg.src)?; builder.copy_object(result_type, Some(arg.dst), arg.src)?;
} }
ast::Instruction::SetpBool(_, _) => todo!(), ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::MovVector(_, _) => todo!(),
}, },
Statement::LoadVar(arg, typ) => { Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ)); let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@ -1751,7 +1750,6 @@ impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word; type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>; type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>; type CallOperand = ast::CallOperand<spirv::Word>;
type MovOperand = ast::MovOperand<spirv::Word>;
} }
impl ArgParamsEx for NormalizedArgParams { impl ArgParamsEx for NormalizedArgParams {
@ -1768,7 +1766,6 @@ impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word; type ID = spirv::Word;
type Operand = spirv::Word; type Operand = spirv::Word;
type CallOperand = spirv::Word; type CallOperand = spirv::Word;
type MovOperand = spirv::Word;
} }
impl ArgParamsEx for ExpandedArgParams { impl ArgParamsEx for ExpandedArgParams {
@ -1781,7 +1778,7 @@ trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID; fn variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID;
fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand; fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand;
fn src_call_operand(&mut self, desc: ArgumentDescriptor<T::CallOperand>) -> U::CallOperand; fn src_call_operand(&mut self, desc: ArgumentDescriptor<T::CallOperand>) -> U::CallOperand;
fn src_mov_operand(&mut self, desc: ArgumentDescriptor<T::MovOperand>) -> U::MovOperand; fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(T::ID, u8)>) -> (U::ID, u8);
} }
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
@ -1794,12 +1791,14 @@ where
fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word { fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc) self(desc)
} }
fn src_call_operand(&mut self, mut desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word { fn src_call_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
desc.op = self(desc.new_op(desc.op)); self(desc.new_op(desc.op))
desc.op
} }
fn src_mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word { fn src_vec_operand(
self(desc) &mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
) -> (spirv::Word, u8) {
(self(desc.new_op(desc.op.0)), desc.op.1)
} }
} }
@ -1832,16 +1831,8 @@ where
} }
} }
fn src_mov_operand( fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(&str, u8)>) -> (spirv::Word, u8) {
&mut self, (self(desc.op.0), desc.op.1)
desc: ArgumentDescriptor<ast::MovOperand<&str>>,
) -> ast::MovOperand<spirv::Word> {
match desc.op {
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))),
ast::MovOperand::Vec(reg, x2) => {
ast::MovOperand::Vec(self.variable(desc.new_op(reg)), x2)
}
}
} }
} }
@ -1869,6 +1860,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
visitor: &mut V, visitor: &mut V,
) -> ast::Instruction<U> { ) -> ast::Instruction<U> {
match self { match self {
ast::Instruction::MovVector(_, _) => todo!(),
ast::Instruction::Abs(_, _) => todo!(), ast::Instruction::Abs(_, _) => todo!(),
ast::Instruction::Call(_) => unreachable!(), ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => { ast::Instruction::Ld(d, a) => {
@ -1879,9 +1871,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer), a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer),
) )
} }
ast::Instruction::Mov(d, a) => { ast::Instruction::Mov(mov_type, a) => {
let inst_type = d.typ; ast::Instruction::Mov(mov_type, a.map(visitor, Some(mov_type.into())))
ast::Instruction::Mov(d, a.map(visitor, Some(inst_type)))
} }
ast::Instruction::Mul(d, a) => { ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type(); let inst_type = d.get_type();
@ -1982,19 +1973,11 @@ where
} }
} }
fn src_mov_operand( fn src_vec_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>, desc: ArgumentDescriptor<(spirv::Word, u8)>,
) -> ast::MovOperand<spirv::Word> { ) -> (spirv::Word, u8) {
match desc.op { (self(desc.new_op(desc.op.0)), desc.op.1)
ast::MovOperand::Op(op) => ast::MovOperand::Op(ArgumentMapVisitor::<
NormalizedArgParams,
NormalizedArgParams,
>::operand(
self, desc.new_op(op)
)),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
}
} }
} }
@ -2004,6 +1987,7 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Bra(_, a) => Some(a.src), ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _) ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _) | ast::Instruction::Mov(_, _)
| ast::Instruction::MovVector(_, _)
| ast::Instruction::Mul(_, _) | ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _) | ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _) | ast::Instruction::Setp(_, _)
@ -2201,25 +2185,55 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
} }
} }
impl<T: ArgParamsEx> ast::Arg2Mov<T> { impl<T: ArgParamsEx> ast::Arg2Vec<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>( fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self, self,
visitor: &mut V, visitor: &mut V,
t: Option<ast::Type>, t: ast::Type,
) -> ast::Arg2Mov<U> { ) -> ast::Arg2Vec<U> {
ast::Arg2Mov { match self {
dst: visitor.variable(ArgumentDescriptor { ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst(
op: self.dst, visitor.src_vec_operand(ArgumentDescriptor {
typ: t, op: dst,
is_dst: true, typ: Some(t),
is_pointer: false, is_dst: true,
}), is_pointer: false,
src: visitor.src_mov_operand(ArgumentDescriptor { }),
op: self.src, visitor.variable(ArgumentDescriptor {
typ: t, op: src,
is_dst: false, typ: Some(t),
is_pointer: false, is_dst: false,
}), is_pointer: false,
}),
),
ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src (
visitor.variable(ArgumentDescriptor {
op: dst,
typ: Some(t),
is_dst: true,
is_pointer: false,
}),
visitor.src_vec_operand(ArgumentDescriptor {
op: src,
typ: Some(t),
is_dst: false,
is_pointer: false,
}),
),
ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both (
visitor.src_vec_operand(ArgumentDescriptor {
op: dst,
typ: Some(t),
is_dst: true,
is_pointer: false,
}),
visitor.src_vec_operand(ArgumentDescriptor {
op: src,
typ: Some(t),
is_dst: false,
is_pointer: false,
}),
),
} }
} }
} }