Add support for shfl.sync.MODE.b32 (#409)

This commit is contained in:
Violet
2025-07-16 17:23:11 -07:00
committed by GitHub
parent 36f0ba9cbb
commit dc69808e54
20 changed files with 623 additions and 4 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@ Cargo.lock
.vscode/ .vscode/
.idea/ .idea/
ptx/lib/zluda_ptx_impl.ll

Binary file not shown.

View File

@ -4,6 +4,7 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <hip/amd_detail/amd_device_functions.h> #include <hip/amd_detail/amd_device_functions.h>
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
@ -170,6 +171,55 @@ extern "C"
BAR_RED_IMPL(and); BAR_RED_IMPL(and);
BAR_RED_IMPL(or); BAR_RED_IMPL(or);
struct ShflSyncResult {
uint32_t output;
bool in_bounds;
};
// shfl.sync opts consists of two values, the warp end ID and the subsection mask.
//
// The current warp is partitioned into some number of subsections with a width of w. The
// subsection mask is 32 - w, and indicates which bits of the lane id are part of the subsection
// address. For example, if each subsection is 8 lanes wide, the subsection mask will be 24
// 11000 in binary. This indicates that the two most significant bits in the 5-bit lane ID are
// the subsection address. For example, for a lane ID 13 (0b01101) the address of the beginning
// of the subsection is 0b01000 (8).
//
// The warp end ID is the max lane ID for a specific mode. For the CUDA __shfl_sync
// intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the
// bounds check.
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \
{ \
int32_t section_mask = (opts >> 8) & 0b11111; \
int32_t warp_end = opts & 0b11111; \
int32_t self = (int32_t)__lane_id(); \
int32_t subsection = section_mask & self; \
int32_t subsection_end = subsection | (~section_mask & warp_end); \
int32_t idx = calculate_index; \
bool out_of_bounds = idx CMP subsection_end; \
if (out_of_bounds) { \
idx = self; \
} \
int32_t output = __builtin_amdgcn_ds_bpermute(idx<<2, (int32_t)input); \
return {(uint32_t)output, !out_of_bounds}; \
} \
\
uint32_t FUNC(shfl_sync_##mode##_b32)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask) \
{ \
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \
}
// We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync
// intrinsics, as those only add an assertion checking that the membermask is used correctly.
// They do not return the result of the range check, so we must replicate that logic here.
SHFL_SYNC_IMPL(up, self - delta, <);
SHFL_SYNC_IMPL(down, self + delta, >);
SHFL_SYNC_IMPL(bfly, self ^ delta, >);
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
void FUNC(__assertfail)(uint64_t message, void FUNC(__assertfail)(uint64_t message,
uint64_t file, uint64_t file,
uint32_t line, uint32_t line,

View File

@ -641,7 +641,8 @@ impl<'a> MethodEmitContext<'a> {
| ast::Instruction::Bar { .. } | ast::Instruction::Bar { .. }
| ast::Instruction::BarRed { .. } | ast::Instruction::BarRed { .. }
| ast::Instruction::Bfi { .. } | ast::Instruction::Bfi { .. }
| ast::Instruction::Activemask { .. } => return Err(error_unreachable()), | ast::Instruction::Activemask { .. }
| ast::Instruction::ShflSync { .. } => return Err(error_unreachable()),
} }
} }

View File

@ -165,6 +165,7 @@ fn run_instruction<'input>(
| ast::Instruction::Selp { .. } | ast::Instruction::Selp { .. }
| ast::Instruction::Setp { .. } | ast::Instruction::Setp { .. }
| ast::Instruction::SetpBool { .. } | ast::Instruction::SetpBool { .. }
| ast::Instruction::ShflSync { .. }
| ast::Instruction::Shl { .. } | ast::Instruction::Shl { .. }
| ast::Instruction::Shr { .. } | ast::Instruction::Shr { .. }
| ast::Instruction::Sin { .. } | ast::Instruction::Sin { .. }

View File

@ -1800,6 +1800,7 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::Bfe { .. } | ast::Instruction::Bfe { .. }
| ast::Instruction::Bfi { .. } | ast::Instruction::Bfi { .. }
| ast::Instruction::Shr { .. } | ast::Instruction::Shr { .. }
| ast::Instruction::ShflSync { .. }
| ast::Instruction::Shl { .. } | ast::Instruction::Shl { .. }
| ast::Instruction::Selp { .. } | ast::Instruction::Selp { .. }
| ast::Instruction::Ret { .. } | ast::Instruction::Ret { .. }

View File

@ -118,6 +118,25 @@ fn run_instruction<'input>(
}; };
to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })? to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })?
} }
ptx_parser::Instruction::ShflSync { data, arguments } => {
let mode = match data.mode {
ptx_parser::ShuffleMode::Up => "up",
ptx_parser::ShuffleMode::Down => "down",
ptx_parser::ShuffleMode::BFly => "bfly",
ptx_parser::ShuffleMode::Idx => "idx",
};
let pred = if arguments.dst_pred.is_some() {
"_pred"
} else {
""
};
to_call(
resolver,
fn_declarations,
format!("shfl_sync_{}_b32{}", mode, pred).into(),
ptx_parser::Instruction::ShflSync { data, arguments },
)?
}
i => i, i => i,
}) })
} }

