Browse Source

chore(scripts): clarify and fix default value of bit combined enum

GitOrigin-RevId: 3716bf9bb5
release-1.3
Megvii Engine Team 4 years ago
parent
commit
d275a82318
7 changed files with 177 additions and 64 deletions
  1. +14
    -11
      dnn/scripts/gen_flatbuffers_schema.py
  2. +106
    -30
      dnn/scripts/gen_param_defs.py
  3. +13
    -2
      dnn/scripts/gen_tablegen.py
  4. +29
    -12
      imperative/python/src/ops.cpp
  5. +6
    -5
      imperative/tablegen/targets/python_c_extension.cpp
  6. +3
    -3
      src/opr/impl/dnn/dnn.oprdecl
  7. +6
    -1
      tools/param_defs/mgb_opr_param_defs.py

+ 14
- 11
dnn/scripts/gen_flatbuffers_schema.py View File

@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase):
name = p + e name = p + e
e = self._enums[(p, e)] e = self._enums[(p, e)]
self._write_doc(e.name) self._write_doc(e.name)
self._write("enum %s%s : uint {", p, e.name, indent=1)
attribute = "(bit_flags)" if e.combined else ""
self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1)
for idx, member in enumerate(e.members): for idx, member in enumerate(e.members):
self._write_doc(member) self._write_doc(member)
if e.combined:
self._write("%s=%d,", scramble_enum_member_name(str(member)),
1<<idx)
else:
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("}\n", indent=-1) self._write("}\n", indent=-1)


def _write_doc(self, doc): def _write_doc(self, doc):
@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase):
return return
self._write_doc(e.name) self._write_doc(e.name)
self._used_enum.add(key) self._used_enum.add(key)
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name,
scramble_enum_member_name(str(e.members[e.default])))
if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = scramble_enum_member_name(str(e.members[e.default]))
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)


def _resolve_const(self, v): def _resolve_const(self, v):
while v in self._cur_const_val: while v in self._cur_const_val:
@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase):
return return
self._used_enum.add((e.src_class, e.src_name)) self._used_enum.add((e.src_class, e.src_name))
enum_name = e.src_class + e.src_name enum_name = e.src_class + e.src_name
self._write(
"%s:%s = %s;", e.name_field, enum_name,
scramble_enum_member_name(str(e.src_enum.members[e.get_default()])))
s = e.src_enum
if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = scramble_enum_member_name(str(s.members[e.get_default()]))
self._write("%s:%s = %s;", e.name_field, enum_name, default)


def _get_fb_default(self, cppdefault): def _get_fb_default(self, cppdefault):
if not isinstance(cppdefault, str): if not isinstance(cppdefault, str):


+ 106
- 30
dnn/scripts/gen_param_defs.py View File

@@ -73,11 +73,21 @@ class member_defs:
"""define an enum; the result would contain both an enum class def and its """define an enum; the result would contain both an enum class def and its
corresponding data field corresponding data field


:param default: index of default member value
:param default:
for normal enum class: index of default member value
for bit combined class: tuple of index of default member value

For example, following representations of the default value for bit
combined class are all equivalent:
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...)
Enum(members=('a', 'b', 'c'), default=(0, 1), ...)
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...)


:attr name_field: name of the data field of this enum in the param :attr name_field: name of the data field of this enum in the param
struct struct
:attr member_alias: list of (member, alias) pairs
:attr member_alias:
for normal enum class: list of (member, alias) pairs
for bit combined class: list of (tuple of members, alias) paris
""" """
__slots__ = ['name', 'name_field', 'members', 'default', __slots__ = ['name', 'name_field', 'members', 'default',
'member_alias', 'combined'] 'member_alias', 'combined']
@@ -90,17 +100,11 @@ class member_defs:
name = member_defs.Doc.make(name) name = member_defs.Doc.make(name)
assert name.id[0].isupper() assert name.id[0].isupper()
members = tuple(map(member_defs.Doc.make, members)) members = tuple(map(member_defs.Doc.make, members))
if isinstance(default, str):
if default not in name_field:
raise ValueError(
"Default value '{}' does not exist.".format(default))
default = name_field.index(default)
assert isinstance(default, int)
self.name = name self.name = name
self.combined = combined self.combined = combined
self.name_field = self.get_name_field(name.id, name_field) self.name_field = self.get_name_field(name.id, name_field)
self.members = members self.members = members
self.default = default
self.default = self.normalize_enum_value(default)


