import collections from collections import namedtuple from functools import partial from typing import Iterable import numpy as np import tabulate from .. import Tensor from .. import functional as F from .. import get_logger from .. import module as M from ..core.tensor.dtype import get_dtype_bit from ..logger import MegEngineLogFormatter from .module_utils import set_module_mode_safe try: MegEngineLogFormatter.max_lines = float("inf") except AttributeError as e: raise ValueError("set logger max lines failed") logger = get_logger(__name__) logger.setLevel("INFO") _calc_flops_dict = {} _calc_receptive_field_dict = {} def _receptive_field_fallback(module, inputs, outputs): if not _receptive_field_enabled: return assert not hasattr(module, "_rf") assert not hasattr(module, "_stride") if len(inputs) == 0: # TODO: support other dimension module._rf = (1, 1) module._stride = (1, 1) return module._rf, module._stride rf, stride = preprocess_receptive_field(module, inputs, outputs) module._rf = rf module._stride = stride return rf, stride # key tuple, impl_dict, fallback _iter_list = [ ("flops_num", _calc_flops_dict, None), ( ("receptive_field", "stride"), _calc_receptive_field_dict, _receptive_field_fallback, ), ] _receptive_field_enabled = False def _register_dict(*modules, dict=None): def callback(impl): for module in modules: dict[module] = impl return impl return callback def register_flops(*modules): return _register_dict(*modules, dict=_calc_flops_dict) def register_receptive_field(*modules): return _register_dict(*modules, dict=_calc_receptive_field_dict) def enable_receptive_field(): global _receptive_field_enabled _receptive_field_enabled = True def disable_receptive_field(): global _receptive_field_enabled _receptive_field_enabled = False @register_flops(M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d) def flops_convNd(module: M.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 # N x Cout x H x W x (Cin x Kw x Kh + bias) return np.prod(outputs[0].shape) * ( float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias ) @register_flops(M.ConvTranspose2d) def flops_convNdTranspose(module: M.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 # N x Cout x H x W x (Cin x Kw x Kh + bias) return ( np.prod(inputs[0].shape) * (module.out_channels // module.groups * np.prod(module.kernel_size)) + np.prod(outputs[0].shape) * bias ) @register_flops( M.batchnorm._BatchNorm, M.SyncBatchNorm, M.GroupNorm, M.LayerNorm, M.InstanceNorm, ) def flops_norm(module: M.Linear, inputs, outputs): return np.prod(inputs[0].shape) * 7 @register_flops(M.AvgPool2d, M.MaxPool2d) def flops_pool(module: M.AvgPool2d, inputs, outputs): kernel_sum = 0 if isinstance(module.kernel_size, tuple) and len(module.kernel_size) == 2: kernel_sum = np.prod(module.kernel_size) else: kernel_sum = module.kernel_size ** 2 return np.prod(outputs[0].shape) * kernel_sum @register_flops(M.AdaptiveAvgPool2d, M.AdaptiveMaxPool2d) def flops_adaptivePool(module: M.AdaptiveAvgPool2d, inputs, outputs): stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1)) kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1)) kernel_w = inputs[0].shape[3] - (inputs[0].shape[3] - 1) * stride_w return np.prod(outputs[0].shape) * kernel_h * kernel_w @register_flops(M.Linear) def flops_linear(module: M.Linear, inputs, outputs): bias = module.out_features if module.bias is not None else 0 return np.prod(outputs[0].shape) * module.in_features + bias @register_flops(M.BatchMatMulActivation) def flops_batchmatmul(module: M.BatchMatMulActivation, inputs, outputs): bias = 1 if module.bias is not None else 0 x = inputs[0] w = module.weight batch_size = x.shape[0] n, p = x.shape[1:] _, m = w.shape[1:] return n * (p + bias) * m * batch_size # does not need import qat and quantized module since they inherit from float module. hook_modules = ( M.conv._ConvNd, M.Linear, M.BatchMatMulActivation, M.batchnorm._BatchNorm, M.LayerNorm, M.GroupNorm, M.InstanceNorm, M.pooling._PoolNd, M.adaptive_pooling._AdaptivePoolNd, ) def _mean(inp): inp = Tensor(inp).astype(np.float32) return F.mean(inp).numpy() def _std(inp): inp = Tensor(inp).astype(np.float32) return F.std(inp).numpy() def dict2table(list_of_dict, header): table_data = [header] for d in list_of_dict: row = [] for h in header: v = "" if h in d: v = d[h] row.append(v) table_data.append(row) return table_data def sizeof_fmt(num, suffix="B"): if suffix == "B": scale = 1024.0 units = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"] else: scale = 1000.0 units = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"] for unit in units: if abs(num) < scale or unit == units[-1]: return "{:3.3f} {}{}".format(num, unit, suffix) num /= scale def preprocess_receptive_field(module, inputs, outputs): # TODO: support other dimensions pre_rf = ( max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs), ) pre_stride = ( max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs), ) return pre_rf, pre_stride def get_op_stats(module, inputs, outputs): if not isinstance(outputs, tuple) and not isinstance(outputs, list): outputs = (outputs,) rst = { "input_shapes": [i.shape for i in inputs], "output_shapes": [o.shape for o in outputs], } valid_flag = False for key, _dict, fallback in _iter_list: for _type in _dict: if isinstance(module, _type): value = _dict[_type](module, inputs, outputs) valid_flag = True break else: if fallback is not None: value = fallback(module, inputs, outputs) continue if isinstance(key, tuple): assert isinstance(value, tuple) for k, v in zip(key, value): rst[k] = v else: rst[key] = value if valid_flag: return rst else: return None return def sum_op_stats(flops, bar_length_max=20): max_flops_num = max([i["flops_num"] for i in flops] + [0]) total_flops_num = 0 for d in flops: total_flops_num += int(d["flops_num"]) d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") for d in flops: ratio = d["ratio"] = d["flops_num"] / total_flops_num d["percentage"] = "{:.2f}%".format(ratio * 100) bar_length = int(d["flops_num"] / max_flops_num * bar_length_max) d["bar"] = "#" * bar_length d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") total_var_size = sum( sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops ) flops.append( dict(name="total", flops=total_flops_str, output_shapes=total_var_size) ) return total_flops_num, flops def print_op_stats(flops): header = [ "name", "class_name", "input_shapes", "output_shapes", "flops", "flops_cum", "percentage", "bar", ] if _receptive_field_enabled: header.insert(4, "receptive_field") header.insert(5, "stride") logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) def get_param_stats(param: Tensor): nbits = get_dtype_bit(np.dtype(param.dtype).name) shape = param.shape param_dim = np.prod(param.shape) param_size = param_dim * nbits // 8 return { "dtype": np.dtype(param.dtype), "shape": shape, "mean": "{:.3g}".format(_mean(param)), "std": "{:.3g}".format(_std(param)), "param_dim": param_dim, "nbits": nbits, "size": param_size, } def sum_param_stats(params, bar_length_max=20): max_size = max([d["size"] for d in params] + [0]) total_param_dims, total_param_size = 0, 0 for d in params: total_param_dims += int(d["param_dim"]) total_param_size += int(d["size"]) d["size_cum"] = sizeof_fmt(total_param_size) for d in params: ratio = d["size"] / total_param_size d["ratio"] = ratio d["percentage"] = "{:.2f}%".format(ratio * 100) bar_length = int(d["size"] / max_size * bar_length_max) d["size_bar"] = "#" * bar_length d["size"] = sizeof_fmt(d["size"]) param_size = sizeof_fmt(total_param_size) params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) return total_param_dims, total_param_size, params def print_param_stats(params): header = [ "name", "dtype", "shape", "mean", "std", "param_dim", "nbits", "size", "size_cum", "percentage", "size_bar", ] logger.info( "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) ) def get_activation_stats(output: Tensor, has_input=False): out_shape = output.shape activations_dtype = np.dtype(output.dtype) nbits = get_dtype_bit(activations_dtype.name) act_dim = np.prod(out_shape) act_size = act_dim * nbits // 8 activation_stats = { "dtype": activations_dtype, "shape": out_shape, "act_dim": act_dim, "nbits": nbits, "size": act_size, } if has_input: activation_stats["mean"] = "{:.3g}".format(_mean(output)) activation_stats["std"] = "{:.3g}".format(_std(output)) return activation_stats def sum_activations_stats(activations, bar_length_max=20): max_act_size = max([i["size"] for i in activations] + [0]) total_act_dims, total_act_size = 0, 0 for d in activations: total_act_size += int(d["size"]) total_act_dims += int(d["act_dim"]) d["size_cum"] = sizeof_fmt(total_act_size) for d in activations: ratio = d["ratio"] = d["size"] / total_act_size d["percentage"] = "{:.2f}%".format(ratio * 100) bar_length = int(d["size"] / max_act_size * bar_length_max) d["size_bar"] = "#" * bar_length d["size"] = sizeof_fmt(d["size"]) act_size = sizeof_fmt(total_act_size) activations.append(dict(name="total", act_dim=total_act_dims, size=act_size,)) return total_act_dims, total_act_size, activations def print_activations_stats(activations, has_input=False): header = [ "name", "class_name", "dtype", "shape", "nbits", "act_dim", "size", "size_cum", "percentage", "size_bar", ] if has_input: header.insert(4, "mean") header.insert(5, "std") logger.info( "activations stats: \n" + tabulate.tabulate(dict2table(activations, header=header)) ) def print_summary(**kwargs): data = [["item", "value"]] data.extend(list(kwargs.items())) logger.info("summary\n" + tabulate.tabulate(data)) def module_stats( model: M.Module, inputs: Iterable[np.ndarray] = None, input_shapes: list = None, cal_params: bool = True, cal_flops: bool = True, cal_activations: bool = True, logging_to_stdout: bool = True, bar_length_max: int = 20, ): r"""Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. Args: model: model that need to get stats info. inputs: user defined input data for running model and calculating stats, alternative with input_shapes. input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs. cal_params: whether calculate and record params size. cal_flops: whether calculate and record op flops. cal_activations: whether calculate and record op activations. logging_to_stdout: whether print all calculated statistic details. bar_length_max: size of bar indicating max flops or parameter size in net stats. """ has_inputs = False if inputs is not None: has_inputs = True if not isinstance(inputs, (tuple, list)): inputs = [inputs] def load_tensor(x): if isinstance(x, np.ndarray): return Tensor(x) elif isinstance(x, collections.abc.Mapping): return {k: load_tensor(v) for k, v in x.items()} elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple return type(x)(*(load_tensor(value) for value in x)) elif isinstance(x, collections.abc.Sequence): return [load_tensor(v) for v in x] else: return Tensor(x, dtype=np.float32) inputs = load_tensor(inputs) else: if input_shapes: if not isinstance(input_shapes[0], tuple): input_shapes = [input_shapes] inputs = [F.zeros(in_size, dtype=np.float32) for in_size in input_shapes] else: logger.error( "Inputs or input_shapes is required for running model and calculating stats.", exc_info=True, ) return if not cal_activations: log_activations = False disable_receptive_field() recorded_parameters = set() def module_stats_hook(module, inputs, outputs, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] if cal_flops: flops_stats = get_op_stats(module, inputs, outputs) if flops_stats is not None: flops_stats["name"] = name flops_stats["class_name"] = class_name flops.append(flops_stats) if cal_params: if ( hasattr(module, "weight") and (module.weight is not None) and module.weight not in recorded_parameters ): w = module.weight param_stats = get_param_stats(w) param_stats["name"] = name + "-w" params.append(param_stats) recorded_parameters.add(w) if ( hasattr(module, "bias") and module.bias is not None and module.bias not in recorded_parameters ): b = module.bias param_stats = get_param_stats(b) param_stats["name"] = name + "-b" params.append(param_stats) recorded_parameters.add(b) if cal_activations: if not isinstance(outputs, (tuple, list)): output = outputs else: output = outputs[0] activation_stats = get_activation_stats(output, has_inputs) activation_stats["name"] = name activation_stats["class_name"] = class_name activations.append(activation_stats) params = [] flops = [] hooks = [] activations = [] total_stats = namedtuple( "total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"] ) stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) for (name, module) in model.named_modules(): if isinstance(module, hook_modules): hooks.append( module.register_forward_hook(partial(module_stats_hook, name=name)) ) with set_module_mode_safe(model, training=False) as model: model(*inputs) for h in hooks: h.remove() extra_info = { "#params": len(params), } ( total_flops, total_param_dims, total_param_size, total_act_dims, total_act_size, ) = (0, 0, 0, 0, 0) if cal_params: total_param_dims, total_param_size, params = sum_param_stats( params, bar_length_max ) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") extra_info["total_param_size"] = sizeof_fmt(total_param_size) if logging_to_stdout: print_param_stats(params) if cal_flops: total_flops, flops = sum_op_stats(flops, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if logging_to_stdout: print_op_stats(flops) if cal_activations: total_act_dims, total_act_size, activations = sum_activations_stats( activations, bar_length_max ) extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") extra_info["total_act_size"] = sizeof_fmt(total_act_size) if logging_to_stdout: print_activations_stats(activations, has_inputs) if cal_flops and cal_params and total_param_size != 0: extra_info["flops/param_size"] = "{:3.3f}".format( total_flops / total_param_size ) print_summary(**extra_info) return ( total_stats( param_size=total_param_size, param_dims=total_param_dims, flops=total_flops, act_size=total_act_size, act_dims=total_act_dims, ), stats_details(params=params, flops=flops, activations=activations), )