Refactor compilation passes (#270)

The overarching goal is to refactor all passes so they are module-scoped and not function-scoped. Additionally, make improvements to the most egregiously buggy/unfit passes (so the code is ready for the next major features: linking, ftz handling) and continue adding more code to the LLVM backend
This commit is contained in:
Andrzej Janik
2024-09-23 16:33:46 +02:00
committed by GitHub
parent 46def3e7e0
commit c92abba2bb
23 changed files with 2365 additions and 172 deletions

View File

@ -1,60 +0,0 @@
name: Rust
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
env:
CARGO_TERM_COLOR: always
jobs:
build_lin:
name: Build and publish (Linux)
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Install GPU drivers
run: |
sudo apt-get install -y gpg-agent wget
wget -qO - https://repositories.intel.com/graphics/intel-graphics.key | sudo apt-key add -
sudo apt-add-repository 'deb [arch=amd64] https://repositories.intel.com/graphics/ubuntu focal main'
sudo apt-get update
sudo apt-get install intel-opencl-icd intel-level-zero-gpu level-zero intel-media-va-driver-non-free libmfx1 libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev ocl-icd-opencl-dev
- name: Build
run: cargo build --workspace --verbose --release
- name: Rename to libcuda.so
run: |
mv target/release/libnvcuda.so target/release/libcuda.so
ln -s libcuda.so target/release/libcuda.so.1
- uses: actions/upload-artifact@v2
with:
name: Linux
path: |
target/release/libcuda.so
target/release/libcuda.so.1
target/release/libnvml.so
build_win:
name: Build and publish (Windows)
runs-on: windows-latest
steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Build
run: cargo build --workspace --verbose --release
- uses: actions/upload-artifact@v2
with:
name: Windows
path: |
target/release/nvcuda.dll
target/release/nvml.dll
target/release/zluda_redirect.dll
target/release/zluda_with.exe
target/release/zluda_dump.dll
# TODO(take-cheeze): Support testing
# - name: Run tests
# run: cargo test --verbose

View File

@ -17,6 +17,9 @@ thiserror = "1.0"
bit-vec = "0.6" bit-vec = "0.6"
half ="1.6" half ="1.6"
bitflags = "1.2" bitflags = "1.2"
rustc-hash = "2.0.0"
strum = "0.26"
strum_macros = "0.26"
[dependencies.lalrpop-util] [dependencies.lalrpop-util]
version = "0.19.12" version = "0.19.12"

View File

@ -489,7 +489,7 @@ fn convert_to_stateful_memory_access_postprocess(
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?; let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
let converting_id = id_defs let converting_id = id_defs
.register_intermediate(Some((old_operand_type.clone(), old_operand_space))); .register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) { let kind = if new_operand_space == ast::StateSpace::Reg {
ConversionKind::Default ConversionKind::Default
} else { } else {
ConversionKind::PtrToPtr ConversionKind::PtrToPtr

View File

@ -0,0 +1,141 @@
use std::collections::BTreeMap;
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
if method.func_decl.name.is_kernel() {
return Ok(method);
}
let is_declaration = method.body.is_none();
let mut body = Vec::new();
let mut remap_returns = Vec::new();
for arg in method.func_decl.return_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
for arg in method.func_decl.input_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
body.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: arg.v_type.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: arg.name,
},
}));
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
if remap_returns.is_empty() {
return Ok(method);
}
let body = method
.body
.map(|statements| {
for statement in statements {
run_statement(&remap_returns, &mut body, statement)?;
}
Ok::<_, TranslateError>(body)
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'input>(
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(ast::Instruction::Ret { .. }) => {
for (old_name, new_name, type_) in remap_returns.iter().cloned() {
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Reg,
caching: ast::LdCacheOperator::Cached,
typ: type_,
non_coherent: false,
},
arguments: ast::LdArgs {
dst: new_name,
src: old_name,
},
}));
}
result.push(statement);
}
statement => {
result.push(statement);
}
}
Ok(())
}

View File