self.all_enums[(param_name, name.id)] = self self.all_enums[(param_name, name.id)] = self


@@ -114,6 +118,43 @@ class member_defs:
assert isinstance(name_field, str) assert isinstance(name_field, str)
return name_field return name_field


def normalize_enum_value(self, value):
def normalize(v):
if isinstance(v, str):
if v not in self.members:
raise ValueError(
"enum member '{}' does not exist.".format(v))
v = self.members.index(v)
assert isinstance(v, int)
return v
if self.combined:
if isinstance(value, int):
value = self.decompose_combined_enum(value)
assert isinstance(value, tuple)
value = tuple(normalize(i) for i in value)
return value
else:
return normalize(value)

@staticmethod
def decompose_combined_enum(v):
"""Integer => tuple of the indexes of the enum members"""
assert isinstance(v, int)
idx = 0
members = []
while v > 0:
if v & 1:
members.append(idx)
idx += 1
v >>= 1
return tuple(members)

def compose_combined_enum(self, v):
"""tuple of members => Integer"""
assert self.combined and isinstance(v, tuple)
norm_v = self.normalize_enum_value(v)
return sum(1 << i for i in norm_v)

class Field(Base): class Field(Base):
"""define a normal data field""" """define a normal data field"""
__slots__ = ['name', 'dtype', 'default'] __slots__ = ['name', 'dtype', 'default']
@@ -146,6 +187,10 @@ class member_defs:
src_name = name src_name = name
self.src_name = src_name self.src_name = src_name
self.default = default self.default = default
# TODO: remove this assertion if needed; adding mock param_defs in
# current testing framework is too complicated, and currently we
# only allow aliasing of normal enum
assert not self.src_enum.combined


@property @property
def src_enum(self): def src_enum(self):
@@ -157,7 +202,7 @@ class member_defs:
set""" set"""
if self.default is None: if self.default is None:
return self.src_enum.default return self.src_enum.default
return self.default
return self.src_enum.normalize_enum_value(self.default)




class ParamDef: class ParamDef:
@@ -198,7 +243,7 @@ class ParamDef:
self.name.id, name, name_field, members, default, member_alias)) self.name.id, name, name_field, members, default, member_alias))
return self return self


def add_bit_combination_enum(self, name, *members, default=0,
def add_bit_combination_enum(self, name, *members, default=tuple(),
name_field=None, member_alias=[]): name_field=None, member_alias=[]):
self.members.append(member_defs.Enum( self.members.append(member_defs.Enum(
self.name.id, name, name_field, members, default, member_alias, True)) self.name.id, name, name_field, members, default, member_alias, True))
@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase):
' for idx, v in enumerate(pdata):\n' ' for idx, v in enumerate(pdata):\n'
' if isinstance(v, _EnumBase):\n' ' if isinstance(v, _EnumBase):\n'
' pdata[idx] = _enum_member2num[id(v)]\n' ' pdata[idx] = _enum_member2num[id(v)]\n'
' elif isinstance(v, _BitCombinedEnumBase):\n'
' pdata[idx] = v._value_\n'
' return tag + self._packer.pack(*pdata)\n' ' return tag + self._packer.pack(*pdata)\n'
'\n' '\n'
) )
self._write(
'class _EnumBase(enum.Enum):\n'
# it's hard to mix custom implemention into enum, just do copy-paste instead
classbody = (
' @classmethod\n' ' @classmethod\n'
' def __normalize(cls, val):\n' ' def __normalize(cls, val):\n'
' if isinstance(val, str):\n' ' if isinstance(val, str):\n'
@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase):
' return super()._missing_(value)\n' ' return super()._missing_(value)\n'
'\n' '\n'
) )
self._write(
'class _EnumBase(enum.Enum):\n' + classbody
)
self._write(
'class _BitCombinedEnumBase(enum.Flag):\n' + classbody
)
if not self._imperative: if not self._imperative:
self._write( self._write(
'def _as_dtype_num(dtype):\n' 'def _as_dtype_num(dtype):\n'
@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase):
def _on_member_enum(self, e): def _on_member_enum(self, e):
qualname = '{}.{}'.format(self._cur_param_name, e.name) qualname = '{}.{}'.format(self._cur_param_name, e.name)


self._write('class %s(_EnumBase):', e.name, indent=1)
if e.combined:
self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1)
else:
self._write('class %s(_EnumBase):', e.name, indent=1)

