Add a simple (and failing) PTX end-to-end test

This commit is contained in:
Andrzej Janik
2020-05-10 22:30:34 +02:00
parent 0c0f0e5a6b
commit d0aa5ba564
9 changed files with 493 additions and 89 deletions

View File

@ -17,3 +17,6 @@ bit-vec = "0.6"
[build-dependencies.lalrpop] [build-dependencies.lalrpop]
version = "0.18.1" version = "0.18.1"
features = ["lexer"] features = ["lexer"]
[dev-dependencies]
ocl = { version = "0.19", features = ["opencl_version_1_1", "opencl_version_1_2", "opencl_version_2_1"] }

View File

@ -189,19 +189,19 @@ pub enum MovOperand<ID> {
pub enum VectorPrefix { pub enum VectorPrefix {
V2, V2,
V4 V4,
} }
pub struct LdData { pub struct LdData {
pub qualifier: LdQualifier, pub qualifier: LdStQualifier,
pub state_space: LdStateSpace, pub state_space: LdStateSpace,
pub caching: LdCacheOperator, pub caching: LdCacheOperator,
pub vector: Option<VectorPrefix>, pub vector: Option<VectorPrefix>,
pub typ: ScalarType pub typ: ScalarType,
} }
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
pub enum LdQualifier { pub enum LdStQualifier {
Weak, Weak,
Volatile, Volatile,
Relaxed(LdScope), Relaxed(LdScope),
@ -212,7 +212,7 @@ pub enum LdQualifier {
pub enum LdScope { pub enum LdScope {
Cta, Cta,
Gpu, Gpu,
Sys Sys,
} }
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
@ -225,14 +225,13 @@ pub enum LdStateSpace {
Shared, Shared,
} }
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
pub enum LdCacheOperator { pub enum LdCacheOperator {
Cached, Cached,
L2Only, L2Only,
Streaming, Streaming,
LastUse, LastUse,
Uncached Uncached,
} }
pub struct MovData {} pub struct MovData {}
@ -248,13 +247,38 @@ pub struct SetpBoolData {}
pub struct NotData {} pub struct NotData {}
pub struct BraData { pub struct BraData {
pub uniform: bool pub uniform: bool,
} }
pub struct CvtData {} pub struct CvtData {}
pub struct ShlData {} pub struct ShlData {}
pub struct StData {} pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StStateSpace,
pub caching: StCacheOperator,
pub vector: Option<VectorPrefix>,
pub typ: ScalarType,
}
pub struct RetData {} #[derive(PartialEq, Eq)]
pub enum StStateSpace {
Generic,
Global,
Local,
Param,
Shared,
}
#[derive(PartialEq, Eq)]
pub enum StCacheOperator {
Writeback,
L2Only,
Streaming,
Writethrough,
}
pub struct RetData {
pub uniform: bool,
}

View File

@ -4,6 +4,8 @@ extern crate lalrpop_util;
extern crate quick_error; extern crate quick_error;
extern crate bit_vec; extern crate bit_vec;
#[cfg(test)]
extern crate ocl;
extern crate rspirv; extern crate rspirv;
extern crate spirv_headers as spirv; extern crate spirv_headers as spirv;

View File

@ -188,10 +188,10 @@ Instruction: ast::Instruction<&'input str> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<&'input str> = { InstLd: ast::Instruction<&'input str> = {
"ld" <q:LdQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => { "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => {
ast::Instruction::Ld( ast::Instruction::Ld(
ast::LdData { ast::LdData {
qualifier: q.unwrap_or(ast::LdQualifier::Weak), qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
state_space: ss.unwrap_or(ast::LdStateSpace::Generic), state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
caching: cop.unwrap_or(ast::LdCacheOperator::Cached), caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
vector: v, vector: v,
@ -202,11 +202,11 @@ InstLd: ast::Instruction<&'input str> = {
} }
}; };
LdQualifier: ast::LdQualifier = { LdStQualifier: ast::LdStQualifier = {
".weak" => ast::LdQualifier::Weak, ".weak" => ast::LdStQualifier::Weak,
".volatile" => ast::LdQualifier::Volatile, ".volatile" => ast::LdStQualifier::Volatile,
".relaxed" <s:LdScope> => ast::LdQualifier::Relaxed(s), ".relaxed" <s:LdScope> => ast::LdStQualifier::Relaxed(s),
".acquire" <s:LdScope> => ast::LdQualifier::Acquire(s), ".acquire" <s:LdScope> => ast::LdStQualifier::Acquire(s),
}; };
LdScope: ast::LdScope = { LdScope: ast::LdScope = {
@ -379,29 +379,39 @@ ShlType = {
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<&'input str> = { InstSt: ast::Instruction<&'input str> = {
"st" LdQualifier? StStateSpace? StCacheOperator? VectorPrefix? MemoryType "[" <dst:ID> "]" "," <src:Operand> => { "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <dst:ID> "]" "," <src:Operand> => {
ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src}) ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
state_space: ss.unwrap_or(ast::StStateSpace::Generic),
caching: cop.unwrap_or(ast::StCacheOperator::Writeback),
vector: v,
typ: t
},
ast::Arg2{dst:dst, src:src}
)
} }
}; };
StStateSpace = { StStateSpace: ast::StStateSpace = {
".global", ".global" => ast::StStateSpace::Global,
".local", ".local" => ast::StStateSpace::Local,
".param", ".param" => ast::StStateSpace::Param,
".shared", ".shared" => ast::StStateSpace::Shared,
}; };
StCacheOperator = { StCacheOperator: ast::StCacheOperator = {
".wb", ".wb" => ast::StCacheOperator::Writeback,
".cg", ".cg" => ast::StCacheOperator::L2Only,
".cs", ".cs" => ast::StCacheOperator::Streaming,
".wt", ".wt" => ast::StCacheOperator::Writethrough,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
InstRet: ast::Instruction<&'input str> = { InstRet: ast::Instruction<&'input str> = {
"ret" ".uni"? => ast::Instruction::Ret(ast::RetData{}) "ret" <u:".uni"?> => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() })
}; };
Operand: ast::Operand<&'input str> = { Operand: ast::Operand<&'input str> = {

View File

@ -1,5 +1,7 @@
use super::ptx; use super::ptx;
mod ops;
fn parse_and_assert(s: &str) { fn parse_and_assert(s: &str) {
let mut errors = Vec::new(); let mut errors = Vec::new();
ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); ptx::ModuleParser::new().parse(&mut errors, s).unwrap();

View File

@ -0,0 +1,20 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry ld_st(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr];
st.u64 [out_addr], temp;
ret;
}

View File

@ -0,0 +1 @@
test_ptx!(ld_st, [1u64], [1u64]);

280
ptx/src/test/ops/mod.rs Normal file
View File

@ -0,0 +1,280 @@
use crate::ptx;
use crate::translate;
use ocl::{Buffer, Context, Device, Kernel, OclPrm, Platform, Program, Queue};
use std::error;
use std::ffi::{c_void, CString};
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::mem;
use std::slice;
use std::{ptr, str};
macro_rules! test_ptx {
($fn_name:ident, $input:expr, $output:expr) => {
#[test]
fn $fn_name() -> Result<(), Box<dyn std::error::Error>> {
let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
let input = $input;
let mut output = $output;
crate::test::ops::test_ptx_assert(stringify!($fn_name), ptx, &input, &mut output)
}
};
}
mod ld_st;
const CL_DEVICE_IL_VERSION: u32 = 0x105B;
struct DisplayError<T: Display + Debug> {
err: T,
}
impl<T: Display + Debug> Display for DisplayError<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.err, f)
}
}
impl<T: Display + Debug> Debug for DisplayError<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.err, f)
}
}
impl<T: Display + Debug> error::Error for DisplayError<T> {}
fn test_ptx_assert<'a, T: OclPrm + From<u8>>(
name: &str,
ptx_text: &'a str,
input: &[T],
output: &mut [T],
) -> Result<(), Box<dyn error::Error + 'a>> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
assert!(errors.len() == 0);
let spirv = translate::to_spirv(ast)?;
let result = run_spirv(name, &spirv, input, output).map_err(|err| DisplayError { err })?;
assert_eq!(&output, &&*result);
Ok(())
}
fn run_spirv<T: OclPrm + From<u8>>(
name: &str,
spirv: &[u32],
input: &[T],
output: &mut [T],
) -> ocl::Result<Vec<T>> {
let (plat, dev) = get_ocl_platform_device();
let ctx = Context::builder().platform(plat).devices(dev).build()?;
let empty_cstr = CString::new("-cl-intel-greater-than-4GB-buffer-required").unwrap();
let byte_il = unsafe {
slice::from_raw_parts::<u8>(
spirv.as_ptr() as *const _,
spirv.len() * mem::size_of::<u32>(),
)
};
let src = CString::new(
"
__kernel void ld_st(ulong a, ulong b)
{
__global ulong* a_copy = (__global ulong*)a;
__global ulong* b_copy = (__global ulong*)b;
*b_copy = *a_copy;
}",
)
.unwrap();
//let prog = Program::with_il(byte_il, Some(&[dev]), &empty_cstr, &ctx)?;
let prog = Program::with_source(&ctx, &[src], Some(&[dev]), &empty_cstr)?;
let queue = Queue::new(&ctx, dev, None)?;
let cl_device_mem_alloc_intel = get_cl_device_mem_alloc_intel(&plat)?;
let cl_enqueue_memcpy_intel = get_cl_enqueue_memcpy_intel(&plat)?;
let cl_enqueue_memset_intel = get_cl_enqueue_memset_intel(&plat)?;
let cl_set_kernel_arg_mem_pointer_intel = get_cl_set_kernel_arg_mem_pointer_intel(&plat)?;
let mut err_code = 0;
let inp_b = cl_device_mem_alloc_intel(
ctx.as_ptr(),
dev.as_raw(),
ptr::null_mut(),
input.len() * mem::size_of::<T>(),
mem::align_of::<T>() as u32,
&mut err_code,
);
assert_eq!(err_code, 0);
let out_b = cl_device_mem_alloc_intel(
ctx.as_ptr(),
dev.as_raw(),
ptr::null_mut(),
output.len() * mem::size_of::<T>(),
mem::align_of::<T>() as u32,
&mut err_code,
);
assert_eq!(err_code, 0);
err_code = cl_enqueue_memcpy_intel(
queue.as_ptr(),
1,
inp_b as *mut _,
input.as_ptr() as *const _,
input.len() * mem::size_of::<T>(),
0,
ptr::null(),
ptr::null_mut(),
);
assert_eq!(err_code, 0);
err_code = cl_enqueue_memset_intel(
queue.as_ptr(),
out_b as *mut _,
0,
input.len() * mem::size_of::<T>(),
0,
ptr::null(),
ptr::null_mut(),
);
assert_eq!(err_code, 0);
let kernel = ocl::core::create_kernel(prog.as_core(), name)?;
err_code = cl_set_kernel_arg_mem_pointer_intel(kernel.as_ptr(), 0, inp_b);
assert_eq!(err_code, 0);
err_code = cl_set_kernel_arg_mem_pointer_intel(kernel.as_ptr(), 1, out_b);
assert_eq!(err_code, 0);
unsafe {
ocl::core::enqueue_kernel::<(), ()>(
queue.as_core(),
&kernel,
1,
None,
&[1, 0, 0],
None,
None,
None,
)
}?;
let mut result: Vec<T> = vec![0u8.into(); output.len()];
err_code = cl_enqueue_memcpy_intel(
queue.as_ptr(),
1,
result.as_mut_ptr() as *mut _,
inp_b,
result.len() * mem::size_of::<T>(),
0,
ptr::null(),
ptr::null_mut(),
);
assert_eq!(err_code, 0);
queue.finish()?;
Ok(result)
}
fn get_ocl_platform_device() -> (Platform, Device) {
for p in Platform::list() {
if p.extensions()
.unwrap()
.iter()
.find(|ext| *ext == "cl_intel_unified_shared_memory_preview")
.is_none()
{
continue;
}
for d in Device::list_all(p).unwrap() {
let typ = d.info(ocl::enums::DeviceInfo::Type).unwrap();
if let ocl::enums::DeviceInfoResult::Type(typ) = typ {
if typ.cpu() == ocl::flags::DeviceType::CPU {
continue;
}
}
if let Ok(version) = d.info_raw(CL_DEVICE_IL_VERSION) {
let name = str::from_utf8(&version).unwrap();
if name.starts_with("SPIR-V") {
return (p, d);
}
}
}
}
panic!("No OpenCL device with SPIR-V and USM support found")
}
fn get_cl_device_mem_alloc_intel(
p: &Platform,
) -> ocl::core::Result<
extern "C" fn(
ocl::core::ffi::cl_context,
ocl::core::ffi::cl_device_id,
*const ocl::core::ffi::cl_bitfield,
ocl::core::ffi::size_t,
ocl::core::ffi::cl_uint,
*mut ocl::core::ffi::cl_int,
) -> *const c_void,
> {
let ptr = unsafe {
ocl::core::get_extension_function_address_for_platform(
p.as_core(),
"clDeviceMemAllocINTEL",
None,
)
}?;
Ok(unsafe { std::mem::transmute(ptr) })
}
fn get_cl_enqueue_memcpy_intel(
p: &Platform,
) -> ocl::core::Result<
extern "C" fn(
ocl::core::ffi::cl_command_queue,
ocl::core::ffi::cl_bool,
*mut c_void,
*const c_void,
ocl::core::ffi::size_t,
ocl::core::ffi::cl_uint,
*const ocl::core::ffi::cl_event,
*mut ocl::core::ffi::cl_event,
) -> ocl::core::ffi::cl_int,
> {
let ptr = unsafe {
ocl::core::get_extension_function_address_for_platform(
p.as_core(),
"clEnqueueMemcpyINTEL",
None,
)
}?;
Ok(unsafe { std::mem::transmute(ptr) })
}
fn get_cl_enqueue_memset_intel(
p: &Platform,
) -> ocl::core::Result<
extern "C" fn(
ocl::core::ffi::cl_command_queue,
*mut c_void,
ocl::core::ffi::cl_int,
ocl::core::ffi::size_t,
ocl::core::ffi::cl_uint,
*const ocl::core::ffi::cl_event,
*mut ocl::core::ffi::cl_event,
) -> ocl::core::ffi::cl_int,
> {
let ptr = unsafe {
ocl::core::get_extension_function_address_for_platform(
p.as_core(),
"clEnqueueMemsetINTEL",
None,
)
}?;
Ok(unsafe { std::mem::transmute(ptr) })
}
fn get_cl_set_kernel_arg_mem_pointer_intel(
p: &Platform,
) -> ocl::core::Result<
extern "C" fn(
ocl::core::ffi::cl_kernel,
ocl::core::ffi::cl_uint,
*const c_void,
) -> ocl::core::ffi::cl_int,
> {
let ptr = unsafe {
ocl::core::get_extension_function_address_for_platform(
p.as_core(),
"clSetKernelArgMemPointerINTEL",
None,
)
}?;
Ok(unsafe { std::mem::transmute(ptr) })
}

