Browse Source

fix(mge/traced_module): fix TracedModule flatten

GitOrigin-RevId: 7b15fe492b
release-1.7
Megvii Engine Team 3 years ago
parent
commit
01d2473ce7
2 changed files with 36 additions and 3 deletions
  1. +1
    -3
      imperative/python/megengine/traced_module/traced_module.py
  2. +35
    -0
      imperative/python/test/unit/traced_module/test_modification.py

+ 1
- 3
imperative/python/megengine/traced_module/traced_module.py View File

@@ -2078,9 +2078,7 @@ class TracedModule(Module):

for node, repl_node in repl_dict.items():
assert node in graph._inputs or node in graph._outputs
for i in node.users:
if i not in repl_node.users:
repl_node.users.append(i)
repl_node.users.extend(node.users)

rename_blacklist = list(chain(call.inputs, call.outputs))



+ 35
- 0
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
from collections import defaultdict
from itertools import chain

import numpy as np
@@ -52,6 +53,25 @@ class MyModule(M.Module):
return x


class MyBlock1(M.Module):
def forward(self, a):
y = F.concat([a, a])
return a, y


class MyModule1(M.Module):
def __init__(self):
super().__init__()
self.block0 = MyBlock1()
self.block1 = MyBlock1()

def forward(self, a):
a, y1 = self.block0(a)
a = a + 1
a, y2 = self.block1(a)
return a, y1 + y2


class NewModule(M.Module):
def __init__(self, traced_module):
super(NewModule, self).__init__()
@@ -64,6 +84,17 @@ class NewModule(M.Module):
return x


def _check_expr_users(traced_module):
node_user = defaultdict(list)
for expr in traced_module.graph._exprs:
for node in expr.inputs:
node_user[node].append(expr)
for node in traced_module.graph.nodes():
node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node]


def _init_cls(cls):
module = cls()
x = F.ones((1, 3, 3, 3))
@@ -201,6 +232,10 @@ def test_flatten():
assert len(traced_module.graph._exprs) == 12
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())

traced_module, x, expect = _init_cls(MyModule1)
traced_module = traced_module.flatten()
_check_expr_users(traced_module)


def test_id_and_name():
def _check_id(traced_module):


Loading…
Cancel
Save