From 9ca1c2da5a1fcbcaab059ee190b74d90e6575007 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 5 Dec 2024 05:43:20 +0100 Subject: [PATCH] Resolve crashes --- ptx/lib/zluda_ptx_impl.bc | Bin 5360 -> 7524 bytes ptx/lib/zluda_ptx_impl.cpp | 11 ++++++ ptx/src/pass/insert_explicit_load_store.rs | 42 +++++++++++++++++++++ ptx/src/pass/mod.rs | 4 +- ptx/src/pass/replace_known_functions.rs | 38 +++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 6 +-- zluda/src/impl/memory.rs | 4 ++ zluda/src/impl/mod.rs | 1 + zluda/src/lib.rs | 1 + 9 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 ptx/src/pass/replace_known_functions.rs diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 24c20d8b1bd94be0ef590b36498b2ddb325a1a31..6cefc813eb60ce7f302e265090ab069682291723 100644 GIT binary patch delta 4634 zcmbuCe^4Cd8ONXf!Lc06?s9iYjvRzNLMTMU3mQBE=-wTPB%XQ%;%K#Vcfb*ekX!&A z>yO+XcMyW8mx3+EXdz-uqmHC74t4B!7n1av(iE*aNwpnfh>fi=7)k1nPSbhzSTs)i zPx}rt-{pBf&+|Ud^Sac z<}Qf*d4czHFrFJ=19votoAQu ztW(q)m=pO+E)G`la{_l}glhr*;w(4k;?79?1(9zNxl11IV;6UEO7lg83w-V3U>TUY z%~0i~Yz!1Mi#>$cnHILe@W|=Xe_gq9B>=Y*v(z*XDd$4b#a(H%4jvr|R?(hQ)2YX? z3_iLP{>K6W|91sy0WV?!F(VeRhr$9vFgy}`NGBUgga{cq+|}NlDCT{zBt?0!Lhc|> zlJ`L0&E$^U*(xmLX1E{rlGDjDJ!kWzMm7F{Z(XFEzBJ-7F=R`E0;n4?XB|Y3Y^Z^M zg1pcj=-WZ<2n>WSCAJrLcLziXk$6msI7-JS%Hoc)FjY1dER8N~*i3}~OyJH(xHD1B zx6dGaHp{h0{6&GA1^$z1?ypmt^8z;~b5Id?RPJB(6212F%q4Kd<`Ecl_eo9V`YCirNQzGu7rVVg2)?U(0UH&%9bTNI##eFu*U7X@h zN4XXcH@mQ2=}f3B>?j+au#OPlK0~6QyIevM^OKNR>$ZMHhb- zgfAriw1xZolx9}sFG&1Z3pa-gXg+fBAA7j7BHv=+KN;lD!Za~rKHewLivz-+FehFf zi=9+^A1M(sAwt4>BS1C)=sP@iGh69(8oAtOgnYf$Cu-{t?_(>yh0ZhOQ(&PLta#iqvX5WCzNmU#58lRYFcuz7M3rF_|_v*FHqPE^1xs4f1iIprH zI;q)b$t(OoG8{X6!wKp(;dL*0D6Mjpc{-2%KpO3!UnT++3tOw2yMxx@Eumo`Mdc42 z#|Y)A@+JR12q6Fv&_Z~UqmeC!!X!o}R4mHL=1oP-sRGbI%V?$LEO4xRc-AO_78pR1 zfimc-Ego7efe{2Cm4FWd`-_uxfj020QG>y!VBGjzR}w)ek(61Oc22bmP348YH?s|fly(+Wg%ReRM6l}FJuxE>I z&Co#2LW6KOHw2zLNHUP3VoK9)ejo#!dex+C=kq;P0 zV9ZE&MtW27ZaGFfw96s4mw+@@vM*P>4jBR70)XooIrGTr%tixI(e1E?Cq^@wKa$2q z>t>ldD{!*{_i2>7VBx-q;9DL>FKA;#AeR8p)9VZI5EcPI!v@*)e`R%lQm362Ym9~srLwchm46wEkJ^TYsd#^KXE7(%h%4+v zV{ya!XAIg-J15qxJi4*Su)!)jWBtyz`_npec(KN6&{oLKz$XU2^CBnK++^q}lbyuV zLMO?IHI=^FD#P(z&Vohy&N@!4xjkvqMuWCYc2+;V$y=RYRQ%%Fq5QaE$q<=yJt-Z) zwtsE`$N+$xTjG&wj$VeH`9nPc+gi3Z`M0(EwziA)_jy}unp*tLKGAW_HDayZW(zbo zKiE>#yltDmsO5pSc7N-(&8@z+Ha%U*-pTA7wM%;!**|(kFl=^6&Bm^yCoPuVL*{(A zKKMfD0;^;7Zmc5G!ek@9=9UG?jiWITf-D93COkU$MiS zDC|dczXd9owFRsxAOsb*8d(>z{s^*fA}bCf`y2Rb|AY0 z!*0sP!juf&L)3++{}Q5eHkcDVk0SaOCYzz4R*ZTKQ327B&k>zP)c**gk6_fmbIUMn zA+j$aJD7~wo<(*BS(!i}cAY`u?Z#;|PoudzmW5TWVI=;r@k2BpaKM}xQOvO{R^sg$ zC(#^1^FUCcvyjes7tLc>-JmT?slP~&c>7Yt{w@YTKp<|jBEEDA@plm)QOtxA+%$>i zNuiN)k-f~sR7m)eNHMcF(iTE9${rDBnvrPe+mx1z_>zp3t#WR=4Pvb!wK2mBnsRs>JU~L*M6Vs{=n|I%31C zcdokC!}c4yCBbeZOo^)SO{b^VDm*r;VI3TQ%JtRmf+((^geRs_U~kL}RX+_nqWo+z)bM7;JI%q)wm~Y0S(-662meXZvkd`nj~&37vsj)YYfumj=Gqm9>ek zU{(>0Hn&4RutK~j1=sn>C7F9ud-NU8t`IdVg*tyYs0|x4Pm^8ObYbb!wT|0a@3lt=fK2Ui60cA@i%U@x1nccU(3N zrJmgW#{5^>#{!%J03Fpoj*HU@5?Ys`gjr;~oJlsJcK3kyuz& zm0wN(ECN_imAB&m%m53j@-=v|0SKrM3l5D0KqkO~stz6V3#h6t@G}Cy2(X|k{~D1v zN2;p)ujl!M1`-Qni{6``P*n|dem;<@YM^iCc~w9! X0}Woczx^RMUKQPHGMG!>u)+Cn6_+SA delta 2442 zcmcIlU2Gdw8a;PrJTtb&VeB}`B(d9ZV}cV}kw$77@xV^vHt_}vf zR{OH^az^Jn-~Gq z4U5Q`SINUf&kyo@6Ng*=O^bMq#2w~JxIdDTYtURAx~OHpo_Qdic{rDONR}Sj&hIia zUE4G9Dd|BR^{@QW?>*`F9Z5Pbyc7M)uYD)Y6-$2aB3W^wN7U;|Kq>6yn+A0f4Udb zGbQgosJE5oIz-4>9yPF^;qyQ6k9j(d)50aiCcK36L??S^{5$f`+v;!A7Fgkgc#0qj z{9OZF0BTdQUOl3Uccccj+H9#iXWua7Ju0kPTeG!yttPz?(Hy5I2Lnfv!P?Ze-eN?4 z#-6$}qI*vpt-^=t(%KpnqT;J0(ZavwpPk(wI@I;K1Ct!sa|8C>@N&#&KxPbHhVsfD ze+aRrDm6&7=vRlQK@3!s2Bx@<>8RaF%)6^9!xjr~tbdvlKtOa#6Vp-5wO`RCB=91s z>gr5?-MtZq3IT#iSpn?`9r}4u#iUqOQD-gDHUlbDP{V53SzNdkvG%HpIxFojfW&_1 z?5ElT#Qe6ZG7`2}c-!TlbGW{0gIYVb`8k^m9-Kj?T-n)sF$EE++S-m#`!l5`OsZ9i zIvYOSZ6NSFXXnI65sOw;)Y<$s^`!VFfGri89@;I>4{)4WwLuTf);rD%2G3eehdC!E z-~WaOZq`Aqbgr26G`b?W1NPSbBSOg0g1`X+r_u>vRs-HHY;*jRE8OWUkp$J(3O%c- zUpi&$4xmM}si|0eYe@u#bih{8h=D%EahXsuN-0Gi#$VTF|4Z%QpKGnvMaM5Hy2}GM zmQB%bxgkH~X-)i?*0t4NMAd>YXGlHr`-^CDoz)Kj>tQ5*^fY%5Ck+sw1B1O+>FOTlOXstv{nA(Gzl)PBCI&rO> zDy7zC)?OV`Ce*B<_Vy6V1dV^KvB|z5axDyi*Rv z2gVD+TsBFb(6#QdrkO@ zo|Mn+i```(O^)XG)cX>8qhqHIC^UsN>Z&yNsu(uzF~xVqtUhcy=EbA9*Q{9*&&TgK zU1L&kcEt3|Nt=G$TS@aa1$YC1o1ci&)X|W;IDr1nBUVB;S;42M0p;t7CbxkD5;O`y z4&?-eE#b_C`>WhQpEdOV?;@lZsv1DKhem+*Qlz}R8n_)O9~VQWf0!O2z%WrR*aN^Q zP>uvc)-VmA4yYpKss3%t5JRT_)EdB$>Hh`53skY9PXIir3}DC_G*a0;LpGp;05w1r znf@wE&yeYFw0MS$zmKky04gum(b1jlu^Je%h6#&j$oSVRLkuZT*Lla%Gi3S`r1EJf zuAz};h3~DCU`ToTE$Nj)bAr69so&31awP{it;B3rfTt@ld`ZafZt5rmn?E6W%@_F% DIYdlo diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 329a810..7af9729 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -4,6 +4,7 @@ #include #include +#include #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME @@ -155,4 +156,14 @@ extern "C" __builtin_amdgcn_fence(__ATOMIC_SEQ_CST, "workgroup"); __builtin_amdgcn_s_barrier(); } + + void FUNC(__assertfail)(uint64_t message, + uint64_t file, + uint32_t line, + uint64_t function, + uint64_t char_size) + { + (void)char_size; + __assert_fail((const char *)message, (const char *)file, line, (const char *)function); + } } diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 60c4a14..702f733 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -122,6 +122,13 @@ fn run_statement<'a, 'input>( result.push(Statement::Instruction(instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction)); } + Statement::PtrAccess(ptr_access) => { + let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); + let statement = statement.visit_map(visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(statement); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } s => { let new_statement = s.visit_map(visitor)?; result.extend(visitor.pre.drain(..).map(Statement::Instruction)); @@ -259,6 +266,41 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { Ok(ast::Instruction::Ld { data, arguments }) } + fn visit_ptr_access( + &mut self, + ptr_access: PtrAccess, + ) -> Result, TranslateError> { + let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) { + Some(RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + }) => (*old_space, *new_space, *name), + Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access), + }; + if ptr_access.state_space != old_space { + return Err(error_mismatched_type()); + } + // Propagate space changes in dst + let new_dst = self + .resolver + .register_unnamed(Some((ptr_access.underlying_type.clone(), new_space))); + self.variables.insert( + ptr_access.dst, + RemapAction::LDStSpaceChange { + old_space, + new_space, + name: new_dst, + }, + ); + Ok(PtrAccess { + ptr_src: name, + dst: new_dst, + state_space: new_space, + ..ptr_access + }) + } + fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { let old_space = match var.state_space { space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index ef131b4..c32cc39 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -22,6 +22,7 @@ mod normalize_identifiers2; mod normalize_predicates2; mod replace_instructions_with_function_calls; mod resolve_function_pointers; +mod replace_known_functions; static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; @@ -42,9 +43,10 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result>, ptx_parser::ParsedOperand>> = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; diff --git a/ptx/src/pass/replace_known_functions.rs b/ptx/src/pass/replace_known_functions.rs new file mode 100644 index 0000000..56bb7e6 --- /dev/null +++ b/ptx/src/pass/replace_known_functions.rs @@ -0,0 +1,38 @@ +use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord}; + +pub(crate) fn run<'input>( + resolver: &GlobalStringIdentResolver2<'input>, + mut directives: Vec>, +) -> Vec> { + for directive in directives.iter_mut() { + match directive { + NormalizedDirective2::Method(func) => { + func.import_as = + replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take()); + } + _ => {} + } + } + directives +} + +fn replace_with_ptx_impl<'input>( + resolver: &GlobalStringIdentResolver2<'input>, + fn_name: &ptx_parser::MethodName<'input, SpirvWord>, + name: Option, +) -> Option { + let known_names = ["__assertfail"]; + match name { + Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)), + Some(name) => Some(name), + None => match fn_name { + ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) { + Some(super::IdentEntry { + name: Some(name), .. + }) => Some(format!("__zluda_ptx_impl_{}", name)), + _ => None, + }, + ptx_parser::MethodName::Kernel(..) => None, + }, + } +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f4b7921..e4171cd 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -298,7 +298,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def let mut result = vec![0u8.into(); output.len()]; { let dev = 0; - let mut stream = ptr::null_mut(); + let mut stream = unsafe { mem::zeroed() }; unsafe { hipStreamCreate(&mut stream) }.unwrap(); let mut dev_props = unsafe { mem::zeroed() }; unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap(); @@ -308,9 +308,9 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def module.linked_bitcode(), ) .unwrap(); - let mut module = ptr::null_mut(); + let mut module = unsafe { mem::zeroed() }; unsafe { hipModuleLoadData(&mut module, elf_module.as_ptr() as _) }.unwrap(); - let mut kernel = ptr::null_mut(); + let mut kernel = unsafe { mem::zeroed() }; unsafe { hipModuleGetFunction(&mut kernel, module, name.as_ptr()) }.unwrap(); let mut inp_b = ptr::null_mut(); unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::()) }.unwrap(); diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 33d5a4e..18e58e7 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -38,3 +38,7 @@ pub(crate) fn get_address_range_v2( pub(crate) fn set_d32_v2(dst: hipDeviceptr_t, ui: ::core::ffi::c_uint, n: usize) -> hipError_t { unsafe { hipMemsetD32(dst, mem::transmute(ui), n) } } + +pub(crate) fn set_d8_v2(dst: hipDeviceptr_t, value: ::core::ffi::c_uchar, n: usize) -> hipError_t { + unsafe { hipMemsetD8(dst, value, n) } +} diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 766b4a5..282f8d5 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -107,6 +107,7 @@ from_cuda_nop!( *const ::core::ffi::c_char, *mut ::core::ffi::c_void, *mut *mut ::core::ffi::c_void, + u8, i32, u32, usize, diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 1f6a7ff..8efbd26 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -73,6 +73,7 @@ cuda_base::cuda_function_declarations!( cuPointerGetAttribute, cuMemGetAddressRange_v2, cuMemsetD32_v2, + cuMemsetD8_v2 ], implemented_in_function <= [ cuLaunchKernel,