diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 6cefc81..84d2f2d 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 7af9729..638ef1e 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -157,6 +157,19 @@ extern "C" __builtin_amdgcn_s_barrier(); } + int32_t __ockl_wgred_and_i32(int32_t) __device__; + int32_t __ockl_wgred_or_i32(int32_t) __device__; + + #define BAR_RED_IMPL(reducer) \ + bool FUNC(bar_red_##reducer##_pred)(uint32_t barrier __attribute__((unused)), bool predicate, bool invert_predicate) \ + { \ + /* TODO: handle barrier */ \ + return __ockl_wgred_##reducer##_i32(predicate ^ invert_predicate); \ + } + + BAR_RED_IMPL(and); + BAR_RED_IMPL(or); + void FUNC(__assertfail)(uint64_t message, uint64_t file, uint32_t line, diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 58341e4..1c2b52c 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -639,6 +639,7 @@ impl<'a> MethodEmitContext<'a> { // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Activemask { .. } => return Err(error_unreachable()), } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index cc9afa7..92be749 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -75,6 +75,7 @@ fn run_instruction<'input>( | ast::Instruction::Atom { .. } | ast::Instruction::AtomCas { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Bfe { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Bra { .. } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index fdaafd1..3d56dd0 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1804,6 +1804,7 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Selp { .. } | ast::Instruction::Ret { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Cvta { .. } | ast::Instruction::Atom { .. } | ast::Instruction::Mul24 { .. } diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 0f9311a..db6b473 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -108,6 +108,16 @@ fn run_instruction<'input>( i @ ptx_parser::Instruction::Bar { .. } => { to_call(resolver, fn_declarations, "bar_sync".into(), i)? } + ptx_parser::Instruction::BarRed { data, arguments } => { + if arguments.src_threadcount.is_some() { + return Err(error_todo()); + } + let name = match data.pred_reduction { + ptx_parser::Reduction::And => "bar_red_and_pred", + ptx_parser::Reduction::Or => "bar_red_or_pred", + }; + to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })? + } i => i, }) } diff --git a/ptx/src/test/ll/bar_red_and_pred.ll b/ptx/src/test/ll/bar_red_and_pred.ll new file mode 100644 index 0000000..649efc0 --- /dev/null +++ b/ptx/src/test/ll/bar_red_and_pred.ll @@ -0,0 +1,121 @@ +declare i1 @__zluda_ptx_impl_bar_red_and_pred(i32, i1, i1) #0 + +declare i1 @__zluda_ptx_impl_bar_red_or_pred(i32, i1, i1) #0 + +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @bar_red_and_pred(ptr addrspace(4) byref(i64) %"73", ptr addrspace(4) byref(i64) %"74") #1 { + %"75" = alloca i64, align 8, addrspace(5) + %"76" = alloca i64, align 8, addrspace(5) + %"77" = alloca i32, align 4, addrspace(5) + %"78" = alloca i32, align 4, addrspace(5) + %"79" = alloca i1, align 1, addrspace(5) + %"80" = alloca i1, align 1, addrspace(5) + %"81" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"70" + +"70": ; preds = %1 + %"82" = load i64, ptr addrspace(4) %"74", align 4 + store i64 %"82", ptr addrspace(5) %"75", align 4 + %"44" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"71" + +"71": ; preds = %"70" + store i32 %"44", ptr addrspace(5) %"77", align 4 + %"85" = load i32, ptr addrspace(5) %"77", align 4 + %"84" = urem i32 %"85", 2 + store i32 %"84", ptr addrspace(5) %"78", align 4 + %"87" = load i32, ptr addrspace(5) %"78", align 4 + %"86" = icmp eq i32 %"87", 0 + store i1 %"86", ptr addrspace(5) %"80", align 1 + store i32 0, ptr addrspace(5) %"81", align 4 + %"90" = load i1, ptr addrspace(5) %"80", align 1 + %"89" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"90", i1 false) + store i1 %"89", ptr addrspace(5) %"79", align 1 + %"91" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"91", label %"17", label %"18" + +"17": ; preds = %"71" + %"93" = load i32, ptr addrspace(5) %"81", align 4 + %"92" = add i32 %"93", 1 + store i32 %"92", ptr addrspace(5) %"81", align 4 + br label %"18" + +"18": ; preds = %"17", %"71" + %"95" = load i1, ptr addrspace(5) %"80", align 1 + %"94" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"95", i1 false) + store i1 %"94", ptr addrspace(5) %"79", align 1 + %"96" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"96", label %"19", label %"20" + +"19": ; preds = %"18" + %"98" = load i32, ptr addrspace(5) %"81", align 4 + %"97" = add i32 %"98", 1 + store i32 %"97", ptr addrspace(5) %"81", align 4 + br label %"20" + +"20": ; preds = %"19", %"18" + store i1 true, ptr addrspace(5) %"80", align 1 + %"101" = load i1, ptr addrspace(5) %"80", align 1 + %"100" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"101", i1 false) + store i1 %"100", ptr addrspace(5) %"79", align 1 + %"102" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"102", label %"21", label %"22" + +"21": ; preds = %"20" + %"104" = load i32, ptr addrspace(5) %"81", align 4 + %"103" = add i32 %"104", 1 + store i32 %"103", ptr addrspace(5) %"81", align 4 + br label %"22" + +"22": ; preds = %"21", %"20" + store i1 false, ptr addrspace(5) %"80", align 1 + %"107" = load i1, ptr addrspace(5) %"80", align 1 + %"106" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"107", i1 false) + store i1 %"106", ptr addrspace(5) %"79", align 1 + %"108" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"108", label %"23", label %"24" + +"23": ; preds = %"22" + %"110" = load i32, ptr addrspace(5) %"81", align 4 + %"109" = add i32 %"110", 1 + store i32 %"109", ptr addrspace(5) %"81", align 4 + br label %"24" + +"24": ; preds = %"23", %"22" + store i1 true, ptr addrspace(5) %"80", align 1 + %"113" = load i1, ptr addrspace(5) %"80", align 1 + %"112" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"113", i1 true) + store i1 %"112", ptr addrspace(5) %"79", align 1 + %"114" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"114", label %"25", label %"26" + +"25": ; preds = %"24" + %"116" = load i32, ptr addrspace(5) %"81", align 4 + %"115" = add i32 %"116", 1 + store i32 %"115", ptr addrspace(5) %"81", align 4 + br label %"26" + +"26": ; preds = %"25", %"24" + %"118" = load i32, ptr addrspace(5) %"77", align 4 + %"117" = zext i32 %"118" to i64 + store i64 %"117", ptr addrspace(5) %"76", align 4 + %"120" = load i64, ptr addrspace(5) %"76", align 4 + %"119" = mul i64 %"120", 4 + store i64 %"119", ptr addrspace(5) %"76", align 4 + %"122" = load i64, ptr addrspace(5) %"75", align 4 + %"123" = load i64, ptr addrspace(5) %"76", align 4 + %"121" = add i64 %"122", %"123" + store i64 %"121", ptr addrspace(5) %"75", align 4 + %"124" = load i64, ptr addrspace(5) %"75", align 4 + %"125" = load i32, ptr addrspace(5) %"81", align 4 + %"126" = inttoptr i64 %"124" to ptr + store i32 %"125", ptr %"126", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/bar_red_and_pred.ptx b/ptx/src/test/spirv_run/bar_red_and_pred.ptx new file mode 100644 index 0000000..777b771 --- /dev/null +++ b/ptx/src/test/spirv_run/bar_red_and_pred.ptx @@ -0,0 +1,60 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry bar_red_and_pred( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u64 out_index; + .reg .u32 thread_id; + .reg .u32 thread_mod_2; + .reg .pred pred; + .reg .pred cond; + .reg .u32 result; + + ld.param.u64 out_addr, [output]; + + mov.u32 thread_id, %tid.x; + rem.u32 thread_mod_2, thread_id, 2; + setp.eq.u32 cond, thread_mod_2, 0; + + mov.u32 result, 0; + + // Basic functionality + + // result += AND(tid.x % 2 == 0) forall threads + bar.red.and.pred pred, 1, cond; + @pred add.u32 result, result, 1; + // result += OR(tid.x % 2 == 0) forall threads + bar.red.or.pred pred, 1, cond; + @pred add.u32 result, result, 1; + + // result += AND(true) forall threads + setp.eq.u32 cond, 1, 1; + bar.red.and.pred pred, 1, cond; + @pred add.u32 result, result, 1; + // result += OR(false) forall threads + setp.eq.u32 cond, 1, 0; + bar.red.or.pred pred, 1, cond; + @pred add.u32 result, result, 1; + + // Negated condition + // result += AND(!true) forall threads + setp.eq.u32 cond, 1, 1; + bar.red.and.pred pred, 1, !cond; + @pred add.u32 result, result, 1; + + // Return result + + cvt.u64.u32 out_index, thread_id; + mul.lo.u64 out_index, out_index, 4; + add.u64 out_addr, out_addr, out_index; + st.u32 [out_addr], result; + + // result should be 2 + + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 84e0731..c594ebb 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -307,6 +307,13 @@ test_ptx_warp!(tid, [ 32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8, 48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8, ]); +test_ptx_warp!(bar_red_and_pred, [ + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, +]); + struct DisplayError { err: T, } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index ca7b9df..6e42871 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -2,7 +2,7 @@ use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; -use crate::{Mul24Control, PtxError, PtxParserState}; +use crate::{Mul24Control, Reduction, PtxError, PtxParserState}; use bitflags::bitflags; use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; @@ -95,6 +95,26 @@ ptx_parser_macros::generate_instruction_type!( src2: Option, } }, + BarRed { + type: Type::Scalar(ScalarType::U32), + data: BarRedData, + arguments: { + dst1: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + src_barrier: T, + src_threadcount: Option, + src_predicate: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + src_negate_predicate: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + } + }, Bfe { type: Type::Scalar(data.clone()), data: ScalarType, @@ -1745,6 +1765,12 @@ pub struct BarData { pub aligned: bool, } +#[derive(Copy, Clone)] +pub struct BarRedData { + pub aligned: bool, + pub pred_reduction: Reduction, +} + pub struct AtomDetails { pub type_: Type, pub semantics: AtomSemantics, diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 6dedbbb..da14406 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1705,6 +1705,9 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum Mul24Control { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum Reduction { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -2987,7 +2990,9 @@ derive_parser!( barrier{.cta}.sync{.aligned} a{, b} => { let _ = cta; ast::Instruction::Bar { - data: ast::BarData { aligned }, + data: ast::BarData { + aligned, + }, arguments: BarArgs { src1: a, src2: b } } } @@ -2997,14 +3002,32 @@ derive_parser!( bar{.cta}.sync a{, b} => { let _ = cta; ast::Instruction::Bar { - data: ast::BarData { aligned: true }, + data: ast::BarData { + aligned: true, + }, arguments: BarArgs { src1: a, src2: b } } } //bar{.cta}.arrive a, b; //bar{.cta}.red.popc.u32 d, a{, b}, {!}c; - //bar{.cta}.red.op.pred p, a{, b}, {!}c; - //.op = { .and, .or }; + bar{.cta}.red.op.pred p, a{, b}, {!}c => { + let _ = cta; + let (negate_src3, c) = c; + ast::Instruction::BarRed { + data: ast::BarRedData { + aligned: true, + pred_reduction: op, + }, + arguments: BarRedArgs { + dst1: p, + src_barrier: a, + src_threadcount: b, + src_predicate: c, + src_negate_predicate: ParsedOperand::Imm(ImmediateValue::U64(negate_src3 as u64)) + } + } + } + .op: Reduction = { .and, .or }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => { diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index f88395d..0e916b4 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -784,7 +784,7 @@ fn emit_definition_parser( }; let can_be_negated = if arg.can_be_negated { quote! { - opt(any.verify(|(t, _)| *t == #token_type::Not)).map(|o| o.is_some()) + opt(any.verify(|(t, _)| *t == #token_type::Exclamation)).map(|o| o.is_some()) } } else { quote! {