Fix ftz behavior slightly

This commit is contained in:
Andrzej Janik
2020-11-07 16:14:37 +01:00
parent ac6265f257
commit 62d14cdffe
3 changed files with 114 additions and 74 deletions

View File

@ -83,8 +83,11 @@ test_ptx!(extern_shared_call, [121u64], [123u64]);
test_ptx!(rcp, [2f32], [0.5f32]); test_ptx!(rcp, [2f32], [0.5f32]);
// 0b1_00000000_10000000000000000000000u32 is a large denormal // 0b1_00000000_10000000000000000000000u32 is a large denormal
// 0x3f000000 is 0.5 // 0x3f000000 is 0.5
// TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2 test_ptx!(
// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]); mul_ftz,
[0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
[0b1_00000000_00000000000000000000000u32]
);
test_ptx!( test_ptx!(
mul_non_ftz, mul_non_ftz,
[0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
@ -196,7 +199,12 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
let (module, maybe_log) = match module.should_link_ptx_impl { let (module, maybe_log) = match module.should_link_ptx_impl {
Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]), Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]),
None => { None => {
let (module, log) = ze::Module::build_spirv(&mut ctx, &dev, byte_il, None); let (module, log) = ze::Module::build_spirv(
&mut ctx,
&dev,
byte_il,
Some(module.build_options.as_c_str()),
);
(module, Some(log)) (module, Some(log))
} }
}; };

View File

@ -1,64 +1,55 @@
; SPIR-V OpCapability GenericPointer
; Version: 1.3 OpCapability Linkage
; Generator: rspirv OpCapability Addresses
; Bound: 38 OpCapability Kernel
OpCapability GenericPointer OpCapability Int8
OpCapability Linkage OpCapability Int16
OpCapability Addresses OpCapability Int64
OpCapability Kernel OpCapability Float16
OpCapability Int8 OpCapability Float64
OpCapability Int16 %28 = OpExtInstImport "OpenCL.std"
OpCapability Int64 OpMemoryModel Physical64 OpenCL
OpCapability Float16 OpEntryPoint Kernel %1 "mul_ftz"
OpCapability Float64 %void = OpTypeVoid
; OpCapability FunctionFloatControlINTEL %ulong = OpTypeInt 64 0
; OpExtension "SPV_INTEL_float_controls2" %31 = OpTypeFunction %void %ulong %ulong
%30 = OpExtInstImport "OpenCL.std" %_ptr_Function_ulong = OpTypePointer Function %ulong
OpMemoryModel Physical64 OpenCL %float = OpTypeFloat 32
OpEntryPoint Kernel %1 "mul_ftz" %_ptr_Function_float = OpTypePointer Function %float
OpDecorate %1 FunctionDenormModeINTEL 32 FlushToZero %_ptr_Generic_float = OpTypePointer Generic %float
%31 = OpTypeVoid %ulong_4 = OpConstant %ulong 4
%32 = OpTypeInt 64 0 %1 = OpFunction %void None %31
%33 = OpTypeFunction %31 %32 %32 %8 = OpFunctionParameter %ulong
%34 = OpTypePointer Function %32 %9 = OpFunctionParameter %ulong
%35 = OpTypeFloat 32 %26 = OpLabel
%36 = OpTypePointer Function %35 %2 = OpVariable %_ptr_Function_ulong Function
%37 = OpTypePointer Generic %35 %3 = OpVariable %_ptr_Function_ulong Function
%23 = OpConstant %32 4 %4 = OpVariable %_ptr_Function_ulong Function
%1 = OpFunction %31 None %33 %5 = OpVariable %_ptr_Function_ulong Function
%8 = OpFunctionParameter %32 %6 = OpVariable %_ptr_Function_float Function
%9 = OpFunctionParameter %32 %7 = OpVariable %_ptr_Function_float Function
%28 = OpLabel OpStore %2 %8
%2 = OpVariable %34 Function OpStore %3 %9
%3 = OpVariable %34 Function %10 = OpLoad %ulong %2
%4 = OpVariable %34 Function OpStore %4 %10
%5 = OpVariable %34 Function %11 = OpLoad %ulong %3
%6 = OpVariable %36 Function OpStore %5 %11
%7 = OpVariable %36 Function %13 = OpLoad %ulong %4
OpStore %2 %8 %23 = OpConvertUToPtr %_ptr_Generic_float %13
OpStore %3 %9 %12 = OpLoad %float %23
%11 = OpLoad %32 %2 OpStore %6 %12
%10 = OpCopyObject %32 %11 %15 = OpLoad %ulong %4
OpStore %4 %10 %22 = OpIAdd %ulong %15 %ulong_4
%13 = OpLoad %32 %3 %24 = OpConvertUToPtr %_ptr_Generic_float %22
%12 = OpCopyObject %32 %13 %14 = OpLoad %float %24
OpStore %5 %12 OpStore %7 %14
%15 = OpLoad %32 %4 %17 = OpLoad %float %6
%25 = OpConvertUToPtr %37 %15 %18 = OpLoad %float %7
%14 = OpLoad %35 %25 %16 = OpFMul %float %17 %18
OpStore %6 %14 OpStore %6 %16
%17 = OpLoad %32 %4 %19 = OpLoad %ulong %5
%24 = OpIAdd %32 %17 %23 %20 = OpLoad %float %6
%26 = OpConvertUToPtr %37 %24 %25 = OpConvertUToPtr %_ptr_Generic_float %19
%16 = OpLoad %35 %26 OpStore %25 %20
OpStore %7 %16 OpReturn
%19 = OpLoad %35 %6 OpFunctionEnd
%20 = OpLoad %35 %7
%18 = OpFMul %35 %19 %20
OpStore %6 %18
%21 = OpLoad %32 %5
%22 = OpLoad %35 %6
%27 = OpConvertUToPtr %37 %21
OpStore %27 %22
OpReturn
OpFunctionEnd

