diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index fa66ee57..531f937c 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -163,10 +163,10 @@ hook_modules = [ def register_hook_module(module): if isinstance(module, (tuple, list)): - modules = list(module) + modules = module for module in modules: register_hook_module(module) - elif isinstance(module, M.Module): + elif issubclass(module, M.Module): hook_modules.append(module) else: raise TypeError("the param type should in [list,tuple,M.Module]") diff --git a/imperative/python/test/unit/utils/test_module_stats.py b/imperative/python/test/unit/utils/test_module_stats.py index 53797b88..d98064b0 100644 --- a/imperative/python/test/unit/utils/test_module_stats.py +++ b/imperative/python/test/unit/utils/test_module_stats.py @@ -10,7 +10,11 @@ import megengine.functional as F import megengine.hub as hub import megengine.module as M from megengine.core._trace_option import use_symbolic_shape -from megengine.utils.module_stats import module_stats +from megengine.utils.module_stats import ( + hook_modules, + module_stats, + register_hook_module, +) @pytest.mark.skipif( @@ -75,6 +79,7 @@ def test_getattribute_param(): self.conv1 = M.Conv2d( 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True ) + self.conv1.reset_parameters() self.bn1 = M.BatchNorm2d(self.in_channels) def forward(self, input): @@ -90,8 +95,10 @@ def test_getattribute_param(): def get_name(obj): return obj["name"] - param_name = list(map(get_name, params)) - assert "conv1-w" in param_name and "conv1-b" in param_name + param_names = list(map(get_name, params)) + assert "conv1-w" in param_names and "conv1-b" in param_names + conv1_b_param = params[param_names.index("conv1-b")] + assert int(conv1_b_param["mean"]) == 0 and int(conv1_b_param["std"]) == 0 class TestNet0(M.Module): @@ -493,3 +500,10 @@ def cal_pool_stats(module, inputs, outputs): np.prod(outputs[0].shape) * (module.kernel_size ** 2), np.prod(outputs[0].shape), ) + + +def test_register_hook_module(): + modules = [TestNet0, TestNet1, TestNet2, FakeNet, BasicBlock, ResNet] + register_hook_module(modules) + for module in modules: + assert module in hook_modules