Only allow (.u32, .pred) for multiple return (#417)

This commit is contained in:
Violet
2025-07-16 17:03:28 -07:00
committed by GitHub
parent 7c6b95a8e3
commit 95d66df18e
4 changed files with 153 additions and 110 deletions

View File

@ -916,15 +916,26 @@ impl<'a> MethodEmitContext<'a> {
}; };
match &*arguments.return_arguments { match &*arguments.return_arguments {
[] => {} [] => {}
[name] => { [name] => self.resolver.register(*name, llvm_call),
self.resolver.register(*name, llvm_call) [b32, pred] => {
self.resolver.with_result(*b32, |name| unsafe {
LLVMBuildExtractValue(self.builder, llvm_call, 0, name)
});
self.resolver.with_result(*pred, |name| unsafe {
let extracted =
LLVMBuildExtractValue(self.builder, llvm_call, 1, LLVM_UNNAMED.as_ptr());
LLVMBuildTrunc(
self.builder,
extracted,
get_scalar_type(self.context, ast::ScalarType::Pred),
name,
)
});
} }
args => { _ => {
for (idx, arg) in args.iter().enumerate() { return Err(error_todo_msg(
self.resolver.with_result(*arg, |name| unsafe { "Only two return arguments (.b32, .pred) currently supported",
LLVMBuildExtractValue(self.builder, llvm_call, idx as u32, name) ))
});
}
} }
} }
Ok(()) Ok(())
@ -1062,37 +1073,49 @@ impl<'a> MethodEmitContext<'a> {
&mut self, &mut self,
values: Vec<(SpirvWord, ptx_parser::Type)>, values: Vec<(SpirvWord, ptx_parser::Type)>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let loads = values.iter().map(|(value, type_)| { let loads = values
let value = self.resolver.value(*value)?; .iter()
let type_ = get_type(self.context, type_)?; .map(|(value, type_)| {
Ok(unsafe { let value = self.resolver.value(*value)?;
LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) let type_ = get_type(self.context, type_)?;
Ok(unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) })
}) })
}).collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
match &*loads { match &*loads {
[] => unsafe { LLVMBuildRetVoid(self.builder) }, [] => unsafe { LLVMBuildRetVoid(self.builder) },
[value] => { [value] => unsafe { LLVMBuildRet(self.builder, *value) },
unsafe { LLVMBuildRet(self.builder, *value) }
}
_ => { _ => {
let struct_ty = check_multiple_return_types(values.iter().map(|(_, type_)| type_))?;
get_struct_type(self.context, values.iter().map(|(_, type_)| type_))?; let array_ty =
let struct_ = loads.into_iter().enumerate().fold( get_array_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?;
unsafe { LLVMGetPoison(struct_ty) }, let insert_b32 = unsafe {
|struct_, (idx, elem)| { LLVMBuildInsertValue(
unsafe { self.builder,
LLVMBuildInsertValue( LLVMGetPoison(array_ty),
self.builder, loads[0],
struct_, 0,
elem, LLVM_UNNAMED.as_ptr(),
idx as u32, )
LLVM_UNNAMED.as_ptr(), };
) let zext_pred = unsafe {
} LLVMBuildZExt(
}, self.builder,
); loads[1],
unsafe { LLVMBuildRet(self.builder, struct_) } get_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32))?,
LLVM_UNNAMED.as_ptr(),
)
};
let insert_pred = unsafe {
LLVMBuildInsertValue(
self.builder,
insert_b32,
zext_pred,
1,
LLVM_UNNAMED.as_ptr(),
)
};
unsafe { LLVMBuildRet(self.builder, insert_pred) }
} }
}; };
Ok(()) Ok(())
@ -2732,21 +2755,31 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
} }
} }
fn get_struct_type<'a>( fn get_array_type<'a>(
context: LLVMContextRef, context: LLVMContextRef,
return_args: impl ExactSizeIterator<Item = &'a ast::Type>, elem_type: &'a ast::Type,
count: u64,
) -> Result<LLVMTypeRef, TranslateError> { ) -> Result<LLVMTypeRef, TranslateError> {
let mut types = return_args let elem_type = get_type(context, elem_type)?;
.map(|type_| get_type(context, type_)) Ok(unsafe { LLVMArrayType2(elem_type, count) })
.collect::<Result<Vec<_>, _>>()?; }
Ok(unsafe {
LLVMStructTypeInContext( fn check_multiple_return_types<'a>(
context, mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
types.as_mut_ptr(), ) -> Result<(), TranslateError> {
types.len() as u32, let err_msg = "Only (.b32, .pred) multiple return types are supported";
false as i32,
) let first = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?;
}) let second = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?;
match (first, second) {
(ast::Type::Scalar(first), ast::Type::Scalar(second)) => {
if first.size_of() != 4 || second.size_of() != 1 {
return Err(error_todo_msg(err_msg));
}
}
_ => return Err(error_todo_msg(err_msg)),
}
Ok(())
} }
fn get_function_type<'a>( fn get_function_type<'a>(
@ -2758,7 +2791,10 @@ fn get_function_type<'a>(
let return_type = match return_args.len() { let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) }, 0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, &return_args.next().unwrap())?, 1 => get_type(context, &return_args.next().unwrap())?,
_ => get_struct_type(context, return_args)?, _ => {
check_multiple_return_types(return_args)?;
get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?
},
}; };
Ok(unsafe { Ok(unsafe {

View File

@ -1,61 +1,68 @@
define { i64, i64 } @do_something(i64 %"10") #0 { define [2 x i32] @do_something(i32 %"10") #0 {
%"42" = alloca i64, align 8, addrspace(5) %"46" = alloca i32, align 4, addrspace(5)
%"43" = alloca i64, align 8, addrspace(5) %"47" = alloca i1, align 1, addrspace(5)
br label %1 br label %1
1: ; preds = %0 1: ; preds = %0
br label %"39" br label %"43"
"39": ; preds = %1 "43": ; preds = %1
%"44" = add i64 %"10", 1 %"48" = add i32 %"10", 1
store i64 %"44", ptr addrspace(5) %"42", align 4 store i32 %"48", ptr addrspace(5) %"46", align 4
%"45" = add i64 %"10", 2 store i1 true, ptr addrspace(5) %"47", align 1
store i64 %"45", ptr addrspace(5) %"43", align 4 %2 = load i32, ptr addrspace(5) %"46", align 4
%2 = load i64, ptr addrspace(5) %"42", align 4 %3 = load i1, ptr addrspace(5) %"47", align 1
%3 = load i64, ptr addrspace(5) %"43", align 4 %4 = insertvalue [2 x i32] poison, i32 %2, 0
%4 = insertvalue { i64, i64 } poison, i64 %2, 0 %5 = zext i1 %3 to i32
%5 = insertvalue { i64, i64 } %4, i64 %3, 1 %6 = insertvalue [2 x i32] %4, i32 %5, 1
ret { i64, i64 } %5 ret [2 x i32] %6
} }
define amdgpu_kernel void @multiple_return(ptr addrspace(4) byref(i64) %"46", ptr addrspace(4) byref(i64) %"47") #1 { define amdgpu_kernel void @multiple_return(ptr addrspace(4) byref(i64) %"50", ptr addrspace(4) byref(i64) %"51") #1 {
%"48" = alloca i64, align 8, addrspace(5)
%"49" = alloca i64, align 8, addrspace(5)
%"50" = alloca i64, align 8, addrspace(5)
%"51" = alloca i64, align 8, addrspace(5)
%"52" = alloca i64, align 8, addrspace(5) %"52" = alloca i64, align 8, addrspace(5)
%"53" = alloca i64, align 8, addrspace(5)
%"54" = alloca i32, align 4, addrspace(5)
%"55" = alloca i32, align 4, addrspace(5)
%"56" = alloca i1, align 1, addrspace(5)
br label %1 br label %1
1: ; preds = %0 1: ; preds = %0
br label %"40" br label %"44"
"40": ; preds = %1 "44": ; preds = %1
%"53" = load i64, ptr addrspace(4) %"46", align 4 %"57" = load i64, ptr addrspace(4) %"50", align 4
store i64 %"53", ptr addrspace(5) %"48", align 4 store i64 %"57", ptr addrspace(5) %"52", align 4
%"54" = load i64, ptr addrspace(4) %"47", align 4 %"58" = load i64, ptr addrspace(4) %"51", align 4
store i64 %"54", ptr addrspace(5) %"49", align 4 store i64 %"58", ptr addrspace(5) %"53", align 4
%"56" = load i64, ptr addrspace(5) %"48", align 4 %"60" = load i64, ptr addrspace(5) %"52", align 4
%"64" = inttoptr i64 %"56" to ptr %"68" = inttoptr i64 %"60" to ptr
%"55" = load i64, ptr %"64", align 4 %"59" = load i32, ptr %"68", align 4
store i64 %"55", ptr addrspace(5) %"50", align 4 store i32 %"59", ptr addrspace(5) %"54", align 4
%"59" = load i64, ptr addrspace(5) %"50", align 4 %"63" = load i32, ptr addrspace(5) %"54", align 4
%2 = call { i64, i64 } @do_something(i64 %"59") %2 = call [2 x i32] @do_something(i32 %"63")
%"57" = extractvalue { i64, i64 } %2, 0 %"61" = extractvalue [2 x i32] %2, 0
%"58" = extractvalue { i64, i64 } %2, 1 %3 = extractvalue [2 x i32] %2, 1
store i64 %"57", ptr addrspace(5) %"51", align 4 %"62" = trunc i32 %3 to i1
store i64 %"58", ptr addrspace(5) %"52", align 4 store i32 %"61", ptr addrspace(5) %"55", align 4
br label %"41" store i1 %"62", ptr addrspace(5) %"56", align 1
br label %"45"
"41": ; preds = %"40" "45": ; preds = %"44"
%"60" = load i64, ptr addrspace(5) %"49", align 4 %"64" = load i64, ptr addrspace(5) %"53", align 4
%"61" = load i64, ptr addrspace(5) %"51", align 4 %"65" = load i32, ptr addrspace(5) %"55", align 4
%"65" = inttoptr i64 %"60" to ptr %"69" = inttoptr i64 %"64" to ptr
store i64 %"61", ptr %"65", align 4 store i32 %"65", ptr %"69", align 4
%"62" = load i64, ptr addrspace(5) %"49", align 4 %"66" = load i1, ptr addrspace(5) %"56", align 1
%"66" = inttoptr i64 %"62" to ptr br i1 %"66", label %"19", label %"20"
%"38" = getelementptr inbounds i8, ptr %"66", i64 8
%"63" = load i64, ptr addrspace(5) %"52", align 4 "19": ; preds = %"45"
store i64 %"63", ptr %"38", align 4 %"67" = load i64, ptr addrspace(5) %"53", align 4
%"70" = inttoptr i64 %"67" to ptr
%"41" = getelementptr inbounds i8, ptr %"70", i64 4
store i32 123, ptr %"41", align 4
br label %"20"
"20": ; preds = %"19", %"45"
ret void ret void
} }

