Compare commits

...

23 Commits

Author SHA1 Message Date
  Megvii Engine Team 7995f8579b chore(release): bump version 4 years ago
  Megvii Engine Team f3f9acf112 fix(tree): fix copybara 4 years ago
  Megvii Engine Team 36fbd5a65a fix(mgb): fix spell error 4 years ago
  Megvii Engine Team 957d1d40e6 fix(src/gopt): add replace func of typecvt opr for nhwcd4 pass 4 years ago
  Megvii Engine Team 865fbe16c0 feat(imperative/utils): add optimize-for-inference interface for opgraph 4 years ago
  Megvii Engine Team c92317edc0 fix(mge/utils): filter out parameter "arg_names" and "output_name" in network.dump 4 years ago
  Megvii Engine Team 24b91b98c7 feat(mge/utils): add array method for varnode 4 years ago
  Megvii Engine Team 5e54fdc32b fix(imperative/utils): fix name filter of network varnode 4 years ago
  Megvii Engine Team 2369c4f607 fix(mge/utils): fix auto naming bug when expanding structure 4 years ago
  Megvii Engine Team d275a82318 chore(scripts): clarify and fix default value of bit combined enum 4 years ago
  Megvii Engine Team 7c715bd42e fix(mge/utils): fix module stats calculate flops bug for group conv and remove model status change 4 years ago
  Megvii Engine Team 245a3f8129 fix(mge/tools): fix module stats' receptive field bug for Module 4 years ago
  Megvii Engine Team 007a2376c3 fix(mge/tools): fix node display bug in tensorboard 4 years ago
  Megvii Engine Team b10238acd1 feat(mge/tools): add support of receptive_field stats for NetworkNode 4 years ago
  Megvii Engine Team 84c2a5c27a feat(mge/tools): add summary print for module_stats and network_visualize 4 years ago
  Megvii Engine Team edea528b40 feat(mge/tools): set network_visualize's log_path as optional flag 4 years ago
  Megvii Engine Team e6d77604c2 refactor(imperative): refactor tablegen code generator 4 years ago
  Megvii Engine Team cad8568c34 fix(mge/optimizer): fix optimizer's state_dict bug 4 years ago
  Megvii Engine Team 0ed3699895 fix(mge/quantization): fix quantized concat forward problem 4 years ago
  Megvii Engine Team 888c7f1f7a fix(mge/module): fix auto-naming error when there are containers in the module 4 years ago
  Megvii Engine Team d13b6a4a51 fix(mgb/core): fix allocate_task_block_unsafe in thread_impl_1.h 4 years ago
  Megvii Engine Team 1edcfa19a8 fix(imperative/opr): fix apply_on_var_node for broadcast 4 years ago
  Megvii Engine Team 26a81b8941 chore(release): bump version 4 years ago
59 changed files with 2459 additions and 1357 deletions
Unified View
  1. +14
    -11
      dnn/scripts/gen_flatbuffers_schema.py
  2. +106
    -30
      dnn/scripts/gen_param_defs.py
  3. +13
    -2
      dnn/scripts/gen_tablegen.py
  4. +1
    -1
      dnn/src/x86/conv_bias/opr_impl.cpp
  5. +17
    -2
      imperative/python/megengine/core/ops/special.py
  6. +4
    -2
      imperative/python/megengine/core/tensor/array_method.py
  7. +7
    -0
      imperative/python/megengine/core/tensor/dtype.py
  8. +16
    -11
      imperative/python/megengine/core/tensor/indexing.py
  9. +32
    -17
      imperative/python/megengine/core/tensor/utils.py
  10. +1
    -1
      imperative/python/megengine/functional/debug_param.py
  11. +2
    -3
      imperative/python/megengine/functional/elemwise.py
  12. +31
    -21
      imperative/python/megengine/functional/tensor.py
  13. +3
    -3
      imperative/python/megengine/module/module.py
  14. +1
    -1
      imperative/python/megengine/module/quantized/concat.py
  15. +2
    -1
      imperative/python/megengine/module/sequential.py
  16. +4
    -2
      imperative/python/megengine/optimizer/optimizer.py
  17. +10
    -2
      imperative/python/megengine/quantization/__init__.py
  18. +119
    -82
      imperative/python/megengine/tools/network_visualize.py
  19. +216
    -118
      imperative/python/megengine/utils/module_stats.py
  20. +108
    -60
      imperative/python/megengine/utils/network.py
  21. +119
    -56
      imperative/python/megengine/utils/network_node.py
  22. +1
    -1
      imperative/python/src/graph_rt.cpp
  23. +29
    -12
      imperative/python/src/ops.cpp
  24. +57
    -19
      imperative/python/src/tensor.cpp
  25. +6
    -0
      imperative/python/src/tensor.h
  26. +21
    -4
      imperative/python/test/helpers/utils.py
  27. +4
    -0
      imperative/python/test/integration/test_optimizer.py
  28. +79
    -52
      imperative/python/test/unit/core/test_indexing_op.py
  29. +222
    -69
      imperative/python/test/unit/functional/test_tensor.py
  30. +45
    -1
      imperative/python/test/unit/quantization/test_module.py
  31. +30
    -1
      imperative/python/test/unit/utils/test_dump_naming.py
  32. +9
    -15
      imperative/python/test/unit/utils/test_network.py
  33. +13
    -0
      imperative/python/test/unit/utils/test_network_node.py
  34. +1
    -1
      imperative/python/version_template.py
  35. +2
    -2
      imperative/src/impl/ops/broadcast.cpp
  36. +2
    -1
      imperative/tablegen/CMakeLists.txt
  37. +15
    -730
      imperative/tablegen/autogen.cpp
  38. +40
    -0
      imperative/tablegen/emitter.h
  39. +36
    -0
      imperative/tablegen/helper.h
  40. +309
    -0
      imperative/tablegen/targets/cpp_class.cpp
  41. +21
    -0
      imperative/tablegen/targets/cpp_class.h
  42. +142
    -0
      imperative/tablegen/targets/pybind11.cpp
  43. +19
    -0
      imperative/tablegen/targets/pybind11.h
  44. +314
    -0
      imperative/tablegen/targets/python_c_extension.cpp
  45. +19
    -0
      imperative/tablegen/targets/python_c_extension.h
  46. +1
    -1
      sdk/load-and-run/src/mgblar.cpp
  47. +3
    -1
      src/core/include/megbrain/utils/thread_impl_1.h
  48. +2
    -2
      src/core/include/megbrain/version.h
  49. +2
    -1
      src/gopt/impl/inference.cpp
  50. +51
    -2
      src/gopt/test/inference.cpp
  51. +3
    -3
      src/opr/impl/dnn/dnn.oprdecl
  52. +1
    -1
      src/opr/impl/search_policy/algo_chooser.cpp
  53. +5
    -5
      src/opr/test/dnn/convolution.cpp
  54. +2
    -5
      src/serialization/impl/extern_c_opr.cpp
  55. +8
    -0
      tools/mlir/mgb-file-check/CMakeLists.txt
  56. +3
    -0
      tools/mlir/mgb-file-check/mgb-file-check.sh
  57. +23
    -0
      tools/mlir/mgb-opt/CMakeLists.txt
  58. +85
    -0
      tools/mlir/mgb-opt/mgb-opt.cpp
  59. +8
    -2
      tools/param_defs/mgb_opr_param_defs.py

+ 14
- 11
dnn/scripts/gen_flatbuffers_schema.py View File

@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase):
name = p + e name = p + e
e = self._enums[(p, e)] e = self._enums[(p, e)]
self._write_doc(e.name) self._write_doc(e.name)
self._write("enum %s%s : uint {", p, e.name, indent=1)
attribute = "(bit_flags)" if e.combined else ""
self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1)
for idx, member in enumerate(e.members): for idx, member in enumerate(e.members):
self._write_doc(member) self._write_doc(member)
if e.combined:
self._write("%s=%d,", scramble_enum_member_name(str(member)),
1<<idx)
else:
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("}\n", indent=-1) self._write("}\n", indent=-1)


def _write_doc(self, doc): def _write_doc(self, doc):
@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase):
return return
self._write_doc(e.name) self._write_doc(e.name)
self._used_enum.add(key) self._used_enum.add(key)
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name,
scramble_enum_member_name(str(e.members[e.default])))
if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = scramble_enum_member_name(str(e.members[e.default]))
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)


def _resolve_const(self, v): def _resolve_const(self, v):
while v in self._cur_const_val: while v in self._cur_const_val:
@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase):
return return
self._used_enum.add((e.src_class, e.src_name)) self._used_enum.add((e.src_class, e.src_name))
enum_name = e.src_class + e.src_name enum_name = e.src_class + e.src_name
self._write(
"%s:%s = %s;", e.name_field, enum_name,
scramble_enum_member_name(str(e.src_enum.members[e.get_default()])))
s = e.src_enum
if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = scramble_enum_member_name(str(s.members[e.get_default()]))
self._write("%s:%s = %s;", e.name_field, enum_name, default)


def _get_fb_default(self, cppdefault): def _get_fb_default(self, cppdefault):
if not isinstance(cppdefault, str): if not isinstance(cppdefault, str):


+ 106
- 30
dnn/scripts/gen_param_defs.py View File

@@ -73,11 +73,21 @@ class member_defs:
"""define an enum; the result would contain both an enum class def and its """define an enum; the result would contain both an enum class def and its
corresponding data field corresponding data field


:param default: index of default member value
:param default:
for normal enum class: index of default member value
for bit combined class: tuple of index of default member value

For example, following representations of the default value for bit
combined class are all equivalent:
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...)
Enum(members=('a', 'b', 'c'), default=(0, 1), ...)
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...)


:attr name_field: name of the data field of this enum in the param :attr name_field: name of the data field of this enum in the param
struct struct
:attr member_alias: list of (member, alias) pairs
:attr member_alias:
for normal enum class: list of (member, alias) pairs
for bit combined class: list of (tuple of members, alias) paris
""" """
__slots__ = ['name', 'name_field', 'members', 'default', __slots__ = ['name', 'name_field', 'members', 'default',
'member_alias', 'combined'] 'member_alias', 'combined']
@@ -90,17 +100,11 @@ class member_defs:
name = member_defs.Doc.make(name) name = member_defs.Doc.make(name)
assert name.id[0].isupper() assert name.id[0].isupper()
members = tuple(map(member_defs.Doc.make, members)) members = tuple(map(member_defs.Doc.make, members))
if isinstance(default, str):
if default not in name_field:
raise ValueError(
"Default value '{}' does not exist.".format(default))
default = name_field.index(default)
assert isinstance(default, int)
self.name = name self.name = name
self.combined = combined self.combined = combined
self.name_field = self.get_name_field(name.id, name_field) self.name_field = self.get_name_field(name.id, name_field)
self.members = members self.members = members
self.default = default
self.default = self.normalize_enum_value(default)


self.all_enums[(param_name, name.id)] = self self.all_enums[(param_name, name.id)] = self


@@ -114,6 +118,43 @@ class member_defs:
assert isinstance(name_field, str) assert isinstance(name_field, str)
return name_field return name_field


def normalize_enum_value(self, value):
def normalize(v):
if isinstance(v, str):
if v not in self.members:
raise ValueError(
"enum member '{}' does not exist.".format(v))
v = self.members.index(v)
assert isinstance(v, int)
return v
if self.combined:
if isinstance(value, int):
value = self.decompose_combined_enum(value)
assert isinstance(value, tuple)
value = tuple(normalize(i) for i in value)
return value
else:
return normalize(value)

@staticmethod
def decompose_combined_enum(v):
"""Integer => tuple of the indexes of the enum members"""
assert isinstance(v, int)
idx = 0
members = []
while v > 0:
if v & 1:
members.append(idx)
idx += 1
v >>= 1
return tuple(members)

def compose_combined_enum(self, v):
"""tuple of members => Integer"""
assert self.combined and isinstance(v, tuple)
norm_v = self.normalize_enum_value(v)
return sum(1 << i for i in norm_v)

class Field(Base): class Field(Base):
"""define a normal data field""" """define a normal data field"""
__slots__ = ['name', 'dtype', 'default'] __slots__ = ['name', 'dtype', 'default']
@@ -146,6 +187,10 @@ class member_defs:
src_name = name src_name = name
self.src_name = src_name self.src_name = src_name
self.default = default self.default = default
# TODO: remove this assertion if needed; adding mock param_defs in
# current testing framework is too complicated, and currently we
# only allow aliasing of normal enum
assert not self.src_enum.combined


@property @property
def src_enum(self): def src_enum(self):
@@ -157,7 +202,7 @@ class member_defs:
set""" set"""
if self.default is None: if self.default is None:
return self.src_enum.default return self.src_enum.default
return self.default
return self.src_enum.normalize_enum_value(self.default)




class ParamDef: class ParamDef:
@@ -198,7 +243,7 @@ class ParamDef:
self.name.id, name, name_field, members, default, member_alias)) self.name.id, name, name_field, members, default, member_alias))
return self return self


def add_bit_combination_enum(self, name, *members, default=0,
def add_bit_combination_enum(self, name, *members, default=tuple(),
name_field=None, member_alias=[]): name_field=None, member_alias=[]):
self.members.append(member_defs.Enum( self.members.append(member_defs.Enum(
self.name.id, name, name_field, members, default, member_alias, True)) self.name.id, name, name_field, members, default, member_alias, True))
@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase):
' for idx, v in enumerate(pdata):\n' ' for idx, v in enumerate(pdata):\n'
' if isinstance(v, _EnumBase):\n' ' if isinstance(v, _EnumBase):\n'
' pdata[idx] = _enum_member2num[id(v)]\n' ' pdata[idx] = _enum_member2num[id(v)]\n'
' elif isinstance(v, _BitCombinedEnumBase):\n'
' pdata[idx] = v._value_\n'
' return tag + self._packer.pack(*pdata)\n' ' return tag + self._packer.pack(*pdata)\n'
'\n' '\n'
) )
self._write(
'class _EnumBase(enum.Enum):\n'
# it's hard to mix custom implemention into enum, just do copy-paste instead
classbody = (
' @classmethod\n' ' @classmethod\n'
' def __normalize(cls, val):\n' ' def __normalize(cls, val):\n'
' if isinstance(val, str):\n' ' if isinstance(val, str):\n'
@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase):
' return super()._missing_(value)\n' ' return super()._missing_(value)\n'
'\n' '\n'
) )
self._write(
'class _EnumBase(enum.Enum):\n' + classbody
)
self._write(
'class _BitCombinedEnumBase(enum.Flag):\n' + classbody
)
if not self._imperative: if not self._imperative:
self._write( self._write(
'def _as_dtype_num(dtype):\n' 'def _as_dtype_num(dtype):\n'
@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase):
def _on_member_enum(self, e): def _on_member_enum(self, e):
qualname = '{}.{}'.format(self._cur_param_name, e.name) qualname = '{}.{}'.format(self._cur_param_name, e.name)


self._write('class %s(_EnumBase):', e.name, indent=1)
if e.combined:
self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1)
else:
self._write('class %s(_EnumBase):', e.name, indent=1)

self._write_doc(e.name) self._write_doc(e.name)


for idx, emem in enumerate(e.members): for idx, emem in enumerate(e.members):
self._write('%s = "%s"', emem, emem)
self._write_doc(emem)
if e.combined: if e.combined:
self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, 1<<idx))
self._write('%s = 1 << %d', emem, idx)
self._write_doc(emem)
else: else:
self._write('%s = "%s"', emem, emem)
self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format( self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx)) qualname, emem, idx))


for emem, emem_alis in e.member_alias:
self._write('%s = %s', emem_alis, emem)
for emem, emem_alias in e.member_alias:
if e.combined:
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem))
else:
self._write('%s = %s', emem_alias, emem)


self._unindent() self._unindent()
self._write('') self._write('')


if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = "'{}'".format(e.members[e.default])

self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field), cvt='{}.convert({})'.format(qualname, e.name_field),
fmt='I', fmt='I',
default="'{}'".format(e.members[e.default]),
default=default,
type=qualname, type=qualname,
doc=None)) doc=None))


@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase):
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) self._write('%s = %s.%s', e.name, e.src_class, e.src_name)
s = e.src_enum s = e.src_enum
qualname = '{}.{}'.format(e.src_class, e.src_name) qualname = '{}.{}'.format(e.src_class, e.src_name)

if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = "'{}'".format(s.members[e.get_default()])
self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field), cvt='{}.convert({})'.format(qualname, e.name_field),
fmt='I', fmt='I',
default="'{}'".format(s.members[e.get_default()]),
default=default,
type=qualname, type=qualname,
doc=None)) doc=None))


@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase):
v += ',' v += ','
self._write(v) self._write(v)
for mem, alias in e.member_alias: for mem, alias in e.member_alias:
self._write('%s = %s,', alias, mem)
if e.combined:
self._write('%s = %s,', alias, e.compose_combined_enum(mem))
else:
self._write('%s = %s,', alias, mem)
self._write('};', indent=-1) self._write('};', indent=-1)
self._non_static_members.append(e) self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
str(e.name).upper(), len(e.members)) str(e.name).upper(), len(e.members))
self._add_ctor_args(e.name,
'{}::{}'.format(e.name, e.members[e.default]),
e.name_field)
if e.combined:
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default))
else:
default = '{}::{}'.format(e.name, e.members[e.default])
self._add_ctor_args(e.name, default, e.name_field)


def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
s = e.src_enum s = e.src_enum
@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase):
self._non_static_members.append(e) self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
str(e.name).upper(), len(s.members)) str(e.name).upper(), len(s.members))
self._add_ctor_args(e.name,
'{}::{}'.format(e.name,
s.members[e.get_default()]),
e.name_field)
if s.combined:
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default))
else:
default = '{}::{}'.format(e.name, s.members[e.get_default()])
self._add_ctor_args(e.name, default, e.name_field)


def _on_member_field(self, f): def _on_member_field(self, f):
self._non_static_members.append(f) self._non_static_members.append(f)


+ 13
- 2
dnn/scripts/gen_tablegen.py View File

@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase):
return return


# wrapped with default value # wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default)
if e.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default))
else:
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default])

wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)


self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase):
self._write("def {} : {};".format(td_class, enum_def)) self._write("def {} : {};".format(td_class, enum_def))


# wrapped with default value # wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default())
s = e.src_enum
if s.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default()))
else:
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()])

wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)


self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) self._current_tparams.append("{}:${}".format(wrapped, e.name_field))


+ 1
- 1
dnn/src/x86/conv_bias/opr_impl.cpp View File

@@ -185,7 +185,7 @@ SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
} }
//! conv1x1 //! conv1x1
im2col_prefer |= (FH == 1 && FW == 1); im2col_prefer |= (FH == 1 && FW == 1);
//! x86 8x8x16 not optmized, so it will use fallback im2col+matmul
//! x86 8x8x16 not optimized, so it will use fallback im2col+matmul
if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) { if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) {
im2col_prefer = true; im2col_prefer = true;
} }


+ 17
- 2
imperative/python/megengine/core/ops/special.py View File

