Carry state space with pointer

This commit is contained in:
Andrzej Janik
2021-05-15 15:58:11 +02:00
parent 425edfcdd4
commit 82b5cef0bd
3 changed files with 108 additions and 106 deletions

View File

@ -108,10 +108,49 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a
#[derive(PartialEq, Eq, Clone)]
pub enum Type {
// .param.b32 foo;
// -> OpTypeInt
Scalar(ScalarType),
// .param.v2.b32 foo;
// -> OpTypeVector
Vector(ScalarType, u8),
// .param.b32 foo[4];
// -> OpTypeArray
Array(ScalarType, Vec<u32>),
Pointer(ScalarType),
/*
Variables of this type almost never exist in the original .ptx and are
usually artificially created. Some examples below:
- extern pointers to the .shared memory in the form:
.extern .shared .b32 shared_mem[];
which we first parse as
.extern .shared .b32 shared_mem;
and then convert to an additional function parameter:
.param .ptr<.b32.shared> shared_mem;
and do a load at the start of the function (and renames inside fn):
.reg .ptr<.b32.shared> temp;
ld.param.ptr<.b32.shared> temp, [shared_mem];
note, we don't support non-.shared extern pointers, because there's
zero use for them in the ptxas
- artifical pointers created by stateful conversion, which work
similiarly to the above
- function parameters:
foobar(.param .align 4 .b8 numbers[])
which get parsed to
foobar(.param .align 4 .b8 numbers)
and then converted to
foobar(.reg .align 4 .ptr<.b8.param> numbers)
- ld/st with offset:
.reg.b32 x;
.param.b64 arg0;
st.param.b32 [arg0+4], x;
Yes, this code is legal and actually emitted by the NV compiler!
We convert the st to:
.reg ptr<.b64.param> temp = ptr_offset(arg0, 4);
st.param.b32 [temp], x;
*/
// .reg ptr<.b64.param>
// -> OpTypePointer Function
Pointer(ScalarType, StateSpace),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]

View File

@ -624,9 +624,9 @@ ModuleVariable: ast::Variable<&'input str> = {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
}
if space == ".global" {
(ast::Type::Pointer(t), ast::StateSpace::Global, Vec::new())
(ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new())
} else {
(ast::Type::Pointer(t), ast::StateSpace::Shared, Vec::new())
(ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new())
}
}
};
@ -648,7 +648,7 @@ ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
(ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
(ast::Type::Pointer(t), Vec::new())
(ast::Type::Scalar(t), Vec::new())
}
};
(align, array_init, v_type, name)

View File

