BREAKING CHANGE:
GitOrigin-RevId: ced3da3a12
release-1.5
@@ -9,6 +9,7 @@ | |||||
import argparse | import argparse | ||||
import logging | import logging | ||||
import re | import re | ||||
from collections import namedtuple | |||||
import numpy as np | import numpy as np | ||||
@@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize | |||||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | ||||
from megengine.utils.module_stats import ( | from megengine.utils.module_stats import ( | ||||
enable_receptive_field, | enable_receptive_field, | ||||
get_activation_stats, | |||||
get_op_stats, | get_op_stats, | ||||
get_param_stats, | get_param_stats, | ||||
print_activations_stats, | |||||
print_op_stats, | print_op_stats, | ||||
print_param_stats, | print_param_stats, | ||||
print_summary, | print_summary, | ||||
sizeof_fmt, | sizeof_fmt, | ||||
sum_activations_stats, | |||||
sum_op_stats, | |||||
sum_param_stats, | |||||
) | ) | ||||
from megengine.utils.network import Network | from megengine.utils.network import Network | ||||
@@ -34,6 +40,7 @@ def visualize( | |||||
bar_length_max: int = 20, | bar_length_max: int = 20, | ||||
log_params: bool = True, | log_params: bool = True, | ||||
log_flops: bool = True, | log_flops: bool = True, | ||||
log_activations: bool = True, | |||||
): | ): | ||||
r""" | r""" | ||||
Load megengine dumped model and visualize graph structure with tensorboard log files. | Load megengine dumped model and visualize graph structure with tensorboard log files. | ||||
@@ -44,6 +51,7 @@ def visualize( | |||||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | ||||
:param log_params: whether print and record params size. | :param log_params: whether print and record params size. | ||||
:param log_flops: whether print and record op flops. | :param log_flops: whether print and record op flops. | ||||
:param log_activations: whether print and record op activations. | |||||
""" | """ | ||||
if log_path: | if log_path: | ||||
try: | try: | ||||
@@ -83,6 +91,10 @@ def visualize( | |||||
node_list = [] | node_list = [] | ||||
flops_list = [] | flops_list = [] | ||||
params_list = [] | params_list = [] | ||||
activations_list = [] | |||||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||||
for node in graph.all_oprs: | for node in graph.all_oprs: | ||||
if hasattr(node, "output_idx"): | if hasattr(node, "output_idx"): | ||||
node_oup = node.outputs[node.output_idx] | node_oup = node.outputs[node.output_idx] | ||||
@@ -124,6 +136,11 @@ def visualize( | |||||
flops_stats["class_name"] = node.type | flops_stats["class_name"] = node.type | ||||
flops_list.append(flops_stats) | flops_list.append(flops_stats) | ||||
acts = get_activation_stats(node_oup.numpy()) | |||||
acts["name"] = node.name | |||||
acts["class_name"] = node.type | |||||
activations_list.append(acts) | |||||
if node.type == "ImmutableTensor": | if node.type == "ImmutableTensor": | ||||
param_stats = get_param_stats(node.numpy()) | param_stats = get_param_stats(node.numpy()) | ||||
# add tensor size attr | # add tensor size attr | ||||
@@ -149,20 +166,36 @@ def visualize( | |||||
"#params": len(params_list), | "#params": len(params_list), | ||||
} | } | ||||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||||
( | |||||
total_flops, | |||||
total_param_dims, | |||||
total_param_size, | |||||
total_act_dims, | |||||
total_param_size, | |||||
) = (0, 0, 0, 0, 0) | |||||
total_param_dims, total_param_size, params = sum_param_stats( | |||||
params_list, bar_length_max | |||||
) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if log_params: | if log_params: | ||||
total_param_dims, total_param_size = print_param_stats( | |||||
params_list, bar_length_max | |||||
) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
print_param_stats(params) | |||||
total_flops, flops = sum_op_stats(flops_list, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_flops: | if log_flops: | ||||
total_flops = print_op_stats(flops_list, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_params and log_flops: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
total_flops / total_param_size | |||||
) | |||||
print_op_stats(flops) | |||||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||||
activations_list, bar_length_max | |||||
) | |||||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
if log_activations: | |||||
print_activations_stats(activations) | |||||
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||||
if log_path: | if log_path: | ||||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | ||||
@@ -179,7 +212,12 @@ def visualize( | |||||
# FIXME: remove this after resolving "span dist too large" warning | # FIXME: remove this after resolving "span dist too large" warning | ||||
_imperative_rt_logger.set_log_level(old_level) | _imperative_rt_logger.set_log_level(old_level) | ||||
return total_param_size, total_flops | |||||
return ( | |||||
total_stats( | |||||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||||
), | |||||
stats_details(params=params, flops=flops, activations=activations), | |||||
) | |||||
def main(): | def main(): | ||||
@@ -5,7 +5,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import contextlib | |||||
from collections import namedtuple | |||||
from functools import partial | from functools import partial | ||||
import numpy as np | import numpy as np | ||||
@@ -18,6 +18,8 @@ import megengine.module.quantized as qm | |||||
from megengine.core.tensor.dtype import get_dtype_bit | from megengine.core.tensor.dtype import get_dtype_bit | ||||
from megengine.functional.tensor import zeros | from megengine.functional.tensor import zeros | ||||
from .module_utils import set_module_mode_safe | |||||
try: | try: | ||||
mge.logger.MegEngineLogFormatter.max_lines = float("inf") | mge.logger.MegEngineLogFormatter.max_lines = float("inf") | ||||
except AttributeError as e: | except AttributeError as e: | ||||
@@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs): | |||||
) | ) | ||||
@register_flops( | |||||
m.batchnorm._BatchNorm, m.SyncBatchNorm, m.GroupNorm, m.LayerNorm, m.InstanceNorm, | |||||
) | |||||
def flops_norm(module: m.Linear, inputs, outputs): | |||||
return np.prod(inputs[0].shape) * 7 | |||||
@register_flops(m.AvgPool2d, m.MaxPool2d) | |||||
def flops_pool(module: m.AvgPool2d, inputs, outputs): | |||||
return np.prod(outputs[0].shape) * (module.kernel_size ** 2) | |||||
@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) | |||||
def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs): | |||||
stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1)) | |||||
kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h | |||||
stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1)) | |||||
kernel_w = inputs[0].shape[3] - (inputs[0].shape[3] - 1) * stride_w | |||||
return np.prod(outputs[0].shape) * kernel_h * kernel_w | |||||
@register_flops(m.Linear) | @register_flops(m.Linear) | ||||
def flops_linear(module: m.Linear, inputs, outputs): | def flops_linear(module: m.Linear, inputs, outputs): | ||||
bias = module.out_features if module.bias is not None else 0 | bias = module.out_features if module.bias is not None else 0 | ||||
@@ -120,6 +143,12 @@ hook_modules = ( | |||||
m.conv._ConvNd, | m.conv._ConvNd, | ||||
m.Linear, | m.Linear, | ||||
m.BatchMatMulActivation, | m.BatchMatMulActivation, | ||||
m.batchnorm._BatchNorm, | |||||
m.LayerNorm, | |||||
m.GroupNorm, | |||||
m.InstanceNorm, | |||||
m.pooling._PoolNd, | |||||
m.adaptive_pooling._AdaptivePoolNd, | |||||
) | ) | ||||
@@ -137,12 +166,16 @@ def dict2table(list_of_dict, header): | |||||
def sizeof_fmt(num, suffix="B"): | def sizeof_fmt(num, suffix="B"): | ||||
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||||
if abs(num) < 1024.0: | |||||
if suffix == "B": | |||||
scale = 1024.0 | |||||
units = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"] | |||||
else: | |||||
scale = 1000.0 | |||||
units = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"] | |||||
for unit in units: | |||||
if abs(num) < scale or unit == units[-1]: | |||||
return "{:3.3f} {}{}".format(num, unit, suffix) | return "{:3.3f} {}{}".format(num, unit, suffix) | ||||
num /= 1024.0 | |||||
sign_str = "-" if num < 0 else "" | |||||
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||||
num /= scale | |||||
def preprocess_receptive_field(module, inputs, outputs): | def preprocess_receptive_field(module, inputs, outputs): | ||||
@@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs): | |||||
def get_op_stats(module, inputs, outputs): | def get_op_stats(module, inputs, outputs): | ||||
if not isinstance(outputs, tuple) and not isinstance(outputs, list): | |||||
outputs = (outputs,) | |||||
rst = { | rst = { | ||||
"input_shapes": [i.shape for i in inputs], | "input_shapes": [i.shape for i in inputs], | ||||
"output_shapes": [o.shape for o in outputs], | "output_shapes": [o.shape for o in outputs], | ||||
@@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs): | |||||
return | return | ||||
def print_op_stats(flops, bar_length_max=20): | |||||
def sum_op_stats(flops, bar_length_max=20): | |||||
max_flops_num = max([i["flops_num"] for i in flops] + [0]) | max_flops_num = max([i["flops_num"] for i in flops] + [0]) | ||||
total_flops_num = 0 | total_flops_num = 0 | ||||
for d in flops: | for d in flops: | ||||
@@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20): | |||||
d["bar"] = "#" * bar_length | d["bar"] = "#" * bar_length | ||||
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") | d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") | ||||
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||||
total_var_size = sum( | |||||
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops | |||||
) | |||||
flops.append( | |||||
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||||
) | |||||
return total_flops_num, flops | |||||
def print_op_stats(flops): | |||||
header = [ | header = [ | ||||
"name", | "name", | ||||
"class_name", | "class_name", | ||||
@@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20): | |||||
if _receptive_field_enabled: | if _receptive_field_enabled: | ||||
header.insert(4, "receptive_field") | header.insert(4, "receptive_field") | ||||
header.insert(5, "stride") | header.insert(5, "stride") | ||||
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||||
total_var_size = sum( | |||||
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops | |||||
) | |||||
flops.append( | |||||
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||||
) | |||||
logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) | logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) | ||||
return total_flops_num | |||||
def get_param_stats(param: np.ndarray): | def get_param_stats(param: np.ndarray): | ||||
nbits = get_dtype_bit(param.dtype.name) | nbits = get_dtype_bit(param.dtype.name) | ||||
@@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray): | |||||
} | } | ||||
def print_param_stats(params, bar_length_max=20): | |||||
def sum_param_stats(params, bar_length_max=20): | |||||
max_size = max([d["size"] for d in params] + [0]) | max_size = max([d["size"] for d in params] + [0]) | ||||
total_param_dims, total_param_size = 0, 0 | total_param_dims, total_param_size = 0, 0 | ||||
for d in params: | for d in params: | ||||
@@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20): | |||||
param_size = sizeof_fmt(total_param_size) | param_size = sizeof_fmt(total_param_size) | ||||
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | ||||
return total_param_dims, total_param_size, params | |||||
def print_param_stats(params): | |||||
header = [ | header = [ | ||||
"name", | "name", | ||||
"dtype", | "dtype", | ||||
@@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20): | |||||
"mean", | "mean", | ||||
"std", | "std", | ||||
"param_dim", | "param_dim", | ||||
"bits", | |||||
"nbits", | |||||
"size", | "size", | ||||
"size_cum", | "size_cum", | ||||
"percentage", | "percentage", | ||||
"size_bar", | "size_bar", | ||||
] | ] | ||||
logger.info( | logger.info( | ||||
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | ||||
) | ) | ||||
return total_param_dims, total_param_size | |||||
def get_activation_stats(output: np.ndarray): | |||||
out_shape = output.shape | |||||
activations_dtype = output.dtype | |||||
nbits = get_dtype_bit(activations_dtype.name) | |||||
act_dim = np.prod(out_shape) | |||||
act_size = act_dim * nbits // 8 | |||||
return { | |||||
"dtype": activations_dtype, | |||||
"shape": out_shape, | |||||
"act_dim": act_dim, | |||||
"mean": "{:.3g}".format(output.mean()), | |||||
"std": "{:.3g}".format(output.std()), | |||||
"nbits": nbits, | |||||
"size": act_size, | |||||
} | |||||
def sum_activations_stats(activations, bar_length_max=20): | |||||
max_act_size = max([i["size"] for i in activations] + [0]) | |||||
total_act_dims, total_act_size = 0, 0 | |||||
for d in activations: | |||||
total_act_size += int(d["size"]) | |||||
total_act_dims += int(d["act_dim"]) | |||||
d["size_cum"] = sizeof_fmt(total_act_size) | |||||
for d in activations: | |||||
ratio = d["ratio"] = d["size"] / total_act_size | |||||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||||
bar_length = int(d["size"] / max_act_size * bar_length_max) | |||||
d["size_bar"] = "#" * bar_length | |||||
d["size"] = sizeof_fmt(d["size"]) | |||||
act_size = sizeof_fmt(total_act_size) | |||||
activations.append(dict(name="total", act_dim=total_act_dims, size=act_size,)) | |||||
return total_act_dims, total_act_size, activations | |||||
def print_activations_stats(activations): | |||||
header = [ | |||||
"name", | |||||
"class_name", | |||||
"dtype", | |||||
"shape", | |||||
"mean", | |||||
"std", | |||||
"nbits", | |||||
"act_dim", | |||||
"size", | |||||
"size_cum", | |||||
"percentage", | |||||
"size_bar", | |||||
] | |||||
logger.info( | |||||
"activations stats: \n" | |||||
+ tabulate.tabulate(dict2table(activations, header=header)) | |||||
) | |||||
def print_summary(**kwargs): | def print_summary(**kwargs): | ||||
@@ -294,25 +390,26 @@ def print_summary(**kwargs): | |||||
def module_stats( | def module_stats( | ||||
model: m.Module, | model: m.Module, | ||||
input_size: int, | |||||
input_shapes: list, | |||||
bar_length_max: int = 20, | bar_length_max: int = 20, | ||||
log_params: bool = True, | log_params: bool = True, | ||||
log_flops: bool = True, | log_flops: bool = True, | ||||
log_activations: bool = True, | |||||
): | ): | ||||
r""" | r""" | ||||
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | ||||
:param model: model that need to get stats info. | :param model: model that need to get stats info. | ||||
:param input_size: size of input for running model and calculating stats. | |||||
:param input_shapes: shapes of inputs for running model and calculating stats. | |||||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | ||||
:param log_params: whether print and record params size. | :param log_params: whether print and record params size. | ||||
:param log_flops: whether print and record op flops. | :param log_flops: whether print and record op flops. | ||||
:param log_activations: whether print and record op activations. | |||||
""" | """ | ||||
disable_receptive_field() | disable_receptive_field() | ||||
def module_stats_hook(module, inputs, outputs, name=""): | def module_stats_hook(module, inputs, outputs, name=""): | ||||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | class_name = str(module.__class__).split(".")[-1].split("'")[0] | ||||
flops_stats = get_op_stats(module, inputs, outputs) | flops_stats = get_op_stats(module, inputs, outputs) | ||||
if flops_stats is not None: | if flops_stats is not None: | ||||
flops_stats["name"] = name | flops_stats["name"] = name | ||||
@@ -331,38 +428,25 @@ def module_stats( | |||||
param_stats["name"] = name + "-b" | param_stats["name"] = name + "-b" | ||||
params.append(param_stats) | params.append(param_stats) | ||||
@contextlib.contextmanager | |||||
def adjust_stats(module, training=False): | |||||
"""Adjust module to training/eval mode temporarily. | |||||
Args: | |||||
module (M.Module): used module. | |||||
training (bool): training mode. True for train mode, False fro eval mode. | |||||
""" | |||||
def recursive_backup_stats(module, mode): | |||||
for m in module.modules(): | |||||
# save prev status to _prev_training | |||||
m._prev_training = m.training | |||||
m.train(mode, recursive=False) | |||||
def recursive_recover_stats(module): | |||||
for m in module.modules(): | |||||
# recover prev status and delete attribute | |||||
m.training = m._prev_training | |||||
delattr(m, "_prev_training") | |||||
recursive_backup_stats(module, mode=training) | |||||
yield module | |||||
recursive_recover_stats(module) | |||||
if not isinstance(outputs, tuple) or not isinstance(outputs, list): | |||||
output = outputs.numpy() | |||||
else: | |||||
output = outputs[0].numpy() | |||||
activation_stats = get_activation_stats(output) | |||||
activation_stats["name"] = name | |||||
activation_stats["class_name"] = class_name | |||||
activations.append(activation_stats) | |||||
# multiple inputs to the network | # multiple inputs to the network | ||||
if not isinstance(input_size[0], tuple): | |||||
input_size = [input_size] | |||||
if not isinstance(input_shapes[0], tuple): | |||||
input_shapes = [input_shapes] | |||||
params = [] | params = [] | ||||
flops = [] | flops = [] | ||||
hooks = [] | hooks = [] | ||||
activations = [] | |||||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||||
for (name, module) in model.named_modules(): | for (name, module) in model.named_modules(): | ||||
if isinstance(module, hook_modules): | if isinstance(module, hook_modules): | ||||
@@ -370,8 +454,8 @@ def module_stats( | |||||
module.register_forward_hook(partial(module_stats_hook, name=name)) | module.register_forward_hook(partial(module_stats_hook, name=name)) | ||||
) | ) | ||||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] | |||||
with adjust_stats(model, training=False) as model: | |||||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||||
with set_module_mode_safe(model, training=False) as model: | |||||
model(*inputs) | model(*inputs) | ||||
for h in hooks: | for h in hooks: | ||||
@@ -380,19 +464,40 @@ def module_stats( | |||||
extra_info = { | extra_info = { | ||||
"#params": len(params), | "#params": len(params), | ||||
} | } | ||||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||||
( | |||||
total_flops, | |||||
total_param_dims, | |||||
total_param_size, | |||||
total_act_dims, | |||||
total_param_size, | |||||
) = (0, 0, 0, 0, 0) | |||||
total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if log_params: | if log_params: | ||||
total_param_dims, total_param_size = print_param_stats(params, bar_length_max) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
print_param_stats(params) | |||||
total_flops, flops = sum_op_stats(flops, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_flops: | if log_flops: | ||||
total_flops = print_op_stats(flops, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_params and log_flops: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
total_flops / total_param_size | |||||
) | |||||
print_op_stats(flops) | |||||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||||
activations, bar_length_max | |||||
) | |||||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
if log_activations: | |||||
print_activations_stats(activations) | |||||
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||||
print_summary(**extra_info) | print_summary(**extra_info) | ||||
return total_param_size, total_flops | |||||
return ( | |||||
total_stats( | |||||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||||
), | |||||
stats_details(params=params, flops=flops, activations=activations), | |||||
) |
@@ -5,6 +5,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import contextlib | |||||
from collections import Iterable | from collections import Iterable | ||||
from ..module import Sequential | from ..module import Sequential | ||||
@@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value): | |||||
parent[key] = value | parent[key] = value | ||||
_access_structure(obj, key, callback=f) | _access_structure(obj, key, callback=f) | ||||
@contextlib.contextmanager | |||||
def set_module_mode_safe( | |||||
module: Module, training: bool = False, | |||||
): | |||||
"""Adjust module to training/eval mode temporarily. | |||||
:param module: used module. | |||||
:param training: training (bool): training mode. True for train mode, False fro eval mode. | |||||
""" | |||||
backup_stats = {} | |||||
def recursive_backup_stats(module, mode): | |||||
for m in module.modules(): | |||||
backup_stats[m] = m.training | |||||
m.train(mode, recursive=False) | |||||
def recursive_recover_stats(module): | |||||
for m in module.modules(): | |||||
m.training = backup_stats.pop(m) | |||||
recursive_backup_stats(module, mode=training) | |||||
yield module | |||||
recursive_recover_stats(module) |
@@ -0,0 +1,377 @@ | |||||
import math | |||||
from copy import deepcopy | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
import megengine.hub as hub | |||||
import megengine.module as M | |||||
from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.utils.module_stats import module_stats | |||||
@pytest.mark.skipif( | |||||
use_symbolic_shape(), reason="This test do not support symbolic shape.", | |||||
) | |||||
def test_module_stats(): | |||||
net = ResNet(BasicBlock, [2, 2, 2, 2]) | |||||
input_shape = (1, 3, 224, 224) | |||||
total_stats, stats_details = module_stats(net, input_shape) | |||||
x1 = mge.tensor(np.zeros((1, 3, 224, 224))) | |||||
gt_flops, gt_acts = net.get_stats(x1) | |||||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||||
gt_flops, | |||||
gt_acts, | |||||
) | |||||
class BasicBlock(M.Module): | |||||
expansion = 1 | |||||
def __init__( | |||||
self, | |||||
in_channels, | |||||
channels, | |||||
stride=1, | |||||
groups=1, | |||||
base_width=64, | |||||
dilation=1, | |||||
norm=M.BatchNorm2d, | |||||
): | |||||
super().__init__() | |||||
self.tmp_in_channels = in_channels | |||||
self.tmp_channels = channels | |||||
self.stride = stride | |||||
if groups != 1 or base_width != 64: | |||||
raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |||||
if dilation > 1: | |||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||||
self.conv1 = M.Conv2d( | |||||
in_channels, channels, 3, stride, padding=dilation, bias=False | |||||
) | |||||
self.bn1 = norm(channels) | |||||
self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False) | |||||
self.bn2 = norm(channels) | |||||
self.downsample_id = M.Identity() | |||||
self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False) | |||||
self.downsample_norm = norm(channels) | |||||
def forward(self, x): | |||||
identity = x | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = F.relu(x) | |||||
x = self.conv2(x) | |||||
x = self.bn2(x) | |||||
if self.tmp_in_channels == self.tmp_channels and self.stride == 1: | |||||
identity = self.downsample_id(identity) | |||||
else: | |||||
identity = self.downsample_conv(identity) | |||||
identity = self.downsample_norm(identity) | |||||
x += identity | |||||
x = F.relu(x) | |||||
return x | |||||
def get_stats(self, x): | |||||
activations, flops = 0, 0 | |||||
identity = x | |||||
in_x = deepcopy(x) | |||||
x = self.conv1(x) | |||||
tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
in_x = deepcopy(x) | |||||
x = self.bn1(x) | |||||
tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x = F.relu(x) | |||||
in_x = deepcopy(x) | |||||
x = self.conv2(x) | |||||
tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
in_x = deepcopy(x) | |||||
x = self.bn2(x) | |||||
tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
if self.tmp_in_channels == self.tmp_channels and self.stride == 1: | |||||
identity = self.downsample_id(identity) | |||||
else: | |||||
in_x = deepcopy(identity) | |||||
identity = self.downsample_conv(identity) | |||||
tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
in_x = deepcopy(identity) | |||||
identity = self.downsample_norm(identity) | |||||
tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x += identity | |||||
x = F.relu(x) | |||||
return x, flops, activations | |||||
class ResNet(M.Module): | |||||
def __init__( | |||||
self, | |||||
block, | |||||
layers=[2, 2, 2, 2], | |||||
num_classes=1000, | |||||
zero_init_residual=False, | |||||
groups=1, | |||||
width_per_group=64, | |||||
replace_stride_with_dilation=None, | |||||
norm=M.BatchNorm2d, | |||||
): | |||||
super().__init__() | |||||
self.in_channels = 64 | |||||
self.dilation = 1 | |||||
if replace_stride_with_dilation is None: | |||||
# each element in the tuple indicates if we should replace | |||||
# the 2x2 stride with a dilated convolution instead | |||||
replace_stride_with_dilation = [False, False, False] | |||||
if len(replace_stride_with_dilation) != 3: | |||||
raise ValueError( | |||||
"replace_stride_with_dilation should be None " | |||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation) | |||||
) | |||||
self.groups = groups | |||||
self.base_width = width_per_group | |||||
self.conv1 = M.Conv2d( | |||||
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False | |||||
) | |||||
self.bn1 = norm(self.in_channels) | |||||
self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
self.layer1_0 = BasicBlock( | |||||
self.in_channels, | |||||
64, | |||||
stride=1, | |||||
groups=self.groups, | |||||
base_width=self.base_width, | |||||
dilation=self.dilation, | |||||
norm=M.BatchNorm2d, | |||||
) | |||||
self.layer1_1 = BasicBlock( | |||||
self.in_channels, | |||||
64, | |||||
stride=1, | |||||
groups=self.groups, | |||||
base_width=self.base_width, | |||||
dilation=self.dilation, | |||||
norm=M.BatchNorm2d, | |||||
) | |||||
self.layer2_0 = BasicBlock(64, 128, stride=2) | |||||
self.layer2_1 = BasicBlock(128, 128) | |||||
self.layer3_0 = BasicBlock(128, 256, stride=2) | |||||
self.layer3_1 = BasicBlock(256, 256) | |||||
self.layer4_0 = BasicBlock(256, 512, stride=2) | |||||
self.layer4_1 = BasicBlock(512, 512) | |||||
self.layer1 = self._make_layer(block, 64, layers[0], norm=norm) | |||||
self.layer2 = self._make_layer( | |||||
block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm | |||||
) | |||||
self.layer3 = self._make_layer( | |||||
block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm | |||||
) | |||||
self.layer4 = self._make_layer( | |||||
block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm | |||||
) | |||||
self.fc = M.Linear(512, num_classes) | |||||
for m in self.modules(): | |||||
if isinstance(m, M.Conv2d): | |||||
M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |||||
if m.bias is not None: | |||||
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||||
bound = 1 / math.sqrt(fan_in) | |||||
M.init.uniform_(m.bias, -bound, bound) | |||||
elif isinstance(m, M.BatchNorm2d): | |||||
M.init.ones_(m.weight) | |||||
M.init.zeros_(m.bias) | |||||
elif isinstance(m, M.Linear): | |||||
M.init.msra_uniform_(m.weight, a=math.sqrt(5)) | |||||
if m.bias is not None: | |||||
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||||
bound = 1 / math.sqrt(fan_in) | |||||
M.init.uniform_(m.bias, -bound, bound) | |||||
if zero_init_residual: | |||||
for m in self.modules(): | |||||
M.init.zeros_(m.bn2.weight) | |||||
def _make_layer( | |||||
self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d | |||||
): | |||||
previous_dilation = self.dilation | |||||
if dilate: | |||||
self.dilation *= stride | |||||
stride = 1 | |||||
layers = [] | |||||
layers.append( | |||||
block( | |||||
self.in_channels, | |||||
channels, | |||||
stride, | |||||
groups=self.groups, | |||||
base_width=self.base_width, | |||||
dilation=previous_dilation, | |||||
norm=norm, | |||||
) | |||||
) | |||||
self.in_channels = channels * block.expansion | |||||
for _ in range(1, blocks): | |||||
layers.append( | |||||
block( | |||||
self.in_channels, | |||||
channels, | |||||
groups=self.groups, | |||||
base_width=self.base_width, | |||||
dilation=self.dilation, | |||||
norm=norm, | |||||
) | |||||
) | |||||
return M.Sequential(*layers) | |||||
def extract_features(self, x): | |||||
outputs = {} | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = F.relu(x) | |||||
x = self.maxpool(x) | |||||
outputs["stem"] = x | |||||
x = self.layer1(x) | |||||
outputs["res2"] = x | |||||
x = self.layer2(x) | |||||
outputs["res3"] = x | |||||
x = self.layer3(x) | |||||
outputs["res4"] = x | |||||
x = self.layer4(x) | |||||
outputs["res5"] = x | |||||
return outputs | |||||
def forward(self, x): | |||||
x = self.extract_features(x)["res5"] | |||||
x = F.avg_pool2d(x, 7) | |||||
x = F.flatten(x, 1) | |||||
x = self.fc(x) | |||||
return x | |||||
def get_stats(self, x): | |||||
flops, activations = 0, 0 | |||||
in_x = deepcopy(x) | |||||
x = self.conv1(x) | |||||
tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
in_x = deepcopy(x) | |||||
x = self.bn1(x) | |||||
tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x = F.relu(x) | |||||
in_x = deepcopy(x) | |||||
x = self.maxpool(x) | |||||
tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
x = F.avg_pool2d(x, 7) | |||||
x = F.flatten(x, 1) | |||||
in_x = deepcopy(x) | |||||
x = self.fc(x) | |||||
tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x) | |||||
activations += tmp_acts | |||||
flops += tmp_flops | |||||
return flops, activations | |||||
def cal_conv_stats(module, input, output): | |||||
bias = 1 if module.bias is not None else 0 | |||||
flops = np.prod(output[0].shape) * ( | |||||
module.in_channels // module.groups * np.prod(module.kernel_size) + bias | |||||
) | |||||
acts = np.prod(output[0].shape) | |||||
return flops, acts | |||||
def cal_norm_stats(module, input, output): | |||||
return np.prod(input[0].shape) * 7, np.prod(output[0].shape) | |||||
def cal_linear_stats(module, inputs, outputs): | |||||
bias = module.out_features if module.bias is not None else 0 | |||||
return ( | |||||
np.prod(outputs[0].shape) * module.in_features + bias, | |||||
np.prod(outputs[0].shape), | |||||
) | |||||
def cal_pool_stats(module, inputs, outputs): | |||||
return ( | |||||
np.prod(outputs[0].shape) * (module.kernel_size ** 2), | |||||
np.prod(outputs[0].shape), | |||||
) |