@@ -30,7 +30,7 @@ class Parameter(Tensor): | |||||
else: | else: | ||||
t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) | t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) | ||||
self.__dict__.update(t.__dict__) | self.__dict__.update(t.__dict__) | ||||
@property | @property | ||||
def shape(self): | def shape(self): | ||||
r"""Return shape of parameter. | r"""Return shape of parameter. | ||||
@@ -12,9 +12,9 @@ | |||||
# | # | ||||
# Copyright (c) 2018 Facebook | # Copyright (c) 2018 Facebook | ||||
# --------------------------------------------------------------------- | # --------------------------------------------------------------------- | ||||
from collections import OrderedDict, defaultdict | |||||
import json | import json | ||||
import os | import os | ||||
from collections import OrderedDict, defaultdict | |||||
import cv2 | import cv2 | ||||
import numpy as np | import numpy as np | ||||
@@ -87,7 +87,7 @@ class ImageNet(ImageFolder): | |||||
if not os.path.exists(self.root): | if not os.path.exists(self.root): | ||||
raise FileNotFoundError("dir %s does not exist" % self.root) | raise FileNotFoundError("dir %s does not exist" % self.root) | ||||
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | ||||
if not os.path.exists(self.devkit_dir): | if not os.path.exists(self.devkit_dir): | ||||
@@ -159,8 +159,14 @@ class ImageNet(ImageFolder): | |||||
classes = [tuple(clss.split(", ")) for clss in classes] | classes = [tuple(clss.split(", ")) for clss in classes] | ||||
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} | idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} | ||||
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} | wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} | ||||
logger.info("saving cached meta file to %s", os.path.join(self.devkit_dir, "meta.pkl")) | |||||
save((idx_to_wnid, wnid_to_classes), os.path.join(self.devkit_dir, "meta.pkl")) | |||||
logger.info( | |||||
"saving cached meta file to %s", | |||||
os.path.join(self.devkit_dir, "meta.pkl"), | |||||
) | |||||
save( | |||||
(idx_to_wnid, wnid_to_classes), | |||||
os.path.join(self.devkit_dir, "meta.pkl"), | |||||
) | |||||
return idx_to_wnid, wnid_to_classes | return idx_to_wnid, wnid_to_classes | ||||
def check_raw_file(self) -> bool: | def check_raw_file(self) -> bool: | ||||
@@ -177,7 +183,10 @@ class ImageNet(ImageFolder): | |||||
val_wnids = [id2wnid[idx] for idx in val_idcs] | val_wnids = [id2wnid[idx] for idx in val_idcs] | ||||
val_images = sorted( | val_images = sorted( | ||||
[os.path.join(self.target_folder, image) for image in os.listdir(self.target_folder)] | |||||
[ | |||||
os.path.join(self.target_folder, image) | |||||
for image in os.listdir(self.target_folder) | |||||
] | |||||
) | ) | ||||
logger.debug("mkdir for val set wnids") | logger.debug("mkdir for val set wnids") | ||||
@@ -198,23 +207,24 @@ class ImageNet(ImageFolder): | |||||
raw_filename, checksum = self.raw_file_meta["val"] | raw_filename, checksum = self.raw_file_meta["val"] | ||||
raw_file = os.path.join(self.root, raw_filename) | raw_file = os.path.join(self.root, raw_filename) | ||||
logger.info("checksum valid tar file {} ..".format(raw_file)) | logger.info("checksum valid tar file {} ..".format(raw_file)) | ||||
assert calculate_md5(raw_file) == checksum, \ | |||||
"checksum mismatch, {} may be damaged".format(raw_file) | |||||
assert ( | |||||
calculate_md5(raw_file) == checksum | |||||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||||
logger.info("extract valid tar file.. this may take 10-20 minutes") | logger.info("extract valid tar file.. this may take 10-20 minutes") | ||||
untar(os.path.join(self.root, raw_file), self.target_folder) | untar(os.path.join(self.root, raw_file), self.target_folder) | ||||
self._organize_val_data() | self._organize_val_data() | ||||
def _prepare_train(self): | def _prepare_train(self): | ||||
assert self.train | assert self.train | ||||
raw_filename, checksum = self.raw_file_meta["train"] | raw_filename, checksum = self.raw_file_meta["train"] | ||||
raw_file = os.path.join(self.root, raw_filename) | raw_file = os.path.join(self.root, raw_filename) | ||||
logger.info("checksum train tar file {} ..".format(raw_file)) | logger.info("checksum train tar file {} ..".format(raw_file)) | ||||
assert calculate_md5(raw_file) == checksum, \ | |||||
"checksum mismatch, {} may be damaged".format(raw_file) | |||||
assert ( | |||||
calculate_md5(raw_file) == checksum | |||||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||||
logger.info("extract train tar file.. this may take several hours") | logger.info("extract train tar file.. this may take several hours") | ||||
untar( | untar( | ||||
os.path.join(self.root, raw_file), | |||||
self.target_folder, | |||||
os.path.join(self.root, raw_file), self.target_folder, | |||||
) | ) | ||||
paths = [ | paths = [ | ||||
os.path.join(self.target_folder, child_dir) | os.path.join(self.target_folder, child_dir) | ||||
@@ -227,7 +237,8 @@ class ImageNet(ImageFolder): | |||||
raw_filename, checksum = self.raw_file_meta["devkit"] | raw_filename, checksum = self.raw_file_meta["devkit"] | ||||
raw_file = os.path.join(self.root, raw_filename) | raw_file = os.path.join(self.root, raw_filename) | ||||
logger.info("checksum devkit tar file {} ..".format(raw_file)) | logger.info("checksum devkit tar file {} ..".format(raw_file)) | ||||
assert calculate_md5(raw_file) == checksum, \ | |||||
"checksum mismatch, {} may be damaged".format(raw_file) | |||||
assert ( | |||||
calculate_md5(raw_file) == checksum | |||||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||||
logger.info("extract devkit file..") | logger.info("extract devkit file..") | ||||
untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) | untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) |
@@ -7,8 +7,8 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import hashlib | import hashlib | ||||
import tarfile | |||||
import os | import os | ||||
import tarfile | |||||
from ....distributed.util import is_distributed | from ....distributed.util import is_distributed | ||||
from ....logger import get_logger | from ....logger import get_logger | ||||
@@ -46,16 +46,16 @@ __all__ = [ | |||||
def _elemwise(mode): # DONT export | def _elemwise(mode): # DONT export | ||||
"""Decorator helps to wrap megbrain element-wise oprs""" | """Decorator helps to wrap megbrain element-wise oprs""" | ||||
def elemwise_decorator(func): | def elemwise_decorator(func): | ||||
@functools.wraps(func) | @functools.wraps(func) | ||||
@wrap_io_tensor | @wrap_io_tensor | ||||
def elemwise_func(*inputs) -> Tensor: | def elemwise_func(*inputs) -> Tensor: | ||||
if all(isinstance(i, (int,float)) for i in inputs): | |||||
if all(isinstance(i, (int, float)) for i in inputs): | |||||
device, comp_graph = _use_default_if_none(None, None) | device, comp_graph = _use_default_if_none(None, None) | ||||
ret = mgb.opr.elemwise(*inputs, | |||||
mode=mode, | |||||
comp_node=device, | |||||
comp_graph=comp_graph) | |||||
ret = mgb.opr.elemwise( | |||||
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph | |||||
) | |||||
return ret.inferred_value[0] | return ret.inferred_value[0] | ||||
return mgb.opr.elemwise(*inputs, mode=mode) | return mgb.opr.elemwise(*inputs, mode=mode) | ||||
@@ -14,6 +14,6 @@ from .embedding import Embedding | |||||
from .identity import Identity | from .identity import Identity | ||||
from .linear import Linear | from .linear import Linear | ||||
from .module import Module | from .module import Module | ||||
from .parampack import ParamPack | |||||
from .pooling import AvgPool2d, MaxPool2d | from .pooling import AvgPool2d, MaxPool2d | ||||
from .sequential import Sequential | from .sequential import Sequential | ||||
from .parampack import ParamPack |
@@ -12,7 +12,7 @@ from typing import Optional, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from ..core import Tensor, Graph | |||||
from ..core import Graph, Tensor | |||||
from ..random import gaussian, uniform | from ..random import gaussian, uniform | ||||
@@ -168,10 +168,9 @@ class Module(metaclass=ABCMeta): | |||||
""" | """ | ||||
yield from self._flatten(predicate=_is_buffer, recursive=recursive) | yield from self._flatten(predicate=_is_buffer, recursive=recursive) | ||||
def replace_param(self, | |||||
params: dict, | |||||
start_pos: int, | |||||
seen: Optional[Set[int]] = None): | |||||
def replace_param( | |||||
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None | |||||
): | |||||
offset = 0 | offset = 0 | ||||
if seen is None: | if seen is None: | ||||
seen = set([id(self)]) | seen = set([id(self)]) | ||||
@@ -183,12 +182,13 @@ class Module(metaclass=ABCMeta): | |||||
seen.add(hash_id) | seen.add(hash_id) | ||||
if isinstance(module_dict[key], Parameter): | if isinstance(module_dict[key], Parameter): | ||||
if start_pos + offset in params: | if start_pos + offset in params: | ||||
assert module_dict[key].shape == params[start_pos + | |||||
offset].shape | |||||
assert module_dict[key].shape == params[start_pos + offset].shape | |||||
module_dict[key] = params[start_pos + offset] | module_dict[key] = params[start_pos + offset] | ||||
offset += 1 | offset += 1 | ||||
if isinstance(module_dict[key], Module): | if isinstance(module_dict[key], Module): | ||||
offset += module_dict[key].replace_param(params, start_pos + offset, seen) | |||||
offset += module_dict[key].replace_param( | |||||
params, start_pos + offset, seen | |||||
) | |||||
return offset | return offset | ||||
def named_buffers( | def named_buffers( | ||||
@@ -8,11 +8,12 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | import collections | ||||
from typing import Iterable, Optional | from typing import Iterable, Optional | ||||
import numpy as np | import numpy as np | ||||
from .._internal.opr import param_pack_split | |||||
from ..core import Parameter, Tensor | from ..core import Parameter, Tensor | ||||
from .module import Module | from .module import Module | ||||
from .._internal.opr import param_pack_split | |||||
class ParamPack(Module): | class ParamPack(Module): | ||||
@@ -24,11 +25,14 @@ class ParamPack(Module): | |||||
:param max_nr_params_per_group: upper bound of the number of parameters of each group. | :param max_nr_params_per_group: upper bound of the number of parameters of each group. | ||||
""" | """ | ||||
def __init__(self, | |||||
model: Module, | |||||
nr_ignore_first:int = 8, | |||||
max_size_per_group: int = 10, | |||||
max_nr_params_per_group: int = 100): | |||||
def __init__( | |||||
self, | |||||
model: Module, | |||||
nr_ignore_first: int = 8, | |||||
max_size_per_group: int = 10, | |||||
max_nr_params_per_group: int = 100, | |||||
): | |||||
super().__init__() | super().__init__() | ||||
self._model = model | self._model = model | ||||
self._nr_ignore_first = nr_ignore_first | self._nr_ignore_first = nr_ignore_first | ||||
@@ -52,11 +56,11 @@ class ParamPack(Module): | |||||
for param in params: | for param in params: | ||||
if self._nr_ignore_first > ignored: | if self._nr_ignore_first > ignored: | ||||
ignored += 1 | ignored += 1 | ||||
self._grouped_params.append([{'tensor': param, 'id': param_id}]) | |||||
self._grouped_params.append([{"tensor": param, "id": param_id}]) | |||||
self._packed_params.append(param) | self._packed_params.append(param) | ||||
else: | else: | ||||
key = (param.dtype, param.device, param.requires_grad) | key = (param.dtype, param.device, param.requires_grad) | ||||
groups[key].append({'tensor': param, 'id': param_id}) | |||||
groups[key].append({"tensor": param, "id": param_id}) | |||||
param_id += 1 | param_id += 1 | ||||
for (dtype, device, requires_grad) in groups.keys(): | for (dtype, device, requires_grad) in groups.keys(): | ||||
dtype_sz = np.dtype(dtype).itemsize | dtype_sz = np.dtype(dtype).itemsize | ||||
@@ -75,33 +79,36 @@ class ParamPack(Module): | |||||
idx = 0 | idx = 0 | ||||
while idx < len(group): | while idx < len(group): | ||||
param = group[idx] | param = group[idx] | ||||
assert param['tensor'].device == device | |||||
assert param["tensor"].device == device | |||||
padding = (align - (offset & (align - 1))) & (align - 1) | padding = (align - (offset & (align - 1))) & (align - 1) | ||||
offset += padding | offset += padding | ||||
aligned_pos.append(offset) | aligned_pos.append(offset) | ||||
params.append(param) | params.append(param) | ||||
offset += int(np.prod(param['tensor'].shape)) | |||||
offset += int(np.prod(param["tensor"].shape)) | |||||
idx += 1 | idx += 1 | ||||
if (offset * dtype_sz >= | |||||
self._max_size_per_group * 1024 * 1024 | |||||
or idx >= self._max_nr_params_per_group): | |||||
if ( | |||||
offset * dtype_sz >= self._max_size_per_group * 1024 * 1024 | |||||
or idx >= self._max_nr_params_per_group | |||||
): | |||||
break | break | ||||
group = group[idx:] | group = group[idx:] | ||||
if idx == 1: | if idx == 1: | ||||
# ignore param packs with only one item | # ignore param packs with only one item | ||||
self._packed_params.append(params[0]['tensor']) | |||||
self._packed_params.append(params[0]["tensor"]) | |||||
self._grouped_params.append(params) | self._grouped_params.append(params) | ||||
continue | continue | ||||
packed_value = np.zeros((offset, ), dtype=dtype) | |||||
packed_value = np.zeros((offset,), dtype=dtype) | |||||
for param, pos in zip(params, aligned_pos): | for param, pos in zip(params, aligned_pos): | ||||
val = param['tensor'].numpy() | |||||
packed_value[pos:pos + val.size] = val.flatten() | |||||
new_param = Parameter(value=packed_value, | |||||
device=device, | |||||
dtype=dtype, | |||||
requires_grad=requires_grad) | |||||
val = param["tensor"].numpy() | |||||
packed_value[pos : pos + val.size] = val.flatten() | |||||
new_param = Parameter( | |||||
value=packed_value, | |||||
device=device, | |||||
dtype=dtype, | |||||
requires_grad=requires_grad, | |||||
) | |||||
self._packed_params.append(new_param) | self._packed_params.append(new_param) | ||||
self._grouped_params.append(params) | self._grouped_params.append(params) | ||||
@@ -112,14 +119,15 @@ class ParamPack(Module): | |||||
grouped_params = self._grouped_params[i] | grouped_params = self._grouped_params[i] | ||||
if len(grouped_params) == 1: | if len(grouped_params) == 1: | ||||
continue | continue | ||||
split = param_pack_split(packed_param._symvar, | |||||
[i['tensor'].shape for i in grouped_params]) | |||||
split = param_pack_split( | |||||
packed_param._symvar, [i["tensor"].shape for i in grouped_params] | |||||
) | |||||
split = [ | split = [ | ||||
Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) | Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) | ||||
for i in split | for i in split | ||||
] | ] | ||||
for j in range(len(split)): | for j in range(len(split)): | ||||
replace_param[grouped_params[j]['id']] = split[j] | |||||
replace_param[grouped_params[j]["id"]] = split[j] | |||||
self._model.replace_param(replace_param, 0) | self._model.replace_param(replace_param, 0) | ||||
return self._model.forward(*args, **kwargs) | return self._model.forward(*args, **kwargs) |
@@ -75,10 +75,9 @@ class XORNet(Module): | |||||
@pytest.mark.slow | @pytest.mark.slow | ||||
def test_static_graph_parampack(): | def test_static_graph_parampack(): | ||||
net = XORNet() | net = XORNet() | ||||
net = ParamPack(net, | |||||
nr_ignore_first=0, | |||||
max_size_per_group=10, | |||||
max_nr_params_per_group=100) | |||||
net = ParamPack( | |||||
net, nr_ignore_first=0, max_size_per_group=10, max_nr_params_per_group=100 | |||||
) | |||||
opt = SGD( | opt = SGD( | ||||
net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | ||||
) | ) | ||||
@@ -110,12 +109,11 @@ def test_static_graph_parampack(): | |||||
pred = infer(data).numpy() | pred = infer(data).numpy() | ||||
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" | assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" | ||||
@pytest.mark.slow | @pytest.mark.slow | ||||
def test_nopack_parampack(): | def test_nopack_parampack(): | ||||
net = XORNet() | net = XORNet() | ||||
net = ParamPack(net, | |||||
max_size_per_group=0, | |||||
max_nr_params_per_group=0) | |||||
net = ParamPack(net, max_size_per_group=0, max_nr_params_per_group=0) | |||||
opt = SGD( | opt = SGD( | ||||
net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | ||||
) | ) | ||||
@@ -146,13 +144,13 @@ def test_nopack_parampack(): | |||||
pred = infer(data).numpy() | pred = infer(data).numpy() | ||||
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" | assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" | ||||
@pytest.mark.slow | @pytest.mark.slow | ||||
def test_dynamic_graph_parampack(): | def test_dynamic_graph_parampack(): | ||||
net = XORNet() | net = XORNet() | ||||
net = ParamPack(net, | |||||
nr_ignore_first=0, | |||||
max_size_per_group=10, | |||||
max_nr_params_per_group=100) | |||||
net = ParamPack( | |||||
net, nr_ignore_first=0, max_size_per_group=10, max_nr_params_per_group=100 | |||||
) | |||||
opt = SGD( | opt = SGD( | ||||
net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | ||||
) | ) | ||||
@@ -184,6 +182,7 @@ def test_dynamic_graph_parampack(): | |||||
pred = infer(data).numpy() | pred = infer(data).numpy() | ||||
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" | assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" | ||||
@pytest.mark.slow | @pytest.mark.slow | ||||
def test_correctness_parampack(): | def test_correctness_parampack(): | ||||
net1 = XORNet() | net1 = XORNet() | ||||
@@ -192,10 +191,9 @@ def test_correctness_parampack(): | |||||
params2 = net2.parameters() | params2 = net2.parameters() | ||||
for param1, param2 in zip(params1, params2): | for param1, param2 in zip(params1, params2): | ||||
param1.set_value(param2.numpy()) | param1.set_value(param2.numpy()) | ||||
net1 = ParamPack(net1, | |||||
nr_ignore_first=0, | |||||
max_size_per_group=10, | |||||
max_nr_params_per_group=100) | |||||
net1 = ParamPack( | |||||
net1, nr_ignore_first=0, max_size_per_group=10, max_nr_params_per_group=100 | |||||
) | |||||
opt1 = SGD( | opt1 = SGD( | ||||
net1.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | net1.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | ||||
) | ) | ||||
@@ -10,31 +10,37 @@ import numpy as np | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
def test_abs(): | def test_abs(): | ||||
assertTensorClose( | assertTensorClose( | ||||
F.abs(tensor([-3., -4., -5.])).numpy(), | |||||
np.abs(np.array([-3., -4., -5.], dtype=np.float32))) | |||||
F.abs(tensor([-3.0, -4.0, -5.0])).numpy(), | |||||
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), | |||||
) | |||||
assertTensorClose(F.abs(-3.), np.abs(np.float32(-3.))) | |||||
assertTensorClose(F.abs(-3.0), np.abs(np.float32(-3.0))) | |||||
def test_multiply(): | def test_multiply(): | ||||
assertTensorClose(F.multiply(-3., -4.), | |||||
np.multiply(np.float32(-3.), np.float32(-4.))) | |||||
assertTensorClose( | |||||
F.multiply(-3.0, -4.0), np.multiply(np.float32(-3.0), np.float32(-4.0)) | |||||
) | |||||
assertTensorClose( | assertTensorClose( | ||||
F.multiply(tensor([3., 4.]), 4.).numpy(), | |||||
np.multiply(np.array([3., 4.], dtype=np.float32), 4.)) | |||||
F.multiply(tensor([3.0, 4.0]), 4.0).numpy(), | |||||
np.multiply(np.array([3.0, 4.0], dtype=np.float32), 4.0), | |||||
) | |||||
assertTensorClose( | assertTensorClose( | ||||
F.multiply(4., tensor([3., 4.])).numpy(), | |||||
np.multiply(4., np.array([3., 4.], dtype=np.float32))) | |||||
F.multiply(4.0, tensor([3.0, 4.0])).numpy(), | |||||
np.multiply(4.0, np.array([3.0, 4.0], dtype=np.float32)), | |||||
) | |||||
assertTensorClose( | assertTensorClose( | ||||
F.multiply(tensor([3., 4.]), tensor([3., 4.])).numpy(), | |||||
np.multiply(np.array([3., 4.], dtype=np.float32), | |||||
np.array([3., 4.], dtype=np.float32))) | |||||
F.multiply(tensor([3.0, 4.0]), tensor([3.0, 4.0])).numpy(), | |||||
np.multiply( | |||||
np.array([3.0, 4.0], dtype=np.float32), | |||||
np.array([3.0, 4.0], dtype=np.float32), | |||||
), | |||||
) |
@@ -15,10 +15,10 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
import megengine.module as M | |||||
from megengine import jit, tensor | from megengine import jit, tensor | ||||
from megengine.core.tensor import Tensor | from megengine.core.tensor import Tensor | ||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
import megengine.module as M | |||||
@contextlib.contextmanager | @contextlib.contextmanager | ||||
@@ -158,13 +158,14 @@ def test_shape_infer(): | |||||
def test_dump_bn_fused(): | def test_dump_bn_fused(): | ||||
class ConvBNReLU(M.Sequential): | class ConvBNReLU(M.Sequential): | ||||
def __init__(self): | def __init__(self): | ||||
super(ConvBNReLU, self).__init__( | super(ConvBNReLU, self).__init__( | ||||
M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False), | M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False), | ||||
M.BatchNorm2d(4), | M.BatchNorm2d(4), | ||||
M.ReLU()) | |||||
M.ReLU(), | |||||
) | |||||
net = ConvBNReLU() | net = ConvBNReLU() | ||||
net.eval() | net.eval() | ||||
@@ -178,8 +179,9 @@ def test_dump_bn_fused(): | |||||
fun.dump(out, optimize_for_inference=True) | fun.dump(out, optimize_for_inference=True) | ||||
cg, _, outputs = mgb.load_comp_graph_from_file(out) | cg, _, outputs = mgb.load_comp_graph_from_file(out) | ||||
out, = outputs | |||||
(out,) = outputs | |||||
inputs = mgb.cgtools.get_inputs(out) | inputs = mgb.cgtools.get_inputs(out) | ||||
assert len(inputs) == 2 and ( | assert len(inputs) == 2 and ( | ||||
mgb.cgtools.get_type(inputs[0]) == 'MultipleDeviceTensorHolder' and | |||||
mgb.cgtools.get_type(inputs[1]) == 'ConvolutionForward') | |||||
mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" | |||||
and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" | |||||
) |