Browse Source

refactor(mgb/dnn): refactor enum used in serializing

GitOrigin-RevId: e57af4a59c
release-1.6
Megvii Engine Team 3 years ago
parent
commit
fb49a2834f
10 changed files with 410 additions and 370 deletions
  1. +10
    -2
      dnn/scripts/gen_flatbuffers_schema.py
  2. +35
    -24
      dnn/scripts/gen_param_defs.py
  3. +5
    -3
      dnn/scripts/gen_tablegen.py
  4. +281
    -281
      dnn/scripts/opr_param_defs.py
  5. +11
    -8
      imperative/tablegen/helper.h
  6. +7
    -3
      imperative/tablegen/targets/cpp_class.cpp
  7. +8
    -7
      imperative/tablegen/targets/pybind11.cpp
  8. +13
    -3
      imperative/tablegen/targets/python_c_extension.cpp
  9. +2
    -1
      tools/gen_header_for_bin_reduce.py
  10. +38
    -38
      tools/param_defs/mgb_opr_param_defs.py

+ 10
- 2
dnn/scripts/gen_flatbuffers_schema.py View File

@@ -23,8 +23,14 @@ def _cname_to_fbname(cname):
}[cname] }[cname]


def scramble_enum_member_name(name): def scramble_enum_member_name(name):
s = name.find('<<')
if s != -1:
name = name[0:name.find('=') + 1] + ' ' + name[s+2:]
if name in ("MIN", "MAX"): if name in ("MIN", "MAX"):
return name + "_" return name + "_"
o_name = name.split(' ')[0].split('=')[0]
if o_name in ("MIN", "MAX"):
return name.replace(o_name, o_name + "_")
return name return name


class FlatBuffersWriter(IndentWriterBase): class FlatBuffersWriter(IndentWriterBase):
@@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase):
if e.combined: if e.combined:
default = e.compose_combined_enum(e.default) default = e.compose_combined_enum(e.default)
else: else:
default = scramble_enum_member_name(str(e.members[e.default]))
default = scramble_enum_member_name(
str(e.members[e.default]).split(' ')[0].split('=')[0])
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)


def _resolve_const(self, v): def _resolve_const(self, v):
@@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase):
if s.combined: if s.combined:
default = s.compose_combined_enum(e.get_default()) default = s.compose_combined_enum(e.get_default())
else: else:
default = scramble_enum_member_name(str(s.members[e.get_default()]))
default = scramble_enum_member_name(
str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
self._write("%s:%s = %s;", e.name_field, enum_name, default) self._write("%s:%s = %s;", e.name_field, enum_name, default)


def _get_fb_default(self, cppdefault): def _get_fb_default(self, cppdefault):


+ 35
- 24
dnn/scripts/gen_param_defs.py View File

@@ -121,10 +121,12 @@ class member_defs:
def normalize_enum_value(self, value): def normalize_enum_value(self, value):
def normalize(v): def normalize(v):
if isinstance(v, str): if isinstance(v, str):
if v not in self.members:
raise ValueError(
"enum member '{}' does not exist.".format(v))
v = self.members.index(v)
for idx, m in enumerate(self.members):
m = str(m).split(' ')[0].split('=')[0]
if v == m :
return idx
raise ValueError(
"enum member '{}' does not exist.".format(v))
assert isinstance(v, int) assert isinstance(v, int)
return v return v
if self.combined: if self.combined:
@@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase):


self._write_doc(e.name) self._write_doc(e.name)


for idx, emem in enumerate(e.members):
for emem in e.members:
if e.combined: if e.combined:
self._write('%s = 1 << %d', emem, idx)
self._write('%s', emem)
self._write_doc(emem) self._write_doc(emem)
else: else:
self._write('%s = "%s"', emem, emem)
v = str(emem).split(' ')[0].split('=')[0]
n = int(str(emem).split('=')[1])
self._write('%s = "%s"', v, v)
self._write_doc(emem) self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format( self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx))
qualname, v, n))


for emem, emem_alias in e.member_alias: for emem, emem_alias in e.member_alias:
em_a = emem_alias.split(' ')[0].split('=')[0]
if e.combined: if e.combined:
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem))
self._write('%s = %s', em_a, e.compose_combined_enum(emem))
else: else:
self._write('%s = %s', emem_alias, emem)
em = str(emem).split(' ')[0].split('=')[0]
self._write('%s = %s', em_a, em)


self._unindent() self._unindent()
self._write('') self._write('')
@@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase):
if e.combined: if e.combined:
default = e.compose_combined_enum(e.default) default = e.compose_combined_enum(e.default)
else: else:
default = "'{}'".format(e.members[e.default])
default = "'{}'".format(str(e.members[e.default]).split(' ')[0].split('=')[0])


self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
@@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase):
if s.combined: if s.combined:
default = s.compose_combined_enum(e.get_default()) default = s.compose_combined_enum(e.get_default())
else: else:
default = "'{}'".format(s.members[e.get_default()])
default = "'{}'".format(str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field), cvt='{}.convert({})'.format(qualname, e.name_field),
@@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase):
def _on_member_enum(self, e): def _on_member_enum(self, e):
self._write_doc(e.name) self._write_doc(e.name)
self._write('enum class %s: uint32_t {', e.name, indent=1) self._write('enum class %s: uint32_t {', e.name, indent=1)
for idx, i in enumerate(e.members):
for i in e.members:
self._write_doc(i) self._write_doc(i)
v = '{} = {}'.format(i, idx)
if e.combined:
v = '{} = 1 << {}'.format(i, idx)
v = str(i)
if i is not e.members[-1] or e.member_alias: if i is not e.members[-1] or e.member_alias:
v += ',' v += ','
self._write(v) self._write(v)
@@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase):
if e.combined: if e.combined:
self._write('%s = %s,', alias, e.compose_combined_enum(mem)) self._write('%s = %s,', alias, e.compose_combined_enum(mem))
else: else:
self._write('%s = %s,', alias, mem)
self._write('%s = %s,', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0])
self._write('};', indent=-1) self._write('};', indent=-1)
self._non_static_members.append(e) self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
@@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase):
if e.combined: if e.combined:
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default))
else: else:
default = '{}::{}'.format(e.name, e.members[e.default])
value = str(e.members[e.default])
value = value.split(' ')[0].split('=')[0]
default = '{}::{}'.format(e.name, value)
self._add_ctor_args(e.name, default, e.name_field) self._add_ctor_args(e.name, default, e.name_field)


def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
@@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase):
if s.combined: if s.combined:
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default))
else: else:
default = '{}::{}'.format(e.name, s.members[e.get_default()])
value = str(s.members[e.get_default()])
value = value.split(' ')[0].split('=')[0]
default = '{}::{}'.format(e.name, value)
self._add_ctor_args(e.name, default, e.name_field) self._add_ctor_args(e.name, default, e.name_field)


def _on_member_field(self, f): def _on_member_field(self, f):
@@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter):
def _on_member_enum(self, e): def _on_member_enum(self, e):
self._write_doc(e.name) self._write_doc(e.name)
self._write('struct %s {', e.name, indent=1) self._write('struct %s {', e.name, indent=1)
for idx, val in enumerate(e.members):
for val in e.members:
self._write_doc(val) self._write_doc(val)
self._write('static const uint32_t %s = %d;', val, idx)
v = str(val)
self._write('static const uint32_t %s;', v)
for mem, alias in e.member_alias: for mem, alias in e.member_alias:
self._write('static const uint32_t %s = %s;', alias, mem)
self._write('static const uint32_t %s = %s;', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0])
self._write('};', indent=-1) self._write('};', indent=-1)


def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
@@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase):
members = e.src_enum.members members = e.src_enum.members
else: else:
members = e.members members = e.members
for idx, i in enumerate(members):
for i in members:
v = str(i)
v = v.split(' ')[0].split('=')[0]
self._write('case %s::%s::%s: return "%s";', self._write('case %s::%s::%s: return "%s";',
self._param_name, e.name, i, i, indent=0)
self._param_name, e.name, v, v, indent=0)
self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));', self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));',
self._param_name, e.name, indent=0) self._param_name, e.name, indent=0)
self._write('}', indent=-1) self._write('}', indent=-1)


+ 5
- 3
dnn/scripts/gen_tablegen.py View File

@@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase):
fullname = "::megdnn::param::{}".format(p.name) fullname = "::megdnn::param::{}".format(p.name)
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
def format(v): def format(v):
return '\"{}\"'.format(str(v))
return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0])
enum_def += ','.join(format(i) for i in e.members) enum_def += ','.join(format(i) for i in e.members)


if e.combined: if e.combined:
@@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase):
default_val = "static_cast<{}::{}>({})".format( default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default)) fullname, e.name, e.compose_combined_enum(e.default))
else: else:
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default])
default_val = "{}::{}::{}".format(
fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0])


wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)


@@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase):
default_val = "static_cast<{}::{}>({})".format( default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default())) fullname, e.name, s.compose_combined_enum(e.get_default()))
else: else:
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()])
default_val = "{}::{}::{}".format(fullname, e.name, str(
s.members[e.get_default()]).split(' ')[0].split('=')[0])


wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)




+ 281
- 281
dnn/scripts/opr_param_defs.py View File

@@ -3,7 +3,7 @@ pdef('Empty')
pdef('Axis').add_fields('int32', 'axis', 0) pdef('Axis').add_fields('int32', 'axis', 0)


