|
|
@@ -1,3 +1,4 @@ |
|
|
|
import weakref |
|
|
|
from collections import defaultdict |
|
|
|
from contextlib import contextmanager |
|
|
|
from typing import Callable |
|
|
@@ -16,6 +17,10 @@ def get_backwarding_grad_manager(): |
|
|
|
return backwarding_grad_manager |
|
|
|
|
|
|
|
|
|
|
|
class AttachSpec: |
|
|
|
__slots__ = "tensor", "callbacks" |
|
|
|
|
|
|
|
|
|
|
|
class GradManager: |
|
|
|
r""" |
|
|
|
GradManager manages auto differentiation and all resources required to perform it. |
|
|
@@ -64,14 +69,13 @@ class GradManager: |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self._call_back_dict = defaultdict(list) |
|
|
|
self._param_dict = dict() |
|
|
|
self._attach_specs = {} # id(Tensor) -> AttachSpec |
|
|
|
self._recording = False |
|
|
|
self._grad = None |
|
|
|
self._after_backward_callback = [] |
|
|
|
self._gradients = dict() |
|
|
|
self._gradients = {} |
|
|
|
|
|
|
|
def attach(self, params: list, callbacks=None): |
|
|
|
def attach(self, tensors: list, callbacks=None): |
|
|
|
r""" |
|
|
|
Registers parameters that gradients should be calculated with respect to. |
|
|
|
Callback Functions should have a signature like this: |
|
|
@@ -89,22 +93,39 @@ class GradManager: |
|
|
|
callbacks = [] |
|
|
|
if isinstance(callbacks, Callable): |
|
|
|
callbacks = [callbacks] |
|
|
|
if isinstance(params, Tensor): |
|
|
|
params = [params] |
|
|
|
for p in params: |
|
|
|
self._param_dict[id(p)] = p |
|
|
|
for cb in callbacks: |
|
|
|
self._call_back_dict[id(p)].append(cb) |
|
|
|
if self._grad is not None: |
|
|
|
for p in params: |
|
|
|
self._record_param(id(p)) |
|
|
|
if isinstance(tensors, Tensor): |
|
|
|
tensors = [tensors] |
|
|
|
|
|
|
|
def make_spec(tensor): |
|
|
|
selfref = weakref.ref(self) |
|
|
|
key = id(tensor) |
|
|
|
|
|
|
|
def deleter(_): |
|
|
|
self = selfref() |
|
|
|
if self is not None: |
|
|
|
del self._attach_specs[key] |
|
|
|
|
|
|
|
spec = AttachSpec() |
|
|
|
spec.tensor = weakref.ref(tensor, deleter) |
|
|
|
spec.callbacks = [] |
|
|
|
return spec |
|
|
|
|
|
|
|
for x in tensors: |
|
|
|
spec = self._attach_specs.get(id(x)) |
|
|
|
new_attach = spec is None |
|
|
|
if spec is None: |
|
|
|
spec = make_spec(x) |
|
|
|
self._attach_specs[id(x)] = spec |
|
|
|
spec.callbacks.extend(callbacks) |
|
|
|
if new_attach and self._recording: |
|
|
|
self._do_record(spec) |
|
|
|
return self |
|
|
|
|
|
|
|
def _register_after_backward_callback(self, callback): |
|
|
|
self._after_backward_callback.append(callback) |
|
|
|
return self |
|
|
|
|
|
|
|
def backward(self, ys=None, dys=None): |
|
|
|
def backward(self, y=None, dy=None): |
|
|
|
r""" |
|
|
|
Performs back-propagation and computes gradients. |
|
|
|
|
|
|
@@ -135,14 +156,16 @@ class GradManager: |
|
|
|
self._grad(ys, dys) |
|
|
|
for callback in self._after_backward_callback: |
|
|
|
callback() |
|
|
|
for p, grad in self._gradients.items(): |
|
|
|
for id_, grad in self._gradients.items(): |
|
|
|
if isinstance(grad, Future): |
|
|
|
grad = grad.get() |
|
|
|
param = self._param_dict[p] |
|
|
|
if param.grad is None: |
|
|
|
param.grad = grad |
|
|
|
else: |
|
|
|
param.grad += grad |
|
|
|
spec = self._attach_specs.get(id_) |
|
|
|
tensor = spec and spec.tensor() |
|
|
|
if tensor is not None: |
|
|
|
if tensor.grad is None: |
|
|
|
tensor.grad = grad |
|
|
|
else: |
|
|
|
tensor.grad += grad |
|
|
|
finally: |
|
|
|
self.release() |
|
|
|
backwarding_grad_manager = cache |
|
|
@@ -156,22 +179,22 @@ class GradManager: |
|
|
|
grad = Grad() |
|
|
|
self._recording = True |
|
|
|
self._grad = grad |
|
|
|
for param_id in self._param_dict.keys(): |
|
|
|
self._record_param(param_id) |
|
|
|
for spec in self._attach_specs.values(): |
|
|
|
self._do_record(spec) |
|
|
|
grad.__enter__() |
|
|
|
|
|
|
|
def _record_param(self, param_id): |
|
|
|
param_wrapper = self._param_dict[param_id] |
|
|
|
callbacks = self._call_back_dict[param_id] |
|
|
|
def _do_record(self, spec): |
|
|
|
tensor = spec.tensor() |
|
|
|
if tensor is None: |
|
|
|
return |
|
|
|
|
|
|
|
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): |
|
|
|
ret = grad |
|
|
|
def callback(_, grad, callbacks=spec.callbacks): |
|
|
|
for cb in callbacks: |
|
|
|
ret = cb(param, ret) |
|
|
|
gm._gradients[id(p)] = ret |
|
|
|
grad = cb(tensor, grad) |
|
|
|
self._gradients[id(tensor)] = grad |
|
|
|
|
|
|
|
# NOTE: override prev callback wrt when called serval times |
|
|
|
self._grad.wrt(param_wrapper, callback=callback) |
|
|
|
self._grad.wrt(tensor, callback=callback) |
|
|
|
|
|
|
|
def release(self): |
|
|
|
r""" |
|
|
|