@@ -8,6 +8,9 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np


from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor



class Const: class Const:
def __init__(self, value=None, *, dtype=None, device=None): def __init__(self, value=None, *, dtype=None, device=None):
@@ -19,7 +22,19 @@ class Const:
from ...tensor import Tensor from ...tensor import Tensor


device = self.device device = self.device
if device is None:
device = reference[0].device

if len(reference) != 0:
reference = reference[0]
assert isinstance(
reference, (SymbolVar, Tensor)
), "Reference should be Tensor or VarNode"

if device is None:
device = reference.device

if isinstance(reference, SymbolVar):
cls = type(reference)
rst = cls(make_const(reference.graph, self.value, device, self.dtype))
return (rst,)


return (Tensor(self.value, self.dtype, self.device, True),) return (Tensor(self.value, self.dtype, self.device, True),)

+ 4
- 2
imperative/python/megengine/core/tensor/array_method.py View File

@@ -13,7 +13,7 @@ from typing import Union
import numpy as np import numpy as np


from .._imperative_rt.common import CompNode from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from ..ops import builtin from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape from ..ops.builtin import Elemwise, GetVarShape
from . import utils from . import utils
@@ -230,7 +230,9 @@ def _todo(*_):


def _expand_args(args): def _expand_args(args):
if len(args) == 1: if len(args) == 1:
if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),):
if isinstance(
args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray),
):
args = args[0] args = args[0]
return args return args




+ 7
- 0
imperative/python/megengine/core/tensor/dtype.py View File

@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import re
from collections import namedtuple from collections import namedtuple
from typing import Union from typing import Union


@@ -22,6 +23,12 @@ from .._imperative_rt.common import (
) )




def get_dtype_bit(dtype_name: str):
numbers = re.findall(r"\d+", dtype_name)
assert len(numbers) == 1, "Unsupport dtype name with more than one number."
return int(numbers[0])


# normal dtype related # normal dtype related
def is_lowbit(dtype): def is_lowbit(dtype):
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) return (dtype is intb1) or (dtype is intb2) or (dtype is intb4)


+ 16
- 11
imperative/python/megengine/core/tensor/indexing.py View File

@@ -10,7 +10,7 @@ from typing import Iterable


import numpy as np import numpy as np


from .._imperative_rt.core2 import Tensor, apply
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from .._trace_option import use_symbolic_shape from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
@@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
return True return True


def get_index(i): def get_index(i):
if not isinstance(i, (Tensor)):
if not isinstance(i, (Tensor, SymbolVar)):
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_:
(i,) = Const(i, dtype=np.bool_, device=inp.device)()
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
else: else:
(i,) = Const(i, dtype=np.int32, device=inp.device)()
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp)
return i return i
assert isinstance(i, Tensor)
assert isinstance(i, (Tensor, SymbolVar))
if i.dtype != np.bool_: if i.dtype != np.bool_:
return i return i
_, ind = apply(builtin.CondTake(), i, i) _, ind = apply(builtin.CondTake(), i, i)
@@ -197,9 +197,9 @@ def try_condtake(tensor, index):
): ):
return [] return []
if isinstance(index, np.ndarray): if isinstance(index, np.ndarray):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)()
assert isinstance(index, Tensor)
if not isinstance(tensor, Tensor):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (Tensor, SymbolVar))
if not isinstance(tensor, (Tensor, SymbolVar)):
raise TypeError("input must be a tensor") raise TypeError("input must be a tensor")
if tensor.device != index.device: if tensor.device != index.device:
raise ValueError( raise ValueError(
@@ -214,11 +214,16 @@ def getitem(tensor, index):
return try_result[0] return try_result[0]
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
for v in tensors: for v in tensors:
if v.shape is None:
break
if isinstance(v.shape, v.__class__): if isinstance(v.shape, v.__class__):
break break
if len(v.shape) > 0 and v.shape[0] == 0: if len(v.shape) > 0 and v.shape[0] == 0:
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)()
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)(
tensor
)
return empty_tensor return empty_tensor

if use_subtensor: if use_subtensor:
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)
else: else:
@@ -235,8 +240,8 @@ def setitem(tensor, index, value):
if len(try_result) == 2: if len(try_result) == 2:
index = try_result[1] index = try_result[1]
tensor = tensor.reshape(-1) tensor = tensor.reshape(-1)
if not isinstance(value, Tensor):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)()
if not isinstance(value, (Tensor, SymbolVar)):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
if use_subtensor: if use_subtensor:
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)


+ 32
- 17
imperative/python/megengine/core/tensor/utils.py View File

@@ -11,8 +11,9 @@ from typing import Iterable, Union


import numpy as np import numpy as np


from .._imperative_rt import VarNode
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._wrap import device as as_device
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from .dtype import is_dtype_equal, is_quantize from .dtype import is_dtype_equal, is_quantize
@@ -38,13 +39,9 @@ def set_convert_inputs(flag):




def concatenate(inputs, axis=0, *, device=None): def concatenate(inputs, axis=0, *, device=None):
dtype = dtype_promotion(inputs)
device = get_device(inputs)

def convert(x):
return convert_single_value(x, dtype=dtype, device=device)

inputs = tuple(map(convert, inputs))
inputs = convert_inputs(*inputs)
if device is None:
device = get_device(inputs)
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs)
return result return result


@@ -60,7 +57,7 @@ def astype(x, dtype):




def convert_single_value(v, *, dtype=None, device=None): def convert_single_value(v, *, dtype=None, device=None):
if isinstance(v, (Tensor, VarNode)):
if isinstance(v, (Tensor, SymbolVar)):
if not is_quantize(v.dtype): if not is_quantize(v.dtype):
v = astype(v, dtype) v = astype(v, dtype)
else: else:
@@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None):
return v return v




def convert_inputs(*args: Tensor):
def convert_inputs(*args, device=None):
if not _enable_convert_inputs: if not _enable_convert_inputs:
return args return args


dtype = dtype_promotion(args) dtype = dtype_promotion(args)
device = get_device(args)
if device is None:
device = get_device(args)
device = as_device(device)

graph = None
sym_type = None
for a in args:
if isinstance(a, SymbolVar):
if graph is None:
graph = a.var.graph
sym_type = type(a)
else:
assert graph == a.var.graph
args = list(args)
if graph is not None:
for i in range(len(args)):
if not isinstance(args[i], SymbolVar):
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype)
args[i] = sym_type(rst)


def convert(value): def convert(value):
if value is None: if value is None:
return value return value
return convert_single_value(value, dtype=dtype, device=device)
return convert_single_value(value, dtype=dtype, device=device.to_c())


return tuple(map(convert, args)) return tuple(map(convert, args))


@@ -98,14 +113,14 @@ def result_type(*args):


def isscalar(x): def isscalar(x):


if isinstance(x, Tensor):
if isinstance(x, (Tensor, SymbolVar)):
return x._isscalar() return x._isscalar()


return np.isscalar(x) return np.isscalar(x)




def setscalar(x): def setscalar(x):
if isinstance(x, Tensor):
if isinstance(x, (Tensor, SymbolVar)):
x._setscalar() x._setscalar()
else: else:
raise NotImplementedError("Unsupport type {}".format(type(x))) raise NotImplementedError("Unsupport type {}".format(type(x)))
@@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if not isinstance(x, collections.abc.Sequence): if not isinstance(x, collections.abc.Sequence):
raise TypeError raise TypeError


if any(isinstance(i, Tensor) for i in x):
if any(isinstance(i, (Tensor, SymbolVar)) for i in x):
x = concatenate(x, device=device) x = concatenate(x, device=device)
if dtype is not None: if dtype is not None:
x = astype(x, dtype) x = astype(x, dtype)
@@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None):




def _expand_int(s, i): def _expand_int(s, i):
if isinstance(i, Tensor):
if isinstance(i, (Tensor, SymbolVar)):
i_np = i.numpy() i_np = i.numpy()
if i_np.ndim == 0: if i_np.ndim == 0:
s.append(int(i_np)) s.append(int(i_np))


+ 1
- 1
imperative/python/megengine/functional/debug_param.py View File

@@ -40,7 +40,7 @@ def set_execution_strategy(option):
* HEURISTIC uses heuristic to choose the fastest algorithm. * HEURISTIC uses heuristic to choose the fastest algorithm.
* PROFILE runs possible algorithms on real device to find the best one. * PROFILE runs possible algorithms on real device to find the best one.
* REPRODUCIBLE uses the algorithms that is reproducible. * REPRODUCIBLE uses the algorithms that is reproducible.
* OPTMIZED uses the algorithms that is optimized.
* OPTIMIZED uses the algorithms that is optimized.


The default strategy is HEURISTIC, this options can be combined to The default strategy is HEURISTIC, this options can be combined to
form a combination option, e.g. PROFILE | REPRODUCIBLE form a combination option, e.g. PROFILE | REPRODUCIBLE


+ 2
- 3
imperative/python/megengine/functional/elemwise.py View File

@@ -9,8 +9,7 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import numpy as np import numpy as np


from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.graph import VarNode
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Elemwise from ..core.ops.builtin import Elemwise
from ..core.tensor import utils from ..core.tensor import utils
@@ -72,7 +71,7 @@ __all__ = [




def _elwise(*args, mode): def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args))
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args))
if len(tensor_args) == 0: if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args) dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())


+ 31
- 21
imperative/python/megengine/functional/tensor.py View File

@@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union
import numpy as np import numpy as np


from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
@@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten
return result return result




def full(shape, value, dtype="float32", device=None):
def full(shape, value, dtype="float32", device=None) -> Tensor:
""" """
Returns a tensor with given shape and value. Returns a tensor with given shape and value.
""" """
@@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None):
return broadcast_to(x, shape) return broadcast_to(x, shape)




def ones(shape, dtype="float32", device=None):
def ones(shape, dtype="float32", device=None) -> Tensor:
""" """
Returns a ones tensor with given shape. Returns a ones tensor with given shape.


@@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None):
return full(shape, 1.0, dtype=dtype, device=device) return full(shape, 1.0, dtype=dtype, device=device)




def zeros(shape, dtype="float32", device=None):
def zeros(shape, dtype="float32", device=None) -> Tensor:
""" """
Returns a zero tensor with given shape. Returns a zero tensor with given shape.
""" """
return full(shape, 0.0, dtype=dtype, device=device) return full(shape, 0.0, dtype=dtype, device=device)




def zeros_like(inp: Tensor) -> Tensor:
def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
""" """
Returns a zero tensor with the same shape as input tensor. Returns a zero tensor with the same shape as input tensor.


@@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor:
[0 0 0]] [0 0 0]]


""" """
return zeros(inp.shape, dtype=inp.dtype, device=inp.device)
return full_like(inp, 0.0)




def ones_like(inp: Tensor) -> Tensor:
def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
""" """
Returns a ones tensor with the same shape as input tensor. Returns a ones tensor with the same shape as input tensor.
""" """
return ones(inp.shape, dtype=inp.dtype, device=inp.device)
return full_like(inp, 1.0)




def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
def full_like(
inp: Union[Tensor, SymbolVar], value: Union[int, float]
) -> Union[Tensor, SymbolVar]:
""" """
Returns a tensor filled with given value with the same shape as input tensor. Returns a tensor filled with given value with the same shape as input tensor.
""" """
return full(inp.shape, value, dtype=inp.dtype, device=inp.device)
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
if inp.shape is ():
return x
return broadcast_to(x, inp.shape)




def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
@@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
if len(inps) == 1: if len(inps) == 1:
return inps[0] return inps[0]


dtype = dtype_promotion(inps)
inps = convert_inputs(*inps, device=device)
if device is None: if device is None:
device = get_device(inps) device = get_device(inps)
device = as_device(device) device = as_device(device)

def convert(x):
return convert_single_value(x, dtype=dtype, device=device)

inps = tuple(map(convert, inps))
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
return result return result


@@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal, axis, Nsections Ntotal, axis, Nsections
) )
) )

func = (
floor_div
if isinstance(Nsections, (SymbolVar, Tensor))
else lambda x, y: x // y
)
div_points = [0] + [ div_points = [0] + [
floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
] ]
for i in range(2, Nsections + 1): for i in range(2, Nsections + 1):
div_points[i] = div_points[i - 1] + div_points[i] div_points[i] = div_points[i - 1] + div_points[i]
@@ -925,11 +931,15 @@ def linspace(
if not (cur_device is None or device == cur_device): if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr") raise ("ambiguous device for linspace opr")


if not isinstance(start, Tensor):
is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num])
if any(is_symbolvar) and not all(is_symbolvar):
raise TypeError("start, stop and num should all be VarNode or none of them")

if not isinstance(start, (Tensor, SymbolVar)):
start = Tensor(start, device=device) start = Tensor(start, device=device)
if not isinstance(stop, Tensor):
if not isinstance(stop, (Tensor, SymbolVar)):
stop = Tensor(stop, device=device) stop = Tensor(stop, device=device)
if not isinstance(num, Tensor):
if not isinstance(num, (Tensor, SymbolVar)):
num = Tensor(num, device=device) num = Tensor(num, device=device)


op = builtin.Linspace(comp_node=device) op = builtin.Linspace(comp_node=device)
@@ -983,7 +993,7 @@ def arange(
stop = stop.astype("float32") stop = stop.astype("float32")
if isinstance(step, Tensor): if isinstance(step, Tensor):
step = step.astype("float32") step = step.astype("float32")
num = ceil(Tensor((stop - start) / step, device=device))
num = ceil((stop - start) / step)
stop = start + step * (num - 1) stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device) result = linspace(start, stop, num, device=device)
if np.dtype(dtype) == np.int32: if np.dtype(dtype) == np.int32:


+ 3
- 3
imperative/python/megengine/module/module.py View File

@@ -607,10 +607,10 @@ class Module(metaclass=ABCMeta):


def __getattribute__(self, name: str): def __getattribute__(self, name: str):
value = super().__getattribute__(name) value = super().__getattribute__(name)
if name == "_name":
if name == "__dict__":
return value return value
if isinstance(value, (Tensor, Module)):
value._name = name
for prefix, variable in _expand_structure(name, value):
variable._name = prefix
return value return value


def __setattr__(self, name: str, value): def __setattr__(self, name: str, value):


+ 1
- 1
imperative/python/megengine/module/quantized/concat.py View File

@@ -23,7 +23,7 @@ class Concat(QuantizedModule):
self.output_dtype = dtype self.output_dtype = dtype


def forward(self, inps: Iterable[Tensor], axis: int = 0): def forward(self, inps: Iterable[Tensor], axis: int = 0):
new_inps = (x.astype(self.output_dtype) for x in inps)
new_inps = tuple(x.astype(self.output_dtype) for x in inps)
return F.concat(new_inps, axis) return F.concat(new_inps, axis)


@classmethod @classmethod


+ 2
- 1
imperative/python/megengine/module/sequential.py View File

@@ -92,6 +92,7 @@ class Sequential(Module):
return [getattr(self, key) for key in self.layer_keys] return [getattr(self, key) for key in self.layer_keys]


def forward(self, inp): def forward(self, inp):
for layer in self.layer_values:
# avoid layer_values as a name prefix, see Module.__getattribute__
for layer in [getattr(self, key) for key in self.layer_keys]:
inp = layer(inp) inp = layer(inp)
return inp return inp

+ 4
- 2
imperative/python/megengine/optimizer/optimizer.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Dict from typing import Dict
@@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta):
cur_id += 1 cur_id += 1


for param, st in self._state.items(): for param, st in self._state.items():
_st = copy.copy(st)
if not keep_var: if not keep_var:
for k, v in st.items(): for k, v in st.items():
st[k] = v.numpy()
state[param2id[param]] = st
_st[k] = v.numpy()
state[param2id[param]] = _st


for group in self.param_groups: for group in self.param_groups:
param_group = {k: v for k, v in group.items() if k != "params"} param_group = {k: v for k, v in group.items() if k != "params"}


+ 10
- 2
imperative/python/megengine/quantization/__init__.py View File

@@ -6,8 +6,16 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


from .fake_quant import FakeQuantize
from .observer import Observer
from .fake_quant import TQT, FakeQuantize
from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
Observer,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
from .qconfig import ( from .qconfig import (
QConfig, QConfig,
calibration_qconfig, calibration_qconfig,


+ 119
- 82
imperative/python/megengine/tools/network_visualize.py View File

@@ -8,14 +8,19 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse
import logging import logging
import re


import numpy as np import numpy as np


from megengine.core.tensor.dtype import is_quantize from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import ( from megengine.utils.module_stats import (
print_flops_stats,
print_params_stats,
enable_receptive_field,
get_op_stats,
get_param_stats,
print_op_stats,
print_param_stats,
print_summary,
sizeof_fmt, sizeof_fmt,
) )
from megengine.utils.network import Network from megengine.utils.network import Network
@@ -40,34 +45,41 @@ def visualize(
:param log_params: whether print and record params size. :param log_params: whether print and record params size.
:param log_flops: whether print and record op flops. :param log_flops: whether print and record op flops.
""" """
try:
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.step_stats_pb2 import (
AllocatorMemoryUsed,
DeviceStepStats,
NodeExecStats,
StepStats,
)
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.compat.proto.versions_pb2 import VersionDef
from tensorboardX import SummaryWriter
except ImportError:
logger.error(
"TensorBoard and TensorboardX are required for visualize.", exc_info=True
)
return
if log_path:
try:
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.step_stats_pb2 import (
AllocatorMemoryUsed,
DeviceStepStats,
NodeExecStats,
StepStats,
)
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.compat.proto.versions_pb2 import VersionDef
from tensorboardX import SummaryWriter
except ImportError:
logger.error(
"TensorBoard and TensorboardX are required for visualize.",
exc_info=True,
)
return
# FIXME: remove this after resolving "span dist too large" warning # FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR) old_level = set_mgb_log_level(logging.ERROR)


enable_receptive_field()

graph = Network.load(model_path) graph = Network.load(model_path)
writer = SummaryWriter(log_path)


def process_name(name): def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8")
# nodes that start with point or contain float const will lead to display bug
if not re.match(r"^[+-]?\d*\.\d*", name):
name = name.replace(".", "/")
return name.encode(encoding="utf-8")


summary = [["item", "value"]]
node_list = [] node_list = []
flops_list = [] flops_list = []
params_list = [] params_list = []
@@ -84,78 +96,90 @@ def visualize(
node_oup = node.outputs[0] node_oup = node.outputs[0]


inp_list = [process_name(var.owner.name) for var in node.inputs] inp_list = [process_name(var.owner.name) for var in node.inputs]
attr = {
"_output_shapes": AttrValue(
list=AttrValue.ListValue(
shape=[
TensorShapeProto(
dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape]
)
]
)
),
}
if hasattr(node, "calc_flops"):
flops_num = node.calc_flops()
if log_path:
# detail format see tensorboard/compat/proto/attr_value.proto
attr = {
"_output_shapes": AttrValue(
list=AttrValue.ListValue(
shape=[
TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=d) for d in node_oup.shape
]
)
]
)
),
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr # add op flops attr
attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8"))
flops_list.append(
dict(
name=node.name,
class_name=node.type,
input_shapes=[i.shape for i in node.inputs],
output_shapes=[o.shape for o in node.outputs],
flops_num=flops_num,
flops_cum=0,
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
) )
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)

