Browse Source

fix(mge/utils): fix using wrong function in register_hook_module

GitOrigin-RevId: 75097de1c1
master
Megvii Engine Team 2 years ago
parent
commit
68d7320e69
2 changed files with 19 additions and 5 deletions
  1. +2
    -2
      imperative/python/megengine/utils/module_stats.py
  2. +17
    -3
      imperative/python/test/unit/utils/test_module_stats.py

+ 2
- 2
imperative/python/megengine/utils/module_stats.py View File

@@ -163,10 +163,10 @@ hook_modules = [

def register_hook_module(module):
if isinstance(module, (tuple, list)):
modules = list(module)
modules = module
for module in modules:
register_hook_module(module)
elif isinstance(module, M.Module):
elif issubclass(module, M.Module):
hook_modules.append(module)
else:
raise TypeError("the param type should in [list,tuple,M.Module]")


+ 17
- 3
imperative/python/test/unit/utils/test_module_stats.py View File

@@ -10,7 +10,11 @@ import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
from megengine.core._trace_option import use_symbolic_shape
from megengine.utils.module_stats import module_stats
from megengine.utils.module_stats import (
hook_modules,
module_stats,
register_hook_module,
)


@pytest.mark.skipif(
@@ -75,6 +79,7 @@ def test_getattribute_param():
self.conv1 = M.Conv2d(
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True
)
self.conv1.reset_parameters()
self.bn1 = M.BatchNorm2d(self.in_channels)

def forward(self, input):
@@ -90,8 +95,10 @@ def test_getattribute_param():
def get_name(obj):
return obj["name"]

param_name = list(map(get_name, params))
assert "conv1-w" in param_name and "conv1-b" in param_name
param_names = list(map(get_name, params))
assert "conv1-w" in param_names and "conv1-b" in param_names
conv1_b_param = params[param_names.index("conv1-b")]
assert int(conv1_b_param["mean"]) == 0 and int(conv1_b_param["std"]) == 0


class TestNet0(M.Module):
@@ -493,3 +500,10 @@ def cal_pool_stats(module, inputs, outputs):
np.prod(outputs[0].shape) * (module.kernel_size ** 2),
np.prod(outputs[0].shape),
)


def test_register_hook_module():
modules = [TestNet0, TestNet1, TestNet2, FakeNet, BasicBlock, ResNet]
register_hook_module(modules)
for module in modules:
assert module in hook_modules

Loading…
Cancel
Save