mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-18 17:56:22 +03:00
Add support for some most common setp variants and fix a bug with branch conditions
This commit is contained in:
@ -355,6 +355,7 @@ pub struct SetpData {
|
|||||||
pub cmp_op: SetpCompareOp,
|
pub cmp_op: SetpCompareOp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||||
pub enum SetpCompareOp {
|
pub enum SetpCompareOp {
|
||||||
Eq,
|
Eq,
|
||||||
NotEq,
|
NotEq,
|
||||||
|
@ -44,6 +44,7 @@ test_ptx!(mov, [1u64], [1u64]);
|
|||||||
test_ptx!(mul_lo, [1u64], [2u64]);
|
test_ptx!(mul_lo, [1u64], [2u64]);
|
||||||
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
|
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
|
||||||
test_ptx!(add, [1u64], [2u64]);
|
test_ptx!(add, [1u64], [2u64]);
|
||||||
|
test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]);
|
||||||
|
|
||||||
struct DisplayError<T: Display + Debug> {
|
struct DisplayError<T: Display + Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
@ -224,20 +224,20 @@ fn normalize_predicates(
|
|||||||
ast::Statement::Label(id) => result.push(Statement::Label(id)),
|
ast::Statement::Label(id) => result.push(Statement::Label(id)),
|
||||||
ast::Statement::Instruction(pred, inst) => {
|
ast::Statement::Instruction(pred, inst) => {
|
||||||
if let Some(pred) = pred {
|
if let Some(pred) = pred {
|
||||||
let mut if_true = id_def.new_id(None);
|
let if_true = id_def.new_id(None);
|
||||||
let mut if_false = id_def.new_id(None);
|
let if_false = id_def.new_id(None);
|
||||||
if pred.not {
|
|
||||||
std::mem::swap(&mut if_true, &mut if_false);
|
|
||||||
}
|
|
||||||
let folded_bra = match &inst {
|
let folded_bra = match &inst {
|
||||||
ast::Instruction::Bra(_, arg) => Some(arg.src),
|
ast::Instruction::Bra(_, arg) => Some(arg.src),
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
let branch = BrachCondition {
|
let mut branch = BrachCondition {
|
||||||
predicate: pred.label,
|
predicate: pred.label,
|
||||||
if_true: folded_bra.unwrap_or(if_true),
|
if_true: folded_bra.unwrap_or(if_true),
|
||||||
if_false,
|
if_false,
|
||||||
};
|
};
|
||||||
|
if pred.not {
|
||||||
|
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||||
|
}
|
||||||
result.push(Statement::Conditional(branch));
|
result.push(Statement::Conditional(branch));
|
||||||
if folded_bra.is_none() {
|
if folded_bra.is_none() {
|
||||||
result.push(Statement::Label(if_true));
|
result.push(Statement::Label(if_true));
|
||||||
@ -306,9 +306,21 @@ fn insert_mem_ssa_statements(
|
|||||||
result.append(&mut post_statements);
|
result.append(&mut post_statements);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
s @ Statement::Variable(_, _, _)
|
Statement::Conditional(mut bra) => {
|
||||||
| s @ Statement::Label(_)
|
let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar(
|
||||||
| s @ Statement::Conditional(_) => result.push(s),
|
ast::ExtendedScalarType::Pred,
|
||||||
|
)));
|
||||||
|
result.push(Statement::LoadVar(
|
||||||
|
Arg2 {
|
||||||
|
dst: generated_id,
|
||||||
|
src: bra.predicate,
|
||||||
|
},
|
||||||
|
ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
|
||||||
|
));
|
||||||
|
bra.predicate = generated_id;
|
||||||
|
result.push(Statement::Conditional(bra));
|
||||||
|
}
|
||||||
|
s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) => result.push(s),
|
||||||
Statement::LoadVar(_, _)
|
Statement::LoadVar(_, _)
|
||||||
| Statement::StoreVar(_, _)
|
| Statement::StoreVar(_, _)
|
||||||
| Statement::Conversion(_)
|
| Statement::Conversion(_)
|
||||||
@ -378,7 +390,39 @@ impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenA
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => todo!(),
|
ast::Operand::RegOffset(reg, offset) => {
|
||||||
|
if let Some(typ) = t {
|
||||||
|
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
|
||||||
|
scalar
|
||||||
|
} else {
|
||||||
|
todo!()
|
||||||
|
};
|
||||||
|
let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
|
||||||
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id_constant_stmt,
|
||||||
|
typ: scalar_t,
|
||||||
|
value: offset as i128,
|
||||||
|
}));
|
||||||
|
let result_id = self.id_def.new_id(t);
|
||||||
|
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
|
||||||
|
self.func.push(Statement::Instruction(
|
||||||
|
ast::Instruction::<ExpandedArgParams>::Add(
|
||||||
|
ast::AddDetails::Int(ast::AddIntDesc {
|
||||||
|
typ: int_type,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
ast::Arg3 {
|
||||||
|
dst: result_id,
|
||||||
|
src1: reg,
|
||||||
|
src2: id_constant_stmt,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
));
|
||||||
|
result_id
|
||||||
|
} else {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -601,6 +645,12 @@ fn emit_function_body_ops(
|
|||||||
}
|
}
|
||||||
ast::AddDetails::Float(_) => todo!(),
|
ast::AddDetails::Float(_) => todo!(),
|
||||||
},
|
},
|
||||||
|
ast::Instruction::Setp(setp, arg) => {
|
||||||
|
if arg.dst2.is_some() {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
emit_setp(builder, map, setp, arg)?;
|
||||||
|
}
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
},
|
},
|
||||||
Statement::LoadVar(arg, typ) => {
|
Statement::LoadVar(arg, typ) => {
|
||||||
@ -615,6 +665,81 @@ fn emit_function_body_ops(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_setp(
|
||||||
|
builder: &mut dr::Builder,
|
||||||
|
map: &mut TypeWordMap,
|
||||||
|
setp: &ast::SetpData,
|
||||||
|
arg: &ast::Arg4<ExpandedArgParams>,
|
||||||
|
) -> Result<(), dr::Error> {
|
||||||
|
if setp.flush_to_zero {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
let result_type = map.get_or_add(builder, SpirvType::Extended(ast::ExtendedScalarType::Pred));
|
||||||
|
let result_id = Some(arg.dst1);
|
||||||
|
let operand_1 = arg.src1;
|
||||||
|
let operand_2 = arg.src2;
|
||||||
|
match (setp.cmp_op, setp.typ.kind()) {
|
||||||
|
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
|
||||||
|
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
|
||||||
|
| (ast::SetpCompareOp::Eq, ScalarKind::Byte) => {
|
||||||
|
builder.i_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
|
||||||
|
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
|
||||||
|
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
|
||||||
|
| (ast::SetpCompareOp::NotEq, ScalarKind::Byte) => {
|
||||||
|
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
|
||||||
|
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
|
||||||
|
| (ast::SetpCompareOp::Less, ScalarKind::Byte) => {
|
||||||
|
builder.u_less_than(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
|
||||||
|
builder.s_less_than(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::Less, ScalarKind::Float) => {
|
||||||
|
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
|
||||||
|
| (ast::SetpCompareOp::LessOrEq, ScalarKind::Byte) => {
|
||||||
|
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
|
||||||
|
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
|
||||||
|
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
|
||||||
|
| (ast::SetpCompareOp::Greater, ScalarKind::Byte) => {
|
||||||
|
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
|
||||||
|
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::Greater, ScalarKind::Float) => {
|
||||||
|
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
|
||||||
|
| (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Byte) => {
|
||||||
|
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
|
||||||
|
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
|
||||||
|
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||||
|
}
|
||||||
|
_ => todo!(),
|
||||||
|
}?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_mul_int(
|
fn emit_mul_int(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
@ -1397,6 +1522,18 @@ impl ast::IntType {
|
|||||||
ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false,
|
ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn try_new(t: ast::ScalarType) -> Option<Self> {
|
||||||
|
match t {
|
||||||
|
ast::ScalarType::U16 => Some(ast::IntType::U16),
|
||||||
|
ast::ScalarType::U32 => Some(ast::IntType::U32),
|
||||||
|
ast::ScalarType::U64 => Some(ast::IntType::U64),
|
||||||
|
ast::ScalarType::S16 => Some(ast::IntType::S16),
|
||||||
|
ast::ScalarType::S32 => Some(ast::IntType::S32),
|
||||||
|
ast::ScalarType::S64 => Some(ast::IntType::S64),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
||||||
|
Reference in New Issue
Block a user