@ -164,17 +164,16 @@ impl Deref for MemoryBuffer {
} }
pub(super) fn run<'input>( pub(super) fn run<'input>(
id_defs: &GlobalStringIdResolver<'input>, id_defs: GlobalStringIdentResolver2<'input>,
call_map: MethodsCallMap<'input>, directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
directives: Vec<Directive<'input>>,
) -> Result<MemoryBuffer, TranslateError> { ) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new(); let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED); let module = Module::new(&context, LLVM_UNNAMED);
let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs); let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives { for directive in directives {
match directive { match directive {
Directive::Variable(..) => todo!(), Directive2::Variable(..) => todo!(),
Directive::Method(method) => emit_ctx.emit_method(method)?, Directive2::Method(method) => emit_ctx.emit_method(method)?,
} }
} }
module.write_to_stderr(); module.write_to_stderr();
@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> {
context: LLVMContextRef, context: LLVMContextRef,
module: LLVMModuleRef, module: LLVMModuleRef,
builder: Builder, builder: Builder,
id_defs: &'a GlobalStringIdResolver<'input>, id_defs: &'a GlobalStringIdentResolver2<'input>,
resolver: ResolveIdent, resolver: ResolveIdent,
} }
@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new( fn new(
context: &Context, context: &Context,
module: &Module, module: &Module,
id_defs: &'a GlobalStringIdResolver<'input>, id_defs: &'a GlobalStringIdentResolver2<'input>,
) -> Self { ) -> Self {
ModuleEmitContext { ModuleEmitContext {
context: context.get(), context: context.get(),
@ -215,26 +214,50 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
LLVMCallConv::LLVMCCallConv as u32 LLVMCallConv::LLVMCCallConv as u32
} }
fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> { fn emit_method(
let func_decl = method.func_decl.borrow(); &mut self,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
let func_decl = method.func_decl;
let name = method let name = method
.import_as .import_as
.as_deref() .as_deref()
.unwrap_or_else(|| match func_decl.name { .or_else(|| match func_decl.name {
ast::MethodName::Kernel(name) => name, ast::MethodName::Kernel(name) => Some(name),
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id], ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
}); })
.ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?; let name = CString::new(name).map_err(|_| error_unreachable())?;
let fn_type = self.function_type( let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type), func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl.input_arguments.iter().map(|v| &v.v_type), func_decl
); .input_arguments
.iter()
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
}
for (i, param) in func_decl.input_arguments.iter().enumerate() { for (i, param) in func_decl.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) }; let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name); let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value); self.resolver.register(param.name, value);
if func_decl.name.is_kernel() {
let attr_kind = unsafe {
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
};
let attr = unsafe {
LLVMCreateTypeAttribute(
self.context,
attr_kind,
get_type(self.context, &param.v_type)?,
)
};
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
}
} }
let call_conv = if func_decl.name.is_kernel() { let call_conv = if func_decl.name.is_kernel() {
Self::kernel_call_convention() Self::kernel_call_convention()
@ -258,66 +281,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
} }
Ok(()) Ok(())
} }
}
fn function_type( fn get_input_argument_type(
&self, context: LLVMContextRef,
return_args: impl ExactSizeIterator<Item = &'a ast::Type>, v_type: &ptx_parser::Type,
input_args: impl ExactSizeIterator<Item = &'a ast::Type>, state_space: ptx_parser::StateSpace,
) -> LLVMTypeRef { ) -> Result<LLVMTypeRef, TranslateError> {
if return_args.len() == 0 { match state_space {
let mut input_args = input_args ptx_parser::StateSpace::ParamEntry => {
.map(|type_| match type_ { Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
ast::Type::Scalar(scalar) => match scalar {
ast::ScalarType::Pred => {
unsafe { LLVMInt1TypeInContext(self.context) }
}
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
unsafe { LLVMInt8TypeInContext(self.context) }
}
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
unsafe { LLVMInt16TypeInContext(self.context) }
}
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
unsafe { LLVMInt32TypeInContext(self.context) }
}
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
unsafe { LLVMInt64TypeInContext(self.context) }
}
ast::ScalarType::B128 => {
unsafe { LLVMInt128TypeInContext(self.context) }
}
ast::ScalarType::F16 => {
unsafe { LLVMHalfTypeInContext(self.context) }
}
ast::ScalarType::F32 => {
unsafe { LLVMFloatTypeInContext(self.context) }
}
ast::ScalarType::F64 => {
unsafe { LLVMDoubleTypeInContext(self.context) }
}
ast::ScalarType::BF16 => {
unsafe { LLVMBFloatTypeInContext(self.context) }
}
ast::ScalarType::U16x2 => todo!(),
ast::ScalarType::S16x2 => todo!(),
ast::ScalarType::F16x2 => todo!(),
ast::ScalarType::BF16x2 => todo!(),
},
ast::Type::Vector(_, _) => todo!(),
ast::Type::Array(_, _, _) => todo!(),
ast::Type::Pointer(_, _) => todo!(),
})
.collect::<Vec<_>>();
return unsafe {
LLVMFunctionType(
LLVMVoidTypeInContext(self.context),
input_args.as_mut_ptr(),
input_args.len() as u32,
0,
)
};
} }
todo!() ptx_parser::StateSpace::Reg => get_type(context, v_type),
_ => return Err(error_unreachable()),
} }
} }
@ -326,7 +302,7 @@ struct MethodEmitContext<'a, 'input> {
module: LLVMModuleRef, module: LLVMModuleRef,
method: LLVMValueRef, method: LLVMValueRef,
builder: LLVMBuilderRef, builder: LLVMBuilderRef,
id_defs: &'a GlobalStringIdResolver<'input>, id_defs: &'a GlobalStringIdentResolver2<'input>,
variables_builder: Builder, variables_builder: Builder,
resolver: &'a mut ResolveIdent, resolver: &'a mut ResolveIdent,
} }
@ -365,6 +341,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Statement::PtrAccess(_) => todo!(), Statement::PtrAccess(_) => todo!(),
Statement::RepackVector(_) => todo!(), Statement::RepackVector(_) => todo!(),
Statement::FunctionPointer(_) => todo!(), Statement::FunctionPointer(_) => todo!(),
Statement::VectorAccess(_) => todo!(),
}) })
} }
@ -414,7 +391,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
inst: ast::Instruction<SpirvWord>, inst: ast::Instruction<SpirvWord>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
match inst { match inst {
ast::Instruction::Mov { data, arguments } => todo!(), ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
@ -425,7 +402,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Or { data, arguments } => todo!(), ast::Instruction::Or { data, arguments } => todo!(),
ast::Instruction::And { data, arguments } => todo!(), ast::Instruction::And { data, arguments } => todo!(),
ast::Instruction::Bra { arguments } => todo!(), ast::Instruction::Bra { arguments } => todo!(),
ast::Instruction::Call { data, arguments } => todo!(), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
ast::Instruction::Cvt { data, arguments } => todo!(), ast::Instruction::Cvt { data, arguments } => todo!(),
ast::Instruction::Shr { data, arguments } => todo!(), ast::Instruction::Shr { data, arguments } => todo!(),
ast::Instruction::Shl { data, arguments } => todo!(), ast::Instruction::Shl { data, arguments } => todo!(),
@ -563,6 +540,70 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn emit_ret(&self, _data: ptx_parser::RetData) { fn emit_ret(&self, _data: ptx_parser::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) }; unsafe { LLVMBuildRetVoid(self.builder) };
} }
fn emit_call(
&mut self,
data: ptx_parser::CallDetails,
arguments: ptx_parser::CallArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if cfg!(debug_assertions) {
for (_, space) in data.return_arguments.iter() {
if *space != ast::StateSpace::Reg {
panic!()
}
}
for (_, space) in data.input_arguments.iter() {
if *space != ast::StateSpace::Reg {
panic!()
}
}
}
let name = match (&*data.return_arguments, &*arguments.return_arguments) {
([], []) => LLVM_UNNAMED.as_ptr(),
([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
_ => todo!(),
};
let type_ = get_function_type(
self.context,
data.return_arguments.iter().map(|(type_, space)| type_),
data.input_arguments
.iter()
.map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
)?;
let mut input_arguments = arguments
.input_arguments
.iter()
.map(|arg| self.resolver.value(*arg))
.collect::<Result<Vec<_>, _>>()?;
let llvm_fn = unsafe {
LLVMBuildCall2(
self.builder,
type_,
self.resolver.value(arguments.func)?,
input_arguments.as_mut_ptr(),
input_arguments.len() as u32,
name,
)
};
match &*arguments.return_arguments {
[] => {}
[name] => {
self.resolver.register(*name, llvm_fn);
}
_ => todo!(),
}
Ok(())
}
fn emit_mov(
&mut self,
_data: ptx_parser::MovDetails,
arguments: ptx_parser::MovArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.resolver
.register(arguments.dst, self.resolver.value(arguments.src)?);
Ok(())
}
} }
fn get_pointer_type<'ctx>( fn get_pointer_type<'ctx>(
@ -624,13 +665,34 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
} }
} }
fn get_function_type<'a>(
context: LLVMContextRef,
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,
_ => todo!(),
};
Ok(unsafe {
LLVMFunctionType(
return_type,
input_args.as_mut_ptr(),
input_args.len() as u32,
0,
)
})
}
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> { fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space { match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Sreg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo), ast::StateSpace::Param => Err(TranslateError::Todo),
ast::StateSpace::ParamEntry => Err(TranslateError::Todo), ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo), ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
@ -647,7 +709,7 @@ struct ResolveIdent {
} }
impl ResolveIdent { impl ResolveIdent {
fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self { fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
ResolveIdent { ResolveIdent {
words: HashMap::new(), words: HashMap::new(),
values: HashMap::new(), values: HashMap::new(),

View File

@ -469,7 +469,6 @@ fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass {
ast::StateSpace::Shared => spirv::StorageClass::Workgroup, ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
ast::StateSpace::Param => spirv::StorageClass::Function, ast::StateSpace::Param => spirv::StorageClass::Function,
ast::StateSpace::Reg => spirv::StorageClass::Function, ast::StateSpace::Reg => spirv::StorageClass::Function,
ast::StateSpace::Sreg => spirv::StorageClass::Input,
ast::StateSpace::ParamEntry ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc | ast::StateSpace::ParamFunc
| ast::StateSpace::SharedCluster | ast::StateSpace::SharedCluster
@ -693,7 +692,6 @@ fn emit_variable<'input>(
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
ast::StateSpace::Generic => todo!(), ast::StateSpace::Generic => todo!(),
ast::StateSpace::Sreg => todo!(),
ast::StateSpace::ParamEntry ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc | ast::StateSpace::ParamFunc
| ast::StateSpace::SharedCluster | ast::StateSpace::SharedCluster
@ -1563,6 +1561,7 @@ fn emit_function_body_ops<'input>(
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
} }
} }
Statement::VectorAccess(vector_access) => todo!(),
} }
} }
Ok(()) Ok(())

