Add support for bar.red.and.pred (#402)

Implements bar.red.and.pred and bar.red.or.pred, using the undocument __ockl_wgred functions. Doesn't yet add support for numbered barriers and threadcount, as these are not needed for llm.c.
This commit is contained in:
Violet
2025-07-03 11:56:20 -07:00
committed by GitHub
parent 7bdd20f0dd
commit 5cb0a9b8e8
12 changed files with 269 additions and 6 deletions

Binary file not shown.

View File

@ -157,6 +157,19 @@ extern "C"
__builtin_amdgcn_s_barrier();
}
int32_t __ockl_wgred_and_i32(int32_t) __device__;
int32_t __ockl_wgred_or_i32(int32_t) __device__;
#define BAR_RED_IMPL(reducer) \
bool FUNC(bar_red_##reducer##_pred)(uint32_t barrier __attribute__((unused)), bool predicate, bool invert_predicate) \
{ \
/* TODO: handle barrier */ \
return __ockl_wgred_##reducer##_i32(predicate ^ invert_predicate); \
}
BAR_RED_IMPL(and);
BAR_RED_IMPL(or);
void FUNC(__assertfail)(uint64_t message,
uint64_t file,
uint32_t line,

View File

@ -639,6 +639,7 @@ impl<'a> MethodEmitContext<'a> {
// replaced by a function call
ast::Instruction::Bfe { .. }
| ast::Instruction::Bar { .. }
| ast::Instruction::BarRed { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Activemask { .. } => return Err(error_unreachable()),
}

View File

@ -75,6 +75,7 @@ fn run_instruction<'input>(
| ast::Instruction::Atom { .. }
| ast::Instruction::AtomCas { .. }
| ast::Instruction::Bar { .. }
| ast::Instruction::BarRed { .. }
| ast::Instruction::Bfe { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Bra { .. }

View File

@ -1804,6 +1804,7 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::Selp { .. }
| ast::Instruction::Ret { .. }
| ast::Instruction::Bar { .. }
| ast::Instruction::BarRed { .. }
| ast::Instruction::Cvta { .. }
| ast::Instruction::Atom { .. }
| ast::Instruction::Mul24 { .. }

View File

@ -108,6 +108,16 @@ fn run_instruction<'input>(
i @ ptx_parser::Instruction::Bar { .. } => {
to_call(resolver, fn_declarations, "bar_sync".into(), i)?
}
ptx_parser::Instruction::BarRed { data, arguments } => {
if arguments.src_threadcount.is_some() {
return Err(error_todo());
}
let name = match data.pred_reduction {
ptx_parser::Reduction::And => "bar_red_and_pred",
ptx_parser::Reduction::Or => "bar_red_or_pred",
};
to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })?
}
i => i,
})
}

View File

