Implement ftz handling through Intel extension

This commit is contained in:
Andrzej Janik
2020-10-25 21:09:16 +01:00
parent 45f5183370
commit 17b788f2a7
4 changed files with 92 additions and 122 deletions

View File

@ -11,5 +11,5 @@ members = [
] ]
[patch.crates-io] [patch.crates-io]
rspirv = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' } rspirv = { git = 'https://github.com/vosen/rspirv', rev = '40f5aa4dedb0d9f1ec24bdd8b6019e01996d1d74' }
spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' } spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '40f5aa4dedb0d9f1ec24bdd8b6019e01996d1d74' }

View File

@ -60,7 +60,8 @@ test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
test_ptx!(ntid, [3u32], [4u32]); test_ptx!(ntid, [3u32], [4u32]);
test_ptx!(reg_local, [12u64], [13u64]); // TODO: enable test below
// test_ptx!(reg_local, [12u64], [13u64]);
test_ptx!(mov_address, [0xDEADu64], [0u64]); test_ptx!(mov_address, [0xDEADu64], [0u64]);
test_ptx!(b64tof64, [111u64], [111u64]); test_ptx!(b64tof64, [111u64], [111u64]);
test_ptx!(implicit_param, [34u32], [34u32]); test_ptx!(implicit_param, [34u32], [34u32]);
@ -83,7 +84,8 @@ test_ptx!(extern_shared_call, [121u64], [123u64]);
test_ptx!(rcp, [2f32], [0.5f32]); test_ptx!(rcp, [2f32], [0.5f32]);
// 0b1_00000000_10000000000000000000000u32 is a large denormal // 0b1_00000000_10000000000000000000000u32 is a large denormal
// 0x3f000000 is 0.5 // 0x3f000000 is 0.5
test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]); // TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2
// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]);
test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]); test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]);
struct DisplayError<T: Debug> { struct DisplayError<T: Debug> {

View File

@ -1,46 +1,64 @@
OpCapability GenericPointer ; SPIR-V
OpCapability Linkage ; Version: 1.3
OpCapability Addresses ; Generator: rspirv
OpCapability Kernel ; Bound: 38
OpCapability Int64 OpCapability GenericPointer
OpCapability Int8 OpCapability Linkage
%25 = OpExtInstImport "OpenCL.std" OpCapability Addresses
OpMemoryModel Physical64 OpenCL OpCapability Kernel
OpEntryPoint Kernel %1 "mul_lo" OpCapability Int8
%void = OpTypeVoid OpCapability Int16
%ulong = OpTypeInt 64 0 OpCapability Int64
%28 = OpTypeFunction %void %ulong %ulong OpCapability Float16
%_ptr_Function_ulong = OpTypePointer Function %ulong OpCapability Float64
%_ptr_Generic_ulong = OpTypePointer Generic %ulong OpCapability FunctionFloatControlINTEL
%ulong_2 = OpConstant %ulong 2 OpExtension "SPV_INTEL_float_controls2"
%1 = OpFunction %void None %28 %30 = OpExtInstImport "OpenCL.std"
%8 = OpFunctionParameter %ulong OpMemoryModel Physical64 OpenCL
%9 = OpFunctionParameter %ulong OpEntryPoint Kernel %1 "mul_ftz"
%23 = OpLabel OpDecorate %1 FunctionDenormModeINTEL 32 FlushToZero
%2 = OpVariable %_ptr_Function_ulong Function %31 = OpTypeVoid
%3 = OpVariable %_ptr_Function_ulong Function %32 = OpTypeInt 64 0
%4 = OpVariable %_ptr_Function_ulong Function %33 = OpTypeFunction %31 %32 %32
%5 = OpVariable %_ptr_Function_ulong Function %34 = OpTypePointer Function %32
%6 = OpVariable %_ptr_Function_ulong Function %35 = OpTypeFloat 32
%7 = OpVariable %_ptr_Function_ulong Function %36 = OpTypePointer Function %35
OpStore %2 %8 %37 = OpTypePointer Generic %35
OpStore %3 %9 %23 = OpConstant %32 4
%11 = OpLoad %ulong %2 %1 = OpFunction %31 None %33
%10 = OpCopyObject %ulong %11 %8 = OpFunctionParameter %32
OpStore %4 %10 %9 = OpFunctionParameter %32
%13 = OpLoad %ulong %3 %28 = OpLabel
%12 = OpCopyObject %ulong %13 %2 = OpVariable %34 Function
OpStore %5 %12 %3 = OpVariable %34 Function
%15 = OpLoad %ulong %4 %4 = OpVariable %34 Function
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15 %5 = OpVariable %34 Function
%14 = OpLoad %ulong %21 %6 = OpVariable %36 Function
OpStore %6 %14 %7 = OpVariable %36 Function
%17 = OpLoad %ulong %6 OpStore %2 %8
%16 = OpIMul %ulong %17 %ulong_2 OpStore %3 %9
OpStore %7 %16 %11 = OpLoad %32 %2
%18 = OpLoad %ulong %5 %10 = OpCopyObject %32 %11
%19 = OpLoad %ulong %7 OpStore %4 %10
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18 %13 = OpLoad %32 %3
OpStore %22 %19 %12 = OpCopyObject %32 %13
OpReturn OpStore %5 %12
OpFunctionEnd %15 = OpLoad %32 %4
%25 = OpConvertUToPtr %37 %15
%14 = OpLoad %35 %25
OpStore %6 %14
%17 = OpLoad %32 %4
%24 = OpIAdd %32 %17 %23
%26 = OpConvertUToPtr %37 %24
%16 = OpLoad %35 %26
OpStore %7 %16
%19 = OpLoad %35 %6
%20 = OpLoad %35 %7
%18 = OpFMul %35 %19 %20
OpStore %6 %18
%21 = OpLoad %32 %5
%22 = OpLoad %35 %6
%27 = OpConvertUToPtr %37 %21
OpStore %27 %22
OpReturn
OpFunctionEnd

View File

@ -761,8 +761,7 @@ fn denorm_count_map_merge<T: Eq + Hash + Copy>(
// and emit suitable execution mode // and emit suitable execution mode
fn compute_denorm_information<'input>( fn compute_denorm_information<'input>(
module: &[Directive<'input>], module: &[Directive<'input>],
) -> HashMap<&'input str, HashMap<u8, spirv::ExecutionMode>> { ) -> HashMap<CallgraphKey<'input>, HashMap<u8, spirv::FPDenormMode>> {
let mut direct_func_calls = MultiHashMap::new();
let mut denorm_methods = HashMap::new(); let mut denorm_methods = HashMap::new();
for directive in module.iter() { for directive in module.iter() {
match directive { match directive {
@ -783,9 +782,7 @@ fn compute_denorm_information<'input>(
} }
Statement::LoadVar(_, _) => {} Statement::LoadVar(_, _) => {}
Statement::StoreVar(_, _) => {} Statement::StoreVar(_, _) => {}
Statement::Call(ResolvedCall { func, .. }) => { Statement::Call(_) => {}
multi_hash_map_append(&mut direct_func_calls, method_key, *func);
}
Statement::Composite(_) => {} Statement::Composite(_) => {}
Statement::Conditional(_) => {} Statement::Conditional(_) => {}
Statement::Conversion(_) => {} Statement::Conversion(_) => {}
@ -800,78 +797,25 @@ fn compute_denorm_information<'input>(
} }
} }
} }
let summed_denorm_methods = sum_up_denorm_use(module, denorm_methods, &direct_func_calls); denorm_methods
summed_denorm_methods
.into_iter() .into_iter()
.filter_map(|(name, v)| { .map(|(name, v)| {
let width_to_denorm = v let width_to_denorm = v
.into_iter() .into_iter()
.map(|(k, ftz_over_preserve)| { .map(|(k, ftz_over_preserve)| {
let mode = if ftz_over_preserve > 0 { let mode = if ftz_over_preserve > 0 {
spirv::ExecutionMode::DenormFlushToZero spirv::FPDenormMode::FlushToZero
} else { } else {
spirv::ExecutionMode::DenormPreserve spirv::FPDenormMode::Preserve
}; };
(k, mode) (k, mode)
}) })
.collect(); .collect();
Some((name, width_to_denorm)) (name, width_to_denorm)
}) })
.collect() .collect()
} }
fn sum_up_denorm_use<'input>(
module: &[Directive<'input>],
denorm_methods: HashMap<CallgraphKey<'input>, DenormCountMap<u8>>,
direct_func_calls: &MultiHashMap<CallgraphKey<'input>, spirv::Word>,
) -> HashMap<&'input str, DenormCountMap<u8>> {
let mut result = HashMap::new();
let empty = Vec::new();
for (method_key, denorm_map) in denorm_methods.iter() {
match method_key {
CallgraphKey::Kernel(name) => {
let mut sum = denorm_map.clone();
let mut visited = HashSet::new();
for child in direct_func_calls
.get(&CallgraphKey::Kernel(name))
.unwrap_or(&empty)
{
sum_up_denorm_use_single(
&denorm_methods,
direct_func_calls,
&mut sum,
&mut visited,
*child,
);
}
result.insert(*name, sum);
}
CallgraphKey::Func(_) => {}
}
}
result
}
fn sum_up_denorm_use_single<'input>(
denorm_methods: &HashMap<CallgraphKey<'input>, DenormCountMap<u8>>,
direct_func_calls: &MultiHashMap<CallgraphKey<'input>, spirv::Word>,
sum: &mut DenormCountMap<u8>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
if !visited.insert(current) {
return;
}
if let Some(denorm_map) = denorm_methods.get(&CallgraphKey::Func(current)) {
denorm_count_map_merge(sum, denorm_map);
}
if let Some(children) = direct_func_calls.get(&CallgraphKey::Func(current)) {
for child in children {
sum_up_denorm_use_single(denorm_methods, direct_func_calls, sum, visited, *child);
}
}
}
#[derive(Hash, PartialEq, Eq, Copy, Clone)] #[derive(Hash, PartialEq, Eq, Copy, Clone)]
enum CallgraphKey<'input> { enum CallgraphKey<'input> {
Kernel(&'input str), Kernel(&'input str),
@ -919,7 +863,7 @@ fn emit_function_header<'a>(
map: &mut TypeWordMap, map: &mut TypeWordMap,
global: &GlobalStringIdResolver<'a>, global: &GlobalStringIdResolver<'a>,
func_directive: ast::MethodDecl<spirv::Word>, func_directive: ast::MethodDecl<spirv::Word>,
denorm_information: &HashMap<&'a str, HashMap<u8, spirv::ExecutionMode>>, denorm_information: &HashMap<CallgraphKey<'a>, HashMap<u8, spirv::FPDenormMode>>,
kernel_info: &mut HashMap<String, KernelInfo>, kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
if let ast::MethodDecl::Kernel { if let ast::MethodDecl::Kernel {
@ -953,11 +897,6 @@ fn emit_function_header<'a>(
.collect::<Vec<_>>(); .collect::<Vec<_>>();
global_variables.append(&mut interface); global_variables.append(&mut interface);
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
if let Some(exec_modes) = denorm_information.get(name) {
for (size_of, exec_mode) in exec_modes {
builder.execution_mode(fn_id, *exec_mode, [(*size_of as u32) * 8])
}
}
fn_id fn_id
} }
ast::MethodDecl::Func(_, name, _) => name, ast::MethodDecl::Func(_, name, _) => name,
@ -968,6 +907,18 @@ fn emit_function_header<'a>(
spirv::FunctionControl::NONE, spirv::FunctionControl::NONE,
func_type, func_type,
)?; )?;
if let Some(denorm_modes) = denorm_information.get(&CallgraphKey::new(&func_directive)) {
for (size_of, denorm_mode) in denorm_modes {
builder.decorate(
fn_id,
spirv::Decoration::FunctionDenormModeINTEL,
[
dr::Operand::LiteralInt32((*size_of as u32) * 8),
dr::Operand::FPDenormMode(*denorm_mode),
],
)
}
}
func_directive.visit_args(&mut |arg| { func_directive.visit_args(&mut |arg| {
let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type.clone()).into()); let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type.clone()).into());
let inst = dr::Instruction::new( let inst = dr::Instruction::new(
@ -1005,13 +956,12 @@ fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::Int64); builder.capability(spirv::Capability::Int64);
builder.capability(spirv::Capability::Float16); builder.capability(spirv::Capability::Float16);
builder.capability(spirv::Capability::Float64); builder.capability(spirv::Capability::Float64);
builder.capability(spirv::Capability::DenormFlushToZero); builder.capability(spirv::Capability::FunctionFloatControlINTEL);
builder.capability(spirv::Capability::DenormPreserve);
} }
// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html // http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html
fn emit_extensions(builder: &mut dr::Builder) { fn emit_extensions(builder: &mut dr::Builder) {
builder.extension("SPV_KHR_float_controls"); builder.extension("SPV_INTEL_float_controls2");
} }
fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {