diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 8f71fe7..4a184d2 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -83,6 +83,10 @@ fn run_instruction<'input>( | ast::Instruction::Call { .. } | ast::Instruction::Clz { .. } | ast::Instruction::Cos { .. } + | ast::Instruction::CpAsync { .. } + | ast::Instruction::CpAsyncCommitGroup { .. } + | ast::Instruction::CpAsyncWaitGroup { .. } + | ast::Instruction::CpAsyncWaitAll { .. } | ast::Instruction::Cvt { data: ast::CvtDetails { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 86ea659..9202ad4 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1819,6 +1819,10 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Bfi { .. } | ast::Instruction::Shr { .. } | ast::Instruction::ShflSync { .. } + | ast::Instruction::CpAsync { .. } + | ast::Instruction::CpAsyncCommitGroup { .. } + | ast::Instruction::CpAsyncWaitGroup { .. } + | ast::Instruction::CpAsyncWaitAll { .. } | ast::Instruction::Shl { .. } | ast::Instruction::Selp { .. } | ast::Instruction::Ret { .. } diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index bd0160a..d73a881 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -34,7 +34,7 @@ use crate::pass::*; use llvm_zluda::{core::*, *}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; -use ptx_parser::Mul24Control; +use ptx_parser::{CpAsyncArgs, CpAsyncDetails, Mul24Control}; struct Builder(LLVMBuilderRef); @@ -515,6 +515,10 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Membar { data } => self.emit_membar(data), ast::Instruction::Trap {} => Err(error_todo_msg("Trap is not implemented yet")), ast::Instruction::Tanh { data, arguments } => self.emit_tanh(data, arguments), + ast::Instruction::CpAsync { data, arguments } => self.emit_cp_async(data, arguments), + ast::Instruction::CpAsyncCommitGroup { } => Ok(()), // nop + ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop + ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bar { .. } @@ -2550,6 +2554,40 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_cp_async( + &mut self, + data: CpAsyncDetails, + arguments: CpAsyncArgs, + ) -> Result<(), TranslateError> { + // Asynchronous copies are not supported by all AMD hardware, so we just do a synchronous copy for now + let to = self.resolver.value(arguments.src_to)?; + let from = self.resolver.value(arguments.src_from)?; + let cp_size = data.cp_size; + let src_size = data.src_size.unwrap_or(cp_size.as_u64()); + + let from_type = unsafe { LLVMIntTypeInContext(self.context, (src_size as u32) * 8) }; + + let to_type = match cp_size { + ptx_parser::CpAsyncCpSize::Bytes4 => unsafe { LLVMInt32TypeInContext(self.context) }, + ptx_parser::CpAsyncCpSize::Bytes8 => unsafe { LLVMInt64TypeInContext(self.context) }, + ptx_parser::CpAsyncCpSize::Bytes16 => unsafe { LLVMInt128TypeInContext(self.context) }, + }; + + let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) }; + unsafe { + LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8); + } + + let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) }; + + unsafe { LLVMBuildStore(self.builder, extended, to) }; + unsafe { + LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8); + } + Ok(()) + } + + fn flush_denormals( &mut self, type_: ptx_parser::ScalarType, diff --git a/ptx/src/test/ll/cp_async.ll b/ptx/src/test/ll/cp_async.ll new file mode 100644 index 0000000..a9b87da --- /dev/null +++ b/ptx/src/test/ll/cp_async.ll @@ -0,0 +1,54 @@ +@from = addrspace(1) global [4 x i32] [i32 1, i32 2, i32 3, i32 4] +@to = external addrspace(3) global [4 x i32] + +define amdgpu_kernel void @cp_async(ptr addrspace(4) byref(i64) %"48", ptr addrspace(4) byref(i64) %"49") #0 { + %"50" = alloca i64, align 8, addrspace(5) + %"51" = alloca i64, align 8, addrspace(5) + %"52" = alloca i32, align 4, addrspace(5) + %"53" = alloca i32, align 4, addrspace(5) + %"54" = alloca i32, align 4, addrspace(5) + %"55" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"47" + +"47": ; preds = %1 + %"56" = load i64, ptr addrspace(4) %"48", align 4 + store i64 %"56", ptr addrspace(5) %"50", align 4 + %"57" = load i64, ptr addrspace(4) %"49", align 4 + store i64 %"57", ptr addrspace(5) %"51", align 4 + %2 = load i96, ptr addrspace(1) @from, align 128 + %3 = zext i96 %2 to i128 + store i128 %3, ptr addrspace(3) @to, align 4 + %"58" = load i32, ptr addrspacecast (ptr addrspace(3) @to to ptr), align 4 + store i32 %"58", ptr addrspace(5) %"52", align 4 + %"59" = load i32, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(3) @to to ptr), i64 4), align 4 + store i32 %"59", ptr addrspace(5) %"53", align 4 + %"60" = load i32, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(3) @to to ptr), i64 8), align 4 + store i32 %"60", ptr addrspace(5) %"54", align 4 + %"61" = load i32, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(3) @to to ptr), i64 12), align 4 + store i32 %"61", ptr addrspace(5) %"55", align 4 + %"62" = load i64, ptr addrspace(5) %"51", align 4 + %"63" = load i32, ptr addrspace(5) %"52", align 4 + %"76" = inttoptr i64 %"62" to ptr + store i32 %"63", ptr %"76", align 4 + %"64" = load i64, ptr addrspace(5) %"51", align 4 + %"77" = inttoptr i64 %"64" to ptr + %"42" = getelementptr inbounds i8, ptr %"77", i64 4 + %"65" = load i32, ptr addrspace(5) %"53", align 4 + store i32 %"65", ptr %"42", align 4 + %"66" = load i64, ptr addrspace(5) %"51", align 4 + %"78" = inttoptr i64 %"66" to ptr + %"44" = getelementptr inbounds i8, ptr %"78", i64 8 + %"67" = load i32, ptr addrspace(5) %"54", align 4 + store i32 %"67", ptr %"44", align 4 + %"68" = load i64, ptr addrspace(5) %"51", align 4 + %"79" = inttoptr i64 %"68" to ptr + %"46" = getelementptr inbounds i8, ptr %"79", i64 12 + %"69" = load i32, ptr addrspace(5) %"55", align 4 + store i32 %"69", ptr %"46", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/cp_async.ptx b/ptx/src/test/spirv_run/cp_async.ptx new file mode 100644 index 0000000..469ae1d --- /dev/null +++ b/ptx/src/test/spirv_run/cp_async.ptx @@ -0,0 +1,39 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry cp_async( + .param .u64 input, + .param .u64 output +) +{ + .global .b32 from[4] = { 1, 2, 3, 4}; + .shared .b32 to[4]; + + .reg .u64 in_addr; + .reg .u64 out_addr; + + .reg .b32 temp1; + .reg .b32 temp2; + .reg .b32 temp3; + .reg .b32 temp4; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + cp.async.ca.shared.global [to], [from], 16, 12; + cp.async.commit_group; + cp.async.wait_group 0; + + ld.b32 temp1, [to]; + ld.b32 temp2, [to+4]; + ld.b32 temp3, [to+8]; + ld.b32 temp4, [to+12]; + + st.b32 [out_addr], temp1; + st.b32 [out_addr+4], temp2; + st.b32 [out_addr+8], temp3; + st.b32 [out_addr+12], temp4; + + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 381a224..54bf991 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -299,6 +299,7 @@ test_ptx!( test_ptx!(multiple_return, [5u32], [6u32, 123u32]); test_ptx!(warp_sz, [0u8], [32u8]); test_ptx!(tanh, [f32::INFINITY], [1.0f32]); +test_ptx!(cp_async, [0u32], [1u32, 2u32, 3u32, 0u32]); test_ptx!(nanosleep, [0u64], [0u64]); diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 7e99d6b..77721e3 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -188,6 +188,28 @@ ptx_parser_macros::generate_instruction_type!( src: T } }, + CpAsync { + type: Type::Scalar(ScalarType::U32), + data: CpAsyncDetails, + arguments: { + src_to: { + repr: T, + space: StateSpace::Shared + }, + src_from: { + repr: T, + space: StateSpace::Global + } + } + }, + CpAsyncCommitGroup { }, + CpAsyncWaitGroup { + type: Type::Scalar(ScalarType::U64), + arguments: { + src_group: T + } + }, + CpAsyncWaitAll { }, Cvt { data: CvtDetails, arguments: { @@ -1049,6 +1071,38 @@ pub struct ShflSyncDetails { pub mode: ShuffleMode, } +pub enum CpAsyncCpSize { + Bytes4, + Bytes8, + Bytes16, +} + +impl CpAsyncCpSize { + pub fn from_u64(n: u64) -> Option { + match n { + 4 => Some(Self::Bytes4), + 8 => Some(Self::Bytes8), + 16 => Some(Self::Bytes16), + _ => None, + } + } + + pub fn as_u64(&self) -> u64 { + match self { + CpAsyncCpSize::Bytes4 => 4, + CpAsyncCpSize::Bytes8 => 8, + CpAsyncCpSize::Bytes16 => 16, + } + } +} + +pub struct CpAsyncDetails { + pub caching: CpAsyncCacheOperator, + pub space: StateSpace, + pub cp_size: CpAsyncCpSize, + pub src_size: Option, +} + #[derive(Clone)] pub enum ParsedOperand { Reg(Ident), @@ -1058,6 +1112,15 @@ pub enum ParsedOperand { VecPack(Vec), } +impl ParsedOperand { + pub fn as_immediate(&self) -> Option { + match self { + ParsedOperand::Imm(imm) => Some(*imm), + _ => None, + } + } +} + impl Operand for ParsedOperand { type Ident = Ident; @@ -1080,6 +1143,17 @@ pub enum ImmediateValue { F64(f64), } +impl ImmediateValue { + /// If the value is a U64 or S64, returns the value as a u64, ignoring the sign. + pub fn as_u64(&self) -> Option { + match *self { + ImmediateValue::U64(n) => Some(n), + ImmediateValue::S64(n) => Some(n as u64), + ImmediateValue::F32(_) | ImmediateValue::F64(_) => None, + } + } +} + #[derive(Copy, Clone, PartialEq, Eq)] pub enum StCacheOperator { Writeback, @@ -1097,6 +1171,11 @@ pub enum LdCacheOperator { Uncached, } +pub enum CpAsyncCacheOperator { + Cached, + L2Only, +} + #[derive(Copy, Clone)] pub enum ArithDetails { Integer(ArithInteger), diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 4572842..76887e5 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -62,6 +62,15 @@ impl From for ast::LdStQualifier { } } +impl From for ast::CpAsyncCacheOperator { + fn from(value: RawCpAsyncCacheOperator) -> Self { + match value { + RawCpAsyncCacheOperator::Ca => ast::CpAsyncCacheOperator::Cached, + RawCpAsyncCacheOperator::Cg => ast::CpAsyncCacheOperator::L2Only, + } + } +} + impl From for ast::RoundingMode { fn from(value: RawRoundingMode) -> Self { value.normalize().0 @@ -3556,6 +3565,66 @@ derive_parser!( } } .type: ScalarType = { .f32, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async + cp.async.cop.space.global{.level::cache_hint}{.level::prefetch_size} + [dst], [src], cp-size{, src-size}{, cache-policy} => { + if level_cache_hint || cache_policy.is_some() || level_prefetch_size.is_some() { + state.errors.push(PtxError::Todo); + } + + let cp_size = cp_size + .as_immediate() + .and_then(|imm| imm.as_u64()) + .and_then(|n| CpAsyncCpSize::from_u64(n)) + .unwrap_or_else(|| { + state.errors.push(PtxError::SyntaxError); + CpAsyncCpSize::Bytes4 + }); + + let src_size = src_size + .and_then(|op| op.as_immediate()) + .and_then(|imm| imm.as_u64()); + + Instruction::CpAsync { + data: CpAsyncDetails { + caching: cop.into(), + space, + cp_size, + src_size, + }, + arguments: CpAsyncArgs { + src_to: dst, + src_from: src, + } + } + } + // cp.async.ca.shared{::cta}.global{.level::cache_hint}{.level::prefetch_size} + // [dst], [src], cp-size{, ignore-src}{, cache-policy} ; + // cp.async.cg.shared{::cta}.global{.level::cache_hint}{.level::prefetch_size} + // [dst], [src], 16{, ignore-src}{, cache-policy} ; + + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + // TODO: how to handle this? + // cp-size = { 4, 8, 16 } + .space: StateSpace = { .shared{::cta} }; + .cop: RawCpAsyncCacheOperator = { .ca, .cg }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-commit-group + cp.async.commit_group => { + Instruction::CpAsyncCommitGroup {} + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-wait-group + cp.async.wait_group n => { + Instruction::CpAsyncWaitGroup { + arguments: CpAsyncWaitGroupArgs { src_group: n }, + } + } + cp.async.wait_all => { + Instruction::CpAsyncWaitAll {} + } ); #[cfg(test)] diff --git a/ptx_parser_macros_impl/src/parser.rs b/ptx_parser_macros_impl/src/parser.rs index 786545b..20a7dab 100644 --- a/ptx_parser_macros_impl/src/parser.rs +++ b/ptx_parser_macros_impl/src/parser.rs @@ -327,7 +327,7 @@ impl DotModifier { write!(&mut result, "_{}", part2.0).unwrap(); } else { match self.part1 { - IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'), + IdentLike::Type(_) | IdentLike::Const(_) | IdentLike::Async(_) => result.push('_'), IdentLike::Ident(_) | IdentLike::Integer(_) => {} } } @@ -437,6 +437,7 @@ impl Parse for HyphenatedIdent { enum IdentLike { Type(Token![type]), Const(Token![const]), + Async(Token![async]), Ident(Ident), Integer(LitInt), } @@ -446,6 +447,7 @@ impl IdentLike { match self { IdentLike::Type(c) => c.span(), IdentLike::Const(t) => t.span(), + IdentLike::Async(a) => a.span(), IdentLike::Ident(i) => i.span(), IdentLike::Integer(l) => l.span(), } @@ -457,6 +459,7 @@ impl std::fmt::Display for IdentLike { match self { IdentLike::Type(_) => f.write_str("type"), IdentLike::Const(_) => f.write_str("const"), + IdentLike::Async(_) => f.write_str("async"), IdentLike::Ident(ident) => write!(f, "{}", ident), IdentLike::Integer(integer) => write!(f, "{}", integer), } @@ -468,6 +471,7 @@ impl ToTokens for IdentLike { match self { IdentLike::Type(_) => quote! { type }.to_tokens(tokens), IdentLike::Const(_) => quote! { const }.to_tokens(tokens), + IdentLike::Async(_) => quote! { async }.to_tokens(tokens), IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens), IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens), } @@ -481,6 +485,8 @@ impl Parse for IdentLike { IdentLike::Const(input.parse::()?) } else if lookahead.peek(Token![type]) { IdentLike::Type(input.parse::()?) + } else if lookahead.peek(Token![async]) { + IdentLike::Async(input.parse::()?) } else if lookahead.peek(Ident) { IdentLike::Ident(input.parse::()?) } else if lookahead.peek(LitInt) {