diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2e2995f..fd86f15 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -17,6 +17,7 @@ thiserror = "1.0" bit-vec = "0.6" half ="1.6" bitflags = "1.2" +rustc-hash = "2.0.0" [dependencies.lalrpop-util] version = "0.19.12" diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 3dcbf84..409425f 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -24,6 +24,7 @@ mod fix_special_registers; mod insert_implicit_conversions; mod insert_mem_ssa_statements; mod normalize_identifiers; +mod normalize_identifiers2; mod normalize_labels; mod normalize_predicates; @@ -1657,3 +1658,35 @@ fn denorm_count_map_update_impl( } } } + +pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> { + Variable(ast::LinkingDirective, ast::Variable), + Method(Function2<'input, Instruction, Operand>), +} + +pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> { + pub func_decl: Rc>>, + pub globals: Vec>, + pub body: Option>>, + import_as: Option, + tuning: Vec, + linkage: ast::LinkingDirective, +} + +type NormalizedDirective2<'input> = Directive2< + 'input, + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +type NormalizedFunction2<'input> = Function2< + 'input, + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs new file mode 100644 index 0000000..925feb7 --- /dev/null +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -0,0 +1,292 @@ +use super::*; +use ptx_parser as ast; +use rustc_hash::FxHashMap; + +pub(crate) fn run<'input>( + fn_defs: &mut GlobalStringIdentResolver<'input>, + directives: Vec>>, +) -> Result>, TranslateError> { + let mut resolver = NameResolver::new(fn_defs); + let result = directives + .into_iter() + .map(|directive| remap_directive(&mut resolver, directive)) + .collect::, _>>()?; + resolver.end_scope(); + Ok(result) +} + +fn remap_directive<'input, 'b>( + resolver: &mut NameResolver<'input, 'b>, + directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, +) -> Result, TranslateError> { + Ok(match directive { + ast::Directive::Variable(linking, var) => { + NormalizedDirective2::Variable(linking, remap_variable(resolver, var)?) + } + ast::Directive::Method(linking, directive) => { + NormalizedDirective2::Method(remap_method(resolver, linking, directive)?) + } + }) +} + +fn remap_method<'input, 'b>( + resolver: &mut NameResolver<'input, 'b>, + linkage: ast::LinkingDirective, + method: ast::Function<'input, &'input str, ast::Statement>>, +) -> Result, TranslateError> { + let name = match method.func_directive.name { + ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), + ast::MethodName::Func(text) => ast::MethodName::Func( + resolver.add(Cow::Borrowed(method.func_directive.name.text()), None)?, + ), + }; + resolver.start_scope(); + let func_decl = Rc::new(RefCell::new(remap_function_decl( + resolver, + method.func_directive, + name, + )?)); + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + remap_statements(resolver, &mut result, statements)?; + Ok::<_, TranslateError>(result) + }) + .transpose()?; + resolver.end_scope(); + Ok(Function2 { + func_decl, + globals: Vec::new(), + body, + import_as: None, + tuning: method.tuning, + linkage, + }) +} + +fn remap_function_decl<'input, 'b>( + resolver: &mut NameResolver<'input, 'b>, + func_directive: ast::MethodDeclaration<'input, &'input str>, + name: ast::MethodName<'input, SpirvWord>, +) -> Result, TranslateError> { + assert!(func_directive.shared_mem.is_none()); + let return_arguments = func_directive + .return_arguments + .into_iter() + .map(|var| remap_variable(resolver, var)) + .collect::, _>>()?; + let input_arguments = func_directive + .input_arguments + .into_iter() + .map(|var| remap_variable(resolver, var)) + .collect::, _>>()?; + Ok(ast::MethodDeclaration { + return_arguments, + name, + input_arguments, + shared_mem: None, + }) +} + +fn remap_variable<'input, 'b>( + resolver: &mut NameResolver<'input, 'b>, + variable: ast::Variable<&'input str>, +) -> Result, TranslateError> { + Ok(ast::Variable { + name: resolver.add( + Cow::Borrowed(variable.name), + Some((variable.v_type.clone(), variable.state_space)), + )?, + align: variable.align, + v_type: variable.v_type, + state_space: variable.state_space, + array_init: variable.array_init, + }) +} + +fn remap_statements<'input, 'b>( + resolver: &mut NameResolver<'input, 'b>, + result: &mut Vec, + statements: Vec>>, +) -> Result<(), TranslateError> { + for statement in statements.iter() { + match statement { + ast::Statement::Label(label) => { + resolver.add(Cow::Borrowed(*label), None)?; + } + _ => {} + } + } + for statement in statements { + match statement { + ast::Statement::Label(label) => { + result.push(Statement::Label(resolver.get_in_current_scope(label)?)) + } + ast::Statement::Variable(variable) => remap_multivariable(resolver, result, variable)?, + ast::Statement::Instruction(predicate, instruction) => { + result.push(Statement::Instruction(( + predicate + .map(|pred| { + Ok::<_, TranslateError>(ast::PredAt { + not: pred.not, + label: resolver.get(pred.label)?, + }) + }) + .transpose()?, + remap_instruction(resolver, instruction)?, + ))) + } + ast::Statement::Block(block) => { + resolver.start_scope(); + remap_statements(resolver, result, block)?; + resolver.end_scope(); + } + } + } + Ok(()) +} + +fn remap_instruction<'input, 'b>( + resolver: &mut NameResolver<'input, 'b>, + instruction: ast::Instruction>, +) -> Result>, TranslateError> { + ast::visit_map(instruction, &mut |name: &'input str, + _: Option<( + &ast::Type, + ast::StateSpace, + )>, + _, + _| { + resolver.get(&name) + }) +} + +fn remap_multivariable<'input, 'b>( + resolver: &mut NameResolver<'input, 'b>, + result: &mut Vec, + variable: ast::MultiVariable<&'input str>, +) -> Result<(), TranslateError> { + match variable.count { + Some(count) => { + for i in 0..count { + let name = Cow::Owned(format!("{}{}", variable.var.name, i)); + let ident = resolver.add( + name, + Some((variable.var.v_type.clone(), variable.var.state_space)), + )?; + result.push(Statement::Variable(ast::Variable { + align: variable.var.align, + v_type: variable.var.v_type.clone(), + state_space: variable.var.state_space, + name: ident, + array_init: variable.var.array_init.clone(), + })); + } + } + None => { + let name = Cow::Borrowed(variable.var.name); + let ident = resolver.add( + name, + Some((variable.var.v_type.clone(), variable.var.state_space)), + )?; + result.push(Statement::Variable(ast::Variable { + align: variable.var.align, + v_type: variable.var.v_type.clone(), + state_space: variable.var.state_space, + name: ident, + array_init: variable.var.array_init.clone(), + })); + } + } + Ok(()) +} + +struct NameResolver<'input, 'b> { + flat_resolver: &'b mut GlobalStringIdentResolver<'input>, + scopes: Vec>, +} + +impl<'input, 'b> NameResolver<'input, 'b> { + fn new(flat_resolver: &'b mut GlobalStringIdentResolver<'input>) -> Self { + Self { + flat_resolver, + scopes: vec![ScopeStringIdentResolver::new()], + } + } + + fn start_scope(&mut self) { + self.scopes.push(ScopeStringIdentResolver::new()); + } + + fn end_scope(&mut self) { + let scope = self.scopes.pop().unwrap(); + scope.flush(self.flat_resolver); + } + + fn add( + &mut self, + name: Cow<'input, str>, + type_space: Option<(ast::Type, ast::StateSpace)>, + ) -> Result { + let result = self.flat_resolver.current_id; + self.flat_resolver.current_id.0 += 1; + let current_scope = self.scopes.last_mut().unwrap(); + if current_scope + .name_to_ident + .insert(name.clone(), result) + .is_some() + { + return Err(error_unknown_symbol()); + } + current_scope + .ident_map + .insert(result, IdentEntry { name, type_space }); + Ok(result) + } + + fn get(&mut self, name: &str) -> Result { + self.scopes + .iter() + .rev() + .find_map(|resolver| resolver.name_to_ident.get(name).copied()) + .ok_or_else(|| error_unreachable()) + } + + fn get_in_current_scope(&self, label: &'input str) -> Result { + let current_scope = self.scopes.last().unwrap(); + current_scope + .name_to_ident + .get(label) + .copied() + .ok_or_else(|| error_unreachable()) + } +} + +struct ScopeStringIdentResolver<'input> { + ident_map: FxHashMap>, + name_to_ident: FxHashMap, SpirvWord>, +} + +impl<'input> ScopeStringIdentResolver<'input> { + fn new() -> Self { + Self { + ident_map: FxHashMap::default(), + name_to_ident: FxHashMap::default(), + } + } + + fn flush(self, resolver: &mut GlobalStringIdentResolver<'input>) { + resolver.ident_map.extend(self.ident_map); + } +} + +struct GlobalStringIdentResolver<'input> { + pub(crate) current_id: SpirvWord, + pub(crate) ident_map: FxHashMap>, +} + +struct IdentEntry<'input> { + name: Cow<'input, str>, + type_space: Option<(ast::Type, ast::StateSpace)>, +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index cc5a1d0..65c624e 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1049,6 +1049,15 @@ impl<'input, ID> MethodName<'input, ID> { } } +impl<'input> MethodName<'input, &'input str> { + pub fn text(&self) -> &'input str { + match self { + MethodName::Kernel(name) => *name, + MethodName::Func(name) => *name, + } + } +} + bitflags! { pub struct LinkingDirective: u8 { const NONE = 0b000;