mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-17 09:16:23 +03:00
Implement ftz handling through Khronos extensions
This commit is contained in:
@ -518,13 +518,13 @@ pub struct MadFloatDesc {}
|
|||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct AbsDetails {
|
pub struct AbsDetails {
|
||||||
pub flush_to_zero: bool,
|
pub flush_to_zero: Option<bool>,
|
||||||
pub typ: ScalarType,
|
pub typ: ScalarType,
|
||||||
}
|
}
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct RcpDetails {
|
pub struct RcpDetails {
|
||||||
pub rounding: Option<RoundingMode>,
|
pub rounding: Option<RoundingMode>,
|
||||||
pub flush_to_zero: bool,
|
pub flush_to_zero: Option<bool>,
|
||||||
pub is_f64: bool,
|
pub is_f64: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -769,7 +769,7 @@ pub struct AddIntDesc {
|
|||||||
|
|
||||||
pub struct SetpData {
|
pub struct SetpData {
|
||||||
pub typ: ScalarType,
|
pub typ: ScalarType,
|
||||||
pub flush_to_zero: bool,
|
pub flush_to_zero: Option<bool>,
|
||||||
pub cmp_op: SetpCompareOp,
|
pub cmp_op: SetpCompareOp,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -799,7 +799,7 @@ pub enum SetpBoolPostOp {
|
|||||||
|
|
||||||
pub struct SetpBoolData {
|
pub struct SetpBoolData {
|
||||||
pub typ: ScalarType,
|
pub typ: ScalarType,
|
||||||
pub flush_to_zero: bool,
|
pub flush_to_zero: Option<bool>,
|
||||||
pub cmp_op: SetpCompareOp,
|
pub cmp_op: SetpCompareOp,
|
||||||
pub bool_op: SetpBoolPostOp,
|
pub bool_op: SetpBoolPostOp,
|
||||||
}
|
}
|
||||||
@ -831,7 +831,7 @@ pub struct CvtIntToIntDesc {
|
|||||||
|
|
||||||
pub struct CvtDesc<Dst, Src> {
|
pub struct CvtDesc<Dst, Src> {
|
||||||
pub rounding: Option<RoundingMode>,
|
pub rounding: Option<RoundingMode>,
|
||||||
pub flush_to_zero: bool,
|
pub flush_to_zero: Option<bool>,
|
||||||
pub saturate: bool,
|
pub saturate: bool,
|
||||||
pub dst: Dst,
|
pub dst: Dst,
|
||||||
pub src: Src,
|
pub src: Src,
|
||||||
@ -873,7 +873,7 @@ impl CvtDetails {
|
|||||||
dst,
|
dst,
|
||||||
src,
|
src,
|
||||||
saturate,
|
saturate,
|
||||||
flush_to_zero,
|
flush_to_zero: Some(flush_to_zero),
|
||||||
rounding: Some(rounding),
|
rounding: Some(rounding),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -893,7 +893,7 @@ impl CvtDetails {
|
|||||||
dst,
|
dst,
|
||||||
src,
|
src,
|
||||||
saturate,
|
saturate,
|
||||||
flush_to_zero,
|
flush_to_zero: Some(flush_to_zero),
|
||||||
rounding: Some(rounding),
|
rounding: Some(rounding),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -1009,7 +1009,7 @@ pub struct ArithSInt {
|
|||||||
pub struct ArithFloat {
|
pub struct ArithFloat {
|
||||||
pub typ: FloatType,
|
pub typ: FloatType,
|
||||||
pub rounding: Option<RoundingMode>,
|
pub rounding: Option<RoundingMode>,
|
||||||
pub flush_to_zero: bool,
|
pub flush_to_zero: Option<bool>,
|
||||||
pub saturate: bool,
|
pub saturate: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1022,7 +1022,7 @@ pub enum MinMaxDetails {
|
|||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
pub struct MinMaxFloat {
|
pub struct MinMaxFloat {
|
||||||
pub ftz: bool,
|
pub flush_to_zero: Option<bool>,
|
||||||
pub nan: bool,
|
pub nan: bool,
|
||||||
pub typ: FloatType,
|
pub typ: FloatType,
|
||||||
}
|
}
|
||||||
|
@ -740,17 +740,29 @@ InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
SetpMode: ast::SetpData = {
|
SetpMode: ast::SetpData = {
|
||||||
<cmp_op:SetpCompareOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpData{
|
<cmp_op:SetpCompareOp> <t:SetpTypeNoF32> => ast::SetpData {
|
||||||
typ: t,
|
typ: t,
|
||||||
flush_to_zero: ftz.is_some(),
|
flush_to_zero: None,
|
||||||
|
cmp_op: cmp_op,
|
||||||
|
},
|
||||||
|
<cmp_op:SetpCompareOp> <ftz:".ftz"?> ".f32" => ast::SetpData {
|
||||||
|
typ: ast::ScalarType::F32,
|
||||||
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
cmp_op: cmp_op,
|
cmp_op: cmp_op,
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
SetpBoolMode: ast::SetpBoolData = {
|
SetpBoolMode: ast::SetpBoolData = {
|
||||||
<cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpBoolData{
|
<cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <t:SetpTypeNoF32> => ast::SetpBoolData {
|
||||||
typ: t,
|
typ: t,
|
||||||
flush_to_zero: ftz.is_some(),
|
flush_to_zero: None,
|
||||||
|
cmp_op: cmp_op,
|
||||||
|
bool_op: bool_op,
|
||||||
|
},
|
||||||
|
<cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <ftz:".ftz"?> ".f32" => ast::SetpBoolData {
|
||||||
|
typ: ast::ScalarType::F32,
|
||||||
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
cmp_op: cmp_op,
|
cmp_op: cmp_op,
|
||||||
bool_op: bool_op,
|
bool_op: bool_op,
|
||||||
}
|
}
|
||||||
@ -783,7 +795,7 @@ SetpBoolPostOp: ast::SetpBoolPostOp = {
|
|||||||
".xor" => ast::SetpBoolPostOp::Xor,
|
".xor" => ast::SetpBoolPostOp::Xor,
|
||||||
};
|
};
|
||||||
|
|
||||||
SetpType: ast::ScalarType = {
|
SetpTypeNoF32: ast::ScalarType = {
|
||||||
".b16" => ast::ScalarType::B16,
|
".b16" => ast::ScalarType::B16,
|
||||||
".b32" => ast::ScalarType::B32,
|
".b32" => ast::ScalarType::B32,
|
||||||
".b64" => ast::ScalarType::B64,
|
".b64" => ast::ScalarType::B64,
|
||||||
@ -793,7 +805,6 @@ SetpType: ast::ScalarType = {
|
|||||||
".s16" => ast::ScalarType::S16,
|
".s16" => ast::ScalarType::S16,
|
||||||
".s32" => ast::ScalarType::S32,
|
".s32" => ast::ScalarType::S32,
|
||||||
".s64" => ast::ScalarType::S64,
|
".s64" => ast::ScalarType::S64,
|
||||||
".f32" => ast::ScalarType::F32,
|
|
||||||
".f64" => ast::ScalarType::F64,
|
".f64" => ast::ScalarType::F64,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -857,7 +868,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: r,
|
rounding: r,
|
||||||
flush_to_zero: false,
|
flush_to_zero: None,
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F16,
|
dst: ast::FloatType::F16,
|
||||||
src: ast::FloatType::F16
|
src: ast::FloatType::F16
|
||||||
@ -868,7 +879,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: None,
|
rounding: None,
|
||||||
flush_to_zero: f.is_some(),
|
flush_to_zero: Some(f.is_some()),
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F32,
|
dst: ast::FloatType::F32,
|
||||||
src: ast::FloatType::F16
|
src: ast::FloatType::F16
|
||||||
@ -879,7 +890,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: None,
|
rounding: None,
|
||||||
flush_to_zero: false,
|
flush_to_zero: None,
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F64,
|
dst: ast::FloatType::F64,
|
||||||
src: ast::FloatType::F16
|
src: ast::FloatType::F16
|
||||||
@ -890,7 +901,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: Some(r),
|
rounding: Some(r),
|
||||||
flush_to_zero: f.is_some(),
|
flush_to_zero: Some(f.is_some()),
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F16,
|
dst: ast::FloatType::F16,
|
||||||
src: ast::FloatType::F32
|
src: ast::FloatType::F32
|
||||||
@ -901,7 +912,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: r,
|
rounding: r,
|
||||||
flush_to_zero: f.is_some(),
|
flush_to_zero: Some(f.is_some()),
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F32,
|
dst: ast::FloatType::F32,
|
||||||
src: ast::FloatType::F32
|
src: ast::FloatType::F32
|
||||||
@ -912,7 +923,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: None,
|
rounding: None,
|
||||||
flush_to_zero: false,
|
flush_to_zero: None,
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F64,
|
dst: ast::FloatType::F64,
|
||||||
src: ast::FloatType::F32
|
src: ast::FloatType::F32
|
||||||
@ -923,7 +934,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: Some(r),
|
rounding: Some(r),
|
||||||
flush_to_zero: false,
|
flush_to_zero: None,
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F16,
|
dst: ast::FloatType::F16,
|
||||||
src: ast::FloatType::F64
|
src: ast::FloatType::F64
|
||||||
@ -934,7 +945,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: Some(r),
|
rounding: Some(r),
|
||||||
flush_to_zero: s.is_some(),
|
flush_to_zero: Some(s.is_some()),
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F32,
|
dst: ast::FloatType::F32,
|
||||||
src: ast::FloatType::F64
|
src: ast::FloatType::F64
|
||||||
@ -945,7 +956,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||||
ast::CvtDesc {
|
ast::CvtDesc {
|
||||||
rounding: r,
|
rounding: r,
|
||||||
flush_to_zero: false,
|
flush_to_zero: None,
|
||||||
saturate: s.is_some(),
|
saturate: s.is_some(),
|
||||||
dst: ast::FloatType::F64,
|
dst: ast::FloatType::F64,
|
||||||
src: ast::FloatType::F64
|
src: ast::FloatType::F64
|
||||||
@ -1082,19 +1093,19 @@ InstCall: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs
|
||||||
InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
"abs" <t:SignedIntType> <a:Arg2> => {
|
"abs" <t:SignedIntType> <a:Arg2> => {
|
||||||
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: t }, a)
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: None, typ: t }, a)
|
||||||
},
|
},
|
||||||
"abs" <f:".ftz"?> ".f32" <a:Arg2> => {
|
"abs" <f:".ftz"?> ".f32" <a:Arg2> => {
|
||||||
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F32 }, a)
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F32 }, a)
|
||||||
},
|
},
|
||||||
"abs" ".f64" <a:Arg2> => {
|
"abs" ".f64" <a:Arg2> => {
|
||||||
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: ast::ScalarType::F64 }, a)
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: None, typ: ast::ScalarType::F64 }, a)
|
||||||
},
|
},
|
||||||
"abs" <f:".ftz"?> ".f16" <a:Arg2> => {
|
"abs" <f:".ftz"?> ".f16" <a:Arg2> => {
|
||||||
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F16 }, a)
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F16 }, a)
|
||||||
},
|
},
|
||||||
"abs" <f:".ftz"?> ".f16x2" <a:Arg2> => {
|
"abs" <f:".ftz"?> ".f16x2" <a:Arg2> => {
|
||||||
todo!()
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F16x2 }, a)
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1128,7 +1139,7 @@ InstRcp: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
"rcp" <rounding:RcpRoundingMode> <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
"rcp" <rounding:RcpRoundingMode> <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
||||||
let details = ast::RcpDetails {
|
let details = ast::RcpDetails {
|
||||||
rounding,
|
rounding,
|
||||||
flush_to_zero: ftz.is_some(),
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
is_f64: false,
|
is_f64: false,
|
||||||
};
|
};
|
||||||
ast::Instruction::Rcp(details, a)
|
ast::Instruction::Rcp(details, a)
|
||||||
@ -1136,7 +1147,7 @@ InstRcp: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||||||
"rcp" <rn:RoundingModeFloat> ".f64" <a:Arg2> => {
|
"rcp" <rn:RoundingModeFloat> ".f64" <a:Arg2> => {
|
||||||
let details = ast::RcpDetails {
|
let details = ast::RcpDetails {
|
||||||
rounding: Some(rn),
|
rounding: Some(rn),
|
||||||
flush_to_zero: false,
|
flush_to_zero: None,
|
||||||
is_f64: true,
|
is_f64: true,
|
||||||
};
|
};
|
||||||
ast::Instruction::Rcp(details, a)
|
ast::Instruction::Rcp(details, a)
|
||||||
@ -1173,16 +1184,16 @@ MinMaxDetails: ast::MinMaxDetails = {
|
|||||||
<t:UIntType> => ast::MinMaxDetails::Unsigned(t),
|
<t:UIntType> => ast::MinMaxDetails::Unsigned(t),
|
||||||
<t:SIntType> => ast::MinMaxDetails::Signed(t),
|
<t:SIntType> => ast::MinMaxDetails::Signed(t),
|
||||||
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
|
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
|
||||||
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 }
|
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 }
|
||||||
),
|
),
|
||||||
".f64" => ast::MinMaxDetails::Float(
|
".f64" => ast::MinMaxDetails::Float(
|
||||||
ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 }
|
ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 }
|
||||||
),
|
),
|
||||||
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
|
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
|
||||||
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 }
|
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 }
|
||||||
),
|
),
|
||||||
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
|
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
|
||||||
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
|
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1203,25 +1214,25 @@ ArithFloat: ast::ArithFloat = {
|
|||||||
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
|
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
|
||||||
typ: ast::FloatType::F32,
|
typ: ast::FloatType::F32,
|
||||||
rounding: rn,
|
rounding: rn,
|
||||||
flush_to_zero: ftz.is_some(),
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
saturate: sat.is_some(),
|
saturate: sat.is_some(),
|
||||||
},
|
},
|
||||||
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
|
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
|
||||||
typ: ast::FloatType::F64,
|
typ: ast::FloatType::F64,
|
||||||
rounding: rn,
|
rounding: rn,
|
||||||
flush_to_zero: false,
|
flush_to_zero: None,
|
||||||
saturate: false,
|
saturate: false,
|
||||||
},
|
},
|
||||||
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
|
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
|
||||||
typ: ast::FloatType::F16,
|
typ: ast::FloatType::F16,
|
||||||
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
||||||
flush_to_zero: ftz.is_some(),
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
saturate: sat.is_some(),
|
saturate: sat.is_some(),
|
||||||
},
|
},
|
||||||
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
|
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
|
||||||
typ: ast::FloatType::F16x2,
|
typ: ast::FloatType::F16x2,
|
||||||
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
||||||
flush_to_zero: ftz.is_some(),
|
flush_to_zero: Some(ftz.is_some()),
|
||||||
saturate: sat.is_some(),
|
saturate: sat.is_some(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -81,6 +81,10 @@ test_ptx!(global_array, [0xDEADu32], [1u32]);
|
|||||||
test_ptx!(extern_shared, [127u64], [127u64]);
|
test_ptx!(extern_shared, [127u64], [127u64]);
|
||||||
test_ptx!(extern_shared_call, [121u64], [123u64]);
|
test_ptx!(extern_shared_call, [121u64], [123u64]);
|
||||||
test_ptx!(rcp, [2f32], [0.5f32]);
|
test_ptx!(rcp, [2f32], [0.5f32]);
|
||||||
|
// 0b1_00000000_10000000000000000000000u32 is a large denormal
|
||||||
|
// 0x3f000000 is 0.5
|
||||||
|
test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]);
|
||||||
|
test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]);
|
||||||
|
|
||||||
struct DisplayError<T: Debug> {
|
struct DisplayError<T: Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
23
ptx/src/test/spirv_run/mul_ftz.ptx
Normal file
23
ptx/src/test/spirv_run/mul_ftz.ptx
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry mul_ftz(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .f32 temp1;
|
||||||
|
.reg .f32 temp2;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.f32 temp1, [in_addr];
|
||||||
|
ld.f32 temp2, [in_addr+4];
|
||||||
|
mul.ftz.f32 temp1, temp1, temp2;
|
||||||
|
st.f32 [out_addr], temp1;
|
||||||
|
ret;
|
||||||
|
}
|
46
ptx/src/test/spirv_run/mul_ftz.spvtxt
Normal file
46
ptx/src/test/spirv_run/mul_ftz.spvtxt
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Int8
|
||||||
|
%25 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "mul_lo"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%28 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||||
|
%ulong_2 = OpConstant %ulong 2
|
||||||
|
%1 = OpFunction %void None %28
|
||||||
|
%8 = OpFunctionParameter %ulong
|
||||||
|
%9 = OpFunctionParameter %ulong
|
||||||
|
%23 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%7 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %8
|
||||||
|
OpStore %3 %9
|
||||||
|
%11 = OpLoad %ulong %2
|
||||||
|
%10 = OpCopyObject %ulong %11
|
||||||
|
OpStore %4 %10
|
||||||
|
%13 = OpLoad %ulong %3
|
||||||
|
%12 = OpCopyObject %ulong %13
|
||||||
|
OpStore %5 %12
|
||||||
|
%15 = OpLoad %ulong %4
|
||||||
|
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||||
|
%14 = OpLoad %ulong %21
|
||||||
|
OpStore %6 %14
|
||||||
|
%17 = OpLoad %ulong %6
|
||||||
|
%16 = OpIMul %ulong %17 %ulong_2
|
||||||
|
OpStore %7 %16
|
||||||
|
%18 = OpLoad %ulong %5
|
||||||
|
%19 = OpLoad %ulong %7
|
||||||
|
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
|
||||||
|
OpStore %22 %19
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
23
ptx/src/test/spirv_run/mul_non_ftz.ptx
Normal file
23
ptx/src/test/spirv_run/mul_non_ftz.ptx
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry mul_non_ftz(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .f32 temp1;
|
||||||
|
.reg .f32 temp2;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.f32 temp1, [in_addr];
|
||||||
|
ld.f32 temp2, [in_addr+4];
|
||||||
|
mul.f32 temp1, temp1, temp2;
|
||||||
|
st.f32 [out_addr], temp1;
|
||||||
|
ret;
|
||||||
|
}
|
61
ptx/src/test/spirv_run/mul_non_ftz.spvtxt
Normal file
61
ptx/src/test/spirv_run/mul_non_ftz.spvtxt
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int8
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Float64
|
||||||
|
OpCapability DenormFlushToZero
|
||||||
|
OpCapability DenormPreserve
|
||||||
|
OpExtension "SPV_KHR_float_controls"
|
||||||
|
%30 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "mul_non_ftz"
|
||||||
|
OpExecutionMode %1 DenormPreserve 32
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%33 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%float = OpTypeFloat 32
|
||||||
|
%_ptr_Function_float = OpTypePointer Function %float
|
||||||
|
%_ptr_Generic_float = OpTypePointer Generic %float
|
||||||
|
%ulong_4 = OpConstant %ulong 4
|
||||||
|
%1 = OpFunction %void None %33
|
||||||
|
%8 = OpFunctionParameter %ulong
|
||||||
|
%9 = OpFunctionParameter %ulong
|
||||||
|
%28 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_float Function
|
||||||
|
%7 = OpVariable %_ptr_Function_float Function
|
||||||
|
OpStore %2 %8
|
||||||
|
OpStore %3 %9
|
||||||
|
%11 = OpLoad %ulong %2
|
||||||
|
%10 = OpCopyObject %ulong %11
|
||||||
|
OpStore %4 %10
|
||||||
|
%13 = OpLoad %ulong %3
|
||||||
|
%12 = OpCopyObject %ulong %13
|
||||||
|
OpStore %5 %12
|
||||||
|
%15 = OpLoad %ulong %4
|
||||||
|
%25 = OpConvertUToPtr %_ptr_Generic_float %15
|
||||||
|
%14 = OpLoad %float %25
|
||||||
|
OpStore %6 %14
|
||||||
|
%17 = OpLoad %ulong %4
|
||||||
|
%24 = OpIAdd %ulong %17 %ulong_4
|
||||||
|
%26 = OpConvertUToPtr %_ptr_Generic_float %24
|
||||||
|
%16 = OpLoad %float %26
|
||||||
|
OpStore %7 %16
|
||||||
|
%19 = OpLoad %float %6
|
||||||
|
%20 = OpLoad %float %7
|
||||||
|
%18 = OpFMul %float %19 %20
|
||||||
|
OpStore %6 %18
|
||||||
|
%21 = OpLoad %ulong %5
|
||||||
|
%22 = OpLoad %float %6
|
||||||
|
%27 = OpConvertUToPtr %_ptr_Generic_float %21
|
||||||
|
OpStore %27 %22
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
@ -7,9 +7,11 @@
|
|||||||
OpCapability Int64
|
OpCapability Int64
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Float64
|
OpCapability Float64
|
||||||
|
OpExtension "SPV_KHR_float_controls"
|
||||||
%23 = OpExtInstImport "OpenCL.std"
|
%23 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "rcp"
|
OpEntryPoint Kernel %1 "rcp"
|
||||||
|
OpExecutionMode %1 DenormPreserve 32
|
||||||
OpDecorate %15 FPFastMathMode AllowRecip
|
OpDecorate %15 FPFastMathMode AllowRecip
|
||||||
%void = OpTypeVoid
|
%void = OpTypeVoid
|
||||||
%ulong = OpTypeInt 64 0
|
%ulong = OpTypeInt 64 0
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::ast;
|
use crate::ast;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use rspirv::{binary::Disassemble, dr};
|
use rspirv::{binary::Disassemble, dr};
|
||||||
use std::{borrow::Cow, iter, mem};
|
use std::{borrow::Cow, hash::Hash, iter, mem};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{hash_map, HashMap, HashSet},
|
collections::{hash_map, HashMap, HashSet},
|
||||||
convert::TryFrom,
|
convert::TryFrom,
|
||||||
@ -438,6 +438,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
|
|||||||
let mut directives =
|
let mut directives =
|
||||||
convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id());
|
convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id());
|
||||||
normalize_variable_decls(&mut directives);
|
normalize_variable_decls(&mut directives);
|
||||||
|
let denorm_information = compute_denorm_information(&directives);
|
||||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||||
builder.set_version(1, 3);
|
builder.set_version(1, 3);
|
||||||
emit_capabilities(&mut builder);
|
emit_capabilities(&mut builder);
|
||||||
@ -463,6 +464,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
|
|||||||
&mut map,
|
&mut map,
|
||||||
&id_defs,
|
&id_defs,
|
||||||
f.func_decl,
|
f.func_decl,
|
||||||
|
&denorm_information,
|
||||||
&mut kernel_info,
|
&mut kernel_info,
|
||||||
)?;
|
)?;
|
||||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
|
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
|
||||||
@ -523,10 +525,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||||||
globals,
|
globals,
|
||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
}) => {
|
}) => {
|
||||||
let call_key = match func_decl {
|
let call_key = CallgraphKey::new(&func_decl);
|
||||||
ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
|
|
||||||
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
|
|
||||||
};
|
|
||||||
let statements = statements
|
let statements = statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|statement| match statement {
|
.map(|statement| match statement {
|
||||||
@ -563,10 +562,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||||||
globals,
|
globals,
|
||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
}) => {
|
}) => {
|
||||||
let call_key = match func_decl {
|
let call_key = CallgraphKey::new(&func_decl);
|
||||||
ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
|
|
||||||
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
|
|
||||||
};
|
|
||||||
if !methods_using_extern_shared.contains(&call_key) {
|
if !methods_using_extern_shared.contains(&call_key) {
|
||||||
return Directive::Method(Function {
|
return Directive::Method(Function {
|
||||||
func_decl,
|
func_decl,
|
||||||
@ -726,12 +722,171 @@ fn get_callers_of_extern_shared_single<'a>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DenormCountMap<T> = HashMap<T, isize>;
|
||||||
|
|
||||||
|
fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
|
||||||
|
let num_value = if value { 1 } else { -1 };
|
||||||
|
denorm_count_map_update_impl(map, key, num_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn denorm_count_map_update_impl<T: Eq + Hash>(
|
||||||
|
map: &mut DenormCountMap<T>,
|
||||||
|
key: T,
|
||||||
|
num_value: isize,
|
||||||
|
) {
|
||||||
|
match map.entry(key) {
|
||||||
|
hash_map::Entry::Occupied(mut counter) => {
|
||||||
|
*(counter.get_mut()) += num_value;
|
||||||
|
}
|
||||||
|
hash_map::Entry::Vacant(entry) => {
|
||||||
|
entry.insert(num_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn denorm_count_map_merge<T: Eq + Hash + Copy>(
|
||||||
|
dst: &mut DenormCountMap<T>,
|
||||||
|
src: &DenormCountMap<T>,
|
||||||
|
) {
|
||||||
|
for (k, count) in src {
|
||||||
|
denorm_count_map_update_impl(dst, *k, *count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HACK ALERT!
|
||||||
|
// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
|
||||||
|
// in the kernel as flushing denorms to zero or preserving them
|
||||||
|
// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
|
||||||
|
// such capability, so instead we guesstimate which use is more common in the kernel
|
||||||
|
// and emit suitable execution mode
|
||||||
|
fn compute_denorm_information<'input>(
|
||||||
|
module: &[Directive<'input>],
|
||||||
|
) -> HashMap<&'input str, HashMap<u8, spirv::ExecutionMode>> {
|
||||||
|
let mut direct_func_calls = MultiHashMap::new();
|
||||||
|
let mut denorm_methods = HashMap::new();
|
||||||
|
for directive in module.iter() {
|
||||||
|
match directive {
|
||||||
|
Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
|
||||||
|
Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
body: Some(statements),
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
let mut flush_counter = DenormCountMap::new();
|
||||||
|
let method_key = CallgraphKey::new(func_decl);
|
||||||
|
for statement in statements {
|
||||||
|
match statement {
|
||||||
|
Statement::Instruction(inst) => {
|
||||||
|
if let Some((flush, width)) = inst.flush_to_zero() {
|
||||||
|
denorm_count_map_update(&mut flush_counter, width, flush);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Statement::LoadVar(_, _) => {}
|
||||||
|
Statement::StoreVar(_, _) => {}
|
||||||
|
Statement::Call(ResolvedCall { func, .. }) => {
|
||||||
|
multi_hash_map_append(&mut direct_func_calls, method_key, *func);
|
||||||
|
}
|
||||||
|
Statement::Composite(_) => {}
|
||||||
|
Statement::Conditional(_) => {}
|
||||||
|
Statement::Conversion(_) => {}
|
||||||
|
Statement::Constant(_) => {}
|
||||||
|
Statement::RetValue(_, _) => {}
|
||||||
|
Statement::Undef(_, _) => {}
|
||||||
|
Statement::Label(_) => {}
|
||||||
|
Statement::Variable(_) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
denorm_methods.insert(method_key, flush_counter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let summed_denorm_methods = sum_up_denorm_use(module, denorm_methods, &direct_func_calls);
|
||||||
|
summed_denorm_methods
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|(name, v)| {
|
||||||
|
let width_to_denorm = v
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, ftz_over_preserve)| {
|
||||||
|
let mode = if ftz_over_preserve > 0 {
|
||||||
|
spirv::ExecutionMode::DenormFlushToZero
|
||||||
|
} else {
|
||||||
|
spirv::ExecutionMode::DenormPreserve
|
||||||
|
};
|
||||||
|
(k, mode)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Some((name, width_to_denorm))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum_up_denorm_use<'input>(
|
||||||
|
module: &[Directive<'input>],
|
||||||
|
denorm_methods: HashMap<CallgraphKey<'input>, DenormCountMap<u8>>,
|
||||||
|
direct_func_calls: &MultiHashMap<CallgraphKey<'input>, spirv::Word>,
|
||||||
|
) -> HashMap<&'input str, DenormCountMap<u8>> {
|
||||||
|
let mut result = HashMap::new();
|
||||||
|
let empty = Vec::new();
|
||||||
|
for (method_key, denorm_map) in denorm_methods.iter() {
|
||||||
|
match method_key {
|
||||||
|
CallgraphKey::Kernel(name) => {
|
||||||
|
let mut sum = denorm_map.clone();
|
||||||
|
let mut visited = HashSet::new();
|
||||||
|
for child in direct_func_calls
|
||||||
|
.get(&CallgraphKey::Kernel(name))
|
||||||
|
.unwrap_or(&empty)
|
||||||
|
{
|
||||||
|
sum_up_denorm_use_single(
|
||||||
|
&denorm_methods,
|
||||||
|
direct_func_calls,
|
||||||
|
&mut sum,
|
||||||
|
&mut visited,
|
||||||
|
*child,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result.insert(*name, sum);
|
||||||
|
}
|
||||||
|
CallgraphKey::Func(_) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum_up_denorm_use_single<'input>(
|
||||||
|
denorm_methods: &HashMap<CallgraphKey<'input>, DenormCountMap<u8>>,
|
||||||
|
direct_func_calls: &MultiHashMap<CallgraphKey<'input>, spirv::Word>,
|
||||||
|
sum: &mut DenormCountMap<u8>,
|
||||||
|
visited: &mut HashSet<spirv::Word>,
|
||||||
|
current: spirv::Word,
|
||||||
|
) {
|
||||||
|
if !visited.insert(current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if let Some(denorm_map) = denorm_methods.get(&CallgraphKey::Func(current)) {
|
||||||
|
denorm_count_map_merge(sum, denorm_map);
|
||||||
|
}
|
||||||
|
if let Some(children) = direct_func_calls.get(&CallgraphKey::Func(current)) {
|
||||||
|
for child in children {
|
||||||
|
sum_up_denorm_use_single(denorm_methods, direct_func_calls, sum, visited, *child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
|
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
|
||||||
enum CallgraphKey<'input> {
|
enum CallgraphKey<'input> {
|
||||||
Kernel(&'input str),
|
Kernel(&'input str),
|
||||||
Func(spirv::Word),
|
Func(spirv::Word),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'input> CallgraphKey<'input> {
|
||||||
|
fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
|
||||||
|
match decl {
|
||||||
|
ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
|
||||||
|
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(*id),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_builtins(
|
fn emit_builtins(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
@ -764,6 +919,7 @@ fn emit_function_header<'a>(
|
|||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
global: &GlobalStringIdResolver<'a>,
|
global: &GlobalStringIdResolver<'a>,
|
||||||
func_directive: ast::MethodDecl<spirv::Word>,
|
func_directive: ast::MethodDecl<spirv::Word>,
|
||||||
|
denorm_information: &HashMap<&'a str, HashMap<u8, spirv::ExecutionMode>>,
|
||||||
kernel_info: &mut HashMap<String, KernelInfo>,
|
kernel_info: &mut HashMap<String, KernelInfo>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
if let ast::MethodDecl::Kernel {
|
if let ast::MethodDecl::Kernel {
|
||||||
@ -797,6 +953,11 @@ fn emit_function_header<'a>(
|
|||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
global_variables.append(&mut interface);
|
global_variables.append(&mut interface);
|
||||||
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
|
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
|
||||||
|
if let Some(exec_modes) = denorm_information.get(name) {
|
||||||
|
for (size_of, exec_mode) in exec_modes {
|
||||||
|
builder.execution_mode(fn_id, *exec_mode, [(*size_of as u32) * 8])
|
||||||
|
}
|
||||||
|
}
|
||||||
fn_id
|
fn_id
|
||||||
}
|
}
|
||||||
ast::MethodDecl::Func(_, name, _) => name,
|
ast::MethodDecl::Func(_, name, _) => name,
|
||||||
@ -844,9 +1005,14 @@ fn emit_capabilities(builder: &mut dr::Builder) {
|
|||||||
builder.capability(spirv::Capability::Int64);
|
builder.capability(spirv::Capability::Int64);
|
||||||
builder.capability(spirv::Capability::Float16);
|
builder.capability(spirv::Capability::Float16);
|
||||||
builder.capability(spirv::Capability::Float64);
|
builder.capability(spirv::Capability::Float64);
|
||||||
|
builder.capability(spirv::Capability::DenormFlushToZero);
|
||||||
|
builder.capability(spirv::Capability::DenormPreserve);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn emit_extensions(_: &mut dr::Builder) {}
|
// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html
|
||||||
|
fn emit_extensions(builder: &mut dr::Builder) {
|
||||||
|
builder.extension("SPV_KHR_float_controls");
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {
|
fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {
|
||||||
builder.ext_inst_import("OpenCL.std")
|
builder.ext_inst_import("OpenCL.std")
|
||||||
@ -2088,7 +2254,7 @@ fn emit_function_body_ops(
|
|||||||
ast::MulDetails::Unsigned(ref ctr) => {
|
ast::MulDetails::Unsigned(ref ctr) => {
|
||||||
emit_mul_uint(builder, map, opencl, ctr, arg)?
|
emit_mul_uint(builder, map, opencl, ctr, arg)?
|
||||||
}
|
}
|
||||||
ast::MulDetails::Float(_) => todo!(),
|
ast::MulDetails::Float(ref ctr) => emit_mul_float(builder, map, ctr, arg)?,
|
||||||
},
|
},
|
||||||
ast::Instruction::Add(add, arg) => match add {
|
ast::Instruction::Add(add, arg) => match add {
|
||||||
ast::ArithDetails::Signed(ref desc) => {
|
ast::ArithDetails::Signed(ref desc) => {
|
||||||
@ -2215,15 +2381,27 @@ fn emit_function_body_ops(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_mul_float(
|
||||||
|
builder: &mut dr::Builder,
|
||||||
|
map: &mut TypeWordMap,
|
||||||
|
ctr: &ast::ArithFloat,
|
||||||
|
arg: &ast::Arg3<ExpandedArgParams>,
|
||||||
|
) -> Result<(), dr::Error> {
|
||||||
|
if ctr.saturate {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
let result_type = map.get_or_add_scalar(builder, ctr.typ.into());
|
||||||
|
builder.f_mul(result_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||||
|
emit_rounding_decoration(builder, arg.dst, ctr.rounding);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_rcp(
|
fn emit_rcp(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
desc: &ast::RcpDetails,
|
desc: &ast::RcpDetails,
|
||||||
a: &ast::Arg2<ExpandedArgParams>,
|
a: &ast::Arg2<ExpandedArgParams>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
if desc.flush_to_zero {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
let (instr_type, constant) = if desc.is_f64 {
|
let (instr_type, constant) = if desc.is_f64 {
|
||||||
(ast::ScalarType::F64, vec_repr(1.0f64))
|
(ast::ScalarType::F64, vec_repr(1.0f64))
|
||||||
} else {
|
} else {
|
||||||
@ -2360,9 +2538,6 @@ fn emit_add_float(
|
|||||||
desc: &ast::ArithFloat,
|
desc: &ast::ArithFloat,
|
||||||
arg: &ast::Arg3<ExpandedArgParams>,
|
arg: &ast::Arg3<ExpandedArgParams>,
|
||||||
) -> Result<(), dr::Error> {
|
) -> Result<(), dr::Error> {
|
||||||
if desc.flush_to_zero {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||||
builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||||
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
||||||
@ -2375,9 +2550,6 @@ fn emit_sub_float(
|
|||||||
desc: &ast::ArithFloat,
|
desc: &ast::ArithFloat,
|
||||||
arg: &ast::Arg3<ExpandedArgParams>,
|
arg: &ast::Arg3<ExpandedArgParams>,
|
||||||
) -> Result<(), dr::Error> {
|
) -> Result<(), dr::Error> {
|
||||||
if desc.flush_to_zero {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||||
builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||||
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
||||||
@ -2441,7 +2613,7 @@ fn emit_cvt(
|
|||||||
if desc.dst == desc.src {
|
if desc.dst == desc.src {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
if desc.saturate || desc.flush_to_zero {
|
if desc.saturate {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
let dest_t: ast::ScalarType = desc.dst.into();
|
let dest_t: ast::ScalarType = desc.dst.into();
|
||||||
@ -2450,7 +2622,7 @@ fn emit_cvt(
|
|||||||
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
||||||
}
|
}
|
||||||
ast::CvtDetails::FloatFromInt(desc) => {
|
ast::CvtDetails::FloatFromInt(desc) => {
|
||||||
if desc.saturate || desc.flush_to_zero {
|
if desc.saturate {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
let dest_t: ast::ScalarType = desc.dst.into();
|
let dest_t: ast::ScalarType = desc.dst.into();
|
||||||
@ -2463,9 +2635,6 @@ fn emit_cvt(
|
|||||||
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
||||||
}
|
}
|
||||||
ast::CvtDetails::IntFromFloat(desc) => {
|
ast::CvtDetails::IntFromFloat(desc) => {
|
||||||
if desc.flush_to_zero {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
let dest_t: ast::ScalarType = desc.dst.into();
|
let dest_t: ast::ScalarType = desc.dst.into();
|
||||||
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
||||||
if desc.dst.is_signed() {
|
if desc.dst.is_signed() {
|
||||||
@ -2561,9 +2730,6 @@ fn emit_setp(
|
|||||||
setp: &ast::SetpData,
|
setp: &ast::SetpData,
|
||||||
arg: &ast::Arg4Setp<ExpandedArgParams>,
|
arg: &ast::Arg4Setp<ExpandedArgParams>,
|
||||||
) -> Result<(), dr::Error> {
|
) -> Result<(), dr::Error> {
|
||||||
if setp.flush_to_zero {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred));
|
let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred));
|
||||||
let result_id = Some(arg.dst1);
|
let result_id = Some(arg.dst1);
|
||||||
let operand_1 = arg.src1;
|
let operand_1 = arg.src1;
|
||||||
@ -4122,6 +4288,73 @@ impl ast::Instruction<ExpandedArgParams> {
|
|||||||
| ast::Instruction::Mad(_, _) => None,
|
| ast::Instruction::Mad(_, _) => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// .wide instructions don't support ftz, so it's enough to just look at the
|
||||||
|
// type declared by the instruction
|
||||||
|
fn flush_to_zero(&self) -> Option<(bool, u8)> {
|
||||||
|
match self {
|
||||||
|
ast::Instruction::Ld(_, _) => None,
|
||||||
|
ast::Instruction::St(_, _) => None,
|
||||||
|
ast::Instruction::Mov(_, _) => None,
|
||||||
|
ast::Instruction::Not(_, _) => None,
|
||||||
|
ast::Instruction::Bra(_, _) => None,
|
||||||
|
ast::Instruction::Shl(_, _) => None,
|
||||||
|
ast::Instruction::Shr(_, _) => None,
|
||||||
|
ast::Instruction::Ret(_) => None,
|
||||||
|
ast::Instruction::Call(_) => None,
|
||||||
|
ast::Instruction::Or(_, _) => None,
|
||||||
|
ast::Instruction::Cvta(_, _) => None,
|
||||||
|
ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None,
|
||||||
|
ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None,
|
||||||
|
ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None,
|
||||||
|
ast::Instruction::Add(ast::ArithDetails::Unsigned(_), _) => None,
|
||||||
|
ast::Instruction::Mul(ast::MulDetails::Unsigned(_), _) => None,
|
||||||
|
ast::Instruction::Mul(ast::MulDetails::Signed(_), _) => None,
|
||||||
|
ast::Instruction::Mad(ast::MulDetails::Unsigned(_), _) => None,
|
||||||
|
ast::Instruction::Mad(ast::MulDetails::Signed(_), _) => None,
|
||||||
|
ast::Instruction::Min(ast::MinMaxDetails::Signed(_), _) => None,
|
||||||
|
ast::Instruction::Min(ast::MinMaxDetails::Unsigned(_), _) => None,
|
||||||
|
ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None,
|
||||||
|
ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None,
|
||||||
|
ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None,
|
||||||
|
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
|
||||||
|
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
|
||||||
|
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
|
||||||
|
| ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control
|
||||||
|
.flush_to_zero
|
||||||
|
.map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
|
||||||
|
ast::Instruction::Setp(details, _) => details
|
||||||
|
.flush_to_zero
|
||||||
|
.map(|ftz| (ftz, details.typ.size_of())),
|
||||||
|
ast::Instruction::SetpBool(details, _) => details
|
||||||
|
.flush_to_zero
|
||||||
|
.map(|ftz| (ftz, details.typ.size_of())),
|
||||||
|
ast::Instruction::Abs(details, _) => details
|
||||||
|
.flush_to_zero
|
||||||
|
.map(|ftz| (ftz, details.typ.size_of())),
|
||||||
|
ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _)
|
||||||
|
| ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control
|
||||||
|
.flush_to_zero
|
||||||
|
.map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
|
||||||
|
ast::Instruction::Rcp(details, _) => details
|
||||||
|
.flush_to_zero
|
||||||
|
.map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })),
|
||||||
|
// Modifier .ftz can only be specified when either .dtype or .atype
|
||||||
|
// is .f32 and applies only to single precision (.f32) inputs and results.
|
||||||
|
ast::Instruction::Cvt(
|
||||||
|
ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }),
|
||||||
|
_,
|
||||||
|
)
|
||||||
|
| ast::Instruction::Cvt(
|
||||||
|
ast::CvtDetails::FloatFromInt(ast::CvtDesc { flush_to_zero, .. }),
|
||||||
|
_,
|
||||||
|
)
|
||||||
|
| ast::Instruction::Cvt(
|
||||||
|
ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }),
|
||||||
|
_,
|
||||||
|
) => flush_to_zero.map(|ftz| (ftz, 4)),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
|
impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
|
||||||
|
Reference in New Issue
Block a user