Implement missing pieces in vector support

This commit is contained in:
Andrzej Janik
2020-09-15 02:34:08 +02:00
parent bb5025c9b1
commit fcf3aaeb16
4 changed files with 226 additions and 146 deletions

View File

@ -317,7 +317,7 @@ pub struct PredAt<ID> {
pub enum Instruction<P: ArgParams> { pub enum Instruction<P: ArgParams> {
Ld(LdData, Arg2<P>), Ld(LdData, Arg2<P>),
Mov(MovType, Arg2<P>), Mov(MovType, Arg2<P>),
MovVector(MovVectorType, Arg2Vec<P>), MovVector(MovVectorDetails, Arg2Vec<P>),
Mul(MulDetails, Arg3<P>), Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>), Add(AddDetails, Arg3<P>),
Setp(SetpData, Arg4<P>), Setp(SetpData, Arg4<P>),
@ -333,6 +333,11 @@ pub enum Instruction<P: ArgParams> {
Abs(AbsDetails, Arg2<P>), Abs(AbsDetails, Arg2<P>),
} }
#[derive(Copy, Clone)]
pub struct MovVectorDetails {
pub typ: MovVectorType,
pub length: u8,
}
pub struct AbsDetails { pub struct AbsDetails {
pub flush_to_zero: bool, pub flush_to_zero: bool,
pub typ: ScalarType, pub typ: ScalarType,
@ -377,10 +382,12 @@ pub struct Arg2St<P: ArgParams> {
pub src2: P::Operand, pub src2: P::Operand,
} }
// We duplicate dst here because during further compilation
// composite dst and composite src will receive different ids
pub enum Arg2Vec<P: ArgParams> { pub enum Arg2Vec<P: ArgParams> {
Dst(P::VecOperand, P::ID), Dst((P::ID, u8), P::ID, P::ID),
Src(P::ID, P::VecOperand), Src(P::ID, P::VecOperand),
Both(P::VecOperand, P::VecOperand), Both((P::ID, u8), P::ID, P::VecOperand),
} }
pub struct Arg3<P: ArgParams> { pub struct Arg3<P: ArgParams> {

View File

@ -499,7 +499,7 @@ InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::Mov(t, a) ast::Instruction::Mov(t, a)
}, },
"mov" <t:MovVectorType> <a:Arg2Vec> => { "mov" <t:MovVectorType> <a:Arg2Vec> => {
ast::Instruction::MovVector(t, a) ast::Instruction::MovVector(ast::MovVectorDetails{typ: t, length: 0}, a)
} }
}; };
@ -1030,9 +1030,9 @@ Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
}; };
Arg2Vec: ast::Arg2Vec<ast::ParsedArgParams<'input>> = { Arg2Vec: ast::Arg2Vec<ast::ParsedArgParams<'input>> = {
<dst:VectorOperand> "," <src:ExtendedID> => ast::Arg2Vec::Dst(dst, src), <dst:VectorOperand> "," <src:ExtendedID> => ast::Arg2Vec::Dst(dst, dst.0, src),
<dst:ExtendedID> "," <src:VectorOperand> => ast::Arg2Vec::Src(dst, src), <dst:ExtendedID> "," <src:VectorOperand> => ast::Arg2Vec::Src(dst, src),
<dst:VectorOperand> "," <src:VectorOperand> => ast::Arg2Vec::Both(dst, src), <dst:VectorOperand> "," <src:VectorOperand> => ast::Arg2Vec::Both(dst, dst.0, src),
}; };
VectorOperand: (&'input str, u8) = { VectorOperand: (&'input str, u8) = {

View File

@ -4,20 +4,20 @@
OpCapability Kernel OpCapability Kernel
OpCapability Int64 OpCapability Int64
OpCapability Int8 OpCapability Int8
%58 = OpExtInstImport "OpenCL.std" %60 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %31 "vector" OpEntryPoint Kernel %31 "vector"
%void = OpTypeVoid %void = OpTypeVoid
%uint = OpTypeInt 32 0 %uint = OpTypeInt 32 0
%v2uint = OpTypeVector %uint 2 %v2uint = OpTypeVector %uint 2
%62 = OpTypeFunction %v2uint %v2uint %64 = OpTypeFunction %v2uint %v2uint
%_ptr_Function_v2uint = OpTypePointer Function %v2uint %_ptr_Function_v2uint = OpTypePointer Function %v2uint
%_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_uint = OpTypePointer Function %uint
%ulong = OpTypeInt 64 0 %ulong = OpTypeInt 64 0
%66 = OpTypeFunction %void %ulong %ulong %68 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint %_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
%1 = OpFunction %v2uint None %62 %1 = OpFunction %v2uint None %64
%7 = OpFunctionParameter %v2uint %7 = OpFunctionParameter %v2uint
%30 = OpLabel %30 = OpLabel
%3 = OpVariable %_ptr_Function_v2uint Function %3 = OpVariable %_ptr_Function_v2uint Function
@ -27,40 +27,40 @@
%6 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_uint Function
OpStore %3 %7 OpStore %3 %7
%9 = OpLoad %v2uint %3 %9 = OpLoad %v2uint %3
%24 = OpCompositeExtract %uint %9 0 %27 = OpCompositeExtract %uint %9 0
%8 = OpCopyObject %uint %24 %8 = OpCopyObject %uint %27
OpStore %5 %8 OpStore %5 %8
%11 = OpLoad %v2uint %3 %11 = OpLoad %v2uint %3
%25 = OpCompositeExtract %uint %11 1 %28 = OpCompositeExtract %uint %11 1
%10 = OpCopyObject %uint %25 %10 = OpCopyObject %uint %28
OpStore %6 %10 OpStore %6 %10
%13 = OpLoad %uint %5 %13 = OpLoad %uint %5
%14 = OpLoad %uint %6 %14 = OpLoad %uint %6
%12 = OpIAdd %uint %13 %14 %12 = OpIAdd %uint %13 %14
OpStore %6 %12 OpStore %6 %12
%16 = OpLoad %uint %6 %16 = OpLoad %v2uint %4
%26 = OpCopyObject %uint %16 %17 = OpLoad %uint %6
%15 = OpCompositeInsert %uint %26 %15 0 %15 = OpCompositeInsert %v2uint %17 %16 0
OpStore %4 %15 OpStore %4 %15
%18 = OpLoad %uint %6 %19 = OpLoad %v2uint %4
%27 = OpCopyObject %uint %18 %20 = OpLoad %uint %6
%17 = OpCompositeInsert %uint %27 %17 1 %18 = OpCompositeInsert %v2uint %20 %19 1
OpStore %4 %17 OpStore %4 %18
%20 = OpLoad %v2uint %4
%29 = OpCompositeExtract %uint %20 1
%28 = OpCopyObject %uint %29
%19 = OpCompositeInsert %uint %28 %19 0
OpStore %4 %19
%22 = OpLoad %v2uint %4 %22 = OpLoad %v2uint %4
%21 = OpCopyObject %v2uint %22 %23 = OpLoad %v2uint %4
OpStore %2 %21 %29 = OpCompositeExtract %uint %23 1
%23 = OpLoad %v2uint %2 %21 = OpCompositeInsert %v2uint %29 %22 0
OpReturnValue %23 OpStore %4 %21
%25 = OpLoad %v2uint %4
%24 = OpCopyObject %v2uint %25
OpStore %2 %24
%26 = OpLoad %v2uint %2
OpReturnValue %26
OpFunctionEnd OpFunctionEnd
%31 = OpFunction %void None %66 %31 = OpFunction %void None %68
%40 = OpFunctionParameter %ulong %40 = OpFunctionParameter %ulong
%41 = OpFunctionParameter %ulong %41 = OpFunctionParameter %ulong
%56 = OpLabel %58 = OpLabel
%32 = OpVariable %_ptr_Function_ulong Function %32 = OpVariable %_ptr_Function_ulong Function
%33 = OpVariable %_ptr_Function_ulong Function %33 = OpVariable %_ptr_Function_ulong Function
%34 = OpVariable %_ptr_Function_ulong Function %34 = OpVariable %_ptr_Function_ulong Function
@ -85,11 +85,13 @@
%48 = OpFunctionCall %v2uint %1 %49 %48 = OpFunctionCall %v2uint %1 %49
OpStore %36 %48 OpStore %36 %48
%51 = OpLoad %v2uint %36 %51 = OpLoad %v2uint %36
%50 = OpCopyObject %ulong %51 %55 = OpBitcast %ulong %51
%56 = OpCopyObject %ulong %55
%50 = OpCopyObject %ulong %56
OpStore %39 %50 OpStore %39 %50
%52 = OpLoad %ulong %35 %52 = OpLoad %ulong %35
%53 = OpLoad %v2uint %36 %53 = OpLoad %v2uint %36
%55 = OpConvertUToPtr %_ptr_Generic_v2uint %52 %57 = OpConvertUToPtr %_ptr_Generic_v2uint %52
OpStore %55 %53 OpStore %57 %53
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd

View File

@ -323,7 +323,8 @@ fn to_ssa<'input, 'b>(
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body); let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body);
let mut numeric_id_defs = id_defs.finish(); let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
let unadorned_statements = resolve_fn_calls(&fn_defs, unadorned_statements); let unadorned_statements =
add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs);
let (f_args, ssa_statements) = let (f_args, ssa_statements) =
insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args); insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
@ -345,9 +346,10 @@ fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedSta
func func
} }
fn resolve_fn_calls( fn add_types_to_statements(
fn_defs: &GlobalFnDeclResolver,
func: Vec<UnadornedStatement>, func: Vec<UnadornedStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &NumericIdResolver,
) -> Vec<UnadornedStatement> { ) -> Vec<UnadornedStatement> {
func.into_iter() func.into_iter()
.map(|s| { .map(|s| {
@ -365,6 +367,17 @@ fn resolve_fn_calls(
}; };
Statement::Call(resolved_call) Statement::Call(resolved_call)
} }
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
// TODO fail on type mismatch
let new_dets = match id_defs.get_type(*args.dst()) {
Some(ast::Type::Vector(_, len)) => ast::MovVectorDetails {
length: len,
..dets
},
_ => dets,
};
Statement::Instruction(ast::Instruction::MovVector(new_dets, args))
}
s => s, s => s,
} }
}) })
@ -685,7 +698,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
fn variable( fn variable(
&mut self, &mut self,
desc: ArgumentDescriptor<spirv::Word>, desc: ArgumentDescriptor<spirv::Word>,
typ: Option<ast::Type>, _: Option<ast::Type>,
) -> spirv::Word { ) -> spirv::Word {
desc.op desc.op
} }
@ -757,34 +770,18 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
fn src_vec_operand( fn src_vec_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>, desc: ArgumentDescriptor<(spirv::Word, u8)>,
typ: ast::MovVectorType, (scalar_type, vec_len): (ast::MovVectorType, u8),
) -> spirv::Word { ) -> spirv::Word {
let (vector_id, index) = desc.op; let new_id = self
let new_id = self.id_def.new_id(Some(ast::Type::Scalar(typ.into()))); .id_def
let composite = if desc.is_dst { .new_id(Some(ast::Type::Vector(scalar_type.into(), vec_len)));
Statement::Composite(CompositeAccess { self.func.push(Statement::Composite(CompositeRead {
typ: typ, typ: scalar_type,
dst: new_id, dst: new_id,
src: vector_id, src_composite: desc.op.0,
index: index as u32, src_index: desc.op.1 as u32,
is_write: true }));
}) new_id
} else {
Statement::Composite(CompositeAccess {
typ: typ,
dst: new_id,
src: vector_id,
index: index as u32,
is_write: false
})
};
if desc.is_dst {
self.post_stmts.push(composite);
new_id
} else {
self.func.push(composite);
new_id
}
} }
} }
@ -864,6 +861,55 @@ fn insert_implicit_conversions(
|arg| ast::Instruction::St(st, arg), |arg| ast::Instruction::St(st, arg),
) )
} }
ast::Instruction::Mov(d, mut arg) => {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2
// TODO: handle the case of mixed vector/scalar implicit conversions
let inst_typ_is_bit = match d {
ast::MovType::Scalar(t) => {
ast::ScalarType::from(t).kind() == ScalarKind::Bit
}
ast::MovType::Vector(_, _) => false,
};
let mut did_vector_implicit = false;
let mut post_conv = None;
if inst_typ_is_bit {
let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!());
if let ast::Type::Vector(_, _) = src_type {
arg.src = insert_conversion_src(
&mut result,
id_def,
arg.src,
src_type,
d.into(),
ConversionKind::Default,
);
did_vector_implicit = true;
}
let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!());
if let ast::Type::Vector(_, _) = src_type {
post_conv = Some(get_conversion_dst(
id_def,
&mut arg.dst,
d.into(),
dst_type,
ConversionKind::Default,
));
did_vector_implicit = true;
}
}
if did_vector_implicit {
result.push(Statement::Instruction(ast::Instruction::Mov(d, arg)));
} else {
insert_implicit_bitcasts(
&mut result,
id_def,
ast::Instruction::Mov(d, arg),
);
}
if let Some(post_conv) = post_conv {
result.push(post_conv);
}
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst), inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
}, },
s @ Statement::Composite(_) s @ Statement::Composite(_)
@ -1087,10 +1133,31 @@ fn emit_function_body_ops(
builder.copy_object(result_type, Some(arg.dst), arg.src)?; builder.copy_object(result_type, Some(arg.dst), arg.src)?;
} }
ast::Instruction::SetpBool(_, _) => todo!(), ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::MovVector(t, arg) => { ast::Instruction::MovVector(typ, arg) => match arg {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); ast::Arg2Vec::Dst((dst, dst_index), composite_src, src)
builder.copy_object(result_type, Some(arg.dst()), arg.src())?; | ast::Arg2Vec::Both((dst, dst_index), composite_src, src) => {
} let result_type = map.get_or_add(
builder,
SpirvType::Vector(
SpirvScalarKey::from(ast::ScalarType::from(typ.typ)),
typ.length,
),
);
let result_id = Some(*dst);
builder.composite_insert(
result_type,
result_id,
*src,
*composite_src,
[*dst_index as u32],
)?;
}
ast::Arg2Vec::Src(dst, src) => {
let result_type =
map.get_or_add_scalar(builder, ast::ScalarType::from(typ.typ));
builder.copy_object(result_type, Some(*dst), *src)?;
}
},
}, },
Statement::LoadVar(arg, typ) => { Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ)); let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@ -1105,15 +1172,12 @@ fn emit_function_body_ops(
Statement::Composite(c) => { Statement::Composite(c) => {
let result_type = map.get_or_add_scalar(builder, c.typ.into()); let result_type = map.get_or_add_scalar(builder, c.typ.into());
let result_id = Some(c.dst); let result_id = Some(c.dst);
let indexes = [c.index]; builder.composite_extract(
if c.is_write { result_type,
let object = c.src; result_id,
let composite = c.dst; c.src_composite,
builder.composite_insert(result_type, result_id, object, composite, indexes)?; [c.src_index],
} else { )?;
let composite = c.src;
builder.composite_extract(result_type, result_id, composite, indexes)?;
}
} }
} }
} }
@ -1369,15 +1433,15 @@ fn emit_implicit_conversion(
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
let from_parts = cv.from.to_parts(); let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts(); let to_parts = cv.to.to_parts();
match cv.kind { match (from_parts.kind, to_parts.kind, cv.kind) {
ConversionKind::Ptr(space) => { (_, _, ConversionKind::Ptr(space)) => {
let dst_type = map.get_or_add( let dst_type = map.get_or_add(
builder, builder,
SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()), SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
); );
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
} }
ConversionKind::Default => { (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width { if from_parts.width == to_parts.width {
let dst_type = map.get_or_add(builder, SpirvType::from(cv.from)); let dst_type = map.get_or_add(builder, SpirvType::from(cv.from));
if from_parts.scalar_kind != ScalarKind::Float if from_parts.scalar_kind != ScalarKind::Float
@ -1424,7 +1488,13 @@ fn emit_implicit_conversion(
} }
} }
} }
ConversionKind::SignExtend => todo!(), (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
_ => unreachable!(),
} }
Ok(()) Ok(())
} }
@ -1723,7 +1793,7 @@ enum Statement<I, P: ast::ArgParams> {
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type), LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type), StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Call(ResolvedCall<P>), Call(ResolvedCall<P>),
Composite(CompositeAccess), Composite(CompositeRead),
// SPIR-V compatible replacement for PTX predicates // SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition), Conditional(BrachCondition),
Conversion(ImplicitConversion), Conversion(ImplicitConversion),
@ -1874,7 +1944,7 @@ trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn src_vec_operand( fn src_vec_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<T::VecOperand>, desc: ArgumentDescriptor<T::VecOperand>,
typ: ast::MovVectorType, typ: (ast::MovVectorType, u8),
) -> U::VecOperand; ) -> U::VecOperand;
} }
@ -1902,9 +1972,12 @@ where
fn src_vec_operand( fn src_vec_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<spirv::Word>, desc: ArgumentDescriptor<spirv::Word>,
t: ast::MovVectorType, (scalar_type, vec_len): (ast::MovVectorType, u8),
) -> spirv::Word { ) -> spirv::Word {
self(desc, Some(ast::Type::Scalar(t.into()))) self(
desc.new_op(desc.op),
Some(ast::Type::Vector(scalar_type.into(), vec_len)),
)
} }
} }
@ -1942,7 +2015,7 @@ where
fn src_vec_operand( fn src_vec_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<(&str, u8)>, desc: ArgumentDescriptor<(&str, u8)>,
_: ast::MovVectorType, _: (ast::MovVectorType, u8),
) -> (spirv::Word, u8) { ) -> (spirv::Word, u8) {
(self(desc.op.0), desc.op.1) (self(desc.op.0), desc.op.1)
} }
@ -1970,7 +2043,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
visitor: &mut V, visitor: &mut V,
) -> ast::Instruction<U> { ) -> ast::Instruction<U> {
match self { match self {
ast::Instruction::MovVector(t, a) => ast::Instruction::MovVector(t, a.map(visitor, t)), ast::Instruction::MovVector(t, a) => {
ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length)))
}
ast::Instruction::Abs(_, _) => todo!(), ast::Instruction::Abs(_, _) => todo!(),
// Call instruction is converted to a call statement early on // Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => unreachable!(), ast::Instruction::Call(_) => unreachable!(),
@ -2090,12 +2165,12 @@ where
fn src_vec_operand( fn src_vec_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>, desc: ArgumentDescriptor<(spirv::Word, u8)>,
t: ast::MovVectorType, (scalar_type, vector_len): (ast::MovVectorType, u8),
) -> (spirv::Word, u8) { ) -> (spirv::Word, u8) {
( (
self( self(
desc.new_op(desc.op.0), desc.new_op(desc.op.0),
Some(ast::Type::Vector(t.into(), desc.op.1)), Some(ast::Type::Vector(scalar_type.into(), vector_len)),
), ),
desc.op.1, desc.op.1,
) )
@ -2195,27 +2270,11 @@ impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
type Arg2 = ast::Arg2<ExpandedArgParams>; type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>; type Arg2St = ast::Arg2St<ExpandedArgParams>;
struct CompositeAccess {
pub typ: ast::MovVectorType,
pub dst: spirv::Word,
pub src: spirv::Word,
pub index: u32,
pub is_write: bool
}
struct CompositeWrite {
pub typ: ast::MovVectorType,
pub dst: spirv::Word,
pub src_composite: spirv::Word,
pub src_scalar: spirv::Word,
pub index: u32,
}
struct CompositeRead { struct CompositeRead {
pub typ: ast::MovVectorType, pub typ: ast::MovVectorType,
pub dst: spirv::Word, pub dst: spirv::Word,
pub src: spirv::Word, pub src_composite: spirv::Word,
pub index: u32, pub src_index: u32,
} }
struct ConstantDefinition { struct ConstantDefinition {
@ -2407,28 +2466,47 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
} }
impl<T: ArgParamsEx> ast::Arg2Vec<T> { impl<T: ArgParamsEx> ast::Arg2Vec<T> {
fn dst(&self) -> &T::ID {
match self {
ast::Arg2Vec::Dst((d, _), _, _)
| ast::Arg2Vec::Src(d, _)
| ast::Arg2Vec::Both((d, _), _, _) => d,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>( fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self, self,
visitor: &mut V, visitor: &mut V,
t: ast::MovVectorType, (scalar_type, vec_len): (ast::MovVectorType, u8),
) -> ast::Arg2Vec<U> { ) -> ast::Arg2Vec<U> {
match self { match self {
ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst( ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => ast::Arg2Vec::Dst(
visitor.src_vec_operand( (
ArgumentDescriptor { visitor.variable(
op: dst, ArgumentDescriptor {
is_dst: true, op: dst,
is_pointer: false, is_dst: true,
}, is_pointer: false,
t, },
Some(ast::Type::Scalar(scalar_type.into())),
),
len,
), ),
visitor.variable( visitor.variable(
ArgumentDescriptor { ArgumentDescriptor {
op: src, op: composite_src,
is_dst: false, is_dst: false,
is_pointer: false, is_pointer: false,
}, },
Some(ast::Type::Scalar(t.into())), Some(ast::Type::Scalar(scalar_type.into())),
),
visitor.variable(
ArgumentDescriptor {
op: scalar_src,
is_dst: false,
is_pointer: false,
},
Some(ast::Type::Scalar(scalar_type.into())),
), ),
), ),
ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src( ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src(
@ -2438,7 +2516,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
is_dst: true, is_dst: true,
is_pointer: false, is_pointer: false,
}, },
Some(ast::Type::Scalar(t.into())), Some(ast::Type::Scalar(scalar_type.into())),
), ),
visitor.src_vec_operand( visitor.src_vec_operand(
ArgumentDescriptor { ArgumentDescriptor {
@ -2446,17 +2524,28 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
is_dst: false, is_dst: false,
is_pointer: false, is_pointer: false,
}, },
t, (scalar_type, vec_len),
), ),
), ),
ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both( ast::Arg2Vec::Both((dst, len), composite_src, src) => ast::Arg2Vec::Both(
visitor.src_vec_operand( (
visitor.variable(
ArgumentDescriptor {
op: dst,
is_dst: true,
is_pointer: false,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
len,
),
visitor.variable(
ArgumentDescriptor { ArgumentDescriptor {
op: dst, op: composite_src,
is_dst: true, is_dst: false,
is_pointer: false, is_pointer: false,
}, },
t, Some(ast::Type::Scalar(scalar_type.into())),
), ),
visitor.src_vec_operand( visitor.src_vec_operand(
ArgumentDescriptor { ArgumentDescriptor {
@ -2464,31 +2553,13 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
is_dst: false, is_dst: false,
is_pointer: false, is_pointer: false,
}, },
t, (scalar_type, vec_len),
), ),
), ),
} }
} }
} }
impl ast::Arg2Vec<ExpandedArgParams> {
fn dst(&self) -> spirv::Word {
match self {
ast::Arg2Vec::Dst(dst, _) | ast::Arg2Vec::Src(dst, _) | ast::Arg2Vec::Both(dst, _) => {
*dst
}
}
}
fn src(&self) -> spirv::Word {
match self {
ast::Arg2Vec::Dst(_, src) | ast::Arg2Vec::Src(_, src) | ast::Arg2Vec::Both(_, src) => {
*src
}
}
}
}
impl<T: ArgParamsEx> ast::Arg3<T> { impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>( fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self, self,