diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 6bf69220..fa66ee57 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -1,5 +1,7 @@ import collections +import functools from collections import namedtuple +from contextlib import contextmanager from functools import partial from typing import Iterable @@ -22,7 +24,6 @@ except AttributeError as e: logger = get_logger(__name__) logger.setLevel("INFO") - _calc_flops_dict = {} _calc_receptive_field_dict = {} @@ -147,7 +148,7 @@ def flops_batchmatmul(module: M.BatchMatMulActivation, inputs, outputs): # does not need import qat and quantized module since they inherit from float module. -hook_modules = ( +hook_modules = [ M.conv._ConvNd, M.Linear, M.BatchMatMulActivation, @@ -157,7 +158,18 @@ hook_modules = ( M.InstanceNorm, M.pooling._PoolNd, M.adaptive_pooling._AdaptivePoolNd, -) +] + + +def register_hook_module(module): + if isinstance(module, (tuple, list)): + modules = list(module) + for module in modules: + register_hook_module(module) + elif isinstance(module, M.Module): + hook_modules.append(module) + else: + raise TypeError("the param type should in [list,tuple,M.Module]") def _mean(inp): @@ -519,12 +531,49 @@ def module_stats( ) stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) + module_to_name = dict() for (name, module) in model.named_modules(): - if isinstance(module, hook_modules): + if isinstance(module, tuple(hook_modules)): hooks.append( module.register_forward_hook(partial(module_stats_hook, name=name)) ) - with set_module_mode_safe(model, training=False) as model: + module_to_name[module] = name + + @contextmanager + def param_stat_context(): + def wrapper(fun): + @functools.wraps(fun) + def param_access_record(module, item): + member = fun(module, item) + if ( + item in ["weight", "bias"] + and member is not None + and member not in recorded_parameters + ): + name = module_to_name[module] + if item == "weight": + suffix = "-w" + elif item == "bias": + suffix = "-b" + + param_name = name + suffix + param_stats = get_param_stats(member) + param_stats["name"] = param_name + params.append(param_stats) + recorded_parameters.add(member) + + return member + + return param_access_record + + origin_get_attr = object.__getattribute__ + try: + M.Module.__getattribute__ = wrapper(origin_get_attr) + yield + finally: + M.Module.__getattribute__ = origin_get_attr + + with set_module_mode_safe(model, training=False) as model, param_stat_context(): model(*inputs) for h in hooks: diff --git a/imperative/python/test/unit/utils/test_module_stats.py b/imperative/python/test/unit/utils/test_module_stats.py index 1abc2edd..53797b88 100644 --- a/imperative/python/test/unit/utils/test_module_stats.py +++ b/imperative/python/test/unit/utils/test_module_stats.py @@ -64,6 +64,36 @@ def test_duplicated_module(): assert net0_stats.param_size == net2_stats.param_size +@pytest.mark.skipif( + use_symbolic_shape(), reason="This test do not support symbolic shape.", +) +def test_getattribute_param(): + class MyConvBn(M.Module): + def __init__(self): + super().__init__() + self.in_channels = 64 + self.conv1 = M.Conv2d( + 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True + ) + self.bn1 = M.BatchNorm2d(self.in_channels) + + def forward(self, input): + input = self.conv1.calc_conv(input, self.conv1.weight, self.conv1.bias) + input = self.bn1(input) + return input + + model = MyConvBn() + input_shape = (1, 3, 224, 224) + total_stats, stats_detail = module_stats(model, input_shapes=input_shape) + params = stats_detail.params + + def get_name(obj): + return obj["name"] + + param_name = list(map(get_name, params)) + assert "conv1-w" in param_name and "conv1-b" in param_name + + class TestNet0(M.Module): def __init__(self): super().__init__() @@ -108,7 +138,12 @@ class FakeNet(M.Module): def forward(self, x): assert isinstance( x, - (np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor), + ( + np.ndarray, + collections.abc.Mapping, + collections.abc.Sequence, + mge.Tensor, + ), ) or (isinstance(x, tuple) and hasattr(x, "_fields"))