|
@@ -6,6 +6,7 @@ |
|
|
# software distributed under the License is distributed on an |
|
|
# software distributed under the License is distributed on an |
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
import pickle |
|
|
import pickle |
|
|
|
|
|
from collections import defaultdict |
|
|
from itertools import chain |
|
|
from itertools import chain |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
@@ -52,6 +53,25 @@ class MyModule(M.Module): |
|
|
return x |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MyBlock1(M.Module): |
|
|
|
|
|
def forward(self, a): |
|
|
|
|
|
y = F.concat([a, a]) |
|
|
|
|
|
return a, y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MyModule1(M.Module): |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.block0 = MyBlock1() |
|
|
|
|
|
self.block1 = MyBlock1() |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, a): |
|
|
|
|
|
a, y1 = self.block0(a) |
|
|
|
|
|
a = a + 1 |
|
|
|
|
|
a, y2 = self.block1(a) |
|
|
|
|
|
return a, y1 + y2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NewModule(M.Module): |
|
|
class NewModule(M.Module): |
|
|
def __init__(self, traced_module): |
|
|
def __init__(self, traced_module): |
|
|
super(NewModule, self).__init__() |
|
|
super(NewModule, self).__init__() |
|
@@ -64,6 +84,17 @@ class NewModule(M.Module): |
|
|
return x |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_expr_users(traced_module): |
|
|
|
|
|
node_user = defaultdict(list) |
|
|
|
|
|
for expr in traced_module.graph._exprs: |
|
|
|
|
|
for node in expr.inputs: |
|
|
|
|
|
node_user[node].append(expr) |
|
|
|
|
|
for node in traced_module.graph.nodes(): |
|
|
|
|
|
node.users.sort(key=lambda m: m._id) |
|
|
|
|
|
node_user[node].sort(key=lambda m: m._id) |
|
|
|
|
|
assert node.users == node_user[node] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_cls(cls): |
|
|
def _init_cls(cls): |
|
|
module = cls() |
|
|
module = cls() |
|
|
x = F.ones((1, 3, 3, 3)) |
|
|
x = F.ones((1, 3, 3, 3)) |
|
@@ -201,6 +232,10 @@ def test_flatten(): |
|
|
assert len(traced_module.graph._exprs) == 12 |
|
|
assert len(traced_module.graph._exprs) == 12 |
|
|
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) |
|
|
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) |
|
|
|
|
|
|
|
|
|
|
|
traced_module, x, expect = _init_cls(MyModule1) |
|
|
|
|
|
traced_module = traced_module.flatten() |
|
|
|
|
|
_check_expr_users(traced_module) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_id_and_name(): |
|
|
def test_id_and_name(): |
|
|
def _check_id(traced_module): |
|
|
def _check_id(traced_module): |
|
|