Add support for cp.async (#427)

Adds support for

* `cp.async`
* `cp.async.commit_group`
* `cp.async.wait_group`
* `cp.async.wait_all`

Asynchronous copy operations are only supported by AMD Instinct GPUs, so for now we lower them as synchronous copy operations. Because of this, `cp.async.commit_group`, `cp.async.wait_group`, and `cp.async.wait_all` are no-op.
This commit is contained in:
Violet
2025-07-23 16:25:49 -07:00
committed by GitHub
parent 3746079b1a
commit 2b90fdb56c
9 changed files with 296 additions and 2 deletions

View File

@ -83,6 +83,10 @@ fn run_instruction<'input>(
| ast::Instruction::Call { .. }
| ast::Instruction::Clz { .. }
| ast::Instruction::Cos { .. }
| ast::Instruction::CpAsync { .. }
| ast::Instruction::CpAsyncCommitGroup { .. }
| ast::Instruction::CpAsyncWaitGroup { .. }
| ast::Instruction::CpAsyncWaitAll { .. }
| ast::Instruction::Cvt {
data:
ast::CvtDetails {

View File

@ -1819,6 +1819,10 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::Bfi { .. }
| ast::Instruction::Shr { .. }
| ast::Instruction::ShflSync { .. }
| ast::Instruction::CpAsync { .. }
| ast::Instruction::CpAsyncCommitGroup { .. }
| ast::Instruction::CpAsyncWaitGroup { .. }
| ast::Instruction::CpAsyncWaitAll { .. }
| ast::Instruction::Shl { .. }
| ast::Instruction::Selp { .. }
| ast::Instruction::Ret { .. }

View File

@ -34,7 +34,7 @@ use crate::pass::*;
use llvm_zluda::{core::*, *};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
use ptx_parser::Mul24Control;
use ptx_parser::{CpAsyncArgs, CpAsyncDetails, Mul24Control};
struct Builder(LLVMBuilderRef);
@ -515,6 +515,10 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Membar { data } => self.emit_membar(data),
ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")),
ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments),
ast::Instruction::CpAsync { data, arguments } => self.emit_cp_async(data, arguments),
ast::Instruction::CpAsyncCommitGroup { } => Ok(()), // nop
ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop
ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop
// replaced by a function call
ast::Instruction::Bfe { .. }
| ast::Instruction::Bar { .. }
@ -2550,6 +2554,40 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_cp_async(
&mut self,
data: CpAsyncDetails,
arguments: CpAsyncArgs<SpirvWord>,
) -> Result<(), TranslateError> {
// Asynchronous copies are not supported by all AMD hardware, so we just do a synchronous copy for now
let to = self.resolver.value(arguments.src_to)?;
let from = self.resolver.value(arguments.src_from)?;
let cp_size = data.cp_size;
let src_size = data.src_size.unwrap_or(cp_size.as_u64());
let from_type = unsafe { LLVMIntTypeInContext(self.context, (src_size as u32) * 8) };
let to_type = match cp_size {
ptx_parser::CpAsyncCpSize::Bytes4 => unsafe { LLVMInt32TypeInContext(self.context) },
ptx_parser::CpAsyncCpSize::Bytes8 => unsafe { LLVMInt64TypeInContext(self.context) },
ptx_parser::CpAsyncCpSize::Bytes16 => unsafe { LLVMInt128TypeInContext(self.context) },
};
let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) };
unsafe {
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8);
}
let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMBuildStore(self.builder, extended, to) };
unsafe {
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8);
}
Ok(())
}
fn flush_denormals(
&mut self,
type_: ptx_parser::ScalarType,

View File

@ -0,0 +1,54 @@
@from = addrspace(1) global [4 x i32] [i32 1, i32 2, i32 3, i32 4]
@to = external addrspace(3) global [4 x i32]
define amdgpu_kernel void @cp_async(ptr addrspace(4) byref(i64) %"48", ptr addrspace(4) byref(i64) %"49") #0 {
%"50" = alloca i64, align 8, addrspace(5)
%"51" = alloca i64, align 8, addrspace(5)
%"52" = alloca i32, align 4, addrspace(5)
%"53" = alloca i32, align 4, addrspace(5)
%"54" = alloca i32, align 4, addrspace(5)
%"55" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"47"
"47": ; preds = %1
%"56" = load i64, ptr addrspace(4) %"48", align 4
store i64 %"56", ptr addrspace(5) %"50", align 4
%"57" = load i64, ptr addrspace(4) %"49", align 4
store i64 %"57", ptr addrspace(5) %"51", align 4
%2 = load i96, ptr addrspace(1) @from, align 128
%3 = zext i96 %2 to i128
store i128 %3, ptr addrspace(3) @to, align 4
%"58" = load i32, ptr addrspacecast (ptr addrspace(3) @to to ptr), align 4
store i32 %"58", ptr addrspace(5) %"52", align 4
%"59" = load i32, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(3) @to to ptr), i64 4), align 4
store i32 %"59", ptr addrspace(5) %"53", align 4
%"60" = load i32, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(3) @to to ptr), i64 8), align 4
store i32 %"60", ptr addrspace(5) %"54", align 4
%"61" = load i32, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(3) @to to ptr), i64 12), align 4
store i32 %"61", ptr addrspace(5) %"55", align 4
%"62" = load i64, ptr addrspace(5) %"51", align 4
%"63" = load i32, ptr addrspace(5) %"52", align 4
%"76" = inttoptr i64 %"62" to ptr
store i32 %"63", ptr %"76", align 4
%"64" = load i64, ptr addrspace(5) %"51", align 4
%"77" = inttoptr i64 %"64" to ptr
%"42" = getelementptr inbounds i8, ptr %"77", i64 4
%"65" = load i32, ptr addrspace(5) %"53", align 4
store i32 %"65", ptr %"42", align 4
%"66" = load i64, ptr addrspace(5) %"51", align 4
%"78" = inttoptr i64 %"66" to ptr
%"44" = getelementptr inbounds i8, ptr %"78", i64 8
%"67" = load i32, ptr addrspace(5) %"54", align 4
store i32 %"67", ptr %"44", align 4
%"68" = load i64, ptr addrspace(5) %"51", align 4
%"79" = inttoptr i64 %"68" to ptr
%"46" = getelementptr inbounds i8, ptr %"79", i64 12
%"69" = load i32, ptr addrspace(5) %"55", align 4
store i32 %"69", ptr %"46", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,39 @@
.version 7.0
.target sm_80
.address_size 64
.visible .entry cp_async(
.param .u64 input,
.param .u64 output
)
{
.global .b32 from[4] = { 1, 2, 3, 4};
.shared .b32 to[4];
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b32 temp1;
.reg .b32 temp2;
.reg .b32 temp3;
.reg .b32 temp4;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
cp.async.ca.shared.global [to], [from], 16, 12;
cp.async.commit_group;
cp.async.wait_group 0;
ld.b32 temp1, [to];
ld.b32 temp2, [to+4];
ld.b32 temp3, [to+8];
ld.b32 temp4, [to+12];
st.b32 [out_addr], temp1;
st.b32 [out_addr+4], temp2;
st.b32 [out_addr+8], temp3;
st.b32 [out_addr+12], temp4;
ret;
}

