Browse Source

fix(mge/utils): fix module_stats missing param problem

GitOrigin-RevId: abbf4935da
master
Megvii Engine Team 2 years ago
parent
commit
069e4e07a1
2 changed files with 90 additions and 6 deletions
  1. +54
    -5
      imperative/python/megengine/utils/module_stats.py
  2. +36
    -1
      imperative/python/test/unit/utils/test_module_stats.py

+ 54
- 5
imperative/python/megengine/utils/module_stats.py View File

@@ -1,5 +1,7 @@
import collections import collections
import functools
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Iterable from typing import Iterable


@@ -22,7 +24,6 @@ except AttributeError as e:
logger = get_logger(__name__) logger = get_logger(__name__)
logger.setLevel("INFO") logger.setLevel("INFO")



_calc_flops_dict = {} _calc_flops_dict = {}
_calc_receptive_field_dict = {} _calc_receptive_field_dict = {}


@@ -147,7 +148,7 @@ def flops_batchmatmul(module: M.BatchMatMulActivation, inputs, outputs):




# does not need import qat and quantized module since they inherit from float module. # does not need import qat and quantized module since they inherit from float module.
hook_modules = (
hook_modules = [
M.conv._ConvNd, M.conv._ConvNd,
M.Linear, M.Linear,
M.BatchMatMulActivation, M.BatchMatMulActivation,
@@ -157,7 +158,18 @@ hook_modules = (
M.InstanceNorm, M.InstanceNorm,
M.pooling._PoolNd, M.pooling._PoolNd,
M.adaptive_pooling._AdaptivePoolNd, M.adaptive_pooling._AdaptivePoolNd,
)
]


def register_hook_module(module):
if isinstance(module, (tuple, list)):
modules = list(module)
for module in modules:
register_hook_module(module)
elif isinstance(module, M.Module):
hook_modules.append(module)
else:
raise TypeError("the param type should in [list,tuple,M.Module]")




def _mean(inp): def _mean(inp):
@@ -519,12 +531,49 @@ def module_stats(
) )
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"])


module_to_name = dict()
for (name, module) in model.named_modules(): for (name, module) in model.named_modules():
if isinstance(module, hook_modules):
if isinstance(module, tuple(hook_modules)):
hooks.append( hooks.append(
module.register_forward_hook(partial(module_stats_hook, name=name)) module.register_forward_hook(partial(module_stats_hook, name=name))
) )
with set_module_mode_safe(model, training=False) as model:
module_to_name[module] = name

@contextmanager
def param_stat_context():
def wrapper(fun):
@functools.wraps(fun)
def param_access_record(module, item):
member = fun(module, item)
if (
item in ["weight", "bias"]
and member is not None
and member not in recorded_parameters
):
name = module_to_name[module]
if item == "weight":
suffix = "-w"
elif item == "bias":
suffix = "-b"

param_name = name + suffix
param_stats = get_param_stats(member)
param_stats["name"] = param_name
params.append(param_stats)
recorded_parameters.add(member)

return member

return param_access_record

origin_get_attr = object.__getattribute__
try:
M.Module.__getattribute__ = wrapper(origin_get_attr)
yield
finally:
M.Module.__getattribute__ = origin_get_attr

with set_module_mode_safe(model, training=False) as model, param_stat_context():
model(*inputs) model(*inputs)


for h in hooks: for h in hooks:


+ 36
- 1
imperative/python/test/unit/utils/test_module_stats.py View File

@@ -64,6 +64,36 @@ def test_duplicated_module():
assert net0_stats.param_size == net2_stats.param_size assert net0_stats.param_size == net2_stats.param_size




@pytest.mark.skipif(
use_symbolic_shape(), reason="This test do not support symbolic shape.",
)
def test_getattribute_param():
class MyConvBn(M.Module):
def __init__(self):
super().__init__()
self.in_channels = 64
self.conv1 = M.Conv2d(
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True
)
self.bn1 = M.BatchNorm2d(self.in_channels)

def forward(self, input):
input = self.conv1.calc_conv(input, self.conv1.weight, self.conv1.bias)
input = self.bn1(input)
return input

model = MyConvBn()
input_shape = (1, 3, 224, 224)
total_stats, stats_detail = module_stats(model, input_shapes=input_shape)
params = stats_detail.params

def get_name(obj):
return obj["name"]

param_name = list(map(get_name, params))
assert "conv1-w" in param_name and "conv1-b" in param_name


class TestNet0(M.Module): class TestNet0(M.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -108,7 +138,12 @@ class FakeNet(M.Module):
def forward(self, x): def forward(self, x):
assert isinstance( assert isinstance(
x, x,
(np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor),
(
np.ndarray,
collections.abc.Mapping,
collections.abc.Sequence,
mge.Tensor,
),
) or (isinstance(x, tuple) and hasattr(x, "_fields")) ) or (isinstance(x, tuple) and hasattr(x, "_fields"))






Loading…
Cancel
Save