View File

@ -5,6 +5,8 @@ use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet}; use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt; use std::fmt;
use rspirv::binary::{Assemble, Disassemble};
#[derive(PartialEq, Eq, Hash, Clone, Copy)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType { enum SpirvType {
Base(ast::ScalarType), Base(ast::ScalarType),
@ -13,7 +15,6 @@ enum SpirvType {
struct TypeWordMap { struct TypeWordMap {
void: spirv::Word, void: spirv::Word,
fn_void: spirv::Word,
complex: HashMap<SpirvType, spirv::Word>, complex: HashMap<SpirvType, spirv::Word>,
} }
@ -22,7 +23,6 @@ impl TypeWordMap {
let void = b.type_void(); let void = b.type_void();
TypeWordMap { TypeWordMap {
void: void, void: void,
fn_void: b.type_function(void, vec![]),
complex: HashMap::<SpirvType, spirv::Word>::new(), complex: HashMap::<SpirvType, spirv::Word>::new(),
} }
} }
@ -30,32 +30,24 @@ impl TypeWordMap {
fn void(&self) -> spirv::Word { fn void(&self) -> spirv::Word {
self.void self.void
} }
fn fn_void(&self) -> spirv::Word {
self.fn_void
}
fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word { fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
*self.complex.entry(SpirvType::Base(t)).or_insert_with(|| match t { *self
ast::ScalarType::B8 | ast::ScalarType::U8 => { .complex
b.type_int(8, 0) .entry(SpirvType::Base(t))
} .or_insert_with(|| match t {
ast::ScalarType::B16 | ast::ScalarType::U16 => { ast::ScalarType::B8 | ast::ScalarType::U8 => b.type_int(8, 0),
b.type_int(16, 0) ast::ScalarType::B16 | ast::ScalarType::U16 => b.type_int(16, 0),
} ast::ScalarType::B32 | ast::ScalarType::U32 => b.type_int(32, 0),
ast::ScalarType::B32 | ast::ScalarType::U32 => { ast::ScalarType::B64 | ast::ScalarType::U64 => b.type_int(64, 0),
b.type_int(32, 0) ast::ScalarType::S8 => b.type_int(8, 1),
} ast::ScalarType::S16 => b.type_int(16, 1),
ast::ScalarType::B64 | ast::ScalarType::U64 => { ast::ScalarType::S32 => b.type_int(32, 1),
b.type_int(64, 0) ast::ScalarType::S64 => b.type_int(64, 1),
} ast::ScalarType::F16 => b.type_float(16),
ast::ScalarType::S8 => b.type_int(8, 1), ast::ScalarType::F32 => b.type_float(32),
ast::ScalarType::S16 => b.type_int(16, 1), ast::ScalarType::F64 => b.type_float(64),
ast::ScalarType::S32 => b.type_int(32, 1), })
ast::ScalarType::S64 => b.type_int(64, 1),
ast::ScalarType::F16 => b.type_float(16),
ast::ScalarType::F32 => b.type_float(32),
ast::ScalarType::F64 => b.type_float(64),
})
} }
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
@ -63,15 +55,25 @@ impl TypeWordMap {
SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar), SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
SpirvType::Pointer(scalar, storage) => { SpirvType::Pointer(scalar, storage) => {
let base = self.get_or_add_scalar(b, scalar); let base = self.get_or_add_scalar(b, scalar);
*self.complex.entry(t).or_insert_with(|| { *self
b.type_pointer(None, storage, base) .complex
}) .entry(t)
.or_insert_with(|| b.type_pointer(None, storage, base))
} }
} }
} }
fn get_or_add_fn<Args: Iterator<Item = SpirvType>>(
&mut self,
b: &mut dr::Builder,
args: Args,
) -> spirv::Word {
let params = args.map(|a| self.get_or_add(b, a)).collect::<Vec<_>>();
b.type_function(self.void(), params)
}
} }
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> { pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> {
let mut builder = dr::Builder::new(); let mut builder = dr::Builder::new();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 0); builder.set_version(1, 0);
@ -83,10 +85,12 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
for f in ast.functions { for f in ast.functions {
emit_function(&mut builder, &mut map, f)?; emit_function(&mut builder, &mut map, f)?;
} }
Ok(vec![]) let module = builder.module();
Ok(module.assemble())
} }
fn emit_capabilities(builder: &mut dr::Builder) { fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::GenericPointer);
builder.capability(spirv::Capability::Linkage); builder.capability(spirv::Capability::Linkage);
builder.capability(spirv::Capability::Addresses); builder.capability(spirv::Capability::Addresses);
builder.capability(spirv::Capability::Kernel); builder.capability(spirv::Capability::Kernel);
@ -112,12 +116,12 @@ fn emit_function<'a>(
map: &mut TypeWordMap, map: &mut TypeWordMap,
f: ast::Function<'a>, f: ast::Function<'a>,
) -> Result<spirv::Word, rspirv::dr::Error> { ) -> Result<spirv::Word, rspirv::dr::Error> {
let func_id = builder.begin_function( let func_type = get_function_type(builder, map, &f.args);
map.void(), let func_id =
None, builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, func_type)?;
spirv::FunctionControl::NONE, if f.kernel {
map.fn_void(), builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]);
)?; }
let mut contant_ids = HashMap::new(); let mut contant_ids = HashMap::new();
collect_arg_ids(&mut contant_ids, &f.args); collect_arg_ids(&mut contant_ids, &f.args);
collect_label_ids(&mut contant_ids, &f.body); collect_label_ids(&mut contant_ids, &f.body);
@ -126,7 +130,7 @@ fn emit_function<'a>(
let rpostorder = to_reverse_postorder(&bbs); let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder); let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms); let dom_fronts = dominance_frontiers(&bbs, &doms);
let phis = ssa_legalize( let (mut phis, unique_ids) = ssa_legalize(
&mut normalized_ids, &mut normalized_ids,
contant_ids.len() as u32, contant_ids.len() as u32,
unique_ids, unique_ids,
@ -138,11 +142,17 @@ fn emit_function<'a>(
emit_function_args(builder, id_offset, map, &f.args); emit_function_args(builder, id_offset, map, &f.args);
emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?; emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?;
builder.end_function()?; builder.end_function()?;
builder.ret()?;
builder.end_function()?;
Ok(func_id) Ok(func_id)
} }
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
args: &[ast::Argument],
) -> spirv::Word {
map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::Base(arg.a_type)))
}
fn emit_function_args( fn emit_function_args(
builder: &mut dr::Builder, builder: &mut dr::Builder,
id_offset: spirv::Word, id_offset: spirv::Word,
@ -151,7 +161,7 @@ fn emit_function_args(
) { ) {
let mut id = id_offset; let mut id = id_offset;
for arg in args { for arg in args {
let result_type = map.get_or_add(builder, SpirvType::Base(arg.a_type)); let result_type = map.get_or_add_scalar(builder, arg.a_type);
let inst = dr::Instruction::new( let inst = dr::Instruction::new(
spirv::Op::FunctionParameter, spirv::Op::FunctionParameter,
Some(result_type), Some(result_type),
@ -195,6 +205,8 @@ fn emit_function_body_ops(
func: &[Statement], func: &[Statement],
cfg: &[BasicBlock], cfg: &[BasicBlock],
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
// TODO: entry basic block can't be target of jumps,
// we need to emit additional BB for this purpose
for bb_idx in 0..cfg.len() { for bb_idx in 0..cfg.len() {
let body = get_bb_body(func, cfg, BBIndex(bb_idx)); let body = get_bb_body(func, cfg, BBIndex(bb_idx));
if body.len() == 0 { if body.len() == 0 {
@ -215,24 +227,63 @@ fn emit_function_body_ops(
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
} }
Statement::Instruction(inst) => match inst { Statement::Instruction(inst) => match inst {
// Sadly, SPIR-V does not support marking jumps as guaranteed-converged // SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => { ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?; builder.branch(arg.src + id_offset)?;
} }
ast::Instruction::Ld(data, arg) => { ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdQualifier::Weak || data.vector.is_some() { if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
todo!() todo!()
} }
let storage_class = match data.state_space { let src = match arg.src {
ast::LdStateSpace::Generic => spirv::StorageClass::Generic, ast::Operand::Reg(id) => id + id_offset,
ast::LdStateSpace::Param => spirv::StorageClass::CrossWorkgroup,
_ => todo!(), _ => todo!(),
}; };
let result_type = map.get_or_add(builder, SpirvType::Base(data.typ)); let result_type = map.get_or_add_scalar(builder, data.typ);
let pointer_type = match data.state_space {
map.get_or_add(builder, SpirvType::Pointer(data.typ, storage_class)); ast::LdStateSpace::Generic => {
builder.load(result_type, None, pointer_type, None, [])?; // TODO: make the cast optional
let ptr_result_type = map.get_or_add(
builder,
SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup),
);
let bitcast = builder.convert_u_to_ptr(ptr_result_type, None, src - 5)?;
builder.load(
result_type,
Some(arg.dst + id_offset),
bitcast,
None,
[],
)?;
}
ast::LdStateSpace::Param => {
//builder.copy_object(result_type, Some(arg.dst + id_offset), src)?;
}
_ => todo!(),
}
} }
ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
|| data.vector.is_some()
|| data.state_space != ast::StStateSpace::Generic
{
todo!()
}
let src = match arg.src {
ast::Operand::Reg(id) => id + id_offset,
_ => todo!(),
};
// TODO make cast optional
let ptr_result_type = map.get_or_add(
builder,
SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup),
);
let bitcast =
builder.convert_u_to_ptr(ptr_result_type, None, arg.dst + id_offset - 5)?;
builder.store(bitcast, src, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
_ => todo!(), _ => todo!(),
}, },
} }
@ -279,7 +330,7 @@ fn ssa_legalize(
bbs: &[BasicBlock], bbs: &[BasicBlock],
doms: &[BBIndex], doms: &[BBIndex],
dom_fronts: &[HashSet<BBIndex>], dom_fronts: &[HashSet<BBIndex>],
) -> Vec<Vec<PhiDef>> { ) -> (Vec<Vec<PhiDef>>, spirv::Word) {
let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts); let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts);
apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis) apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis)
} }
@ -301,7 +352,7 @@ fn apply_ssa_renaming(
constant_ids: spirv::Word, constant_ids: spirv::Word,
all_ids: spirv::Word, all_ids: spirv::Word,
old_phi: &[HashSet<spirv::Word>], old_phi: &[HashSet<spirv::Word>],
) -> Vec<Vec<PhiDef>> { ) -> (Vec<Vec<PhiDef>>, spirv::Word) {
let mut dom_tree = vec![Vec::new(); bbs.len()]; let mut dom_tree = vec![Vec::new(); bbs.len()];
for (bb, idom) in doms.iter().enumerate().skip(1) { for (bb, idom) in doms.iter().enumerate().skip(1) {
dom_tree[idom.0].push(BBIndex(bb)); dom_tree[idom.0].push(BBIndex(bb));
@ -345,7 +396,7 @@ fn apply_ssa_renaming(
break; break;
} }
} }
new_phi let phi = new_phi
.into_iter() .into_iter()
.map(|map| { .map(|map| {
map.into_iter() map.into_iter()
@ -355,7 +406,8 @@ fn apply_ssa_renaming(
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>();
(phi, ssa_state.next_id())
} }
// before ssa-renaming every phi is x <- phi(x,x,x,x) // before ssa-renaming every phi is x <- phi(x,x,x,x)
@ -479,6 +531,10 @@ impl<'a> SSARewriteState {
self.stack[(x - self.constant_ids) as usize].pop(); self.stack[(x - self.constant_ids) as usize].pop();
} }
} }
fn next_id(&self) -> spirv::Word {
self.next
}
} }
// "Engineering a Compiler" - Figure 9.9 // "Engineering a Compiler" - Figure 9.9
@ -895,7 +951,10 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Not(_, a) => a.visit_id(f), ast::Instruction::Not(_, a) => a.visit_id(f),
ast::Instruction::Cvt(_, a) => a.visit_id(f), ast::Instruction::Cvt(_, a) => a.visit_id(f),
ast::Instruction::Shl(_, a) => a.visit_id(f), ast::Instruction::Shl(_, a) => a.visit_id(f),
ast::Instruction::St(_, a) => a.visit_id(f), ast::Instruction::St(_, a) => {
f(false, &a.dst);
a.src.visit_id(f);
}
ast::Instruction::Bra(_, a) => a.visit_id(f), ast::Instruction::Bra(_, a) => a.visit_id(f),
ast::Instruction::Ret(_) => (), ast::Instruction::Ret(_) => (),
} }
@ -912,7 +971,10 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Not(_, a) => a.visit_id_mut(f), ast::Instruction::Not(_, a) => a.visit_id_mut(f),
ast::Instruction::Cvt(_, a) => a.visit_id_mut(f), ast::Instruction::Cvt(_, a) => a.visit_id_mut(f),
ast::Instruction::Shl(_, a) => a.visit_id_mut(f), ast::Instruction::Shl(_, a) => a.visit_id_mut(f),
ast::Instruction::St(_, a) => a.visit_id_mut(f), ast::Instruction::St(_, a) => {
f(false, &mut a.dst);
a.src.visit_id_mut(f);
}
ast::Instruction::Bra(_, a) => a.visit_id_mut(f), ast::Instruction::Bra(_, a) => a.visit_id_mut(f),
ast::Instruction::Ret(_) => (), ast::Instruction::Ret(_) => (),
} }
@ -965,7 +1027,7 @@ impl<T: Copy> ast::Instruction<T> {
ast::Instruction::Not(_, a) => a.for_dst_id(f), ast::Instruction::Not(_, a) => a.for_dst_id(f),
ast::Instruction::Cvt(_, a) => a.for_dst_id(f), ast::Instruction::Cvt(_, a) => a.for_dst_id(f),
ast::Instruction::Shl(_, a) => a.for_dst_id(f), ast::Instruction::Shl(_, a) => a.for_dst_id(f),
ast::Instruction::St(_, a) => a.for_dst_id(f), ast::Instruction::St(_, _) => (),
ast::Instruction::Bra(_, _) => (), ast::Instruction::Bra(_, _) => (),
ast::Instruction::Ret(_) => (), ast::Instruction::Ret(_) => (),
} }
@ -1736,7 +1798,7 @@ mod tests {
let rpostorder = to_reverse_postorder(&bbs); let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder); let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms); let dom_fronts = dominance_frontiers(&bbs, &doms);
let mut ssa_phis = ssa_legalize( let (mut ssa_phis, _) = ssa_legalize(
&mut func, &mut func,
constant_ids.len() as u32, constant_ids.len() as u32,
unique_ids, unique_ids,