self._write_doc(e.name) self._write_doc(e.name)


for idx, emem in enumerate(e.members): for idx, emem in enumerate(e.members):
self._write('%s = "%s"', emem, emem)
self._write_doc(emem)
if e.combined: if e.combined:
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, 1<<idx))
self._write('%s = 1 << %d', emem, idx)
self._write_doc(emem)
else: else:
self._write('%s = "%s"', emem, emem)
self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format( self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx)) qualname, emem, idx))


for emem, emem_alis in e.member_alias:
self._write('%s = %s', emem_alis, emem)
for emem, emem_alias in e.member_alias:
if e.combined:
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem))
else:
self._write('%s = %s', emem_alias, emem)


self._unindent() self._unindent()
self._write('') self._write('')


if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = "'{}'".format(e.members[e.default])

self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field), cvt='{}.convert({})'.format(qualname, e.name_field),
fmt='I', fmt='I',
default="'{}'".format(e.members[e.default]),
default=default,
type=qualname, type=qualname,
doc=None)) doc=None))


@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase):
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) self._write('%s = %s.%s', e.name, e.src_class, e.src_name)
s = e.src_enum s = e.src_enum
qualname = '{}.{}'.format(e.src_class, e.src_name) qualname = '{}.{}'.format(e.src_class, e.src_name)

if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = "'{}'".format(s.members[e.get_default()])
self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field), cvt='{}.convert({})'.format(qualname, e.name_field),
fmt='I', fmt='I',
default="'{}'".format(s.members[e.get_default()]),
default=default,
type=qualname, type=qualname,
doc=None)) doc=None))


@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase):
v += ',' v += ','
self._write(v) self._write(v)
for mem, alias in e.member_alias: for mem, alias in e.member_alias:
self._write('%s = %s,', alias, mem)
if e.combined:
self._write('%s = %s,', alias, e.compose_combined_enum(mem))
else:
self._write('%s = %s,', alias, mem)
self._write('};', indent=-1) self._write('};', indent=-1)
self._non_static_members.append(e) self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
str(e.name).upper(), len(e.members)) str(e.name).upper(), len(e.members))
self._add_ctor_args(e.name,
'{}::{}'.format(e.name, e.members[e.default]),
e.name_field)
if e.combined:
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default))
else:
default = '{}::{}'.format(e.name, e.members[e.default])
self._add_ctor_args(e.name, default, e.name_field)


def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
s = e.src_enum s = e.src_enum
@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase):
self._non_static_members.append(e) self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
str(e.name).upper(), len(s.members)) str(e.name).upper(), len(s.members))
self._add_ctor_args(e.name,
'{}::{}'.format(e.name,
s.members[e.get_default()]),
e.name_field)
if s.combined:
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default))
else:
default = '{}::{}'.format(e.name, s.members[e.get_default()])
self._add_ctor_args(e.name, default, e.name_field)


def _on_member_field(self, f): def _on_member_field(self, f):
self._non_static_members.append(f) self._non_static_members.append(f)


+ 13
- 2
dnn/scripts/gen_tablegen.py View File

@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase):
return return


# wrapped with default value # wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default)
if e.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default))
else:
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default])

wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)


self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase):
self._write("def {} : {};".format(td_class, enum_def)) self._write("def {} : {};".format(td_class, enum_def))


# wrapped with default value # wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default())
s = e.src_enum
if s.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default()))
else:
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()])

wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)


self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) self._current_tparams.append("{}:${}".format(wrapped, e.name_field))


+ 29
- 12
imperative/python/src/ops.cpp View File

@@ -87,9 +87,13 @@ struct pyobj_convert_generic {
} }
}; };


template<typename T, typename SFINAE=void>
struct EnumTrait;

template <typename T> template <typename T>
struct EnumTrait {
struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> {
static constexpr bool is_bit_combined = false; static constexpr bool is_bit_combined = false;
static constexpr std::underlying_type_t<T> max = 0;
}; };


