|
@@ -1,5 +1,6 @@ |
|
|
from collections import defaultdict |
|
|
from collections import defaultdict |
|
|
from contextlib import contextmanager |
|
|
from contextlib import contextmanager |
|
|
|
|
|
from typing import Callable |
|
|
|
|
|
|
|
|
from ..core.autodiff.grad import Grad |
|
|
from ..core.autodiff.grad import Grad |
|
|
from ..tensor import tensor |
|
|
from ..tensor import tensor |
|
@@ -21,7 +22,11 @@ class GradManager: |
|
|
self._after_backward_callback = [] |
|
|
self._after_backward_callback = [] |
|
|
self._gradients = dict() |
|
|
self._gradients = dict() |
|
|
|
|
|
|
|
|
def register(self, params, callbacks=[]): |
|
|
|
|
|
|
|
|
def register(self, params, callbacks=None): |
|
|
|
|
|
if callbacks is None: |
|
|
|
|
|
callbacks = [] |
|
|
|
|
|
if isinstance(callbacks, Callable): |
|
|
|
|
|
callbacks = [callbacks] |
|
|
for p in params: |
|
|
for p in params: |
|
|
self._param_dict[id(p)] = p |
|
|
self._param_dict[id(p)] = p |
|
|
for cb in callbacks: |
|
|
for cb in callbacks: |
|
@@ -62,37 +67,37 @@ class GradManager: |
|
|
else: |
|
|
else: |
|
|
param.grad += grad |
|
|
param.grad += grad |
|
|
finally: |
|
|
finally: |
|
|
self._grad = None |
|
|
|
|
|
self._gradients = dict() |
|
|
|
|
|
|
|
|
self._stop_record() |
|
|
backwarding_grad_manager = cache |
|
|
backwarding_grad_manager = cache |
|
|
|
|
|
|
|
|
def record(self): |
|
|
|
|
|
@contextmanager |
|
|
|
|
|
def recorder(): |
|
|
|
|
|
grad = Grad() |
|
|
|
|
|
if self._recording: |
|
|
|
|
|
raise RuntimeError("already recording!") |
|
|
|
|
|
try: |
|
|
|
|
|
self._recording = True |
|
|
|
|
|
self._grad = grad |
|
|
|
|
|
for param_id in self._param_dict.keys(): |
|
|
|
|
|
param_wrapper = self._param_dict[param_id] |
|
|
|
|
|
callbacks = self._call_back_dict[param_id] |
|
|
|
|
|
|
|
|
def __enter__(self): |
|
|
|
|
|
if self._recording: |
|
|
|
|
|
return self |
|
|
|
|
|
grad = Grad() |
|
|
|
|
|
self._recording = True |
|
|
|
|
|
self._grad = grad |
|
|
|
|
|
for param_id in self._param_dict.keys(): |
|
|
|
|
|
param_wrapper = self._param_dict[param_id] |
|
|
|
|
|
callbacks = self._call_back_dict[param_id] |
|
|
|
|
|
|
|
|
def callback( |
|
|
|
|
|
param, grad, callbacks=callbacks, p=param_wrapper, gm=self |
|
|
|
|
|
): |
|
|
|
|
|
ret = grad |
|
|
|
|
|
for cb in callbacks: |
|
|
|
|
|
ret = cb(param, ret) |
|
|
|
|
|
gm._gradients[id(p)] = ret |
|
|
|
|
|
|
|
|
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): |
|
|
|
|
|
ret = grad |
|
|
|
|
|
for cb in callbacks: |
|
|
|
|
|
ret = cb(param, ret) |
|
|
|
|
|
gm._gradients[id(p)] = ret |
|
|
|
|
|
|
|
|
grad.wrt(param_wrapper, callback=callback) |
|
|
|
|
|
with grad: |
|
|
|
|
|
yield |
|
|
|
|
|
finally: |
|
|
|
|
|
self._recording = False |
|
|
|
|
|
self._grad = None |
|
|
|
|
|
self._gradients = dict() |
|
|
|
|
|
|
|
|
grad.wrt(param_wrapper, callback=callback) |
|
|
|
|
|
grad.__enter__() |
|
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
|
|
|
self._stop_record() |
|
|
|
|
|
|
|
|
|
|
|
record = __enter__ |
|
|
|
|
|
|
|
|
return recorder() |
|
|
|
|
|
|
|
|
def _stop_record(self): |
|
|
|
|
|
if self._grad is not None: |
|
|
|
|
|
self._grad.__exit__(None, None, None) |
|
|
|
|
|
self._recording = False |
|
|
|
|
|
self._grad = None |
|
|
|
|
|
self._gradients = dict() |