diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 8f4ced5..57a7d0f 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -916,15 +916,26 @@ impl<'a> MethodEmitContext<'a> { }; match &*arguments.return_arguments { [] => {} - [name] => { - self.resolver.register(*name, llvm_call) + [name] => self.resolver.register(*name, llvm_call), + [b32, pred] => { + self.resolver.with_result(*b32, |name| unsafe { + LLVMBuildExtractValue(self.builder, llvm_call, 0, name) + }); + self.resolver.with_result(*pred, |name| unsafe { + let extracted = + LLVMBuildExtractValue(self.builder, llvm_call, 1, LLVM_UNNAMED.as_ptr()); + LLVMBuildTrunc( + self.builder, + extracted, + get_scalar_type(self.context, ast::ScalarType::Pred), + name, + ) + }); } - args => { - for (idx, arg) in args.iter().enumerate() { - self.resolver.with_result(*arg, |name| unsafe { - LLVMBuildExtractValue(self.builder, llvm_call, idx as u32, name) - }); - } + _ => { + return Err(error_todo_msg( + "Only two return arguments (.b32, .pred) currently supported", + )) } } Ok(()) @@ -1062,37 +1073,49 @@ impl<'a> MethodEmitContext<'a> { &mut self, values: Vec<(SpirvWord, ptx_parser::Type)>, ) -> Result<(), TranslateError> { - let loads = values.iter().map(|(value, type_)| { - let value = self.resolver.value(*value)?; - let type_ = get_type(self.context, type_)?; - Ok(unsafe { - LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) + let loads = values + .iter() + .map(|(value, type_)| { + let value = self.resolver.value(*value)?; + let type_ = get_type(self.context, type_)?; + Ok(unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) }) }) - }).collect::, _>>()?; + .collect::, _>>()?; match &*loads { [] => unsafe { LLVMBuildRetVoid(self.builder) }, - [value] => { - unsafe { LLVMBuildRet(self.builder, *value) } - } + [value] => unsafe { LLVMBuildRet(self.builder, *value) }, _ => { - let struct_ty = - get_struct_type(self.context, values.iter().map(|(_, type_)| type_))?; - let struct_ = loads.into_iter().enumerate().fold( - unsafe { LLVMGetPoison(struct_ty) }, - |struct_, (idx, elem)| { - unsafe { - LLVMBuildInsertValue( - self.builder, - struct_, - elem, - idx as u32, - LLVM_UNNAMED.as_ptr(), - ) - } - }, - ); - unsafe { LLVMBuildRet(self.builder, struct_) } + check_multiple_return_types(values.iter().map(|(_, type_)| type_))?; + let array_ty = + get_array_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?; + let insert_b32 = unsafe { + LLVMBuildInsertValue( + self.builder, + LLVMGetPoison(array_ty), + loads[0], + 0, + LLVM_UNNAMED.as_ptr(), + ) + }; + let zext_pred = unsafe { + LLVMBuildZExt( + self.builder, + loads[1], + get_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32))?, + LLVM_UNNAMED.as_ptr(), + ) + }; + let insert_pred = unsafe { + LLVMBuildInsertValue( + self.builder, + insert_b32, + zext_pred, + 1, + LLVM_UNNAMED.as_ptr(), + ) + }; + unsafe { LLVMBuildRet(self.builder, insert_pred) } } }; Ok(()) @@ -2732,21 +2755,31 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR } } -fn get_struct_type<'a>( +fn get_array_type<'a>( context: LLVMContextRef, - return_args: impl ExactSizeIterator, + elem_type: &'a ast::Type, + count: u64, ) -> Result { - let mut types = return_args - .map(|type_| get_type(context, type_)) - .collect::, _>>()?; - Ok(unsafe { - LLVMStructTypeInContext( - context, - types.as_mut_ptr(), - types.len() as u32, - false as i32, - ) - }) + let elem_type = get_type(context, elem_type)?; + Ok(unsafe { LLVMArrayType2(elem_type, count) }) +} + +fn check_multiple_return_types<'a>( + mut return_args: impl ExactSizeIterator, +) -> Result<(), TranslateError> { + let err_msg = "Only (.b32, .pred) multiple return types are supported"; + + let first = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?; + let second = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?; + match (first, second) { + (ast::Type::Scalar(first), ast::Type::Scalar(second)) => { + if first.size_of() != 4 || second.size_of() != 1 { + return Err(error_todo_msg(err_msg)); + } + } + _ => return Err(error_todo_msg(err_msg)), + } + Ok(()) } fn get_function_type<'a>( @@ -2758,7 +2791,10 @@ fn get_function_type<'a>( let return_type = match return_args.len() { 0 => unsafe { LLVMVoidTypeInContext(context) }, 1 => get_type(context, &return_args.next().unwrap())?, - _ => get_struct_type(context, return_args)?, + _ => { + check_multiple_return_types(return_args)?; + get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)? + }, }; Ok(unsafe { diff --git a/ptx/src/test/ll/multiple_return.ll b/ptx/src/test/ll/multiple_return.ll index fed61d6..9ec20c7 100644 --- a/ptx/src/test/ll/multiple_return.ll +++ b/ptx/src/test/ll/multiple_return.ll @@ -1,61 +1,68 @@ -define { i64, i64 } @do_something(i64 %"10") #0 { - %"42" = alloca i64, align 8, addrspace(5) - %"43" = alloca i64, align 8, addrspace(5) +define [2 x i32] @do_something(i32 %"10") #0 { + %"46" = alloca i32, align 4, addrspace(5) + %"47" = alloca i1, align 1, addrspace(5) br label %1 1: ; preds = %0 - br label %"39" + br label %"43" -"39": ; preds = %1 - %"44" = add i64 %"10", 1 - store i64 %"44", ptr addrspace(5) %"42", align 4 - %"45" = add i64 %"10", 2 - store i64 %"45", ptr addrspace(5) %"43", align 4 - %2 = load i64, ptr addrspace(5) %"42", align 4 - %3 = load i64, ptr addrspace(5) %"43", align 4 - %4 = insertvalue { i64, i64 } poison, i64 %2, 0 - %5 = insertvalue { i64, i64 } %4, i64 %3, 1 - ret { i64, i64 } %5 +"43": ; preds = %1 + %"48" = add i32 %"10", 1 + store i32 %"48", ptr addrspace(5) %"46", align 4 + store i1 true, ptr addrspace(5) %"47", align 1 + %2 = load i32, ptr addrspace(5) %"46", align 4 + %3 = load i1, ptr addrspace(5) %"47", align 1 + %4 = insertvalue [2 x i32] poison, i32 %2, 0 + %5 = zext i1 %3 to i32 + %6 = insertvalue [2 x i32] %4, i32 %5, 1 + ret [2 x i32] %6 } -define amdgpu_kernel void @multiple_return(ptr addrspace(4) byref(i64) %"46", ptr addrspace(4) byref(i64) %"47") #1 { - %"48" = alloca i64, align 8, addrspace(5) - %"49" = alloca i64, align 8, addrspace(5) - %"50" = alloca i64, align 8, addrspace(5) - %"51" = alloca i64, align 8, addrspace(5) +define amdgpu_kernel void @multiple_return(ptr addrspace(4) byref(i64) %"50", ptr addrspace(4) byref(i64) %"51") #1 { %"52" = alloca i64, align 8, addrspace(5) + %"53" = alloca i64, align 8, addrspace(5) + %"54" = alloca i32, align 4, addrspace(5) + %"55" = alloca i32, align 4, addrspace(5) + %"56" = alloca i1, align 1, addrspace(5) br label %1 1: ; preds = %0 - br label %"40" + br label %"44" -"40": ; preds = %1 - %"53" = load i64, ptr addrspace(4) %"46", align 4 - store i64 %"53", ptr addrspace(5) %"48", align 4 - %"54" = load i64, ptr addrspace(4) %"47", align 4 - store i64 %"54", ptr addrspace(5) %"49", align 4 - %"56" = load i64, ptr addrspace(5) %"48", align 4 - %"64" = inttoptr i64 %"56" to ptr - %"55" = load i64, ptr %"64", align 4 - store i64 %"55", ptr addrspace(5) %"50", align 4 - %"59" = load i64, ptr addrspace(5) %"50", align 4 - %2 = call { i64, i64 } @do_something(i64 %"59") - %"57" = extractvalue { i64, i64 } %2, 0 - %"58" = extractvalue { i64, i64 } %2, 1 - store i64 %"57", ptr addrspace(5) %"51", align 4 - store i64 %"58", ptr addrspace(5) %"52", align 4 - br label %"41" +"44": ; preds = %1 + %"57" = load i64, ptr addrspace(4) %"50", align 4 + store i64 %"57", ptr addrspace(5) %"52", align 4 + %"58" = load i64, ptr addrspace(4) %"51", align 4 + store i64 %"58", ptr addrspace(5) %"53", align 4 + %"60" = load i64, ptr addrspace(5) %"52", align 4 + %"68" = inttoptr i64 %"60" to ptr + %"59" = load i32, ptr %"68", align 4 + store i32 %"59", ptr addrspace(5) %"54", align 4 + %"63" = load i32, ptr addrspace(5) %"54", align 4 + %2 = call [2 x i32] @do_something(i32 %"63") + %"61" = extractvalue [2 x i32] %2, 0 + %3 = extractvalue [2 x i32] %2, 1 + %"62" = trunc i32 %3 to i1 + store i32 %"61", ptr addrspace(5) %"55", align 4 + store i1 %"62", ptr addrspace(5) %"56", align 1 + br label %"45" -"41": ; preds = %"40" - %"60" = load i64, ptr addrspace(5) %"49", align 4 - %"61" = load i64, ptr addrspace(5) %"51", align 4 - %"65" = inttoptr i64 %"60" to ptr - store i64 %"61", ptr %"65", align 4 - %"62" = load i64, ptr addrspace(5) %"49", align 4 - %"66" = inttoptr i64 %"62" to ptr - %"38" = getelementptr inbounds i8, ptr %"66", i64 8 - %"63" = load i64, ptr addrspace(5) %"52", align 4 - store i64 %"63", ptr %"38", align 4 +"45": ; preds = %"44" + %"64" = load i64, ptr addrspace(5) %"53", align 4 + %"65" = load i32, ptr addrspace(5) %"55", align 4 + %"69" = inttoptr i64 %"64" to ptr + store i32 %"65", ptr %"69", align 4 + %"66" = load i1, ptr addrspace(5) %"56", align 1 + br i1 %"66", label %"19", label %"20" + +"19": ; preds = %"45" + %"67" = load i64, ptr addrspace(5) %"53", align 4 + %"70" = inttoptr i64 %"67" to ptr + %"41" = getelementptr inbounds i8, ptr %"70", i64 4 + store i32 123, ptr %"41", align 4 + br label %"20" + +"20": ; preds = %"19", %"45" ret void } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index c029ec5..11820e7 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -294,7 +294,7 @@ test_ptx!( ], [1.0000001, 1.0f32] ); -test_ptx!(multiple_return, [5u64], [6u64, 7u64]); +test_ptx!(multiple_return, [5u32], [6u32, 123u32]); test_ptx!(warp_sz, [0u8], [32u8]); test_ptx!(assertfail); diff --git a/ptx/src/test/spirv_run/multiple_return.ptx b/ptx/src/test/spirv_run/multiple_return.ptx index 831a967..d6803fa 100644 --- a/ptx/src/test/spirv_run/multiple_return.ptx +++ b/ptx/src/test/spirv_run/multiple_return.ptx @@ -2,12 +2,12 @@ .target sm_30 .address_size 64 -.func (.reg .u64 a, .reg .u64 b) do_something( - .reg .u64 x +.func (.reg .u32 a, .reg .pred b) do_something( + .reg .u32 x ) { - add.u64 a, x, 1; - add.u64 b, x, 2; + add.u32 a, x, 1; + setp.eq.u32 b, 0, 0; ret; } @@ -16,18 +16,18 @@ .param .u64 output ) { - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .u64 temp; - .reg .u64 temp2; - .reg .u64 temp3; + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp; + .reg .u32 temp2; + .reg .pred temp3; - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; - ld.u64 temp, [in_addr]; - call (temp2, temp3), do_something, (temp); - st.u64 [out_addr], temp2; - st.u64 [out_addr+8], temp3; + ld.u32 temp, [in_addr]; + call (temp2, temp3), do_something, (temp); + st.u32 [out_addr], temp2; + @temp3 st.u32 [out_addr+4], 123; ret; }