(pdef('Convolution', version=0, is_legacy=True). (pdef('Convolution', version=0, is_legacy=True).
add_enum('Mode', 'CROSS_CORRELATION', 'CONVOLUTION').
add_enum('Mode', 'CROSS_CORRELATION = 0', 'CONVOLUTION = 1').
add_fields( add_fields(
'uint32', 'uint32',
Doc('pad_h', 'padding on one side on the first dimension'), 0, Doc('pad_h', 'padding on one side on the first dimension'), 0,
@@ -16,41 +16,41 @@ pdef('Axis').add_fields('int32', 'axis', 0)
'on the second dimension'), 1 'on the second dimension'), 1
). ).
add_enum('DataType', add_enum('DataType',
Doc('FLOAT', 'input/output both float32/float16'),
'INT8x8x16',
'INT8x8x32',
Doc('FLOAT_IO16xC32', 'input/output both float16, the internal '
Doc('FLOAT = 0', 'input/output both float32/float16'),
'INT8x8x16 = 1',
'INT8x8x32 = 2',
Doc('FLOAT_IO16xC32 = 3', 'input/output both float16, the internal '
'compute is float32'), 'compute is float32'),
Doc('QUINT8x8x32', 'input QuantizedAsymm8, output QuantizedS32'),
Doc('INT8x8xX', 'input int8, output specified by tensor DType'),
Doc('QUINT4x4x32', 'input QuantizedAsymm4, output QuantizedS32'),
Doc('QUINT8x8x32 = 4', 'input QuantizedAsymm8, output QuantizedS32'),
Doc('INT8x8xX = 5', 'input int8, output specified by tensor DType'),
Doc('QUINT4x4x32 = 6', 'input QuantizedAsymm4, output QuantizedS32'),
name_field='data_type'). name_field='data_type').
add_enum('Sparse', add_enum('Sparse',
Doc('DENSE', 'dense convolution: filter shape should be '
Doc('DENSE = 0', 'dense convolution: filter shape should be '
'[oc, ic, spatial...] if format is NCHW, ' '[oc, ic, spatial...] if format is NCHW, '
'[oc, spatial..., ic] if format is NHWC'), '[oc, spatial..., ic] if format is NHWC'),
Doc('GROUP', 'group convolution: filter shape should be '
Doc('GROUP = 1', 'group convolution: filter shape should be '
'[group, oc_per_group, ic_per_group, spatial...] if format is NCHW, ' '[group, oc_per_group, ic_per_group, spatial...] if format is NCHW, '
'[group, oc_per_group, spatial..., ic_per_group] if format is NHWC') '[group, oc_per_group, spatial..., ic_per_group] if format is NHWC')
). ).
add_enum(Doc('Format', 'convolution data/filter/output format; see ' add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'), ':class:`RelayoutFormat` for more details'),
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88',
'NCHW44','NCHW44_DOT',
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'),
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'),
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'),
Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, '
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6',
'NCHW44 = 7','NCHW44_DOT = 8',
Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'),
Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'),
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NCHW4_NHWC = 15', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'),
Doc('NHWC_NCHW = 16', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'), 'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
Doc('NHWC_NCHW4_IC_SMALL = 17', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'), 'output tensor is nchw4 layout, padding c=4'),
Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
Doc('NCHW_NCHW4_IC_SMALL = 18', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'), 'output tensor is nchw4 layout, padding c=4'),
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
Doc('CHWN4 = 19', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'))
) )


@@ -72,9 +72,9 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. ' add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. '
'different combinations of intermediate result ' 'different combinations of intermediate result '
'data types.'), 'data types.'),
Doc('DEFAULT', 'No special requirements on the precision of '
Doc('DEFAULT = 0', 'No special requirements on the precision of '
'intermediate results.'), 'intermediate results.'),
Doc('FLOAT32', 'Use Float32 accumulator and intermediate result. '
Doc('FLOAT32 = 1', 'Use Float32 accumulator and intermediate result. '
'Only supported when input and output is Float16.'), 'Only supported when input and output is Float16.'),
name_field='compute_mode') name_field='compute_mode')
) )
@@ -95,21 +95,21 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum_alias('Sparse', 'ConvolutionV0'). add_enum_alias('Sparse', 'ConvolutionV0').
add_enum(Doc('Format', 'convolution data/filter/output format; see ' add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'), ':class:`RelayoutFormat` for more details'),
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88',
'NCHW44','NCHW44_DOT',
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'),
Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, '
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6',
'NCHW44 = 7','NCHW44_DOT = 8',
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NCHW4_NHWC = 12', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'),
Doc('NHWC_NCHW = 13', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'), 'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
Doc('NHWC_NCHW4_IC_SMALL = 14', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'), 'output tensor is nchw4 layout, padding c=4'),
Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
Doc('NCHW_NCHW4_IC_SMALL = 15', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'), 'output tensor is nchw4 layout, padding c=4'),
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
Doc('CHWN4 = 16', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'),
Doc('NCHW64', 'NCHW64 is designed for convolution implementation to utilizing TensorCore '
Doc('NCHW64 = 17', 'NCHW64 is designed for convolution implementation to utilizing TensorCore '
'instructions for 4-bit integers on Nvidia platforms')). 'instructions for 4-bit integers on Nvidia platforms')).
add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode')
) )
@@ -129,15 +129,15 @@ pdef('Axis').add_fields('int32', 'axis', 0)
) )


(pdef('ConvPooling'). (pdef('ConvPooling').
add_enum('Method', 'WITH_TEXTURE_OBJ', 'WITH_SHARED_MEM').
add_enum('Method', 'WITH_TEXTURE_OBJ = 0', 'WITH_SHARED_MEM = 1').
add_enum_alias('ConvMode', 'ConvolutionV0', 'Mode'). add_enum_alias('ConvMode', 'ConvolutionV0', 'Mode').
add_enum('PoolMode', 'AVERAGE', 'MAX').
add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID').
add_enum('PoolMode', 'AVERAGE = 0', 'MAX = 1').
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2').
add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1, \ add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1, \
'pool_pad_h', 0, 'pool_pad_w', 0, 'conv_stride_h', 1, 'conv_stride_w', 1, 'conv_pad_h', 0, 'conv_pad_w', 0)) 'pool_pad_h', 0, 'pool_pad_w', 0, 'conv_stride_h', 1, 'conv_stride_w', 1, 'conv_pad_h', 0, 'conv_pad_w', 0))


(pdef('ConvBias', 'legacy conv_bias', version=0, is_legacy=True). (pdef('ConvBias', 'legacy conv_bias', version=0, is_legacy=True).
add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID', 'H_SWISH').
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2', 'H_SWISH = 3').
add_enum_alias('Mode', 'ConvolutionV0'). add_enum_alias('Mode', 'ConvolutionV0').
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1)) add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1))