View File

@ -0,0 +1,59 @@
declare [2 x i32] @__zluda_ptx_impl_shfl_sync_bfly_b32_pred(i32, i32, i32, i32) #0
declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @shfl_sync_bfly_b32_pred(ptr addrspace(4) byref(i64) %"42") #1 {
%"43" = alloca i64, align 8, addrspace(5)
%"44" = alloca i64, align 8, addrspace(5)
%"45" = alloca i32, align 4, addrspace(5)
%"46" = alloca i32, align 4, addrspace(5)
%"47" = alloca i1, align 1, addrspace(5)
br label %1
1: ; preds = %0
br label %"39"
"39": ; preds = %1
%"48" = load i64, ptr addrspace(4) %"42", align 4
store i64 %"48", ptr addrspace(5) %"43", align 4
%"33" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"40"
"40": ; preds = %"39"
store i32 %"33", ptr addrspace(5) %"45", align 4
%"52" = load i32, ptr addrspace(5) %"45", align 4
%2 = call [2 x i32] @__zluda_ptx_impl_shfl_sync_bfly_b32_pred(i32 %"52", i32 3, i32 31, i32 -1)
%"65" = extractvalue [2 x i32] %2, 0
%3 = extractvalue [2 x i32] %2, 1
%"51" = trunc i32 %3 to i1
store i32 %"65", ptr addrspace(5) %"46", align 4
store i1 %"51", ptr addrspace(5) %"47", align 1
%"53" = load i1, ptr addrspace(5) %"47", align 1
br i1 %"53", label %"15", label %"14"
"14": ; preds = %"40"
%"55" = load i32, ptr addrspace(5) %"46", align 4
%"54" = add i32 %"55", 1000
store i32 %"54", ptr addrspace(5) %"46", align 4
br label %"15"
"15": ; preds = %"14", %"40"
%"57" = load i32, ptr addrspace(5) %"45", align 4
%"56" = zext i32 %"57" to i64
store i64 %"56", ptr addrspace(5) %"44", align 4
%"59" = load i64, ptr addrspace(5) %"44", align 4
%"58" = mul i64 %"59", 4
store i64 %"58", ptr addrspace(5) %"44", align 4
%"61" = load i64, ptr addrspace(5) %"43", align 4
%"62" = load i64, ptr addrspace(5) %"44", align 4
%"60" = add i64 %"61", %"62"
store i64 %"60", ptr addrspace(5) %"43", align 4
%"63" = load i64, ptr addrspace(5) %"43", align 4
%"64" = load i32, ptr addrspace(5) %"46", align 4
%"67" = inttoptr i64 %"63" to ptr
store i32 %"64", ptr %"67", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,59 @@
declare [2 x i32] @__zluda_ptx_impl_shfl_sync_down_b32_pred(i32, i32, i32, i32) #0
declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @shfl_sync_down_b32_pred(ptr addrspace(4) byref(i64) %"42") #1 {
%"43" = alloca i64, align 8, addrspace(5)
%"44" = alloca i64, align 8, addrspace(5)
%"45" = alloca i32, align 4, addrspace(5)
%"46" = alloca i32, align 4, addrspace(5)
%"47" = alloca i1, align 1, addrspace(5)
br label %1
1: ; preds = %0
br label %"39"
"39": ; preds = %1
%"48" = load i64, ptr addrspace(4) %"42", align 4
store i64 %"48", ptr addrspace(5) %"43", align 4
%"33" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"40"
"40": ; preds = %"39"
store i32 %"33", ptr addrspace(5) %"45", align 4
%"52" = load i32, ptr addrspace(5) %"45", align 4
%2 = call [2 x i32] @__zluda_ptx_impl_shfl_sync_down_b32_pred(i32 %"52", i32 3, i32 31, i32 -1)
%"65" = extractvalue [2 x i32] %2, 0
%3 = extractvalue [2 x i32] %2, 1
%"51" = trunc i32 %3 to i1
store i32 %"65", ptr addrspace(5) %"46", align 4
store i1 %"51", ptr addrspace(5) %"47", align 1
%"53" = load i1, ptr addrspace(5) %"47", align 1
br i1 %"53", label %"15", label %"14"
"14": ; preds = %"40"
%"55" = load i32, ptr addrspace(5) %"46", align 4
%"54" = add i32 %"55", 1000
store i32 %"54", ptr addrspace(5) %"46", align 4
br label %"15"
"15": ; preds = %"14", %"40"
%"57" = load i32, ptr addrspace(5) %"45", align 4
%"56" = zext i32 %"57" to i64
store i64 %"56", ptr addrspace(5) %"44", align 4
%"59" = load i64, ptr addrspace(5) %"44", align 4
%"58" = mul i64 %"59", 4
store i64 %"58", ptr addrspace(5) %"44", align 4
%"61" = load i64, ptr addrspace(5) %"43", align 4
%"62" = load i64, ptr addrspace(5) %"44", align 4
%"60" = add i64 %"61", %"62"
store i64 %"60", ptr addrspace(5) %"43", align 4
%"63" = load i64, ptr addrspace(5) %"43", align 4
%"64" = load i32, ptr addrspace(5) %"46", align 4
%"67" = inttoptr i64 %"63" to ptr
store i32 %"64", ptr %"67", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,59 @@
declare [2 x i32] @__zluda_ptx_impl_shfl_sync_idx_b32_pred(i32, i32, i32, i32) #0
declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @shfl_sync_idx_b32_pred(ptr addrspace(4) byref(i64) %"42") #1 {
%"43" = alloca i64, align 8, addrspace(5)
%"44" = alloca i64, align 8, addrspace(5)
%"45" = alloca i32, align 4, addrspace(5)
%"46" = alloca i32, align 4, addrspace(5)
%"47" = alloca i1, align 1, addrspace(5)
br label %1
1: ; preds = %0
br label %"39"
"39": ; preds = %1
%"48" = load i64, ptr addrspace(4) %"42", align 4
store i64 %"48", ptr addrspace(5) %"43", align 4
%"33" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"40"
"40": ; preds = %"39"
store i32 %"33", ptr addrspace(5) %"45", align 4
%"52" = load i32, ptr addrspace(5) %"45", align 4
%2 = call [2 x i32] @__zluda_ptx_impl_shfl_sync_idx_b32_pred(i32 %"52", i32 12, i32 31, i32 -1)
%"65" = extractvalue [2 x i32] %2, 0
%3 = extractvalue [2 x i32] %2, 1
%"51" = trunc i32 %3 to i1
store i32 %"65", ptr addrspace(5) %"46", align 4
store i1 %"51", ptr addrspace(5) %"47", align 1
%"53" = load i1, ptr addrspace(5) %"47", align 1
br i1 %"53", label %"15", label %"14"
"14": ; preds = %"40"
%"55" = load i32, ptr addrspace(5) %"46", align 4
%"54" = add i32 %"55", 1000
store i32 %"54", ptr addrspace(5) %"46", align 4
br label %"15"
"15": ; preds = %"14", %"40"
%"57" = load i32, ptr addrspace(5) %"45", align 4
%"56" = zext i32 %"57" to i64
store i64 %"56", ptr addrspace(5) %"44", align 4
%"59" = load i64, ptr addrspace(5) %"44", align 4
%"58" = mul i64 %"59", 4
store i64 %"58", ptr addrspace(5) %"44", align 4
%"61" = load i64, ptr addrspace(5) %"43", align 4
%"62" = load i64, ptr addrspace(5) %"44", align 4
%"60" = add i64 %"61", %"62"
store i64 %"60", ptr addrspace(5) %"43", align 4
%"63" = load i64, ptr addrspace(5) %"43", align 4
%"64" = load i32, ptr addrspace(5) %"46", align 4
%"67" = inttoptr i64 %"63" to ptr
store i32 %"64", ptr %"67", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,74 @@
declare i32 @__zluda_ptx_impl_shfl_sync_down_b32(i32, i32, i32, i32) #0
declare i32 @__zluda_ptx_impl_shfl_sync_up_b32(i32, i32, i32, i32) #0
declare i32 @__zluda_ptx_impl_shfl_sync_bfly_b32(i32, i32, i32, i32) #0
declare i32 @__zluda_ptx_impl_shfl_sync_idx_b32(i32, i32, i32, i32) #0
declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @shfl_sync_mode_b32(ptr addrspace(4) byref(i64) %"48") #1 {
%"49" = alloca i64, align 8, addrspace(5)
%"50" = alloca i64, align 8, addrspace(5)
%"51" = alloca i32, align 4, addrspace(5)
%"52" = alloca i32, align 4, addrspace(5)
%"53" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"45"
"45": ; preds = %1
%"54" = load i64, ptr addrspace(4) %"48", align 4
store i64 %"54", ptr addrspace(5) %"49", align 4
%"31" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"46"
"46": ; preds = %"45"
store i32 %"31", ptr addrspace(5) %"51", align 4
%"57" = load i32, ptr addrspace(5) %"51", align 4
%"84" = call i32 @__zluda_ptx_impl_shfl_sync_up_b32(i32 %"57", i32 3, i32 7680, i32 -1)
store i32 %"84", ptr addrspace(5) %"52", align 4
%"59" = load i32, ptr addrspace(5) %"52", align 4
store i32 %"59", ptr addrspace(5) %"53", align 4
%"61" = load i32, ptr addrspace(5) %"51", align 4
%"86" = call i32 @__zluda_ptx_impl_shfl_sync_down_b32(i32 %"61", i32 3, i32 7199, i32 -1)
store i32 %"86", ptr addrspace(5) %"52", align 4
%"63" = load i32, ptr addrspace(5) %"53", align 4
%"64" = load i32, ptr addrspace(5) %"52", align 4
%"62" = add i32 %"63", %"64"
store i32 %"62", ptr addrspace(5) %"53", align 4
%"66" = load i32, ptr addrspace(5) %"51", align 4
%"88" = call i32 @__zluda_ptx_impl_shfl_sync_bfly_b32(i32 %"66", i32 3, i32 6175, i32 -1)
store i32 %"88", ptr addrspace(5) %"52", align 4
%"68" = load i32, ptr addrspace(5) %"53", align 4
%"69" = load i32, ptr addrspace(5) %"52", align 4
%"67" = add i32 %"68", %"69"
store i32 %"67", ptr addrspace(5) %"53", align 4
%"71" = load i32, ptr addrspace(5) %"51", align 4
%"90" = call i32 @__zluda_ptx_impl_shfl_sync_idx_b32(i32 %"71", i32 3, i32 4127, i32 -1)
store i32 %"90", ptr addrspace(5) %"52", align 4
%"73" = load i32, ptr addrspace(5) %"53", align 4
%"74" = load i32, ptr addrspace(5) %"52", align 4
%"72" = add i32 %"73", %"74"
store i32 %"72", ptr addrspace(5) %"53", align 4
%"76" = load i32, ptr addrspace(5) %"51", align 4
%"75" = zext i32 %"76" to i64
store i64 %"75", ptr addrspace(5) %"50", align 4
%"78" = load i64, ptr addrspace(5) %"50", align 4
%"77" = mul i64 %"78", 4
store i64 %"77", ptr addrspace(5) %"50", align 4
%"80" = load i64, ptr addrspace(5) %"49", align 4
%"81" = load i64, ptr addrspace(5) %"50", align 4
%"79" = add i64 %"80", %"81"
store i64 %"79", ptr addrspace(5) %"49", align 4
%"82" = load i64, ptr addrspace(5) %"49", align 4
%"83" = load i32, ptr addrspace(5) %"53", align 4
%"92" = inttoptr i64 %"82" to ptr
store i32 %"83", ptr %"92", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -0,0 +1,59 @@
declare [2 x i32] @__zluda_ptx_impl_shfl_sync_up_b32_pred(i32, i32, i32, i32) #0
declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @shfl_sync_up_b32_pred(ptr addrspace(4) byref(i64) %"42") #1 {
%"43" = alloca i64, align 8, addrspace(5)
%"44" = alloca i64, align 8, addrspace(5)
%"45" = alloca i32, align 4, addrspace(5)
%"46" = alloca i32, align 4, addrspace(5)
%"47" = alloca i1, align 1, addrspace(5)
br label %1
1: ; preds = %0
br label %"39"
"39": ; preds = %1
%"48" = load i64, ptr addrspace(4) %"42", align 4
store i64 %"48", ptr addrspace(5) %"43", align 4
%"33" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"40"
"40": ; preds = %"39"
store i32 %"33", ptr addrspace(5) %"45", align 4
%"52" = load i32, ptr addrspace(5) %"45", align 4
%2 = call [2 x i32] @__zluda_ptx_impl_shfl_sync_up_b32_pred(i32 %"52", i32 3, i32 0, i32 -1)
%"65" = extractvalue [2 x i32] %2, 0
%3 = extractvalue [2 x i32] %2, 1
%"51" = trunc i32 %3 to i1
store i32 %"65", ptr addrspace(5) %"46", align 4
store i1 %"51", ptr addrspace(5) %"47", align 1
%"53" = load i1, ptr addrspace(5) %"47", align 1
br i1 %"53", label %"15", label %"14"
"14": ; preds = %"40"
%"55" = load i32, ptr addrspace(5) %"46", align 4
%"54" = add i32 %"55", 1000
store i32 %"54", ptr addrspace(5) %"46", align 4
br label %"15"
"15": ; preds = %"14", %"40"
%"57" = load i32, ptr addrspace(5) %"45", align 4
%"56" = zext i32 %"57" to i64
store i64 %"56", ptr addrspace(5) %"44", align 4
%"59" = load i64, ptr addrspace(5) %"44", align 4
%"58" = mul i64 %"59", 4
store i64 %"58", ptr addrspace(5) %"44", align 4
%"61" = load i64, ptr addrspace(5) %"43", align 4
%"62" = load i64, ptr addrspace(5) %"44", align 4
%"60" = add i64 %"61", %"62"
store i64 %"60", ptr addrspace(5) %"43", align 4
%"63" = load i64, ptr addrspace(5) %"43", align 4
%"64" = load i32, ptr addrspace(5) %"46", align 4
%"67" = inttoptr i64 %"63" to ptr
store i32 %"64", ptr %"67", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -315,6 +315,36 @@ test_ptx_warp!(bar_red_and_pred, [
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32,
]); ]);
test_ptx_warp!(shfl_sync_up_b32_pred, [
1000u32, 1001u32, 1002u32, 0u32, 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32,
13u32, 14u32, 15u32, 16u32, 17u32, 18u32, 19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32,
1032u32, 1033u32, 1034u32, 32u32, 33u32, 34u32, 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32,
45u32, 46u32, 47u32, 48u32, 49u32, 50u32, 51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32,
]);
test_ptx_warp!(shfl_sync_down_b32_pred, [
3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, 13u32, 14u32, 15u32, 16u32, 17u32, 18u32,
19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, 29u32, 30u32, 31u32, 1029u32, 1030u32, 1031u32,
35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32, 45u32, 46u32, 47u32, 48u32, 49u32, 50u32,
51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32, 61u32, 62u32, 63u32, 1061u32, 1062u32, 1063u32,
]);
test_ptx_warp!(shfl_sync_bfly_b32_pred, [
3u32, 2u32, 1u32, 0u32, 7u32, 6u32, 5u32, 4u32, 11u32, 10u32, 9u32, 8u32, 15u32, 14u32, 13u32, 12u32,
19u32, 18u32, 17u32, 16u32, 23u32, 22u32, 21u32, 20u32, 27u32, 26u32, 25u32, 24u32, 31u32, 30u32, 29u32, 28u32,
35u32, 34u32, 33u32, 32u32, 39u32, 38u32, 37u32, 36u32, 43u32, 42u32, 41u32, 40u32, 47u32, 46u32, 45u32, 44u32,
51u32, 50u32, 49u32, 48u32, 55u32, 54u32, 53u32, 52u32, 59u32, 58u32, 57u32, 56u32, 63u32, 62u32, 61u32, 60u32,
]);
test_ptx_warp!(shfl_sync_idx_b32_pred, [
12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32,
12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32,
44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32,
44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32,
]);
test_ptx_warp!(shfl_sync_mode_b32, [
9u32, 7u32, 8u32, 9u32, 21u32, 19u32, 20u32, 21u32, 33u32, 31u32, 32u32, 33u32, 45u32, 43u32, 44u32, 45u32,
73u32, 71u32, 72u32, 73u32, 85u32, 83u32, 84u32, 85u32, 97u32, 95u32, 96u32, 97u32, 109u32, 107u32, 108u32, 109u32,
137u32, 135u32, 136u32, 137u32, 149u32, 147u32, 148u32, 149u32, 161u32, 159u32, 160u32, 161u32, 173u32, 171u32, 172u32, 173u32,
201u32, 199u32, 200u32, 201u32, 213u32, 211u32, 212u32, 213u32, 225u32, 223u32, 224u32, 225u32, 237u32, 235u32, 236u32, 237u32,
]);
struct DisplayError<T: Debug> { struct DisplayError<T: Debug> {
err: T, err: T,

View File

@ -0,0 +1,32 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry shfl_sync_bfly_b32_pred(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u64 out_index;
.reg .u32 thread_id;
.reg .u32 result;
.reg .pred in_range;
ld.param.u64 out_addr, [output];
mov.u32 thread_id, %tid.x;
// result = __shfl_xor_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=32)
// c is ((32-width) << 8) | 31
shfl.sync.bfly.b32 result|in_range, thread_id, 3, 31, 0xFFFFFFFF;
@!in_range add.u32 result, result, 1000;
// Return result
cvt.u64.u32 out_index, thread_id;
mul.lo.u64 out_index, out_index, 4;
add.u64 out_addr, out_addr, out_index;
st.u32 [out_addr], result;
ret;
}

View File

@ -0,0 +1,32 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry shfl_sync_down_b32_pred(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u64 out_index;
.reg .u32 thread_id;
.reg .u32 result;
.reg .pred in_range;
ld.param.u64 out_addr, [output];
mov.u32 thread_id, %tid.x;
// result = __shfl_down_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=32)
// c is ((32-width) << 8) | 31
shfl.sync.down.b32 result|in_range, thread_id, 3, 31, 0xFFFFFFFF;
@!in_range add.u32 result, result, 1000;
// Return result
cvt.u64.u32 out_index, thread_id;
mul.lo.u64 out_index, out_index, 4;
add.u64 out_addr, out_addr, out_index;
st.u32 [out_addr], result;
ret;
}

View File

@ -0,0 +1,32 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry shfl_sync_idx_b32_pred(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u64 out_index;
.reg .u32 thread_id;
.reg .u32 result;
.reg .pred in_range;
ld.param.u64 out_addr, [output];
mov.u32 thread_id, %tid.x;
// result = __shfl_sync(mask=0xFFFFFFFF, thread_id, srcLane=12, width=32)
// c is ((32-width) << 8) | 31
shfl.sync.idx.b32 result|in_range, thread_id, 12, 31, 0xFFFFFFFF;
@!in_range add.u32 result, result, 1000;
// Return result
cvt.u64.u32 out_index, thread_id;
mul.lo.u64 out_index, out_index, 4;
add.u64 out_addr, out_addr, out_index;
st.u32 [out_addr], result;
ret;
}

View File

@ -0,0 +1,46 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry shfl_sync_mode_b32(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u64 out_index;
.reg .u32 thread_id;
.reg .u32 temp;
.reg .u32 result;
ld.param.u64 out_addr, [output];
mov.u32 thread_id, %tid.x;
// result = __shfl_up_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=2)
// c is ((32-width) << 8)
shfl.sync.up.b32 temp, thread_id, 3, 7680, 0xFFFFFFFF;
mov.u32 result, temp;
// result += __shfl_down_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=4)
// c is ((32-width) << 8) | 31
shfl.sync.down.b32 temp, thread_id, 3, 7199, 0xFFFFFFFF;
add.u32 result, result, temp;
// result = __shfl_xor_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=8)
// c is ((32-width) << 8) | 31
shfl.sync.bfly.b32 temp, thread_id, 3, 6175, 0xFFFFFFFF;
add.u32 result, result, temp;
// result = __shfl_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=16)
// c is ((32-width) << 8) | 31
shfl.sync.idx.b32 temp, thread_id, 3, 4127, 0xFFFFFFFF;
add.u32 result, result, temp;
// Return result
cvt.u64.u32 out_index, thread_id;
mul.lo.u64 out_index, out_index, 4;
add.u64 out_addr, out_addr, out_index;
st.u32 [out_addr], result;
ret;
}

View File

@ -0,0 +1,32 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry shfl_sync_up_b32_pred(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u64 out_index;
.reg .u32 thread_id;
.reg .u32 result;
.reg .pred in_range;
ld.param.u64 out_addr, [output];
mov.u32 thread_id, %tid.x;
// result = __shfl_up_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=32)
// c is ((32-width) << 8)
shfl.sync.up.b32 result|in_range, thread_id, 3, 0, 0xFFFFFFFF;
@!in_range add.u32 result, result, 1000;
// Return result
cvt.u64.u32 out_index, thread_id;
mul.lo.u64 out_index, out_index, 4;
add.u64 out_addr, out_addr, out_index;
st.u32 [out_addr], result;
ret;
}

View File

@ -2,7 +2,7 @@ use super::{
AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
StateSpace, VectorPrefix, StateSpace, VectorPrefix,
}; };
use crate::{Mul24Control, Reduction, PtxError, PtxParserState}; use crate::{Mul24Control, Reduction, PtxError, PtxParserState, ShuffleMode};
use bitflags::bitflags; use bitflags::bitflags;
use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
@ -468,6 +468,21 @@ ptx_parser_macros::generate_instruction_type!(
} }
} }
}, },
ShflSync {
data: ShflSyncDetails,
type: Type::Scalar(ScalarType::B32),
arguments<T>: {
dst: T,
dst_pred: {
repr: Option<T>,
type: Type::from(ScalarType::Pred)
},
src: T,
src_lane: T,
src_opts: T,
src_membermask: T
}
},
Shl { Shl {
data: ScalarType, data: ScalarType,
type: { Type::Scalar(data.clone()) }, type: { Type::Scalar(data.clone()) },
@ -979,6 +994,11 @@ impl MovDetails {
} }
} }
#[derive(Copy, Clone)]
pub struct ShflSyncDetails {
pub mode: ShuffleMode,
}
#[derive(Clone)] #[derive(Clone)]
pub enum ParsedOperand<Ident> { pub enum ParsedOperand<Ident> {
Reg(Ident), Reg(Ident),

View File

@ -1589,7 +1589,7 @@ where
// * Opcode: `ld` // * Opcode: `ld`
// * Modifiers, always start with a dot: `.global`, `.relaxed`. Optionals are enclosed in braces // * Modifiers, always start with a dot: `.global`, `.relaxed`. Optionals are enclosed in braces
// * Arguments: `a`, `b`. Optionals are enclosed in braces // * Arguments: `a`, `b`. Optionals are enclosed in braces
// * Code block: => { <code expression> }. Code blocks implictly take all modifiers ansd arguments // * Code block: => { <code expression> }. Code blocks implictly take all modifiers and arguments
// as parameters. All modifiers and arguments are passed to the code block: // as parameters. All modifiers and arguments are passed to the code block:
// * If it is an alternative (as defined in rules list later): // * If it is an alternative (as defined in rules list later):
// * If it is mandatory then its type is Foo (as defined by the relevant rule) // * If it is mandatory then its type is Foo (as defined by the relevant rule)
@ -1723,6 +1723,9 @@ derive_parser!(
#[derive(Copy, Clone, PartialEq, Eq, Hash)] #[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum Reduction { } pub enum Reduction { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum ShuffleMode { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => { mov{.vec}.type d, a => {
Instruction::Mov { Instruction::Mov {
@ -3487,6 +3490,14 @@ derive_parser!(
.mode: Mul24Control = { .hi, .lo }; .mode: Mul24Control = { .hi, .lo };
.type: ScalarType = { .u32, .s32 }; .type: ScalarType = { .u32, .s32 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync
shfl.sync.mode.b32 d[|p], a, b, c, membermask => {
Instruction::ShflSync {
data: ast::ShflSyncDetails { mode },
arguments: ShflSyncArgs { dst: d, dst_pred: p, src: a, src_lane: b, src_opts: c, src_membermask: membermask }
}
}
.mode: ShuffleMode = { .up, .down, .bfly, .idx };
); );
#[cfg(test)] #[cfg(test)]