@ -0,0 +1,121 @@
declare i1 @__zluda_ptx_impl_bar_red_and_pred(i32, i1, i1) #0
declare i1 @__zluda_ptx_impl_bar_red_or_pred(i32, i1, i1) #0
declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @bar_red_and_pred(ptr addrspace(4) byref(i64) %"73", ptr addrspace(4) byref(i64) %"74") #1 {
%"75" = alloca i64, align 8, addrspace(5)
%"76" = alloca i64, align 8, addrspace(5)
%"77" = alloca i32, align 4, addrspace(5)
%"78" = alloca i32, align 4, addrspace(5)
%"79" = alloca i1, align 1, addrspace(5)
%"80" = alloca i1, align 1, addrspace(5)
%"81" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"70"
"70": ; preds = %1
%"82" = load i64, ptr addrspace(4) %"74", align 4
store i64 %"82", ptr addrspace(5) %"75", align 4
%"44" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"71"
"71": ; preds = %"70"
store i32 %"44", ptr addrspace(5) %"77", align 4
%"85" = load i32, ptr addrspace(5) %"77", align 4
%"84" = urem i32 %"85", 2
store i32 %"84", ptr addrspace(5) %"78", align 4
%"87" = load i32, ptr addrspace(5) %"78", align 4
%"86" = icmp eq i32 %"87", 0
store i1 %"86", ptr addrspace(5) %"80", align 1
store i32 0, ptr addrspace(5) %"81", align 4
%"90" = load i1, ptr addrspace(5) %"80", align 1
%"89" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"90", i1 false)
store i1 %"89", ptr addrspace(5) %"79", align 1
%"91" = load i1, ptr addrspace(5) %"79", align 1
br i1 %"91", label %"17", label %"18"
"17": ; preds = %"71"
%"93" = load i32, ptr addrspace(5) %"81", align 4
%"92" = add i32 %"93", 1
store i32 %"92", ptr addrspace(5) %"81", align 4
br label %"18"
"18": ; preds = %"17", %"71"
%"95" = load i1, ptr addrspace(5) %"80", align 1
%"94" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"95", i1 false)
store i1 %"94", ptr addrspace(5) %"79", align 1
%"96" = load i1, ptr addrspace(5) %"79", align 1
br i1 %"96", label %"19", label %"20"
"19": ; preds = %"18"
%"98" = load i32, ptr addrspace(5) %"81", align 4
%"97" = add i32 %"98", 1
store i32 %"97", ptr addrspace(5) %"81", align 4
br label %"20"
"20": ; preds = %"19", %"18"
store i1 true, ptr addrspace(5) %"80", align 1
%"101" = load i1, ptr addrspace(5) %"80", align 1
%"100" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"101", i1 false)
store i1 %"100", ptr addrspace(5) %"79", align 1
%"102" = load i1, ptr addrspace(5) %"79", align 1
br i1 %"102", label %"21", label %"22"
"21": ; preds = %"20"
%"104" = load i32, ptr addrspace(5) %"81", align 4
%"103" = add i32 %"104", 1
store i32 %"103", ptr addrspace(5) %"81", align 4
br label %"22"
"22": ; preds = %"21", %"20"
store i1 false, ptr addrspace(5) %"80", align 1
%"107" = load i1, ptr addrspace(5) %"80", align 1
%"106" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"107", i1 false)
store i1 %"106", ptr addrspace(5) %"79", align 1
%"108" = load i1, ptr addrspace(5) %"79", align 1
br i1 %"108", label %"23", label %"24"
"23": ; preds = %"22"
%"110" = load i32, ptr addrspace(5) %"81", align 4
%"109" = add i32 %"110", 1
store i32 %"109", ptr addrspace(5) %"81", align 4
br label %"24"
"24": ; preds = %"23", %"22"
store i1 true, ptr addrspace(5) %"80", align 1
%"113" = load i1, ptr addrspace(5) %"80", align 1
%"112" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"113", i1 true)
store i1 %"112", ptr addrspace(5) %"79", align 1
%"114" = load i1, ptr addrspace(5) %"79", align 1
br i1 %"114", label %"25", label %"26"
"25": ; preds = %"24"
%"116" = load i32, ptr addrspace(5) %"81", align 4
%"115" = add i32 %"116", 1
store i32 %"115", ptr addrspace(5) %"81", align 4
br label %"26"
"26": ; preds = %"25", %"24"
%"118" = load i32, ptr addrspace(5) %"77", align 4
%"117" = zext i32 %"118" to i64
store i64 %"117", ptr addrspace(5) %"76", align 4
%"120" = load i64, ptr addrspace(5) %"76", align 4
%"119" = mul i64 %"120", 4
store i64 %"119", ptr addrspace(5) %"76", align 4
%"122" = load i64, ptr addrspace(5) %"75", align 4
%"123" = load i64, ptr addrspace(5) %"76", align 4
%"121" = add i64 %"122", %"123"
store i64 %"121", ptr addrspace(5) %"75", align 4
%"124" = load i64, ptr addrspace(5) %"75", align 4
%"125" = load i32, ptr addrspace(5) %"81", align 4
%"126" = inttoptr i64 %"124" to ptr
store i32 %"125", ptr %"126", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,60 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry bar_red_and_pred(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u64 out_index;
.reg .u32 thread_id;
.reg .u32 thread_mod_2;
.reg .pred pred;
.reg .pred cond;
.reg .u32 result;
ld.param.u64 out_addr, [output];
mov.u32 thread_id, %tid.x;
rem.u32 thread_mod_2, thread_id, 2;
setp.eq.u32 cond, thread_mod_2, 0;
mov.u32 result, 0;
// Basic functionality
// result += AND(tid.x % 2 == 0) forall threads
bar.red.and.pred pred, 1, cond;
@pred add.u32 result, result, 1;
// result += OR(tid.x % 2 == 0) forall threads
bar.red.or.pred pred, 1, cond;
@pred add.u32 result, result, 1;
// result += AND(true) forall threads
setp.eq.u32 cond, 1, 1;
bar.red.and.pred pred, 1, cond;
@pred add.u32 result, result, 1;
// result += OR(false) forall threads
setp.eq.u32 cond, 1, 0;
bar.red.or.pred pred, 1, cond;
@pred add.u32 result, result, 1;
// Negated condition
// result += AND(!true) forall threads
setp.eq.u32 cond, 1, 1;
bar.red.and.pred pred, 1, !cond;
@pred add.u32 result, result, 1;
// Return result
cvt.u64.u32 out_index, thread_id;
mul.lo.u64 out_index, out_index, 4;
add.u64 out_addr, out_addr, out_index;
st.u32 [out_addr], result;
// result should be 2
ret;
}