View File

@ -63,9 +63,9 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
} else { } else {
return Err(TranslateError::UntypedSymbol); return Err(TranslateError::UntypedSymbol);
}; };
if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg { if state_space == ast::StateSpace::Reg {
let (reg_type, reg_space) = self.id_def.get_typed(reg)?; let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
if !space_is_compatible(reg_space, ast::StateSpace::Reg) { if reg_space != ast::StateSpace::Reg {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
let reg_scalar_type = match reg_type { let reg_scalar_type = match reg_type {

View File

@ -0,0 +1,289 @@
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: Function2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(resolver, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: UnconditionalStatement,
) -> Result<(), TranslateError> {
let mut visitor = FlattenArguments::new(resolver, result);
let new_statement = statement.visit_map(&mut visitor)?;
visitor.result.push(new_statement);
Ok(())
}
struct FlattenArguments<'a, 'input> {
result: &'a mut Vec<ExpandedStatement>,
resolver: &'a mut GlobalStringIdentResolver2<'input>,
post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'input> FlattenArguments<'a, 'input> {
fn new(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
result: &'a mut Vec<ExpandedStatement>,
) -> Self {
FlattenArguments {
result,
resolver,
post_stmts: Vec::new(),
}
}
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
Ok(name)
}
fn reg_offset(
&mut self,
reg: SpirvWord,
offset: i32,
type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
(type_, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
if state_space == ast::StateSpace::Reg {
let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
if *reg_space != ast::StateSpace::Reg {
return Err(error_mismatched_type());
}
let reg_scalar_type = match reg_type {
ast::Type::Scalar(underlying_type) => *underlying_type,
_ => return Err(error_mismatched_type()),
};
let reg_type = reg_type.clone();
let id_constant_stmt = self
.resolver
.register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: reg_scalar_type,
value: ast::ImmediateValue::S64(offset as i64),
}));
let arith_details = match reg_scalar_type.kind() {
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
}),
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
})
}
_ => return Err(error_unreachable()),
};
let id_add_result = self
.resolver
.register_unnamed(Some((reg_type, state_space)));
self.result
.push(Statement::Instruction(ast::Instruction::Add {
data: arith_details,
arguments: ast::AddArgs {
dst: id_add_result,
src1: reg,
src2: id_constant_stmt,
},
}));
Ok(id_add_result)
} else {
let id_constant_stmt = self.resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
let dst = self
.resolver
.register_unnamed(Some((type_.clone(), state_space)));
self.result.push(Statement::PtrAccess(PtrAccess {
underlying_type: type_.clone(),
state_space: state_space,
dst,
ptr_src: reg,
offset_src: id_constant_stmt,
}));
Ok(dst)
}
}
fn immediate(
&mut self,
value: ast::ImmediateValue,
type_space: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) =
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
(*scalar, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
let id = self
.resolver
.register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
value,
}));
Ok(id)
}
fn vec_member(
&mut self,
vector_src: SpirvWord,
member: u8,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
if is_dst {
return Err(error_mismatched_type());
}
let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
(ast::Type::Vector(vector_width, scalar_t), space) => {
(*vector_width, *scalar_t, *space)
}
_ => return Err(error_mismatched_type()),
};
let temporary = self
.resolver
.register_unnamed(Some((scalar_type.into(), space)));
self.result.push(Statement::VectorAccess(VectorAccess {
scalar_type,
vector_width,
dst: temporary,
src: vector_src,
member: member,
}));
Ok(temporary)
}
fn vec_pack(
&mut self,
vecs: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
_ => return Err(error_mismatched_type()),
};
let temp_vec = self
.resolver
.register_unnamed(Some((scalar_t.into(), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: vecs,
relaxed_type_check,
});
if is_dst {
self.post_stmts.push(statement);
} else {
self.result.push(statement);
}
Ok(temp_vec)
}
}
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
for FlattenArguments<'a, 'b>
{
fn visit(
&mut self,
args: ast::ParsedOperand<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
match args {
ast::ParsedOperand::Reg(r) => self.reg(r),
ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
ast::ParsedOperand::RegOffset(reg, offset) => {
self.reg_offset(reg, offset, type_space, is_dst)
}
ast::ParsedOperand::VecMember(vec, member) => {
self.vec_member(vec, member, type_space, is_dst)
}
ast::ParsedOperand::VecPack(vecs) => {
self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
}
}
}
fn visit_ident(
&mut self,
name: <TypedOperand as ast::Operand>::Ident,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
_relaxed_type_check: bool,
) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
self.reg(name)
}
}
impl Drop for FlattenArguments<'_, '_> {
fn drop(&mut self) {
self.result.extend(self.post_stmts.drain(..));
}
}

View File

@ -273,7 +273,6 @@ fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
ast::StateSpace::Const => "const", ast::StateSpace::Const => "const",
ast::StateSpace::Local => "local", ast::StateSpace::Local => "local",
ast::StateSpace::Param => "param", ast::StateSpace::Param => "param",
ast::StateSpace::Sreg => "sreg",
ast::StateSpace::SharedCluster => "shared_cluster", ast::StateSpace::SharedCluster => "shared_cluster",
ast::StateSpace::ParamEntry => "param_entry", ast::StateSpace::ParamEntry => "param_entry",
ast::StateSpace::SharedCta => "shared_cta", ast::StateSpace::SharedCta => "shared_cta",

View File

