Add sub, min, max

This commit is contained in:
Andrzej Janik
2020-10-02 00:11:28 +02:00
parent bd3d440dba
commit 9a65dd32f5
12 changed files with 820 additions and 181 deletions

View File

@ -241,6 +241,10 @@ sub_scalar_type!(IntType {
S64 S64
}); });
sub_scalar_type!(UIntType { U8, U16, U32, U64 });
sub_scalar_type!(SIntType { S8, S16, S32, S64 });
impl IntType { impl IntType {
pub fn is_signed(self) -> bool { pub fn is_signed(self) -> bool {
match self { match self {
@ -331,7 +335,7 @@ pub enum Instruction<P: ArgParams> {
Ld(LdDetails, Arg2Ld<P>), Ld(LdDetails, Arg2Ld<P>),
Mov(MovDetails, Arg2Mov<P>), Mov(MovDetails, Arg2Mov<P>),
Mul(MulDetails, Arg3<P>), Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>), Add(ArithDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>), Setp(SetpData, Arg4Setp<P>),
SetpBool(SetpBoolData, Arg5<P>), SetpBool(SetpBoolData, Arg5<P>),
Not(NotType, Arg2<P>), Not(NotType, Arg2<P>),
@ -346,6 +350,9 @@ pub enum Instruction<P: ArgParams> {
Abs(AbsDetails, Arg2<P>), Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>), Mad(MulDetails, Arg4<P>),
Or(OrType, Arg3<P>), Or(OrType, Arg3<P>),
Sub(ArithDetails, Arg3<P>),
Min(MinMaxDetails, Arg3<P>),
Max(MinMaxDetails, Arg3<P>),
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
@ -554,11 +561,6 @@ impl MovDetails {
} }
} }
pub enum MulDetails {
Int(MulIntDesc),
Float(MulFloatDesc),
}
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct MulIntDesc { pub struct MulIntDesc {
pub typ: IntType, pub typ: IntType,
@ -572,14 +574,6 @@ pub enum MulIntControl {
Wide, Wide,
} }
#[derive(Copy, Clone)]
pub struct MulFloatDesc {
pub typ: FloatType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
}
#[derive(PartialEq, Eq, Copy, Clone)] #[derive(PartialEq, Eq, Copy, Clone)]
pub enum RoundingMode { pub enum RoundingMode {
NearestEven, NearestEven,
@ -588,23 +582,11 @@ pub enum RoundingMode {
PositiveInf, PositiveInf,
} }
pub enum AddDetails {
Int(AddIntDesc),
Float(AddFloatDesc),
}
pub struct AddIntDesc { pub struct AddIntDesc {
pub typ: IntType, pub typ: IntType,
pub saturate: bool, pub saturate: bool,
} }
pub struct AddFloatDesc {
pub typ: FloatType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
}
pub struct SetpData { pub struct SetpData {
pub typ: ScalarType, pub typ: ScalarType,
pub flush_to_zero: bool, pub flush_to_zero: bool,
@ -810,3 +792,57 @@ sub_scalar_type!(OrType {
B32, B32,
B64, B64,
}); });
#[derive(Copy, Clone)]
pub enum MulDetails {
Unsigned(MulUInt),
Signed(MulSInt),
Float(ArithFloat),
}
#[derive(Copy, Clone)]
pub struct MulUInt {
pub typ: UIntType,
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
pub struct MulSInt {
pub typ: SIntType,
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
pub enum ArithDetails {
Unsigned(UIntType),
Signed(ArithSInt),
Float(ArithFloat),
}
#[derive(Copy, Clone)]
pub struct ArithSInt {
pub typ: SIntType,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub struct ArithFloat {
pub typ: FloatType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub enum MinMaxDetails {
Signed(SIntType),
Unsigned(UIntType),
Float(MinMaxFloat),
}
#[derive(Copy, Clone)]
pub struct MinMaxFloat {
pub ftz: bool,
pub nan: bool,
pub typ: FloatType,
}

View File

@ -70,6 +70,7 @@ match {
".ltu", ".ltu",
".lu", ".lu",
".nan", ".nan",
".NaN",
".ne", ".ne",
".neu", ".neu",
".num", ".num",
@ -124,6 +125,8 @@ match {
"ld", "ld",
"mad", "mad",
"map_f64_to_f32", "map_f64_to_f32",
"max",
"min",
"mov", "mov",
"mul", "mul",
"not", "not",
@ -134,6 +137,7 @@ match {
"shr", "shr",
r"sm_[0-9]+" => ShaderModel, r"sm_[0-9]+" => ShaderModel,
"st", "st",
"sub",
"texmode_independent", "texmode_independent",
"texmode_unified", "texmode_unified",
} else { } else {
@ -153,6 +157,8 @@ ExtendedID : &'input str = {
"ld", "ld",
"mad", "mad",
"map_f64_to_f32", "map_f64_to_f32",
"max",
"min",
"mov", "mov",
"mul", "mul",
"not", "not",
@ -163,6 +169,7 @@ ExtendedID : &'input str = {
"shr", "shr",
ShaderModel, ShaderModel,
"st", "st",
"sub",
"texmode_independent", "texmode_independent",
"texmode_unified", "texmode_unified",
ID ID
@ -448,7 +455,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstCall, InstCall,
InstAbs, InstAbs,
InstMad, InstMad,
InstOr InstOr,
InstSub,
InstMin,
InstMax,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@ -570,38 +580,19 @@ MovVectorType: ast::ScalarType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = { InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a) "mul" <d:MulDetails> <a:Arg3> => ast::Instruction::Mul(d, a)
}; };
InstMulMode: ast::MulDetails = { MulDetails: ast::MulDetails = {
<ctr:MulIntControl> <t:IntType> => ast::MulDetails::Int(ast::MulIntDesc { <ctr:MulIntControl> <t:UIntType> => ast::MulDetails::Unsigned(ast::MulUInt{
typ: t, typ: t,
control: ctr control: ctr
}), }),
<r:RoundingModeFloat?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc { <ctr:MulIntControl> <t:SIntType> => ast::MulDetails::Signed(ast::MulSInt{
typ: ast::FloatType::F32, typ: t,
rounding: r, control: ctr
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
}), }),
<r:RoundingModeFloat?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc { <f:ArithFloat> => ast::MulDetails::Float(f)
typ: ast::FloatType::F64,
rounding: r,
flush_to_zero: false,
saturate: false
}),
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F16,
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
}),
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F16x2,
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
}; };
MulIntControl: ast::MulIntControl = { MulIntControl: ast::MulIntControl = {
@ -634,41 +625,23 @@ IntType : ast::IntType = {
".s64" => ast::IntType::S64, ".s64" => ast::IntType::S64,
}; };
UIntType: ast::UIntType = {
".u16" => ast::UIntType::U16,
".u32" => ast::UIntType::U32,
".u64" => ast::UIntType::U64,
};
SIntType: ast::SIntType = {
".s16" => ast::SIntType::S16,
".s32" => ast::SIntType::S32,
".s64" => ast::SIntType::S64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = { InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a) "add" <d:ArithDetails> <a:Arg3> => ast::Instruction::Add(d, a)
};
InstAddMode: ast::AddDetails = {
<t:IntType> => ast::AddDetails::Int(ast::AddIntDesc {
typ: t,
saturate: false,
}),
".sat" ".s32" => ast::AddDetails::Int(ast::AddIntDesc {
typ: ast::IntType::S32,
saturate: true,
}),
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F32,
rounding: rn,
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
}),
<rn:RoundingModeFloat?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F64,
rounding: rn,
flush_to_zero: false,
saturate: false,
}),
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?>".f16" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F16,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
}),
".rn"? ".ftz"? ".sat"? ".f16x2" => todo!()
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
@ -1041,7 +1014,7 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = { InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mad" <d:InstMulMode> <a:Arg4> => ast::Instruction::Mad(d, a), "mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a),
"mad" ".hi" ".sat" ".s32" => todo!() "mad" ".hi" ".sat" ".s32" => todo!()
}; };
@ -1063,6 +1036,84 @@ OrType: ast::OrType = {
".b64" => ast::OrType::B64, ".b64" => ast::OrType::B64,
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
InstSub: ast::Instruction<ast::ParsedArgParams<'input>> = {
"sub" <d:ArithDetails> <a:Arg3> => ast::Instruction::Sub(d, a),
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min
InstMin: ast::Instruction<ast::ParsedArgParams<'input>> = {
"min" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Min(d, a),
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max
InstMax: ast::Instruction<ast::ParsedArgParams<'input>> = {
"max" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Max(d, a),
};
MinMaxDetails: ast::MinMaxDetails = {
<t:UIntType> => ast::MinMaxDetails::Unsigned(t),
<t:SIntType> => ast::MinMaxDetails::Signed(t),
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 }
),
".f64" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 }
),
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 }
),
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
)
}
ArithDetails: ast::ArithDetails = {
<t:UIntType> => ast::ArithDetails::Unsigned(t),
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
typ: t,
saturate: false,
}),
".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S32,
saturate: true,
}),
<f:ArithFloat> => ast::ArithDetails::Float(f)
}
ArithFloat: ast::ArithFloat = {
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
typ: ast::FloatType::F32,
rounding: rn,
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
},
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
typ: ast::FloatType::F64,
rounding: rn,
flush_to_zero: false,
saturate: false,
},
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
typ: ast::FloatType::F16,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
},
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
typ: ast::FloatType::F16x2,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
},
}
Operand: ast::Operand<&'input str> = { Operand: ast::Operand<&'input str> = {
<r:ExtendedID> => ast::Operand::Reg(r), <r:ExtendedID> => ast::Operand::Reg(r),
<r:ExtendedID> "+" <o:Num> => { <r:ExtendedID> "+" <o:Num> => {

View File

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry max(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp1;
.reg .s32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp1, [in_addr];
ld.s32 temp2, [in_addr+4];
max.s32 temp1, temp1, temp2;
st.s32 [out_addr], temp1;
ret;
}

View File

@ -0,0 +1,57 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "max"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%33 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
%1 = OpFunction %void None %33
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%28 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%25 = OpConvertUToPtr %_ptr_Generic_uint %15
%14 = OpLoad %uint %25
OpStore %6 %14
%17 = OpLoad %ulong %4
%24 = OpIAdd %ulong %17 %ulong_4
%26 = OpConvertUToPtr %_ptr_Generic_uint %24
%16 = OpLoad %uint %26
OpStore %7 %16
%19 = OpLoad %uint %6
%20 = OpLoad %uint %7
%18 = OpExtInst %uint %30 s_max %19 %20
OpStore %6 %18
%21 = OpLoad %ulong %5
%22 = OpLoad %uint %6
%27 = OpConvertUToPtr %_ptr_Generic_uint %21
OpStore %27 %22
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry min(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp1;
.reg .s32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp1, [in_addr];
ld.s32 temp2, [in_addr+4];
min.s32 temp1, temp1, temp2;
st.s32 [out_addr], temp1;
ret;
}

View File

@ -0,0 +1,57 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "min"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%33 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
%1 = OpFunction %void None %33
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%28 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%25 = OpConvertUToPtr %_ptr_Generic_uint %15
%14 = OpLoad %uint %25
OpStore %6 %14
%17 = OpLoad %ulong %4
%24 = OpIAdd %ulong %17 %ulong_4
%26 = OpConvertUToPtr %_ptr_Generic_uint %24
%16 = OpLoad %uint %26
OpStore %7 %16
%19 = OpLoad %uint %6
%20 = OpLoad %uint %7
%18 = OpExtInst %uint %30 s_min %19 %20
OpStore %6 %18
%21 = OpLoad %ulong %5
%22 = OpLoad %uint %6
%27 = OpConvertUToPtr %_ptr_Generic_uint %21
OpStore %27 %22
OpReturn
OpFunctionEnd

View File

@ -70,6 +70,9 @@ test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64])
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]); test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(or, [1u64, 2u64], [3u64]); test_ptx!(or, [1u64, 2u64], [3u64]);
test_ptx!(sub, [2u64], [1u64]);
test_ptx!(min, [555i32, 444i32], [444i32]);
test_ptx!(max, [555i32, 444i32], [555i32]);
struct DisplayError<T: Debug> { struct DisplayError<T: Debug> {

View File

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry or(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp1;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp1, [in_addr];
ld.u64 temp2, [in_addr+8];
or.b64 temp1, temp1, temp2;
st.u64 [out_addr], temp1;
ret;
}

View File

@ -0,0 +1,58 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%33 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "or"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%36 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_8 = OpConstant %ulong 8
%1 = OpFunction %void None %36
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%31 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%25 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %25
OpStore %6 %14
%17 = OpLoad %ulong %4
%24 = OpIAdd %ulong %17 %ulong_8
%26 = OpConvertUToPtr %_ptr_Generic_ulong %24
%16 = OpLoad %ulong %26
OpStore %7 %16
%19 = OpLoad %ulong %6
%20 = OpLoad %ulong %7
%28 = OpCopyObject %ulong %19
%29 = OpCopyObject %ulong %20
%27 = OpBitwiseOr %ulong %28 %29
%18 = OpCopyObject %ulong %27
OpStore %6 %18
%21 = OpLoad %ulong %5
%22 = OpLoad %ulong %6
%30 = OpConvertUToPtr %_ptr_Generic_ulong %21
OpStore %30 %22
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry sub(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr];
sub.u64 temp2, temp, 1;
st.u64 [out_addr], temp2;
ret;
}

View File

@ -0,0 +1,49 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "sub"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_1 = OpConstant %ulong 1
%1 = OpFunction %void None %28
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%23 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %21
OpStore %6 %14
%17 = OpLoad %ulong %6
%16 = OpISub %ulong %17 %ulong_1
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
OpStore %22 %19
OpReturn
OpFunctionEnd

View File

@ -595,6 +595,15 @@ fn convert_to_typed_statements(
ast::Instruction::Or(d, a) => { ast::Instruction::Or(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast()))) result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
} }
ast::Instruction::Sub(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast())))
}
ast::Instruction::Min(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast())))
}
ast::Instruction::Max(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
}
}, },
Statement::Label(i) => result.push(Statement::Label(i)), Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)), Statement::Variable(v) => result.push(Statement::Variable(v)),
@ -968,62 +977,74 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn reg_offset( fn reg_offset(
&mut self, &mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>, desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: ast::Type, mut typ: ast::Type,
) -> Result<spirv::Word, TranslateError> { ) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op; let (reg, offset) = desc.op;
match desc.sema { match desc.sema {
ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => { ArgumentSemantics::Default
let scalar_t = if let ast::Type::Scalar(scalar) = typ { | ArgumentSemantics::DefaultRelaxed
scalar | ArgumentSemantics::PhysicalPointer => {
} else { if desc.sema == ArgumentSemantics::PhysicalPointer {
todo!() typ = ast::Type::Scalar(ast::ScalarType::U64);
}
let (width, kind) = match typ {
ast::Type::Scalar(scalar_t) => {
let kind = match scalar_t.kind() {
kind @ ScalarKind::Bit
| kind @ ScalarKind::Unsigned
| kind @ ScalarKind::Signed => kind,
ScalarKind::Float => return Err(TranslateError::MismatchedType),
ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
ScalarKind::Pred => return Err(TranslateError::MismatchedType),
};
(scalar_t.width(), kind)
}
_ => return Err(TranslateError::MismatchedType),
}; };
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); let arith_detail = if kind == ScalarKind::Signed {
ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::from_size(width),
saturate: false,
})
} else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
let id_constant_stmt = self.id_def.new_id(typ);
let result_id = self.id_def.new_id(typ); let result_id = self.id_def.new_id(typ);
self.func.push(Statement::Constant(ConstantDefinition { // TODO: check for edge cases around min value/max value/wrapping
dst: id_constant_stmt, if offset < 0 && kind != ScalarKind::Signed {
typ: scalar_t, self.func.push(Statement::Constant(ConstantDefinition {
value: offset as i64, dst: id_constant_stmt,
})); typ: ast::ScalarType::from_parts(width, kind),
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); value: -(offset as i64),
self.func.push(Statement::Instruction( }));
ast::Instruction::<ExpandedArgParams>::Add( self.func.push(Statement::Instruction(
ast::AddDetails::Int(ast::AddIntDesc { ast::Instruction::<ExpandedArgParams>::Sub(
typ: int_type, arith_detail,
saturate: false, ast::Arg3 {
}), dst: result_id,
ast::Arg3 { src1: reg,
dst: result_id, src2: id_constant_stmt,
src1: reg, },
src2: id_constant_stmt, ),
}, ));
), } else {
)); self.func.push(Statement::Constant(ConstantDefinition {
Ok(result_id) dst: id_constant_stmt,
} typ: ast::ScalarType::from_parts(width, kind),
ArgumentSemantics::PhysicalPointer => { value: offset as i64,
let scalar_t = ast::ScalarType::U64; }));
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); self.func.push(Statement::Instruction(
let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t)); ast::Instruction::<ExpandedArgParams>::Add(
self.func.push(Statement::Constant(ConstantDefinition { arith_detail,
dst: id_constant_stmt, ast::Arg3 {
typ: scalar_t, dst: result_id,
value: offset as i64, src1: reg,
})); src2: id_constant_stmt,
let int_type = ast::IntType::U64; },
self.func.push(Statement::Instruction( ),
ast::Instruction::<ExpandedArgParams>::Add( ));
ast::AddDetails::Int(ast::AddIntDesc { }
typ: int_type,
saturate: false,
}),
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
Ok(result_id) Ok(result_id)
} }
ArgumentSemantics::RegisterPointer => { ArgumentSemantics::RegisterPointer => {
@ -1522,14 +1543,22 @@ fn emit_function_body_ops(
} }
}, },
ast::Instruction::Mul(mul, arg) => match mul { ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Int(ref ctr) => { ast::MulDetails::Signed(ref ctr) => {
emit_mul_int(builder, map, opencl, ctr, arg)?; emit_mul_sint(builder, map, opencl, ctr, arg)?
}
ast::MulDetails::Unsigned(ref ctr) => {
emit_mul_uint(builder, map, opencl, ctr, arg)?
} }
ast::MulDetails::Float(_) => todo!(), ast::MulDetails::Float(_) => todo!(),
}, },
ast::Instruction::Add(add, arg) => match add { ast::Instruction::Add(add, arg) => match add {
ast::AddDetails::Int(ref desc) => emit_add_int(builder, map, desc, arg)?, ast::ArithDetails::Signed(ref desc) => {
ast::AddDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?, emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)?
}
ast::ArithDetails::Unsigned(ref desc) => {
emit_add_int(builder, map, (*desc).into(), false, arg)?
}
ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
}, },
ast::Instruction::Setp(setp, arg) => { ast::Instruction::Setp(setp, arg) => {
if arg.dst2.is_some() { if arg.dst2.is_some() {
@ -1581,8 +1610,11 @@ fn emit_function_body_ops(
} }
ast::Instruction::SetpBool(_, _) => todo!(), ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::Mad(mad, arg) => match mad { ast::Instruction::Mad(mad, arg) => match mad {
ast::MulDetails::Int(ref desc) => { ast::MulDetails::Signed(ref desc) => {
emit_mad_int(builder, map, opencl, desc, arg)? emit_mad_sint(builder, map, opencl, desc, arg)?
}
ast::MulDetails::Unsigned(ref desc) => {
emit_mad_uint(builder, map, opencl, desc, arg)?
} }
ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?, ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
}, },
@ -1594,6 +1626,23 @@ fn emit_function_body_ops(
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
} }
} }
ast::Instruction::Sub(d, arg) => match d {
ast::ArithDetails::Signed(desc) => {
emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?;
}
ast::ArithDetails::Unsigned(desc) => {
emit_sub_int(builder, map, (*desc).into(), false, arg)?;
}
ast::ArithDetails::Float(desc) => {
emit_sub_float(builder, map, desc, arg)?;
}
},
ast::Instruction::Min(d, a) => {
emit_min(builder, map, opencl, d, a)?;
}
ast::Instruction::Max(d, a) => {
emit_max(builder, map, opencl, d, a)?;
}
}, },
Statement::LoadVar(arg, typ) => { Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ)); let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@ -1624,11 +1673,11 @@ fn emit_function_body_ops(
Ok(()) Ok(())
} }
fn emit_mad_int( fn emit_mad_uint(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
opencl: spirv::Word, opencl: spirv::Word,
desc: &ast::MulIntDesc, desc: &ast::MulUInt,
arg: &ast::Arg4<ExpandedArgParams>, arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
@ -1638,16 +1687,38 @@ fn emit_mad_int(
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?; builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
} }
ast::MulIntControl::High => { ast::MulIntControl::High => {
let cl_op = if desc.typ.is_signed() {
spirv::CLOp::s_mad_hi
} else {
spirv::CLOp::u_mad_hi
};
builder.ext_inst( builder.ext_inst(
inst_type, inst_type,
Some(arg.dst), Some(arg.dst),
opencl, opencl,
cl_op as spirv::Word, spirv::CLOp::u_mad_hi as spirv::Word,
[arg.src1, arg.src2, arg.src3],
)?;
}
ast::MulIntControl::Wide => todo!(),
};
Ok(())
}
fn emit_mad_sint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulSInt,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
}
ast::MulIntControl::High => {
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::s_mad_hi as spirv::Word,
[arg.src1, arg.src2, arg.src3], [arg.src1, arg.src2, arg.src3],
)?; )?;
} }
@ -1659,7 +1730,7 @@ fn emit_mad_int(
fn emit_mad_float( fn emit_mad_float(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
desc: &ast::MulFloatDesc, desc: &ast::ArithFloat,
arg: &ast::Arg4<ExpandedArgParams>, arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
todo!() todo!()
@ -1668,7 +1739,7 @@ fn emit_mad_float(
fn emit_add_float( fn emit_add_float(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
desc: &ast::AddFloatDesc, desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>, arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
if desc.flush_to_zero { if desc.flush_to_zero {
@ -1680,6 +1751,67 @@ fn emit_add_float(
Ok(()) Ok(())
} }
fn emit_sub_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if desc.flush_to_zero {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
Ok(())
}
fn emit_min(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MinMaxDetails,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let cl_op = match desc {
ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
};
let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
cl_op as spirv::Word,
[arg.src1, arg.src2],
)?;
Ok(())
}
fn emit_max(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MinMaxDetails,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let cl_op = match desc {
ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
};
let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
cl_op as spirv::Word,
[arg.src1, arg.src2],
)?;
Ok(())
}
fn emit_cvt( fn emit_cvt(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
@ -1880,11 +2012,11 @@ fn emit_setp(
Ok(()) Ok(())
} }
fn emit_mul_int( fn emit_mul_sint(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
opencl: spirv::Word, opencl: spirv::Word,
desc: &ast::MulIntDesc, desc: &ast::MulSInt,
arg: &ast::Arg3<ExpandedArgParams>, arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ); let instruction_type = ast::ScalarType::from(desc.typ);
@ -1894,16 +2026,11 @@ fn emit_mul_int(
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
} }
ast::MulIntControl::High => { ast::MulIntControl::High => {
let ocl_mul_hi = if desc.typ.is_signed() {
spirv::CLOp::s_mul_hi
} else {
spirv::CLOp::u_mul_hi
};
builder.ext_inst( builder.ext_inst(
inst_type, inst_type,
Some(arg.dst), Some(arg.dst),
opencl, opencl,
ocl_mul_hi as spirv::Word, spirv::CLOp::s_mul_hi as spirv::Word,
[arg.src1, arg.src2], [arg.src1, arg.src2],
)?; )?;
} }
@ -1913,11 +2040,54 @@ fn emit_mul_int(
SpirvScalarKey::from(instruction_type), SpirvScalarKey::from(instruction_type),
]); ]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = if desc.typ.is_signed() { let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)? let instr_width = instruction_type.width();
} else { let instr_kind = instruction_type.kind();
builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)? let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
}; let dst_type_id = map.get_or_add_scalar(builder, dst_type);
struct2_bitcast_to_wide(
builder,
map,
SpirvScalarKey::from(instruction_type),
inst_type,
arg.dst,
dst_type_id,
mul,
)?;
}
}
Ok(())
}
fn emit_mul_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulUInt,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ);
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::MulIntControl::High => {
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::u_mul_hi as spirv::Word,
[arg.src1, arg.src2],
)?;
}
ast::MulIntControl::Wide => {
let mul_ext_type = SpirvType::Struct(vec![
SpirvScalarKey::from(instruction_type),
SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.width(); let instr_width = instruction_type.width();
let instr_kind = instruction_type.kind(); let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
@ -1981,14 +2151,33 @@ fn emit_abs(
fn emit_add_int( fn emit_add_int(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
ctr: &ast::AddIntDesc, typ: ast::ScalarType,
saturate: bool,
arg: &ast::Arg3<ExpandedArgParams>, arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(ctr.typ))); if saturate {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(()) Ok(())
} }
fn emit_sub_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
typ: ast::ScalarType,
saturate: bool,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if saturate {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
fn emit_implicit_conversion( fn emit_implicit_conversion(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
@ -2920,6 +3109,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
t, t,
a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?, a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?,
), ),
ast::Instruction::Sub(d, a) => {
let typ = d.get_type();
ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?)
}
ast::Instruction::Min(d, a) => {
let typ = d.get_type();
ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?)
}
ast::Instruction::Max(d, a) => {
let typ = d.get_type();
ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?)
}
}) })
} }
} }
@ -3129,6 +3330,9 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Abs(_, _) | ast::Instruction::Abs(_, _)
| ast::Instruction::Call(_) | ast::Instruction::Call(_)
| ast::Instruction::Or(_, _) | ast::Instruction::Or(_, _)
| ast::Instruction::Sub(_, _)
| ast::Instruction::Min(_, _)
| ast::Instruction::Max(_, _)
| ast::Instruction::Mad(_, _) => None, | ast::Instruction::Mad(_, _) => None,
} }
} }
@ -4049,25 +4253,33 @@ impl ast::ShrType {
} }
} }
impl ast::AddDetails { impl ast::ArithDetails {
fn get_type(&self) -> ast::Type { fn get_type(&self) -> ast::Type {
match self { ast::Type::Scalar(match self {
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), ast::ArithDetails::Unsigned(t) => (*t).into(),
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => { ast::ArithDetails::Signed(d) => d.typ.into(),
ast::Type::Scalar((*typ).into()) ast::ArithDetails::Float(d) => d.typ.into(),
} })
}
} }
} }
impl ast::MulDetails { impl ast::MulDetails {
fn get_type(&self) -> ast::Type { fn get_type(&self) -> ast::Type {
match self { ast::Type::Scalar(match self {
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), ast::MulDetails::Unsigned(d) => d.typ.into(),
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => { ast::MulDetails::Signed(d) => d.typ.into(),
ast::Type::Scalar((*typ).into()) ast::MulDetails::Float(d) => d.typ.into(),
} })
} }
}
impl ast::MinMaxDetails {
fn get_type(&self) -> ast::Type {
ast::Type::Scalar(match self {
ast::MinMaxDetails::Signed(t) => (*t).into(),
ast::MinMaxDetails::Unsigned(t) => (*t).into(),
ast::MinMaxDetails::Float(d) => d.typ.into(),
})
} }
} }
@ -4085,6 +4297,30 @@ impl ast::IntType {
} }
} }
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::SIntType::S8,
2 => ast::SIntType::S16,
4 => ast::SIntType::S32,
8 => ast::SIntType::S64,
_ => unreachable!(),
}
}
}
impl ast::UIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::UIntType::U8,
2 => ast::UIntType::U16,
4 => ast::UIntType::U32,
8 => ast::UIntType::U64,
_ => unreachable!(),
}
}
}
impl ast::LdStateSpace { impl ast::LdStateSpace {
fn to_spirv(self) -> spirv::StorageClass { fn to_spirv(self) -> spirv::StorageClass {
match self { match self {
@ -4128,7 +4364,8 @@ impl<T> ast::OperandOrVector<T> {
impl ast::MulDetails { impl ast::MulDetails {
fn is_wide(&self) -> bool { fn is_wide(&self) -> bool {
match self { match self {
ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide, ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide,
ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide,
ast::MulDetails::Float(_) => false, ast::MulDetails::Float(_) => false,
} }
} }