GitOrigin-RevId: 3716bf9bb5
tags/v1.3.1
@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
name = p + e | name = p + e | ||||
e = self._enums[(p, e)] | e = self._enums[(p, e)] | ||||
self._write_doc(e.name) | self._write_doc(e.name) | ||||
self._write("enum %s%s : uint {", p, e.name, indent=1) | |||||
attribute = "(bit_flags)" if e.combined else "" | |||||
self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1) | |||||
for idx, member in enumerate(e.members): | for idx, member in enumerate(e.members): | ||||
self._write_doc(member) | self._write_doc(member) | ||||
if e.combined: | |||||
self._write("%s=%d,", scramble_enum_member_name(str(member)), | |||||
1<<idx) | |||||
else: | |||||
self._write("%s,", scramble_enum_member_name(str(member))) | |||||
self._write("%s,", scramble_enum_member_name(str(member))) | |||||
self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
def _write_doc(self, doc): | def _write_doc(self, doc): | ||||
@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
return | return | ||||
self._write_doc(e.name) | self._write_doc(e.name) | ||||
self._used_enum.add(key) | self._used_enum.add(key) | ||||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, | |||||
scramble_enum_member_name(str(e.members[e.default]))) | |||||
if e.combined: | |||||
default = e.compose_combined_enum(e.default) | |||||
else: | |||||
default = scramble_enum_member_name(str(e.members[e.default])) | |||||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | |||||
def _resolve_const(self, v): | def _resolve_const(self, v): | ||||
while v in self._cur_const_val: | while v in self._cur_const_val: | ||||
@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
return | return | ||||
self._used_enum.add((e.src_class, e.src_name)) | self._used_enum.add((e.src_class, e.src_name)) | ||||
enum_name = e.src_class + e.src_name | enum_name = e.src_class + e.src_name | ||||
self._write( | |||||
"%s:%s = %s;", e.name_field, enum_name, | |||||
scramble_enum_member_name(str(e.src_enum.members[e.get_default()]))) | |||||
s = e.src_enum | |||||
if s.combined: | |||||
default = s.compose_combined_enum(e.get_default()) | |||||
else: | |||||
default = scramble_enum_member_name(str(s.members[e.get_default()])) | |||||
self._write("%s:%s = %s;", e.name_field, enum_name, default) | |||||
def _get_fb_default(self, cppdefault): | def _get_fb_default(self, cppdefault): | ||||
if not isinstance(cppdefault, str): | if not isinstance(cppdefault, str): | ||||
@@ -73,11 +73,21 @@ class member_defs: | |||||
"""define an enum; the result would contain both an enum class def and its | """define an enum; the result would contain both an enum class def and its | ||||
corresponding data field | corresponding data field | ||||
:param default: index of default member value | |||||
:param default: | |||||
for normal enum class: index of default member value | |||||
for bit combined class: tuple of index of default member value | |||||
For example, following representations of the default value for bit | |||||
combined class are all equivalent: | |||||
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...) | |||||
Enum(members=('a', 'b', 'c'), default=(0, 1), ...) | |||||
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...) | |||||
:attr name_field: name of the data field of this enum in the param | :attr name_field: name of the data field of this enum in the param | ||||
struct | struct | ||||
:attr member_alias: list of (member, alias) pairs | |||||
:attr member_alias: | |||||
for normal enum class: list of (member, alias) pairs | |||||
for bit combined class: list of (tuple of members, alias) paris | |||||
""" | """ | ||||
__slots__ = ['name', 'name_field', 'members', 'default', | __slots__ = ['name', 'name_field', 'members', 'default', | ||||
'member_alias', 'combined'] | 'member_alias', 'combined'] | ||||
@@ -90,17 +100,11 @@ class member_defs: | |||||
name = member_defs.Doc.make(name) | name = member_defs.Doc.make(name) | ||||
assert name.id[0].isupper() | assert name.id[0].isupper() | ||||
members = tuple(map(member_defs.Doc.make, members)) | members = tuple(map(member_defs.Doc.make, members)) | ||||
if isinstance(default, str): | |||||
if default not in name_field: | |||||
raise ValueError( | |||||
"Default value '{}' does not exist.".format(default)) | |||||
default = name_field.index(default) | |||||
assert isinstance(default, int) | |||||
self.name = name | self.name = name | ||||
self.combined = combined | self.combined = combined | ||||
self.name_field = self.get_name_field(name.id, name_field) | self.name_field = self.get_name_field(name.id, name_field) | ||||
self.members = members | self.members = members | ||||
self.default = default | |||||
self.default = self.normalize_enum_value(default) | |||||
self.all_enums[(param_name, name.id)] = self | self.all_enums[(param_name, name.id)] = self | ||||
@@ -114,6 +118,43 @@ class member_defs: | |||||
assert isinstance(name_field, str) | assert isinstance(name_field, str) | ||||
return name_field | return name_field | ||||
def normalize_enum_value(self, value): | |||||
def normalize(v): | |||||
if isinstance(v, str): | |||||
if v not in self.members: | |||||
raise ValueError( | |||||
"enum member '{}' does not exist.".format(v)) | |||||
v = self.members.index(v) | |||||
assert isinstance(v, int) | |||||
return v | |||||
if self.combined: | |||||
if isinstance(value, int): | |||||
value = self.decompose_combined_enum(value) | |||||
assert isinstance(value, tuple) | |||||
value = tuple(normalize(i) for i in value) | |||||
return value | |||||
else: | |||||
return normalize(value) | |||||
@staticmethod | |||||
def decompose_combined_enum(v): | |||||
"""Integer => tuple of the indexes of the enum members""" | |||||
assert isinstance(v, int) | |||||
idx = 0 | |||||
members = [] | |||||
while v > 0: | |||||
if v & 1: | |||||
members.append(idx) | |||||
idx += 1 | |||||
v >>= 1 | |||||
return tuple(members) | |||||
def compose_combined_enum(self, v): | |||||
"""tuple of members => Integer""" | |||||
assert self.combined and isinstance(v, tuple) | |||||
norm_v = self.normalize_enum_value(v) | |||||
return sum(1 << i for i in norm_v) | |||||
class Field(Base): | class Field(Base): | ||||
"""define a normal data field""" | """define a normal data field""" | ||||
__slots__ = ['name', 'dtype', 'default'] | __slots__ = ['name', 'dtype', 'default'] | ||||
@@ -146,6 +187,10 @@ class member_defs: | |||||
src_name = name | src_name = name | ||||
self.src_name = src_name | self.src_name = src_name | ||||
self.default = default | self.default = default | ||||
# TODO: remove this assertion if needed; adding mock param_defs in | |||||
# current testing framework is too complicated, and currently we | |||||
# only allow aliasing of normal enum | |||||
assert not self.src_enum.combined | |||||
@property | @property | ||||
def src_enum(self): | def src_enum(self): | ||||
@@ -157,7 +202,7 @@ class member_defs: | |||||
set""" | set""" | ||||
if self.default is None: | if self.default is None: | ||||
return self.src_enum.default | return self.src_enum.default | ||||
return self.default | |||||
return self.src_enum.normalize_enum_value(self.default) | |||||
class ParamDef: | class ParamDef: | ||||
@@ -198,7 +243,7 @@ class ParamDef: | |||||
self.name.id, name, name_field, members, default, member_alias)) | self.name.id, name, name_field, members, default, member_alias)) | ||||
return self | return self | ||||
def add_bit_combination_enum(self, name, *members, default=0, | |||||
def add_bit_combination_enum(self, name, *members, default=tuple(), | |||||
name_field=None, member_alias=[]): | name_field=None, member_alias=[]): | ||||
self.members.append(member_defs.Enum( | self.members.append(member_defs.Enum( | ||||
self.name.id, name, name_field, members, default, member_alias, True)) | self.name.id, name, name_field, members, default, member_alias, True)) | ||||
@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase): | |||||
' for idx, v in enumerate(pdata):\n' | ' for idx, v in enumerate(pdata):\n' | ||||
' if isinstance(v, _EnumBase):\n' | ' if isinstance(v, _EnumBase):\n' | ||||
' pdata[idx] = _enum_member2num[id(v)]\n' | ' pdata[idx] = _enum_member2num[id(v)]\n' | ||||
' elif isinstance(v, _BitCombinedEnumBase):\n' | |||||
' pdata[idx] = v._value_\n' | |||||
' return tag + self._packer.pack(*pdata)\n' | ' return tag + self._packer.pack(*pdata)\n' | ||||
'\n' | '\n' | ||||
) | ) | ||||
self._write( | |||||
'class _EnumBase(enum.Enum):\n' | |||||
# it's hard to mix custom implemention into enum, just do copy-paste instead | |||||
classbody = ( | |||||
' @classmethod\n' | ' @classmethod\n' | ||||
' def __normalize(cls, val):\n' | ' def __normalize(cls, val):\n' | ||||
' if isinstance(val, str):\n' | ' if isinstance(val, str):\n' | ||||
@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase): | |||||
' return super()._missing_(value)\n' | ' return super()._missing_(value)\n' | ||||
'\n' | '\n' | ||||
) | ) | ||||
self._write( | |||||
'class _EnumBase(enum.Enum):\n' + classbody | |||||
) | |||||
self._write( | |||||
'class _BitCombinedEnumBase(enum.Flag):\n' + classbody | |||||
) | |||||
if not self._imperative: | if not self._imperative: | ||||
self._write( | self._write( | ||||
'def _as_dtype_num(dtype):\n' | 'def _as_dtype_num(dtype):\n' | ||||
@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase): | |||||
def _on_member_enum(self, e): | def _on_member_enum(self, e): | ||||
qualname = '{}.{}'.format(self._cur_param_name, e.name) | qualname = '{}.{}'.format(self._cur_param_name, e.name) | ||||
self._write('class %s(_EnumBase):', e.name, indent=1) | |||||
if e.combined: | |||||
self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1) | |||||
else: | |||||
self._write('class %s(_EnumBase):', e.name, indent=1) | |||||
self._write_doc(e.name) | self._write_doc(e.name) | ||||
for idx, emem in enumerate(e.members): | for idx, emem in enumerate(e.members): | ||||
self._write('%s = "%s"', emem, emem) | |||||
self._write_doc(emem) | |||||
if e.combined: | if e.combined: | ||||
self._enum_member2num.append('id({}.{}):{}'.format( | |||||
qualname, emem, 1<<idx)) | |||||
self._write('%s = 1 << %d', emem, idx) | |||||
self._write_doc(emem) | |||||
else: | else: | ||||
self._write('%s = "%s"', emem, emem) | |||||
self._write_doc(emem) | |||||
self._enum_member2num.append('id({}.{}):{}'.format( | self._enum_member2num.append('id({}.{}):{}'.format( | ||||
qualname, emem, idx)) | qualname, emem, idx)) | ||||
for emem, emem_alis in e.member_alias: | |||||
self._write('%s = %s', emem_alis, emem) | |||||
for emem, emem_alias in e.member_alias: | |||||
if e.combined: | |||||
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem)) | |||||
else: | |||||
self._write('%s = %s', emem_alias, emem) | |||||
self._unindent() | self._unindent() | ||||
self._write('') | self._write('') | ||||
if e.combined: | |||||
default = e.compose_combined_enum(e.default) | |||||
else: | |||||
default = "'{}'".format(e.members[e.default]) | |||||
self._cur_fields.append(self.FieldDef( | self._cur_fields.append(self.FieldDef( | ||||
name=e.name_field, | name=e.name_field, | ||||
cvt='{}.convert({})'.format(qualname, e.name_field), | cvt='{}.convert({})'.format(qualname, e.name_field), | ||||
fmt='I', | fmt='I', | ||||
default="'{}'".format(e.members[e.default]), | |||||
default=default, | |||||
type=qualname, | type=qualname, | ||||
doc=None)) | doc=None)) | ||||
@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase): | |||||
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) | self._write('%s = %s.%s', e.name, e.src_class, e.src_name) | ||||
s = e.src_enum | s = e.src_enum | ||||
qualname = '{}.{}'.format(e.src_class, e.src_name) | qualname = '{}.{}'.format(e.src_class, e.src_name) | ||||
if s.combined: | |||||
default = s.compose_combined_enum(e.get_default()) | |||||
else: | |||||
default = "'{}'".format(s.members[e.get_default()]) | |||||
self._cur_fields.append(self.FieldDef( | self._cur_fields.append(self.FieldDef( | ||||
name=e.name_field, | name=e.name_field, | ||||
cvt='{}.convert({})'.format(qualname, e.name_field), | cvt='{}.convert({})'.format(qualname, e.name_field), | ||||
fmt='I', | fmt='I', | ||||
default="'{}'".format(s.members[e.get_default()]), | |||||
default=default, | |||||
type=qualname, | type=qualname, | ||||
doc=None)) | doc=None)) | ||||
@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase): | |||||
v += ',' | v += ',' | ||||
self._write(v) | self._write(v) | ||||
for mem, alias in e.member_alias: | for mem, alias in e.member_alias: | ||||
self._write('%s = %s,', alias, mem) | |||||
if e.combined: | |||||
self._write('%s = %s,', alias, e.compose_combined_enum(mem)) | |||||
else: | |||||
self._write('%s = %s,', alias, mem) | |||||
self._write('};', indent=-1) | self._write('};', indent=-1) | ||||
self._non_static_members.append(e) | self._non_static_members.append(e) | ||||
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | ||||
str(e.name).upper(), len(e.members)) | str(e.name).upper(), len(e.members)) | ||||
self._add_ctor_args(e.name, | |||||
'{}::{}'.format(e.name, e.members[e.default]), | |||||
e.name_field) | |||||
if e.combined: | |||||
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) | |||||
else: | |||||
default = '{}::{}'.format(e.name, e.members[e.default]) | |||||
self._add_ctor_args(e.name, default, e.name_field) | |||||
def _on_member_enum_alias(self, e): | def _on_member_enum_alias(self, e): | ||||
s = e.src_enum | s = e.src_enum | ||||
@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase): | |||||
self._non_static_members.append(e) | self._non_static_members.append(e) | ||||
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | ||||
str(e.name).upper(), len(s.members)) | str(e.name).upper(), len(s.members)) | ||||
self._add_ctor_args(e.name, | |||||
'{}::{}'.format(e.name, | |||||
s.members[e.get_default()]), | |||||
e.name_field) | |||||
if s.combined: | |||||
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) | |||||
else: | |||||
default = '{}::{}'.format(e.name, s.members[e.get_default()]) | |||||
self._add_ctor_args(e.name, default, e.name_field) | |||||
def _on_member_field(self, f): | def _on_member_field(self, f): | ||||
self._non_static_members.append(f) | self._non_static_members.append(f) | ||||
@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase): | |||||
return | return | ||||
# wrapped with default value | # wrapped with default value | ||||
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default) | |||||
if e.combined: | |||||
default_val = "static_cast<{}::{}>({})".format( | |||||
fullname, e.name, e.compose_combined_enum(e.default)) | |||||
else: | |||||
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default]) | |||||
wrapped = self._wrapped_with_default_value(td_class, default_val) | wrapped = self._wrapped_with_default_value(td_class, default_val) | ||||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | ||||
@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase): | |||||
self._write("def {} : {};".format(td_class, enum_def)) | self._write("def {} : {};".format(td_class, enum_def)) | ||||
# wrapped with default value | # wrapped with default value | ||||
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default()) | |||||
s = e.src_enum | |||||
if s.combined: | |||||
default_val = "static_cast<{}::{}>({})".format( | |||||
fullname, e.name, s.compose_combined_enum(e.get_default())) | |||||
else: | |||||
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()]) | |||||
wrapped = self._wrapped_with_default_value(td_class, default_val) | wrapped = self._wrapped_with_default_value(td_class, default_val) | ||||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | ||||
@@ -87,9 +87,13 @@ struct pyobj_convert_generic { | |||||
} | } | ||||
}; | }; | ||||
template<typename T, typename SFINAE=void> | |||||
struct EnumTrait; | |||||
template <typename T> | template <typename T> | ||||
struct EnumTrait { | |||||
struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> { | |||||
static constexpr bool is_bit_combined = false; | static constexpr bool is_bit_combined = false; | ||||
static constexpr std::underlying_type_t<T> max = 0; | |||||
}; | }; | ||||
template <typename T> | template <typename T> | ||||
@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper { | |||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { | |||||
PyObject* obj = type->tp_alloc(type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); | |||||
return obj; | |||||
} | |||||
static int py_init(PyObject* self, PyObject* args, PyObject*) { | |||||
int input = 1; | |||||
if (PyArg_ParseTuple(args, "|i", &input)){ | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = | |||||
static_cast<T>(input); | |||||
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) { | |||||
if (!PyTuple_Size(args)) { | |||||
PyObject* obj = type->tp_alloc(type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T(); | |||||
return obj; | |||||
} | |||||
else { | |||||
PyObject* input; | |||||
if (!PyArg_ParseTuple(args, "|O", &input)) { | |||||
return nullptr; | |||||
} | |||||
T value; | |||||
try { | |||||
value = pyobj_convert_generic<T>::from(input); | |||||
} CATCH_ALL(nullptr); | |||||
PyObject* obj = type->tp_alloc(type, 0); | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value; | |||||
return obj; | |||||
} | } | ||||
return 0; | |||||
} | } | ||||
static PyObject* py_repr(PyObject* self) { | static PyObject* py_repr(PyObject* self) { | ||||
return pyobj_convert_generic<std::string>::to( | return pyobj_convert_generic<std::string>::to( | ||||
@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T, | |||||
static T from(PyObject* obj) { | static T from(PyObject* obj) { | ||||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | if (PyObject_TypeCheck(obj, &Wrapper::type)) { | ||||
return reinterpret_cast<Wrapper*>(obj)->value; | return reinterpret_cast<Wrapper*>(obj)->value; | ||||
} else if(PyLong_Check(obj)) { | |||||
auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj); | |||||
mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError, | |||||
"out of range, cannot convert %zu to %s", | |||||
static_cast<uint32_t>(value), Wrapper::name); | |||||
return static_cast<T>(value); | |||||
} | } | ||||
// try as string | // try as string | ||||
// TODO: type checkcd | // TODO: type checkcd | ||||
@@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() { | |||||
"template<> PyNumberMethods " | "template<> PyNumberMethods " | ||||
"$enumTpl<$opClass::$enumClass>::number_methods={};\n", | "$enumTpl<$opClass::$enumClass>::number_methods={};\n", | ||||
&ctx); | &ctx); | ||||
os << tgfmt( | |||||
"template<> struct EnumTrait<$opClass::$enumClass> { static constexpr " | |||||
"bool is_bit_combined = true;};\n", | |||||
&ctx); | |||||
os << tgfmt(R"( | |||||
template<> struct EnumTrait<$opClass::$enumClass> { | |||||
static constexpr bool is_bit_combined = true; | |||||
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1; | |||||
}; | |||||
)", &ctx, attr->getEnumMembers().size()); | |||||
} | } | ||||
auto str2type = [&](auto&& i) -> std::string { | auto str2type = [&](auto&& i) -> std::string { | ||||
@@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) { | |||||
// others should always use singleton | // others should always use singleton | ||||
os << tgfmt(R"( | os << tgfmt(R"( | ||||
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; | e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; | ||||
e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init; | |||||
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; | auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; | ||||
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; | number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; | ||||
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; | number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; | ||||
@@ -6,7 +6,7 @@ decl_opr('Convolution', | |||||
'convolution kernel in ' | 'convolution kernel in ' | ||||
'(out channel, in channel, kern row, kern col) format')], | '(out channel, in channel, kern row, kern col) format')], | ||||
params=[('param', 'ConvolutionV0'), | params=[('param', 'ConvolutionV0'), | ||||
('execution_polity', 'ExecutionPolicy')], | |||||
('execution_polity', 'ExecutionPolicyV0')], | |||||
desc='batched convolution on channeled 2D images') | desc='batched convolution on channeled 2D images') | ||||
decl_opr('Convolution', | decl_opr('Convolution', | ||||
@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData', | |||||
'convolution kernel in ' | 'convolution kernel in ' | ||||
'(out channel, in channel, kern row, kern col) format')], | '(out channel, in channel, kern row, kern col) format')], | ||||
params=[('param', 'ConvolutionV0'), | params=[('param', 'ConvolutionV0'), | ||||
('execution_polity', 'ExecutionPolicy')], | |||||
('execution_polity', 'ExecutionPolicyV0')], | |||||
body=[ | body=[ | ||||
'a, b = all_inputs', | 'a, b = all_inputs', | ||||
'all_inputs = [b, a]' | 'all_inputs = [b, a]' | ||||
@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward', | |||||
Doc('bias', 'bias'), | Doc('bias', 'bias'), | ||||
], | ], | ||||
params=[('param', 'ConvBiasV1'), | params=[('param', 'ConvBiasV1'), | ||||
('execution_policy', 'ExecutionPolicy')], | |||||
('execution_policy', 'ExecutionPolicyV0')], | |||||
desc=('activation(convolution(src, filter) + bias) with specified ' | desc=('activation(convolution(src, filter) + bias) with specified ' | ||||
'dtype'), | 'dtype'), | ||||
has_out_dtype=True) | has_out_dtype=True) | ||||
@@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields( | |||||
'when profile or heuristic algo selection it require the algos' | 'when profile or heuristic algo selection it require the algos' | ||||
'must be reproducible'), | 'must be reproducible'), | ||||
Doc('OPTMIZED', | Doc('OPTMIZED', | ||||
'profile require algos are optmized to achieve fast-profile')). | |||||
'profile require algos are optmized to achieve fast-profile'), | |||||
default=('HEURISTIC',), | |||||
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'), | |||||
(('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'), | |||||
(('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'), | |||||
]). | |||||
add_fields('uint64', | add_fields('uint64', | ||||
Doc('workspace_limit', 'workspace limit in bytes'), | Doc('workspace_limit', 'workspace limit in bytes'), | ||||
str(2**64-1)+'ull')) | str(2**64-1)+'ull')) | ||||