View File

@ -307,6 +307,13 @@ test_ptx_warp!(tid, [
32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8,
48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8,
]);
test_ptx_warp!(bar_red_and_pred, [
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
]);
struct DisplayError<T: Debug> {
err: T,
}

View File

@ -2,7 +2,7 @@ use super::{
AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
StateSpace, VectorPrefix,
};
use crate::{Mul24Control, PtxError, PtxParserState};
use crate::{Mul24Control, Reduction, PtxError, PtxParserState};
use bitflags::bitflags;
use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
@ -95,6 +95,26 @@ ptx_parser_macros::generate_instruction_type!(
src2: Option<T>,
}
},
BarRed {
type: Type::Scalar(ScalarType::U32),
data: BarRedData,
arguments<T>: {
dst1: {
repr: T,
type: Type::from(ScalarType::Pred)
},
src_barrier: T,
src_threadcount: Option<T>,
src_predicate: {
repr: T,
type: Type::from(ScalarType::Pred)
},
src_negate_predicate: {
repr: T,
type: Type::from(ScalarType::Pred)
},
}
},
Bfe {
type: Type::Scalar(data.clone()),
data: ScalarType,
@ -1745,6 +1765,12 @@ pub struct BarData {
pub aligned: bool,
}
#[derive(Copy, Clone)]
pub struct BarRedData {
pub aligned: bool,
pub pred_reduction: Reduction,
}
pub struct AtomDetails {
pub type_: Type,
pub semantics: AtomSemantics,

View File

@ -1705,6 +1705,9 @@ derive_parser!(
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum Mul24Control { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum Reduction { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => {
Instruction::Mov {
@ -2987,7 +2990,9 @@ derive_parser!(
barrier{.cta}.sync{.aligned} a{, b} => {
let _ = cta;
ast::Instruction::Bar {
data: ast::BarData { aligned },
data: ast::BarData {
aligned,
},
arguments: BarArgs { src1: a, src2: b }
}
}
@ -2997,14 +3002,32 @@ derive_parser!(
bar{.cta}.sync a{, b} => {
let _ = cta;
ast::Instruction::Bar {
data: ast::BarData { aligned: true },
data: ast::BarData {
aligned: true,
},
arguments: BarArgs { src1: a, src2: b }
}
}
//bar{.cta}.arrive a, b;
//bar{.cta}.red.popc.u32 d, a{, b}, {!}c;
//bar{.cta}.red.op.pred p, a{, b}, {!}c;
//.op = { .and, .or };
bar{.cta}.red.op.pred p, a{, b}, {!}c => {
let _ = cta;
let (negate_src3, c) = c;
ast::Instruction::BarRed {
data: ast::BarRedData {
aligned: true,
pred_reduction: op,
},
arguments: BarRedArgs {
dst1: p,
src_barrier: a,
src_threadcount: b,
src_predicate: c,
src_negate_predicate: ParsedOperand::Imm(ImmediateValue::U64(negate_src3 as u64))
}
}
}
.op: Reduction = { .and, .or };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => {

View File

@ -784,7 +784,7 @@ fn emit_definition_parser(
};
let can_be_negated = if arg.can_be_negated {
quote! {
opt(any.verify(|(t, _)| *t == #token_type::Not)).map(|o| o.is_some())
opt(any.verify(|(t, _)| *t == #token_type::Exclamation)).map(|o| o.is_some())
}
} else {
quote! {