@ -0,0 +1,209 @@
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
let declarations = SpecialRegistersMap2::generate_declarations(resolver);
let mut result = Vec::with_capacity(declarations.len() + directives.len());
let mut sreg_to_function =
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
for (sreg, declaration) in declarations {
let name = if let ast::MethodName::Func(name) = declaration.name {
name
} else {
return Err(error_unreachable());
};
result.push(UnconditionalDirective::Method(UnconditionalFunction {
func_decl: declaration,
globals: Vec::new(),
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
}));
sreg_to_function.insert(sreg, name);
}
let mut visitor = SpecialRegisterResolver {
resolver,
special_registers,
sreg_to_function,
result: Vec::new(),
};
directives
.into_iter()
.map(|directive| run_directive(&mut visitor, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
directive: UnconditionalDirective<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
})
}
fn run_method<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
method: UnconditionalFunction<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(visitor, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
result: &mut Vec<UnconditionalStatement>,
statement: UnconditionalStatement,
) -> Result<(), TranslateError> {
let converted_statement = statement.visit_map(visitor)?;
result.extend(visitor.result.drain(..));
result.push(converted_statement);
Ok(())
}
struct SpecialRegisterResolver<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
result: Vec<UnconditionalStatement>,
}
impl<'a, 'b, 'input>
ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
for SpecialRegisterResolver<'a, 'input>
{
fn visit(
&mut self,
operand: ast::ParsedOperand<SpirvWord>,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
map_operand(operand, &mut |ident, vector_index| {
self.replace_sreg(ident, vector_index, is_dst)
})
}
fn visit_ident(
&mut self,
args: SpirvWord,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.replace_sreg(args, None, is_dst)
}
}
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
fn replace_sreg(
&mut self,
name: SpirvWord,
vector_index: Option<u8>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
if let Some(sreg) = self.special_registers.get(name) {
if is_dst {
return Err(error_mismatched_type());
}
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => {
if inp_type != ast::ScalarType::U8 {
return Err(TranslateError::Unreachable);
}
let constant = self.resolver.register_unnamed(Some((
ast::Type::Scalar(inp_type),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: constant,
typ: inp_type,
value: ast::ImmediateValue::U64(idx as u64),
}));
vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
}
(None, None) => Vec::new(),
_ => return Err(error_mismatched_type()),
};
let return_type = sreg.get_function_return_type();
let fn_result = self
.resolver
.register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
let return_arguments = vec![(
fn_result,
ast::Type::Scalar(return_type),
ast::StateSpace::Reg,
)];
let data = ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
input_arguments: input_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
};
let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
func: self.sreg_to_function[&sreg],
input_arguments: input_arguments
.iter()
.map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
.collect(),
};
self.result
.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
Ok(fn_result)
} else {
Ok(name)
}
}
}
pub fn map_operand<T, U, Err>(
this: ast::ParsedOperand<T>,
fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
) -> Result<ast::ParsedOperand<U>, Err> {
Ok(match this {
ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
ast::ParsedOperand::RegOffset(ident, offset) => {
ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
}
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
ast::ParsedOperand::VecMember(ident, member) => {
ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
}
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
idents
.into_iter()
.map(|ident| fn_(ident, None))
.collect::<Result<Vec<_>, _>>()?,
),
})
}

View File

@ -0,0 +1,45 @@
use super::*;
pub(super) fn run<'input>(
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() {
run_directive(&mut result, &mut directive);
result.push(directive);
}
Ok(result)
}
fn run_directive<'input>(
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match directive {
Directive2::Variable(..) => {}
Directive2::Method(function2) => run_function(result, function2),
}
Ok(())
}
fn run_function<'input>(
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
) {
function.body = function.body.take().map(|statements| {
statements
.into_iter()
.filter_map(|statement| match statement {
Statement::Variable(var @ ast::Variable {
state_space:
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
..
}) => {
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
None
}
s => Some(s),
})
.collect()
});
}

View File

@ -0,0 +1,338 @@
use super::*;
use ptx_parser::VisitorMap;
use rustc_hash::FxHashSet;
// This pass:
// * Turns all .local, .param and .reg in-body variables into .local variables
// (if _not_ an input method argument)
// * Inserts explicit `ld`/`st` for newly converted .reg variables
// * Fixup state space of all existing `ld`/`st` instructions into newly
// converted variables
// * Turns `.entry` input arguments into param::entry and all related `.param`
// loads into `param::entry` loads
// * All `.func` input arguments are turned into `.reg` arguments by another
// pass, so we do nothing there
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
let visitor = InsertMemSSAVisitor::new(resolver);
Directive2::Method(run_method(visitor, method)?)
}
})
}
fn run_method<'a, 'input>(
mut visitor: InsertMemSSAVisitor<'a, 'input>,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
for arg in func_decl.return_arguments.iter_mut() {
visitor.visit_variable(arg)?;
}
let is_kernel = func_decl.name.is_kernel();
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
let old_name = arg.name;
let old_space = arg.state_space;
let new_space = ast::StateSpace::ParamEntry;
let new_name = visitor
.resolver
.register_unnamed(Some((arg.v_type.clone(), new_space)));
visitor.input_argument(old_name, new_name, old_space);
arg.name = new_name;
arg.state_space = new_space;
}
};
let body = method
.body
.map(move |statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(&mut visitor, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
func_decl: func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'a, 'input>(
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
result: &mut Vec<ExpandedStatement>,
statement: ExpandedStatement,
) -> Result<(), TranslateError> {
match statement {
Statement::Variable(mut var) => {
visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
}
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
let instruction = visitor.visit_ld(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::Instruction(ast::Instruction::St { data, arguments }) => {
let instruction = visitor.visit_st(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
s => {
let new_statement = s.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
}
Ok(())
}
struct InsertMemSSAVisitor<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
variables: FxHashMap<SpirvWord, RemapAction>,
pre: Vec<ast::Instruction<SpirvWord>>,
post: Vec<ast::Instruction<SpirvWord>>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
resolver,
variables: FxHashMap::default(),
pre: Vec::new(),
post: Vec::new(),
}
}
fn input_argument(
&mut self,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
if old_space != ast::StateSpace::Param {
return Err(error_unreachable());
}
self.variables.insert(
old_name,
RemapAction::LDStSpaceChange {
name: new_name,
old_space,
new_space: ast::StateSpace::ParamEntry,
},
);
Ok(())
}
fn variable(
&mut self,
type_: &ast::Type,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables.insert(
old_name,
RemapAction::PreLdPostSt {
name: new_name,
type_: type_.clone(),
},
);
}
ast::StateSpace::Param => {
self.variables.insert(
old_name,
RemapAction::LDStSpaceChange {
old_space,
new_space: ast::StateSpace::Local,
name: new_name,
},
);
}
// Good as-is
ast::StateSpace::Local => {}
// Will be pulled into global scope later
ast::StateSpace::Generic
| ast::StateSpace::SharedCluster
| ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::SharedCta
| ast::StateSpace::Shared => {}
ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
return Err(error_unreachable())
}
})
}
fn visit_st(
&self,
mut data: ast::StData,
mut arguments: ast::StArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src1) {
match remap {
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
} => {
if data.state_space != *old_space {
return Err(error_mismatched_type());
}
data.state_space = *new_space;
arguments.src1 = *name;
}
}
}
Ok(ast::Instruction::St { data, arguments })
}
fn visit_ld(
&self,
mut data: ast::LdDetails,
mut arguments: ast::LdArgs<SpirvWord>,
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
name,
} => {
if data.state_space != *old_space {
return Err(error_mismatched_type());
}
data.state_space = *new_space;
arguments.src = *name;
}
}
}
Ok(ast::Instruction::Ld { data, arguments })
}
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
if var.state_space != ast::StateSpace::Local {
let old_name = var.name;
let old_space = var.state_space;
let new_space = ast::StateSpace::Local;
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name;
var.state_space = new_space;
}
Ok(())
}
}
impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
for InsertMemSSAVisitor<'a, 'input>
{
fn visit(
&mut self,
ident: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
if let Some(remap) = self.variables.get(&ident) {
match remap {
RemapAction::PreLdPostSt { name, type_ } => {
if is_dst {
let temp = self
.resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.post.push(ast::Instruction::St {
data: ast::StData {
state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak,
caching: ast::StCacheOperator::Writethrough,
typ: type_.clone(),
},
arguments: ast::StArgs {
src1: *name,
src2: temp,
},
});
Ok(temp)
} else {
let temp = self
.resolver
.register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
self.pre.push(ast::Instruction::Ld {
data: ast::LdDetails {
state_space: ast::StateSpace::Local,
qualifier: ast::LdStQualifier::Weak,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: temp,
src: *name,
},
});
Ok(temp)
}
}
RemapAction::LDStSpaceChange { .. } => {
return Err(error_mismatched_type());
}
}
} else {
Ok(ident)
}
}
fn visit_ident(
&mut self,
args: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.visit(args, type_space, is_dst, relaxed_type_check)
}
}
#[derive(Clone)]
enum RemapAction {
PreLdPostSt {
name: SpirvWord,
type_: ast::Type,
},
LDStSpaceChange {
old_space: ast::StateSpace,
new_space: ast::StateSpace,
name: SpirvWord,
},
}

