mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-12 10:48:53 +03:00
Refactor compilation passes (#270)
The overarching goal is to refactor all passes so they are module-scoped and not function-scoped. Additionally, make improvements to the most egregiously buggy/unfit passes (so the code is ready for the next major features: linking, ftz handling) and continue adding more code to the LLVM backend
This commit is contained in:
60
.github/workflows/rust.yml
vendored
60
.github/workflows/rust.yml
vendored
@ -1,60 +0,0 @@
|
||||
name: Rust
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build_lin:
|
||||
name: Build and publish (Linux)
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: true
|
||||
- name: Install GPU drivers
|
||||
run: |
|
||||
sudo apt-get install -y gpg-agent wget
|
||||
wget -qO - https://repositories.intel.com/graphics/intel-graphics.key | sudo apt-key add -
|
||||
sudo apt-add-repository 'deb [arch=amd64] https://repositories.intel.com/graphics/ubuntu focal main'
|
||||
sudo apt-get update
|
||||
sudo apt-get install intel-opencl-icd intel-level-zero-gpu level-zero intel-media-va-driver-non-free libmfx1 libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev ocl-icd-opencl-dev
|
||||
- name: Build
|
||||
run: cargo build --workspace --verbose --release
|
||||
- name: Rename to libcuda.so
|
||||
run: |
|
||||
mv target/release/libnvcuda.so target/release/libcuda.so
|
||||
ln -s libcuda.so target/release/libcuda.so.1
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: Linux
|
||||
path: |
|
||||
target/release/libcuda.so
|
||||
target/release/libcuda.so.1
|
||||
target/release/libnvml.so
|
||||
build_win:
|
||||
name: Build and publish (Windows)
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: true
|
||||
- name: Build
|
||||
run: cargo build --workspace --verbose --release
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: Windows
|
||||
path: |
|
||||
target/release/nvcuda.dll
|
||||
target/release/nvml.dll
|
||||
target/release/zluda_redirect.dll
|
||||
target/release/zluda_with.exe
|
||||
target/release/zluda_dump.dll
|
||||
# TODO(take-cheeze): Support testing
|
||||
# - name: Run tests
|
||||
# run: cargo test --verbose
|
@ -17,6 +17,9 @@ thiserror = "1.0"
|
||||
bit-vec = "0.6"
|
||||
half ="1.6"
|
||||
bitflags = "1.2"
|
||||
rustc-hash = "2.0.0"
|
||||
strum = "0.26"
|
||||
strum_macros = "0.26"
|
||||
|
||||
[dependencies.lalrpop-util]
|
||||
version = "0.19.12"
|
||||
|
@ -489,7 +489,7 @@ fn convert_to_stateful_memory_access_postprocess(
|
||||
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
|
||||
let converting_id = id_defs
|
||||
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
||||
let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) {
|
||||
let kind = if new_operand_space == ast::StateSpace::Reg {
|
||||
ConversionKind::Default
|
||||
} else {
|
||||
ConversionKind::PtrToPtr
|
||||
|
141
ptx/src/pass/deparamize_functions.rs
Normal file
141
ptx/src/pass/deparamize_functions.rs
Normal file
@ -0,0 +1,141 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
if method.func_decl.name.is_kernel() {
|
||||
return Ok(method);
|
||||
}
|
||||
let is_declaration = method.body.is_none();
|
||||
let mut body = Vec::new();
|
||||
let mut remap_returns = Vec::new();
|
||||
for arg in method.func_decl.return_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
for arg in method.func_decl.input_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
body.push(Statement::Instruction(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: arg.v_type.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: old_name,
|
||||
src2: arg.name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
if remap_returns.is_empty() {
|
||||
return Ok(method);
|
||||
}
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
for statement in statements {
|
||||
run_statement(&remap_returns, &mut body, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(body)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||
for (old_name, new_name, type_) in remap_returns.iter().cloned() {
|
||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Reg,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_,
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: new_name,
|
||||
src: old_name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
result.push(statement);
|
||||
}
|
||||
statement => {
|
||||
result.push(statement);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -164,17 +164,16 @@ impl Deref for MemoryBuffer {
|
||||
}
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
id_defs: &GlobalStringIdResolver<'input>,
|
||||
call_map: MethodsCallMap<'input>,
|
||||
directives: Vec<Directive<'input>>,
|
||||
id_defs: GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<MemoryBuffer, TranslateError> {
|
||||
let context = Context::new();
|
||||
let module = Module::new(&context, LLVM_UNNAMED);
|
||||
let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
|
||||
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
|
||||
for directive in directives {
|
||||
match directive {
|
||||
Directive::Variable(..) => todo!(),
|
||||
Directive::Method(method) => emit_ctx.emit_method(method)?,
|
||||
Directive2::Variable(..) => todo!(),
|
||||
Directive2::Method(method) => emit_ctx.emit_method(method)?,
|
||||
}
|
||||
}
|
||||
module.write_to_stderr();
|
||||
@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> {
|
||||
context: LLVMContextRef,
|
||||
module: LLVMModuleRef,
|
||||
builder: Builder,
|
||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
||||
id_defs: &'a GlobalStringIdentResolver2<'input>,
|
||||
resolver: ResolveIdent,
|
||||
}
|
||||
|
||||
@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||
fn new(
|
||||
context: &Context,
|
||||
module: &Module,
|
||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
||||
id_defs: &'a GlobalStringIdentResolver2<'input>,
|
||||
) -> Self {
|
||||
ModuleEmitContext {
|
||||
context: context.get(),
|
||||
@ -215,26 +214,50 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||
LLVMCallConv::LLVMCCallConv as u32
|
||||
}
|
||||
|
||||
fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
|
||||
let func_decl = method.func_decl.borrow();
|
||||
fn emit_method(
|
||||
&mut self,
|
||||
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let func_decl = method.func_decl;
|
||||
let name = method
|
||||
.import_as
|
||||
.as_deref()
|
||||
.unwrap_or_else(|| match func_decl.name {
|
||||
ast::MethodName::Kernel(name) => name,
|
||||
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
|
||||
});
|
||||
.or_else(|| match func_decl.name {
|
||||
ast::MethodName::Kernel(name) => Some(name),
|
||||
ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
|
||||
})
|
||||
.ok_or_else(|| error_unreachable())?;
|
||||
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
||||
let fn_type = self.function_type(
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
||||
func_decl.input_arguments.iter().map(|v| &v.v_type),
|
||||
);
|
||||
func_decl
|
||||
.input_arguments
|
||||
.iter()
|
||||
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
|
||||
)?;
|
||||
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||
if let ast::MethodName::Func(name) = func_decl.name {
|
||||
self.resolver.register(name, fn_);
|
||||
}
|
||||
for (i, param) in func_decl.input_arguments.iter().enumerate() {
|
||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||
let name = self.resolver.get_or_add(param.name);
|
||||
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
||||
self.resolver.register(param.name, value);
|
||||
if func_decl.name.is_kernel() {
|
||||
let attr_kind = unsafe {
|
||||
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
|
||||
};
|
||||
let attr = unsafe {
|
||||
LLVMCreateTypeAttribute(
|
||||
self.context,
|
||||
attr_kind,
|
||||
get_type(self.context, ¶m.v_type)?,
|
||||
)
|
||||
};
|
||||
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
|
||||
}
|
||||
}
|
||||
let call_conv = if func_decl.name.is_kernel() {
|
||||
Self::kernel_call_convention()
|
||||
@ -258,66 +281,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn function_type(
|
||||
&self,
|
||||
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
) -> LLVMTypeRef {
|
||||
if return_args.len() == 0 {
|
||||
let mut input_args = input_args
|
||||
.map(|type_| match type_ {
|
||||
ast::Type::Scalar(scalar) => match scalar {
|
||||
ast::ScalarType::Pred => {
|
||||
unsafe { LLVMInt1TypeInContext(self.context) }
|
||||
fn get_input_argument_type(
|
||||
context: LLVMContextRef,
|
||||
v_type: &ptx_parser::Type,
|
||||
state_space: ptx_parser::StateSpace,
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
match state_space {
|
||||
ptx_parser::StateSpace::ParamEntry => {
|
||||
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
|
||||
}
|
||||
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
|
||||
unsafe { LLVMInt8TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
|
||||
unsafe { LLVMInt16TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
|
||||
unsafe { LLVMInt32TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
|
||||
unsafe { LLVMInt64TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::B128 => {
|
||||
unsafe { LLVMInt128TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F16 => {
|
||||
unsafe { LLVMHalfTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F32 => {
|
||||
unsafe { LLVMFloatTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F64 => {
|
||||
unsafe { LLVMDoubleTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::BF16 => {
|
||||
unsafe { LLVMBFloatTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::U16x2 => todo!(),
|
||||
ast::ScalarType::S16x2 => todo!(),
|
||||
ast::ScalarType::F16x2 => todo!(),
|
||||
ast::ScalarType::BF16x2 => todo!(),
|
||||
},
|
||||
ast::Type::Vector(_, _) => todo!(),
|
||||
ast::Type::Array(_, _, _) => todo!(),
|
||||
ast::Type::Pointer(_, _) => todo!(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
return unsafe {
|
||||
LLVMFunctionType(
|
||||
LLVMVoidTypeInContext(self.context),
|
||||
input_args.as_mut_ptr(),
|
||||
input_args.len() as u32,
|
||||
0,
|
||||
)
|
||||
};
|
||||
}
|
||||
todo!()
|
||||
ptx_parser::StateSpace::Reg => get_type(context, v_type),
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -326,7 +302,7 @@ struct MethodEmitContext<'a, 'input> {
|
||||
module: LLVMModuleRef,
|
||||
method: LLVMValueRef,
|
||||
builder: LLVMBuilderRef,
|
||||
id_defs: &'a GlobalStringIdResolver<'input>,
|
||||
id_defs: &'a GlobalStringIdentResolver2<'input>,
|
||||
variables_builder: Builder,
|
||||
resolver: &'a mut ResolveIdent,
|
||||
}
|
||||
@ -365,6 +341,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||
Statement::PtrAccess(_) => todo!(),
|
||||
Statement::RepackVector(_) => todo!(),
|
||||
Statement::FunctionPointer(_) => todo!(),
|
||||
Statement::VectorAccess(_) => todo!(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -414,7 +391,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||
inst: ast::Instruction<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match inst {
|
||||
ast::Instruction::Mov { data, arguments } => todo!(),
|
||||
ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
|
||||
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
|
||||
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
|
||||
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
|
||||
@ -425,7 +402,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||
ast::Instruction::Or { data, arguments } => todo!(),
|
||||
ast::Instruction::And { data, arguments } => todo!(),
|
||||
ast::Instruction::Bra { arguments } => todo!(),
|
||||
ast::Instruction::Call { data, arguments } => todo!(),
|
||||
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
|
||||
ast::Instruction::Cvt { data, arguments } => todo!(),
|
||||
ast::Instruction::Shr { data, arguments } => todo!(),
|
||||
ast::Instruction::Shl { data, arguments } => todo!(),
|
||||
@ -563,6 +540,70 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
||||
fn emit_ret(&self, _data: ptx_parser::RetData) {
|
||||
unsafe { LLVMBuildRetVoid(self.builder) };
|
||||
}
|
||||
|
||||
fn emit_call(
|
||||
&mut self,
|
||||
data: ptx_parser::CallDetails,
|
||||
arguments: ptx_parser::CallArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if cfg!(debug_assertions) {
|
||||
for (_, space) in data.return_arguments.iter() {
|
||||
if *space != ast::StateSpace::Reg {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
for (_, space) in data.input_arguments.iter() {
|
||||
if *space != ast::StateSpace::Reg {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
}
|
||||
let name = match (&*data.return_arguments, &*arguments.return_arguments) {
|
||||
([], []) => LLVM_UNNAMED.as_ptr(),
|
||||
([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
|
||||
_ => todo!(),
|
||||
};
|
||||
let type_ = get_function_type(
|
||||
self.context,
|
||||
data.return_arguments.iter().map(|(type_, space)| type_),
|
||||
data.input_arguments
|
||||
.iter()
|
||||
.map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
|
||||
)?;
|
||||
let mut input_arguments = arguments
|
||||
.input_arguments
|
||||
.iter()
|
||||
.map(|arg| self.resolver.value(*arg))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let llvm_fn = unsafe {
|
||||
LLVMBuildCall2(
|
||||
self.builder,
|
||||
type_,
|
||||
self.resolver.value(arguments.func)?,
|
||||
input_arguments.as_mut_ptr(),
|
||||
input_arguments.len() as u32,
|
||||
name,
|
||||
)
|
||||
};
|
||||
match &*arguments.return_arguments {
|
||||
[] => {}
|
||||
[name] => {
|
||||
self.resolver.register(*name, llvm_fn);
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mov(
|
||||
&mut self,
|
||||
_data: ptx_parser::MovDetails,
|
||||
arguments: ptx_parser::MovArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
self.resolver
|
||||
.register(arguments.dst, self.resolver.value(arguments.src)?);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pointer_type<'ctx>(
|
||||
@ -624,13 +665,34 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
||||
}
|
||||
}
|
||||
|
||||
fn get_function_type<'a>(
|
||||
context: LLVMContextRef,
|
||||
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
|
||||
input_args.collect::<Result<Vec<_>, _>>()?;
|
||||
let return_type = match return_args.len() {
|
||||
0 => unsafe { LLVMVoidTypeInContext(context) },
|
||||
1 => get_type(context, return_args.next().unwrap())?,
|
||||
_ => todo!(),
|
||||
};
|
||||
Ok(unsafe {
|
||||
LLVMFunctionType(
|
||||
return_type,
|
||||
input_args.as_mut_ptr(),
|
||||
input_args.len() as u32,
|
||||
0,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
|
||||
match space {
|
||||
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
|
||||
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
|
||||
ast::StateSpace::Sreg => Ok(PRIVATE_ADDRESS_SPACE),
|
||||
ast::StateSpace::Param => Err(TranslateError::Todo),
|
||||
ast::StateSpace::ParamEntry => Err(TranslateError::Todo),
|
||||
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
|
||||
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
|
||||
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
|
||||
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
|
||||
@ -647,7 +709,7 @@ struct ResolveIdent {
|
||||
}
|
||||
|
||||
impl ResolveIdent {
|
||||
fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
|
||||
fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
|
||||
ResolveIdent {
|
||||
words: HashMap::new(),
|
||||
values: HashMap::new(),
|
||||
|
@ -469,7 +469,6 @@ fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass {
|
||||
ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
|
||||
ast::StateSpace::Param => spirv::StorageClass::Function,
|
||||
ast::StateSpace::Reg => spirv::StorageClass::Function,
|
||||
ast::StateSpace::Sreg => spirv::StorageClass::Input,
|
||||
ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc
|
||||
| ast::StateSpace::SharedCluster
|
||||
@ -693,7 +692,6 @@ fn emit_variable<'input>(
|
||||
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
|
||||
ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
|
||||
ast::StateSpace::Generic => todo!(),
|
||||
ast::StateSpace::Sreg => todo!(),
|
||||
ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc
|
||||
| ast::StateSpace::SharedCluster
|
||||
@ -1563,6 +1561,7 @@ fn emit_function_body_ops<'input>(
|
||||
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
|
||||
}
|
||||
}
|
||||
Statement::VectorAccess(vector_access) => todo!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -63,9 +63,9 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg {
|
||||
if state_space == ast::StateSpace::Reg {
|
||||
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
|
||||
if !space_is_compatible(reg_space, ast::StateSpace::Reg) {
|
||||
if reg_space != ast::StateSpace::Reg {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let reg_scalar_type = match reg_type {
|
||||
|
289
ptx/src/pass/expand_operands.rs
Normal file
289
ptx/src/pass/expand_operands.rs
Normal file
@ -0,0 +1,289 @@
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<UnconditionalDirective<'input>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: Function2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut visitor = FlattenArguments::new(resolver, result);
|
||||
let new_statement = statement.visit_map(&mut visitor)?;
|
||||
visitor.result.push(new_statement);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct FlattenArguments<'a, 'input> {
|
||||
result: &'a mut Vec<ExpandedStatement>,
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
post_stmts: Vec<ExpandedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> FlattenArguments<'a, 'input> {
|
||||
fn new(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
result: &'a mut Vec<ExpandedStatement>,
|
||||
) -> Self {
|
||||
FlattenArguments {
|
||||
result,
|
||||
resolver,
|
||||
post_stmts: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
fn reg_offset(
|
||||
&mut self,
|
||||
reg: SpirvWord,
|
||||
offset: i32,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||
(type_, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
if state_space == ast::StateSpace::Reg {
|
||||
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
|
||||
if *reg_space != ast::StateSpace::Reg {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let reg_scalar_type = match reg_type {
|
||||
ast::Type::Scalar(underlying_type) => *underlying_type,
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let reg_type = reg_type.clone();
|
||||
let id_constant_stmt = self
|
||||
.resolver
|
||||
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: reg_scalar_type,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let arith_details = match reg_scalar_type.kind() {
|
||||
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
}),
|
||||
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
})
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let id_add_result = self
|
||||
.resolver
|
||||
.register_unnamed(Some((reg_type, state_space)));
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Add {
|
||||
data: arith_details,
|
||||
arguments: ast::AddArgs {
|
||||
dst: id_add_result,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
}));
|
||||
Ok(id_add_result)
|
||||
} else {
|
||||
let id_constant_stmt = self.resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::S64),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::S64,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let dst = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), state_space)));
|
||||
self.result.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: type_.clone(),
|
||||
state_space: state_space,
|
||||
dst,
|
||||
ptr_src: reg,
|
||||
offset_src: id_constant_stmt,
|
||||
}));
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
fn immediate(
|
||||
&mut self,
|
||||
value: ast::ImmediateValue,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (scalar_t, state_space) =
|
||||
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||
(*scalar, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
let id = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id,
|
||||
typ: scalar_t,
|
||||
value,
|
||||
}));
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
fn vec_member(
|
||||
&mut self,
|
||||
vector_src: SpirvWord,
|
||||
member: u8,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
|
||||
(ast::Type::Vector(vector_width, scalar_t), space) => {
|
||||
(*vector_width, *scalar_t, *space)
|
||||
}
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temporary = self
|
||||
.resolver
|
||||
.register_unnamed(Some((scalar_type.into(), space)));
|
||||
self.result.push(Statement::VectorAccess(VectorAccess {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst: temporary,
|
||||
src: vector_src,
|
||||
member: member,
|
||||
}));
|
||||
Ok(temporary)
|
||||
}
|
||||
|
||||
fn vec_pack(
|
||||
&mut self,
|
||||
vecs: Vec<SpirvWord>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (scalar_t, state_space) = match type_space {
|
||||
Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temp_vec = self
|
||||
.resolver
|
||||
.register_unnamed(Some((scalar_t.into(), state_space)));
|
||||
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract: is_dst,
|
||||
typ: scalar_t,
|
||||
packed: temp_vec,
|
||||
unpacked: vecs,
|
||||
relaxed_type_check,
|
||||
});
|
||||
if is_dst {
|
||||
self.post_stmts.push(statement);
|
||||
} else {
|
||||
self.result.push(statement);
|
||||
}
|
||||
Ok(temp_vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
|
||||
for FlattenArguments<'a, 'b>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: ast::ParsedOperand<SpirvWord>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
match args {
|
||||
ast::ParsedOperand::Reg(r) => self.reg(r),
|
||||
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
|
||||
ast::ParsedOperand::RegOffset(reg, offset) => {
|
||||
self.reg_offset(reg, offset, type_space, is_dst)
|
||||
}
|
||||
ast::ParsedOperand::VecMember(vec, member) => {
|
||||
self.vec_member(vec, member, type_space, is_dst)
|
||||
}
|
||||
ast::ParsedOperand::VecPack(vecs) => {
|
||||
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
name: <TypedOperand as ast::Operand>::Ident,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
|
||||
self.reg(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for FlattenArguments<'_, '_> {
|
||||
fn drop(&mut self) {
|
||||
self.result.extend(self.post_stmts.drain(..));
|
||||
}
|
||||
}
|
@ -273,7 +273,6 @@ fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
|
||||
ast::StateSpace::Const => "const",
|
||||
ast::StateSpace::Local => "local",
|
||||
ast::StateSpace::Param => "param",
|
||||
ast::StateSpace::Sreg => "sreg",
|
||||
ast::StateSpace::SharedCluster => "shared_cluster",
|
||||
ast::StateSpace::ParamEntry => "param_entry",
|
||||
ast::StateSpace::SharedCta => "shared_cta",
|
||||
|
209
ptx/src/pass/fix_special_registers2.rs
Normal file
209
ptx/src/pass/fix_special_registers2.rs
Normal file
@ -0,0 +1,209 @@
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
directives: Vec<UnconditionalDirective<'input>>,
|
||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||
let declarations = SpecialRegistersMap2::generate_declarations(resolver);
|
||||
let mut result = Vec::with_capacity(declarations.len() + directives.len());
|
||||
let mut sreg_to_function =
|
||||
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
|
||||
for (sreg, declaration) in declarations {
|
||||
let name = if let ast::MethodName::Func(name) = declaration.name {
|
||||
name
|
||||
} else {
|
||||
return Err(error_unreachable());
|
||||
};
|
||||
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||
func_decl: declaration,
|
||||
globals: Vec::new(),
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
}));
|
||||
sreg_to_function.insert(sreg, name);
|
||||
}
|
||||
let mut visitor = SpecialRegisterResolver {
|
||||
resolver,
|
||||
special_registers,
|
||||
sreg_to_function,
|
||||
result: Vec::new(),
|
||||
};
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(&mut visitor, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
directive: UnconditionalDirective<'input>,
|
||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
method: UnconditionalFunction<'input>,
|
||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(visitor, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
result: &mut Vec<UnconditionalStatement>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let converted_statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.result.drain(..));
|
||||
result.push(converted_statement);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct SpecialRegisterResolver<'a, 'input> {
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||
result: Vec<UnconditionalStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input>
|
||||
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
|
||||
for SpecialRegisterResolver<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
operand: ast::ParsedOperand<SpirvWord>,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
|
||||
map_operand(operand, &mut |ident, vector_index| {
|
||||
self.replace_sreg(ident, vector_index, is_dst)
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.replace_sreg(args, None, is_dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
|
||||
fn replace_sreg(
|
||||
&mut self,
|
||||
name: SpirvWord,
|
||||
vector_index: Option<u8>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if let Some(sreg) = self.special_registers.get(name) {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||
(Some(idx), Some(inp_type)) => {
|
||||
if inp_type != ast::ScalarType::U8 {
|
||||
return Err(TranslateError::Unreachable);
|
||||
}
|
||||
let constant = self.resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(inp_type),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: constant,
|
||||
typ: inp_type,
|
||||
value: ast::ImmediateValue::U64(idx as u64),
|
||||
}));
|
||||
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
|
||||
}
|
||||
(None, None) => Vec::new(),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let return_type = sreg.get_function_return_type();
|
||||
let fn_result = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
|
||||
let return_arguments = vec![(
|
||||
fn_result,
|
||||
ast::Type::Scalar(return_type),
|
||||
ast::StateSpace::Reg,
|
||||
)];
|
||||
let data = ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
};
|
||||
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
|
||||
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||
func: self.sreg_to_function[&sreg],
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
|
||||
.collect(),
|
||||
};
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}));
|
||||
Ok(fn_result)
|
||||
} else {
|
||||
Ok(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_operand<T, U, Err>(
|
||||
this: ast::ParsedOperand<T>,
|
||||
fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
|
||||
) -> Result<ast::ParsedOperand<U>, Err> {
|
||||
Ok(match this {
|
||||
ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
|
||||
ast::ParsedOperand::RegOffset(ident, offset) => {
|
||||
ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
|
||||
}
|
||||
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
|
||||
ast::ParsedOperand::VecMember(ident, member) => {
|
||||
ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
|
||||
}
|
||||
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||
idents
|
||||
.into_iter()
|
||||
.map(|ident| fn_(ident, None))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
})
|
||||
}
|
45
ptx/src/pass/hoist_globals.rs
Normal file
45
ptx/src/pass/hoist_globals.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(directives.len());
|
||||
for mut directive in directives.into_iter() {
|
||||
run_directive(&mut result, &mut directive);
|
||||
result.push(directive);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
|
||||
directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match directive {
|
||||
Directive2::Variable(..) => {}
|
||||
Directive2::Method(function2) => run_function(result, function2),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_function<'input>(
|
||||
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
|
||||
function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
|
||||
) {
|
||||
function.body = function.body.take().map(|statements| {
|
||||
statements
|
||||
.into_iter()
|
||||
.filter_map(|statement| match statement {
|
||||
Statement::Variable(var @ ast::Variable {
|
||||
state_space:
|
||||
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
|
||||
..
|
||||
}) => {
|
||||
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
|
||||
None
|
||||
}
|
||||
s => Some(s),
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
}
|
338
ptx/src/pass/insert_explicit_load_store.rs
Normal file
338
ptx/src/pass/insert_explicit_load_store.rs
Normal file
@ -0,0 +1,338 @@
|
||||
use super::*;
|
||||
use ptx_parser::VisitorMap;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
// This pass:
|
||||
// * Turns all .local, .param and .reg in-body variables into .local variables
|
||||
// (if _not_ an input method argument)
|
||||
// * Inserts explicit `ld`/`st` for newly converted .reg variables
|
||||
// * Fixup state space of all existing `ld`/`st` instructions into newly
|
||||
// converted variables
|
||||
// * Turns `.entry` input arguments into param::entry and all related `.param`
|
||||
// loads into `param::entry` loads
|
||||
// * All `.func` input arguments are turned into `.reg` arguments by another
|
||||
// pass, so we do nothing there
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
let visitor = InsertMemSSAVisitor::new(resolver);
|
||||
Directive2::Method(run_method(visitor, method)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'a, 'input>(
|
||||
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
||||
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let mut func_decl = method.func_decl;
|
||||
for arg in func_decl.return_arguments.iter_mut() {
|
||||
visitor.visit_variable(arg)?;
|
||||
}
|
||||
let is_kernel = func_decl.name.is_kernel();
|
||||
if is_kernel {
|
||||
for arg in func_decl.input_arguments.iter_mut() {
|
||||
let old_name = arg.name;
|
||||
let old_space = arg.state_space;
|
||||
let new_space = ast::StateSpace::ParamEntry;
|
||||
let new_name = visitor
|
||||
.resolver
|
||||
.register_unnamed(Some((arg.v_type.clone(), new_space)));
|
||||
visitor.input_argument(old_name, new_name, old_space);
|
||||
arg.name = new_name;
|
||||
arg.state_space = new_space;
|
||||
}
|
||||
};
|
||||
let body = method
|
||||
.body
|
||||
.map(move |statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(&mut visitor, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
|
||||
result: &mut Vec<ExpandedStatement>,
|
||||
statement: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Variable(mut var) => {
|
||||
visitor.visit_variable(&mut var)?;
|
||||
result.push(Statement::Variable(var));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
|
||||
let instruction = visitor.visit_ld(data, arguments)?;
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
|
||||
let instruction = visitor.visit_st(data, arguments)?;
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
s => {
|
||||
let new_statement = s.visit_map(visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(new_statement);
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct InsertMemSSAVisitor<'a, 'input> {
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
variables: FxHashMap<SpirvWord, RemapAction>,
|
||||
pre: Vec<ast::Instruction<SpirvWord>>,
|
||||
post: Vec<ast::Instruction<SpirvWord>>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||
Self {
|
||||
resolver,
|
||||
variables: FxHashMap::default(),
|
||||
pre: Vec::new(),
|
||||
post: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn input_argument(
|
||||
&mut self,
|
||||
old_name: SpirvWord,
|
||||
new_name: SpirvWord,
|
||||
old_space: ast::StateSpace,
|
||||
) -> Result<(), TranslateError> {
|
||||
if old_space != ast::StateSpace::Param {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::LDStSpaceChange {
|
||||
name: new_name,
|
||||
old_space,
|
||||
new_space: ast::StateSpace::ParamEntry,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn variable(
|
||||
&mut self,
|
||||
type_: &ast::Type,
|
||||
old_name: SpirvWord,
|
||||
new_name: SpirvWord,
|
||||
old_space: ast::StateSpace,
|
||||
) -> Result<(), TranslateError> {
|
||||
Ok(match old_space {
|
||||
ast::StateSpace::Reg => {
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::PreLdPostSt {
|
||||
name: new_name,
|
||||
type_: type_.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
ast::StateSpace::Param => {
|
||||
self.variables.insert(
|
||||
old_name,
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space: ast::StateSpace::Local,
|
||||
name: new_name,
|
||||
},
|
||||
);
|
||||
}
|
||||
// Good as-is
|
||||
ast::StateSpace::Local => {}
|
||||
// Will be pulled into global scope later
|
||||
ast::StateSpace::Generic
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::Shared => {}
|
||||
ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
|
||||
return Err(error_unreachable())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_st(
|
||||
&self,
|
||||
mut data: ast::StData,
|
||||
mut arguments: ast::StArgs<SpirvWord>,
|
||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src1) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { .. } => {}
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
} => {
|
||||
if data.state_space != *old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
data.state_space = *new_space;
|
||||
arguments.src1 = *name;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ast::Instruction::St { data, arguments })
|
||||
}
|
||||
|
||||
fn visit_ld(
|
||||
&self,
|
||||
mut data: ast::LdDetails,
|
||||
mut arguments: ast::LdArgs<SpirvWord>,
|
||||
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { .. } => {}
|
||||
RemapAction::LDStSpaceChange {
|
||||
old_space,
|
||||
new_space,
|
||||
name,
|
||||
} => {
|
||||
if data.state_space != *old_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
data.state_space = *new_space;
|
||||
arguments.src = *name;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ast::Instruction::Ld { data, arguments })
|
||||
}
|
||||
|
||||
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
||||
if var.state_space != ast::StateSpace::Local {
|
||||
let old_name = var.name;
|
||||
let old_space = var.state_space;
|
||||
let new_space = ast::StateSpace::Local;
|
||||
let new_name = self
|
||||
.resolver
|
||||
.register_unnamed(Some((var.v_type.clone(), new_space)));
|
||||
self.variable(&var.v_type, old_name, new_name, old_space)?;
|
||||
var.name = new_name;
|
||||
var.state_space = new_space;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
|
||||
for InsertMemSSAVisitor<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
ident: SpirvWord,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&ident) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { name, type_ } => {
|
||||
if is_dst {
|
||||
let temp = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||
self.post.push(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
state_space: ast::StateSpace::Local,
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: type_.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: *name,
|
||||
src2: temp,
|
||||
},
|
||||
});
|
||||
Ok(temp)
|
||||
} else {
|
||||
let temp = self
|
||||
.resolver
|
||||
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
|
||||
self.pre.push(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
state_space: ast::StateSpace::Local,
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: temp,
|
||||
src: *name,
|
||||
},
|
||||
});
|
||||
Ok(temp)
|
||||
}
|
||||
}
|
||||
RemapAction::LDStSpaceChange { .. } => {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(ident)
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.visit(args, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum RemapAction {
|
||||
PreLdPostSt {
|
||||
name: SpirvWord,
|
||||
type_: ast::Type,
|
||||
},
|
||||
LDStSpaceChange {
|
||||
old_space: ast::StateSpace,
|
||||
new_space: ast::StateSpace,
|
||||
name: SpirvWord,
|
||||
},
|
||||
}
|
@ -45,6 +45,13 @@ pub(super) fn run(
|
||||
Statement::RepackVector(repack),
|
||||
)?;
|
||||
}
|
||||
Statement::VectorAccess(vector_access) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::VectorAccess(vector_access),
|
||||
)?;
|
||||
}
|
||||
s @ Statement::Conditional(_)
|
||||
| s @ Statement::Conversion(_)
|
||||
| s @ Statement::Label(_)
|
||||
@ -128,7 +135,7 @@ pub(crate) fn default_implicit_conversion(
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if instruction_space == ast::StateSpace::Reg {
|
||||
if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
||||
if operand_space == ast::StateSpace::Reg {
|
||||
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||
(operand_type, instruction_type)
|
||||
{
|
||||
@ -142,7 +149,7 @@ pub(crate) fn default_implicit_conversion(
|
||||
return Ok(Some(ConversionKind::AddressOf));
|
||||
}
|
||||
}
|
||||
if !space_is_compatible(instruction_space, operand_space) {
|
||||
if instruction_space != operand_space {
|
||||
default_implicit_conversion_space(
|
||||
(operand_space, operand_type),
|
||||
(instruction_space, instruction_type),
|
||||
@ -161,7 +168,7 @@ fn is_addressable(this: ast::StateSpace) -> bool {
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
|
||||
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
||||
ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::ParamEntry
|
||||
@ -178,7 +185,7 @@ fn default_implicit_conversion_space(
|
||||
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||
{
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
||||
} else if operand_space == ast::StateSpace::Reg {
|
||||
match operand_type {
|
||||
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
|
||||
if *operand_ptr_space == instruction_space =>
|
||||
@ -210,7 +217,7 @@ fn default_implicit_conversion_space(
|
||||
},
|
||||
_ => Err(error_mismatched_type()),
|
||||
}
|
||||
} else if space_is_compatible(instruction_space, ast::StateSpace::Reg) {
|
||||
} else if instruction_space == ast::StateSpace::Reg {
|
||||
match instruction_type {
|
||||
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
||||
if operand_space == *instruction_ptr_space =>
|
||||
@ -234,7 +241,7 @@ fn default_implicit_conversion_type(
|
||||
operand_type: &ast::Type,
|
||||
instruction_type: &ast::Type,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if space_is_compatible(space, ast::StateSpace::Reg) {
|
||||
if space == ast::StateSpace::Reg {
|
||||
if should_bitcast(instruction_type, operand_type) {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
@ -257,8 +264,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||
| ast::StateSpace::Param
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Sreg => false,
|
||||
| ast::StateSpace::Generic => false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -294,7 +300,7 @@ pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if !space_is_compatible(operand_space, instruction_space) {
|
||||
if operand_space != instruction_space {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
@ -371,7 +377,7 @@ pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if !space_is_compatible(operand_space, instruction_space) {
|
||||
if operand_space != instruction_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
|
426
ptx/src/pass/insert_implicit_conversions2.rs
Normal file
426
ptx/src/pass/insert_implicit_conversions2.rs
Normal file
@ -0,0 +1,426 @@
|
||||
use std::mem;
|
||||
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
/*
|
||||
There are several kinds of implicit conversions in PTX:
|
||||
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||
documented special ld/st/cvt conversion rules for destination operands
|
||||
- st.param [x] y (used as function return arguments) same rule as above applies
|
||||
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
||||
documented special ld/st/cvt conversion rules are applied to dst
|
||||
- generic/global st: for instruction `st [x], y`, x must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer
|
||||
*/
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(mut method) => {
|
||||
method.body = method
|
||||
.body
|
||||
.map(|statements| run_statements(resolver, statements))
|
||||
.transpose()?;
|
||||
Directive2::Method(method)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
func: Vec<ExpandedStatement>,
|
||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func.into_iter() {
|
||||
insert_implicit_conversions_impl(resolver, &mut result, s)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_impl<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
stmt: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut post_conv = Vec::new();
|
||||
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
||||
&mut |operand,
|
||||
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst,
|
||||
relaxed_type_check| {
|
||||
let (instr_type, instruction_space) = match type_state {
|
||||
None => return Ok(operand),
|
||||
Some(t) => t,
|
||||
};
|
||||
let (operand_type, operand_space) = resolver.get_typed(operand)?;
|
||||
let conversion_fn = if relaxed_type_check {
|
||||
if is_dst {
|
||||
should_convert_relaxed_dst_wrapper
|
||||
} else {
|
||||
should_convert_relaxed_src_wrapper
|
||||
}
|
||||
} else {
|
||||
default_implicit_conversion
|
||||
};
|
||||
match conversion_fn(
|
||||
(*operand_space, &operand_type),
|
||||
(instruction_space, instr_type),
|
||||
)? {
|
||||
Some(conv_kind) => {
|
||||
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
||||
let mut from_type = instr_type.clone();
|
||||
let mut from_space = instruction_space;
|
||||
let mut to_type = operand_type.clone();
|
||||
let mut to_space = *operand_space;
|
||||
let mut src =
|
||||
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
|
||||
let mut dst = operand;
|
||||
let result = Ok::<_, TranslateError>(src);
|
||||
if !is_dst {
|
||||
mem::swap(&mut src, &mut dst);
|
||||
mem::swap(&mut from_type, &mut to_type);
|
||||
mem::swap(&mut from_space, &mut to_space);
|
||||
}
|
||||
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||
src,
|
||||
dst,
|
||||
from_type,
|
||||
from_space,
|
||||
to_type,
|
||||
to_space,
|
||||
kind: conv_kind,
|
||||
}));
|
||||
result
|
||||
}
|
||||
None => Ok(operand),
|
||||
}
|
||||
},
|
||||
)?;
|
||||
func.push(statement);
|
||||
func.append(&mut post_conv);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn default_implicit_conversion(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if instruction_space == ast::StateSpace::Reg {
|
||||
if operand_space == ast::StateSpace::Reg {
|
||||
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||
(operand_type, instruction_type)
|
||||
{
|
||||
if scalar.kind() == ast::ScalarKind::Bit
|
||||
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
||||
{
|
||||
return Ok(Some(ConversionKind::Default));
|
||||
}
|
||||
}
|
||||
} else if is_addressable(operand_space) {
|
||||
return Ok(Some(ConversionKind::AddressOf));
|
||||
}
|
||||
}
|
||||
if instruction_space != operand_space {
|
||||
default_implicit_conversion_space(
|
||||
(operand_space, operand_type),
|
||||
(instruction_space, instruction_type),
|
||||
)
|
||||
} else if instruction_type != operand_type {
|
||||
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_addressable(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Const
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
||||
ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
// Space is different
|
||||
fn default_implicit_conversion_space(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
||||
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||
{
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else if operand_space == ast::StateSpace::Reg {
|
||||
match operand_type {
|
||||
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
|
||||
if *operand_ptr_space == instruction_space =>
|
||||
{
|
||||
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
// TODO: 32 bit
|
||||
ast::Type::Scalar(ast::ScalarType::B64)
|
||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
ast::Type::Scalar(ast::ScalarType::B32)
|
||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||
Ok(Some(ConversionKind::BitToPtr))
|
||||
}
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
_ => Err(error_mismatched_type()),
|
||||
}
|
||||
} else if instruction_space == ast::StateSpace::Reg {
|
||||
match instruction_type {
|
||||
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
||||
if operand_space == *instruction_ptr_space =>
|
||||
{
|
||||
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
_ => Err(error_mismatched_type()),
|
||||
}
|
||||
} else {
|
||||
Err(error_mismatched_type())
|
||||
}
|
||||
}
|
||||
|
||||
// Space is same, but type is different
|
||||
fn default_implicit_conversion_type(
|
||||
space: ast::StateSpace,
|
||||
operand_type: &ast::Type,
|
||||
instruction_type: &ast::Type,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if space == ast::StateSpace::Reg {
|
||||
if should_bitcast(instruction_type, operand_type) {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
} else {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
}
|
||||
}
|
||||
|
||||
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ptx_parser::StateSpace::SharedCta
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Reg
|
||||
| ast::StateSpace::Param
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc
|
||||
| ast::StateSpace::Generic => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
if inst.size_of() != operand.size_of() {
|
||||
return false;
|
||||
}
|
||||
match inst.kind() {
|
||||
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Signed => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Signed
|
||||
}
|
||||
ast::ScalarKind::Pred => false,
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
||||
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
||||
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if operand_space != instruction_space {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||
fn should_convert_relaxed_dst(
|
||||
dst_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if dst_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (dst_type, instr_type) {
|
||||
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed => {
|
||||
if dst_type.kind() != ast::ScalarKind::Float {
|
||||
if instr_type.size_of() == dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else if instr_type.size_of() < dst_type.size_of() {
|
||||
Some(ConversionKind::SignExtend)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||
should_convert_relaxed_dst(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if operand_space != instruction_space {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(error_mismatched_type()),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||
fn should_convert_relaxed_src(
|
||||
src_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if src_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (src_type, instr_type) {
|
||||
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= src_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||
should_convert_relaxed_src(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
@ -189,7 +189,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||
return Ok(symbol);
|
||||
};
|
||||
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
|
||||
if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
|
||||
if var_space != ast::StateSpace::Reg || !is_variable {
|
||||
return Ok(symbol);
|
||||
};
|
||||
let member_index = match member_index {
|
||||
|
@ -1,5 +1,6 @@
|
||||
use ptx_parser as ast;
|
||||
use rspirv::{binary::Assemble, dr};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::hash::Hash;
|
||||
use std::num::NonZeroU8;
|
||||
use std::{
|
||||
@ -12,20 +13,31 @@ use std::{
|
||||
mem,
|
||||
rc::Rc,
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use strum_macros::EnumIter;
|
||||
|
||||
mod convert_dynamic_shared_memory_usage;
|
||||
mod convert_to_stateful_memory_access;
|
||||
mod convert_to_typed;
|
||||
mod deparamize_functions;
|
||||
pub(crate) mod emit_llvm;
|
||||
mod emit_spirv;
|
||||
mod expand_arguments;
|
||||
mod expand_operands;
|
||||
mod extract_globals;
|
||||
mod fix_special_registers;
|
||||
mod fix_special_registers2;
|
||||
mod hoist_globals;
|
||||
mod insert_explicit_load_store;
|
||||
mod insert_implicit_conversions;
|
||||
mod insert_implicit_conversions2;
|
||||
mod insert_mem_ssa_statements;
|
||||
mod normalize_identifiers;
|
||||
mod normalize_identifiers2;
|
||||
mod normalize_labels;
|
||||
mod normalize_predicates;
|
||||
mod normalize_predicates2;
|
||||
mod resolve_function_pointers;
|
||||
|
||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
@ -57,7 +69,30 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
||||
})?;
|
||||
normalize_variable_decls(&mut directives);
|
||||
let denorm_information = compute_denorm_information(&directives);
|
||||
let llvm_ir = emit_llvm::run(&id_defs, call_map, directives)?;
|
||||
todo!()
|
||||
/*
|
||||
let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?;
|
||||
Ok(Module {
|
||||
llvm_ir,
|
||||
kernel_info: HashMap::new(),
|
||||
}) */
|
||||
}
|
||||
|
||||
pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
|
||||
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
|
||||
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
|
||||
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
|
||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
|
||||
let directives = resolve_function_pointers::run(directives)?;
|
||||
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||
let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
|
||||
expand_operands::run(&mut flat_resolver, directives)?;
|
||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||
let directives = hoist_globals::run(directives)?;
|
||||
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
|
||||
Ok(Module {
|
||||
llvm_ir,
|
||||
kernel_info: HashMap::new(),
|
||||
@ -319,7 +354,7 @@ pub struct KernelInfo {
|
||||
pub uses_shared_mem: bool,
|
||||
}
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
|
||||
enum PtxSpecialRegister {
|
||||
Tid,
|
||||
Ntid,
|
||||
@ -342,6 +377,17 @@ impl PtxSpecialRegister {
|
||||
}
|
||||
}
|
||||
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Tid => "%tid",
|
||||
Self::Ntid => "%ntid",
|
||||
Self::Ctaid => "%ctaid",
|
||||
Self::Nctaid => "%nctaid",
|
||||
Self::Clock => "%clock",
|
||||
Self::LanemaskLt => "%lanemask_lt",
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(self) -> ast::Type {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid
|
||||
@ -525,7 +571,7 @@ impl<'b> NumericIdResolver<'b> {
|
||||
Some(Some(x)) => Ok(x.clone()),
|
||||
Some(None) => Err(TranslateError::UntypedSymbol),
|
||||
None => match self.special_registers.get(id) {
|
||||
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
|
||||
Some(x) => Ok((x.get_type(), ast::StateSpace::Reg, true)),
|
||||
None => match self.global_type_check.get(&id) {
|
||||
Some(Some(result)) => Ok(result.clone()),
|
||||
Some(None) | None => Err(TranslateError::UntypedSymbol),
|
||||
@ -722,6 +768,7 @@ enum Statement<I, P: ast::Operand> {
|
||||
PtrAccess(PtrAccess<P>),
|
||||
RepackVector(RepackVectorDetails),
|
||||
FunctionPointer(FunctionPointerDetails),
|
||||
VectorAccess(VectorAccess),
|
||||
}
|
||||
|
||||
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||
@ -890,6 +937,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||
offset_src,
|
||||
})
|
||||
}
|
||||
Statement::VectorAccess(VectorAccess {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst,
|
||||
src: vector_src,
|
||||
member,
|
||||
}) => {
|
||||
let dst: SpirvWord = visitor.visit_ident(
|
||||
dst,
|
||||
Some((&scalar_type.into(), ast::StateSpace::Reg)),
|
||||
true,
|
||||
false,
|
||||
)?;
|
||||
let src = visitor.visit_ident(
|
||||
vector_src,
|
||||
Some((
|
||||
&ast::Type::Vector(vector_width, scalar_type),
|
||||
ast::StateSpace::Reg,
|
||||
)),
|
||||
false,
|
||||
false,
|
||||
)?;
|
||||
Statement::VectorAccess(VectorAccess {
|
||||
scalar_type,
|
||||
vector_width,
|
||||
dst,
|
||||
src,
|
||||
member,
|
||||
})
|
||||
}
|
||||
Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract,
|
||||
typ,
|
||||
@ -1207,12 +1284,6 @@ impl<
|
||||
}
|
||||
}
|
||||
|
||||
fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
|
||||
this == other
|
||||
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|
||||
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
|
||||
}
|
||||
|
||||
fn register_external_fn_call<'a>(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
@ -1450,6 +1521,7 @@ fn compute_denorm_information<'input>(
|
||||
Statement::Label(_) => {}
|
||||
Statement::Variable(_) => {}
|
||||
Statement::PtrAccess { .. } => {}
|
||||
Statement::VectorAccess { .. } => {}
|
||||
Statement::RepackVector(_) => {}
|
||||
Statement::FunctionPointer(_) => {}
|
||||
}
|
||||
@ -1663,3 +1735,278 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
|
||||
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
|
||||
Method(Function2<'input, Instruction, Operand>),
|
||||
}
|
||||
|
||||
pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
|
||||
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
|
||||
pub globals: Vec<ast::Variable<SpirvWord>>,
|
||||
pub body: Option<Vec<Statement<Instruction, Operand>>>,
|
||||
import_as: Option<String>,
|
||||
tuning: Vec<ast::TuningDirective>,
|
||||
linkage: ast::LinkingDirective,
|
||||
}
|
||||
|
||||
type NormalizedDirective2<'input> = Directive2<
|
||||
'input,
|
||||
(
|
||||
Option<ast::PredAt<SpirvWord>>,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
),
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
type NormalizedFunction2<'input> = Function2<
|
||||
'input,
|
||||
(
|
||||
Option<ast::PredAt<SpirvWord>>,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
),
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
type UnconditionalDirective<'input> = Directive2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
type UnconditionalFunction<'input> = Function2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
struct GlobalStringIdentResolver2<'input> {
|
||||
pub(crate) current_id: SpirvWord,
|
||||
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||
}
|
||||
|
||||
impl<'input> GlobalStringIdentResolver2<'input> {
|
||||
fn new(spirv_word: SpirvWord) -> Self {
|
||||
Self {
|
||||
current_id: spirv_word,
|
||||
ident_map: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register_named(
|
||||
&mut self,
|
||||
name: Cow<'input, str>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
) -> SpirvWord {
|
||||
let new_id = self.current_id;
|
||||
self.ident_map.insert(
|
||||
new_id,
|
||||
IdentEntry {
|
||||
name: Some(name),
|
||||
type_space,
|
||||
},
|
||||
);
|
||||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
|
||||
fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
|
||||
let new_id = self.current_id;
|
||||
self.ident_map.insert(
|
||||
new_id,
|
||||
IdentEntry {
|
||||
name: None,
|
||||
type_space,
|
||||
},
|
||||
);
|
||||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
|
||||
fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
|
||||
match self.ident_map.get(&id) {
|
||||
Some(IdentEntry {
|
||||
type_space: Some(type_space),
|
||||
..
|
||||
}) => Ok(type_space),
|
||||
_ => Err(error_unknown_symbol()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct IdentEntry<'input> {
|
||||
name: Option<Cow<'input, str>>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
}
|
||||
|
||||
struct ScopedResolver<'input, 'b> {
|
||||
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
|
||||
scopes: Vec<ScopeMarker<'input>>,
|
||||
}
|
||||
|
||||
impl<'input, 'b> ScopedResolver<'input, 'b> {
|
||||
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||
Self {
|
||||
flat_resolver,
|
||||
scopes: vec![ScopeMarker::new()],
|
||||
}
|
||||
}
|
||||
|
||||
fn start_scope(&mut self) {
|
||||
self.scopes.push(ScopeMarker::new());
|
||||
}
|
||||
|
||||
fn end_scope(&mut self) {
|
||||
let scope = self.scopes.pop().unwrap();
|
||||
scope.flush(self.flat_resolver);
|
||||
}
|
||||
|
||||
fn add(
|
||||
&mut self,
|
||||
name: Cow<'input, str>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let result = self.flat_resolver.current_id;
|
||||
self.flat_resolver.current_id.0 += 1;
|
||||
let current_scope = self.scopes.last_mut().unwrap();
|
||||
if current_scope
|
||||
.name_to_ident
|
||||
.insert(name.clone(), result)
|
||||
.is_some()
|
||||
{
|
||||
return Err(error_unknown_symbol());
|
||||
}
|
||||
current_scope.ident_map.insert(
|
||||
result,
|
||||
IdentEntry {
|
||||
name: Some(name),
|
||||
type_space,
|
||||
},
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
|
||||
self.scopes
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|resolver| resolver.name_to_ident.get(name).copied())
|
||||
.ok_or_else(|| error_unreachable())
|
||||
}
|
||||
|
||||
fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
|
||||
let current_scope = self.scopes.last().unwrap();
|
||||
current_scope
|
||||
.name_to_ident
|
||||
.get(label)
|
||||
.copied()
|
||||
.ok_or_else(|| error_unreachable())
|
||||
}
|
||||
}
|
||||
|
||||
struct ScopeMarker<'input> {
|
||||
ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||
name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
|
||||
}
|
||||
|
||||
impl<'input> ScopeMarker<'input> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
ident_map: FxHashMap::default(),
|
||||
name_to_ident: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
|
||||
resolver.ident_map.extend(self.ident_map);
|
||||
}
|
||||
}
|
||||
|
||||
struct SpecialRegistersMap2 {
|
||||
reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
|
||||
id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
|
||||
}
|
||||
|
||||
impl SpecialRegistersMap2 {
|
||||
fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
|
||||
let mut result = SpecialRegistersMap2 {
|
||||
reg_to_id: FxHashMap::default(),
|
||||
id_to_reg: FxHashMap::default(),
|
||||
};
|
||||
for sreg in PtxSpecialRegister::iter() {
|
||||
let text = sreg.as_str();
|
||||
let id = resolver.add(
|
||||
Cow::Borrowed(text),
|
||||
Some((sreg.get_type(), ast::StateSpace::Reg)),
|
||||
)?;
|
||||
result.reg_to_id.insert(sreg, id);
|
||||
result.id_to_reg.insert(id, sreg);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
|
||||
self.id_to_reg.get(&id).copied()
|
||||
}
|
||||
|
||||
fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
|
||||
match self.reg_to_id.entry(reg) {
|
||||
hash_map::Entry::Occupied(e) => *e.get(),
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
let numeric_id = SpirvWord(current_id.0);
|
||||
current_id.0 += 1;
|
||||
e.insert(numeric_id);
|
||||
self.id_to_reg.insert(numeric_id, reg);
|
||||
numeric_id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_declarations<'a, 'input>(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
) -> impl ExactSizeIterator<
|
||||
Item = (
|
||||
PtxSpecialRegister,
|
||||
ast::MethodDeclaration<'input, SpirvWord>,
|
||||
),
|
||||
> + 'a {
|
||||
PtxSpecialRegister::iter().map(|sreg| {
|
||||
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||
let name =
|
||||
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
|
||||
let return_type = sreg.get_function_return_type();
|
||||
let input_type = sreg.get_function_return_type();
|
||||
(
|
||||
sreg,
|
||||
ast::MethodDeclaration {
|
||||
return_arguments: vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: return_type.into(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: resolver
|
||||
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
name: name,
|
||||
input_arguments: vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: input_type.into(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: resolver
|
||||
.register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
shared_mem: None,
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorAccess {
|
||||
scalar_type: ast::ScalarType,
|
||||
vector_width: u8,
|
||||
dst: SpirvWord,
|
||||
src: SpirvWord,
|
||||
member: u8,
|
||||
}
|
||||
|
199
ptx/src/pass/normalize_identifiers2.rs
Normal file
199
ptx/src/pass/normalize_identifiers2.rs
Normal file
@ -0,0 +1,199 @@
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
pub(crate) fn run<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
|
||||
resolver.start_scope();
|
||||
let result = directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
resolver.end_scope();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<NormalizedDirective2<'input>, TranslateError> {
|
||||
Ok(match directive {
|
||||
ast::Directive::Variable(linking, var) => {
|
||||
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||
}
|
||||
ast::Directive::Method(linking, directive) => {
|
||||
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
linkage: ast::LinkingDirective,
|
||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<NormalizedFunction2<'input>, TranslateError> {
|
||||
let name = match method.func_directive.name {
|
||||
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||
ast::MethodName::Func(text) => {
|
||||
ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
|
||||
}
|
||||
};
|
||||
resolver.start_scope();
|
||||
let func_decl = run_function_decl(resolver, method.func_directive, name)?;
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
run_statements(resolver, &mut result, statements)?;
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
resolver.end_scope();
|
||||
Ok(Function2 {
|
||||
func_decl,
|
||||
globals: Vec::new(),
|
||||
body,
|
||||
import_as: None,
|
||||
tuning: method.tuning,
|
||||
linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_function_decl<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||
name: ast::MethodName<'input, SpirvWord>,
|
||||
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
|
||||
assert!(func_directive.shared_mem.is_none());
|
||||
let return_arguments = func_directive
|
||||
.return_arguments
|
||||
.into_iter()
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let input_arguments = func_directive
|
||||
.input_arguments
|
||||
.into_iter()
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(ast::MethodDeclaration {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
shared_mem: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_variable<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
variable: ast::Variable<&'input str>,
|
||||
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
||||
Ok(ast::Variable {
|
||||
name: resolver.add(
|
||||
Cow::Borrowed(variable.name),
|
||||
Some((variable.v_type.clone(), variable.state_space)),
|
||||
)?,
|
||||
align: variable.align,
|
||||
v_type: variable.v_type,
|
||||
state_space: variable.state_space,
|
||||
array_init: variable.array_init,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<(), TranslateError> {
|
||||
for statement in statements.iter() {
|
||||
match statement {
|
||||
ast::Statement::Label(label) => {
|
||||
resolver.add(Cow::Borrowed(*label), None)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for statement in statements {
|
||||
match statement {
|
||||
ast::Statement::Label(label) => {
|
||||
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
||||
}
|
||||
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
|
||||
ast::Statement::Instruction(predicate, instruction) => {
|
||||
result.push(Statement::Instruction((
|
||||
predicate
|
||||
.map(|pred| {
|
||||
Ok::<_, TranslateError>(ast::PredAt {
|
||||
not: pred.not,
|
||||
label: resolver.get(pred.label)?,
|
||||
})
|
||||
})
|
||||
.transpose()?,
|
||||
run_instruction(resolver, instruction)?,
|
||||
)))
|
||||
}
|
||||
ast::Statement::Block(block) => {
|
||||
resolver.start_scope();
|
||||
run_statements(resolver, result, block)?;
|
||||
resolver.end_scope();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_instruction<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||
ast::visit_map(instruction, &mut |name: &'input str,
|
||||
_: Option<(
|
||||
&ast::Type,
|
||||
ast::StateSpace,
|
||||
)>,
|
||||
_,
|
||||
_| {
|
||||
resolver.get(&name)
|
||||
})
|
||||
}
|
||||
|
||||
fn run_multivariable<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
variable: ast::MultiVariable<&'input str>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match variable.count {
|
||||
Some(count) => {
|
||||
for i in 0..count {
|
||||
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
|
||||
let ident = resolver.add(
|
||||
name,
|
||||
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||
)?;
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: variable.var.align,
|
||||
v_type: variable.var.v_type.clone(),
|
||||
state_space: variable.var.state_space,
|
||||
name: ident,
|
||||
array_init: variable.var.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let name = Cow::Borrowed(variable.var.name);
|
||||
let ident = resolver.add(
|
||||
name,
|
||||
Some((variable.var.v_type.clone(), variable.var.state_space)),
|
||||
)?;
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: variable.var.align,
|
||||
v_type: variable.var.v_type.clone(),
|
||||
state_space: variable.var.state_space,
|
||||
name: ident,
|
||||
array_init: variable.var.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -26,6 +26,7 @@ pub(super) fn run(
|
||||
| Statement::Constant(..)
|
||||
| Statement::Label(..)
|
||||
| Statement::PtrAccess { .. }
|
||||
| Statement::VectorAccess { .. }
|
||||
| Statement::RepackVector(..)
|
||||
| Statement::FunctionPointer(..) => {}
|
||||
}
|
||||
|
84
ptx/src/pass/normalize_predicates2.rs
Normal file
84
ptx/src/pass/normalize_predicates2.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<NormalizedDirective2<'input>>,
|
||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: NormalizedDirective2<'input>,
|
||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: NormalizedFunction2<'input>,
|
||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<UnconditionalStatement>,
|
||||
statement: NormalizedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
Ok(match statement {
|
||||
Statement::Label(label) => result.push(Statement::Label(label)),
|
||||
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||
Statement::Instruction((predicate, instruction)) => {
|
||||
if let Some(pred) = predicate {
|
||||
let if_true = resolver.register_unnamed(None);
|
||||
let if_false = resolver.register_unnamed(None);
|
||||
let folded_bra = match &instruction {
|
||||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||
_ => None,
|
||||
};
|
||||
let mut branch = BrachCondition {
|
||||
predicate: pred.label,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
if pred.not {
|
||||
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||
}
|
||||
result.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
result.push(Statement::Label(if_true));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
result.push(Statement::Label(if_false));
|
||||
} else {
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
})
|
||||
}
|
82
ptx/src/pass/resolve_function_pointers.rs
Normal file
82
ptx/src/pass/resolve_function_pointers.rs
Normal file
@ -0,0 +1,82 @@
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
directives: Vec<UnconditionalDirective<'input>>,
|
||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||
let mut functions = FxHashSet::default();
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(&mut functions, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
directive: UnconditionalDirective<'input>,
|
||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
{
|
||||
let func_decl = &method.func_decl;
|
||||
match func_decl.name {
|
||||
ptx_parser::MethodName::Kernel(_) => {}
|
||||
ptx_parser::MethodName::Func(name) => {
|
||||
functions.insert(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
Directive2::Method(run_method(functions, method)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
method: UnconditionalFunction<'input>,
|
||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|statement| run_statement(functions, statement))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
statement: UnconditionalStatement,
|
||||
) -> Result<UnconditionalStatement, TranslateError> {
|
||||
Ok(match statement {
|
||||
Statement::Instruction(ast::Instruction::Mov {
|
||||
data,
|
||||
arguments:
|
||||
ast::MovArgs {
|
||||
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||
src: ast::ParsedOperand::Reg(src_reg),
|
||||
},
|
||||
}) if functions.contains(&src_reg) => {
|
||||
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
|
||||
dst: dst_reg,
|
||||
src: src_reg,
|
||||
})
|
||||
}
|
||||
s => s,
|
||||
})
|
||||
}
|
@ -236,7 +236,7 @@ fn test_hip_assert<
|
||||
output: &mut [Output],
|
||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
|
||||
let llvm_ir = pass::to_llvm_module(ast).unwrap();
|
||||
let llvm_ir = pass::to_llvm_module2(ast).unwrap();
|
||||
let name = CString::new(name)?;
|
||||
let result =
|
||||
run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?;
|
||||
|
@ -1049,6 +1049,15 @@ impl<'input, ID> MethodName<'input, ID> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'input> MethodName<'input, &'input str> {
|
||||
pub fn text(&self) -> &'input str {
|
||||
match self {
|
||||
MethodName::Kernel(name) => *name,
|
||||
MethodName::Func(name) => *name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
pub struct LinkingDirective: u8 {
|
||||
const NONE = 0b000;
|
||||
@ -1291,7 +1300,12 @@ impl<T: Operand> CallArgs<T> {
|
||||
.iter()
|
||||
.zip(details.return_arguments.iter())
|
||||
{
|
||||
visitor.visit_ident(param, Some((type_, *space)), true, false)?;
|
||||
visitor.visit_ident(
|
||||
param,
|
||||
Some((type_, *space)),
|
||||
*space == StateSpace::Reg,
|
||||
false,
|
||||
)?;
|
||||
}
|
||||
visitor.visit_ident(&self.func, None, false, false)?;
|
||||
for (param, (type_, space)) in self
|
||||
@ -1315,7 +1329,12 @@ impl<T: Operand> CallArgs<T> {
|
||||
.iter_mut()
|
||||
.zip(details.return_arguments.iter())
|
||||
{
|
||||
visitor.visit_ident(param, Some((type_, *space)), true, false)?;
|
||||
visitor.visit_ident(
|
||||
param,
|
||||
Some((type_, *space)),
|
||||
*space == StateSpace::Reg,
|
||||
false,
|
||||
)?;
|
||||
}
|
||||
visitor.visit_ident(&mut self.func, None, false, false)?;
|
||||
for (param, (type_, space)) in self
|
||||
@ -1339,7 +1358,12 @@ impl<T: Operand> CallArgs<T> {
|
||||
.into_iter()
|
||||
.zip(details.return_arguments.iter())
|
||||
.map(|(param, (type_, space))| {
|
||||
visitor.visit_ident(param, Some((type_, *space)), true, false)
|
||||
visitor.visit_ident(
|
||||
param,
|
||||
Some((type_, *space)),
|
||||
*space == StateSpace::Reg,
|
||||
false,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let func = visitor.visit_ident(self.func, None, false, false)?;
|
||||
|
@ -1499,7 +1499,6 @@ derive_parser!(
|
||||
pub enum StateSpace {
|
||||
Reg,
|
||||
Generic,
|
||||
Sreg,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
|
Reference in New Issue
Block a user