|
|
@@ -291,19 +291,21 @@ class Module(metaclass=ABCMeta): |
|
|
|
if param.grad is not None: |
|
|
|
param.grad.reset_zero() |
|
|
|
|
|
|
|
def train(self, mode: bool = True) -> None: |
|
|
|
def train(self, mode: bool = True, recursive: bool = True) -> None: |
|
|
|
"""Set training mode of all the modules within this module (including itself) to |
|
|
|
``mode``. This effectively sets the ``training`` attributes of those modules |
|
|
|
to ``mode``, but only has effect on certain modules (e.g. |
|
|
|
:class:`~.BatchNorm2d`, :class:`~.Dropout`) |
|
|
|
:class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) |
|
|
|
|
|
|
|
:param mode: The training mode to be set on modules. |
|
|
|
:param mode: the training mode to be set on modules. |
|
|
|
:param recursive: whether to recursively call submodules' ``train()``. |
|
|
|
""" |
|
|
|
self.training = mode |
|
|
|
if not recursive: |
|
|
|
self.training = mode |
|
|
|
return |
|
|
|
|
|
|
|
def fn(x) -> None: |
|
|
|
if x is not self: |
|
|
|
x.train(mode=mode) |
|
|
|
def fn(module: Module) -> None: |
|
|
|
module.train(mode, recursive=False) |
|
|
|
|
|
|
|
self.apply(fn) |
|
|
|
|
|
|
|