@ -56,33 +56,20 @@ enum SpirvType {
}
impl SpirvType {
fn new(t: ast::Type, decl_space: ast::StateSpace) -> Self {
fn new(t: ast::Type) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
ast::Type::Pointer(pointer_t) => {
let spirv_space = match decl_space {
ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
spirv::StorageClass::Private
}
ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
ast::StateSpace::Generic => spirv::StorageClass::Generic,
ast::StateSpace::Sreg => spirv::StorageClass::Input,
};
SpirvType::Pointer(Box::new(SpirvType::Base(pointer_t.into())), spirv_space)
}
ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
Box::new(SpirvType::Base(pointer_t.into())),
space.to_spirv(),
),
}
}
fn pointer_to(
t: ast::Type,
inner_space: ast::StateSpace,
outer_space: spirv::StorageClass,
) -> Self {
let key = Self::new(t, inner_space);
fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self {
let key = Self::new(t);
SpirvType::Pointer(Box::new(key), outer_space)
}
}
@ -394,7 +381,7 @@ impl TypeWordMap {
b.constant_composite(result_type, None, components.into_iter())
}
},
ast::Type::Pointer(typ) => return Err(error_unreachable()),
ast::Type::Pointer(..) => return Err(error_unreachable()),
})
}
@ -453,7 +440,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
})
.collect::<Result<Vec<_>, _>>()?;
let must_link_ptx_impl = ptx_impl_imports.len() > 0;
let directives = ptx_impl_imports
let mut directives = ptx_impl_imports
.into_iter()
.map(|(_, v)| v)
.chain(directives.into_iter())
@ -461,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
let call_map = get_kernels_call_map(&directives);
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
//let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@ -725,6 +712,7 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
transformation has a semantical meaning - we emit additional
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
*/
/*
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
@ -819,7 +807,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
ast::Variable {
name: shared_id_param,
align: None,
v_type: ast::Type::Pointer(ast::ScalarType::B8),
v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()),
state_space: ast::StateSpace::Shared,
array_init: Vec::new(),
}
@ -937,6 +925,7 @@ fn get_callers_of_extern_shared_single<'a>(
}
}
}
*/
type DenormCountMap<T> = HashMap<T, isize>;
@ -1031,11 +1020,7 @@ fn emit_builtins(
for (reg, id) in id_defs.special_registers.builtins() {
let result_type = map.get_or_add(
builder,
SpirvType::pointer_to(
reg.get_type(),
ast::StateSpace::Reg,
spirv::StorageClass::Input,
),
SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input),
);
builder.variable(result_type, Some(id), spirv::StorageClass::Input, None);
builder.decorate(
@ -1144,10 +1129,7 @@ fn emit_function_header<'a>(
}
*/
for input in &func_decl.input_arguments {
let result_type = map.get_or_add(
builder,
SpirvType::new(input.v_type.clone(), input.state_space),
);
let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone()));
builder.function_parameter(Some(input.name), result_type)?;
}
Ok(fn_id)
@ -1753,8 +1735,8 @@ fn to_ptx_impl_atomic_call(
input_arguments: vec![
ast::Variable {
align: None,
v_type: ast::Type::Pointer(typ),
state_space: ptr_space,
v_type: ast::Type::Pointer(typ, ptr_space),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
@ -1791,7 +1773,11 @@ fn to_ptx_impl_atomic_call(
func: fn_id,
ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
param_list: vec![
(arg.src1, ast::Type::Pointer(typ), ptr_space),
(
arg.src1,
ast::Type::Pointer(typ, ptr_space),
ast::StateSpace::Reg,
),
(
arg.src2,
ast::Type::Scalar(scalar_typ),
@ -2629,8 +2615,8 @@ fn insert_implicit_conversions(
is_dst: false,
sema: ArgumentSemantics::PhysicalPointer,
},
typ: &ast::Type::Pointer(underlying_type),
state_space,
typ: &ast::Type::Pointer(underlying_type, state_space),
state_space: new_todo!(),
stmt_ctor: |new_ptr_src| {
Statement::PtrAccess(PtrAccess {
underlying_type,
@ -2758,10 +2744,10 @@ fn get_function_type(
builder,
spirv_input
.iter()
.map(|var| SpirvType::new(var.v_type.clone(), var.state_space)),
.map(|var| SpirvType::new(var.v_type.clone())),
spirv_output
.iter()
.map(|var| SpirvType::new(var.v_type.clone(), var.state_space)),
.map(|var| SpirvType::new(var.v_type.clone())),
)
}
@ -2790,7 +2776,7 @@ fn emit_function_body_ops(
Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
[(id, typ, space)] => (
map.get_or_add(builder, SpirvType::new(typ.clone(), *space)),
map.get_or_add(builder, SpirvType::new(typ.clone())),
Some(*id),
),
[] => (map.void(), None),
@ -2922,10 +2908,8 @@ fn emit_function_body_ops(
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
let result_type = map.get_or_add(
builder,
SpirvType::new(ast::Type::from(data.typ.clone()), data.state_space),
);
let result_type =
map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
builder.load(
result_type,
Some(arg.dst),
@ -2956,10 +2940,8 @@ fn emit_function_body_ops(
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(d, arg) => {
let result_type = map.get_or_add(
builder,
SpirvType::new(ast::Type::from(d.typ.clone()), ast::StateSpace::Reg),
);
let result_type =
map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone())));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Mul(mul, arg) => match mul {
@ -3000,8 +2982,7 @@ fn emit_function_body_ops(
ast::Instruction::Shl(t, a) => {
let full_type = ast::Type::Scalar(*t);
let size_of = full_type.size_of();
let result_type =
map.get_or_add(builder, SpirvType::new(full_type, ast::StateSpace::Reg));
let result_type = map.get_or_add(builder, SpirvType::new(full_type));
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?;
}
@ -3265,7 +3246,6 @@ fn emit_function_body_ops(
builder,
SpirvType::pointer_to(
details.typ.clone(),
details.state_space,
spirv::StorageClass::Function,
),
);
@ -3297,11 +3277,11 @@ fn emit_function_body_ops(
}) => {
let u8_pointer = map.get_or_add(
builder,
SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8), *state_space),
SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)),
);
let result_type = map.get_or_add(
builder,
SpirvType::new(ast::Type::Pointer(*underlying_type), *state_space),
SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
);
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain(
@ -3596,15 +3576,12 @@ fn emit_variable(
&*var.array_init,
)?)
} else if must_init {
let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone(), var.state_space));
let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone()));
Some(builder.constant_null(type_id, None))
} else {
None
};
let ptr_type_id = map.get_or_add(
builder,
SpirvType::pointer_to(var.v_type.clone(), var.state_space, st_class),
);
let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class));
builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
if let Some(align) = var.align {
builder.decorate(
@ -3742,10 +3719,7 @@ fn emit_min(
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
};
let inst_type = map.get_or_add(
builder,
SpirvType::new(desc.get_type(), ast::StateSpace::Reg),
);
let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
@ -3770,10 +3744,7 @@ fn emit_max(
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
};
let inst_type = map.get_or_add(
builder,
SpirvType::new(desc.get_type(), ast::StateSpace::Reg),
);
let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
@ -4255,14 +4226,13 @@ fn emit_implicit_conversion(
(_, _, ConversionKind::BitToPtr) => {
let dst_type = map.get_or_add(
builder,
SpirvType::pointer_to(cv.to_type.clone(), cv.from_space, cv.to_space.to_spirv()),
SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width {
let dst_type =
map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
if from_parts.scalar_kind != ast::ScalarKind::Float
&& to_parts.scalar_kind != ast::ScalarKind::Float
{
@ -4275,13 +4245,10 @@ fn emit_implicit_conversion(
// This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = map.get_or_add(
builder,
SpirvType::new(
ast::Type::from_parts(TypeParts {
scalar_kind: ast::ScalarKind::Bit,
..from_parts
}),
cv.from_space,
),
SpirvType::new(ast::Type::from_parts(TypeParts {
scalar_kind: ast::ScalarKind::Bit,
..from_parts
})),
);
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts {
@ -4289,7 +4256,7 @@ fn emit_implicit_conversion(
..to_parts
});
let wide_bit_type_spirv =
map.get_or_add(builder, SpirvType::new(wide_bit_type.clone(), cv.to_space));
map.get_or_add(builder, SpirvType::new(wide_bit_type.clone()));
if to_parts.scalar_kind == ast::ScalarKind::Unsigned
|| to_parts.scalar_kind == ast::ScalarKind::Bit
{
@ -4323,15 +4290,13 @@ fn emit_implicit_conversion(
}
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => {
let result_type =
map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
builder.s_convert(result_type, Some(cv.dst), cv.src)?;
}
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type =
map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
(_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
@ -4339,12 +4304,12 @@ fn emit_implicit_conversion(
map.get_or_add(
builder,
SpirvType::Pointer(
Box::new(SpirvType::new(cv.to_type.clone(), cv.to_space)),
Box::new(SpirvType::new(cv.to_type.clone())),
spirv::StorageClass::Function,
),
)
} else {
map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space))
map.get_or_add(builder, SpirvType::new(cv.to_type.clone()))
};
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
}
@ -4358,18 +4323,14 @@ fn emit_load_var(
map: &mut TypeWordMap,
details: &LoadVarDetails,
) -> Result<(), TranslateError> {
let result_type = map.get_or_add(
builder,
SpirvType::new(details.typ.clone(), details.state_space),
);
let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone()));
match details.member_index {
Some((index, Some(width))) => {
let vector_type = match details.typ {
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
_ => return Err(TranslateError::MismatchedType),
};
let vector_type_spirv =
map.get_or_add(builder, SpirvType::new(vector_type, details.state_space));
let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
let vector_temp = builder.load(
vector_type_spirv,
None,
@ -4387,11 +4348,7 @@ fn emit_load_var(
Some((index, None)) => {
let result_ptr_type = map.get_or_add(
builder,
SpirvType::pointer_to(
details.typ.clone(),
details.state_space,
spirv::StorageClass::Function,
),
SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function),
);
let index_spirv = map.get_or_add_constant(
builder,
@ -5661,7 +5618,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
ast::StateSpace::Reg => new_todo!(),
ast::StateSpace::Sreg => new_todo!(),
};
let ptr_type = ast::Type::Pointer(self.underlying_type.clone());
let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), new_todo!());
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
@ -6231,24 +6188,28 @@ impl ast::Type {
match self {
ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar,
state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*components as u32],
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: components.clone(),
},
ast::Type::Pointer(scalar) => TypeParts {
kind: TypeKind::PointerScalar,
ast::Type::Pointer(scalar, space) => TypeParts {
kind: TypeKind::Pointer,
state_space: *space,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
@ -6269,9 +6230,10 @@ impl ast::Type {
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
TypeKind::PointerScalar => {
ast::Type::Pointer(ast::ScalarType::from_parts(t.width, t.scalar_kind))
}
TypeKind::Pointer => ast::Type::Pointer(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.state_space,
),
}
}
@ -6292,6 +6254,7 @@ struct TypeParts {
kind: TypeKind,
scalar_kind: ast::ScalarKind,
width: u8,
state_space: ast::StateSpace,
components: Vec<u32>,
}
@ -6300,7 +6263,7 @@ enum TypeKind {
Scalar,
Vector,
Array,
PointerScalar,
Pointer,
}
impl ast::Instruction<ExpandedArgParams> {