Compare commits

...

6 Commits

Author SHA1 Message Date
  Megvii Engine Team cc6084b4d5 revert: chore(mge/misc): api converage 2 years ago
  Megvii Engine Team 2a4295fbf6 feat(ci): update submodules 2 years ago
  Megvii Engine Team 75c413e3bb feat(imperative): region restrictd conv support bias in python 2 years ago
  Megvii Engine Team 6f9f25a882 fix(gopt): fix global layout transform fold conv typecvt 2 years ago
  Megvii Engine Team cc9f743ed5 fix(serilization): fix elemwise multitype compatibilty in v1.11.0 2 years ago
  wanchenxi 7b3f9e85bd chore(release): bump version 2 years ago
38 changed files with 266 additions and 184 deletions
Unified View
  1. +1
    -1
      dnn/src/cuda/conv_bias/conv_nchwqs8.cpp
  2. +9
    -9
      imperative/python/megengine/autodiff/grad_manager.py
  3. +7
    -7
      imperative/python/megengine/data/dataloader.py
  4. +2
    -2
      imperative/python/megengine/data/dataset/vision/coco.py
  5. +6
    -6
      imperative/python/megengine/data/dataset/vision/mnist.py
  6. +0
    -0
      imperative/python/megengine/data/transform/vision/functional.py
  7. +1
    -1
      imperative/python/megengine/data/transform/vision/transform.py
  8. +2
    -1
      imperative/python/megengine/distributed/__init__.py
  9. +5
    -5
      imperative/python/megengine/distributed/group.py
  10. +6
    -6
      imperative/python/megengine/distributed/helper.py
  11. +1
    -1
      imperative/python/megengine/distributed/launcher.py
  12. +9
    -9
      imperative/python/megengine/distributed/server.py
  13. +3
    -1
      imperative/python/megengine/functional/__init__.py
  14. +0
    -8
      imperative/python/megengine/functional/debug_param.py
  15. +1
    -1
      imperative/python/megengine/functional/elemwise.py
  16. +49
    -44
      imperative/python/megengine/functional/nn.py
  17. +2
    -2
      imperative/python/megengine/functional/tensor_cache.py
  18. +2
    -0
      imperative/python/megengine/functional/utils.py
  19. +3
    -3
      imperative/python/megengine/hub/fetcher.py
  20. +3
    -3
      imperative/python/megengine/jit/tracing.py
  21. +14
    -14
      imperative/python/megengine/logger.py
  22. +6
    -3
      imperative/python/megengine/module/conv.py
  23. +8
    -8
      imperative/python/megengine/module/rnn.py
  24. +1
    -1
      imperative/python/megengine/quantization/observer.py
  25. +2
    -2
      imperative/python/megengine/serialization.py
  26. +0
    -2
      imperative/python/megengine/utils/deprecation.py
  27. +2
    -2
      imperative/python/megengine/utils/module_stats.py
  28. +2
    -2
      imperative/python/test/unit/core/test_util.py
  29. +1
    -1
      imperative/python/test/unit/distributed/test_distributed.py
  30. +43
    -18
      imperative/python/test/unit/functional/test_functional.py
  31. +2
    -2
      src/core/include/megbrain/version.h
  32. +27
    -11
      src/gopt/impl/folding_conv_typecvt.cpp
  33. +3
    -0
      src/gopt/impl/framework.cpp
  34. +1
    -1
      src/opr/impl/nn_int.sereg.h
  35. +6
    -5
      src/serialization/impl/opr_registry.cpp
  36. +2
    -1
      src/serialization/include/megbrain/serialization/opr_registry.h
  37. +33
    -0
      src/serialization/include/megbrain/serialization/sereg.h
  38. +1
    -1
      third_party/cutlass

+ 1
- 1
dnn/src/cuda/conv_bias/conv_nchwqs8.cpp View File

@@ -110,7 +110,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
bool is_version_ok = CUDNN_VERSION >= 7500; bool is_version_ok = CUDNN_VERSION >= 7500;
bool is_dtype_ok = bool is_dtype_ok =
(args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 ||
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 &&
args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm));
bool is_bias_ok = bool is_bias_ok =
args.bias_layout->ndim == 0 || args.bias_layout->ndim == 0 ||


+ 9
- 9
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -19,11 +19,11 @@ logger = get_logger(__name__)
backwarding_grad_manager = None backwarding_grad_manager = None




def _get_backwarding_grad_manager():
def get_backwarding_grad_manager():
return backwarding_grad_manager return backwarding_grad_manager




class _AttachSpec:
class AttachSpec:
__slots__ = "tensor", "callbacks" __slots__ = "tensor", "callbacks"




@@ -118,7 +118,7 @@ class GradManager:
""" """


def __init__(self): def __init__(self):
self._attach_specs = {} # id(Tensor) -> _AttachSpec
self._attach_specs = {} # id(Tensor) -> AttachSpec
self._recording = False self._recording = False
self._grad = None self._grad = None
self._after_backward_callback = [] self._after_backward_callback = []
@@ -200,7 +200,7 @@ class GradManager:
if self is not None: if self is not None:
del self._attach_specs[key] del self._attach_specs[key]


spec = _AttachSpec()
spec = AttachSpec()
spec.tensor = weakref.ref(tensor, deleter) spec.tensor = weakref.ref(tensor, deleter)
spec.callbacks = [] spec.callbacks = []
return spec return spec
@@ -354,22 +354,22 @@ class GradManager:


def __or__(self, other): def __or__(self, other):
if isinstance(other, GradManager): if isinstance(other, GradManager):
return _GradManagerGroup([self, other])
return GradManagerGroup([self, other])
return NotImplemented return NotImplemented


__ror__ = __or__ __ror__ = __or__




class _GradManagerGroup:
class GradManagerGroup:
def __init__(self, gms) -> None: def __init__(self, gms) -> None:
self._gms = list(gms) self._gms = list(gms)


def merge_with(self, other): def merge_with(self, other):
if isinstance(other, GradManager): if isinstance(other, GradManager):
other = _GradManagerGroup([other])
elif not isinstance(other, _GradManagerGroup):
other = GradManagerGroup([other])
elif not isinstance(other, GradManagerGroup):
return NotImplemented return NotImplemented
return _GradManagerGroup([*self._gms, *other._gms])
return GradManagerGroup([*self._gms, *other._gms])


__or__ = merge_with __or__ = merge_with
__ror__ = merge_with __ror__ = merge_with


+ 7
- 7
imperative/python/megengine/data/dataloader.py View File

@@ -34,7 +34,7 @@ logger = get_logger(__name__)
GLOBAL_TIMEOUT = 5 GLOBAL_TIMEOUT = 5




def _raise_timeout_error():
def raise_timeout_error():
raise RuntimeError("dataloader timeout") raise RuntimeError("dataloader timeout")




@@ -191,7 +191,7 @@ class DataLoader:
) )




class _PreLoader:
class PreLoader:
def __init__(self, loader, preload): def __init__(self, loader, preload):
self.dataset = loader.dataset self.dataset = loader.dataset
self.sampler = loader.sampler self.sampler = loader.sampler
@@ -319,7 +319,7 @@ class _ParallelDataLoaderIter:
if success: if success:
return data return data
else: else:
_raise_timeout_error()
raise_timeout_error()
else: else:
while True: while True:
success, data = self._try_get_data() success, data = self._try_get_data()
@@ -417,7 +417,7 @@ class _ParallelDataLoaderIter:
self._shutdown_workers() self._shutdown_workers()




