Browse Source

refactor(mgb/grad): place grad at param.grad

GitOrigin-RevId: fddeace402
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
5ae89c799b
2 changed files with 7 additions and 12 deletions
  1. +7
    -6
      imperative/python/megengine/autodiff/grad_manager.py
  2. +0
    -6
      imperative/python/megengine/optimizer/sgd.py

+ 7
- 6
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -42,14 +42,15 @@ class GradManager:
self._recording = True
self._grad = grad
for params, callbacks in self._call_back_pair:
for p in params:

def callback(param, grad, callbacks=callbacks):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
param.grad = ret
def callback(param, grad, callbacks=callbacks, p=p):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
p.grad = ret

grad.wrt(*params, callback=callback)
grad.wrt(p, callback=callback)
with grad:
yield
finally:


+ 0
- 6
imperative/python/megengine/optimizer/sgd.py View File

@@ -52,12 +52,6 @@ class SGD(Optimizer):
momentum = param_group["momentum"]

for param in param_group["params"]:

if not isinstance(param.grad, Buffer):
raise TypeError(
"grad must be a Buffer, maybe you forget to call backward()?"
)

if not param.requires_grad:
continue



Loading…
Cancel
Save