|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import numpy as np
-
- import megengine.functional as F
- import megengine.module as M
- from megengine.experimental.traced_module import trace_module
- from megengine.experimental.traced_module.expr import CallFunction, GetAttr
-
-
- class MyBlock(M.Module):
- def __init__(self, in_channels=3, channels=3):
- super(MyBlock, self).__init__()
- self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
- self.bn1 = M.BatchNorm2d(channels)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = F.relu(x) + 1
- return x
-
-
- class MyModule(M.Module):
- def __init__(self):
- super(MyModule, self).__init__()
- self.block0 = MyBlock()
- self.block1 = MyBlock()
-
- def forward(self, x):
- x = self.block0(x)
- x = self.block1(x)
- return x
-
-
- def _init_cls(cls):
- module = cls()
- x = F.ones((1, 3, 3, 3))
- y = module(x)
- traced_module = trace_module(module, x)
- return traced_module, x, y
-
-
- def _init_block():
- return _init_cls(MyBlock)
-
-
- def _init_module():
- return _init_cls(MyModule)
-
-
- def test_search():
- traced_module, *_ = _init_block()
- graph = traced_module.graph
- relu_expr = graph.get_function_by_type(F.relu).as_unique()
- assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu
-
-
- def test_insert():
- traced_module, x, expect = _init_block()
- graph = traced_module.graph
- relu_node = graph.get_function_by_type(F.relu).as_unique().outputs
- neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node)
- graph.replace_node({relu_node[0]: neg_node})
- graph.compile()
- np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
-
-
- def test_delete():
- traced_module, x, expect = _init_block()
- graph = traced_module.graph
- relu_expr = graph.get_function_by_type(F.relu).as_unique()
- node = relu_expr.outputs
- repl_node = relu_expr.inputs
- graph.replace_node({node[0]: repl_node[0]})
- graph.compile()
- np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)
-
-
- def test_flatten():
- traced_module, x, expect = _init_module()
- traced_module = traced_module.flatten()
- traced_module.graph.compile()
- assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
- assert len(traced_module.graph._exprs) == 12
-
-
- def test_extra_block():
- class PostProcess(M.Module):
- def forward(self, x):
- return x * 2
-
- class Net(M.Module):
- def __init__(self, traced_module):
- super().__init__()
- self.post_process = PostProcess()
- self.traced_module = traced_module
-
- def forward(self, x):
- x = self.traced_module(x)
- x = self.post_process(x)
- return x
-
- traced_module, x, expect = _init_block()
- module = Net(traced_module)
- np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
- traced_module = trace_module(module, x)
- np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)
|