|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from functools import partial
-
- import numpy as np
- import tabulate
-
- import megengine as mge
- import megengine.core.tensor.dtype as dtype
- import megengine.module as m
- import megengine.module.qat as qatm
- import megengine.module.quantized as qm
- from megengine.functional.tensor import zeros
-
- try:
- mge.logger.MegEngineLogFormatter.max_lines = float("inf")
- except AttributeError as e:
- raise ValueError("set logger max lines failed")
-
- logger = mge.get_logger(__name__)
- logger.setLevel("INFO")
-
-
- CALC_FLOPS = {}
-
-
- def _register_modules(*modules):
- def callback(impl):
- for module in modules:
- CALC_FLOPS[module] = impl
- return impl
-
- return callback
-
-
- @_register_modules(
- m.Conv2d,
- m.ConvTranspose2d,
- m.LocalConv2d,
- qm.Conv2d,
- qm.ConvRelu2d,
- qm.ConvBn2d,
- qm.ConvBnRelu2d,
- qatm.Conv2d,
- qatm.ConvRelu2d,
- qatm.ConvBn2d,
- qatm.ConvBnRelu2d,
- )
- def count_convNd(module, input, output):
- bias = 1 if module.bias is not None else 0
- group = module.groups
- ic = input[0].shape[1]
- oc = output[0].shape[1]
- goc = oc // group
- gic = ic // group
- N = output[0].shape[0]
- HW = np.prod(output[0].shape[2:])
- # N x Cout x H x W x (Cin x Kw x Kh + bias)
- return N * HW * goc * (gic * np.prod(module.kernel_size) + bias)
-
-
- @_register_modules(m.ConvTranspose2d)
- def count_deconvNd(module, input, output):
- return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size)
-
-
- @_register_modules(m.Linear, qatm.Linear, qm.Linear)
- def count_linear(module, input, output):
- return np.prod(output[0].shape) * module.in_features
-
-
- # does not need import qat and quantized module since they inherit from float module.
- hook_modules = (
- m.Conv2d,
- m.ConvTranspose2d,
- m.LocalConv2d,
- m.BatchNorm2d,
- m.Linear,
- )
-
-
- def dict2table(list_of_dict, header):
- table_data = [header]
- for d in list_of_dict:
- row = []
- for h in header:
- v = ""
- if h in d:
- v = d[h]
- row.append(v)
- table_data.append(row)
- return table_data
-
-
- def sizeof_fmt(num, suffix="B"):
- for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
- if abs(num) < 1024.0:
- return "{:3.3f} {}{}".format(num, unit, suffix)
- num /= 1024.0
- sign_str = "-" if num < 0 else ""
- return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)
-
-
- def print_flops_stats(flops, bar_length_max=20):
- flops_list = [i["flops_num"] for i in flops]
- max_flops_num = max(flops_list + [0])
- # calc total flops and set flops_cum
- total_flops_num = 0
- for d in flops:
- total_flops_num += int(d["flops_num"])
- d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")
-
- for i in flops:
- f = i["flops_num"]
- i["flops"] = sizeof_fmt(f, suffix="OPs")
- r = i["ratio"] = f / total_flops_num
- i["percentage"] = "{:.2f}%".format(r * 100)
- bar_length = int(f / max_flops_num * bar_length_max)
- i["bar"] = "#" * bar_length
-
- header = [
- "name",
- "class_name",
- "input_shapes",
- "output_shapes",
- "flops",
- "flops_cum",
- "percentage",
- "bar",
- ]
-
- total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
- total_var_size = sum(
- sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops
- )
- flops.append(
- dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
- )
-
- logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)))
-
- return total_flops_num
-
-
- def print_params_stats(params, bar_length_max=20):
- total_param_dims, total_param_size = 0, 0
- for d in params:
- total_param_dims += int(d["param_dim"])
- total_param_size += int(d["size"])
- d["size"] = sizeof_fmt(d["size"])
- d["size_cum"] = sizeof_fmt(total_param_size)
-
- for d in params:
- ratio = d["param_dim"] / total_param_dims
- d["ratio"] = ratio
- d["percentage"] = "{:.2f}%".format(ratio * 100)
-
- # construct bar
- max_ratio = max([d["ratio"] for d in params])
- for d in params:
- bar_length = int(d["ratio"] / max_ratio * bar_length_max)
- d["size_bar"] = "#" * bar_length
-
- param_size = sizeof_fmt(total_param_size)
- params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))
-
- header = [
- "name",
- "shape",
- "mean",
- "std",
- "param_dim",
- "bits",
- "size",
- "size_cum",
- "percentage",
- "size_bar",
- ]
-
- logger.info(
- "param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
- )
-
- return total_param_size
-
-
- def module_stats(
- model: m.Module,
- input_size: int,
- bar_length_max: int = 20,
- log_params: bool = True,
- log_flops: bool = True,
- ):
- r"""
- Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
-
- :param model: model that need to get stats info.
- :param input_size: size of input for running model and calculating stats.
- :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
- :param log_params: whether print and record params size.
- :param log_flops: whether print and record op flops.
- """
-
- def get_byteswidth(tensor):
- if dtype.is_quantize(tensor.dtype):
- return 1
- # elif dtype.is_bfloat16(tensor.dtype):
- # return 2
- else:
- return 4
-
- def module_stats_hook(module, input, output, name=""):
- class_name = str(module.__class__).split(".")[-1].split("'")[0]
-
- flops_fun = CALC_FLOPS.get(type(module))
- if callable(flops_fun):
- flops_num = flops_fun(module, input, output)
-
- if not isinstance(output, (list, tuple)):
- output = [output]
-
- flops.append(
- dict(
- name=name,
- class_name=class_name,
- input_shapes=[i.shape for i in input],
- output_shapes=[o.shape for o in output],
- flops_num=flops_num,
- flops_cum=0,
- )
- )
-
- if hasattr(module, "weight") and module.weight is not None:
- w = module.weight
- value = w.numpy()
- param_dim = np.prod(w.shape)
- param_bytes = get_byteswidth(w)
- params.append(
- dict(
- name=name + "-w",
- shape=w.shape,
- param_dim=param_dim,
- bits=param_bytes * 8,
- size=param_dim * param_bytes,
- size_cum=0,
- mean="{:.2g}".format(value.mean()),
- std="{:.2g}".format(value.std()),
- )
- )
-
- if hasattr(module, "bias") and module.bias is not None:
- b = module.bias
- value = b.numpy()
- param_dim = np.prod(b.shape)
- param_bytes = get_byteswidth(b)
- params.append(
- dict(
- name=name + "-b",
- shape=b.shape,
- param_dim=param_dim,
- bits=param_bytes * 8,
- size=param_dim * param_bytes,
- size_cum=0,
- mean="{:.2g}".format(value.mean()),
- std="{:.2g}".format(value.std()),
- )
- )
-
- # multiple inputs to the network
- if not isinstance(input_size[0], tuple):
- input_size = [input_size]
-
- params = []
- flops = []
- hooks = []
-
- for (name, module) in model.named_modules():
- if isinstance(module, hook_modules):
- hooks.append(
- module.register_forward_hook(partial(module_stats_hook, name=name))
- )
-
- inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]
- model.eval()
- model(*inputs)
- for h in hooks:
- h.remove()
-
- total_flops, total_params = 0, 0
- if log_params:
- total_params = print_params_stats(params, bar_length_max)
- if log_flops:
- total_flops = print_flops_stats(flops, bar_length_max)
-
- return total_params, total_flops
|