|
@@ -3,9 +3,12 @@ from contextlib import contextmanager |
|
|
from typing import Callable |
|
|
from typing import Callable |
|
|
|
|
|
|
|
|
from ..core.autodiff.grad import Grad |
|
|
from ..core.autodiff.grad import Grad |
|
|
from ..tensor import Tensor, tensor |
|
|
|
|
|
|
|
|
from ..logger import get_logger |
|
|
|
|
|
from ..tensor import Tensor |
|
|
from ..utils.future import Future |
|
|
from ..utils.future import Future |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
backwarding_grad_manager = None |
|
|
backwarding_grad_manager = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -67,7 +70,7 @@ class GradManager: |
|
|
self._after_backward_callback = [] |
|
|
self._after_backward_callback = [] |
|
|
self._gradients = dict() |
|
|
self._gradients = dict() |
|
|
|
|
|
|
|
|
def attach(self, params, callbacks=None): |
|
|
|
|
|
|
|
|
def attach(self, params: list, callbacks=None): |
|
|
r"""Registers parameters that gradients should be calculated with respect to. |
|
|
r"""Registers parameters that gradients should be calculated with respect to. |
|
|
Callback Functions should have a signature like this: |
|
|
Callback Functions should have a signature like this: |
|
|
|
|
|
|
|
@@ -77,7 +80,7 @@ class GradManager: |
|
|
# do something |
|
|
# do something |
|
|
return grad |
|
|
return grad |
|
|
|
|
|
|
|
|
:param params: registered parameters |
|
|
|
|
|
|
|
|
:param params: to be registered parameters |
|
|
:param callbacks: list of callback functions |
|
|
:param callbacks: list of callback functions |
|
|
""" |
|
|
""" |
|
|
if callbacks is None: |
|
|
if callbacks is None: |
|
@@ -95,6 +98,20 @@ class GradManager: |
|
|
self._record_param(id(p)) |
|
|
self._record_param(id(p)) |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
def detach(self, params: list): |
|
|
|
|
|
r"""Remove specific registered parameters and callback functions. |
|
|
|
|
|
|
|
|
|
|
|
:param params: registered parameters |
|
|
|
|
|
""" |
|
|
|
|
|
if isinstance(params, Tensor): |
|
|
|
|
|
params = [params] |
|
|
|
|
|
for idx, param in enumerate(params): |
|
|
|
|
|
if id(param) in self._param_dict: |
|
|
|
|
|
self._param_dict.pop(id(param)) |
|
|
|
|
|
self._call_back_dict.pop(id(param)) |
|
|
|
|
|
else: |
|
|
|
|
|
logger.warning("params with index {} is not attached.".format(idx)) |
|
|
|
|
|
|
|
|
def _register_after_backward_callback(self, callback): |
|
|
def _register_after_backward_callback(self, callback): |
|
|
self._after_backward_callback.append(callback) |
|
|
self._after_backward_callback.append(callback) |
|
|
return self |
|
|
return self |
|
@@ -136,7 +153,7 @@ class GradManager: |
|
|
else: |
|
|
else: |
|
|
param.grad += grad |
|
|
param.grad += grad |
|
|
finally: |
|
|
finally: |
|
|
self._stop_record() |
|
|
|
|
|
|
|
|
self.release() |
|
|
backwarding_grad_manager = cache |
|
|
backwarding_grad_manager = cache |
|
|
|
|
|
|
|
|
def record(self): |
|
|
def record(self): |
|
@@ -167,15 +184,10 @@ class GradManager: |
|
|
def release(self): |
|
|
def release(self): |
|
|
r"""Stops recording and releases resources for gradients calculation. |
|
|
r"""Stops recording and releases resources for gradients calculation. |
|
|
""" |
|
|
""" |
|
|
if not self._recording: |
|
|
|
|
|
raise RuntimeError("not recording") |
|
|
|
|
|
self._stop_record() |
|
|
|
|
|
|
|
|
|
|
|
def _stop_record(self): |
|
|
|
|
|
if self._grad is not None: |
|
|
if self._grad is not None: |
|
|
self._grad.__exit__(None, None, None) |
|
|
self._grad.__exit__(None, None, None) |
|
|
|
|
|
self._grad = None |
|
|
self._recording = False |
|
|
self._recording = False |
|
|
self._grad = None |
|
|
|
|
|
self._gradients = dict() |
|
|
self._gradients = dict() |
|
|
|
|
|
|
|
|
def __enter__(self): |
|
|
def __enter__(self): |
|
@@ -183,4 +195,4 @@ class GradManager: |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
self._stop_record() |
|
|
|
|
|
|
|
|
self.release() |