mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-18 09:46:21 +03:00
Be more precise about types admitted in register definitions and method arguments
This commit is contained in:
256
ptx/src/ast.rs
256
ptx/src/ast.rs
@ -12,9 +12,117 @@ quick_error! {
|
|||||||
SyntaxError {}
|
SyntaxError {}
|
||||||
NonF32Ftz {}
|
NonF32Ftz {}
|
||||||
WrongArrayType {}
|
WrongArrayType {}
|
||||||
|
WrongVectorElement {}
|
||||||
|
MultiArrayVariable {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! sub_scalar_type {
|
||||||
|
($name:ident { $($variant:ident),+ $(,)? }) => {
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum $name {
|
||||||
|
$(
|
||||||
|
$variant,
|
||||||
|
)+
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<$name> for ScalarType {
|
||||||
|
fn from(t: $name) -> ScalarType {
|
||||||
|
match t {
|
||||||
|
$(
|
||||||
|
$name::$variant => ScalarType::$variant,
|
||||||
|
)+
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! sub_type {
|
||||||
|
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum $type_name {
|
||||||
|
$(
|
||||||
|
$variant ($($field_type),+),
|
||||||
|
)+
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<$type_name> for Type {
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
fn from(t: $type_name) -> Type {
|
||||||
|
match t {
|
||||||
|
$(
|
||||||
|
$type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+),
|
||||||
|
)+
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_type! {
|
||||||
|
VariableRegType {
|
||||||
|
Scalar(ScalarType),
|
||||||
|
Vector(SizedScalarType, u8),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_type! {
|
||||||
|
VariableLocalType {
|
||||||
|
Scalar(SizedScalarType),
|
||||||
|
Vector(SizedScalarType, u8),
|
||||||
|
Array(SizedScalarType, u32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For some weird reson this is illegal:
|
||||||
|
// .param .f16x2 foobar;
|
||||||
|
// but this is legal:
|
||||||
|
// .param .f16x2 foobar[1];
|
||||||
|
sub_type! {
|
||||||
|
VariableParamType {
|
||||||
|
Scalar(ParamScalarType),
|
||||||
|
Array(SizedScalarType, u32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_scalar_type!(SizedScalarType {
|
||||||
|
B8,
|
||||||
|
B16,
|
||||||
|
B32,
|
||||||
|
B64,
|
||||||
|
U8,
|
||||||
|
U16,
|
||||||
|
U32,
|
||||||
|
U64,
|
||||||
|
S8,
|
||||||
|
S16,
|
||||||
|
S32,
|
||||||
|
S64,
|
||||||
|
F16,
|
||||||
|
F16x2,
|
||||||
|
F32,
|
||||||
|
F64,
|
||||||
|
});
|
||||||
|
|
||||||
|
sub_scalar_type!(ParamScalarType {
|
||||||
|
B8,
|
||||||
|
B16,
|
||||||
|
B32,
|
||||||
|
B64,
|
||||||
|
U8,
|
||||||
|
U16,
|
||||||
|
U32,
|
||||||
|
U64,
|
||||||
|
S8,
|
||||||
|
S16,
|
||||||
|
S32,
|
||||||
|
S64,
|
||||||
|
F16,
|
||||||
|
F32,
|
||||||
|
F64,
|
||||||
|
});
|
||||||
|
|
||||||
pub trait UnwrapWithVec<E, To> {
|
pub trait UnwrapWithVec<E, To> {
|
||||||
fn unwrap_with(self, errs: &mut Vec<E>) -> To;
|
fn unwrap_with(self, errs: &mut Vec<E>) -> To;
|
||||||
}
|
}
|
||||||
@ -56,6 +164,9 @@ pub enum MethodDecl<'a, P: ArgParams> {
|
|||||||
Kernel(&'a str, Vec<KernelArgument<P>>),
|
Kernel(&'a str, Vec<KernelArgument<P>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub type FnArgument<P: ArgParams> = Variable<FnArgumentType, P>;
|
||||||
|
pub type KernelArgument<P: ArgParams> = Variable<VariableParamType, P>;
|
||||||
|
|
||||||
pub struct Function<'a, P: ArgParams, S> {
|
pub struct Function<'a, P: ArgParams, S> {
|
||||||
pub func_directive: MethodDecl<'a, P>,
|
pub func_directive: MethodDecl<'a, P>,
|
||||||
pub body: Option<Vec<S>>,
|
pub body: Option<Vec<S>>,
|
||||||
@ -63,43 +174,28 @@ pub struct Function<'a, P: ArgParams, S> {
|
|||||||
|
|
||||||
pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>;
|
pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>;
|
||||||
|
|
||||||
pub struct FnArgument<P: ArgParams> {
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
pub base: KernelArgument<P>,
|
pub enum FnArgumentType {
|
||||||
pub state_space: FnArgStateSpace,
|
Reg(VariableRegType),
|
||||||
|
Param(VariableParamType),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
impl From<FnArgumentType> for Type {
|
||||||
pub enum FnArgStateSpace {
|
fn from(t: FnArgumentType) -> Self {
|
||||||
Reg,
|
match t {
|
||||||
Param,
|
FnArgumentType::Reg(x) => x.into(),
|
||||||
}
|
FnArgumentType::Param(x) => x.into(),
|
||||||
|
}
|
||||||
#[derive(Default, Copy, Clone)]
|
}
|
||||||
pub struct KernelArgument<P: ArgParams> {
|
|
||||||
pub name: P::ID,
|
|
||||||
pub a_type: ScalarType,
|
|
||||||
// TODO: turn length into part of type definition
|
|
||||||
pub length: u32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
Scalar(ScalarType),
|
Scalar(ScalarType),
|
||||||
ExtendedScalar(ExtendedScalarType),
|
Vector(ScalarType, u8),
|
||||||
Array(ScalarType, u32),
|
Array(ScalarType, u32),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<FloatType> for Type {
|
|
||||||
fn from(t: FloatType) -> Self {
|
|
||||||
match t {
|
|
||||||
FloatType::F16 => Type::Scalar(ScalarType::F16),
|
|
||||||
FloatType::F16x2 => Type::ExtendedScalar(ExtendedScalarType::F16x2),
|
|
||||||
FloatType::F32 => Type::Scalar(ScalarType::F32),
|
|
||||||
FloatType::F64 => Type::Scalar(ScalarType::F64),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||||
pub enum ScalarType {
|
pub enum ScalarType {
|
||||||
B8,
|
B8,
|
||||||
@ -117,25 +213,11 @@ pub enum ScalarType {
|
|||||||
F16,
|
F16,
|
||||||
F32,
|
F32,
|
||||||
F64,
|
F64,
|
||||||
|
F16x2,
|
||||||
|
Pred,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<IntType> for ScalarType {
|
sub_scalar_type!(IntType {
|
||||||
fn from(t: IntType) -> Self {
|
|
||||||
match t {
|
|
||||||
IntType::S8 => ScalarType::S8,
|
|
||||||
IntType::S16 => ScalarType::S16,
|
|
||||||
IntType::S32 => ScalarType::S32,
|
|
||||||
IntType::S64 => ScalarType::S64,
|
|
||||||
IntType::U8 => ScalarType::U8,
|
|
||||||
IntType::U16 => ScalarType::U16,
|
|
||||||
IntType::U32 => ScalarType::U32,
|
|
||||||
IntType::U64 => ScalarType::U64,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
|
||||||
pub enum IntType {
|
|
||||||
U8,
|
U8,
|
||||||
U16,
|
U16,
|
||||||
U32,
|
U32,
|
||||||
@ -143,8 +225,8 @@ pub enum IntType {
|
|||||||
S8,
|
S8,
|
||||||
S16,
|
S16,
|
||||||
S32,
|
S32,
|
||||||
S64,
|
S64
|
||||||
}
|
});
|
||||||
|
|
||||||
impl IntType {
|
impl IntType {
|
||||||
pub fn is_signed(self) -> bool {
|
pub fn is_signed(self) -> bool {
|
||||||
@ -168,19 +250,12 @@ impl IntType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
sub_scalar_type!(FloatType {
|
||||||
pub enum FloatType {
|
|
||||||
F16,
|
F16,
|
||||||
F16x2,
|
F16x2,
|
||||||
F32,
|
F32,
|
||||||
F64,
|
F64
|
||||||
}
|
});
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
|
||||||
pub enum ExtendedScalarType {
|
|
||||||
F16x2,
|
|
||||||
Pred,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ScalarType {
|
impl Default for ScalarType {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
@ -190,19 +265,39 @@ impl Default for ScalarType {
|
|||||||
|
|
||||||
pub enum Statement<P: ArgParams> {
|
pub enum Statement<P: ArgParams> {
|
||||||
Label(P::ID),
|
Label(P::ID),
|
||||||
Variable(Variable<P>),
|
Variable(MultiVariable<P>),
|
||||||
Instruction(Option<PredAt<P::ID>>, Instruction<P>),
|
Instruction(Option<PredAt<P::ID>>, Instruction<P>),
|
||||||
Block(Vec<Statement<P>>),
|
Block(Vec<Statement<P>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Variable<P: ArgParams> {
|
pub struct MultiVariable<P: ArgParams> {
|
||||||
pub space: StateSpace,
|
pub var: Variable<VariableType, P>,
|
||||||
pub align: Option<u32>,
|
|
||||||
pub v_type: Type,
|
|
||||||
pub name: P::ID,
|
|
||||||
pub count: Option<u32>,
|
pub count: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct Variable<T, P: ArgParams> {
|
||||||
|
pub align: Option<u32>,
|
||||||
|
pub v_type: T,
|
||||||
|
pub name: P::ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Eq, PartialEq, Copy, Clone)]
|
||||||
|
pub enum VariableType {
|
||||||
|
Reg(VariableRegType),
|
||||||
|
Local(VariableLocalType),
|
||||||
|
Param(VariableParamType),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<VariableType> for Type {
|
||||||
|
fn from(t: VariableType) -> Self {
|
||||||
|
match t {
|
||||||
|
VariableType::Reg(t) => t.into(),
|
||||||
|
VariableType::Local(t) => t.into(),
|
||||||
|
VariableType::Param(t) => t.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||||
pub enum StateSpace {
|
pub enum StateSpace {
|
||||||
Reg,
|
Reg,
|
||||||
@ -322,7 +417,7 @@ pub enum CallOperand<ID> {
|
|||||||
|
|
||||||
pub enum MovOperand<ID> {
|
pub enum MovOperand<ID> {
|
||||||
Op(Operand<ID>),
|
Op(Operand<ID>),
|
||||||
Vec(String, String),
|
Vec(ID, u8),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum VectorPrefix {
|
pub enum VectorPrefix {
|
||||||
@ -334,7 +429,7 @@ pub struct LdData {
|
|||||||
pub qualifier: LdStQualifier,
|
pub qualifier: LdStQualifier,
|
||||||
pub state_space: LdStateSpace,
|
pub state_space: LdStateSpace,
|
||||||
pub caching: LdCacheOperator,
|
pub caching: LdCacheOperator,
|
||||||
pub vector: Option<VectorPrefix>,
|
pub vector: Option<u8>,
|
||||||
pub typ: ScalarType,
|
pub typ: ScalarType,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -376,6 +471,37 @@ pub struct MovData {
|
|||||||
pub typ: Type,
|
pub typ: Type,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sub_scalar_type!(MovScalarType {
|
||||||
|
B16,
|
||||||
|
B32,
|
||||||
|
B64,
|
||||||
|
U16,
|
||||||
|
U32,
|
||||||
|
U64,
|
||||||
|
S16,
|
||||||
|
S32,
|
||||||
|
S64,
|
||||||
|
F32,
|
||||||
|
F64,
|
||||||
|
Pred,
|
||||||
|
});
|
||||||
|
|
||||||
|
enum MovType {
|
||||||
|
Scalar(MovScalarType),
|
||||||
|
Vector(MovScalarType, u8),
|
||||||
|
Array(MovScalarType, u32),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<MovType> for Type {
|
||||||
|
fn from(t: MovType) -> Self {
|
||||||
|
match t {
|
||||||
|
MovType::Scalar(t) => Type::Scalar(t.into()),
|
||||||
|
MovType::Vector(t, len) => Type::Vector(t.into(), len),
|
||||||
|
MovType::Array(t, len) => Type::Array(t.into(), len),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub enum MulDetails {
|
pub enum MulDetails {
|
||||||
Int(MulIntDesc),
|
Int(MulIntDesc),
|
||||||
Float(MulFloatDesc),
|
Float(MulFloatDesc),
|
||||||
@ -587,7 +713,7 @@ pub struct StData {
|
|||||||
pub qualifier: LdStQualifier,
|
pub qualifier: LdStQualifier,
|
||||||
pub state_space: StStateSpace,
|
pub state_space: StStateSpace,
|
||||||
pub caching: StCacheOperator,
|
pub caching: StCacheOperator,
|
||||||
pub vector: Option<VectorPrefix>,
|
pub vector: Option<u8>,
|
||||||
pub typ: ScalarType,
|
pub typ: ScalarType,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ match {
|
|||||||
"@",
|
"@",
|
||||||
"[", "]",
|
"[", "]",
|
||||||
"{", "}",
|
"{", "}",
|
||||||
|
"<", ">",
|
||||||
"|",
|
"|",
|
||||||
".acquire",
|
".acquire",
|
||||||
".address_size",
|
".address_size",
|
||||||
@ -133,8 +134,6 @@ match {
|
|||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
|
||||||
r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID,
|
r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID,
|
||||||
r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID,
|
r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID,
|
||||||
} else {
|
|
||||||
r"(?:[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+)<[0-9]+>" => ParametrizedID,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ExtendedID : &'input str = {
|
ExtendedID : &'input str = {
|
||||||
@ -214,7 +213,9 @@ LinkingDirective = {
|
|||||||
|
|
||||||
MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = {
|
MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = {
|
||||||
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
|
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
|
||||||
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
|
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
|
||||||
|
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = {
|
KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = {
|
||||||
@ -225,32 +226,25 @@ FnArguments: Vec<ast::FnArgument<ast::ParsedArgParams<'input>>> = {
|
|||||||
"(" <args:Comma<FnInput>> ")" => args
|
"(" <args:Comma<FnInput>> ")" => args
|
||||||
};
|
};
|
||||||
|
|
||||||
FnInput: ast::FnArgument<ast::ParsedArgParams<'input>> = {
|
KernelInput: ast::Variable<ast::VariableParamType, ast::ParsedArgParams<'input>> = {
|
||||||
".reg" <_type:ScalarType> <name:ExtendedID> => {
|
<v:ParamVariable> => {
|
||||||
ast::FnArgument {
|
let (align, v_type, name) = v;
|
||||||
base: ast::KernelArgument {a_type: _type, name: name, length: 1 },
|
ast::Variable{ align, v_type, name }
|
||||||
state_space: ast::FnArgStateSpace::Reg,
|
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
<p:KernelInput> => {
|
|
||||||
ast::FnArgument {
|
|
||||||
base: p,
|
|
||||||
state_space: ast::FnArgStateSpace::Param,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
FnInput: ast::Variable<ast::FnArgumentType, ast::ParsedArgParams<'input>> = {
|
||||||
KernelInput: ast::KernelArgument<ast::ParsedArgParams<'input>> = {
|
<v:RegVariable> => {
|
||||||
".param" <_type:ScalarType> <name:ExtendedID> => {
|
let (align, v_type, name) = v;
|
||||||
ast::KernelArgument {a_type: _type, name: name, length: 1 }
|
let v_type = ast::FnArgumentType::Reg(v_type);
|
||||||
|
ast::Variable{ align, v_type, name }
|
||||||
},
|
},
|
||||||
".param" <a_type:ScalarType> <name:ExtendedID> "[" <length:Num> "]" => {
|
<v:ParamVariable> => {
|
||||||
let length = length.parse::<u32>();
|
let (align, v_type, name) = v;
|
||||||
let length = length.unwrap_with(errors);
|
let v_type = ast::FnArgumentType::Param(v_type);
|
||||||
ast::KernelArgument { a_type: a_type, name: name, length: length }
|
ast::Variable{ align, v_type, name }
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>> = {
|
pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>> = {
|
||||||
"{" <s:Statement*> "}" => { Some(without_none(s)) },
|
"{" <s:Statement*> "}" => { Some(without_none(s)) },
|
||||||
@ -267,22 +261,13 @@ StateSpaceSpecifier: ast::StateSpace = {
|
|||||||
".param" => ast::StateSpace::Param, // used to prepare function call
|
".param" => ast::StateSpace::Param, // used to prepare function call
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
Type: ast::Type = {
|
|
||||||
<t:ScalarType> => ast::Type::Scalar(t),
|
|
||||||
<t:ExtendedScalarType> => ast::Type::ExtendedScalar(t),
|
|
||||||
};
|
|
||||||
|
|
||||||
ScalarType: ast::ScalarType = {
|
ScalarType: ast::ScalarType = {
|
||||||
".f16" => ast::ScalarType::F16,
|
".f16" => ast::ScalarType::F16,
|
||||||
|
".f16x2" => ast::ScalarType::F16x2,
|
||||||
|
".pred" => ast::ScalarType::Pred,
|
||||||
MemoryType
|
MemoryType
|
||||||
};
|
};
|
||||||
|
|
||||||
ExtendedScalarType: ast::ExtendedScalarType = {
|
|
||||||
".f16x2" => ast::ExtendedScalarType::F16x2,
|
|
||||||
".pred" => ast::ExtendedScalarType::Pred,
|
|
||||||
};
|
|
||||||
|
|
||||||
MemoryType: ast::ScalarType = {
|
MemoryType: ast::ScalarType = {
|
||||||
".b8" => ast::ScalarType::B8,
|
".b8" => ast::ScalarType::B8,
|
||||||
".b16" => ast::ScalarType::B16,
|
".b16" => ast::ScalarType::B16,
|
||||||
@ -303,7 +288,7 @@ MemoryType: ast::ScalarType = {
|
|||||||
Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
|
Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
|
||||||
<l:Label> => Some(ast::Statement::Label(l)),
|
<l:Label> => Some(ast::Statement::Label(l)),
|
||||||
DebugDirective => None,
|
DebugDirective => None,
|
||||||
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
|
<v:MultiVariable> ";" => Some(ast::Statement::Variable(v)),
|
||||||
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
|
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
|
||||||
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
|
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
|
||||||
};
|
};
|
||||||
@ -328,21 +313,109 @@ Align: u32 = {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
|
||||||
<s:StateSpaceSpecifier> <a:Align?> <t:Type> <v:VariableName> <arr: ArraySpecifier?> => {
|
MultiVariable: ast::MultiVariable<ast::ParsedArgParams<'input>> = {
|
||||||
let (name, count) = v;
|
<var:Variable> <count:VariableParam?> => ast::MultiVariable{<>}
|
||||||
let t = match (t, arr) {
|
}
|
||||||
(ast::Type::Scalar(st), Some(arr_size)) => ast::Type::Array(st, arr_size),
|
|
||||||
(t, Some(_)) => {
|
VariableParam: u32 = {
|
||||||
errors.push(ast::PtxError::WrongArrayType);
|
"<" <n:Num> ">" => {
|
||||||
t
|
let size = n.parse::<u32>();
|
||||||
},
|
size.unwrap_with(errors)
|
||||||
(t, None) => t,
|
|
||||||
};
|
|
||||||
ast::Variable { space: s, align: a, v_type: t, name: name, count: count }
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Variable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
|
||||||
|
<v:RegVariable> => {
|
||||||
|
let (align, v_type, name) = v;
|
||||||
|
let v_type = ast::VariableType::Reg(v_type);
|
||||||
|
ast::Variable {align, v_type, name}
|
||||||
|
},
|
||||||
|
LocalVariable,
|
||||||
|
<v:ParamVariable> => {
|
||||||
|
let (align, v_type, name) = v;
|
||||||
|
let v_type = ast::VariableType::Param(v_type);
|
||||||
|
ast::Variable {align, v_type, name}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
|
||||||
|
".reg" <align:Align?> <t:ScalarType> <name:ExtendedID> => {
|
||||||
|
let v_type = ast::VariableRegType::Scalar(t);
|
||||||
|
(align, v_type, name)
|
||||||
|
},
|
||||||
|
".reg" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
|
||||||
|
let v_type = ast::VariableRegType::Vector(t, v_len);
|
||||||
|
(align, v_type, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LocalVariable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
|
||||||
|
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
|
||||||
|
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
|
||||||
|
ast::Variable {align, v_type, name}
|
||||||
|
},
|
||||||
|
".local" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
|
||||||
|
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
|
||||||
|
ast::Variable {align, v_type, name}
|
||||||
|
},
|
||||||
|
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
|
||||||
|
let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr));
|
||||||
|
ast::Variable {align, v_type, name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||||
|
ParamVariable: (Option<u32>, ast::VariableParamType, &'input str) = {
|
||||||
|
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
|
||||||
|
let v_type = ast::VariableParamType::Scalar(t);
|
||||||
|
(align, v_type, name)
|
||||||
|
},
|
||||||
|
".param" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
|
||||||
|
let v_type = ast::VariableParamType::Array(t, arr);
|
||||||
|
(align, v_type, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
SizedScalarType: ast::SizedScalarType = {
|
||||||
|
".b8" => ast::SizedScalarType::B8,
|
||||||
|
".b16" => ast::SizedScalarType::B16,
|
||||||
|
".b32" => ast::SizedScalarType::B32,
|
||||||
|
".b64" => ast::SizedScalarType::B64,
|
||||||
|
".u8" => ast::SizedScalarType::U8,
|
||||||
|
".u16" => ast::SizedScalarType::U16,
|
||||||
|
".u32" => ast::SizedScalarType::U32,
|
||||||
|
".u64" => ast::SizedScalarType::U64,
|
||||||
|
".s8" => ast::SizedScalarType::S8,
|
||||||
|
".s16" => ast::SizedScalarType::S16,
|
||||||
|
".s32" => ast::SizedScalarType::S32,
|
||||||
|
".s64" => ast::SizedScalarType::S64,
|
||||||
|
".f16" => ast::SizedScalarType::F16,
|
||||||
|
".f16x2" => ast::SizedScalarType::F16x2,
|
||||||
|
".f32" => ast::SizedScalarType::F32,
|
||||||
|
".f64" => ast::SizedScalarType::F64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
ParamScalarType: ast::ParamScalarType = {
|
||||||
|
".b8" => ast::ParamScalarType::B8,
|
||||||
|
".b16" => ast::ParamScalarType::B16,
|
||||||
|
".b32" => ast::ParamScalarType::B32,
|
||||||
|
".b64" => ast::ParamScalarType::B64,
|
||||||
|
".u8" => ast::ParamScalarType::U8,
|
||||||
|
".u16" => ast::ParamScalarType::U16,
|
||||||
|
".u32" => ast::ParamScalarType::U32,
|
||||||
|
".u64" => ast::ParamScalarType::U64,
|
||||||
|
".s8" => ast::ParamScalarType::S8,
|
||||||
|
".s16" => ast::ParamScalarType::S16,
|
||||||
|
".s32" => ast::ParamScalarType::S32,
|
||||||
|
".s64" => ast::ParamScalarType::S64,
|
||||||
|
".f16" => ast::ParamScalarType::F16,
|
||||||
|
".f32" => ast::ParamScalarType::F32,
|
||||||
|
".f64" => ast::ParamScalarType::F64,
|
||||||
|
}
|
||||||
|
|
||||||
ArraySpecifier: u32 = {
|
ArraySpecifier: u32 = {
|
||||||
"[" <n:Num> "]" => {
|
"[" <n:Num> "]" => {
|
||||||
let size = n.parse::<u32>();
|
let size = n.parse::<u32>();
|
||||||
@ -350,20 +423,6 @@ ArraySpecifier: u32 = {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
VariableName: (&'input str, Option<u32>) = {
|
|
||||||
<id:ExtendedID> => (id, None),
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
|
|
||||||
<id:ParametrizedID> => {
|
|
||||||
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
|
|
||||||
let count = id[left_angle+1..id.len()-1].parse::<u32>();
|
|
||||||
let count = match count {
|
|
||||||
Ok(c) => Some(c),
|
|
||||||
Err(e) => { errors.push(e.into()); None },
|
|
||||||
};
|
|
||||||
(&id[0..left_angle], count)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
InstLd,
|
InstLd,
|
||||||
InstMov,
|
InstMov,
|
||||||
@ -445,7 +504,7 @@ MovType: ast::Type = {
|
|||||||
".s64" => ast::Type::Scalar(ast::ScalarType::S64),
|
".s64" => ast::Type::Scalar(ast::ScalarType::S64),
|
||||||
".f32" => ast::Type::Scalar(ast::ScalarType::F32),
|
".f32" => ast::Type::Scalar(ast::ScalarType::F32),
|
||||||
".f64" => ast::Type::Scalar(ast::ScalarType::F64),
|
".f64" => ast::Type::Scalar(ast::ScalarType::F64),
|
||||||
".pred" => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)
|
".pred" => ast::Type::Scalar(ast::ScalarType::Pred)
|
||||||
};
|
};
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
|
||||||
@ -934,7 +993,17 @@ MovOperand: ast::MovOperand<&'input str> = {
|
|||||||
<o:Operand> => ast::MovOperand::Op(o),
|
<o:Operand> => ast::MovOperand::Op(o),
|
||||||
<o:VectorOperand> => {
|
<o:VectorOperand> => {
|
||||||
let (pref, suf) = o;
|
let (pref, suf) = o;
|
||||||
ast::MovOperand::Vec(pref.to_string(), suf.to_string())
|
let suf_idx = match suf {
|
||||||
|
"x" | "r" => 0,
|
||||||
|
"y" | "g" => 1,
|
||||||
|
"z" | "b" => 2,
|
||||||
|
"w" | "a" => 3,
|
||||||
|
_ => {
|
||||||
|
errors.push(ast::PtxError::WrongVectorElement);
|
||||||
|
0
|
||||||
|
}
|
||||||
|
};
|
||||||
|
ast::MovOperand::Vec(pref, suf_idx)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -980,9 +1049,9 @@ OptionalDst: &'input str = {
|
|||||||
"|" <dst2:ExtendedID> => dst2
|
"|" <dst2:ExtendedID> => dst2
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorPrefix: ast::VectorPrefix = {
|
VectorPrefix: u8 = {
|
||||||
".v2" => ast::VectorPrefix::V2,
|
".v2" => 2,
|
||||||
".v4" => ast::VectorPrefix::V4
|
".v4" => 4
|
||||||
};
|
};
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file
|
||||||
|
@ -8,16 +8,16 @@ fn parse_and_assert(s: &str) {
|
|||||||
assert!(errors.len() == 0);
|
assert!(errors.len() == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn compile_and_assert(s: &str) -> Result<(), rspirv::dr::Error> {
|
||||||
fn empty() {
|
let mut errors = Vec::new();
|
||||||
parse_and_assert(".version 6.5 .target sm_30, debug");
|
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
|
||||||
|
crate::to_spirv(ast)?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[allow(non_snake_case)]
|
fn empty() {
|
||||||
fn vectorAdd_kernel64_ptx() {
|
parse_and_assert(".version 6.5 .target sm_30, debug");
|
||||||
let vector_add = include_str!("vectorAdd_kernel64.ptx");
|
|
||||||
parse_and_assert(vector_add);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -28,8 +28,14 @@ fn operands_ptx() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
fn _Z9vectorAddPKfS0_Pfi_ptx() {
|
fn vectorAdd_kernel64_ptx() -> Result<(), rspirv::dr::Error> {
|
||||||
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
|
let vector_add = include_str!("vectorAdd_kernel64.ptx");
|
||||||
parse_and_assert(vector_add);
|
compile_and_assert(vector_add)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), rspirv::dr::Error> {
|
||||||
|
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
|
||||||
|
compile_and_assert(vector_add)
|
||||||
|
}
|
||||||
|
@ -54,6 +54,7 @@ test_ptx!(cvta, [3.0f32], [3.0f32]);
|
|||||||
test_ptx!(block, [1u64], [2u64]);
|
test_ptx!(block, [1u64], [2u64]);
|
||||||
test_ptx!(local_align, [1u64], [1u64]);
|
test_ptx!(local_align, [1u64], [1u64]);
|
||||||
test_ptx!(call, [1u64], [2u64]);
|
test_ptx!(call, [1u64], [2u64]);
|
||||||
|
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
||||||
|
|
||||||
struct DisplayError<T: Debug> {
|
struct DisplayError<T: Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
44
ptx/src/test/spirv_run/vector.ptx
Normal file
44
ptx/src/test/spirv_run/vector.ptx
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
// Excersise as many features of vector types as possible
|
||||||
|
|
||||||
|
.version 6.5
|
||||||
|
.target sm_53
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.func (.reg .v2 .u32 output) impl(
|
||||||
|
.reg .v2 .u32 input
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .v2 .u32 temp_v;
|
||||||
|
.reg .u32 temp1;
|
||||||
|
.reg .u32 temp2;
|
||||||
|
|
||||||
|
mov.u32 temp1, input.x;
|
||||||
|
mov.u32 temp2, input.y;
|
||||||
|
add.u32 temp2, temp1, temp2;
|
||||||
|
mov.u32 temp_v.x, temp2;
|
||||||
|
mov.u32 temp_v.y, temp2;
|
||||||
|
mov.v2.u32 output, temp_v;
|
||||||
|
ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
.visible .entry vector(
|
||||||
|
.param .u64 input_p,
|
||||||
|
.param .u64 output_p
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .v2 .u32 temp;
|
||||||
|
.reg .u32 temp1;
|
||||||
|
.reg .u32 temp2;
|
||||||
|
.reg .b64 packed;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input_p];
|
||||||
|
ld.param.u64 out_addr, [output_p];
|
||||||
|
|
||||||
|
ld.v2.u32 temp, [in_addr];
|
||||||
|
call (temp), impl, (temp);
|
||||||
|
mov.b64 packed, temp;
|
||||||
|
st.v2.u32 [out_addr], temp;
|
||||||
|
ret;
|
||||||
|
}
|
46
ptx/src/test/spirv_run/vector.spvtxt
Normal file
46
ptx/src/test/spirv_run/vector.spvtxt
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Int8
|
||||||
|
%25 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "add"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%28 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||||
|
%ulong_1 = OpConstant %ulong 1
|
||||||
|
%1 = OpFunction %void None %28
|
||||||
|
%8 = OpFunctionParameter %ulong
|
||||||
|
%9 = OpFunctionParameter %ulong
|
||||||
|
%23 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%7 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %8
|
||||||
|
OpStore %3 %9
|
||||||
|
%11 = OpLoad %ulong %2
|
||||||
|
%10 = OpCopyObject %ulong %11
|
||||||
|
OpStore %4 %10
|
||||||
|
%13 = OpLoad %ulong %3
|
||||||
|
%12 = OpCopyObject %ulong %13
|
||||||
|
OpStore %5 %12
|
||||||
|
%15 = OpLoad %ulong %4
|
||||||
|
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||||
|
%14 = OpLoad %ulong %21
|
||||||
|
OpStore %6 %14
|
||||||
|
%17 = OpLoad %ulong %6
|
||||||
|
%16 = OpIAdd %ulong %17 %ulong_1
|
||||||
|
OpStore %7 %16
|
||||||
|
%18 = OpLoad %ulong %5
|
||||||
|
%19 = OpLoad %ulong %7
|
||||||
|
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
|
||||||
|
OpStore %22 %19
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
@ -8,6 +8,7 @@ use rspirv::binary::Assemble;
|
|||||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||||
enum SpirvType {
|
enum SpirvType {
|
||||||
Base(SpirvScalarKey),
|
Base(SpirvScalarKey),
|
||||||
|
Vector(SpirvScalarKey, u8),
|
||||||
Array(SpirvScalarKey, u32),
|
Array(SpirvScalarKey, u32),
|
||||||
Pointer(Box<SpirvType>, spirv::StorageClass),
|
Pointer(Box<SpirvType>, spirv::StorageClass),
|
||||||
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
|
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
|
||||||
@ -17,7 +18,7 @@ impl SpirvType {
|
|||||||
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
|
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
|
||||||
let key = match t {
|
let key = match t {
|
||||||
ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
|
ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
|
||||||
ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
|
ast::Type::Vector(typ, len) => SpirvType::Vector(SpirvScalarKey::from(typ), len),
|
||||||
ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
|
ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
|
||||||
};
|
};
|
||||||
SpirvType::Pointer(Box::new(key), sc)
|
SpirvType::Pointer(Box::new(key), sc)
|
||||||
@ -28,7 +29,7 @@ impl From<ast::Type> for SpirvType {
|
|||||||
fn from(t: ast::Type) -> Self {
|
fn from(t: ast::Type) -> Self {
|
||||||
match t {
|
match t {
|
||||||
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
||||||
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
|
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
|
||||||
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -77,15 +78,8 @@ impl From<ast::ScalarType> for SpirvScalarKey {
|
|||||||
ast::ScalarType::F16 => SpirvScalarKey::F16,
|
ast::ScalarType::F16 => SpirvScalarKey::F16,
|
||||||
ast::ScalarType::F32 => SpirvScalarKey::F32,
|
ast::ScalarType::F32 => SpirvScalarKey::F32,
|
||||||
ast::ScalarType::F64 => SpirvScalarKey::F64,
|
ast::ScalarType::F64 => SpirvScalarKey::F64,
|
||||||
}
|
ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
|
||||||
}
|
ast::ScalarType::Pred => SpirvScalarKey::Pred,
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ast::ExtendedScalarType> for SpirvScalarKey {
|
|
||||||
fn from(t: ast::ExtendedScalarType) -> Self {
|
|
||||||
match t {
|
|
||||||
ast::ExtendedScalarType::Pred => SpirvScalarKey::Pred,
|
|
||||||
ast::ExtendedScalarType::F16x2 => SpirvScalarKey::F16x2,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -135,6 +129,13 @@ impl TypeWordMap {
|
|||||||
.entry(t)
|
.entry(t)
|
||||||
.or_insert_with(|| b.type_pointer(None, storage, base))
|
.or_insert_with(|| b.type_pointer(None, storage, base))
|
||||||
}
|
}
|
||||||
|
SpirvType::Vector(typ, len) => {
|
||||||
|
let base = self.get_or_add_spirv_scalar(b, typ);
|
||||||
|
*self
|
||||||
|
.complex
|
||||||
|
.entry(t)
|
||||||
|
.or_insert_with(|| b.type_vector(base, len as u32))
|
||||||
|
}
|
||||||
SpirvType::Array(typ, len) => {
|
SpirvType::Array(typ, len) => {
|
||||||
let base = self.get_or_add_spirv_scalar(b, typ);
|
let base = self.get_or_add_spirv_scalar(b, typ);
|
||||||
*self
|
*self
|
||||||
@ -232,8 +233,8 @@ fn emit_function_header<'a>(
|
|||||||
spirv::FunctionControl::NONE,
|
spirv::FunctionControl::NONE,
|
||||||
func_type,
|
func_type,
|
||||||
)?;
|
)?;
|
||||||
func_directive.visit_args(|arg| {
|
func_directive.visit_args(&mut |arg| {
|
||||||
let result_type = map.get_or_add_scalar(builder, arg.a_type);
|
let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type).into());
|
||||||
let inst = dr::Instruction::new(
|
let inst = dr::Instruction::new(
|
||||||
spirv::Op::FunctionParameter,
|
spirv::Op::FunctionParameter,
|
||||||
Some(result_type),
|
Some(result_type),
|
||||||
@ -285,9 +286,9 @@ fn expand_kernel_params<'a, 'b>(
|
|||||||
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
|
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
|
||||||
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
|
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
|
||||||
args.map(|a| ast::KernelArgument {
|
args.map(|a| ast::KernelArgument {
|
||||||
name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))),
|
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
|
||||||
a_type: a.a_type,
|
v_type: a.v_type,
|
||||||
length: a.length,
|
align: a.align,
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
@ -297,12 +298,9 @@ fn expand_fn_params<'a, 'b>(
|
|||||||
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
|
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
|
||||||
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
|
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
|
||||||
args.map(|a| ast::FnArgument {
|
args.map(|a| ast::FnArgument {
|
||||||
state_space: a.state_space,
|
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
|
||||||
base: ast::KernelArgument {
|
v_type: a.v_type,
|
||||||
name: fn_resolver.add_def(a.base.name, Some(ast::Type::Scalar(a.base.a_type))),
|
align: a.align,
|
||||||
a_type: a.base.a_type,
|
|
||||||
length: a.base.length,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
@ -375,16 +373,12 @@ fn resolve_fn_calls(
|
|||||||
|
|
||||||
fn to_resolved_fn_args<T>(
|
fn to_resolved_fn_args<T>(
|
||||||
params: Vec<T>,
|
params: Vec<T>,
|
||||||
params_decl: &[(ast::FnArgStateSpace, ast::ScalarType)],
|
params_decl: &[ast::FnArgumentType],
|
||||||
) -> Vec<ArgCall<T>> {
|
) -> Vec<(T, ast::FnArgumentType)> {
|
||||||
params
|
params
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(params_decl.iter())
|
.zip(params_decl.iter())
|
||||||
.map(|(id, &(space, typ))| ArgCall {
|
.map(|(id, typ)| (id, *typ))
|
||||||
id,
|
|
||||||
typ: ast::Type::Scalar(typ),
|
|
||||||
space: space,
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -476,12 +470,11 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||||||
let out_param = match &mut f_args {
|
let out_param = match &mut f_args {
|
||||||
ast::MethodDecl::Kernel(_, in_params) => {
|
ast::MethodDecl::Kernel(_, in_params) => {
|
||||||
for p in in_params.iter_mut() {
|
for p in in_params.iter_mut() {
|
||||||
let typ = ast::Type::Scalar(p.a_type);
|
let typ = ast::Type::from(p.v_type);
|
||||||
let new_id = id_def.new_id(Some(typ));
|
let new_id = id_def.new_id(Some(typ));
|
||||||
result.push(Statement::Variable(VariableDecl {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
space: ast::StateSpace::Reg,
|
align: p.align,
|
||||||
align: None,
|
v_type: ast::VariableType::Param(p.v_type),
|
||||||
v_type: typ,
|
|
||||||
name: p.name,
|
name: p.name,
|
||||||
}));
|
}));
|
||||||
result.push(Statement::StoreVar(
|
result.push(Statement::StoreVar(
|
||||||
@ -497,32 +490,31 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||||||
}
|
}
|
||||||
ast::MethodDecl::Func(out_params, _, in_params) => {
|
ast::MethodDecl::Func(out_params, _, in_params) => {
|
||||||
for p in in_params.iter_mut() {
|
for p in in_params.iter_mut() {
|
||||||
let typ = ast::Type::Scalar(p.base.a_type);
|
let typ = ast::Type::from(p.v_type);
|
||||||
let new_id = id_def.new_id(Some(typ));
|
let new_id = id_def.new_id(Some(typ));
|
||||||
result.push(Statement::Variable(VariableDecl {
|
let var_typ = ast::VariableType::from(p.v_type);
|
||||||
space: ast::StateSpace::Reg,
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: None,
|
align: p.align,
|
||||||
v_type: typ,
|
v_type: var_typ,
|
||||||
name: p.base.name,
|
name: p.name,
|
||||||
}));
|
}));
|
||||||
result.push(Statement::StoreVar(
|
result.push(Statement::StoreVar(
|
||||||
ast::Arg2St {
|
ast::Arg2St {
|
||||||
src1: p.base.name,
|
src1: p.name,
|
||||||
src2: new_id,
|
src2: new_id,
|
||||||
},
|
},
|
||||||
typ,
|
typ,
|
||||||
));
|
));
|
||||||
p.base.name = new_id;
|
p.name = new_id;
|
||||||
}
|
}
|
||||||
match &mut **out_params {
|
match &mut **out_params {
|
||||||
[p] => {
|
[p] => {
|
||||||
result.push(Statement::Variable(VariableDecl {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
space: ast::StateSpace::Reg,
|
align: p.align,
|
||||||
align: None,
|
v_type: ast::VariableType::from(p.v_type),
|
||||||
v_type: ast::Type::Scalar(p.base.a_type),
|
name: p.name,
|
||||||
name: p.base.name,
|
|
||||||
}));
|
}));
|
||||||
Some(p.base.name)
|
Some(p.name)
|
||||||
}
|
}
|
||||||
[] => None,
|
[] => None,
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
@ -552,15 +544,13 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||||||
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
|
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
|
||||||
},
|
},
|
||||||
Statement::Conditional(mut bra) => {
|
Statement::Conditional(mut bra) => {
|
||||||
let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar(
|
let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
|
||||||
ast::ExtendedScalarType::Pred,
|
|
||||||
)));
|
|
||||||
result.push(Statement::LoadVar(
|
result.push(Statement::LoadVar(
|
||||||
Arg2 {
|
Arg2 {
|
||||||
dst: generated_id,
|
dst: generated_id,
|
||||||
src: bra.predicate,
|
src: bra.predicate,
|
||||||
},
|
},
|
||||||
ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
|
ast::Type::Scalar(ast::ScalarType::Pred),
|
||||||
));
|
));
|
||||||
bra.predicate = generated_id;
|
bra.predicate = generated_id;
|
||||||
result.push(Statement::Conditional(bra));
|
result.push(Statement::Conditional(bra));
|
||||||
@ -642,7 +632,15 @@ fn expand_arguments<'a, 'b>(
|
|||||||
let new_inst = inst.map(&mut visitor);
|
let new_inst = inst.map(&mut visitor);
|
||||||
result.push(Statement::Instruction(new_inst));
|
result.push(Statement::Instruction(new_inst));
|
||||||
}
|
}
|
||||||
Statement::Variable(v_decl) => result.push(Statement::Variable(v_decl)),
|
Statement::Variable(ast::Variable {
|
||||||
|
align,
|
||||||
|
v_type,
|
||||||
|
name,
|
||||||
|
}) => result.push(Statement::Variable(ast::Variable {
|
||||||
|
align,
|
||||||
|
v_type,
|
||||||
|
name,
|
||||||
|
})),
|
||||||
Statement::Label(id) => result.push(Statement::Label(id)),
|
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||||
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
||||||
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
|
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
|
||||||
@ -745,7 +743,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||||||
) -> spirv::Word {
|
) -> spirv::Word {
|
||||||
match &desc.op {
|
match &desc.op {
|
||||||
ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)),
|
ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)),
|
||||||
ast::MovOperand::Vec(_, _) => todo!(),
|
ast::MovOperand::Vec(opr, _) => self.variable(desc.new_op(*opr)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -835,13 +833,19 @@ fn get_function_type(
|
|||||||
match method_decl {
|
match method_decl {
|
||||||
ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn(
|
ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn(
|
||||||
builder,
|
builder,
|
||||||
out_params.iter().map(|p| SpirvType::from(p.base.a_type)),
|
out_params
|
||||||
in_params.iter().map(|p| SpirvType::from(p.base.a_type)),
|
.iter()
|
||||||
|
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
|
||||||
|
in_params
|
||||||
|
.iter()
|
||||||
|
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
|
||||||
),
|
),
|
||||||
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
|
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
|
||||||
builder,
|
builder,
|
||||||
iter::empty(),
|
iter::empty(),
|
||||||
params.iter().map(|p| SpirvType::from(p.a_type)),
|
params
|
||||||
|
.iter()
|
||||||
|
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -870,31 +874,38 @@ fn emit_function_body_ops(
|
|||||||
Statement::Label(_) => (),
|
Statement::Label(_) => (),
|
||||||
Statement::Call(call) => {
|
Statement::Call(call) => {
|
||||||
let (result_type, result_id) = match &*call.ret_params {
|
let (result_type, result_id) = match &*call.ret_params {
|
||||||
[p] => (map.get_or_add(builder, SpirvType::from(p.typ)), p.id),
|
[(id, typ)] => (
|
||||||
|
map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))),
|
||||||
|
*id,
|
||||||
|
),
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
};
|
};
|
||||||
let arg_list = call.param_list.iter().map(|p| p.id).collect::<Vec<_>>();
|
let arg_list = call
|
||||||
|
.param_list
|
||||||
|
.iter()
|
||||||
|
.map(|(id, _)| *id)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
builder.function_call(result_type, Some(result_id), call.func, arg_list)?;
|
builder.function_call(result_type, Some(result_id), call.func, arg_list)?;
|
||||||
}
|
}
|
||||||
Statement::Variable(VariableDecl {
|
Statement::Variable(ast::Variable {
|
||||||
name: id,
|
|
||||||
v_type: typ,
|
|
||||||
space: ss,
|
|
||||||
align,
|
align,
|
||||||
|
v_type,
|
||||||
|
name,
|
||||||
}) => {
|
}) => {
|
||||||
let type_id = map.get_or_add(
|
let type_id = map.get_or_add(
|
||||||
builder,
|
builder,
|
||||||
SpirvType::new_pointer(*typ, spirv::StorageClass::Function),
|
SpirvType::new_pointer(ast::Type::from(*v_type), spirv::StorageClass::Function),
|
||||||
);
|
);
|
||||||
let st_class = match ss {
|
let st_class = match v_type {
|
||||||
ast::StateSpace::Reg | ast::StateSpace::Param => spirv::StorageClass::Function,
|
ast::VariableType::Reg(_) | ast::VariableType::Param(_) => {
|
||||||
ast::StateSpace::Local => spirv::StorageClass::Workgroup,
|
spirv::StorageClass::Function
|
||||||
_ => todo!(),
|
}
|
||||||
|
ast::VariableType::Local(_) => spirv::StorageClass::Workgroup,
|
||||||
};
|
};
|
||||||
builder.variable(type_id, Some(*id), st_class, None);
|
builder.variable(type_id, Some(*name), st_class, None);
|
||||||
if let Some(align) = align {
|
if let Some(align) = align {
|
||||||
builder.decorate(
|
builder.decorate(
|
||||||
*id,
|
*name,
|
||||||
spirv::Decoration::Alignment,
|
spirv::Decoration::Alignment,
|
||||||
&[dr::Operand::LiteralInt32(*align)],
|
&[dr::Operand::LiteralInt32(*align)],
|
||||||
);
|
);
|
||||||
@ -1051,7 +1062,7 @@ fn emit_cvt(
|
|||||||
if desc.saturate || desc.flush_to_zero {
|
if desc.saturate || desc.flush_to_zero {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
let dest_t: ast::Type = desc.dst.into();
|
let dest_t: ast::ScalarType = desc.dst.into();
|
||||||
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
||||||
builder.f_convert(result_type, Some(arg.dst), arg.src)?;
|
builder.f_convert(result_type, Some(arg.dst), arg.src)?;
|
||||||
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
||||||
@ -1060,7 +1071,7 @@ fn emit_cvt(
|
|||||||
if desc.saturate || desc.flush_to_zero {
|
if desc.saturate || desc.flush_to_zero {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
let dest_t: ast::Type = desc.dst.into();
|
let dest_t: ast::ScalarType = desc.dst.into();
|
||||||
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
||||||
if desc.src.is_signed() {
|
if desc.src.is_signed() {
|
||||||
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
|
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
|
||||||
@ -1367,7 +1378,7 @@ fn normalize_identifiers<'a, 'b>(
|
|||||||
|
|
||||||
fn expand_map_variables<'a, 'b>(
|
fn expand_map_variables<'a, 'b>(
|
||||||
id_defs: &mut FnStringIdResolver<'a, 'b>,
|
id_defs: &mut FnStringIdResolver<'a, 'b>,
|
||||||
fn_defs: &GlobalFnDeclResolver,
|
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
|
||||||
result: &mut Vec<NormalizedStatement>,
|
result: &mut Vec<NormalizedStatement>,
|
||||||
s: ast::Statement<ast::ParsedArgParams<'a>>,
|
s: ast::Statement<ast::ParsedArgParams<'a>>,
|
||||||
) {
|
) {
|
||||||
@ -1386,21 +1397,19 @@ fn expand_map_variables<'a, 'b>(
|
|||||||
))),
|
))),
|
||||||
ast::Statement::Variable(var) => match var.count {
|
ast::Statement::Variable(var) => match var.count {
|
||||||
Some(count) => {
|
Some(count) => {
|
||||||
for new_id in id_defs.add_defs(var.name, count, var.v_type) {
|
for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) {
|
||||||
result.push(Statement::Variable(VariableDecl {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
space: var.space,
|
align: var.var.align,
|
||||||
align: var.align,
|
v_type: var.var.v_type,
|
||||||
v_type: var.v_type,
|
|
||||||
name: new_id,
|
name: new_id,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let new_id = id_defs.add_def(var.name, Some(var.v_type));
|
let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into()));
|
||||||
result.push(Statement::Variable(VariableDecl {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
space: var.space,
|
align: var.var.align,
|
||||||
align: var.align,
|
v_type: var.var.v_type,
|
||||||
v_type: var.v_type,
|
|
||||||
name: new_id,
|
name: new_id,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
@ -1408,15 +1417,38 @@ fn expand_map_variables<'a, 'b>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash)]
|
||||||
|
enum PtxSpecialRegister {
|
||||||
|
Tid,
|
||||||
|
Ntid,
|
||||||
|
Ctaid,
|
||||||
|
Nctaid,
|
||||||
|
Gridid,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PtxSpecialRegister {
|
||||||
|
fn try_parse(s: &str) -> Option<Self> {
|
||||||
|
match s {
|
||||||
|
"%tid" => Some(Self::Tid),
|
||||||
|
"%ntid" => Some(Self::Ntid),
|
||||||
|
"%ctaid" => Some(Self::Ctaid),
|
||||||
|
"%nctaid" => Some(Self::Nctaid),
|
||||||
|
"%gridid" => Some(Self::Gridid),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct GlobalStringIdResolver<'input> {
|
struct GlobalStringIdResolver<'input> {
|
||||||
current_id: spirv::Word,
|
current_id: spirv::Word,
|
||||||
variables: HashMap<Cow<'input, str>, spirv::Word>,
|
variables: HashMap<Cow<'input, str>, spirv::Word>,
|
||||||
|
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
|
||||||
fns: HashMap<spirv::Word, FnDecl>,
|
fns: HashMap<spirv::Word, FnDecl>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct FnDecl {
|
pub struct FnDecl {
|
||||||
ret_vals: Vec<(ast::FnArgStateSpace, ast::ScalarType)>,
|
ret_vals: Vec<ast::FnArgumentType>,
|
||||||
params: Vec<(ast::FnArgStateSpace, ast::ScalarType)>,
|
params: Vec<ast::FnArgumentType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> GlobalStringIdResolver<'a> {
|
impl<'a> GlobalStringIdResolver<'a> {
|
||||||
@ -1424,6 +1456,7 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||||||
Self {
|
Self {
|
||||||
current_id: start_id,
|
current_id: start_id,
|
||||||
variables: HashMap::new(),
|
variables: HashMap::new(),
|
||||||
|
special_registers: HashMap::new(),
|
||||||
fns: HashMap::new(),
|
fns: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1461,6 +1494,7 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||||||
let mut fn_resolver = FnStringIdResolver {
|
let mut fn_resolver = FnStringIdResolver {
|
||||||
current_id: &mut self.current_id,
|
current_id: &mut self.current_id,
|
||||||
global_variables: &self.variables,
|
global_variables: &self.variables,
|
||||||
|
special_registers: &mut self.special_registers,
|
||||||
variables: vec![HashMap::new(); 1],
|
variables: vec![HashMap::new(); 1],
|
||||||
type_check: HashMap::new(),
|
type_check: HashMap::new(),
|
||||||
};
|
};
|
||||||
@ -1474,14 +1508,8 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||||||
self.fns.insert(
|
self.fns.insert(
|
||||||
name_id,
|
name_id,
|
||||||
FnDecl {
|
FnDecl {
|
||||||
ret_vals: ret_params_ids
|
ret_vals: ret_params_ids.iter().map(|p| p.v_type).collect(),
|
||||||
.iter()
|
params: params_ids.iter().map(|p| p.v_type).collect(),
|
||||||
.map(|p| (p.state_space, p.base.a_type))
|
|
||||||
.collect(),
|
|
||||||
params: params_ids
|
|
||||||
.iter()
|
|
||||||
.map(|p| (p.state_space, p.base.a_type))
|
|
||||||
.collect(),
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
|
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
|
||||||
@ -1516,7 +1544,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
|
|||||||
struct FnStringIdResolver<'input, 'b> {
|
struct FnStringIdResolver<'input, 'b> {
|
||||||
current_id: &'b mut spirv::Word,
|
current_id: &'b mut spirv::Word,
|
||||||
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
|
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
|
||||||
//global: &'b mut GlobalStringIdResolver<'a>,
|
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
|
||||||
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
|
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
|
||||||
type_check: HashMap<u32, ast::Type>,
|
type_check: HashMap<u32, ast::Type>,
|
||||||
}
|
}
|
||||||
@ -1537,14 +1565,28 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||||||
self.variables.pop();
|
self.variables.pop();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_id(&self, id: &str) -> spirv::Word {
|
fn get_id(&mut self, id: &str) -> spirv::Word {
|
||||||
for scope in self.variables.iter().rev() {
|
for scope in self.variables.iter().rev() {
|
||||||
match scope.get(id) {
|
match scope.get(id) {
|
||||||
Some(id) => return *id,
|
Some(id) => return *id,
|
||||||
None => continue,
|
None => continue,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.global_variables[id]
|
match self.global_variables.get(id) {
|
||||||
|
Some(id) => *id,
|
||||||
|
None => {
|
||||||
|
let sreg = PtxSpecialRegister::try_parse(id).unwrap_or_else(|| todo!());
|
||||||
|
match self.special_registers.entry(sreg) {
|
||||||
|
hash_map::Entry::Occupied(e) => *e.get(),
|
||||||
|
hash_map::Entry::Vacant(e) => {
|
||||||
|
let numeric_id = *self.current_id;
|
||||||
|
*self.current_id += 1;
|
||||||
|
e.insert(numeric_id);
|
||||||
|
numeric_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
|
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
|
||||||
@ -1602,7 +1644,7 @@ impl<'b> NumericIdResolver<'b> {
|
|||||||
|
|
||||||
enum Statement<I, P: ast::ArgParams> {
|
enum Statement<I, P: ast::ArgParams> {
|
||||||
Label(u32),
|
Label(u32),
|
||||||
Variable(VariableDecl),
|
Variable(ast::Variable<ast::VariableType, P>),
|
||||||
Instruction(I),
|
Instruction(I),
|
||||||
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
|
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
|
||||||
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
|
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
|
||||||
@ -1614,18 +1656,11 @@ enum Statement<I, P: ast::ArgParams> {
|
|||||||
RetValue(ast::RetData, spirv::Word),
|
RetValue(ast::RetData, spirv::Word),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct VariableDecl {
|
|
||||||
pub space: ast::StateSpace,
|
|
||||||
pub align: Option<u32>,
|
|
||||||
pub v_type: ast::Type,
|
|
||||||
pub name: spirv::Word,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ResolvedCall<P: ast::ArgParams> {
|
struct ResolvedCall<P: ast::ArgParams> {
|
||||||
pub uniform: bool,
|
pub uniform: bool,
|
||||||
pub ret_params: Vec<ArgCall<spirv::Word>>,
|
pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
|
||||||
pub func: spirv::Word,
|
pub func: spirv::Word,
|
||||||
pub param_list: Vec<ArgCall<P::CallOperand>>,
|
pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
|
impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
|
||||||
@ -1636,18 +1671,14 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
|
|||||||
let ret_params = self
|
let ret_params = self
|
||||||
.ret_params
|
.ret_params
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|p| {
|
.map(|(id, typ)| {
|
||||||
let new_id = visitor.variable(ArgumentDescriptor {
|
let new_id = visitor.variable(ArgumentDescriptor {
|
||||||
op: p.id,
|
op: id,
|
||||||
typ: Some(p.typ),
|
typ: Some(typ.into()),
|
||||||
is_dst: true,
|
is_dst: true,
|
||||||
is_pointer: false,
|
is_pointer: false,
|
||||||
});
|
});
|
||||||
ArgCall {
|
(new_id, typ)
|
||||||
id: new_id,
|
|
||||||
typ: p.typ,
|
|
||||||
space: p.space,
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let func = visitor.variable(ArgumentDescriptor {
|
let func = visitor.variable(ArgumentDescriptor {
|
||||||
@ -1659,18 +1690,14 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
|
|||||||
let param_list = self
|
let param_list = self
|
||||||
.param_list
|
.param_list
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|p| {
|
.map(|(id, typ)| {
|
||||||
let new_id = visitor.src_call_operand(ArgumentDescriptor {
|
let new_id = visitor.src_call_operand(ArgumentDescriptor {
|
||||||
op: p.id,
|
op: id,
|
||||||
typ: Some(p.typ),
|
typ: Some(typ.into()),
|
||||||
is_dst: false,
|
is_dst: false,
|
||||||
is_pointer: false,
|
is_pointer: false,
|
||||||
});
|
});
|
||||||
ArgCall {
|
(new_id, typ)
|
||||||
id: new_id,
|
|
||||||
typ: p.typ,
|
|
||||||
space: p.space,
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
ResolvedCall {
|
ResolvedCall {
|
||||||
@ -1700,12 +1727,6 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ArgCall<ID> {
|
|
||||||
id: ID,
|
|
||||||
typ: ast::Type,
|
|
||||||
space: ast::FnArgStateSpace,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ArgParamsEx: ast::ArgParams {
|
pub trait ArgParamsEx: ast::ArgParams {
|
||||||
fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl;
|
fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl;
|
||||||
}
|
}
|
||||||
@ -1817,7 +1838,9 @@ where
|
|||||||
) -> ast::MovOperand<spirv::Word> {
|
) -> ast::MovOperand<spirv::Word> {
|
||||||
match desc.op {
|
match desc.op {
|
||||||
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))),
|
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))),
|
||||||
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
|
ast::MovOperand::Vec(reg, x2) => {
|
||||||
|
ast::MovOperand::Vec(self.variable(desc.new_op(reg)), x2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1881,13 +1904,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||||||
}
|
}
|
||||||
ast::Instruction::Cvt(d, a) => {
|
ast::Instruction::Cvt(d, a) => {
|
||||||
let (dst_t, src_t) = match &d {
|
let (dst_t, src_t) = match &d {
|
||||||
ast::CvtDetails::FloatFromFloat(desc) => (desc.dst.into(), desc.src.into()),
|
ast::CvtDetails::FloatFromFloat(desc) => (
|
||||||
ast::CvtDetails::FloatFromInt(desc) => {
|
ast::Type::Scalar(desc.dst.into()),
|
||||||
(desc.dst.into(), ast::Type::Scalar(desc.src.into()))
|
ast::Type::Scalar(desc.src.into()),
|
||||||
}
|
),
|
||||||
ast::CvtDetails::IntFromFloat(desc) => {
|
ast::CvtDetails::FloatFromInt(desc) => (
|
||||||
(ast::Type::Scalar(desc.dst.into()), desc.src.into())
|
ast::Type::Scalar(desc.dst.into()),
|
||||||
}
|
ast::Type::Scalar(desc.src.into()),
|
||||||
|
),
|
||||||
|
ast::CvtDetails::IntFromFloat(desc) => (
|
||||||
|
ast::Type::Scalar(desc.dst.into()),
|
||||||
|
ast::Type::Scalar(desc.src.into()),
|
||||||
|
),
|
||||||
ast::CvtDetails::IntFromInt(desc) => (
|
ast::CvtDetails::IntFromInt(desc) => (
|
||||||
ast::Type::Scalar(desc.dst.into()),
|
ast::Type::Scalar(desc.dst.into()),
|
||||||
ast::Type::Scalar(desc.src.into()),
|
ast::Type::Scalar(desc.src.into()),
|
||||||
@ -2261,14 +2289,14 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
|
|||||||
ast::Arg4 {
|
ast::Arg4 {
|
||||||
dst1: visitor.variable(ArgumentDescriptor {
|
dst1: visitor.variable(ArgumentDescriptor {
|
||||||
op: self.dst1,
|
op: self.dst1,
|
||||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||||
is_dst: true,
|
is_dst: true,
|
||||||
is_pointer: false,
|
is_pointer: false,
|
||||||
}),
|
}),
|
||||||
dst2: self.dst2.map(|dst2| {
|
dst2: self.dst2.map(|dst2| {
|
||||||
visitor.variable(ArgumentDescriptor {
|
visitor.variable(ArgumentDescriptor {
|
||||||
op: dst2,
|
op: dst2,
|
||||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||||
is_dst: true,
|
is_dst: true,
|
||||||
is_pointer: false,
|
is_pointer: false,
|
||||||
})
|
})
|
||||||
@ -2298,14 +2326,14 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
|
|||||||
ast::Arg5 {
|
ast::Arg5 {
|
||||||
dst1: visitor.variable(ArgumentDescriptor {
|
dst1: visitor.variable(ArgumentDescriptor {
|
||||||
op: self.dst1,
|
op: self.dst1,
|
||||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||||
is_dst: true,
|
is_dst: true,
|
||||||
is_pointer: false,
|
is_pointer: false,
|
||||||
}),
|
}),
|
||||||
dst2: self.dst2.map(|dst2| {
|
dst2: self.dst2.map(|dst2| {
|
||||||
visitor.variable(ArgumentDescriptor {
|
visitor.variable(ArgumentDescriptor {
|
||||||
op: dst2,
|
op: dst2,
|
||||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||||
is_dst: true,
|
is_dst: true,
|
||||||
is_pointer: false,
|
is_pointer: false,
|
||||||
})
|
})
|
||||||
@ -2324,7 +2352,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
|
|||||||
}),
|
}),
|
||||||
src3: visitor.operand(ArgumentDescriptor {
|
src3: visitor.operand(ArgumentDescriptor {
|
||||||
op: self.src3,
|
op: self.src3,
|
||||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||||
is_dst: false,
|
is_dst: false,
|
||||||
is_pointer: false,
|
is_pointer: false,
|
||||||
}),
|
}),
|
||||||
@ -2332,65 +2360,6 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
impl<T: ArgParamsEx> ast::ArgCall<T> {
|
|
||||||
fn map<'a, U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
|
||||||
self,
|
|
||||||
visitor: &mut V,
|
|
||||||
fn_resolve: &GlobalFnDeclResolver<'a>,
|
|
||||||
) -> ast::ArgCall<U> {
|
|
||||||
// TODO: error out if lengths don't match
|
|
||||||
let fn_decl = T::get_fn_decl(&self.func, fn_resolve);
|
|
||||||
let ret_params = self
|
|
||||||
.ret_params
|
|
||||||
.into_iter()
|
|
||||||
.zip(fn_decl.ret_vals.iter().copied())
|
|
||||||
.map(|(a, (space, typ))| {
|
|
||||||
visitor.variable(ArgumentDescriptor {
|
|
||||||
op: a,
|
|
||||||
typ: Some(ast::Type::Scalar(typ)),
|
|
||||||
is_dst: true,
|
|
||||||
is_pointer: if space == ast::FnArgStateSpace::Reg {
|
|
||||||
false
|
|
||||||
} else {
|
|
||||||
true
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let func = visitor.variable(ArgumentDescriptor {
|
|
||||||
op: self.func,
|
|
||||||
typ: None,
|
|
||||||
is_dst: false,
|
|
||||||
is_pointer: false,
|
|
||||||
});
|
|
||||||
let param_list = self
|
|
||||||
.param_list
|
|
||||||
.into_iter()
|
|
||||||
.zip(fn_decl.params.iter().copied())
|
|
||||||
.map(|(a, (space, typ))| {
|
|
||||||
visitor.src_call_operand(ArgumentDescriptor {
|
|
||||||
op: a,
|
|
||||||
typ: Some(ast::Type::Scalar(typ)),
|
|
||||||
is_dst: false,
|
|
||||||
is_pointer: if space == ast::FnArgStateSpace::Reg {
|
|
||||||
false
|
|
||||||
} else {
|
|
||||||
true
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
ast::ArgCall {
|
|
||||||
uniform: false,
|
|
||||||
ret_params,
|
|
||||||
func: func,
|
|
||||||
param_list: param_list,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
impl<T> ast::CallOperand<T> {
|
impl<T> ast::CallOperand<T> {
|
||||||
fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::CallOperand<U> {
|
fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::CallOperand<U> {
|
||||||
match self {
|
match self {
|
||||||
@ -2418,6 +2387,8 @@ enum ScalarKind {
|
|||||||
Unsigned,
|
Unsigned,
|
||||||
Signed,
|
Signed,
|
||||||
Float,
|
Float,
|
||||||
|
Float2,
|
||||||
|
Pred,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ast::ScalarType {
|
impl ast::ScalarType {
|
||||||
@ -2438,6 +2409,8 @@ impl ast::ScalarType {
|
|||||||
ast::ScalarType::S64 => 8,
|
ast::ScalarType::S64 => 8,
|
||||||
ast::ScalarType::B64 => 8,
|
ast::ScalarType::B64 => 8,
|
||||||
ast::ScalarType::F64 => 8,
|
ast::ScalarType::F64 => 8,
|
||||||
|
ast::ScalarType::F16x2 => 4,
|
||||||
|
ast::ScalarType::Pred => 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2458,6 +2431,8 @@ impl ast::ScalarType {
|
|||||||
ast::ScalarType::F16 => ScalarKind::Float,
|
ast::ScalarType::F16 => ScalarKind::Float,
|
||||||
ast::ScalarType::F32 => ScalarKind::Float,
|
ast::ScalarType::F32 => ScalarKind::Float,
|
||||||
ast::ScalarType::F64 => ScalarKind::Float,
|
ast::ScalarType::F64 => ScalarKind::Float,
|
||||||
|
ast::ScalarType::F16x2 => ScalarKind::Float,
|
||||||
|
ast::ScalarType::Pred => ScalarKind::Pred,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2490,6 +2465,11 @@ impl ast::ScalarType {
|
|||||||
8 => ast::ScalarType::U64,
|
8 => ast::ScalarType::U64,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
},
|
},
|
||||||
|
ScalarKind::Float2 => match width {
|
||||||
|
4 => ast::ScalarType::F16x2,
|
||||||
|
_ => unreachable!(),
|
||||||
|
},
|
||||||
|
ScalarKind::Pred => ast::ScalarType::Pred,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2497,7 +2477,7 @@ impl ast::ScalarType {
|
|||||||
impl ast::NotType {
|
impl ast::NotType {
|
||||||
fn to_type(self) -> ast::Type {
|
fn to_type(self) -> ast::Type {
|
||||||
match self {
|
match self {
|
||||||
ast::NotType::Pred => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
|
ast::NotType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
|
||||||
ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
|
ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
|
||||||
ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
|
ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
|
||||||
ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
|
ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
|
||||||
@ -2519,7 +2499,9 @@ impl ast::AddDetails {
|
|||||||
fn get_type(&self) -> ast::Type {
|
fn get_type(&self) -> ast::Type {
|
||||||
match self {
|
match self {
|
||||||
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
|
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
|
||||||
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => (*typ).into(),
|
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => {
|
||||||
|
ast::Type::Scalar((*typ).into())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2528,7 +2510,9 @@ impl ast::MulDetails {
|
|||||||
fn get_type(&self) -> ast::Type {
|
fn get_type(&self) -> ast::Type {
|
||||||
match self {
|
match self {
|
||||||
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
|
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
|
||||||
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => (*typ).into(),
|
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => {
|
||||||
|
ast::Type::Scalar((*typ).into())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2560,6 +2544,15 @@ impl ast::LdStateSpace {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<ast::FnArgumentType> for ast::VariableType {
|
||||||
|
fn from(t: ast::FnArgumentType) -> Self {
|
||||||
|
match t {
|
||||||
|
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
|
||||||
|
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
||||||
match (instr, operand) {
|
match (instr, operand) {
|
||||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||||
@ -2575,6 +2568,8 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
|||||||
ScalarKind::Unsigned => {
|
ScalarKind::Unsigned => {
|
||||||
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
|
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
|
||||||
}
|
}
|
||||||
|
ScalarKind::Float2 => todo!(),
|
||||||
|
ScalarKind::Pred => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => false,
|
_ => false,
|
||||||
@ -2758,6 +2753,8 @@ fn should_convert_relaxed_src(
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ScalarKind::Float2 => todo!(),
|
||||||
|
ScalarKind::Pred => None,
|
||||||
},
|
},
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
@ -2807,6 +2804,8 @@ fn should_convert_relaxed_dst(
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ScalarKind::Float2 => todo!(),
|
||||||
|
ScalarKind::Pred => None,
|
||||||
},
|
},
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
@ -2862,16 +2861,21 @@ impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> {
|
impl<'a, P: ArgParamsEx<ID = spirv::Word>> ast::MethodDecl<'a, P> {
|
||||||
fn visit_args(&self, f: impl FnMut(&ast::KernelArgument<P>)) {
|
fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<P>)) {
|
||||||
match self {
|
match self {
|
||||||
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(f),
|
ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
|
||||||
ast::MethodDecl::Func(_, _, params) => params.iter().map(|a| &a.base).for_each(f),
|
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| {
|
||||||
|
f(&ast::FnArgument {
|
||||||
|
align: arg.align,
|
||||||
|
name: arg.name,
|
||||||
|
v_type: ast::FnArgumentType::Param(arg.v_type),
|
||||||
|
})
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CFGs below taken from "Modern Compiler Implementation in Java"
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
Reference in New Issue
Block a user