@@ -215,9 +215,9 @@ pdef('Axis').add_fields('int32', 'axis', 0)
) )
(pdef('SeparableConv'). (pdef('SeparableConv').
add_enum_alias('Mode', 'ConvolutionV0'). add_enum_alias('Mode', 'ConvolutionV0').
add_enum('BorderMode', 'BORDER_REPLICATE', 'BORDER_REFLECT',
'BORDER_REFLECT_101','BORDER_WRAP',
'BORDER_CONSTANT', 'BORDER_TRANSPARENT','BORDER_ISOLATED').
add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1',
'BORDER_REFLECT_101 = 2','BORDER_WRAP = 3',
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5','BORDER_ISOLATED = 6').
add_fields('bool', 'is_symm_kernel', 'true'). add_fields('bool', 'is_symm_kernel', 'true').
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1,
'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1)) 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1))
@@ -233,11 +233,11 @@ pdef('Axis').add_fields('int32', 'axis', 0)
(pdef('Pooling', version=0, is_legacy=True). (pdef('Pooling', version=0, is_legacy=True).
add_enum( add_enum(
'Mode', 'Mode',
Doc('MAX', 'maximum value inside pooling window'),
Doc('AVERAGE',
Doc('MAX = 0', 'maximum value inside pooling window'),
Doc('AVERAGE = 1',
'arithmetic mean of all values inside pooling window. Padding values ' 'arithmetic mean of all values inside pooling window. Padding values '
'are taken into account and are viewed as zero'), 'are taken into account and are viewed as zero'),
Doc('AVERAGE_COUNT_EXCLUDE_PADDING',
Doc('AVERAGE_COUNT_EXCLUDE_PADDING = 2',
'arithmetic mean of all values inside pooling window. No padding is' 'arithmetic mean of all values inside pooling window. No padding is'
'used.') 'used.')
). ).
@@ -273,15 +273,15 @@ pdef('Axis').add_fields('int32', 'axis', 0)
(pdef('BN'). (pdef('BN').
add_enum( add_enum(
'ParamDim', 'ParamDim',
Doc('DIM_11HW', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'),
Doc('DIM_1CHW', 'Dim of params (Sigma, Mu) is 1 x C x H x W'),
Doc('DIM_1C11', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'),
Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'),
Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'),
Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'),
name_field='param_dim' name_field='param_dim'
). ).
add_enum( add_enum(
'FwdMode', 'FwdMode',
Doc('TRAINING', 'Training phase.'),
Doc('INFERENCE', 'Inference phase.'),
Doc('TRAINING = 0', 'Training phase.'),
Doc('INFERENCE = 1', 'Inference phase.'),
name_field='fwd_mode' name_field='fwd_mode'
). ).
add_fields('float64', 'epsilon', '1e-4f'). add_fields('float64', 'epsilon', '1e-4f').
@@ -293,22 +293,22 @@ pdef('Axis').add_fields('int32', 'axis', 0)
(pdef('ROIPooling'). (pdef('ROIPooling').
add_enum( add_enum(
'Mode', 'Mode',
Doc('MAX', 'maximum value inside pooling window; pooling result would '
Doc('MAX = 0', 'maximum value inside pooling window; pooling result would '
'be 0 if pooling window is empty'), 'be 0 if pooling window is empty'),
Doc('AVERAGE',
Doc('AVERAGE = 1',
'arithmetic mean of all values inside pooling window; pooling result ' 'arithmetic mean of all values inside pooling window; pooling result '
'would be 0 if pooling window is empty') 'would be 0 if pooling window is empty')
). ).
add_fields('float32', 'scale', '1.f')) add_fields('float32', 'scale', '1.f'))


INTERP_MODES = ['NEAREST', 'LINEAR', 'AREA', 'CUBIC', 'LANCZOS4']
BORDER_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'),
Doc('REFLECT', 'fedcba|abcdefgh|hgfedcb'),
Doc('REFLECT_101', 'gfedcb|abcdefgh|gfedcba'),
Doc('WRAP', 'cdefgh|abcdefgh|abcdefg'),
Doc('CONSTANT', 'iiiiii|abcdefgh|iiiiiii'),
Doc('TRANSPARENT', ''),
Doc('ISOLATED', '')]
INTERP_MODES = ['NEAREST = 0', 'LINEAR = 1', 'AREA = 2', 'CUBIC = 3', 'LANCZOS4 = 4']
BORDER_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'),
Doc('REFLECT_101 = 2', 'gfedcb|abcdefgh|gfedcba'),
Doc('WRAP = 3', 'cdefgh|abcdefgh|abcdefg'),
Doc('CONSTANT = 4', 'iiiiii|abcdefgh|iiiiiii'),
Doc('TRANSPARENT = 5', ''),
Doc('ISOLATED = 6', '')]
(pdef('WarpPerspective', version=1, is_legacy=True). (pdef('WarpPerspective', version=1, is_legacy=True).
add_enum('InterpolationMode', *INTERP_MODES, add_enum('InterpolationMode', *INTERP_MODES,
name_field='imode', default=1, name_field='imode', default=1,
@@ -328,181 +328,181 @@ BORDER_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'),
add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')) add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))




pdef('SpatialTfGridGenerator').add_enum('Mode', 'AFFINE')
pdef('SpatialTfSampler').add_enum('Mode', 'BILINEAR')
pdef('SpatialTfGridGenerator').add_enum('Mode', 'AFFINE = 0')
pdef('SpatialTfSampler').add_enum('Mode', 'BILINEAR = 0')


pdef('AddUpdate').add_fields( pdef('AddUpdate').add_fields(
'float32', 'alpha', '1.f', 'beta', '1.f', 'bias', '0.f') 'float32', 'alpha', '1.f', 'beta', '1.f', 'bias', '0.f')


pdef('Elemwise').add_enum( pdef('Elemwise').add_enum(
'Mode', 'Mode',
Doc('RELU', 'unary: max(x, 0)'),
Doc('ABS', 'unary: abs(x)'),
Doc('ACOS', 'unary: acos(x)'),
Doc('ASIN', 'unary: asin(x)'),
Doc('CEIL', 'unary: ceil(x)'),
Doc('COS', 'unary: cos(x)'),
Doc('EXP', 'unary: exp(x)'),
Doc('EXPM1', 'unary: numerically stable exp(x)-1'),
Doc('FLOOR', 'unary: floor(x)'),
Doc('LOG', 'unary: natural logarithm, log(x)'),
Doc('LOG1P', 'unary: numerically stable log(x+1)'),
Doc('NEGATE', 'unary: -x'),
Doc('SIGMOID', 'unary: 1/(1+exp(-x))'),
Doc('SIN', 'unary: sin(x)'),
Doc('TANH', 'unary: tanh(x)'),
Doc('ABS_GRAD', 'binary: x > 0 ? y : -y'),
Doc('ADD', 'binary: x + y'),
Doc('FLOOR_DIV', 'binary: floor(x / y)'),
Doc('MAX', 'binary: max(x, y)'),
Doc('MIN', 'binary: min(x, y)'),
Doc('MOD', 'binary: x % y or fmodf(x, y)'),
Doc('MUL', 'binary: x * y'),
Doc('POW', 'binary: pow(x, y)'),
Doc('SIGMOID_GRAD', 'binary: x * (1 - x) * y'),
Doc('SUB', 'binary: x - y'),
Doc('SWITCH_GT0', 'binary: (x > 0) * y'),
Doc('TANH_GRAD', 'binary: (1 - x * x) * y'),
Doc('TRUE_DIV', 'binary: x / y'),
Doc('LOG_SUM_EXP', 'binary: numerically stable log(exp(x) + exp(y))'),
Doc('LT', 'binary: x < y'),
Doc('LEQ', 'binary: x <= y'),
Doc('EQ', 'binary: x == y'),
Doc('SHL', 'bitwise binary: x << y. '
Doc('RELU = 0', 'unary: max(x, 0)'),
Doc('ABS = 1', 'unary: abs(x)'),
Doc('ACOS = 2', 'unary: acos(x)'),
Doc('ASIN = 3', 'unary: asin(x)'),
Doc('CEIL = 4', 'unary: ceil(x)'),
Doc('COS = 5', 'unary: cos(x)'),
Doc('EXP = 6', 'unary: exp(x)'),
Doc('EXPM1 = 7', 'unary: numerically stable exp(x)-1'),
Doc('FLOOR = 8', 'unary: floor(x)'),
Doc('LOG = 9', 'unary: natural logarithm, log(x)'),
Doc('LOG1P = 10', 'unary: numerically stable log(x+1)'),
Doc('NEGATE = 11', 'unary: -x'),
Doc('SIGMOID = 12', 'unary: 1/(1+exp(-x))'),
Doc('SIN = 13', 'unary: sin(x)'),
Doc('TANH = 14', 'unary: tanh(x)'),
Doc('ABS_GRAD = 15', 'binary: x > 0 ? y : -y'),
Doc('ADD = 16', 'binary: x + y'),
Doc('FLOOR_DIV = 17', 'binary: floor(x / y)'),
Doc('MAX = 18', 'binary: max(x, y)'),
Doc('MIN = 19', 'binary: min(x, y)'),
Doc('MOD = 20', 'binary: x % y or fmodf(x, y)'),
Doc('MUL = 21', 'binary: x * y'),
Doc('POW = 22', 'binary: pow(x, y)'),
Doc('SIGMOID_GRAD = 23', 'binary: x * (1 - x) * y'),
Doc('SUB = 24', 'binary: x - y'),
Doc('SWITCH_GT0 = 25', 'binary: (x > 0) * y'),
Doc('TANH_GRAD = 26', 'binary: (1 - x * x) * y'),
Doc('TRUE_DIV = 27', 'binary: x / y'),
Doc('LOG_SUM_EXP = 28', 'binary: numerically stable log(exp(x) + exp(y))'),
Doc('LT = 29', 'binary: x < y'),
Doc('LEQ = 30', 'binary: x <= y'),
Doc('EQ = 31', 'binary: x == y'),
Doc('SHL = 32', 'bitwise binary: x << y. '
'Note that result is undefined if y < 0 or y >= bitwidth. Logical ' 'Note that result is undefined if y < 0 or y >= bitwidth. Logical '
'shift is performed for unsigned intergers, and arithmetic shift for ' 'shift is performed for unsigned intergers, and arithmetic shift for '
'signed ones.'), 'signed ones.'),
Doc('SHR', 'bitwise binary: x >> y; see SHL mode for more details'),
Doc('SHR = 33', 'bitwise binary: x >> y; see SHL mode for more details'),


Doc('COND_LEQ_MOV', 'ternary: x <= y ? z : 0'),
Doc('FUSE_MUL_ADD3',
Doc('COND_LEQ_MOV = 34', 'ternary: x <= y ? z : 0'),
Doc('FUSE_MUL_ADD3 = 35',
'compute ``a * b + c`` where c must either have same layout as ' 'compute ``a * b + c`` where c must either have same layout as '
'a or b, or be a scalar'), 'a or b, or be a scalar'),


Doc('FUSE_MUL_ADD4',
Doc('FUSE_MUL_ADD4 = 36',
'compute ``a * A + b * B`` where a and b must have equal layout, ' 'compute ``a * A + b * B`` where a and b must have equal layout, '
'and A and B must have equal layout. In the inputs ``b`` and ``B`` ' 'and A and B must have equal layout. In the inputs ``b`` and ``B`` '
'can be swapped'), 'can be swapped'),


Doc('FUSE_ADD_RELU', 'binary: max(x+y, 0)'),
Doc('FUSE_ADD_SIGMOID', 'binary: 1/(1+exp(-(x+y)))'),
Doc('FUSE_ADD_TANH', 'binary: tanh(x+y)'),
Doc('FAST_TANH', 'unary: rational approximation of tanh(x)'),
Doc('FAST_TANH_GRAD', 'binary: grad of the rational approximation of tanh(x)'),
Doc('FUSE_ADD_RELU = 37', 'binary: max(x+y, 0)'),
Doc('FUSE_ADD_SIGMOID = 38', 'binary: 1/(1+exp(-(x+y)))'),
Doc('FUSE_ADD_TANH = 39', 'binary: tanh(x+y)'),
Doc('FAST_TANH = 40', 'unary: rational approximation of tanh(x)'),
Doc('FAST_TANH_GRAD = 41', 'binary: grad of the rational approximation of tanh(x)'),


Doc('ROUND', 'unary: round(x), the nearest integer value to x, rounding '
Doc('ROUND = 42', 'unary: round(x), the nearest integer value to x, rounding '
'halfway cases away from zero. Float only.'), 'halfway cases away from zero. Float only.'),
Doc('RMULH', 'binary: rounded higher l bits of x * y, where l is the bit '
Doc('RMULH = 43', 'binary: rounded higher l bits of x * y, where l is the bit '
'length of x.'), 'length of x.'),


Doc('ATAN2','binary: atan2(y,x)'),
Doc('ERF', 'unary: erf(x)'),
Doc('ERFINV', 'unary: inverse function of erf(x)'),
Doc('ERFC', 'unary: erfc(x)'),
Doc('ERFCINV', 'unary: inverse function of erfc(x)'),
Doc('H_SWISH', 'unary: x * clip(x + 3, 0, 6) / 6'),
Doc('H_SWISH_GRAD', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'),
Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)'),
Doc('NOT', 'unary: !x'),
Doc('AND', 'binary: x && y'),
Doc('OR', 'binary: x || y'),
Doc('XOR', 'binary: x ^ y'),
Doc('SILU', 'unary: x / (1 + exp(-x))'),
Doc('SILU_GRAD', 'binary: grad(x / (1 + exp(-x))'),
Doc('GELU', 'unary: x Phi(x)'),
Doc('GELU_GRAD', 'binary: grad(x Phi(x))'),
Doc('ATAN2 = 44','binary: atan2(y,x)'),
Doc('ERF = 45', 'unary: erf(x)'),
Doc('ERFINV = 46', 'unary: inverse function of erf(x)'),
Doc('ERFC = 47', 'unary: erfc(x)'),
Doc('ERFCINV = 48', 'unary: inverse function of erfc(x)'),
Doc('H_SWISH = 49', 'unary: x * clip(x + 3, 0, 6) / 6'),
Doc('H_SWISH_GRAD = 50', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'),
Doc('FUSE_ADD_H_SWISH = 51', 'binary: hswish(x+y)'),
Doc('NOT = 52', 'unary: !x'),
Doc('AND = 53', 'binary: x && y'),
Doc('OR = 54', 'binary: x || y'),
Doc('XOR = 55', 'binary: x ^ y'),
Doc('SILU = 56', 'unary: x / (1 + exp(-x))'),
Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'),
Doc('GELU = 58', 'unary: x Phi(x)'),
Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'),
) )


pdef('ElemwiseMultiType').add_enum( pdef('ElemwiseMultiType').add_enum(
'Mode', 'Mode',
Doc('FUSE_MUL_ADD3_INT16x32x32x32',
Doc('FUSE_MUL_ADD3_INT16x32x32x32 = 0',
'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and ' 'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and '
'``c`` int32, and the result is int32. This mode is optimized for ' '``c`` int32, and the result is int32. This mode is optimized for '
'the channel-broadacsted case, i.e. ``a`` has shape (A, B, C) and ' 'the channel-broadacsted case, i.e. ``a`` has shape (A, B, C) and '
'``b`` and ``c`` have shape (1, C, 1)'), '``b`` and ``c`` have shape (1, C, 1)'),
Doc('FUSE_MUL_ADD3_IXxF32xF32xI8',
Doc('FUSE_MUL_ADD3_IXxF32xF32xI8 = 1',
'compuate ``a * b + c`` where the inputs ``a`` is an integer type ' 'compuate ``a * b + c`` where the inputs ``a`` is an integer type '
'``b`` and ``c`` are both ``float32``, the result is ' '``b`` and ``c`` are both ``float32``, the result is '
'``int8``. This is currently only optimized for ``(1, x)`` ' '``int8``. This is currently only optimized for ``(1, x)`` '
'broadcast for ``b`` and ``c``. Computation is carried in floating ' 'broadcast for ``b`` and ``c``. Computation is carried in floating '
'points and results are rounded towards zero with saturated cast to ' 'points and results are rounded towards zero with saturated cast to '
'int.'), 'int.'),
Doc('ROUND_SHR_SATURATE_IXxI8xI8',
Doc('ROUND_SHR_SATURATE_IXxI8xI8 = 2',
'Compute ``a >> b``, round the result according to lower ``b`` bits ' 'Compute ``a >> b``, round the result according to lower ``b`` bits '
'of ``a``` and make a saturating conversion to int8. Where ``a`` should' 'of ``a``` and make a saturating conversion to int8. Where ``a`` should'
' be an integer tensor and ``b`` should be an int8 scalar.'), ' be an integer tensor and ``b`` should be an int8 scalar.'),
Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8',
Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8 = 3',
'Fused operation of an int16 elemwise add, an int16 rounding multiply ' 'Fused operation of an int16 elemwise add, an int16 rounding multiply '
'high and an int16 to int8 rounding right shift with saturation.'), 'high and an int16 to int8 rounding right shift with saturation.'),
Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8',
Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8 = 4',
'Fused operation of an int32 elemwise add, an int32 rounding multiply ' 'Fused operation of an int32 elemwise add, an int32 rounding multiply '
'high and an int32 to int8 rounding right shift with saturation.'), 'high and an int32 to int8 rounding right shift with saturation.'),
Doc('ROUND_SHR_SATURATE_IXxI8xI16',
Doc('ROUND_SHR_SATURATE_IXxI8xI16 = 5',
'Compute ``a >> b``, round the result according to lower ``b`` bits of ' 'Compute ``a >> b``, round the result according to lower ``b`` bits of '
'``a``` and make a saturating conversion to int16. Where ``a`` should' '``a``` and make a saturating conversion to int16. Where ``a`` should'
' be an integer tensor and ``b`` should be an int8 scalar.'), ' be an integer tensor and ``b`` should be an int8 scalar.'),
Doc('QADD', 'Fused elemwise add two quantized int8 with specified'
Doc('QADD = 6', 'Fused elemwise add two quantized int8 with specified'
'output quantized dtype'), 'output quantized dtype'),
Doc('QFUSE_ADD_RELU', 'Fused elemwise add two quantized int8 followed'
Doc('QFUSE_ADD_RELU = 7', 'Fused elemwise add two quantized int8 followed'
' by ReLU and typecvt to specified dtype'), ' by ReLU and typecvt to specified dtype'),
Doc('QMUL', 'Fused elemwise multiply two quantized int8 with specified'
Doc('QMUL = 8', 'Fused elemwise multiply two quantized int8 with specified'
'output quantized dtype'), 'output quantized dtype'),
Doc('QMIN', 'Fused elemwise min two quantized int8 with specified'
Doc('QMIN = 9', 'Fused elemwise min two quantized int8 with specified'
'output quantized dtype'), 'output quantized dtype'),
Doc('QMAX', 'quantized: max(x, y), with specified output quantized dtype'),
Doc('QSUB', 'quantized: x - y'),
Doc('QTRUE_DIV', 'quantized: x / y'),
Doc('QFUSE_ADD_SIGMOID', 'quantized: sigmoid(x + y)'),
Doc('QFUSE_ADD_TANH', 'quantized: tanh(x + y)'),
Doc('QRELU', 'quantized: x > 0 ? x : 0'),
Doc('QABS', 'quantized: x > 0 ? x : -x'),
Doc('QSIGMOID', 'quantized: sigmoid(x)'),
Doc('QEXP', 'quantized: exp(x)'),
Doc('QTANH', 'quantized: tanh(x)'),
Doc('QFUSE_MUL_ADD3', 'quantized: x * y + z'),
Doc('QFAST_TANH', 'quantized: fast_tanh(x)'),
Doc('QNEGATE', 'quantized: -x'),
Doc('QACOS', 'quantized: acos(x)'),
Doc('QASIN', 'quantized: asin(x)'),
Doc('QCEIL', 'quantized: ceil(x)'),
Doc('QCOS', 'quantized: cos(x)'),
Doc('QEXPM1', 'quantized: expm1(x)'),
Doc('QFLOOR', 'quantized: floor(x)'),
Doc('QLOG', 'quantized: log(x)'),
Doc('QLOG1P', 'quantized: log1p(x)'),
Doc('QSIN', 'quantized: sin(x)'),
Doc('QROUND', 'quantized: round(x)'),
Doc('QERF', 'quantized: erf(x)'),
Doc('QERFINV', 'quantized: erfinv(x)'),
Doc('QERFC', 'quantized: erfc(x)'),
Doc('QERFCINV', 'quantized: erfcinv(x)'),
Doc('QABS_GRAD', 'quantized: abs_grad'),
Doc('QFLOOR_DIV', 'quantized floor_div'),
Doc('QMOD', 'quantized mod'),
Doc('QSIGMOID_GRAD', 'quantized sigmoid_grad'),
Doc('QSWITCH_GT0', 'quantized switch_gt0'),
Doc('QTANH_GRAD', 'quantized tanh_grad'),
Doc('QLT', 'quantized lt'),
Doc('QLEQ', 'quantized leq'),
Doc('QEQ', 'quantized eq'),
Doc('QPOW', 'quantized pow'),
Doc('QLOG_SUM_EXP', 'quantized log_sum_exp'),
Doc('QFAST_TANH_GRAD', 'quantized fast_tanh_grad'),
Doc('QATAN2', 'quantized atan2'),
Doc('QCOND_LEQ_MOV', 'quantized cond_leq_mov'),
Doc('QH_SWISH', 'quantized h_swish'),
Doc('QFUSE_ADD_H_SWISH', 'quantized h_swish(x+y)'),
Doc('QH_SWISH_GRAD', 'quantized h_swish_grad')
Doc('QMAX = 10', 'quantized: max(x, y), with specified output quantized dtype'),
Doc('QSUB = 11', 'quantized: x - y'),
Doc('QTRUE_DIV = 12', 'quantized: x / y'),
Doc('QFUSE_ADD_SIGMOID = 13', 'quantized: sigmoid(x + y)'),
Doc('QFUSE_ADD_TANH = 14', 'quantized: tanh(x + y)'),
Doc('QRELU = 15', 'quantized: x > 0 ? x : 0'),
Doc('QABS = 16', 'quantized: x > 0 ? x : -x'),
Doc('QSIGMOID = 17', 'quantized: sigmoid(x)'),
Doc('QEXP = 18', 'quantized: exp(x)'),
Doc('QTANH = 19', 'quantized: tanh(x)'),
Doc('QFUSE_MUL_ADD3 = 20', 'quantized: x * y + z'),
Doc('QFAST_TANH = 21', 'quantized: fast_tanh(x)'),
Doc('QNEGATE = 22', 'quantized: -x'),
Doc('QACOS = 23', 'quantized: acos(x)'),
Doc('QASIN = 24', 'quantized: asin(x)'),
Doc('QCEIL = 25', 'quantized: ceil(x)'),
Doc('QCOS = 26', 'quantized: cos(x)'),
Doc('QEXPM1 = 27', 'quantized: expm1(x)'),
Doc('QFLOOR = 28', 'quantized: floor(x)'),
Doc('QLOG = 29', 'quantized: log(x)'),
Doc('QLOG1P = 30', 'quantized: log1p(x)'),
Doc('QSIN = 31', 'quantized: sin(x)'),
Doc('QROUND = 32', 'quantized: round(x)'),
Doc('QERF = 33', 'quantized: erf(x)'),
Doc('QERFINV = 34', 'quantized: erfinv(x)'),
Doc('QERFC = 35', 'quantized: erfc(x)'),
Doc('QERFCINV = 36', 'quantized: erfcinv(x)'),
Doc('QABS_GRAD = 37', 'quantized: abs_grad'),
Doc('QFLOOR_DIV = 38', 'quantized floor_div'),
Doc('QMOD = 39', 'quantized mod'),
Doc('QSIGMOID_GRAD = 40', 'quantized sigmoid_grad'),
Doc('QSWITCH_GT0 = 41', 'quantized switch_gt0'),
Doc('QTANH_GRAD = 42', 'quantized tanh_grad'),
Doc('QLT = 43', 'quantized lt'),
Doc('QLEQ = 44', 'quantized leq'),
Doc('QEQ = 45', 'quantized eq'),
Doc('QPOW = 46', 'quantized pow'),
Doc('QLOG_SUM_EXP = 47', 'quantized log_sum_exp'),
Doc('QFAST_TANH_GRAD = 48', 'quantized fast_tanh_grad'),
Doc('QATAN2 = 49', 'quantized atan2'),
Doc('QCOND_LEQ_MOV = 50', 'quantized cond_leq_mov'),
Doc('QH_SWISH = 51', 'quantized h_swish'),
Doc('QFUSE_ADD_H_SWISH = 52', 'quantized h_swish(x+y)'),
Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad')
) )


pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)


(pdef('DctChannelSelect', '2d discrete cosine transform', version=0, is_legacy=True).add_enum_alias('Format', 'ConvolutionV0'). (pdef('DctChannelSelect', '2d discrete cosine transform', version=0, is_legacy=True).add_enum_alias('Format', 'ConvolutionV0').
add_enum('FastImpl', 'NONE', 'FIX_32_MASK').add_fields('int32', 'dct_block_size', 8))
add_enum('FastImpl', 'NONE = 0', 'FIX_32_MASK = 1').add_fields('int32', 'dct_block_size', 8))


(pdef('DctChannelSelect', '2d discrete cosine transform', version=1).add_enum_alias('Format', 'Convolution'). (pdef('DctChannelSelect', '2d discrete cosine transform', version=1).add_enum_alias('Format', 'Convolution').
add_enum_alias('FastImpl', 'DctChannelSelectV0').add_fields('int32', 'dct_block_size', 8)) add_enum_alias('FastImpl', 'DctChannelSelectV0').add_fields('int32', 'dct_block_size', 8))
@@ -510,13 +510,13 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
(pdef('MatrixMul', version=0, is_legacy=True). (pdef('MatrixMul', version=0, is_legacy=True).
add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). add_fields('bool', 'transposeA', 'false', 'transposeB', 'false').
add_enum('DataType', add_enum('DataType',
Doc('FLOAT', 'input/output both float32/float16'),
'INT8x8x16',
'INT8x8x32',
Doc('FLOAT_IO16xC32', 'input/output both float16, the internal compute is '
Doc('FLOAT = 0', 'input/output both float32/float16'),
'INT8x8x16 = 1',
'INT8x8x32 = 2',
Doc('FLOAT_IO16xC32 = 3', 'input/output both float16, the internal compute is '
'float32'), 'float32'),
Doc('QUINT8x8x32', 'input QuantizedAsymm8, output QuantizedS32'),
Doc('QUINT4x4x32', 'input QuantizedAsymm4, output QuantizedS32'),
Doc('QUINT8x8x32 = 4', 'input QuantizedAsymm8, output QuantizedS32'),
Doc('QUINT4x4x32 = 5', 'input QuantizedAsymm4, output QuantizedS32'),
name_field='data_type')) name_field='data_type'))


(pdef('MatrixMul', version=1, is_legacy=True). (pdef('MatrixMul', version=1, is_legacy=True).
@@ -524,9 +524,9 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. ' add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. '
'different combinations of intermediate result ' 'different combinations of intermediate result '
'data types.'), 'data types.'),
Doc('DEFAULT', 'No special requirements on the precision of '
Doc('DEFAULT = 0', 'No special requirements on the precision of '
'intermediate results.'), 'intermediate results.'),
Doc('FLOAT32', 'Use Float32 accumulator and intermediate result. '
Doc('FLOAT32 = 1', 'Use Float32 accumulator and intermediate result. '
'Only supported when input and output is Float16.'), 'Only supported when input and output is Float16.'),
name_field='compute_mode')) name_field='compute_mode'))


@@ -534,14 +534,14 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). add_fields('bool', 'transposeA', 'false', 'transposeB', 'false').
add_enum_alias('ComputeMode', 'MatrixMulV1', name_field='compute_mode'). add_enum_alias('ComputeMode', 'MatrixMulV1', name_field='compute_mode').
add_enum('Format', add_enum('Format',
Doc('DEFAULT', 'Normal matrix mul: (M, K) x (K, N) = (M, N)'),
Doc('MK4', 'Split 4 from M and K, better for neon compute:'
Doc('DEFAULT = 0', 'Normal matrix mul: (M, K) x (K, N) = (M, N)'),
Doc('MK4 = 1', 'Split 4 from M and K, better for neon compute:'
'(M/4, K/4, 4(k), 4(m)) x (K/4, N, 4(k)). if transposeA the ' '(M/4, K/4, 4(k), 4(m)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'),
Doc('MK8', 'Split 8 from M and K, better for neon compute:'
Doc('MK8 = 2', 'Split 8 from M and K, better for neon compute:'
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the '
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'),
Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:'
Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:'
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'))
) )
@@ -560,9 +560,9 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)


(pdef('Reduce', 'legacy reduce', version=0, is_legacy=True). (pdef('Reduce', 'legacy reduce', version=0, is_legacy=True).
add_enum('Mode', add_enum('Mode',
'SUM',
Doc('SUM_SQR', 'sum of x * x for each element x'),
'PRODUCT', 'MIN', 'MAX').
'SUM = 0',
Doc('SUM_SQR = 1', 'sum of x * x for each element x'),
'PRODUCT = 2', 'MIN = 3', 'MAX = 4').
add_fields('int32', add_fields('int32',
Doc('axis', Doc('axis',
'axis along which reduction is performed; if -1 is given, ' 'axis along which reduction is performed; if -1 is given, '
@@ -571,16 +571,16 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)


(pdef('Reduce', 'reduce along given axis', version=1, is_legacy=True). (pdef('Reduce', 'reduce along given axis', version=1, is_legacy=True).
add_enum('Mode', add_enum('Mode',
'SUM',
Doc('SUM_SQR', 'sum of x * x for each element x'),
'PRODUCT', 'MIN', 'MAX', 'MEAN').
'SUM = 0',
Doc('SUM_SQR = 1', 'sum of x * x for each element x'),
'PRODUCT = 2', 'MIN = 3', 'MAX = 4', 'MEAN = 5').
add_fields('int32', add_fields('int32',
Doc('axis', Doc('axis',
'axis along which reduction is performed; if -1 is given, ' 'axis along which reduction is performed; if -1 is given, '
'reduce to given target shape (only used in megbrain)'), 'reduce to given target shape (only used in megbrain)'),
-1). -1).
add_enum('DataType', add_enum('DataType',
Doc('DEFAULT',
Doc('DEFAULT = 0',
''' '''
input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode. input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode.
Currently, ```DEFAULT``` mode means: Currently, ```DEFAULT``` mode means:
@@ -607,26 +607,26 @@ Currently, ```DEFAULT``` mode means:


''' '''
), ),
Doc('FLOAT_IO16xC32', 'Deprecated. This was replaced by '
Doc('FLOAT_IO16xC32 = 1', 'Deprecated. This was replaced by '
'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'), 'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'),
Doc('FLOAT_O32xC32', 'compute/output both are float32'),
Doc('FLOAT_O16xC32', 'compute are float32, output float16'),
Doc('QUINT_I8xO32', 'input quint8, compute and output are qint32'),
Doc('QINT_I8xO32', 'input qint8, compute and output are qint32'),
Doc('FLOAT_O32xC32 = 2', 'compute/output both are float32'),
Doc('FLOAT_O16xC32 = 3', 'compute are float32, output float16'),
Doc('QUINT_I8xO32 = 4', 'input quint8, compute and output are qint32'),
Doc('QINT_I8xO32 = 5', 'input qint8, compute and output are qint32'),
name_field='data_type')) name_field='data_type'))


(pdef('Reduce', 'reduce along given axis', version=2). (pdef('Reduce', 'reduce along given axis', version=2).
add_enum('Mode', add_enum('Mode',
'SUM',
Doc('SUM_SQR', 'sum of x * x for each element x'),
'PRODUCT', 'MIN', 'MAX', 'MEAN').
'SUM = 0',
Doc('SUM_SQR = 1', 'sum of x * x for each element x'),
'PRODUCT = 2', 'MIN = 3', 'MAX = 4', 'MEAN = 5').
add_fields('int32', add_fields('int32',
Doc('axis', Doc('axis',
'axis along which reduction is performed; if INT_MAX is given, ' 'axis along which reduction is performed; if INT_MAX is given, '
'reduce to given target shape (only used in megbrain)'), 'reduce to given target shape (only used in megbrain)'),
(1<<31)-1). (1<<31)-1).
add_enum('DataType', add_enum('DataType',
Doc('DEFAULT',
Doc('DEFAULT = 0',
''' '''
input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode. input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode.
Currently, ```DEFAULT``` mode means: Currently, ```DEFAULT``` mode means:
@@ -653,12 +653,12 @@ Currently, ```DEFAULT``` mode means:


''' '''
), ),
Doc('FLOAT_IO16xC32', 'Deprecated. This was replaced by '
Doc('FLOAT_IO16xC32 = 1', 'Deprecated. This was replaced by '
'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'), 'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'),
Doc('FLOAT_O32xC32', 'compute/output both are float32'),
Doc('FLOAT_O16xC32', 'compute are float32, output float16'),
Doc('QUINT_I8xO32', 'input quint8, compute and output are qint32'),
Doc('QINT_I8xO32', 'input qint8, compute and output are qint32'),
Doc('FLOAT_O32xC32 = 2', 'compute/output both are float32'),
Doc('FLOAT_O16xC32 = 3', 'compute are float32, output float16'),
Doc('QUINT_I8xO32 = 4', 'input quint8, compute and output are qint32'),
Doc('QINT_I8xO32 = 5', 'input qint8, compute and output are qint32'),
name_field='data_type')) name_field='data_type'))


(pdef('Cumsum', 'calculate accumulated sum along given axis', version=0, is_legacy=True). (pdef('Cumsum', 'calculate accumulated sum along given axis', version=0, is_legacy=True).
@@ -691,12 +691,12 @@ Currently, ```DEFAULT``` mode means:


(pdef('CondTake'). (pdef('CondTake').
add_enum('Mode', add_enum('Mode',
Doc('EQ', 'take if ``abs(data-val)<eps``'),
Doc('NEQ', 'take if ``abs(data-val)>=eps``'),
Doc('LT', 'take if ``data<val``'),
Doc('LEQ', 'take if ``data<=val``'),
Doc('GT', 'take if ``data>val``'),
Doc('GEQ', 'take if ``data>=val``')).
Doc('EQ = 0', 'take if ``abs(data-val)<eps``'),
Doc('NEQ = 1', 'take if ``abs(data-val)>=eps``'),
Doc('LT = 2', 'take if ``data<val``'),
Doc('LEQ = 3', 'take if ``data<=val``'),
Doc('GT = 4', 'take if ``data>val``'),
Doc('GEQ = 5', 'take if ``data>=val``')).
add_fields('float32', add_fields('float32',
Doc('val', 'the value to be compared with; note that for integer ' Doc('val', 'the value to be compared with; note that for integer '
'data, val is also converted to int'), 0). 'data, val is also converted to int'), 0).
@@ -704,7 +704,7 @@ Currently, ```DEFAULT``` mode means:
1e-6)) 1e-6))




pdef('Argsort').add_enum('Order', 'ASCENDING', 'DESCENDING')
pdef('Argsort').add_enum('Order', 'ASCENDING = 0', 'DESCENDING = 1')


(pdef('IndexingRemap'). (pdef('IndexingRemap').
add_fields('bool', add_fields('bool',
@@ -791,17 +791,17 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
.add_fields('uint32', 'row_from', 0, 'row_to', 0, 'col_from', 0, 'col_to', 0)) .add_fields('uint32', 'row_from', 0, 'row_to', 0, 'col_from', 0, 'col_to', 0))


(pdef('CvtColor') (pdef('CvtColor')
.add_enum('Mode', 'RGB2GRAY', 'RGB2YUV', 'YUV2RGB', 'GRAY2RGB', 'RGBA2RGB',
'RGBA2BGR', 'RGBA2GRAY', 'RGB2BGR', 'BGR2GRAY', 'BGR2RGB',
Doc('YUV2GRAY_NV21', 'For historical reasons, referred to as YCC by opencv'),
'YUV2RGB_NV21', 'YUV2BGR_NV21', 'YUV2GRAY_NV12', 'YUV2RGB_NV12',
'YUV2BGR_NV12', 'YUV2GRAY_YV12', 'YUV2RGB_YV12', 'YUV2BGR_YV12',
'YUV2GRAY_YU12', 'YUV2RGB_YU12', 'YUV2BGR_YU12',
'YCrCb2RGB', 'YCrCb2BGR',
Doc('BT601_YUV2RGB_NV21', 'BT601 yuv format, referred to as YUV by opencv'),
'BT601_YUV2BGR_NV21', 'BT601_YUV2RGB_NV12', 'BT601_YUV2BGR_NV12',
'BT601_YUV2RGB_YV12', 'BT601_YUV2BGR_YV12', 'BT601_YUV2RGB_YU12',
'BT601_YUV2BGR_YU12',
.add_enum('Mode', 'RGB2GRAY = 0', 'RGB2YUV = 1', 'YUV2RGB = 2', 'GRAY2RGB = 3', 'RGBA2RGB = 4',
'RGBA2BGR = 5', 'RGBA2GRAY = 6', 'RGB2BGR = 7', 'BGR2GRAY = 8', 'BGR2RGB = 9',
Doc('YUV2GRAY_NV21 = 10', 'For historical reasons, referred to as YCC by opencv'),
'YUV2RGB_NV21 = 11', 'YUV2BGR_NV21 = 12', 'YUV2GRAY_NV12 = 13', 'YUV2RGB_NV12 = 14',
'YUV2BGR_NV12 = 15', 'YUV2GRAY_YV12 = 16', 'YUV2RGB_YV12 = 17', 'YUV2BGR_YV12 = 18',
'YUV2GRAY_YU12 = 19', 'YUV2RGB_YU12 = 20', 'YUV2BGR_YU12 = 21',
'YCrCb2RGB = 22', 'YCrCb2BGR = 23',
Doc('BT601_YUV2RGB_NV21 = 24', 'BT601 yuv format, referred to as YUV by opencv'),
'BT601_YUV2BGR_NV21 = 25', 'BT601_YUV2RGB_NV12 = 26', 'BT601_YUV2BGR_NV12 = 27',
'BT601_YUV2RGB_YV12 = 28', 'BT601_YUV2BGR_YV12 = 29', 'BT601_YUV2RGB_YU12 = 30',
'BT601_YUV2BGR_YU12 = 31',
member_alias=[('YUV2GRAY_NV21', 'BT601_YUV2GRAY_NV21'), member_alias=[('YUV2GRAY_NV21', 'BT601_YUV2GRAY_NV21'),
('YUV2GRAY_NV12', 'BT601_YUV2GRAY_NV12'), ('YUV2GRAY_NV12', 'BT601_YUV2GRAY_NV12'),
('YUV2GRAY_YV12', 'BT601_YUV2GRAY_YV12'), ('YUV2GRAY_YV12', 'BT601_YUV2GRAY_YV12'),
@@ -855,7 +855,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
.add_fields('float32', 'scalar', '0.f')) .add_fields('float32', 'scalar', '0.f'))


(pdef('Convolution3D'). (pdef('Convolution3D').
add_enum('Mode', 'CROSS_CORRELATION', 'CONVOLUTION').
add_enum('Mode', 'CROSS_CORRELATION = 0', 'CONVOLUTION = 1').
add_fields( add_fields(
'uint32', 'uint32',
Doc('pad_d', 'padding on one side on the first dimension'), 0, Doc('pad_d', 'padding on one side on the first dimension'), 0,
@@ -872,32 +872,32 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'on the third dimension'), 1 'on the third dimension'), 1
). ).
add_enum('Sparse', add_enum('Sparse',
Doc('DENSE', 'dense convolution: filter shape should be '
Doc('DENSE = 0', 'dense convolution: filter shape should be '
'[oc, ic, spatial...] if format is NCDHW, ' '[oc, ic, spatial...] if format is NCDHW, '
'[oc, spatial..., ic] if format is NDHWC'), '[oc, spatial..., ic] if format is NDHWC'),
Doc('GROUP', 'group convolution: filter shape should be '
Doc('GROUP = 1', 'group convolution: filter shape should be '
'[group, oc_per_group, ic_per_group, spatial...] if format is NCDHW, ' '[group, oc_per_group, ic_per_group, spatial...] if format is NCDHW, '
'[group, oc_per_group, spatial..., ic_per_group] if format is NDHWC') '[group, oc_per_group, spatial..., ic_per_group] if format is NDHWC')
). ).
add_enum('DataType', add_enum('DataType',
Doc('FLOAT', 'input/output both float32/float16'),
Doc('FLOAT_IO16xC32', 'input/output both float16, the internal '
Doc('FLOAT = 0', 'input/output both float32/float16'),
Doc('FLOAT_IO16xC32 = 1', 'input/output both float16, the internal '
'compute is float32'), 'compute is float32'),
name_field='data_type'). name_field='data_type').
add_enum('Format', 'NCDHW', 'NDHWC')
add_enum('Format', 'NCDHW = 0', 'NDHWC = 1')
) )


(pdef('Conv3DBias'). (pdef('Conv3DBias').
add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID').
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2').
add_enum_alias('Mode', 'Convolution3D'). add_enum_alias('Mode', 'Convolution3D').
add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0, add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0,
'stride_d', 1, 'stride_h', 1, 'stride_w', 0)) 'stride_d', 1, 'stride_h', 1, 'stride_w', 0))


(pdef('SeparableConv3D'). (pdef('SeparableConv3D').
add_enum_alias('Mode', 'Convolution3D'). add_enum_alias('Mode', 'Convolution3D').
add_enum('BorderMode', 'BORDER_REPLICATE', 'BORDER_REFLECT',
'BORDER_REFLECT_101','BORDER_WRAP',
'BORDER_CONSTANT', 'BORDER_TRANSPARENT','BORDER_ISOLATED').
add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1',
'BORDER_REFLECT_101 = 2','BORDER_WRAP = 3',
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5','BORDER_ISOLATED = 6').
add_fields('bool', 'is_symm_kernel', 'true'). add_fields('bool', 'is_symm_kernel', 'true').
add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0, add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0,
'stride_d', 0, 'stride_h', 1, 'stride_w', 1, 'stride_d', 0, 'stride_h', 1, 'stride_w', 1,
@@ -907,11 +907,11 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
(pdef('TopK'). (pdef('TopK').
add_enum( add_enum(
'Mode', 'Mode',
Doc('KTH_ONLY', "only the value of the k'th element would be computed"),
Doc('VALUE_IDX_NOSORT',
Doc('KTH_ONLY = 0', "only the value of the k'th element would be computed"),
Doc('VALUE_IDX_NOSORT = 1',
'all the top-k values and corresponding indices would be computed; ' 'all the top-k values and corresponding indices would be computed; '
'no order is guaranteed'), 'no order is guaranteed'),
Doc('VALUE_IDX_SORTED',
Doc('VALUE_IDX_SORTED = 2',
'all the top-k values and corresponding indices sorted')) 'all the top-k values and corresponding indices sorted'))
) )


@@ -983,37 +983,37 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
(pdef('RelayoutFormat', 'Change the tensor layout format', version=0, is_legacy=True). (pdef('RelayoutFormat', 'Change the tensor layout format', version=0, is_legacy=True).
add_enum( add_enum(
Doc('Mode', RELAYOUT_FORMAT_MODE_DOC), Doc('Mode', RELAYOUT_FORMAT_MODE_DOC),
'NHWC_NHWCD4',
'NHWCD4_NHWC',
'NHWC_NHWCD4I',
'NCHW_NHWCD4',
'NCHW_NHWCD4I',
'NHWCD4I_NCHW',
'NHWCD4_NCHW',
'INTER_WEIGHT_DENSE',
'INTER_WEIGHT_DENSEI',
'INTER_WEIGHT_GROUP',
'INTER_WEIGHT_GROUPI',
'INTER_WEIGHT_CHAN',
'INTER_WEIGHT_CHANI',
'INTER_WEIGHT_DENSEI_DOT',
'INTER_WEIGHT_GROUPI_DOT',
'NCHW4_CHWN4',
'CHWN4_NCHW4',
'NCHW_NCHW88_CONV_DENSE_WEIGHT',
'NCHW_NCHW88_CONV_CHAN_WEIGHT',
'NCHW_NCHW88_CONV_GROUP_WEIGHT',
'NCHW_NCHW88',
'NCHW88_NCHW',
'NCHW_NCHW4_IC_SMALL',
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT',
'NCHW_NCHW4',
'NCHW4_NCHW',
'NCHW_NCHW4_WEIGHT',
'NCHW_NCHW64',
'NCHW64_NCHW',
'NCHW_NHWC',
'NHWC_NCHW',
'NHWC_NHWCD4 = 0',
'NHWCD4_NHWC = 1',
'NHWC_NHWCD4I = 2',
'NCHW_NHWCD4 = 3',
'NCHW_NHWCD4I = 4',
'NHWCD4I_NCHW = 5',
'NHWCD4_NCHW = 6',
'INTER_WEIGHT_DENSE = 7',
'INTER_WEIGHT_DENSEI = 8',
'INTER_WEIGHT_GROUP = 9',
'INTER_WEIGHT_GROUPI = 10',
'INTER_WEIGHT_CHAN = 11',
'INTER_WEIGHT_CHANI = 12',
'INTER_WEIGHT_DENSEI_DOT = 13',
'INTER_WEIGHT_GROUPI_DOT = 14',
'NCHW4_CHWN4 = 15',
'CHWN4_NCHW4 = 16',
'NCHW_NCHW88_CONV_DENSE_WEIGHT = 17',
'NCHW_NCHW88_CONV_CHAN_WEIGHT = 18',
'NCHW_NCHW88_CONV_GROUP_WEIGHT = 19',
'NCHW_NCHW88 = 20',
'NCHW88_NCHW = 21',
'NCHW_NCHW4_IC_SMALL = 22',
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT = 23',
'NCHW_NCHW4 = 24',
'NCHW4_NCHW = 25',
'NCHW_NCHW4_WEIGHT = 26',
'NCHW_NCHW64 = 27',
'NCHW64_NCHW = 28',
'NCHW_NHWC = 29',
'NHWC_NCHW = 30',
) )
) )


@@ -1077,7 +1077,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o




(pdef('ROIAlign',version=0,is_legacy=True). (pdef('ROIAlign',version=0,is_legacy=True).
add_enum('Mode', 'MAX', 'AVERAGE', name_field='mode').
add_enum('Mode', 'MAX = 0', 'AVERAGE = 1', name_field='mode').
add_enum_alias('Format', 'ConvolutionV0'). add_enum_alias('Format', 'ConvolutionV0').
add_fields('float32', 'spatial_scale', '1.0'). add_fields('float32', 'spatial_scale', '1.0').
add_fields('float32', 'offset', '0.0'). add_fields('float32', 'offset', '0.0').
@@ -1173,9 +1173,9 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
pdef('Fill').add_fields('float32', 'value', '0') pdef('Fill').add_fields('float32', 'value', '0')




PADDING_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'),
Doc('REFLECT', 'fedcba|abcdefgh|hgfedcb'),
Doc('CONSTANT', 'iiiiii|abcdefgh|iiiiiii')]
PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'),
Doc('CONSTANT = 2', 'iiiiii|abcdefgh|iiiiiii')]
(pdef('Padding'). (pdef('Padding').
add_fields('uint32', Doc('front_offset_dim0','offset in dim 0'), 0). add_fields('uint32', Doc('front_offset_dim0','offset in dim 0'), 0).
add_fields('uint32', Doc('front_offset_dim1','offset in dim 1'), 0). add_fields('uint32', Doc('front_offset_dim1','offset in dim 1'), 0).


+ 11
- 8
imperative/tablegen/helper.h View File

@@ -241,14 +241,17 @@ private:
if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) { if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) {
body += formatv(" switch ({0}){{\n", "$_self." + it.name); body += formatv(" switch ({0}){{\n", "$_self." + it.name);
for (auto&& enumMember: enumAttr->getEnumMembers()) { for (auto&& enumMember: enumAttr->getEnumMembers()) {
body += formatv(
" case {0}::{1}::{2}:\n",
getCppClassName(), enumAttr->getEnumName(), enumMember
);
body += formatv(
" props_.emplace_back(\"{0}\", \"{1}\");\n",
it.name, enumMember
);
size_t d1 = enumMember.find(' ');
size_t d2 = enumMember.find('=');
size_t d = d1 <= d2 ? d1 : d2;
body += formatv(" case {0}::{1}::{2}:\n",
getCppClassName(),
enumAttr->getEnumName(),
enumMember.substr(0, d));
body +=
formatv(" props_.emplace_back(\"{0}\", "
"\"{1}\");\n",
it.name, enumMember.substr(0, d));
body += " break;\n"; body += " break;\n";
} }
body += " default: break;\n"; body += " default: break;\n";


+ 7
- 3
imperative/tablegen/targets/cpp_class.cpp View File

@@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() {
std::vector<std::string> case_body; std::vector<std::string> case_body;
std::string ename = formatv("{0}::{1}", std::string ename = formatv("{0}::{1}",
op.getCppClassName(), attr->getEnumName()); op.getCppClassName(), attr->getEnumName());
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
case_body.push_back(formatv(
"case {0}::{1}: return \"{1}\";", ename, v));
llvm::for_each(attr->getEnumMembers(), [&](auto&& v) {
size_t d1 = v.find(' ');
size_t d2 = v.find('=');
size_t d = d1 <= d2 ? d1 : d2;
case_body.push_back(
formatv("case {0}::{1}: return \"{1}\";", ename,
v.substr(0, d)));
}); });
os << formatv(R"( os << formatv(R"(
template <> template <>


+ 8
- 7
imperative/tablegen/targets/pybind11.cpp View File

@@ -50,14 +50,15 @@ void OpDefEmitter::emit() {
); );
std::vector<std::string> body; std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) { for (auto&& i: attr->getEnumMembers()) {
os << formatv(
"\n .value(\"{2}\", {0}::{1}::{2})",
className, attr->getEnumName(), i
);
size_t d1 = i.find(' ');
size_t d2 = i.find('=');
size_t d = d1 <= d2 ? d1 : d2;
os << formatv("\n .value(\"{2}\", {0}::{1}::{2})",
className, attr->getEnumName(),
i.substr(0, d));
body.push_back(formatv( body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};",
className, attr->getEnumName(), i
));
"if (str == \"{2}\") return {0}::{1}::{2};",
className, attr->getEnumName(), i.substr(0, d)));
} }
if (attr->getEnumCombinedFlag()) { if (attr->getEnumCombinedFlag()) {
//! define operator | //! define operator |


+ 13
- 3
imperative/tablegen/targets/python_c_extension.cpp View File

@@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() {
&ctx); &ctx);


auto quote = [&](auto&& i) -> std::string { auto quote = [&](auto&& i) -> std::string {
return formatv("\"{0}\"", i);
size_t d1 = i.find(' ');
size_t d2 = i.find('=');
size_t d = d1 <= d2 ? d1 : d2;
return formatv("\"{0}\"", i.substr(0, d));
}; };
os << tgfmt(R"( os << tgfmt(R"(
template<> const char* template<> const char*
@@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", ")); )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));


auto mem2value = [&](auto&& i) -> std::string { auto mem2value = [&](auto&& i) -> std::string {
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i);
size_t d1 = i.find(' ');
size_t d2 = i.find('=');
size_t d = d1 <= d2 ? d1 : d2;
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx,
i.substr(0, d));
}; };
os << tgfmt(R"( os << tgfmt(R"(
template<> std::unordered_map<std::string, $opClass::$enumClass> template<> std::unordered_map<std::string, $opClass::$enumClass>
@@ -192,12 +199,15 @@ os << tgfmt(R"(


auto&& members = attr->getEnumMembers(); auto&& members = attr->getEnumMembers();
for (size_t idx = 0; idx < members.size(); ++ idx) { for (size_t idx = 0; idx < members.size(); ++ idx) {
size_t d1 = members[idx].find(' ');
size_t d2 = members[idx].find('=');
size_t d = d1 <= d2 ? d1 : d2;
os << tgfmt(R"({ os << tgfmt(R"({
PyObject* inst = e_type->tp_alloc(e_type, 0); PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0); mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst; $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})", &ctx, members[idx], idx);
})", &ctx, members[idx].substr(0, d), idx);
} }
} }




+ 2
- 1
tools/gen_header_for_bin_reduce.py View File

@@ -136,12 +136,13 @@ class HeaderGen:
mode_list = [i.strip() for i in fin] mode_list = [i.strip() for i in fin]


for i in mode_list: for i in mode_list:
i = i.split(' ')[0].split('=')[0]
if i in self._elemwise_modes: if i in self._elemwise_modes:
content = '_cb({})'.format(i) content = '_cb({})'.format(i)
else: else:
content = '' content = ''
self._write_def( self._write_def(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content)
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content)
self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')




+ 38
- 38
tools/param_defs/mgb_opr_param_defs.py View File

@@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields(


(pdef('ExecutionPolicy', version=0, is_legacy=True). (pdef('ExecutionPolicy', version=0, is_legacy=True).
add_enum('Strategy', add_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'),
Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, '
Doc('HEURISTIC = 0', 'use heuristic to choose the fastest algorithm'),
Doc('HEURISTIC_REPRODUCIBLE = 1', 'use heuristic to choose the fastest algorithm, '
'and the chosen algorithm is reproducible'), 'and the chosen algorithm is reproducible'),
Doc('PROFILE',
Doc('PROFILE = 2',
'run possible algorithms on real device to find the best'), 'run possible algorithms on real device to find the best'),
Doc('PROFILE_REPRODUCIBLE',
Doc('PROFILE_REPRODUCIBLE = 3',
'the fastest of profile result that is also reproducible'), 'the fastest of profile result that is also reproducible'),
Doc('PROFILE_HEURISTIC',
Doc('PROFILE_HEURISTIC = 4',
'use profile result and heuristic to choose the fastest algorithm')). 'use profile result and heuristic to choose the fastest algorithm')).
add_fields('uint64', add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'), Doc('workspace_limit', 'workspace limit in bytes'),
@@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields(


(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1). (pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1).
add_bit_combination_enum('Strategy', add_bit_combination_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'),
Doc('PROFILE',
Doc('HEURISTIC = 1 << 0', 'use heuristic to choose the fastest algorithm'),
Doc('PROFILE = 1 << 1',
'run possible algorithms on real device to find the best'), 'run possible algorithms on real device to find the best'),
Doc('REPRODUCIBLE',
Doc('REPRODUCIBLE = 1 << 2',
'when profile or heuristic algo selection it require the algos' 'when profile or heuristic algo selection it require the algos'
'must be reproducible'), 'must be reproducible'),
Doc('OPTIMIZED',
Doc('OPTIMIZED = 1 << 3',
'profile require algos are optmized to achieve fast-profile'), 'profile require algos are optmized to achieve fast-profile'),
default=('HEURISTIC',), default=('HEURISTIC',),
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'), member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'),
@@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('CollectiveComm', 'collective communication between multiple computing ' (pdef('CollectiveComm', 'collective communication between multiple computing '
'nodes on localhost') 'nodes on localhost')
.add_enum(Doc('Mode', 'mode of collective communication'), .add_enum(Doc('Mode', 'mode of collective communication'),
Doc('REDUCE_SUM', 'reduce by sum to output computing node'),
Doc('BROADCAST', 'copy input value to each output computing node'),
Doc('ALL_GATHER', 'each output comp node gets the concatenated '
Doc('REDUCE_SUM = 0', 'reduce by sum to output computing node'),
Doc('BROADCAST = 1', 'copy input value to each output computing node'),
Doc('ALL_GATHER = 2', 'each output comp node gets the concatenated '
'value of all inputs'), 'value of all inputs'),
Doc('REDUCE_SCATTER_SUM',
Doc('REDUCE_SCATTER_SUM = 3',
'reduce inputs by sum and each output gets one part of it'), 'reduce inputs by sum and each output gets one part of it'),
Doc('ALL_REDUCE_SUM', 'every output gets the sum of all inputs'),
Doc('ALL_REDUCE_MAX', 'every output gets the max of all inputs'),
Doc('ALL_REDUCE_MIN', 'every output gets the min of all inputs'),
Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'),
Doc('GATHER', 'concat inputs to one node'),
Doc('SCATTER', 'scatter input to each output computing node'),
Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'),
Doc('ALL_REDUCE_SUM = 4', 'every output gets the sum of all inputs'),
Doc('ALL_REDUCE_MAX = 5', 'every output gets the max of all inputs'),
Doc('ALL_REDUCE_MIN = 6', 'every output gets the min of all inputs'),
Doc('ALL_REDUCE_PROD = 7', 'every output gets the prod of all inputs'),
Doc('GATHER = 8', 'concat inputs to one node'),
Doc('SCATTER = 9', 'scatter input to each output computing node'),
Doc('ALL_TO_ALL = 10', 'scatter inputs and gather them on each computing node'),
name_field='mode')) name_field='mode'))


(pdef('FakeSerializedDType', (pdef('FakeSerializedDType',
@@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields(
'evaluate a predicate and branch keys to setup ExecutionMask objects ' 'evaluate a predicate and branch keys to setup ExecutionMask objects '
'with associated predicate proxy vars (PPVs)') 'with associated predicate proxy vars (PPVs)')
.add_enum(Doc('Mode', 'how to compare predicate var with branch keys'), .add_enum(Doc('Mode', 'how to compare predicate var with branch keys'),
Doc('CASE',
Doc('CASE = 0',
'The outputs correspond to branch keys, ' 'The outputs correspond to branch keys, '
'and the one which equals predicate would be activated. ' 'and the one which equals predicate would be activated. '
'This behaves like a case-statement in many languages.'), 'This behaves like a case-statement in many languages.'),
Doc('CASE_FALLBACK', 'like :attr:`CASE`, but add an extra output '
Doc('CASE_FALLBACK = 1', 'like :attr:`CASE`, but add an extra output '
'that would be activated if no branch is matched'), 'that would be activated if no branch is matched'),
Doc('PIECEWISE', 'One more outputs would be produced than the '
Doc('PIECEWISE = 2', 'One more outputs would be produced than the '
'number of branch keys, representing the interval in which the ' 'number of branch keys, representing the interval in which the '
'predicate var fits in. The intervals are defined as ' 'predicate var fits in. The intervals are defined as '
r':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, ' r':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, '
@@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields(


(pdef('CondExecPredLogical', (pdef('CondExecPredLogical',
'compute a logical function over a set of PPVs') 'compute a logical function over a set of PPVs')
.add_enum('Mode', Doc('OR', 'logical or'),
Doc('AND', 'logical and'),
Doc('XOR', 'exclusive-or'),
Doc('NOR', 'not or(inputs)'),
Doc('NAND', 'not and(inputs)'),
Doc('XNOR', 'not xor(inputs)'))
.add_enum('Mode', Doc('OR = 0', 'logical or'),
Doc('AND = 1', 'logical and'),
Doc('XOR = 2', 'exclusive-or'),
Doc('NOR = 3', 'not or(inputs)'),
Doc('NAND = 4', 'not and(inputs)'),
Doc('XNOR = 5', 'not xor(inputs)'))
) )


(pdef('CondExecMark', (pdef('CondExecMark',
'add ExecutionMask of the input PPV to this opr and readers of the ' 'add ExecutionMask of the input PPV to this opr and readers of the '
'outputs of this opr') 'outputs of this opr')
.add_enum(Doc('GradMode', 'mode for computing the gradient'), .add_enum(Doc('GradMode', 'mode for computing the gradient'),
Doc('SUM', 'normal gradient mode: sum all the activated components'),
Doc('SUM_COND_OUT', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
Doc('SUM = 0', 'normal gradient mode: sum all the activated components'),
Doc('SUM_COND_OUT = 1', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
'oprs that depend on the gradient opr would not be executed ' 'oprs that depend on the gradient opr would not be executed '
'if the forward var is not used.'), 'if the forward var is not used.'),
name_field='grad_mode') name_field='grad_mode')
@@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields(
execution into account, this option can be used to bypass static execution into account, this option can be used to bypass static
inference errors. This is currently only used by automatically inference errors. This is currently only used by automatically
generated gradient oprs."""), generated gradient oprs."""),
Doc('SHAPE_VALUE', 'enable both shape and value inference'),
Doc('SHAPE_ONLY',
Doc('SHAPE_VALUE = 0', 'enable both shape and value inference'),
Doc('SHAPE_ONLY = 1',
'only enable shape inference (disable value inference)'), 'only enable shape inference (disable value inference)'),
Doc('NONE', 'disable both shape and value inference'),
Doc('NONE = 2', 'disable both shape and value inference'),
name_field='static_infer') name_field='static_infer')
) )


@@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields(
'number of output vars (i.e. vars per branch)'), 'number of output vars (i.e. vars per branch)'),
1) 1)
.add_enum('Mode', .add_enum('Mode',
Doc('EXACT_ONE', 'copy the var whose mask is activated to the output'
Doc('EXACT_ONE = 0', 'copy the var whose mask is activated to the output'
', requiring that exactly one branch is active'), ', requiring that exactly one branch is active'),
Doc('EXACT_ONE_SAME_SHAPE', 'like :attr:`EXACT_ONE` with the '
Doc('EXACT_ONE_SAME_SHAPE = 1', 'like :attr:`EXACT_ONE` with the '
'requirement that all branches have the same shape, so shape ' 'requirement that all branches have the same shape, so shape '
'inference can be easier'), 'inference can be easier'),
Doc('SUM', 'sum all the active branches into output var; require '
Doc('SUM = 2', 'sum all the active branches into output var; require '
'all branches to have the same shape. Extra shape vars are ' 'all branches to have the same shape. Extra shape vars are '
'needed in this mod, so the outputs can be initialized to zero ' 'needed in this mod, so the outputs can be initialized to zero '
'when no input is active (and their shapes are probably ' 'when no input is active (and their shapes are probably '
'unknown).'), 'unknown).'),
Doc('SUM_COND_OUT', 'like :attr:`SUM` but also add an ExecutionMask'
Doc('SUM_COND_OUT = 3', 'like :attr:`SUM` but also add an ExecutionMask'
' to the readers of output vars, so they would be skipped if ' ' to the readers of output vars, so they would be skipped if '
' no branch is taken') ' no branch is taken')
) )


Loading…
Cancel
Save