Browse Source

fix(mge/module): fix redundant recursion in `train()`

GitOrigin-RevId: 6b3566930b
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
38f7cbd9aa
2 changed files with 11 additions and 9 deletions
  1. +9
    -7
      python_module/megengine/module/module.py
  2. +2
    -2
      python_module/megengine/quantization/observer.py

+ 9
- 7
python_module/megengine/module/module.py View File

@@ -291,19 +291,21 @@ class Module(metaclass=ABCMeta):
if param.grad is not None:
param.grad.reset_zero()

def train(self, mode: bool = True) -> None:
def train(self, mode: bool = True, recursive: bool = True) -> None:
"""Set training mode of all the modules within this module (including itself) to
``mode``. This effectively sets the ``training`` attributes of those modules
to ``mode``, but only has effect on certain modules (e.g.
:class:`~.BatchNorm2d`, :class:`~.Dropout`)
:class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`)

:param mode: The training mode to be set on modules.
:param mode: the training mode to be set on modules.
:param recursive: whether to recursively call submodules' ``train()``.
"""
self.training = mode
if not recursive:
self.training = mode
return

def fn(x) -> None:
if x is not self:
x.train(mode=mode)
def fn(module: Module) -> None:
module.train(mode, recursive=False)

self.apply(fn)



+ 2
- 2
python_module/megengine/quantization/observer.py View File

@@ -60,8 +60,8 @@ class Observer(Module):
def disable(self):
self.enabled = False

def train(self, mode: bool = True) -> None:
super().train(mode)
def train(self, mode: bool = True, recursive: bool = True) -> None:
super().train(mode, recursive)
if mode:
self.enable()
else:


Loading…
Cancel
Save