Handle WARP_SZ (#412)

* Add tests for `WARP_SZ`

* Handle WARP_SZ in parser
This commit is contained in:
Violet
2025-07-16 11:02:17 -07:00
committed by GitHub
parent 06b28cfec7
commit 6fb09f393a
4 changed files with 49 additions and 0 deletions

View File

@ -0,0 +1,17 @@
define amdgpu_kernel void @warp_sz(ptr addrspace(4) byref(i64) %"29", ptr addrspace(4) byref(i64) %"30") #0 {
%"31" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"28"
"28": ; preds = %1
%"32" = load i64, ptr addrspace(4) %"30", align 4
store i64 %"32", ptr addrspace(5) %"31", align 4
%"33" = load i64, ptr addrspace(5) %"31", align 4
%"34" = inttoptr i64 %"33" to ptr
store i8 32, ptr %"34", align 1
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -295,6 +295,7 @@ test_ptx!(
[1.0000001, 1.0f32]
);
test_ptx!(multiple_return, [5u64], [6u64, 7u64]);
test_ptx!(warp_sz, [0u8], [32u8]);
test_ptx!(assertfail);
// TODO: not yet supported

View File

@ -0,0 +1,16 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry warp_sz(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 out_addr;
ld.param.u64 out_addr, [output];
st.u8 [out_addr], WARP_SZ;
ret;
}

View File

@ -285,11 +285,24 @@ fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> {
.parse_next(stream)
}
fn constant<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> {
// Currently the only built-in constant is WARP_SZ
// If new ones are added, we can change this to use a Token::Constant(&str) instead
any.verify_map(|(t, _)| {
if t == Token::WarpSz {
Some(ast::ImmediateValue::U64(32))
} else {
None
}
}).parse_next(stream)
}
fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> {
alt((
int_immediate,
f32.map(ast::ImmediateValue::F32),
f64.map(ast::ImmediateValue::F64),
constant,
))
.parse_next(stream)
}
@ -1648,6 +1661,8 @@ derive_parser!(
Plus,
#[token("=")]
Eq,
#[token("WARP_SZ")]
WarpSz,
#[token(".version")]
DotVersion,
#[token(".loc")]