|
|
@@ -11,10 +11,10 @@ import numpy as np |
|
|
|
import tabulate |
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
import megengine.core.tensor.dtype as dtype |
|
|
|
import megengine.module as m |
|
|
|
import megengine.module.qat as qatm |
|
|
|
import megengine.module.quantized as qm |
|
|
|
from megengine.core.tensor.dtype import get_dtype_bit |
|
|
|
from megengine.functional.tensor import zeros |
|
|
|
|
|
|
|
try: |
|
|
@@ -115,13 +115,13 @@ def print_flops_stats(flops, bar_length_max=20): |
|
|
|
total_flops_num += int(d["flops_num"]) |
|
|
|
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") |
|
|
|
|
|
|
|
for i in flops: |
|
|
|
f = i["flops_num"] |
|
|
|
i["flops"] = sizeof_fmt(f, suffix="OPs") |
|
|
|
r = i["ratio"] = f / total_flops_num |
|
|
|
i["percentage"] = "{:.2f}%".format(r * 100) |
|
|
|
for d in flops: |
|
|
|
f = d["flops_num"] |
|
|
|
d["flops"] = sizeof_fmt(f, suffix="OPs") |
|
|
|
r = d["ratio"] = f / total_flops_num |
|
|
|
d["percentage"] = "{:.2f}%".format(r * 100) |
|
|
|
bar_length = int(f / max_flops_num * bar_length_max) |
|
|
|
i["bar"] = "#" * bar_length |
|
|
|
d["bar"] = "#" * bar_length |
|
|
|
|
|
|
|
header = [ |
|
|
|
"name", |
|
|
@@ -136,7 +136,7 @@ def print_flops_stats(flops, bar_length_max=20): |
|
|
|
|
|
|
|
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") |
|
|
|
total_var_size = sum( |
|
|
|
sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops |
|
|
|
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops |
|
|
|
) |
|
|
|
flops.append( |
|
|
|
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) |
|
|
@@ -147,16 +147,29 @@ def print_flops_stats(flops, bar_length_max=20): |
|
|
|
return total_flops_num |
|
|
|
|
|
|
|
|
|
|
|
def get_param_stats(param: np.ndarray): |
|
|
|
nbits = get_dtype_bit(param.dtype.name) |
|
|
|
shape = param.shape |
|
|
|
param_dim = np.prod(param.shape) |
|
|
|
param_size = param_dim * nbits // 8 |
|
|
|
return { |
|
|
|
"shape": shape, |
|
|
|
"mean": param.mean(), |
|
|
|
"std": param.std(), |
|
|
|
"param_dim": param_dim, |
|
|
|
"nbits": nbits, |
|
|
|
"size": param_size, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def print_params_stats(params, bar_length_max=20): |
|
|
|
total_param_dims, total_param_size = 0, 0 |
|
|
|
for d in params: |
|
|
|
total_param_dims += int(d["param_dim"]) |
|
|
|
total_param_size += int(d["size"]) |
|
|
|
ratio = d["size"] / total_param_size |
|
|
|
d["size"] = sizeof_fmt(d["size"]) |
|
|
|
d["size_cum"] = sizeof_fmt(total_param_size) |
|
|
|
|
|
|
|
for d in params: |
|
|
|
ratio = d["param_dim"] / total_param_dims |
|
|
|
d["ratio"] = ratio |
|
|
|
d["percentage"] = "{:.2f}%".format(ratio * 100) |
|
|
|
|
|
|
@@ -186,7 +199,13 @@ def print_params_stats(params, bar_length_max=20): |
|
|
|
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) |
|
|
|
) |
|
|
|
|
|
|
|
return total_param_size |
|
|
|
return total_param_dims, total_param_size |
|
|
|
|
|
|
|
|
|
|
|
def print_summary(**kwargs): |
|
|
|
data = [["item", "value"]] |
|
|
|
data.extend(list(kwargs.items())) |
|
|
|
logger.info("summary\n" + tabulate.tabulate(data)) |
|
|
|
|
|
|
|
|
|
|
|
def module_stats( |
|
|
@@ -206,14 +225,6 @@ def module_stats( |
|
|
|
:param log_flops: whether print and record op flops. |
|
|
|
""" |
|
|
|
|
|
|
|
def get_byteswidth(tensor): |
|
|
|
if dtype.is_quantize(tensor.dtype): |
|
|
|
return 1 |
|
|
|
# elif dtype.is_bfloat16(tensor.dtype): |
|
|
|
# return 2 |
|
|
|
else: |
|
|
|
return 4 |
|
|
|
|
|
|
|
def module_stats_hook(module, input, output, name=""): |
|
|
|
class_name = str(module.__class__).split(".")[-1].split("'")[0] |
|
|
|
|
|
|
@@ -237,39 +248,15 @@ def module_stats( |
|
|
|
|
|
|
|
if hasattr(module, "weight") and module.weight is not None: |
|
|
|
w = module.weight |
|
|
|
value = w.numpy() |
|
|
|
param_dim = np.prod(w.shape) |
|
|
|
param_bytes = get_byteswidth(w) |
|
|
|
params.append( |
|
|
|
dict( |
|
|
|
name=name + "-w", |
|
|
|
shape=w.shape, |
|
|
|
param_dim=param_dim, |
|
|
|
bits=param_bytes * 8, |
|
|
|
size=param_dim * param_bytes, |
|
|
|
size_cum=0, |
|
|
|
mean="{:.2g}".format(value.mean()), |
|
|
|
std="{:.2g}".format(value.std()), |
|
|
|
) |
|
|
|
) |
|
|
|
param_stats = get_param_stats(w.numpy()) |
|
|
|
param_stats["name"] = name + "-w" |
|
|
|
params.append(param_stats) |
|
|
|
|
|
|
|
if hasattr(module, "bias") and module.bias is not None: |
|
|
|
b = module.bias |
|
|
|
value = b.numpy() |
|
|
|
param_dim = np.prod(b.shape) |
|
|
|
param_bytes = get_byteswidth(b) |
|
|
|
params.append( |
|
|
|
dict( |
|
|
|
name=name + "-b", |
|
|
|
shape=b.shape, |
|
|
|
param_dim=param_dim, |
|
|
|
bits=param_bytes * 8, |
|
|
|
size=param_dim * param_bytes, |
|
|
|
size_cum=0, |
|
|
|
mean="{:.2g}".format(value.mean()), |
|
|
|
std="{:.2g}".format(value.std()), |
|
|
|
) |
|
|
|
) |
|
|
|
param_stats = get_param_stats(b.numpy()) |
|
|
|
param_stats["name"] = name + "-b" |
|
|
|
params.append(param_stats) |
|
|
|
|
|
|
|
# multiple inputs to the network |
|
|
|
if not isinstance(input_size[0], tuple): |
|
|
@@ -293,8 +280,17 @@ def module_stats( |
|
|
|
|
|
|
|
total_flops, total_params = 0, 0 |
|
|
|
if log_params: |
|
|
|
total_params = print_params_stats(params, bar_length_max) |
|
|
|
total_param_dims, total_param_size = print_params_stats(params, bar_length_max) |
|
|
|
if log_flops: |
|
|
|
total_flops = print_flops_stats(flops, bar_length_max) |
|
|
|
|
|
|
|
extra_info = { |
|
|
|
"#params": len(params), |
|
|
|
"total_param_dims": sizeof_fmt(total_param_dims), |
|
|
|
"total_param_size": sizeof_fmt(total_param_size), |
|
|
|
"total_flops": sizeof_fmt(total_flops, suffix="OPs"), |
|
|
|
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size), |
|
|
|
} |
|
|
|
print_summary(**extra_info) |
|
|
|
|
|
|
|
return total_params, total_flops |