GitOrigin-RevId: 0d6bb20b2b
release-1.6
@@ -201,7 +201,8 @@ class Apply(Expr): | |||||
NodeMixin.wrap_safe(i, Constant.make(i)) | NodeMixin.wrap_safe(i, Constant.make(i)) | ||||
apply_node = cls.make(opdef) | apply_node = cls.make(opdef) | ||||
for i in inputs: | for i in inputs: | ||||
apply_node.add_input(NodeMixin.get(i)) | |||||
assert isinstance(i, RawTensor) | |||||
apply_node.inputs.append(NodeMixin.get(i)) | |||||
unset_module_tracing() | unset_module_tracing() | ||||
outputs = apply(opdef, *inputs) | outputs = apply(opdef, *inputs) | ||||
@@ -1,3 +1,13 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import collections | |||||
from typing import Callable, NamedTuple | from typing import Callable, NamedTuple | ||||
SUPPORTED_TYPE = {} | SUPPORTED_TYPE = {} | ||||
@@ -9,11 +19,22 @@ def register_supported_type(type, flatten, unflatten): | |||||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | ||||
def _dict_flatten(inp): | |||||
aux_data = [] | |||||
results = [] | |||||
for key, value in sorted(inp.items()): | |||||
results.append(value) | |||||
aux_data.append(key) | |||||
return results, aux_data | |||||
def _dict_unflatten(inps, aux_data): | |||||
return dict(zip(aux_data, inps)) | |||||
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | ||||
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | ||||
register_supported_type( | |||||
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) | |||||
) | |||||
register_supported_type(dict, _dict_flatten, _dict_unflatten) | |||||
register_supported_type( | register_supported_type( | ||||
slice, | slice, | ||||
lambda x: ([x.start, x.stop, x.step], None), | lambda x: ([x.start, x.stop, x.step], None), | ||||
@@ -68,6 +89,8 @@ class TreeDef: | |||||
class LeafDef(TreeDef): | class LeafDef(TreeDef): | ||||
def __init__(self, type): | def __init__(self, type): | ||||
if not isinstance(type, collections.abc.Sequence): | |||||
type = (type,) | |||||
super().__init__(type, None, []) | super().__init__(type, None, []) | ||||
self.num_leaves = 1 | self.num_leaves = 1 | ||||
@@ -77,4 +100,4 @@ class LeafDef(TreeDef): | |||||
return leaves[0] | return leaves[0] | ||||
def __repr__(self): | def __repr__(self): | ||||
return "Leaf({})".format(self.type.__name__) | |||||
return "Leaf({})".format(", ".join(t.__name__ for t in self.type)) |
@@ -14,6 +14,7 @@ import megengine as mge | |||||
import megengine.autodiff as ad | import megengine.autodiff as ad | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import Tensor | from megengine import Tensor | ||||
from megengine.experimental.traced_module import trace_module | |||||
from megengine.module import Linear, Module | from megengine.module import Linear, Module | ||||
from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
@@ -71,8 +72,13 @@ class XORNet(Module): | |||||
return x | return x | ||||
def test_training_converge(): | |||||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
def test_training_converge(test_traced_module): | |||||
net = XORNet() | net = XORNet() | ||||
if test_training_converge: | |||||
inp = Tensor(np.random.random((14, 2))) | |||||
net = trace_module(net, inp) | |||||
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | ||||
gm = ad.GradManager().attach(net.parameters()) | gm = ad.GradManager().attach(net.parameters()) | ||||
@@ -105,9 +111,8 @@ def test_training_converge(): | |||||
xx = xx.reshape((ngrid * ngrid, 1)) | xx = xx.reshape((ngrid * ngrid, 1)) | ||||
yy = yy.reshape((ngrid * ngrid, 1)) | yy = yy.reshape((ngrid * ngrid, 1)) | ||||
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | ||||
pred = infer(data).numpy() | |||||
precision = calculate_precision(data.numpy(), pred) | |||||
pred = infer(data) | |||||
precision = calculate_precision(data.numpy(), pred.numpy()) | |||||
assert precision == 1.0, "Test precision must be high enough, get {}".format( | assert precision == 1.0, "Test precision must be high enough, get {}".format( | ||||
precision | precision | ||||
) | ) |
@@ -15,6 +15,7 @@ import megengine.autodiff as ad | |||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.optimizer as optim | import megengine.optimizer as optim | ||||
from megengine import Tensor | from megengine import Tensor | ||||
from megengine.experimental.traced_module import trace_module | |||||
from megengine.jit import trace | from megengine.jit import trace | ||||
from megengine.module import Linear, Module | from megengine.module import Linear, Module | ||||
from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
@@ -73,8 +74,12 @@ class XORNet(Module): | |||||
return x | return x | ||||
def test_training_converge(): | |||||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
def test_training_converge(test_traced_module): | |||||
net = XORNet() | net = XORNet() | ||||
if test_traced_module: | |||||
inp = Tensor(np.random.random((14, 2))) | |||||
net = trace_module(net, inp) | |||||
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | ||||
gm = ad.GradManager().attach(net.parameters()) | gm = ad.GradManager().attach(net.parameters()) | ||||
@@ -110,9 +115,8 @@ def test_training_converge(): | |||||
xx = xx.reshape((ngrid * ngrid, 1)) | xx = xx.reshape((ngrid * ngrid, 1)) | ||||
yy = yy.reshape((ngrid * ngrid, 1)) | yy = yy.reshape((ngrid * ngrid, 1)) | ||||
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | ||||
pred = infer(data).numpy() | |||||
precision = calculate_precision(data.numpy(), pred) | |||||
pred = infer(data) | |||||
precision = calculate_precision(data.numpy(), pred.numpy()) | |||||
print("precision=", precision) | print("precision=", precision) | ||||
assert precision == 1.0, "Test precision must be high enough, get {}".format( | assert precision == 1.0, "Test precision must be high enough, get {}".format( | ||||
precision | precision | ||||
@@ -19,6 +19,7 @@ import megengine.module as M | |||||
import megengine.optimizer as optim | import megengine.optimizer as optim | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
from megengine.experimental.traced_module import trace_module | |||||
from megengine.jit import trace | from megengine.jit import trace | ||||
@@ -15,6 +15,7 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import Parameter, Tensor, tensor | from megengine import Parameter, Tensor, tensor | ||||
from megengine.experimental.traced_module import TracedModule, trace_module | |||||
from megengine.module import ( | from megengine.module import ( | ||||
BatchNorm1d, | BatchNorm1d, | ||||
BatchNorm2d, | BatchNorm2d, | ||||
@@ -67,8 +68,18 @@ class MyModule(Module): | |||||
return x | return x | ||||
def test_module_api(): | |||||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
def test_module_api(test_traced_module): | |||||
m = MyModule() | m = MyModule() | ||||
if test_traced_module: | |||||
buff = m.buff | |||||
param = m.param | |||||
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||||
assert "buff" not in m.__dict__ | |||||
assert "param" not in m.__dict__ | |||||
m.buff = buff | |||||
m.param = param | |||||
assert list(m.children()) == [m.bn, m.i] | assert list(m.children()) == [m.bn, m.i] | ||||
assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | ||||
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | ||||
@@ -141,8 +152,11 @@ def test_module_api(): | |||||
assert m.bn.training == False and m.i.bn.training == False | assert m.bn.training == False and m.i.bn.training == False | ||||
def test_module_api_reuse_submodule(): | |||||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
def test_module_api_reuse_submodule(test_traced_module): | |||||
m = MyModule() | m = MyModule() | ||||
if test_traced_module: | |||||
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||||
m.h = m.i # pylint: disable=attribute-defined-outside-init | m.h = m.i # pylint: disable=attribute-defined-outside-init | ||||
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | ||||
assert list(m.named_modules()) == [ | assert list(m.named_modules()) == [ | ||||
@@ -153,15 +167,21 @@ def test_module_api_reuse_submodule(): | |||||
] | ] | ||||
def test_module_api_iterable_stability(): | |||||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
def test_module_api_iterable_stability(test_traced_module): | |||||
m = MyModule() | m = MyModule() | ||||
if test_traced_module: | |||||
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||||
l = list(m.modules()) | l = list(m.modules()) | ||||
for _ in range(100): | for _ in range(100): | ||||
assert list(m.modules()) == l | assert list(m.modules()) == l | ||||
def test_module_api_hooks(): | |||||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
def test_module_api_hooks(test_traced_module): | |||||
net = MyModule() | net = MyModule() | ||||
if test_traced_module: | |||||
net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1)))) | |||||
pre_hook_num = 0 | pre_hook_num = 0 | ||||
post_hook_num = 0 | post_hook_num = 0 | ||||
hooks = [] | hooks = [] | ||||
@@ -383,11 +403,16 @@ class Simple(Module): | |||||
self.conv1.weight = self.conv0.weight | self.conv1.weight = self.conv0.weight | ||||
def forward(self, inputs): | def forward(self, inputs): | ||||
pass | |||||
x = self.conv0(inputs) | |||||
y = self.conv1(inputs) | |||||
return x + y | |||||
def test_shared_param(): | |||||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
def test_shared_param(test_traced_module): | |||||
net = Simple() | net = Simple() | ||||
if test_traced_module: | |||||
net = trace_module(net, tensor(np.random.random((1, 1, 8, 8)))) | |||||
assert net.conv0.weight is net.conv1.weight | assert net.conv0.weight is net.conv1.weight | ||||
data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | ||||
np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) | np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) | ||||
@@ -449,15 +474,21 @@ def test_shared_param_1d(): | |||||
np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | ||||
def test_pickle_module(): | |||||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
def test_pickle_module(test_traced_module): | |||||
data_shape = (2, 28) | data_shape = (2, 28) | ||||
data = tensor(np.random.random(data_shape)) | data = tensor(np.random.random(data_shape)) | ||||
mlp = MLP() | mlp = MLP() | ||||
pred_gt = mlp(data) | |||||
if test_traced_module: | |||||
mlp = trace_module(mlp, data) | |||||
# pickle before forward | # pickle before forward | ||||
with BytesIO() as fout: | with BytesIO() as fout: | ||||
mge.save(mlp, fout) | mge.save(mlp, fout) | ||||
fout.seek(0) | fout.seek(0) | ||||
mlp1 = mge.load(fout) | mlp1 = mge.load(fout) | ||||
if test_traced_module: | |||||
assert type(mlp1) == TracedModule | |||||
pred0 = mlp1(data) | pred0 = mlp1(data) | ||||
pred1 = mlp(data) | pred1 = mlp(data) | ||||
@@ -467,8 +498,11 @@ def test_pickle_module(): | |||||
mge.save(mlp, fout) | mge.save(mlp, fout) | ||||
fout.seek(0) | fout.seek(0) | ||||
mlp1 = mge.load(fout) | mlp1 = mge.load(fout) | ||||
if test_traced_module: | |||||
assert type(mlp1) == TracedModule | |||||
pred2 = mlp1(data) | pred2 = mlp1(data) | ||||
np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6) | |||||
np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) | np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) | ||||
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | ||||
@@ -0,0 +1,59 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import io | |||||
import numpy as np | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.utils.comp_graph_tools as cgtools | |||||
from megengine.experimental.traced_module import trace_module | |||||
from megengine.jit import trace | |||||
from megengine.module import Module | |||||
class MyBlock(Module): | |||||
def __init__(self, in_channels, channels): | |||||
super(MyBlock, self).__init__() | |||||
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) | |||||
self.bn1 = M.BatchNorm2d(channels) | |||||
def forward(self, x): | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = F.relu(x) + 1 | |||||
return x | |||||
class MyModule(Module): | |||||
def __init__(self): | |||||
super(MyModule, self).__init__() | |||||
self.block0 = MyBlock(8, 4) | |||||
self.block1 = MyBlock(4, 2) | |||||
def forward(self, x): | |||||
x = self.block0(x) | |||||
x = self.block1(x) | |||||
return x | |||||
def test_jit_trace(): | |||||
module = MyModule() | |||||
module.eval() | |||||
x = F.ones((1, 8, 14, 14)) | |||||
expect = module(x) | |||||
traced_module = trace_module(module, x) | |||||
func = trace(traced_module, capture_as_const=True) | |||||
np.testing.assert_array_equal(func(x), expect) | |||||
model = io.BytesIO() | |||||
func.dump(model) | |||||
model.seek(0) | |||||
infer_cg = cgtools.GraphInference(model) | |||||
np.testing.assert_allclose( | |||||
list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6 | |||||
) |