Browse Source

fix(mge/traced_module): fix TracedModule flatten

GitOrigin-RevId: 7b15fe492b
release-1.7
Megvii Engine Team 3 years ago
parent
commit
01d2473ce7
2 changed files with 36 additions and 3 deletions
  1. +1
    -3
      imperative/python/megengine/traced_module/traced_module.py
  2. +35
    -0
      imperative/python/test/unit/traced_module/test_modification.py

+ 1
- 3
imperative/python/megengine/traced_module/traced_module.py View File

@@ -2078,9 +2078,7 @@ class TracedModule(Module):


for node, repl_node in repl_dict.items(): for node, repl_node in repl_dict.items():
assert node in graph._inputs or node in graph._outputs assert node in graph._inputs or node in graph._outputs
for i in node.users:
if i not in repl_node.users:
repl_node.users.append(i)
repl_node.users.extend(node.users)


rename_blacklist = list(chain(call.inputs, call.outputs)) rename_blacklist = list(chain(call.inputs, call.outputs))




+ 35
- 0
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle import pickle
from collections import defaultdict
from itertools import chain from itertools import chain


import numpy as np import numpy as np
@@ -52,6 +53,25 @@ class MyModule(M.Module):
return x return x




class MyBlock1(M.Module):
def forward(self, a):
y = F.concat([a, a])
return a, y


class MyModule1(M.Module):
def __init__(self):
super().__init__()
self.block0 = MyBlock1()
self.block1 = MyBlock1()

def forward(self, a):
a, y1 = self.block0(a)
a = a + 1
a, y2 = self.block1(a)
return a, y1 + y2


class NewModule(M.Module): class NewModule(M.Module):
def __init__(self, traced_module): def __init__(self, traced_module):
super(NewModule, self).__init__() super(NewModule, self).__init__()
@@ -64,6 +84,17 @@ class NewModule(M.Module):
return x return x




def _check_expr_users(traced_module):
node_user = defaultdict(list)
for expr in traced_module.graph._exprs:
for node in expr.inputs:
node_user[node].append(expr)
for node in traced_module.graph.nodes():
node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node]


def _init_cls(cls): def _init_cls(cls):
module = cls() module = cls()
x = F.ones((1, 3, 3, 3)) x = F.ones((1, 3, 3, 3))
@@ -201,6 +232,10 @@ def test_flatten():
assert len(traced_module.graph._exprs) == 12 assert len(traced_module.graph._exprs) == 12
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())


traced_module, x, expect = _init_cls(MyModule1)
traced_module = traced_module.flatten()
_check_expr_users(traced_module)



def test_id_and_name(): def test_id_and_name():
def _check_id(traced_module): def _check_id(traced_module):


Loading…
Cancel
Save