From 6e27f78ae70255910db6a15aab30437d2b98ccd9 Mon Sep 17 00:00:00 2001 From: Violet Date: Wed, 9 Jul 2025 08:17:15 -0700 Subject: [PATCH] Add support for multiple return arguments (#406) --- ptx/src/pass/emit_llvm.rs | 75 +++++++++++++++++----- ptx/src/test/ll/multiple_return.ll | 63 ++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/multiple_return.ptx | 33 ++++++++++ 4 files changed, 157 insertions(+), 15 deletions(-) create mode 100644 ptx/src/test/ll/multiple_return.ll create mode 100644 ptx/src/test/spirv_run/multiple_return.ptx diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 1c2b52c..e95f269 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -889,9 +889,8 @@ impl<'a> MethodEmitContext<'a> { } } let name = match &*arguments.return_arguments { - [] => LLVM_UNNAMED.as_ptr(), [dst] => self.resolver.get_or_add_raw(*dst), - _ => todo!(), + _ => LLVM_UNNAMED.as_ptr(), }; let type_ = get_function_type( self.context, @@ -905,7 +904,7 @@ impl<'a> MethodEmitContext<'a> { .iter() .map(|arg| self.resolver.value(*arg)) .collect::, _>>()?; - let llvm_fn = unsafe { + let llvm_call = unsafe { LLVMBuildCall2( self.builder, type_, @@ -918,9 +917,15 @@ impl<'a> MethodEmitContext<'a> { match &*arguments.return_arguments { [] => {} [name] => { - self.resolver.register(*name, llvm_fn); + self.resolver.register(*name, llvm_call) + } + args => { + for (idx, arg) in args.iter().enumerate() { + self.resolver.with_result(*arg, |name| unsafe { + LLVMBuildExtractValue(self.builder, llvm_call, idx as u32, name) + }); + } } - _ => todo!(), } Ok(()) } @@ -1057,16 +1062,38 @@ impl<'a> MethodEmitContext<'a> { &mut self, values: Vec<(SpirvWord, ptx_parser::Type)>, ) -> Result<(), TranslateError> { - match &*values { + let loads = values.iter().map(|(value, type_)| { + let value = self.resolver.value(*value)?; + let type_ = get_type(self.context, type_)?; + Ok(unsafe { + LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) + }) + }).collect::, _>>()?; + + match &*loads { [] => unsafe { LLVMBuildRetVoid(self.builder) }, - [(value, type_)] => { - let value = self.resolver.value(*value)?; - let type_ = get_type(self.context, type_)?; - let value = - unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) }; - unsafe { LLVMBuildRet(self.builder, value) } + [value] => { + unsafe { LLVMBuildRet(self.builder, *value) } + } + _ => { + let struct_ty = + get_struct_type(self.context, values.iter().map(|(_, type_)| type_))?; + let struct_ = loads.into_iter().enumerate().fold( + unsafe { LLVMGetPoison(struct_ty) }, + |struct_, (idx, elem)| { + unsafe { + LLVMBuildInsertValue( + self.builder, + struct_, + elem, + idx as u32, + LLVM_UNNAMED.as_ptr(), + ) + } + }, + ); + unsafe { LLVMBuildRet(self.builder, struct_) } } - _ => todo!(), }; Ok(()) } @@ -2705,6 +2732,23 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR } } +fn get_struct_type<'a>( + context: LLVMContextRef, + return_args: impl ExactSizeIterator, +) -> Result { + let mut types = return_args + .map(|type_| get_type(context, type_)) + .collect::, _>>()?; + Ok(unsafe { + LLVMStructTypeInContext( + context, + types.as_mut_ptr(), + types.len() as u32, + false as i32, + ) + }) +} + fn get_function_type<'a>( context: LLVMContextRef, mut return_args: impl ExactSizeIterator, @@ -2713,9 +2757,10 @@ fn get_function_type<'a>( let mut input_args = input_args.collect::, _>>()?; let return_type = match return_args.len() { 0 => unsafe { LLVMVoidTypeInContext(context) }, - 1 => get_type(context, return_args.next().unwrap())?, - _ => todo!(), + 1 => get_type(context, &return_args.next().unwrap())?, + _ => get_struct_type(context, return_args)?, }; + Ok(unsafe { LLVMFunctionType( return_type, diff --git a/ptx/src/test/ll/multiple_return.ll b/ptx/src/test/ll/multiple_return.ll new file mode 100644 index 0000000..fed61d6 --- /dev/null +++ b/ptx/src/test/ll/multiple_return.ll @@ -0,0 +1,63 @@ +define { i64, i64 } @do_something(i64 %"10") #0 { + %"42" = alloca i64, align 8, addrspace(5) + %"43" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"39" + +"39": ; preds = %1 + %"44" = add i64 %"10", 1 + store i64 %"44", ptr addrspace(5) %"42", align 4 + %"45" = add i64 %"10", 2 + store i64 %"45", ptr addrspace(5) %"43", align 4 + %2 = load i64, ptr addrspace(5) %"42", align 4 + %3 = load i64, ptr addrspace(5) %"43", align 4 + %4 = insertvalue { i64, i64 } poison, i64 %2, 0 + %5 = insertvalue { i64, i64 } %4, i64 %3, 1 + ret { i64, i64 } %5 +} + +define amdgpu_kernel void @multiple_return(ptr addrspace(4) byref(i64) %"46", ptr addrspace(4) byref(i64) %"47") #1 { + %"48" = alloca i64, align 8, addrspace(5) + %"49" = alloca i64, align 8, addrspace(5) + %"50" = alloca i64, align 8, addrspace(5) + %"51" = alloca i64, align 8, addrspace(5) + %"52" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"40" + +"40": ; preds = %1 + %"53" = load i64, ptr addrspace(4) %"46", align 4 + store i64 %"53", ptr addrspace(5) %"48", align 4 + %"54" = load i64, ptr addrspace(4) %"47", align 4 + store i64 %"54", ptr addrspace(5) %"49", align 4 + %"56" = load i64, ptr addrspace(5) %"48", align 4 + %"64" = inttoptr i64 %"56" to ptr + %"55" = load i64, ptr %"64", align 4 + store i64 %"55", ptr addrspace(5) %"50", align 4 + %"59" = load i64, ptr addrspace(5) %"50", align 4 + %2 = call { i64, i64 } @do_something(i64 %"59") + %"57" = extractvalue { i64, i64 } %2, 0 + %"58" = extractvalue { i64, i64 } %2, 1 + store i64 %"57", ptr addrspace(5) %"51", align 4 + store i64 %"58", ptr addrspace(5) %"52", align 4 + br label %"41" + +"41": ; preds = %"40" + %"60" = load i64, ptr addrspace(5) %"49", align 4 + %"61" = load i64, ptr addrspace(5) %"51", align 4 + %"65" = inttoptr i64 %"60" to ptr + store i64 %"61", ptr %"65", align 4 + %"62" = load i64, ptr addrspace(5) %"49", align 4 + %"66" = inttoptr i64 %"62" to ptr + %"38" = getelementptr inbounds i8, ptr %"66", i64 8 + %"63" = load i64, ptr addrspace(5) %"52", align 4 + store i64 %"63", ptr %"38", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 8cd88eb..bd78639 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -294,6 +294,7 @@ test_ptx!( ], [1.0000001, 1.0f32] ); +test_ptx!(multiple_return, [5u64], [6u64, 7u64]); test_ptx!(assertfail); // TODO: not yet supported diff --git a/ptx/src/test/spirv_run/multiple_return.ptx b/ptx/src/test/spirv_run/multiple_return.ptx new file mode 100644 index 0000000..831a967 --- /dev/null +++ b/ptx/src/test/spirv_run/multiple_return.ptx @@ -0,0 +1,33 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.func (.reg .u64 a, .reg .u64 b) do_something( + .reg .u64 x +) +{ + add.u64 a, x, 1; + add.u64 b, x, 2; + ret; +} + +.visible .entry multiple_return( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + .reg .u64 temp3; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + call (temp2, temp3), do_something, (temp); + st.u64 [out_addr], temp2; + st.u64 [out_addr+8], temp3; + ret; +}