Browse Source

fix(mge/function): do not deeply copy saved tensor in Function

GitOrigin-RevId: 3c89d1ceaa
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
35bc0e1f60
3 changed files with 65 additions and 10 deletions
  1. +15
    -0
      python_module/megengine/core/function.py
  2. +6
    -1
      python_module/megengine/core/tensor.py
  3. +44
    -9
      python_module/test/unit/core/test_function.py

+ 15
- 0
python_module/megengine/core/function.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union


@@ -142,6 +143,20 @@ class Function(metaclass=ABCMeta):
""" """
self.saved_tensors = tensors self.saved_tensors = tensors


def __deepcopy__(self, memo):
"""
Defines how the operator is deeply copied
"""
cls = self.__class__
result = cls.__new__(cls)
tmp = self.saved_tensors
self.saved_tensors = None
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
self.saved_tensors = tmp
return result

def __call__(self, *inputs): def __call__(self, *inputs):
assert ( assert (
not self._has_saved_state not self._has_saved_state


+ 6
- 1
python_module/megengine/core/tensor.py View File

@@ -495,7 +495,12 @@ class Tensor:
) )


def __getstate__(self): def __getstate__(self):
assert (self.__val is not None) and (self.__sym is None)
r""" __getstate__ will be called for pickle serialization or deep copy
"""

assert (self.__val is not None) and (
self.__sym is None
), "Only SharedND initialized Tensor can be serialized or deep copied"
metadata = {"requires_grad": self.requires_grad} metadata = {"requires_grad": self.requires_grad}
state = { state = {
"data": self.numpy(), "data": self.numpy(),


+ 44
- 9
python_module/test/unit/core/test_function.py View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -6,10 +5,13 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy

import numpy as np import numpy as np


import megengine.functional as F import megengine.functional as F
from megengine.core import Function, tensor from megengine.core import Function, tensor
from megengine.jit import trace
from megengine.test import assertTensorClose from megengine.test import assertTensorClose




@@ -76,6 +78,27 @@ def test_ste():
) )




def test_deepcopy():
class Sigmoid(Function):
def __init__(self, param):
super().__init__()
self.param = param

def forward(self, x):
y = 1 / (1 + F.exp(-x))
self.save_for_backward(y)
return y

def backward(self, grad_y):
(y,) = self.saved_tensors
return grad_y * y * (1 - y)

origin = Sigmoid(0)
new = copy.deepcopy(Sigmoid(0))
assert new.param == origin.param
assert new.saved_tensors == None


def test_save_context(): def test_save_context():
class Sigmoid(Function): class Sigmoid(Function):
def forward(self, x): def forward(self, x):
@@ -87,14 +110,26 @@ def test_save_context():
(y,) = self.saved_tensors (y,) = self.saved_tensors
return grad_y * y * (1 - y) return grad_y * y * (1 - y)


a = tensor(np.array([1926.0817], dtype=np.float32))
s = Sigmoid()(a)
s2 = F.sigmoid(a)
assertTensorClose(s.numpy(), s2.numpy())
assertTensorClose(
F.grad(s, a, use_virtual_grad=False).numpy(),
F.grad(s2, a, use_virtual_grad=False).numpy(),
)
def run_saved_context(a, net=None):
return net(a)

def run(use_trace, symbolic):
a = tensor(np.array([1926.0817], dtype=np.float32))
net = Sigmoid()
func_run = run_saved_context
if use_trace:
func_run = trace(run_saved_context, symbolic=symbolic)
s = func_run(a, net=net)
s2 = F.sigmoid(a)
assertTensorClose(s.numpy(), s2.numpy())
assertTensorClose(
F.grad(s, a, use_virtual_grad=False).numpy(),
F.grad(s2, a, use_virtual_grad=False).numpy(),
)

run(False, False)
run(True, False)
run(True, True)




def test_none_in_out_grad(): def test_none_in_out_grad():


Loading…
Cancel
Save