Browse Source

fix(traced_module): fix TracedModule InternalGraph deepcopy exceed max recursion limit

GitOrigin-RevId: 2b52ad913d
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
04b1424169
3 changed files with 26 additions and 25 deletions
  1. +0
    -2
      imperative/python/megengine/traced_module/module_tracer.py
  2. +24
    -21
      imperative/python/megengine/traced_module/traced_module.py
  3. +2
    -2
      imperative/python/test/unit/traced_module/test_trace_module.py

+ 0
- 2
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -148,7 +148,6 @@ class module_tracer:
self.checker.push_scope()
self._activate_constant_cache.append([])


def pop_scope(self):
self._active_scopes.pop()
self.checker.pop_scope()
@@ -157,7 +156,6 @@ class module_tracer:
if hasattr(obj, "_NodeMixin__node"):
delattr(obj, "_NodeMixin__node")


def current_scope(self):
if self._active_scopes:
return self._active_scopes[-1]


+ 24
- 21
imperative/python/megengine/traced_module/traced_module.py View File

@@ -58,6 +58,7 @@ from ..quantization.observer import (
SyncMinMaxObserver,
)
from ..tensor import Tensor
from ..utils.max_recursion_limit import max_recursion_limit
from ..version import __version__
from .expr import (
Apply,
@@ -1247,17 +1248,18 @@ class InternalGraph:
return result

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result
with max_recursion_limit():
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result


def _get_meth_name(obj, func):
@@ -2359,16 +2361,17 @@ class TracedModule(Module):
return result

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result
with max_recursion_limit():
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result


def cpp_apply_module_trace(opdef, *args):


+ 2
- 2
imperative/python/test/unit/traced_module/test_trace_module.py View File

@@ -6,7 +6,7 @@ import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.module.module import Module
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
from megengine.traced_module.expr import CallFunction


@@ -58,7 +58,7 @@ class MyModule4(M.Module):


def test_trace_module():
enable_expr_checker()
x = Tensor(1)
m1 = MyModule1()
tm1 = trace_module(m1, x)


Loading…
Cancel
Save