|
|
@@ -58,6 +58,7 @@ from ..quantization.observer import ( |
|
|
|
SyncMinMaxObserver, |
|
|
|
) |
|
|
|
from ..tensor import Tensor |
|
|
|
from ..utils.max_recursion_limit import max_recursion_limit |
|
|
|
from ..version import __version__ |
|
|
|
from .expr import ( |
|
|
|
Apply, |
|
|
@@ -1247,17 +1248,18 @@ class InternalGraph: |
|
|
|
return result |
|
|
|
|
|
|
|
def __deepcopy__(self, memo): |
|
|
|
if id(self) in memo: |
|
|
|
return memo[id(self)] |
|
|
|
cls = self.__class__ |
|
|
|
result = cls.__new__(cls) |
|
|
|
state = {} |
|
|
|
memo[id(self)] = result |
|
|
|
for k, v in self.__dict__.items(): |
|
|
|
if not isinstance(v, weakref.ReferenceType): |
|
|
|
state[k] = copy.deepcopy(v, memo) |
|
|
|
result.__dict__.update(state) |
|
|
|
return result |
|
|
|
with max_recursion_limit(): |
|
|
|
if id(self) in memo: |
|
|
|
return memo[id(self)] |
|
|
|
cls = self.__class__ |
|
|
|
result = cls.__new__(cls) |
|
|
|
state = {} |
|
|
|
memo[id(self)] = result |
|
|
|
for k, v in self.__dict__.items(): |
|
|
|
if not isinstance(v, weakref.ReferenceType): |
|
|
|
state[k] = copy.deepcopy(v, memo) |
|
|
|
result.__dict__.update(state) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def _get_meth_name(obj, func): |
|
|
@@ -2359,16 +2361,17 @@ class TracedModule(Module): |
|
|
|
return result |
|
|
|
|
|
|
|
def __deepcopy__(self, memo): |
|
|
|
cls = self.__class__ |
|
|
|
result = cls.__new__(cls) |
|
|
|
state = {} |
|
|
|
memo[id(self)] = result |
|
|
|
for k, v in self.__dict__.items(): |
|
|
|
if not isinstance(v, weakref.ReferenceType): |
|
|
|
state[k] = copy.deepcopy(v, memo) |
|
|
|
result.__dict__.update(state) |
|
|
|
result._update_ref() |
|
|
|
return result |
|
|
|
with max_recursion_limit(): |
|
|
|
cls = self.__class__ |
|
|
|
result = cls.__new__(cls) |
|
|
|
state = {} |
|
|
|
memo[id(self)] = result |
|
|
|
for k, v in self.__dict__.items(): |
|
|
|
if not isinstance(v, weakref.ReferenceType): |
|
|
|
state[k] = copy.deepcopy(v, memo) |
|
|
|
result.__dict__.update(state) |
|
|
|
result._update_ref() |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def cpp_apply_module_trace(opdef, *args): |
|
|
|