View File

@ -45,6 +45,13 @@ pub(super) fn run(
Statement::RepackVector(repack), Statement::RepackVector(repack),
)?; )?;
} }
Statement::VectorAccess(vector_access) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::VectorAccess(vector_access),
)?;
}
s @ Statement::Conditional(_) s @ Statement::Conditional(_)
| s @ Statement::Conversion(_) | s @ Statement::Conversion(_)
| s @ Statement::Label(_) | s @ Statement::Label(_)
@ -128,7 +135,7 @@ pub(crate) fn default_implicit_conversion(
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if instruction_space == ast::StateSpace::Reg { if instruction_space == ast::StateSpace::Reg {
if space_is_compatible(operand_space, ast::StateSpace::Reg) { if operand_space == ast::StateSpace::Reg {
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
(operand_type, instruction_type) (operand_type, instruction_type)
{ {
@ -142,7 +149,7 @@ pub(crate) fn default_implicit_conversion(
return Ok(Some(ConversionKind::AddressOf)); return Ok(Some(ConversionKind::AddressOf));
} }
} }
if !space_is_compatible(instruction_space, operand_space) { if instruction_space != operand_space {
default_implicit_conversion_space( default_implicit_conversion_space(
(operand_space, operand_type), (operand_space, operand_type),
(instruction_space, instruction_type), (instruction_space, instruction_type),
@ -161,7 +168,7 @@ fn is_addressable(this: ast::StateSpace) -> bool {
| ast::StateSpace::Global | ast::StateSpace::Global
| ast::StateSpace::Local | ast::StateSpace::Local
| ast::StateSpace::Shared => true, | ast::StateSpace::Shared => true,
ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta | ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry | ast::StateSpace::ParamEntry
@ -178,7 +185,7 @@ fn default_implicit_conversion_space(
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{ {
Ok(Some(ConversionKind::PtrToPtr)) Ok(Some(ConversionKind::PtrToPtr))
} else if space_is_compatible(operand_space, ast::StateSpace::Reg) { } else if operand_space == ast::StateSpace::Reg {
match operand_type { match operand_type {
ast::Type::Pointer(operand_ptr_type, operand_ptr_space) ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
if *operand_ptr_space == instruction_space => if *operand_ptr_space == instruction_space =>
@ -210,7 +217,7 @@ fn default_implicit_conversion_space(
}, },
_ => Err(error_mismatched_type()), _ => Err(error_mismatched_type()),
} }
} else if space_is_compatible(instruction_space, ast::StateSpace::Reg) { } else if instruction_space == ast::StateSpace::Reg {
match instruction_type { match instruction_type {
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
if operand_space == *instruction_ptr_space => if operand_space == *instruction_ptr_space =>
@ -234,7 +241,7 @@ fn default_implicit_conversion_type(
operand_type: &ast::Type, operand_type: &ast::Type,
instruction_type: &ast::Type, instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if space_is_compatible(space, ast::StateSpace::Reg) { if space == ast::StateSpace::Reg {
if should_bitcast(instruction_type, operand_type) { if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default)) Ok(Some(ConversionKind::Default))
} else { } else {
@ -257,8 +264,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool {
| ast::StateSpace::Param | ast::StateSpace::Param
| ast::StateSpace::ParamEntry | ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc | ast::StateSpace::ParamFunc
| ast::StateSpace::Generic | ast::StateSpace::Generic => false,
| ast::StateSpace::Sreg => false,
} }
} }
@ -294,7 +300,7 @@ pub(crate) fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type), (operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if !space_is_compatible(operand_space, instruction_space) { if operand_space != instruction_space {
return Err(TranslateError::MismatchedType); return Err(TranslateError::MismatchedType);
} }
if operand_type == instruction_type { if operand_type == instruction_type {
@ -371,7 +377,7 @@ pub(crate) fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type), (operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
if !space_is_compatible(operand_space, instruction_space) { if operand_space != instruction_space {
return Err(error_mismatched_type()); return Err(error_mismatched_type());
} }
if operand_type == instruction_type { if operand_type == instruction_type {

View File

@ -0,0 +1,426 @@
use std::mem;
use super::*;
use ptx_parser as ast;
/*
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- st.param [x] y (used as function return arguments) same rule as above applies
- generic/global ld: for instruction `ld x, [y]`, y must be of type
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
documented special ld/st/cvt conversion rules are applied to dst
- generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer
*/
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {
method.body = method
.body
.map(|statements| run_statements(resolver, statements))
.transpose()?;
Directive2::Method(method)
}
})
}
fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
func: Vec<ExpandedStatement>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
insert_implicit_conversions_impl(resolver, &mut result, s)?;
}
Ok(result)
}
fn insert_implicit_conversions_impl<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
func: &mut Vec<ExpandedStatement>,
stmt: ExpandedStatement,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
&mut |operand,
type_state: Option<(&ast::Type, ast::StateSpace)>,
is_dst,
relaxed_type_check| {
let (instr_type, instruction_space) = match type_state {
None => return Ok(operand),
Some(t) => t,
};
let (operand_type, operand_space) = resolver.get_typed(operand)?;
let conversion_fn = if relaxed_type_check {
if is_dst {
should_convert_relaxed_dst_wrapper
} else {
should_convert_relaxed_src_wrapper
}
} else {
default_implicit_conversion
};
match conversion_fn(
(*operand_space, &operand_type),
(instruction_space, instr_type),
)? {
Some(conv_kind) => {
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
let mut from_type = instr_type.clone();
let mut from_space = instruction_space;
let mut to_type = operand_type.clone();
let mut to_space = *operand_space;
let mut src =
resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
let mut dst = operand;
let result = Ok::<_, TranslateError>(src);
if !is_dst {
mem::swap(&mut src, &mut dst);
mem::swap(&mut from_type, &mut to_type);
mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
from_type,
from_space,
to_type,
to_space,
kind: conv_kind,
}));
result
}
None => Ok(operand),
}
},
)?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
pub(crate) fn default_implicit_conversion(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if instruction_space == ast::StateSpace::Reg {
if operand_space == ast::StateSpace::Reg {
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
(operand_type, instruction_type)
{
if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
}
} else if is_addressable(operand_space) {
return Ok(Some(ConversionKind::AddressOf));
}
}
if instruction_space != operand_space {
default_implicit_conversion_space(
(operand_space, operand_type),
(instruction_space, instruction_type),
)
} else if instruction_type != operand_type {
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
} else {
Ok(None)
}
}
fn is_addressable(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Const
| ast::StateSpace::Generic
| ast::StateSpace::Global
| ast::StateSpace::Local
| ast::StateSpace::Shared => true,
ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => todo!(),
}
}
// Space is different
fn default_implicit_conversion_space(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{
Ok(Some(ConversionKind::PtrToPtr))
} else if operand_space == ast::StateSpace::Reg {
match operand_type {
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
if *operand_ptr_space == instruction_space =>
{
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
// TODO: 32 bit
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
ast::StateSpace::Global
| ast::StateSpace::Generic
| ast::StateSpace::Const
| ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(error_mismatched_type()),
},
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr))
}
_ => Err(error_mismatched_type()),
},
_ => Err(error_mismatched_type()),
}
} else if instruction_space == ast::StateSpace::Reg {
match instruction_type {
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
if operand_space == *instruction_ptr_space =>
{
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
_ => Err(error_mismatched_type()),
}
} else {
Err(error_mismatched_type())
}
}
// Space is same, but type is different
fn default_implicit_conversion_type(
space: ast::StateSpace,
operand_type: &ast::Type,
instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
if space == ast::StateSpace::Reg {
if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default))
} else {
Err(TranslateError::MismatchedType)
}
} else {
Ok(Some(ConversionKind::PtrToPtr))
}
}
fn coerces_to_generic(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta
| ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true,
ast::StateSpace::Reg
| ast::StateSpace::Param
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
| ast::StateSpace::Generic => false,
}
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ast::ScalarKind::Signed => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned
}
ast::ScalarKind::Unsigned => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed
}
ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
pub(crate) fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space {
return Err(TranslateError::MismatchedType);
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed => {
if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
}
} else {
None
}
}
ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_dst(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
pub(crate) fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if operand_space != instruction_space {
return Err(error_mismatched_type());
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(error_mismatched_type()),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_src(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}

View File

@ -189,7 +189,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
return Ok(symbol); return Ok(symbol);
}; };
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable { if var_space != ast::StateSpace::Reg || !is_variable {
return Ok(symbol); return Ok(symbol);
}; };
let member_index = match member_index { let member_index = match member_index {

View File

@ -1,5 +1,6 @@
use ptx_parser as ast; use ptx_parser as ast;
use rspirv::{binary::Assemble, dr}; use rspirv::{binary::Assemble, dr};
use rustc_hash::FxHashMap;
use std::hash::Hash; use std::hash::Hash;
use std::num::NonZeroU8; use std::num::NonZeroU8;
use std::{ use std::{
@ -12,20 +13,31 @@ use std::{
mem, mem,
rc::Rc, rc::Rc,
}; };
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
mod convert_dynamic_shared_memory_usage; mod convert_dynamic_shared_memory_usage;
mod convert_to_stateful_memory_access; mod convert_to_stateful_memory_access;
mod convert_to_typed; mod convert_to_typed;
mod deparamize_functions;
pub(crate) mod emit_llvm; pub(crate) mod emit_llvm;
mod emit_spirv; mod emit_spirv;
mod expand_arguments; mod expand_arguments;
mod expand_operands;
mod extract_globals; mod extract_globals;
mod fix_special_registers; mod fix_special_registers;
mod fix_special_registers2;
mod hoist_globals;
mod insert_explicit_load_store;
mod insert_implicit_conversions; mod insert_implicit_conversions;
mod insert_implicit_conversions2;
mod insert_mem_ssa_statements; mod insert_mem_ssa_statements;
mod normalize_identifiers; mod normalize_identifiers;
mod normalize_identifiers2;
mod normalize_labels; mod normalize_labels;
mod normalize_predicates; mod normalize_predicates;
mod normalize_predicates2;
mod resolve_function_pointers;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
@ -57,7 +69,30 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
})?; })?;
normalize_variable_decls(&mut directives); normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives); let denorm_information = compute_denorm_information(&directives);
let llvm_ir = emit_llvm::run(&id_defs, call_map, directives)?; todo!()
/*
let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?;
Ok(Module {
llvm_ir,
kernel_info: HashMap::new(),
}) */
}
pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
Ok(Module { Ok(Module {
llvm_ir, llvm_ir,
kernel_info: HashMap::new(), kernel_info: HashMap::new(),
@ -319,7 +354,7 @@ pub struct KernelInfo {
pub uses_shared_mem: bool, pub uses_shared_mem: bool,
} }
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
enum PtxSpecialRegister { enum PtxSpecialRegister {
Tid, Tid,
Ntid, Ntid,
@ -342,6 +377,17 @@ impl PtxSpecialRegister {
} }
} }
fn as_str(self) -> &'static str {
match self {
Self::Tid => "%tid",
Self::Ntid => "%ntid",
Self::Ctaid => "%ctaid",
Self::Nctaid => "%nctaid",
Self::Clock => "%clock",
Self::LanemaskLt => "%lanemask_lt",
}
}
fn get_type(self) -> ast::Type { fn get_type(self) -> ast::Type {
match self { match self {
PtxSpecialRegister::Tid PtxSpecialRegister::Tid
@ -525,7 +571,7 @@ impl<'b> NumericIdResolver<'b> {
Some(Some(x)) => Ok(x.clone()), Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol), Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) { None => match self.special_registers.get(id) {
Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), Some(x) => Ok((x.get_type(), ast::StateSpace::Reg, true)),
None => match self.global_type_check.get(&id) { None => match self.global_type_check.get(&id) {
Some(Some(result)) => Ok(result.clone()), Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol), Some(None) | None => Err(TranslateError::UntypedSymbol),
@ -722,6 +768,7 @@ enum Statement<I, P: ast::Operand> {
PtrAccess(PtrAccess<P>), PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails), RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails), FunctionPointer(FunctionPointerDetails),
VectorAccess(VectorAccess),
} }
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> { impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@ -890,6 +937,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
offset_src, offset_src,
}) })
} }
Statement::VectorAccess(VectorAccess {
scalar_type,
vector_width,
dst,
src: vector_src,
member,
}) => {
let dst: SpirvWord = visitor.visit_ident(
dst,
Some((&scalar_type.into(), ast::StateSpace::Reg)),
true,
false,
)?;
let src = visitor.visit_ident(
vector_src,
Some((
&ast::Type::Vector(vector_width, scalar_type),
ast::StateSpace::Reg,
)),
false,
false,
)?;
Statement::VectorAccess(VectorAccess {
scalar_type,
vector_width,
dst,
src,
member,
})
}
Statement::RepackVector(RepackVectorDetails { Statement::RepackVector(RepackVectorDetails {
is_extract, is_extract,
typ, typ,
@ -1207,12 +1284,6 @@ impl<
} }
} }
fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
this == other
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}
fn register_external_fn_call<'a>( fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver, id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>, ptx_impl_imports: &mut HashMap<String, Directive>,
@ -1450,6 +1521,7 @@ fn compute_denorm_information<'input>(
Statement::Label(_) => {} Statement::Label(_) => {}
Statement::Variable(_) => {} Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {} Statement::PtrAccess { .. } => {}
Statement::VectorAccess { .. } => {}
Statement::RepackVector(_) => {} Statement::RepackVector(_) => {}
Statement::FunctionPointer(_) => {} Statement::FunctionPointer(_) => {}
} }
@ -1663,3 +1735,278 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
} }
} }
} }
pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
Method(Function2<'input, Instruction, Operand>),
}
pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
pub globals: Vec<ast::Variable<SpirvWord>>,
pub body: Option<Vec<Statement<Instruction, Operand>>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective,
}
type NormalizedDirective2<'input> = Directive2<
'input,
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
),
ast::ParsedOperand<SpirvWord>,
>;
type NormalizedFunction2<'input> = Function2<
'input,
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
),
ast::ParsedOperand<SpirvWord>,
>;
type UnconditionalDirective<'input> = Directive2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>;
type UnconditionalFunction<'input> = Function2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>;
struct GlobalStringIdentResolver2<'input> {
pub(crate) current_id: SpirvWord,
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
}
impl<'input> GlobalStringIdentResolver2<'input> {
fn new(spirv_word: SpirvWord) -> Self {
Self {
current_id: spirv_word,
ident_map: FxHashMap::default(),
}
}
fn register_named(
&mut self,
name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> SpirvWord {
let new_id = self.current_id;
self.ident_map.insert(
new_id,
IdentEntry {
name: Some(name),
type_space,
},
);
self.current_id.0 += 1;
new_id
}
fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
let new_id = self.current_id;
self.ident_map.insert(
new_id,
IdentEntry {
name: None,
type_space,
},
);
self.current_id.0 += 1;
new_id
}
fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
match self.ident_map.get(&id) {
Some(IdentEntry {
type_space: Some(type_space),
..
}) => Ok(type_space),
_ => Err(error_unknown_symbol()),
}
}
}
struct IdentEntry<'input> {
name: Option<Cow<'input, str>>,
type_space: Option<(ast::Type, ast::StateSpace)>,
}
struct ScopedResolver<'input, 'b> {
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
scopes: Vec<ScopeMarker<'input>>,
}
impl<'input, 'b> ScopedResolver<'input, 'b> {
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
flat_resolver,
scopes: vec![ScopeMarker::new()],
}
}
fn start_scope(&mut self) {
self.scopes.push(ScopeMarker::new());
}
fn end_scope(&mut self) {
let scope = self.scopes.pop().unwrap();
scope.flush(self.flat_resolver);
}
fn add(
&mut self,
name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let result = self.flat_resolver.current_id;
self.flat_resolver.current_id.0 += 1;
let current_scope = self.scopes.last_mut().unwrap();
if current_scope
.name_to_ident
.insert(name.clone(), result)
.is_some()
{
return Err(error_unknown_symbol());
}
current_scope.ident_map.insert(
result,
IdentEntry {
name: Some(name),
type_space,
},
);
Ok(result)
}
fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
self.scopes
.iter()
.rev()
.find_map(|resolver| resolver.name_to_ident.get(name).copied())
.ok_or_else(|| error_unreachable())
}
fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
let current_scope = self.scopes.last().unwrap();
current_scope
.name_to_ident
.get(label)
.copied()
.ok_or_else(|| error_unreachable())
}
}
struct ScopeMarker<'input> {
ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
}
impl<'input> ScopeMarker<'input> {
fn new() -> Self {
Self {
ident_map: FxHashMap::default(),
name_to_ident: FxHashMap::default(),
}
}
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
resolver.ident_map.extend(self.ident_map);
}
}
struct SpecialRegistersMap2 {
reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
}
impl SpecialRegistersMap2 {
fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
let mut result = SpecialRegistersMap2 {
reg_to_id: FxHashMap::default(),
id_to_reg: FxHashMap::default(),
};
for sreg in PtxSpecialRegister::iter() {
let text = sreg.as_str();
let id = resolver.add(
Cow::Borrowed(text),
Some((sreg.get_type(), ast::StateSpace::Reg)),
)?;
result.reg_to_id.insert(sreg, id);
result.id_to_reg.insert(id, sreg);
}
Ok(result)
}
fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
self.id_to_reg.get(&id).copied()
}
fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
match self.reg_to_id.entry(reg) {
hash_map::Entry::Occupied(e) => *e.get(),
hash_map::Entry::Vacant(e) => {
let numeric_id = SpirvWord(current_id.0);
current_id.0 += 1;
e.insert(numeric_id);
self.id_to_reg.insert(numeric_id, reg);
numeric_id
}
}
}
fn generate_declarations<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
) -> impl ExactSizeIterator<
Item = (
PtxSpecialRegister,
ast::MethodDeclaration<'input, SpirvWord>,
),
> + 'a {
PtxSpecialRegister::iter().map(|sreg| {
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
let name =
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
let return_type = sreg.get_function_return_type();
let input_type = sreg.get_function_return_type();
(
sreg,
ast::MethodDeclaration {
return_arguments: vec![ast::Variable {
align: None,
v_type: return_type.into(),
state_space: ast::StateSpace::Reg,
name: resolver
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
}],
name: name,
input_arguments: vec![ast::Variable {
align: None,
v_type: input_type.into(),
state_space: ast::StateSpace::Reg,
name: resolver
.register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
}],
shared_mem: None,
},
)
})
}
}
pub struct VectorAccess {
scalar_type: ast::ScalarType,
vector_width: u8,
dst: SpirvWord,
src: SpirvWord,
member: u8,
}

