Browse Source

fix(mge/core): fix Tensor deepcopy issue

GitOrigin-RevId: 6bea7970b8
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
9440842e27
1 changed files with 17 additions and 0 deletions
  1. +17
    -0
      python_module/megengine/core/tensor.py

+ 17
- 0
python_module/megengine/core/tensor.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
import functools
import itertools
import weakref
@@ -674,6 +675,22 @@ class Tensor:
snd = mgb.make_shared(device, value=data, dtype=dtype)
self._reset(snd, requires_grad=requires_grad)

def __deepcopy__(self, memo):
"""
Since Tensor have __getstate__ and __setstate__ method,
deepcopy only process the that and ignore the attribute of Parameter.
So we need to add __deepcopy__ method to deepcopy correct attribute.
"""
assert (self.__val is not None) and (
self.__sym is None
), "Only SharedND initialized Tensor can be serialized or deep copied"
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result


def tensor(
data: Union[list, np.ndarray] = None,


Loading…
Cancel
Save