From 5cb0a9b8e83298ecf17b7aa18360866003c842e1 Mon Sep 17 00:00:00 2001 From: Violet Date: Thu, 3 Jul 2025 11:56:20 -0700 Subject: [PATCH] Add support for `bar.red.and.pred` (#402) Implements bar.red.and.pred and bar.red.or.pred, using the undocument __ockl_wgred functions. Doesn't yet add support for numbered barriers and threadcount, as these are not needed for llm.c. --- ptx/lib/zluda_ptx_impl.bc | Bin 7524 -> 7496 bytes ptx/lib/zluda_ptx_impl.cpp | 13 ++ ptx/src/pass/emit_llvm.rs | 1 + ptx/src/pass/insert_post_saturation.rs | 1 + .../instruction_mode_to_global_mode/mod.rs | 1 + ...eplace_instructions_with_function_calls.rs | 10 ++ ptx/src/test/ll/bar_red_and_pred.ll | 121 ++++++++++++++++++ ptx/src/test/spirv_run/bar_red_and_pred.ptx | 60 +++++++++ ptx/src/test/spirv_run/mod.rs | 7 + ptx_parser/src/ast.rs | 28 +++- ptx_parser/src/lib.rs | 31 ++++- ptx_parser_macros/src/lib.rs | 2 +- 12 files changed, 269 insertions(+), 6 deletions(-) create mode 100644 ptx/src/test/ll/bar_red_and_pred.ll create mode 100644 ptx/src/test/spirv_run/bar_red_and_pred.ptx diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 6cefc813eb60ce7f302e265090ab069682291723..84d2f2d734e8c76faeaa25d9778b95755e71d944 100644 GIT binary patch delta 3784 zcma)84QyN06+Z9z`T5y>PG20Sd2XEMIcb8^h0J#TyD^@fCN*saNmf=miu@BNB;Zo# zr?d?}esLU}WED7b7@;aGjnayC6kH4)0!@n3Py&RCmV!!EjTG9gOr6w$5Wicd`|O(r z4NbeU^j)3zopZl?&bjyc^7^^;(X~YTXpvdR*OQLlt#Hh8Hhm4lZ`RnnqpsYbH_Wdy zuZvnAj&C0oPR#+pumDg500r$7_mU$toY~?C4UcxsG&wcu;)YX7hqsuSala&u7mA-3 zjh8ZJ#ZacC<9jG~lV(8RypMyvc*Tf~f4hY|q#tpZo!j)vW_^9LURgreHgbAmctoP= zo1zYZ3F#{w&D@lp2?U0;{{XRsebK^IU!uUly^Z|bJ=8EAl95s=AAyMy9eSM%&NT|1ACpSN6ac%DQPpmLTVHNjg${T zaEf{G2zQ>0AN=g!i;IgvxShDJ5XoAb zviglGZqRC4DCVB|sy`VP*kZQ99SS$fj4>TuVZj*v*7q(Uq`InIqG z!~HBxCbrOz6G7JIo0^MGWX{A}PBba5-}?TSh4YE8XTG>LE;WwN3nc&y0EAc=KbZL` z3xS5ohSiL*mP`eF?ON+3mKl{I-Le+d?J4pJj%uuNTujgX{B+z@5Te%_x2jcQCn znm%lsl%Q$3rYk?aNWrg2Oo|e9>z+Ac7C;XSAf*h_^k?m_iRqUy2Wk3y?*R=1bXrqT z7WK-&Hn3DKu?&`}`|_VBrOu@d&q<~`)wS6iu*>t*&LqHn-;$fn?R=%-yb@vxY(jZz z?zRZ9X1S&}biXI2sZ{J>s(9<3v1#Ejs%@e1zD6TyisNYiUaEa$mAs3xMIrxwq(8A%R+@Vm{D!W850|zKs918XAi!UgoVG`u^0nq-kpQz_K$i~r8 z&M!Rwxdd@+0h{O@5HGzzf>g?70f24IIB&>ojP6qxZ#+fS>CS04_1tsh1;vwpUC$l& z-|@sd$M){uP7nU=i|nSo_MbU@|c$wMmLjunum@9?((uPWUo5 zJHrPH(S^(+yeqzw=xU`vYU!du(RCd317GDao!Vc3u2e4 z3CSKU;>}j?f?r_hJ4zUrNc&R1_`5p2B0Fs6>dq)k3r=PxP#34&mUrBGd&Z{22J>DA zmlI8TnrMT?D|HAAol$DOj~h#tCaq2O>mjT07LF_zkOX_4lq^+rQ?=%&dnPWc?c294 z(2->BlJM6iF4`@sLYjh*JMcZJng#SG-2GUg`vg1e^BayzAKA8_J>JF-NFUk8`on(1 zUj+9Fwo8ytvB?T@Tdf$eGg9n$R&G%L0YD+GkniBcw`_McHZRK3>0_2@y)# zPBVV^(Qj9j1|d%T__sIy$G1bNZ@F?%o!k%gazCK1505R1-|UOE>rKHkg95+B6W7m{ z64v+nxsau1nwEVsWuFbpeDX`srEpRk7U%?9$FbGi06%qb)QE}$MwP_sk*mj(si-5g@8sR%!De0I9G zF3#js1a9LD=C0sg)S5$Llt@<>T^dZX7_ znV?RJg9S`ubVPR(wc9?EQ`4uM%F*wxizldAaj+oB`qCQF8E*Vq5UzvTz~AOP~rqU@`}?e-2` z_;FtN4I8q5LJznp7Pk)90G(p&stQ1^x+ADa31|e6cXwBs zjS!u?_Y^6+Qvq_{Zc~8V^Qe8$xVrQPrTA#n0|0s#Frt4qOew(a0HB2%i0nbI0Nf2A zH!?=-&;a_s-0ALN=Kx{t>0ILC(5Zh7r#3(%bAVZ{7SGL>bF*d8KNS;L6zs47p9049(zC{M7%t0VyW(~J?>(Zcg& z>R&2G%)dsa%7_v37fFCTfE$?Z6JxshF)l7{O)0urky45gIUuEI;1hxn*-<^v95pCL zWS1y_a&#;Kpgd+@owF5t-mgfV7c}l+0WPHu@7L(UaJgD{KG4%!=5l#@dfK{sUG1Ju ze_8vEUYDz*GvI37)#-D2+iE->TU%9?rPkx|SQ1aI-bHTpdb(ZRZ9bRh0iP>?Ca#^W z_xoMD@5RQ>a+}N5=ilY?xB|UJG+rf>^oyww%`BYjcYEBC+~puFpFNP^atY)JrrA3_3gpz|BaW$dIKqb}MFl@P zwr0FNI0!+Ehldq4+CwqYR3}6-nKWak9Qo);X^J*Fq^2E`l8&~fr4dsdXG-VUyG1kY z_zpAA&3}LI|9#%)eb{}YTNk$mvxyT!Iksd|CFy!L$u-8?O%BGqgSFQWHCcW20du}B zKUlPH;@%~Pc^H}tvr|Ds?TLkz1!_Q2%{XI z1v$bawAaIC1`LnTVR4sd*dJ~r8`%JDZ*xY40Pl(pPt6+WO7SR@Ya(87nOoc$y#280 zoVjJjJD8f6;EIl9@(o;f(0RWgG))obK|J?g6iEX-;#E*DcFBC|&F+3zV0QopwwRU;l92~qW%Sov@#z2R)m2>;^bd&oPr~}E zus}>?%=~9%WyJ?~5KGj89<<|BSWn&qCHx;TNxghLG9LpzgkTAu3PBQT%mC@38cqF1 zwnJcBT*2F!3AP)pnY@Ojqk8`w{SNRn6d2FxS7_mp~ATlQQNBaqMqT7m#w^>Y7Ss=^>hyNg$MjmCW)h zk27$XR?RT84@Mf3d0>4vrFPrJOc+U4F(tD}dyNFPn=+g2eG0s_s}yE-z@0-7VUU0U z07=R0l%<1|_uQ1(=OyQ98mMn(<~#C`$iiWb>ZTO!n%4;KBvrzq9l6n!D}Xs(v{=o) z>Zt>8neQjYVLfdT<+ssv?_UBZ~2+z1dQ38ftN4I^fY!5tjb=6>%SUCw4f z_8SkfkzG`F>g z?T#W*EG;e)+lx0BZzNZ5lY}wwTD? zdcDxVo|Q)B=y17nxT0=2q1VL;4Q5kqsqEh4vE2asS?A~&6&+6Owd;fiA;&d(JZh?W z!Nm5ubV5V^+4@3LtwVN4hTZQCuj$Ry3k?nvyG?diePYu0Ue^f?-!%=C$!_8~zL(Sq z4Ha#T)ux#P?z|LZ?{1yYa7V(t*Pi_$+;eJWO{-yyy=2vfj^AsVG*`5m3f?%_bX@}q8>y$gNVN4 zLsl3=_NWBg7_GwIg0?PXgUC)OY!FfLHAG)S+ox@4YeUqCD39opjJ7hOk`K{0FxmIg zIV&GYM|3d}U6U}`FC!C#&oF|l=SgIrMBAOn5=sUiAnQT4^Yh3AVOVy;vf$l^=&vzZ zi-J1P_8g)-q7(l@bP3VUClGxCU0>dSt_6s`foM1plf4@8O%R4fWMu-0$gL(x?=ml- z#RAIxku==pRz}iKn=hd}UJlDbP>~~PoTPVI=1~r!JmOdAG^8yTP(F^k3p-Jc6!MbZ zwO)}gp!YbD#`&zg(}DQmeJdGA!>erI~?R6p^sUoeYpkyK~;Y%)is zkb{ix^jr`B1)ndHv!i?aP55~q3D{gJ%YX%CH_spw=aMq3q^%F&UqKZ(fg zesTOj?5}Yo$EL}r?HPTqCT4gNl3nHpna4@4-7=UNCEET><N;N^5mJyP~pzS|pvK z_lE1@0q!|UwcGw@4^ z=SJT5%w-Bs87LB*qOGxq+6}-0z&(9k%*e*_(&}?R&DcP;|5nt$f2{3`{qjY7AQ0%V z+?DA)!EMaQfs7(hJ~>GPP6j9lS(7ck<-Gt|ngqVv3lBsz0L=nG0^lpgsmZP41`6ni z^nN$Tb-@3BHvKKn-u{+nug0EPl?Cx5WTY#VcN)H_TW5UXeO00J$r(+fy@J{OQ9VCN zpON-Endt?v#u70ob1J?u=NZo$064uSAZRAJa-Q45_n620SAuY#MbjLm64pnw_jWEB z{O6N7Dy?7Wl#@Fkbcbb zeIrj_c;DVoi7C?;L5H<jS+wQELDGzGc>~7;oj&NP_S9Ox>QV%5!7`tD}5j6Sy?#_Up4VY6elYO`L`<8RB z9AuXJy!UAb{A`7A0p{a}j9*409SQoqOqPb`yS*NP#m2x7w~=)!zXbu3IN_BV+*_S<^=$E$61VcvzDlsK1qVo6lX|K z8mdI)&?G*P$FUIrw0Ij(egFVoH<8X!RDji}^4)QNHLCpcL^*&N0G=me zQJrvGXS8-a6F*?nh?nC+(GPZ_Fmy)p@)Tq$e z0Ci-n%4$^gy=1ItYD67R#YY09;YBkZRsK-CHxH3es5GfM0SESv8c~<&n1*jfYiP)_ zZoAsbACj@I#PI_T;GNjN%3tt>-PCmdp3bJ0$M&=~ImFG)o10r&ikq7q?Zu^SlPi|T P$k3(Kw2&@sD*^unQPbP8 diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 7af9729..638ef1e 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -157,6 +157,19 @@ extern "C" __builtin_amdgcn_s_barrier(); } + int32_t __ockl_wgred_and_i32(int32_t) __device__; + int32_t __ockl_wgred_or_i32(int32_t) __device__; + + #define BAR_RED_IMPL(reducer) \ + bool FUNC(bar_red_##reducer##_pred)(uint32_t barrier __attribute__((unused)), bool predicate, bool invert_predicate) \ + { \ + /* TODO: handle barrier */ \ + return __ockl_wgred_##reducer##_i32(predicate ^ invert_predicate); \ + } + + BAR_RED_IMPL(and); + BAR_RED_IMPL(or); + void FUNC(__assertfail)(uint64_t message, uint64_t file, uint32_t line, diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 58341e4..1c2b52c 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -639,6 +639,7 @@ impl<'a> MethodEmitContext<'a> { // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Activemask { .. } => return Err(error_unreachable()), } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index cc9afa7..92be749 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -75,6 +75,7 @@ fn run_instruction<'input>( | ast::Instruction::Atom { .. } | ast::Instruction::AtomCas { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Bfe { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Bra { .. } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index fdaafd1..3d56dd0 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1804,6 +1804,7 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Selp { .. } | ast::Instruction::Ret { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Cvta { .. } | ast::Instruction::Atom { .. } | ast::Instruction::Mul24 { .. } diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 0f9311a..db6b473 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -108,6 +108,16 @@ fn run_instruction<'input>( i @ ptx_parser::Instruction::Bar { .. } => { to_call(resolver, fn_declarations, "bar_sync".into(), i)? } + ptx_parser::Instruction::BarRed { data, arguments } => { + if arguments.src_threadcount.is_some() { + return Err(error_todo()); + } + let name = match data.pred_reduction { + ptx_parser::Reduction::And => "bar_red_and_pred", + ptx_parser::Reduction::Or => "bar_red_or_pred", + }; + to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })? + } i => i, }) } diff --git a/ptx/src/test/ll/bar_red_and_pred.ll b/ptx/src/test/ll/bar_red_and_pred.ll new file mode 100644 index 0000000..649efc0 --- /dev/null +++ b/ptx/src/test/ll/bar_red_and_pred.ll @@ -0,0 +1,121 @@ +declare i1 @__zluda_ptx_impl_bar_red_and_pred(i32, i1, i1) #0 + +declare i1 @__zluda_ptx_impl_bar_red_or_pred(i32, i1, i1) #0 + +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @bar_red_and_pred(ptr addrspace(4) byref(i64) %"73", ptr addrspace(4) byref(i64) %"74") #1 { + %"75" = alloca i64, align 8, addrspace(5) + %"76" = alloca i64, align 8, addrspace(5) + %"77" = alloca i32, align 4, addrspace(5) + %"78" = alloca i32, align 4, addrspace(5) + %"79" = alloca i1, align 1, addrspace(5) + %"80" = alloca i1, align 1, addrspace(5) + %"81" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"70" + +"70": ; preds = %1 + %"82" = load i64, ptr addrspace(4) %"74", align 4 + store i64 %"82", ptr addrspace(5) %"75", align 4 + %"44" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"71" + +"71": ; preds = %"70" + store i32 %"44", ptr addrspace(5) %"77", align 4 + %"85" = load i32, ptr addrspace(5) %"77", align 4 + %"84" = urem i32 %"85", 2 + store i32 %"84", ptr addrspace(5) %"78", align 4 + %"87" = load i32, ptr addrspace(5) %"78", align 4 + %"86" = icmp eq i32 %"87", 0 + store i1 %"86", ptr addrspace(5) %"80", align 1 + store i32 0, ptr addrspace(5) %"81", align 4 + %"90" = load i1, ptr addrspace(5) %"80", align 1 + %"89" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"90", i1 false) + store i1 %"89", ptr addrspace(5) %"79", align 1 + %"91" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"91", label %"17", label %"18" + +"17": ; preds = %"71" + %"93" = load i32, ptr addrspace(5) %"81", align 4 + %"92" = add i32 %"93", 1 + store i32 %"92", ptr addrspace(5) %"81", align 4 + br label %"18" + +"18": ; preds = %"17", %"71" + %"95" = load i1, ptr addrspace(5) %"80", align 1 + %"94" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"95", i1 false) + store i1 %"94", ptr addrspace(5) %"79", align 1 + %"96" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"96", label %"19", label %"20" + +"19": ; preds = %"18" + %"98" = load i32, ptr addrspace(5) %"81", align 4 + %"97" = add i32 %"98", 1 + store i32 %"97", ptr addrspace(5) %"81", align 4 + br label %"20" + +"20": ; preds = %"19", %"18" + store i1 true, ptr addrspace(5) %"80", align 1 + %"101" = load i1, ptr addrspace(5) %"80", align 1 + %"100" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"101", i1 false) + store i1 %"100", ptr addrspace(5) %"79", align 1 + %"102" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"102", label %"21", label %"22" + +"21": ; preds = %"20" + %"104" = load i32, ptr addrspace(5) %"81", align 4 + %"103" = add i32 %"104", 1 + store i32 %"103", ptr addrspace(5) %"81", align 4 + br label %"22" + +"22": ; preds = %"21", %"20" + store i1 false, ptr addrspace(5) %"80", align 1 + %"107" = load i1, ptr addrspace(5) %"80", align 1 + %"106" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"107", i1 false) + store i1 %"106", ptr addrspace(5) %"79", align 1 + %"108" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"108", label %"23", label %"24" + +"23": ; preds = %"22" + %"110" = load i32, ptr addrspace(5) %"81", align 4 + %"109" = add i32 %"110", 1 + store i32 %"109", ptr addrspace(5) %"81", align 4 + br label %"24" + +"24": ; preds = %"23", %"22" + store i1 true, ptr addrspace(5) %"80", align 1 + %"113" = load i1, ptr addrspace(5) %"80", align 1 + %"112" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"113", i1 true) + store i1 %"112", ptr addrspace(5) %"79", align 1 + %"114" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"114", label %"25", label %"26" + +"25": ; preds = %"24" + %"116" = load i32, ptr addrspace(5) %"81", align 4 + %"115" = add i32 %"116", 1 + store i32 %"115", ptr addrspace(5) %"81", align 4 + br label %"26" + +"26": ; preds = %"25", %"24" + %"118" = load i32, ptr addrspace(5) %"77", align 4 + %"117" = zext i32 %"118" to i64 + store i64 %"117", ptr addrspace(5) %"76", align 4 + %"120" = load i64, ptr addrspace(5) %"76", align 4 + %"119" = mul i64 %"120", 4 + store i64 %"119", ptr addrspace(5) %"76", align 4 + %"122" = load i64, ptr addrspace(5) %"75", align 4 + %"123" = load i64, ptr addrspace(5) %"76", align 4 + %"121" = add i64 %"122", %"123" + store i64 %"121", ptr addrspace(5) %"75", align 4 + %"124" = load i64, ptr addrspace(5) %"75", align 4 + %"125" = load i32, ptr addrspace(5) %"81", align 4 + %"126" = inttoptr i64 %"124" to ptr + store i32 %"125", ptr %"126", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/bar_red_and_pred.ptx b/ptx/src/test/spirv_run/bar_red_and_pred.ptx new file mode 100644 index 0000000..777b771 --- /dev/null +++ b/ptx/src/test/spirv_run/bar_red_and_pred.ptx @@ -0,0 +1,60 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry bar_red_and_pred( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u64 out_index; + .reg .u32 thread_id; + .reg .u32 thread_mod_2; + .reg .pred pred; + .reg .pred cond; + .reg .u32 result; + + ld.param.u64 out_addr, [output]; + + mov.u32 thread_id, %tid.x; + rem.u32 thread_mod_2, thread_id, 2; + setp.eq.u32 cond, thread_mod_2, 0; + + mov.u32 result, 0; + + // Basic functionality + + // result += AND(tid.x % 2 == 0) forall threads + bar.red.and.pred pred, 1, cond; + @pred add.u32 result, result, 1; + // result += OR(tid.x % 2 == 0) forall threads + bar.red.or.pred pred, 1, cond; + @pred add.u32 result, result, 1; + + // result += AND(true) forall threads + setp.eq.u32 cond, 1, 1; + bar.red.and.pred pred, 1, cond; + @pred add.u32 result, result, 1; + // result += OR(false) forall threads + setp.eq.u32 cond, 1, 0; + bar.red.or.pred pred, 1, cond; + @pred add.u32 result, result, 1; + + // Negated condition + // result += AND(!true) forall threads + setp.eq.u32 cond, 1, 1; + bar.red.and.pred pred, 1, !cond; + @pred add.u32 result, result, 1; + + // Return result + + cvt.u64.u32 out_index, thread_id; + mul.lo.u64 out_index, out_index, 4; + add.u64 out_addr, out_addr, out_index; + st.u32 [out_addr], result; + + // result should be 2 + + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 84e0731..c594ebb 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -307,6 +307,13 @@ test_ptx_warp!(tid, [ 32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8, 48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8, ]); +test_ptx_warp!(bar_red_and_pred, [ + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, +]); + struct DisplayError { err: T, } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index ca7b9df..6e42871 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -2,7 +2,7 @@ use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; -use crate::{Mul24Control, PtxError, PtxParserState}; +use crate::{Mul24Control, Reduction, PtxError, PtxParserState}; use bitflags::bitflags; use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; @@ -95,6 +95,26 @@ ptx_parser_macros::generate_instruction_type!( src2: Option, } }, + BarRed { + type: Type::Scalar(ScalarType::U32), + data: BarRedData, + arguments: { + dst1: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + src_barrier: T, + src_threadcount: Option, + src_predicate: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + src_negate_predicate: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + } + }, Bfe { type: Type::Scalar(data.clone()), data: ScalarType, @@ -1745,6 +1765,12 @@ pub struct BarData { pub aligned: bool, } +#[derive(Copy, Clone)] +pub struct BarRedData { + pub aligned: bool, + pub pred_reduction: Reduction, +} + pub struct AtomDetails { pub type_: Type, pub semantics: AtomSemantics, diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 6dedbbb..da14406 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1705,6 +1705,9 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum Mul24Control { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum Reduction { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -2987,7 +2990,9 @@ derive_parser!( barrier{.cta}.sync{.aligned} a{, b} => { let _ = cta; ast::Instruction::Bar { - data: ast::BarData { aligned }, + data: ast::BarData { + aligned, + }, arguments: BarArgs { src1: a, src2: b } } } @@ -2997,14 +3002,32 @@ derive_parser!( bar{.cta}.sync a{, b} => { let _ = cta; ast::Instruction::Bar { - data: ast::BarData { aligned: true }, + data: ast::BarData { + aligned: true, + }, arguments: BarArgs { src1: a, src2: b } } } //bar{.cta}.arrive a, b; //bar{.cta}.red.popc.u32 d, a{, b}, {!}c; - //bar{.cta}.red.op.pred p, a{, b}, {!}c; - //.op = { .and, .or }; + bar{.cta}.red.op.pred p, a{, b}, {!}c => { + let _ = cta; + let (negate_src3, c) = c; + ast::Instruction::BarRed { + data: ast::BarRedData { + aligned: true, + pred_reduction: op, + }, + arguments: BarRedArgs { + dst1: p, + src_barrier: a, + src_threadcount: b, + src_predicate: c, + src_negate_predicate: ParsedOperand::Imm(ImmediateValue::U64(negate_src3 as u64)) + } + } + } + .op: Reduction = { .and, .or }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => { diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index f88395d..0e916b4 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -784,7 +784,7 @@ fn emit_definition_parser( }; let can_be_negated = if arg.can_be_negated { quote! { - opt(any.verify(|(t, _)| *t == #token_type::Not)).map(|o| o.is_some()) + opt(any.verify(|(t, _)| *t == #token_type::Exclamation)).map(|o| o.is_some()) } } else { quote! {