GitOrigin-RevId: 3716bf9bb5
release-1.3
@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase): | |||
name = p + e | |||
e = self._enums[(p, e)] | |||
self._write_doc(e.name) | |||
self._write("enum %s%s : uint {", p, e.name, indent=1) | |||
attribute = "(bit_flags)" if e.combined else "" | |||
self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1) | |||
for idx, member in enumerate(e.members): | |||
self._write_doc(member) | |||
if e.combined: | |||
self._write("%s=%d,", scramble_enum_member_name(str(member)), | |||
1<<idx) | |||
else: | |||
self._write("%s,", scramble_enum_member_name(str(member))) | |||
self._write("%s,", scramble_enum_member_name(str(member))) | |||
self._write("}\n", indent=-1) | |||
def _write_doc(self, doc): | |||
@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase): | |||
return | |||
self._write_doc(e.name) | |||
self._used_enum.add(key) | |||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, | |||
scramble_enum_member_name(str(e.members[e.default]))) | |||
if e.combined: | |||
default = e.compose_combined_enum(e.default) | |||
else: | |||
default = scramble_enum_member_name(str(e.members[e.default])) | |||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | |||
def _resolve_const(self, v): | |||
while v in self._cur_const_val: | |||
@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase): | |||
return | |||
self._used_enum.add((e.src_class, e.src_name)) | |||
enum_name = e.src_class + e.src_name | |||
self._write( | |||
"%s:%s = %s;", e.name_field, enum_name, | |||
scramble_enum_member_name(str(e.src_enum.members[e.get_default()]))) | |||
s = e.src_enum | |||
if s.combined: | |||
default = s.compose_combined_enum(e.get_default()) | |||
else: | |||
default = scramble_enum_member_name(str(s.members[e.get_default()])) | |||
self._write("%s:%s = %s;", e.name_field, enum_name, default) | |||
def _get_fb_default(self, cppdefault): | |||
if not isinstance(cppdefault, str): | |||
@@ -73,11 +73,21 @@ class member_defs: | |||
"""define an enum; the result would contain both an enum class def and its | |||
corresponding data field | |||
:param default: index of default member value | |||
:param default: | |||
for normal enum class: index of default member value | |||
for bit combined class: tuple of index of default member value | |||
For example, following representations of the default value for bit | |||
combined class are all equivalent: | |||
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...) | |||
Enum(members=('a', 'b', 'c'), default=(0, 1), ...) | |||
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...) | |||
:attr name_field: name of the data field of this enum in the param | |||
struct | |||
:attr member_alias: list of (member, alias) pairs | |||
:attr member_alias: | |||
for normal enum class: list of (member, alias) pairs | |||
for bit combined class: list of (tuple of members, alias) paris | |||
""" | |||
__slots__ = ['name', 'name_field', 'members', 'default', | |||
'member_alias', 'combined'] | |||
@@ -90,17 +100,11 @@ class member_defs: | |||
name = member_defs.Doc.make(name) | |||
assert name.id[0].isupper() | |||
members = tuple(map(member_defs.Doc.make, members)) | |||
if isinstance(default, str): | |||
if default not in name_field: | |||
raise ValueError( | |||
"Default value '{}' does not exist.".format(default)) | |||
default = name_field.index(default) | |||
assert isinstance(default, int) | |||
self.name = name | |||
self.combined = combined | |||
self.name_field = self.get_name_field(name.id, name_field) | |||
self.members = members | |||
self.default = default | |||
self.default = self.normalize_enum_value(default) | |||
self.all_enums[(param_name, name.id)] = self | |||
@@ -114,6 +118,43 @@ class member_defs: | |||
assert isinstance(name_field, str) | |||
return name_field | |||
def normalize_enum_value(self, value): | |||
def normalize(v): | |||
if isinstance(v, str): | |||
if v not in self.members: | |||
raise ValueError( | |||
"enum member '{}' does not exist.".format(v)) | |||
v = self.members.index(v) | |||
assert isinstance(v, int) | |||
return v | |||
if self.combined: | |||
if isinstance(value, int): | |||
value = self.decompose_combined_enum(value) | |||
assert isinstance(value, tuple) | |||
value = tuple(normalize(i) for i in value) | |||
return value | |||
else: | |||
return normalize(value) | |||
@staticmethod | |||
def decompose_combined_enum(v): | |||
"""Integer => tuple of the indexes of the enum members""" | |||
assert isinstance(v, int) | |||
idx = 0 | |||
members = [] | |||
while v > 0: | |||
if v & 1: | |||
members.append(idx) | |||
idx += 1 | |||
v >>= 1 | |||
return tuple(members) | |||
def compose_combined_enum(self, v): | |||
"""tuple of members => Integer""" | |||
assert self.combined and isinstance(v, tuple) | |||
norm_v = self.normalize_enum_value(v) | |||
return sum(1 << i for i in norm_v) | |||
class Field(Base): | |||
"""define a normal data field""" | |||
__slots__ = ['name', 'dtype', 'default'] | |||
@@ -146,6 +187,10 @@ class member_defs: | |||
src_name = name | |||
self.src_name = src_name | |||
self.default = default | |||
# TODO: remove this assertion if needed; adding mock param_defs in | |||
# current testing framework is too complicated, and currently we | |||
# only allow aliasing of normal enum | |||
assert not self.src_enum.combined | |||
@property | |||
def src_enum(self): | |||
@@ -157,7 +202,7 @@ class member_defs: | |||
set""" | |||
if self.default is None: | |||
return self.src_enum.default | |||
return self.default | |||
return self.src_enum.normalize_enum_value(self.default) | |||
class ParamDef: | |||
@@ -198,7 +243,7 @@ class ParamDef: | |||
self.name.id, name, name_field, members, default, member_alias)) | |||
return self | |||
def add_bit_combination_enum(self, name, *members, default=0, | |||
def add_bit_combination_enum(self, name, *members, default=tuple(), | |||
name_field=None, member_alias=[]): | |||
self.members.append(member_defs.Enum( | |||
self.name.id, name, name_field, members, default, member_alias, True)) | |||
@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase): | |||
' for idx, v in enumerate(pdata):\n' | |||
' if isinstance(v, _EnumBase):\n' | |||
' pdata[idx] = _enum_member2num[id(v)]\n' | |||
' elif isinstance(v, _BitCombinedEnumBase):\n' | |||
' pdata[idx] = v._value_\n' | |||
' return tag + self._packer.pack(*pdata)\n' | |||
'\n' | |||
) | |||
self._write( | |||
'class _EnumBase(enum.Enum):\n' | |||
# it's hard to mix custom implemention into enum, just do copy-paste instead | |||
classbody = ( | |||
' @classmethod\n' | |||
' def __normalize(cls, val):\n' | |||
' if isinstance(val, str):\n' | |||
@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase): | |||
' return super()._missing_(value)\n' | |||
'\n' | |||
) | |||
self._write( | |||
'class _EnumBase(enum.Enum):\n' + classbody | |||
) | |||
self._write( | |||
'class _BitCombinedEnumBase(enum.Flag):\n' + classbody | |||
) | |||
if not self._imperative: | |||
self._write( | |||
'def _as_dtype_num(dtype):\n' | |||
@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase): | |||
def _on_member_enum(self, e): | |||
qualname = '{}.{}'.format(self._cur_param_name, e.name) | |||
self._write('class %s(_EnumBase):', e.name, indent=1) | |||
if e.combined: | |||
self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1) | |||
else: | |||
self._write('class %s(_EnumBase):', e.name, indent=1) | |||
self._write_doc(e.name) | |||
for idx, emem in enumerate(e.members): | |||
self._write('%s = "%s"', emem, emem) | |||
self._write_doc(emem) | |||
if e.combined: | |||
self._enum_member2num.append('id({}.{}):{}'.format( | |||
qualname, emem, 1<<idx)) | |||
self._write('%s = 1 << %d', emem, idx) | |||
self._write_doc(emem) | |||
else: | |||
self._write('%s = "%s"', emem, emem) | |||
self._write_doc(emem) | |||
self._enum_member2num.append('id({}.{}):{}'.format( | |||
qualname, emem, idx)) | |||
for emem, emem_alis in e.member_alias: | |||
self._write('%s = %s', emem_alis, emem) | |||
for emem, emem_alias in e.member_alias: | |||
if e.combined: | |||
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem)) | |||
else: | |||
self._write('%s = %s', emem_alias, emem) | |||
self._unindent() | |||
self._write('') | |||
if e.combined: | |||
default = e.compose_combined_enum(e.default) | |||
else: | |||
default = "'{}'".format(e.members[e.default]) | |||
self._cur_fields.append(self.FieldDef( | |||
name=e.name_field, | |||
cvt='{}.convert({})'.format(qualname, e.name_field), | |||
fmt='I', | |||
default="'{}'".format(e.members[e.default]), | |||
default=default, | |||
type=qualname, | |||
doc=None)) | |||
@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase): | |||
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) | |||
s = e.src_enum | |||
qualname = '{}.{}'.format(e.src_class, e.src_name) | |||
if s.combined: | |||
default = s.compose_combined_enum(e.get_default()) | |||
else: | |||
default = "'{}'".format(s.members[e.get_default()]) | |||
self._cur_fields.append(self.FieldDef( | |||
name=e.name_field, | |||
cvt='{}.convert({})'.format(qualname, e.name_field), | |||
fmt='I', | |||
default="'{}'".format(s.members[e.get_default()]), | |||
default=default, | |||
type=qualname, | |||
doc=None)) | |||
@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase): | |||
v += ',' | |||
self._write(v) | |||
for mem, alias in e.member_alias: | |||
self._write('%s = %s,', alias, mem) | |||
if e.combined: | |||
self._write('%s = %s,', alias, e.compose_combined_enum(mem)) | |||
else: | |||
self._write('%s = %s,', alias, mem) | |||
self._write('};', indent=-1) | |||
self._non_static_members.append(e) | |||
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | |||
str(e.name).upper(), len(e.members)) | |||
self._add_ctor_args(e.name, | |||
'{}::{}'.format(e.name, e.members[e.default]), | |||
e.name_field) | |||
if e.combined: | |||
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) | |||
else: | |||
default = '{}::{}'.format(e.name, e.members[e.default]) | |||
self._add_ctor_args(e.name, default, e.name_field) | |||
def _on_member_enum_alias(self, e): | |||
s = e.src_enum | |||
@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase): | |||
self._non_static_members.append(e) | |||
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | |||
str(e.name).upper(), len(s.members)) | |||
self._add_ctor_args(e.name, | |||
'{}::{}'.format(e.name, | |||
s.members[e.get_default()]), | |||
e.name_field) | |||
if s.combined: | |||
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) | |||
else: | |||
default = '{}::{}'.format(e.name, s.members[e.get_default()]) | |||
self._add_ctor_args(e.name, default, e.name_field) | |||
def _on_member_field(self, f): | |||
self._non_static_members.append(f) | |||
@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase): | |||
return | |||
# wrapped with default value | |||
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default) | |||
if e.combined: | |||
default_val = "static_cast<{}::{}>({})".format( | |||
fullname, e.name, e.compose_combined_enum(e.default)) | |||
else: | |||
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default]) | |||
wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase): | |||
self._write("def {} : {};".format(td_class, enum_def)) | |||
# wrapped with default value | |||
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default()) | |||
s = e.src_enum | |||
if s.combined: | |||
default_val = "static_cast<{}::{}>({})".format( | |||
fullname, e.name, s.compose_combined_enum(e.get_default())) | |||
else: | |||
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()]) | |||
wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
@@ -87,9 +87,13 @@ struct pyobj_convert_generic { | |||
} | |||
}; | |||
template<typename T, typename SFINAE=void> | |||
struct EnumTrait; | |||
template <typename T> | |||
struct EnumTrait { | |||
struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> { | |||
static constexpr bool is_bit_combined = false; | |||
static constexpr std::underlying_type_t<T> max = 0; | |||
}; | |||
template <typename T> | |||
@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper { | |||
return ret; | |||
} | |||
} | |||
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { | |||
PyObject* obj = type->tp_alloc(type, 0); | |||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); | |||
return obj; | |||
} | |||
static int py_init(PyObject* self, PyObject* args, PyObject*) { | |||
int input = 1; | |||
if (PyArg_ParseTuple(args, "|i", &input)){ | |||
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = | |||
static_cast<T>(input); | |||
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) { | |||
if (!PyTuple_Size(args)) { | |||
PyObject* obj = type->tp_alloc(type, 0); | |||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T(); | |||
return obj; | |||
} | |||
else { | |||
PyObject* input; | |||
if (!PyArg_ParseTuple(args, "|O", &input)) { | |||
return nullptr; | |||
} | |||
T value; | |||
try { | |||
value = pyobj_convert_generic<T>::from(input); | |||
} CATCH_ALL(nullptr); | |||
PyObject* obj = type->tp_alloc(type, 0); | |||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value; | |||
return obj; | |||
} | |||
return 0; | |||
} | |||
static PyObject* py_repr(PyObject* self) { | |||
return pyobj_convert_generic<std::string>::to( | |||
@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T, | |||
static T from(PyObject* obj) { | |||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||
return reinterpret_cast<Wrapper*>(obj)->value; | |||
} else if(PyLong_Check(obj)) { | |||
auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj); | |||
mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError, | |||
"out of range, cannot convert %zu to %s", | |||
static_cast<uint32_t>(value), Wrapper::name); | |||
return static_cast<T>(value); | |||
} | |||
// try as string | |||
// TODO: type checkcd | |||
@@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() { | |||
"template<> PyNumberMethods " | |||
"$enumTpl<$opClass::$enumClass>::number_methods={};\n", | |||
&ctx); | |||
os << tgfmt( | |||
"template<> struct EnumTrait<$opClass::$enumClass> { static constexpr " | |||
"bool is_bit_combined = true;};\n", | |||
&ctx); | |||
os << tgfmt(R"( | |||
template<> struct EnumTrait<$opClass::$enumClass> { | |||
static constexpr bool is_bit_combined = true; | |||
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1; | |||
}; | |||
)", &ctx, attr->getEnumMembers().size()); | |||
} | |||
auto str2type = [&](auto&& i) -> std::string { | |||
@@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) { | |||
// others should always use singleton | |||
os << tgfmt(R"( | |||
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; | |||
e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init; | |||
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; | |||
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; | |||
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; | |||
@@ -6,7 +6,7 @@ decl_opr('Convolution', | |||
'convolution kernel in ' | |||
'(out channel, in channel, kern row, kern col) format')], | |||
params=[('param', 'ConvolutionV0'), | |||
('execution_polity', 'ExecutionPolicy')], | |||
('execution_polity', 'ExecutionPolicyV0')], | |||
desc='batched convolution on channeled 2D images') | |||
decl_opr('Convolution', | |||
@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData', | |||
'convolution kernel in ' | |||
'(out channel, in channel, kern row, kern col) format')], | |||
params=[('param', 'ConvolutionV0'), | |||
('execution_polity', 'ExecutionPolicy')], | |||
('execution_polity', 'ExecutionPolicyV0')], | |||
body=[ | |||
'a, b = all_inputs', | |||
'all_inputs = [b, a]' | |||
@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward', | |||
Doc('bias', 'bias'), | |||
], | |||
params=[('param', 'ConvBiasV1'), | |||
('execution_policy', 'ExecutionPolicy')], | |||
('execution_policy', 'ExecutionPolicyV0')], | |||
desc=('activation(convolution(src, filter) + bias) with specified ' | |||
'dtype'), | |||
has_out_dtype=True) | |||
@@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields( | |||
'when profile or heuristic algo selection it require the algos' | |||
'must be reproducible'), | |||
Doc('OPTMIZED', | |||
'profile require algos are optmized to achieve fast-profile')). | |||
'profile require algos are optmized to achieve fast-profile'), | |||
default=('HEURISTIC',), | |||
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'), | |||
(('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'), | |||
(('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'), | |||
]). | |||
add_fields('uint64', | |||
Doc('workspace_limit', 'workspace limit in bytes'), | |||
str(2**64-1)+'ull')) | |||