mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-20 18:56:24 +03:00
Only allow (.u32, .pred) for multiple return (#417)
This commit is contained in:
@ -916,15 +916,26 @@ impl<'a> MethodEmitContext<'a> {
|
|||||||
};
|
};
|
||||||
match &*arguments.return_arguments {
|
match &*arguments.return_arguments {
|
||||||
[] => {}
|
[] => {}
|
||||||
[name] => {
|
[name] => self.resolver.register(*name, llvm_call),
|
||||||
self.resolver.register(*name, llvm_call)
|
[b32, pred] => {
|
||||||
|
self.resolver.with_result(*b32, |name| unsafe {
|
||||||
|
LLVMBuildExtractValue(self.builder, llvm_call, 0, name)
|
||||||
|
});
|
||||||
|
self.resolver.with_result(*pred, |name| unsafe {
|
||||||
|
let extracted =
|
||||||
|
LLVMBuildExtractValue(self.builder, llvm_call, 1, LLVM_UNNAMED.as_ptr());
|
||||||
|
LLVMBuildTrunc(
|
||||||
|
self.builder,
|
||||||
|
extracted,
|
||||||
|
get_scalar_type(self.context, ast::ScalarType::Pred),
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
});
|
||||||
}
|
}
|
||||||
args => {
|
_ => {
|
||||||
for (idx, arg) in args.iter().enumerate() {
|
return Err(error_todo_msg(
|
||||||
self.resolver.with_result(*arg, |name| unsafe {
|
"Only two return arguments (.b32, .pred) currently supported",
|
||||||
LLVMBuildExtractValue(self.builder, llvm_call, idx as u32, name)
|
))
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1062,37 +1073,49 @@ impl<'a> MethodEmitContext<'a> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
values: Vec<(SpirvWord, ptx_parser::Type)>,
|
values: Vec<(SpirvWord, ptx_parser::Type)>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
let loads = values.iter().map(|(value, type_)| {
|
let loads = values
|
||||||
let value = self.resolver.value(*value)?;
|
.iter()
|
||||||
let type_ = get_type(self.context, type_)?;
|
.map(|(value, type_)| {
|
||||||
Ok(unsafe {
|
let value = self.resolver.value(*value)?;
|
||||||
LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr())
|
let type_ = get_type(self.context, type_)?;
|
||||||
|
Ok(unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) })
|
||||||
})
|
})
|
||||||
}).collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
match &*loads {
|
match &*loads {
|
||||||
[] => unsafe { LLVMBuildRetVoid(self.builder) },
|
[] => unsafe { LLVMBuildRetVoid(self.builder) },
|
||||||
[value] => {
|
[value] => unsafe { LLVMBuildRet(self.builder, *value) },
|
||||||
unsafe { LLVMBuildRet(self.builder, *value) }
|
|
||||||
}
|
|
||||||
_ => {
|
_ => {
|
||||||
let struct_ty =
|
check_multiple_return_types(values.iter().map(|(_, type_)| type_))?;
|
||||||
get_struct_type(self.context, values.iter().map(|(_, type_)| type_))?;
|
let array_ty =
|
||||||
let struct_ = loads.into_iter().enumerate().fold(
|
get_array_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?;
|
||||||
unsafe { LLVMGetPoison(struct_ty) },
|
let insert_b32 = unsafe {
|
||||||
|struct_, (idx, elem)| {
|
LLVMBuildInsertValue(
|
||||||
unsafe {
|
self.builder,
|
||||||
LLVMBuildInsertValue(
|
LLVMGetPoison(array_ty),
|
||||||
self.builder,
|
loads[0],
|
||||||
struct_,
|
0,
|
||||||
elem,
|
LLVM_UNNAMED.as_ptr(),
|
||||||
idx as u32,
|
)
|
||||||
LLVM_UNNAMED.as_ptr(),
|
};
|
||||||
)
|
let zext_pred = unsafe {
|
||||||
}
|
LLVMBuildZExt(
|
||||||
},
|
self.builder,
|
||||||
);
|
loads[1],
|
||||||
unsafe { LLVMBuildRet(self.builder, struct_) }
|
get_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32))?,
|
||||||
|
LLVM_UNNAMED.as_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let insert_pred = unsafe {
|
||||||
|
LLVMBuildInsertValue(
|
||||||
|
self.builder,
|
||||||
|
insert_b32,
|
||||||
|
zext_pred,
|
||||||
|
1,
|
||||||
|
LLVM_UNNAMED.as_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
unsafe { LLVMBuildRet(self.builder, insert_pred) }
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -2732,21 +2755,31 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_struct_type<'a>(
|
fn get_array_type<'a>(
|
||||||
context: LLVMContextRef,
|
context: LLVMContextRef,
|
||||||
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
elem_type: &'a ast::Type,
|
||||||
|
count: u64,
|
||||||
) -> Result<LLVMTypeRef, TranslateError> {
|
) -> Result<LLVMTypeRef, TranslateError> {
|
||||||
let mut types = return_args
|
let elem_type = get_type(context, elem_type)?;
|
||||||
.map(|type_| get_type(context, type_))
|
Ok(unsafe { LLVMArrayType2(elem_type, count) })
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
}
|
||||||
Ok(unsafe {
|
|
||||||
LLVMStructTypeInContext(
|
fn check_multiple_return_types<'a>(
|
||||||
context,
|
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||||
types.as_mut_ptr(),
|
) -> Result<(), TranslateError> {
|
||||||
types.len() as u32,
|
let err_msg = "Only (.b32, .pred) multiple return types are supported";
|
||||||
false as i32,
|
|
||||||
)
|
let first = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?;
|
||||||
})
|
let second = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?;
|
||||||
|
match (first, second) {
|
||||||
|
(ast::Type::Scalar(first), ast::Type::Scalar(second)) => {
|
||||||
|
if first.size_of() != 4 || second.size_of() != 1 {
|
||||||
|
return Err(error_todo_msg(err_msg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(error_todo_msg(err_msg)),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_function_type<'a>(
|
fn get_function_type<'a>(
|
||||||
@ -2758,7 +2791,10 @@ fn get_function_type<'a>(
|
|||||||
let return_type = match return_args.len() {
|
let return_type = match return_args.len() {
|
||||||
0 => unsafe { LLVMVoidTypeInContext(context) },
|
0 => unsafe { LLVMVoidTypeInContext(context) },
|
||||||
1 => get_type(context, &return_args.next().unwrap())?,
|
1 => get_type(context, &return_args.next().unwrap())?,
|
||||||
_ => get_struct_type(context, return_args)?,
|
_ => {
|
||||||
|
check_multiple_return_types(return_args)?;
|
||||||
|
get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(unsafe {
|
Ok(unsafe {
|
||||||
|
@ -1,61 +1,68 @@
|
|||||||
define { i64, i64 } @do_something(i64 %"10") #0 {
|
define [2 x i32] @do_something(i32 %"10") #0 {
|
||||||
%"42" = alloca i64, align 8, addrspace(5)
|
%"46" = alloca i32, align 4, addrspace(5)
|
||||||
%"43" = alloca i64, align 8, addrspace(5)
|
%"47" = alloca i1, align 1, addrspace(5)
|
||||||
br label %1
|
br label %1
|
||||||
|
|
||||||
1: ; preds = %0
|
1: ; preds = %0
|
||||||
br label %"39"
|
br label %"43"
|
||||||
|
|
||||||
"39": ; preds = %1
|
"43": ; preds = %1
|
||||||
%"44" = add i64 %"10", 1
|
%"48" = add i32 %"10", 1
|
||||||
store i64 %"44", ptr addrspace(5) %"42", align 4
|
store i32 %"48", ptr addrspace(5) %"46", align 4
|
||||||
%"45" = add i64 %"10", 2
|
store i1 true, ptr addrspace(5) %"47", align 1
|
||||||
store i64 %"45", ptr addrspace(5) %"43", align 4
|
%2 = load i32, ptr addrspace(5) %"46", align 4
|
||||||
%2 = load i64, ptr addrspace(5) %"42", align 4
|
%3 = load i1, ptr addrspace(5) %"47", align 1
|
||||||
%3 = load i64, ptr addrspace(5) %"43", align 4
|
%4 = insertvalue [2 x i32] poison, i32 %2, 0
|
||||||
%4 = insertvalue { i64, i64 } poison, i64 %2, 0
|
%5 = zext i1 %3 to i32
|
||||||
%5 = insertvalue { i64, i64 } %4, i64 %3, 1
|
%6 = insertvalue [2 x i32] %4, i32 %5, 1
|
||||||
ret { i64, i64 } %5
|
ret [2 x i32] %6
|
||||||
}
|
}
|
||||||
|
|
||||||
define amdgpu_kernel void @multiple_return(ptr addrspace(4) byref(i64) %"46", ptr addrspace(4) byref(i64) %"47") #1 {
|
define amdgpu_kernel void @multiple_return(ptr addrspace(4) byref(i64) %"50", ptr addrspace(4) byref(i64) %"51") #1 {
|
||||||
%"48" = alloca i64, align 8, addrspace(5)
|
|
||||||
%"49" = alloca i64, align 8, addrspace(5)
|
|
||||||
%"50" = alloca i64, align 8, addrspace(5)
|
|
||||||
%"51" = alloca i64, align 8, addrspace(5)
|
|
||||||
%"52" = alloca i64, align 8, addrspace(5)
|
%"52" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"53" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"54" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"55" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"56" = alloca i1, align 1, addrspace(5)
|
||||||
br label %1
|
br label %1
|
||||||
|
|
||||||
1: ; preds = %0
|
1: ; preds = %0
|
||||||
br label %"40"
|
br label %"44"
|
||||||
|
|
||||||
"40": ; preds = %1
|
"44": ; preds = %1
|
||||||
%"53" = load i64, ptr addrspace(4) %"46", align 4
|
%"57" = load i64, ptr addrspace(4) %"50", align 4
|
||||||
store i64 %"53", ptr addrspace(5) %"48", align 4
|
store i64 %"57", ptr addrspace(5) %"52", align 4
|
||||||
%"54" = load i64, ptr addrspace(4) %"47", align 4
|
%"58" = load i64, ptr addrspace(4) %"51", align 4
|
||||||
store i64 %"54", ptr addrspace(5) %"49", align 4
|
store i64 %"58", ptr addrspace(5) %"53", align 4
|
||||||
%"56" = load i64, ptr addrspace(5) %"48", align 4
|
%"60" = load i64, ptr addrspace(5) %"52", align 4
|
||||||
%"64" = inttoptr i64 %"56" to ptr
|
%"68" = inttoptr i64 %"60" to ptr
|
||||||
%"55" = load i64, ptr %"64", align 4
|
%"59" = load i32, ptr %"68", align 4
|
||||||
store i64 %"55", ptr addrspace(5) %"50", align 4
|
store i32 %"59", ptr addrspace(5) %"54", align 4
|
||||||
%"59" = load i64, ptr addrspace(5) %"50", align 4
|
%"63" = load i32, ptr addrspace(5) %"54", align 4
|
||||||
%2 = call { i64, i64 } @do_something(i64 %"59")
|
%2 = call [2 x i32] @do_something(i32 %"63")
|
||||||
%"57" = extractvalue { i64, i64 } %2, 0
|
%"61" = extractvalue [2 x i32] %2, 0
|
||||||
%"58" = extractvalue { i64, i64 } %2, 1
|
%3 = extractvalue [2 x i32] %2, 1
|
||||||
store i64 %"57", ptr addrspace(5) %"51", align 4
|
%"62" = trunc i32 %3 to i1
|
||||||
store i64 %"58", ptr addrspace(5) %"52", align 4
|
store i32 %"61", ptr addrspace(5) %"55", align 4
|
||||||
br label %"41"
|
store i1 %"62", ptr addrspace(5) %"56", align 1
|
||||||
|
br label %"45"
|
||||||
|
|
||||||
"41": ; preds = %"40"
|
"45": ; preds = %"44"
|
||||||
%"60" = load i64, ptr addrspace(5) %"49", align 4
|
%"64" = load i64, ptr addrspace(5) %"53", align 4
|
||||||
%"61" = load i64, ptr addrspace(5) %"51", align 4
|
%"65" = load i32, ptr addrspace(5) %"55", align 4
|
||||||
%"65" = inttoptr i64 %"60" to ptr
|
%"69" = inttoptr i64 %"64" to ptr
|
||||||
store i64 %"61", ptr %"65", align 4
|
store i32 %"65", ptr %"69", align 4
|
||||||
%"62" = load i64, ptr addrspace(5) %"49", align 4
|
%"66" = load i1, ptr addrspace(5) %"56", align 1
|
||||||
%"66" = inttoptr i64 %"62" to ptr
|
br i1 %"66", label %"19", label %"20"
|
||||||
%"38" = getelementptr inbounds i8, ptr %"66", i64 8
|
|
||||||
%"63" = load i64, ptr addrspace(5) %"52", align 4
|
"19": ; preds = %"45"
|
||||||
store i64 %"63", ptr %"38", align 4
|
%"67" = load i64, ptr addrspace(5) %"53", align 4
|
||||||
|
%"70" = inttoptr i64 %"67" to ptr
|
||||||
|
%"41" = getelementptr inbounds i8, ptr %"70", i64 4
|
||||||
|
store i32 123, ptr %"41", align 4
|
||||||
|
br label %"20"
|
||||||
|
|
||||||
|
"20": ; preds = %"19", %"45"
|
||||||
ret void
|
ret void
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -294,7 +294,7 @@ test_ptx!(
|
|||||||
],
|
],
|
||||||
[1.0000001, 1.0f32]
|
[1.0000001, 1.0f32]
|
||||||
);
|
);
|
||||||
test_ptx!(multiple_return, [5u64], [6u64, 7u64]);
|
test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
|
||||||
test_ptx!(warp_sz, [0u8], [32u8]);
|
test_ptx!(warp_sz, [0u8], [32u8]);
|
||||||
|
|
||||||
test_ptx!(assertfail);
|
test_ptx!(assertfail);
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
.target sm_30
|
.target sm_30
|
||||||
.address_size 64
|
.address_size 64
|
||||||
|
|
||||||
.func (.reg .u64 a, .reg .u64 b) do_something(
|
.func (.reg .u32 a, .reg .pred b) do_something(
|
||||||
.reg .u64 x
|
.reg .u32 x
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
add.u64 a, x, 1;
|
add.u32 a, x, 1;
|
||||||
add.u64 b, x, 2;
|
setp.eq.u32 b, 0, 0;
|
||||||
ret;
|
ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -16,18 +16,18 @@
|
|||||||
.param .u64 output
|
.param .u64 output
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
.reg .u64 in_addr;
|
.reg .u64 in_addr;
|
||||||
.reg .u64 out_addr;
|
.reg .u64 out_addr;
|
||||||
.reg .u64 temp;
|
.reg .u32 temp;
|
||||||
.reg .u64 temp2;
|
.reg .u32 temp2;
|
||||||
.reg .u64 temp3;
|
.reg .pred temp3;
|
||||||
|
|
||||||
ld.param.u64 in_addr, [input];
|
ld.param.u64 in_addr, [input];
|
||||||
ld.param.u64 out_addr, [output];
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
ld.u64 temp, [in_addr];
|
ld.u32 temp, [in_addr];
|
||||||
call (temp2, temp3), do_something, (temp);
|
call (temp2, temp3), do_something, (temp);
|
||||||
st.u64 [out_addr], temp2;
|
st.u32 [out_addr], temp2;
|
||||||
st.u64 [out_addr+8], temp3;
|
@temp3 st.u32 [out_addr+4], 123;
|
||||||
ret;
|
ret;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user