Browse Source

fix(mge/grad_manager): allow multiple calls of `release`

GitOrigin-RevId: 38ca4c78ff
release-1.1
Megvii Engine Team 4 years ago
parent
commit
2627e1f7df
4 changed files with 66 additions and 23 deletions
  1. +23
    -11
      imperative/python/megengine/autodiff/grad_manager.py
  2. +2
    -5
      imperative/python/megengine/tensor.py
  3. +41
    -5
      imperative/python/test/unit/autodiff/test_grad_manger.py
  4. +0
    -2
      imperative/python/test/unit/module/test_qat.py

+ 23
- 11
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -3,9 +3,12 @@ from contextlib import contextmanager
from typing import Callable

from ..core.autodiff.grad import Grad
from ..tensor import Tensor, tensor
from ..logger import get_logger
from ..tensor import Tensor
from ..utils.future import Future

logger = get_logger(__name__)

backwarding_grad_manager = None


@@ -67,7 +70,7 @@ class GradManager:
self._after_backward_callback = []
self._gradients = dict()

def attach(self, params, callbacks=None):
def attach(self, params: list, callbacks=None):
r"""Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this:

@@ -77,7 +80,7 @@ class GradManager:
# do something
return grad

:param params: registered parameters
:param params: to be registered parameters
:param callbacks: list of callback functions
"""
if callbacks is None:
@@ -95,6 +98,20 @@ class GradManager:
self._record_param(id(p))
return self

def detach(self, params: list):
r"""Remove specific registered parameters and callback functions.

:param params: registered parameters
"""
if isinstance(params, Tensor):
params = [params]
for idx, param in enumerate(params):
if id(param) in self._param_dict:
self._param_dict.pop(id(param))
self._call_back_dict.pop(id(param))
else:
logger.warning("params with index {} is not attached.".format(idx))

def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
return self
@@ -136,7 +153,7 @@ class GradManager:
else:
param.grad += grad
finally:
self._stop_record()
self.release()
backwarding_grad_manager = cache

def record(self):
@@ -167,15 +184,10 @@ class GradManager:
def release(self):
r"""Stops recording and releases resources for gradients calculation.
"""
if not self._recording:
raise RuntimeError("not recording")
self._stop_record()

def _stop_record(self):
if self._grad is not None:
self._grad.__exit__(None, None, None)
self._grad = None
self._recording = False
self._grad = None
self._gradients = dict()

def __enter__(self):
@@ -183,4 +195,4 @@ class GradManager:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_record()
self.release()

+ 2
- 5
imperative/python/megengine/tensor.py View File

@@ -85,11 +85,8 @@ class Tensor(_Tensor):

def detach(self):
r"""
Returns a new tensor which is treated as constant during backward gradient calcuation,
i.e. its gradient is zero.

:param inp: input tensor

Returns a new tensor sharing the same data memory, which is treated as a constant
during backward gradient calcuation, i.e. its gradient is zero.
"""
Wrapper = type(self)
Tensor = type(self.__wrapped__)


+ 41
- 5
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -1,15 +1,51 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

import megengine as mge
from megengine import autodiff as ad
import megengine.functional as F
from megengine.autodiff import GradManager


def test_basic():
x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1)
b = mge.tensor(-1.0)

gm = GradManager().attach([w, b])
gm.record()

p = F.matmul(x, w)
y = p + b

gm.backward(y)
gm.release() # is not necessary
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])

w.grad = None
b.grad = None
with gm:
p = F.matmul(x, w)
y = p + b
gm.backward(y)

np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])


def test_attach_in_with_block():
a = mge.Parameter([1.0])
g = ad.GradManager()
with g:
gm = GradManager()
with gm:
b = a * 3
g.attach(b)
gm.attach(b)
c = b + 1
g.backward(c)
gm.backward(c)
assert int(b.grad.numpy()) == 1

+ 0
- 2
imperative/python/test/unit/module/test_qat.py View File

@@ -27,8 +27,6 @@ def test_qat_convbn2d():
disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
# import pdb
# pdb.set_trace()
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6


Loading…
Cancel
Save