mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:19:20 +03:00
Refactor compilation passes (#270)
The overarching goal is to refactor all passes so they are module-scoped and not function-scoped. Additionally, make improvements to the most egregiously buggy/unfit passes (so the code is ready for the next major features: linking, ftz handling) and continue adding more code to the LLVM backend
This commit is contained in:
60
.github/workflows/rust.yml
vendored
60
.github/workflows/rust.yml
vendored
@ -1,60 +0,0 @@
|
|||||||
name: Rust
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ master ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ master ]
|
|
||||||
|
|
||||||
env:
|
|
||||||
CARGO_TERM_COLOR: always
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build_lin:
|
|
||||||
name: Build and publish (Linux)
|
|
||||||
runs-on: ubuntu-20.04
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Install GPU drivers
|
|
||||||
run: |
|
|
||||||
sudo apt-get install -y gpg-agent wget
|
|
||||||
wget -qO - https://repositories.intel.com/graphics/intel-graphics.key | sudo apt-key add -
|
|
||||||
sudo apt-add-repository 'deb [arch=amd64] https://repositories.intel.com/graphics/ubuntu focal main'
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install intel-opencl-icd intel-level-zero-gpu level-zero intel-media-va-driver-non-free libmfx1 libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev ocl-icd-opencl-dev
|
|
||||||
- name: Build
|
|
||||||
run: cargo build --workspace --verbose --release
|
|
||||||
- name: Rename to libcuda.so
|
|
||||||
run: |
|
|
||||||
mv target/release/libnvcuda.so target/release/libcuda.so
|
|
||||||
ln -s libcuda.so target/release/libcuda.so.1
|
|
||||||
- uses: actions/upload-artifact@v2
|
|
||||||
with:
|
|
||||||
name: Linux
|
|
||||||
path: |
|
|
||||||
target/release/libcuda.so
|
|
||||||
target/release/libcuda.so.1
|
|
||||||
target/release/libnvml.so
|
|
||||||
build_win:
|
|
||||||
name: Build and publish (Windows)
|
|
||||||
runs-on: windows-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Build
|
|
||||||
run: cargo build --workspace --verbose --release
|
|
||||||
- uses: actions/upload-artifact@v2
|
|
||||||
with:
|
|
||||||
name: Windows
|
|
||||||
path: |
|
|
||||||
target/release/nvcuda.dll
|
|
||||||
target/release/nvml.dll
|
|
||||||
target/release/zluda_redirect.dll
|
|
||||||
target/release/zluda_with.exe
|
|
||||||
target/release/zluda_dump.dll
|
|
||||||
# TODO(take-cheeze): Support testing
|
|
||||||
# - name: Run tests
|
|
||||||
# run: cargo test --verbose
|
|
@ -17,6 +17,9 @@ thiserror = "1.0"
|
|||||||
bit-vec = "0.6"
|
bit-vec = "0.6"
|
||||||
half ="1.6"
|
half ="1.6"
|
||||||
bitflags = "1.2"
|
bitflags = "1.2"
|
||||||
|
rustc-hash = "2.0.0"
|
||||||
|
strum = "0.26"
|
||||||
|
strum_macros = "0.26"
|
||||||
|
|
||||||
[dependencies.lalrpop-util]
|
[dependencies.lalrpop-util]
|
||||||
version = "0.19.12"
|
version = "0.19.12"
|
||||||
|
@ -489,7 +489,7 @@ fn convert_to_stateful_memory_access_postprocess(
|
|||||||
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
|
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
|
||||||
let converting_id = id_defs
|
let converting_id = id_defs
|
||||||
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
||||||
let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) {
|
let kind = if new_operand_space == ast::StateSpace::Reg {
|
||||||
ConversionKind::Default
|
ConversionKind::Default
|
||||||
} else {
|
} else {
|
||||||
ConversionKind::PtrToPtr
|
ConversionKind::PtrToPtr
|
||||||
|
141
ptx/src/pass/deparamize_functions.rs
Normal file
141
ptx/src/pass/deparamize_functions.rs
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub(super) fn run<'a, 'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
|
directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2,
|
||||||
|
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
var @ Directive2::Variable(..) => var,
|
||||||
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_method<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2,
|
||||||
|
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
|
if method.func_decl.name.is_kernel() {
|
||||||
|
return Ok(method);
|
||||||
|
}
|
||||||
|
let is_declaration = method.body.is_none();
|
||||||
|
let mut body = Vec::new();
|
||||||
|
let mut remap_returns = Vec::new();
|
||||||
|
for arg in method.func_decl.return_arguments.iter_mut() {
|
||||||
|
match arg.state_space {
|
||||||
|
ptx_parser::StateSpace::Param => {
|
||||||
|
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||||
|
let old_name = arg.name;
|
||||||
|
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||||
|
if is_declaration {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
||||||
|
body.push(Statement::Variable(ast::Variable {
|
||||||
|
align: None,
|
||||||
|
name: old_name,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: ptx_parser::StateSpace::Param,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
ptx_parser::StateSpace::Reg => {}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for arg in method.func_decl.input_arguments.iter_mut() {
|
||||||
|
match arg.state_space {
|
||||||
|
ptx_parser::StateSpace::Param => {
|
||||||
|
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||||
|
let old_name = arg.name;
|
||||||
|
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||||
|
if is_declaration {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
body.push(Statement::Variable(ast::Variable {
|
||||||
|
align: None,
|
||||||
|
name: old_name,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: ptx_parser::StateSpace::Param,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
}));
|
||||||
|
body.push(Statement::Instruction(ast::Instruction::St {
|
||||||
|
data: ast::StData {
|
||||||
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
|
state_space: ast::StateSpace::Param,
|
||||||
|
caching: ast::StCacheOperator::Writethrough,
|
||||||
|
typ: arg.v_type.clone(),
|
||||||
|
},
|
||||||
|
arguments: ast::StArgs {
|
||||||
|
src1: old_name,
|
||||||
|
src2: arg.name,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
ptx_parser::StateSpace::Reg => {}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if remap_returns.is_empty() {
|
||||||
|
return Ok(method);
|
||||||
|
}
|
||||||
|
let body = method
|
||||||
|
.body
|
||||||
|
.map(|statements| {
|
||||||
|
for statement in statements {
|
||||||
|
run_statement(&remap_returns, &mut body, statement)?;
|
||||||
|
}
|
||||||
|
Ok::<_, TranslateError>(body)
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
Ok(Function2 {
|
||||||
|
func_decl: method.func_decl,
|
||||||
|
globals: method.globals,
|
||||||
|
body,
|
||||||
|
import_as: method.import_as,
|
||||||
|
tuning: method.tuning,
|
||||||
|
linkage: method.linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statement<'input>(
|
||||||
|
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
|
||||||
|
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
match statement {
|
||||||
|
Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||||
|
for (old_name, new_name, type_) in remap_returns.iter().cloned() {
|
||||||
|
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||||
|
data: ast::LdDetails {
|
||||||
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
caching: ast::LdCacheOperator::Cached,
|
||||||
|
typ: type_,
|
||||||
|
non_coherent: false,
|
||||||
|
},
|
||||||
|
arguments: ast::LdArgs {
|
||||||
|
dst: new_name,
|
||||||
|
src: old_name,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
result.push(statement);
|
||||||
|
}
|
||||||
|
statement => {
|
||||||
|
result.push(statement);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -164,17 +164,16 @@ impl Deref for MemoryBuffer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn run<'input>(
|
pub(super) fn run<'input>(
|
||||||
id_defs: &GlobalStringIdResolver<'input>,
|
id_defs: GlobalStringIdentResolver2<'input>,
|
||||||
call_map: MethodsCallMap<'input>,
|
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
directives: Vec<Directive<'input>>,
|
|
||||||
) -> Result<MemoryBuffer, TranslateError> {
|
) -> Result<MemoryBuffer, TranslateError> {
|
||||||
let context = Context::new();
|
let context = Context::new();
|
||||||
let module = Module::new(&context, LLVM_UNNAMED);
|
let module = Module::new(&context, LLVM_UNNAMED);
|
||||||
let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
|
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
|
||||||
for directive in directives {
|
for directive in directives {
|
||||||
match directive {
|
match directive {
|
||||||
Directive::Variable(..) => todo!(),
|
Directive2::Variable(..) => todo!(),
|
||||||
Directive::Method(method) => emit_ctx.emit_method(method)?,
|
Directive2::Method(method) => emit_ctx.emit_method(method)?,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
module.write_to_stderr();
|
module.write_to_stderr();
|
||||||
@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> {
|
|||||||
context: LLVMContextRef,
|
context: LLVMContextRef,
|
||||||
module: LLVMModuleRef,
|
module: LLVMModuleRef,
|
||||||
builder: Builder,
|
builder: Builder,
|
||||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
id_defs: &'a GlobalStringIdentResolver2<'input>,
|
||||||
resolver: ResolveIdent,
|
resolver: ResolveIdent,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||||||
fn new(
|
fn new(
|
||||||
context: &Context,
|
context: &Context,
|
||||||
module: &Module,
|
module: &Module,
|
||||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
id_defs: &'a GlobalStringIdentResolver2<'input>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
ModuleEmitContext {
|
ModuleEmitContext {
|
||||||
context: context.get(),
|
context: context.get(),
|
||||||
@ -215,26 +214,50 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||||||
LLVMCallConv::LLVMCCallConv as u32
|
LLVMCallConv::LLVMCCallConv as u32
|
||||||
}
|
}
|
||||||
|
|
||||||
fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
|
fn emit_method(
|
||||||
let func_decl = method.func_decl.borrow();
|
&mut self,
|
||||||
|
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let func_decl = method.func_decl;
|
||||||
let name = method
|
let name = method
|
||||||
.import_as
|
.import_as
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.unwrap_or_else(|| match func_decl.name {
|
.or_else(|| match func_decl.name {
|
||||||
ast::MethodName::Kernel(name) => name,
|
ast::MethodName::Kernel(name) => Some(name),
|
||||||
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
|
ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
|
||||||
});
|
})
|
||||||
|
.ok_or_else(|| error_unreachable())?;
|
||||||
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
||||||
let fn_type = self.function_type(
|
let fn_type = get_function_type(
|
||||||
|
self.context,
|
||||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
||||||
func_decl.input_arguments.iter().map(|v| &v.v_type),
|
func_decl
|
||||||
);
|
.input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
|
||||||
|
)?;
|
||||||
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||||
|
if let ast::MethodName::Func(name) = func_decl.name {
|
||||||
|
self.resolver.register(name, fn_);
|
||||||
|
}
|
||||||
for (i, param) in func_decl.input_arguments.iter().enumerate() {
|
for (i, param) in func_decl.input_arguments.iter().enumerate() {
|
||||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||||
let name = self.resolver.get_or_add(param.name);
|
let name = self.resolver.get_or_add(param.name);
|
||||||
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
||||||
self.resolver.register(param.name, value);
|
self.resolver.register(param.name, value);
|
||||||
|
if func_decl.name.is_kernel() {
|
||||||
|
let attr_kind = unsafe {
|
||||||
|
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
|
||||||
|
};
|
||||||
|
let attr = unsafe {
|
||||||
|
LLVMCreateTypeAttribute(
|
||||||
|
self.context,
|
||||||
|
attr_kind,
|
||||||
|
get_type(self.context, ¶m.v_type)?,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let call_conv = if func_decl.name.is_kernel() {
|
let call_conv = if func_decl.name.is_kernel() {
|
||||||
Self::kernel_call_convention()
|
Self::kernel_call_convention()
|
||||||
@ -258,66 +281,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn function_type(
|
fn get_input_argument_type(
|
||||||
&self,
|
context: LLVMContextRef,
|
||||||
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
v_type: &ptx_parser::Type,
|
||||||
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
state_space: ptx_parser::StateSpace,
|
||||||
) -> LLVMTypeRef {
|
) -> Result<LLVMTypeRef, TranslateError> {
|
||||||
if return_args.len() == 0 {
|
match state_space {
|
||||||
let mut input_args = input_args
|
ptx_parser::StateSpace::ParamEntry => {
|
||||||
.map(|type_| match type_ {
|
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
|
||||||
ast::Type::Scalar(scalar) => match scalar {
|
|
||||||
ast::ScalarType::Pred => {
|
|
||||||
unsafe { LLVMInt1TypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
|
|
||||||
unsafe { LLVMInt8TypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
|
|
||||||
unsafe { LLVMInt16TypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
|
|
||||||
unsafe { LLVMInt32TypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
|
|
||||||
unsafe { LLVMInt64TypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::B128 => {
|
|
||||||
unsafe { LLVMInt128TypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::F16 => {
|
|
||||||
unsafe { LLVMHalfTypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::F32 => {
|
|
||||||
unsafe { LLVMFloatTypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::F64 => {
|
|
||||||
unsafe { LLVMDoubleTypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::BF16 => {
|
|
||||||
unsafe { LLVMBFloatTypeInContext(self.context) }
|
|
||||||
}
|
|
||||||
ast::ScalarType::U16x2 => todo!(),
|
|
||||||
ast::ScalarType::S16x2 => todo!(),
|
|
||||||
ast::ScalarType::F16x2 => todo!(),
|
|
||||||
ast::ScalarType::BF16x2 => todo!(),
|
|
||||||
},
|
|
||||||
ast::Type::Vector(_, _) => todo!(),
|
|
||||||
ast::Type::Array(_, _, _) => todo!(),
|
|
||||||
ast::Type::Pointer(_, _) => todo!(),
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
return unsafe {
|
|
||||||
LLVMFunctionType(
|
|
||||||
LLVMVoidTypeInContext(self.context),
|
|
||||||
input_args.as_mut_ptr(),
|
|
||||||
input_args.len() as u32,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
todo!()
|
ptx_parser::StateSpace::Reg => get_type(context, v_type),
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -326,7 +302,7 @@ struct MethodEmitContext<'a, 'input> {
|
|||||||
module: LLVMModuleRef,
|
module: LLVMModuleRef,
|
||||||
method: LLVMValueRef,
|
method: LLVMValueRef,
|
||||||
builder: LLVMBuilderRef,
|
builder: LLVMBuilderRef,
|
||||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
id_defs: &'a GlobalStringIdentResolver2<'input>,
|
||||||
variables_builder: Builder,
|
variables_builder: Builder,
|
||||||
resolver: &'a mut ResolveIdent,
|
resolver: &'a mut ResolveIdent,
|
||||||
}
|
}
|
||||||
@ -365,6 +341,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||||||
Statement::PtrAccess(_) => todo!(),
|
Statement::PtrAccess(_) => todo!(),
|
||||||
Statement::RepackVector(_) => todo!(),
|
Statement::RepackVector(_) => todo!(),
|
||||||
Statement::FunctionPointer(_) => todo!(),
|
Statement::FunctionPointer(_) => todo!(),
|
||||||
|
Statement::VectorAccess(_) => todo!(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -414,7 +391,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||||||
inst: ast::Instruction<SpirvWord>,
|
inst: ast::Instruction<SpirvWord>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
match inst {
|
match inst {
|
||||||
ast::Instruction::Mov { data, arguments } => todo!(),
|
ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
|
||||||
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
|
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
|
||||||
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
|
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
|
||||||
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
|
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
|
||||||
@ -425,7 +402,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||||||
ast::Instruction::Or { data, arguments } => todo!(),
|
ast::Instruction::Or { data, arguments } => todo!(),
|
||||||
ast::Instruction::And { data, arguments } => todo!(),
|
ast::Instruction::And { data, arguments } => todo!(),
|
||||||
ast::Instruction::Bra { arguments } => todo!(),
|
ast::Instruction::Bra { arguments } => todo!(),
|
||||||
ast::Instruction::Call { data, arguments } => todo!(),
|
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
|
||||||
ast::Instruction::Cvt { data, arguments } => todo!(),
|
ast::Instruction::Cvt { data, arguments } => todo!(),
|
||||||
ast::Instruction::Shr { data, arguments } => todo!(),
|
ast::Instruction::Shr { data, arguments } => todo!(),
|
||||||
ast::Instruction::Shl { data, arguments } => todo!(),
|
ast::Instruction::Shl { data, arguments } => todo!(),
|
||||||
@ -563,6 +540,70 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||||||
fn emit_ret(&self, _data: ptx_parser::RetData) {
|
fn emit_ret(&self, _data: ptx_parser::RetData) {
|
||||||
unsafe { LLVMBuildRetVoid(self.builder) };
|
unsafe { LLVMBuildRetVoid(self.builder) };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_call(
|
||||||
|
&mut self,
|
||||||
|
data: ptx_parser::CallDetails,
|
||||||
|
arguments: ptx_parser::CallArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
for (_, space) in data.return_arguments.iter() {
|
||||||
|
if *space != ast::StateSpace::Reg {
|
||||||
|
panic!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (_, space) in data.input_arguments.iter() {
|
||||||
|
if *space != ast::StateSpace::Reg {
|
||||||
|
panic!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let name = match (&*data.return_arguments, &*arguments.return_arguments) {
|
||||||
|
([], []) => LLVM_UNNAMED.as_ptr(),
|
||||||
|
([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
|
||||||
|
_ => todo!(),
|
||||||
|
};
|
||||||
|
let type_ = get_function_type(
|
||||||
|
self.context,
|
||||||
|
data.return_arguments.iter().map(|(type_, space)| type_),
|
||||||
|
data.input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
|
||||||
|
)?;
|
||||||
|
let mut input_arguments = arguments
|
||||||
|
.input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|arg| self.resolver.value(*arg))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let llvm_fn = unsafe {
|
||||||
|
LLVMBuildCall2(
|
||||||
|
self.builder,
|
||||||
|
type_,
|
||||||
|
self.resolver.value(arguments.func)?,
|
||||||
|
input_arguments.as_mut_ptr(),
|
||||||
|
input_arguments.len() as u32,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
match &*arguments.return_arguments {
|
||||||
|
[] => {}
|
||||||
|
[name] => {
|
||||||
|
self.resolver.register(*name, llvm_fn);
|
||||||
|
}
|
||||||
|
_ => todo!(),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_mov(
|
||||||
|
&mut self,
|
||||||
|
_data: ptx_parser::MovDetails,
|
||||||
|
arguments: ptx_parser::MovArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
self.resolver
|
||||||
|
.register(arguments.dst, self.resolver.value(arguments.src)?);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pointer_type<'ctx>(
|
fn get_pointer_type<'ctx>(
|
||||||
@ -624,13 +665,34 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_function_type<'a>(
|
||||||
|
context: LLVMContextRef,
|
||||||
|
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||||
|
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
|
||||||
|
) -> Result<LLVMTypeRef, TranslateError> {
|
||||||
|
let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
|
||||||
|
input_args.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let return_type = match return_args.len() {
|
||||||
|
0 => unsafe { LLVMVoidTypeInContext(context) },
|
||||||
|
1 => get_type(context, return_args.next().unwrap())?,
|
||||||
|
_ => todo!(),
|
||||||
|
};
|
||||||
|
Ok(unsafe {
|
||||||
|
LLVMFunctionType(
|
||||||
|
return_type,
|
||||||
|
input_args.as_mut_ptr(),
|
||||||
|
input_args.len() as u32,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
|
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
|
||||||
match space {
|
match space {
|
||||||
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
|
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
|
||||||
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
|
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
|
||||||
ast::StateSpace::Sreg => Ok(PRIVATE_ADDRESS_SPACE),
|
|
||||||
ast::StateSpace::Param => Err(TranslateError::Todo),
|
ast::StateSpace::Param => Err(TranslateError::Todo),
|
||||||
ast::StateSpace::ParamEntry => Err(TranslateError::Todo),
|
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
|
||||||
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
|
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
|
||||||
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
|
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
|
||||||
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
|
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
|
||||||
@ -647,7 +709,7 @@ struct ResolveIdent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ResolveIdent {
|
impl ResolveIdent {
|
||||||
fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
|
fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
|
||||||
ResolveIdent {
|
ResolveIdent {
|
||||||
words: HashMap::new(),
|
words: HashMap::new(),
|
||||||
values: HashMap::new(),
|
values: HashMap::new(),
|
||||||
|
@ -469,7 +469,6 @@ fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass {
|
|||||||
ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
|
ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
|
||||||
ast::StateSpace::Param => spirv::StorageClass::Function,
|
ast::StateSpace::Param => spirv::StorageClass::Function,
|
||||||
ast::StateSpace::Reg => spirv::StorageClass::Function,
|
ast::StateSpace::Reg => spirv::StorageClass::Function,
|
||||||
ast::StateSpace::Sreg => spirv::StorageClass::Input,
|
|
||||||
ast::StateSpace::ParamEntry
|
ast::StateSpace::ParamEntry
|
||||||
| ast::StateSpace::ParamFunc
|
| ast::StateSpace::ParamFunc
|
||||||
| ast::StateSpace::SharedCluster
|
| ast::StateSpace::SharedCluster
|
||||||
@ -693,7 +692,6 @@ fn emit_variable<'input>(
|
|||||||
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
|
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
|
||||||
ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
|
ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
|
||||||
ast::StateSpace::Generic => todo!(),
|
ast::StateSpace::Generic => todo!(),
|
||||||
ast::StateSpace::Sreg => todo!(),
|
|
||||||
ast::StateSpace::ParamEntry
|
ast::StateSpace::ParamEntry
|
||||||
| ast::StateSpace::ParamFunc
|
| ast::StateSpace::ParamFunc
|
||||||
| ast::StateSpace::SharedCluster
|
| ast::StateSpace::SharedCluster
|
||||||
@ -1563,6 +1561,7 @@ fn emit_function_body_ops<'input>(
|
|||||||
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
|
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Statement::VectorAccess(vector_access) => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -63,9 +63,9 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||||||
} else {
|
} else {
|
||||||
return Err(TranslateError::UntypedSymbol);
|
return Err(TranslateError::UntypedSymbol);
|
||||||
};
|
};
|
||||||
if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg {
|
if state_space == ast::StateSpace::Reg {
|
||||||
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
|
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
|
||||||
if !space_is_compatible(reg_space, ast::StateSpace::Reg) {
|
if reg_space != ast::StateSpace::Reg {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
let reg_scalar_type = match reg_type {
|
let reg_scalar_type = match reg_type {
|
||||||
|
289
ptx/src/pass/expand_operands.rs
Normal file
289
ptx/src/pass/expand_operands.rs
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub(super) fn run<'a, 'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directives: Vec<UnconditionalDirective<'input>>,
|
||||||
|
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
|
directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directive: Directive2<
|
||||||
|
'input,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>,
|
||||||
|
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||||
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_method<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
method: Function2<
|
||||||
|
'input,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>,
|
||||||
|
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
|
let body = method
|
||||||
|
.body
|
||||||
|
.map(|statements| {
|
||||||
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
|
for statement in statements {
|
||||||
|
run_statement(resolver, &mut result, statement)?;
|
||||||
|
}
|
||||||
|
Ok::<_, TranslateError>(result)
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
Ok(Function2 {
|
||||||
|
func_decl: method.func_decl,
|
||||||
|
globals: method.globals,
|
||||||
|
body,
|
||||||
|
import_as: method.import_as,
|
||||||
|
tuning: method.tuning,
|
||||||
|
linkage: method.linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statement<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
statement: UnconditionalStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let mut visitor = FlattenArguments::new(resolver, result);
|
||||||
|
let new_statement = statement.visit_map(&mut visitor)?;
|
||||||
|
visitor.result.push(new_statement);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FlattenArguments<'a, 'input> {
|
||||||
|
result: &'a mut Vec<ExpandedStatement>,
|
||||||
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
|
post_stmts: Vec<ExpandedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'input> FlattenArguments<'a, 'input> {
|
||||||
|
fn new(
|
||||||
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
|
result: &'a mut Vec<ExpandedStatement>,
|
||||||
|
) -> Self {
|
||||||
|
FlattenArguments {
|
||||||
|
result,
|
||||||
|
resolver,
|
||||||
|
post_stmts: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||||
|
Ok(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reg_offset(
|
||||||
|
&mut self,
|
||||||
|
reg: SpirvWord,
|
||||||
|
offset: i32,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
_is_dst: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||||
|
(type_, state_space)
|
||||||
|
} else {
|
||||||
|
return Err(TranslateError::UntypedSymbol);
|
||||||
|
};
|
||||||
|
if state_space == ast::StateSpace::Reg {
|
||||||
|
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
|
||||||
|
if *reg_space != ast::StateSpace::Reg {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
let reg_scalar_type = match reg_type {
|
||||||
|
ast::Type::Scalar(underlying_type) => *underlying_type,
|
||||||
|
_ => return Err(error_mismatched_type()),
|
||||||
|
};
|
||||||
|
let reg_type = reg_type.clone();
|
||||||
|
let id_constant_stmt = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
|
||||||
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id_constant_stmt,
|
||||||
|
typ: reg_scalar_type,
|
||||||
|
value: ast::ImmediateValue::S64(offset as i64),
|
||||||
|
}));
|
||||||
|
let arith_details = match reg_scalar_type.kind() {
|
||||||
|
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: reg_scalar_type,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: reg_scalar_type,
|
||||||
|
saturate: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
let id_add_result = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((reg_type, state_space)));
|
||||||
|
self.result
|
||||||
|
.push(Statement::Instruction(ast::Instruction::Add {
|
||||||
|
data: arith_details,
|
||||||
|
arguments: ast::AddArgs {
|
||||||
|
dst: id_add_result,
|
||||||
|
src1: reg,
|
||||||
|
src2: id_constant_stmt,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
Ok(id_add_result)
|
||||||
|
} else {
|
||||||
|
let id_constant_stmt = self.resolver.register_unnamed(Some((
|
||||||
|
ast::Type::Scalar(ast::ScalarType::S64),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id_constant_stmt,
|
||||||
|
typ: ast::ScalarType::S64,
|
||||||
|
value: ast::ImmediateValue::S64(offset as i64),
|
||||||
|
}));
|
||||||
|
let dst = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((type_.clone(), state_space)));
|
||||||
|
self.result.push(Statement::PtrAccess(PtrAccess {
|
||||||
|
underlying_type: type_.clone(),
|
||||||
|
state_space: state_space,
|
||||||
|
dst,
|
||||||
|
ptr_src: reg,
|
||||||
|
offset_src: id_constant_stmt,
|
||||||
|
}));
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn immediate(
|
||||||
|
&mut self,
|
||||||
|
value: ast::ImmediateValue,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let (scalar_t, state_space) =
|
||||||
|
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||||
|
(*scalar, state_space)
|
||||||
|
} else {
|
||||||
|
return Err(TranslateError::UntypedSymbol);
|
||||||
|
};
|
||||||
|
let id = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
|
||||||
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id,
|
||||||
|
typ: scalar_t,
|
||||||
|
value,
|
||||||
|
}));
|
||||||
|
Ok(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_member(
|
||||||
|
&mut self,
|
||||||
|
vector_src: SpirvWord,
|
||||||
|
member: u8,
|
||||||
|
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
if is_dst {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
|
||||||
|
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
||||||
|
(*vector_width, *scalar_t, *space)
|
||||||
|
}
|
||||||
|
_ => return Err(error_mismatched_type()),
|
||||||
|
};
|
||||||
|
let temporary = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((scalar_type.into(), space)));
|
||||||
|
self.result.push(Statement::VectorAccess(VectorAccess {
|
||||||
|
scalar_type,
|
||||||
|
vector_width,
|
||||||
|
dst: temporary,
|
||||||
|
src: vector_src,
|
||||||
|
member: member,
|
||||||
|
}));
|
||||||
|
Ok(temporary)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_pack(
|
||||||
|
&mut self,
|
||||||
|
vecs: Vec<SpirvWord>,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let (scalar_t, state_space) = match type_space {
|
||||||
|
Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
|
||||||
|
_ => return Err(error_mismatched_type()),
|
||||||
|
};
|
||||||
|
let temp_vec = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((scalar_t.into(), state_space)));
|
||||||
|
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||||
|
is_extract: is_dst,
|
||||||
|
typ: scalar_t,
|
||||||
|
packed: temp_vec,
|
||||||
|
unpacked: vecs,
|
||||||
|
relaxed_type_check,
|
||||||
|
});
|
||||||
|
if is_dst {
|
||||||
|
self.post_stmts.push(statement);
|
||||||
|
} else {
|
||||||
|
self.result.push(statement);
|
||||||
|
}
|
||||||
|
Ok(temp_vec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
|
||||||
|
for FlattenArguments<'a, 'b>
|
||||||
|
{
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
args: ast::ParsedOperand<SpirvWord>,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
match args {
|
||||||
|
ast::ParsedOperand::Reg(r) => self.reg(r),
|
||||||
|
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
|
||||||
|
ast::ParsedOperand::RegOffset(reg, offset) => {
|
||||||
|
self.reg_offset(reg, offset, type_space, is_dst)
|
||||||
|
}
|
||||||
|
ast::ParsedOperand::VecMember(vec, member) => {
|
||||||
|
self.vec_member(vec, member, type_space, is_dst)
|
||||||
|
}
|
||||||
|
ast::ParsedOperand::VecPack(vecs) => {
|
||||||
|
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
name: <TypedOperand as ast::Operand>::Ident,
|
||||||
|
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
_is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
|
||||||
|
self.reg(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for FlattenArguments<'_, '_> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.result.extend(self.post_stmts.drain(..));
|
||||||
|
}
|
||||||
|
}
|
@ -273,7 +273,6 @@ fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
|
|||||||
ast::StateSpace::Const => "const",
|
ast::StateSpace::Const => "const",
|
||||||
ast::StateSpace::Local => "local",
|
ast::StateSpace::Local => "local",
|
||||||
ast::StateSpace::Param => "param",
|
ast::StateSpace::Param => "param",
|
||||||
ast::StateSpace::Sreg => "sreg",
|
|
||||||
ast::StateSpace::SharedCluster => "shared_cluster",
|
ast::StateSpace::SharedCluster => "shared_cluster",
|
||||||
ast::StateSpace::ParamEntry => "param_entry",
|
ast::StateSpace::ParamEntry => "param_entry",
|
||||||
ast::StateSpace::SharedCta => "shared_cta",
|
ast::StateSpace::SharedCta => "shared_cta",
|
||||||
|
209
ptx/src/pass/fix_special_registers2.rs
Normal file
209
ptx/src/pass/fix_special_registers2.rs
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub(super) fn run<'a, 'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
special_registers: &'a SpecialRegistersMap2,
|
||||||
|
directives: Vec<UnconditionalDirective<'input>>,
|
||||||
|
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||||
|
let declarations = SpecialRegistersMap2::generate_declarations(resolver);
|
||||||
|
let mut result = Vec::with_capacity(declarations.len() + directives.len());
|
||||||
|
let mut sreg_to_function =
|
||||||
|
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
|
||||||
|
for (sreg, declaration) in declarations {
|
||||||
|
let name = if let ast::MethodName::Func(name) = declaration.name {
|
||||||
|
name
|
||||||
|
} else {
|
||||||
|
return Err(error_unreachable());
|
||||||
|
};
|
||||||
|
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||||
|
func_decl: declaration,
|
||||||
|
globals: Vec::new(),
|
||||||
|
body: None,
|
||||||
|
import_as: None,
|
||||||
|
tuning: Vec::new(),
|
||||||
|
linkage: ast::LinkingDirective::EXTERN,
|
||||||
|
}));
|
||||||
|
sreg_to_function.insert(sreg, name);
|
||||||
|
}
|
||||||
|
let mut visitor = SpecialRegisterResolver {
|
||||||
|
resolver,
|
||||||
|
special_registers,
|
||||||
|
sreg_to_function,
|
||||||
|
result: Vec::new(),
|
||||||
|
};
|
||||||
|
directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(&mut visitor, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'a, 'input>(
|
||||||
|
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||||
|
directive: UnconditionalDirective<'input>,
|
||||||
|
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
var @ Directive2::Variable(..) => var,
|
||||||
|
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_method<'a, 'input>(
|
||||||
|
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||||
|
method: UnconditionalFunction<'input>,
|
||||||
|
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||||
|
let body = method
|
||||||
|
.body
|
||||||
|
.map(|statements| {
|
||||||
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
|
for statement in statements {
|
||||||
|
run_statement(visitor, &mut result, statement)?;
|
||||||
|
}
|
||||||
|
Ok::<_, TranslateError>(result)
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
Ok(Function2 {
|
||||||
|
func_decl: method.func_decl,
|
||||||
|
globals: method.globals,
|
||||||
|
body,
|
||||||
|
import_as: method.import_as,
|
||||||
|
tuning: method.tuning,
|
||||||
|
linkage: method.linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statement<'a, 'input>(
|
||||||
|
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||||
|
result: &mut Vec<UnconditionalStatement>,
|
||||||
|
statement: UnconditionalStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let converted_statement = statement.visit_map(visitor)?;
|
||||||
|
result.extend(visitor.result.drain(..));
|
||||||
|
result.push(converted_statement);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SpecialRegisterResolver<'a, 'input> {
|
||||||
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
|
special_registers: &'a SpecialRegistersMap2,
|
||||||
|
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||||
|
result: Vec<UnconditionalStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b, 'input>
|
||||||
|
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
|
||||||
|
for SpecialRegisterResolver<'a, 'input>
|
||||||
|
{
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
operand: ast::ParsedOperand<SpirvWord>,
|
||||||
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
|
||||||
|
map_operand(operand, &mut |ident, vector_index| {
|
||||||
|
self.replace_sreg(ident, vector_index, is_dst)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
args: SpirvWord,
|
||||||
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
self.replace_sreg(args, None, is_dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
||||||
|
fn replace_sreg(
|
||||||
|
&mut self,
|
||||||
|
name: SpirvWord,
|
||||||
|
vector_index: Option<u8>,
|
||||||
|
is_dst: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
if let Some(sreg) = self.special_registers.get(name) {
|
||||||
|
if is_dst {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||||
|
(Some(idx), Some(inp_type)) => {
|
||||||
|
if inp_type != ast::ScalarType::U8 {
|
||||||
|
return Err(TranslateError::Unreachable);
|
||||||
|
}
|
||||||
|
let constant = self.resolver.register_unnamed(Some((
|
||||||
|
ast::Type::Scalar(inp_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: constant,
|
||||||
|
typ: inp_type,
|
||||||
|
value: ast::ImmediateValue::U64(idx as u64),
|
||||||
|
}));
|
||||||
|
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
|
||||||
|
}
|
||||||
|
(None, None) => Vec::new(),
|
||||||
|
_ => return Err(error_mismatched_type()),
|
||||||
|
};
|
||||||
|
let return_type = sreg.get_function_return_type();
|
||||||
|
let fn_result = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
|
||||||
|
let return_arguments = vec![(
|
||||||
|
fn_result,
|
||||||
|
ast::Type::Scalar(return_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)];
|
||||||
|
let data = ast::CallDetails {
|
||||||
|
uniform: false,
|
||||||
|
return_arguments: return_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||||
|
.collect(),
|
||||||
|
input_arguments: input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||||
|
.collect(),
|
||||||
|
};
|
||||||
|
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
|
||||||
|
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||||
|
func: self.sreg_to_function[&sreg],
|
||||||
|
input_arguments: input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
|
||||||
|
.collect(),
|
||||||
|
};
|
||||||
|
self.result
|
||||||
|
.push(Statement::Instruction(ast::Instruction::Call {
|
||||||
|
data,
|
||||||
|
arguments,
|
||||||
|
}));
|
||||||
|
Ok(fn_result)
|
||||||
|
} else {
|
||||||
|
Ok(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn map_operand<T, U, Err>(
|
||||||
|
this: ast::ParsedOperand<T>,
|
||||||
|
fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
|
||||||
|
) -> Result<ast::ParsedOperand<U>, Err> {
|
||||||
|
Ok(match this {
|
||||||
|
ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
|
||||||
|
ast::ParsedOperand::RegOffset(ident, offset) => {
|
||||||
|
ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
|
||||||
|
}
|
||||||
|
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
||||||
|
ast::ParsedOperand::VecMember(ident, member) => {
|
||||||
|
ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
|
||||||
|
}
|
||||||
|
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||||
|
idents
|
||||||
|
.into_iter()
|
||||||
|
.map(|ident| fn_(ident, None))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?,
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
45
ptx/src/pass/hoist_globals.rs
Normal file
45
ptx/src/pass/hoist_globals.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub(super) fn run<'input>(
|
||||||
|
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(directives.len());
|
||||||
|
for mut directive in directives.into_iter() {
|
||||||
|
run_directive(&mut result, &mut directive);
|
||||||
|
result.push(directive);
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'input>(
|
||||||
|
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
match directive {
|
||||||
|
Directive2::Variable(..) => {}
|
||||||
|
Directive2::Method(function2) => run_function(result, function2),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_function<'input>(
|
||||||
|
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) {
|
||||||
|
function.body = function.body.take().map(|statements| {
|
||||||
|
statements
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|statement| match statement {
|
||||||
|
Statement::Variable(var @ ast::Variable {
|
||||||
|
state_space:
|
||||||
|
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
s => Some(s),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
}
|
338
ptx/src/pass/insert_explicit_load_store.rs
Normal file
338
ptx/src/pass/insert_explicit_load_store.rs
Normal file
@ -0,0 +1,338 @@
|
|||||||
|
use super::*;
|
||||||
|
use ptx_parser::VisitorMap;
|
||||||
|
use rustc_hash::FxHashSet;
|
||||||
|
|
||||||
|
// This pass:
|
||||||
|
// * Turns all .local, .param and .reg in-body variables into .local variables
|
||||||
|
// (if _not_ an input method argument)
|
||||||
|
// * Inserts explicit `ld`/`st` for newly converted .reg variables
|
||||||
|
// * Fixup state space of all existing `ld`/`st` instructions into newly
|
||||||
|
// converted variables
|
||||||
|
// * Turns `.entry` input arguments into param::entry and all related `.param`
|
||||||
|
// loads into `param::entry` loads
|
||||||
|
// * All `.func` input arguments are turned into `.reg` arguments by another
|
||||||
|
// pass, so we do nothing there
|
||||||
|
pub(super) fn run<'a, 'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
|
directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'a, 'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
var @ Directive2::Variable(..) => var,
|
||||||
|
Directive2::Method(method) => {
|
||||||
|
let visitor = InsertMemSSAVisitor::new(resolver);
|
||||||
|
Directive2::Method(run_method(visitor, method)?)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_method<'a, 'input>(
|
||||||
|
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
||||||
|
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
|
let mut func_decl = method.func_decl;
|
||||||
|
for arg in func_decl.return_arguments.iter_mut() {
|
||||||
|
visitor.visit_variable(arg)?;
|
||||||
|
}
|
||||||
|
let is_kernel = func_decl.name.is_kernel();
|
||||||
|
if is_kernel {
|
||||||
|
for arg in func_decl.input_arguments.iter_mut() {
|
||||||
|
let old_name = arg.name;
|
||||||
|
let old_space = arg.state_space;
|
||||||
|
let new_space = ast::StateSpace::ParamEntry;
|
||||||
|
let new_name = visitor
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
||||||
|
visitor.input_argument(old_name, new_name, old_space);
|
||||||
|
arg.name = new_name;
|
||||||
|
arg.state_space = new_space;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let body = method
|
||||||
|
.body
|
||||||
|
.map(move |statements| {
|
||||||
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
|
for statement in statements {
|
||||||
|
run_statement(&mut visitor, &mut result, statement)?;
|
||||||
|
}
|
||||||
|
Ok::<_, TranslateError>(result)
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
Ok(Function2 {
|
||||||
|
func_decl: func_decl,
|
||||||
|
globals: method.globals,
|
||||||
|
body,
|
||||||
|
import_as: method.import_as,
|
||||||
|
tuning: method.tuning,
|
||||||
|
linkage: method.linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statement<'a, 'input>(
|
||||||
|
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
|
||||||
|
result: &mut Vec<ExpandedStatement>,
|
||||||
|
statement: ExpandedStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
match statement {
|
||||||
|
Statement::Variable(mut var) => {
|
||||||
|
visitor.visit_variable(&mut var)?;
|
||||||
|
result.push(Statement::Variable(var));
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
|
||||||
|
let instruction = visitor.visit_ld(data, arguments)?;
|
||||||
|
let instruction = ast::visit_map(instruction, visitor)?;
|
||||||
|
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||||
|
result.push(Statement::Instruction(instruction));
|
||||||
|
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
|
||||||
|
let instruction = visitor.visit_st(data, arguments)?;
|
||||||
|
let instruction = ast::visit_map(instruction, visitor)?;
|
||||||
|
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||||
|
result.push(Statement::Instruction(instruction));
|
||||||
|
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||||
|
}
|
||||||
|
s => {
|
||||||
|
let new_statement = s.visit_map(visitor)?;
|
||||||
|
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||||
|
result.push(new_statement);
|
||||||
|
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InsertMemSSAVisitor<'a, 'input> {
|
||||||
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
|
variables: FxHashMap<SpirvWord, RemapAction>,
|
||||||
|
pre: Vec<ast::Instruction<SpirvWord>>,
|
||||||
|
post: Vec<ast::Instruction<SpirvWord>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
|
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||||
|
Self {
|
||||||
|
resolver,
|
||||||
|
variables: FxHashMap::default(),
|
||||||
|
pre: Vec::new(),
|
||||||
|
post: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_argument(
|
||||||
|
&mut self,
|
||||||
|
old_name: SpirvWord,
|
||||||
|
new_name: SpirvWord,
|
||||||
|
old_space: ast::StateSpace,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
if old_space != ast::StateSpace::Param {
|
||||||
|
return Err(error_unreachable());
|
||||||
|
}
|
||||||
|
self.variables.insert(
|
||||||
|
old_name,
|
||||||
|
RemapAction::LDStSpaceChange {
|
||||||
|
name: new_name,
|
||||||
|
old_space,
|
||||||
|
new_space: ast::StateSpace::ParamEntry,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn variable(
|
||||||
|
&mut self,
|
||||||
|
type_: &ast::Type,
|
||||||
|
old_name: SpirvWord,
|
||||||
|
new_name: SpirvWord,
|
||||||
|
old_space: ast::StateSpace,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
Ok(match old_space {
|
||||||
|
ast::StateSpace::Reg => {
|
||||||
|
self.variables.insert(
|
||||||
|
old_name,
|
||||||
|
RemapAction::PreLdPostSt {
|
||||||
|
name: new_name,
|
||||||
|
type_: type_.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
ast::StateSpace::Param => {
|
||||||
|
self.variables.insert(
|
||||||
|
old_name,
|
||||||
|
RemapAction::LDStSpaceChange {
|
||||||
|
old_space,
|
||||||
|
new_space: ast::StateSpace::Local,
|
||||||
|
name: new_name,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Good as-is
|
||||||
|
ast::StateSpace::Local => {}
|
||||||
|
// Will be pulled into global scope later
|
||||||
|
ast::StateSpace::Generic
|
||||||
|
| ast::StateSpace::SharedCluster
|
||||||
|
| ast::StateSpace::Global
|
||||||
|
| ast::StateSpace::Const
|
||||||
|
| ast::StateSpace::SharedCta
|
||||||
|
| ast::StateSpace::Shared => {}
|
||||||
|
ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
|
||||||
|
return Err(error_unreachable())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_st(
|
||||||
|
&self,
|
||||||
|
mut data: ast::StData,
|
||||||
|
mut arguments: ast::StArgs<SpirvWord>,
|
||||||
|
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||||
|
if let Some(remap) = self.variables.get(&arguments.src1) {
|
||||||
|
match remap {
|
||||||
|
RemapAction::PreLdPostSt { .. } => {}
|
||||||
|
RemapAction::LDStSpaceChange {
|
||||||
|
old_space,
|
||||||
|
new_space,
|
||||||
|
name,
|
||||||
|
} => {
|
||||||
|
if data.state_space != *old_space {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
data.state_space = *new_space;
|
||||||
|
arguments.src1 = *name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(ast::Instruction::St { data, arguments })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ld(
|
||||||
|
&self,
|
||||||
|
mut data: ast::LdDetails,
|
||||||
|
mut arguments: ast::LdArgs<SpirvWord>,
|
||||||
|
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||||
|
if let Some(remap) = self.variables.get(&arguments.src) {
|
||||||
|
match remap {
|
||||||
|
RemapAction::PreLdPostSt { .. } => {}
|
||||||
|
RemapAction::LDStSpaceChange {
|
||||||
|
old_space,
|
||||||
|
new_space,
|
||||||
|
name,
|
||||||
|
} => {
|
||||||
|
if data.state_space != *old_space {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
data.state_space = *new_space;
|
||||||
|
arguments.src = *name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(ast::Instruction::Ld { data, arguments })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
||||||
|
if var.state_space != ast::StateSpace::Local {
|
||||||
|
let old_name = var.name;
|
||||||
|
let old_space = var.state_space;
|
||||||
|
let new_space = ast::StateSpace::Local;
|
||||||
|
let new_name = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
||||||
|
self.variable(&var.v_type, old_name, new_name, old_space)?;
|
||||||
|
var.name = new_name;
|
||||||
|
var.state_space = new_space;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
|
||||||
|
for InsertMemSSAVisitor<'a, 'input>
|
||||||
|
{
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
ident: SpirvWord,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
if let Some(remap) = self.variables.get(&ident) {
|
||||||
|
match remap {
|
||||||
|
RemapAction::PreLdPostSt { name, type_ } => {
|
||||||
|
if is_dst {
|
||||||
|
let temp = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||||
|
self.post.push(ast::Instruction::St {
|
||||||
|
data: ast::StData {
|
||||||
|
state_space: ast::StateSpace::Local,
|
||||||
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
|
caching: ast::StCacheOperator::Writethrough,
|
||||||
|
typ: type_.clone(),
|
||||||
|
},
|
||||||
|
arguments: ast::StArgs {
|
||||||
|
src1: *name,
|
||||||
|
src2: temp,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
Ok(temp)
|
||||||
|
} else {
|
||||||
|
let temp = self
|
||||||
|
.resolver
|
||||||
|
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||||
|
self.pre.push(ast::Instruction::Ld {
|
||||||
|
data: ast::LdDetails {
|
||||||
|
state_space: ast::StateSpace::Local,
|
||||||
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
|
caching: ast::LdCacheOperator::Cached,
|
||||||
|
typ: type_.clone(),
|
||||||
|
non_coherent: false,
|
||||||
|
},
|
||||||
|
arguments: ast::LdArgs {
|
||||||
|
dst: temp,
|
||||||
|
src: *name,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
Ok(temp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RemapAction::LDStSpaceChange { .. } => {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(ident)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
args: SpirvWord,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
self.visit(args, type_space, is_dst, relaxed_type_check)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum RemapAction {
|
||||||
|
PreLdPostSt {
|
||||||
|
name: SpirvWord,
|
||||||
|
type_: ast::Type,
|
||||||
|
},
|
||||||
|
LDStSpaceChange {
|
||||||
|
old_space: ast::StateSpace,
|
||||||
|
new_space: ast::StateSpace,
|
||||||
|
name: SpirvWord,
|
||||||
|
},
|
||||||
|
}
|
@ -45,6 +45,13 @@ pub(super) fn run(
|
|||||||
Statement::RepackVector(repack),
|
Statement::RepackVector(repack),
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
Statement::VectorAccess(vector_access) => {
|
||||||
|
insert_implicit_conversions_impl(
|
||||||
|
&mut result,
|
||||||
|
id_def,
|
||||||
|
Statement::VectorAccess(vector_access),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
s @ Statement::Conditional(_)
|
s @ Statement::Conditional(_)
|
||||||
| s @ Statement::Conversion(_)
|
| s @ Statement::Conversion(_)
|
||||||
| s @ Statement::Label(_)
|
| s @ Statement::Label(_)
|
||||||
@ -128,7 +135,7 @@ pub(crate) fn default_implicit_conversion(
|
|||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if instruction_space == ast::StateSpace::Reg {
|
if instruction_space == ast::StateSpace::Reg {
|
||||||
if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
if operand_space == ast::StateSpace::Reg {
|
||||||
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||||
(operand_type, instruction_type)
|
(operand_type, instruction_type)
|
||||||
{
|
{
|
||||||
@ -142,7 +149,7 @@ pub(crate) fn default_implicit_conversion(
|
|||||||
return Ok(Some(ConversionKind::AddressOf));
|
return Ok(Some(ConversionKind::AddressOf));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !space_is_compatible(instruction_space, operand_space) {
|
if instruction_space != operand_space {
|
||||||
default_implicit_conversion_space(
|
default_implicit_conversion_space(
|
||||||
(operand_space, operand_type),
|
(operand_space, operand_type),
|
||||||
(instruction_space, instruction_type),
|
(instruction_space, instruction_type),
|
||||||
@ -161,7 +168,7 @@ fn is_addressable(this: ast::StateSpace) -> bool {
|
|||||||
| ast::StateSpace::Global
|
| ast::StateSpace::Global
|
||||||
| ast::StateSpace::Local
|
| ast::StateSpace::Local
|
||||||
| ast::StateSpace::Shared => true,
|
| ast::StateSpace::Shared => true,
|
||||||
ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
|
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
||||||
ast::StateSpace::SharedCluster
|
ast::StateSpace::SharedCluster
|
||||||
| ast::StateSpace::SharedCta
|
| ast::StateSpace::SharedCta
|
||||||
| ast::StateSpace::ParamEntry
|
| ast::StateSpace::ParamEntry
|
||||||
@ -178,7 +185,7 @@ fn default_implicit_conversion_space(
|
|||||||
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||||
{
|
{
|
||||||
Ok(Some(ConversionKind::PtrToPtr))
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
} else if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
} else if operand_space == ast::StateSpace::Reg {
|
||||||
match operand_type {
|
match operand_type {
|
||||||
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
|
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
|
||||||
if *operand_ptr_space == instruction_space =>
|
if *operand_ptr_space == instruction_space =>
|
||||||
@ -210,7 +217,7 @@ fn default_implicit_conversion_space(
|
|||||||
},
|
},
|
||||||
_ => Err(error_mismatched_type()),
|
_ => Err(error_mismatched_type()),
|
||||||
}
|
}
|
||||||
} else if space_is_compatible(instruction_space, ast::StateSpace::Reg) {
|
} else if instruction_space == ast::StateSpace::Reg {
|
||||||
match instruction_type {
|
match instruction_type {
|
||||||
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
||||||
if operand_space == *instruction_ptr_space =>
|
if operand_space == *instruction_ptr_space =>
|
||||||
@ -234,7 +241,7 @@ fn default_implicit_conversion_type(
|
|||||||
operand_type: &ast::Type,
|
operand_type: &ast::Type,
|
||||||
instruction_type: &ast::Type,
|
instruction_type: &ast::Type,
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if space_is_compatible(space, ast::StateSpace::Reg) {
|
if space == ast::StateSpace::Reg {
|
||||||
if should_bitcast(instruction_type, operand_type) {
|
if should_bitcast(instruction_type, operand_type) {
|
||||||
Ok(Some(ConversionKind::Default))
|
Ok(Some(ConversionKind::Default))
|
||||||
} else {
|
} else {
|
||||||
@ -257,8 +264,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
|||||||
| ast::StateSpace::Param
|
| ast::StateSpace::Param
|
||||||
| ast::StateSpace::ParamEntry
|
| ast::StateSpace::ParamEntry
|
||||||
| ast::StateSpace::ParamFunc
|
| ast::StateSpace::ParamFunc
|
||||||
| ast::StateSpace::Generic
|
| ast::StateSpace::Generic => false,
|
||||||
| ast::StateSpace::Sreg => false,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,7 +300,7 @@ pub(crate) fn should_convert_relaxed_dst_wrapper(
|
|||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if !space_is_compatible(operand_space, instruction_space) {
|
if operand_space != instruction_space {
|
||||||
return Err(TranslateError::MismatchedType);
|
return Err(TranslateError::MismatchedType);
|
||||||
}
|
}
|
||||||
if operand_type == instruction_type {
|
if operand_type == instruction_type {
|
||||||
@ -371,7 +377,7 @@ pub(crate) fn should_convert_relaxed_src_wrapper(
|
|||||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
if !space_is_compatible(operand_space, instruction_space) {
|
if operand_space != instruction_space {
|
||||||
return Err(error_mismatched_type());
|
return Err(error_mismatched_type());
|
||||||
}
|
}
|
||||||
if operand_type == instruction_type {
|
if operand_type == instruction_type {
|
||||||
|
426
ptx/src/pass/insert_implicit_conversions2.rs
Normal file
426
ptx/src/pass/insert_implicit_conversions2.rs
Normal file
@ -0,0 +1,426 @@
|
|||||||
|
use std::mem;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
/*
|
||||||
|
There are several kinds of implicit conversions in PTX:
|
||||||
|
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||||
|
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||||
|
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||||
|
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||||
|
documented special ld/st/cvt conversion rules for destination operands
|
||||||
|
- st.param [x] y (used as function return arguments) same rule as above applies
|
||||||
|
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
||||||
|
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
||||||
|
documented special ld/st/cvt conversion rules are applied to dst
|
||||||
|
- generic/global st: for instruction `st [x], y`, x must be of type
|
||||||
|
b64/u64/s64, which is bitcast to a pointer
|
||||||
|
*/
|
||||||
|
pub(super) fn run<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
|
directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'a, 'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
var @ Directive2::Variable(..) => var,
|
||||||
|
Directive2::Method(mut method) => {
|
||||||
|
method.body = method
|
||||||
|
.body
|
||||||
|
.map(|statements| run_statements(resolver, statements))
|
||||||
|
.transpose()?;
|
||||||
|
Directive2::Method(method)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statements<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
func: Vec<ExpandedStatement>,
|
||||||
|
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(func.len());
|
||||||
|
for s in func.into_iter() {
|
||||||
|
insert_implicit_conversions_impl(resolver, &mut result, s)?;
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_implicit_conversions_impl<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
func: &mut Vec<ExpandedStatement>,
|
||||||
|
stmt: ExpandedStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let mut post_conv = Vec::new();
|
||||||
|
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
||||||
|
&mut |operand,
|
||||||
|
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst,
|
||||||
|
relaxed_type_check| {
|
||||||
|
let (instr_type, instruction_space) = match type_state {
|
||||||
|
None => return Ok(operand),
|
||||||
|
Some(t) => t,
|
||||||
|
};
|
||||||
|
let (operand_type, operand_space) = resolver.get_typed(operand)?;
|
||||||
|
let conversion_fn = if relaxed_type_check {
|
||||||
|
if is_dst {
|
||||||
|
should_convert_relaxed_dst_wrapper
|
||||||
|
} else {
|
||||||
|
should_convert_relaxed_src_wrapper
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
default_implicit_conversion
|
||||||
|
};
|
||||||
|
match conversion_fn(
|
||||||
|
(*operand_space, &operand_type),
|
||||||
|
(instruction_space, instr_type),
|
||||||
|
)? {
|
||||||
|
Some(conv_kind) => {
|
||||||
|
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
||||||
|
let mut from_type = instr_type.clone();
|
||||||
|
let mut from_space = instruction_space;
|
||||||
|
let mut to_type = operand_type.clone();
|
||||||
|
let mut to_space = *operand_space;
|
||||||
|
let mut src =
|
||||||
|
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
|
||||||
|
let mut dst = operand;
|
||||||
|
let result = Ok::<_, TranslateError>(src);
|
||||||
|
if !is_dst {
|
||||||
|
mem::swap(&mut src, &mut dst);
|
||||||
|
mem::swap(&mut from_type, &mut to_type);
|
||||||
|
mem::swap(&mut from_space, &mut to_space);
|
||||||
|
}
|
||||||
|
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||||
|
src,
|
||||||
|
dst,
|
||||||
|
from_type,
|
||||||
|
from_space,
|
||||||
|
to_type,
|
||||||
|
to_space,
|
||||||
|
kind: conv_kind,
|
||||||
|
}));
|
||||||
|
result
|
||||||
|
}
|
||||||
|
None => Ok(operand),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
func.push(statement);
|
||||||
|
func.append(&mut post_conv);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn default_implicit_conversion(
|
||||||
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if instruction_space == ast::StateSpace::Reg {
|
||||||
|
if operand_space == ast::StateSpace::Reg {
|
||||||
|
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||||
|
(operand_type, instruction_type)
|
||||||
|
{
|
||||||
|
if scalar.kind() == ast::ScalarKind::Bit
|
||||||
|
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
||||||
|
{
|
||||||
|
return Ok(Some(ConversionKind::Default));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if is_addressable(operand_space) {
|
||||||
|
return Ok(Some(ConversionKind::AddressOf));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if instruction_space != operand_space {
|
||||||
|
default_implicit_conversion_space(
|
||||||
|
(operand_space, operand_type),
|
||||||
|
(instruction_space, instruction_type),
|
||||||
|
)
|
||||||
|
} else if instruction_type != operand_type {
|
||||||
|
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_addressable(this: ast::StateSpace) -> bool {
|
||||||
|
match this {
|
||||||
|
ast::StateSpace::Const
|
||||||
|
| ast::StateSpace::Generic
|
||||||
|
| ast::StateSpace::Global
|
||||||
|
| ast::StateSpace::Local
|
||||||
|
| ast::StateSpace::Shared => true,
|
||||||
|
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
||||||
|
ast::StateSpace::SharedCluster
|
||||||
|
| ast::StateSpace::SharedCta
|
||||||
|
| ast::StateSpace::ParamEntry
|
||||||
|
| ast::StateSpace::ParamFunc => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Space is different
|
||||||
|
fn default_implicit_conversion_space(
|
||||||
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
||||||
|
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||||
|
{
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
} else if operand_space == ast::StateSpace::Reg {
|
||||||
|
match operand_type {
|
||||||
|
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
|
||||||
|
if *operand_ptr_space == instruction_space =>
|
||||||
|
{
|
||||||
|
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: 32 bit
|
||||||
|
ast::Type::Scalar(ast::ScalarType::B64)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
||||||
|
ast::StateSpace::Global
|
||||||
|
| ast::StateSpace::Generic
|
||||||
|
| ast::StateSpace::Const
|
||||||
|
| ast::StateSpace::Local
|
||||||
|
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||||
|
_ => Err(error_mismatched_type()),
|
||||||
|
},
|
||||||
|
ast::Type::Scalar(ast::ScalarType::B32)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||||
|
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||||
|
Ok(Some(ConversionKind::BitToPtr))
|
||||||
|
}
|
||||||
|
_ => Err(error_mismatched_type()),
|
||||||
|
},
|
||||||
|
_ => Err(error_mismatched_type()),
|
||||||
|
}
|
||||||
|
} else if instruction_space == ast::StateSpace::Reg {
|
||||||
|
match instruction_type {
|
||||||
|
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
||||||
|
if operand_space == *instruction_ptr_space =>
|
||||||
|
{
|
||||||
|
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(error_mismatched_type()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(error_mismatched_type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Space is same, but type is different
|
||||||
|
fn default_implicit_conversion_type(
|
||||||
|
space: ast::StateSpace,
|
||||||
|
operand_type: &ast::Type,
|
||||||
|
instruction_type: &ast::Type,
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if space == ast::StateSpace::Reg {
|
||||||
|
if should_bitcast(instruction_type, operand_type) {
|
||||||
|
Ok(Some(ConversionKind::Default))
|
||||||
|
} else {
|
||||||
|
Err(TranslateError::MismatchedType)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||||
|
match this {
|
||||||
|
ast::StateSpace::Global
|
||||||
|
| ast::StateSpace::Const
|
||||||
|
| ast::StateSpace::Local
|
||||||
|
| ptx_parser::StateSpace::SharedCta
|
||||||
|
| ast::StateSpace::SharedCluster
|
||||||
|
| ast::StateSpace::Shared => true,
|
||||||
|
ast::StateSpace::Reg
|
||||||
|
| ast::StateSpace::Param
|
||||||
|
| ast::StateSpace::ParamEntry
|
||||||
|
| ast::StateSpace::ParamFunc
|
||||||
|
| ast::StateSpace::Generic => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||||
|
match (instr, operand) {
|
||||||
|
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||||
|
if inst.size_of() != operand.size_of() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
match inst.kind() {
|
||||||
|
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||||
|
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||||
|
ast::ScalarKind::Signed => {
|
||||||
|
operand.kind() == ast::ScalarKind::Bit
|
||||||
|
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Unsigned => {
|
||||||
|
operand.kind() == ast::ScalarKind::Bit
|
||||||
|
|| operand.kind() == ast::ScalarKind::Signed
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Pred => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
||||||
|
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
||||||
|
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||||
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if operand_space != instruction_space {
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
if operand_type == instruction_type {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||||
|
conv @ Some(_) => Ok(conv),
|
||||||
|
None => Err(TranslateError::MismatchedType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||||
|
fn should_convert_relaxed_dst(
|
||||||
|
dst_type: &ast::Type,
|
||||||
|
instr_type: &ast::Type,
|
||||||
|
) -> Option<ConversionKind> {
|
||||||
|
if dst_type == instr_type {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
match (dst_type, instr_type) {
|
||||||
|
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||||
|
ast::ScalarKind::Bit => {
|
||||||
|
if instr_type.size_of() <= dst_type.size_of() {
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Signed => {
|
||||||
|
if dst_type.kind() != ast::ScalarKind::Float {
|
||||||
|
if instr_type.size_of() == dst_type.size_of() {
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else if instr_type.size_of() < dst_type.size_of() {
|
||||||
|
Some(ConversionKind::SignExtend)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Unsigned => {
|
||||||
|
if instr_type.size_of() <= dst_type.size_of()
|
||||||
|
&& dst_type.kind() != ast::ScalarKind::Float
|
||||||
|
{
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Float => {
|
||||||
|
if instr_type.size_of() <= dst_type.size_of()
|
||||||
|
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||||
|
{
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Pred => None,
|
||||||
|
},
|
||||||
|
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||||
|
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||||
|
should_convert_relaxed_dst(
|
||||||
|
&ast::Type::Scalar(*dst_type),
|
||||||
|
&ast::Type::Scalar(*instr_type),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||||
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if operand_space != instruction_space {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
if operand_type == instruction_type {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||||
|
conv @ Some(_) => Ok(conv),
|
||||||
|
None => Err(error_mismatched_type()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||||
|
fn should_convert_relaxed_src(
|
||||||
|
src_type: &ast::Type,
|
||||||
|
instr_type: &ast::Type,
|
||||||
|
) -> Option<ConversionKind> {
|
||||||
|
if src_type == instr_type {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
match (src_type, instr_type) {
|
||||||
|
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||||
|
ast::ScalarKind::Bit => {
|
||||||
|
if instr_type.size_of() <= src_type.size_of() {
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||||
|
if instr_type.size_of() <= src_type.size_of()
|
||||||
|
&& src_type.kind() != ast::ScalarKind::Float
|
||||||
|
{
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Float => {
|
||||||
|
if instr_type.size_of() <= src_type.size_of()
|
||||||
|
&& src_type.kind() == ast::ScalarKind::Bit
|
||||||
|
{
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Pred => None,
|
||||||
|
},
|
||||||
|
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||||
|
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||||
|
should_convert_relaxed_src(
|
||||||
|
&ast::Type::Scalar(*dst_type),
|
||||||
|
&ast::Type::Scalar(*instr_type),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
@ -189,7 +189,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
|||||||
return Ok(symbol);
|
return Ok(symbol);
|
||||||
};
|
};
|
||||||
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
|
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
|
||||||
if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
|
if var_space != ast::StateSpace::Reg || !is_variable {
|
||||||
return Ok(symbol);
|
return Ok(symbol);
|
||||||
};
|
};
|
||||||
let member_index = match member_index {
|
let member_index = match member_index {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
use rspirv::{binary::Assemble, dr};
|
use rspirv::{binary::Assemble, dr};
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::num::NonZeroU8;
|
use std::num::NonZeroU8;
|
||||||
use std::{
|
use std::{
|
||||||
@ -12,20 +13,31 @@ use std::{
|
|||||||
mem,
|
mem,
|
||||||
rc::Rc,
|
rc::Rc,
|
||||||
};
|
};
|
||||||
|
use strum::IntoEnumIterator;
|
||||||
|
use strum_macros::EnumIter;
|
||||||
|
|
||||||
mod convert_dynamic_shared_memory_usage;
|
mod convert_dynamic_shared_memory_usage;
|
||||||
mod convert_to_stateful_memory_access;
|
mod convert_to_stateful_memory_access;
|
||||||
mod convert_to_typed;
|
mod convert_to_typed;
|
||||||
|
mod deparamize_functions;
|
||||||
pub(crate) mod emit_llvm;
|
pub(crate) mod emit_llvm;
|
||||||
mod emit_spirv;
|
mod emit_spirv;
|
||||||
mod expand_arguments;
|
mod expand_arguments;
|
||||||
|
mod expand_operands;
|
||||||
mod extract_globals;
|
mod extract_globals;
|
||||||
mod fix_special_registers;
|
mod fix_special_registers;
|
||||||
|
mod fix_special_registers2;
|
||||||
|
mod hoist_globals;
|
||||||
|
mod insert_explicit_load_store;
|
||||||
mod insert_implicit_conversions;
|
mod insert_implicit_conversions;
|
||||||
|
mod insert_implicit_conversions2;
|
||||||
mod insert_mem_ssa_statements;
|
mod insert_mem_ssa_statements;
|
||||||
mod normalize_identifiers;
|
mod normalize_identifiers;
|
||||||
|
mod normalize_identifiers2;
|
||||||
mod normalize_labels;
|
mod normalize_labels;
|
||||||
mod normalize_predicates;
|
mod normalize_predicates;
|
||||||
|
mod normalize_predicates2;
|
||||||
|
mod resolve_function_pointers;
|
||||||
|
|
||||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||||
@ -57,7 +69,30 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
|||||||
})?;
|
})?;
|
||||||
normalize_variable_decls(&mut directives);
|
normalize_variable_decls(&mut directives);
|
||||||
let denorm_information = compute_denorm_information(&directives);
|
let denorm_information = compute_denorm_information(&directives);
|
||||||
let llvm_ir = emit_llvm::run(&id_defs, call_map, directives)?;
|
todo!()
|
||||||
|
/*
|
||||||
|
let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?;
|
||||||
|
Ok(Module {
|
||||||
|
llvm_ir,
|
||||||
|
kernel_info: HashMap::new(),
|
||||||
|
}) */
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
|
||||||
|
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
|
||||||
|
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||||
|
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
|
||||||
|
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
|
||||||
|
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
|
||||||
|
let directives = resolve_function_pointers::run(directives)?;
|
||||||
|
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||||
|
let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
|
||||||
|
expand_operands::run(&mut flat_resolver, directives)?;
|
||||||
|
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||||
|
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||||
|
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||||
|
let directives = hoist_globals::run(directives)?;
|
||||||
|
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
|
||||||
Ok(Module {
|
Ok(Module {
|
||||||
llvm_ir,
|
llvm_ir,
|
||||||
kernel_info: HashMap::new(),
|
kernel_info: HashMap::new(),
|
||||||
@ -319,7 +354,7 @@ pub struct KernelInfo {
|
|||||||
pub uses_shared_mem: bool,
|
pub uses_shared_mem: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
|
||||||
enum PtxSpecialRegister {
|
enum PtxSpecialRegister {
|
||||||
Tid,
|
Tid,
|
||||||
Ntid,
|
Ntid,
|
||||||
@ -342,6 +377,17 @@ impl PtxSpecialRegister {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Tid => "%tid",
|
||||||
|
Self::Ntid => "%ntid",
|
||||||
|
Self::Ctaid => "%ctaid",
|
||||||
|
Self::Nctaid => "%nctaid",
|
||||||
|
Self::Clock => "%clock",
|
||||||
|
Self::LanemaskLt => "%lanemask_lt",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_type(self) -> ast::Type {
|
fn get_type(self) -> ast::Type {
|
||||||
match self {
|
match self {
|
||||||
PtxSpecialRegister::Tid
|
PtxSpecialRegister::Tid
|
||||||
@ -525,7 +571,7 @@ impl<'b> NumericIdResolver<'b> {
|
|||||||
Some(Some(x)) => Ok(x.clone()),
|
Some(Some(x)) => Ok(x.clone()),
|
||||||
Some(None) => Err(TranslateError::UntypedSymbol),
|
Some(None) => Err(TranslateError::UntypedSymbol),
|
||||||
None => match self.special_registers.get(id) {
|
None => match self.special_registers.get(id) {
|
||||||
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
|
Some(x) => Ok((x.get_type(), ast::StateSpace::Reg, true)),
|
||||||
None => match self.global_type_check.get(&id) {
|
None => match self.global_type_check.get(&id) {
|
||||||
Some(Some(result)) => Ok(result.clone()),
|
Some(Some(result)) => Ok(result.clone()),
|
||||||
Some(None) | None => Err(TranslateError::UntypedSymbol),
|
Some(None) | None => Err(TranslateError::UntypedSymbol),
|
||||||
@ -722,6 +768,7 @@ enum Statement<I, P: ast::Operand> {
|
|||||||
PtrAccess(PtrAccess<P>),
|
PtrAccess(PtrAccess<P>),
|
||||||
RepackVector(RepackVectorDetails),
|
RepackVector(RepackVectorDetails),
|
||||||
FunctionPointer(FunctionPointerDetails),
|
FunctionPointer(FunctionPointerDetails),
|
||||||
|
VectorAccess(VectorAccess),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
@ -890,6 +937,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
|||||||
offset_src,
|
offset_src,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Statement::VectorAccess(VectorAccess {
|
||||||
|
scalar_type,
|
||||||
|
vector_width,
|
||||||
|
dst,
|
||||||
|
src: vector_src,
|
||||||
|
member,
|
||||||
|
}) => {
|
||||||
|
let dst: SpirvWord = visitor.visit_ident(
|
||||||
|
dst,
|
||||||
|
Some((&scalar_type.into(), ast::StateSpace::Reg)),
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
let src = visitor.visit_ident(
|
||||||
|
vector_src,
|
||||||
|
Some((
|
||||||
|
&ast::Type::Vector(vector_width, scalar_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
Statement::VectorAccess(VectorAccess {
|
||||||
|
scalar_type,
|
||||||
|
vector_width,
|
||||||
|
dst,
|
||||||
|
src,
|
||||||
|
member,
|
||||||
|
})
|
||||||
|
}
|
||||||
Statement::RepackVector(RepackVectorDetails {
|
Statement::RepackVector(RepackVectorDetails {
|
||||||
is_extract,
|
is_extract,
|
||||||
typ,
|
typ,
|
||||||
@ -1207,12 +1284,6 @@ impl<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
|
|
||||||
this == other
|
|
||||||
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|
|
||||||
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_external_fn_call<'a>(
|
fn register_external_fn_call<'a>(
|
||||||
id_defs: &mut NumericIdResolver,
|
id_defs: &mut NumericIdResolver,
|
||||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
@ -1450,6 +1521,7 @@ fn compute_denorm_information<'input>(
|
|||||||
Statement::Label(_) => {}
|
Statement::Label(_) => {}
|
||||||
Statement::Variable(_) => {}
|
Statement::Variable(_) => {}
|
||||||
Statement::PtrAccess { .. } => {}
|
Statement::PtrAccess { .. } => {}
|
||||||
|
Statement::VectorAccess { .. } => {}
|
||||||
Statement::RepackVector(_) => {}
|
Statement::RepackVector(_) => {}
|
||||||
Statement::FunctionPointer(_) => {}
|
Statement::FunctionPointer(_) => {}
|
||||||
}
|
}
|
||||||
@ -1663,3 +1735,278 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
|
||||||
|
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
|
||||||
|
Method(Function2<'input, Instruction, Operand>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
|
||||||
|
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
|
||||||
|
pub globals: Vec<ast::Variable<SpirvWord>>,
|
||||||
|
pub body: Option<Vec<Statement<Instruction, Operand>>>,
|
||||||
|
import_as: Option<String>,
|
||||||
|
tuning: Vec<ast::TuningDirective>,
|
||||||
|
linkage: ast::LinkingDirective,
|
||||||
|
}
|
||||||
|
|
||||||
|
type NormalizedDirective2<'input> = Directive2<
|
||||||
|
'input,
|
||||||
|
(
|
||||||
|
Option<ast::PredAt<SpirvWord>>,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
),
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
type NormalizedFunction2<'input> = Function2<
|
||||||
|
'input,
|
||||||
|
(
|
||||||
|
Option<ast::PredAt<SpirvWord>>,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
),
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
type UnconditionalDirective<'input> = Directive2<
|
||||||
|
'input,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
type UnconditionalFunction<'input> = Function2<
|
||||||
|
'input,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
struct GlobalStringIdentResolver2<'input> {
|
||||||
|
pub(crate) current_id: SpirvWord,
|
||||||
|
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'input> GlobalStringIdentResolver2<'input> {
|
||||||
|
fn new(spirv_word: SpirvWord) -> Self {
|
||||||
|
Self {
|
||||||
|
current_id: spirv_word,
|
||||||
|
ident_map: FxHashMap::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_named(
|
||||||
|
&mut self,
|
||||||
|
name: Cow<'input, str>,
|
||||||
|
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||||
|
) -> SpirvWord {
|
||||||
|
let new_id = self.current_id;
|
||||||
|
self.ident_map.insert(
|
||||||
|
new_id,
|
||||||
|
IdentEntry {
|
||||||
|
name: Some(name),
|
||||||
|
type_space,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
self.current_id.0 += 1;
|
||||||
|
new_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
|
||||||
|
let new_id = self.current_id;
|
||||||
|
self.ident_map.insert(
|
||||||
|
new_id,
|
||||||
|
IdentEntry {
|
||||||
|
name: None,
|
||||||
|
type_space,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
self.current_id.0 += 1;
|
||||||
|
new_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
|
||||||
|
match self.ident_map.get(&id) {
|
||||||
|
Some(IdentEntry {
|
||||||
|
type_space: Some(type_space),
|
||||||
|
..
|
||||||
|
}) => Ok(type_space),
|
||||||
|
_ => Err(error_unknown_symbol()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct IdentEntry<'input> {
|
||||||
|
name: Option<Cow<'input, str>>,
|
||||||
|
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ScopedResolver<'input, 'b> {
|
||||||
|
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
|
||||||
|
scopes: Vec<ScopeMarker<'input>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'input, 'b> ScopedResolver<'input, 'b> {
|
||||||
|
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||||
|
Self {
|
||||||
|
flat_resolver,
|
||||||
|
scopes: vec![ScopeMarker::new()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_scope(&mut self) {
|
||||||
|
self.scopes.push(ScopeMarker::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end_scope(&mut self) {
|
||||||
|
let scope = self.scopes.pop().unwrap();
|
||||||
|
scope.flush(self.flat_resolver);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(
|
||||||
|
&mut self,
|
||||||
|
name: Cow<'input, str>,
|
||||||
|
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let result = self.flat_resolver.current_id;
|
||||||
|
self.flat_resolver.current_id.0 += 1;
|
||||||
|
let current_scope = self.scopes.last_mut().unwrap();
|
||||||
|
if current_scope
|
||||||
|
.name_to_ident
|
||||||
|
.insert(name.clone(), result)
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
return Err(error_unknown_symbol());
|
||||||
|
}
|
||||||
|
current_scope.ident_map.insert(
|
||||||
|
result,
|
||||||
|
IdentEntry {
|
||||||
|
name: Some(name),
|
||||||
|
type_space,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
|
||||||
|
self.scopes
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find_map(|resolver| resolver.name_to_ident.get(name).copied())
|
||||||
|
.ok_or_else(|| error_unreachable())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let current_scope = self.scopes.last().unwrap();
|
||||||
|
current_scope
|
||||||
|
.name_to_ident
|
||||||
|
.get(label)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(|| error_unreachable())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ScopeMarker<'input> {
|
||||||
|
ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||||
|
name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'input> ScopeMarker<'input> {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
ident_map: FxHashMap::default(),
|
||||||
|
name_to_ident: FxHashMap::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
|
||||||
|
resolver.ident_map.extend(self.ident_map);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SpecialRegistersMap2 {
|
||||||
|
reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||||
|
id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SpecialRegistersMap2 {
|
||||||
|
fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
|
||||||
|
let mut result = SpecialRegistersMap2 {
|
||||||
|
reg_to_id: FxHashMap::default(),
|
||||||
|
id_to_reg: FxHashMap::default(),
|
||||||
|
};
|
||||||
|
for sreg in PtxSpecialRegister::iter() {
|
||||||
|
let text = sreg.as_str();
|
||||||
|
let id = resolver.add(
|
||||||
|
Cow::Borrowed(text),
|
||||||
|
Some((sreg.get_type(), ast::StateSpace::Reg)),
|
||||||
|
)?;
|
||||||
|
result.reg_to_id.insert(sreg, id);
|
||||||
|
result.id_to_reg.insert(id, sreg);
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
|
||||||
|
self.id_to_reg.get(&id).copied()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
|
||||||
|
match self.reg_to_id.entry(reg) {
|
||||||
|
hash_map::Entry::Occupied(e) => *e.get(),
|
||||||
|
hash_map::Entry::Vacant(e) => {
|
||||||
|
let numeric_id = SpirvWord(current_id.0);
|
||||||
|
current_id.0 += 1;
|
||||||
|
e.insert(numeric_id);
|
||||||
|
self.id_to_reg.insert(numeric_id, reg);
|
||||||
|
numeric_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_declarations<'a, 'input>(
|
||||||
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
|
) -> impl ExactSizeIterator<
|
||||||
|
Item = (
|
||||||
|
PtxSpecialRegister,
|
||||||
|
ast::MethodDeclaration<'input, SpirvWord>,
|
||||||
|
),
|
||||||
|
> + 'a {
|
||||||
|
PtxSpecialRegister::iter().map(|sreg| {
|
||||||
|
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||||
|
let name =
|
||||||
|
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
|
||||||
|
let return_type = sreg.get_function_return_type();
|
||||||
|
let input_type = sreg.get_function_return_type();
|
||||||
|
(
|
||||||
|
sreg,
|
||||||
|
ast::MethodDeclaration {
|
||||||
|
return_arguments: vec![ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: return_type.into(),
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
name: resolver
|
||||||
|
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
|
||||||
|
array_init: Vec::new(),
|
||||||
|
}],
|
||||||
|
name: name,
|
||||||
|
input_arguments: vec![ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: input_type.into(),
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
name: resolver
|
||||||
|
.register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
|
||||||
|
array_init: Vec::new(),
|
||||||
|
}],
|
||||||
|
shared_mem: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct VectorAccess {
|
||||||
|
scalar_type: ast::ScalarType,
|
||||||
|
vector_width: u8,
|
||||||
|
dst: SpirvWord,
|
||||||
|
src: SpirvWord,
|
||||||
|
member: u8,
|
||||||
|
}
|
||||||
|
199
ptx/src/pass/normalize_identifiers2.rs
Normal file
199
ptx/src/pass/normalize_identifiers2.rs
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
|
|
||||||
|
pub(crate) fn run<'input, 'b>(
|
||||||
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
|
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||||
|
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
|
||||||
|
resolver.start_scope();
|
||||||
|
let result = directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
resolver.end_scope();
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'input, 'b>(
|
||||||
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
|
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||||
|
) -> Result<NormalizedDirective2<'input>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
ast::Directive::Variable(linking, var) => {
|
||||||
|
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||||
|
}
|
||||||
|
ast::Directive::Method(linking, directive) => {
|
||||||
|
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_method<'input, 'b>(
|
||||||
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
|
linkage: ast::LinkingDirective,
|
||||||
|
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
|
) -> Result<NormalizedFunction2<'input>, TranslateError> {
|
||||||
|
let name = match method.func_directive.name {
|
||||||
|
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||||
|
ast::MethodName::Func(text) => {
|
||||||
|
ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
resolver.start_scope();
|
||||||
|
let func_decl = run_function_decl(resolver, method.func_directive, name)?;
|
||||||
|
let body = method
|
||||||
|
.body
|
||||||
|
.map(|statements| {
|
||||||
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
|
run_statements(resolver, &mut result, statements)?;
|
||||||
|
Ok::<_, TranslateError>(result)
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
resolver.end_scope();
|
||||||
|
Ok(Function2 {
|
||||||
|
func_decl,
|
||||||
|
globals: Vec::new(),
|
||||||
|
body,
|
||||||
|
import_as: None,
|
||||||
|
tuning: method.tuning,
|
||||||
|
linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_function_decl<'input, 'b>(
|
||||||
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
|
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||||
|
name: ast::MethodName<'input, SpirvWord>,
|
||||||
|
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
|
||||||
|
assert!(func_directive.shared_mem.is_none());
|
||||||
|
let return_arguments = func_directive
|
||||||
|
.return_arguments
|
||||||
|
.into_iter()
|
||||||
|
.map(|var| run_variable(resolver, var))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let input_arguments = func_directive
|
||||||
|
.input_arguments
|
||||||
|
.into_iter()
|
||||||
|
.map(|var| run_variable(resolver, var))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
Ok(ast::MethodDeclaration {
|
||||||
|
return_arguments,
|
||||||
|
name,
|
||||||
|
input_arguments,
|
||||||
|
shared_mem: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_variable<'input, 'b>(
|
||||||
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
|
variable: ast::Variable<&'input str>,
|
||||||
|
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
||||||
|
Ok(ast::Variable {
|
||||||
|
name: resolver.add(
|
||||||
|
Cow::Borrowed(variable.name),
|
||||||
|
Some((variable.v_type.clone(), variable.state_space)),
|
||||||
|
)?,
|
||||||
|
align: variable.align,
|
||||||
|
v_type: variable.v_type,
|
||||||
|
state_space: variable.state_space,
|
||||||
|
array_init: variable.array_init,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statements<'input, 'b>(
|
||||||
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
|
result: &mut Vec<NormalizedStatement>,
|
||||||
|
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
for statement in statements.iter() {
|
||||||
|
match statement {
|
||||||
|
ast::Statement::Label(label) => {
|
||||||
|
resolver.add(Cow::Borrowed(*label), None)?;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for statement in statements {
|
||||||
|
match statement {
|
||||||
|
ast::Statement::Label(label) => {
|
||||||
|
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
||||||
|
}
|
||||||
|
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
|
||||||
|
ast::Statement::Instruction(predicate, instruction) => {
|
||||||
|
result.push(Statement::Instruction((
|
||||||
|
predicate
|
||||||
|
.map(|pred| {
|
||||||
|
Ok::<_, TranslateError>(ast::PredAt {
|
||||||
|
not: pred.not,
|
||||||
|
label: resolver.get(pred.label)?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.transpose()?,
|
||||||
|
run_instruction(resolver, instruction)?,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
ast::Statement::Block(block) => {
|
||||||
|
resolver.start_scope();
|
||||||
|
run_statements(resolver, result, block)?;
|
||||||
|
resolver.end_scope();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_instruction<'input, 'b>(
|
||||||
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
|
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
||||||
|
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||||
|
ast::visit_map(instruction, &mut |name: &'input str,
|
||||||
|
_: Option<(
|
||||||
|
&ast::Type,
|
||||||
|
ast::StateSpace,
|
||||||
|
)>,
|
||||||
|
_,
|
||||||
|
_| {
|
||||||
|
resolver.get(&name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_multivariable<'input, 'b>(
|
||||||
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
|
result: &mut Vec<NormalizedStatement>,
|
||||||
|
variable: ast::MultiVariable<&'input str>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
match variable.count {
|
||||||
|
Some(count) => {
|
||||||
|
for i in 0..count {
|
||||||
|
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
|
||||||
|
let ident = resolver.add(
|
||||||
|
name,
|
||||||
|
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||||
|
)?;
|
||||||
|
result.push(Statement::Variable(ast::Variable {
|
||||||
|
align: variable.var.align,
|
||||||
|
v_type: variable.var.v_type.clone(),
|
||||||
|
state_space: variable.var.state_space,
|
||||||
|
name: ident,
|
||||||
|
array_init: variable.var.array_init.clone(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let name = Cow::Borrowed(variable.var.name);
|
||||||
|
let ident = resolver.add(
|
||||||
|
name,
|
||||||
|
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||||
|
)?;
|
||||||
|
result.push(Statement::Variable(ast::Variable {
|
||||||
|
align: variable.var.align,
|
||||||
|
v_type: variable.var.v_type.clone(),
|
||||||
|
state_space: variable.var.state_space,
|
||||||
|
name: ident,
|
||||||
|
array_init: variable.var.array_init.clone(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -26,6 +26,7 @@ pub(super) fn run(
|
|||||||
| Statement::Constant(..)
|
| Statement::Constant(..)
|
||||||
| Statement::Label(..)
|
| Statement::Label(..)
|
||||||
| Statement::PtrAccess { .. }
|
| Statement::PtrAccess { .. }
|
||||||
|
| Statement::VectorAccess { .. }
|
||||||
| Statement::RepackVector(..)
|
| Statement::RepackVector(..)
|
||||||
| Statement::FunctionPointer(..) => {}
|
| Statement::FunctionPointer(..) => {}
|
||||||
}
|
}
|
||||||
|
84
ptx/src/pass/normalize_predicates2.rs
Normal file
84
ptx/src/pass/normalize_predicates2.rs
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(crate) fn run<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directives: Vec<NormalizedDirective2<'input>>,
|
||||||
|
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||||
|
directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directive: NormalizedDirective2<'input>,
|
||||||
|
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||||
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_method<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
method: NormalizedFunction2<'input>,
|
||||||
|
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||||
|
let body = method
|
||||||
|
.body
|
||||||
|
.map(|statements| {
|
||||||
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
|
for statement in statements {
|
||||||
|
run_statement(resolver, &mut result, statement)?;
|
||||||
|
}
|
||||||
|
Ok::<_, TranslateError>(result)
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
Ok(Function2 {
|
||||||
|
func_decl: method.func_decl,
|
||||||
|
globals: method.globals,
|
||||||
|
body,
|
||||||
|
import_as: method.import_as,
|
||||||
|
tuning: method.tuning,
|
||||||
|
linkage: method.linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statement<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
result: &mut Vec<UnconditionalStatement>,
|
||||||
|
statement: NormalizedStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
Ok(match statement {
|
||||||
|
Statement::Label(label) => result.push(Statement::Label(label)),
|
||||||
|
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||||
|
Statement::Instruction((predicate, instruction)) => {
|
||||||
|
if let Some(pred) = predicate {
|
||||||
|
let if_true = resolver.register_unnamed(None);
|
||||||
|
let if_false = resolver.register_unnamed(None);
|
||||||
|
let folded_bra = match &instruction {
|
||||||
|
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
let mut branch = BrachCondition {
|
||||||
|
predicate: pred.label,
|
||||||
|
if_true: folded_bra.unwrap_or(if_true),
|
||||||
|
if_false,
|
||||||
|
};
|
||||||
|
if pred.not {
|
||||||
|
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||||
|
}
|
||||||
|
result.push(Statement::Conditional(branch));
|
||||||
|
if folded_bra.is_none() {
|
||||||
|
result.push(Statement::Label(if_true));
|
||||||
|
result.push(Statement::Instruction(instruction));
|
||||||
|
}
|
||||||
|
result.push(Statement::Label(if_false));
|
||||||
|
} else {
|
||||||
|
result.push(Statement::Instruction(instruction));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
})
|
||||||
|
}
|
82
ptx/src/pass/resolve_function_pointers.rs
Normal file
82
ptx/src/pass/resolve_function_pointers.rs
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
use rustc_hash::FxHashSet;
|
||||||
|
|
||||||
|
pub(crate) fn run<'input>(
|
||||||
|
directives: Vec<UnconditionalDirective<'input>>,
|
||||||
|
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||||
|
let mut functions = FxHashSet::default();
|
||||||
|
directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(&mut functions, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'input>(
|
||||||
|
functions: &mut FxHashSet<SpirvWord>,
|
||||||
|
directive: UnconditionalDirective<'input>,
|
||||||
|
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
var @ Directive2::Variable(..) => var,
|
||||||
|
Directive2::Method(method) => {
|
||||||
|
{
|
||||||
|
let func_decl = &method.func_decl;
|
||||||
|
match func_decl.name {
|
||||||
|
ptx_parser::MethodName::Kernel(_) => {}
|
||||||
|
ptx_parser::MethodName::Func(name) => {
|
||||||
|
functions.insert(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Directive2::Method(run_method(functions, method)?)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_method<'input>(
|
||||||
|
functions: &mut FxHashSet<SpirvWord>,
|
||||||
|
method: UnconditionalFunction<'input>,
|
||||||
|
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||||
|
let body = method
|
||||||
|
.body
|
||||||
|
.map(|statements| {
|
||||||
|
statements
|
||||||
|
.into_iter()
|
||||||
|
.map(|statement| run_statement(functions, statement))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
Ok(Function2 {
|
||||||
|
func_decl: method.func_decl,
|
||||||
|
globals: method.globals,
|
||||||
|
body,
|
||||||
|
import_as: method.import_as,
|
||||||
|
tuning: method.tuning,
|
||||||
|
linkage: method.linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statement<'input>(
|
||||||
|
functions: &mut FxHashSet<SpirvWord>,
|
||||||
|
statement: UnconditionalStatement,
|
||||||
|
) -> Result<UnconditionalStatement, TranslateError> {
|
||||||
|
Ok(match statement {
|
||||||
|
Statement::Instruction(ast::Instruction::Mov {
|
||||||
|
data,
|
||||||
|
arguments:
|
||||||
|
ast::MovArgs {
|
||||||
|
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||||
|
src: ast::ParsedOperand::Reg(src_reg),
|
||||||
|
},
|
||||||
|
}) if functions.contains(&src_reg) => {
|
||||||
|
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
|
||||||
|
dst: dst_reg,
|
||||||
|
src: src_reg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
s => s,
|
||||||
|
})
|
||||||
|
}
|
@ -236,7 +236,7 @@ fn test_hip_assert<
|
|||||||
output: &mut [Output],
|
output: &mut [Output],
|
||||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||||
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
|
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
|
||||||
let llvm_ir = pass::to_llvm_module(ast).unwrap();
|
let llvm_ir = pass::to_llvm_module2(ast).unwrap();
|
||||||
let name = CString::new(name)?;
|
let name = CString::new(name)?;
|
||||||
let result =
|
let result =
|
||||||
run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?;
|
run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?;
|
||||||
|
@ -1049,6 +1049,15 @@ impl<'input, ID> MethodName<'input, ID> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'input> MethodName<'input, &'input str> {
|
||||||
|
pub fn text(&self) -> &'input str {
|
||||||
|
match self {
|
||||||
|
MethodName::Kernel(name) => *name,
|
||||||
|
MethodName::Func(name) => *name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bitflags! {
|
bitflags! {
|
||||||
pub struct LinkingDirective: u8 {
|
pub struct LinkingDirective: u8 {
|
||||||
const NONE = 0b000;
|
const NONE = 0b000;
|
||||||
@ -1291,7 +1300,12 @@ impl<T: Operand> CallArgs<T> {
|
|||||||
.iter()
|
.iter()
|
||||||
.zip(details.return_arguments.iter())
|
.zip(details.return_arguments.iter())
|
||||||
{
|
{
|
||||||
visitor.visit_ident(param, Some((type_, *space)), true, false)?;
|
visitor.visit_ident(
|
||||||
|
param,
|
||||||
|
Some((type_, *space)),
|
||||||
|
*space == StateSpace::Reg,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
visitor.visit_ident(&self.func, None, false, false)?;
|
visitor.visit_ident(&self.func, None, false, false)?;
|
||||||
for (param, (type_, space)) in self
|
for (param, (type_, space)) in self
|
||||||
@ -1315,7 +1329,12 @@ impl<T: Operand> CallArgs<T> {
|
|||||||
.iter_mut()
|
.iter_mut()
|
||||||
.zip(details.return_arguments.iter())
|
.zip(details.return_arguments.iter())
|
||||||
{
|
{
|
||||||
visitor.visit_ident(param, Some((type_, *space)), true, false)?;
|
visitor.visit_ident(
|
||||||
|
param,
|
||||||
|
Some((type_, *space)),
|
||||||
|
*space == StateSpace::Reg,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
visitor.visit_ident(&mut self.func, None, false, false)?;
|
visitor.visit_ident(&mut self.func, None, false, false)?;
|
||||||
for (param, (type_, space)) in self
|
for (param, (type_, space)) in self
|
||||||
@ -1339,7 +1358,12 @@ impl<T: Operand> CallArgs<T> {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(details.return_arguments.iter())
|
.zip(details.return_arguments.iter())
|
||||||
.map(|(param, (type_, space))| {
|
.map(|(param, (type_, space))| {
|
||||||
visitor.visit_ident(param, Some((type_, *space)), true, false)
|
visitor.visit_ident(
|
||||||
|
param,
|
||||||
|
Some((type_, *space)),
|
||||||
|
*space == StateSpace::Reg,
|
||||||
|
false,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let func = visitor.visit_ident(self.func, None, false, false)?;
|
let func = visitor.visit_ident(self.func, None, false, false)?;
|
||||||
|
@ -1499,7 +1499,6 @@ derive_parser!(
|
|||||||
pub enum StateSpace {
|
pub enum StateSpace {
|
||||||
Reg,
|
Reg,
|
||||||
Generic,
|
Generic,
|
||||||
Sreg,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
|
Reference in New Issue
Block a user