Browse Source

fix(mge/utils): fix module_stats missing param problem

GitOrigin-RevId: abbf4935da
master
Megvii Engine Team 2 years ago
parent
commit
069e4e07a1
2 changed files with 90 additions and 6 deletions
  1. +54
    -5
      imperative/python/megengine/utils/module_stats.py
  2. +36
    -1
      imperative/python/test/unit/utils/test_module_stats.py

+ 54
- 5
imperative/python/megengine/utils/module_stats.py View File

@@ -1,5 +1,7 @@
import collections
import functools
from collections import namedtuple
from contextlib import contextmanager
from functools import partial
from typing import Iterable

@@ -22,7 +24,6 @@ except AttributeError as e:
logger = get_logger(__name__)
logger.setLevel("INFO")


_calc_flops_dict = {}
_calc_receptive_field_dict = {}

@@ -147,7 +148,7 @@ def flops_batchmatmul(module: M.BatchMatMulActivation, inputs, outputs):


# does not need import qat and quantized module since they inherit from float module.
hook_modules = (
hook_modules = [
M.conv._ConvNd,
M.Linear,
M.BatchMatMulActivation,
@@ -157,7 +158,18 @@ hook_modules = (
M.InstanceNorm,
M.pooling._PoolNd,
M.adaptive_pooling._AdaptivePoolNd,
)
]


def register_hook_module(module):
if isinstance(module, (tuple, list)):
modules = list(module)
for module in modules:
register_hook_module(module)
elif isinstance(module, M.Module):
hook_modules.append(module)
else:
raise TypeError("the param type should in [list,tuple,M.Module]")


def _mean(inp):
@@ -519,12 +531,49 @@ def module_stats(
)
stats_details = namedtuple("module_stats", ["params", "flops", "activations"])

module_to_name = dict()
for (name, module) in model.named_modules():
if isinstance(module, hook_modules):
if isinstance(module, tuple(hook_modules)):
hooks.append(
module.register_forward_hook(partial(module_stats_hook, name=name))
)
with set_module_mode_safe(model, training=False) as model:
module_to_name[module] = name

@contextmanager
def param_stat_context():
def wrapper(fun):
@functools.wraps(fun)
def param_access_record(module, item):
member = fun(module, item)
if (
item in ["weight", "bias"]
and member is not None
and member not in recorded_parameters
):
name = module_to_name[module]
if item == "weight":
suffix = "-w"
elif item == "bias":
suffix = "-b"

param_name = name + suffix
param_stats = get_param_stats(member)
param_stats["name"] = param_name
params.append(param_stats)
recorded_parameters.add(member)

return member

return param_access_record

origin_get_attr = object.__getattribute__
try:
M.Module.__getattribute__ = wrapper(origin_get_attr)
yield
finally:
M.Module.__getattribute__ = origin_get_attr

with set_module_mode_safe(model, training=False) as model, param_stat_context():
model(*inputs)

for h in hooks:


+ 36
- 1
imperative/python/test/unit/utils/test_module_stats.py View File

@@ -64,6 +64,36 @@ def test_duplicated_module():
assert net0_stats.param_size == net2_stats.param_size


@pytest.mark.skipif(
use_symbolic_shape(), reason="This test do not support symbolic shape.",
)
def test_getattribute_param():
class MyConvBn(M.Module):
def __init__(self):
super().__init__()
self.in_channels = 64
self.conv1 = M.Conv2d(
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True
)
self.bn1 = M.BatchNorm2d(self.in_channels)

def forward(self, input):
input = self.conv1.calc_conv(input, self.conv1.weight, self.conv1.bias)
input = self.bn1(input)
return input

model = MyConvBn()
input_shape = (1, 3, 224, 224)
total_stats, stats_detail = module_stats(model, input_shapes=input_shape)
params = stats_detail.params

def get_name(obj):
return obj["name"]

param_name = list(map(get_name, params))
assert "conv1-w" in param_name and "conv1-b" in param_name


class TestNet0(M.Module):
def __init__(self):
super().__init__()
@@ -108,7 +138,12 @@ class FakeNet(M.Module):
def forward(self, x):
assert isinstance(
x,
(np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor),
(
np.ndarray,
collections.abc.Mapping,
collections.abc.Sequence,
mge.Tensor,
),
) or (isinstance(x, tuple) and hasattr(x, "_fields"))




Loading…
Cancel
Save