mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-05-31 04:28:55 +03:00
Support vector member read/write
This commit is contained in:
@ -456,7 +456,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
|
||||
Statement::RepackVector(_) => todo!(),
|
||||
Statement::FunctionPointer(_) => todo!(),
|
||||
Statement::VectorAccess(_) => todo!(),
|
||||
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
|
||||
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
|
||||
})
|
||||
}
|
||||
|
||||
@ -986,6 +987,39 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> {
|
||||
let src = self.resolver.value(vec_acccess.vector_src)?;
|
||||
let index = unsafe {
|
||||
LLVMConstInt(
|
||||
get_scalar_type(self.context, ast::ScalarType::B8),
|
||||
vec_acccess.member as _,
|
||||
0,
|
||||
)
|
||||
};
|
||||
self.resolver
|
||||
.with_result(vec_acccess.scalar_dst, |dst| unsafe {
|
||||
LLVMBuildExtractElement(self.builder, src, index, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> {
|
||||
let vector_src = self.resolver.value(vector_write.vector_src)?;
|
||||
let scalar_src = self.resolver.value(vector_write.scalar_src)?;
|
||||
let index = unsafe {
|
||||
LLVMConstInt(
|
||||
get_scalar_type(self.context, ast::ScalarType::B8),
|
||||
vector_write.member as _,
|
||||
0,
|
||||
)
|
||||
};
|
||||
self.resolver
|
||||
.with_result(vector_write.vector_dst, |dst| unsafe {
|
||||
LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pointer_type<'ctx>(
|
||||
|
@ -1562,7 +1562,8 @@ fn emit_function_body_ops<'input>(
|
||||
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
|
||||
}
|
||||
}
|
||||
Statement::VectorAccess(vector_access) => todo!(),
|
||||
Statement::VectorRead(vector_access) => todo!(),
|
||||
Statement::VectorWrite(vector_write) => todo!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -189,15 +189,12 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
|
||||
|
||||
fn vec_member(
|
||||
&mut self,
|
||||
vector_src: SpirvWord,
|
||||
vector_ident: SpirvWord,
|
||||
member: u8,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
|
||||
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
|
||||
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
||||
(*vector_width, *scalar_t, *space)
|
||||
}
|
||||
@ -206,13 +203,24 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
|
||||
let temporary = self
|
||||
.resolver
|
||||
.register_unnamed(Some((scalar_type.into(), space)));
|
||||
self.result.push(Statement::VectorAccess(VectorAccess {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst: temporary,
|
||||
src: vector_src,
|
||||
member: member,
|
||||
}));
|
||||
if is_dst {
|
||||
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
vector_dst: vector_ident,
|
||||
vector_src: vector_ident,
|
||||
scalar_src: temporary,
|
||||
member,
|
||||
}));
|
||||
} else {
|
||||
self.result.push(Statement::VectorRead(VectorRead {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
scalar_dst: temporary,
|
||||
vector_src: vector_ident,
|
||||
member,
|
||||
}));
|
||||
}
|
||||
Ok(temporary)
|
||||
}
|
||||
|
||||
|
@ -112,7 +112,7 @@ impl<'a, 'b, 'input>
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.replace_sreg(args, None, is_dst)
|
||||
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,7 +122,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
||||
name: SpirvWord,
|
||||
vector_index: Option<u8>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
) -> Result<Option<SpirvWord>, TranslateError> {
|
||||
if let Some(sreg) = self.special_registers.get(name) {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
@ -179,30 +179,33 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
||||
data,
|
||||
arguments,
|
||||
}));
|
||||
Ok(fn_result)
|
||||
Ok(Some(fn_result))
|
||||
} else {
|
||||
Ok(name)
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_operand<T, U, Err>(
|
||||
pub fn map_operand<T: Copy, Err>(
|
||||
this: ast::ParsedOperand<T>,
|
||||
fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
|
||||
) -> Result<ast::ParsedOperand<U>, Err> {
|
||||
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
|
||||
) -> Result<ast::ParsedOperand<T>, Err> {
|
||||
Ok(match this {
|
||||
ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
|
||||
ast::ParsedOperand::Reg(ident) => {
|
||||
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
|
||||
}
|
||||
ast::ParsedOperand::RegOffset(ident, offset) => {
|
||||
ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
|
||||
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
|
||||
}
|
||||
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
||||
ast::ParsedOperand::VecMember(ident, member) => {
|
||||
ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
|
||||
}
|
||||
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
|
||||
Some(ident) => ast::ParsedOperand::Reg(ident),
|
||||
None => ast::ParsedOperand::VecMember(ident, member),
|
||||
},
|
||||
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||
idents
|
||||
.into_iter()
|
||||
.map(|ident| fn_(ident, None))
|
||||
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
})
|
||||
|
@ -45,11 +45,18 @@ pub(super) fn run(
|
||||
Statement::RepackVector(repack),
|
||||
)?;
|
||||
}
|
||||
Statement::VectorAccess(vector_access) => {
|
||||
Statement::VectorRead(vector_read) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::VectorAccess(vector_access),
|
||||
Statement::VectorRead(vector_read),
|
||||
)?;
|
||||
}
|
||||
Statement::VectorWrite(vector_write) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::VectorWrite(vector_write),
|
||||
)?;
|
||||
}
|
||||
s @ Statement::Conditional(_)
|
||||
|
@ -774,7 +774,8 @@ enum Statement<I, P: ast::Operand> {
|
||||
PtrAccess(PtrAccess<P>),
|
||||
RepackVector(RepackVectorDetails),
|
||||
FunctionPointer(FunctionPointerDetails),
|
||||
VectorAccess(VectorAccess),
|
||||
VectorRead(VectorRead),
|
||||
VectorWrite(VectorWrite),
|
||||
}
|
||||
|
||||
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||
@ -954,33 +955,69 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||
offset_src,
|
||||
})
|
||||
}
|
||||
Statement::VectorAccess(VectorAccess {
|
||||
Statement::VectorRead(VectorRead {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst,
|
||||
src: vector_src,
|
||||
scalar_dst: dst,
|
||||
vector_src,
|
||||
member,
|
||||
}) => {
|
||||
let scalar_t = scalar_type.into();
|
||||
let vector_t = ast::Type::Vector(vector_width, scalar_type);
|
||||
let dst: SpirvWord = visitor.visit_ident(
|
||||
dst,
|
||||
Some((&scalar_type.into(), ast::StateSpace::Reg)),
|
||||
Some((&scalar_t, ast::StateSpace::Reg)),
|
||||
true,
|
||||
false,
|
||||
)?;
|
||||
let src = visitor.visit_ident(
|
||||
vector_src,
|
||||
Some((
|
||||
&ast::Type::Vector(vector_width, scalar_type),
|
||||
ast::StateSpace::Reg,
|
||||
)),
|
||||
Some((&vector_t, ast::StateSpace::Reg)),
|
||||
false,
|
||||
false,
|
||||
)?;
|
||||
Statement::VectorAccess(VectorAccess {
|
||||
Statement::VectorRead(VectorRead {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
scalar_dst: dst,
|
||||
vector_src: src,
|
||||
member,
|
||||
})
|
||||
}
|
||||
Statement::VectorWrite(VectorWrite {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
vector_dst,
|
||||
vector_src,
|
||||
scalar_src,
|
||||
member,
|
||||
}) => {
|
||||
let scalar_t = scalar_type.into();
|
||||
let vector_t = ast::Type::Vector(vector_width, scalar_type);
|
||||
let vector_dst = visitor.visit_ident(
|
||||
vector_dst,
|
||||
Some((&vector_t, ast::StateSpace::Reg)),
|
||||
true,
|
||||
false,
|
||||
)?;
|
||||
let vector_src = visitor.visit_ident(
|
||||
vector_src,
|
||||
Some((&vector_t, ast::StateSpace::Reg)),
|
||||
false,
|
||||
false,
|
||||
)?;
|
||||
let scalar_src = visitor.visit_ident(
|
||||
scalar_src,
|
||||
Some((&scalar_t, ast::StateSpace::Reg)),
|
||||
false,
|
||||
false,
|
||||
)?;
|
||||
Statement::VectorWrite(VectorWrite {
|
||||
vector_dst,
|
||||
vector_src,
|
||||
scalar_src,
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst,
|
||||
src,
|
||||
member,
|
||||
})
|
||||
}
|
||||
@ -1538,7 +1575,8 @@ fn compute_denorm_information<'input>(
|
||||
Statement::Label(_) => {}
|
||||
Statement::Variable(_) => {}
|
||||
Statement::PtrAccess { .. } => {}
|
||||
Statement::VectorAccess { .. } => {}
|
||||
Statement::VectorRead { .. } => {}
|
||||
Statement::VectorWrite { .. } => {}
|
||||
Statement::RepackVector(_) => {}
|
||||
Statement::FunctionPointer(_) => {}
|
||||
}
|
||||
@ -2058,11 +2096,20 @@ impl SpecialRegistersMap2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorAccess {
|
||||
pub struct VectorRead {
|
||||
scalar_type: ast::ScalarType,
|
||||
vector_width: u8,
|
||||
dst: SpirvWord,
|
||||
src: SpirvWord,
|
||||
scalar_dst: SpirvWord,
|
||||
vector_src: SpirvWord,
|
||||
member: u8,
|
||||
}
|
||||
|
||||
pub struct VectorWrite {
|
||||
scalar_type: ast::ScalarType,
|
||||
vector_width: u8,
|
||||
vector_dst: SpirvWord,
|
||||
vector_src: SpirvWord,
|
||||
scalar_src: SpirvWord,
|
||||
member: u8,
|
||||
}
|
||||
|
||||
|
@ -26,7 +26,8 @@ pub(super) fn run(
|
||||
| Statement::Constant(..)
|
||||
| Statement::Label(..)
|
||||
| Statement::PtrAccess { .. }
|
||||
| Statement::VectorAccess { .. }
|
||||
| Statement::VectorRead { .. }
|
||||
| Statement::VectorWrite { .. }
|
||||
| Statement::RepackVector(..)
|
||||
| Statement::FunctionPointer(..) => {}
|
||||
}
|
||||
|
Reference in New Issue
Block a user