class _BaseMapDataLoaderIter(_PreLoader):
class _BaseMapDataLoaderIter(PreLoader):
def __init__(self, loader, preload): def __init__(self, loader, preload):
super().__init__(loader, preload) super().__init__(loader, preload)


@@ -510,7 +510,7 @@ def get_worker_info():
return _worker_info return _worker_info




class _BaseStreamDataLoaderIter(_PreLoader):
class _BaseStreamDataLoaderIter(PreLoader):
def __init__(self, loader, preload): def __init__(self, loader, preload):
super().__init__(loader, preload) super().__init__(loader, preload)
self.dataset_iter = iter(self.dataset) self.dataset_iter = iter(self.dataset)
@@ -552,7 +552,7 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
timer.cancel() timer.cancel()
waited_time = time.time() - start_time waited_time = time.time() - start_time
if waited_time > self.timeout: if waited_time > self.timeout:
_raise_timeout_error()
raise_timeout_error()
return raw_data return raw_data


def _get_next_batch(self): def _get_next_batch(self):
@@ -583,7 +583,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoad
place_holder = [next(self.dataset_iter)] place_holder = [next(self.dataset_iter)]
waited_time = time.time() - start_time waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout: if self.timeout > 0 and waited_time > self.timeout:
_raise_timeout_error()
raise_timeout_error()
place_holder = self._get_remaind_data(place_holder) place_holder = self._get_remaind_data(place_holder)
else: else:
place_holder = next(self._sampler_iter) place_holder = next(self._sampler_iter)


+ 2
- 2
imperative/python/megengine/data/dataset/vision/coco.py View File

@@ -21,7 +21,7 @@ def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)




def _has_valid_annotation(anno, order):
def has_valid_annotation(anno, order):
# if it"s empty, there is no annotation # if it"s empty, there is no annotation
if len(anno) == 0: if len(anno) == 0:
return False return False
@@ -101,7 +101,7 @@ class COCO(VisionDataset):
anno = [ anno = [
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0
] ]
if _has_valid_annotation(anno, order):
if has_valid_annotation(anno, order):
ids.append(img_id) ids.append(img_id)
self.img_to_anns[img_id] = anno self.img_to_anns[img_id] = anno
else: else:


+ 6
- 6
imperative/python/megengine/data/dataset/vision/mnist.py View File

@@ -140,17 +140,17 @@ class MNIST(VisionDataset):
# load raw files and transform them into meta data and datasets Tuple(np.array) # load raw files and transform them into meta data and datasets Tuple(np.array)
logger.info("process the raw files of %s set...", "train" if train else "test") logger.info("process the raw files of %s set...", "train" if train else "test")
if train: if train:
meta_data_images, images = _parse_idx3(
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[0]) os.path.join(self.root, self.raw_file_name[0])
) )
meta_data_labels, labels = _parse_idx1(
meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[1]) os.path.join(self.root, self.raw_file_name[1])
) )
else: else:
meta_data_images, images = _parse_idx3(
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[2]) os.path.join(self.root, self.raw_file_name[2])
) )
meta_data_labels, labels = _parse_idx1(
meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[3]) os.path.join(self.root, self.raw_file_name[3])
) )


@@ -161,7 +161,7 @@ class MNIST(VisionDataset):
self.arrays = (images, labels.astype(np.int32)) self.arrays = (images, labels.astype(np.int32))




def _parse_idx3(idx3_file):
def parse_idx3(idx3_file):
# parse idx3 file to meta data and data in numpy array (images) # parse idx3 file to meta data and data in numpy array (images)
logger.debug("parse idx3 file %s ...", idx3_file) logger.debug("parse idx3 file %s ...", idx3_file)
assert idx3_file.endswith(".gz") assert idx3_file.endswith(".gz")
@@ -187,7 +187,7 @@ def _parse_idx3(idx3_file):
return meta_data, images return meta_data, images




def _parse_idx1(idx1_file):
def parse_idx1(idx1_file):
# parse idx1 file to meta data and data in numpy array (labels) # parse idx1 file to meta data and data in numpy array (labels)
logger.debug("parse idx1 file %s ...", idx1_file) logger.debug("parse idx1 file %s ...", idx1_file)
assert idx1_file.endswith(".gz") assert idx1_file.endswith(".gz")


imperative/python/megengine/data/transform/vision/_functional.py → imperative/python/megengine/data/transform/vision/functional.py View File


+ 1
- 1
imperative/python/megengine/data/transform/vision/transform.py View File

@@ -7,7 +7,7 @@ import cv2
import numpy as np import numpy as np


from megengine.data.transform import Transform from megengine.data.transform import Transform
from megengine.data.transform.vision import _functional as F
from megengine.data.transform.vision import functional as F


