Support vector member read/write

This commit is contained in:
Andrzej Janik
2024-10-06 06:44:14 +02:00
parent 56c41b5690
commit 6490519885
7 changed files with 147 additions and 46 deletions

View File

@ -456,7 +456,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
Statement::RepackVector(_) => todo!(),
Statement::FunctionPointer(_) => todo!(),
Statement::VectorAccess(_) => todo!(),
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
})
}
@ -986,6 +987,39 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> {
let src = self.resolver.value(vec_acccess.vector_src)?;
let index = unsafe {
LLVMConstInt(
get_scalar_type(self.context, ast::ScalarType::B8),
vec_acccess.member as _,
0,
)
};
self.resolver
.with_result(vec_acccess.scalar_dst, |dst| unsafe {
LLVMBuildExtractElement(self.builder, src, index, dst)
});
Ok(())
}
fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> {
let vector_src = self.resolver.value(vector_write.vector_src)?;
let scalar_src = self.resolver.value(vector_write.scalar_src)?;
let index = unsafe {
LLVMConstInt(
get_scalar_type(self.context, ast::ScalarType::B8),
vector_write.member as _,
0,
)
};
self.resolver
.with_result(vector_write.vector_dst, |dst| unsafe {
LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst)
});
Ok(())
}
}
fn get_pointer_type<'ctx>(

View File

@ -1562,7 +1562,8 @@ fn emit_function_body_ops<'input>(
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
}
}
Statement::VectorAccess(vector_access) => todo!(),
Statement::VectorRead(vector_access) => todo!(),
Statement::VectorWrite(vector_write) => todo!(),
}
}
Ok(())

View File

@ -189,15 +189,12 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
fn vec_member(
&mut self,
vector_src: SpirvWord,
vector_ident: SpirvWord,
member: u8,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
if is_dst {
return Err(error_mismatched_type());
}
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
(ast::Type::Vector(vector_width, scalar_t), space) => {
(*vector_width, *scalar_t, *space)
}
@ -206,13 +203,24 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
let temporary = self
.resolver
.register_unnamed(Some((scalar_type.into(), space)));
self.result.push(Statement::VectorAccess(VectorAccess {
scalar_type,
vector_width,
dst: temporary,
src: vector_src,
member: member,
}));
if is_dst {
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
scalar_type,
vector_width,
vector_dst: vector_ident,
vector_src: vector_ident,
scalar_src: temporary,
member,
}));
} else {
self.result.push(Statement::VectorRead(VectorRead {
scalar_type,
vector_width,
scalar_dst: temporary,
vector_src: vector_ident,
member,
}));
}
Ok(temporary)
}

View File

@ -112,7 +112,7 @@ impl<'a, 'b, 'input>
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.replace_sreg(args, None, is_dst)
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
}
}
@ -122,7 +122,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
name: SpirvWord,
vector_index: Option<u8>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
) -> Result<Option<SpirvWord>, TranslateError> {
if let Some(sreg) = self.special_registers.get(name) {
if is_dst {
return Err(error_mismatched_type());
@ -179,30 +179,33 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
data,
arguments,
}));
Ok(fn_result)
Ok(Some(fn_result))
} else {
Ok(name)
Ok(None)
}
}
}
pub fn map_operand<T, U, Err>(
pub fn map_operand<T: Copy, Err>(
this: ast::ParsedOperand<T>,
fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
) -> Result<ast::ParsedOperand<U>, Err> {
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
) -> Result<ast::ParsedOperand<T>, Err> {
Ok(match this {
ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
ast::ParsedOperand::Reg(ident) => {
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
}
ast::ParsedOperand::RegOffset(ident, offset) => {
ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
}
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
ast::ParsedOperand::VecMember(ident, member) => {
ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
}
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
Some(ident) => ast::ParsedOperand::Reg(ident),
None => ast::ParsedOperand::VecMember(ident, member),
},
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
idents
.into_iter()
.map(|ident| fn_(ident, None))
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
.collect::<Result<Vec<_>, _>>()?,
),
})