View File

@ -299,6 +299,7 @@ test_ptx!(
test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
test_ptx!(warp_sz, [0u8], [32u8]);
test_ptx!(tanh, [f32::INFINITY], [1.0f32]);
test_ptx!(cp_async, [0u32], [1u32, 2u32, 3u32, 0u32]);
test_ptx!(nanosleep, [0u64], [0u64]);

View File

@ -188,6 +188,28 @@ ptx_parser_macros::generate_instruction_type!(
src: T
}
},
CpAsync {
type: Type::Scalar(ScalarType::U32),
data: CpAsyncDetails,
arguments<T>: {
src_to: {
repr: T,
space: StateSpace::Shared
},
src_from: {
repr: T,
space: StateSpace::Global
}
}
},
CpAsyncCommitGroup { },
CpAsyncWaitGroup {
type: Type::Scalar(ScalarType::U64),
arguments<T>: {
src_group: T
}
},
CpAsyncWaitAll { },
Cvt {
data: CvtDetails,
arguments<T>: {
@ -1049,6 +1071,38 @@ pub struct ShflSyncDetails {
pub mode: ShuffleMode,
}
pub enum CpAsyncCpSize {
Bytes4,
Bytes8,
Bytes16,
}
impl CpAsyncCpSize {
pub fn from_u64(n: u64) -> Option<Self> {
match n {
4 => Some(Self::Bytes4),
8 => Some(Self::Bytes8),
16 => Some(Self::Bytes16),
_ => None,
}
}
pub fn as_u64(&self) -> u64 {
match self {
CpAsyncCpSize::Bytes4 => 4,
CpAsyncCpSize::Bytes8 => 8,
CpAsyncCpSize::Bytes16 => 16,
}
}
}
pub struct CpAsyncDetails {
pub caching: CpAsyncCacheOperator,
pub space: StateSpace,
pub cp_size: CpAsyncCpSize,
pub src_size: Option<u64>,
}
#[derive(Clone)]
pub enum ParsedOperand<Ident> {
Reg(Ident),
@ -1058,6 +1112,15 @@ pub enum ParsedOperand<Ident> {
VecPack(Vec<Ident>),
}
impl<Ident> ParsedOperand<Ident> {
pub fn as_immediate(&self) -> Option<ImmediateValue> {
match self {
ParsedOperand::Imm(imm) => Some(*imm),
_ => None,
}
}
}
impl<Ident: Copy> Operand for ParsedOperand<Ident> {
type Ident = Ident;
@ -1080,6 +1143,17 @@ pub enum ImmediateValue {
F64(f64),
}
impl ImmediateValue {
/// If the value is a U64 or S64, returns the value as a u64, ignoring the sign.
pub fn as_u64(&self) -> Option<u64> {
match *self {
ImmediateValue::U64(n) => Some(n),
ImmediateValue::S64(n) => Some(n as u64),
ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum StCacheOperator {
Writeback,
@ -1097,6 +1171,11 @@ pub enum LdCacheOperator {
Uncached,
}
pub enum CpAsyncCacheOperator {
Cached,
L2Only,
}
#[derive(Copy, Clone)]
pub enum ArithDetails {
Integer(ArithInteger),

View File

@ -62,6 +62,15 @@ impl From<RawLdStQualifier> for ast::LdStQualifier {
}
}
impl From<RawCpAsyncCacheOperator> for ast::CpAsyncCacheOperator {
fn from(value: RawCpAsyncCacheOperator) -> Self {
match value {
RawCpAsyncCacheOperator::Ca => ast::CpAsyncCacheOperator::Cached,
RawCpAsyncCacheOperator::Cg => ast::CpAsyncCacheOperator::L2Only,
}
}
}
impl From<RawRoundingMode> for ast::RoundingMode {
fn from(value: RawRoundingMode) -> Self {
value.normalize().0
@ -3556,6 +3565,66 @@ derive_parser!(
}
}
.type: ScalarType = { .f32, .f16, .f16x2, .bf16, .bf16x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async
cp.async.cop.space.global{.level::cache_hint}{.level::prefetch_size}
[dst], [src], cp-size{, src-size}{, cache-policy} => {
if level_cache_hint || cache_policy.is_some() || level_prefetch_size.is_some() {
state.errors.push(PtxError::Todo);
}
let cp_size = cp_size
.as_immediate()
.and_then(|imm| imm.as_u64())
.and_then(|n| CpAsyncCpSize::from_u64(n))
.unwrap_or_else(|| {
state.errors.push(PtxError::SyntaxError);
CpAsyncCpSize::Bytes4
});
let src_size = src_size
.and_then(|op| op.as_immediate())
.and_then(|imm| imm.as_u64());
Instruction::CpAsync {
data: CpAsyncDetails {
caching: cop.into(),
space,
cp_size,
src_size,
},
arguments: CpAsyncArgs {
src_to: dst,
src_from: src,
}
}
}
// cp.async.ca.shared{::cta}.global{.level::cache_hint}{.level::prefetch_size}
// [dst], [src], cp-size{, ignore-src}{, cache-policy} ;
// cp.async.cg.shared{::cta}.global{.level::cache_hint}{.level::prefetch_size}
// [dst], [src], 16{, ignore-src}{, cache-policy} ;
.level::cache_hint = { .L2::cache_hint };
.level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B };
// TODO: how to handle this?
// cp-size = { 4, 8, 16 }
.space: StateSpace = { .shared{::cta} };
.cop: RawCpAsyncCacheOperator = { .ca, .cg };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-commit-group
cp.async.commit_group => {
Instruction::CpAsyncCommitGroup {}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-wait-group
cp.async.wait_group n => {
Instruction::CpAsyncWaitGroup {
arguments: CpAsyncWaitGroupArgs { src_group: n },
}
}
cp.async.wait_all => {
Instruction::CpAsyncWaitAll {}
}
);
#[cfg(test)]

View File

@ -327,7 +327,7 @@ impl DotModifier {
write!(&mut result, "_{}", part2.0).unwrap();
} else {
match self.part1 {
IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'),
IdentLike::Type(_) | IdentLike::Const(_) | IdentLike::Async(_) => result.push('_'),
IdentLike::Ident(_) | IdentLike::Integer(_) => {}
}
}
@ -437,6 +437,7 @@ impl Parse for HyphenatedIdent {
enum IdentLike {
Type(Token![type]),
Const(Token![const]),
Async(Token![async]),
Ident(Ident),
Integer(LitInt),
}
@ -446,6 +447,7 @@ impl IdentLike {
match self {
IdentLike::Type(c) => c.span(),
IdentLike::Const(t) => t.span(),
IdentLike::Async(a) => a.span(),
IdentLike::Ident(i) => i.span(),
IdentLike::Integer(l) => l.span(),
}
@ -457,6 +459,7 @@ impl std::fmt::Display for IdentLike {
match self {
IdentLike::Type(_) => f.write_str("type"),
IdentLike::Const(_) => f.write_str("const"),
IdentLike::Async(_) => f.write_str("async"),
IdentLike::Ident(ident) => write!(f, "{}", ident),
IdentLike::Integer(integer) => write!(f, "{}", integer),
}
@ -468,6 +471,7 @@ impl ToTokens for IdentLike {
match self {
IdentLike::Type(_) => quote! { type }.to_tokens(tokens),
IdentLike::Const(_) => quote! { const }.to_tokens(tokens),
IdentLike::Async(_) => quote! { async }.to_tokens(tokens),
IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens),
IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens),
}
@ -481,6 +485,8 @@ impl Parse for IdentLike {
IdentLike::Const(input.parse::<Token![const]>()?)
} else if lookahead.peek(Token![type]) {
IdentLike::Type(input.parse::<Token![type]>()?)
} else if lookahead.peek(Token![async]) {
IdentLike::Async(input.parse::<Token![async]>()?)
} else if lookahead.peek(Ident) {
IdentLike::Ident(input.parse::<Ident>()?)
} else if lookahead.peek(LitInt) {