View File

@ -0,0 +1,199 @@
use super::*;
use ptx_parser as ast;
use rustc_hash::FxHashMap;
pub(crate) fn run<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
resolver.start_scope();
let result = directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
fn run_directive<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> {
Ok(match directive {
ast::Directive::Variable(linking, var) => {
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
}
ast::Directive::Method(linking, directive) => {
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
}
})
}
fn run_method<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> {
let name = match method.func_directive.name {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(text) => {
ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
}
};
resolver.start_scope();
let func_decl = run_function_decl(resolver, method.func_directive, name)?;
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
run_statements(resolver, &mut result, statements)?;
Ok::<_, TranslateError>(result)
})
.transpose()?;
resolver.end_scope();
Ok(Function2 {
func_decl,
globals: Vec::new(),
body,
import_as: None,
tuning: method.tuning,
linkage,
})
}
fn run_function_decl<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>,
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
assert!(func_directive.shared_mem.is_none());
let return_arguments = func_directive
.return_arguments
.into_iter()
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
let input_arguments = func_directive
.input_arguments
.into_iter()
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
Ok(ast::MethodDeclaration {
return_arguments,
name,
input_arguments,
shared_mem: None,
})
}
fn run_variable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable {
name: resolver.add(
Cow::Borrowed(variable.name),
Some((variable.v_type.clone(), variable.state_space)),
)?,
align: variable.align,
v_type: variable.v_type,
state_space: variable.state_space,
array_init: variable.array_init,
})
}
fn run_statements<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<(), TranslateError> {
for statement in statements.iter() {
match statement {
ast::Statement::Label(label) => {
resolver.add(Cow::Borrowed(*label), None)?;
}
_ => {}
}
}
for statement in statements {
match statement {
ast::Statement::Label(label) => {
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
}
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
ast::Statement::Instruction(predicate, instruction) => {
result.push(Statement::Instruction((
predicate
.map(|pred| {
Ok::<_, TranslateError>(ast::PredAt {
not: pred.not,
label: resolver.get(pred.label)?,
})
})
.transpose()?,
run_instruction(resolver, instruction)?,
)))
}
ast::Statement::Block(block) => {
resolver.start_scope();
run_statements(resolver, result, block)?;
resolver.end_scope();
}
}
}
Ok(())
}
fn run_instruction<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(instruction, &mut |name: &'input str,
_: Option<(
&ast::Type,
ast::StateSpace,
)>,
_,
_| {
resolver.get(&name)
})
}
fn run_multivariable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> {
match variable.count {
Some(count) => {
for i in 0..count {
let name = Cow::Owned(format!("{}{}", variable.var.name, i));
let ident = resolver.add(
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable {
align: variable.var.align,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident,
array_init: variable.var.array_init.clone(),
}));
}
}
None => {
let name = Cow::Borrowed(variable.var.name);
let ident = resolver.add(
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable {
align: variable.var.align,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident,
array_init: variable.var.array_init.clone(),
}));
}
}
Ok(())
}

