diff --git a/ptx/src/test/ll/warp_sz.ll b/ptx/src/test/ll/warp_sz.ll new file mode 100644 index 0000000..aac6b34 --- /dev/null +++ b/ptx/src/test/ll/warp_sz.ll @@ -0,0 +1,17 @@ +define amdgpu_kernel void @warp_sz(ptr addrspace(4) byref(i64) %"29", ptr addrspace(4) byref(i64) %"30") #0 { + %"31" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"28" + +"28": ; preds = %1 + %"32" = load i64, ptr addrspace(4) %"30", align 4 + store i64 %"32", ptr addrspace(5) %"31", align 4 + %"33" = load i64, ptr addrspace(5) %"31", align 4 + %"34" = inttoptr i64 %"33" to ptr + store i8 32, ptr %"34", align 1 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index bd78639..c029ec5 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -295,6 +295,7 @@ test_ptx!( [1.0000001, 1.0f32] ); test_ptx!(multiple_return, [5u64], [6u64, 7u64]); +test_ptx!(warp_sz, [0u8], [32u8]); test_ptx!(assertfail); // TODO: not yet supported diff --git a/ptx/src/test/spirv_run/warp_sz.ptx b/ptx/src/test/spirv_run/warp_sz.ptx new file mode 100644 index 0000000..641cda5 --- /dev/null +++ b/ptx/src/test/spirv_run/warp_sz.ptx @@ -0,0 +1,16 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry warp_sz( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 out_addr; + + ld.param.u64 out_addr, [output]; + st.u8 [out_addr], WARP_SZ; + + ret; +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index da14406..a377387 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -285,11 +285,24 @@ fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { .parse_next(stream) } +fn constant<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + // Currently the only built-in constant is WARP_SZ + // If new ones are added, we can change this to use a Token::Constant(&str) instead + any.verify_map(|(t, _)| { + if t == Token::WarpSz { + Some(ast::ImmediateValue::U64(32)) + } else { + None + } + }).parse_next(stream) +} + fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { alt(( int_immediate, f32.map(ast::ImmediateValue::F32), f64.map(ast::ImmediateValue::F64), + constant, )) .parse_next(stream) } @@ -1648,6 +1661,8 @@ derive_parser!( Plus, #[token("=")] Eq, + #[token("WARP_SZ")] + WarpSz, #[token(".version")] DotVersion, #[token(".loc")]