mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-20 02:36:35 +03:00
Support kernel tuning directives
This commit is contained in:
@ -283,6 +283,7 @@ pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>;
|
|||||||
|
|
||||||
pub struct Function<'a, ID, S> {
|
pub struct Function<'a, ID, S> {
|
||||||
pub func_directive: MethodDecl<'a, ID>,
|
pub func_directive: MethodDecl<'a, ID>,
|
||||||
|
pub tuning: Vec<TuningDirective>,
|
||||||
pub body: Option<Vec<S>>,
|
pub body: Option<Vec<S>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1369,6 +1370,14 @@ bitflags! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||||
|
pub enum TuningDirective {
|
||||||
|
MaxNReg(u32),
|
||||||
|
MaxNtid(u32, u32, u32),
|
||||||
|
ReqNtid(u32, u32, u32),
|
||||||
|
MinNCtaPerSm(u32),
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -87,6 +87,9 @@ match {
|
|||||||
".ltu",
|
".ltu",
|
||||||
".lu",
|
".lu",
|
||||||
".max",
|
".max",
|
||||||
|
".maxnreg",
|
||||||
|
".maxntid",
|
||||||
|
".minnctapersm",
|
||||||
".min",
|
".min",
|
||||||
".nan",
|
".nan",
|
||||||
".NaN",
|
".NaN",
|
||||||
@ -100,6 +103,7 @@ match {
|
|||||||
".reg",
|
".reg",
|
||||||
".relaxed",
|
".relaxed",
|
||||||
".release",
|
".release",
|
||||||
|
".reqntid",
|
||||||
".rm",
|
".rm",
|
||||||
".rmi",
|
".rmi",
|
||||||
".rn",
|
".rn",
|
||||||
@ -356,15 +360,27 @@ AddressSize = {
|
|||||||
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
|
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
|
||||||
LinkingDirectives
|
LinkingDirectives
|
||||||
<func_directive:MethodDecl>
|
<func_directive:MethodDecl>
|
||||||
|
<tuning:TuningDirective*>
|
||||||
<body:FunctionBody> => ast::Function{<>}
|
<body:FunctionBody> => ast::Function{<>}
|
||||||
};
|
};
|
||||||
|
|
||||||
LinkingDirective: ast::LinkingDirective = {
|
LinkingDirective: ast::LinkingDirective = {
|
||||||
".extern" => ast::LinkingDirective::EXTERN,
|
".extern" => ast::LinkingDirective::EXTERN,
|
||||||
".visible" => ast::LinkingDirective::VISIBLE,
|
".visible" => ast::LinkingDirective::VISIBLE,
|
||||||
".weak" => ast::LinkingDirective::WEAK,
|
".weak" => ast::LinkingDirective::WEAK,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TuningDirective: ast::TuningDirective = {
|
||||||
|
".maxnreg" <ncta:U32Num> => ast::TuningDirective::MaxNReg(ncta),
|
||||||
|
".maxntid" <nx:U32Num> => ast::TuningDirective::MaxNtid(nx, 1, 1),
|
||||||
|
".maxntid" <nx:U32Num> "," <ny:U32Num> => ast::TuningDirective::MaxNtid(nx, ny, 1),
|
||||||
|
".maxntid" <nx:U32Num> "," <ny:U32Num> "," <nz:U32Num> => ast::TuningDirective::MaxNtid(nx, ny, nz),
|
||||||
|
".reqntid" <nx:U32Num> => ast::TuningDirective::ReqNtid(nx, 1, 1),
|
||||||
|
".reqntid" <nx:U32Num> "," <ny:U32Num> => ast::TuningDirective::ReqNtid(nx, ny, 1),
|
||||||
|
".reqntid" <nx:U32Num> "," <ny:U32Num> "," <nz:U32Num> => ast::TuningDirective::ReqNtid(nx, ny, nz),
|
||||||
|
".minnctapersm" <ncta:U32Num> => ast::TuningDirective::MinNCtaPerSm(ncta),
|
||||||
|
};
|
||||||
|
|
||||||
LinkingDirectives: ast::LinkingDirective = {
|
LinkingDirectives: ast::LinkingDirective = {
|
||||||
<ldirs:LinkingDirective*> => {
|
<ldirs:LinkingDirective*> => {
|
||||||
ldirs.into_iter().fold(ast::LinkingDirective::NONE, |x, y| x | y)
|
ldirs.into_iter().fold(ast::LinkingDirective::NONE, |x, y| x | y)
|
||||||
|
24
ptx/src/test/spirv_run/add_tuning.ptx
Normal file
24
ptx/src/test/spirv_run/add_tuning.ptx
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry add_tuning(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
.maxntid 256, 1, 1
|
||||||
|
.minnctapersm 4
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .u64 temp;
|
||||||
|
.reg .u64 temp2;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.u64 temp, [in_addr];
|
||||||
|
add.u64 temp2, temp, 1;
|
||||||
|
st.u64 [out_addr], temp2;
|
||||||
|
ret;
|
||||||
|
}
|
48
ptx/src/test/spirv_run/add_tuning.spvtxt
Normal file
48
ptx/src/test/spirv_run/add_tuning.spvtxt
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int8
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Float64
|
||||||
|
%23 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "add_tuning"
|
||||||
|
OpExecutionMode %1 MaxWorkgroupSizeINTEL 256 1 1
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%26 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||||
|
%ulong_1 = OpConstant %ulong 1
|
||||||
|
%1 = OpFunction %void None %26
|
||||||
|
%8 = OpFunctionParameter %ulong
|
||||||
|
%9 = OpFunctionParameter %ulong
|
||||||
|
%21 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%7 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %8
|
||||||
|
OpStore %3 %9
|
||||||
|
%10 = OpLoad %ulong %2 Aligned 8
|
||||||
|
OpStore %4 %10
|
||||||
|
%11 = OpLoad %ulong %3 Aligned 8
|
||||||
|
OpStore %5 %11
|
||||||
|
%13 = OpLoad %ulong %4
|
||||||
|
%19 = OpConvertUToPtr %_ptr_Generic_ulong %13
|
||||||
|
%12 = OpLoad %ulong %19 Aligned 8
|
||||||
|
OpStore %6 %12
|
||||||
|
%15 = OpLoad %ulong %6
|
||||||
|
%14 = OpIAdd %ulong %15 %ulong_1
|
||||||
|
OpStore %7 %14
|
||||||
|
%16 = OpLoad %ulong %5
|
||||||
|
%17 = OpLoad %ulong %7
|
||||||
|
%20 = OpConvertUToPtr %_ptr_Generic_ulong %16
|
||||||
|
OpStore %20 %17 Aligned 8
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
@ -152,6 +152,7 @@ test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]);
|
|||||||
// For now, we just make sure that it builds and links
|
// For now, we just make sure that it builds and links
|
||||||
test_ptx!(assertfail, [716523871u64], [716523872u64]);
|
test_ptx!(assertfail, [716523871u64], [716523872u64]);
|
||||||
test_ptx!(cvt_s64_s32, [-1i32], [-1i64]);
|
test_ptx!(cvt_s64_s32, [-1i32], [-1i64]);
|
||||||
|
test_ptx!(add_tuning, [2u64], [3u64]);
|
||||||
|
|
||||||
struct DisplayError<T: Debug> {
|
struct DisplayError<T: Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
@ -589,7 +589,7 @@ fn emit_directives<'input>(
|
|||||||
for var in f.globals.iter() {
|
for var in f.globals.iter() {
|
||||||
emit_variable(builder, map, var)?;
|
emit_variable(builder, map, var)?;
|
||||||
}
|
}
|
||||||
emit_function_header(
|
let fn_id = emit_function_header(
|
||||||
builder,
|
builder,
|
||||||
map,
|
map,
|
||||||
&id_defs,
|
&id_defs,
|
||||||
@ -600,6 +600,27 @@ fn emit_directives<'input>(
|
|||||||
&directives,
|
&directives,
|
||||||
kernel_info,
|
kernel_info,
|
||||||
)?;
|
)?;
|
||||||
|
for t in f.tuning.iter() {
|
||||||
|
match *t {
|
||||||
|
ast::TuningDirective::MaxNtid(nx, ny, nz) => {
|
||||||
|
builder.execution_mode(
|
||||||
|
fn_id,
|
||||||
|
spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
|
||||||
|
[nx, ny, nz],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
ast::TuningDirective::ReqNtid(nx, ny, nz) => {
|
||||||
|
builder.execution_mode(
|
||||||
|
fn_id,
|
||||||
|
spirv_headers::ExecutionMode::LocalSize,
|
||||||
|
[nx, ny, nz],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Too architecture specific
|
||||||
|
ast::TuningDirective::MaxNReg(..)
|
||||||
|
| ast::TuningDirective::MinNCtaPerSm(..) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
|
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
|
||||||
builder.end_function()?;
|
builder.end_function()?;
|
||||||
if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
|
if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
|
||||||
@ -729,6 +750,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
import_as,
|
import_as,
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning,
|
||||||
}) => {
|
}) => {
|
||||||
let call_key = MethodName::new(&func_decl);
|
let call_key = MethodName::new(&func_decl);
|
||||||
let statements = statements
|
let statements = statements
|
||||||
@ -752,6 +774,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
import_as,
|
import_as,
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
directive => directive,
|
directive => directive,
|
||||||
@ -770,6 +793,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
import_as,
|
import_as,
|
||||||
mut spirv_decl,
|
mut spirv_decl,
|
||||||
|
tuning,
|
||||||
}) => {
|
}) => {
|
||||||
if !methods_using_extern_shared.contains(&spirv_decl.name) {
|
if !methods_using_extern_shared.contains(&spirv_decl.name) {
|
||||||
return Directive::Method(Function {
|
return Directive::Method(Function {
|
||||||
@ -778,6 +802,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
import_as,
|
import_as,
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
let shared_id_param = new_id();
|
let shared_id_param = new_id();
|
||||||
@ -827,6 +852,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||||||
body: Some(new_statements),
|
body: Some(new_statements),
|
||||||
import_as,
|
import_as,
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
directive => directive,
|
directive => directive,
|
||||||
@ -1044,9 +1070,7 @@ fn emit_builtins(
|
|||||||
builder.decorate(
|
builder.decorate(
|
||||||
id,
|
id,
|
||||||
spirv::Decoration::BuiltIn,
|
spirv::Decoration::BuiltIn,
|
||||||
[dr::Operand::BuiltIn(reg.get_builtin())]
|
[dr::Operand::BuiltIn(reg.get_builtin())].iter().cloned(),
|
||||||
.iter()
|
|
||||||
.cloned(),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1061,7 +1085,7 @@ fn emit_function_header<'a>(
|
|||||||
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
|
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
|
||||||
direcitves: &[Directive],
|
direcitves: &[Directive],
|
||||||
kernel_info: &mut HashMap<String, KernelInfo>,
|
kernel_info: &mut HashMap<String, KernelInfo>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<spirv::Word, TranslateError> {
|
||||||
if let MethodName::Kernel(name) = func_decl.name {
|
if let MethodName::Kernel(name) = func_decl.name {
|
||||||
let input_args = if !func_decl.uses_shared_mem {
|
let input_args = if !func_decl.uses_shared_mem {
|
||||||
func_decl.input.as_slice()
|
func_decl.input.as_slice()
|
||||||
@ -1143,7 +1167,7 @@ fn emit_function_header<'a>(
|
|||||||
let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
|
let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
|
||||||
builder.function_parameter(Some(input.name), result_type)?;
|
builder.function_parameter(Some(input.name), result_type)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(fn_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn emit_capabilities(builder: &mut dr::Builder) {
|
fn emit_capabilities(builder: &mut dr::Builder) {
|
||||||
@ -1235,7 +1259,14 @@ fn translate_function<'a>(
|
|||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
|
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
|
||||||
let mut func = to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)?;
|
let mut func = to_ssa(
|
||||||
|
ptx_impl_imports,
|
||||||
|
str_resolver,
|
||||||
|
fn_resolver,
|
||||||
|
fn_decl,
|
||||||
|
f.body,
|
||||||
|
f.tuning,
|
||||||
|
)?;
|
||||||
func.import_as = import_as;
|
func.import_as = import_as;
|
||||||
if func.import_as.is_some() {
|
if func.import_as.is_some() {
|
||||||
ptx_impl_imports.insert(
|
ptx_impl_imports.insert(
|
||||||
@ -1293,6 +1324,7 @@ fn to_ssa<'input, 'b>(
|
|||||||
fn_defs: GlobalFnDeclResolver<'input, 'b>,
|
fn_defs: GlobalFnDeclResolver<'input, 'b>,
|
||||||
f_args: ast::MethodDecl<'input, spirv::Word>,
|
f_args: ast::MethodDecl<'input, spirv::Word>,
|
||||||
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
|
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
|
||||||
|
tuning: Vec<ast::TuningDirective>,
|
||||||
) -> Result<Function<'input>, TranslateError> {
|
) -> Result<Function<'input>, TranslateError> {
|
||||||
let mut spirv_decl = SpirvMethodDecl::new(&f_args);
|
let mut spirv_decl = SpirvMethodDecl::new(&f_args);
|
||||||
let f_body = match f_body {
|
let f_body = match f_body {
|
||||||
@ -1304,6 +1336,7 @@ fn to_ssa<'input, 'b>(
|
|||||||
globals: Vec::new(),
|
globals: Vec::new(),
|
||||||
import_as: None,
|
import_as: None,
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1335,6 +1368,7 @@ fn to_ssa<'input, 'b>(
|
|||||||
body: Some(f_body),
|
body: Some(f_body),
|
||||||
import_as: None,
|
import_as: None,
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1716,6 +1750,7 @@ fn to_ptx_impl_atomic_call(
|
|||||||
body: None,
|
body: None,
|
||||||
import_as: Some(entry.key().clone()),
|
import_as: Some(entry.key().clone()),
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning: Vec::new(),
|
||||||
};
|
};
|
||||||
entry.insert(Directive::Method(func));
|
entry.insert(Directive::Method(func));
|
||||||
fn_id
|
fn_id
|
||||||
@ -1809,6 +1844,7 @@ fn to_ptx_impl_bfe_call(
|
|||||||
body: None,
|
body: None,
|
||||||
import_as: Some(entry.key().clone()),
|
import_as: Some(entry.key().clone()),
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning: Vec::new(),
|
||||||
};
|
};
|
||||||
entry.insert(Directive::Method(func));
|
entry.insert(Directive::Method(func));
|
||||||
fn_id
|
fn_id
|
||||||
@ -1907,6 +1943,7 @@ fn to_ptx_impl_bfi_call(
|
|||||||
body: None,
|
body: None,
|
||||||
import_as: Some(entry.key().clone()),
|
import_as: Some(entry.key().clone()),
|
||||||
spirv_decl,
|
spirv_decl,
|
||||||
|
tuning: Vec::new(),
|
||||||
};
|
};
|
||||||
entry.insert(Directive::Method(func));
|
entry.insert(Directive::Method(func));
|
||||||
fn_id
|
fn_id
|
||||||
@ -4112,16 +4149,11 @@ fn struct2_bitcast_to_wide(
|
|||||||
dst_type_id: spirv::Word,
|
dst_type_id: spirv::Word,
|
||||||
src: spirv::Word,
|
src: spirv::Word,
|
||||||
) -> Result<(), dr::Error> {
|
) -> Result<(), dr::Error> {
|
||||||
let low_bits =
|
let low_bits = builder.composite_extract(instruction_type, None, src, [0].iter().copied())?;
|
||||||
builder.composite_extract(instruction_type, None, src, [0].iter().copied())?;
|
let high_bits = builder.composite_extract(instruction_type, None, src, [1].iter().copied())?;
|
||||||
let high_bits =
|
|
||||||
builder.composite_extract(instruction_type, None, src, [1].iter().copied())?;
|
|
||||||
let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2));
|
let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2));
|
||||||
let vector = builder.composite_construct(
|
let vector =
|
||||||
vector_type,
|
builder.composite_construct(vector_type, None, [low_bits, high_bits].iter().copied())?;
|
||||||
None,
|
|
||||||
[low_bits, high_bits].iter().copied(),
|
|
||||||
)?;
|
|
||||||
builder.bitcast(dst_type_id, Some(dst), vector)?;
|
builder.bitcast(dst_type_id, Some(dst), vector)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -5668,6 +5700,7 @@ struct Function<'input> {
|
|||||||
pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
|
pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
|
||||||
pub body: Option<Vec<ExpandedStatement>>,
|
pub body: Option<Vec<ExpandedStatement>>,
|
||||||
import_as: Option<String>,
|
import_as: Option<String>,
|
||||||
|
tuning: Vec<ast::TuningDirective>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
|
pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
|
||||||
|
Reference in New Issue
Block a user