mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-20 10:46:21 +03:00
Add support for bar.red.and.pred
(#402)
Implements bar.red.and.pred and bar.red.or.pred, using the undocument __ockl_wgred functions. Doesn't yet add support for numbered barriers and threadcount, as these are not needed for llm.c.
This commit is contained in:
Binary file not shown.
@ -157,6 +157,19 @@ extern "C"
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
int32_t __ockl_wgred_and_i32(int32_t) __device__;
|
||||
int32_t __ockl_wgred_or_i32(int32_t) __device__;
|
||||
|
||||
#define BAR_RED_IMPL(reducer) \
|
||||
bool FUNC(bar_red_##reducer##_pred)(uint32_t barrier __attribute__((unused)), bool predicate, bool invert_predicate) \
|
||||
{ \
|
||||
/* TODO: handle barrier */ \
|
||||
return __ockl_wgred_##reducer##_i32(predicate ^ invert_predicate); \
|
||||
}
|
||||
|
||||
BAR_RED_IMPL(and);
|
||||
BAR_RED_IMPL(or);
|
||||
|
||||
void FUNC(__assertfail)(uint64_t message,
|
||||
uint64_t file,
|
||||
uint32_t line,
|
||||
|
@ -639,6 +639,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||
// replaced by a function call
|
||||
ast::Instruction::Bfe { .. }
|
||||
| ast::Instruction::Bar { .. }
|
||||
| ast::Instruction::BarRed { .. }
|
||||
| ast::Instruction::Bfi { .. }
|
||||
| ast::Instruction::Activemask { .. } => return Err(error_unreachable()),
|
||||
}
|
||||
|
@ -75,6 +75,7 @@ fn run_instruction<'input>(
|
||||
| ast::Instruction::Atom { .. }
|
||||
| ast::Instruction::AtomCas { .. }
|
||||
| ast::Instruction::Bar { .. }
|
||||
| ast::Instruction::BarRed { .. }
|
||||
| ast::Instruction::Bfe { .. }
|
||||
| ast::Instruction::Bfi { .. }
|
||||
| ast::Instruction::Bra { .. }
|
||||
|
@ -1804,6 +1804,7 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||
| ast::Instruction::Selp { .. }
|
||||
| ast::Instruction::Ret { .. }
|
||||
| ast::Instruction::Bar { .. }
|
||||
| ast::Instruction::BarRed { .. }
|
||||
| ast::Instruction::Cvta { .. }
|
||||
| ast::Instruction::Atom { .. }
|
||||
| ast::Instruction::Mul24 { .. }
|
||||
|
@ -108,6 +108,16 @@ fn run_instruction<'input>(
|
||||
i @ ptx_parser::Instruction::Bar { .. } => {
|
||||
to_call(resolver, fn_declarations, "bar_sync".into(), i)?
|
||||
}
|
||||
ptx_parser::Instruction::BarRed { data, arguments } => {
|
||||
if arguments.src_threadcount.is_some() {
|
||||
return Err(error_todo());
|
||||
}
|
||||
let name = match data.pred_reduction {
|
||||
ptx_parser::Reduction::And => "bar_red_and_pred",
|
||||
ptx_parser::Reduction::Or => "bar_red_or_pred",
|
||||
};
|
||||
to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })?
|
||||
}
|
||||
i => i,
|
||||
})
|
||||
}
|
||||
|
121
ptx/src/test/ll/bar_red_and_pred.ll
Normal file
121
ptx/src/test/ll/bar_red_and_pred.ll
Normal file
@ -0,0 +1,121 @@
|
||||
declare i1 @__zluda_ptx_impl_bar_red_and_pred(i32, i1, i1) #0
|
||||
|
||||
declare i1 @__zluda_ptx_impl_bar_red_or_pred(i32, i1, i1) #0
|
||||
|
||||
declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0
|
||||
|
||||
define amdgpu_kernel void @bar_red_and_pred(ptr addrspace(4) byref(i64) %"73", ptr addrspace(4) byref(i64) %"74") #1 {
|
||||
%"75" = alloca i64, align 8, addrspace(5)
|
||||
%"76" = alloca i64, align 8, addrspace(5)
|
||||
%"77" = alloca i32, align 4, addrspace(5)
|
||||
%"78" = alloca i32, align 4, addrspace(5)
|
||||
%"79" = alloca i1, align 1, addrspace(5)
|
||||
%"80" = alloca i1, align 1, addrspace(5)
|
||||
%"81" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"70"
|
||||
|
||||
"70": ; preds = %1
|
||||
%"82" = load i64, ptr addrspace(4) %"74", align 4
|
||||
store i64 %"82", ptr addrspace(5) %"75", align 4
|
||||
%"44" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
|
||||
br label %"71"
|
||||
|
||||
"71": ; preds = %"70"
|
||||
store i32 %"44", ptr addrspace(5) %"77", align 4
|
||||
%"85" = load i32, ptr addrspace(5) %"77", align 4
|
||||
%"84" = urem i32 %"85", 2
|
||||
store i32 %"84", ptr addrspace(5) %"78", align 4
|
||||
%"87" = load i32, ptr addrspace(5) %"78", align 4
|
||||
%"86" = icmp eq i32 %"87", 0
|
||||
store i1 %"86", ptr addrspace(5) %"80", align 1
|
||||
store i32 0, ptr addrspace(5) %"81", align 4
|
||||
%"90" = load i1, ptr addrspace(5) %"80", align 1
|
||||
%"89" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"90", i1 false)
|
||||
store i1 %"89", ptr addrspace(5) %"79", align 1
|
||||
%"91" = load i1, ptr addrspace(5) %"79", align 1
|
||||
br i1 %"91", label %"17", label %"18"
|
||||
|
||||
"17": ; preds = %"71"
|
||||
%"93" = load i32, ptr addrspace(5) %"81", align 4
|
||||
%"92" = add i32 %"93", 1
|
||||
store i32 %"92", ptr addrspace(5) %"81", align 4
|
||||
br label %"18"
|
||||
|
||||
"18": ; preds = %"17", %"71"
|
||||
%"95" = load i1, ptr addrspace(5) %"80", align 1
|
||||
%"94" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"95", i1 false)
|
||||
store i1 %"94", ptr addrspace(5) %"79", align 1
|
||||
%"96" = load i1, ptr addrspace(5) %"79", align 1
|
||||
br i1 %"96", label %"19", label %"20"
|
||||
|
||||
"19": ; preds = %"18"
|
||||
%"98" = load i32, ptr addrspace(5) %"81", align 4
|
||||
%"97" = add i32 %"98", 1
|
||||
store i32 %"97", ptr addrspace(5) %"81", align 4
|
||||
br label %"20"
|
||||
|
||||
"20": ; preds = %"19", %"18"
|
||||
store i1 true, ptr addrspace(5) %"80", align 1
|
||||
%"101" = load i1, ptr addrspace(5) %"80", align 1
|
||||
%"100" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"101", i1 false)
|
||||
store i1 %"100", ptr addrspace(5) %"79", align 1
|
||||
%"102" = load i1, ptr addrspace(5) %"79", align 1
|
||||
br i1 %"102", label %"21", label %"22"
|
||||
|
||||
"21": ; preds = %"20"
|
||||
%"104" = load i32, ptr addrspace(5) %"81", align 4
|
||||
%"103" = add i32 %"104", 1
|
||||
store i32 %"103", ptr addrspace(5) %"81", align 4
|
||||
br label %"22"
|
||||
|
||||
"22": ; preds = %"21", %"20"
|
||||
store i1 false, ptr addrspace(5) %"80", align 1
|
||||
%"107" = load i1, ptr addrspace(5) %"80", align 1
|
||||
%"106" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"107", i1 false)
|
||||
store i1 %"106", ptr addrspace(5) %"79", align 1
|
||||
%"108" = load i1, ptr addrspace(5) %"79", align 1
|
||||
br i1 %"108", label %"23", label %"24"
|
||||
|
||||
"23": ; preds = %"22"
|
||||
%"110" = load i32, ptr addrspace(5) %"81", align 4
|
||||
%"109" = add i32 %"110", 1
|
||||
store i32 %"109", ptr addrspace(5) %"81", align 4
|
||||
br label %"24"
|
||||
|
||||
"24": ; preds = %"23", %"22"
|
||||
store i1 true, ptr addrspace(5) %"80", align 1
|
||||
%"113" = load i1, ptr addrspace(5) %"80", align 1
|
||||
%"112" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"113", i1 true)
|
||||
store i1 %"112", ptr addrspace(5) %"79", align 1
|
||||
%"114" = load i1, ptr addrspace(5) %"79", align 1
|
||||
br i1 %"114", label %"25", label %"26"
|
||||
|
||||
"25": ; preds = %"24"
|
||||
%"116" = load i32, ptr addrspace(5) %"81", align 4
|
||||
%"115" = add i32 %"116", 1
|
||||
store i32 %"115", ptr addrspace(5) %"81", align 4
|
||||
br label %"26"
|
||||
|
||||
"26": ; preds = %"25", %"24"
|
||||
%"118" = load i32, ptr addrspace(5) %"77", align 4
|
||||
%"117" = zext i32 %"118" to i64
|
||||
store i64 %"117", ptr addrspace(5) %"76", align 4
|
||||
%"120" = load i64, ptr addrspace(5) %"76", align 4
|
||||
%"119" = mul i64 %"120", 4
|
||||
store i64 %"119", ptr addrspace(5) %"76", align 4
|
||||
%"122" = load i64, ptr addrspace(5) %"75", align 4
|
||||
%"123" = load i64, ptr addrspace(5) %"76", align 4
|
||||
%"121" = add i64 %"122", %"123"
|
||||
store i64 %"121", ptr addrspace(5) %"75", align 4
|
||||
%"124" = load i64, ptr addrspace(5) %"75", align 4
|
||||
%"125" = load i32, ptr addrspace(5) %"81", align 4
|
||||
%"126" = inttoptr i64 %"124" to ptr
|
||||
store i32 %"125", ptr %"126", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
60
ptx/src/test/spirv_run/bar_red_and_pred.ptx
Normal file
60
ptx/src/test/spirv_run/bar_red_and_pred.ptx
Normal file
@ -0,0 +1,60 @@
|
||||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry bar_red_and_pred(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 out_index;
|
||||
.reg .u32 thread_id;
|
||||
.reg .u32 thread_mod_2;
|
||||
.reg .pred pred;
|
||||
.reg .pred cond;
|
||||
.reg .u32 result;
|
||||
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
mov.u32 thread_id, %tid.x;
|
||||
rem.u32 thread_mod_2, thread_id, 2;
|
||||
setp.eq.u32 cond, thread_mod_2, 0;
|
||||
|
||||
mov.u32 result, 0;
|
||||
|
||||
// Basic functionality
|
||||
|
||||
// result += AND(tid.x % 2 == 0) forall threads
|
||||
bar.red.and.pred pred, 1, cond;
|
||||
@pred add.u32 result, result, 1;
|
||||
// result += OR(tid.x % 2 == 0) forall threads
|
||||
bar.red.or.pred pred, 1, cond;
|
||||
@pred add.u32 result, result, 1;
|
||||
|
||||
// result += AND(true) forall threads
|
||||
setp.eq.u32 cond, 1, 1;
|
||||
bar.red.and.pred pred, 1, cond;
|
||||
@pred add.u32 result, result, 1;
|
||||
// result += OR(false) forall threads
|
||||
setp.eq.u32 cond, 1, 0;
|
||||
bar.red.or.pred pred, 1, cond;
|
||||
@pred add.u32 result, result, 1;
|
||||
|
||||
// Negated condition
|
||||
// result += AND(!true) forall threads
|
||||
setp.eq.u32 cond, 1, 1;
|
||||
bar.red.and.pred pred, 1, !cond;
|
||||
@pred add.u32 result, result, 1;
|
||||
|
||||
// Return result
|
||||
|
||||
cvt.u64.u32 out_index, thread_id;
|
||||
mul.lo.u64 out_index, out_index, 4;
|
||||
add.u64 out_addr, out_addr, out_index;
|
||||
st.u32 [out_addr], result;
|
||||
|
||||
// result should be 2
|
||||
|
||||
ret;
|
||||
}
|
@ -307,6 +307,13 @@ test_ptx_warp!(tid, [
|
||||
32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8,
|
||||
48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8,
|
||||
]);
|
||||
test_ptx_warp!(bar_red_and_pred, [
|
||||
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
|
||||
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
|
||||
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
|
||||
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
|
||||
]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use super::{
|
||||
AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
|
||||
StateSpace, VectorPrefix,
|
||||
};
|
||||
use crate::{Mul24Control, PtxError, PtxParserState};
|
||||
use crate::{Mul24Control, Reduction, PtxError, PtxParserState};
|
||||
use bitflags::bitflags;
|
||||
use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
|
||||
|
||||
@ -95,6 +95,26 @@ ptx_parser_macros::generate_instruction_type!(
|
||||
src2: Option<T>,
|
||||
}
|
||||
},
|
||||
BarRed {
|
||||
type: Type::Scalar(ScalarType::U32),
|
||||
data: BarRedData,
|
||||
arguments<T>: {
|
||||
dst1: {
|
||||
repr: T,
|
||||
type: Type::from(ScalarType::Pred)
|
||||
},
|
||||
src_barrier: T,
|
||||
src_threadcount: Option<T>,
|
||||
src_predicate: {
|
||||
repr: T,
|
||||
type: Type::from(ScalarType::Pred)
|
||||
},
|
||||
src_negate_predicate: {
|
||||
repr: T,
|
||||
type: Type::from(ScalarType::Pred)
|
||||
},
|
||||
}
|
||||
},
|
||||
Bfe {
|
||||
type: Type::Scalar(data.clone()),
|
||||
data: ScalarType,
|
||||
@ -1745,6 +1765,12 @@ pub struct BarData {
|
||||
pub aligned: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct BarRedData {
|
||||
pub aligned: bool,
|
||||
pub pred_reduction: Reduction,
|
||||
}
|
||||
|
||||
pub struct AtomDetails {
|
||||
pub type_: Type,
|
||||
pub semantics: AtomSemantics,
|
||||
|
@ -1705,6 +1705,9 @@ derive_parser!(
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Mul24Control { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Reduction { }
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
||||
mov{.vec}.type d, a => {
|
||||
Instruction::Mov {
|
||||
@ -2987,7 +2990,9 @@ derive_parser!(
|
||||
barrier{.cta}.sync{.aligned} a{, b} => {
|
||||
let _ = cta;
|
||||
ast::Instruction::Bar {
|
||||
data: ast::BarData { aligned },
|
||||
data: ast::BarData {
|
||||
aligned,
|
||||
},
|
||||
arguments: BarArgs { src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
@ -2997,14 +3002,32 @@ derive_parser!(
|
||||
bar{.cta}.sync a{, b} => {
|
||||
let _ = cta;
|
||||
ast::Instruction::Bar {
|
||||
data: ast::BarData { aligned: true },
|
||||
data: ast::BarData {
|
||||
aligned: true,
|
||||
},
|
||||
arguments: BarArgs { src1: a, src2: b }
|
||||
}
|
||||
}
|
||||
//bar{.cta}.arrive a, b;
|
||||
//bar{.cta}.red.popc.u32 d, a{, b}, {!}c;
|
||||
//bar{.cta}.red.op.pred p, a{, b}, {!}c;
|
||||
//.op = { .and, .or };
|
||||
bar{.cta}.red.op.pred p, a{, b}, {!}c => {
|
||||
let _ = cta;
|
||||
let (negate_src3, c) = c;
|
||||
ast::Instruction::BarRed {
|
||||
data: ast::BarRedData {
|
||||
aligned: true,
|
||||
pred_reduction: op,
|
||||
},
|
||||
arguments: BarRedArgs {
|
||||
dst1: p,
|
||||
src_barrier: a,
|
||||
src_threadcount: b,
|
||||
src_predicate: c,
|
||||
src_negate_predicate: ParsedOperand::Imm(ImmediateValue::U64(negate_src3 as u64))
|
||||
}
|
||||
}
|
||||
}
|
||||
.op: Reduction = { .and, .or };
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
|
||||
atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => {
|
||||
|
@ -784,7 +784,7 @@ fn emit_definition_parser(
|
||||
};
|
||||
let can_be_negated = if arg.can_be_negated {
|
||||
quote! {
|
||||
opt(any.verify(|(t, _)| *t == #token_type::Not)).map(|o| o.is_some())
|
||||
opt(any.verify(|(t, _)| *t == #token_type::Exclamation)).map(|o| o.is_some())
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
|
Reference in New Issue
Block a user