Be more precise about types admitted in register definitions and method arguments

This commit is contained in:
Andrzej Janik
2020-09-11 00:40:13 +02:00
parent 76afbeba63
commit 1238796dfd
7 changed files with 647 additions and 351 deletions

View File

@ -12,9 +12,117 @@ quick_error! {
SyntaxError {} SyntaxError {}
NonF32Ftz {} NonF32Ftz {}
WrongArrayType {} WrongArrayType {}
WrongVectorElement {}
MultiArrayVariable {}
} }
} }
macro_rules! sub_scalar_type {
($name:ident { $($variant:ident),+ $(,)? }) => {
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum $name {
$(
$variant,
)+
}
impl From<$name> for ScalarType {
fn from(t: $name) -> ScalarType {
match t {
$(
$name::$variant => ScalarType::$variant,
)+
}
}
}
};
}
macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum $type_name {
$(
$variant ($($field_type),+),
)+
}
impl From<$type_name> for Type {
#[allow(non_snake_case)]
fn from(t: $type_name) -> Type {
match t {
$(
$type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+),
)+
}
}
}
};
}
sub_type! {
VariableRegType {
Scalar(ScalarType),
Vector(SizedScalarType, u8),
}
}
sub_type! {
VariableLocalType {
Scalar(SizedScalarType),
Vector(SizedScalarType, u8),
Array(SizedScalarType, u32),
}
}
// For some weird reson this is illegal:
// .param .f16x2 foobar;
// but this is legal:
// .param .f16x2 foobar[1];
sub_type! {
VariableParamType {
Scalar(ParamScalarType),
Array(SizedScalarType, u32),
}
}
sub_scalar_type!(SizedScalarType {
B8,
B16,
B32,
B64,
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F16,
F16x2,
F32,
F64,
});
sub_scalar_type!(ParamScalarType {
B8,
B16,
B32,
B64,
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F16,
F32,
F64,
});
pub trait UnwrapWithVec<E, To> { pub trait UnwrapWithVec<E, To> {
fn unwrap_with(self, errs: &mut Vec<E>) -> To; fn unwrap_with(self, errs: &mut Vec<E>) -> To;
} }
@ -56,6 +164,9 @@ pub enum MethodDecl<'a, P: ArgParams> {
Kernel(&'a str, Vec<KernelArgument<P>>), Kernel(&'a str, Vec<KernelArgument<P>>),
} }
pub type FnArgument<P: ArgParams> = Variable<FnArgumentType, P>;
pub type KernelArgument<P: ArgParams> = Variable<VariableParamType, P>;
pub struct Function<'a, P: ArgParams, S> { pub struct Function<'a, P: ArgParams, S> {
pub func_directive: MethodDecl<'a, P>, pub func_directive: MethodDecl<'a, P>,
pub body: Option<Vec<S>>, pub body: Option<Vec<S>>,
@ -63,43 +174,28 @@ pub struct Function<'a, P: ArgParams, S> {
pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>; pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>;
pub struct FnArgument<P: ArgParams> { #[derive(PartialEq, Eq, Clone, Copy)]
pub base: KernelArgument<P>, pub enum FnArgumentType {
pub state_space: FnArgStateSpace, Reg(VariableRegType),
Param(VariableParamType),
} }
#[derive(Copy, Clone, PartialEq, Eq)] impl From<FnArgumentType> for Type {
pub enum FnArgStateSpace { fn from(t: FnArgumentType) -> Self {
Reg, match t {
Param, FnArgumentType::Reg(x) => x.into(),
FnArgumentType::Param(x) => x.into(),
}
} }
#[derive(Default, Copy, Clone)]
pub struct KernelArgument<P: ArgParams> {
pub name: P::ID,
pub a_type: ScalarType,
// TODO: turn length into part of type definition
pub length: u32,
} }
#[derive(PartialEq, Eq, Hash, Clone, Copy)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum Type { pub enum Type {
Scalar(ScalarType), Scalar(ScalarType),
ExtendedScalar(ExtendedScalarType), Vector(ScalarType, u8),
Array(ScalarType, u32), Array(ScalarType, u32),
} }
impl From<FloatType> for Type {
fn from(t: FloatType) -> Self {
match t {
FloatType::F16 => Type::Scalar(ScalarType::F16),
FloatType::F16x2 => Type::ExtendedScalar(ExtendedScalarType::F16x2),
FloatType::F32 => Type::Scalar(ScalarType::F32),
FloatType::F64 => Type::Scalar(ScalarType::F64),
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum ScalarType { pub enum ScalarType {
B8, B8,
@ -117,25 +213,11 @@ pub enum ScalarType {
F16, F16,
F32, F32,
F64, F64,
F16x2,
Pred,
} }
impl From<IntType> for ScalarType { sub_scalar_type!(IntType {
fn from(t: IntType) -> Self {
match t {
IntType::S8 => ScalarType::S8,
IntType::S16 => ScalarType::S16,
IntType::S32 => ScalarType::S32,
IntType::S64 => ScalarType::S64,
IntType::U8 => ScalarType::U8,
IntType::U16 => ScalarType::U16,
IntType::U32 => ScalarType::U32,
IntType::U64 => ScalarType::U64,
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum IntType {
U8, U8,
U16, U16,
U32, U32,
@ -143,8 +225,8 @@ pub enum IntType {
S8, S8,
S16, S16,
S32, S32,
S64, S64
} });
impl IntType { impl IntType {
pub fn is_signed(self) -> bool { pub fn is_signed(self) -> bool {
@ -168,19 +250,12 @@ impl IntType {
} }
} }
#[derive(PartialEq, Eq, Hash, Clone, Copy)] sub_scalar_type!(FloatType {
pub enum FloatType {
F16, F16,
F16x2, F16x2,
F32, F32,
F64, F64
} });
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum ExtendedScalarType {
F16x2,
Pred,
}
impl Default for ScalarType { impl Default for ScalarType {
fn default() -> Self { fn default() -> Self {
@ -190,19 +265,39 @@ impl Default for ScalarType {
pub enum Statement<P: ArgParams> { pub enum Statement<P: ArgParams> {
Label(P::ID), Label(P::ID),
Variable(Variable<P>), Variable(MultiVariable<P>),
Instruction(Option<PredAt<P::ID>>, Instruction<P>), Instruction(Option<PredAt<P::ID>>, Instruction<P>),
Block(Vec<Statement<P>>), Block(Vec<Statement<P>>),
} }
pub struct Variable<P: ArgParams> { pub struct MultiVariable<P: ArgParams> {
pub space: StateSpace, pub var: Variable<VariableType, P>,
pub align: Option<u32>,
pub v_type: Type,
pub name: P::ID,
pub count: Option<u32>, pub count: Option<u32>,
} }
pub struct Variable<T, P: ArgParams> {
pub align: Option<u32>,
pub v_type: T,
pub name: P::ID,
}
#[derive(Eq, PartialEq, Copy, Clone)]
pub enum VariableType {
Reg(VariableRegType),
Local(VariableLocalType),
Param(VariableParamType),
}
impl From<VariableType> for Type {
fn from(t: VariableType) -> Self {
match t {
VariableType::Reg(t) => t.into(),
VariableType::Local(t) => t.into(),
VariableType::Param(t) => t.into(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)] #[derive(Copy, Clone, PartialEq, Eq)]
pub enum StateSpace { pub enum StateSpace {
Reg, Reg,
@ -322,7 +417,7 @@ pub enum CallOperand<ID> {
pub enum MovOperand<ID> { pub enum MovOperand<ID> {
Op(Operand<ID>), Op(Operand<ID>),
Vec(String, String), Vec(ID, u8),
} }
pub enum VectorPrefix { pub enum VectorPrefix {
@ -334,7 +429,7 @@ pub struct LdData {
pub qualifier: LdStQualifier, pub qualifier: LdStQualifier,
pub state_space: LdStateSpace, pub state_space: LdStateSpace,
pub caching: LdCacheOperator, pub caching: LdCacheOperator,
pub vector: Option<VectorPrefix>, pub vector: Option<u8>,
pub typ: ScalarType, pub typ: ScalarType,
} }
@ -376,6 +471,37 @@ pub struct MovData {
pub typ: Type, pub typ: Type,
} }
sub_scalar_type!(MovScalarType {
B16,
B32,
B64,
U16,
U32,
U64,
S16,
S32,
S64,
F32,
F64,
Pred,
});
enum MovType {
Scalar(MovScalarType),
Vector(MovScalarType, u8),
Array(MovScalarType, u32),
}
impl From<MovType> for Type {
fn from(t: MovType) -> Self {
match t {
MovType::Scalar(t) => Type::Scalar(t.into()),
MovType::Vector(t, len) => Type::Vector(t.into(), len),
MovType::Array(t, len) => Type::Array(t.into(), len),
}
}
}
pub enum MulDetails { pub enum MulDetails {
Int(MulIntDesc), Int(MulIntDesc),
Float(MulFloatDesc), Float(MulFloatDesc),
@ -587,7 +713,7 @@ pub struct StData {
pub qualifier: LdStQualifier, pub qualifier: LdStQualifier,
pub state_space: StStateSpace, pub state_space: StStateSpace,
pub caching: StCacheOperator, pub caching: StCacheOperator,
pub vector: Option<VectorPrefix>, pub vector: Option<u8>,
pub typ: ScalarType, pub typ: ScalarType,
} }

View File

@ -21,6 +21,7 @@ match {
"@", "@",
"[", "]", "[", "]",
"{", "}", "{", "}",
"<", ">",
"|", "|",
".acquire", ".acquire",
".address_size", ".address_size",
@ -133,8 +134,6 @@ match {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID, r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID,
r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID, r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID,
} else {
r"(?:[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+)<[0-9]+>" => ParametrizedID,
} }
ExtendedID : &'input str = { ExtendedID : &'input str = {
@ -214,7 +213,9 @@ LinkingDirective = {
MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = { MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = {
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params), ".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) ".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
}
}; };
KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = { KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = {
@ -225,32 +226,25 @@ FnArguments: Vec<ast::FnArgument<ast::ParsedArgParams<'input>>> = {
"(" <args:Comma<FnInput>> ")" => args "(" <args:Comma<FnInput>> ")" => args
}; };
FnInput: ast::FnArgument<ast::ParsedArgParams<'input>> = { KernelInput: ast::Variable<ast::VariableParamType, ast::ParsedArgParams<'input>> = {
".reg" <_type:ScalarType> <name:ExtendedID> => { <v:ParamVariable> => {
ast::FnArgument { let (align, v_type, name) = v;
base: ast::KernelArgument {a_type: _type, name: name, length: 1 }, ast::Variable{ align, v_type, name }
state_space: ast::FnArgStateSpace::Reg,
}
},
<p:KernelInput> => {
ast::FnArgument {
base: p,
state_space: ast::FnArgStateSpace::Param,
} }
} }
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space FnInput: ast::Variable<ast::FnArgumentType, ast::ParsedArgParams<'input>> = {
KernelInput: ast::KernelArgument<ast::ParsedArgParams<'input>> = { <v:RegVariable> => {
".param" <_type:ScalarType> <name:ExtendedID> => { let (align, v_type, name) = v;
ast::KernelArgument {a_type: _type, name: name, length: 1 } let v_type = ast::FnArgumentType::Reg(v_type);
ast::Variable{ align, v_type, name }
}, },
".param" <a_type:ScalarType> <name:ExtendedID> "[" <length:Num> "]" => { <v:ParamVariable> => {
let length = length.parse::<u32>(); let (align, v_type, name) = v;
let length = length.unwrap_with(errors); let v_type = ast::FnArgumentType::Param(v_type);
ast::KernelArgument { a_type: a_type, name: name, length: length } ast::Variable{ align, v_type, name }
}
} }
};
pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>> = { pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>> = {
"{" <s:Statement*> "}" => { Some(without_none(s)) }, "{" <s:Statement*> "}" => { Some(without_none(s)) },
@ -267,22 +261,13 @@ StateSpaceSpecifier: ast::StateSpace = {
".param" => ast::StateSpace::Param, // used to prepare function call ".param" => ast::StateSpace::Param, // used to prepare function call
}; };
Type: ast::Type = {
<t:ScalarType> => ast::Type::Scalar(t),
<t:ExtendedScalarType> => ast::Type::ExtendedScalar(t),
};
ScalarType: ast::ScalarType = { ScalarType: ast::ScalarType = {
".f16" => ast::ScalarType::F16, ".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
".pred" => ast::ScalarType::Pred,
MemoryType MemoryType
}; };
ExtendedScalarType: ast::ExtendedScalarType = {
".f16x2" => ast::ExtendedScalarType::F16x2,
".pred" => ast::ExtendedScalarType::Pred,
};
MemoryType: ast::ScalarType = { MemoryType: ast::ScalarType = {
".b8" => ast::ScalarType::B8, ".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16, ".b16" => ast::ScalarType::B16,
@ -303,7 +288,7 @@ MemoryType: ast::ScalarType = {
Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = { Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
<l:Label> => Some(ast::Statement::Label(l)), <l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None, DebugDirective => None,
<v:Variable> ";" => Some(ast::Statement::Variable(v)), <v:MultiVariable> ";" => Some(ast::Statement::Variable(v)),
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)), <p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s))) "{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
}; };
@ -328,21 +313,109 @@ Align: u32 = {
} }
}; };
Variable: ast::Variable<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
<s:StateSpaceSpecifier> <a:Align?> <t:Type> <v:VariableName> <arr: ArraySpecifier?> => { MultiVariable: ast::MultiVariable<ast::ParsedArgParams<'input>> = {
let (name, count) = v; <var:Variable> <count:VariableParam?> => ast::MultiVariable{<>}
let t = match (t, arr) {
(ast::Type::Scalar(st), Some(arr_size)) => ast::Type::Array(st, arr_size),
(t, Some(_)) => {
errors.push(ast::PtxError::WrongArrayType);
t
},
(t, None) => t,
};
ast::Variable { space: s, align: a, v_type: t, name: name, count: count }
} }
VariableParam: u32 = {
"<" <n:Num> ">" => {
let size = n.parse::<u32>();
size.unwrap_with(errors)
}
}
Variable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::VariableType::Reg(v_type);
ast::Variable {align, v_type, name}
},
LocalVariable,
<v:ParamVariable> => {
let (align, v_type, name) = v;
let v_type = ast::VariableType::Param(v_type);
ast::Variable {align, v_type, name}
},
}; };
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
".reg" <align:Align?> <t:ScalarType> <name:ExtendedID> => {
let v_type = ast::VariableRegType::Scalar(t);
(align, v_type, name)
},
".reg" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableRegType::Vector(t, v_len);
(align, v_type, name)
}
}
LocalVariable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
ast::Variable {align, v_type, name}
},
".local" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
ast::Variable {align, v_type, name}
},
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr));
ast::Variable {align, v_type, name}
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
ParamVariable: (Option<u32>, ast::VariableParamType, &'input str) = {
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
let v_type = ast::VariableParamType::Scalar(t);
(align, v_type, name)
},
".param" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
let v_type = ast::VariableParamType::Array(t, arr);
(align, v_type, name)
}
}
#[inline]
SizedScalarType: ast::SizedScalarType = {
".b8" => ast::SizedScalarType::B8,
".b16" => ast::SizedScalarType::B16,
".b32" => ast::SizedScalarType::B32,
".b64" => ast::SizedScalarType::B64,
".u8" => ast::SizedScalarType::U8,
".u16" => ast::SizedScalarType::U16,
".u32" => ast::SizedScalarType::U32,
".u64" => ast::SizedScalarType::U64,
".s8" => ast::SizedScalarType::S8,
".s16" => ast::SizedScalarType::S16,
".s32" => ast::SizedScalarType::S32,
".s64" => ast::SizedScalarType::S64,
".f16" => ast::SizedScalarType::F16,
".f16x2" => ast::SizedScalarType::F16x2,
".f32" => ast::SizedScalarType::F32,
".f64" => ast::SizedScalarType::F64,
}
#[inline]
ParamScalarType: ast::ParamScalarType = {
".b8" => ast::ParamScalarType::B8,
".b16" => ast::ParamScalarType::B16,
".b32" => ast::ParamScalarType::B32,
".b64" => ast::ParamScalarType::B64,
".u8" => ast::ParamScalarType::U8,
".u16" => ast::ParamScalarType::U16,
".u32" => ast::ParamScalarType::U32,
".u64" => ast::ParamScalarType::U64,
".s8" => ast::ParamScalarType::S8,
".s16" => ast::ParamScalarType::S16,
".s32" => ast::ParamScalarType::S32,
".s64" => ast::ParamScalarType::S64,
".f16" => ast::ParamScalarType::F16,
".f32" => ast::ParamScalarType::F32,
".f64" => ast::ParamScalarType::F64,
}
ArraySpecifier: u32 = { ArraySpecifier: u32 = {
"[" <n:Num> "]" => { "[" <n:Num> "]" => {
let size = n.parse::<u32>(); let size = n.parse::<u32>();
@ -350,20 +423,6 @@ ArraySpecifier: u32 = {
} }
}; };
VariableName: (&'input str, Option<u32>) = {
<id:ExtendedID> => (id, None),
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
<id:ParametrizedID> => {
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
let count = id[left_angle+1..id.len()-1].parse::<u32>();
let count = match count {
Ok(c) => Some(c),
Err(e) => { errors.push(e.into()); None },
};
(&id[0..left_angle], count)
}
};
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd, InstLd,
InstMov, InstMov,
@ -445,7 +504,7 @@ MovType: ast::Type = {
".s64" => ast::Type::Scalar(ast::ScalarType::S64), ".s64" => ast::Type::Scalar(ast::ScalarType::S64),
".f32" => ast::Type::Scalar(ast::ScalarType::F32), ".f32" => ast::Type::Scalar(ast::ScalarType::F32),
".f64" => ast::Type::Scalar(ast::ScalarType::F64), ".f64" => ast::Type::Scalar(ast::ScalarType::F64),
".pred" => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred) ".pred" => ast::Type::Scalar(ast::ScalarType::Pred)
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
@ -934,7 +993,17 @@ MovOperand: ast::MovOperand<&'input str> = {
<o:Operand> => ast::MovOperand::Op(o), <o:Operand> => ast::MovOperand::Op(o),
<o:VectorOperand> => { <o:VectorOperand> => {
let (pref, suf) = o; let (pref, suf) = o;
ast::MovOperand::Vec(pref.to_string(), suf.to_string()) let suf_idx = match suf {
"x" | "r" => 0,
"y" | "g" => 1,
"z" | "b" => 2,
"w" | "a" => 3,
_ => {
errors.push(ast::PtxError::WrongVectorElement);
0
}
};
ast::MovOperand::Vec(pref, suf_idx)
} }
}; };
@ -980,9 +1049,9 @@ OptionalDst: &'input str = {
"|" <dst2:ExtendedID> => dst2 "|" <dst2:ExtendedID> => dst2
} }
VectorPrefix: ast::VectorPrefix = { VectorPrefix: u8 = {
".v2" => ast::VectorPrefix::V2, ".v2" => 2,
".v4" => ast::VectorPrefix::V4 ".v4" => 4
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file

View File

@ -8,16 +8,16 @@ fn parse_and_assert(s: &str) {
assert!(errors.len() == 0); assert!(errors.len() == 0);
} }
#[test] fn compile_and_assert(s: &str) -> Result<(), rspirv::dr::Error> {
fn empty() { let mut errors = Vec::new();
parse_and_assert(".version 6.5 .target sm_30, debug"); let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
crate::to_spirv(ast)?;
Ok(())
} }
#[test] #[test]
#[allow(non_snake_case)] fn empty() {
fn vectorAdd_kernel64_ptx() { parse_and_assert(".version 6.5 .target sm_30, debug");
let vector_add = include_str!("vectorAdd_kernel64.ptx");
parse_and_assert(vector_add);
} }
#[test] #[test]
@ -28,8 +28,14 @@ fn operands_ptx() {
#[test] #[test]
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn _Z9vectorAddPKfS0_Pfi_ptx() { fn vectorAdd_kernel64_ptx() -> Result<(), rspirv::dr::Error> {
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx"); let vector_add = include_str!("vectorAdd_kernel64.ptx");
parse_and_assert(vector_add); compile_and_assert(vector_add)
} }
#[test]
#[allow(non_snake_case)]
fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), rspirv::dr::Error> {
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
compile_and_assert(vector_add)
}

