diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index a64b4fd1..54575a70 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -112,6 +112,14 @@ class GradManager: else: logger.warning("params with index {} is not attached.".format(idx)) + def clear_grad(self): + r""" + For advanced usage: set the grad attribute to None for registered parameters. + It could be more convenient when there is more than one Optimizer. + """ + for param in self._param_dict.values(): + param.grad = None + def _register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) return self diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index cf869cdf..fc9ebf0b 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -91,7 +91,7 @@ class Optimizer(metaclass=ABCMeta): if not isinstance(param, Parameter): raise TypeError( "optimizer can only optimize Parameters, but one of the params is " - + type(param) + + str(type(param)) ) for name, default in self._defaults.items(): @@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): def clear_grad(self): r"""Set the grad attribute to None for all parameters. - """ for param_group in self.param_groups: for param in param_group["params"]: diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 372b3816..f54bd02a 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -29,8 +29,7 @@ def test_basic(): np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) np.testing.assert_equal(b.grad.numpy(), [1]) - w.grad = None - b.grad = None + gm.clear_grad() with gm: p = F.matmul(x, w) y = p + b