Browse Source

feat(mge/grad_manager): add `clear_grad` method for GradManager

GitOrigin-RevId: aa9540e090
release-1.1
Megvii Engine Team 4 years ago
parent
commit
b327822994
3 changed files with 10 additions and 4 deletions
  1. +8
    -0
      imperative/python/megengine/autodiff/grad_manager.py
  2. +1
    -2
      imperative/python/megengine/optimizer/optimizer.py
  3. +1
    -2
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 8
- 0
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -112,6 +112,14 @@ class GradManager:
else:
logger.warning("params with index {} is not attached.".format(idx))

def clear_grad(self):
r"""
For advanced usage: set the grad attribute to None for registered parameters.
It could be more convenient when there is more than one Optimizer.
"""
for param in self._param_dict.values():
param.grad = None

def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
return self


+ 1
- 2
imperative/python/megengine/optimizer/optimizer.py View File

@@ -91,7 +91,7 @@ class Optimizer(metaclass=ABCMeta):
if not isinstance(param, Parameter):
raise TypeError(
"optimizer can only optimize Parameters, but one of the params is "
+ type(param)
+ str(type(param))
)

for name, default in self._defaults.items():
@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta):

def clear_grad(self):
r"""Set the grad attribute to None for all parameters.

"""
for param_group in self.param_groups:
for param in param_group["params"]:


+ 1
- 2
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -29,8 +29,7 @@ def test_basic():
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])

w.grad = None
b.grad = None
gm.clear_grad()
with gm:
p = F.matmul(x, w)
y = p + b


Loading…
Cancel
Save