|
|
@@ -1,5 +1,7 @@ |
|
|
|
import collections |
|
|
|
import functools |
|
|
|
from collections import namedtuple |
|
|
|
from contextlib import contextmanager |
|
|
|
from functools import partial |
|
|
|
from typing import Iterable |
|
|
|
|
|
|
@@ -22,7 +24,6 @@ except AttributeError as e: |
|
|
|
logger = get_logger(__name__) |
|
|
|
logger.setLevel("INFO") |
|
|
|
|
|
|
|
|
|
|
|
_calc_flops_dict = {} |
|
|
|
_calc_receptive_field_dict = {} |
|
|
|
|
|
|
@@ -147,7 +148,7 @@ def flops_batchmatmul(module: M.BatchMatMulActivation, inputs, outputs): |
|
|
|
|
|
|
|
|
|
|
|
# does not need import qat and quantized module since they inherit from float module. |
|
|
|
hook_modules = ( |
|
|
|
hook_modules = [ |
|
|
|
M.conv._ConvNd, |
|
|
|
M.Linear, |
|
|
|
M.BatchMatMulActivation, |
|
|
@@ -157,7 +158,18 @@ hook_modules = ( |
|
|
|
M.InstanceNorm, |
|
|
|
M.pooling._PoolNd, |
|
|
|
M.adaptive_pooling._AdaptivePoolNd, |
|
|
|
) |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def register_hook_module(module): |
|
|
|
if isinstance(module, (tuple, list)): |
|
|
|
modules = list(module) |
|
|
|
for module in modules: |
|
|
|
register_hook_module(module) |
|
|
|
elif isinstance(module, M.Module): |
|
|
|
hook_modules.append(module) |
|
|
|
else: |
|
|
|
raise TypeError("the param type should in [list,tuple,M.Module]") |
|
|
|
|
|
|
|
|
|
|
|
def _mean(inp): |
|
|
@@ -519,12 +531,49 @@ def module_stats( |
|
|
|
) |
|
|
|
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) |
|
|
|
|
|
|
|
module_to_name = dict() |
|
|
|
for (name, module) in model.named_modules(): |
|
|
|
if isinstance(module, hook_modules): |
|
|
|
if isinstance(module, tuple(hook_modules)): |
|
|
|
hooks.append( |
|
|
|
module.register_forward_hook(partial(module_stats_hook, name=name)) |
|
|
|
) |
|
|
|
with set_module_mode_safe(model, training=False) as model: |
|
|
|
module_to_name[module] = name |
|
|
|
|
|
|
|
@contextmanager |
|
|
|
def param_stat_context(): |
|
|
|
def wrapper(fun): |
|
|
|
@functools.wraps(fun) |
|
|
|
def param_access_record(module, item): |
|
|
|
member = fun(module, item) |
|
|
|
if ( |
|
|
|
item in ["weight", "bias"] |
|
|
|
and member is not None |
|
|
|
and member not in recorded_parameters |
|
|
|
): |
|
|
|
name = module_to_name[module] |
|
|
|
if item == "weight": |
|
|
|
suffix = "-w" |
|
|
|
elif item == "bias": |
|
|
|
suffix = "-b" |
|
|
|
|
|
|
|
param_name = name + suffix |
|
|
|
param_stats = get_param_stats(member) |
|
|
|
param_stats["name"] = param_name |
|
|
|
params.append(param_stats) |
|
|
|
recorded_parameters.add(member) |
|
|
|
|
|
|
|
return member |
|
|
|
|
|
|
|
return param_access_record |
|
|
|
|
|
|
|
origin_get_attr = object.__getattribute__ |
|
|
|
try: |
|
|
|
M.Module.__getattribute__ = wrapper(origin_get_attr) |
|
|
|
yield |
|
|
|
finally: |
|
|
|
M.Module.__getattribute__ = origin_get_attr |
|
|
|
|
|
|
|
with set_module_mode_safe(model, training=False) as model, param_stat_context(): |
|
|
|
model(*inputs) |
|
|
|
|
|
|
|
for h in hooks: |
|
|
|