if node.type == "ImmutableTensor": if node.type == "ImmutableTensor":
param_dim = np.prod(node_oup.shape)
# TODO: consider other quantize dtypes
param_bytes = 1 if is_quantize(node_oup.dtype) else 4
param_stats = get_param_stats(node.numpy())
# add tensor size attr # add tensor size attr
attr["size"] = AttrValue(
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8")
)
params_list.append(
dict(
name=node.name,
shape=node_oup.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(node.numpy().mean()),
std="{:.2g}".format(node.numpy().std()),
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
)
param_stats["name"] = node.name
params_list.append(param_stats)

if log_path:
node_list.append(
NodeDef(
name=process_name(node.name),
op=node.type,
input=inp_list,
attr=attr,
) )
) )
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
node_list.append(
NodeDef(
name=process_name(node.name), op=node.type, input=inp_list, attr=attr,
)
)
# summary
extra_info = {
"#ops": len(graph.all_oprs),
"#params": len(params_list),
}


total_flops, total_params = 0, 0
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params: if log_params:
total_params = print_params_stats(params_list, bar_length_max)
total_param_dims, total_param_size = print_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops: if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)
total_flops = print_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)


graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))


device = "/device:CPU:0"
stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
)
writer._get_file_writer().add_graph((graph_def, stepstats))
device = "/device:CPU:0"
stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
)
writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats))

print_summary(**extra_info)


# FIXME: remove this after resolving "span dist too large" warning # FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level) _imperative_rt_logger.set_log_level(old_level)


return total_params, total_flops
return total_param_size, total_flops




def main(): def main():
@@ -164,7 +188,7 @@ def main():
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument("model_path", help="dumped model path.") parser.add_argument("model_path", help="dumped model path.")
parser.add_argument("log_path", help="tensorboard log path.")
parser.add_argument("--log_path", help="tensorboard log path.")
parser.add_argument( parser.add_argument(
"--bar_length_max", "--bar_length_max",
type=int, type=int,
@@ -179,7 +203,20 @@ def main():
parser.add_argument( parser.add_argument(
"--log_flops", action="store_true", help="whether print and record op flops.", "--log_flops", action="store_true", help="whether print and record op flops.",
) )
visualize(**vars(parser.parse_args()))
parser.add_argument(
"--all",
action="store_true",
help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
)
args = parser.parse_args()
if args.all:
args.log_params = True
args.log_flops = True
if not args.log_path:
args.log_path = "./log"
kwargs = vars(args)
kwargs.pop("all")
visualize(**kwargs)




if __name__ == "__main__": if __name__ == "__main__":


+ 216
- 118
imperative/python/megengine/utils/module_stats.py View File

@@ -5,16 +5,17 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
from functools import partial from functools import partial


import numpy as np import numpy as np
import tabulate import tabulate


import megengine as mge import megengine as mge
import megengine.core.tensor.dtype as dtype
import megengine.module as m import megengine.module as m
import megengine.module.qat as qatm import megengine.module.qat as qatm
import megengine.module.quantized as qm import megengine.module.quantized as qm
from megengine.core.tensor.dtype import get_dtype_bit
from megengine.functional.tensor import zeros from megengine.functional.tensor import zeros


try: try:
@@ -26,61 +27,99 @@ logger = mge.get_logger(__name__)
logger.setLevel("INFO") logger.setLevel("INFO")




CALC_FLOPS = {}
_calc_flops_dict = {}
_calc_receptive_field_dict = {}




def _register_modules(*modules):
def _receptive_field_fallback(module, inputs, outputs):
if not _receptive_field_enabled:
return
assert not hasattr(module, "_rf")
assert not hasattr(module, "_stride")
if len(inputs) == 0:
# TODO: support other dimension
module._rf = (1, 1)
module._stride = (1, 1)
return module._rf, module._stride
rf, stride = preprocess_receptive_field(module, inputs, outputs)
module._rf = rf
module._stride = stride
return rf, stride


# key tuple, impl_dict, fallback
_iter_list = [
("flops_num", _calc_flops_dict, None),
(
("receptive_field", "stride"),
_calc_receptive_field_dict,
_receptive_field_fallback,
),
]

_receptive_field_enabled = False


def _register_dict(*modules, dict=None):
def callback(impl): def callback(impl):
for module in modules: for module in modules:
CALC_FLOPS[module] = impl
dict[module] = impl
return impl return impl


return callback return callback




@_register_modules(
m.Conv2d,
m.ConvTranspose2d,
m.LocalConv2d,
qm.Conv2d,
qm.ConvRelu2d,
qm.ConvBn2d,
qm.ConvBnRelu2d,
qatm.Conv2d,
qatm.ConvRelu2d,
qatm.ConvBn2d,
qatm.ConvBnRelu2d,
def register_flops(*modules):
return _register_dict(*modules, dict=_calc_flops_dict)


def register_receptive_field(*modules):
return _register_dict(*modules, dict=_calc_receptive_field_dict)


def enable_receptive_field():
global _receptive_field_enabled
_receptive_field_enabled = True


def disable_receptive_field():
global _receptive_field_enabled
_receptive_field_enabled = False


@register_flops(
m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d
) )
def count_convNd(module, input, output):
def flops_convNd(module: m.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0 bias = 1 if module.bias is not None else 0
group = module.groups
ic = input[0].shape[1]
oc = output[0].shape[1]
goc = oc // group
gic = ic // group
N = output[0].shape[0]
HW = np.prod(output[0].shape[2:])
# N x Cout x H x W x (Cin x Kw x Kh + bias) # N x Cout x H x W x (Cin x Kw x Kh + bias)
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias)
return np.prod(outputs[0].shape) * (
module.in_channels // module.groups * np.prod(module.kernel_size) + bias
)




@_register_modules(m.ConvTranspose2d)
def count_deconvNd(module, input, output):
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size)
@register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs):
bias = module.out_features if module.bias is not None else 0
return np.prod(outputs[0].shape) * module.in_features + bias




@_register_modules(m.Linear, qatm.Linear, qm.Linear)
def count_linear(module, input, output):
return np.prod(output[0].shape) * module.in_features
@register_flops(m.BatchMatMulActivation)
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs):
bias = 1 if module.bias is not None else 0
x = inputs[0]
w = module.weight
batch_size = x.shape[0]
n, p = x.shape[1:]
_, m = w.shape[1:]
return n * (p + bias) * m * batch_size




# does not need import qat and quantized module since they inherit from float module. # does not need import qat and quantized module since they inherit from float module.
hook_modules = ( hook_modules = (
m.Conv2d,
m.ConvTranspose2d,
m.LocalConv2d,
m.BatchNorm2d,
m.conv._ConvNd,
m.Linear, m.Linear,
m.BatchMatMulActivation,
) )




@@ -106,22 +145,63 @@ def sizeof_fmt(num, suffix="B"):
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)




def print_flops_stats(flops, bar_length_max=20):
flops_list = [i["flops_num"] for i in flops]
max_flops_num = max(flops_list + [0])
# calc total flops and set flops_cum
def preprocess_receptive_field(module, inputs, outputs):
# TODO: support other dimensions
pre_rf = (
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs),
max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs),
)
pre_stride = (
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs),
max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs),
)
return pre_rf, pre_stride


def get_op_stats(module, inputs, outputs):
rst = {
"input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs],
}
valid_flag = False
for key, _dict, fallback in _iter_list:
for _type in _dict:
if isinstance(module, _type):
value = _dict[_type](module, inputs, outputs)
valid_flag = True
break
else:
if fallback is not None:
value = fallback(module, inputs, outputs)
continue

if isinstance(key, tuple):
assert isinstance(value, tuple)
for k, v in zip(key, value):
rst[k] = v
else:
rst[key] = value

if valid_flag:
return rst
else:
return None
return


def print_op_stats(flops, bar_length_max=20):
max_flops_num = max([i["flops_num"] for i in flops] + [0])
total_flops_num = 0 total_flops_num = 0
for d in flops: for d in flops:
total_flops_num += int(d["flops_num"]) total_flops_num += int(d["flops_num"])
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")


for i in flops:
f = i["flops_num"]
i["flops"] = sizeof_fmt(f, suffix="OPs")
r = i["ratio"] = f / total_flops_num
i["percentage"] = "{:.2f}%".format(r * 100)
bar_length = int(f / max_flops_num * bar_length_max)
i["bar"] = "#" * bar_length
for d in flops:
ratio = d["ratio"] = d["flops_num"] / total_flops_num
d["percentage"] = "{:.2f}%".format(ratio * 100)
bar_length = int(d["flops_num"] / max_flops_num * bar_length_max)
d["bar"] = "#" * bar_length
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs")


header = [ header = [
"name", "name",
@@ -133,10 +213,13 @@ def print_flops_stats(flops, bar_length_max=20):
"percentage", "percentage",
"bar", "bar",
] ]
if _receptive_field_enabled:
header.insert(4, "receptive_field")
header.insert(5, "stride")


total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum( total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops
) )
flops.append( flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
@@ -147,30 +230,44 @@ def print_flops_stats(flops, bar_length_max=20):
return total_flops_num return total_flops_num




def print_params_stats(params, bar_length_max=20):
def get_param_stats(param: np.ndarray):
nbits = get_dtype_bit(param.dtype.name)
shape = param.shape
param_dim = np.prod(param.shape)
param_size = param_dim * nbits // 8
return {
"dtype": param.dtype,
"shape": shape,
"mean": "{:.3g}".format(param.mean()),
"std": "{:.3g}".format(param.std()),
"param_dim": param_dim,
"nbits": nbits,
"size": param_size,
}


def print_param_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0 total_param_dims, total_param_size = 0, 0
for d in params: for d in params:
total_param_dims += int(d["param_dim"]) total_param_dims += int(d["param_dim"])
total_param_size += int(d["size"]) total_param_size += int(d["size"])
d["size"] = sizeof_fmt(d["size"])
d["size_cum"] = sizeof_fmt(total_param_size) d["size_cum"] = sizeof_fmt(total_param_size)


for d in params: for d in params:
ratio = d["param_dim"] / total_param_dims
ratio = d["size"] / total_param_size
d["ratio"] = ratio d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100) d["percentage"] = "{:.2f}%".format(ratio * 100)

# construct bar
max_ratio = max([d["ratio"] for d in params])
for d in params:
bar_length = int(d["ratio"] / max_ratio * bar_length_max)
bar_length = int(d["size"] / max_size * bar_length_max)
d["size_bar"] = "#" * bar_length d["size_bar"] = "#" * bar_length
d["size"] = sizeof_fmt(d["size"])


param_size = sizeof_fmt(total_param_size) param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))


header = [ header = [
"name", "name",
"dtype",
"shape", "shape",
"mean", "mean",
"std", "std",
@@ -186,7 +283,13 @@ def print_params_stats(params, bar_length_max=20):
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) "param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
) )


return total_param_size
return total_param_dims, total_param_size


def print_summary(**kwargs):
data = [["item", "value"]]
data.extend(list(kwargs.items()))
logger.info("summary\n" + tabulate.tabulate(data))




def module_stats( def module_stats(
@@ -205,71 +308,53 @@ def module_stats(
:param log_params: whether print and record params size. :param log_params: whether print and record params size.
:param log_flops: whether print and record op flops. :param log_flops: whether print and record op flops.
""" """
disable_receptive_field()


def get_byteswidth(tensor):
if dtype.is_quantize(tensor.dtype):
return 1
# elif dtype.is_bfloat16(tensor.dtype):
# return 2
else:
return 4

def module_stats_hook(module, input, output, name=""):
def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]


flops_fun = CALC_FLOPS.get(type(module))
if callable(flops_fun):
flops_num = flops_fun(module, input, output)

if not isinstance(output, (list, tuple)):
output = [output]

flops.append(
dict(
name=name,
class_name=class_name,
input_shapes=[i.shape for i in input],
output_shapes=[o.shape for o in output],
flops_num=flops_num,
flops_cum=0,
)
)
flops_stats = get_op_stats(module, inputs, outputs)
if flops_stats is not None:
flops_stats["name"] = name
flops_stats["class_name"] = class_name
flops.append(flops_stats)


if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
w = module.weight w = module.weight
value = w.numpy()
param_dim = np.prod(w.shape)
param_bytes = get_byteswidth(w)
params.append(
dict(
name=name + "-w",
shape=w.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(value.mean()),
std="{:.2g}".format(value.std()),
)
)
param_stats = get_param_stats(w.numpy())
param_stats["name"] = name + "-w"
params.append(param_stats)


if hasattr(module, "bias") and module.bias is not None: if hasattr(module, "bias") and module.bias is not None:
b = module.bias b = module.bias
value = b.numpy()
param_dim = np.prod(b.shape)
param_bytes = get_byteswidth(b)
params.append(
dict(
name=name + "-b",
shape=b.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(value.mean()),
std="{:.2g}".format(value.std()),
)
)
param_stats = get_param_stats(b.numpy())
param_stats["name"] = name + "-b"
params.append(param_stats)

@contextlib.contextmanager
def adjust_stats(module, training=False):
"""Adjust module to training/eval mode temporarily.

Args:
module (M.Module): used module.
training (bool): training mode. True for train mode, False fro eval mode.
"""

def recursive_backup_stats(module, mode):
for m in module.modules():
# save prev status to _prev_training
m._prev_training = m.training
m.train(mode, recursive=False)

def recursive_recover_stats(module):
for m in module.modules():
# recover prev status and delete attribute
m.training = m._prev_training
delattr(m, "_prev_training")

recursive_backup_stats(module, mode=training)
yield module
recursive_recover_stats(module)


# multiple inputs to the network # multiple inputs to the network
if not isinstance(input_size[0], tuple): if not isinstance(input_size[0], tuple):
@@ -286,15 +371,28 @@ def module_stats(
) )


inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]
model.eval()
model(*inputs)
with adjust_stats(model, training=False) as model:
model(*inputs)

for h in hooks: for h in hooks:
h.remove() h.remove()


total_flops, total_params = 0, 0
extra_info = {
"#params": len(params),
}
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params: if log_params:
total_params = print_params_stats(params, bar_length_max)
total_param_dims, total_param_size = print_param_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops: if log_flops:
total_flops = print_flops_stats(flops, bar_length_max)
total_flops = print_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)

print_summary(**extra_info)


return total_params, total_flops
return total_param_size, total_flops

+ 108
- 60
imperative/python/megengine/utils/network.py View File

@@ -11,12 +11,14 @@ import fnmatch
import itertools import itertools
import re import re
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List
from typing import Dict, List, Sequence


import numpy as np import numpy as np


from ..core._imperative_rt import ComputingGraph from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..logger import get_logger
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import ( from .network_node import (
Host2DeviceCopy, Host2DeviceCopy,
@@ -27,6 +29,8 @@ from .network_node import (
str_to_mge_class, str_to_mge_class,
) )


logger = get_logger(__name__)



class Network: class Network:
def __init__(self): def __init__(self):
@@ -60,12 +64,12 @@ class Network:
) )
outputs = [new_outputs[i] for i in outspec] outputs = [new_outputs[i] for i in outspec]
self._orig_outputs = outputs self._orig_outputs = outputs
self.add_dep_oprs(*outputs)
for x in self._orig_outputs:
self.output_vars.append(self._get_var(x))
self.add_dep_oprs()
for x in self._orig_inputs: for x in self._orig_inputs:
self.input_vars.append(self._get_var(x)) self.input_vars.append(self._get_var(x))


for x in self._orig_outputs:
self.output_vars.append(self._get_var(x))
self.graph = self._orig_outputs[0].graph self.graph = self._orig_outputs[0].graph
return self return self


@@ -83,6 +87,58 @@ class Network:
for o in opr.outputs: for o in opr.outputs:
self.all_vars_map[o.var.id] = o self.all_vars_map[o.var.id] = o


def optimize_for_inference(self, dest_vars, **kwargs):
r"""
Applies optimize_for_inference pass for operator graph.

:param dest_vars: list of output vars in the operator graph

:Keyword Arguments:

* enable_io16xc32 --
whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be
changed to float16.
* enable_ioc16 --
whether to use float16 for both I/O and computation
precision.

* enable_hwcd4 --
whether to use NHWCD4 data layout. This is faster on some
OpenCL backend.
* enable_nchw88 --
whether to use NCHW88 data layout, currently
used in X86 AVX backend.
* enable_nchw44 --
whether to use NCHW44 data layout, currently
used in arm backend.
* enable_nchw44_dot --
whether to use NCHW44_dot data layout, currently
used in armv8.2+dotprod backend.
* enable_nchw4 --
whether to use NCHW4 data layout, currently
used in nvidia backend(based on cudnn).
* enable_nchw32 --
whether to use NCHW32 data layout, currently
used in nvidia backend with tensorcore(based on cudnn).
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.

* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
"""

if not isinstance(dest_vars, Sequence):
dest_vars = [dest_vars]
dest_vars = list(G.VarNode(var.var) for var in dest_vars)
new_vars = G.optimize_for_inference(dest_vars, **kwargs)
return list(self._get_var(var) for var in new_vars)

