Browse Source

fix(traced_module): fix Module compatible issue and traced module getattr check

GitOrigin-RevId: 62eb3bfb10
release-1.8
Megvii Engine Team “wenjuan” 3 years ago
parent
commit
84d99d1cc4
5 changed files with 52 additions and 10 deletions
  1. +7
    -5
      imperative/python/megengine/module/module.py
  2. +1
    -1
      imperative/python/megengine/traced_module/serialization.py
  3. +6
    -4
      imperative/python/megengine/traced_module/traced_module.py
  4. +14
    -0
      imperative/python/test/unit/core/test_serialization.py
  5. +24
    -0
      imperative/python/test/unit/module/test_module.py

+ 7
- 5
imperative/python/megengine/module/module.py View File

@@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook) return HookHandler(self._forward_hooks, hook)


def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
AutoNaming.push_scope(
self.name
if self.name is not None
else (self._short_name if hasattr(self, "_short_name") else self._name)
)
AutoNaming.push_scope(self.name if self.name is not None else self._short_name)
for hook in self._forward_pre_hooks.values(): for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs) modified_inputs = hook(self, inputs)
if modified_inputs is not None: if modified_inputs is not None:
@@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta):
set_name(self, prefix, k, v) set_name(self, prefix, k, v)
super().__setattr__(name, value) super().__setattr__(name, value)


def __setstate__(self, state):
if "_short_name" not in state:
state["_short_name"] = state["_name"]
state["_name"] = None
self.__dict__.update(state)

def __delattr__(self, name: str): def __delattr__(self, name: str):
if name in self.__dict__ and _is_module(self.__dict__[name]): if name in self.__dict__ and _is_module(self.__dict__[name]):
modules = self.__dict__.get("_modules") modules = self.__dict__.get("_modules")


+ 1
- 1
imperative/python/megengine/traced_module/serialization.py View File

@@ -50,7 +50,7 @@ class _ModuleState:
if self.obj is None: if self.obj is None:
typem = getattr(import_module(self.module[0]), self.module[1]) typem = getattr(import_module(self.module[0]), self.module[1])
m_obj = typem.__new__(typem) m_obj = typem.__new__(typem)
m_obj.__dict__.update(self.state)
m_obj.__setstate__(self.state)
self.obj = m_obj self.obj = m_obj
return self.obj return self.obj




+ 6
- 4
imperative/python/megengine/traced_module/traced_module.py View File

@@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin):


if isinstance(wrapped, TracedModuleBuilder): if isinstance(wrapped, TracedModuleBuilder):
if not isinstance(mod_attr, (List, Dict, QATModule)): if not isinstance(mod_attr, (List, Dict, QATModule)):
assert mod_attr is wrapped._mod
else:
assert (
mod_attr is wrapped._mod
), "TracedModule do not support modify module attributes, please check your code."
if isinstance(wrapped, RawTensor):
assert ( assert (
mod_attr is wrapped mod_attr is wrapped
), "TracedModule do not support modify attributes, please check your code."
), "TracedModule do not support modify tensor attributes, please check your code."


if isinstance(wrapped, (NodeMixin, RawTensor)): if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap( NodeMixin.wrap(
@@ -2296,7 +2298,7 @@ class TracedModule(Module):
for k, v in state.items(): for k, v in state.items():
if isinstance(v, _ModuleState): if isinstance(v, _ModuleState):
state[k] = v.to_module() state[k] = v.to_module()
self.__dict__.update(state)
super().__setstate__(state)
self._update_ref() self._update_ref()


for _, graph in self.argdef_graph_map.items(): for _, graph in self.argdef_graph_map.items():


+ 14
- 0
imperative/python/test/unit/core/test_serialization.py View File

@@ -87,3 +87,17 @@ def test_compatibility():


test_old_tensor("tensor_v1_1.mge") test_old_tensor("tensor_v1_1.mge")
test_old_tensor("tensor_v1_2.mge") test_old_tensor("tensor_v1_2.mge")

t = mge.tensor([1])
getattr(t, "qparams")
new_args = t.__getnewargs__()
assert (
len(new_args) == 3
and isinstance(new_args[0], np.ndarray)
and new_args[1] == np.int32
and isinstance(new_args[2], str)
), "Modify Tensor __getnewargs__ may break pickle serialization compatible"
state = t.__getstate__()
assert set(state.keys()) == set(
["qparams"]
), "Modify Tensor __getstate__ may break pickle serialization compatible"

+ 24
- 0
imperative/python/test/unit/module/test_module.py View File

@@ -681,3 +681,27 @@ def test_repr_module_reset_attr():
m1 = ResetAttrModule(False) m1 = ResetAttrModule(False)
output = [m0.__repr__(), m1.__repr__()] output = [m0.__repr__(), m1.__repr__()]
assert output == ground_truth assert output == ground_truth


def test_module_compatible():
class Empty(Module):
def forward(self):
pass

empty_module = Empty()
old_attributes = set(
[
"_modules",
"name",
"training",
"quantize_disabled",
"_forward_pre_hooks",
"_forward_hooks",
"_name",
"_short_name",
]
)
current_attributes = set(empty_module.__dict__.keys())
assert (
old_attributes == current_attributes
), "Add or delete attributes in Module class may break compatibility of pickle serialization"

Loading…
Cancel
Save