Add support for atom.cas

This commit is contained in:
Andrzej Janik
2024-09-26 21:38:50 +02:00
parent 820eaf8ada
commit f0b3bf8013
4 changed files with 67 additions and 4 deletions

View File

@ -130,4 +130,19 @@ LLVMValueRef LLVMZludaBuildAtomicRMW(LLVMBuilderRef B, LLVMZludaAtomicRMWBinOp o
context.getOrInsertSyncScopeID(scope)));
}
LLVMValueRef LLVMZludaBuildAtomicCmpXchg(LLVMBuilderRef B, LLVMValueRef Ptr,
LLVMValueRef Cmp, LLVMValueRef New,
char *scope,
LLVMAtomicOrdering SuccessOrdering,
LLVMAtomicOrdering FailureOrdering)
{
auto builder = llvm::unwrap(B);
LLVMContext &context = builder->getContext();
return wrap(builder->CreateAtomicCmpXchg(
unwrap(Ptr), unwrap(Cmp), unwrap(New), MaybeAlign(),
mapFromLLVMOrdering(SuccessOrdering),
mapFromLLVMOrdering(FailureOrdering),
context.getOrInsertSyncScopeID(scope)));
}
LLVM_C_EXTERN_C_END

View File

@ -39,4 +39,14 @@ extern "C" {
scope: *const i8,
ordering: LLVMAtomicOrdering,
) -> LLVMValueRef;
pub fn LLVMZludaBuildAtomicCmpXchg(
B: LLVMBuilderRef,
Ptr: LLVMValueRef,
Cmp: LLVMValueRef,
New: LLVMValueRef,
scope: *const i8,
SuccessOrdering: LLVMAtomicOrdering,
FailureOrdering: LLVMAtomicOrdering,
) -> LLVMValueRef;
}

View File

@ -26,7 +26,10 @@ use std::ptr;
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::{core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp};
use llvm_zluda::{
core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp,
LLVMZludaBuildAtomicCmpXchg,
};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
@ -457,7 +460,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Selp { data, arguments } => todo!(),
ast::Instruction::Bar { data, arguments } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => todo!(),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
ast::Instruction::Div { data, arguments } => todo!(),
ast::Instruction::Neg { data, arguments } => todo!(),
ast::Instruction::Sin { data, arguments } => todo!(),
@ -724,6 +727,33 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
fn emit_atom_cas(
&mut self,
data: ptx_parser::AtomCasDetails,
arguments: ptx_parser::AtomCasArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let src3 = self.resolver.value(arguments.src3)?;
let success_ordering = get_ordering(data.semantics);
let failure_ordering = get_ordering_failure(data.semantics);
let temp = unsafe {
LLVMZludaBuildAtomicCmpXchg(
self.builder,
src1,
src2,
src3,
get_scope(data.scope)?,
success_ordering,
failure_ordering,
)
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildExtractValue(self.builder, temp, 0, dst)
});
Ok(())
}
}
fn get_pointer_type<'ctx>(
@ -753,6 +783,15 @@ fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
}
}
fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
match semantics {
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
}
}
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),

View File

@ -86,8 +86,7 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
expand_operands::run(&mut flat_resolver, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;