View File

@ -294,7 +294,7 @@ test_ptx!(
], ],
[1.0000001, 1.0f32] [1.0000001, 1.0f32]
); );
test_ptx!(multiple_return, [5u64], [6u64, 7u64]); test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
test_ptx!(warp_sz, [0u8], [32u8]); test_ptx!(warp_sz, [0u8], [32u8]);
test_ptx!(assertfail); test_ptx!(assertfail);

View File

@ -2,12 +2,12 @@
.target sm_30 .target sm_30
.address_size 64 .address_size 64
.func (.reg .u64 a, .reg .u64 b) do_something( .func (.reg .u32 a, .reg .pred b) do_something(
.reg .u64 x .reg .u32 x
) )
{ {
add.u64 a, x, 1; add.u32 a, x, 1;
add.u64 b, x, 2; setp.eq.u32 b, 0, 0;
ret; ret;
} }
@ -16,18 +16,18 @@
.param .u64 output .param .u64 output
) )
{ {
.reg .u64 in_addr; .reg .u64 in_addr;
.reg .u64 out_addr; .reg .u64 out_addr;
.reg .u64 temp; .reg .u32 temp;
.reg .u64 temp2; .reg .u32 temp2;
.reg .u64 temp3; .reg .pred temp3;
ld.param.u64 in_addr, [input]; ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output]; ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr]; ld.u32 temp, [in_addr];
call (temp2, temp3), do_something, (temp); call (temp2, temp3), do_something, (temp);
st.u64 [out_addr], temp2; st.u32 [out_addr], temp2;
st.u64 [out_addr+8], temp3; @temp3 st.u32 [out_addr+4], 123;
ret; ret;
} }