Browse Source

feat(mge/grad): attach grad immediately

GitOrigin-RevId: e3a168c03a
release-1.1
Megvii Engine Team 4 years ago
parent
commit
3af1056377
2 changed files with 34 additions and 10 deletions
  1. +19
    -10
      imperative/python/megengine/autodiff/grad_manager.py
  2. +15
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 19
- 10
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -3,7 +3,7 @@ from contextlib import contextmanager
from typing import Callable

from ..core.autodiff.grad import Grad
from ..tensor import tensor
from ..tensor import Tensor, tensor
from ..utils.future import Future

backwarding_grad_manager = None
@@ -84,10 +84,15 @@ class GradManager:
callbacks = []
if isinstance(callbacks, Callable):
callbacks = [callbacks]
if isinstance(params, Tensor):
params = [params]
for p in params:
self._param_dict[id(p)] = p
for cb in callbacks:
self._call_back_dict[id(p)].append(cb)
if self._grad is not None:
for p in params:
self._record_param(id(p))
return self

def _register_after_backward_callback(self, callback):
@@ -143,17 +148,21 @@ class GradManager:
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
self._record_param(param_id)
grad.__enter__()

def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
def _record_param(self, param_id):
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]

grad.wrt(param_wrapper, callback=callback)
grad.__enter__()
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret

# NOTE: override prev callback wrt when called serval times
self._grad.wrt(param_wrapper, callback=callback)

def release(self):
r"""Stops recording and releases resources for gradients calculation.


+ 15
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -0,0 +1,15 @@
import numpy as np

import megengine as mge
from megengine import autodiff as ad


def test_attach_in_with_block():
a = mge.Parameter([1.0])
g = ad.GradManager()
with g:
b = a * 3
g.attach(b)
c = b + 1
g.backward(c)
assert int(b.grad.numpy()) == 1

Loading…
Cancel
Save