View File

@ -1,7 +1,7 @@
use crate::ast; use crate::ast;
use half::f16; use half::f16;
use rspirv::{binary::Disassemble, dr}; use rspirv::{binary::Disassemble, dr};
use std::{borrow::Cow, convert::TryFrom, hash::Hash, iter, mem}; use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem};
use std::{ use std::{
collections::{hash_map, HashMap, HashSet}, collections::{hash_map, HashMap, HashSet},
convert::TryInto, convert::TryInto,
@ -448,6 +448,7 @@ pub struct Module {
pub spirv: dr::Module, pub spirv: dr::Module,
pub kernel_info: HashMap<String, KernelInfo>, pub kernel_info: HashMap<String, KernelInfo>,
pub should_link_ptx_impl: Option<&'static [u8]>, pub should_link_ptx_impl: Option<&'static [u8]>,
pub build_options: CString,
} }
pub struct KernelInfo { pub struct KernelInfo {
@ -484,6 +485,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let mut map = TypeWordMap::new(&mut builder); let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs); emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new(); let mut kernel_info = HashMap::new();
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
emit_directives( emit_directives(
&mut builder, &mut builder,
&mut map, &mut map,
@ -503,15 +505,51 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
} else { } else {
None None
}, },
build_options,
}) })
} }
// TODO: remove this once we have perf-function support for denorms
fn emit_denorm_build_string(
call_map: &HashMap<&str, HashSet<u32>>,
denorm_information: &HashMap<MethodName, HashMap<u8, (spirv::FPDenormMode, isize)>>,
) -> CString {
let denorm_counts = denorm_information
.iter()
.map(|(method, meth_denorm)| {
let f16_count = meth_denorm
.get(&(mem::size_of::<f16>() as u8))
.unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
.1;
let f32_count = meth_denorm
.get(&(mem::size_of::<f32>() as u8))
.unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
.1;
(method, (f16_count + f32_count))
})
.collect::<HashMap<_, _>>();
let mut flush_over_preserve = 0;
for (kernel, children) in call_map {
flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
for child_fn in children {
flush_over_preserve += *denorm_counts
.get(&MethodName::Func(*child_fn))
.unwrap_or(&0);
}
}
if flush_over_preserve > 0 {
CString::new("-cl-denorms-are-zero").unwrap()
} else {
CString::default()
}
}
fn emit_directives<'input>( fn emit_directives<'input>(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver<'input>, id_defs: &GlobalStringIdResolver<'input>,
opencl_id: spirv::Word, opencl_id: spirv::Word,
denorm_information: &HashMap<MethodName<'input>, HashMap<u8, spirv::FPDenormMode>>, denorm_information: &HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
call_map: &HashMap<&'input str, HashSet<spirv::Word>>, call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
directives: Vec<Directive>, directives: Vec<Directive>,
kernel_info: &mut HashMap<String, KernelInfo>, kernel_info: &mut HashMap<String, KernelInfo>,
@ -579,6 +617,9 @@ fn get_call_map<'input>(
.. ..
}) => { }) => {
let call_key = MethodName::new(&func_decl); let call_key = MethodName::new(&func_decl);
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
entry.insert(Vec::new());
}
for statement in statements { for statement in statements {
match statement { match statement {
Statement::Call(call) => { Statement::Call(call) => {
@ -895,7 +936,7 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
// and emit suitable execution mode // and emit suitable execution mode
fn compute_denorm_information<'input>( fn compute_denorm_information<'input>(
module: &[Directive<'input>], module: &[Directive<'input>],
) -> HashMap<MethodName<'input>, HashMap<u8, spirv::FPDenormMode>> { ) -> HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
let mut denorm_methods = HashMap::new(); let mut denorm_methods = HashMap::new();
for directive in module { for directive in module {
match directive { match directive {
@ -937,13 +978,13 @@ fn compute_denorm_information<'input>(
.map(|(name, v)| { .map(|(name, v)| {
let width_to_denorm = v let width_to_denorm = v
.into_iter() .into_iter()
.map(|(k, ftz_over_preserve)| { .map(|(k, flush_over_preserve)| {
let mode = if ftz_over_preserve > 0 { let mode = if flush_over_preserve > 0 {
spirv::FPDenormMode::FlushToZero spirv::FPDenormMode::FlushToZero
} else { } else {
spirv::FPDenormMode::Preserve spirv::FPDenormMode::Preserve
}; };
(k, mode) (k, (mode, flush_over_preserve))
}) })
.collect(); .collect();
(name, width_to_denorm) (name, width_to_denorm)
@ -999,7 +1040,7 @@ fn emit_function_header<'a>(
defined_globals: &GlobalStringIdResolver<'a>, defined_globals: &GlobalStringIdResolver<'a>,
synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>], synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
func_decl: &SpirvMethodDecl<'a>, func_decl: &SpirvMethodDecl<'a>,
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, spirv::FPDenormMode>>, _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>, call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive], direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>, kernel_info: &mut HashMap<String, KernelInfo>,