mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-19 18:26:26 +03:00
Fix ftz behavior slightly
This commit is contained in:
@ -83,8 +83,11 @@ test_ptx!(extern_shared_call, [121u64], [123u64]);
|
|||||||
test_ptx!(rcp, [2f32], [0.5f32]);
|
test_ptx!(rcp, [2f32], [0.5f32]);
|
||||||
// 0b1_00000000_10000000000000000000000u32 is a large denormal
|
// 0b1_00000000_10000000000000000000000u32 is a large denormal
|
||||||
// 0x3f000000 is 0.5
|
// 0x3f000000 is 0.5
|
||||||
// TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2
|
test_ptx!(
|
||||||
// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]);
|
mul_ftz,
|
||||||
|
[0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
|
||||||
|
[0b1_00000000_00000000000000000000000u32]
|
||||||
|
);
|
||||||
test_ptx!(
|
test_ptx!(
|
||||||
mul_non_ftz,
|
mul_non_ftz,
|
||||||
[0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
|
[0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
|
||||||
@ -196,7 +199,12 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
|||||||
let (module, maybe_log) = match module.should_link_ptx_impl {
|
let (module, maybe_log) = match module.should_link_ptx_impl {
|
||||||
Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]),
|
Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]),
|
||||||
None => {
|
None => {
|
||||||
let (module, log) = ze::Module::build_spirv(&mut ctx, &dev, byte_il, None);
|
let (module, log) = ze::Module::build_spirv(
|
||||||
|
&mut ctx,
|
||||||
|
&dev,
|
||||||
|
byte_il,
|
||||||
|
Some(module.build_options.as_c_str()),
|
||||||
|
);
|
||||||
(module, Some(log))
|
(module, Some(log))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,64 +1,55 @@
|
|||||||
; SPIR-V
|
OpCapability GenericPointer
|
||||||
; Version: 1.3
|
OpCapability Linkage
|
||||||
; Generator: rspirv
|
OpCapability Addresses
|
||||||
; Bound: 38
|
OpCapability Kernel
|
||||||
OpCapability GenericPointer
|
OpCapability Int8
|
||||||
OpCapability Linkage
|
OpCapability Int16
|
||||||
OpCapability Addresses
|
OpCapability Int64
|
||||||
OpCapability Kernel
|
OpCapability Float16
|
||||||
OpCapability Int8
|
OpCapability Float64
|
||||||
OpCapability Int16
|
%28 = OpExtInstImport "OpenCL.std"
|
||||||
OpCapability Int64
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpCapability Float16
|
OpEntryPoint Kernel %1 "mul_ftz"
|
||||||
OpCapability Float64
|
%void = OpTypeVoid
|
||||||
; OpCapability FunctionFloatControlINTEL
|
%ulong = OpTypeInt 64 0
|
||||||
; OpExtension "SPV_INTEL_float_controls2"
|
%31 = OpTypeFunction %void %ulong %ulong
|
||||||
%30 = OpExtInstImport "OpenCL.std"
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
OpMemoryModel Physical64 OpenCL
|
%float = OpTypeFloat 32
|
||||||
OpEntryPoint Kernel %1 "mul_ftz"
|
%_ptr_Function_float = OpTypePointer Function %float
|
||||||
OpDecorate %1 FunctionDenormModeINTEL 32 FlushToZero
|
%_ptr_Generic_float = OpTypePointer Generic %float
|
||||||
%31 = OpTypeVoid
|
%ulong_4 = OpConstant %ulong 4
|
||||||
%32 = OpTypeInt 64 0
|
%1 = OpFunction %void None %31
|
||||||
%33 = OpTypeFunction %31 %32 %32
|
%8 = OpFunctionParameter %ulong
|
||||||
%34 = OpTypePointer Function %32
|
%9 = OpFunctionParameter %ulong
|
||||||
%35 = OpTypeFloat 32
|
%26 = OpLabel
|
||||||
%36 = OpTypePointer Function %35
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
%37 = OpTypePointer Generic %35
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
%23 = OpConstant %32 4
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
%1 = OpFunction %31 None %33
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
%8 = OpFunctionParameter %32
|
%6 = OpVariable %_ptr_Function_float Function
|
||||||
%9 = OpFunctionParameter %32
|
%7 = OpVariable %_ptr_Function_float Function
|
||||||
%28 = OpLabel
|
OpStore %2 %8
|
||||||
%2 = OpVariable %34 Function
|
OpStore %3 %9
|
||||||
%3 = OpVariable %34 Function
|
%10 = OpLoad %ulong %2
|
||||||
%4 = OpVariable %34 Function
|
OpStore %4 %10
|
||||||
%5 = OpVariable %34 Function
|
%11 = OpLoad %ulong %3
|
||||||
%6 = OpVariable %36 Function
|
OpStore %5 %11
|
||||||
%7 = OpVariable %36 Function
|
%13 = OpLoad %ulong %4
|
||||||
OpStore %2 %8
|
%23 = OpConvertUToPtr %_ptr_Generic_float %13
|
||||||
OpStore %3 %9
|
%12 = OpLoad %float %23
|
||||||
%11 = OpLoad %32 %2
|
OpStore %6 %12
|
||||||
%10 = OpCopyObject %32 %11
|
%15 = OpLoad %ulong %4
|
||||||
OpStore %4 %10
|
%22 = OpIAdd %ulong %15 %ulong_4
|
||||||
%13 = OpLoad %32 %3
|
%24 = OpConvertUToPtr %_ptr_Generic_float %22
|
||||||
%12 = OpCopyObject %32 %13
|
%14 = OpLoad %float %24
|
||||||
OpStore %5 %12
|
OpStore %7 %14
|
||||||
%15 = OpLoad %32 %4
|
%17 = OpLoad %float %6
|
||||||
%25 = OpConvertUToPtr %37 %15
|
%18 = OpLoad %float %7
|
||||||
%14 = OpLoad %35 %25
|
%16 = OpFMul %float %17 %18
|
||||||
OpStore %6 %14
|
OpStore %6 %16
|
||||||
%17 = OpLoad %32 %4
|
%19 = OpLoad %ulong %5
|
||||||
%24 = OpIAdd %32 %17 %23
|
%20 = OpLoad %float %6
|
||||||
%26 = OpConvertUToPtr %37 %24
|
%25 = OpConvertUToPtr %_ptr_Generic_float %19
|
||||||
%16 = OpLoad %35 %26
|
OpStore %25 %20
|
||||||
OpStore %7 %16
|
OpReturn
|
||||||
%19 = OpLoad %35 %6
|
OpFunctionEnd
|
||||||
%20 = OpLoad %35 %7
|
|
||||||
%18 = OpFMul %35 %19 %20
|
|
||||||
OpStore %6 %18
|
|
||||||
%21 = OpLoad %32 %5
|
|
||||||
%22 = OpLoad %35 %6
|
|
||||||
%27 = OpConvertUToPtr %37 %21
|
|
||||||
OpStore %27 %22
|
|
||||||
OpReturn
|
|
||||||
OpFunctionEnd
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::ast;
|
use crate::ast;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use rspirv::{binary::Disassemble, dr};
|
use rspirv::{binary::Disassemble, dr};
|
||||||
use std::{borrow::Cow, convert::TryFrom, hash::Hash, iter, mem};
|
use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{hash_map, HashMap, HashSet},
|
collections::{hash_map, HashMap, HashSet},
|
||||||
convert::TryInto,
|
convert::TryInto,
|
||||||
@ -448,6 +448,7 @@ pub struct Module {
|
|||||||
pub spirv: dr::Module,
|
pub spirv: dr::Module,
|
||||||
pub kernel_info: HashMap<String, KernelInfo>,
|
pub kernel_info: HashMap<String, KernelInfo>,
|
||||||
pub should_link_ptx_impl: Option<&'static [u8]>,
|
pub should_link_ptx_impl: Option<&'static [u8]>,
|
||||||
|
pub build_options: CString,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct KernelInfo {
|
pub struct KernelInfo {
|
||||||
@ -484,6 +485,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
|
|||||||
let mut map = TypeWordMap::new(&mut builder);
|
let mut map = TypeWordMap::new(&mut builder);
|
||||||
emit_builtins(&mut builder, &mut map, &id_defs);
|
emit_builtins(&mut builder, &mut map, &id_defs);
|
||||||
let mut kernel_info = HashMap::new();
|
let mut kernel_info = HashMap::new();
|
||||||
|
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
|
||||||
emit_directives(
|
emit_directives(
|
||||||
&mut builder,
|
&mut builder,
|
||||||
&mut map,
|
&mut map,
|
||||||
@ -503,15 +505,51 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
},
|
},
|
||||||
|
build_options,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: remove this once we have perf-function support for denorms
|
||||||
|
fn emit_denorm_build_string(
|
||||||
|
call_map: &HashMap<&str, HashSet<u32>>,
|
||||||
|
denorm_information: &HashMap<MethodName, HashMap<u8, (spirv::FPDenormMode, isize)>>,
|
||||||
|
) -> CString {
|
||||||
|
let denorm_counts = denorm_information
|
||||||
|
.iter()
|
||||||
|
.map(|(method, meth_denorm)| {
|
||||||
|
let f16_count = meth_denorm
|
||||||
|
.get(&(mem::size_of::<f16>() as u8))
|
||||||
|
.unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
|
||||||
|
.1;
|
||||||
|
let f32_count = meth_denorm
|
||||||
|
.get(&(mem::size_of::<f32>() as u8))
|
||||||
|
.unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
|
||||||
|
.1;
|
||||||
|
(method, (f16_count + f32_count))
|
||||||
|
})
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
let mut flush_over_preserve = 0;
|
||||||
|
for (kernel, children) in call_map {
|
||||||
|
flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
|
||||||
|
for child_fn in children {
|
||||||
|
flush_over_preserve += *denorm_counts
|
||||||
|
.get(&MethodName::Func(*child_fn))
|
||||||
|
.unwrap_or(&0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if flush_over_preserve > 0 {
|
||||||
|
CString::new("-cl-denorms-are-zero").unwrap()
|
||||||
|
} else {
|
||||||
|
CString::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_directives<'input>(
|
fn emit_directives<'input>(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
id_defs: &GlobalStringIdResolver<'input>,
|
id_defs: &GlobalStringIdResolver<'input>,
|
||||||
opencl_id: spirv::Word,
|
opencl_id: spirv::Word,
|
||||||
denorm_information: &HashMap<MethodName<'input>, HashMap<u8, spirv::FPDenormMode>>,
|
denorm_information: &HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
|
||||||
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
|
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
|
||||||
directives: Vec<Directive>,
|
directives: Vec<Directive>,
|
||||||
kernel_info: &mut HashMap<String, KernelInfo>,
|
kernel_info: &mut HashMap<String, KernelInfo>,
|
||||||
@ -579,6 +617,9 @@ fn get_call_map<'input>(
|
|||||||
..
|
..
|
||||||
}) => {
|
}) => {
|
||||||
let call_key = MethodName::new(&func_decl);
|
let call_key = MethodName::new(&func_decl);
|
||||||
|
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
|
||||||
|
entry.insert(Vec::new());
|
||||||
|
}
|
||||||
for statement in statements {
|
for statement in statements {
|
||||||
match statement {
|
match statement {
|
||||||
Statement::Call(call) => {
|
Statement::Call(call) => {
|
||||||
@ -895,7 +936,7 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
|
|||||||
// and emit suitable execution mode
|
// and emit suitable execution mode
|
||||||
fn compute_denorm_information<'input>(
|
fn compute_denorm_information<'input>(
|
||||||
module: &[Directive<'input>],
|
module: &[Directive<'input>],
|
||||||
) -> HashMap<MethodName<'input>, HashMap<u8, spirv::FPDenormMode>> {
|
) -> HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
|
||||||
let mut denorm_methods = HashMap::new();
|
let mut denorm_methods = HashMap::new();
|
||||||
for directive in module {
|
for directive in module {
|
||||||
match directive {
|
match directive {
|
||||||
@ -937,13 +978,13 @@ fn compute_denorm_information<'input>(
|
|||||||
.map(|(name, v)| {
|
.map(|(name, v)| {
|
||||||
let width_to_denorm = v
|
let width_to_denorm = v
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(k, ftz_over_preserve)| {
|
.map(|(k, flush_over_preserve)| {
|
||||||
let mode = if ftz_over_preserve > 0 {
|
let mode = if flush_over_preserve > 0 {
|
||||||
spirv::FPDenormMode::FlushToZero
|
spirv::FPDenormMode::FlushToZero
|
||||||
} else {
|
} else {
|
||||||
spirv::FPDenormMode::Preserve
|
spirv::FPDenormMode::Preserve
|
||||||
};
|
};
|
||||||
(k, mode)
|
(k, (mode, flush_over_preserve))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
(name, width_to_denorm)
|
(name, width_to_denorm)
|
||||||
@ -999,7 +1040,7 @@ fn emit_function_header<'a>(
|
|||||||
defined_globals: &GlobalStringIdResolver<'a>,
|
defined_globals: &GlobalStringIdResolver<'a>,
|
||||||
synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
|
synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
|
||||||
func_decl: &SpirvMethodDecl<'a>,
|
func_decl: &SpirvMethodDecl<'a>,
|
||||||
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, spirv::FPDenormMode>>,
|
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
|
||||||
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
|
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
|
||||||
direcitves: &[Directive],
|
direcitves: &[Directive],
|
||||||
kernel_info: &mut HashMap<String, KernelInfo>,
|
kernel_info: &mut HashMap<String, KernelInfo>,
|
||||||
|
Reference in New Issue
Block a user