def dump( def dump(
self, self,
file, file,
@@ -122,47 +178,22 @@ class Network:


:Keyword Arguments: :Keyword Arguments:


* enable_io16xc32 --
whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be
changed to float16.
* enable_ioc16 --
whether to use float16 for both I/O and computation
precision.

* enable_hwcd4 --
whether to use NHWCD4 data layout. This is faster on some
OpenCL backend.
* enable_nchw88 --
whether to use NCHW88 data layout, currently
used in X86 AVX backend.
* enable_nchw44 --
whether to use NCHW44 data layout, currently
used in arm backend.
* enable_nchw44_dot --
whether to use NCHW44_dot data layout, currently
used in armv8.2+dotprod backend.
* enable_nchw4 --
whether to use NCHW4 data layout, currently
used in nvidia backend(based on cudnn).
* enable_nchw32 --
whether to use NCHW32 data layout, currently
used in nvidia backend with tensorcore(based on cudnn).
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.

* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
See also :py:meth:`optimize_for_inference`.

""" """


self._compile() self._compile()
out = [G.VarNode(var.var) for var in self.output_vars] out = [G.VarNode(var.var) for var in self.output_vars]


if kwargs.pop("arg_names", False):
logger.warning(
'"arg_names" is not supported in Network.dump, rename input vars directly'
)
if kwargs.pop("output_names", False):
logger.warning(
'"output_names" is not supported in Network.dump, rename output vars directly'
)

if optimize_for_inference: if optimize_for_inference:
out = G.optimize_for_inference(out, **kwargs) out = G.optimize_for_inference(out, **kwargs)


@@ -197,6 +228,8 @@ class Network:
def add_output(self, *vars: VarNode): def add_output(self, *vars: VarNode):
"""Adds vars into the network output node list """Adds vars into the network output node list
""" """
if not all([var.owner for var in vars]):
self.add_dep_oprs(*vars)
for var in vars: for var in vars:
if var not in self.output_vars: if var not in self.output_vars:
self.output_vars.append(var) self.output_vars.append(var)
@@ -209,21 +242,25 @@ class Network:
self.output_vars.remove(var) self.output_vars.remove(var)


def add_dep_oprs(self, *vars): def add_dep_oprs(self, *vars):
"""Adds dependent opnodes and varnodes of vars into network
"""
oprs = get_oprs_seq(vars, False, False)
for mge_opr in oprs:
if len(vars) == 0:
vars = self.output_vars
q = list(vars)
while len(q) > 0:
cur = q.pop(0)
if cur.owner is not None:
continue
if cur.name is None:
cur.name = cur.var.name
self.all_vars_map[cur.var.id] = cur
mge_opr = cur.var.owner
if get_opr_type(mge_opr) == "Host2DeviceCopy": if get_opr_type(mge_opr) == "Host2DeviceCopy":
self._orig_inputs.extend(mge_opr.outputs) self._orig_inputs.extend(mge_opr.outputs)
opr = self._add_opr(mge_opr)
if opr is not None:
for x in mge_opr.inputs:
opr.add_inp_var(self._get_var(x))
# set out var
for x in mge_opr.outputs:
opr.add_out_var(self._get_var(x))

return [self.all_vars_map[var.id] for var in vars]
cur.owner = self._add_opr(mge_opr)
if cur.owner is None:
cur.owner = self.all_oprs_map[mge_opr.id]
continue
q.extend(cur.owner.inputs)
return list(vars)


def modify_opr_names(self, modifier): def modify_opr_names(self, modifier):
"""Modifies names of operators **inplace**; useful for merging loaded """Modifies names of operators **inplace**; useful for merging loaded
@@ -275,6 +312,9 @@ class Network:
Replaces vars in the graph. Replaces vars in the graph.
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
""" """
if not all([var.owner for var in repl_dict.values()]):
print(repl_dict.values())
self.add_dep_oprs(*list(repl_dict.values()))
for var in self.all_vars: for var in self.all_vars:
if var in repl_dict: if var in repl_dict:
repl_var = repl_dict[var] repl_var = repl_dict[var]
@@ -282,6 +322,7 @@ class Network:
idx = owner.outputs.index(repl_var) idx = owner.outputs.index(repl_var)
owner.outputs[idx] = var owner.outputs[idx] = var
var.__dict__.update(repl_var.__dict__) var.__dict__.update(repl_var.__dict__)
var.var = repl_var.var


def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
""" """
@@ -297,6 +338,7 @@ class Network:
for ind, var in enumerate(opr.outputs): for ind, var in enumerate(opr.outputs):
var.owner = repl_dict[opr] var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var


def get_opr_by_type(self, oprcls, unique=True): def get_opr_by_type(self, oprcls, unique=True):
assert issubclass(oprcls, OpNode) assert issubclass(oprcls, OpNode)
@@ -381,11 +423,16 @@ class Network:
return self.opr_filter.as_dict() return self.opr_filter.as_dict()


# used for loading and building graph # used for loading and building graph
def _add_opr(self, x):
def _add_opr(self, opr):
# TODO: use megbrain C++ RTTI to replace type string # TODO: use megbrain C++ RTTI to replace type string
if x.id not in self.all_oprs_map:
self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x)
return self.all_oprs_map[x.id]
if opr.id not in self.all_oprs_map:
opnode = str_to_mge_class(get_opr_type(opr)).load(opr)
self.all_oprs_map[opr.id] = opnode
for var in opr.inputs:
opnode.add_inp_var(self._get_var(var))
for var in opr.outputs:
opnode.add_out_var(self._get_var(var))
return opnode
else: else:
return None return None


@@ -397,7 +444,7 @@ class Network:


def _get_var(self, x): def _get_var(self, x):
# auto convert to VarNode of Network # auto convert to VarNode of Network
if x.id not in self.all_vars_map:
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x:
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner))
return self.all_vars_map[x.id] return self.all_vars_map[x.id]


@@ -652,7 +699,7 @@ class NodeFilterHasInput(NodeFilter):
assert isinstance( assert isinstance(
i, OpNode i, OpNode
), "has_input() must be used with OpNode; " "got {!r}".format(i) ), "has_input() must be used with OpNode; " "got {!r}".format(i)
if self.var in i.inputs:
if any(self.var is _ for _ in i.inputs):
yield i yield i




@@ -663,6 +710,7 @@ class NodeFilterName(NodeFilter):


def __init__(self, node_iter, pattern, ignorecase): def __init__(self, node_iter, pattern, ignorecase):
super().__init__(node_iter) super().__init__(node_iter)
self.pattern = pattern
self._re = self.make_re(pattern, ignorecase) self._re = self.make_re(pattern, ignorecase)


@classmethod @classmethod
@@ -676,5 +724,5 @@ class NodeFilterName(NodeFilter):


def __iter__(self): def __iter__(self):
for i in self._iter: for i in self._iter:
if self._re.match(i.name):
if self.pattern == i.name or self._re.match(i.name):
yield i yield i

+ 119
- 56
imperative/python/megengine/utils/network_node.py View File

@@ -6,27 +6,41 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import json import json
import sys import sys
from typing import Callable
from typing import Callable, Sequence


import numpy as np import numpy as np


from ..core import _imperative_rt as rt from ..core import _imperative_rt as rt
from ..core._imperative_rt.core2 import SymbolVar
from ..core._wrap import Device from ..core._wrap import Device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor.megbrain_graph import InputNode
from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.indexing import getitem as _getitem
from ..core.tensor.indexing import setitem as _setitem
from ..core.tensor.megbrain_graph import InputNode, OutputNode
from ..tensor import Tensor from ..tensor import Tensor
from .comp_graph_tools import replace_vars from .comp_graph_tools import replace_vars
from .module_stats import (
preprocess_receptive_field,
register_flops,
register_receptive_field,
)




class NetworkNode: class NetworkNode:
pass pass




class VarNode(NetworkNode):
def __init__(self, owner_opr=None, name=None):
self.var = None
class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
pass


class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def __init__(self, var=None, *, owner_opr=None, name=None):
SymbolVar.__init__(self, var)
self.owner = owner_opr self.owner = owner_opr
self.name = name self.name = name
self.id = id(self) self.id = id(self)
@@ -53,6 +67,40 @@ class VarNode(NetworkNode):
def dtype(self): def dtype(self):
return self.var.dtype if self.var else None return self.var.dtype if self.var else None


def __bool__(self):
return False

__index__ = None
__int__ = None
__float__ = None
__complex__ = None

def __hash__(self):
return id(self)

@property
def _tuple_shape(self):
return self.var.shape

def numpy(self):
o = OutputNode(self.var)
self.graph.compile(o.outputs).execute()
return o.get_value().numpy()

def __getitem__(self, index):
return _getitem(self, index)

def __setitem__(self, index, value):
if index is not Ellipsis:
value = _setitem(self, index, value)
if self.owner is not None:
idx = self.owner.outputs.index(self)
self.owner.outputs[idx] = VarNode(
self.var, owner_opr=self.owner, name=self.var.name
)
self.var = value.var
self.owner = None

def set_owner_opr(self, owner_opr): def set_owner_opr(self, owner_opr):
self.owner = owner_opr self.owner = owner_opr


@@ -130,7 +178,7 @@ class Host2DeviceCopy(OpNode):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner self._opr = outputs.owner
if len(self.outputs) == 0: if len(self.outputs) == 0:
self.outputs.append(VarNode(self, self.name))
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = outputs self.outputs[0].var = outputs
assert self.outputs[0].owner is self assert self.outputs[0].owner is self


@@ -168,8 +216,8 @@ class ImmutableTensor(OpNode):
def set_value(self, data, device=None): def set_value(self, data, device=None):
assert self.graph is not None assert self.graph is not None
cn = device if device else self.device cn = device if device else self.device
assert isinstance(data, (int, float, np.ndarray))
if isinstance(data, (int, float)):
assert isinstance(data, (int, float, Sequence, np.ndarray))
if not isinstance(data, np.ndarray):
data = np.array(data) data = np.array(data)
if data.dtype == np.float64: if data.dtype == np.float64:
data = data.astype(np.float32) data = data.astype(np.float32)
@@ -177,7 +225,7 @@ class ImmutableTensor(OpNode):
data = data.astype(np.int32) data = data.astype(np.int32)
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
if len(self.outputs) == 0: if len(self.outputs) == 0:
self.outputs.append(VarNode(self, self.name))
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = varnode self.outputs[0].var = varnode
self._opr = varnode.owner self._opr = varnode.owner


@@ -225,8 +273,21 @@ class Elemwise(OpNode):
type = "Elemwise" type = "Elemwise"
opdef = builtin.Elemwise opdef = builtin.Elemwise


def calc_flops(self):
return np.prod(self.outputs[0].shape)

class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType

@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj


@register_flops(Elemwise, ElemwiseMultiType)
def flops_elemwise(opnode: Elemwise, inputs, outputs):
return np.prod(outputs[0].shape)




class Reduce(OpNode): class Reduce(OpNode):
@@ -255,20 +316,24 @@ class MatrixMul(OpNode):
type = "MatrixMul" type = "MatrixMul"
opdef = builtin.MatrixMul opdef = builtin.MatrixMul


def calc_flops(self):
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2
mid_shape = self.inputs[0].shape[1]
return np.prod(self.outputs[0].shape) * mid_shape

@register_flops(MatrixMul)
def flops_matmul(opnode: MatrixMul, inputs, outputs):
assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
mid_shape = inputs[0].shape[1]
return np.prod(outputs[0].shape) * mid_shape




class BatchedMatrixMul(OpNode): class BatchedMatrixMul(OpNode):
type = "BatchedMatmul" type = "BatchedMatmul"
opdef = builtin.BatchedMatrixMul opdef = builtin.BatchedMatrixMul


def calc_flops(self):
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3
mid_shape = self.inputs[0].shape[2]
return np.prod(self.outputs[0].shape) * mid_shape

@register_flops(BatchedMatrixMul)
def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
mid_shape = inputs[0].shape[2]
return np.prod(outputs[0].shape) * mid_shape




class Dot(OpNode): class Dot(OpNode):
@@ -285,18 +350,6 @@ class ConvolutionForward(OpNode):
type = "Convolution" type = "Convolution"
opdef = builtin.Convolution opdef = builtin.Convolution


def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh)



class ConvolutionBackwardData(OpNode): class ConvolutionBackwardData(OpNode):
type = "ConvTranspose" type = "ConvTranspose"
@@ -343,17 +396,41 @@ class ConvBiasForward(OpNode):
obj.params["dtype"] = opr.outputs[0].dtype obj.params["dtype"] = opr.outputs[0].dtype
return obj return obj


def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return NCHW * (num_input * kw * kh + 1)

@register_flops(
ConvolutionForward, ConvBiasForward,
)
def flops_conv(opnode: ConvolutionForward, inputs, outputs):
param_W_shape = inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(outputs[0].shape)
bias = 1 if isinstance(opnode, ConvBiasForward) else 0
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh + bias)


@register_receptive_field(ConvolutionForward, ConvBiasForward)
def receptive_field(opnode: ConvolutionForward, inputs, outputs):
pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
param_W_shape = inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
rf = (
kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
)
stride = (
opnode.params["stride_h"] * pre_stride[0],
opnode.params["stride_w"] * pre_stride[1],
)
opnode._rf = rf
opnode._stride = stride
return rf, stride




class BatchConvBiasForward(OpNode): class BatchConvBiasForward(OpNode):
@@ -652,20 +729,6 @@ class AssertEqual(OpNode):
opdef = builtin.AssertEqual opdef = builtin.AssertEqual




class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType

@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj

def calc_flops(self):
return np.prod(self.outputs[0].shape)


class CvtColorForward(OpNode): class CvtColorForward(OpNode):
type = "CvtColor" type = "CvtColor"
opdef = builtin.CvtColor opdef = builtin.CvtColor

+ 1
- 1
imperative/python/src/graph_rt.cpp View File

@@ -266,7 +266,7 @@ void init_graph_rt(py::module m) {
{"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }}, {"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }},
{"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }}, {"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }},
{"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }}, {"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }},
{"OPTMIZED", [&]() { stg = _AlgoStrategy::OPTMIZED; }},
{"OPTIMIZED", [&]() { stg = _AlgoStrategy::OPTIMIZED; }},
}; };
auto it = m.find(strategy); auto it = m.find(strategy);
mgb_assert(it != m.end(), "Invalid strategy string!"); mgb_assert(it != m.end(), "Invalid strategy string!");


+ 29
- 12
imperative/python/src/ops.cpp View File

@@ -87,9 +87,13 @@ struct pyobj_convert_generic {
} }
}; };


template<typename T, typename SFINAE=void>
struct EnumTrait;

template <typename T> template <typename T>
struct EnumTrait {
struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> {
static constexpr bool is_bit_combined = false; static constexpr bool is_bit_combined = false;
static constexpr std::underlying_type_t<T> max = 0;
}; };


template <typename T> template <typename T>
@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper {
return ret; return ret;
} }
} }
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1);
return obj;
}
static int py_init(PyObject* self, PyObject* args, PyObject*) {
int input = 1;
if (PyArg_ParseTuple(args, "|i", &input)){
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value =
static_cast<T>(input);
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) {
if (!PyTuple_Size(args)) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
return obj;
}
else {
PyObject* input;
if (!PyArg_ParseTuple(args, "|O", &input)) {
return nullptr;
}
T value;
try {
value = pyobj_convert_generic<T>::from(input);
} CATCH_ALL(nullptr);
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
} }
return 0;
} }
static PyObject* py_repr(PyObject* self) { static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to( return pyobj_convert_generic<std::string>::to(
@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T,
static T from(PyObject* obj) { static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) { if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value; return reinterpret_cast<Wrapper*>(obj)->value;
} else if(PyLong_Check(obj)) {
auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj);
mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError,
"out of range, cannot convert %zu to %s",
static_cast<uint32_t>(value), Wrapper::name);
return static_cast<T>(value);
} }
// try as string // try as string
// TODO: type checkcd // TODO: type checkcd


+ 57
- 19
imperative/python/src/tensor.cpp View File

@@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
if (ctx.op->same_type<BackwardGraph>()) { if (ctx.op->same_type<BackwardGraph>()) {
ctx.backward = true; ctx.backward = true;
} }
if (py::isinstance<cg::VarNode>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>();
}
auto op = ctx.op.get();
return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr();
}

if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
}
auto op = ctx.op.get();
auto rst = OpDef::apply_on_var_node(*op, vinputs);
auto ret = pybind11::tuple(rst.size());
auto typeobj = py::handle(args[0]).get_type();
for (size_t i = 0; i<rst.size(); ++i) {
ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic));
}
return ret.release().ptr();
}


for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
@@ -686,9 +691,9 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
continue; continue;
} }


if (py::isinstance<cg::VarNode>(py::handle(handle))){
auto var = py::handle(handle).cast<cg::VarNode *>();
mgb::DType type = var->dtype();
if (py::isinstance<PySymbolVar>(py::handle(handle))){
auto var = py::handle(handle).cast<PySymbolVar*>();
mgb::DType type = var->m_node->dtype();
auto && descr = npy::dtype_mgb2np_descr(type); auto && descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get()); Py_INCREF(descr.get());
tensors.emplace_back(descr.get()); tensors.emplace_back(descr.get());
@@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
bool valid = false; bool valid = false;
CompNode cn; CompNode cn;
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle); TensorWrapper* tw = TensorWrapper::try_cast(handle);


bool is_var = py::isinstance<cg::VarNode>(py::handle(handle));
if (tw || is_var) {
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
if (tw || is_symvar) {
if (!valid) { if (!valid) {
cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node();
cn = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
valid = true; valid = true;
} else { } else {
CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node();
CompNode cn1 = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
if (cn1 != cn) { if (cn1 != cn) {
throw py::value_error(ssprintf("ambiguous device: %s vs %s", throw py::value_error(ssprintf("ambiguous device: %s vs %s",
cn.to_string().c_str(), cn1.to_string().c_str()));
cn.to_string().c_str(),
cn1.to_string().c_str()));
} }
} }
} }
@@ -849,6 +861,32 @@ void init_tensor(py::module m) {
.def("__call__", &TensorWeakRef::operator()) .def("__call__", &TensorWeakRef::operator())
.def("_use_cnt", &TensorWeakRef::_use_cnt); .def("_use_cnt", &TensorWeakRef::_use_cnt);


py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
.def_property_readonly(
"dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
.def_property("var", [](PySymbolVar* v) { return v->m_node; },
[](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
.def_property_readonly(
"device",
[](PySymbolVar* v) { return v->m_node->comp_node(); })
.def_property_readonly(
"graph",
[](PySymbolVar* v) { return v->m_node->owner_graph(); })
.def_property_readonly(
"shape",
[](PySymbolVar* v) -> const TensorShape* {
auto&& mgr = v->m_node->owner_graph()
->static_infer_manager();
return mgr.infer_shape_fallible(v->m_node);
})
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
.def("_setscalar",
[](PySymbolVar* v) { return v->is_scalar = true; })
.def(py::init([](cg::VarNode* node) {
return std::make_shared<PySymbolVar>(node);
}),
py::arg() = nullptr);

static PyMethodDef method_defs[] = { static PyMethodDef method_defs[] = {
MGE_PY_INTERFACE(apply, py_apply), MGE_PY_INTERFACE(apply, py_apply),
MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),


+ 6
- 0
imperative/python/src/tensor.h View File

@@ -181,6 +181,12 @@ struct TensorWrapper {
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
}; };


struct PySymbolVar {
cg::VarNode* m_node = nullptr;
bool is_scalar = false;
PySymbolVar() = default;
PySymbolVar(VarNode *m): m_node(m){}
};


PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);




+ 21
- 4
imperative/python/test/helpers/utils.py View File

@@ -2,9 +2,11 @@ import io


import numpy as np import numpy as np


import megengine.core.tensor.megbrain_graph as G
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor from megengine import tensor
from megengine.jit import trace from megengine.jit import trace
from megengine.utils.network_node import VarNode




def _default_compare_fn(x, y): def _default_compare_fn(x, y):
@@ -14,8 +16,23 @@ def _default_compare_fn(x, y):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)




def make_tensor(x, network=None, device=None):
if network is not None:
if isinstance(x, VarNode):
return VarNode(x.var)
return network.make_const(x, device=device)
else:
return tensor(x, device=device)


def opr_test( def opr_test(
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs
cases,
func,
compare_fn=_default_compare_fn,
ref_fn=None,
test_trace=True,
network=None,
**kwargs
): ):
""" """
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test. :param cases: the list which have dict element, the list length should be 2 for dynamic shape test.
@@ -44,7 +61,7 @@ def opr_test(
if not isinstance(results, (tuple, list)): if not isinstance(results, (tuple, list)):
results = (results,) results = (results,)
for r, e in zip(results, expected): for r, e in zip(results, expected):
if not isinstance(r, tensor):
if not isinstance(r, (tensor, VarNode)):
r = tensor(r) r = tensor(r)
compare_fn(r, e) compare_fn(r, e)


@@ -72,9 +89,9 @@ def opr_test(
raise ValueError("the input func should be callable") raise ValueError("the input func should be callable")


inp, outp = get_param(cases, 0) inp, outp = get_param(cases, 0)
inp_tensor = [tensor(inpi) for inpi in inp]
inp_tensor = [make_tensor(inpi, network) for inpi in inp]


if test_trace:
if test_trace and not network:
copied_inp = inp_tensor.copy() copied_inp = inp_tensor.copy()
for symbolic in [False, True]: for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func) traced_func = trace(symbolic=symbolic)(func)


+ 4
- 0
imperative/python/test/integration/test_optimizer.py View File

@@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
) )
step += 1 step += 1
check_func(ori_params, net.parameters(), step) check_func(ori_params, net.parameters(), step)
try_state_dict = {
"net": net.state_dict(),
"opt": opt.state_dict(),
}




def test_sgd(): def test_sgd():


+ 79
- 52
imperative/python/test/unit/core/test_indexing_op.py View File

@@ -10,12 +10,17 @@ import collections


import numpy as np import numpy as np
import pytest import pytest
from utils import make_tensor


import megengine import megengine
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin from megengine.core.ops import builtin
from megengine.tensor import Tensor from megengine.tensor import Tensor
from megengine.utils.network import Network
from megengine.utils.network_node import VarNode




def cvt_to_shape_desc(val, inpvar, config=None): def cvt_to_shape_desc(val, inpvar, config=None):
@@ -387,108 +392,130 @@ def test_batched_mesh_indexing():




# high level # high level
def get_value(x):
if isinstance(x, VarNode):
var = x.var
o = G.OutputNode(var)
graph = x.graph
graph.compile(o.outputs).execute()
return o.get_value().numpy()
else:
return x.numpy()


@pytest.mark.parametrize("test_varnode", [True, False])
def test_advance_indexing_high_level(test_varnode):
if test_varnode:
network = Network()
else:
network = None



def test_advance_indexing_high_level():
x = np.arange(25).reshape(5, 5).astype("int32") x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(15).reshape(3, 5).astype("int32") d = np.arange(15).reshape(3, 5).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)


