Browse Source

fix(mge/tools): fix module stats' receptive field bug for Module

GitOrigin-RevId: b471363830
release-1.3
Megvii Engine Team 4 years ago
parent
commit
245a3f8129
2 changed files with 35 additions and 16 deletions
  1. +9
    -6
      imperative/python/megengine/tools/network_visualize.py
  2. +26
    -10
      imperative/python/megengine/utils/module_stats.py

+ 9
- 6
imperative/python/megengine/tools/network_visualize.py View File

@@ -15,10 +15,11 @@ import numpy as np
from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import (
get_flops_stats,
enable_receptive_field,
get_op_stats,
get_param_stats,
print_flops_stats,
print_params_stats,
print_op_stats,
print_param_stats,
print_summary,
sizeof_fmt,
)
@@ -68,6 +69,8 @@ def visualize(
# FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR)

enable_receptive_field()

graph = Network.load(model_path)

def process_name(name):
@@ -110,7 +113,7 @@ def visualize(
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
flops_stats = get_flops_stats(node, node.inputs, node.outputs)
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
@@ -148,13 +151,13 @@ def visualize(

total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params:
total_param_dims, total_param_size = print_params_stats(
total_param_dims, total_param_size = print_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)
total_flops = print_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(


+ 26
- 10
imperative/python/megengine/utils/module_stats.py View File

@@ -31,6 +31,8 @@ _calc_receptive_field_dict = {}


def _receptive_field_fallback(module, inputs, outputs):
if not _receptive_field_enabled:
return
assert not hasattr(module, "_rf")
assert not hasattr(module, "_stride")
if len(inputs) == 0:
@@ -54,6 +56,8 @@ _iter_list = [
),
]

_receptive_field_enabled = False


def _register_dict(*modules, dict=None):
def callback(impl):
@@ -72,6 +76,16 @@ def register_receptive_field(*modules):
return _register_dict(*modules, dict=_calc_receptive_field_dict)


def enable_receptive_field():
global _receptive_field_enabled
_receptive_field_enabled = True


def disable_receptive_field():
global _receptive_field_enabled
_receptive_field_enabled = False


@register_flops(
m.Conv1d, m.Conv2d, m.Conv3d,
)
@@ -144,16 +158,16 @@ def preprocess_receptive_field(module, inputs, outputs):
# TODO: support other dimensions
pre_rf = (
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs),
max(i.owner._rf[1] for i in inputs),
max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs),
)
pre_stride = (
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs),
max(i.owner._stride[1] for i in inputs),
max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs),
)
return pre_rf, pre_stride


def get_flops_stats(module, inputs, outputs):
def get_op_stats(module, inputs, outputs):
rst = {
"input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs],
@@ -184,7 +198,7 @@ def get_flops_stats(module, inputs, outputs):
return


def print_flops_stats(flops, bar_length_max=20):
def print_op_stats(flops, bar_length_max=20):
max_flops_num = max([i["flops_num"] for i in flops] + [0])
total_flops_num = 0
for d in flops:
@@ -203,13 +217,14 @@ def print_flops_stats(flops, bar_length_max=20):
"class_name",
"input_shapes",
"output_shapes",
"receptive_field",
"stride",
"flops",
"flops_cum",
"percentage",
"bar",
]
if _receptive_field_enabled:
header.insert(4, "receptive_field")
header.insert(5, "stride")

total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(
@@ -240,7 +255,7 @@ def get_param_stats(param: np.ndarray):
}


def print_params_stats(params, bar_length_max=20):
def print_param_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0
for d in params:
@@ -302,11 +317,12 @@ def module_stats(
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
disable_receptive_field()

def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]

flops_stats = get_flops_stats(module, inputs, outputs)
flops_stats = get_op_stats(module, inputs, outputs)
if flops_stats is not None:
flops_stats["name"] = name
flops_stats["class_name"] = class_name
@@ -349,11 +365,11 @@ def module_stats(
}
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params:
total_param_dims, total_param_size = print_params_stats(params, bar_length_max)
total_param_dims, total_param_size = print_param_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops:
total_flops = print_flops_stats(flops, bar_length_max)
total_flops = print_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(


Loading…
Cancel
Save