Add support for multiple return arguments (#406)

This commit is contained in:
Violet
2025-07-09 08:17:15 -07:00
committed by GitHub
parent fa7ecb2e02
commit 6e27f78ae7
4 changed files with 157 additions and 15 deletions

View File

@ -889,9 +889,8 @@ impl<'a> MethodEmitContext<'a> {
} }
} }
let name = match &*arguments.return_arguments { let name = match &*arguments.return_arguments {
[] => LLVM_UNNAMED.as_ptr(),
[dst] => self.resolver.get_or_add_raw(*dst), [dst] => self.resolver.get_or_add_raw(*dst),
_ => todo!(), _ => LLVM_UNNAMED.as_ptr(),
}; };
let type_ = get_function_type( let type_ = get_function_type(
self.context, self.context,
@ -905,7 +904,7 @@ impl<'a> MethodEmitContext<'a> {
.iter() .iter()
.map(|arg| self.resolver.value(*arg)) .map(|arg| self.resolver.value(*arg))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let llvm_fn = unsafe { let llvm_call = unsafe {
LLVMBuildCall2( LLVMBuildCall2(
self.builder, self.builder,
type_, type_,
@ -918,9 +917,15 @@ impl<'a> MethodEmitContext<'a> {
match &*arguments.return_arguments { match &*arguments.return_arguments {
[] => {} [] => {}
[name] => { [name] => {
self.resolver.register(*name, llvm_fn); self.resolver.register(*name, llvm_call)
}
args => {
for (idx, arg) in args.iter().enumerate() {
self.resolver.with_result(*arg, |name| unsafe {
LLVMBuildExtractValue(self.builder, llvm_call, idx as u32, name)
});
}
} }
_ => todo!(),
} }
Ok(()) Ok(())
} }
@ -1057,16 +1062,38 @@ impl<'a> MethodEmitContext<'a> {
&mut self, &mut self,
values: Vec<(SpirvWord, ptx_parser::Type)>, values: Vec<(SpirvWord, ptx_parser::Type)>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
match &*values { let loads = values.iter().map(|(value, type_)| {
let value = self.resolver.value(*value)?;
let type_ = get_type(self.context, type_)?;
Ok(unsafe {
LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr())
})
}).collect::<Result<Vec<_>, _>>()?;
match &*loads {
[] => unsafe { LLVMBuildRetVoid(self.builder) }, [] => unsafe { LLVMBuildRetVoid(self.builder) },
[(value, type_)] => { [value] => {
let value = self.resolver.value(*value)?; unsafe { LLVMBuildRet(self.builder, *value) }
let type_ = get_type(self.context, type_)?; }
let value = _ => {
unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) }; let struct_ty =
unsafe { LLVMBuildRet(self.builder, value) } get_struct_type(self.context, values.iter().map(|(_, type_)| type_))?;
let struct_ = loads.into_iter().enumerate().fold(
unsafe { LLVMGetPoison(struct_ty) },
|struct_, (idx, elem)| {
unsafe {
LLVMBuildInsertValue(
self.builder,
struct_,
elem,
idx as u32,
LLVM_UNNAMED.as_ptr(),
)
}
},
);
unsafe { LLVMBuildRet(self.builder, struct_) }
} }
_ => todo!(),
}; };
Ok(()) Ok(())
} }
@ -2705,6 +2732,23 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
} }
} }
fn get_struct_type<'a>(
context: LLVMContextRef,
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
) -> Result<LLVMTypeRef, TranslateError> {
let mut types = return_args
.map(|type_| get_type(context, type_))
.collect::<Result<Vec<_>, _>>()?;
Ok(unsafe {
LLVMStructTypeInContext(
context,
types.as_mut_ptr(),
types.len() as u32,
false as i32,
)
})
}
fn get_function_type<'a>( fn get_function_type<'a>(
context: LLVMContextRef, context: LLVMContextRef,
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>, mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
@ -2713,9 +2757,10 @@ fn get_function_type<'a>(
let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?; let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() { let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) }, 0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?, 1 => get_type(context, &return_args.next().unwrap())?,
_ => todo!(), _ => get_struct_type(context, return_args)?,
}; };
Ok(unsafe { Ok(unsafe {
LLVMFunctionType( LLVMFunctionType(
return_type, return_type,

View File

@ -0,0 +1,63 @@
define { i64, i64 } @do_something(i64 %"10") #0 {
%"42" = alloca i64, align 8, addrspace(5)
%"43" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"39"
"39": ; preds = %1
%"44" = add i64 %"10", 1
store i64 %"44", ptr addrspace(5) %"42", align 4
%"45" = add i64 %"10", 2
store i64 %"45", ptr addrspace(5) %"43", align 4
%2 = load i64, ptr addrspace(5) %"42", align 4
%3 = load i64, ptr addrspace(5) %"43", align 4
%4 = insertvalue { i64, i64 } poison, i64 %2, 0
%5 = insertvalue { i64, i64 } %4, i64 %3, 1
ret { i64, i64 } %5
}
define amdgpu_kernel void @multiple_return(ptr addrspace(4) byref(i64) %"46", ptr addrspace(4) byref(i64) %"47") #1 {
%"48" = alloca i64, align 8, addrspace(5)
%"49" = alloca i64, align 8, addrspace(5)
%"50" = alloca i64, align 8, addrspace(5)
%"51" = alloca i64, align 8, addrspace(5)
%"52" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"40"
"40": ; preds = %1
%"53" = load i64, ptr addrspace(4) %"46", align 4
store i64 %"53", ptr addrspace(5) %"48", align 4
%"54" = load i64, ptr addrspace(4) %"47", align 4
store i64 %"54", ptr addrspace(5) %"49", align 4
%"56" = load i64, ptr addrspace(5) %"48", align 4
%"64" = inttoptr i64 %"56" to ptr
%"55" = load i64, ptr %"64", align 4
store i64 %"55", ptr addrspace(5) %"50", align 4
%"59" = load i64, ptr addrspace(5) %"50", align 4
%2 = call { i64, i64 } @do_something(i64 %"59")
%"57" = extractvalue { i64, i64 } %2, 0
%"58" = extractvalue { i64, i64 } %2, 1
store i64 %"57", ptr addrspace(5) %"51", align 4
store i64 %"58", ptr addrspace(5) %"52", align 4
br label %"41"
"41": ; preds = %"40"
%"60" = load i64, ptr addrspace(5) %"49", align 4
%"61" = load i64, ptr addrspace(5) %"51", align 4
%"65" = inttoptr i64 %"60" to ptr
store i64 %"61", ptr %"65", align 4
%"62" = load i64, ptr addrspace(5) %"49", align 4
%"66" = inttoptr i64 %"62" to ptr
%"38" = getelementptr inbounds i8, ptr %"66", i64 8
%"63" = load i64, ptr addrspace(5) %"52", align 4
store i64 %"63", ptr %"38", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View File

@ -294,6 +294,7 @@ test_ptx!(
], ],
[1.0000001, 1.0f32] [1.0000001, 1.0f32]
); );
test_ptx!(multiple_return, [5u64], [6u64, 7u64]);
test_ptx!(assertfail); test_ptx!(assertfail);
// TODO: not yet supported // TODO: not yet supported

View File

@ -0,0 +1,33 @@
.version 6.5
.target sm_30
.address_size 64
.func (.reg .u64 a, .reg .u64 b) do_something(
.reg .u64 x
)
{
add.u64 a, x, 1;
add.u64 b, x, 2;
ret;
}
.visible .entry multiple_return(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 temp2;
.reg .u64 temp3;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr];
call (temp2, temp3), do_something, (temp);
st.u64 [out_addr], temp2;
st.u64 [out_addr+8], temp3;
ret;
}