np.testing.assert_equal(x[1, :], xx[1, :].numpy())
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy())
np.testing.assert_equal(x[1, :], get_value(xx[1, :]))
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1]))
np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :]))


np.testing.assert_equal(x[:, :], xx[:, :].numpy())
np.testing.assert_equal(x[1, 1], xx[1, 1].numpy())
np.testing.assert_equal(x[:, :], get_value(xx[:, :]))
np.testing.assert_equal(x[1, 1], get_value(xx[1, 1]))
yy = xx[(0, 4, 2), :] yy = xx[(0, 4, 2), :]
np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy())
np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy))


x_ = x.copy() x_ = x.copy()
x_[(0, 4, 2), :] = d x_[(0, 4, 2), :] = d
xx_ = Tensor(xx)
xx_ = make_tensor(xx, network)
xx_[(0, 4, 2), :] = d xx_[(0, 4, 2), :] = d
np.testing.assert_equal(x_, xx_.numpy())
np.testing.assert_equal(x_, get_value(xx_))


x = np.arange(27).reshape(3, 3, 3).astype("int32") x = np.arange(27).reshape(3, 3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)


np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy())
np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy())
np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy())
np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy())
np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy())
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy())
np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :]))
np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1]))
np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :]))
np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1]))
np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1]))
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1]))
np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2]))


x_ = x.copy() x_ = x.copy()
x_[1, 1, 1] = -1 x_[1, 1, 1] = -1
xx[1, 1, 1] = -1 xx[1, 1, 1] = -1
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x_[:, 1, 1] = -2 x_[:, 1, 1] = -2
xx[:, 1, 1] = x_[:, 1, 1] xx[:, 1, 1] = x_[:, 1, 1]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x_[0:1, :, 1] = -3 x_[0:1, :, 1] = -3
xx[0:1, :, 1] = x_[0:1, :, 1] xx[0:1, :, 1] = x_[0:1, :, 1]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x_[0:1, :, 1] = -4 x_[0:1, :, 1] = -4
y = Tensor(x_)
y = make_tensor(x_, network)
xx[0:1, :, 1] = y[0:1, :, 1] xx[0:1, :, 1] = y[0:1, :, 1]
np.testing.assert_equal(y.numpy(), xx.numpy())
np.testing.assert_equal(get_value(y), get_value(xx))


x[:] = 1 x[:] = 1
xx[:] = 1 xx[:] = 1
np.testing.assert_equal(x, xx.numpy())
np.testing.assert_equal(x, get_value(xx))


x = np.arange(9).reshape(3, 3).astype("int32") x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)
y = np.array([1, 2]) y = np.array([1, 2])
yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy())
np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
yy = make_tensor(y, network)
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]]))
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]]))
np.testing.assert_equal(x[:, y], get_value(xx[:, y]))
np.testing.assert_equal(x[:, y], get_value(xx[:, yy]))


x_ = x.copy() x_ = x.copy()
x_[:, y[0]] = -1 x_[:, y[0]] = -1
xx_ = Tensor(x_)
xx_ = make_tensor(x_, network)
xx[:, yy[0]] = xx_[:, yy[0]] xx[:, yy[0]] = xx_[:, yy[0]]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x_[:, y] = -1 x_[:, y] = -1
xx_ = Tensor(x_)
xx_ = make_tensor(x_, network)
xx[:, yy] = xx_[:, yy] xx[:, yy] = xx_[:, yy]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x = np.arange(9).reshape(3, 3).astype("int32") x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)
y = np.array([1]) y = np.array([1])
yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy())
yy = make_tensor(y, network)
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]]))
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]]))
np.testing.assert_equal(x[:, y], get_value(xx[:, y]))


np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
np.testing.assert_equal(x[:, y], get_value(xx[:, yy]))


x = np.arange(9).reshape(3, 3).astype("int32") x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy())
np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy())


def test_advance_indexing_with_bool():
xx = make_tensor(x, network)
np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0]))
np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0]))


@pytest.mark.parametrize(
"test_varnode", [True, False],
)
def test_advance_indexing_with_bool(test_varnode):
if test_varnode:
network = Network()
else:
network = None
a = np.arange(9).reshape(3, 3).astype(np.float32) a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([1, 2, 3]) b = np.array([1, 2, 3])
c = np.array([1, 2, 3]) c = np.array([1, 2, 3])
aa = Tensor(a)
bb = Tensor(b)
cc = Tensor(c)
np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy())
aa = make_tensor(a, network)
bb = make_tensor(b, network)
cc = make_tensor(c, network)
np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2]))
a[b == 1, c == 2] = -1.0 a[b == 1, c == 2] = -1.0
aa[bb == 1, cc == 2] = -1.0 aa[bb == 1, cc == 2] = -1.0
np.testing.assert_equal(a, aa.numpy())
np.testing.assert_equal(a, get_value(aa))


a = np.arange(9).reshape(3, 3).astype(np.float32) a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([False, True, True]) b = np.array([False, True, True])


+ 222
- 69
imperative/python/test/unit/functional/test_tensor.py View File

@@ -11,13 +11,16 @@ import platform


import numpy as np import numpy as np
import pytest import pytest
from utils import opr_test
from utils import make_tensor, opr_test


import megengine.functional as F import megengine.functional as F
from megengine import tensor from megengine import tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.utils.network import Network
from megengine.utils.network_node import VarNode




def test_eye(): def test_eye():
@@ -38,7 +41,13 @@ def test_eye():
) )




def test_concat():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat(is_varnode):
if is_varnode:
network = Network()
else:
network = None

def get_data_shape(length: int): def get_data_shape(length: int):
return (length, 2, 3) return (length, 2, 3)


@@ -50,18 +59,30 @@ def test_concat():
return F.concat([data1, data2]) return F.concat([data1, data2])


cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]))
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)




def test_concat_device():
data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0")
data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat_device(is_varnode):
if is_varnode:
network = Network()
else:
network = None

data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")


out = F.concat([data1, data2], device="cpu0") out = F.concat([data1, data2], device="cpu0")
assert str(out.device).split(":")[0] == "cpu0" assert str(out.device).split(":")[0] == "cpu0"




def test_stack():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_stack(is_varnode):
if is_varnode:
network = Network()
else:
network = None

data1 = np.random.random((3, 2, 2)).astype("float32") data1 = np.random.random((3, 2, 2)).astype("float32")
data2 = np.random.random((3, 2, 2)).astype("float32") data2 = np.random.random((3, 2, 2)).astype("float32")
data3 = np.random.random((3, 2, 2)).astype("float32") data3 = np.random.random((3, 2, 2)).astype("float32")
@@ -72,12 +93,20 @@ def test_stack():
def run(data1, data2): def run(data1, data2):
return F.stack([data1, data2], axis=ai) return F.stack([data1, data2], axis=ai)


opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai))
opr_test(
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
)



@pytest.mark.parametrize("is_varnode", [True, False])
def test_split(is_varnode):
if is_varnode:
network = Network()
else:
network = None


def test_split():
data = np.random.random((2, 3, 4, 5)).astype(np.float32) data = np.random.random((2, 3, 4, 5)).astype(np.float32)
inp = tensor(data)
inp = make_tensor(data, network)


mge_out0 = F.split(inp, 2, axis=3) mge_out0 = F.split(inp, 2, axis=3)
mge_out1 = F.split(inp, [3], axis=3) mge_out1 = F.split(inp, [3], axis=3)
@@ -106,26 +135,42 @@ def test_split():
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"




def test_reshape():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.arange(6, dtype="float32") x = np.arange(6, dtype="float32")
xx = tensor(x)
xx = make_tensor(x, network)
y = x.reshape(1, 2, 3) y = x.reshape(1, 2, 3)


for shape in [ for shape in [
(1, 2, 3), (1, 2, 3),
(1, -1, 3), (1, -1, 3),
(1, tensor(-1), 3),
(1, make_tensor(-1, network), 3),
np.array([1, -1, 3], dtype="int32"), np.array([1, -1, 3], dtype="int32"),
tensor([1, -1, 3]),
make_tensor([1, -1, 3], network),
]: ]:
yy = F.reshape(xx, shape) yy = F.reshape(xx, shape)
np.testing.assert_equal(yy.numpy(), y) np.testing.assert_equal(yy.numpy(), y)




def test_reshape_shape_inference():
x_shape_known = tensor([1, 2, 3, 4], dtype="float32")
x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum())
tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape_shape_inference(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_shape_known = make_tensor([1, 2, 3, 4], network)
x_shape_unknown = F.broadcast_to(
make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
)
tshp_unknown = astensor1d(
(make_tensor([2], network), make_tensor([2], network)), x_shape_known
)
tshp_known = astensor1d((2, 2), x_shape_known) tshp_known = astensor1d((2, 2), x_shape_known)
tshp_known_unspec = astensor1d((2, -1), x_shape_known) tshp_known_unspec = astensor1d((2, -1), x_shape_known)


@@ -146,12 +191,18 @@ def test_reshape_shape_inference():
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
] ]
opr_test(cases, func, compare_fn=check_shape, test_trace=True)
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)



@pytest.mark.parametrize("is_varnode", [True, False])
def test_squeeze(is_varnode):
if is_varnode:
network = Network()
else:
network = None


def test_squeeze():
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
xx = tensor(x)
xx = make_tensor(x, network)


for axis in [None, 3, -4, (3, -4)]: for axis in [None, 3, -4, (3, -4)]:
y = np.squeeze(x, axis) y = np.squeeze(x, axis)
@@ -159,9 +210,15 @@ def test_squeeze():
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())




def test_expand_dims():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_expand_dims(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.arange(6, dtype="float32").reshape(2, 3) x = np.arange(6, dtype="float32").reshape(2, 3)
xx = tensor(x)
xx = make_tensor(x, network)


for axis in [2, -3, (3, -4), (1, -4)]: for axis in [2, -3, (3, -4), (1, -4)]:
y = np.expand_dims(x, axis) y = np.expand_dims(x, axis)
@@ -169,11 +226,17 @@ def test_expand_dims():
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())




def test_elemwise_dtype_promotion():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_elemwise_dtype_promotion(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.random.rand(2, 3).astype("float32") x = np.random.rand(2, 3).astype("float32")
y = np.random.rand(1, 3).astype("float16") y = np.random.rand(1, 3).astype("float16")
xx = tensor(x)
yy = tensor(y)
xx = make_tensor(x, network)
yy = make_tensor(y, network)
z = xx * yy z = xx * yy
np.testing.assert_equal(z.numpy(), x * y) np.testing.assert_equal(z.numpy(), x * y)


@@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion():
np.testing.assert_equal(z.numpy(), x - y) np.testing.assert_equal(z.numpy(), x - y)




def test_linspace():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_linspace(is_varnode):
if is_varnode:
network = Network()
else:
network = None

cases = [ cases = [
{"input": [1, 9, 9]}, {"input": [1, 9, 9]},
{"input": [3, 10, 8]}, {"input": [3, 10, 8]},
@@ -193,6 +262,7 @@ def test_linspace():
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
network=network,
) )


cases = [ cases = [
@@ -203,20 +273,28 @@ def test_linspace():
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
network=network,
) )


cases = [ cases = [
{"input": [1, tensor(9), 9]},
{"input": [tensor(1), 9, tensor(9)]},
{"input": [1, make_tensor(9, network), 9]},
{"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
] ]
opr_test( opr_test(
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
network=network,
) )




def test_arange():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_arange(is_varnode):
if is_varnode:
network = Network()
else:
network = None

cases = [ cases = [
{"input": [1, 9, 1]}, {"input": [1, 9, 1]},
{"input": [2, 10, 2]}, {"input": [2, 10, 2]},
@@ -225,6 +303,7 @@ def test_arange():
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
) )


cases = [ cases = [
@@ -235,6 +314,7 @@ def test_arange():
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
) )


cases = [ cases = [
@@ -245,20 +325,33 @@ def test_arange():
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
) )




def test_round():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_round(is_varnode):
if is_varnode:
network = Network()
else:
network = None

data1_shape = (15,) data1_shape = (15,)
data2_shape = (25,) data2_shape = (25,)
data1 = np.random.random(data1_shape).astype(np.float32) data1 = np.random.random(data1_shape).astype(np.float32)
data2 = np.random.random(data2_shape).astype(np.float32) data2 = np.random.random(data2_shape).astype(np.float32)


cases = [{"input": data1}, {"input": data2}] cases = [{"input": data1}, {"input": data2}]
opr_test(cases, F.round, ref_fn=np.round)
opr_test(cases, F.round, ref_fn=np.round, network=network)




def test_flatten():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_flatten(is_varnode):
if is_varnode:
network = Network()
else:
network = None

data0_shape = (2, 3, 4, 5) data0_shape = (2, 3, 4, 5)
data1_shape = (4, 5, 6, 7) data1_shape = (4, 5, 6, 7)
data0 = np.random.random(data0_shape).astype(np.float32) data0 = np.random.random(data0_shape).astype(np.float32)
@@ -273,7 +366,7 @@ def test_flatten():
{"input": data0, "output": output0}, {"input": data0, "output": output0},
{"input": data1, "output": output1}, {"input": data1, "output": output1},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn)
opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)


output0 = (2, 3 * 4 * 5) output0 = (2, 3 * 4 * 5)
output1 = (4, 5 * 6 * 7) output1 = (4, 5 * 6 * 7)
@@ -281,7 +374,7 @@ def test_flatten():
{"input": data0, "output": output0}, {"input": data0, "output": output0},
{"input": data1, "output": output1}, {"input": data1, "output": output1},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1)
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)


output0 = (2, 3, 4 * 5) output0 = (2, 3, 4 * 5)
output1 = (4, 5, 6 * 7) output1 = (4, 5, 6 * 7)
@@ -289,7 +382,7 @@ def test_flatten():
{"input": data0, "output": output0}, {"input": data0, "output": output0},
{"input": data1, "output": output1}, {"input": data1, "output": output1},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2)
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)


output0 = (2, 3 * 4, 5) output0 = (2, 3 * 4, 5)
output1 = (4, 5 * 6, 7) output1 = (4, 5 * 6, 7)
@@ -297,10 +390,23 @@ def test_flatten():
{"input": data0, "output": output0}, {"input": data0, "output": output0},
{"input": data1, "output": output1}, {"input": data1, "output": output1},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2)
opr_test(
cases,
F.flatten,
compare_fn=compare_fn,
start_axis=1,
end_axis=2,
network=network,
)



@pytest.mark.parametrize("is_varnode", [True, False])
def test_broadcast(is_varnode):
if is_varnode:
network = Network()
else:
network = None


def test_broadcast():
input1_shape = (20, 30) input1_shape = (20, 30)
output1_shape = (30, 20, 30) output1_shape = (30, 20, 30)
data1 = np.random.random(input1_shape).astype(np.float32) data1 = np.random.random(input1_shape).astype(np.float32)
@@ -309,14 +415,19 @@ def test_broadcast():
output2_shape = (20, 10, 20) output2_shape = (20, 10, 20)
data2 = np.random.random(input2_shape).astype(np.float32) data2 = np.random.random(input2_shape).astype(np.float32)


input3_shape = (10, 10)
output3_shape = (10, 10)
data3 = np.random.random(input3_shape).astype(np.float32)

def compare_fn(x, y): def compare_fn(x, y):
assert x.shape[0] == y assert x.shape[0] == y


cases = [ cases = [
{"input": [data1, output1_shape], "output": output1_shape}, {"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape}, {"input": [data2, output2_shape], "output": output2_shape},
{"input": [data3, output3_shape], "output": output3_shape},
] ]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)


x = F.ones((2, 1, 3)) x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -329,35 +440,41 @@ def test_broadcast():
F.broadcast_to(x, (1, 3)) F.broadcast_to(x, (1, 3))




def test_utils_astensor1d():
reference = tensor(0)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_utils_astensor1d(is_varnode):
if is_varnode:
network = Network()
else:
network = None

reference = make_tensor(0, network)


# literal # literal
x = [1, 2, 3] x = [1, 2, 3]
for dtype in [None, "float32"]: for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype) xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x) np.testing.assert_equal(xx.numpy(), x)


# numpy array # numpy array
x = np.asarray([1, 2, 3], dtype="int32") x = np.asarray([1, 2, 3], dtype="int32")
for dtype in [None, "float32"]: for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype) xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)


# tensor # tensor
x = tensor([1, 2, 3], dtype="int32")
x = make_tensor([1, 2, 3], network)
for dtype in [None, "float32"]: for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype) xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x.numpy()) np.testing.assert_equal(xx.numpy(), x.numpy())


# mixed # mixed
x = [1, tensor(2), 3]
x = [1, make_tensor(2, network), 3]
for dtype in [None, "float32"]: for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype) xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), [1, 2, 3]) np.testing.assert_equal(xx.numpy(), [1, 2, 3])




@@ -377,35 +494,60 @@ def test_device():
np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) np.testing.assert_almost_equal(y5.numpy(), y6.numpy())




