diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 7041b93f..8c7bf8cd 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -294,7 +294,8 @@ class Module(metaclass=ABCMeta): self.training = mode def fn(x) -> None: - x.training = mode + if x is not self: + x.train(mode=mode) self.apply(fn) diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 64e6adda..a1d9b84b 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -56,6 +56,13 @@ class Observer(Module): def disable(self): self.enabled = False + def train(self, mode: bool = True) -> None: + super().train(mode) + if mode: + self.enable() + else: + self.disable() + @abstractmethod def forward(self, x): pass