GitOrigin-RevId: 62eb3bfb10
release-1.8
@@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta): | |||||
return HookHandler(self._forward_hooks, hook) | return HookHandler(self._forward_hooks, hook) | ||||
def __call__(self, *inputs, **kwargs): | def __call__(self, *inputs, **kwargs): | ||||
AutoNaming.push_scope( | |||||
self.name | |||||
if self.name is not None | |||||
else (self._short_name if hasattr(self, "_short_name") else self._name) | |||||
) | |||||
AutoNaming.push_scope(self.name if self.name is not None else self._short_name) | |||||
for hook in self._forward_pre_hooks.values(): | for hook in self._forward_pre_hooks.values(): | ||||
modified_inputs = hook(self, inputs) | modified_inputs = hook(self, inputs) | ||||
if modified_inputs is not None: | if modified_inputs is not None: | ||||
@@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta): | |||||
set_name(self, prefix, k, v) | set_name(self, prefix, k, v) | ||||
super().__setattr__(name, value) | super().__setattr__(name, value) | ||||
def __setstate__(self, state): | |||||
if "_short_name" not in state: | |||||
state["_short_name"] = state["_name"] | |||||
state["_name"] = None | |||||
self.__dict__.update(state) | |||||
def __delattr__(self, name: str): | def __delattr__(self, name: str): | ||||
if name in self.__dict__ and _is_module(self.__dict__[name]): | if name in self.__dict__ and _is_module(self.__dict__[name]): | ||||
modules = self.__dict__.get("_modules") | modules = self.__dict__.get("_modules") | ||||
@@ -50,7 +50,7 @@ class _ModuleState: | |||||
if self.obj is None: | if self.obj is None: | ||||
typem = getattr(import_module(self.module[0]), self.module[1]) | typem = getattr(import_module(self.module[0]), self.module[1]) | ||||
m_obj = typem.__new__(typem) | m_obj = typem.__new__(typem) | ||||
m_obj.__dict__.update(self.state) | |||||
m_obj.__setstate__(self.state) | |||||
self.obj = m_obj | self.obj = m_obj | ||||
return self.obj | return self.obj | ||||
@@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin): | |||||
if isinstance(wrapped, TracedModuleBuilder): | if isinstance(wrapped, TracedModuleBuilder): | ||||
if not isinstance(mod_attr, (List, Dict, QATModule)): | if not isinstance(mod_attr, (List, Dict, QATModule)): | ||||
assert mod_attr is wrapped._mod | |||||
else: | |||||
assert ( | |||||
mod_attr is wrapped._mod | |||||
), "TracedModule do not support modify module attributes, please check your code." | |||||
if isinstance(wrapped, RawTensor): | |||||
assert ( | assert ( | ||||
mod_attr is wrapped | mod_attr is wrapped | ||||
), "TracedModule do not support modify attributes, please check your code." | |||||
), "TracedModule do not support modify tensor attributes, please check your code." | |||||
if isinstance(wrapped, (NodeMixin, RawTensor)): | if isinstance(wrapped, (NodeMixin, RawTensor)): | ||||
NodeMixin.wrap( | NodeMixin.wrap( | ||||
@@ -2296,7 +2298,7 @@ class TracedModule(Module): | |||||
for k, v in state.items(): | for k, v in state.items(): | ||||
if isinstance(v, _ModuleState): | if isinstance(v, _ModuleState): | ||||
state[k] = v.to_module() | state[k] = v.to_module() | ||||
self.__dict__.update(state) | |||||
super().__setstate__(state) | |||||
self._update_ref() | self._update_ref() | ||||
for _, graph in self.argdef_graph_map.items(): | for _, graph in self.argdef_graph_map.items(): | ||||
@@ -87,3 +87,17 @@ def test_compatibility(): | |||||
test_old_tensor("tensor_v1_1.mge") | test_old_tensor("tensor_v1_1.mge") | ||||
test_old_tensor("tensor_v1_2.mge") | test_old_tensor("tensor_v1_2.mge") | ||||
t = mge.tensor([1]) | |||||
getattr(t, "qparams") | |||||
new_args = t.__getnewargs__() | |||||
assert ( | |||||
len(new_args) == 3 | |||||
and isinstance(new_args[0], np.ndarray) | |||||
and new_args[1] == np.int32 | |||||
and isinstance(new_args[2], str) | |||||
), "Modify Tensor __getnewargs__ may break pickle serialization compatible" | |||||
state = t.__getstate__() | |||||
assert set(state.keys()) == set( | |||||
["qparams"] | |||||
), "Modify Tensor __getstate__ may break pickle serialization compatible" |
@@ -681,3 +681,27 @@ def test_repr_module_reset_attr(): | |||||
m1 = ResetAttrModule(False) | m1 = ResetAttrModule(False) | ||||
output = [m0.__repr__(), m1.__repr__()] | output = [m0.__repr__(), m1.__repr__()] | ||||
assert output == ground_truth | assert output == ground_truth | ||||
def test_module_compatible(): | |||||
class Empty(Module): | |||||
def forward(self): | |||||
pass | |||||
empty_module = Empty() | |||||
old_attributes = set( | |||||
[ | |||||
"_modules", | |||||
"name", | |||||
"training", | |||||
"quantize_disabled", | |||||
"_forward_pre_hooks", | |||||
"_forward_hooks", | |||||
"_name", | |||||
"_short_name", | |||||
] | |||||
) | |||||
current_attributes = set(empty_module.__dict__.keys()) | |||||
assert ( | |||||
old_attributes == current_attributes | |||||
), "Add or delete attributes in Module class may break compatibility of pickle serialization" |