def test_identity():
x = tensor(np.random.random((5, 10)).astype(np.float32))
@pytest.mark.parametrize("is_varnode", [True, False])
def test_identity(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
y = F.copy(x) y = F.copy(x)
np.testing.assert_equal(y.numpy(), x) np.testing.assert_equal(y.numpy(), x)




def copy_test(dst, src):
def copy_test(dst, src, network):
data = np.random.random((2, 3)).astype(np.float32) data = np.random.random((2, 3)).astype(np.float32)
x = tensor(data, device=src)
x = make_tensor(data, device=src, network=network)
y = F.copy(x, dst) y = F.copy(x, dst)
assert np.allclose(data, y.numpy()) assert np.allclose(data, y.numpy())
z = x.to(dst)
assert np.allclose(data, z.numpy())
if network is None:
z = x.to(dst)
assert np.allclose(data, z.numpy())




@pytest.mark.require_ngpu(1) @pytest.mark.require_ngpu(1)
def test_copy_h2d():
copy_test("cpu0", "gpu0")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_h2d(is_varnode):
if is_varnode:
network = Network()
else:
network = None

copy_test("cpu0", "gpu0", network=network)




@pytest.mark.require_ngpu(1) @pytest.mark.require_ngpu(1)
def test_copy_d2h():
copy_test("gpu0", "cpu0")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_d2h(is_varnode):
if is_varnode:
network = Network()
else:
network = None

copy_test("gpu0", "cpu0", network=network)




@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)
def test_copy_d2d():
copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_d2d(is_varnode):
if is_varnode:
network = Network()
else:
network = None

copy_test("gpu0", "gpu1", network=network)
copy_test("gpu0:0", "gpu0:1", network=network)




@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -420,7 +562,13 @@ def test_copy_d2d():
((), 10, None), ((), 10, None),
], ],
) )
def test_repeat(shape, repeats, axis):
@pytest.mark.parametrize("is_varnode", [True, False])
def test_repeat(shape, repeats, axis, is_varnode):
if is_varnode:
network = Network()
else:
network = None

def repeat_func(inp): def repeat_func(inp):
return F.repeat(inp=inp, repeats=repeats, axis=axis) return F.repeat(inp=inp, repeats=repeats, axis=axis)


@@ -432,7 +580,10 @@ def test_repeat(shape, repeats, axis):
cases = [{"input": np.array(1.23)}] cases = [{"input": np.array(1.23)}]


opr_test( opr_test(
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis),
cases,
repeat_func,
ref_fn=lambda inp: np.repeat(inp, repeats, axis),
network=network,
) )




@@ -445,14 +596,16 @@ def test_repeat(shape, repeats, axis):
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
], ],
) )
def test_tile(shape, reps):
@pytest.mark.parametrize("is_varnode", [True])
def test_tile(shape, reps, is_varnode):
if is_varnode:
network = Network()
else:
network = None

def tile_func(inp): def tile_func(inp):
return F.tile(inp=inp, reps=reps) return F.tile(inp=inp, reps=reps)


cases = [
{"input": np.random.randn(*shape).astype("float32")},
]
cases = [{"input": np.random.randn(*shape).astype("float32")}]


opr_test(
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps),
)
opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)

+ 45
- 1
imperative/python/test/unit/quantization/test_module.py View File

@@ -30,7 +30,10 @@ min_max_fakequant_qconfig = QConfig(
act_fake_quant=partial(FakeQuantize, dtype="qint8"), act_fake_quant=partial(FakeQuantize, dtype="qint8"),
) )


inp_scale = np.float32(np.random.rand() + 1)

def gen_inp_scale():
return np.float32(np.random.rand() + 1)



min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
max_val = np.random.randint(1, 127, size=(2,)).astype("float32") max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
@@ -116,6 +119,7 @@ def test_dequant_stub():
q_net.eval() q_net.eval()


x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale) x = fake_quant_act(x, inp_scale)
x.qparams.scale = inp_scale x.qparams.scale = inp_scale


@@ -192,6 +196,7 @@ def test_linear():
init_qat_net(qat_net) init_qat_net(qat_net)


x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale) x = fake_quant_act(x, inp_scale)
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))


@@ -235,6 +240,7 @@ def test_conv(module):
init_qat_net(qat_net) init_qat_net(qat_net)


x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale) x = fake_quant_act(x, inp_scale)
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))


@@ -269,3 +275,41 @@ def test_conv(module):
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5)
np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale)
np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale) np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)


def test_concat():
normal_net = Float.Concat()
normal_net.eval()

qat_net = QAT.Concat()
qat_net.eval()
disable_observer(qat_net)

propagate_qconfig(qat_net, min_max_fakequant_qconfig)
init_qat_net(qat_net)

inps = []
inps_int8 = []
for i in range(3):
inp_scale = gen_inp_scale()
inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32")))
inps[i] = fake_quant_act(inps[i], inp_scale)
inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
inps_int8.append(quant(inps[i], inp_scale))

qat_from_float = QAT.Concat.from_float_module(normal_net)
qat_from_float.eval()
disable_fake_quant(qat_from_float)
disable_observer(qat_from_float)

q_net = Q.Concat.from_qat_module(qat_net)
q_net.eval()

normal = normal_net(inps)
qat_without_fakequant = qat_from_float(inps)
fake_quant_normal = fake_quant_act(normal_net(inps), act_scale)
qat = qat_net(inps)
q = q_net(inps_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal.numpy())
np.testing.assert_allclose(q, fake_quant_normal.numpy())

+ 30
- 1
imperative/python/test/unit/utils/test_dump_naming.py View File

@@ -124,6 +124,35 @@ def test_with_submodule(symbolic):




@pytest.mark.parametrize("symbolic", [False, True]) @pytest.mark.parametrize("symbolic", [False, True])
def test_with_submodule_in_container(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.l0 = [M.Linear(3, 3) for _ in range(2)]
self.l1 = tuple(self.l0)
self.l2 = dict(zip(["l2-0", "l2-1"], self.l0))

def forward(self, x):
for i in range(2):
x = self.l0[i](x)
x = self.l1[i](x)
x = self.l2["l2-%d" % i](x)
return x

m = Simple("simple")

ops = _dump_and_load(m, symbolic)
assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD"
assert ops[-1].name == "simple.l2.l2-1.ADD"
assert ops[-2].name == "simple.l2.l2-1.MatrixMul"
assert ops[-3].name == "simple.l1.1.ADD"
assert ops[-4].name == "simple.l1.1.MatrixMul"
assert ops[-5].name == "simple.l0.1.ADD"
assert ops[-6].name == "simple.l0.1.MatrixMul"


@pytest.mark.parametrize("symbolic", [False, True])
def test_named_submodule(symbolic): def test_named_submodule(symbolic):
class Simple(M.Module): class Simple(M.Module):
def __init__(self, name): def __init__(self, name):
@@ -264,4 +293,4 @@ def test_quantized_module_user_naming_param(symbolic):
(matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"] (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"]
for var in matrix_mul_op.inputs: for var in matrix_mul_op.inputs:
assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight") assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight")
# BUG bias' name does not meet expectations because of astype operator after quantization
# WONTFIX: bias' name does not meet expectations because of astype operator after quantization

+ 9
- 15
imperative/python/test/unit/utils/test_network.py View File

@@ -34,13 +34,11 @@ def test_replace_var():
vara = graph.var_filter.name("a").as_unique() vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique() varb = graph.var_filter.name("b").as_unique()


out = F.mul(vara.var, varb.var)
out = F.mul(vara, varb)
out = F.relu(out) out = F.relu(out)


var_list = graph.add_dep_oprs(out)

opnode = list(graph.opr_filter.has_input(vara)) opnode = list(graph.opr_filter.has_input(vara))
repl_dict = {opnode[0].outputs[0]: var_list[0]}
repl_dict = {opnode[0].outputs[0]: out}
graph.replace_vars(repl_dict) graph.replace_vars(repl_dict)


modified_model = io.BytesIO() modified_model = io.BytesIO()
@@ -72,14 +70,12 @@ def test_replace_opr():
vara = graph.var_filter.name("a").as_unique() vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique() varb = graph.var_filter.name("b").as_unique()


out1 = F.sub(vara.var, varb.var)
out1 = F.sub(vara, varb)
out1 = F.relu(out1) out1 = F.relu(out1)

var_list = graph.add_dep_oprs(out1)
repl_opr = as_oprnode(var_list)
out1 = graph.add_dep_oprs(out1)
orig_opr = graph.opr_filter.has_input(vara).as_unique() orig_opr = graph.opr_filter.has_input(vara).as_unique()


repl_dict = {orig_opr: repl_opr}
repl_dict = {orig_opr: out1[0].owner}
graph.replace_oprs(repl_dict) graph.replace_oprs(repl_dict)
modified_model1 = io.BytesIO() modified_model1 = io.BytesIO()
graph.dump(modified_model1) graph.dump(modified_model1)
@@ -171,8 +167,7 @@ def test_add_input():
inp_c = graph.make_input_node((2,), np.int32, name="c") inp_c = graph.make_input_node((2,), np.int32, name="c")
varo = graph.var_filter.name("o").as_unique() varo = graph.var_filter.name("o").as_unique()


out = F.add(varo.var, inp_c.var)
out = graph.add_dep_oprs(out)[0]
out = F.add(varo, inp_c)
out.name = "o1" out.name = "o1"
graph.remove_output(varo) graph.remove_output(varo)
graph.add_output(out) graph.add_output(out)
@@ -206,12 +201,11 @@ def test_add_output():
var_a = net.var_filter.name("a").as_unique() var_a = net.var_filter.name("a").as_unique()
var_b = net.var_filter.name("b").as_unique() var_b = net.var_filter.name("b").as_unique()


y = F.add(var_a.var, var_b.var)
y = F.add(var_a, var_b)
y = F.sigmoid(y) y = F.sigmoid(y)


new_vars = net.add_dep_oprs(y)[0]
new_vars.name = "o1"
net.add_output(new_vars)
y.name = "o1"
net.add_output(y)


modified_model = io.BytesIO() modified_model = io.BytesIO()
net.dump(modified_model) net.dump(modified_model)


+ 13
- 0
imperative/python/test/unit/utils/test_network_node.py View File

@@ -466,6 +466,19 @@ def test_topk():
check_pygraph_dump(fwd, [x], [top, indices]) check_pygraph_dump(fwd, [x], [top, indices])




def test_nvof():
if not is_cuda_available():
return
src_shape = (4, 5, 224, 224, 4)
src = np.random.randint(0, 255, src_shape).astype("uint8")
src = Tensor(src)

@trace(symbolic=True, capture_as_const=True)
def fwd(src):
return F.nn.nvof(src, precision=1)

result = fwd(src)
check_pygraph_dump(fwd, [src], [result])




def test_random(): def test_random():


+ 1
- 1
imperative/python/version_template.py View File

@@ -6,5 +6,5 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
__version__ = "1.3.0.dev"
__version__ = "1.3.1"



+ 2
- 2
imperative/src/impl/ops/broadcast.cpp View File

@@ -24,14 +24,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
return Broadcast::make(); return Broadcast::make();
} }


cg::OperatorNodeBase* apply_on_var_node(
auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Broadcast>(); auto&& op = def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
OperatorNodeConfig config{op.make_name()}; OperatorNodeConfig config{op.make_name()};
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr();
return opr::Broadcast::make(inputs[0], inputs[1], config);
} }


bool valid_broadcast(const TensorShape& src_shape, bool valid_broadcast(const TensorShape& src_shape,


+ 2
- 1
imperative/tablegen/CMakeLists.txt View File

@@ -1,6 +1,7 @@
# mgb tablegen executable # mgb tablegen executable
set(TABLE_TARGET mgb-mlir-autogen) set(TABLE_TARGET mgb-mlir-autogen)
add_executable(${TABLE_TARGET} autogen.cpp)
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
add_executable(${TABLE_TARGET} ${SRCS})
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR})
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport)
set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) set(MGB_TABLEGEN_EXE ${TABLE_TARGET})


+ 15
- 730
imperative/tablegen/autogen.cpp View File

@@ -1,8 +1,17 @@
#include <iostream>
#include <unordered_map>
#include <functional>

#include "./helper.h"
/**
* \file imperative/tablegen/autogen.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./targets/cpp_class.h"
#include "./targets/pybind11.h"
#include "./targets/python_c_extension.h"


using llvm::raw_ostream; using llvm::raw_ostream;
using llvm::RecordKeeper; using llvm::RecordKeeper;
@@ -27,731 +36,7 @@ llvm::cl::opt<ActionType> action(
clEnumValN(CPython, "gen-python-c-extension", clEnumValN(CPython, "gen-python-c-extension",
"Generate python c extensions"))); "Generate python c extensions")));


using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
using MgbOp = mlir::tblgen::MgbOpBase;
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;

llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
// Note: we have already registered the corresponding attr wrappers
// for following basic ctypes so we needn't handle them here
/* auto&& attr_type_name = attr.getAttrDefName();
if (attr_type_name == "UI32Attr") {
return "uint32_t";
}
if (attr_type_name == "UI64Attr") {
return "uint64_t";
}
if (attr_type_name == "I32Attr") {
return "int32_t";
}
if (attr_type_name == "F32Attr") {
return "float";
}
if (attr_type_name == "F64Attr") {
return "double";
}
if (attr_type_name == "StrAttr") {
return "std::string";
}
if (attr_type_name == "BoolAttr") {
return "bool";
}*/

auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
return e->getEnumName();
}
return attr.getUnderlyingType();
}

static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
os << formatv(
"class {0} : public OpDefImplBase<{0}> {{\n"
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
"public:\n",
op.getCppClassName()
);
// handle enum alias
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
os << formatv(
" using {0} = {1};\n",
attr->getEnumName(), attr->getUnderlyingType()
);
}
}
for (auto &&i : op.getMgbAttributes()) {
auto defaultValue = i.attr.getDefaultValue().str();
if (!defaultValue.empty()) {
defaultValue = formatv(" = {0}", defaultValue);
}
os << formatv(
" {0} {1}{2};\n",
attr_to_ctype(i.attr), i.name, defaultValue
);
}

auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
os << formatv(
" {0}({1}){2}{3}\n",
op.getCppClassName(), paramList, memInitList, body
);
};

gen_ctor("", "", " = default;");

if (!op.getMgbAttributes().empty()) {
std::vector<std::string> paramList, initList;
for (auto &&i : op.getMgbAttributes()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
paramList.push_back("std::string scope_ = {}");
gen_ctor(llvm::join(paramList, ", "),
": " + llvm::join(initList, ", "),
" { set_scope(scope_); }");
}

auto packedParams = op.getPackedParams();
if (!packedParams.empty()) {
std::vector<std::string> paramList, initList;
for (auto &&p : packedParams) {
auto&& paramFields = p.getFields();
auto&& paramType = p.getFullName();
auto&& paramName = formatv("packed_param_{0}", paramList.size());
paramList.push_back(
paramFields.empty() ? paramType.str()
: formatv("{0} {1}", paramType, paramName)
);
for (auto&& i : paramFields) {
initList.push_back(formatv(
"{0}({1}.{0})", i.name, paramName
));
}
}
for (auto&& i : op.getExtraArguments()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
gen_ctor(llvm::join(paramList, ", "),
initList.empty() ? "" : ": " + llvm::join(initList, ", "),
" {}");
}

if (!packedParams.empty()) {
for (auto&& p : packedParams) {
auto accessor = p.getAccessor();
if (!accessor.empty()) {
os << formatv(
" {0} {1}() const {{\n",
p.getFullName(), accessor
);
std::vector<llvm::StringRef> fields;
for (auto&& i : p.getFields()) {
fields.push_back(i.name);
}
os << formatv(
" return {{{0}};\n",
llvm::join(fields, ", ")
);
os << " }\n";
}
}
}

if (auto decl = op.getExtraOpdefDecl()) {
os << decl.getValue();
}

os << formatv(
"};\n\n"
);
}

static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) {
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
if (attr->supportToString()) {
std::vector<std::string> case_body;
std::string ename = formatv("{0}::{1}",
op.getCppClassName(), attr->getEnumName());
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
case_body.push_back(formatv(
"case {0}::{1}: return \"{1}\";", ename, v));
});
os << formatv(R"(
template <>
struct ToStringTrait<{0}> {
std::string operator()({0} e) const {
switch (e) {
{1}
default:
return "{0}::Unknown";
}
}
};
)", ename, llvm::join(case_body, "\n"));
}
}
}
}

static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
auto&& className = op.getCppClassName();
os << formatv(
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
);
auto formatMethImpl = [&](auto&& meth) {
return formatv(
"{0}_{1}_impl", className, meth
);
};
std::vector<std::string> methods;
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
os << "namespace {\n";

// generate hash()
mlir::tblgen::FmtContext ctx;
os << formatv(
"size_t {0}(const OpDef& def_) {{\n",
formatMethImpl("hash")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
os << "}\n";

// generate is_same_st()
os << formatv(
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
formatMethImpl("is_same_st")
);
os << formatv(
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
" &&b_ = rhs_.cast_final_safe<{0}>();\n"
" static_cast<void>(a_);\n"
" static_cast<void>(b_);\n",
className
);
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
os << "}\n";

// generate props()
os << formatv(
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
formatMethImpl("props")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
os << "}\n";

// generate make_name()
os << formatv(
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
os << "}\n";

os << "} // anonymous namespace\n";

methods.push_back("hash");
methods.push_back("is_same_st");
methods.push_back("props");
methods.push_back("make_name");
}
if (!methods.empty()) {
os << formatv(
"OP_TRAIT_REG({0}, {0})", op.getCppClassName()
);
for (auto&& i : methods) {
os << formatv(
"\n .{0}({1})", i, formatMethImpl(i)
);
}
os << ";\n\n";
}
}

struct EnumContext {
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
};

static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
auto className = op.getCppClassName();
os << formatv(
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
className
);
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID =
llvm::cast<MgbEnumAttr>(aliasBase)
.getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = ctx.enumAlias;
auto&& iter = enumAlias.find(enumID);
if (iter == enumAlias.end()) {
os << formatv(
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
className, attr->getEnumName()
);
std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) {
os << formatv(
"\n .value(\"{2}\", {0}::{1}::{2})",
className, attr->getEnumName(), i
);
body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};",
className, attr->getEnumName(), i
));
}
if (attr->getEnumCombinedFlag()) {
//! define operator |
os << formatv(
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));"
"\n })",
className, attr->getEnumName());
//! define operator &
os << formatv(
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));"
"\n })",
className, attr->getEnumName());
}
os << formatv(
"\n .def(py::init([](const std::string& in) {"
"\n auto&& str = normalize_enum(in);"
"\n {0}"
"\n throw py::cast_error(\"invalid enum value \" + in);"
"\n }));\n",
llvm::join(body, "\n ")
);
os << formatv(
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
className, attr->getEnumName()
);
enumAlias.emplace(enumID,
std::make_pair(className, attr->getEnumName()));
} else {
os << formatv(
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
className, attr->getEnumName(),
iter->second.first, iter->second.second
);
}
}
}
// generate op class binding
os << formatv("{0}Inst", className);
bool hasDefaultCtor = op.getMgbAttributes().empty();
if (!hasDefaultCtor) {
os << "\n .def(py::init<";
std::vector<llvm::StringRef> targs;
for (auto &&i : op.getMgbAttributes()) {
targs.push_back(i.attr.getReturnType());
}
os << llvm::join(targs, ", ");
os << ", std::string>()";
for (auto &&i : op.getMgbAttributes()) {
os << formatv(", py::arg(\"{0}\")", i.name);
auto defaultValue = i.attr.getDefaultValue();
if (!defaultValue.empty()) {
os << formatv(" = {0}", defaultValue);
} else {
hasDefaultCtor = true;
}
}
os << ", py::arg(\"scope\") = {})";
}
if (hasDefaultCtor) {
os << "\n .def(py::init<>())";
}
for (auto &&i : op.getMgbAttributes()) {
os << formatv(
"\n .def_readwrite(\"{0}\", &{1}::{0})",
i.name, className
);
}
os << ";\n\n";
}

