use crate::ast; use rspirv::dr; use std::collections::HashMap; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { Base(ast::ScalarType), } struct TypeWordMap { void: spirv::Word, fn_void: spirv::Word, complex: HashMap, } impl TypeWordMap { fn new(b: &mut dr::Builder) -> TypeWordMap { let void = b.type_void(); TypeWordMap { void: void, fn_void: b.type_function(void, vec![]), complex: HashMap::::new(), } } fn void(&self) -> spirv::Word { self.void } fn fn_void(&self) -> spirv::Word { self.fn_void } fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { *self.complex.entry(t).or_insert_with(|| match t { SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => { b.type_int(8, 0) } SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => { b.type_int(16, 0) } SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => { b.type_int(32, 0) } SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => { b.type_int(64, 0) } SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1), SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1), SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1), SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1), SpirvType::Base(ast::ScalarType::F16) => b.type_float(16), SpirvType::Base(ast::ScalarType::F32) => b.type_float(32), SpirvType::Base(ast::ScalarType::F64) => b.type_float(64), }) } } struct IdWordMap<'a>(HashMap<&'a str, spirv::Word>); impl<'a> IdWordMap<'a> { fn new() -> Self { IdWordMap(HashMap::new()) } } impl<'a> IdWordMap<'a> { fn get_or_add(&mut self, b: &mut dr::Builder, id: &'a str) -> spirv::Word { *self.0.entry(id).or_insert_with(|| b.id()) } } pub fn to_spirv(ast: ast::Module) -> Result, rspirv::dr::Error> { let mut builder = dr::Builder::new(); let mut ids = IdWordMap::new(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module builder.set_version(1, 0); emit_capabilities(&mut builder); emit_extensions(&mut builder); emit_extended_instruction_sets(&mut builder); emit_memory_model(&mut builder); let mut map = TypeWordMap::new(&mut builder); for f in ast.functions { emit_function(&mut builder, &mut map, &mut ids, &f)?; } Ok(vec![]) } fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::Linkage); builder.capability(spirv::Capability::Addresses); builder.capability(spirv::Capability::Kernel); builder.capability(spirv::Capability::Int64); builder.capability(spirv::Capability::Int8); } fn emit_extensions(_: &mut dr::Builder) {} fn emit_extended_instruction_sets(builder: &mut dr::Builder) { builder.ext_inst_import("OpenCL.std"); } fn emit_memory_model(builder: &mut dr::Builder) { builder.memory_model( spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL, ); } fn emit_function<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, ids: &mut IdWordMap<'a>, f: &ast::Function<'a>, ) -> Result<(), rspirv::dr::Error> { let func_id = builder.begin_function( map.void(), None, spirv::FunctionControl::NONE, map.fn_void(), )?; for arg in f.args.iter() { let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type)); builder.function_parameter(arg_type)?; } for s in f.body.iter() { match s { ast::Statement::Label(name) => { let id = ids.get_or_add(builder, name); builder.begin_block(Some(id))?; } ast::Statement::Variable(var) => panic!(), ast::Statement::Instruction(_, _) => panic!(), } } builder.ret()?; builder.end_function()?; Ok(()) } // TODO: support scopes fn normalize_identifiers<'a>(func: Vec>) -> Vec { let mut result = Vec::with_capacity(func.len()); let mut id: u32 = 0; let mut known_ids = HashMap::new(); let mut get_or_add = |key| { *known_ids.entry(key).or_insert_with(|| { id += 1; id }) }; for s in func { if let Some(s) = Statement::from_ast(s, &mut get_or_add) { result.push(s); } } result } fn ssa_legalize(func: Vec) -> Vec { vec![] } enum Statement { Label(u32), Instruction(Option>, ast::Instruction), Phi(Vec), } impl Statement { fn from_ast<'a, F: FnMut(&'a str) -> u32>(s: ast::Statement<&'a str>, f: &mut F) -> Option { match s { ast::Statement::Label(name) => Some(Statement::Label(f(name))), ast::Statement::Instruction(p, i) => Some(Statement::Instruction( p.map(|p| p.map_id(f)), i.map_id(f), )), ast::Statement::Variable(_) => None, } } } impl ast::PredAt { fn map_id U>(self, f: &mut F) -> ast::PredAt { ast::PredAt { not: self.not, label: f(self.label), } } } impl ast::Instruction { fn map_id U>(self, f: &mut F) -> ast::Instruction { match self { ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)), ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)), ast::Instruction::Mul(d, a) => ast::Instruction::Mul(d, a.map_id(f)), ast::Instruction::Add(d, a) => ast::Instruction::Add(d, a.map_id(f)), ast::Instruction::Setp(d, a) => ast::Instruction::Setp(d, a.map_id(f)), ast::Instruction::SetpBool(d, a) => ast::Instruction::SetpBool(d, a.map_id(f)), ast::Instruction::Not(d, a) => ast::Instruction::Not(d, a.map_id(f)), ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map_id(f)), ast::Instruction::Cvt(d, a) => ast::Instruction::Cvt(d, a.map_id(f)), ast::Instruction::Shl(d, a) => ast::Instruction::Shl(d, a.map_id(f)), ast::Instruction::St(d, a) => ast::Instruction::St(d, a.map_id(f)), ast::Instruction::At(d, a) => ast::Instruction::At(d, a.map_id(f)), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), } } } impl ast::Arg1 { fn map_id U>(self, f: &mut F) -> ast::Arg1 { ast::Arg1 { dst: f(self.dst) } } } impl ast::Arg2 { fn map_id U>(self, f: &mut F) -> ast::Arg2 { ast::Arg2 { dst: f(self.dst), src: self.src.map_id(f), } } } impl ast::Arg2Mov { fn map_id U>(self, f: &mut F) -> ast::Arg2Mov { ast::Arg2Mov { dst: f(self.dst), src: self.src.map_id(f), } } } impl ast::Arg3 { fn map_id U>(self, f: &mut F) -> ast::Arg3 { ast::Arg3 { dst: f(self.dst), src1: self.src1.map_id(f), src2: self.src2.map_id(f), } } } impl ast::Arg4 { fn map_id U>(self, f: &mut F) -> ast::Arg4 { ast::Arg4 { dst1: f(self.dst1), dst2: self.dst2.map(|i| f(i)), src1: self.src1.map_id(f), src2: self.src2.map_id(f), } } } impl ast::Arg5 { fn map_id U>(self, f: &mut F) -> ast::Arg5 { ast::Arg5 { dst1: f(self.dst1), dst2: self.dst2.map(|i| f(i)), src1: self.src1.map_id(f), src2: self.src2.map_id(f), src3: self.src3.map_id(f), } } } impl ast::Operand { fn map_id U>(self, f: &mut F) -> ast::Operand { match self { ast::Operand::Reg(i) => ast::Operand::Reg(f(i)), ast::Operand::RegOffset(i, o) => ast::Operand::RegOffset(f(i), o), ast::Operand::Imm(v) => ast::Operand::Imm(v), } } } impl ast::MovOperand { fn map_id U>(self, f: &mut F) -> ast::MovOperand { match self { ast::MovOperand::Op(o) => ast::MovOperand::Op(o.map_id(f)), ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2) } } }