This reverts commit --release-1.11.177fd432cb2
0d7126641b
559d205a8c
6d8e7398e4
b72517bfe9
GitOrigin-RevId:18dc4ce999
@@ -19,11 +19,11 @@ logger = get_logger(__name__) | |||
backwarding_grad_manager = None | |||
def _get_backwarding_grad_manager(): | |||
def get_backwarding_grad_manager(): | |||
return backwarding_grad_manager | |||
class _AttachSpec: | |||
class AttachSpec: | |||
__slots__ = "tensor", "callbacks" | |||
@@ -118,7 +118,7 @@ class GradManager: | |||
""" | |||
def __init__(self): | |||
self._attach_specs = {} # id(Tensor) -> _AttachSpec | |||
self._attach_specs = {} # id(Tensor) -> AttachSpec | |||
self._recording = False | |||
self._grad = None | |||
self._after_backward_callback = [] | |||
@@ -200,7 +200,7 @@ class GradManager: | |||
if self is not None: | |||
del self._attach_specs[key] | |||
spec = _AttachSpec() | |||
spec = AttachSpec() | |||
spec.tensor = weakref.ref(tensor, deleter) | |||
spec.callbacks = [] | |||
return spec | |||
@@ -354,22 +354,22 @@ class GradManager: | |||
def __or__(self, other): | |||
if isinstance(other, GradManager): | |||
return _GradManagerGroup([self, other]) | |||
return GradManagerGroup([self, other]) | |||
return NotImplemented | |||
__ror__ = __or__ | |||
class _GradManagerGroup: | |||
class GradManagerGroup: | |||
def __init__(self, gms) -> None: | |||
self._gms = list(gms) | |||
def merge_with(self, other): | |||
if isinstance(other, GradManager): | |||
other = _GradManagerGroup([other]) | |||
elif not isinstance(other, _GradManagerGroup): | |||
other = GradManagerGroup([other]) | |||
elif not isinstance(other, GradManagerGroup): | |||
return NotImplemented | |||
return _GradManagerGroup([*self._gms, *other._gms]) | |||
return GradManagerGroup([*self._gms, *other._gms]) | |||
__or__ = merge_with | |||
__ror__ = merge_with | |||
@@ -34,7 +34,7 @@ logger = get_logger(__name__) | |||
GLOBAL_TIMEOUT = 5 | |||
def _raise_timeout_error(): | |||
def raise_timeout_error(): | |||
raise RuntimeError("dataloader timeout") | |||
@@ -191,7 +191,7 @@ class DataLoader: | |||
) | |||
class _PreLoader: | |||
class PreLoader: | |||
def __init__(self, loader, preload): | |||
self.dataset = loader.dataset | |||
self.sampler = loader.sampler | |||
@@ -319,7 +319,7 @@ class _ParallelDataLoaderIter: | |||
if success: | |||
return data | |||
else: | |||
_raise_timeout_error() | |||
raise_timeout_error() | |||
else: | |||
while True: | |||
success, data = self._try_get_data() | |||
@@ -417,7 +417,7 @@ class _ParallelDataLoaderIter: | |||
self._shutdown_workers() | |||
class _BaseMapDataLoaderIter(_PreLoader): | |||
class _BaseMapDataLoaderIter(PreLoader): | |||
def __init__(self, loader, preload): | |||
super().__init__(loader, preload) | |||
@@ -510,7 +510,7 @@ def get_worker_info(): | |||
return _worker_info | |||
class _BaseStreamDataLoaderIter(_PreLoader): | |||
class _BaseStreamDataLoaderIter(PreLoader): | |||
def __init__(self, loader, preload): | |||
super().__init__(loader, preload) | |||
self.dataset_iter = iter(self.dataset) | |||
@@ -552,7 +552,7 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||
timer.cancel() | |||
waited_time = time.time() - start_time | |||
if waited_time > self.timeout: | |||
_raise_timeout_error() | |||
raise_timeout_error() | |||
return raw_data | |||
def _get_next_batch(self): | |||
@@ -583,7 +583,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoad | |||
place_holder = [next(self.dataset_iter)] | |||
waited_time = time.time() - start_time | |||
if self.timeout > 0 and waited_time > self.timeout: | |||
_raise_timeout_error() | |||
raise_timeout_error() | |||
place_holder = self._get_remaind_data(place_holder) | |||
else: | |||
place_holder = next(self._sampler_iter) | |||
@@ -21,7 +21,7 @@ def _count_visible_keypoints(anno): | |||
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |||
def _has_valid_annotation(anno, order): | |||
def has_valid_annotation(anno, order): | |||
# if it"s empty, there is no annotation | |||
if len(anno) == 0: | |||
return False | |||
@@ -101,7 +101,7 @@ class COCO(VisionDataset): | |||
anno = [ | |||
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
] | |||
if _has_valid_annotation(anno, order): | |||
if has_valid_annotation(anno, order): | |||
ids.append(img_id) | |||
self.img_to_anns[img_id] = anno | |||
else: | |||
@@ -140,17 +140,17 @@ class MNIST(VisionDataset): | |||
# load raw files and transform them into meta data and datasets Tuple(np.array) | |||
logger.info("process the raw files of %s set...", "train" if train else "test") | |||
if train: | |||
meta_data_images, images = _parse_idx3( | |||
meta_data_images, images = parse_idx3( | |||
os.path.join(self.root, self.raw_file_name[0]) | |||
) | |||
meta_data_labels, labels = _parse_idx1( | |||
meta_data_labels, labels = parse_idx1( | |||
os.path.join(self.root, self.raw_file_name[1]) | |||
) | |||
else: | |||
meta_data_images, images = _parse_idx3( | |||
meta_data_images, images = parse_idx3( | |||
os.path.join(self.root, self.raw_file_name[2]) | |||
) | |||
meta_data_labels, labels = _parse_idx1( | |||
meta_data_labels, labels = parse_idx1( | |||
os.path.join(self.root, self.raw_file_name[3]) | |||
) | |||
@@ -161,7 +161,7 @@ class MNIST(VisionDataset): | |||
self.arrays = (images, labels.astype(np.int32)) | |||
def _parse_idx3(idx3_file): | |||
def parse_idx3(idx3_file): | |||
# parse idx3 file to meta data and data in numpy array (images) | |||
logger.debug("parse idx3 file %s ...", idx3_file) | |||
assert idx3_file.endswith(".gz") | |||
@@ -187,7 +187,7 @@ def _parse_idx3(idx3_file): | |||
return meta_data, images | |||
def _parse_idx1(idx1_file): | |||
def parse_idx1(idx1_file): | |||
# parse idx1 file to meta data and data in numpy array (labels) | |||
logger.debug("parse idx1 file %s ...", idx1_file) | |||
assert idx1_file.endswith(".gz") | |||
@@ -7,7 +7,7 @@ import cv2 | |||
import numpy as np | |||
from megengine.data.transform import Transform | |||
from megengine.data.transform.vision import _functional as F | |||
from megengine.data.transform.vision import functional as F | |||
__all__ = [ | |||
"VisionTransform", | |||
@@ -2,6 +2,7 @@ | |||
from mprop import mproperty | |||
from ..core._imperative_rt.core2 import group_end, group_start | |||
from . import group | |||
from .group import ( | |||
WORLD, | |||
Group, | |||
@@ -19,7 +20,7 @@ from .group import ( | |||
) | |||
from .helper import bcast_list_, make_allreduce_cb, synchronized | |||
from .launcher import launcher | |||
from .server import Server | |||
from .server import Client, Server | |||
@mproperty | |||
@@ -7,10 +7,10 @@ from mprop import mproperty | |||
from ..device import _sh, set_default_device, what_is_xpu | |||
from ..random import seed | |||
from .server import Server, _Client | |||
from .server import Client, Server | |||
class _StaticData: | |||
class StaticData: | |||
server = None | |||
client = None | |||
master_ip = None | |||
@@ -139,13 +139,13 @@ def init_process_group( | |||
global _sd | |||
assert _sd is None, "init_process_group should be called only once" | |||
_sd = _StaticData() | |||
_sd = StaticData() | |||
assert world_size > 1 | |||
assert rank >= 0 and rank < world_size | |||
assert port > 0 | |||
_sd.client = _Client(master_ip, port) | |||
_sd.client = Client(master_ip, port) | |||
_sd.master_ip = master_ip | |||
_sd.py_server_port = port | |||
_sd.mm_server_port = _sd.client.get_mm_server_port() | |||
@@ -225,7 +225,7 @@ def get_mm_server_addr() -> Tuple[str, int]: | |||
return _sd.master_ip, _sd.mm_server_port | |||
def get_client() -> _Client: | |||
def get_client() -> Client: | |||
r"""Get client of python XML RPC server.""" | |||
assert _sd is not None, "please call init_process_group first" | |||
return _sd.client | |||
@@ -7,7 +7,7 @@ from weakref import WeakSet | |||
import numpy as np | |||
from megengine.autodiff.grad_manager import GradManager, _get_backwarding_grad_manager | |||
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | |||
@@ -78,7 +78,7 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): | |||
return apply(op, *inps, offsets)[0] | |||
def _get_offsets(shapes): | |||
def get_offsets(shapes): | |||
offsets = [] | |||
offset = 0 | |||
for shape in shapes: | |||
@@ -108,7 +108,7 @@ def _check_enable_p2p(): | |||
def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||
offsets_val = _get_offsets(shapes) | |||
offsets_val = get_offsets(shapes) | |||
offsets = Tensor(offsets_val) | |||
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | |||
@@ -119,7 +119,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||
return grads | |||
class _TensorFuture(Future): | |||
class TensorFuture(Future): | |||
def device(self): | |||
raise "Sorry, this tensor is not ready" | |||
@@ -234,13 +234,13 @@ class AllreduceCallback: | |||
self._packing_size[dtype] = 0 | |||
def __call__(self, param, grad): | |||
gm = _get_backwarding_grad_manager() | |||
gm = get_backwarding_grad_manager() | |||
assert isinstance(gm, GradManager) | |||
if gm not in self._marked_gm: | |||
gm._register_after_backward_callback(self._flush) | |||
self._marked_gm.add(gm) | |||
self._params.append(param) | |||
self._futures_dict[param] = _TensorFuture(ack=False) | |||
self._futures_dict[param] = TensorFuture(ack=False) | |||
self._gradients_dict[param] = grad | |||
self._grad_origin_device[param] = str(grad.device) | |||
@@ -10,7 +10,7 @@ from ..device import get_device_count | |||
from ..logger import get_logger | |||
from .group import _set_machine_ranks, group_barrier, init_process_group | |||
from .helper import _check_device_initialized, _check_interpreter_status | |||
from .server import Server | |||
from .server import Client, Server | |||
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( | |||
"subprocess exited with code 0 but did not return a value" | |||
@@ -12,7 +12,7 @@ from ..core._imperative_rt.utils import create_mm_server | |||
from ..utils.future import Future | |||
class _Methods: | |||
class Methods: | |||
r"""Distributed Server Method. | |||
Used for exchange information between distributed nodes. | |||
@@ -149,7 +149,7 @@ class _Methods: | |||
return ret | |||
class _ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||
pass | |||
@@ -163,10 +163,10 @@ def _start_server(py_server_port, queue): | |||
""" | |||
try: | |||
mm_server_port = create_mm_server("0.0.0.0", 0) | |||
server = _ThreadXMLRPCServer( | |||
server = ThreadXMLRPCServer( | |||
("0.0.0.0", py_server_port), logRequests=False, allow_none=True | |||
) | |||
server.register_instance(_Methods(mm_server_port)) | |||
server.register_instance(Methods(mm_server_port)) | |||
_, py_server_port = server.server_address | |||
queue.put((py_server_port, mm_server_port)) | |||
server.serve_forever() | |||
@@ -196,7 +196,7 @@ class Server: | |||
self.proc.terminate() | |||
class _Client: | |||
class Client: | |||
r"""Distributed Client for distributed training. | |||
Args: | |||
@@ -298,10 +298,10 @@ class _Client: | |||
return self.proxy.bcast_val(val, key, size) | |||
def _main(port=0, verbose=True): | |||
def main(port=0, verbose=True): | |||
mm_server_port = create_mm_server("0.0.0.0", 0) | |||
server = _ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) | |||
server.register_instance(_Methods(mm_server_port)) | |||
server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) | |||
server.register_instance(Methods(mm_server_port)) | |||
_, port = server.server_address | |||
print("serving on port", port) | |||
server.serve_forever() | |||
@@ -314,4 +314,4 @@ if __name__ == "__main__": | |||
ap.add_argument("-p", "--port", type=int, default=0) | |||
ap.add_argument("-v", "--verbose", type=bool, default=True) | |||
args = ap.parse_args() | |||
_main(port=args.port, verbose=args.verbose) | |||
main(port=args.port, verbose=args.verbose) |
@@ -1,11 +1,13 @@ | |||
# -*- coding: utf-8 -*- | |||
# pylint: disable=redefined-builtin | |||
from . import metric, utils, vision | |||
from .elemwise import * | |||
from .math import * | |||
from .nn import * | |||
from .tensor import * | |||
from .utils import * | |||
from . import utils, vision, distributed # isort:skip | |||
from . import distributed # isort:skip | |||
# delete namespace | |||
# pylint: disable=undefined-variable | |||
@@ -21,10 +21,6 @@ _valid_string_option = { | |||
} | |||
@deprecated( | |||
version="1.10", | |||
reason="use ``megengine.config.benchmark_kernel`` and ``megengine.config.deterministic_kernel`` instead", | |||
) | |||
def get_execution_strategy() -> Strategy: | |||
r"""Returns the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` | |||
@@ -40,10 +36,6 @@ def get_execution_strategy() -> Strategy: | |||
return strategy | |||
@deprecated( | |||
version="1.10", | |||
reason="use ``megengine.config.benchmark_kernel`` and ``megengine.config.deterministic_kernel`` instead", | |||
) | |||
def set_execution_strategy(option): | |||
r"""Sets the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` | |||
@@ -9,7 +9,7 @@ from ..core.tensor.array_method import _elwise | |||
from ..core.tensor.utils import convert_inputs | |||
from ..tensor import Tensor | |||
from ..utils.deprecation import deprecated_func | |||
from ._tensor_cache import get_scalar_one | |||
from .tensor_cache import get_scalar_one | |||
__all__ = [ | |||
"abs", | |||
@@ -43,7 +43,6 @@ from .debug_param import get_execution_strategy | |||
from .distributed import all_reduce_sum | |||
from .elemwise import _elwise, exp, log, log1p, maximum, minimum | |||
from .math import max, normalize, sum | |||
from .metric import topk_accuracy | |||
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros | |||
__all__ = [ | |||
@@ -87,7 +86,6 @@ __all__ = [ | |||
"softmax", | |||
"softplus", | |||
"sync_batch_norm", | |||
"topk_accuracy", | |||
"warp_affine", | |||
"warp_perspective", | |||
"pixel_shuffle", | |||
@@ -95,7 +93,7 @@ __all__ = [ | |||
] | |||
def _expand_hw(x): | |||
def expand_hw(x): | |||
# judge int is 5 times faster than judge Sequence | |||
if isinstance(x, int): | |||
return x, x | |||
@@ -104,7 +102,7 @@ def _expand_hw(x): | |||
return int(x), int(x) | |||
def _expand_dhw(x): | |||
def expand_dhw(x): | |||
if isinstance(x, int): | |||
return x, x, x | |||
if isinstance(x, Sequence): | |||
@@ -246,9 +244,9 @@ def conv2d( | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
stride_h, stride_w = _expand_hw(stride) | |||
pad_h, pad_w = _expand_hw(padding) | |||
dilate_h, dilate_w = _expand_hw(dilation) | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
@@ -308,9 +306,9 @@ def conv3d( | |||
D, H, W = 0, 1, 2 | |||
pad = _expand_dhw(padding) | |||
stride = _expand_dhw(stride) | |||
dilate = _expand_dhw(dilation) | |||
pad = expand_dhw(padding) | |||
stride = expand_dhw(stride) | |||
dilate = expand_dhw(dilation) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
op = builtin.Convolution3D( | |||
@@ -378,10 +376,10 @@ def conv_transpose2d( | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
stride_h, stride_w = _expand_hw(stride) | |||
pad_h, pad_w = _expand_hw(padding) | |||
output_pad_h, output_pad_w = _expand_hw(output_padding) | |||
dilate_h, dilate_w = _expand_hw(dilation) | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
output_pad_h, output_pad_w = expand_hw(output_padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
@@ -479,9 +477,9 @@ def deformable_conv2d( | |||
offset = offset.astype("float32") | |||
mask = mask.astype("float32") | |||
stride_h, stride_w = _expand_hw(stride) | |||
pad_h, pad_w = _expand_hw(padding) | |||
dilate_h, dilate_w = _expand_hw(dilation) | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
@@ -533,9 +531,9 @@ def local_conv2d( | |||
or conv_mode.name == "CROSS_CORRELATION" | |||
) | |||
stride_h, stride_w = _expand_hw(stride) | |||
pad_h, pad_w = _expand_hw(padding) | |||
dilate_h, dilate_w = _expand_hw(dilation) | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
# local conv only support "dense" mode, but weight could contain group dimension. | |||
op = builtin.GroupLocal( | |||
@@ -589,10 +587,10 @@ def conv_transpose3d( | |||
output tensor. | |||
""" | |||
D, H, W = 0, 1, 2 | |||
pad = _expand_dhw(padding) | |||
stride = _expand_dhw(stride) | |||
dilate = _expand_dhw(dilation) | |||
output_padding = _expand_dhw(output_padding) | |||
pad = expand_dhw(padding) | |||
stride = expand_dhw(stride) | |||
dilate = expand_dhw(dilation) | |||
output_padding = expand_dhw(output_padding) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
op = builtin.Convolution3DBackwardData( | |||
@@ -677,9 +675,9 @@ def max_pool2d( | |||
""" | |||
if stride is None: | |||
stride = kernel_size | |||
window_h, window_w = _expand_hw(kernel_size) | |||
stride_h, stride_w = _expand_hw(stride) | |||
padding_h, padding_w = _expand_hw(padding) | |||
window_h, window_w = expand_hw(kernel_size) | |||
stride_h, stride_w = expand_hw(stride) | |||
padding_h, padding_w = expand_hw(padding) | |||
op = builtin.Pooling( | |||
window_h=window_h, | |||
@@ -727,9 +725,9 @@ def avg_pool2d( | |||
""" | |||
if stride is None: | |||
stride = kernel_size | |||
window_h, window_w = _expand_hw(kernel_size) | |||
stride_h, stride_w = _expand_hw(stride) | |||
padding_h, padding_w = _expand_hw(padding) | |||
window_h, window_w = expand_hw(kernel_size) | |||
stride_h, stride_w = expand_hw(stride) | |||
padding_h, padding_w = expand_hw(padding) | |||
op = builtin.Pooling( | |||
window_h=window_h, | |||
@@ -1725,10 +1723,10 @@ def sliding_window( | |||
stride: stride of the window. Default: 1 | |||
dilation: dilation of the window. Default: 1 | |||
""" | |||
padding_h, padding_w = _expand_hw(padding) | |||
stride_h, stride_w = _expand_hw(stride) | |||
dilation_h, dilation_w = _expand_hw(dilation) | |||
window_h, window_w = _expand_hw(kernel_size) | |||
padding_h, padding_w = expand_hw(padding) | |||
stride_h, stride_w = expand_hw(stride) | |||
dilation_h, dilation_w = expand_hw(dilation) | |||
window_h, window_w = expand_hw(kernel_size) | |||
op = builtin.Images2Neibs( | |||
pad_h=padding_h, | |||
@@ -1764,11 +1762,11 @@ def sliding_window_transpose( | |||
stride: stride of the window. Default: 1 | |||
dilation: dilation of the window. Default: 1 | |||
""" | |||
output_h, output_w = _expand_hw(output_size) | |||
padding_h, padding_w = _expand_hw(padding) | |||
stride_h, stride_w = _expand_hw(stride) | |||
dilation_h, dilation_w = _expand_hw(dilation) | |||
window_h, window_w = _expand_hw(kernel_size) | |||
output_h, output_w = expand_hw(output_size) | |||
padding_h, padding_w = expand_hw(padding) | |||
stride_h, stride_w = expand_hw(stride) | |||
dilation_h, dilation_w = expand_hw(dilation) | |||
window_h, window_w = expand_hw(kernel_size) | |||
expected_h = ( | |||
output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 | |||
@@ -1921,7 +1919,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order): | |||
return layerPixelShuffle | |||
def _layerPixelShuffle_traceable(inp, upscale_factor): | |||
def layerPixelShuffle_traceable(inp, upscale_factor): | |||
assert upscale_factor > 0, "upscale_factor should larger than 0" | |||
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" | |||
assert ( | |||
@@ -1972,7 +1970,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
:param upscale_factor: upscale factor of pixel_shuffle. | |||
:return: output tensor. | |||
""" | |||
return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable) | |||
return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable) | |||
def region_restricted_conv( | |||
@@ -2014,9 +2012,9 @@ def region_restricted_conv( | |||
""" | |||
assert conv_mode.lower() == "cross_correlation" | |||
pad_h, pad_w = _expand_hw(padding) | |||
stride_h, stride_w = _expand_hw(stride) | |||
dilate_h, dilate_w = _expand_hw(dilation) | |||
pad_h, pad_w = expand_hw(padding) | |||
stride_h, stride_w = expand_hw(stride) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
op = builtin.RegionRestrictedConvolution( | |||
@@ -2038,4 +2036,5 @@ def region_restricted_conv( | |||
from .quantized import conv_bias_activation # isort:skip | |||
from .loss import * # isort:skip | |||
from .metric import * # isort:skip | |||
from .vision import * # isort:skip |
@@ -1,12 +1,12 @@ | |||
from ..core._imperative_rt.core2 import Const | |||
from ..jit.tracing import _is_tracing | |||
from ..jit.tracing import is_tracing | |||
small_tensor_cache = {} | |||
def _get_scalar_tensor_with_value(value, dtype=None, device=None): | |||
global small_tensor_cache | |||
if _is_tracing(): | |||
if is_tracing(): | |||
ret = Const(value, dtype, device) | |||
else: | |||
cache_key = (value, dtype, device) |
@@ -7,6 +7,8 @@ from ..utils.deprecation import deprecated_func | |||
from .elemwise import abs, maximum, minimum | |||
from .tensor import ones, zeros | |||
__all__ = ["topk_accuracy"] | |||
def _assert_equal( | |||
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False | |||
@@ -36,7 +36,7 @@ pattern = re.compile( | |||
) | |||
class _RepoFetcherBase: | |||
class RepoFetcherBase: | |||
@classmethod | |||
def fetch( | |||
cls, | |||
@@ -84,7 +84,7 @@ class _RepoFetcherBase: | |||
return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] | |||
class GitSSHFetcher(_RepoFetcherBase): | |||
class GitSSHFetcher(RepoFetcherBase): | |||
@classmethod | |||
@synchronized | |||
def fetch( | |||
@@ -193,7 +193,7 @@ class GitSSHFetcher(_RepoFetcherBase): | |||
) | |||
class GitHTTPSFetcher(_RepoFetcherBase): | |||
class GitHTTPSFetcher(RepoFetcherBase): | |||
@classmethod | |||
@synchronized | |||
def fetch( | |||
@@ -49,7 +49,7 @@ active_trace = None | |||
skip_tracing = False | |||
def _is_tracing(): | |||
def is_tracing(): | |||
if active_trace is None: | |||
return False | |||
else: | |||
@@ -73,7 +73,7 @@ def exclude_from_trace(): | |||
skip_tracing = False | |||
def _array_comparator(lhs, rhs): | |||
def array_comparator(lhs, rhs): | |||
return np.all(lhs == rhs) | |||
@@ -184,7 +184,7 @@ class trace: | |||
self._trace.no_exec = record_only | |||
self._trace.options_visitor = apply_options | |||
self._trace.profile = profiling | |||
self._trace.array_comparator = _array_comparator | |||
self._trace.array_comparator = array_comparator | |||
self._trace.record_input_shapes = _input_node_use_static_shape() | |||
def __call__(self, *args, **kwargs): | |||
@@ -18,10 +18,10 @@ def set_log_file(fout, mode="a"): | |||
""" | |||
if isinstance(fout, str): | |||
fout = open(fout, mode) | |||
_MegEngineLogFormatter.log_fout = fout | |||
MegEngineLogFormatter.log_fout = fout | |||
class _MegEngineLogFormatter(logging.Formatter): | |||
class MegEngineLogFormatter(logging.Formatter): | |||
log_fout = None | |||
date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " | |||
date = "%(asctime)s " | |||
@@ -71,7 +71,7 @@ class _MegEngineLogFormatter(logging.Formatter): | |||
if self.log_fout: | |||
self.__set_fmt(self.date_full + mtxt + self.msg) | |||
formatted = super(_MegEngineLogFormatter, self).format(record) | |||
formatted = super(MegEngineLogFormatter, self).format(record) | |||
nr_line = formatted.count("\n") + 1 | |||
if nr_line >= self.max_lines: | |||
head, body = formatted.split("\n", 1) | |||
@@ -88,7 +88,7 @@ class _MegEngineLogFormatter(logging.Formatter): | |||
self.log_fout.flush() | |||
self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) | |||
formatted = super(_MegEngineLogFormatter, self).format(record) | |||
formatted = super(MegEngineLogFormatter, self).format(record) | |||
if record.exc_text or record.exc_info: | |||
# handle exception format | |||
@@ -125,7 +125,7 @@ class _MegEngineLogFormatter(logging.Formatter): | |||
self._style._fmt = fmt | |||
def get_logger(name=None, formatter=_MegEngineLogFormatter): | |||
def get_logger(name=None, formatter=MegEngineLogFormatter): | |||
r"""Gets megengine logger with given name.""" | |||
logger = logging.getLogger(name) | |||
@@ -167,16 +167,16 @@ try: | |||
from .core._imperative_rt.utils import Logger as _imperative_rt_logger | |||
class _MegBrainLogFormatter(_MegEngineLogFormatter): | |||
class MegBrainLogFormatter(MegEngineLogFormatter): | |||
date = "%(asctime)s[mgb] " | |||
def _color_date(self, msg): | |||
return "\x1b[33m{}\x1b[0m".format(msg) | |||
_megbrain_logger = get_logger("megbrain", _MegBrainLogFormatter) | |||
_megbrain_logger = get_logger("megbrain", MegBrainLogFormatter) | |||
_imperative_rt_logger.set_log_handler(_megbrain_logger) | |||
def _set_mgb_log_level(level): | |||
def set_mgb_log_level(level): | |||
r"""Sets megbrain log level | |||
Args: | |||
@@ -200,30 +200,30 @@ try: | |||
) | |||
return rst | |||
_set_mgb_log_level(_default_level) | |||
set_mgb_log_level(_default_level) | |||
except ImportError as exc: | |||
def _set_mgb_log_level(level): | |||
def set_mgb_log_level(level): | |||
raise NotImplementedError("imperative_rt has not been imported") | |||
@contextlib.contextmanager | |||
def _replace_mgb_log_level(level): | |||
def replace_mgb_log_level(level): | |||
r"""Replaces megbrain log level in a block and restore after exiting. | |||
Args: | |||
level: new log level | |||
""" | |||
old = _set_mgb_log_level(level) | |||
old = set_mgb_log_level(level) | |||
try: | |||
yield | |||
finally: | |||
_set_mgb_log_level(old) | |||
set_mgb_log_level(old) | |||
def enable_debug_log(): | |||
r"""Sets logging level to debug for all components.""" | |||
set_log_level(logging.DEBUG) | |||
_set_mgb_log_level(logging.DEBUG) | |||
set_mgb_log_level(logging.DEBUG) |
@@ -15,12 +15,12 @@ from . import init | |||
from .module import Module | |||
class _RNNCellBase(Module): | |||
class RNNCellBase(Module): | |||
def __init__( | |||
self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, | |||
) -> None: | |||
# num_chunks indicates the number of gates | |||
super(_RNNCellBase, self).__init__() | |||
super(RNNCellBase, self).__init__() | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
@@ -57,7 +57,7 @@ class _RNNCellBase(Module): | |||
raise NotImplementedError("forward not implemented !") | |||
class RNNCell(_RNNCellBase): | |||
class RNNCell(RNNCellBase): | |||
r"""An Elman RNN cell with tanh or ReLU non-linearity. | |||
@@ -135,7 +135,7 @@ class RNNCell(_RNNCellBase): | |||
)[0] | |||
class LSTMCell(_RNNCellBase): | |||
class LSTMCell(RNNCellBase): | |||
r"""A long short-term memory (LSTM) cell. | |||
@@ -216,7 +216,7 @@ class LSTMCell(_RNNCellBase): | |||
)[:2] | |||
class _RNNBase(Module): | |||
class RNNBase(Module): | |||
def __init__( | |||
self, | |||
input_size: int, | |||
@@ -228,7 +228,7 @@ class _RNNBase(Module): | |||
bidirectional: bool = False, | |||
proj_size: int = 0, | |||
) -> None: | |||
super(_RNNBase, self).__init__() | |||
super(RNNBase, self).__init__() | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
@@ -323,7 +323,7 @@ class _RNNBase(Module): | |||
return output, h | |||
class RNN(_RNNBase): | |||
class RNN(RNNBase): | |||
r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an | |||
input sequence. | |||
@@ -453,7 +453,7 @@ class RNN(_RNNBase): | |||
return output, h | |||
class LSTM(_RNNBase): | |||
class LSTM(RNNBase): | |||
r"""Applies a multi-layer long short-term memory LSTM to an input | |||
sequence. | |||
@@ -7,7 +7,7 @@ import numpy as np | |||
from .. import functional as F | |||
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | |||
from ..distributed import WORLD, is_distributed | |||
from ..distributed import WORLD, get_rank, is_distributed | |||
from ..functional.distributed import all_reduce_max, all_reduce_min | |||
from ..logger import get_logger | |||
from ..module import Module | |||
@@ -66,7 +66,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL): | |||
pickle_module.dump(obj, f, pickle_protocol) | |||
class _dmap: | |||
class dmap: | |||
def __init__(self, map_location): | |||
self.map_location = map_location | |||
@@ -177,5 +177,5 @@ def load(f, map_location=None, pickle_module=pickle): | |||
map_location = _get_callable_map_location(map_location) # callable map_location | |||
with _dmap(map_location) as dm: | |||
with dmap(map_location) as dm: | |||
return pickle_module.load(f) |
@@ -33,7 +33,6 @@ def deprecated_func(version, origin, name, tbd): | |||
) | |||
return func(*args, **kwargs) | |||
wrapper.__deprecated__ = True | |||
return wrapper | |||
@@ -58,7 +57,6 @@ def deprecated_kwargs_default(version, kwargs_name, kwargs_pos): | |||
) | |||
return func(*args, **kwargs) | |||
wrapper.__deprecated__ = True | |||
return wrapper | |||
return deprecated |
@@ -11,11 +11,11 @@ from .. import functional as F | |||
from .. import get_logger | |||
from .. import module as M | |||
from ..core.tensor.dtype import get_dtype_bit | |||
from ..logger import _MegEngineLogFormatter | |||
from ..logger import MegEngineLogFormatter | |||
from .module_utils import set_module_mode_safe | |||
try: | |||
_MegEngineLogFormatter.max_lines = float("inf") | |||
MegEngineLogFormatter.max_lines = float("inf") | |||
except AttributeError as e: | |||
raise ValueError("set logger max lines failed") | |||
@@ -2,14 +2,14 @@ | |||
import logging | |||
from megengine.core._imperative_rt import Logger | |||
from megengine.logger import _imperative_rt_logger, _set_mgb_log_level | |||
from megengine.logger import _imperative_rt_logger, set_mgb_log_level | |||
def test_logger(): | |||
orig_level = Logger().set_log_level(Logger.LogLevel.Debug) | |||
assert Logger().set_log_level(Logger.LogLevel.Debug) == Logger.LogLevel.Debug | |||
Logger().set_log_level(orig_level) | |||
orig_level = _set_mgb_log_level(logging.DEBUG) | |||
orig_level = set_mgb_log_level(logging.DEBUG) | |||
assert ( | |||
_imperative_rt_logger.set_log_level(Logger.LogLevel.Debug) | |||
== Logger.LogLevel.Debug | |||
@@ -50,7 +50,7 @@ def test_init_process_group(backend): | |||
assert mm_server_addr[0] == "localhost" | |||
assert mm_server_addr[1] > 0 | |||
assert isinstance(dist.get_client(), dist.server._Client) | |||
assert isinstance(dist.get_client(), dist.Client) | |||
procs = [] | |||
for rank in range(world_size): | |||