mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-06-03 05:58:53 +03:00
Support vector member read/write
This commit is contained in:
@ -456,7 +456,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||||||
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
|
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
|
||||||
Statement::RepackVector(_) => todo!(),
|
Statement::RepackVector(_) => todo!(),
|
||||||
Statement::FunctionPointer(_) => todo!(),
|
Statement::FunctionPointer(_) => todo!(),
|
||||||
Statement::VectorAccess(_) => todo!(),
|
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
|
||||||
|
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -986,6 +987,39 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||||||
});
|
});
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> {
|
||||||
|
let src = self.resolver.value(vec_acccess.vector_src)?;
|
||||||
|
let index = unsafe {
|
||||||
|
LLVMConstInt(
|
||||||
|
get_scalar_type(self.context, ast::ScalarType::B8),
|
||||||
|
vec_acccess.member as _,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
self.resolver
|
||||||
|
.with_result(vec_acccess.scalar_dst, |dst| unsafe {
|
||||||
|
LLVMBuildExtractElement(self.builder, src, index, dst)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> {
|
||||||
|
let vector_src = self.resolver.value(vector_write.vector_src)?;
|
||||||
|
let scalar_src = self.resolver.value(vector_write.scalar_src)?;
|
||||||
|
let index = unsafe {
|
||||||
|
LLVMConstInt(
|
||||||
|
get_scalar_type(self.context, ast::ScalarType::B8),
|
||||||
|
vector_write.member as _,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
self.resolver
|
||||||
|
.with_result(vector_write.vector_dst, |dst| unsafe {
|
||||||
|
LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pointer_type<'ctx>(
|
fn get_pointer_type<'ctx>(
|
||||||
|
@ -1562,7 +1562,8 @@ fn emit_function_body_ops<'input>(
|
|||||||
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
|
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Statement::VectorAccess(vector_access) => todo!(),
|
Statement::VectorRead(vector_access) => todo!(),
|
||||||
|
Statement::VectorWrite(vector_write) => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -189,15 +189,12 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
|
|||||||
|
|
||||||
fn vec_member(
|
fn vec_member(
|
||||||
&mut self,
|
&mut self,
|
||||||
vector_src: SpirvWord,
|
vector_ident: SpirvWord,
|
||||||
member: u8,
|
member: u8,
|
||||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
if is_dst {
|
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
|
||||||
return Err(error_mismatched_type());
|
|
||||||
}
|
|
||||||
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
|
|
||||||
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
||||||
(*vector_width, *scalar_t, *space)
|
(*vector_width, *scalar_t, *space)
|
||||||
}
|
}
|
||||||
@ -206,13 +203,24 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
|
|||||||
let temporary = self
|
let temporary = self
|
||||||
.resolver
|
.resolver
|
||||||
.register_unnamed(Some((scalar_type.into(), space)));
|
.register_unnamed(Some((scalar_type.into(), space)));
|
||||||
self.result.push(Statement::VectorAccess(VectorAccess {
|
if is_dst {
|
||||||
scalar_type,
|
self.post_stmts.push(Statement::VectorWrite(VectorWrite {
|
||||||
vector_width,
|
scalar_type,
|
||||||
dst: temporary,
|
vector_width,
|
||||||
src: vector_src,
|
vector_dst: vector_ident,
|
||||||
member: member,
|
vector_src: vector_ident,
|
||||||
}));
|
scalar_src: temporary,
|
||||||
|
member,
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
self.result.push(Statement::VectorRead(VectorRead {
|
||||||
|
scalar_type,
|
||||||
|
vector_width,
|
||||||
|
scalar_dst: temporary,
|
||||||
|
vector_src: vector_ident,
|
||||||
|
member,
|
||||||
|
}));
|
||||||
|
}
|
||||||
Ok(temporary)
|
Ok(temporary)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ impl<'a, 'b, 'input>
|
|||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
_relaxed_type_check: bool,
|
_relaxed_type_check: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
self.replace_sreg(args, None, is_dst)
|
Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
|||||||
name: SpirvWord,
|
name: SpirvWord,
|
||||||
vector_index: Option<u8>,
|
vector_index: Option<u8>,
|
||||||
is_dst: bool,
|
is_dst: bool,
|
||||||
) -> Result<SpirvWord, TranslateError> {
|
) -> Result<Option<SpirvWord>, TranslateError> {
|
||||||
if let Some(sreg) = self.special_registers.get(name) {
|
if let Some(sreg) = self.special_registers.get(name) {
|
||||||
if is_dst {
|
if is_dst {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
@ -179,30 +179,33 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
|||||||
data,
|
data,
|
||||||
arguments,
|
arguments,
|
||||||
}));
|
}));
|
||||||
Ok(fn_result)
|
Ok(Some(fn_result))
|
||||||
} else {
|
} else {
|
||||||
Ok(name)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn map_operand<T, U, Err>(
|
pub fn map_operand<T: Copy, Err>(
|
||||||
this: ast::ParsedOperand<T>,
|
this: ast::ParsedOperand<T>,
|
||||||
fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
|
fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
|
||||||
) -> Result<ast::ParsedOperand<U>, Err> {
|
) -> Result<ast::ParsedOperand<T>, Err> {
|
||||||
Ok(match this {
|
Ok(match this {
|
||||||
ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
|
ast::ParsedOperand::Reg(ident) => {
|
||||||
|
ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
|
||||||
|
}
|
||||||
ast::ParsedOperand::RegOffset(ident, offset) => {
|
ast::ParsedOperand::RegOffset(ident, offset) => {
|
||||||
ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
|
ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
|
||||||
}
|
}
|
||||||
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
||||||
ast::ParsedOperand::VecMember(ident, member) => {
|
ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
|
||||||
ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
|
Some(ident) => ast::ParsedOperand::Reg(ident),
|
||||||
}
|
None => ast::ParsedOperand::VecMember(ident, member),
|
||||||
|
},
|
||||||
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||||
idents
|
idents
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|ident| fn_(ident, None))
|
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
|
||||||
.collect::<Result<Vec<_>, _>>()?,
|
.collect::<Result<Vec<_>, _>>()?,
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
@ -45,11 +45,18 @@ pub(super) fn run(
|
|||||||
Statement::RepackVector(repack),
|
Statement::RepackVector(repack),
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
Statement::VectorAccess(vector_access) => {
|
Statement::VectorRead(vector_read) => {
|
||||||
insert_implicit_conversions_impl(
|
insert_implicit_conversions_impl(
|
||||||
&mut result,
|
&mut result,
|
||||||
id_def,
|
id_def,
|
||||||
Statement::VectorAccess(vector_access),
|
Statement::VectorRead(vector_read),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
Statement::VectorWrite(vector_write) => {
|
||||||
|
insert_implicit_conversions_impl(
|
||||||
|
&mut result,
|
||||||
|
id_def,
|
||||||
|
Statement::VectorWrite(vector_write),
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
s @ Statement::Conditional(_)
|
s @ Statement::Conditional(_)
|
||||||
|
@ -774,7 +774,8 @@ enum Statement<I, P: ast::Operand> {
|
|||||||
PtrAccess(PtrAccess<P>),
|
PtrAccess(PtrAccess<P>),
|
||||||
RepackVector(RepackVectorDetails),
|
RepackVector(RepackVectorDetails),
|
||||||
FunctionPointer(FunctionPointerDetails),
|
FunctionPointer(FunctionPointerDetails),
|
||||||
VectorAccess(VectorAccess),
|
VectorRead(VectorRead),
|
||||||
|
VectorWrite(VectorWrite),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
@ -954,33 +955,69 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
|||||||
offset_src,
|
offset_src,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Statement::VectorAccess(VectorAccess {
|
Statement::VectorRead(VectorRead {
|
||||||
scalar_type,
|
scalar_type,
|
||||||
vector_width,
|
vector_width,
|
||||||
dst,
|
scalar_dst: dst,
|
||||||
src: vector_src,
|
vector_src,
|
||||||
member,
|
member,
|
||||||
}) => {
|
}) => {
|
||||||
|
let scalar_t = scalar_type.into();
|
||||||
|
let vector_t = ast::Type::Vector(vector_width, scalar_type);
|
||||||
let dst: SpirvWord = visitor.visit_ident(
|
let dst: SpirvWord = visitor.visit_ident(
|
||||||
dst,
|
dst,
|
||||||
Some((&scalar_type.into(), ast::StateSpace::Reg)),
|
Some((&scalar_t, ast::StateSpace::Reg)),
|
||||||
true,
|
true,
|
||||||
false,
|
false,
|
||||||
)?;
|
)?;
|
||||||
let src = visitor.visit_ident(
|
let src = visitor.visit_ident(
|
||||||
vector_src,
|
vector_src,
|
||||||
Some((
|
Some((&vector_t, ast::StateSpace::Reg)),
|
||||||
&ast::Type::Vector(vector_width, scalar_type),
|
|
||||||
ast::StateSpace::Reg,
|
|
||||||
)),
|
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
)?;
|
)?;
|
||||||
Statement::VectorAccess(VectorAccess {
|
Statement::VectorRead(VectorRead {
|
||||||
|
scalar_type,
|
||||||
|
vector_width,
|
||||||
|
scalar_dst: dst,
|
||||||
|
vector_src: src,
|
||||||
|
member,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Statement::VectorWrite(VectorWrite {
|
||||||
|
scalar_type,
|
||||||
|
vector_width,
|
||||||
|
vector_dst,
|
||||||
|
vector_src,
|
||||||
|
scalar_src,
|
||||||
|
member,
|
||||||
|
}) => {
|
||||||
|
let scalar_t = scalar_type.into();
|
||||||
|
let vector_t = ast::Type::Vector(vector_width, scalar_type);
|
||||||
|
let vector_dst = visitor.visit_ident(
|
||||||
|
vector_dst,
|
||||||
|
Some((&vector_t, ast::StateSpace::Reg)),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
let vector_src = visitor.visit_ident(
|
||||||
|
vector_src,
|
||||||
|
Some((&vector_t, ast::StateSpace::Reg)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
let scalar_src = visitor.visit_ident(
|
||||||
|
scalar_src,
|
||||||
|
Some((&scalar_t, ast::StateSpace::Reg)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
Statement::VectorWrite(VectorWrite {
|
||||||
|
vector_dst,
|
||||||
|
vector_src,
|
||||||
|
scalar_src,
|
||||||
scalar_type,
|
scalar_type,
|
||||||
vector_width,
|
vector_width,
|
||||||
dst,
|
|
||||||
src,
|
|
||||||
member,
|
member,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -1538,7 +1575,8 @@ fn compute_denorm_information<'input>(
|
|||||||
Statement::Label(_) => {}
|
Statement::Label(_) => {}
|
||||||
Statement::Variable(_) => {}
|
Statement::Variable(_) => {}
|
||||||
Statement::PtrAccess { .. } => {}
|
Statement::PtrAccess { .. } => {}
|
||||||
Statement::VectorAccess { .. } => {}
|
Statement::VectorRead { .. } => {}
|
||||||
|
Statement::VectorWrite { .. } => {}
|
||||||
Statement::RepackVector(_) => {}
|
Statement::RepackVector(_) => {}
|
||||||
Statement::FunctionPointer(_) => {}
|
Statement::FunctionPointer(_) => {}
|
||||||
}
|
}
|
||||||
@ -2058,11 +2096,20 @@ impl SpecialRegistersMap2 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct VectorAccess {
|
pub struct VectorRead {
|
||||||
scalar_type: ast::ScalarType,
|
scalar_type: ast::ScalarType,
|
||||||
vector_width: u8,
|
vector_width: u8,
|
||||||
dst: SpirvWord,
|
scalar_dst: SpirvWord,
|
||||||
src: SpirvWord,
|
vector_src: SpirvWord,
|
||||||
|
member: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct VectorWrite {
|
||||||
|
scalar_type: ast::ScalarType,
|
||||||
|
vector_width: u8,
|
||||||
|
vector_dst: SpirvWord,
|
||||||
|
vector_src: SpirvWord,
|
||||||
|
scalar_src: SpirvWord,
|
||||||
member: u8,
|
member: u8,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,7 +26,8 @@ pub(super) fn run(
|
|||||||
| Statement::Constant(..)
|
| Statement::Constant(..)
|
||||||
| Statement::Label(..)
|
| Statement::Label(..)
|
||||||
| Statement::PtrAccess { .. }
|
| Statement::PtrAccess { .. }
|
||||||
| Statement::VectorAccess { .. }
|
| Statement::VectorRead { .. }
|
||||||
|
| Statement::VectorWrite { .. }
|
||||||
| Statement::RepackVector(..)
|
| Statement::RepackVector(..)
|
||||||
| Statement::FunctionPointer(..) => {}
|
| Statement::FunctionPointer(..) => {}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user