Refactor normalize_identifiers

This commit is contained in:
Andrzej Janik
2024-09-16 16:42:34 +02:00
parent 2cd7910d46
commit 3b5efbf88b
4 changed files with 335 additions and 0 deletions

View File

@ -17,6 +17,7 @@ thiserror = "1.0"
bit-vec = "0.6"
half ="1.6"
bitflags = "1.2"
rustc-hash = "2.0.0"
[dependencies.lalrpop-util]
version = "0.19.12"

View File

@ -24,6 +24,7 @@ mod fix_special_registers;
mod insert_implicit_conversions;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
@ -1657,3 +1658,35 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
}
}
}
pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
Method(Function2<'input, Instruction, Operand>),
}
pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
pub globals: Vec<ast::Variable<SpirvWord>>,
pub body: Option<Vec<Statement<Instruction, Operand>>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective,
}
type NormalizedDirective2<'input> = Directive2<
'input,
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
),
ast::ParsedOperand<SpirvWord>,
>;
type NormalizedFunction2<'input> = Function2<
'input,
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
),
ast::ParsedOperand<SpirvWord>,
>;

View File

@ -0,0 +1,292 @@
use super::*;
use ptx_parser as ast;
use rustc_hash::FxHashMap;
pub(crate) fn run<'input>(
fn_defs: &mut GlobalStringIdentResolver<'input>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
let mut resolver = NameResolver::new(fn_defs);
let result = directives
.into_iter()
.map(|directive| remap_directive(&mut resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
fn remap_directive<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> {
Ok(match directive {
ast::Directive::Variable(linking, var) => {
NormalizedDirective2::Variable(linking, remap_variable(resolver, var)?)
}
ast::Directive::Method(linking, directive) => {
NormalizedDirective2::Method(remap_method(resolver, linking, directive)?)
}
})
}
fn remap_method<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> {
let name = match method.func_directive.name {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(text) => ast::MethodName::Func(
resolver.add(Cow::Borrowed(method.func_directive.name.text()), None)?,
),
};
resolver.start_scope();
let func_decl = Rc::new(RefCell::new(remap_function_decl(
resolver,
method.func_directive,
name,
)?));
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
remap_statements(resolver, &mut result, statements)?;
Ok::<_, TranslateError>(result)
})
.transpose()?;
resolver.end_scope();
Ok(Function2 {
func_decl,
globals: Vec::new(),
body,
import_as: None,
tuning: method.tuning,
linkage,
})
}
fn remap_function_decl<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>,
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
assert!(func_directive.shared_mem.is_none());
let return_arguments = func_directive
.return_arguments
.into_iter()
.map(|var| remap_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
let input_arguments = func_directive
.input_arguments
.into_iter()
.map(|var| remap_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
Ok(ast::MethodDeclaration {
return_arguments,
name,
input_arguments,
shared_mem: None,
})
}
fn remap_variable<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable {
name: resolver.add(
Cow::Borrowed(variable.name),
Some((variable.v_type.clone(), variable.state_space)),
)?,
align: variable.align,
v_type: variable.v_type,
state_space: variable.state_space,
array_init: variable.array_init,
})
}
fn remap_statements<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<(), TranslateError> {
for statement in statements.iter() {
match statement {
ast::Statement::Label(label) => {
resolver.add(Cow::Borrowed(*label), None)?;
}
_ => {}
}
}
for statement in statements {
match statement {
ast::Statement::Label(label) => {
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
}
ast::Statement::Variable(variable) => remap_multivariable(resolver, result, variable)?,
ast::Statement::Instruction(predicate, instruction) => {
result.push(Statement::Instruction((
predicate
.map(|pred| {
Ok::<_, TranslateError>(ast::PredAt {
not: pred.not,
label: resolver.get(pred.label)?,
})
})
.transpose()?,
remap_instruction(resolver, instruction)?,
)))
}
ast::Statement::Block(block) => {
resolver.start_scope();
remap_statements(resolver, result, block)?;
resolver.end_scope();
}
}
}
Ok(())
}
fn remap_instruction<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(instruction, &mut |name: &'input str,
_: Option<(
&ast::Type,
ast::StateSpace,
)>,
_,
_| {
resolver.get(&name)
})
}
fn remap_multivariable<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> {
match variable.count {
Some(count) => {
for i in 0..count {
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
let ident = resolver.add(
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable {
align: variable.var.align,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident,
array_init: variable.var.array_init.clone(),
}));
}
}
None => {
let name = Cow::Borrowed(variable.var.name);
let ident = resolver.add(
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable {
align: variable.var.align,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident,
array_init: variable.var.array_init.clone(),
}));
}
}
Ok(())
}
struct NameResolver<'input, 'b> {
flat_resolver: &'b mut GlobalStringIdentResolver<'input>,
scopes: Vec<ScopeStringIdentResolver<'input>>,
}
impl<'input, 'b> NameResolver<'input, 'b> {
fn new(flat_resolver: &'b mut GlobalStringIdentResolver<'input>) -> Self {
Self {
flat_resolver,
scopes: vec![ScopeStringIdentResolver::new()],
}
}
fn start_scope(&mut self) {
self.scopes.push(ScopeStringIdentResolver::new());
}
fn end_scope(&mut self) {
let scope = self.scopes.pop().unwrap();
scope.flush(self.flat_resolver);
}
fn add(
&mut self,
name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let result = self.flat_resolver.current_id;
self.flat_resolver.current_id.0 += 1;
let current_scope = self.scopes.last_mut().unwrap();
if current_scope
.name_to_ident
.insert(name.clone(), result)
.is_some()
{
return Err(error_unknown_symbol());
}
current_scope
.ident_map
.insert(result, IdentEntry { name, type_space });
Ok(result)
}
fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
self.scopes
.iter()
.rev()
.find_map(|resolver| resolver.name_to_ident.get(name).copied())
.ok_or_else(|| error_unreachable())
}
fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
let current_scope = self.scopes.last().unwrap();
current_scope
.name_to_ident
.get(label)
.copied()
.ok_or_else(|| error_unreachable())
}
}
struct ScopeStringIdentResolver<'input> {
ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
}
impl<'input> ScopeStringIdentResolver<'input> {
fn new() -> Self {
Self {
ident_map: FxHashMap::default(),
name_to_ident: FxHashMap::default(),
}
}
fn flush(self, resolver: &mut GlobalStringIdentResolver<'input>) {
resolver.ident_map.extend(self.ident_map);
}
}
struct GlobalStringIdentResolver<'input> {
pub(crate) current_id: SpirvWord,
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
}
struct IdentEntry<'input> {
name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
}

View File

@ -1049,6 +1049,15 @@ impl<'input, ID> MethodName<'input, ID> {
}
}
impl<'input> MethodName<'input, &'input str> {
pub fn text(&self) -> &'input str {
match self {
MethodName::Kernel(name) => *name,
MethodName::Func(name) => *name,
}
}
}
bitflags! {
pub struct LinkingDirective: u8 {
const NONE = 0b000;