|
|
@@ -3,7 +3,7 @@ from contextlib import contextmanager |
|
|
|
from typing import Callable |
|
|
|
|
|
|
|
from ..core.autodiff.grad import Grad |
|
|
|
from ..tensor import tensor |
|
|
|
from ..tensor import Tensor, tensor |
|
|
|
from ..utils.future import Future |
|
|
|
|
|
|
|
backwarding_grad_manager = None |
|
|
@@ -84,10 +84,15 @@ class GradManager: |
|
|
|
callbacks = [] |
|
|
|
if isinstance(callbacks, Callable): |
|
|
|
callbacks = [callbacks] |
|
|
|
if isinstance(params, Tensor): |
|
|
|
params = [params] |
|
|
|
for p in params: |
|
|
|
self._param_dict[id(p)] = p |
|
|
|
for cb in callbacks: |
|
|
|
self._call_back_dict[id(p)].append(cb) |
|
|
|
if self._grad is not None: |
|
|
|
for p in params: |
|
|
|
self._record_param(id(p)) |
|
|
|
return self |
|
|
|
|
|
|
|
def _register_after_backward_callback(self, callback): |
|
|
@@ -143,17 +148,21 @@ class GradManager: |
|
|
|
self._recording = True |
|
|
|
self._grad = grad |
|
|
|
for param_id in self._param_dict.keys(): |
|
|
|
param_wrapper = self._param_dict[param_id] |
|
|
|
callbacks = self._call_back_dict[param_id] |
|
|
|
self._record_param(param_id) |
|
|
|
grad.__enter__() |
|
|
|
|
|
|
|
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): |
|
|
|
ret = grad |
|
|
|
for cb in callbacks: |
|
|
|
ret = cb(param, ret) |
|
|
|
gm._gradients[id(p)] = ret |
|
|
|
def _record_param(self, param_id): |
|
|
|
param_wrapper = self._param_dict[param_id] |
|
|
|
callbacks = self._call_back_dict[param_id] |
|
|
|
|
|
|
|
grad.wrt(param_wrapper, callback=callback) |
|
|
|
grad.__enter__() |
|
|
|
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): |
|
|
|
ret = grad |
|
|
|
for cb in callbacks: |
|
|
|
ret = cb(param, ret) |
|
|
|
gm._gradients[id(p)] = ret |
|
|
|
|
|
|
|
# NOTE: override prev callback wrt when called serval times |
|
|
|
self._grad.wrt(param_wrapper, callback=callback) |
|
|
|
|
|
|
|
def release(self): |
|
|
|
r"""Stops recording and releases resources for gradients calculation. |
|
|
|