View File

@ -54,6 +54,7 @@ test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]); test_ptx!(block, [1u64], [2u64]);
test_ptx!(local_align, [1u64], [1u64]); test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]); test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
struct DisplayError<T: Debug> { struct DisplayError<T: Debug> {
err: T, err: T,

View File

@ -0,0 +1,44 @@
// Excersise as many features of vector types as possible
.version 6.5
.target sm_53
.address_size 64
.func (.reg .v2 .u32 output) impl(
.reg .v2 .u32 input
)
{
.reg .v2 .u32 temp_v;
.reg .u32 temp1;
.reg .u32 temp2;
mov.u32 temp1, input.x;
mov.u32 temp2, input.y;
add.u32 temp2, temp1, temp2;
mov.u32 temp_v.x, temp2;
mov.u32 temp_v.y, temp2;
mov.v2.u32 output, temp_v;
ret;
}
.visible .entry vector(
.param .u64 input_p,
.param .u64 output_p
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .v2 .u32 temp;
.reg .u32 temp1;
.reg .u32 temp2;
.reg .b64 packed;
ld.param.u64 in_addr, [input_p];
ld.param.u64 out_addr, [output_p];
ld.v2.u32 temp, [in_addr];
call (temp), impl, (temp);
mov.b64 packed, temp;
st.v2.u32 [out_addr], temp;
ret;
}

View File

@ -0,0 +1,46 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "add"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_1 = OpConstant %ulong 1
%1 = OpFunction %void None %28
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%23 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %21
OpStore %6 %14
%17 = OpLoad %ulong %6
%16 = OpIAdd %ulong %17 %ulong_1
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
OpStore %22 %19
OpReturn
OpFunctionEnd

View File

@ -8,6 +8,7 @@ use rspirv::binary::Assemble;
#[derive(PartialEq, Eq, Hash, Clone)] #[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType { enum SpirvType {
Base(SpirvScalarKey), Base(SpirvScalarKey),
Vector(SpirvScalarKey, u8),
Array(SpirvScalarKey, u32), Array(SpirvScalarKey, u32),
Pointer(Box<SpirvType>, spirv::StorageClass), Pointer(Box<SpirvType>, spirv::StorageClass),
Func(Option<Box<SpirvType>>, Vec<SpirvType>), Func(Option<Box<SpirvType>>, Vec<SpirvType>),
@ -17,7 +18,7 @@ impl SpirvType {
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
let key = match t { let key = match t {
ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)), ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)), ast::Type::Vector(typ, len) => SpirvType::Vector(SpirvScalarKey::from(typ), len),
ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len), ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
}; };
SpirvType::Pointer(Box::new(key), sc) SpirvType::Pointer(Box::new(key), sc)
@ -28,7 +29,7 @@ impl From<ast::Type> for SpirvType {
fn from(t: ast::Type) -> Self { fn from(t: ast::Type) -> Self {
match t { match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
} }
} }
@ -77,15 +78,8 @@ impl From<ast::ScalarType> for SpirvScalarKey {
ast::ScalarType::F16 => SpirvScalarKey::F16, ast::ScalarType::F16 => SpirvScalarKey::F16,
ast::ScalarType::F32 => SpirvScalarKey::F32, ast::ScalarType::F32 => SpirvScalarKey::F32,
ast::ScalarType::F64 => SpirvScalarKey::F64, ast::ScalarType::F64 => SpirvScalarKey::F64,
} ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
} ast::ScalarType::Pred => SpirvScalarKey::Pred,
}
impl From<ast::ExtendedScalarType> for SpirvScalarKey {
fn from(t: ast::ExtendedScalarType) -> Self {
match t {
ast::ExtendedScalarType::Pred => SpirvScalarKey::Pred,
ast::ExtendedScalarType::F16x2 => SpirvScalarKey::F16x2,
} }
} }
} }
@ -135,6 +129,13 @@ impl TypeWordMap {
.entry(t) .entry(t)
.or_insert_with(|| b.type_pointer(None, storage, base)) .or_insert_with(|| b.type_pointer(None, storage, base))
} }
SpirvType::Vector(typ, len) => {
let base = self.get_or_add_spirv_scalar(b, typ);
*self
.complex
.entry(t)
.or_insert_with(|| b.type_vector(base, len as u32))
}
SpirvType::Array(typ, len) => { SpirvType::Array(typ, len) => {
let base = self.get_or_add_spirv_scalar(b, typ); let base = self.get_or_add_spirv_scalar(b, typ);
*self *self
@ -232,8 +233,8 @@ fn emit_function_header<'a>(
spirv::FunctionControl::NONE, spirv::FunctionControl::NONE,
func_type, func_type,
)?; )?;
func_directive.visit_args(|arg| { func_directive.visit_args(&mut |arg| {
let result_type = map.get_or_add_scalar(builder, arg.a_type); let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type).into());
let inst = dr::Instruction::new( let inst = dr::Instruction::new(
spirv::Op::FunctionParameter, spirv::Op::FunctionParameter,
Some(result_type), Some(result_type),
@ -285,9 +286,9 @@ fn expand_kernel_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>, args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::KernelArgument<ExpandedArgParams>> { ) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
args.map(|a| ast::KernelArgument { args.map(|a| ast::KernelArgument {
name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))), name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
a_type: a.a_type, v_type: a.v_type,
length: a.length, align: a.align,
}) })
.collect() .collect()
} }
@ -297,12 +298,9 @@ fn expand_fn_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>, args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::FnArgument<ExpandedArgParams>> { ) -> Vec<ast::FnArgument<ExpandedArgParams>> {
args.map(|a| ast::FnArgument { args.map(|a| ast::FnArgument {
state_space: a.state_space, name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
base: ast::KernelArgument { v_type: a.v_type,
name: fn_resolver.add_def(a.base.name, Some(ast::Type::Scalar(a.base.a_type))), align: a.align,
a_type: a.base.a_type,
length: a.base.length,
},
}) })
.collect() .collect()
} }
@ -375,16 +373,12 @@ fn resolve_fn_calls(
fn to_resolved_fn_args<T>( fn to_resolved_fn_args<T>(
params: Vec<T>, params: Vec<T>,
params_decl: &[(ast::FnArgStateSpace, ast::ScalarType)], params_decl: &[ast::FnArgumentType],
) -> Vec<ArgCall<T>> { ) -> Vec<(T, ast::FnArgumentType)> {
params params
.into_iter() .into_iter()
.zip(params_decl.iter()) .zip(params_decl.iter())
.map(|(id, &(space, typ))| ArgCall { .map(|(id, typ)| (id, *typ))
id,
typ: ast::Type::Scalar(typ),
space: space,
})
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
@ -476,12 +470,11 @@ fn insert_mem_ssa_statements<'a, 'b>(
let out_param = match &mut f_args { let out_param = match &mut f_args {
ast::MethodDecl::Kernel(_, in_params) => { ast::MethodDecl::Kernel(_, in_params) => {
for p in in_params.iter_mut() { for p in in_params.iter_mut() {
let typ = ast::Type::Scalar(p.a_type); let typ = ast::Type::from(p.v_type);
let new_id = id_def.new_id(Some(typ)); let new_id = id_def.new_id(Some(typ));
result.push(Statement::Variable(VariableDecl { result.push(Statement::Variable(ast::Variable {
space: ast::StateSpace::Reg, align: p.align,
align: None, v_type: ast::VariableType::Param(p.v_type),
v_type: typ,
name: p.name, name: p.name,
})); }));
result.push(Statement::StoreVar( result.push(Statement::StoreVar(
@ -497,32 +490,31 @@ fn insert_mem_ssa_statements<'a, 'b>(
} }
ast::MethodDecl::Func(out_params, _, in_params) => { ast::MethodDecl::Func(out_params, _, in_params) => {
for p in in_params.iter_mut() { for p in in_params.iter_mut() {
let typ = ast::Type::Scalar(p.base.a_type); let typ = ast::Type::from(p.v_type);
let new_id = id_def.new_id(Some(typ)); let new_id = id_def.new_id(Some(typ));
result.push(Statement::Variable(VariableDecl { let var_typ = ast::VariableType::from(p.v_type);
space: ast::StateSpace::Reg, result.push(Statement::Variable(ast::Variable {
align: None, align: p.align,
v_type: typ, v_type: var_typ,
name: p.base.name, name: p.name,
})); }));
result.push(Statement::StoreVar( result.push(Statement::StoreVar(
ast::Arg2St { ast::Arg2St {
src1: p.base.name, src1: p.name,
src2: new_id, src2: new_id,
}, },
typ, typ,
)); ));
p.base.name = new_id; p.name = new_id;
} }
match &mut **out_params { match &mut **out_params {
[p] => { [p] => {
result.push(Statement::Variable(VariableDecl { result.push(Statement::Variable(ast::Variable {
space: ast::StateSpace::Reg, align: p.align,
align: None, v_type: ast::VariableType::from(p.v_type),
v_type: ast::Type::Scalar(p.base.a_type), name: p.name,
name: p.base.name,
})); }));
Some(p.base.name) Some(p.name)
} }
[] => None, [] => None,
_ => todo!(), _ => todo!(),
@ -552,15 +544,13 @@ fn insert_mem_ssa_statements<'a, 'b>(
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst), inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
}, },
Statement::Conditional(mut bra) => { Statement::Conditional(mut bra) => {
let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar( let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
ast::ExtendedScalarType::Pred,
)));
result.push(Statement::LoadVar( result.push(Statement::LoadVar(
Arg2 { Arg2 {
dst: generated_id, dst: generated_id,
src: bra.predicate, src: bra.predicate,
}, },
ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred), ast::Type::Scalar(ast::ScalarType::Pred),
)); ));
bra.predicate = generated_id; bra.predicate = generated_id;
result.push(Statement::Conditional(bra)); result.push(Statement::Conditional(bra));
@ -642,7 +632,15 @@ fn expand_arguments<'a, 'b>(
let new_inst = inst.map(&mut visitor); let new_inst = inst.map(&mut visitor);
result.push(Statement::Instruction(new_inst)); result.push(Statement::Instruction(new_inst));
} }
Statement::Variable(v_decl) => result.push(Statement::Variable(v_decl)), Statement::Variable(ast::Variable {
align,
v_type,
name,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
name,
})),
Statement::Label(id) => result.push(Statement::Label(id)), Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
@ -745,7 +743,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
) -> spirv::Word { ) -> spirv::Word {
match &desc.op { match &desc.op {
ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)), ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)),
ast::MovOperand::Vec(_, _) => todo!(), ast::MovOperand::Vec(opr, _) => self.variable(desc.new_op(*opr)),
} }
} }
} }
@ -835,13 +833,19 @@ fn get_function_type(
match method_decl { match method_decl {
ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn( ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn(
builder, builder,
out_params.iter().map(|p| SpirvType::from(p.base.a_type)), out_params
in_params.iter().map(|p| SpirvType::from(p.base.a_type)), .iter()
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
in_params
.iter()
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
), ),
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn( ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
builder, builder,
iter::empty(), iter::empty(),
params.iter().map(|p| SpirvType::from(p.a_type)), params
.iter()
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
), ),
} }
} }
@ -870,31 +874,38 @@ fn emit_function_body_ops(
Statement::Label(_) => (), Statement::Label(_) => (),
Statement::Call(call) => { Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params { let (result_type, result_id) = match &*call.ret_params {
[p] => (map.get_or_add(builder, SpirvType::from(p.typ)), p.id), [(id, typ)] => (
map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))),
*id,
),
_ => todo!(), _ => todo!(),
}; };
let arg_list = call.param_list.iter().map(|p| p.id).collect::<Vec<_>>(); let arg_list = call
.param_list
.iter()
.map(|(id, _)| *id)
.collect::<Vec<_>>();
builder.function_call(result_type, Some(result_id), call.func, arg_list)?; builder.function_call(result_type, Some(result_id), call.func, arg_list)?;
} }
Statement::Variable(VariableDecl { Statement::Variable(ast::Variable {
name: id,
v_type: typ,
space: ss,
align, align,
v_type,
name,
}) => { }) => {
let type_id = map.get_or_add( let type_id = map.get_or_add(
builder, builder,
SpirvType::new_pointer(*typ, spirv::StorageClass::Function), SpirvType::new_pointer(ast::Type::from(*v_type), spirv::StorageClass::Function),
); );
let st_class = match ss { let st_class = match v_type {
ast::StateSpace::Reg | ast::StateSpace::Param => spirv::StorageClass::Function, ast::VariableType::Reg(_) | ast::VariableType::Param(_) => {
ast::StateSpace::Local => spirv::StorageClass::Workgroup, spirv::StorageClass::Function
_ => todo!(), }
ast::VariableType::Local(_) => spirv::StorageClass::Workgroup,
}; };
builder.variable(type_id, Some(*id), st_class, None); builder.variable(type_id, Some(*name), st_class, None);
if let Some(align) = align { if let Some(align) = align {
builder.decorate( builder.decorate(
*id, *name,
spirv::Decoration::Alignment, spirv::Decoration::Alignment,
&[dr::Operand::LiteralInt32(*align)], &[dr::Operand::LiteralInt32(*align)],
); );
@ -1051,7 +1062,7 @@ fn emit_cvt(
if desc.saturate || desc.flush_to_zero { if desc.saturate || desc.flush_to_zero {
todo!() todo!()
} }
let dest_t: ast::Type = desc.dst.into(); let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
builder.f_convert(result_type, Some(arg.dst), arg.src)?; builder.f_convert(result_type, Some(arg.dst), arg.src)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding); emit_rounding_decoration(builder, arg.dst, desc.rounding);
@ -1060,7 +1071,7 @@ fn emit_cvt(
if desc.saturate || desc.flush_to_zero { if desc.saturate || desc.flush_to_zero {
todo!() todo!()
} }
let dest_t: ast::Type = desc.dst.into(); let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.src.is_signed() { if desc.src.is_signed() {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
@ -1367,7 +1378,7 @@ fn normalize_identifiers<'a, 'b>(
fn expand_map_variables<'a, 'b>( fn expand_map_variables<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>, id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver, fn_defs: &GlobalFnDeclResolver<'a, 'b>,
result: &mut Vec<NormalizedStatement>, result: &mut Vec<NormalizedStatement>,
s: ast::Statement<ast::ParsedArgParams<'a>>, s: ast::Statement<ast::ParsedArgParams<'a>>,
) { ) {
@ -1386,21 +1397,19 @@ fn expand_map_variables<'a, 'b>(
))), ))),
ast::Statement::Variable(var) => match var.count { ast::Statement::Variable(var) => match var.count {
Some(count) => { Some(count) => {
for new_id in id_defs.add_defs(var.name, count, var.v_type) { for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) {
result.push(Statement::Variable(VariableDecl { result.push(Statement::Variable(ast::Variable {
space: var.space, align: var.var.align,
align: var.align, v_type: var.var.v_type,
v_type: var.v_type,
name: new_id, name: new_id,
})) }))
} }
} }
None => { None => {
let new_id = id_defs.add_def(var.name, Some(var.v_type)); let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into()));
result.push(Statement::Variable(VariableDecl { result.push(Statement::Variable(ast::Variable {
space: var.space, align: var.var.align,
align: var.align, v_type: var.var.v_type,
v_type: var.v_type,
name: new_id, name: new_id,
})); }));
} }
@ -1408,15 +1417,38 @@ fn expand_map_variables<'a, 'b>(
} }
} }
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash)]
enum PtxSpecialRegister {
Tid,
Ntid,
Ctaid,
Nctaid,
Gridid,
}
impl PtxSpecialRegister {
fn try_parse(s: &str) -> Option<Self> {
match s {
"%tid" => Some(Self::Tid),
"%ntid" => Some(Self::Ntid),
"%ctaid" => Some(Self::Ctaid),
"%nctaid" => Some(Self::Nctaid),
"%gridid" => Some(Self::Gridid),
_ => None,
}
}
}
struct GlobalStringIdResolver<'input> { struct GlobalStringIdResolver<'input> {
current_id: spirv::Word, current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>, variables: HashMap<Cow<'input, str>, spirv::Word>,
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
fns: HashMap<spirv::Word, FnDecl>, fns: HashMap<spirv::Word, FnDecl>,
} }
pub struct FnDecl { pub struct FnDecl {
ret_vals: Vec<(ast::FnArgStateSpace, ast::ScalarType)>, ret_vals: Vec<ast::FnArgumentType>,
params: Vec<(ast::FnArgStateSpace, ast::ScalarType)>, params: Vec<ast::FnArgumentType>,
} }
impl<'a> GlobalStringIdResolver<'a> { impl<'a> GlobalStringIdResolver<'a> {
@ -1424,6 +1456,7 @@ impl<'a> GlobalStringIdResolver<'a> {
Self { Self {
current_id: start_id, current_id: start_id,
variables: HashMap::new(), variables: HashMap::new(),
special_registers: HashMap::new(),
fns: HashMap::new(), fns: HashMap::new(),
} }
} }
@ -1461,6 +1494,7 @@ impl<'a> GlobalStringIdResolver<'a> {
let mut fn_resolver = FnStringIdResolver { let mut fn_resolver = FnStringIdResolver {
current_id: &mut self.current_id, current_id: &mut self.current_id,
global_variables: &self.variables, global_variables: &self.variables,
special_registers: &mut self.special_registers,
variables: vec![HashMap::new(); 1], variables: vec![HashMap::new(); 1],
type_check: HashMap::new(), type_check: HashMap::new(),
}; };
@ -1474,14 +1508,8 @@ impl<'a> GlobalStringIdResolver<'a> {
self.fns.insert( self.fns.insert(
name_id, name_id,
FnDecl { FnDecl {
ret_vals: ret_params_ids ret_vals: ret_params_ids.iter().map(|p| p.v_type).collect(),
.iter() params: params_ids.iter().map(|p| p.v_type).collect(),
.map(|p| (p.state_space, p.base.a_type))
.collect(),
params: params_ids
.iter()
.map(|p| (p.state_space, p.base.a_type))
.collect(),
}, },
); );
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
@ -1516,7 +1544,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
struct FnStringIdResolver<'input, 'b> { struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word, current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>, global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
//global: &'b mut GlobalStringIdResolver<'a>, special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>, variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
type_check: HashMap<u32, ast::Type>, type_check: HashMap<u32, ast::Type>,
} }
@ -1537,14 +1565,28 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
self.variables.pop(); self.variables.pop();
} }
fn get_id(&self, id: &str) -> spirv::Word { fn get_id(&mut self, id: &str) -> spirv::Word {
for scope in self.variables.iter().rev() { for scope in self.variables.iter().rev() {
match scope.get(id) { match scope.get(id) {
Some(id) => return *id, Some(id) => return *id,
None => continue, None => continue,
} }
} }
self.global_variables[id] match self.global_variables.get(id) {
Some(id) => *id,
None => {
let sreg = PtxSpecialRegister::try_parse(id).unwrap_or_else(|| todo!());
match self.special_registers.entry(sreg) {
hash_map::Entry::Occupied(e) => *e.get(),
hash_map::Entry::Vacant(e) => {
let numeric_id = *self.current_id;
*self.current_id += 1;
e.insert(numeric_id);
numeric_id
}
}
}
}
} }
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word { fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
@ -1602,7 +1644,7 @@ impl<'b> NumericIdResolver<'b> {
enum Statement<I, P: ast::ArgParams> { enum Statement<I, P: ast::ArgParams> {
Label(u32), Label(u32),
Variable(VariableDecl), Variable(ast::Variable<ast::VariableType, P>),
Instruction(I), Instruction(I),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type), LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type), StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
@ -1614,18 +1656,11 @@ enum Statement<I, P: ast::ArgParams> {
RetValue(ast::RetData, spirv::Word), RetValue(ast::RetData, spirv::Word),
} }
struct VariableDecl {
pub space: ast::StateSpace,
pub align: Option<u32>,
pub v_type: ast::Type,
pub name: spirv::Word,
}
struct ResolvedCall<P: ast::ArgParams> { struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool, pub uniform: bool,
pub ret_params: Vec<ArgCall<spirv::Word>>, pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
pub func: spirv::Word, pub func: spirv::Word,
pub param_list: Vec<ArgCall<P::CallOperand>>, pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
} }
impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> { impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
@ -1636,18 +1671,14 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
let ret_params = self let ret_params = self
.ret_params .ret_params
.into_iter() .into_iter()
.map(|p| { .map(|(id, typ)| {
let new_id = visitor.variable(ArgumentDescriptor { let new_id = visitor.variable(ArgumentDescriptor {
op: p.id, op: id,
typ: Some(p.typ), typ: Some(typ.into()),
is_dst: true, is_dst: true,
is_pointer: false, is_pointer: false,
}); });
ArgCall { (new_id, typ)
id: new_id,
typ: p.typ,
space: p.space,
}
}) })
.collect(); .collect();
let func = visitor.variable(ArgumentDescriptor { let func = visitor.variable(ArgumentDescriptor {
@ -1659,18 +1690,14 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
let param_list = self let param_list = self
.param_list .param_list
.into_iter() .into_iter()
.map(|p| { .map(|(id, typ)| {
let new_id = visitor.src_call_operand(ArgumentDescriptor { let new_id = visitor.src_call_operand(ArgumentDescriptor {
op: p.id, op: id,
typ: Some(p.typ), typ: Some(typ.into()),
is_dst: false, is_dst: false,
is_pointer: false, is_pointer: false,
}); });
ArgCall { (new_id, typ)
id: new_id,
typ: p.typ,
space: p.space,
}
}) })
.collect(); .collect();
ResolvedCall { ResolvedCall {
@ -1700,12 +1727,6 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
} }
} }
struct ArgCall<ID> {
id: ID,
typ: ast::Type,
space: ast::FnArgStateSpace,
}
pub trait ArgParamsEx: ast::ArgParams { pub trait ArgParamsEx: ast::ArgParams {
fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl; fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl;
} }
@ -1817,7 +1838,9 @@ where
) -> ast::MovOperand<spirv::Word> { ) -> ast::MovOperand<spirv::Word> {
match desc.op { match desc.op {
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))), ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2), ast::MovOperand::Vec(reg, x2) => {
ast::MovOperand::Vec(self.variable(desc.new_op(reg)), x2)
}
} }
} }
} }
@ -1881,13 +1904,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
} }
ast::Instruction::Cvt(d, a) => { ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d { let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (desc.dst.into(), desc.src.into()), ast::CvtDetails::FloatFromFloat(desc) => (
ast::CvtDetails::FloatFromInt(desc) => { ast::Type::Scalar(desc.dst.into()),
(desc.dst.into(), ast::Type::Scalar(desc.src.into())) ast::Type::Scalar(desc.src.into()),
} ),
ast::CvtDetails::IntFromFloat(desc) => { ast::CvtDetails::FloatFromInt(desc) => (
(ast::Type::Scalar(desc.dst.into()), desc.src.into()) ast::Type::Scalar(desc.dst.into()),
} ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromFloat(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromInt(desc) => ( ast::CvtDetails::IntFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()), ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()), ast::Type::Scalar(desc.src.into()),
@ -2261,14 +2289,14 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ast::Arg4 { ast::Arg4 {
dst1: visitor.variable(ArgumentDescriptor { dst1: visitor.variable(ArgumentDescriptor {
op: self.dst1, op: self.dst1,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: true, is_dst: true,
is_pointer: false, is_pointer: false,
}), }),
dst2: self.dst2.map(|dst2| { dst2: self.dst2.map(|dst2| {
visitor.variable(ArgumentDescriptor { visitor.variable(ArgumentDescriptor {
op: dst2, op: dst2,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: true, is_dst: true,
is_pointer: false, is_pointer: false,
}) })
@ -2298,14 +2326,14 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
ast::Arg5 { ast::Arg5 {
dst1: visitor.variable(ArgumentDescriptor { dst1: visitor.variable(ArgumentDescriptor {
op: self.dst1, op: self.dst1,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: true, is_dst: true,
is_pointer: false, is_pointer: false,
}), }),
dst2: self.dst2.map(|dst2| { dst2: self.dst2.map(|dst2| {
visitor.variable(ArgumentDescriptor { visitor.variable(ArgumentDescriptor {
op: dst2, op: dst2,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: true, is_dst: true,
is_pointer: false, is_pointer: false,
}) })
@ -2324,7 +2352,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
}), }),
src3: visitor.operand(ArgumentDescriptor { src3: visitor.operand(ArgumentDescriptor {
op: self.src3, op: self.src3,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: false, is_dst: false,
is_pointer: false, is_pointer: false,
}), }),
@ -2332,65 +2360,6 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
} }
} }
/*
impl<T: ArgParamsEx> ast::ArgCall<T> {
fn map<'a, U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
fn_resolve: &GlobalFnDeclResolver<'a>,
) -> ast::ArgCall<U> {
// TODO: error out if lengths don't match
let fn_decl = T::get_fn_decl(&self.func, fn_resolve);
let ret_params = self
.ret_params
.into_iter()
.zip(fn_decl.ret_vals.iter().copied())
.map(|(a, (space, typ))| {
visitor.variable(ArgumentDescriptor {
op: a,
typ: Some(ast::Type::Scalar(typ)),
is_dst: true,
is_pointer: if space == ast::FnArgStateSpace::Reg {
false
} else {
true
},
})
})
.collect();
let func = visitor.variable(ArgumentDescriptor {
op: self.func,
typ: None,
is_dst: false,
is_pointer: false,
});
let param_list = self
.param_list
.into_iter()
.zip(fn_decl.params.iter().copied())
.map(|(a, (space, typ))| {
visitor.src_call_operand(ArgumentDescriptor {
op: a,
typ: Some(ast::Type::Scalar(typ)),
is_dst: false,
is_pointer: if space == ast::FnArgStateSpace::Reg {
false
} else {
true
},
})
})
.collect();
ast::ArgCall {
uniform: false,
ret_params,
func: func,
param_list: param_list,
}
}
}
*/
impl<T> ast::CallOperand<T> { impl<T> ast::CallOperand<T> {
fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::CallOperand<U> { fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::CallOperand<U> {
match self { match self {
@ -2418,6 +2387,8 @@ enum ScalarKind {
Unsigned, Unsigned,
Signed, Signed,
Float, Float,
Float2,
Pred,
} }
impl ast::ScalarType { impl ast::ScalarType {
@ -2438,6 +2409,8 @@ impl ast::ScalarType {
ast::ScalarType::S64 => 8, ast::ScalarType::S64 => 8,
ast::ScalarType::B64 => 8, ast::ScalarType::B64 => 8,
ast::ScalarType::F64 => 8, ast::ScalarType::F64 => 8,
ast::ScalarType::F16x2 => 4,
ast::ScalarType::Pred => 1,
} }
} }
@ -2458,6 +2431,8 @@ impl ast::ScalarType {
ast::ScalarType::F16 => ScalarKind::Float, ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float, ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float, ast::ScalarType::F64 => ScalarKind::Float,
ast::ScalarType::F16x2 => ScalarKind::Float,
ast::ScalarType::Pred => ScalarKind::Pred,
} }
} }
@ -2490,6 +2465,11 @@ impl ast::ScalarType {
8 => ast::ScalarType::U64, 8 => ast::ScalarType::U64,
_ => unreachable!(), _ => unreachable!(),
}, },
ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2,
_ => unreachable!(),
},
ScalarKind::Pred => ast::ScalarType::Pred,
} }
} }
} }
@ -2497,7 +2477,7 @@ impl ast::ScalarType {
impl ast::NotType { impl ast::NotType {
fn to_type(self) -> ast::Type { fn to_type(self) -> ast::Type {
match self { match self {
ast::NotType::Pred => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred), ast::NotType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16), ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32), ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64), ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
@ -2519,7 +2499,9 @@ impl ast::AddDetails {
fn get_type(&self) -> ast::Type { fn get_type(&self) -> ast::Type {
match self { match self {
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => (*typ).into(), ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => {
ast::Type::Scalar((*typ).into())
}
} }
} }
} }
@ -2528,7 +2510,9 @@ impl ast::MulDetails {
fn get_type(&self) -> ast::Type { fn get_type(&self) -> ast::Type {
match self { match self {
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => (*typ).into(), ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => {
ast::Type::Scalar((*typ).into())
}
} }
} }
} }
@ -2560,6 +2544,15 @@ impl ast::LdStateSpace {
} }
} }
impl From<ast::FnArgumentType> for ast::VariableType {
fn from(t: ast::FnArgumentType) -> Self {
match t {
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
}
}
}
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) { match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@ -2575,6 +2568,8 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
ScalarKind::Unsigned => { ScalarKind::Unsigned => {
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
} }
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => false,
} }
} }
_ => false, _ => false,
@ -2758,6 +2753,8 @@ fn should_convert_relaxed_src(
None None
} }
} }
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
}, },
_ => None, _ => None,
} }
@ -2807,6 +2804,8 @@ fn should_convert_relaxed_dst(
None None
} }
} }
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
}, },
_ => None, _ => None,
} }
@ -2862,16 +2861,21 @@ impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
} }
} }
impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> { impl<'a, P: ArgParamsEx<ID = spirv::Word>> ast::MethodDecl<'a, P> {
fn visit_args(&self, f: impl FnMut(&ast::KernelArgument<P>)) { fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<P>)) {
match self { match self {
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(f), ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
ast::MethodDecl::Func(_, _, params) => params.iter().map(|a| &a.base).for_each(f), ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| {
f(&ast::FnArgument {
align: arg.align,
name: arg.name,
v_type: ast::FnArgumentType::Param(arg.v_type),
})
}),
} }
} }
} }
// CFGs below taken from "Modern Compiler Implementation in Java"
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;