|
|
@@ -42,14 +42,15 @@ class GradManager: |
|
|
|
self._recording = True |
|
|
|
self._grad = grad |
|
|
|
for params, callbacks in self._call_back_pair: |
|
|
|
for p in params: |
|
|
|
|
|
|
|
def callback(param, grad, callbacks=callbacks): |
|
|
|
ret = grad |
|
|
|
for cb in callbacks: |
|
|
|
ret = cb(param, ret) |
|
|
|
param.grad = ret |
|
|
|
def callback(param, grad, callbacks=callbacks, p=p): |
|
|
|
ret = grad |
|
|
|
for cb in callbacks: |
|
|
|
ret = cb(param, ret) |
|
|
|
p.grad = ret |
|
|
|
|
|
|
|
grad.wrt(*params, callback=callback) |
|
|
|
grad.wrt(p, callback=callback) |
|
|
|
with grad: |
|
|
|
yield |
|
|
|
finally: |
|
|
|