Author | SHA1 | Message | Date |
---|---|---|---|
|
cc6084b4d5 |
revert: chore(mge/misc): api converage
This reverts commit -- |
2 years ago |
|
2a4295fbf6 | feat(ci): update submodules | 2 years ago |
|
75c413e3bb |
feat(imperative): region restrictd conv support bias in python
GitOrigin-RevId:
|
2 years ago |
|
6f9f25a882 |
fix(gopt): fix global layout transform fold conv typecvt
GitOrigin-RevId:
|
2 years ago |
|
cc9f743ed5 |
fix(serilization): fix elemwise multitype compatibilty in v1.11.0
GitOrigin-RevId:
|
2 years ago |
|
7b3f9e85bd | chore(release): bump version | 2 years ago |
@@ -110,7 +110,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( | |||||
bool is_version_ok = CUDNN_VERSION >= 7500; | bool is_version_ok = CUDNN_VERSION >= 7500; | ||||
bool is_dtype_ok = | bool is_dtype_ok = | ||||
(args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && | (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && | ||||
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || | |||||
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 && | |||||
args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); | args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); | ||||
bool is_bias_ok = | bool is_bias_ok = | ||||
args.bias_layout->ndim == 0 || | args.bias_layout->ndim == 0 || | ||||
@@ -19,11 +19,11 @@ logger = get_logger(__name__) | |||||
backwarding_grad_manager = None | backwarding_grad_manager = None | ||||
def _get_backwarding_grad_manager(): | |||||
def get_backwarding_grad_manager(): | |||||
return backwarding_grad_manager | return backwarding_grad_manager | ||||
class _AttachSpec: | |||||
class AttachSpec: | |||||
__slots__ = "tensor", "callbacks" | __slots__ = "tensor", "callbacks" | ||||
@@ -118,7 +118,7 @@ class GradManager: | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self._attach_specs = {} # id(Tensor) -> _AttachSpec | |||||
self._attach_specs = {} # id(Tensor) -> AttachSpec | |||||
self._recording = False | self._recording = False | ||||
self._grad = None | self._grad = None | ||||
self._after_backward_callback = [] | self._after_backward_callback = [] | ||||
@@ -200,7 +200,7 @@ class GradManager: | |||||
if self is not None: | if self is not None: | ||||
del self._attach_specs[key] | del self._attach_specs[key] | ||||
spec = _AttachSpec() | |||||
spec = AttachSpec() | |||||
spec.tensor = weakref.ref(tensor, deleter) | spec.tensor = weakref.ref(tensor, deleter) | ||||
spec.callbacks = [] | spec.callbacks = [] | ||||
return spec | return spec | ||||
@@ -354,22 +354,22 @@ class GradManager: | |||||
def __or__(self, other): | def __or__(self, other): | ||||
if isinstance(other, GradManager): | if isinstance(other, GradManager): | ||||
return _GradManagerGroup([self, other]) | |||||
return GradManagerGroup([self, other]) | |||||
return NotImplemented | return NotImplemented | ||||
__ror__ = __or__ | __ror__ = __or__ | ||||
class _GradManagerGroup: | |||||
class GradManagerGroup: | |||||
def __init__(self, gms) -> None: | def __init__(self, gms) -> None: | ||||
self._gms = list(gms) | self._gms = list(gms) | ||||
def merge_with(self, other): | def merge_with(self, other): | ||||
if isinstance(other, GradManager): | if isinstance(other, GradManager): | ||||
other = _GradManagerGroup([other]) | |||||
elif not isinstance(other, _GradManagerGroup): | |||||
other = GradManagerGroup([other]) | |||||
elif not isinstance(other, GradManagerGroup): | |||||
return NotImplemented | return NotImplemented | ||||
return _GradManagerGroup([*self._gms, *other._gms]) | |||||
return GradManagerGroup([*self._gms, *other._gms]) | |||||
__or__ = merge_with | __or__ = merge_with | ||||
__ror__ = merge_with | __ror__ = merge_with | ||||
@@ -34,7 +34,7 @@ logger = get_logger(__name__) | |||||
GLOBAL_TIMEOUT = 5 | GLOBAL_TIMEOUT = 5 | ||||
def _raise_timeout_error(): | |||||
def raise_timeout_error(): | |||||
raise RuntimeError("dataloader timeout") | raise RuntimeError("dataloader timeout") | ||||
@@ -191,7 +191,7 @@ class DataLoader: | |||||
) | ) | ||||
class _PreLoader: | |||||
class PreLoader: | |||||
def __init__(self, loader, preload): | def __init__(self, loader, preload): | ||||
self.dataset = loader.dataset | self.dataset = loader.dataset | ||||
self.sampler = loader.sampler | self.sampler = loader.sampler | ||||
@@ -319,7 +319,7 @@ class _ParallelDataLoaderIter: | |||||
if success: | if success: | ||||
return data | return data | ||||
else: | else: | ||||
_raise_timeout_error() | |||||
raise_timeout_error() | |||||
else: | else: | ||||
while True: | while True: | ||||
success, data = self._try_get_data() | success, data = self._try_get_data() | ||||
@@ -417,7 +417,7 @@ class _ParallelDataLoaderIter: | |||||
self._shutdown_workers() | self._shutdown_workers() | ||||
class _BaseMapDataLoaderIter(_PreLoader): | |||||
class _BaseMapDataLoaderIter(PreLoader): | |||||
def __init__(self, loader, preload): | def __init__(self, loader, preload): | ||||
super().__init__(loader, preload) | super().__init__(loader, preload) | ||||
@@ -510,7 +510,7 @@ def get_worker_info(): | |||||
return _worker_info | return _worker_info | ||||
class _BaseStreamDataLoaderIter(_PreLoader): | |||||
class _BaseStreamDataLoaderIter(PreLoader): | |||||
def __init__(self, loader, preload): | def __init__(self, loader, preload): | ||||
super().__init__(loader, preload) | super().__init__(loader, preload) | ||||
self.dataset_iter = iter(self.dataset) | self.dataset_iter = iter(self.dataset) | ||||
@@ -552,7 +552,7 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||||
timer.cancel() | timer.cancel() | ||||
waited_time = time.time() - start_time | waited_time = time.time() - start_time | ||||
if waited_time > self.timeout: | if waited_time > self.timeout: | ||||
_raise_timeout_error() | |||||
raise_timeout_error() | |||||
return raw_data | return raw_data | ||||
def _get_next_batch(self): | def _get_next_batch(self): | ||||
@@ -583,7 +583,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoad | |||||
place_holder = [next(self.dataset_iter)] | place_holder = [next(self.dataset_iter)] | ||||
waited_time = time.time() - start_time | waited_time = time.time() - start_time | ||||
if self.timeout > 0 and waited_time > self.timeout: | if self.timeout > 0 and waited_time > self.timeout: | ||||
_raise_timeout_error() | |||||
raise_timeout_error() | |||||
place_holder = self._get_remaind_data(place_holder) | place_holder = self._get_remaind_data(place_holder) | ||||
else: | else: | ||||
place_holder = next(self._sampler_iter) | place_holder = next(self._sampler_iter) | ||||
@@ -21,7 +21,7 @@ def _count_visible_keypoints(anno): | |||||
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | ||||
def _has_valid_annotation(anno, order): | |||||
def has_valid_annotation(anno, order): | |||||
# if it"s empty, there is no annotation | # if it"s empty, there is no annotation | ||||
if len(anno) == 0: | if len(anno) == 0: | ||||
return False | return False | ||||
@@ -101,7 +101,7 @@ class COCO(VisionDataset): | |||||
anno = [ | anno = [ | ||||
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | ||||
] | ] | ||||
if _has_valid_annotation(anno, order): | |||||
if has_valid_annotation(anno, order): | |||||
ids.append(img_id) | ids.append(img_id) | ||||
self.img_to_anns[img_id] = anno | self.img_to_anns[img_id] = anno | ||||
else: | else: | ||||
@@ -140,17 +140,17 @@ class MNIST(VisionDataset): | |||||
# load raw files and transform them into meta data and datasets Tuple(np.array) | # load raw files and transform them into meta data and datasets Tuple(np.array) | ||||
logger.info("process the raw files of %s set...", "train" if train else "test") | logger.info("process the raw files of %s set...", "train" if train else "test") | ||||
if train: | if train: | ||||
meta_data_images, images = _parse_idx3( | |||||
meta_data_images, images = parse_idx3( | |||||
os.path.join(self.root, self.raw_file_name[0]) | os.path.join(self.root, self.raw_file_name[0]) | ||||
) | ) | ||||
meta_data_labels, labels = _parse_idx1( | |||||
meta_data_labels, labels = parse_idx1( | |||||
os.path.join(self.root, self.raw_file_name[1]) | os.path.join(self.root, self.raw_file_name[1]) | ||||
) | ) | ||||
else: | else: | ||||
meta_data_images, images = _parse_idx3( | |||||
meta_data_images, images = parse_idx3( | |||||
os.path.join(self.root, self.raw_file_name[2]) | os.path.join(self.root, self.raw_file_name[2]) | ||||
) | ) | ||||
meta_data_labels, labels = _parse_idx1( | |||||
meta_data_labels, labels = parse_idx1( | |||||
os.path.join(self.root, self.raw_file_name[3]) | os.path.join(self.root, self.raw_file_name[3]) | ||||
) | ) | ||||
@@ -161,7 +161,7 @@ class MNIST(VisionDataset): | |||||
self.arrays = (images, labels.astype(np.int32)) | self.arrays = (images, labels.astype(np.int32)) | ||||
def _parse_idx3(idx3_file): | |||||
def parse_idx3(idx3_file): | |||||
# parse idx3 file to meta data and data in numpy array (images) | # parse idx3 file to meta data and data in numpy array (images) | ||||
logger.debug("parse idx3 file %s ...", idx3_file) | logger.debug("parse idx3 file %s ...", idx3_file) | ||||
assert idx3_file.endswith(".gz") | assert idx3_file.endswith(".gz") | ||||
@@ -187,7 +187,7 @@ def _parse_idx3(idx3_file): | |||||
return meta_data, images | return meta_data, images | ||||
def _parse_idx1(idx1_file): | |||||
def parse_idx1(idx1_file): | |||||
# parse idx1 file to meta data and data in numpy array (labels) | # parse idx1 file to meta data and data in numpy array (labels) | ||||
logger.debug("parse idx1 file %s ...", idx1_file) | logger.debug("parse idx1 file %s ...", idx1_file) | ||||
assert idx1_file.endswith(".gz") | assert idx1_file.endswith(".gz") | ||||
@@ -7,7 +7,7 @@ import cv2 | |||||
import numpy as np | import numpy as np | ||||
from megengine.data.transform import Transform | from megengine.data.transform import Transform | ||||
from megengine.data.transform.vision import _functional as F | |||||
from megengine.data.transform.vision import functional as F | |||||
__all__ = [ | __all__ = [ | ||||
"VisionTransform", | "VisionTransform", | ||||
@@ -2,6 +2,7 @@ | |||||
from mprop import mproperty | from mprop import mproperty | ||||
from ..core._imperative_rt.core2 import group_end, group_start | from ..core._imperative_rt.core2 import group_end, group_start | ||||
from . import group | |||||
from .group import ( | from .group import ( | ||||
WORLD, | WORLD, | ||||
Group, | Group, | ||||
@@ -19,7 +20,7 @@ from .group import ( | |||||
) | ) | ||||
from .helper import bcast_list_, make_allreduce_cb, synchronized | from .helper import bcast_list_, make_allreduce_cb, synchronized | ||||
from .launcher import launcher | from .launcher import launcher | ||||
from .server import Server | |||||
from .server import Client, Server | |||||
@mproperty | @mproperty | ||||
@@ -7,10 +7,10 @@ from mprop import mproperty | |||||
from ..device import _sh, set_default_device, what_is_xpu | from ..device import _sh, set_default_device, what_is_xpu | ||||
from ..random import seed | from ..random import seed | ||||
from .server import Server, _Client | |||||
from .server import Client, Server | |||||
class _StaticData: | |||||
class StaticData: | |||||
server = None | server = None | ||||
client = None | client = None | ||||
master_ip = None | master_ip = None | ||||
@@ -139,13 +139,13 @@ def init_process_group( | |||||
global _sd | global _sd | ||||
assert _sd is None, "init_process_group should be called only once" | assert _sd is None, "init_process_group should be called only once" | ||||
_sd = _StaticData() | |||||
_sd = StaticData() | |||||
assert world_size > 1 | assert world_size > 1 | ||||
assert rank >= 0 and rank < world_size | assert rank >= 0 and rank < world_size | ||||
assert port > 0 | assert port > 0 | ||||
_sd.client = _Client(master_ip, port) | |||||
_sd.client = Client(master_ip, port) | |||||
_sd.master_ip = master_ip | _sd.master_ip = master_ip | ||||
_sd.py_server_port = port | _sd.py_server_port = port | ||||
_sd.mm_server_port = _sd.client.get_mm_server_port() | _sd.mm_server_port = _sd.client.get_mm_server_port() | ||||
@@ -225,7 +225,7 @@ def get_mm_server_addr() -> Tuple[str, int]: | |||||
return _sd.master_ip, _sd.mm_server_port | return _sd.master_ip, _sd.mm_server_port | ||||
def get_client() -> _Client: | |||||
def get_client() -> Client: | |||||
r"""Get client of python XML RPC server.""" | r"""Get client of python XML RPC server.""" | ||||
assert _sd is not None, "please call init_process_group first" | assert _sd is not None, "please call init_process_group first" | ||||
return _sd.client | return _sd.client | ||||
@@ -7,7 +7,7 @@ from weakref import WeakSet | |||||
import numpy as np | import numpy as np | ||||
from megengine.autodiff.grad_manager import GradManager, _get_backwarding_grad_manager | |||||
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager | |||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | ||||
@@ -78,7 +78,7 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): | |||||
return apply(op, *inps, offsets)[0] | return apply(op, *inps, offsets)[0] | ||||
def _get_offsets(shapes): | |||||
def get_offsets(shapes): | |||||
offsets = [] | offsets = [] | ||||
offset = 0 | offset = 0 | ||||
for shape in shapes: | for shape in shapes: | ||||
@@ -108,7 +108,7 @@ def _check_enable_p2p(): | |||||
def pack_allreduce_split(pack_list, shapes, group, reduce_method): | def pack_allreduce_split(pack_list, shapes, group, reduce_method): | ||||
offsets_val = _get_offsets(shapes) | |||||
offsets_val = get_offsets(shapes) | |||||
offsets = Tensor(offsets_val) | offsets = Tensor(offsets_val) | ||||
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | ||||
@@ -119,7 +119,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||||
return grads | return grads | ||||
class _TensorFuture(Future): | |||||
class TensorFuture(Future): | |||||
def device(self): | def device(self): | ||||
raise "Sorry, this tensor is not ready" | raise "Sorry, this tensor is not ready" | ||||
@@ -234,13 +234,13 @@ class AllreduceCallback: | |||||
self._packing_size[dtype] = 0 | self._packing_size[dtype] = 0 | ||||
def __call__(self, param, grad): | def __call__(self, param, grad): | ||||
gm = _get_backwarding_grad_manager() | |||||
gm = get_backwarding_grad_manager() | |||||
assert isinstance(gm, GradManager) | assert isinstance(gm, GradManager) | ||||
if gm not in self._marked_gm: | if gm not in self._marked_gm: | ||||
gm._register_after_backward_callback(self._flush) | gm._register_after_backward_callback(self._flush) | ||||
self._marked_gm.add(gm) | self._marked_gm.add(gm) | ||||
self._params.append(param) | self._params.append(param) | ||||
self._futures_dict[param] = _TensorFuture(ack=False) | |||||
self._futures_dict[param] = TensorFuture(ack=False) | |||||
self._gradients_dict[param] = grad | self._gradients_dict[param] = grad | ||||
self._grad_origin_device[param] = str(grad.device) | self._grad_origin_device[param] = str(grad.device) | ||||
@@ -10,7 +10,7 @@ from ..device import get_device_count | |||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from .group import _set_machine_ranks, group_barrier, init_process_group | from .group import _set_machine_ranks, group_barrier, init_process_group | ||||
from .helper import _check_device_initialized, _check_interpreter_status | from .helper import _check_device_initialized, _check_interpreter_status | ||||
from .server import Server | |||||
from .server import Client, Server | |||||
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( | WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( | ||||
"subprocess exited with code 0 but did not return a value" | "subprocess exited with code 0 but did not return a value" | ||||
@@ -12,7 +12,7 @@ from ..core._imperative_rt.utils import create_mm_server | |||||
from ..utils.future import Future | from ..utils.future import Future | ||||
class _Methods: | |||||
class Methods: | |||||
r"""Distributed Server Method. | r"""Distributed Server Method. | ||||
Used for exchange information between distributed nodes. | Used for exchange information between distributed nodes. | ||||
@@ -149,7 +149,7 @@ class _Methods: | |||||
return ret | return ret | ||||
class _ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||||
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||||
pass | pass | ||||
@@ -163,10 +163,10 @@ def _start_server(py_server_port, queue): | |||||
""" | """ | ||||
try: | try: | ||||
mm_server_port = create_mm_server("0.0.0.0", 0) | mm_server_port = create_mm_server("0.0.0.0", 0) | ||||
server = _ThreadXMLRPCServer( | |||||
server = ThreadXMLRPCServer( | |||||
("0.0.0.0", py_server_port), logRequests=False, allow_none=True | ("0.0.0.0", py_server_port), logRequests=False, allow_none=True | ||||
) | ) | ||||
server.register_instance(_Methods(mm_server_port)) | |||||
server.register_instance(Methods(mm_server_port)) | |||||
_, py_server_port = server.server_address | _, py_server_port = server.server_address | ||||
queue.put((py_server_port, mm_server_port)) | queue.put((py_server_port, mm_server_port)) | ||||
server.serve_forever() | server.serve_forever() | ||||
@@ -196,7 +196,7 @@ class Server: | |||||
self.proc.terminate() | self.proc.terminate() | ||||
class _Client: | |||||
class Client: | |||||
r"""Distributed Client for distributed training. | r"""Distributed Client for distributed training. | ||||
Args: | Args: | ||||
@@ -298,10 +298,10 @@ class _Client: | |||||
return self.proxy.bcast_val(val, key, size) | return self.proxy.bcast_val(val, key, size) | ||||
def _main(port=0, verbose=True): | |||||
def main(port=0, verbose=True): | |||||
mm_server_port = create_mm_server("0.0.0.0", 0) | mm_server_port = create_mm_server("0.0.0.0", 0) | ||||
server = _ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) | |||||
server.register_instance(_Methods(mm_server_port)) | |||||
server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) | |||||
server.register_instance(Methods(mm_server_port)) | |||||
_, port = server.server_address | _, port = server.server_address | ||||
print("serving on port", port) | print("serving on port", port) | ||||
server.serve_forever() | server.serve_forever() | ||||
@@ -314,4 +314,4 @@ if __name__ == "__main__": | |||||
ap.add_argument("-p", "--port", type=int, default=0) | ap.add_argument("-p", "--port", type=int, default=0) | ||||
ap.add_argument("-v", "--verbose", type=bool, default=True) | ap.add_argument("-v", "--verbose", type=bool, default=True) | ||||
args = ap.parse_args() | args = ap.parse_args() | ||||
_main(port=args.port, verbose=args.verbose) | |||||
main(port=args.port, verbose=args.verbose) |
@@ -1,11 +1,13 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
# pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
from . import metric, utils, vision | |||||
from .elemwise import * | from .elemwise import * | ||||
from .math import * | from .math import * | ||||
from .nn import * | from .nn import * | ||||
from .tensor import * | from .tensor import * | ||||
from .utils import * | |||||
from . import utils, vision, distributed # isort:skip | |||||
from . import distributed # isort:skip | |||||
# delete namespace | # delete namespace | ||||
# pylint: disable=undefined-variable | # pylint: disable=undefined-variable | ||||
@@ -21,10 +21,6 @@ _valid_string_option = { | |||||
} | } | ||||
@deprecated( | |||||
version="1.10", | |||||
reason="use ``megengine.config.benchmark_kernel`` and ``megengine.config.deterministic_kernel`` instead", | |||||
) | |||||
def get_execution_strategy() -> Strategy: | def get_execution_strategy() -> Strategy: | ||||
r"""Returns the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` | r"""Returns the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` | ||||
@@ -40,10 +36,6 @@ def get_execution_strategy() -> Strategy: | |||||
return strategy | return strategy | ||||
@deprecated( | |||||
version="1.10", | |||||
reason="use ``megengine.config.benchmark_kernel`` and ``megengine.config.deterministic_kernel`` instead", | |||||
) | |||||
def set_execution_strategy(option): | def set_execution_strategy(option): | ||||
r"""Sets the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` | r"""Sets the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` | ||||
@@ -9,7 +9,7 @@ from ..core.tensor.array_method import _elwise | |||||
from ..core.tensor.utils import convert_inputs | from ..core.tensor.utils import convert_inputs | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.deprecation import deprecated_func | from ..utils.deprecation import deprecated_func | ||||
from ._tensor_cache import get_scalar_one | |||||
from .tensor_cache import get_scalar_one | |||||
__all__ = [ | __all__ = [ | ||||
"abs", | "abs", | ||||
@@ -43,7 +43,6 @@ from .debug_param import get_execution_strategy | |||||
from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
from .elemwise import _elwise, exp, log, log1p, maximum, minimum | from .elemwise import _elwise, exp, log, log1p, maximum, minimum | ||||
from .math import max, normalize, sum | from .math import max, normalize, sum | ||||
from .metric import topk_accuracy | |||||
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros | from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros | ||||
__all__ = [ | __all__ = [ | ||||
@@ -87,7 +86,6 @@ __all__ = [ | |||||
"softmax", | "softmax", | ||||
"softplus", | "softplus", | ||||
"sync_batch_norm", | "sync_batch_norm", | ||||
"topk_accuracy", | |||||
"warp_affine", | "warp_affine", | ||||
"warp_perspective", | "warp_perspective", | ||||
"pixel_shuffle", | "pixel_shuffle", | ||||
@@ -95,7 +93,7 @@ __all__ = [ | |||||
] | ] | ||||
def _expand_hw(x): | |||||
def expand_hw(x): | |||||
# judge int is 5 times faster than judge Sequence | # judge int is 5 times faster than judge Sequence | ||||
if isinstance(x, int): | if isinstance(x, int): | ||||
return x, x | return x, x | ||||
@@ -104,7 +102,7 @@ def _expand_hw(x): | |||||
return int(x), int(x) | return int(x), int(x) | ||||
def _expand_dhw(x): | |||||
def expand_dhw(x): | |||||
if isinstance(x, int): | if isinstance(x, int): | ||||
return x, x, x | return x, x, x | ||||
if isinstance(x, Sequence): | if isinstance(x, Sequence): | ||||
@@ -246,9 +244,9 @@ def conv2d( | |||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
stride_h, stride_w = _expand_hw(stride) | |||||
pad_h, pad_w = _expand_hw(padding) | |||||
dilate_h, dilate_w = _expand_hw(dilation) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
pad_h, pad_w = expand_hw(padding) | |||||
dilate_h, dilate_w = expand_hw(dilation) | |||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
@@ -308,9 +306,9 @@ def conv3d( | |||||
D, H, W = 0, 1, 2 | D, H, W = 0, 1, 2 | ||||
pad = _expand_dhw(padding) | |||||
stride = _expand_dhw(stride) | |||||
dilate = _expand_dhw(dilation) | |||||
pad = expand_dhw(padding) | |||||
stride = expand_dhw(stride) | |||||
dilate = expand_dhw(dilation) | |||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
op = builtin.Convolution3D( | op = builtin.Convolution3D( | ||||
@@ -378,10 +376,10 @@ def conv_transpose2d( | |||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
stride_h, stride_w = _expand_hw(stride) | |||||
pad_h, pad_w = _expand_hw(padding) | |||||
output_pad_h, output_pad_w = _expand_hw(output_padding) | |||||
dilate_h, dilate_w = _expand_hw(dilation) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
pad_h, pad_w = expand_hw(padding) | |||||
output_pad_h, output_pad_w = expand_hw(output_padding) | |||||
dilate_h, dilate_w = expand_hw(dilation) | |||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
@@ -479,9 +477,9 @@ def deformable_conv2d( | |||||
offset = offset.astype("float32") | offset = offset.astype("float32") | ||||
mask = mask.astype("float32") | mask = mask.astype("float32") | ||||
stride_h, stride_w = _expand_hw(stride) | |||||
pad_h, pad_w = _expand_hw(padding) | |||||
dilate_h, dilate_w = _expand_hw(dilation) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
pad_h, pad_w = expand_hw(padding) | |||||
dilate_h, dilate_w = expand_hw(dilation) | |||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
@@ -533,9 +531,9 @@ def local_conv2d( | |||||
or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
) | ) | ||||
stride_h, stride_w = _expand_hw(stride) | |||||
pad_h, pad_w = _expand_hw(padding) | |||||
dilate_h, dilate_w = _expand_hw(dilation) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
pad_h, pad_w = expand_hw(padding) | |||||
dilate_h, dilate_w = expand_hw(dilation) | |||||
# local conv only support "dense" mode, but weight could contain group dimension. | # local conv only support "dense" mode, but weight could contain group dimension. | ||||
op = builtin.GroupLocal( | op = builtin.GroupLocal( | ||||
@@ -589,10 +587,10 @@ def conv_transpose3d( | |||||
output tensor. | output tensor. | ||||
""" | """ | ||||
D, H, W = 0, 1, 2 | D, H, W = 0, 1, 2 | ||||
pad = _expand_dhw(padding) | |||||
stride = _expand_dhw(stride) | |||||
dilate = _expand_dhw(dilation) | |||||
output_padding = _expand_dhw(output_padding) | |||||
pad = expand_dhw(padding) | |||||
stride = expand_dhw(stride) | |||||
dilate = expand_dhw(dilation) | |||||
output_padding = expand_dhw(output_padding) | |||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
op = builtin.Convolution3DBackwardData( | op = builtin.Convolution3DBackwardData( | ||||
@@ -677,9 +675,9 @@ def max_pool2d( | |||||
""" | """ | ||||
if stride is None: | if stride is None: | ||||
stride = kernel_size | stride = kernel_size | ||||
window_h, window_w = _expand_hw(kernel_size) | |||||
stride_h, stride_w = _expand_hw(stride) | |||||
padding_h, padding_w = _expand_hw(padding) | |||||
window_h, window_w = expand_hw(kernel_size) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
padding_h, padding_w = expand_hw(padding) | |||||
op = builtin.Pooling( | op = builtin.Pooling( | ||||
window_h=window_h, | window_h=window_h, | ||||
@@ -727,9 +725,9 @@ def avg_pool2d( | |||||
""" | """ | ||||
if stride is None: | if stride is None: | ||||
stride = kernel_size | stride = kernel_size | ||||
window_h, window_w = _expand_hw(kernel_size) | |||||
stride_h, stride_w = _expand_hw(stride) | |||||
padding_h, padding_w = _expand_hw(padding) | |||||
window_h, window_w = expand_hw(kernel_size) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
padding_h, padding_w = expand_hw(padding) | |||||
op = builtin.Pooling( | op = builtin.Pooling( | ||||
window_h=window_h, | window_h=window_h, | ||||
@@ -1725,10 +1723,10 @@ def sliding_window( | |||||
stride: stride of the window. Default: 1 | stride: stride of the window. Default: 1 | ||||
dilation: dilation of the window. Default: 1 | dilation: dilation of the window. Default: 1 | ||||
""" | """ | ||||
padding_h, padding_w = _expand_hw(padding) | |||||
stride_h, stride_w = _expand_hw(stride) | |||||
dilation_h, dilation_w = _expand_hw(dilation) | |||||
window_h, window_w = _expand_hw(kernel_size) | |||||
padding_h, padding_w = expand_hw(padding) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
dilation_h, dilation_w = expand_hw(dilation) | |||||
window_h, window_w = expand_hw(kernel_size) | |||||
op = builtin.Images2Neibs( | op = builtin.Images2Neibs( | ||||
pad_h=padding_h, | pad_h=padding_h, | ||||
@@ -1764,11 +1762,11 @@ def sliding_window_transpose( | |||||
stride: stride of the window. Default: 1 | stride: stride of the window. Default: 1 | ||||
dilation: dilation of the window. Default: 1 | dilation: dilation of the window. Default: 1 | ||||
""" | """ | ||||
output_h, output_w = _expand_hw(output_size) | |||||
padding_h, padding_w = _expand_hw(padding) | |||||
stride_h, stride_w = _expand_hw(stride) | |||||
dilation_h, dilation_w = _expand_hw(dilation) | |||||
window_h, window_w = _expand_hw(kernel_size) | |||||
output_h, output_w = expand_hw(output_size) | |||||
padding_h, padding_w = expand_hw(padding) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
dilation_h, dilation_w = expand_hw(dilation) | |||||
window_h, window_w = expand_hw(kernel_size) | |||||
expected_h = ( | expected_h = ( | ||||
output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 | output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 | ||||
@@ -1921,7 +1919,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order): | |||||
return layerPixelShuffle | return layerPixelShuffle | ||||
def _layerPixelShuffle_traceable(inp, upscale_factor): | |||||
def layerPixelShuffle_traceable(inp, upscale_factor): | |||||
assert upscale_factor > 0, "upscale_factor should larger than 0" | assert upscale_factor > 0, "upscale_factor should larger than 0" | ||||
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" | assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" | ||||
assert ( | assert ( | ||||
@@ -1972,7 +1970,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||||
:param upscale_factor: upscale factor of pixel_shuffle. | :param upscale_factor: upscale factor of pixel_shuffle. | ||||
:return: output tensor. | :return: output tensor. | ||||
""" | """ | ||||
return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable) | |||||
return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable) | |||||
def region_restricted_conv( | def region_restricted_conv( | ||||
@@ -1980,6 +1978,7 @@ def region_restricted_conv( | |||||
weight: Tensor, | weight: Tensor, | ||||
rin: Tensor, | rin: Tensor, | ||||
rout: Tensor, | rout: Tensor, | ||||
bias: Optional[Tensor] = None, | |||||
stride: Union[int, Tuple[int, int, int]] = 1, | stride: Union[int, Tuple[int, int, int]] = 1, | ||||
padding: Union[int, Tuple[int, int, int]] = 0, | padding: Union[int, Tuple[int, int, int]] = 0, | ||||
dilation: Union[int, Tuple[int, int, int]] = 1, | dilation: Union[int, Tuple[int, int, int]] = 1, | ||||
@@ -1994,6 +1993,9 @@ def region_restricted_conv( | |||||
Args: | Args: | ||||
inp: feature map of the convolution operation. | inp: feature map of the convolution operation. | ||||
weight: convolution kernel. | weight: convolution kernel. | ||||
rin: input mask | |||||
rout: output mask | |||||
bias: bias added to the result of convolution (if given). | |||||
stride: stride of the 2D region restricted convolution operation. Default: 1 | stride: stride of the 2D region restricted convolution operation. Default: 1 | ||||
padding: size of the paddings added to the input on both sides of its | padding: size of the paddings added to the input on both sides of its | ||||
spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
@@ -2010,9 +2012,9 @@ def region_restricted_conv( | |||||
""" | """ | ||||
assert conv_mode.lower() == "cross_correlation" | assert conv_mode.lower() == "cross_correlation" | ||||
pad_h, pad_w = _expand_hw(padding) | |||||
stride_h, stride_w = _expand_hw(stride) | |||||
dilate_h, dilate_w = _expand_hw(dilation) | |||||
pad_h, pad_w = expand_hw(padding) | |||||
stride_h, stride_w = expand_hw(stride) | |||||
dilate_h, dilate_w = expand_hw(dilation) | |||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
op = builtin.RegionRestrictedConvolution( | op = builtin.RegionRestrictedConvolution( | ||||
@@ -2027,9 +2029,12 @@ def region_restricted_conv( | |||||
sparse=sparse_type, | sparse=sparse_type, | ||||
) | ) | ||||
(output,) = apply(op, inp, weight, rin, rout) | (output,) = apply(op, inp, weight, rin, rout) | ||||
if bias is not None: | |||||
output += bias | |||||
return output | return output | ||||
from .quantized import conv_bias_activation # isort:skip | from .quantized import conv_bias_activation # isort:skip | ||||
from .loss import * # isort:skip | from .loss import * # isort:skip | ||||
from .metric import * # isort:skip | |||||
from .vision import * # isort:skip | from .vision import * # isort:skip |
@@ -1,12 +1,12 @@ | |||||
from ..core._imperative_rt.core2 import Const | from ..core._imperative_rt.core2 import Const | ||||
from ..jit.tracing import _is_tracing | |||||
from ..jit.tracing import is_tracing | |||||
small_tensor_cache = {} | small_tensor_cache = {} | ||||
def _get_scalar_tensor_with_value(value, dtype=None, device=None): | def _get_scalar_tensor_with_value(value, dtype=None, device=None): | ||||
global small_tensor_cache | global small_tensor_cache | ||||
if _is_tracing(): | |||||
if is_tracing(): | |||||
ret = Const(value, dtype, device) | ret = Const(value, dtype, device) | ||||
else: | else: | ||||
cache_key = (value, dtype, device) | cache_key = (value, dtype, device) |
@@ -7,6 +7,8 @@ from ..utils.deprecation import deprecated_func | |||||
from .elemwise import abs, maximum, minimum | from .elemwise import abs, maximum, minimum | ||||
from .tensor import ones, zeros | from .tensor import ones, zeros | ||||
__all__ = ["topk_accuracy"] | |||||
def _assert_equal( | def _assert_equal( | ||||
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False | expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False | ||||
@@ -36,7 +36,7 @@ pattern = re.compile( | |||||
) | ) | ||||
class _RepoFetcherBase: | |||||
class RepoFetcherBase: | |||||
@classmethod | @classmethod | ||||
def fetch( | def fetch( | ||||
cls, | cls, | ||||
@@ -84,7 +84,7 @@ class _RepoFetcherBase: | |||||
return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] | return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] | ||||
class GitSSHFetcher(_RepoFetcherBase): | |||||
class GitSSHFetcher(RepoFetcherBase): | |||||
@classmethod | @classmethod | ||||
@synchronized | @synchronized | ||||
def fetch( | def fetch( | ||||
@@ -193,7 +193,7 @@ class GitSSHFetcher(_RepoFetcherBase): | |||||
) | ) | ||||
class GitHTTPSFetcher(_RepoFetcherBase): | |||||
class GitHTTPSFetcher(RepoFetcherBase): | |||||
@classmethod | @classmethod | ||||
@synchronized | @synchronized | ||||
def fetch( | def fetch( | ||||
@@ -49,7 +49,7 @@ active_trace = None | |||||
skip_tracing = False | skip_tracing = False | ||||
def _is_tracing(): | |||||
def is_tracing(): | |||||
if active_trace is None: | if active_trace is None: | ||||
return False | return False | ||||
else: | else: | ||||
@@ -73,7 +73,7 @@ def exclude_from_trace(): | |||||
skip_tracing = False | skip_tracing = False | ||||
def _array_comparator(lhs, rhs): | |||||
def array_comparator(lhs, rhs): | |||||
return np.all(lhs == rhs) | return np.all(lhs == rhs) | ||||
@@ -184,7 +184,7 @@ class trace: | |||||
self._trace.no_exec = record_only | self._trace.no_exec = record_only | ||||
self._trace.options_visitor = apply_options | self._trace.options_visitor = apply_options | ||||
self._trace.profile = profiling | self._trace.profile = profiling | ||||
self._trace.array_comparator = _array_comparator | |||||
self._trace.array_comparator = array_comparator | |||||
self._trace.record_input_shapes = _input_node_use_static_shape() | self._trace.record_input_shapes = _input_node_use_static_shape() | ||||
def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
@@ -18,10 +18,10 @@ def set_log_file(fout, mode="a"): | |||||
""" | """ | ||||
if isinstance(fout, str): | if isinstance(fout, str): | ||||
fout = open(fout, mode) | fout = open(fout, mode) | ||||
_MegEngineLogFormatter.log_fout = fout | |||||
MegEngineLogFormatter.log_fout = fout | |||||
class _MegEngineLogFormatter(logging.Formatter): | |||||
class MegEngineLogFormatter(logging.Formatter): | |||||
log_fout = None | log_fout = None | ||||
date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " | date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " | ||||
date = "%(asctime)s " | date = "%(asctime)s " | ||||
@@ -71,7 +71,7 @@ class _MegEngineLogFormatter(logging.Formatter): | |||||
if self.log_fout: | if self.log_fout: | ||||
self.__set_fmt(self.date_full + mtxt + self.msg) | self.__set_fmt(self.date_full + mtxt + self.msg) | ||||
formatted = super(_MegEngineLogFormatter, self).format(record) | |||||
formatted = super(MegEngineLogFormatter, self).format(record) | |||||
nr_line = formatted.count("\n") + 1 | nr_line = formatted.count("\n") + 1 | ||||
if nr_line >= self.max_lines: | if nr_line >= self.max_lines: | ||||
head, body = formatted.split("\n", 1) | head, body = formatted.split("\n", 1) | ||||
@@ -88,7 +88,7 @@ class _MegEngineLogFormatter(logging.Formatter): | |||||
self.log_fout.flush() | self.log_fout.flush() | ||||
self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) | ||||
formatted = super(_MegEngineLogFormatter, self).format(record) | |||||
formatted = super(MegEngineLogFormatter, self).format(record) | |||||
if record.exc_text or record.exc_info: | if record.exc_text or record.exc_info: | ||||
# handle exception format | # handle exception format | ||||
@@ -125,7 +125,7 @@ class _MegEngineLogFormatter(logging.Formatter): | |||||
self._style._fmt = fmt | self._style._fmt = fmt | ||||
def get_logger(name=None, formatter=_MegEngineLogFormatter): | |||||
def get_logger(name=None, formatter=MegEngineLogFormatter): | |||||
r"""Gets megengine logger with given name.""" | r"""Gets megengine logger with given name.""" | ||||
logger = logging.getLogger(name) | logger = logging.getLogger(name) | ||||
@@ -167,16 +167,16 @@ try: | |||||
from .core._imperative_rt.utils import Logger as _imperative_rt_logger | from .core._imperative_rt.utils import Logger as _imperative_rt_logger | ||||
class _MegBrainLogFormatter(_MegEngineLogFormatter): | |||||
class MegBrainLogFormatter(MegEngineLogFormatter): | |||||
date = "%(asctime)s[mgb] " | date = "%(asctime)s[mgb] " | ||||
def _color_date(self, msg): | def _color_date(self, msg): | ||||
return "\x1b[33m{}\x1b[0m".format(msg) | return "\x1b[33m{}\x1b[0m".format(msg) | ||||
_megbrain_logger = get_logger("megbrain", _MegBrainLogFormatter) | |||||
_megbrain_logger = get_logger("megbrain", MegBrainLogFormatter) | |||||
_imperative_rt_logger.set_log_handler(_megbrain_logger) | _imperative_rt_logger.set_log_handler(_megbrain_logger) | ||||
def _set_mgb_log_level(level): | |||||
def set_mgb_log_level(level): | |||||
r"""Sets megbrain log level | r"""Sets megbrain log level | ||||
Args: | Args: | ||||
@@ -200,30 +200,30 @@ try: | |||||
) | ) | ||||
return rst | return rst | ||||
_set_mgb_log_level(_default_level) | |||||
set_mgb_log_level(_default_level) | |||||
except ImportError as exc: | except ImportError as exc: | ||||
def _set_mgb_log_level(level): | |||||
def set_mgb_log_level(level): | |||||
raise NotImplementedError("imperative_rt has not been imported") | raise NotImplementedError("imperative_rt has not been imported") | ||||
@contextlib.contextmanager | @contextlib.contextmanager | ||||
def _replace_mgb_log_level(level): | |||||
def replace_mgb_log_level(level): | |||||
r"""Replaces megbrain log level in a block and restore after exiting. | r"""Replaces megbrain log level in a block and restore after exiting. | ||||
Args: | Args: | ||||
level: new log level | level: new log level | ||||
""" | """ | ||||
old = _set_mgb_log_level(level) | |||||
old = set_mgb_log_level(level) | |||||
try: | try: | ||||
yield | yield | ||||
finally: | finally: | ||||
_set_mgb_log_level(old) | |||||
set_mgb_log_level(old) | |||||
def enable_debug_log(): | def enable_debug_log(): | ||||
r"""Sets logging level to debug for all components.""" | r"""Sets logging level to debug for all components.""" | ||||
set_log_level(logging.DEBUG) | set_log_level(logging.DEBUG) | ||||
_set_mgb_log_level(logging.DEBUG) | |||||
set_mgb_log_level(logging.DEBUG) |
@@ -1040,6 +1040,7 @@ class RegionRestrictedConv(_ConvNd): | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
and the shape of weight should be ``(groups, out_channel // groups, | and the shape of weight should be ``(groups, out_channel // groups, | ||||
in_channels // groups, height, width)``. Default: 1 | in_channels // groups, height, width)``. Default: 1 | ||||
bias: whether to add a bias onto the result of convolution. Default: True | |||||
conv_mode: Supports `cross_correlation`. Default: `cross_correlation` | conv_mode: Supports `cross_correlation`. Default: `cross_correlation` | ||||
compute_mode: When set to "default", no special requirements will be | compute_mode: When set to "default", no special requirements will be | ||||
placed on the precision of intermediate results. When set to "float32", | placed on the precision of intermediate results. When set to "float32", | ||||
@@ -1071,6 +1072,7 @@ class RegionRestrictedConv(_ConvNd): | |||||
out_channels: int, | out_channels: int, | ||||
kernel_size: Union[int, Tuple[int, int]], | kernel_size: Union[int, Tuple[int, int]], | ||||
groups: int, | groups: int, | ||||
bias: bool = True, | |||||
stride: Union[int, Tuple[int, int]] = 1, | stride: Union[int, Tuple[int, int]] = 1, | ||||
padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
@@ -1095,7 +1097,7 @@ class RegionRestrictedConv(_ConvNd): | |||||
0, | 0, | ||||
dilation, | dilation, | ||||
groups, | groups, | ||||
False, | |||||
bias, | |||||
**kwargs, | **kwargs, | ||||
) | ) | ||||
@@ -1133,7 +1135,7 @@ class RegionRestrictedConv(_ConvNd): | |||||
(self.padding[1], self.padding[1]), | (self.padding[1], self.padding[1]), | ||||
) | ) | ||||
def calc_conv(self, inp, weight, rin, rout): | |||||
def calc_conv(self, inp, weight, rin, rout, bias): | |||||
assert self.padding_mode in [ | assert self.padding_mode in [ | ||||
"zeros", | "zeros", | ||||
"reflect", | "reflect", | ||||
@@ -1144,6 +1146,7 @@ class RegionRestrictedConv(_ConvNd): | |||||
weight, | weight, | ||||
rin, | rin, | ||||
rout, | rout, | ||||
bias, | |||||
self.stride, | self.stride, | ||||
self.padding, | self.padding, | ||||
self.dilation, | self.dilation, | ||||
@@ -1153,4 +1156,4 @@ class RegionRestrictedConv(_ConvNd): | |||||
) | ) | ||||
def forward(self, inp, rin, rout): | def forward(self, inp, rin, rout): | ||||
return self.calc_conv(inp, self.weight, rin, rout) | |||||
return self.calc_conv(inp, self.weight, rin, rout, self.bias) |
@@ -15,12 +15,12 @@ from . import init | |||||
from .module import Module | from .module import Module | ||||
class _RNNCellBase(Module): | |||||
class RNNCellBase(Module): | |||||
def __init__( | def __init__( | ||||
self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, | self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, | ||||
) -> None: | ) -> None: | ||||
# num_chunks indicates the number of gates | # num_chunks indicates the number of gates | ||||
super(_RNNCellBase, self).__init__() | |||||
super(RNNCellBase, self).__init__() | |||||
self.input_size = input_size | self.input_size = input_size | ||||
self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
@@ -57,7 +57,7 @@ class _RNNCellBase(Module): | |||||
raise NotImplementedError("forward not implemented !") | raise NotImplementedError("forward not implemented !") | ||||
class RNNCell(_RNNCellBase): | |||||
class RNNCell(RNNCellBase): | |||||
r"""An Elman RNN cell with tanh or ReLU non-linearity. | r"""An Elman RNN cell with tanh or ReLU non-linearity. | ||||
@@ -135,7 +135,7 @@ class RNNCell(_RNNCellBase): | |||||
)[0] | )[0] | ||||
class LSTMCell(_RNNCellBase): | |||||
class LSTMCell(RNNCellBase): | |||||
r"""A long short-term memory (LSTM) cell. | r"""A long short-term memory (LSTM) cell. | ||||
@@ -216,7 +216,7 @@ class LSTMCell(_RNNCellBase): | |||||
)[:2] | )[:2] | ||||
class _RNNBase(Module): | |||||
class RNNBase(Module): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
input_size: int, | input_size: int, | ||||
@@ -228,7 +228,7 @@ class _RNNBase(Module): | |||||
bidirectional: bool = False, | bidirectional: bool = False, | ||||
proj_size: int = 0, | proj_size: int = 0, | ||||
) -> None: | ) -> None: | ||||
super(_RNNBase, self).__init__() | |||||
super(RNNBase, self).__init__() | |||||
self.input_size = input_size | self.input_size = input_size | ||||
self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
self.num_layers = num_layers | self.num_layers = num_layers | ||||
@@ -323,7 +323,7 @@ class _RNNBase(Module): | |||||
return output, h | return output, h | ||||
class RNN(_RNNBase): | |||||
class RNN(RNNBase): | |||||
r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an | r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an | ||||
input sequence. | input sequence. | ||||
@@ -453,7 +453,7 @@ class RNN(_RNNBase): | |||||
return output, h | return output, h | ||||
class LSTM(_RNNBase): | |||||
class LSTM(RNNBase): | |||||
r"""Applies a multi-layer long short-term memory LSTM to an input | r"""Applies a multi-layer long short-term memory LSTM to an input | ||||
sequence. | sequence. | ||||
@@ -7,7 +7,7 @@ import numpy as np | |||||
from .. import functional as F | from .. import functional as F | ||||
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | ||||
from ..distributed import WORLD, is_distributed | |||||
from ..distributed import WORLD, get_rank, is_distributed | |||||
from ..functional.distributed import all_reduce_max, all_reduce_min | from ..functional.distributed import all_reduce_max, all_reduce_min | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..module import Module | from ..module import Module | ||||
@@ -66,7 +66,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL): | |||||
pickle_module.dump(obj, f, pickle_protocol) | pickle_module.dump(obj, f, pickle_protocol) | ||||
class _dmap: | |||||
class dmap: | |||||
def __init__(self, map_location): | def __init__(self, map_location): | ||||
self.map_location = map_location | self.map_location = map_location | ||||
@@ -177,5 +177,5 @@ def load(f, map_location=None, pickle_module=pickle): | |||||
map_location = _get_callable_map_location(map_location) # callable map_location | map_location = _get_callable_map_location(map_location) # callable map_location | ||||
with _dmap(map_location) as dm: | |||||
with dmap(map_location) as dm: | |||||
return pickle_module.load(f) | return pickle_module.load(f) |
@@ -33,7 +33,6 @@ def deprecated_func(version, origin, name, tbd): | |||||
) | ) | ||||
return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
wrapper.__deprecated__ = True | |||||
return wrapper | return wrapper | ||||
@@ -58,7 +57,6 @@ def deprecated_kwargs_default(version, kwargs_name, kwargs_pos): | |||||
) | ) | ||||
return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
wrapper.__deprecated__ = True | |||||
return wrapper | return wrapper | ||||
return deprecated | return deprecated |
@@ -11,11 +11,11 @@ from .. import functional as F | |||||
from .. import get_logger | from .. import get_logger | ||||
from .. import module as M | from .. import module as M | ||||
from ..core.tensor.dtype import get_dtype_bit | from ..core.tensor.dtype import get_dtype_bit | ||||
from ..logger import _MegEngineLogFormatter | |||||
from ..logger import MegEngineLogFormatter | |||||
from .module_utils import set_module_mode_safe | from .module_utils import set_module_mode_safe | ||||
try: | try: | ||||
_MegEngineLogFormatter.max_lines = float("inf") | |||||
MegEngineLogFormatter.max_lines = float("inf") | |||||
except AttributeError as e: | except AttributeError as e: | ||||
raise ValueError("set logger max lines failed") | raise ValueError("set logger max lines failed") | ||||
@@ -2,14 +2,14 @@ | |||||
import logging | import logging | ||||
from megengine.core._imperative_rt import Logger | from megengine.core._imperative_rt import Logger | ||||
from megengine.logger import _imperative_rt_logger, _set_mgb_log_level | |||||
from megengine.logger import _imperative_rt_logger, set_mgb_log_level | |||||
def test_logger(): | def test_logger(): | ||||
orig_level = Logger().set_log_level(Logger.LogLevel.Debug) | orig_level = Logger().set_log_level(Logger.LogLevel.Debug) | ||||
assert Logger().set_log_level(Logger.LogLevel.Debug) == Logger.LogLevel.Debug | assert Logger().set_log_level(Logger.LogLevel.Debug) == Logger.LogLevel.Debug | ||||
Logger().set_log_level(orig_level) | Logger().set_log_level(orig_level) | ||||
orig_level = _set_mgb_log_level(logging.DEBUG) | |||||
orig_level = set_mgb_log_level(logging.DEBUG) | |||||
assert ( | assert ( | ||||
_imperative_rt_logger.set_log_level(Logger.LogLevel.Debug) | _imperative_rt_logger.set_log_level(Logger.LogLevel.Debug) | ||||
== Logger.LogLevel.Debug | == Logger.LogLevel.Debug | ||||
@@ -50,7 +50,7 @@ def test_init_process_group(backend): | |||||
assert mm_server_addr[0] == "localhost" | assert mm_server_addr[0] == "localhost" | ||||
assert mm_server_addr[1] > 0 | assert mm_server_addr[1] > 0 | ||||
assert isinstance(dist.get_client(), dist.server._Client) | |||||
assert isinstance(dist.get_client(), dist.Client) | |||||
procs = [] | procs = [] | ||||
for rank in range(world_size): | for rank in range(world_size): | ||||
@@ -930,7 +930,8 @@ def test_batch_conv_bias(): | |||||
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | ||||
def test_region_restricted_conv_forward_backward_naive(): | |||||
@pytest.mark.parametrize("bias", [True, False]) | |||||
def test_region_restricted_conv_forward_backward_naive(bias): | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.module as M | import megengine.module as M | ||||
from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
@@ -943,15 +944,22 @@ def test_region_restricted_conv_forward_backward_naive(): | |||||
cpu_src = tensor(src_1, device=handle) | cpu_src = tensor(src_1, device=handle) | ||||
cpu_filter = tensor(filter_1, device=handle) | cpu_filter = tensor(filter_1, device=handle) | ||||
gm = GradManager().attach([cpu_src, cpu_filter]) | gm = GradManager().attach([cpu_src, cpu_filter]) | ||||
cpu_bias = ( | |||||
tensor(np.ones((1, 2, 1, 1), dtype=np.float32), device=handle) if bias else None | |||||
) | |||||
with gm: | with gm: | ||||
cpu_out = F.region_restricted_conv( | cpu_out = F.region_restricted_conv( | ||||
cpu_src, | cpu_src, | ||||
cpu_filter, | cpu_filter, | ||||
tensor(rin_1, device=handle), | tensor(rin_1, device=handle), | ||||
tensor(rout_1, device=handle), | tensor(rout_1, device=handle), | ||||
bias=cpu_bias, | |||||
groups=2, | groups=2, | ||||
) | ) | ||||
gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle)) | gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle)) | ||||
if cpu_bias is not None: | |||||
cpu_out = cpu_out - cpu_bias | |||||
np.testing.assert_allclose(cpu_out, np.array([14, 126]).reshape(1, 2, 1, 1)) | |||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2) | cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2) | ||||
) | ) | ||||
@@ -963,7 +971,8 @@ def test_region_restricted_conv_forward_backward_naive(): | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" | not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" | ||||
) | ) | ||||
def test_region_restricted_conv_forward_backward_cuda(): | |||||
@pytest.mark.parametrize("bias", [True, False]) | |||||
def test_region_restricted_conv_forward_backward_cuda(bias): | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.module as M | import megengine.module as M | ||||
from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
@@ -998,18 +1007,23 @@ def test_region_restricted_conv_forward_backward_cuda(): | |||||
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") | filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") | ||||
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") | rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") | ||||
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") | rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") | ||||
bias_cpu = ( | |||||
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0") | |||||
if bias | |||||
else None | |||||
) | |||||
gm = GradManager().attach([src, filter]) | gm = GradManager().attach([src, filter]) | ||||
with gm: | with gm: | ||||
expected_out = F.region_restricted_conv( | expected_out = F.region_restricted_conv( | ||||
src, filter, rin, rout, groups=GROUP | |||||
src, filter, rin, rout, bias=bias_cpu, groups=GROUP | |||||
) | ) | ||||
gm.backward( | gm.backward( | ||||
expected_out, | expected_out, | ||||
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), | tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), | ||||
) | ) | ||||
return src, filter | |||||
return src, filter, expected_out | |||||
expected_src, expected_filter = get_groundtruth() | |||||
expected_src, expected_filter, expected_out = get_groundtruth() | |||||
src = tensor( | src = tensor( | ||||
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | ||||
@@ -1018,18 +1032,25 @@ def test_region_restricted_conv_forward_backward_cuda(): | |||||
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) | filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) | ||||
rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle) | rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle) | ||||
rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle) | rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle) | ||||
bias_gpu = ( | |||||
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None | |||||
) | |||||
gm = GradManager().attach([src, filter]) | gm = GradManager().attach([src, filter]) | ||||
with gm: | with gm: | ||||
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP) | |||||
gpu_out = F.region_restricted_conv( | |||||
src, filter, rin, rout, bias=bias_gpu, groups=GROUP | |||||
) | |||||
gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle)) | gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle)) | ||||
np.testing.assert_allclose(src.grad, expected_src.grad) | np.testing.assert_allclose(src.grad, expected_src.grad) | ||||
np.testing.assert_allclose(filter.grad, expected_filter.grad) | np.testing.assert_allclose(filter.grad, expected_filter.grad) | ||||
np.testing.assert_allclose(gpu_out, expected_out) | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" | not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" | ||||
) | ) | ||||
def test_region_restricted_conv_forward_backward_uint8(): | |||||
@pytest.mark.parametrize("bias", [True, False]) | |||||
def test_region_restricted_conv_forward_backward_uint8(bias): | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.module as M | import megengine.module as M | ||||
from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
@@ -1063,18 +1084,23 @@ def test_region_restricted_conv_forward_backward_uint8(): | |||||
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") | filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") | ||||
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") | rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") | ||||
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") | rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") | ||||
bias_cpu = ( | |||||
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0") | |||||
if bias | |||||
else None | |||||
) | |||||
gm = GradManager().attach([src, filter]) | gm = GradManager().attach([src, filter]) | ||||
with gm: | with gm: | ||||
expected_out = F.region_restricted_conv( | expected_out = F.region_restricted_conv( | ||||
src, filter, rin, rout, groups=GROUP | |||||
src, filter, rin, rout, bias=bias_cpu, groups=GROUP | |||||
) | ) | ||||
gm.backward( | gm.backward( | ||||
expected_out, | expected_out, | ||||
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), | tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), | ||||
) | ) | ||||
return src, filter | |||||
return src, filter, expected_out | |||||
expected_src, expected_filter = get_groundtruth() | |||||
expected_src, expected_filter, expected_out = get_groundtruth() | |||||
# forward and dgrad/wgrad | # forward and dgrad/wgrad | ||||
src = tensor( | src = tensor( | ||||
@@ -1084,23 +1110,22 @@ def test_region_restricted_conv_forward_backward_uint8(): | |||||
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) | filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) | ||||
rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle) | rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle) | ||||
rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle) | rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle) | ||||
bias_gpu = ( | |||||
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None | |||||
) | |||||
gm = GradManager().attach([src, filter]) | gm = GradManager().attach([src, filter]) | ||||
with gm: | with gm: | ||||
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP) | |||||
gpu_out = F.region_restricted_conv( | |||||
src, filter, rin, rout, bias=bias_gpu, groups=GROUP | |||||
) | |||||
gm.backward( | gm.backward( | ||||
gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle) | gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle) | ||||
) | ) | ||||
# assert uint8 gpu result close to cpu result | # assert uint8 gpu result close to cpu result | ||||
np.testing.assert_allclose(src.grad, expected_src.grad) | np.testing.assert_allclose(src.grad, expected_src.grad) | ||||
np.testing.assert_allclose(filter.grad, expected_filter.grad) | np.testing.assert_allclose(filter.grad, expected_filter.grad) | ||||
def test_region_restricted_conv(): | |||||
test_region_restricted_conv_forward_backward_naive() | |||||
if is_cuda_available(): | |||||
test_region_restricted_conv_forward_backward_cuda() | |||||
test_region_restricted_conv_forward_backward_uint8() | |||||
np.testing.assert_allclose(gpu_out, expected_out) | |||||
def test_conv2d_autocast(): | def test_conv2d_autocast(): | ||||
@@ -3,8 +3,8 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#define MGE_MAJOR 1 | #define MGE_MAJOR 1 | ||||
#define MGE_MINOR 9999 | |||||
#define MGE_PATCH 0 | |||||
#define MGE_MINOR 11 | |||||
#define MGE_PATCH 1 | |||||
// for rc version, could be like "rc1", "rc2", etc | // for rc version, could be like "rc1", "rc2", etc | ||||
#define MGE_EXTRA_NAME "" | #define MGE_EXTRA_NAME "" | ||||
@@ -76,7 +76,7 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { | |||||
if (conv_bias == nullptr) | if (conv_bias == nullptr) | ||||
return false; | return false; | ||||
auto inp_dtype_conv = conv_bias->input(0)->dtype(), | auto inp_dtype_conv = conv_bias->input(0)->dtype(), | ||||
out_dtype_conv = conv_bias->input(0)->dtype(); | |||||
out_dtype_conv = conv_bias->output(0)->dtype(); | |||||
bool is_s8nhwc = | bool is_s8nhwc = | ||||
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && | inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && | ||||
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | ||||
@@ -86,7 +86,11 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { | |||||
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && | inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && | ||||
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | ||||
conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC; | conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC; | ||||
if (!(is_s8nhwc || is_s4nhwc)) | |||||
bool is_s8nchw = | |||||
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && | |||||
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||||
conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW; | |||||
if (!(is_s8nhwc || is_s4nhwc || is_s8nchw)) | |||||
return false; | return false; | ||||
if (conv_bias->input().size() != 3) | if (conv_bias->input().size() != 3) | ||||
return false; | return false; | ||||
@@ -107,15 +111,27 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { | |||||
auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32) | auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32) | ||||
? opr::TypeCvt::make(bias, dtype::Float32()).node() | ? opr::TypeCvt::make(bias, dtype::Float32()).node() | ||||
: bias; | : bias; | ||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NHWC; | |||||
auto conv_bias_typecvt = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
OperatorNodeConfig{out_dtype_typecvt}); | |||||
rewriter.replace_var( | |||||
opr->output(0), conv_bias_typecvt.node(), | |||||
mgb_cstr_log("replace conv_bias(NHWC) + typecvt " | |||||
"to conv_bias(NHWC)")); | |||||
if (is_s8nchw && is_s82s4) { | |||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NCHW; | |||||
auto conv_bias_typecvt = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
OperatorNodeConfig{out_dtype_typecvt}); | |||||
rewriter.replace_var( | |||||
opr->output(0), conv_bias_typecvt.node(), | |||||
mgb_cstr_log("replace conv_bias(NCHW) + typecvt " | |||||
"to conv_bias(NCHW)")); | |||||
} else { | |||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NHWC; | |||||
auto conv_bias_typecvt = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
OperatorNodeConfig{out_dtype_typecvt}); | |||||
rewriter.replace_var( | |||||
opr->output(0), conv_bias_typecvt.node(), | |||||
mgb_cstr_log("replace conv_bias(NHWC) + typecvt " | |||||
"to conv_bias(NHWC)")); | |||||
} | |||||
return true; | return true; | ||||
}; | }; | ||||
@@ -823,6 +823,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
if (options.target == Target::CUDA) | if (options.target == Target::CUDA) | ||||
add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
#if CUDA_VERSION >= 10020 | |||||
add_pass<FoldingConvBiasTypecvtPass>(); | |||||
#endif | |||||
add_pass(LayoutTransformPass::make(options.target)); | add_pass(LayoutTransformPass::make(options.target)); | ||||
add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
if (options.target == Target::CUDA) { | if (options.target == Target::CUDA) { | ||||
@@ -72,7 +72,7 @@ struct OprLoadDumpImplV2<opr::ElemwiseMultiType, 0> { | |||||
namespace opr { | namespace opr { | ||||
MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false); | MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false); | ||||
MGB_SEREG_OPR_V2( | |||||
MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0( | |||||
ElemwiseMultiType, 0, | ElemwiseMultiType, 0, | ||||
(mgb::serialization::OprLoadDumpImplV2<opr::ElemwiseMultiType, 0>::replace_opr), | (mgb::serialization::OprLoadDumpImplV2<opr::ElemwiseMultiType, 0>::replace_opr), | ||||
VERSION_1, VERSION_1); | VERSION_1, VERSION_1); | ||||
@@ -64,8 +64,8 @@ const OprRegistryV2* dynamic_registry_v2() { | |||||
auto id = MGB_HASH_STR("dynamic"); | auto id = MGB_HASH_STR("dynamic"); | ||||
OprRegistryV2::versioned_add( | OprRegistryV2::versioned_add( | ||||
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, | |||||
CURRENT_VERSION); | |||||
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, CURRENT_VERSION, | |||||
true); | |||||
ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION); | ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION); | ||||
mgb_assert(ret); | mgb_assert(ret); | ||||
return ret; | return ret; | ||||
@@ -182,7 +182,8 @@ const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo( | |||||
} | } | ||||
void OprRegistryV2::versioned_add( | void OprRegistryV2::versioned_add( | ||||
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) { | |||||
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version, | |||||
bool dynamic) { | |||||
mgb_assert(max_version >= min_version); | mgb_assert(max_version >= min_version); | ||||
auto&& sd = static_data(); | auto&& sd = static_data(); | ||||
@@ -190,7 +191,7 @@ void OprRegistryV2::versioned_add( | |||||
uint64_t type_id = id; | uint64_t type_id = id; | ||||
//! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0 | //! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0 | ||||
#if MGB_VERBOSE_TYPEINFO_NAME | #if MGB_VERBOSE_TYPEINFO_NAME | ||||
if (record.type && record.type->name) { | |||||
if (dynamic && record.type && record.type->name) { | |||||
type_id = MGB_HASH_RUNTIME(std::string(record.type->name)); | type_id = MGB_HASH_RUNTIME(std::string(record.type->name)); | ||||
} | } | ||||
#endif | #endif | ||||
@@ -236,7 +237,7 @@ void OprRegistry::add_using_dynamic_loader( | |||||
OprRegistryV2::versioned_add( | OprRegistryV2::versioned_add( | ||||
{type, dynamic_registry_v2()->type_id, type->name, dumper, | {type, dynamic_registry_v2()->type_id, type->name, dumper, | ||||
dynamic_registry_v2()->loader, nullptr}, | dynamic_registry_v2()->loader, nullptr}, | ||||
CURRENT_VERSION, CURRENT_VERSION); | |||||
CURRENT_VERSION, CURRENT_VERSION, true); | |||||
} | } | ||||
#if MGB_ENABLE_DEBUG_UTIL | #if MGB_ENABLE_DEBUG_UTIL | ||||
@@ -111,7 +111,8 @@ struct OprRegistryV2 { | |||||
//! register opr load/dump to version2regmap | //! register opr load/dump to version2regmap | ||||
MGE_WIN_DECLSPEC_FUC static void versioned_add( | MGE_WIN_DECLSPEC_FUC static void versioned_add( | ||||
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version); | |||||
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version, | |||||
bool dynamic = false); | |||||
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id( | MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id( | ||||
const size_t id, uint8_t version); | const size_t id, uint8_t version); | ||||
@@ -180,6 +180,18 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||||
_version_min, _version_max); \ | _version_min, _version_max); \ | ||||
} while (0) | } while (0) | ||||
//! in order to compatibility with MGB_SEREG_OPR_INTL_CALL_ADD, the macro use | |||||
//! the same hash with MGB_SEREG_OPR_INTL_CALL_ADD, | |||||
//! MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION is different with MGB_HASH_STR | |||||
#define MGB_SEREG_OPR_INTL_CALL_ADD_V2_WITHOUT_TAIL_0_AND_VERSION_HASH( \ | |||||
_cls, _dump, _load, _convert, _version_min, _version_max) \ | |||||
do { \ | |||||
::mgb::serialization::OprRegistryV2::versioned_add( \ | |||||
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \ | |||||
_version_min, _version_max); \ | |||||
} while (0) | |||||
/*! | /*! | ||||
* \brief register opr serialization methods | * \brief register opr serialization methods | ||||
*/ | */ | ||||
@@ -223,6 +235,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||||
} \ | } \ | ||||
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) | MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) | ||||
//! using MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION macro to get the type id | |||||
#define MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0( \ | |||||
_cls, _arity, _converter, _version_min, _version_max) \ | |||||
namespace { \ | |||||
namespace ser = ::mgb::serialization; \ | |||||
struct _OprRegV2##_cls { \ | |||||
using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \ | |||||
static ser::OprWithOutputAccessor wrap_loader( \ | |||||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||||
const mgb::cg::OperatorNodeConfig& config) { \ | |||||
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||||
} \ | |||||
static void entry() { \ | |||||
MGB_SEREG_OPR_INTL_CALL_ADD_V2_WITHOUT_TAIL_0_AND_VERSION_HASH( \ | |||||
_cls, Impl::dump, wrap_loader, _converter, _version_min, \ | |||||
_version_max); \ | |||||
} \ | |||||
}; \ | |||||
} \ | |||||
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) | |||||
//! use to check type is complete or not, midout need a complete type | //! use to check type is complete or not, midout need a complete type | ||||
template <class T, class = void> | template <class T, class = void> | ||||
struct IsComplete : std::false_type {}; | struct IsComplete : std::false_type {}; | ||||
@@ -1 +1 @@ | |||||
Subproject commit ebb7404f428b572beeb39fd048429b5ad3a9c2eb | |||||
Subproject commit 00273750862913a4289429394c449a6a25eadb0c |