__all__ = [ __all__ = [
"VisionTransform", "VisionTransform",


+ 2
- 1
imperative/python/megengine/distributed/__init__.py View File

@@ -2,6 +2,7 @@
from mprop import mproperty from mprop import mproperty


from ..core._imperative_rt.core2 import group_end, group_start from ..core._imperative_rt.core2 import group_end, group_start
from . import group
from .group import ( from .group import (
WORLD, WORLD,
Group, Group,
@@ -19,7 +20,7 @@ from .group import (
) )
from .helper import bcast_list_, make_allreduce_cb, synchronized from .helper import bcast_list_, make_allreduce_cb, synchronized
from .launcher import launcher from .launcher import launcher
from .server import Server
from .server import Client, Server




@mproperty @mproperty


+ 5
- 5
imperative/python/megengine/distributed/group.py View File

@@ -7,10 +7,10 @@ from mprop import mproperty


from ..device import _sh, set_default_device, what_is_xpu from ..device import _sh, set_default_device, what_is_xpu
from ..random import seed from ..random import seed
from .server import Server, _Client
from .server import Client, Server




class _StaticData:
class StaticData:
server = None server = None
client = None client = None
master_ip = None master_ip = None
@@ -139,13 +139,13 @@ def init_process_group(


global _sd global _sd
assert _sd is None, "init_process_group should be called only once" assert _sd is None, "init_process_group should be called only once"
_sd = _StaticData()
_sd = StaticData()


assert world_size > 1 assert world_size > 1
assert rank >= 0 and rank < world_size assert rank >= 0 and rank < world_size
assert port > 0 assert port > 0


_sd.client = _Client(master_ip, port)
_sd.client = Client(master_ip, port)
_sd.master_ip = master_ip _sd.master_ip = master_ip
_sd.py_server_port = port _sd.py_server_port = port
_sd.mm_server_port = _sd.client.get_mm_server_port() _sd.mm_server_port = _sd.client.get_mm_server_port()
@@ -225,7 +225,7 @@ def get_mm_server_addr() -> Tuple[str, int]:
return _sd.master_ip, _sd.mm_server_port return _sd.master_ip, _sd.mm_server_port




def get_client() -> _Client:
def get_client() -> Client:
r"""Get client of python XML RPC server.""" r"""Get client of python XML RPC server."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.client return _sd.client


+ 6
- 6
imperative/python/megengine/distributed/helper.py View File

@@ -7,7 +7,7 @@ from weakref import WeakSet


import numpy as np import numpy as np


from megengine.autodiff.grad_manager import GradManager, _get_backwarding_grad_manager
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager


from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
@@ -78,7 +78,7 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
return apply(op, *inps, offsets)[0] return apply(op, *inps, offsets)[0]




def _get_offsets(shapes):
def get_offsets(shapes):
offsets = [] offsets = []
offset = 0 offset = 0
for shape in shapes: for shape in shapes:
@@ -108,7 +108,7 @@ def _check_enable_p2p():




def pack_allreduce_split(pack_list, shapes, group, reduce_method): def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = _get_offsets(shapes)
offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val) offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) packed_grads = param_pack_concat(pack_list, offsets, offsets_val)


@@ -119,7 +119,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method):
return grads return grads




class _TensorFuture(Future):
class TensorFuture(Future):
def device(self): def device(self):
raise "Sorry, this tensor is not ready" raise "Sorry, this tensor is not ready"


@@ -234,13 +234,13 @@ class AllreduceCallback:
self._packing_size[dtype] = 0 self._packing_size[dtype] = 0


def __call__(self, param, grad): def __call__(self, param, grad):
gm = _get_backwarding_grad_manager()
gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager) assert isinstance(gm, GradManager)
if gm not in self._marked_gm: if gm not in self._marked_gm:
gm._register_after_backward_callback(self._flush) gm._register_after_backward_callback(self._flush)
self._marked_gm.add(gm) self._marked_gm.add(gm)
self._params.append(param) self._params.append(param)
self._futures_dict[param] = _TensorFuture(ack=False)
self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device) self._grad_origin_device[param] = str(grad.device)




+ 1
- 1
imperative/python/megengine/distributed/launcher.py View File

@@ -10,7 +10,7 @@ from ..device import get_device_count
from ..logger import get_logger from ..logger import get_logger
from .group import _set_machine_ranks, group_barrier, init_process_group from .group import _set_machine_ranks, group_barrier, init_process_group
from .helper import _check_device_initialized, _check_interpreter_status from .helper import _check_device_initialized, _check_interpreter_status
from .server import Server
from .server import Client, Server


WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
"subprocess exited with code 0 but did not return a value" "subprocess exited with code 0 but did not return a value"


+ 9
- 9
imperative/python/megengine/distributed/server.py View File

@@ -12,7 +12,7 @@ from ..core._imperative_rt.utils import create_mm_server
from ..utils.future import Future from ..utils.future import Future




class _Methods:
class Methods:
r"""Distributed Server Method. r"""Distributed Server Method.
Used for exchange information between distributed nodes. Used for exchange information between distributed nodes.


@@ -149,7 +149,7 @@ class _Methods:
return ret return ret




class _ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass




@@ -163,10 +163,10 @@ def _start_server(py_server_port, queue):
""" """
try: try:
mm_server_port = create_mm_server("0.0.0.0", 0) mm_server_port = create_mm_server("0.0.0.0", 0)
server = _ThreadXMLRPCServer(
server = ThreadXMLRPCServer(
("0.0.0.0", py_server_port), logRequests=False, allow_none=True ("0.0.0.0", py_server_port), logRequests=False, allow_none=True
) )
server.register_instance(_Methods(mm_server_port))
server.register_instance(Methods(mm_server_port))
_, py_server_port = server.server_address _, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port)) queue.put((py_server_port, mm_server_port))
server.serve_forever() server.serve_forever()
@@ -196,7 +196,7 @@ class Server:
self.proc.terminate() self.proc.terminate()




class _Client:
class Client:
r"""Distributed Client for distributed training. r"""Distributed Client for distributed training.


Args: Args:
@@ -298,10 +298,10 @@ class _Client:
return self.proxy.bcast_val(val, key, size) return self.proxy.bcast_val(val, key, size)




def _main(port=0, verbose=True):
def main(port=0, verbose=True):
mm_server_port = create_mm_server("0.0.0.0", 0) mm_server_port = create_mm_server("0.0.0.0", 0)
server = _ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose)
server.register_instance(_Methods(mm_server_port))
server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose)
server.register_instance(Methods(mm_server_port))
_, port = server.server_address _, port = server.server_address
print("serving on port", port) print("serving on port", port)
server.serve_forever() server.serve_forever()
@@ -314,4 +314,4 @@ if __name__ == "__main__":
ap.add_argument("-p", "--port", type=int, default=0) ap.add_argument("-p", "--port", type=int, default=0)
ap.add_argument("-v", "--verbose", type=bool, default=True) ap.add_argument("-v", "--verbose", type=bool, default=True)
args = ap.parse_args() args = ap.parse_args()
_main(port=args.port, verbose=args.verbose)
main(port=args.port, verbose=args.verbose)

+ 3
- 1
imperative/python/megengine/functional/__init__.py View File

@@ -1,11 +1,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from . import metric, utils, vision
from .elemwise import * from .elemwise import *
from .math import * from .math import *
from .nn import * from .nn import *
from .tensor import * from .tensor import *
from .utils import *


from . import utils, vision, distributed # isort:skip
from . import distributed # isort:skip


# delete namespace # delete namespace
# pylint: disable=undefined-variable # pylint: disable=undefined-variable


+ 0
- 8
imperative/python/megengine/functional/debug_param.py View File

@@ -21,10 +21,6 @@ _valid_string_option = {
} }




@deprecated(
version="1.10",
reason="use ``megengine.config.benchmark_kernel`` and ``megengine.config.deterministic_kernel`` instead",
)
def get_execution_strategy() -> Strategy: def get_execution_strategy() -> Strategy:
r"""Returns the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` r"""Returns the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul`


@@ -40,10 +36,6 @@ def get_execution_strategy() -> Strategy:
return strategy return strategy




@deprecated(
version="1.10",
reason="use ``megengine.config.benchmark_kernel`` and ``megengine.config.deterministic_kernel`` instead",
)
def set_execution_strategy(option): def set_execution_strategy(option):
r"""Sets the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` r"""Sets the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul`




+ 1
- 1
imperative/python/megengine/functional/elemwise.py View File

@@ -9,7 +9,7 @@ from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import convert_inputs from ..core.tensor.utils import convert_inputs
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_func from ..utils.deprecation import deprecated_func
from ._tensor_cache import get_scalar_one
from .tensor_cache import get_scalar_one


__all__ = [ __all__ = [
"abs", "abs",


+ 49
- 44
imperative/python/megengine/functional/nn.py View File

@@ -43,7 +43,6 @@ from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import _elwise, exp, log, log1p, maximum, minimum from .elemwise import _elwise, exp, log, log1p, maximum, minimum
from .math import max, normalize, sum from .math import max, normalize, sum
from .metric import topk_accuracy
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros


__all__ = [ __all__ = [
@@ -87,7 +86,6 @@ __all__ = [
"softmax", "softmax",
"softplus", "softplus",
"sync_batch_norm", "sync_batch_norm",
"topk_accuracy",
"warp_affine", "warp_affine",
"warp_perspective", "warp_perspective",
"pixel_shuffle", "pixel_shuffle",
@@ -95,7 +93,7 @@ __all__ = [
] ]




def _expand_hw(x):
def expand_hw(x):
# judge int is 5 times faster than judge Sequence # judge int is 5 times faster than judge Sequence
if isinstance(x, int): if isinstance(x, int):
return x, x return x, x
@@ -104,7 +102,7 @@ def _expand_hw(x):
return int(x), int(x) return int(x), int(x)




def _expand_dhw(x):
def expand_dhw(x):
if isinstance(x, int): if isinstance(x, int):
return x, x, x return x, x, x
if isinstance(x, Sequence): if isinstance(x, Sequence):
@@ -246,9 +244,9 @@ def conv2d(
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )


stride_h, stride_w = _expand_hw(stride)
pad_h, pad_w = _expand_hw(padding)
dilate_h, dilate_w = _expand_hw(dilation)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)


sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
@@ -308,9 +306,9 @@ def conv3d(


D, H, W = 0, 1, 2 D, H, W = 0, 1, 2


pad = _expand_dhw(padding)
stride = _expand_dhw(stride)
dilate = _expand_dhw(dilation)
pad = expand_dhw(padding)
stride = expand_dhw(stride)
dilate = expand_dhw(dilation)


sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D( op = builtin.Convolution3D(
@@ -378,10 +376,10 @@ def conv_transpose2d(
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )


stride_h, stride_w = _expand_hw(stride)
pad_h, pad_w = _expand_hw(padding)
output_pad_h, output_pad_w = _expand_hw(output_padding)
dilate_h, dilate_w = _expand_hw(dilation)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
output_pad_h, output_pad_w = expand_hw(output_padding)
dilate_h, dilate_w = expand_hw(dilation)


compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
@@ -479,9 +477,9 @@ def deformable_conv2d(
offset = offset.astype("float32") offset = offset.astype("float32")
mask = mask.astype("float32") mask = mask.astype("float32")


stride_h, stride_w = _expand_hw(stride)
pad_h, pad_w = _expand_hw(padding)
dilate_h, dilate_w = _expand_hw(dilation)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)


compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
@@ -533,9 +531,9 @@ def local_conv2d(
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )


stride_h, stride_w = _expand_hw(stride)
pad_h, pad_w = _expand_hw(padding)
dilate_h, dilate_w = _expand_hw(dilation)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)


# local conv only support "dense" mode, but weight could contain group dimension. # local conv only support "dense" mode, but weight could contain group dimension.
op = builtin.GroupLocal( op = builtin.GroupLocal(
@@ -589,10 +587,10 @@ def conv_transpose3d(
output tensor. output tensor.
""" """
D, H, W = 0, 1, 2 D, H, W = 0, 1, 2
pad = _expand_dhw(padding)
stride = _expand_dhw(stride)
dilate = _expand_dhw(dilation)
output_padding = _expand_dhw(output_padding)
pad = expand_dhw(padding)
stride = expand_dhw(stride)
dilate = expand_dhw(dilation)
output_padding = expand_dhw(output_padding)


sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3DBackwardData( op = builtin.Convolution3DBackwardData(
@@ -677,9 +675,9 @@ def max_pool2d(
""" """
if stride is None: if stride is None:
stride = kernel_size stride = kernel_size
window_h, window_w = _expand_hw(kernel_size)
stride_h, stride_w = _expand_hw(stride)
padding_h, padding_w = _expand_hw(padding)
window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = expand_hw(padding)


op = builtin.Pooling( op = builtin.Pooling(
window_h=window_h, window_h=window_h,
@@ -727,9 +725,9 @@ def avg_pool2d(
""" """
if stride is None: if stride is None:
stride = kernel_size stride = kernel_size
window_h, window_w = _expand_hw(kernel_size)
stride_h, stride_w = _expand_hw(stride)
padding_h, padding_w = _expand_hw(padding)
window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = expand_hw(padding)


op = builtin.Pooling( op = builtin.Pooling(
window_h=window_h, window_h=window_h,
@@ -1725,10 +1723,10 @@ def sliding_window(
stride: stride of the window. Default: 1 stride: stride of the window. Default: 1
dilation: dilation of the window. Default: 1 dilation: dilation of the window. Default: 1
""" """
padding_h, padding_w = _expand_hw(padding)
stride_h, stride_w = _expand_hw(stride)
dilation_h, dilation_w = _expand_hw(dilation)
window_h, window_w = _expand_hw(kernel_size)
padding_h, padding_w = expand_hw(padding)
stride_h, stride_w = expand_hw(stride)
dilation_h, dilation_w = expand_hw(dilation)
window_h, window_w = expand_hw(kernel_size)


op = builtin.Images2Neibs( op = builtin.Images2Neibs(
pad_h=padding_h, pad_h=padding_h,
@@ -1764,11 +1762,11 @@ def sliding_window_transpose(
stride: stride of the window. Default: 1 stride: stride of the window. Default: 1
dilation: dilation of the window. Default: 1 dilation: dilation of the window. Default: 1
""" """
output_h, output_w = _expand_hw(output_size)
padding_h, padding_w = _expand_hw(padding)
stride_h, stride_w = _expand_hw(stride)
dilation_h, dilation_w = _expand_hw(dilation)
window_h, window_w = _expand_hw(kernel_size)
output_h, output_w = expand_hw(output_size)
padding_h, padding_w = expand_hw(padding)
stride_h, stride_w = expand_hw(stride)
dilation_h, dilation_w = expand_hw(dilation)
window_h, window_w = expand_hw(kernel_size)


expected_h = ( expected_h = (
output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1
@@ -1921,7 +1919,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order):
return layerPixelShuffle return layerPixelShuffle




def _layerPixelShuffle_traceable(inp, upscale_factor):
def layerPixelShuffle_traceable(inp, upscale_factor):
assert upscale_factor > 0, "upscale_factor should larger than 0" assert upscale_factor > 0, "upscale_factor should larger than 0"
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
assert ( assert (
@@ -1972,7 +1970,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
:param upscale_factor: upscale factor of pixel_shuffle. :param upscale_factor: upscale factor of pixel_shuffle.
:return: output tensor. :return: output tensor.
""" """
return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable)
return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable)




def region_restricted_conv( def region_restricted_conv(
@@ -1980,6 +1978,7 @@ def region_restricted_conv(
weight: Tensor, weight: Tensor,
rin: Tensor, rin: Tensor,
rout: Tensor, rout: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1, stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0, padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
@@ -1994,6 +1993,9 @@ def region_restricted_conv(
Args: Args:
inp: feature map of the convolution operation. inp: feature map of the convolution operation.
weight: convolution kernel. weight: convolution kernel.
rin: input mask
rout: output mask
bias: bias added to the result of convolution (if given).
stride: stride of the 2D region restricted convolution operation. Default: 1 stride: stride of the 2D region restricted convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
@@ -2010,9 +2012,9 @@ def region_restricted_conv(
""" """
assert conv_mode.lower() == "cross_correlation" assert conv_mode.lower() == "cross_correlation"


pad_h, pad_w = _expand_hw(padding)
stride_h, stride_w = _expand_hw(stride)
dilate_h, dilate_w = _expand_hw(dilation)
pad_h, pad_w = expand_hw(padding)
stride_h, stride_w = expand_hw(stride)
dilate_h, dilate_w = expand_hw(dilation)


sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.RegionRestrictedConvolution( op = builtin.RegionRestrictedConvolution(
@@ -2027,9 +2029,12 @@ def region_restricted_conv(
sparse=sparse_type, sparse=sparse_type,
) )
(output,) = apply(op, inp, weight, rin, rout) (output,) = apply(op, inp, weight, rin, rout)
if bias is not None:
output += bias
return output return output




from .quantized import conv_bias_activation # isort:skip from .quantized import conv_bias_activation # isort:skip
from .loss import * # isort:skip from .loss import * # isort:skip
from .metric import * # isort:skip
from .vision import * # isort:skip from .vision import * # isort:skip

imperative/python/megengine/functional/_tensor_cache.py → imperative/python/megengine/functional/tensor_cache.py View File

@@ -1,12 +1,12 @@
from ..core._imperative_rt.core2 import Const from ..core._imperative_rt.core2 import Const
from ..jit.tracing import _is_tracing
from ..jit.tracing import is_tracing


small_tensor_cache = {} small_tensor_cache = {}




def _get_scalar_tensor_with_value(value, dtype=None, device=None): def _get_scalar_tensor_with_value(value, dtype=None, device=None):
global small_tensor_cache global small_tensor_cache
if _is_tracing():
if is_tracing():
ret = Const(value, dtype, device) ret = Const(value, dtype, device)
else: else:
cache_key = (value, dtype, device) cache_key = (value, dtype, device)

+ 2
- 0
imperative/python/megengine/functional/utils.py View File

@@ -7,6 +7,8 @@ from ..utils.deprecation import deprecated_func
from .elemwise import abs, maximum, minimum from .elemwise import abs, maximum, minimum
from .tensor import ones, zeros from .tensor import ones, zeros


__all__ = ["topk_accuracy"]



def _assert_equal( def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False


+ 3
- 3
imperative/python/megengine/hub/fetcher.py View File

@@ -36,7 +36,7 @@ pattern = re.compile(
) )




class _RepoFetcherBase:
class RepoFetcherBase:
@classmethod @classmethod
def fetch( def fetch(
cls, cls,
@@ -84,7 +84,7 @@ class _RepoFetcherBase:
return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] return hashlib.sha1(repo_dir.encode()).hexdigest()[:16]




class GitSSHFetcher(_RepoFetcherBase):
class GitSSHFetcher(RepoFetcherBase):
@classmethod @classmethod
@synchronized @synchronized
def fetch( def fetch(
@@ -193,7 +193,7 @@ class GitSSHFetcher(_RepoFetcherBase):
) )




class GitHTTPSFetcher(_RepoFetcherBase):
class GitHTTPSFetcher(RepoFetcherBase):
@classmethod @classmethod
@synchronized @synchronized
def fetch( def fetch(


+ 3
- 3
imperative/python/megengine/jit/tracing.py View File

@@ -49,7 +49,7 @@ active_trace = None
skip_tracing = False skip_tracing = False




def _is_tracing():
def is_tracing():
if active_trace is None: if active_trace is None:
return False return False
else: else:
@@ -73,7 +73,7 @@ def exclude_from_trace():
skip_tracing = False skip_tracing = False




def _array_comparator(lhs, rhs):
def array_comparator(lhs, rhs):
return np.all(lhs == rhs) return np.all(lhs == rhs)




@@ -184,7 +184,7 @@ class trace:
self._trace.no_exec = record_only self._trace.no_exec = record_only
self._trace.options_visitor = apply_options self._trace.options_visitor = apply_options
self._trace.profile = profiling self._trace.profile = profiling
self._trace.array_comparator = _array_comparator
self._trace.array_comparator = array_comparator
self._trace.record_input_shapes = _input_node_use_static_shape() self._trace.record_input_shapes = _input_node_use_static_shape()


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):


+ 14
- 14
imperative/python/megengine/logger.py View File

@@ -18,10 +18,10 @@ def set_log_file(fout, mode="a"):
""" """
if isinstance(fout, str): if isinstance(fout, str):
fout = open(fout, mode) fout = open(fout, mode)
_MegEngineLogFormatter.log_fout = fout
MegEngineLogFormatter.log_fout = fout




class _MegEngineLogFormatter(logging.Formatter):
class MegEngineLogFormatter(logging.Formatter):
log_fout = None log_fout = None
date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] "
date = "%(asctime)s " date = "%(asctime)s "
@@ -71,7 +71,7 @@ class _MegEngineLogFormatter(logging.Formatter):


if self.log_fout: if self.log_fout:
self.__set_fmt(self.date_full + mtxt + self.msg) self.__set_fmt(self.date_full + mtxt + self.msg)
formatted = super(_MegEngineLogFormatter, self).format(record)
formatted = super(MegEngineLogFormatter, self).format(record)
nr_line = formatted.count("\n") + 1 nr_line = formatted.count("\n") + 1
if nr_line >= self.max_lines: if nr_line >= self.max_lines:
head, body = formatted.split("\n", 1) head, body = formatted.split("\n", 1)
@@ -88,7 +88,7 @@ class _MegEngineLogFormatter(logging.Formatter):
self.log_fout.flush() self.log_fout.flush()


self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
formatted = super(_MegEngineLogFormatter, self).format(record)
formatted = super(MegEngineLogFormatter, self).format(record)


if record.exc_text or record.exc_info: if record.exc_text or record.exc_info:
# handle exception format # handle exception format
@@ -125,7 +125,7 @@ class _MegEngineLogFormatter(logging.Formatter):
self._style._fmt = fmt self._style._fmt = fmt




def get_logger(name=None, formatter=_MegEngineLogFormatter):
def get_logger(name=None, formatter=MegEngineLogFormatter):
r"""Gets megengine logger with given name.""" r"""Gets megengine logger with given name."""


logger = logging.getLogger(name) logger = logging.getLogger(name)
@@ -167,16 +167,16 @@ try:


from .core._imperative_rt.utils import Logger as _imperative_rt_logger from .core._imperative_rt.utils import Logger as _imperative_rt_logger


class _MegBrainLogFormatter(_MegEngineLogFormatter):
class MegBrainLogFormatter(MegEngineLogFormatter):
date = "%(asctime)s[mgb] " date = "%(asctime)s[mgb] "


def _color_date(self, msg): def _color_date(self, msg):
return "\x1b[33m{}\x1b[0m".format(msg) return "\x1b[33m{}\x1b[0m".format(msg)


_megbrain_logger = get_logger("megbrain", _MegBrainLogFormatter)
_megbrain_logger = get_logger("megbrain", MegBrainLogFormatter)
_imperative_rt_logger.set_log_handler(_megbrain_logger) _imperative_rt_logger.set_log_handler(_megbrain_logger)


def _set_mgb_log_level(level):
def set_mgb_log_level(level):
r"""Sets megbrain log level r"""Sets megbrain log level


Args: Args:
@@ -200,30 +200,30 @@ try:
) )
return rst return rst


_set_mgb_log_level(_default_level)
set_mgb_log_level(_default_level)




except ImportError as exc: except ImportError as exc:


def _set_mgb_log_level(level):
def set_mgb_log_level(level):
raise NotImplementedError("imperative_rt has not been imported") raise NotImplementedError("imperative_rt has not been imported")




@contextlib.contextmanager @contextlib.contextmanager
def _replace_mgb_log_level(level):
def replace_mgb_log_level(level):
r"""Replaces megbrain log level in a block and restore after exiting. r"""Replaces megbrain log level in a block and restore after exiting.


Args: Args:
level: new log level level: new log level
""" """
old = _set_mgb_log_level(level)
old = set_mgb_log_level(level)
try: try:
yield yield
finally: finally:
_set_mgb_log_level(old)
set_mgb_log_level(old)




def enable_debug_log(): def enable_debug_log():
r"""Sets logging level to debug for all components.""" r"""Sets logging level to debug for all components."""
set_log_level(logging.DEBUG) set_log_level(logging.DEBUG)
_set_mgb_log_level(logging.DEBUG)
set_mgb_log_level(logging.DEBUG)

+ 6
- 3
imperative/python/megengine/module/conv.py View File

@@ -1040,6 +1040,7 @@ class RegionRestrictedConv(_ConvNd):
``in_channels`` and ``out_channels`` must be divisible by ``groups``, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups, and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``. Default: 1 in_channels // groups, height, width)``. Default: 1
bias: whether to add a bias onto the result of convolution. Default: True
conv_mode: Supports `cross_correlation`. Default: `cross_correlation` conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
compute_mode: When set to "default", no special requirements will be compute_mode: When set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32", placed on the precision of intermediate results. When set to "float32",
@@ -1071,6 +1072,7 @@ class RegionRestrictedConv(_ConvNd):
out_channels: int, out_channels: int,
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int, int]],
groups: int, groups: int,
bias: bool = True,
stride: Union[int, Tuple[int, int]] = 1, stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0, padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1, dilation: Union[int, Tuple[int, int]] = 1,
@@ -1095,7 +1097,7 @@ class RegionRestrictedConv(_ConvNd):
0, 0,
dilation, dilation,
groups, groups,
False,
bias,
**kwargs, **kwargs,
) )


@@ -1133,7 +1135,7 @@ class RegionRestrictedConv(_ConvNd):
(self.padding[1], self.padding[1]), (self.padding[1], self.padding[1]),
) )


def calc_conv(self, inp, weight, rin, rout):
def calc_conv(self, inp, weight, rin, rout, bias):
assert self.padding_mode in [ assert self.padding_mode in [
"zeros", "zeros",
"reflect", "reflect",
@@ -1144,6 +1146,7 @@ class RegionRestrictedConv(_ConvNd):
weight, weight,
rin, rin,
rout, rout,
bias,
self.stride, self.stride,
self.padding, self.padding,
self.dilation, self.dilation,
@@ -1153,4 +1156,4 @@ class RegionRestrictedConv(_ConvNd):
) )


def forward(self, inp, rin, rout): def forward(self, inp, rin, rout):
return self.calc_conv(inp, self.weight, rin, rout)
return self.calc_conv(inp, self.weight, rin, rout, self.bias)

+ 8
- 8
imperative/python/megengine/module/rnn.py View File

@@ -15,12 +15,12 @@ from . import init
from .module import Module from .module import Module




class _RNNCellBase(Module):
class RNNCellBase(Module):
def __init__( def __init__(
self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
) -> None: ) -> None:
# num_chunks indicates the number of gates # num_chunks indicates the number of gates
super(_RNNCellBase, self).__init__()
super(RNNCellBase, self).__init__()


self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -57,7 +57,7 @@ class _RNNCellBase(Module):
raise NotImplementedError("forward not implemented !") raise NotImplementedError("forward not implemented !")




class RNNCell(_RNNCellBase):
class RNNCell(RNNCellBase):


r"""An Elman RNN cell with tanh or ReLU non-linearity. r"""An Elman RNN cell with tanh or ReLU non-linearity.


@@ -135,7 +135,7 @@ class RNNCell(_RNNCellBase):
)[0] )[0]




class LSTMCell(_RNNCellBase):
class LSTMCell(RNNCellBase):


r"""A long short-term memory (LSTM) cell. r"""A long short-term memory (LSTM) cell.


@@ -216,7 +216,7 @@ class LSTMCell(_RNNCellBase):
)[:2] )[:2]




class _RNNBase(Module):
class RNNBase(Module):
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
@@ -228,7 +228,7 @@ class _RNNBase(Module):
bidirectional: bool = False, bidirectional: bool = False,
proj_size: int = 0, proj_size: int = 0,
) -> None: ) -> None:
super(_RNNBase, self).__init__()
super(RNNBase, self).__init__()
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_layers = num_layers self.num_layers = num_layers
@@ -323,7 +323,7 @@ class _RNNBase(Module):
return output, h return output, h




class RNN(_RNNBase):
class RNN(RNNBase):


r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
input sequence. input sequence.
@@ -453,7 +453,7 @@ class RNN(_RNNBase):
return output, h return output, h




class LSTM(_RNNBase):
class LSTM(RNNBase):


r"""Applies a multi-layer long short-term memory LSTM to an input r"""Applies a multi-layer long short-term memory LSTM to an input
sequence. sequence.


+ 1
- 1
imperative/python/megengine/quantization/observer.py View File

@@ -7,7 +7,7 @@ import numpy as np


from .. import functional as F from .. import functional as F
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..distributed import WORLD, is_distributed
from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min from ..functional.distributed import all_reduce_max, all_reduce_min
from ..logger import get_logger from ..logger import get_logger
from ..module import Module from ..module import Module


+ 2
- 2
imperative/python/megengine/serialization.py View File

@@ -66,7 +66,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL):
pickle_module.dump(obj, f, pickle_protocol) pickle_module.dump(obj, f, pickle_protocol)




class _dmap:
class dmap:
def __init__(self, map_location): def __init__(self, map_location):
self.map_location = map_location self.map_location = map_location


@@ -177,5 +177,5 @@ def load(f, map_location=None, pickle_module=pickle):


map_location = _get_callable_map_location(map_location) # callable map_location map_location = _get_callable_map_location(map_location) # callable map_location


with _dmap(map_location) as dm:
with dmap(map_location) as dm:
return pickle_module.load(f) return pickle_module.load(f)

+ 0
- 2
imperative/python/megengine/utils/deprecation.py View File

@@ -33,7 +33,6 @@ def deprecated_func(version, origin, name, tbd):
) )
return func(*args, **kwargs) return func(*args, **kwargs)


wrapper.__deprecated__ = True
return wrapper return wrapper




@@ -58,7 +57,6 @@ def deprecated_kwargs_default(version, kwargs_name, kwargs_pos):
) )
return func(*args, **kwargs) return func(*args, **kwargs)


wrapper.__deprecated__ = True
return wrapper return wrapper


return deprecated return deprecated

+ 2
- 2
imperative/python/megengine/utils/module_stats.py View File

@@ -11,11 +11,11 @@ from .. import functional as F
from .. import get_logger from .. import get_logger
from .. import module as M from .. import module as M
from ..core.tensor.dtype import get_dtype_bit from ..core.tensor.dtype import get_dtype_bit
from ..logger import _MegEngineLogFormatter
from ..logger import MegEngineLogFormatter
from .module_utils import set_module_mode_safe from .module_utils import set_module_mode_safe


try: try:
_MegEngineLogFormatter.max_lines = float("inf")
MegEngineLogFormatter.max_lines = float("inf")
except AttributeError as e: except AttributeError as e:
raise ValueError("set logger max lines failed") raise ValueError("set logger max lines failed")




+ 2
- 2
imperative/python/test/unit/core/test_util.py View File

@@ -2,14 +2,14 @@
import logging import logging


from megengine.core._imperative_rt import Logger from megengine.core._imperative_rt import Logger
from megengine.logger import _imperative_rt_logger, _set_mgb_log_level
from megengine.logger import _imperative_rt_logger, set_mgb_log_level




def test_logger(): def test_logger():
orig_level = Logger().set_log_level(Logger.LogLevel.Debug) orig_level = Logger().set_log_level(Logger.LogLevel.Debug)
assert Logger().set_log_level(Logger.LogLevel.Debug) == Logger.LogLevel.Debug assert Logger().set_log_level(Logger.LogLevel.Debug) == Logger.LogLevel.Debug
Logger().set_log_level(orig_level) Logger().set_log_level(orig_level)
orig_level = _set_mgb_log_level(logging.DEBUG)
orig_level = set_mgb_log_level(logging.DEBUG)
assert ( assert (
_imperative_rt_logger.set_log_level(Logger.LogLevel.Debug) _imperative_rt_logger.set_log_level(Logger.LogLevel.Debug)
== Logger.LogLevel.Debug == Logger.LogLevel.Debug


+ 1
- 1
imperative/python/test/unit/distributed/test_distributed.py View File

@@ -50,7 +50,7 @@ def test_init_process_group(backend):
assert mm_server_addr[0] == "localhost" assert mm_server_addr[0] == "localhost"
assert mm_server_addr[1] > 0 assert mm_server_addr[1] > 0


assert isinstance(dist.get_client(), dist.server._Client)
assert isinstance(dist.get_client(), dist.Client)


procs = [] procs = []
for rank in range(world_size): for rank in range(world_size):


+ 43
- 18
imperative/python/test/unit/functional/test_functional.py View File

@@ -930,7 +930,8 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)




def test_region_restricted_conv_forward_backward_naive():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_naive(bias):
import megengine as mge import megengine as mge
import megengine.module as M import megengine.module as M
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
@@ -943,15 +944,22 @@ def test_region_restricted_conv_forward_backward_naive():
cpu_src = tensor(src_1, device=handle) cpu_src = tensor(src_1, device=handle)
cpu_filter = tensor(filter_1, device=handle) cpu_filter = tensor(filter_1, device=handle)
gm = GradManager().attach([cpu_src, cpu_filter]) gm = GradManager().attach([cpu_src, cpu_filter])
cpu_bias = (
tensor(np.ones((1, 2, 1, 1), dtype=np.float32), device=handle) if bias else None
)
with gm: with gm:
cpu_out = F.region_restricted_conv( cpu_out = F.region_restricted_conv(
cpu_src, cpu_src,
cpu_filter, cpu_filter,
tensor(rin_1, device=handle), tensor(rin_1, device=handle),
tensor(rout_1, device=handle), tensor(rout_1, device=handle),
bias=cpu_bias,
groups=2, groups=2,
) )
gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle)) gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle))
if cpu_bias is not None:
cpu_out = cpu_out - cpu_bias
np.testing.assert_allclose(cpu_out, np.array([14, 126]).reshape(1, 2, 1, 1))
np.testing.assert_allclose( np.testing.assert_allclose(
cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2) cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2)
) )
@@ -963,7 +971,8 @@ def test_region_restricted_conv_forward_backward_naive():
@pytest.mark.skipif( @pytest.mark.skipif(
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
) )
def test_region_restricted_conv_forward_backward_cuda():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_cuda(bias):
import megengine as mge import megengine as mge
import megengine.module as M import megengine.module as M
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
@@ -998,18 +1007,23 @@ def test_region_restricted_conv_forward_backward_cuda():
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
bias_cpu = (
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0")
if bias
else None
)
gm = GradManager().attach([src, filter]) gm = GradManager().attach([src, filter])
with gm: with gm:
expected_out = F.region_restricted_conv( expected_out = F.region_restricted_conv(
src, filter, rin, rout, groups=GROUP
src, filter, rin, rout, bias=bias_cpu, groups=GROUP
) )
gm.backward( gm.backward(
expected_out, expected_out,
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
) )
return src, filter
return src, filter, expected_out


expected_src, expected_filter = get_groundtruth()
expected_src, expected_filter, expected_out = get_groundtruth()


src = tensor( src = tensor(
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
@@ -1018,18 +1032,25 @@ def test_region_restricted_conv_forward_backward_cuda():
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle) rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle) rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle)
bias_gpu = (
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None
)
gm = GradManager().attach([src, filter]) gm = GradManager().attach([src, filter])
with gm: with gm:
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP)
gpu_out = F.region_restricted_conv(
src, filter, rin, rout, bias=bias_gpu, groups=GROUP
)
gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle)) gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle))
np.testing.assert_allclose(src.grad, expected_src.grad) np.testing.assert_allclose(src.grad, expected_src.grad)
np.testing.assert_allclose(filter.grad, expected_filter.grad) np.testing.assert_allclose(filter.grad, expected_filter.grad)
np.testing.assert_allclose(gpu_out, expected_out)




@pytest.mark.skipif( @pytest.mark.skipif(
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
) )
def test_region_restricted_conv_forward_backward_uint8():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_uint8(bias):
import megengine as mge import megengine as mge
import megengine.module as M import megengine.module as M
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
@@ -1063,18 +1084,23 @@ def test_region_restricted_conv_forward_backward_uint8():
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
bias_cpu = (
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0")
if bias
else None
)
gm = GradManager().attach([src, filter]) gm = GradManager().attach([src, filter])
with gm: with gm:
expected_out = F.region_restricted_conv( expected_out = F.region_restricted_conv(
src, filter, rin, rout, groups=GROUP
src, filter, rin, rout, bias=bias_cpu, groups=GROUP
) )
gm.backward( gm.backward(
expected_out, expected_out,
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
) )
return src, filter
return src, filter, expected_out


expected_src, expected_filter = get_groundtruth()
expected_src, expected_filter, expected_out = get_groundtruth()


# forward and dgrad/wgrad # forward and dgrad/wgrad
src = tensor( src = tensor(
@@ -1084,23 +1110,22 @@ def test_region_restricted_conv_forward_backward_uint8():
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle) rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle) rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle)
bias_gpu = (
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None
)


gm = GradManager().attach([src, filter]) gm = GradManager().attach([src, filter])
with gm: with gm:
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP)
gpu_out = F.region_restricted_conv(
src, filter, rin, rout, bias=bias_gpu, groups=GROUP
)
gm.backward( gm.backward(
gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle) gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle)
) )
# assert uint8 gpu result close to cpu result # assert uint8 gpu result close to cpu result
np.testing.assert_allclose(src.grad, expected_src.grad) np.testing.assert_allclose(src.grad, expected_src.grad)
np.testing.assert_allclose(filter.grad, expected_filter.grad) np.testing.assert_allclose(filter.grad, expected_filter.grad)


def test_region_restricted_conv():
test_region_restricted_conv_forward_backward_naive()
if is_cuda_available():
test_region_restricted_conv_forward_backward_cuda()
test_region_restricted_conv_forward_backward_uint8()
np.testing.assert_allclose(gpu_out, expected_out)




def test_conv2d_autocast(): def test_conv2d_autocast():


+ 2
- 2
src/core/include/megbrain/version.h View File

@@ -3,8 +3,8 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"


#define MGE_MAJOR 1 #define MGE_MAJOR 1
#define MGE_MINOR 9999
#define MGE_PATCH 0
#define MGE_MINOR 11
#define MGE_PATCH 1


// for rc version, could be like "rc1", "rc2", etc // for rc version, could be like "rc1", "rc2", etc
#define MGE_EXTRA_NAME "" #define MGE_EXTRA_NAME ""


+ 27
- 11
src/gopt/impl/folding_conv_typecvt.cpp View File

@@ -76,7 +76,7 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
if (conv_bias == nullptr) if (conv_bias == nullptr)
return false; return false;
auto inp_dtype_conv = conv_bias->input(0)->dtype(), auto inp_dtype_conv = conv_bias->input(0)->dtype(),
out_dtype_conv = conv_bias->input(0)->dtype();
out_dtype_conv = conv_bias->output(0)->dtype();
bool is_s8nhwc = bool is_s8nhwc =
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
@@ -86,7 +86,11 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC; conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC;
if (!(is_s8nhwc || is_s4nhwc))
bool is_s8nchw =
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW;
if (!(is_s8nhwc || is_s4nhwc || is_s8nchw))
return false; return false;
if (conv_bias->input().size() != 3) if (conv_bias->input().size() != 3)
return false; return false;
@@ -107,15 +111,27 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32) auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32)
? opr::TypeCvt::make(bias, dtype::Float32()).node() ? opr::TypeCvt::make(bias, dtype::Float32()).node()
: bias; : bias;
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NHWC;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NHWC) + typecvt "
"to conv_bias(NHWC)"));
if (is_s8nchw && is_s82s4) {
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NCHW) + typecvt "
"to conv_bias(NCHW)"));
} else {
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NHWC;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NHWC) + typecvt "
"to conv_bias(NHWC)"));
}
return true; return true;
}; };




+ 3
- 0
src/gopt/impl/framework.cpp View File

@@ -823,6 +823,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
if (options.target == Target::CUDA) if (options.target == Target::CUDA)
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
#if CUDA_VERSION >= 10020
add_pass<FoldingConvBiasTypecvtPass>();
#endif
add_pass(LayoutTransformPass::make(options.target)); add_pass(LayoutTransformPass::make(options.target));
add_pass<ShuffleShuffleRemovePass>(); add_pass<ShuffleShuffleRemovePass>();
if (options.target == Target::CUDA) { if (options.target == Target::CUDA) {


+ 1
- 1
src/opr/impl/nn_int.sereg.h View File

@@ -72,7 +72,7 @@ struct OprLoadDumpImplV2<opr::ElemwiseMultiType, 0> {


namespace opr { namespace opr {
MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false); MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false);
MGB_SEREG_OPR_V2(
MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0(
ElemwiseMultiType, 0, ElemwiseMultiType, 0,
(mgb::serialization::OprLoadDumpImplV2<opr::ElemwiseMultiType, 0>::replace_opr), (mgb::serialization::OprLoadDumpImplV2<opr::ElemwiseMultiType, 0>::replace_opr),
VERSION_1, VERSION_1); VERSION_1, VERSION_1);


+ 6
- 5
src/serialization/impl/opr_registry.cpp View File

@@ -64,8 +64,8 @@ const OprRegistryV2* dynamic_registry_v2() {


auto id = MGB_HASH_STR("dynamic"); auto id = MGB_HASH_STR("dynamic");
OprRegistryV2::versioned_add( OprRegistryV2::versioned_add(
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION,
CURRENT_VERSION);
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, CURRENT_VERSION,
true);
ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION); ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION);
mgb_assert(ret); mgb_assert(ret);
return ret; return ret;
@@ -182,7 +182,8 @@ const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo(
} }


void OprRegistryV2::versioned_add( void OprRegistryV2::versioned_add(
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) {
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version,
bool dynamic) {
mgb_assert(max_version >= min_version); mgb_assert(max_version >= min_version);


auto&& sd = static_data(); auto&& sd = static_data();
@@ -190,7 +191,7 @@ void OprRegistryV2::versioned_add(
uint64_t type_id = id; uint64_t type_id = id;
//! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0 //! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0
#if MGB_VERBOSE_TYPEINFO_NAME #if MGB_VERBOSE_TYPEINFO_NAME
if (record.type && record.type->name) {
if (dynamic && record.type && record.type->name) {
type_id = MGB_HASH_RUNTIME(std::string(record.type->name)); type_id = MGB_HASH_RUNTIME(std::string(record.type->name));
} }
#endif #endif
@@ -236,7 +237,7 @@ void OprRegistry::add_using_dynamic_loader(
OprRegistryV2::versioned_add( OprRegistryV2::versioned_add(
{type, dynamic_registry_v2()->type_id, type->name, dumper, {type, dynamic_registry_v2()->type_id, type->name, dumper,
dynamic_registry_v2()->loader, nullptr}, dynamic_registry_v2()->loader, nullptr},
CURRENT_VERSION, CURRENT_VERSION);
CURRENT_VERSION, CURRENT_VERSION, true);
} }


#if MGB_ENABLE_DEBUG_UTIL #if MGB_ENABLE_DEBUG_UTIL


+ 2
- 1
src/serialization/include/megbrain/serialization/opr_registry.h View File

@@ -111,7 +111,8 @@ struct OprRegistryV2 {


//! register opr load/dump to version2regmap //! register opr load/dump to version2regmap
MGE_WIN_DECLSPEC_FUC static void versioned_add( MGE_WIN_DECLSPEC_FUC static void versioned_add(
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version);
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version,
bool dynamic = false);


MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id( MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id(
const size_t id, uint8_t version); const size_t id, uint8_t version);


+ 33
- 0
src/serialization/include/megbrain/serialization/sereg.h View File

@@ -180,6 +180,18 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
_version_min, _version_max); \ _version_min, _version_max); \
} while (0) } while (0)


//! in order to compatibility with MGB_SEREG_OPR_INTL_CALL_ADD, the macro use
//! the same hash with MGB_SEREG_OPR_INTL_CALL_ADD,
//! MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION is different with MGB_HASH_STR
#define MGB_SEREG_OPR_INTL_CALL_ADD_V2_WITHOUT_TAIL_0_AND_VERSION_HASH( \
_cls, _dump, _load, _convert, _version_min, _version_max) \
do { \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \
_version_min, _version_max); \
} while (0)

/*! /*!
* \brief register opr serialization methods * \brief register opr serialization methods
*/ */
@@ -223,6 +235,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
} \ } \
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls)


//! using MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION macro to get the type id
#define MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0( \
_cls, _arity, _converter, _version_min, _version_max) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprRegV2##_cls { \
using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD_V2_WITHOUT_TAIL_0_AND_VERSION_HASH( \
_cls, Impl::dump, wrap_loader, _converter, _version_min, \
_version_max); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls)

//! use to check type is complete or not, midout need a complete type //! use to check type is complete or not, midout need a complete type
template <class T, class = void> template <class T, class = void>
struct IsComplete : std::false_type {}; struct IsComplete : std::false_type {};


+ 1
- 1
third_party/cutlass

@@ -1 +1 @@
Subproject commit ebb7404f428b572beeb39fd048429b5ad3a9c2eb
Subproject commit 00273750862913a4289429394c449a6a25eadb0c

Loading…
Cancel
Save