diff --git a/.gitignore b/.gitignore index 76550e8..6c86c12 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ target/ Cargo.lock .vscode/ -.idea/ \ No newline at end of file +.idea/ + +ptx/lib/zluda_ptx_impl.ll diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 84d2f2d..6bf56ca 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 638ef1e..75e88f1 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -4,6 +4,7 @@ #include #include + #include #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME @@ -170,6 +171,55 @@ extern "C" BAR_RED_IMPL(and); BAR_RED_IMPL(or); + struct ShflSyncResult { + uint32_t output; + bool in_bounds; + }; + + // shfl.sync opts consists of two values, the warp end ID and the subsection mask. + // + // The current warp is partitioned into some number of subsections with a width of w. The + // subsection mask is 32 - w, and indicates which bits of the lane id are part of the subsection + // address. For example, if each subsection is 8 lanes wide, the subsection mask will be 24 – + // 11000 in binary. This indicates that the two most significant bits in the 5-bit lane ID are + // the subsection address. For example, for a lane ID 13 (0b01101) the address of the beginning + // of the subsection is 0b01000 (8). + // + // The warp end ID is the max lane ID for a specific mode. For the CUDA __shfl_sync + // intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the + // bounds check. + + #define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \ + ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \ + { \ + int32_t section_mask = (opts >> 8) & 0b11111; \ + int32_t warp_end = opts & 0b11111; \ + int32_t self = (int32_t)__lane_id(); \ + int32_t subsection = section_mask & self; \ + int32_t subsection_end = subsection | (~section_mask & warp_end); \ + int32_t idx = calculate_index; \ + bool out_of_bounds = idx CMP subsection_end; \ + if (out_of_bounds) { \ + idx = self; \ + } \ + int32_t output = __builtin_amdgcn_ds_bpermute(idx<<2, (int32_t)input); \ + return {(uint32_t)output, !out_of_bounds}; \ + } \ + \ + uint32_t FUNC(shfl_sync_##mode##_b32)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask) \ + { \ + return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \ + } + + // We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync + // intrinsics, as those only add an assertion checking that the membermask is used correctly. + // They do not return the result of the range check, so we must replicate that logic here. + + SHFL_SYNC_IMPL(up, self - delta, <); + SHFL_SYNC_IMPL(down, self + delta, >); + SHFL_SYNC_IMPL(bfly, self ^ delta, >); + SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >); + void FUNC(__assertfail)(uint64_t message, uint64_t file, uint32_t line, diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 57a7d0f..b888202 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -641,7 +641,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Bar { .. } | ast::Instruction::BarRed { .. } | ast::Instruction::Bfi { .. } - | ast::Instruction::Activemask { .. } => return Err(error_unreachable()), + | ast::Instruction::Activemask { .. } + | ast::Instruction::ShflSync { .. } => return Err(error_unreachable()), } } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 92be749..4ad5339 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -165,6 +165,7 @@ fn run_instruction<'input>( | ast::Instruction::Selp { .. } | ast::Instruction::Setp { .. } | ast::Instruction::SetpBool { .. } + | ast::Instruction::ShflSync { .. } | ast::Instruction::Shl { .. } | ast::Instruction::Shr { .. } | ast::Instruction::Sin { .. } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index d1c5489..91bf1a8 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1800,6 +1800,7 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Bfe { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Shr { .. } + | ast::Instruction::ShflSync { .. } | ast::Instruction::Shl { .. } | ast::Instruction::Selp { .. } | ast::Instruction::Ret { .. } diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index db6b473..0480e5f 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -118,6 +118,25 @@ fn run_instruction<'input>( }; to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })? } + ptx_parser::Instruction::ShflSync { data, arguments } => { + let mode = match data.mode { + ptx_parser::ShuffleMode::Up => "up", + ptx_parser::ShuffleMode::Down => "down", + ptx_parser::ShuffleMode::BFly => "bfly", + ptx_parser::ShuffleMode::Idx => "idx", + }; + let pred = if arguments.dst_pred.is_some() { + "_pred" + } else { + "" + }; + to_call( + resolver, + fn_declarations, + format!("shfl_sync_{}_b32{}", mode, pred).into(), + ptx_parser::Instruction::ShflSync { data, arguments }, + )? + } i => i, }) } diff --git a/ptx/src/test/ll/shfl_sync_bfly_b32_pred.ll b/ptx/src/test/ll/shfl_sync_bfly_b32_pred.ll new file mode 100644 index 0000000..da51305 --- /dev/null +++ b/ptx/src/test/ll/shfl_sync_bfly_b32_pred.ll @@ -0,0 +1,59 @@ +declare [2 x i32] @__zluda_ptx_impl_shfl_sync_bfly_b32_pred(i32, i32, i32, i32) #0 + +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @shfl_sync_bfly_b32_pred(ptr addrspace(4) byref(i64) %"42") #1 { + %"43" = alloca i64, align 8, addrspace(5) + %"44" = alloca i64, align 8, addrspace(5) + %"45" = alloca i32, align 4, addrspace(5) + %"46" = alloca i32, align 4, addrspace(5) + %"47" = alloca i1, align 1, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"39" + +"39": ; preds = %1 + %"48" = load i64, ptr addrspace(4) %"42", align 4 + store i64 %"48", ptr addrspace(5) %"43", align 4 + %"33" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"40" + +"40": ; preds = %"39" + store i32 %"33", ptr addrspace(5) %"45", align 4 + %"52" = load i32, ptr addrspace(5) %"45", align 4 + %2 = call [2 x i32] @__zluda_ptx_impl_shfl_sync_bfly_b32_pred(i32 %"52", i32 3, i32 31, i32 -1) + %"65" = extractvalue [2 x i32] %2, 0 + %3 = extractvalue [2 x i32] %2, 1 + %"51" = trunc i32 %3 to i1 + store i32 %"65", ptr addrspace(5) %"46", align 4 + store i1 %"51", ptr addrspace(5) %"47", align 1 + %"53" = load i1, ptr addrspace(5) %"47", align 1 + br i1 %"53", label %"15", label %"14" + +"14": ; preds = %"40" + %"55" = load i32, ptr addrspace(5) %"46", align 4 + %"54" = add i32 %"55", 1000 + store i32 %"54", ptr addrspace(5) %"46", align 4 + br label %"15" + +"15": ; preds = %"14", %"40" + %"57" = load i32, ptr addrspace(5) %"45", align 4 + %"56" = zext i32 %"57" to i64 + store i64 %"56", ptr addrspace(5) %"44", align 4 + %"59" = load i64, ptr addrspace(5) %"44", align 4 + %"58" = mul i64 %"59", 4 + store i64 %"58", ptr addrspace(5) %"44", align 4 + %"61" = load i64, ptr addrspace(5) %"43", align 4 + %"62" = load i64, ptr addrspace(5) %"44", align 4 + %"60" = add i64 %"61", %"62" + store i64 %"60", ptr addrspace(5) %"43", align 4 + %"63" = load i64, ptr addrspace(5) %"43", align 4 + %"64" = load i32, ptr addrspace(5) %"46", align 4 + %"67" = inttoptr i64 %"63" to ptr + store i32 %"64", ptr %"67", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/ll/shfl_sync_down_b32_pred.ll b/ptx/src/test/ll/shfl_sync_down_b32_pred.ll new file mode 100644 index 0000000..2f9edef --- /dev/null +++ b/ptx/src/test/ll/shfl_sync_down_b32_pred.ll @@ -0,0 +1,59 @@ +declare [2 x i32] @__zluda_ptx_impl_shfl_sync_down_b32_pred(i32, i32, i32, i32) #0 + +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @shfl_sync_down_b32_pred(ptr addrspace(4) byref(i64) %"42") #1 { + %"43" = alloca i64, align 8, addrspace(5) + %"44" = alloca i64, align 8, addrspace(5) + %"45" = alloca i32, align 4, addrspace(5) + %"46" = alloca i32, align 4, addrspace(5) + %"47" = alloca i1, align 1, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"39" + +"39": ; preds = %1 + %"48" = load i64, ptr addrspace(4) %"42", align 4 + store i64 %"48", ptr addrspace(5) %"43", align 4 + %"33" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"40" + +"40": ; preds = %"39" + store i32 %"33", ptr addrspace(5) %"45", align 4 + %"52" = load i32, ptr addrspace(5) %"45", align 4 + %2 = call [2 x i32] @__zluda_ptx_impl_shfl_sync_down_b32_pred(i32 %"52", i32 3, i32 31, i32 -1) + %"65" = extractvalue [2 x i32] %2, 0 + %3 = extractvalue [2 x i32] %2, 1 + %"51" = trunc i32 %3 to i1 + store i32 %"65", ptr addrspace(5) %"46", align 4 + store i1 %"51", ptr addrspace(5) %"47", align 1 + %"53" = load i1, ptr addrspace(5) %"47", align 1 + br i1 %"53", label %"15", label %"14" + +"14": ; preds = %"40" + %"55" = load i32, ptr addrspace(5) %"46", align 4 + %"54" = add i32 %"55", 1000 + store i32 %"54", ptr addrspace(5) %"46", align 4 + br label %"15" + +"15": ; preds = %"14", %"40" + %"57" = load i32, ptr addrspace(5) %"45", align 4 + %"56" = zext i32 %"57" to i64 + store i64 %"56", ptr addrspace(5) %"44", align 4 + %"59" = load i64, ptr addrspace(5) %"44", align 4 + %"58" = mul i64 %"59", 4 + store i64 %"58", ptr addrspace(5) %"44", align 4 + %"61" = load i64, ptr addrspace(5) %"43", align 4 + %"62" = load i64, ptr addrspace(5) %"44", align 4 + %"60" = add i64 %"61", %"62" + store i64 %"60", ptr addrspace(5) %"43", align 4 + %"63" = load i64, ptr addrspace(5) %"43", align 4 + %"64" = load i32, ptr addrspace(5) %"46", align 4 + %"67" = inttoptr i64 %"63" to ptr + store i32 %"64", ptr %"67", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/ll/shfl_sync_idx_b32_pred.ll b/ptx/src/test/ll/shfl_sync_idx_b32_pred.ll new file mode 100644 index 0000000..e7ac9c6 --- /dev/null +++ b/ptx/src/test/ll/shfl_sync_idx_b32_pred.ll @@ -0,0 +1,59 @@ +declare [2 x i32] @__zluda_ptx_impl_shfl_sync_idx_b32_pred(i32, i32, i32, i32) #0 + +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @shfl_sync_idx_b32_pred(ptr addrspace(4) byref(i64) %"42") #1 { + %"43" = alloca i64, align 8, addrspace(5) + %"44" = alloca i64, align 8, addrspace(5) + %"45" = alloca i32, align 4, addrspace(5) + %"46" = alloca i32, align 4, addrspace(5) + %"47" = alloca i1, align 1, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"39" + +"39": ; preds = %1 + %"48" = load i64, ptr addrspace(4) %"42", align 4 + store i64 %"48", ptr addrspace(5) %"43", align 4 + %"33" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"40" + +"40": ; preds = %"39" + store i32 %"33", ptr addrspace(5) %"45", align 4 + %"52" = load i32, ptr addrspace(5) %"45", align 4 + %2 = call [2 x i32] @__zluda_ptx_impl_shfl_sync_idx_b32_pred(i32 %"52", i32 12, i32 31, i32 -1) + %"65" = extractvalue [2 x i32] %2, 0 + %3 = extractvalue [2 x i32] %2, 1 + %"51" = trunc i32 %3 to i1 + store i32 %"65", ptr addrspace(5) %"46", align 4 + store i1 %"51", ptr addrspace(5) %"47", align 1 + %"53" = load i1, ptr addrspace(5) %"47", align 1 + br i1 %"53", label %"15", label %"14" + +"14": ; preds = %"40" + %"55" = load i32, ptr addrspace(5) %"46", align 4 + %"54" = add i32 %"55", 1000 + store i32 %"54", ptr addrspace(5) %"46", align 4 + br label %"15" + +"15": ; preds = %"14", %"40" + %"57" = load i32, ptr addrspace(5) %"45", align 4 + %"56" = zext i32 %"57" to i64 + store i64 %"56", ptr addrspace(5) %"44", align 4 + %"59" = load i64, ptr addrspace(5) %"44", align 4 + %"58" = mul i64 %"59", 4 + store i64 %"58", ptr addrspace(5) %"44", align 4 + %"61" = load i64, ptr addrspace(5) %"43", align 4 + %"62" = load i64, ptr addrspace(5) %"44", align 4 + %"60" = add i64 %"61", %"62" + store i64 %"60", ptr addrspace(5) %"43", align 4 + %"63" = load i64, ptr addrspace(5) %"43", align 4 + %"64" = load i32, ptr addrspace(5) %"46", align 4 + %"67" = inttoptr i64 %"63" to ptr + store i32 %"64", ptr %"67", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/ll/shfl_sync_mode_b32.ll b/ptx/src/test/ll/shfl_sync_mode_b32.ll new file mode 100644 index 0000000..a65ad1e --- /dev/null +++ b/ptx/src/test/ll/shfl_sync_mode_b32.ll @@ -0,0 +1,74 @@ +declare i32 @__zluda_ptx_impl_shfl_sync_down_b32(i32, i32, i32, i32) #0 + +declare i32 @__zluda_ptx_impl_shfl_sync_up_b32(i32, i32, i32, i32) #0 + +declare i32 @__zluda_ptx_impl_shfl_sync_bfly_b32(i32, i32, i32, i32) #0 + +declare i32 @__zluda_ptx_impl_shfl_sync_idx_b32(i32, i32, i32, i32) #0 + +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @shfl_sync_mode_b32(ptr addrspace(4) byref(i64) %"48") #1 { + %"49" = alloca i64, align 8, addrspace(5) + %"50" = alloca i64, align 8, addrspace(5) + %"51" = alloca i32, align 4, addrspace(5) + %"52" = alloca i32, align 4, addrspace(5) + %"53" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"45" + +"45": ; preds = %1 + %"54" = load i64, ptr addrspace(4) %"48", align 4 + store i64 %"54", ptr addrspace(5) %"49", align 4 + %"31" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"46" + +"46": ; preds = %"45" + store i32 %"31", ptr addrspace(5) %"51", align 4 + %"57" = load i32, ptr addrspace(5) %"51", align 4 + %"84" = call i32 @__zluda_ptx_impl_shfl_sync_up_b32(i32 %"57", i32 3, i32 7680, i32 -1) + store i32 %"84", ptr addrspace(5) %"52", align 4 + %"59" = load i32, ptr addrspace(5) %"52", align 4 + store i32 %"59", ptr addrspace(5) %"53", align 4 + %"61" = load i32, ptr addrspace(5) %"51", align 4 + %"86" = call i32 @__zluda_ptx_impl_shfl_sync_down_b32(i32 %"61", i32 3, i32 7199, i32 -1) + store i32 %"86", ptr addrspace(5) %"52", align 4 + %"63" = load i32, ptr addrspace(5) %"53", align 4 + %"64" = load i32, ptr addrspace(5) %"52", align 4 + %"62" = add i32 %"63", %"64" + store i32 %"62", ptr addrspace(5) %"53", align 4 + %"66" = load i32, ptr addrspace(5) %"51", align 4 + %"88" = call i32 @__zluda_ptx_impl_shfl_sync_bfly_b32(i32 %"66", i32 3, i32 6175, i32 -1) + store i32 %"88", ptr addrspace(5) %"52", align 4 + %"68" = load i32, ptr addrspace(5) %"53", align 4 + %"69" = load i32, ptr addrspace(5) %"52", align 4 + %"67" = add i32 %"68", %"69" + store i32 %"67", ptr addrspace(5) %"53", align 4 + %"71" = load i32, ptr addrspace(5) %"51", align 4 + %"90" = call i32 @__zluda_ptx_impl_shfl_sync_idx_b32(i32 %"71", i32 3, i32 4127, i32 -1) + store i32 %"90", ptr addrspace(5) %"52", align 4 + %"73" = load i32, ptr addrspace(5) %"53", align 4 + %"74" = load i32, ptr addrspace(5) %"52", align 4 + %"72" = add i32 %"73", %"74" + store i32 %"72", ptr addrspace(5) %"53", align 4 + %"76" = load i32, ptr addrspace(5) %"51", align 4 + %"75" = zext i32 %"76" to i64 + store i64 %"75", ptr addrspace(5) %"50", align 4 + %"78" = load i64, ptr addrspace(5) %"50", align 4 + %"77" = mul i64 %"78", 4 + store i64 %"77", ptr addrspace(5) %"50", align 4 + %"80" = load i64, ptr addrspace(5) %"49", align 4 + %"81" = load i64, ptr addrspace(5) %"50", align 4 + %"79" = add i64 %"80", %"81" + store i64 %"79", ptr addrspace(5) %"49", align 4 + %"82" = load i64, ptr addrspace(5) %"49", align 4 + %"83" = load i32, ptr addrspace(5) %"53", align 4 + %"92" = inttoptr i64 %"82" to ptr + store i32 %"83", ptr %"92", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/ll/shfl_sync_up_b32_pred.ll b/ptx/src/test/ll/shfl_sync_up_b32_pred.ll new file mode 100644 index 0000000..399b03a --- /dev/null +++ b/ptx/src/test/ll/shfl_sync_up_b32_pred.ll @@ -0,0 +1,59 @@ +declare [2 x i32] @__zluda_ptx_impl_shfl_sync_up_b32_pred(i32, i32, i32, i32) #0 + +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @shfl_sync_up_b32_pred(ptr addrspace(4) byref(i64) %"42") #1 { + %"43" = alloca i64, align 8, addrspace(5) + %"44" = alloca i64, align 8, addrspace(5) + %"45" = alloca i32, align 4, addrspace(5) + %"46" = alloca i32, align 4, addrspace(5) + %"47" = alloca i1, align 1, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"39" + +"39": ; preds = %1 + %"48" = load i64, ptr addrspace(4) %"42", align 4 + store i64 %"48", ptr addrspace(5) %"43", align 4 + %"33" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"40" + +"40": ; preds = %"39" + store i32 %"33", ptr addrspace(5) %"45", align 4 + %"52" = load i32, ptr addrspace(5) %"45", align 4 + %2 = call [2 x i32] @__zluda_ptx_impl_shfl_sync_up_b32_pred(i32 %"52", i32 3, i32 0, i32 -1) + %"65" = extractvalue [2 x i32] %2, 0 + %3 = extractvalue [2 x i32] %2, 1 + %"51" = trunc i32 %3 to i1 + store i32 %"65", ptr addrspace(5) %"46", align 4 + store i1 %"51", ptr addrspace(5) %"47", align 1 + %"53" = load i1, ptr addrspace(5) %"47", align 1 + br i1 %"53", label %"15", label %"14" + +"14": ; preds = %"40" + %"55" = load i32, ptr addrspace(5) %"46", align 4 + %"54" = add i32 %"55", 1000 + store i32 %"54", ptr addrspace(5) %"46", align 4 + br label %"15" + +"15": ; preds = %"14", %"40" + %"57" = load i32, ptr addrspace(5) %"45", align 4 + %"56" = zext i32 %"57" to i64 + store i64 %"56", ptr addrspace(5) %"44", align 4 + %"59" = load i64, ptr addrspace(5) %"44", align 4 + %"58" = mul i64 %"59", 4 + store i64 %"58", ptr addrspace(5) %"44", align 4 + %"61" = load i64, ptr addrspace(5) %"43", align 4 + %"62" = load i64, ptr addrspace(5) %"44", align 4 + %"60" = add i64 %"61", %"62" + store i64 %"60", ptr addrspace(5) %"43", align 4 + %"63" = load i64, ptr addrspace(5) %"43", align 4 + %"64" = load i32, ptr addrspace(5) %"46", align 4 + %"67" = inttoptr i64 %"63" to ptr + store i32 %"64", ptr %"67", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 11820e7..d861593 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -315,6 +315,36 @@ test_ptx_warp!(bar_red_and_pred, [ 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, ]); +test_ptx_warp!(shfl_sync_up_b32_pred, [ + 1000u32, 1001u32, 1002u32, 0u32, 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, + 13u32, 14u32, 15u32, 16u32, 17u32, 18u32, 19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, + 1032u32, 1033u32, 1034u32, 32u32, 33u32, 34u32, 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32, + 45u32, 46u32, 47u32, 48u32, 49u32, 50u32, 51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32, +]); +test_ptx_warp!(shfl_sync_down_b32_pred, [ + 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, 13u32, 14u32, 15u32, 16u32, 17u32, 18u32, + 19u32, 20u32, 21u32, 22u32, 23u32, 24u32, 25u32, 26u32, 27u32, 28u32, 29u32, 30u32, 31u32, 1029u32, 1030u32, 1031u32, + 35u32, 36u32, 37u32, 38u32, 39u32, 40u32, 41u32, 42u32, 43u32, 44u32, 45u32, 46u32, 47u32, 48u32, 49u32, 50u32, + 51u32, 52u32, 53u32, 54u32, 55u32, 56u32, 57u32, 58u32, 59u32, 60u32, 61u32, 62u32, 63u32, 1061u32, 1062u32, 1063u32, +]); +test_ptx_warp!(shfl_sync_bfly_b32_pred, [ + 3u32, 2u32, 1u32, 0u32, 7u32, 6u32, 5u32, 4u32, 11u32, 10u32, 9u32, 8u32, 15u32, 14u32, 13u32, 12u32, + 19u32, 18u32, 17u32, 16u32, 23u32, 22u32, 21u32, 20u32, 27u32, 26u32, 25u32, 24u32, 31u32, 30u32, 29u32, 28u32, + 35u32, 34u32, 33u32, 32u32, 39u32, 38u32, 37u32, 36u32, 43u32, 42u32, 41u32, 40u32, 47u32, 46u32, 45u32, 44u32, + 51u32, 50u32, 49u32, 48u32, 55u32, 54u32, 53u32, 52u32, 59u32, 58u32, 57u32, 56u32, 63u32, 62u32, 61u32, 60u32, +]); +test_ptx_warp!(shfl_sync_idx_b32_pred, [ + 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, + 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, 12u32, + 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, + 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, 44u32, +]); +test_ptx_warp!(shfl_sync_mode_b32, [ + 9u32, 7u32, 8u32, 9u32, 21u32, 19u32, 20u32, 21u32, 33u32, 31u32, 32u32, 33u32, 45u32, 43u32, 44u32, 45u32, + 73u32, 71u32, 72u32, 73u32, 85u32, 83u32, 84u32, 85u32, 97u32, 95u32, 96u32, 97u32, 109u32, 107u32, 108u32, 109u32, + 137u32, 135u32, 136u32, 137u32, 149u32, 147u32, 148u32, 149u32, 161u32, 159u32, 160u32, 161u32, 173u32, 171u32, 172u32, 173u32, + 201u32, 199u32, 200u32, 201u32, 213u32, 211u32, 212u32, 213u32, 225u32, 223u32, 224u32, 225u32, 237u32, 235u32, 236u32, 237u32, +]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/shfl_sync_bfly_b32_pred.ptx b/ptx/src/test/spirv_run/shfl_sync_bfly_b32_pred.ptx new file mode 100644 index 0000000..13416d0 --- /dev/null +++ b/ptx/src/test/spirv_run/shfl_sync_bfly_b32_pred.ptx @@ -0,0 +1,32 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shfl_sync_bfly_b32_pred( + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u64 out_index; + .reg .u32 thread_id; + .reg .u32 result; + .reg .pred in_range; + + ld.param.u64 out_addr, [output]; + mov.u32 thread_id, %tid.x; + + // result = __shfl_xor_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=32) + // c is ((32-width) << 8) | 31 + shfl.sync.bfly.b32 result|in_range, thread_id, 3, 31, 0xFFFFFFFF; + + @!in_range add.u32 result, result, 1000; + + // Return result + + cvt.u64.u32 out_index, thread_id; + mul.lo.u64 out_index, out_index, 4; + add.u64 out_addr, out_addr, out_index; + st.u32 [out_addr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/shfl_sync_down_b32_pred.ptx b/ptx/src/test/spirv_run/shfl_sync_down_b32_pred.ptx new file mode 100644 index 0000000..3c74914 --- /dev/null +++ b/ptx/src/test/spirv_run/shfl_sync_down_b32_pred.ptx @@ -0,0 +1,32 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shfl_sync_down_b32_pred( + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u64 out_index; + .reg .u32 thread_id; + .reg .u32 result; + .reg .pred in_range; + + ld.param.u64 out_addr, [output]; + mov.u32 thread_id, %tid.x; + + // result = __shfl_down_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=32) + // c is ((32-width) << 8) | 31 + shfl.sync.down.b32 result|in_range, thread_id, 3, 31, 0xFFFFFFFF; + + @!in_range add.u32 result, result, 1000; + + // Return result + + cvt.u64.u32 out_index, thread_id; + mul.lo.u64 out_index, out_index, 4; + add.u64 out_addr, out_addr, out_index; + st.u32 [out_addr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/shfl_sync_idx_b32_pred.ptx b/ptx/src/test/spirv_run/shfl_sync_idx_b32_pred.ptx new file mode 100644 index 0000000..d5d6b05 --- /dev/null +++ b/ptx/src/test/spirv_run/shfl_sync_idx_b32_pred.ptx @@ -0,0 +1,32 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shfl_sync_idx_b32_pred( + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u64 out_index; + .reg .u32 thread_id; + .reg .u32 result; + .reg .pred in_range; + + ld.param.u64 out_addr, [output]; + mov.u32 thread_id, %tid.x; + + // result = __shfl_sync(mask=0xFFFFFFFF, thread_id, srcLane=12, width=32) + // c is ((32-width) << 8) | 31 + shfl.sync.idx.b32 result|in_range, thread_id, 12, 31, 0xFFFFFFFF; + + @!in_range add.u32 result, result, 1000; + + // Return result + + cvt.u64.u32 out_index, thread_id; + mul.lo.u64 out_index, out_index, 4; + add.u64 out_addr, out_addr, out_index; + st.u32 [out_addr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/shfl_sync_mode_b32.ptx b/ptx/src/test/spirv_run/shfl_sync_mode_b32.ptx new file mode 100644 index 0000000..f500098 --- /dev/null +++ b/ptx/src/test/spirv_run/shfl_sync_mode_b32.ptx @@ -0,0 +1,46 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shfl_sync_mode_b32( + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u64 out_index; + .reg .u32 thread_id; + .reg .u32 temp; + .reg .u32 result; + + ld.param.u64 out_addr, [output]; + mov.u32 thread_id, %tid.x; + + // result = __shfl_up_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=2) + // c is ((32-width) << 8) + shfl.sync.up.b32 temp, thread_id, 3, 7680, 0xFFFFFFFF; + mov.u32 result, temp; + + // result += __shfl_down_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=4) + // c is ((32-width) << 8) | 31 + shfl.sync.down.b32 temp, thread_id, 3, 7199, 0xFFFFFFFF; + add.u32 result, result, temp; + + // result = __shfl_xor_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=8) + // c is ((32-width) << 8) | 31 + shfl.sync.bfly.b32 temp, thread_id, 3, 6175, 0xFFFFFFFF; + add.u32 result, result, temp; + + // result = __shfl_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=16) + // c is ((32-width) << 8) | 31 + shfl.sync.idx.b32 temp, thread_id, 3, 4127, 0xFFFFFFFF; + add.u32 result, result, temp; + + // Return result + + cvt.u64.u32 out_index, thread_id; + mul.lo.u64 out_index, out_index, 4; + add.u64 out_addr, out_addr, out_index; + st.u32 [out_addr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/shfl_sync_up_b32_pred.ptx b/ptx/src/test/spirv_run/shfl_sync_up_b32_pred.ptx new file mode 100644 index 0000000..540298e --- /dev/null +++ b/ptx/src/test/spirv_run/shfl_sync_up_b32_pred.ptx @@ -0,0 +1,32 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shfl_sync_up_b32_pred( + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u64 out_index; + .reg .u32 thread_id; + .reg .u32 result; + .reg .pred in_range; + + ld.param.u64 out_addr, [output]; + mov.u32 thread_id, %tid.x; + + // result = __shfl_up_sync(mask=0xFFFFFFFF, thread_id, delta=3, width=32) + // c is ((32-width) << 8) + shfl.sync.up.b32 result|in_range, thread_id, 3, 0, 0xFFFFFFFF; + + @!in_range add.u32 result, result, 1000; + + // Return result + + cvt.u64.u32 out_index, thread_id; + mul.lo.u64 out_index, out_index, 4; + add.u64 out_addr, out_addr, out_index; + st.u32 [out_addr], result; + + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 6e42871..63155f4 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -2,7 +2,7 @@ use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; -use crate::{Mul24Control, Reduction, PtxError, PtxParserState}; +use crate::{Mul24Control, Reduction, PtxError, PtxParserState, ShuffleMode}; use bitflags::bitflags; use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; @@ -468,6 +468,21 @@ ptx_parser_macros::generate_instruction_type!( } } }, + ShflSync { + data: ShflSyncDetails, + type: Type::Scalar(ScalarType::B32), + arguments: { + dst: T, + dst_pred: { + repr: Option, + type: Type::from(ScalarType::Pred) + }, + src: T, + src_lane: T, + src_opts: T, + src_membermask: T + } + }, Shl { data: ScalarType, type: { Type::Scalar(data.clone()) }, @@ -979,6 +994,11 @@ impl MovDetails { } } +#[derive(Copy, Clone)] +pub struct ShflSyncDetails { + pub mode: ShuffleMode, +} + #[derive(Clone)] pub enum ParsedOperand { Reg(Ident), diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index a377387..5f03b49 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1589,7 +1589,7 @@ where // * Opcode: `ld` // * Modifiers, always start with a dot: `.global`, `.relaxed`. Optionals are enclosed in braces // * Arguments: `a`, `b`. Optionals are enclosed in braces -// * Code block: => { }. Code blocks implictly take all modifiers ansd arguments +// * Code block: => { }. Code blocks implictly take all modifiers and arguments // as parameters. All modifiers and arguments are passed to the code block: // * If it is an alternative (as defined in rules list later): // * If it is mandatory then its type is Foo (as defined by the relevant rule) @@ -1723,6 +1723,9 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum Reduction { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum ShuffleMode { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -3487,6 +3490,14 @@ derive_parser!( .mode: Mul24Control = { .hi, .lo }; .type: ScalarType = { .u32, .s32 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync + shfl.sync.mode.b32 d[|p], a, b, c, membermask => { + Instruction::ShflSync { + data: ast::ShflSyncDetails { mode }, + arguments: ShflSyncArgs { dst: d, dst_pred: p, src: a, src_lane: b, src_opts: c, src_membermask: membermask } + } + } + .mode: ShuffleMode = { .up, .down, .bfly, .idx }; ); #[cfg(test)]