static std::string gen_op_def_python_c_extension_enum(
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr,
llvm::StringRef className) {
std::string body;
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = ctx.enumAlias;
auto&& iter = enumAlias.find(enumID);
auto enumName = attr->getEnumName();
body += "{\n";
body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className,
enumName);
if (iter == enumAlias.end()) {
os << formatv(
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
className, enumName);
os << formatv(
"template<> const char* EnumWrapper<{0}::{1}>::name = "
"\"{0}.{1}\";\n",
className, enumName);
std::vector<std::string> pairStr;
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<std::string, {0}::{1}>
EnumWrapper<{0}::{1}>::str2type = {{
{2}
};
)",
className, enumName, llvm::join(pairStr, ", "));
pairStr.clear();
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<{0}::{1}, std::string>
EnumWrapper<{0}::{1}>::type2str = {{
{2}
};
)",
className, enumName, llvm::join(pairStr, ", "));
body += formatv(R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "{0}.{1}";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
mgb_assert(PyType_Ready(&e_type) >= 0);
)",
className, enumName);
for (auto&& i : attr->getEnumMembers()) {
body += formatv(R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})",
className, enumName, i);
}
enumAlias.emplace(enumID, std::make_pair(className, enumName));
}
body += formatv(R"(
PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)",
enumName);
body += "}\n";
return body;
}

static std::string gen_op_def_python_c_extension_bit_combined_enum(
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr,
llvm::StringRef className) {
std::string body;
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = ctx.enumAlias;
auto&& iter = enumAlias.find(enumID);
auto enumName = attr->getEnumName();
body += "{\n";
body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;",
className, enumName);
if (iter == enumAlias.end()) {
os << formatv(
"template<> PyTypeObject "
"BitCombinedEnumWrapper<{0}::{1}>::type={{};\n",
className, enumName);
os << formatv(
"template<> PyNumberMethods "
"BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n",
className, enumName);
os << formatv(
"template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name "
"= \"{0}.{1}\";\n",
className, enumName);
os << formatv(
"template<> struct EnumTrait<{0}::{1}> {{ static constexpr "
"bool is_bit_combined = true;};\n",
className, enumName);
std::vector<std::string> pairStr;
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<std::string, {0}::{1}>
BitCombinedEnumWrapper<{0}::{1}>::str2type = {{
{2}
};
)",
className, enumName, llvm::join(pairStr, ", "));
pairStr.clear();
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<{0}::{1}, std::string>
BitCombinedEnumWrapper<{0}::{1}>::type2str = {{
{2}
};
)",
className, enumName, llvm::join(pairStr, ", "));
body += formatv(R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "{0}.{1}";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum;
e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init;
e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare;
auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods;
number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or;
number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and;
e_type.tp_as_number = &number_method;
mgb_assert(PyType_Ready(&e_type) >= 0);
)",
className, enumName);
for (auto&& i : attr->getEnumMembers()) {
body += formatv(R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})",
className, enumName, i);
}
enumAlias.emplace(enumID, std::make_pair(className, enumName));
}
body += formatv(R"(
PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)",
enumName);
body += "}\n";
return body;
}

static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
auto className = op.getCppClassName();
std::string body;

// generate PyType for enum class member
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
if (attr->getEnumCombinedFlag()) {
body += gen_op_def_python_c_extension_bit_combined_enum(
os, ctx, attr, className);
} else {
body += gen_op_def_python_c_extension_enum(os, ctx, attr,
className);
}
}
}

// generate getsetters
std::vector<std::string> getsetters;
for (auto &&i : op.getMgbAttributes()) {
getsetters.push_back(formatv(
"{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},",
className, i.name));
}

// generate tp_init
std::string initBody;
if (!op.getMgbAttributes().empty()) {
initBody += "static const char* kwlist[] = {";

std::vector<llvm::StringRef> attr_name_list;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
attr_name_list.push_back(attr.name);
});
attr_name_list.push_back("scope");

llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr);
});
initBody += "NULL};\n";
initBody += " PyObject ";
std::vector<std::string> attr_init;
llvm::for_each(attr_name_list, [&](auto&& attr) {
attr_init.push_back(formatv("*{0} = NULL", attr));
});
initBody += llvm::join(attr_init, ", ") + ";\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
// an extra slot created for name
initBody += std::string(attr_name_list.size(), 'O');
initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv(", &{0}", attr);
});
initBody += "))\n";
initBody += " return -1;\n";

llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(R"(
if ({1}) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
pyobj_convert_generic<decltype({0}::{1})>::from({1});
} CATCH_ALL(-1)
}
)", className, attr.name);
});

initBody += formatv(R"(
if (scope) {{
try {{
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(pyobj_convert_generic<std::string>::from(scope));
} CATCH_ALL(-1)
}
)", className);

}
initBody += "\n return 0;";

os << formatv(R"(
PyOpDefBegin({0}) // {{
static PyGetSetDef py_getsetters[];
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd({0})
PyGetSetDef PyOp({0})::py_getsetters[] = {{
{1}
{{NULL} /* Sentinel */
};
int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{
{2}
}

void _init_py_{0}(py::module m) {{
using py_op = PyOp({0});
auto& py_type = PyOpType({0});
py_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.{0}";
py_type.tp_basicsize = sizeof(PyOp({0}));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "{0}";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
{3}
PyType_Modified(&py_type);
m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second);
}
)",
op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body);
}

static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
std::function<void(raw_ostream&, MgbOp&)> callback) {
auto op_base_class = keeper.getClass("Op");
ASSERT(op_base_class, "could not find base class Op");
for (auto&& i: keeper.getDefs()) {
auto&& r = i.second;
if (r->isSubClassOf(op_base_class)) {
auto op = mlir::tblgen::Operator(r.get());
if (op.getDialectName().str() == "mgb") {
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
callback(os, llvm::cast<MgbOp>(op));
}
}
}
}

static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
for_each_operator(os, keeper, gen_op_def_c_header_single);
for_each_operator(os, keeper, gen_to_string_trait_for_enum);
return false;
}

static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
for_each_operator(os, keeper, gen_op_def_c_body_single);
return false;
}

static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
EnumContext ctx;
using namespace std::placeholders;
for_each_operator(os, keeper,
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
return false;
}

static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) {
EnumContext ctx;
using namespace std::placeholders;
for_each_operator(os, keeper,
std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx)));
os << "#define INIT_ALL_OP(m)";
for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) {
os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName());
});
os << "\n";
return false;
}
using namespace mlir::tblgen;


int main(int argc, char **argv) { int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);


+ 40
- 0
imperative/tablegen/emitter.h View File

@@ -0,0 +1,40 @@
/**
* \file imperative/tablegen/emitter.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <unordered_map>
#include <stdexcept>

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"

namespace mlir::tblgen {

struct Environment {
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
};

struct EmitterBase {
EmitterBase(raw_ostream& os_): os(os_) {}
EmitterBase(raw_ostream& os_, Environment& env): os(os_), env_p(&env) {}
protected:
void newline() { os << "\n"; }
Environment& env() {
if (env_p) {
return *env_p;
}
throw std::runtime_error("access global environment via non-environment emitter");
}
raw_ostream& os;
Environment* env_p = nullptr;
};

} // namespace mlir::tblgen

+ 36
- 0
imperative/tablegen/helper.h View File

@@ -1,3 +1,16 @@
/**
* \file imperative/tablegen/helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>


@@ -278,5 +291,28 @@ public:
} }
}; };


using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
using MgbOp = mlir::tblgen::MgbOpBase;
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;

static inline void foreach_operator(llvm::RecordKeeper &keeper,
std::function<void(MgbOp&)> callback) {
auto op_base_class = keeper.getClass("Op");
ASSERT(op_base_class, "could not find base class Op");
for (auto&& i: keeper.getDefs()) {
auto&& r = i.second;
if (r->isSubClassOf(op_base_class)) {
auto op = mlir::tblgen::Operator(r.get());
if (op.getDialectName().str() == "mgb") {
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
callback(llvm::cast<MgbOp>(op));
}
}
}
}

} // namespace tblgen } // namespace tblgen
} // namespace mlir } // namespace mlir

+ 309
- 0
imperative/tablegen/targets/cpp_class.cpp View File

@@ -0,0 +1,309 @@
/**
* \file imperative/tablegen/targets/cpp_class.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./cpp_class.h"
#include "../emitter.h"

namespace mlir::tblgen {
namespace {
llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
// Note: we have already registered the corresponding attr wrappers
// for following basic ctypes so we needn't handle them here
/* auto&& attr_type_name = attr.getAttrDefName();
if (attr_type_name == "UI32Attr") {
return "uint32_t";
}
if (attr_type_name == "UI64Attr") {
return "uint64_t";
}
if (attr_type_name == "I32Attr") {
return "int32_t";
}
if (attr_type_name == "F32Attr") {
return "float";
}
if (attr_type_name == "F64Attr") {
return "double";
}
if (attr_type_name == "StrAttr") {
return "std::string";
}
if (attr_type_name == "BoolAttr") {
return "bool";
}*/

auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
return e->getEnumName();
}
return attr.getUnderlyingType();
}

class OpDefEmitter final: public EmitterBase {
public:
OpDefEmitter(MgbOp& op_, raw_ostream& os_):
EmitterBase(os_), op(op_) {}
void emit_header();
void emit_tpl_spl();
void emit_body();
private:
MgbOp& op;
};

void OpDefEmitter::emit_header() {
os << formatv(
"class {0} : public OpDefImplBase<{0}> {{\n"
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
"public:\n",
op.getCppClassName()
);
// handle enum alias
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
os << formatv(
" using {0} = {1};\n",
attr->getEnumName(), attr->getUnderlyingType()
);
}
}
for (auto &&i : op.getMgbAttributes()) {
auto defaultValue = i.attr.getDefaultValue().str();
if (!defaultValue.empty()) {
defaultValue = formatv(" = {0}", defaultValue);
}
os << formatv(
" {0} {1}{2};\n",
attr_to_ctype(i.attr), i.name, defaultValue
);
}

auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
os << formatv(
" {0}({1}){2}{3}\n",
op.getCppClassName(), paramList, memInitList, body
);
};

gen_ctor("", "", " = default;");

if (!op.getMgbAttributes().empty()) {
std::vector<std::string> paramList, initList;
for (auto &&i : op.getMgbAttributes()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
paramList.push_back("std::string scope_ = {}");
gen_ctor(llvm::join(paramList, ", "),
": " + llvm::join(initList, ", "),
" { set_scope(scope_); }");
}

auto packedParams = op.getPackedParams();
if (!packedParams.empty()) {
std::vector<std::string> paramList, initList;
for (auto &&p : packedParams) {
auto&& paramFields = p.getFields();
auto&& paramType = p.getFullName();
auto&& paramName = formatv("packed_param_{0}", paramList.size());
paramList.push_back(
paramFields.empty() ? paramType.str()
: formatv("{0} {1}", paramType, paramName)
);
for (auto&& i : paramFields) {
initList.push_back(formatv(
"{0}({1}.{0})", i.name, paramName
));
}
}
for (auto&& i : op.getExtraArguments()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
gen_ctor(llvm::join(paramList, ", "),
initList.empty() ? "" : ": " + llvm::join(initList, ", "),
" {}");
}

if (!packedParams.empty()) {
for (auto&& p : packedParams) {
auto accessor = p.getAccessor();
if (!accessor.empty()) {
os << formatv(
" {0} {1}() const {{\n",
p.getFullName(), accessor
);
std::vector<llvm::StringRef> fields;
for (auto&& i : p.getFields()) {
fields.push_back(i.name);
}
os << formatv(
" return {{{0}};\n",
llvm::join(fields, ", ")
);
os << " }\n";
}
}
}

if (auto decl = op.getExtraOpdefDecl()) {
os << decl.getValue();
}

os << formatv(
"};\n\n"
);
}

void OpDefEmitter::emit_tpl_spl() {
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
if (attr->supportToString()) {
std::vector<std::string> case_body;
std::string ename = formatv("{0}::{1}",
op.getCppClassName(), attr->getEnumName());
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
case_body.push_back(formatv(
"case {0}::{1}: return \"{1}\";", ename, v));
});
os << formatv(R"(
template <>
struct ToStringTrait<{0}> {
std::string operator()({0} e) const {
switch (e) {
{1}
default:
return "{0}::Unknown";
}
}
};
)", ename, llvm::join(case_body, "\n"));
}
}
}
}

void OpDefEmitter::emit_body() {
auto&& className = op.getCppClassName();
os << formatv(
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
);
auto formatMethImpl = [&](auto&& meth) {
return formatv(
"{0}_{1}_impl", className, meth
);
};
std::vector<std::string> methods;
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
os << "namespace {\n";

// generate hash()
mlir::tblgen::FmtContext ctx;
os << formatv(
"size_t {0}(const OpDef& def_) {{\n",
formatMethImpl("hash")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
os << "}\n";

// generate is_same_st()
os << formatv(
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
formatMethImpl("is_same_st")
);
os << formatv(
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
" &&b_ = rhs_.cast_final_safe<{0}>();\n"
" static_cast<void>(a_);\n"
" static_cast<void>(b_);\n",
className
);
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
os << "}\n";

// generate props()
os << formatv(
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
formatMethImpl("props")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
os << "}\n";

// generate make_name()
os << formatv(
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
os << "}\n";

os << "} // anonymous namespace\n";

methods.push_back("hash");
methods.push_back("is_same_st");
methods.push_back("props");
methods.push_back("make_name");
}
if (!methods.empty()) {
os << formatv(
"OP_TRAIT_REG({0}, {0})", op.getCppClassName()
);
for (auto&& i : methods) {
os << formatv(
"\n .{0}({1})", i, formatMethImpl(i)
);
}
os << ";\n\n";
}
}
} // namespace

bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) {
foreach_operator(keeper, [&](MgbOp& op) {
OpDefEmitter emitter(op, os);
emitter.emit_header();
emitter.emit_tpl_spl();
});
return false;
}

bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) {
foreach_operator(keeper, [&](MgbOp& op) {
OpDefEmitter emitter(op, os);
emitter.emit_body();
});
return false;
}
} // namespace mlir::tblgen

+ 21
- 0
imperative/tablegen/targets/cpp_class.h View File