View File

@ -26,6 +26,7 @@ pub(super) fn run(
| Statement::Constant(..) | Statement::Constant(..)
| Statement::Label(..) | Statement::Label(..)
| Statement::PtrAccess { .. } | Statement::PtrAccess { .. }
| Statement::VectorAccess { .. }
| Statement::RepackVector(..) | Statement::RepackVector(..)
| Statement::FunctionPointer(..) => {} | Statement::FunctionPointer(..) => {}
} }

View File

@ -0,0 +1,84 @@
use super::*;
use ptx_parser as ast;
pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<NormalizedDirective2<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: NormalizedDirective2<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: NormalizedFunction2<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(resolver, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<UnconditionalStatement>,
statement: NormalizedStatement,
) -> Result<(), TranslateError> {
Ok(match statement {
Statement::Label(label) => result.push(Statement::Label(label)),
Statement::Variable(var) => result.push(Statement::Variable(var)),
Statement::Instruction((predicate, instruction)) => {
if let Some(pred) = predicate {
let if_true = resolver.register_unnamed(None);
let if_false = resolver.register_unnamed(None);
let folded_bra = match &instruction {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None,
};
let mut branch = BrachCondition {
predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
}
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
result.push(Statement::Instruction(instruction));
}
result.push(Statement::Label(if_false));
} else {
result.push(Statement::Instruction(instruction));
}
}
_ => return Err(error_unreachable()),
})
}

