Browse Source

fix(mge/traced_module): fix module dict

GitOrigin-RevId: d7baf00e3c
release-1.7
Megvii Engine Team 3 years ago
parent
commit
dbca32702a
2 changed files with 20 additions and 10 deletions
  1. +9
    -7
      imperative/python/megengine/traced_module/utils.py
  2. +11
    -3
      imperative/python/test/unit/traced_module/test_trace_module.py

+ 9
- 7
imperative/python/megengine/traced_module/utils.py View File

@@ -147,13 +147,14 @@ class _ModuleDict(Module, MutableMapping):


def __init__(self, modules: Optional[Dict[str, Module]] = None): def __init__(self, modules: Optional[Dict[str, Module]] = None):
super().__init__() super().__init__()
self._size = 0
self._module_keys = []
if modules is not None: if modules is not None:
self.update(modules) self.update(modules)


def __delitem__(self, key): def __delitem__(self, key):
delattr(self, key) delattr(self, key)
self._size -= 1
assert key in self._module_keys
self._module_keys.remove(key)


def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)
@@ -162,22 +163,23 @@ class _ModuleDict(Module, MutableMapping):
if not isinstance(value, Module): if not isinstance(value, Module):
raise ValueError("invalid sub-module") raise ValueError("invalid sub-module")
setattr(self, key, value) setattr(self, key, value)
self._size += 1
if key not in self._module_keys:
self._module_keys.append(key)


def __iter__(self): def __iter__(self):
return iter(self.keys()) return iter(self.keys())


def __len__(self): def __len__(self):
return self._size
return len(self._module_keys)


def items(self): def items(self):
return dict(self.named_children()).items()
return [(key, getattr(self, key)) for key in self._module_keys]


def values(self): def values(self):
return dict(self.named_children()).values()
return [getattr(self, key) for key in self._module_keys]


def keys(self): def keys(self):
return dict(self.named_children()).keys()
return self._module_keys


def forward(self): def forward(self):
raise RuntimeError("ModuleList is not callable") raise RuntimeError("ModuleList is not callable")

+ 11
- 3
imperative/python/test/unit/traced_module/test_trace_module.py View File

@@ -1,3 +1,5 @@
from collections import OrderedDict

import numpy as np import numpy as np


import megengine.functional as F import megengine.functional as F
@@ -29,14 +31,19 @@ class MyModule3(M.Module):
self.modules = [ self.modules = [
M.Elemwise("ADD"), M.Elemwise("ADD"),
M.Elemwise("ADD"), M.Elemwise("ADD"),
{"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")},
OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
M.Elemwise("RELU"),
M.Elemwise("RELU"),
] ]


def forward(self, a, b): def forward(self, a, b):
x = self.modules[0](a, b) x = self.modules[0](a, b)
y = self.modules[1](a, b) y = self.modules[1](a, b)
y = self.modules[2]["a"](x, y)
y = self.modules[2]["b"](x, y)
assert list(self.modules[2].keys()) == ["a", "b"]
for _, m in self.modules[2].items():
y = m(x, y)
for m in self.modules[3:]:
y = m(y)
return y return y




@@ -78,6 +85,7 @@ def test_trace_module():
assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
assert isinstance(tm3.modules.__dict__["2"], TracedModule) assert isinstance(tm3.modules.__dict__["2"], TracedModule)
assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)


m4 = MyModule4() m4 = MyModule4()
tm4 = trace_module(m4, a, b) tm4 = trace_module(m4, a, b)


Loading…
Cancel
Save