GitOrigin-RevId: aa9540e090
release-1.1
@@ -112,6 +112,14 @@ class GradManager: | |||||
else: | else: | ||||
logger.warning("params with index {} is not attached.".format(idx)) | logger.warning("params with index {} is not attached.".format(idx)) | ||||
def clear_grad(self): | |||||
r""" | |||||
For advanced usage: set the grad attribute to None for registered parameters. | |||||
It could be more convenient when there is more than one Optimizer. | |||||
""" | |||||
for param in self._param_dict.values(): | |||||
param.grad = None | |||||
def _register_after_backward_callback(self, callback): | def _register_after_backward_callback(self, callback): | ||||
self._after_backward_callback.append(callback) | self._after_backward_callback.append(callback) | ||||
return self | return self | ||||
@@ -91,7 +91,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
if not isinstance(param, Parameter): | if not isinstance(param, Parameter): | ||||
raise TypeError( | raise TypeError( | ||||
"optimizer can only optimize Parameters, but one of the params is " | "optimizer can only optimize Parameters, but one of the params is " | ||||
+ type(param) | |||||
+ str(type(param)) | |||||
) | ) | ||||
for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): | |||||
def clear_grad(self): | def clear_grad(self): | ||||
r"""Set the grad attribute to None for all parameters. | r"""Set the grad attribute to None for all parameters. | ||||
""" | """ | ||||
for param_group in self.param_groups: | for param_group in self.param_groups: | ||||
for param in param_group["params"]: | for param in param_group["params"]: | ||||
@@ -29,8 +29,7 @@ def test_basic(): | |||||
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) | np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) | ||||
np.testing.assert_equal(b.grad.numpy(), [1]) | np.testing.assert_equal(b.grad.numpy(), [1]) | ||||
w.grad = None | |||||
b.grad = None | |||||
gm.clear_grad() | |||||
with gm: | with gm: | ||||
p = F.matmul(x, w) | p = F.matmul(x, w) | ||||
y = p + b | y = p + b | ||||