View File

@ -0,0 +1,82 @@
use super::*;
use ptx_parser as ast;
use rustc_hash::FxHashSet;
pub(crate) fn run<'input>(
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
let mut functions = FxHashSet::default();
directives
.into_iter()
.map(|directive| run_directive(&mut functions, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
functions: &mut FxHashSet<SpirvWord>,
directive: UnconditionalDirective<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
{
let func_decl = &method.func_decl;
match func_decl.name {
ptx_parser::MethodName::Kernel(_) => {}
ptx_parser::MethodName::Func(name) => {
functions.insert(name);
}
}
}
Directive2::Method(run_method(functions, method)?)
}
})
}
fn run_method<'input>(
functions: &mut FxHashSet<SpirvWord>,
method: UnconditionalFunction<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
let body = method
.body
.map(|statements| {
statements
.into_iter()
.map(|statement| run_statement(functions, statement))
.collect::<Result<Vec<_>, _>>()
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'input>(
functions: &mut FxHashSet<SpirvWord>,
statement: UnconditionalStatement,
) -> Result<UnconditionalStatement, TranslateError> {
Ok(match statement {
Statement::Instruction(ast::Instruction::Mov {
data,
arguments:
ast::MovArgs {
dst: ast::ParsedOperand::Reg(dst_reg),
src: ast::ParsedOperand::Reg(src_reg),
},
}) if functions.contains(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(error_mismatched_type());
}
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg,
src: src_reg,
})
}
s => s,
})
}

View File

@ -236,7 +236,7 @@ fn test_hip_assert<
output: &mut [Output], output: &mut [Output],
) -> Result<(), Box<dyn error::Error + 'a>> { ) -> Result<(), Box<dyn error::Error + 'a>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap(); let llvm_ir = pass::to_llvm_module2(ast).unwrap();
let name = CString::new(name)?; let name = CString::new(name)?;
let result = let result =
run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?; run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?;

View File

@ -1049,6 +1049,15 @@ impl<'input, ID> MethodName<'input, ID> {
} }
} }
impl<'input> MethodName<'input, &'input str> {
pub fn text(&self) -> &'input str {
match self {
MethodName::Kernel(name) => *name,
MethodName::Func(name) => *name,
}
}
}
bitflags! { bitflags! {
pub struct LinkingDirective: u8 { pub struct LinkingDirective: u8 {
const NONE = 0b000; const NONE = 0b000;
@ -1291,7 +1300,12 @@ impl<T: Operand> CallArgs<T> {
.iter() .iter()
.zip(details.return_arguments.iter()) .zip(details.return_arguments.iter())
{ {
visitor.visit_ident(param, Some((type_, *space)), true, false)?; visitor.visit_ident(
param,
Some((type_, *space)),
*space == StateSpace::Reg,
false,
)?;
} }
visitor.visit_ident(&self.func, None, false, false)?; visitor.visit_ident(&self.func, None, false, false)?;
for (param, (type_, space)) in self for (param, (type_, space)) in self
@ -1315,7 +1329,12 @@ impl<T: Operand> CallArgs<T> {
.iter_mut() .iter_mut()
.zip(details.return_arguments.iter()) .zip(details.return_arguments.iter())
{ {
visitor.visit_ident(param, Some((type_, *space)), true, false)?; visitor.visit_ident(
param,
Some((type_, *space)),
*space == StateSpace::Reg,
false,
)?;
} }
visitor.visit_ident(&mut self.func, None, false, false)?; visitor.visit_ident(&mut self.func, None, false, false)?;
for (param, (type_, space)) in self for (param, (type_, space)) in self
@ -1339,7 +1358,12 @@ impl<T: Operand> CallArgs<T> {
.into_iter() .into_iter()
.zip(details.return_arguments.iter()) .zip(details.return_arguments.iter())
.map(|(param, (type_, space))| { .map(|(param, (type_, space))| {
visitor.visit_ident(param, Some((type_, *space)), true, false) visitor.visit_ident(
param,
Some((type_, *space)),
*space == StateSpace::Reg,
false,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let func = visitor.visit_ident(self.func, None, false, false)?; let func = visitor.visit_ident(self.func, None, false, false)?;

View File

@ -1499,7 +1499,6 @@ derive_parser!(
pub enum StateSpace { pub enum StateSpace {
Reg, Reg,
Generic, Generic,
Sreg,
} }
#[derive(Copy, Clone, PartialEq, Eq, Hash)] #[derive(Copy, Clone, PartialEq, Eq, Hash)]