From 6490519885a5086a6662347ec9d77bbfc473d2cc Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 6 Oct 2024 06:44:14 +0200 Subject: [PATCH] Support vector member read/write --- ptx/src/pass/emit_llvm.rs | 36 +++++++++- ptx/src/pass/emit_spirv.rs | 3 +- ptx/src/pass/expand_operands.rs | 32 +++++---- ptx/src/pass/fix_special_registers2.rs | 29 ++++---- ptx/src/pass/insert_implicit_conversions.rs | 11 ++- ptx/src/pass/mod.rs | 79 ++++++++++++++++----- ptx/src/pass/normalize_labels.rs | 3 +- 7 files changed, 147 insertions(+), 46 deletions(-) diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index a2b2638..d6af00d 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -456,7 +456,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?, Statement::RepackVector(_) => todo!(), Statement::FunctionPointer(_) => todo!(), - Statement::VectorAccess(_) => todo!(), + Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, + Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, }) } @@ -986,6 +987,39 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }); Ok(()) } + + fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> { + let src = self.resolver.value(vec_acccess.vector_src)?; + let index = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B8), + vec_acccess.member as _, + 0, + ) + }; + self.resolver + .with_result(vec_acccess.scalar_dst, |dst| unsafe { + LLVMBuildExtractElement(self.builder, src, index, dst) + }); + Ok(()) + } + + fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> { + let vector_src = self.resolver.value(vector_write.vector_src)?; + let scalar_src = self.resolver.value(vector_write.scalar_src)?; + let index = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B8), + vector_write.member as _, + 0, + ) + }; + self.resolver + .with_result(vector_write.vector_dst, |dst| unsafe { + LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst) + }); + Ok(()) + } } fn get_pointer_type<'ctx>( diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index 3f37684..d522b12 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -1562,7 +1562,8 @@ fn emit_function_body_ops<'input>( builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; } } - Statement::VectorAccess(vector_access) => todo!(), + Statement::VectorRead(vector_access) => todo!(), + Statement::VectorWrite(vector_write) => todo!(), } } Ok(()) diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index 3dabf40..e9768e0 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -189,15 +189,12 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { fn vec_member( &mut self, - vector_src: SpirvWord, + vector_ident: SpirvWord, member: u8, _type_space: Option<(&ast::Type, ast::StateSpace)>, is_dst: bool, ) -> Result { - if is_dst { - return Err(error_mismatched_type()); - } - let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? { + let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? { (ast::Type::Vector(vector_width, scalar_t), space) => { (*vector_width, *scalar_t, *space) } @@ -206,13 +203,24 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { let temporary = self .resolver .register_unnamed(Some((scalar_type.into(), space))); - self.result.push(Statement::VectorAccess(VectorAccess { - scalar_type, - vector_width, - dst: temporary, - src: vector_src, - member: member, - })); + if is_dst { + self.post_stmts.push(Statement::VectorWrite(VectorWrite { + scalar_type, + vector_width, + vector_dst: vector_ident, + vector_src: vector_ident, + scalar_src: temporary, + member, + })); + } else { + self.result.push(Statement::VectorRead(VectorRead { + scalar_type, + vector_width, + scalar_dst: temporary, + vector_src: vector_ident, + member, + })); + } Ok(temporary) } diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 3553139..8c3b794 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -112,7 +112,7 @@ impl<'a, 'b, 'input> is_dst: bool, _relaxed_type_check: bool, ) -> Result { - self.replace_sreg(args, None, is_dst) + Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args)) } } @@ -122,7 +122,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { name: SpirvWord, vector_index: Option, is_dst: bool, - ) -> Result { + ) -> Result, TranslateError> { if let Some(sreg) = self.special_registers.get(name) { if is_dst { return Err(error_mismatched_type()); @@ -179,30 +179,33 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { data, arguments, })); - Ok(fn_result) + Ok(Some(fn_result)) } else { - Ok(name) + Ok(None) } } } -pub fn map_operand( +pub fn map_operand( this: ast::ParsedOperand, - fn_: &mut impl FnMut(T, Option) -> Result, -) -> Result, Err> { + fn_: &mut impl FnMut(T, Option) -> Result, Err>, +) -> Result, Err> { Ok(match this { - ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?), + ast::ParsedOperand::Reg(ident) => { + ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident)) + } ast::ParsedOperand::RegOffset(ident, offset) => { - ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset) + ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset) } ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm), - ast::ParsedOperand::VecMember(ident, member) => { - ast::ParsedOperand::Reg(fn_(ident, Some(member))?) - } + ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? { + Some(ident) => ast::ParsedOperand::Reg(ident), + None => ast::ParsedOperand::VecMember(ident, member), + }, ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( idents .into_iter() - .map(|ident| fn_(ident, None)) + .map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident))) .collect::, _>>()?, ), }) diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index c04fa09..2a6ea5a 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -45,11 +45,18 @@ pub(super) fn run( Statement::RepackVector(repack), )?; } - Statement::VectorAccess(vector_access) => { + Statement::VectorRead(vector_read) => { insert_implicit_conversions_impl( &mut result, id_def, - Statement::VectorAccess(vector_access), + Statement::VectorRead(vector_read), + )?; + } + Statement::VectorWrite(vector_write) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::VectorWrite(vector_write), )?; } s @ Statement::Conditional(_) diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index a232eb9..ae6adce 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -774,7 +774,8 @@ enum Statement { PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), FunctionPointer(FunctionPointerDetails), - VectorAccess(VectorAccess), + VectorRead(VectorRead), + VectorWrite(VectorWrite), } impl> Statement, T> { @@ -954,33 +955,69 @@ impl> Statement, T> { offset_src, }) } - Statement::VectorAccess(VectorAccess { + Statement::VectorRead(VectorRead { scalar_type, vector_width, - dst, - src: vector_src, + scalar_dst: dst, + vector_src, member, }) => { + let scalar_t = scalar_type.into(); + let vector_t = ast::Type::Vector(vector_width, scalar_type); let dst: SpirvWord = visitor.visit_ident( dst, - Some((&scalar_type.into(), ast::StateSpace::Reg)), + Some((&scalar_t, ast::StateSpace::Reg)), true, false, )?; let src = visitor.visit_ident( vector_src, - Some(( - &ast::Type::Vector(vector_width, scalar_type), - ast::StateSpace::Reg, - )), + Some((&vector_t, ast::StateSpace::Reg)), false, false, )?; - Statement::VectorAccess(VectorAccess { + Statement::VectorRead(VectorRead { + scalar_type, + vector_width, + scalar_dst: dst, + vector_src: src, + member, + }) + } + Statement::VectorWrite(VectorWrite { + scalar_type, + vector_width, + vector_dst, + vector_src, + scalar_src, + member, + }) => { + let scalar_t = scalar_type.into(); + let vector_t = ast::Type::Vector(vector_width, scalar_type); + let vector_dst = visitor.visit_ident( + vector_dst, + Some((&vector_t, ast::StateSpace::Reg)), + true, + false, + )?; + let vector_src = visitor.visit_ident( + vector_src, + Some((&vector_t, ast::StateSpace::Reg)), + false, + false, + )?; + let scalar_src = visitor.visit_ident( + scalar_src, + Some((&scalar_t, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::VectorWrite(VectorWrite { + vector_dst, + vector_src, + scalar_src, scalar_type, vector_width, - dst, - src, member, }) } @@ -1538,7 +1575,8 @@ fn compute_denorm_information<'input>( Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} - Statement::VectorAccess { .. } => {} + Statement::VectorRead { .. } => {} + Statement::VectorWrite { .. } => {} Statement::RepackVector(_) => {} Statement::FunctionPointer(_) => {} } @@ -2058,11 +2096,20 @@ impl SpecialRegistersMap2 { } } -pub struct VectorAccess { +pub struct VectorRead { scalar_type: ast::ScalarType, vector_width: u8, - dst: SpirvWord, - src: SpirvWord, + scalar_dst: SpirvWord, + vector_src: SpirvWord, + member: u8, +} + +pub struct VectorWrite { + scalar_type: ast::ScalarType, + vector_width: u8, + vector_dst: SpirvWord, + vector_src: SpirvWord, + scalar_src: SpirvWord, member: u8, } diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs index 037e918..13295d8 100644 --- a/ptx/src/pass/normalize_labels.rs +++ b/ptx/src/pass/normalize_labels.rs @@ -26,7 +26,8 @@ pub(super) fn run( | Statement::Constant(..) | Statement::Label(..) | Statement::PtrAccess { .. } - | Statement::VectorAccess { .. } + | Statement::VectorRead { .. } + | Statement::VectorWrite { .. } | Statement::RepackVector(..) | Statement::FunctionPointer(..) => {} }