|
|
@@ -5,12 +5,21 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import pickle |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import megengine.functional as F |
|
|
|
import megengine.module as M |
|
|
|
from megengine.module.identity import Identity |
|
|
|
from megengine.traced_module import trace_module |
|
|
|
from megengine.traced_module.expr import CallFunction, GetAttr |
|
|
|
from megengine.traced_module.expr import CallFunction, Expr, GetAttr |
|
|
|
from megengine.traced_module.node import Node |
|
|
|
|
|
|
|
|
|
|
|
class IdentityMod(M.Module): |
|
|
|
def forward(self, x): |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class MyBlock(M.Module): |
|
|
@@ -18,11 +27,13 @@ class MyBlock(M.Module): |
|
|
|
super(MyBlock, self).__init__() |
|
|
|
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) |
|
|
|
self.bn1 = M.BatchNorm2d(channels) |
|
|
|
self.nothing = IdentityMod() |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = self.conv1(x) |
|
|
|
x = self.bn1(x) |
|
|
|
x = F.relu(x) + 1 |
|
|
|
x = self.nothing(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
@@ -31,10 +42,24 @@ class MyModule(M.Module): |
|
|
|
super(MyModule, self).__init__() |
|
|
|
self.block0 = MyBlock() |
|
|
|
self.block1 = MyBlock() |
|
|
|
self.nothing = IdentityMod() |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = self.block0(x) |
|
|
|
x = self.block1(x) |
|
|
|
x = self.nothing(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class NewModule(M.Module): |
|
|
|
def __init__(self, traced_module): |
|
|
|
super(NewModule, self).__init__() |
|
|
|
self.module = traced_module |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = x - 1 |
|
|
|
x = self.module(x) |
|
|
|
x = x + 1 |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
@@ -82,6 +107,12 @@ def test_delete(): |
|
|
|
graph.compile() |
|
|
|
np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6) |
|
|
|
|
|
|
|
# clear graph |
|
|
|
graph.replace_node({graph.outputs[0]: graph.inputs[1]}) |
|
|
|
graph.compile() |
|
|
|
np.testing.assert_equal(len(list(graph._exprs)), 0) |
|
|
|
np.testing.assert_equal(traced_module(x).numpy(), x.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def test_flatten(): |
|
|
|
traced_module, x, expect = _init_module() |
|
|
@@ -89,6 +120,74 @@ def test_flatten(): |
|
|
|
traced_module.graph.compile() |
|
|
|
assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs) |
|
|
|
assert len(traced_module.graph._exprs) == 12 |
|
|
|
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def test_id_and_name(): |
|
|
|
def _check_id(traced_module): |
|
|
|
_total_ids = traced_module.graph._total_ids |
|
|
|
node_ids = [n._id for n in traced_module.graph.nodes().as_list()] |
|
|
|
assert len(set(node_ids)) == len(node_ids) |
|
|
|
assert max(node_ids) + 1 == len(node_ids) |
|
|
|
|
|
|
|
expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] |
|
|
|
assert len(set(expr_ids)) == len(expr_ids) |
|
|
|
assert max(expr_ids) + 1 == _total_ids[1] |
|
|
|
|
|
|
|
def _check_name(flatened_module): |
|
|
|
node_names = [n._name for n in flatened_module.graph.nodes().as_list()] |
|
|
|
assert len(set(node_names)) == len(node_names) |
|
|
|
|
|
|
|
traced_module, x, expect = _init_module() |
|
|
|
_check_id(traced_module) |
|
|
|
|
|
|
|
flattened_module = traced_module.flatten() |
|
|
|
_check_id(flattened_module) |
|
|
|
_check_name(flattened_module) |
|
|
|
|
|
|
|
# pickle check |
|
|
|
obj = pickle.dumps(traced_module) |
|
|
|
traced_module = pickle.loads(obj) |
|
|
|
Node._set_next_id(159) |
|
|
|
Expr._set_next_id(1024) |
|
|
|
|
|
|
|
graph = traced_module.graph |
|
|
|
for expr in graph.get_function_by_type(F.relu).as_list(): |
|
|
|
relu_out = expr.outputs[0] |
|
|
|
cur_graph = expr.top_graph |
|
|
|
with cur_graph.insert_exprs(): |
|
|
|
neg_out = F.neg(relu_out) |
|
|
|
cur_graph.replace_node({relu_out: neg_out}) |
|
|
|
cur_graph.compile() |
|
|
|
_check_id(traced_module) |
|
|
|
|
|
|
|
flattened_module = traced_module.flatten() |
|
|
|
_check_id(flattened_module) |
|
|
|
_check_name(flattened_module) |
|
|
|
|
|
|
|
# check trace TracedModule |
|
|
|
obj = pickle.dumps(traced_module) |
|
|
|
traced_module = pickle.loads(obj) |
|
|
|
module = NewModule(traced_module) |
|
|
|
traced_module = trace_module(module, x) |
|
|
|
_check_id(traced_module) |
|
|
|
|
|
|
|
flattened_module = traced_module.flatten() |
|
|
|
_check_id(flattened_module) |
|
|
|
_check_name(flattened_module) |
|
|
|
|
|
|
|
|
|
|
|
def test_set_name(): |
|
|
|
traced_module, x, expect = _init_module() |
|
|
|
graph = traced_module.graph |
|
|
|
output_node = graph.outputs[0] |
|
|
|
|
|
|
|
def rename(name): |
|
|
|
output_node.name = name |
|
|
|
|
|
|
|
np.testing.assert_raises(AssertionError, rename, "block1_out") |
|
|
|
rename("output") |
|
|
|
np.testing.assert_equal(str(graph.outputs[0]), "output") |
|
|
|
|
|
|
|
|
|
|
|
def test_extra_block(): |
|
|
|