Browse Source

fix(mge/grad_manager): allow multiple calls of `release`

GitOrigin-RevId: 38ca4c78ff
release-1.1
Megvii Engine Team 4 years ago
parent
commit
2627e1f7df
4 changed files with 66 additions and 23 deletions
  1. +23
    -11
      imperative/python/megengine/autodiff/grad_manager.py
  2. +2
    -5
      imperative/python/megengine/tensor.py
  3. +41
    -5
      imperative/python/test/unit/autodiff/test_grad_manger.py
  4. +0
    -2
      imperative/python/test/unit/module/test_qat.py

+ 23
- 11
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -3,9 +3,12 @@ from contextlib import contextmanager
from typing import Callable from typing import Callable


from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..tensor import Tensor, tensor
from ..logger import get_logger
from ..tensor import Tensor
from ..utils.future import Future from ..utils.future import Future


logger = get_logger(__name__)

backwarding_grad_manager = None backwarding_grad_manager = None




@@ -67,7 +70,7 @@ class GradManager:
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = dict() self._gradients = dict()


def attach(self, params, callbacks=None):
def attach(self, params: list, callbacks=None):
r"""Registers parameters that gradients should be calculated with respect to. r"""Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this: Callback Functions should have a signature like this:


@@ -77,7 +80,7 @@ class GradManager:
# do something # do something
return grad return grad


:param params: registered parameters
:param params: to be registered parameters
:param callbacks: list of callback functions :param callbacks: list of callback functions
""" """
if callbacks is None: if callbacks is None:
@@ -95,6 +98,20 @@ class GradManager:
self._record_param(id(p)) self._record_param(id(p))
return self return self


def detach(self, params: list):
r"""Remove specific registered parameters and callback functions.

:param params: registered parameters
"""
if isinstance(params, Tensor):
params = [params]
for idx, param in enumerate(params):
if id(param) in self._param_dict:
self._param_dict.pop(id(param))
self._call_back_dict.pop(id(param))
else:
logger.warning("params with index {} is not attached.".format(idx))

def _register_after_backward_callback(self, callback): def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
return self return self
@@ -136,7 +153,7 @@ class GradManager:
else: else:
param.grad += grad param.grad += grad
finally: finally:
self._stop_record()
self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache


def record(self): def record(self):
@@ -167,15 +184,10 @@ class GradManager:
def release(self): def release(self):
r"""Stops recording and releases resources for gradients calculation. r"""Stops recording and releases resources for gradients calculation.
""" """
if not self._recording:
raise RuntimeError("not recording")
self._stop_record()

def _stop_record(self):
if self._grad is not None: if self._grad is not None:
self._grad.__exit__(None, None, None) self._grad.__exit__(None, None, None)
self._grad = None
self._recording = False self._recording = False
self._grad = None
self._gradients = dict() self._gradients = dict()


def __enter__(self): def __enter__(self):
@@ -183,4 +195,4 @@ class GradManager:
return self return self


def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_record()
self.release()

+ 2
- 5
imperative/python/megengine/tensor.py View File

@@ -85,11 +85,8 @@ class Tensor(_Tensor):


def detach(self): def detach(self):
r""" r"""
Returns a new tensor which is treated as constant during backward gradient calcuation,
i.e. its gradient is zero.

:param inp: input tensor

Returns a new tensor sharing the same data memory, which is treated as a constant
during backward gradient calcuation, i.e. its gradient is zero.
""" """
Wrapper = type(self) Wrapper = type(self)
Tensor = type(self.__wrapped__) Tensor = type(self.__wrapped__)


+ 41
- 5
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -1,15 +1,51 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
import pytest


import megengine as mge import megengine as mge
from megengine import autodiff as ad
import megengine.functional as F
from megengine.autodiff import GradManager


def test_basic():
x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1)
b = mge.tensor(-1.0)

gm = GradManager().attach([w, b])
gm.record()

p = F.matmul(x, w)
y = p + b

gm.backward(y)
gm.release() # is not necessary
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])

w.grad = None
b.grad = None
with gm:
p = F.matmul(x, w)
y = p + b
gm.backward(y)

np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])




def test_attach_in_with_block(): def test_attach_in_with_block():
a = mge.Parameter([1.0]) a = mge.Parameter([1.0])
g = ad.GradManager()
with g:
gm = GradManager()
with gm:
b = a * 3 b = a * 3
g.attach(b)
gm.attach(b)
c = b + 1 c = b + 1
g.backward(c)
gm.backward(c)
assert int(b.grad.numpy()) == 1 assert int(b.grad.numpy()) == 1

+ 0
- 2
imperative/python/test/unit/module/test_qat.py View File

@@ -27,8 +27,6 @@ def test_qat_convbn2d():
disable_fake_quant(qat_module) disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs) normal_outputs = module(inputs)
# import pdb
# pdb.set_trace()
qat_outputs = qat_module(inputs) qat_outputs = qat_module(inputs)
np.testing.assert_allclose( np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6


Loading…
Cancel
Save