Browse Source

fix(mge/function): do not deeply copy saved tensor in Function

GitOrigin-RevId: 3c89d1ceaa
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
35bc0e1f60
3 changed files with 65 additions and 10 deletions
  1. +15
    -0
      python_module/megengine/core/function.py
  2. +6
    -1
      python_module/megengine/core/tensor.py
  3. +44
    -9
      python_module/test/unit/core/test_function.py

+ 15
- 0
python_module/megengine/core/function.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod
from typing import Iterable, Tuple, Union

@@ -142,6 +143,20 @@ class Function(metaclass=ABCMeta):
"""
self.saved_tensors = tensors

def __deepcopy__(self, memo):
"""
Defines how the operator is deeply copied
"""
cls = self.__class__
result = cls.__new__(cls)
tmp = self.saved_tensors
self.saved_tensors = None
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
self.saved_tensors = tmp
return result

def __call__(self, *inputs):
assert (
not self._has_saved_state


+ 6
- 1
python_module/megengine/core/tensor.py View File

@@ -495,7 +495,12 @@ class Tensor:
)

def __getstate__(self):
assert (self.__val is not None) and (self.__sym is None)
r""" __getstate__ will be called for pickle serialization or deep copy
"""

assert (self.__val is not None) and (
self.__sym is None
), "Only SharedND initialized Tensor can be serialized or deep copied"
metadata = {"requires_grad": self.requires_grad}
state = {
"data": self.numpy(),


+ 44
- 9
python_module/test/unit/core/test_function.py View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -6,10 +5,13 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy

import numpy as np

import megengine.functional as F
from megengine.core import Function, tensor
from megengine.jit import trace
from megengine.test import assertTensorClose


@@ -76,6 +78,27 @@ def test_ste():
)


def test_deepcopy():
class Sigmoid(Function):
def __init__(self, param):
super().__init__()
self.param = param

def forward(self, x):
y = 1 / (1 + F.exp(-x))
self.save_for_backward(y)
return y

def backward(self, grad_y):
(y,) = self.saved_tensors
return grad_y * y * (1 - y)

origin = Sigmoid(0)
new = copy.deepcopy(Sigmoid(0))
assert new.param == origin.param
assert new.saved_tensors == None


def test_save_context():
class Sigmoid(Function):
def forward(self, x):
@@ -87,14 +110,26 @@ def test_save_context():
(y,) = self.saved_tensors
return grad_y * y * (1 - y)

a = tensor(np.array([1926.0817], dtype=np.float32))
s = Sigmoid()(a)
s2 = F.sigmoid(a)
assertTensorClose(s.numpy(), s2.numpy())
assertTensorClose(
F.grad(s, a, use_virtual_grad=False).numpy(),
F.grad(s2, a, use_virtual_grad=False).numpy(),
)
def run_saved_context(a, net=None):
return net(a)

def run(use_trace, symbolic):
a = tensor(np.array([1926.0817], dtype=np.float32))
net = Sigmoid()
func_run = run_saved_context
if use_trace:
func_run = trace(run_saved_context, symbolic=symbolic)
s = func_run(a, net=net)
s2 = F.sigmoid(a)
assertTensorClose(s.numpy(), s2.numpy())
assertTensorClose(
F.grad(s, a, use_virtual_grad=False).numpy(),
F.grad(s2, a, use_virtual_grad=False).numpy(),
)

run(False, False)
run(True, False)
run(True, True)


def test_none_in_out_grad():


Loading…
Cancel
Save