@@ -0,0 +1,21 @@
/**
* \file imperative/tablegen/targets/cpp_class.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "../helper.h"

namespace mlir::tblgen {

bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper);

bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper);

} // namespace mlir::tblgen

+ 142
- 0
imperative/tablegen/targets/pybind11.cpp View File

@@ -0,0 +1,142 @@
/**
* \file imperative/tablegen/targets/pybind11.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./pybind11.h"
#include "../emitter.h"

namespace mlir::tblgen {
namespace {
class OpDefEmitter final: public EmitterBase {
public:
OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
EmitterBase(os_, env_), op(op_) {}

void emit();
private:
MgbOp& op;
};

void OpDefEmitter::emit() {
auto className = op.getCppClassName();
os << formatv(
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
className
);
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID =
llvm::cast<MgbEnumAttr>(aliasBase)
.getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = env().enumAlias;
auto&& iter = enumAlias.find(enumID);
if (iter == enumAlias.end()) {
os << formatv(
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
className, attr->getEnumName()
);
std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) {
os << formatv(
"\n .value(\"{2}\", {0}::{1}::{2})",
className, attr->getEnumName(), i
);
body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};",
className, attr->getEnumName(), i
));
}
if (attr->getEnumCombinedFlag()) {
//! define operator |
os << formatv(
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));"
"\n })",
className, attr->getEnumName());
//! define operator &
os << formatv(
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));"
"\n })",
className, attr->getEnumName());
}
os << formatv(
"\n .def(py::init([](const std::string& in) {"
"\n auto&& str = normalize_enum(in);"
"\n {0}"
"\n throw py::cast_error(\"invalid enum value \" + in);"
"\n }));\n",
llvm::join(body, "\n ")
);
os << formatv(
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
className, attr->getEnumName()
);
enumAlias.emplace(enumID,
std::make_pair(className, attr->getEnumName()));
} else {
os << formatv(
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
className, attr->getEnumName(),
iter->second.first, iter->second.second
);
}
}
}
// generate op class binding
os << formatv("{0}Inst", className);
bool hasDefaultCtor = op.getMgbAttributes().empty();
if (!hasDefaultCtor) {
os << "\n .def(py::init<";
std::vector<llvm::StringRef> targs;
for (auto &&i : op.getMgbAttributes()) {
targs.push_back(i.attr.getReturnType());
}
os << llvm::join(targs, ", ");
os << ", std::string>()";
for (auto &&i : op.getMgbAttributes()) {
os << formatv(", py::arg(\"{0}\")", i.name);
auto defaultValue = i.attr.getDefaultValue();
if (!defaultValue.empty()) {
os << formatv(" = {0}", defaultValue);
} else {
hasDefaultCtor = true;
}
}
os << ", py::arg(\"scope\") = {})";
}
if (hasDefaultCtor) {
os << "\n .def(py::init<>())";
}
for (auto &&i : op.getMgbAttributes()) {
os << formatv(
"\n .def_readwrite(\"{0}\", &{1}::{0})",
i.name, className
);
}
os << ";\n\n";
}
} // namespace

bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) {
Environment env;
using namespace std::placeholders;
foreach_operator(keeper, [&](MgbOp& op) {
OpDefEmitter(op, os, env).emit();
});
return false;
}
} // namespace mlir::tblgen

+ 19
- 0
imperative/tablegen/targets/pybind11.h View File

@@ -0,0 +1,19 @@
/**
* \file imperative/tablegen/targets/pybind11.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "../helper.h"

namespace mlir::tblgen {

bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper);

} // namespace mlir::tblgen

+ 314
- 0
imperative/tablegen/targets/python_c_extension.cpp View File

@@ -0,0 +1,314 @@
/**
* \file imperative/tablegen/targets/python_c_extension.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "python_c_extension.h"
#include "../emitter.h"

namespace mlir::tblgen {
namespace {
struct Initproc {
std::string func;
Initproc(std::string&& s): func(std::move(s)) {}
std::string operator()(std::string argument) {
return formatv("{0}({1})", func, argument);
}
};

class OpDefEmitter: public EmitterBase {
public:
OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
EmitterBase(os_, env_), op(op_) {
ctx.withSelf(op.getCppClassName());
}

Initproc emit();
private:
void emit_class();
void emit_py_init();
void emit_py_getsetters();
Initproc emit_initproc();

MgbOp& op;
std::vector<Initproc> subclasses;
mlir::tblgen::FmtContext ctx;
};

class EnumAttrEmitter: public EmitterBase {
public:
EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_):
EmitterBase(os_, env_), attr(attr_) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
ctx.addSubst("opClass", parent);
ctx.addSubst("enumClass", attr->getEnumName());
firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second;
}

Initproc emit();
protected:
void emit_tpl_spl();
Initproc emit_initproc();

MgbEnumAttr* attr;
bool firstOccur;
mlir::tblgen::FmtContext ctx;
};

Initproc EnumAttrEmitter::emit() {
emit_tpl_spl();
return emit_initproc();
}

void EnumAttrEmitter::emit_tpl_spl() {
if (!firstOccur) return;

os << tgfmt(
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n",
&ctx);

os << tgfmt(
"template<> const char* $enumTpl<$opClass::$enumClass>::name = "
"\"$opClass.$enumClass\";\n",
&ctx);

if (attr->getEnumCombinedFlag()) {
os << tgfmt(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods={};\n",
&ctx);
os << tgfmt(R"(
template<> struct EnumTrait<$opClass::$enumClass> {
static constexpr bool is_bit_combined = true;
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1;
};
)", &ctx, attr->getEnumMembers().size());
}

auto str2type = [&](auto&& i) -> std::string {
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i);
};
os << tgfmt(R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
$enumTpl<$opClass::$enumClass>::str2type = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", "));

auto type2str = [&](auto&& i) -> std::string {
return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i);
};
os << tgfmt(R"(
template<> std::unordered_map<$opClass::$enumClass, std::string>
$enumTpl<$opClass::$enumClass>::type2str = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", "));
}

Initproc EnumAttrEmitter::emit_initproc() {
std::string initproc = formatv("_init_py_{0}_{1}",
ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass"));

os << tgfmt(R"(
void $0(PyTypeObject& py_type) {
auto& e_type = $enumTpl<$opClass::$enumClass>::type;
)", &ctx, initproc);

if (firstOccur) {
os << tgfmt(R"(
e_type = {PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass";
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "$opClass.$enumClass";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr;
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare;
)", &ctx);
if (attr->getEnumCombinedFlag()) {
// only bit combined enum could new instance because bitwise operation,
// others should always use singleton
os << tgfmt(R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
e_type.tp_as_number = &number_method;
)", &ctx);
}

os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n";


for (auto&& i : attr->getEnumMembers()) {
os << tgfmt(R"({
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0);
PyType_Modified(&e_type);
})", &ctx, i);
}
}

os << tgfmt(R"(
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)", &ctx);
os << "}\n";
return initproc;
}

Initproc OpDefEmitter::emit() {
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
}
}

emit_class();
emit_py_init();
emit_py_getsetters();
return emit_initproc();
}

void OpDefEmitter::emit_class() {
os << tgfmt(R"(
PyOpDefBegin($_self) // {
static PyGetSetDef py_getsetters[];
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd($_self)
)", &ctx);
}

void OpDefEmitter::emit_py_init() {
std::string initBody;
if (!op.getMgbAttributes().empty()) {
initBody += "static const char* kwlist[] = {";

std::vector<llvm::StringRef> attr_name_list;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
attr_name_list.push_back(attr.name);
});
attr_name_list.push_back("scope");

llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr);
});
initBody += "NULL};\n";
initBody += " PyObject ";
auto initializer = [&](auto&& attr) -> std::string {
return formatv("*{0} = NULL", attr);
};
initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
// an extra slot created for name
initBody += std::string(attr_name_list.size(), 'O');
initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv(", &{0}", attr);
});
initBody += "))\n";
initBody += " return -1;\n";

llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += tgfmt(R"(
if ($0) {
try {
reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
pyobj_convert_generic<decltype($_self::$0)>::from($0);
} CATCH_ALL(-1)
}
)", &ctx, attr.name);
});

initBody += tgfmt(R"(
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(pyobj_convert_generic<std::string>::from(scope));
} CATCH_ALL(-1)
}
)", &ctx);

}
initBody += "\n return 0;";


os << tgfmt(R"(
int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
$0
}
)", &ctx, initBody);
}

void OpDefEmitter::emit_py_getsetters() {
auto f = [&](auto&& attr) -> std::string {
return tgfmt(
"{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
&ctx, attr.name);
};
os << tgfmt(R"(
PyGetSetDef PyOp($_self)::py_getsetters[] = {
$0
{NULL} /* Sentinel */
};
)", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
}

Initproc OpDefEmitter::emit_initproc() {
std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
std::string subclass_init_call;
for (auto&& i : subclasses) {
subclass_init_call += formatv(" {0};\n", i("py_type"));
}
os << tgfmt(R"(
void $0(py::module m) {
using py_op = PyOp($_self);
auto& py_type = PyOpType($_self);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
py_type.tp_basicsize = sizeof(PyOp($_self));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "$_self";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
$1
PyType_Modified(&py_type);
m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
}
)", &ctx, initproc, subclass_init_call);
return initproc;
}
} // namespace

bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) {
Environment env;
using namespace std::placeholders;
std::vector<Initproc> initprocs;
foreach_operator(keeper, [&](MgbOp& op) {
initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
});
os << "#define INIT_ALL_OP(m)";
for(auto&& init : initprocs) {
os << formatv(" \\\n {0};", init("m"));
}
os << "\n";
return false;
}
} // namespace mlir::tblgen

+ 19
- 0
imperative/tablegen/targets/python_c_extension.h View File

@@ -0,0 +1,19 @@
/**
* \file imperative/tablegen/targets/python_c_extension.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "../helper.h"

namespace mlir::tblgen {

bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper);

} // namespace mlir::tblgen

+ 1
- 1
sdk/load-and-run/src/mgblar.cpp View File

@@ -709,7 +709,7 @@ void run_test_st(Args &env) {
strategy = S::PROFILE; strategy = S::PROFILE;
} }
} else if (env.use_fast_run) { } else if (env.use_fast_run) {
strategy = S::PROFILE | S::OPTMIZED;
strategy = S::PROFILE | S::OPTIMIZED;
} else if (env.reproducible) { } else if (env.reproducible) {
strategy = S::HEURISTIC | S::REPRODUCIBLE; strategy = S::HEURISTIC | S::REPRODUCIBLE;
} }


+ 3
- 1
src/core/include/megbrain/utils/thread_impl_1.h View File

@@ -365,14 +365,16 @@ namespace mgb {
if (!m_free_task_block.empty()) { if (!m_free_task_block.empty()) {
ret = std::move(m_free_task_block.back()); ret = std::move(m_free_task_block.back());
m_free_task_block.pop_back(); m_free_task_block.pop_back();
break;
} else if (m_block_quota > 0) { } else if (m_block_quota > 0) {
ret = std::make_unique<TaskBlock>(); ret = std::make_unique<TaskBlock>();
m_block_quota--; m_block_quota--;
break;
} else { } else {
m_cv.wait(m_mutex); m_cv.wait(m_mutex);
continue; continue;
} }
} while (false);
} while (true);
ret->first_tid = m_new_block_first_tid; ret->first_tid = m_new_block_first_tid;
m_new_block_first_tid += BLOCK_SIZE; m_new_block_first_tid += BLOCK_SIZE;
ret->prev = prev; ret->prev = prev;


+ 2
- 2
src/core/include/megbrain/version.h View File

@@ -12,8 +12,8 @@
#pragma once #pragma once


#define MGB_MAJOR 8 #define MGB_MAJOR 8
#define MGB_MINOR 9999
#define MGB_PATCH 0
#define MGB_MINOR 10
#define MGB_PATCH 1
//! whether it is development version //! whether it is development version
#ifndef MGB_IS_DEV #ifndef MGB_IS_DEV
#define MGB_IS_DEV 0 #define MGB_IS_DEV 0


+ 2
- 1
src/gopt/impl/inference.cpp View File

@@ -1565,7 +1565,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
if (new_inp[i]->shape()[1] % 4 != 0) { if (new_inp[i]->shape()[1] % 4 != 0) {
can_exec_cd4 = false; can_exec_cd4 = false;
} }
//! cd4 elemwise with scaler is supported
//! cd4 elemwise with scaler is unsupported
} else if (!new_inp[i]->shape().is_scalar()) { } else if (!new_inp[i]->shape().is_scalar()) {
can_exec_cd4 = false; can_exec_cd4 = false;
} }
@@ -1627,6 +1627,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::WarpPerspectiveForward::typeinfo()] = replace_func[opr::WarpPerspectiveForward::typeinfo()] =
replace_warp_perspective_opr; replace_warp_perspective_opr;


+ 51
- 2
src/gopt/test/inference.cpp View File

@@ -1265,6 +1265,55 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
} }


TEST(TestGoptInference, ConvertFormatNHWCD4TypeCvt) {
NaiveMegDNNHandleScope naive_megdnn_handle;

HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_x = gen({8, 8, 8, 8}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);

opr::Convolution::Param param;

param.pad_h = param.pad_w = 0;
auto w1 = mkcvar("w1", {8, 8, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param),
tcvt1 = opr::TypeCvt::make(conv1, dtype::Float16());
auto w2 = mkcvar("w2", {8, 8, 3, 3}),
conv2 = opr::Convolution::make(x, w2, param),
tcvt2 = opr::TypeCvt::make(conv2, dtype::Float16());
auto y = opr::Elemwise::make({tcvt1, tcvt2}, opr::Elemwise::Param::Mode::ADD);

SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);

ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
find_opr<opr::Convolution>(y_opt).param().format);

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNHWCD4TypeCvt.json"));

HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);

*host_x = *gen({8, 8, 16, 16}, cn);
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
}

TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) {
// hwcd4 is only supported in naive handle // hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle; NaiveMegDNNHandleScope naive_megdnn_handle;
@@ -1707,8 +1756,8 @@ TEST(TestGoptInference, FastProfileCache) {
using S = opr::Convolution::ExecutionPolicy::Strategy; using S = opr::Convolution::ExecutionPolicy::Strategy;
ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy);
gopt::modify_opr_algo_strategy_inplace({z + 2.3f}, gopt::modify_opr_algo_strategy_inplace({z + 2.3f},
S::PROFILE | S::OPTMIZED);
ASSERT_EQ(S::PROFILE | S::OPTMIZED, conv.execution_policy().strategy);
S::PROFILE | S::OPTIMIZED);
ASSERT_EQ(S::PROFILE | S::OPTIMIZED, conv.execution_policy().strategy);
} }


TEST(TestGoptInference, AlgoWorkspaceLimit) { TEST(TestGoptInference, AlgoWorkspaceLimit) {


+ 3
- 3
src/opr/impl/dnn/dnn.oprdecl View File

@@ -6,7 +6,7 @@ decl_opr('Convolution',
'convolution kernel in ' 'convolution kernel in '
'(out channel, in channel, kern row, kern col) format')], '(out channel, in channel, kern row, kern col) format')],
params=[('param', 'ConvolutionV0'), params=[('param', 'ConvolutionV0'),
('execution_polity', 'ExecutionPolicy')],
('execution_polity', 'ExecutionPolicyV0')],
desc='batched convolution on channeled 2D images') desc='batched convolution on channeled 2D images')


decl_opr('Convolution', decl_opr('Convolution',
@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData',
'convolution kernel in ' 'convolution kernel in '
'(out channel, in channel, kern row, kern col) format')], '(out channel, in channel, kern row, kern col) format')],
params=[('param', 'ConvolutionV0'), params=[('param', 'ConvolutionV0'),
('execution_polity', 'ExecutionPolicy')],
('execution_polity', 'ExecutionPolicyV0')],
body=[ body=[
'a, b = all_inputs', 'a, b = all_inputs',
'all_inputs = [b, a]' 'all_inputs = [b, a]'
@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward',
Doc('bias', 'bias'), Doc('bias', 'bias'),
], ],
params=[('param', 'ConvBiasV1'), params=[('param', 'ConvBiasV1'),
('execution_policy', 'ExecutionPolicy')],
('execution_policy', 'ExecutionPolicyV0')],
desc=('activation(convolution(src, filter) + bias) with specified ' desc=('activation(convolution(src, filter) + bias) with specified '
'dtype'), 'dtype'),
has_out_dtype=True) has_out_dtype=True)


+ 1
- 1
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -283,7 +283,7 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
static bool algo_attribute_match_strategy(AlgoAttribute attribute, static bool algo_attribute_match_strategy(AlgoAttribute attribute,
ExecutionStrategy selected_strategy) { ExecutionStrategy selected_strategy) {
bool ret = true; bool ret = true;
if (selected_strategy & ExecutionStrategy::OPTMIZED) {
if (selected_strategy & ExecutionStrategy::OPTIMIZED) {
ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute));
} else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { } else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) {
ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute);


+ 5
- 5
src/opr/test/dnn/convolution.cpp View File

@@ -357,7 +357,7 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) {
#else #else
for (auto strategy : for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
@@ -444,7 +444,7 @@ TEST(TestOprDNN, ConvolutionExePolicy) {
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) {
#else #else
for (auto strategy : for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
@@ -1717,7 +1717,7 @@ TEST(TestOprDNN, LocalShareForwardExecPolicy) {
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) {
#else #else
for (auto strategy : for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
@@ -1828,7 +1828,7 @@ TEST(TestOprDNN, DeformableConvForward) {
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) {
#else #else
for (auto strategy : for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
@@ -1997,7 +1997,7 @@ TEST(TestOprDNN, BatchConvBiasForward) {
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) {
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) {
#else #else
for (auto strategy : for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {


+ 2
- 5
src/serialization/impl/extern_c_opr.cpp View File

@@ -290,11 +290,8 @@ ExternCOprRunner::ExternCOprRunner(std::string& name,
m_dump_name{name}, m_dump_name{name},
m_param{nullptr} { m_param{nullptr} {
mgb_assert(m_desc->size == sizeof(MGBOprDesc), mgb_assert(m_desc->size == sizeof(MGBOprDesc),
"invalid MGBOprDesc size: expect=%zu got=%u, may caused by "
"extern_c_opr.h mismatch, please confirm that the "
"extern_c_opr.h used when compiling the loader is consistent "
"with the runtime caller build used",
sizeof(MGBOprDesc), m_desc->size);
"invalid MGBOprDesc size: expect=%zu got=%u", sizeof(MGBOprDesc),
m_desc->size);
for (auto i : inputs) { for (auto i : inputs) {
add_input({i}); add_input({i});
} }


+ 8
- 0
tools/mlir/mgb-file-check/CMakeLists.txt View File

@@ -0,0 +1,8 @@
add_custom_command(
OUTPUT link_sh
COMMAND ${CMAKE_COMMAND} -E create_symlink
${PROJECT_SOURCE_DIR}/tools/mlir/mgb-file-check/mgb-file-check.sh
${PROJECT_BINARY_DIR}/tools/mlir/mgb-file-check/mgb-file-check
)

add_custom_target(mgb-file-check DEPENDS link_sh)

+ 3
- 0
tools/mlir/mgb-file-check/mgb-file-check.sh View File

@@ -0,0 +1,3 @@
#!/bin/bash -e

FileCheck --enable-var-scope --dump-input=fail "$@"

+ 23
- 0
tools/mlir/mgb-opt/CMakeLists.txt View File

@@ -0,0 +1,23 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
set(LIBS
${dialect_libs}
${conversion_libs}
LLVMSupport
MLIROptLib
MLIRIR
MLIRPass
MLIRSupport
)
add_executable(mgb-opt mgb-opt.cpp)

target_include_directories(
mgb-opt
PRIVATE ${MLIR_LLVM_INCLUDE_DIR} ${PROJECT_SOURCE_DIR}/src/jit/include
${PROJECT_BINARY_DIR}/src/jit/include)

add_dependencies(mgb-opt mgb_dialect)

target_link_libraries(mgb-opt PRIVATE ${LIBS} megbrain megdnn ${MGE_CUDA_LIBS})

llvm_update_compile_flags(mgb-opt)

+ 85
- 0
tools/mlir/mgb-opt/mgb-opt.cpp View File

@@ -0,0 +1,85 @@
/**
* \file tools/mlir/mgb-opt/mgb-opt.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h"

#include <llvm/Support/CommandLine.h>
#include <llvm/Support/InitLLVM.h>
#include <llvm/Support/PrettyStackTrace.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/ToolOutputFile.h>

#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/AsmState.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h>

using namespace llvm;
using namespace mlir;

//! TODO: Implement a custom MlirOptMain that supports the following flags.
static cl::opt<bool> print_mlir{
"print-mlir",
cl::desc("Prints MLIR IR after translation"),
cl::init(false),
};

static cl::list<std::string> input_values{
"input-value",
cl::desc("Input shapes and optional values"),
cl::ZeroOrMore,
};

static cl::opt<std::string> input_values_file{
"input-value-file",
cl::desc("Provides a file for input shapes and optional values (see "
"ParseToVariantListFromFile in vm_util.h for details)"),
cl::init(""),
};

static cl::opt<bool> run{
"run",
cl::desc("Runs the module (vs. just compiling and verifing)"),
cl::init(true),
};

static cl::list<std::string> run_args{
"run-arg",
cl::desc("Argument passed to the execution flag parser"),
cl::ZeroOrMore,
};

namespace mgb {
namespace jit {
void register_test_mgb_to_affine_lowering_pass();
void register_test_affine_to_llvm_lowering_pass();
} // namespace jit
} // namespace mgb

int main(int argc, char** argv) {
mlir::registerAllPasses();

mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
registry.insert<mgb::jit::MgbDialect>();

mgb::jit::register_test_mgb_to_affine_lowering_pass();
mgb::jit::register_test_affine_to_llvm_lowering_pass();

return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver", registry));
}

+ 8
- 2
tools/param_defs/mgb_opr_param_defs.py View File

@@ -41,8 +41,14 @@ pdef('PersistentOutputStorage').add_fields(
Doc('REPRODUCIBLE', Doc('REPRODUCIBLE',
'when profile or heuristic algo selection it require the algos' 'when profile or heuristic algo selection it require the algos'
'must be reproducible'), 'must be reproducible'),
Doc('OPTMIZED',
'profile require algos are optmized to achieve fast-profile')).
Doc('OPTIMIZED',
'profile require algos are optmized to achieve fast-profile'),
default=('HEURISTIC',),
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'),
(('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'),
(('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'),
(('OPTIMIZED',), 'OPTMIZED'),
]).
add_fields('uint64', add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'), Doc('workspace_limit', 'workspace limit in bytes'),
str(2**64-1)+'ull')) str(2**64-1)+'ull'))


Loading…
Cancel
Save