Browse Source

feat(mge/grad_manager): add `clear_grad` method for GradManager

GitOrigin-RevId: aa9540e090
release-1.1
Megvii Engine Team 4 years ago
parent
commit
b327822994
3 changed files with 10 additions and 4 deletions
  1. +8
    -0
      imperative/python/megengine/autodiff/grad_manager.py
  2. +1
    -2
      imperative/python/megengine/optimizer/optimizer.py
  3. +1
    -2
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 8
- 0
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -112,6 +112,14 @@ class GradManager:
else: else:
logger.warning("params with index {} is not attached.".format(idx)) logger.warning("params with index {} is not attached.".format(idx))


def clear_grad(self):
r"""
For advanced usage: set the grad attribute to None for registered parameters.
It could be more convenient when there is more than one Optimizer.
"""
for param in self._param_dict.values():
param.grad = None

def _register_after_backward_callback(self, callback): def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
return self return self


+ 1
- 2
imperative/python/megengine/optimizer/optimizer.py View File

@@ -91,7 +91,7 @@ class Optimizer(metaclass=ABCMeta):
if not isinstance(param, Parameter): if not isinstance(param, Parameter):
raise TypeError( raise TypeError(
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ type(param)
+ str(type(param))
) )


for name, default in self._defaults.items(): for name, default in self._defaults.items():
@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta):


def clear_grad(self): def clear_grad(self):
r"""Set the grad attribute to None for all parameters. r"""Set the grad attribute to None for all parameters.

""" """
for param_group in self.param_groups: for param_group in self.param_groups:
for param in param_group["params"]: for param in param_group["params"]:


+ 1
- 2
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -29,8 +29,7 @@ def test_basic():
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1]) np.testing.assert_equal(b.grad.numpy(), [1])


w.grad = None
b.grad = None
gm.clear_grad()
with gm: with gm:
p = F.matmul(x, w) p = F.matmul(x, w)
y = p + b y = p + b


Loading…
Cancel
Save