GitOrigin-RevId: 0d6bb20b2b
release-1.6
@@ -201,7 +201,8 @@ class Apply(Expr): | |||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||
apply_node = cls.make(opdef) | |||
for i in inputs: | |||
apply_node.add_input(NodeMixin.get(i)) | |||
assert isinstance(i, RawTensor) | |||
apply_node.inputs.append(NodeMixin.get(i)) | |||
unset_module_tracing() | |||
outputs = apply(opdef, *inputs) | |||
@@ -1,3 +1,13 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Callable, NamedTuple | |||
SUPPORTED_TYPE = {} | |||
@@ -9,11 +19,22 @@ def register_supported_type(type, flatten, unflatten): | |||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
def _dict_flatten(inp): | |||
aux_data = [] | |||
results = [] | |||
for key, value in sorted(inp.items()): | |||
results.append(value) | |||
aux_data.append(key) | |||
return results, aux_data | |||
def _dict_unflatten(inps, aux_data): | |||
return dict(zip(aux_data, inps)) | |||
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
register_supported_type( | |||
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) | |||
) | |||
register_supported_type(dict, _dict_flatten, _dict_unflatten) | |||
register_supported_type( | |||
slice, | |||
lambda x: ([x.start, x.stop, x.step], None), | |||
@@ -68,6 +89,8 @@ class TreeDef: | |||
class LeafDef(TreeDef): | |||
def __init__(self, type): | |||
if not isinstance(type, collections.abc.Sequence): | |||
type = (type,) | |||
super().__init__(type, None, []) | |||
self.num_leaves = 1 | |||
@@ -77,4 +100,4 @@ class LeafDef(TreeDef): | |||
return leaves[0] | |||
def __repr__(self): | |||
return "Leaf({})".format(self.type.__name__) | |||
return "Leaf({})".format(", ".join(t.__name__ for t in self.type)) |
@@ -14,6 +14,7 @@ import megengine as mge | |||
import megengine.autodiff as ad | |||
import megengine.functional as F | |||
from megengine import Tensor | |||
from megengine.experimental.traced_module import trace_module | |||
from megengine.module import Linear, Module | |||
from megengine.optimizer import SGD | |||
@@ -71,8 +72,13 @@ class XORNet(Module): | |||
return x | |||
def test_training_converge(): | |||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||
def test_training_converge(test_traced_module): | |||
net = XORNet() | |||
if test_training_converge: | |||
inp = Tensor(np.random.random((14, 2))) | |||
net = trace_module(net, inp) | |||
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||
gm = ad.GradManager().attach(net.parameters()) | |||
@@ -105,9 +111,8 @@ def test_training_converge(): | |||
xx = xx.reshape((ngrid * ngrid, 1)) | |||
yy = yy.reshape((ngrid * ngrid, 1)) | |||
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | |||
pred = infer(data).numpy() | |||
precision = calculate_precision(data.numpy(), pred) | |||
pred = infer(data) | |||
precision = calculate_precision(data.numpy(), pred.numpy()) | |||
assert precision == 1.0, "Test precision must be high enough, get {}".format( | |||
precision | |||
) |
@@ -15,6 +15,7 @@ import megengine.autodiff as ad | |||
import megengine.functional as F | |||
import megengine.optimizer as optim | |||
from megengine import Tensor | |||
from megengine.experimental.traced_module import trace_module | |||
from megengine.jit import trace | |||
from megengine.module import Linear, Module | |||
from megengine.optimizer import SGD | |||
@@ -73,8 +74,12 @@ class XORNet(Module): | |||
return x | |||
def test_training_converge(): | |||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||
def test_training_converge(test_traced_module): | |||
net = XORNet() | |||
if test_traced_module: | |||
inp = Tensor(np.random.random((14, 2))) | |||
net = trace_module(net, inp) | |||
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||
gm = ad.GradManager().attach(net.parameters()) | |||
@@ -110,9 +115,8 @@ def test_training_converge(): | |||
xx = xx.reshape((ngrid * ngrid, 1)) | |||
yy = yy.reshape((ngrid * ngrid, 1)) | |||
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | |||
pred = infer(data).numpy() | |||
precision = calculate_precision(data.numpy(), pred) | |||
pred = infer(data) | |||
precision = calculate_precision(data.numpy(), pred.numpy()) | |||
print("precision=", precision) | |||
assert precision == 1.0, "Test precision must be high enough, get {}".format( | |||
precision | |||
@@ -19,6 +19,7 @@ import megengine.module as M | |||
import megengine.optimizer as optim | |||
from megengine import tensor | |||
from megengine.autodiff import GradManager | |||
from megengine.experimental.traced_module import trace_module | |||
from megengine.jit import trace | |||
@@ -15,6 +15,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
from megengine import Parameter, Tensor, tensor | |||
from megengine.experimental.traced_module import TracedModule, trace_module | |||
from megengine.module import ( | |||
BatchNorm1d, | |||
BatchNorm2d, | |||
@@ -67,8 +68,18 @@ class MyModule(Module): | |||
return x | |||
def test_module_api(): | |||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||
def test_module_api(test_traced_module): | |||
m = MyModule() | |||
if test_traced_module: | |||
buff = m.buff | |||
param = m.param | |||
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||
assert "buff" not in m.__dict__ | |||
assert "param" not in m.__dict__ | |||
m.buff = buff | |||
m.param = param | |||
assert list(m.children()) == [m.bn, m.i] | |||
assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | |||
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||
@@ -141,8 +152,11 @@ def test_module_api(): | |||
assert m.bn.training == False and m.i.bn.training == False | |||
def test_module_api_reuse_submodule(): | |||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||
def test_module_api_reuse_submodule(test_traced_module): | |||
m = MyModule() | |||
if test_traced_module: | |||
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||
m.h = m.i # pylint: disable=attribute-defined-outside-init | |||
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||
assert list(m.named_modules()) == [ | |||
@@ -153,15 +167,21 @@ def test_module_api_reuse_submodule(): | |||
] | |||
def test_module_api_iterable_stability(): | |||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||
def test_module_api_iterable_stability(test_traced_module): | |||
m = MyModule() | |||
if test_traced_module: | |||
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||
l = list(m.modules()) | |||
for _ in range(100): | |||
assert list(m.modules()) == l | |||
def test_module_api_hooks(): | |||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||
def test_module_api_hooks(test_traced_module): | |||
net = MyModule() | |||
if test_traced_module: | |||
net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1)))) | |||
pre_hook_num = 0 | |||
post_hook_num = 0 | |||
hooks = [] | |||
@@ -383,11 +403,16 @@ class Simple(Module): | |||
self.conv1.weight = self.conv0.weight | |||
def forward(self, inputs): | |||
pass | |||
x = self.conv0(inputs) | |||
y = self.conv1(inputs) | |||
return x + y | |||
def test_shared_param(): | |||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||
def test_shared_param(test_traced_module): | |||
net = Simple() | |||
if test_traced_module: | |||
net = trace_module(net, tensor(np.random.random((1, 1, 8, 8)))) | |||
assert net.conv0.weight is net.conv1.weight | |||
data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | |||
np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) | |||
@@ -449,15 +474,21 @@ def test_shared_param_1d(): | |||
np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | |||
def test_pickle_module(): | |||
@pytest.mark.parametrize("test_traced_module", [True, False]) | |||
def test_pickle_module(test_traced_module): | |||
data_shape = (2, 28) | |||
data = tensor(np.random.random(data_shape)) | |||
mlp = MLP() | |||
pred_gt = mlp(data) | |||
if test_traced_module: | |||
mlp = trace_module(mlp, data) | |||
# pickle before forward | |||
with BytesIO() as fout: | |||
mge.save(mlp, fout) | |||
fout.seek(0) | |||
mlp1 = mge.load(fout) | |||
if test_traced_module: | |||
assert type(mlp1) == TracedModule | |||
pred0 = mlp1(data) | |||
pred1 = mlp(data) | |||
@@ -467,8 +498,11 @@ def test_pickle_module(): | |||
mge.save(mlp, fout) | |||
fout.seek(0) | |||
mlp1 = mge.load(fout) | |||
if test_traced_module: | |||
assert type(mlp1) == TracedModule | |||
pred2 = mlp1(data) | |||
np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6) | |||
np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) | |||
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | |||
@@ -0,0 +1,59 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import io | |||
import numpy as np | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine.experimental.traced_module import trace_module | |||
from megengine.jit import trace | |||
from megengine.module import Module | |||
class MyBlock(Module): | |||
def __init__(self, in_channels, channels): | |||
super(MyBlock, self).__init__() | |||
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) | |||
self.bn1 = M.BatchNorm2d(channels) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = F.relu(x) + 1 | |||
return x | |||
class MyModule(Module): | |||
def __init__(self): | |||
super(MyModule, self).__init__() | |||
self.block0 = MyBlock(8, 4) | |||
self.block1 = MyBlock(4, 2) | |||
def forward(self, x): | |||
x = self.block0(x) | |||
x = self.block1(x) | |||
return x | |||
def test_jit_trace(): | |||
module = MyModule() | |||
module.eval() | |||
x = F.ones((1, 8, 14, 14)) | |||
expect = module(x) | |||
traced_module = trace_module(module, x) | |||
func = trace(traced_module, capture_as_const=True) | |||
np.testing.assert_array_equal(func(x), expect) | |||
model = io.BytesIO() | |||
func.dump(model) | |||
model.seek(0) | |||
infer_cg = cgtools.GraphInference(model) | |||
np.testing.assert_allclose( | |||
list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6 | |||
) |