template <typename T> template <typename T>
@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper {
return ret; return ret;
} }
} }
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1);
return obj;
}
static int py_init(PyObject* self, PyObject* args, PyObject*) {
int input = 1;
if (PyArg_ParseTuple(args, "|i", &input)){
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value =
static_cast<T>(input);
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) {
if (!PyTuple_Size(args)) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
return obj;
}
else {
PyObject* input;
if (!PyArg_ParseTuple(args, "|O", &input)) {
return nullptr;
}
T value;
try {
value = pyobj_convert_generic<T>::from(input);
} CATCH_ALL(nullptr);
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
} }
return 0;
} }
static PyObject* py_repr(PyObject* self) { static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to( return pyobj_convert_generic<std::string>::to(
@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T,
static T from(PyObject* obj) { static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) { if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value; return reinterpret_cast<Wrapper*>(obj)->value;
} else if(PyLong_Check(obj)) {
auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj);
mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError,
"out of range, cannot convert %zu to %s",
static_cast<uint32_t>(value), Wrapper::name);
return static_cast<T>(value);
} }
// try as string // try as string
// TODO: type checkcd // TODO: type checkcd


+ 6
- 5
imperative/tablegen/targets/python_c_extension.cpp View File

@@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() {
"template<> PyNumberMethods " "template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods={};\n", "$enumTpl<$opClass::$enumClass>::number_methods={};\n",
&ctx); &ctx);
os << tgfmt(
"template<> struct EnumTrait<$opClass::$enumClass> { static constexpr "
"bool is_bit_combined = true;};\n",
&ctx);
os << tgfmt(R"(
template<> struct EnumTrait<$opClass::$enumClass> {
static constexpr bool is_bit_combined = true;
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1;
};
)", &ctx, attr->getEnumMembers().size());
} }


auto str2type = [&](auto&& i) -> std::string { auto str2type = [&](auto&& i) -> std::string {
@@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) {
// others should always use singleton // others should always use singleton
os << tgfmt(R"( os << tgfmt(R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;


+ 3
- 3
src/opr/impl/dnn/dnn.oprdecl View File

@@ -6,7 +6,7 @@ decl_opr('Convolution',
'convolution kernel in ' 'convolution kernel in '
'(out channel, in channel, kern row, kern col) format')], '(out channel, in channel, kern row, kern col) format')],
params=[('param', 'ConvolutionV0'), params=[('param', 'ConvolutionV0'),
('execution_polity', 'ExecutionPolicy')],
('execution_polity', 'ExecutionPolicyV0')],
desc='batched convolution on channeled 2D images') desc='batched convolution on channeled 2D images')


decl_opr('Convolution', decl_opr('Convolution',
@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData',
'convolution kernel in ' 'convolution kernel in '
'(out channel, in channel, kern row, kern col) format')], '(out channel, in channel, kern row, kern col) format')],
params=[('param', 'ConvolutionV0'), params=[('param', 'ConvolutionV0'),
('execution_polity', 'ExecutionPolicy')],
('execution_polity', 'ExecutionPolicyV0')],
body=[ body=[
'a, b = all_inputs', 'a, b = all_inputs',
'all_inputs = [b, a]' 'all_inputs = [b, a]'
@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward',
Doc('bias', 'bias'), Doc('bias', 'bias'),
], ],
params=[('param', 'ConvBiasV1'), params=[('param', 'ConvBiasV1'),
('execution_policy', 'ExecutionPolicy')],
('execution_policy', 'ExecutionPolicyV0')],
desc=('activation(convolution(src, filter) + bias) with specified ' desc=('activation(convolution(src, filter) + bias) with specified '
'dtype'), 'dtype'),
has_out_dtype=True) has_out_dtype=True)


+ 6
- 1
tools/param_defs/mgb_opr_param_defs.py View File

@@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields(
'when profile or heuristic algo selection it require the algos' 'when profile or heuristic algo selection it require the algos'
'must be reproducible'), 'must be reproducible'),
Doc('OPTMIZED', Doc('OPTMIZED',
'profile require algos are optmized to achieve fast-profile')).
'profile require algos are optmized to achieve fast-profile'),
default=('HEURISTIC',),
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'),
(('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'),
(('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'),
]).
add_fields('uint64', add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'), Doc('workspace_limit', 'workspace limit in bytes'),
str(2**64-1)+'ull')) str(2**64-1)+'ull'))


Loading…
Cancel
Save