View File

@ -45,11 +45,18 @@ pub(super) fn run(
Statement::RepackVector(repack),
)?;
}
Statement::VectorAccess(vector_access) => {
Statement::VectorRead(vector_read) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::VectorAccess(vector_access),
Statement::VectorRead(vector_read),
)?;
}
Statement::VectorWrite(vector_write) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::VectorWrite(vector_write),
)?;
}
s @ Statement::Conditional(_)

View File

@ -774,7 +774,8 @@ enum Statement<I, P: ast::Operand> {
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
VectorAccess(VectorAccess),
VectorRead(VectorRead),
VectorWrite(VectorWrite),
}
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@ -954,33 +955,69 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
offset_src,
})
}
Statement::VectorAccess(VectorAccess {
Statement::VectorRead(VectorRead {
scalar_type,
vector_width,
dst,
src: vector_src,
scalar_dst: dst,
vector_src,
member,
}) => {
let scalar_t = scalar_type.into();
let vector_t = ast::Type::Vector(vector_width, scalar_type);
let dst: SpirvWord = visitor.visit_ident(
dst,
Some((&scalar_type.into(), ast::StateSpace::Reg)),
Some((&scalar_t, ast::StateSpace::Reg)),
true,
false,
)?;
let src = visitor.visit_ident(
vector_src,
Some((
&ast::Type::Vector(vector_width, scalar_type),
ast::StateSpace::Reg,
)),
Some((&vector_t, ast::StateSpace::Reg)),
false,
false,
)?;
Statement::VectorAccess(VectorAccess {
Statement::VectorRead(VectorRead {
scalar_type,
vector_width,
scalar_dst: dst,
vector_src: src,
member,
})
}
Statement::VectorWrite(VectorWrite {
scalar_type,
vector_width,
vector_dst,
vector_src,
scalar_src,
member,
}) => {
let scalar_t = scalar_type.into();
let vector_t = ast::Type::Vector(vector_width, scalar_type);
let vector_dst = visitor.visit_ident(
vector_dst,
Some((&vector_t, ast::StateSpace::Reg)),
true,
false,
)?;
let vector_src = visitor.visit_ident(
vector_src,
Some((&vector_t, ast::StateSpace::Reg)),
false,
false,
)?;
let scalar_src = visitor.visit_ident(
scalar_src,
Some((&scalar_t, ast::StateSpace::Reg)),
false,
false,
)?;
Statement::VectorWrite(VectorWrite {
vector_dst,
vector_src,
scalar_src,
scalar_type,
vector_width,
dst,
src,
member,
})
}
@ -1538,7 +1575,8 @@ fn compute_denorm_information<'input>(
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
Statement::VectorAccess { .. } => {}
Statement::VectorRead { .. } => {}
Statement::VectorWrite { .. } => {}
Statement::RepackVector(_) => {}
Statement::FunctionPointer(_) => {}
}
@ -2058,11 +2096,20 @@ impl SpecialRegistersMap2 {
}
}
pub struct VectorAccess {
pub struct VectorRead {
scalar_type: ast::ScalarType,
vector_width: u8,
dst: SpirvWord,
src: SpirvWord,
scalar_dst: SpirvWord,
vector_src: SpirvWord,
member: u8,
}
pub struct VectorWrite {
scalar_type: ast::ScalarType,
vector_width: u8,
vector_dst: SpirvWord,
vector_src: SpirvWord,
scalar_src: SpirvWord,
member: u8,
}

View File

@ -26,7 +26,8 @@ pub(super) fn run(
| Statement::Constant(..)
| Statement::Label(..)
| Statement::PtrAccess { .. }
| Statement::VectorAccess { .. }
| Statement::VectorRead { .. }
| Statement::VectorWrite { .. }
| Statement::RepackVector(..)
| Statement::FunctionPointer(..) => {}
}