Browse Source

fix(traced_module): fix TracedModule InternalGraph deepcopy exceed max recursion limit

GitOrigin-RevId: 2b52ad913d
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
04b1424169
3 changed files with 26 additions and 25 deletions
  1. +0
    -2
      imperative/python/megengine/traced_module/module_tracer.py
  2. +24
    -21
      imperative/python/megengine/traced_module/traced_module.py
  3. +2
    -2
      imperative/python/test/unit/traced_module/test_trace_module.py

+ 0
- 2
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -148,7 +148,6 @@ class module_tracer:
self.checker.push_scope() self.checker.push_scope()
self._activate_constant_cache.append([]) self._activate_constant_cache.append([])



def pop_scope(self): def pop_scope(self):
self._active_scopes.pop() self._active_scopes.pop()
self.checker.pop_scope() self.checker.pop_scope()
@@ -157,7 +156,6 @@ class module_tracer:
if hasattr(obj, "_NodeMixin__node"): if hasattr(obj, "_NodeMixin__node"):
delattr(obj, "_NodeMixin__node") delattr(obj, "_NodeMixin__node")



def current_scope(self): def current_scope(self):
if self._active_scopes: if self._active_scopes:
return self._active_scopes[-1] return self._active_scopes[-1]


+ 24
- 21
imperative/python/megengine/traced_module/traced_module.py View File

@@ -58,6 +58,7 @@ from ..quantization.observer import (
SyncMinMaxObserver, SyncMinMaxObserver,
) )
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.max_recursion_limit import max_recursion_limit
from ..version import __version__ from ..version import __version__
from .expr import ( from .expr import (
Apply, Apply,
@@ -1247,17 +1248,18 @@ class InternalGraph:
return result return result


def __deepcopy__(self, memo): def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
with max_recursion_limit():
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result




def _get_meth_name(obj, func): def _get_meth_name(obj, func):
@@ -2359,16 +2361,17 @@ class TracedModule(Module):
return result return result


def __deepcopy__(self, memo): def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result
with max_recursion_limit():
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result




def cpp_apply_module_trace(opdef, *args): def cpp_apply_module_trace(opdef, *args):


+ 2
- 2
imperative/python/test/unit/traced_module/test_trace_module.py View File

@@ -6,7 +6,7 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Tensor from megengine import Tensor
from megengine.module.module import Module from megengine.module.module import Module
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
from megengine.traced_module.expr import CallFunction from megengine.traced_module.expr import CallFunction




@@ -58,7 +58,7 @@ class MyModule4(M.Module):




def test_trace_module(): def test_trace_module():
enable_expr_checker()
x = Tensor(1) x = Tensor(1)
m1 = MyModule1() m1 = MyModule1()
tm1 = trace_module(m1, x) tm1 = trace_module(m1, x)


Loading…
Cancel
Save