From 81baecf2c821fcae29465ab9f0af85d810754182 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 25 Sep 2024 02:46:08 +0200 Subject: [PATCH] Add ptx_impl bitcode module --- comgr/src/lib.rs | 68 +++++++++++++++++-------- ptx/lib/zluda_ptx_impl.bc | Bin 34052 -> 2660 bytes ptx/lib/zluda_ptx_impl.cpp | 18 +++++++ ptx/src/pass/deparamize_functions.rs | 63 ++++++++++++++++++++++- ptx/src/pass/fix_special_registers2.rs | 8 +-- ptx/src/pass/hoist_globals.rs | 2 +- ptx/src/pass/mod.rs | 32 +++++++----- ptx/src/test/spirv_run/mod.rs | 1 + 8 files changed, 152 insertions(+), 40 deletions(-) create mode 100644 ptx/lib/zluda_ptx_impl.cpp diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index f27a127..bdec0fb 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -79,6 +79,10 @@ impl ActionInfo { unsafe { amd_comgr_action_info_set_isa_name(self.get(), full_isa.as_ptr().cast()) } } + fn set_language(&self, language: amd_comgr_language_t) -> Result<(), amd_comgr_status_s> { + unsafe { amd_comgr_action_info_set_language(self.get(), language) } + } + fn get(&self) -> amd_comgr_action_info_t { self.0 } @@ -90,36 +94,56 @@ impl Drop for ActionInfo { } } -pub fn compile_bitcode(gcn_arch: &CStr, buffer: &[u8]) -> Result, amd_comgr_status_s> { +pub fn compile_bitcode( + gcn_arch: &CStr, + main_buffer: &[u8], + ptx_impl: &[u8], +) -> Result, amd_comgr_status_s> { use amd_comgr_sys::*; let bitcode_data_set = DataSet::new()?; - let bitcode_data = Data::new( + let main_bitcode_data = Data::new( amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, c"zluda.bc", - buffer, + main_buffer, + )?; + bitcode_data_set.add(&main_bitcode_data)?; + let stdlib_bitcode_data = Data::new( + amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, + c"ptx_impl.bc", + ptx_impl, + )?; + bitcode_data_set.add(&stdlib_bitcode_data)?; + let lang_action_info = ActionInfo::new()?; + lang_action_info.set_isa_name(gcn_arch)?; + lang_action_info.set_language(amd_comgr_language_t::AMD_COMGR_LANGUAGE_LLVM_IR)?; + let linked_data_set = do_action( + &bitcode_data_set, + &lang_action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, )?; - bitcode_data_set.add(&bitcode_data)?; - let reloc_data_set = DataSet::new()?; let action_info = ActionInfo::new()?; action_info.set_isa_name(gcn_arch)?; - unsafe { - amd_comgr_do_action( - amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, - action_info.get(), - bitcode_data_set.get(), - reloc_data_set.get(), - ) - }?; - let exec_data_set = DataSet::new()?; - unsafe { - amd_comgr_do_action( - amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, - action_info.get(), - reloc_data_set.get(), - exec_data_set.get(), - ) - }?; + let reloc_data_set = do_action( + &linked_data_set, + &action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, + )?; + let exec_data_set = do_action( + &reloc_data_set, + &action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, + )?; let executable = exec_data_set.get_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_EXECUTABLE, 0)?; executable.copy_content() } + +fn do_action( + data_set: &DataSet, + action: &ActionInfo, + kind: amd_comgr_action_kind_t, +) -> Result { + let result = DataSet::new()?; + unsafe { amd_comgr_do_action(kind, action.get(), data_set.get(), result.get()) }?; + Ok(result) +} diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 2d194c40c4406fc81c0a2b832e37afad1928fce4..cbbf2dc61f1365e90155feebfacc3088b43d0d05 100644 GIT binary patch delta 1512 zcmYLJZD?Cn7=CY(-IKIgKGt?zmZrODYh8!No3u?6o4K2$)9UI-H-rh!n>0YS|&{Kb10;NJIX&y0r@6 zMXK7}rmW9R59`QLw|7{XS1sf27LSvyGB6a+lD^Dl6s2`22`EjYDPMhN5ML%Y1)SpR zbEKe2vMB9Ci8A?6hmsPC`B22CoO(D4CHE zhBEoki()iNSLTRM8v}&b)uAX|@uN|cMJWmKfD!_VtZU+Q&EQg{)ppd@~?s> z9YtvY?J}vmgN=Q#)pKnBiu0{!oYxBtHx#fSs~k4~2-~T{5!X&aA=D$VM0h^6m({R9 z)j}p z@$={uC2|X5G%6;f$~i8~%(@)wCHsc{mMze86^CzRd z&uZ^79#d|LLh&Jylf=`4$S+)&C!G0`bG~FRd!D_V8@O7^|DYZ+wYdFl+IuGBFU!Wm zGtDqf16YeKR7WeN>n5Hc09-E*(t)u5JM{gg~+s`moCcxPIwsA|3)o=GRw!YDB0O(Q7#S;a*0dOLT=`VBs9o?F>wH{x3&3N<2~}s+;+PtstGb(RTsg<3!&75^ zE}BSjp-42$`2&eiG8pm4PlUtCh?x)i$2dMX5sD{*yg5|VCz?9rd~kw0nFx(>++^TH gn41VDxbS4aALinrbkNdoX3VzYtyz;Q8m;Z?e~U1w6aWAK literal 34052 zcmeI53s@Uf-v7@eA(KK#QlM=h2~MD`w6-)rrJ;Z(q*bg{V^{aZwYwxhT4=e1khT;r zn@MPc6}wSU+1ho%cC}s?6_s74fuZ~JYZDNPLI{mSSR`@ZuwMc>o~L}nCH zixGNf5<)RNHR2(JCL@GMI`v_FpWr5;EHmd3qES38N|cx^c(BBwT*xbUB&OhyEHgnX zlHVMiX=^I9?og*Lh;ui~^pEh$`b4iJP=*_tsCS4=Tcdnt+DuQzy&X#*rgrxsgdU&} z%0~!gru0pq%GL=UP2Jikcr^83W4k9IZt=$En8rhkWgWdrLN03)JCH-*66uUWMCjJ( z1*SCqAfk4N(TqK9MQKeLO>F{;QJXL-UtnCHnU)g$h?tQ5Mcfj-Zk9`Qhfx^4zTl5~ z!HP7##h|`bjSzC=)d|{+BAE!a8D}7bkZ#KohmJyyMGFxczC$2uS~3m!R%}Ewcx?+` zj&q2zd7369+)43_35M(^-(QyX{OvD@9iKjAb&0BUExH@u%0vhqixnggqp5=l-AcVl z(CyLPfMUss-IRO*a+x=xqy>dhB{G6S2;DA6XgYzA;@2tX&YwSjo*3aH33@=}UQM~{ z9PWm4SA*5PRz%jNxwAE7eYwlp=3Z0ow#(gX(Xc-GSi2JSDi5iZ``eZKQS#?%)d4zr zk3Jc#sVCQ@k&be=-RjOZxoujv)#T3ZaNFc$Lp`~++;vYOX=@|xXb@2chE)3<$%oX+ zK_dBs9@Ts8vHR7^VM2u(Jnn`x($Pvbh}^YB?raCSww`QAbJw=H*Ld97GIxW9tV_ea ztt}+)OCyog5!)z3@2-Db{-5*b&;K3`A`gll=BFt4)5`aVI1X#2+>Ba-R8C-QJu$aEqB|>`StZ~yOnH6BiEt>L#hMqO7yhqgC5m> zhicD|vX@rv?@r!hRt^uT4z2z_J~?F+O9--D0{o$BpwxhhY%@I zoQ}9p*J#M~cyqc9`JMXwE^l6=E2qVvXyxa1>a{*^Zl_^>vo{ZoQqt2#2{KDg$s}JJ zB~?iJjYo1qCp}C_9l6_Gxh;HMr(u4VLD6Bz*=*2u7!(7BoMwLhc2{0!PySYZj*mYd zS&{6-h@?s=TZRNWL7D zj2b1UP?tBi%aDT>={9hid)c<8x1CQgn0bLreqmp`@D9V*G<3u1b?k z6rmL^-oKfY3sKZQ!bd7Lh(x>+e#!+n+9NFej_qv)U(a$y%Q4JIQ!Jhg*A3{3@Q>G{ zl^w+292Di8%bXg}O>%@J2%-2!U8_tS&nqb~E*3RU=(R2uc_!#Hd*lL(dD8^%U|zf+ zZIi4po>yL=xka>Xg8o`{$8^!YsQ9O?<)lsPuC=kSZj#wUKoc*`Oh9KH-Rl`Gz4?Na}CHMj0=9X!V$=FHI!dn@L)0A{%7n zYM<~7g<%~;3yRN!KzhZ4TU^)`OQxn~@D%Wg;6Wysa`xSIO{0c`*s3GzeZmtO+<3K)tkT5bUCW|Z`Gn8P$f`lfmu=*lwu!x- z+utzqN3Q?ybZY-==Xx`@@($<_qL2g0@M#KJ`KH3d!fFa38PXvUK2xDCLbLUYsChDk zM2JEHgxeoL@+HqTDntkgkc>HX%TP5EKR41FD?o^cM2JAR{XDe(jw3@t9zp~XIPjTI z;7}mL%%oaAfgA|&beQqFV;$usz?l z*yuoeukRB^1=@Rfg^L$p@1wh4*AoHuHoUpNEwFlzl~fr6tM}?_*=d2*d+2UOLSXew zF9+MZ?~`DAr&k8sixmd%H)UsAtSm6!lk!f27m&!Ft6kZ;K!-0a^9eF5yE6+2$OqjLzuKmLQIZ0>CwR^3vOx;o)?-3s+J@Nn#{@H(GzHL(d%w)eW1y&&&o(?KqY5pU-iXc zU?o5L%Am{(tmM-Bn_OgICI2UcN$2?xCe4K*Oiq3v*yKY~Cz5f2CjUP53mq?ToqTG6 zvM|tOWlIx(LZHd#;%X@(;1t~ZUn!EQft9Sdc?T^B816-D9?*IMEBTTpgvsls2b=up zy$~kv-4w#)%6OY8&}4LNfn-*o$+}zn=%@gb&mWkO(-CO$vph99IiQly@4Na?cVH!7 zKi{Y<3asS*4i7&ju#$5o2b-f2x>BCI*PRNGby>X`lOrBdC&(-k`Jw zR`QMyoA{EzO5UFoY_jB6slg`SUJ=5iC4|X;zx420Y<%ohuJx0?UWAR0xqQ9zhonC; zi_MtJ*G20dnkB|7;@HdPhhBO3Ap+NI?Dg<<*{^2ei{G)=;dj2fezg!G@kP7-{9mo_ zArW)M&3rLe&V`EKBq5}^_=>%vbP=CoF5&%O_)@&4%DNIE9WmDAiq~WJ#^UI)CLeqA zmA?@P2`@4!o2Q+JQ1?Y9cOJaqs1PCfMJ7)j82lI^2mhk_yrh_<>4-A=II{z?f%AWvBW+MuEHjP4d5z5aGMf5Z5PVua_4^ zA>{=)4^7P$CH(t2g!*^{LQ;G`Gr5&YoVeK2JYo8cy95Sai%#8QnAnVDMf&?baXmdK&pF{E6UK;$6w}P5=;z~GM9uvC{`zLSmApqpt}S$1tz?6hzq;OS zrQC`A6D7wzk|B+>S|dF;A{jACvSd3{pR1JvYE`c}d70(`e5cB1$m!(gZFl80@^kw~ z5>}hYd& zxgB2d?h3cH$ZfN_?B!&=3E%#-x49GF`<3(rlAbb34tXR8bBM9qU}$es6AvA>U`n`Jp%OkRi9po7>@S-Wly)+v;8;ciTnyR;8_# zUV2a_Io>TfY?OSfN$vSSpL}R&%D|B7b0T?JVPbBFE5FmFX!L5E4LSYZysn;{K5y}4hp{QA}b4f$RCoEBG3f6;cUm0VL#Hi*cE zLiZZ0J8}0s=_!rm_=w~nLTD|8(2EEmTXG>~Xc8w*uw@ofsdV$h55K>lO_o-cmyubp z(WBlnpqa^A6c^`>pCtS*SqvYoDJ?hKit8#GEXJD*)MCqOOJ!-Lxq3OZ+G4A#sHvth zveGr_8Y->aZnxIWpEGB9g}r=LNqT8b)tn{fGD}VMZB|RQp{B~Z%5IrcS-HAuw$)a% z!cuCdY&E5`=cZ?+YpBeO%$#efWtlm1vzBFR=g!TVSC(0tkuf*Rd~LR6ZicojGiO<8 zc1c!N#$3(3%o!Ec_L{2d3TutsQf;TO!CzHWSt~7+VUFRNYp6R4^!jC$HD-HR&8m`0 zOL?8Sc%{WwZK*6StEjV@?WN_#R&(h}i@mt6uDr%(udJzFZmwQSS!*h)?G_tuUstib z+FV&&htM-?{VRaRiC)|(;&o3>73uD7rciZ;6K&rq;>o*riV9}-;Y~mO(T6wt@MZvS z4&coo-W=NpeLM!c!)7Jew^KO+24>{4{~%ysGR%tBe>Mfjnc@n$jJ>{Tw#y7vL) zVw>rmo@r#Ovdq@@T+cJY%}MqZgU|KMBTYKetRYK`WU0n$B3E?m)eO){)-3VrHsNQl z5R|Zl`B&};^N$jibl){Osf||n|5;^8ZA4NV@#ijD(N}14W|ow`wz96`p41mguc^x@ zjmdpdu~5#>Ec`O|<^HPXy@#us_wIaoadrCFd#hU7#7oS$^-ZbQ)NN5r(=g|P=i~_C z*KuN^ZE5Z`4zi)h+G@z}@@hK_IXIYC-<6_ zzIJ7w{{ST3mmxG2zfRVlx7C%qnV;Lj&u??(wHUM>FJbi@FlYx1IjycdysUclHt9E% z^ee06)QIFXl9CIxn+-W#{CuTm$SOHJC>iw$zZt~$)`v7w(uZFec+tzQ6`ei!7JH{& z+wY}MN=_IhRWix9R`l7}{tXKJGJ`?U=*>fB3ZXcB4_CmS;K;aHCcXpld;)K5k4|{C zRr8swthY;VlM$4hr)i!p=qY!i{!9_A*`z@V!47;A5FvE8kTSx_mt%h#IX|NP=_2gU z7-4@@LWJv=5aIeo*q<@N{;t&T=N)fXj^CeIzX%^6SB~GGHy;1r)i1*5w=3u0EAx-} z;&{DwW&Qh!`I5w3KO1J@7vu45B*8WU%|*Lr=`}_|q>Epw>9!h-o*ui)BseIXP#|!a zMBSQAGQpM!Jk7h5ptrmiRmx=ii>{^B83^HP>X`LN^H7ulHM4iT@OWMB{<=-#Uvz)S z{LAL2Gyk&rIgGz-e%;1jHa~RxFPoq4_Fp!?E8~YEy1xA4<9B8KjO_SbSwACM&nxR^ zWb1k5{1@4JURggQThA-&XN2qdGp`5n=aJ$ChG_aPNV9B$laAJ*O%#9U?~FR3Lp85Z zoPq=fwEq1@N`8yXahFgnyCGgk@po2Vcn-FnGU4Z5_&FGUeIsT)qToFP`1#kR`xlmu zZ~vlk_8&dY{u>L%w?6Aj#<%~barVDqoc&*^|4*+Ek*R;O~|M2}sX8j4@e`MC5@cobX{J&Cvy#9=L{oUid{)F#8GV4$H{v)&ggzrBx z>reRpBQyWQ_aB+{Cw%{rS%1R!ADQ(heE(PKkJq2@^@&XVcm4$}V4iXD(6xoXi3@dDH1dQs#1jYOa0Rs4=J=KW;rDJz%fU*Y?A^vK7xf7Cn9 z{(HySfBQK5Zy#s>Pmi@r2f0ddHo6Be`MC5@cl<-{R!WHWY(YX z{YPg0hwnc!>reRpBeVX5?|mgMyeZ;)Z*0F>u!ioLGONL`ao6)cV&XOtX;#s# zgFL4+V=mu#Kz1NXRO}m${linT<3e=?`x>2f>}zz8gPvGN$v5kj0>#C-F??%}0yyhA z4=et?hbcGqd2Sa9c(@-+y`@ya%SWlgg$Yg*ik2@bLylN~f2oY;g-lJck9|o*8l8gB zNsa&goAClFJzMlhGS1&SwMr-3GSRc%nB;9D@Zh}IGw)_yvs`wY$k~&0?`wh`W3add z_3A`7=*Vs%k-_A1_j-rVHQTEj?s>eO8vXJ0$@mTW3!g2=b<~IPlqzgYaGE;E<(kDG z2%|ERhzikPGGOn``jr!UTWH&i!i}HDMbB^@UV`L58qz0qKJFou&?9iRQ=bntEmt8AhInua*-k zq-ccCy#>~NUYLD_to!KWGHv3lCwSfEv*~S_#`LInGH%w5K5ku%A*z_Ie;>n}80a&x zEs5`-C9bGO>zv%h(-?Vau4zsB!5Ce%L#u9$5hPYqyo^31yHU`g*NNGB3R5~>Heo(?_)wLJ#B5=SyTM zjp+MW`N*(?w?(40e&rCnD=8ZJ+9CW_Qg1!y5Irev9r?}??Ui;}|LqWeB^?+!?~pty zJ7ianPDJ`9{C0=XGrKi3dUu3(Ep6*JhKXntG_Etix?G}3Gw+UURis)G{ zOd7B)rPB{iI(Mp=);u{`Ubm9Y{BW}N+iH65gHwv?8t82Al=^S)qqW~pX{~#Jp8uM> z>)U2J?<@I0T`PUvmc(=4K1u67O_baJOcy+wq#b>pUa&icsy{tU z7d@NYdipE+H=iVTojyS?bze1b`hV%)y?NES(?8IrqgTn#oTp2kptNVAon^bJqBF71 zWoN1SGjeD7v+C9}3g?PX)Lmziot5sCfip9l)o-SpJ2T5^J(?mvJI85zB2{~KzSF)t zwdicVbM@KO`m=@3`e&!Mp1s-G@X6G!v$s0$bx#{OdzbV6H>aIDyVU6*j!u)yM~a-s}J$%9ASef7HI2AG9r_(J6R42+{toK;D2Q~DDVR; z2L(QX@82&d@NAZY0xxAbC~yzUL4m)`a!}yMSPlw&ioov&6!-#`g95K*IVkX_SPlyO zU6z9aKgV)V;M1b~en5dQVL2%9-?JPP_=_wD1^zFVg8~-{{rd$4p2c!d;3lDeK7#_^ z$Z}BN+gT0@{20qYfy+sM|DeEcU^ysoJIg_VKh1Jb;69dv0{;)oL4nT_`Tc+b{|(DQ zfj5c#^#ld}I?F+Uf5CE4;L-{Hen5feu^be5192EFJSq=(( zgyo>X)iVEnL4hw}IVf-!%Rzy6vK$onVU~jepFGjOUr^vTvK$n6Jg96{ka!}y^Ohn8* zA5h=}EC&UClI5Vlqm%sofC5itIVkWvmV*MnGs*vXZBXDfNeIOwbU`r}%|gh@wg-jw zo7wiD(EhJUh`H|q3hm!z`vC=hI0+#ILKhS`hRiQ4WEHqlfshKJ3kp1kZ3hbcP6a}f z5W1k48T93Vf~7pBE_bM_3LD`~{YS0^gzZe+~r{IR5&r z7@^xp=AJ`!145{{xS?`YnYq|%uP?^G?5?=jY_F**wpl98^_H^YQoFf$xux1-t0*n5 zs4flD+FV+@s={XB*xF{Pw3zF-wl&neGOBu(xammKunQPBDfLmkmj0feCHpb44OEw12 zTzh6bb7~Bp@t|DN26*Pw7(8=oJa%qe(gt{Dj&An+Qf9fF^Go>F;F@#aID%tq@XZ{w zoPBd^4ZgXxhH`Uj4ZfKJm$Pqft-&|9*5I34Yw*pCrGMQEKkzWOIW-2?oEn2^PL07c zGcsH~b7~BpIW-2)oEn2?W>mO(=F}KGb7~BpIW-2)cu?2{huhq^WMlBmwPzf_tuc7U zgK|k5FrT?LW^CZnU9xAe!OUlNIWU)%6)(%myo^iypN<$?M9%wQQsC|i@?hrpW#z#= z0`lM<0kfEU1mwZY`^(CMdj#acJp%IJ9szkUBj{g7etMF?6z3cOS#S=3Bsd2^4$SCX zP7a&{AP3F?kOSuc$blKT%gKRr0OY_q0CM0Q06E}+V;3~;bBR*`dAP?>lkO%h&$b%X4G3#vDNdm_==K#op za{wg4IRJ8CM(=WR;2Z!sa1MYRI0rxu%*b6%4x9rZ2hIVI1Lpw90T0~Rb(Ygy;uHWm zT+SYF5cdGc0S_D}OK?1J4uBju2f#RR4uBl+z+F}joC6>S&H<1Emw=LGmf}^Rp2jf` zgqlqFfx2)Jb)h3Fic3OAR1}wljIdc&;}HCZn$k-u>^2MjJ5QGMijci9jkD(lH4fFY zxzt{<+EQh%TN!*Nud`W}7uze!!aPzP{-IL4IegepJy=> zdSy*8E3X*K>PnU|gEE$_3c=RY*j6sL)vU6nSCpmKk8OR?BWuGx(h&Gah22sWvNL!f zs4sZH-<8c`E-SWKmRHo-Ew=QE?799wAhx{1&iN03`A2NFxvC8RFRwFS{L*N8X{G<& G)c+4?XI`TK diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp new file mode 100644 index 0000000..937bda1 --- /dev/null +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -0,0 +1,18 @@ +// Every time this file changes it must te rebuilt, you need llvm-17: +// /opt/rocm/llvm/bin/clang -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && llvm-dis-17 zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | llvm-as-17 - -o zluda_ptx_impl.bc && llvm-dis-17 zluda_ptx_impl.bc + +#include +#include + +#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_ ## NAME + +extern "C" { + uint32_t FUNC(activemask)() { + return __builtin_amdgcn_read_exec_lo(); + } + + size_t __ockl_get_local_size(uint32_t) __device__; + uint32_t FUNC(sreg_ntid)(uint8_t member) { + return (uint32_t)__ockl_get_local_size(member); + } +} diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index 04c8831..6e0beab 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -94,7 +94,7 @@ fn run_method<'input>( .body .map(|statements| { for statement in statements { - run_statement(&remap_returns, &mut body, statement)?; + run_statement(resolver, &remap_returns, &mut body, statement)?; } Ok::<_, TranslateError>(body) }) @@ -110,6 +110,7 @@ fn run_method<'input>( } fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, result: &mut Vec, SpirvWord>>, statement: Statement, SpirvWord>, @@ -133,6 +134,66 @@ fn run_statement<'input>( } result.push(statement); } + Statement::Instruction(ast::Instruction::Call { + mut data, + mut arguments, + }) => { + let mut post_st = Vec::new(); + for ((type_, space), ident) in data + .input_arguments + .iter_mut() + .zip(arguments.input_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + result.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: *ident, + src: old_name, + }, + })); + } + } + for ((type_, space), ident) in data + .return_arguments + .iter_mut() + .zip(arguments.return_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + post_st.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: type_.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: *ident, + }, + })); + } + } + result.push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + result.extend(post_st.into_iter()); + } statement => { result.push(statement); } diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 97f6356..3553139 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -31,10 +31,10 @@ pub(super) fn run<'a, 'input>( sreg_to_function, result: Vec::new(), }; - directives - .into_iter() - .map(|directive| run_directive(&mut visitor, directive)) - .collect::, _>>() + for directive in directives.into_iter() { + result.push(run_directive(&mut visitor, directive)?); + } + Ok(result) } fn run_directive<'a, 'input>( diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index 753172a..718c052 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -5,7 +5,7 @@ pub(super) fn run<'input>( ) -> Result, SpirvWord>>, TranslateError> { let mut result = Vec::with_capacity(directives.len()); for mut directive in directives.into_iter() { - run_directive(&mut result, &mut directive); + run_directive(&mut result, &mut directive)?; result.push(directive); } Ok(result) diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 0e233ed..7ba9ed0 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -39,9 +39,8 @@ mod normalize_predicates; mod normalize_predicates2; mod resolve_function_pointers; -static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); -static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); -const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; +static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result { let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1)); @@ -220,6 +219,12 @@ pub struct Module { pub kernel_info: HashMap, } +impl Module { + pub fn linked_bitcode(&self) -> &[u8] { + ZLUDA_PTX_IMPL + } +} + struct GlobalStringIdResolver<'input> { current_id: SpirvWord, variables: HashMap, SpirvWord>, @@ -1975,7 +1980,7 @@ impl SpecialRegistersMap2 { let name = ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None)); let return_type = sreg.get_function_return_type(); - let input_type = sreg.get_function_return_type(); + let input_type = sreg.get_function_input_type(); ( sreg, ast::MethodDeclaration { @@ -1988,14 +1993,17 @@ impl SpecialRegistersMap2 { array_init: Vec::new(), }], name: name, - input_arguments: vec![ast::Variable { - align: None, - v_type: input_type.into(), - state_space: ast::StateSpace::Reg, - name: resolver - .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), - }], + input_arguments: input_type + .into_iter() + .map(|type_| ast::Variable { + align: None, + v_type: type_.into(), + state_space: ast::StateSpace::Reg, + name: resolver + .register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }) + .collect::>(), shared_mem: None, }, ) diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e15d6ea..60f5052 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -326,6 +326,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def let elf_module = comgr::compile_bitcode( unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, &*module.llvm_ir, + module.linked_bitcode(), ) .unwrap(); let mut module = ptr::null_mut();