GitOrigin-RevId: 38ca4c78ff
release-1.1
@@ -3,9 +3,12 @@ from contextlib import contextmanager | |||
from typing import Callable | |||
from ..core.autodiff.grad import Grad | |||
from ..tensor import Tensor, tensor | |||
from ..logger import get_logger | |||
from ..tensor import Tensor | |||
from ..utils.future import Future | |||
logger = get_logger(__name__) | |||
backwarding_grad_manager = None | |||
@@ -67,7 +70,7 @@ class GradManager: | |||
self._after_backward_callback = [] | |||
self._gradients = dict() | |||
def attach(self, params, callbacks=None): | |||
def attach(self, params: list, callbacks=None): | |||
r"""Registers parameters that gradients should be calculated with respect to. | |||
Callback Functions should have a signature like this: | |||
@@ -77,7 +80,7 @@ class GradManager: | |||
# do something | |||
return grad | |||
:param params: registered parameters | |||
:param params: to be registered parameters | |||
:param callbacks: list of callback functions | |||
""" | |||
if callbacks is None: | |||
@@ -95,6 +98,20 @@ class GradManager: | |||
self._record_param(id(p)) | |||
return self | |||
def detach(self, params: list): | |||
r"""Remove specific registered parameters and callback functions. | |||
:param params: registered parameters | |||
""" | |||
if isinstance(params, Tensor): | |||
params = [params] | |||
for idx, param in enumerate(params): | |||
if id(param) in self._param_dict: | |||
self._param_dict.pop(id(param)) | |||
self._call_back_dict.pop(id(param)) | |||
else: | |||
logger.warning("params with index {} is not attached.".format(idx)) | |||
def _register_after_backward_callback(self, callback): | |||
self._after_backward_callback.append(callback) | |||
return self | |||
@@ -136,7 +153,7 @@ class GradManager: | |||
else: | |||
param.grad += grad | |||
finally: | |||
self._stop_record() | |||
self.release() | |||
backwarding_grad_manager = cache | |||
def record(self): | |||
@@ -167,15 +184,10 @@ class GradManager: | |||
def release(self): | |||
r"""Stops recording and releases resources for gradients calculation. | |||
""" | |||
if not self._recording: | |||
raise RuntimeError("not recording") | |||
self._stop_record() | |||
def _stop_record(self): | |||
if self._grad is not None: | |||
self._grad.__exit__(None, None, None) | |||
self._grad = None | |||
self._recording = False | |||
self._grad = None | |||
self._gradients = dict() | |||
def __enter__(self): | |||
@@ -183,4 +195,4 @@ class GradManager: | |||
return self | |||
def __exit__(self, exc_type, exc_val, exc_tb): | |||
self._stop_record() | |||
self.release() |
@@ -85,11 +85,8 @@ class Tensor(_Tensor): | |||
def detach(self): | |||
r""" | |||
Returns a new tensor which is treated as constant during backward gradient calcuation, | |||
i.e. its gradient is zero. | |||
:param inp: input tensor | |||
Returns a new tensor sharing the same data memory, which is treated as a constant | |||
during backward gradient calcuation, i.e. its gradient is zero. | |||
""" | |||
Wrapper = type(self) | |||
Tensor = type(self.__wrapped__) | |||
@@ -1,15 +1,51 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
from megengine import autodiff as ad | |||
import megengine.functional as F | |||
from megengine.autodiff import GradManager | |||
def test_basic(): | |||
x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3) | |||
w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1) | |||
b = mge.tensor(-1.0) | |||
gm = GradManager().attach([w, b]) | |||
gm.record() | |||
p = F.matmul(x, w) | |||
y = p + b | |||
gm.backward(y) | |||
gm.release() # is not necessary | |||
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) | |||
np.testing.assert_equal(b.grad.numpy(), [1]) | |||
w.grad = None | |||
b.grad = None | |||
with gm: | |||
p = F.matmul(x, w) | |||
y = p + b | |||
gm.backward(y) | |||
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) | |||
np.testing.assert_equal(b.grad.numpy(), [1]) | |||
def test_attach_in_with_block(): | |||
a = mge.Parameter([1.0]) | |||
g = ad.GradManager() | |||
with g: | |||
gm = GradManager() | |||
with gm: | |||
b = a * 3 | |||
g.attach(b) | |||
gm.attach(b) | |||
c = b + 1 | |||
g.backward(c) | |||
gm.backward(c) | |||
assert int(b.grad.numpy()) == 1 |
@@ -27,8 +27,6 @@ def test_qat_convbn2d(): | |||
disable_fake_quant(qat_module) | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
normal_outputs = module(inputs) | |||
# import pdb | |||
# pdb.set_trace() | |||
qat_outputs = qat_module(inputs) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||