|
|
@@ -31,6 +31,8 @@ _calc_receptive_field_dict = {} |
|
|
|
|
|
|
|
|
|
|
|
def _receptive_field_fallback(module, inputs, outputs): |
|
|
|
if not _receptive_field_enabled: |
|
|
|
return |
|
|
|
assert not hasattr(module, "_rf") |
|
|
|
assert not hasattr(module, "_stride") |
|
|
|
if len(inputs) == 0: |
|
|
@@ -54,6 +56,8 @@ _iter_list = [ |
|
|
|
), |
|
|
|
] |
|
|
|
|
|
|
|
_receptive_field_enabled = False |
|
|
|
|
|
|
|
|
|
|
|
def _register_dict(*modules, dict=None): |
|
|
|
def callback(impl): |
|
|
@@ -72,6 +76,16 @@ def register_receptive_field(*modules): |
|
|
|
return _register_dict(*modules, dict=_calc_receptive_field_dict) |
|
|
|
|
|
|
|
|
|
|
|
def enable_receptive_field(): |
|
|
|
global _receptive_field_enabled |
|
|
|
_receptive_field_enabled = True |
|
|
|
|
|
|
|
|
|
|
|
def disable_receptive_field(): |
|
|
|
global _receptive_field_enabled |
|
|
|
_receptive_field_enabled = False |
|
|
|
|
|
|
|
|
|
|
|
@register_flops( |
|
|
|
m.Conv1d, m.Conv2d, m.Conv3d, |
|
|
|
) |
|
|
@@ -144,16 +158,16 @@ def preprocess_receptive_field(module, inputs, outputs): |
|
|
|
# TODO: support other dimensions |
|
|
|
pre_rf = ( |
|
|
|
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), |
|
|
|
max(i.owner._rf[1] for i in inputs), |
|
|
|
max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs), |
|
|
|
) |
|
|
|
pre_stride = ( |
|
|
|
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), |
|
|
|
max(i.owner._stride[1] for i in inputs), |
|
|
|
max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs), |
|
|
|
) |
|
|
|
return pre_rf, pre_stride |
|
|
|
|
|
|
|
|
|
|
|
def get_flops_stats(module, inputs, outputs): |
|
|
|
def get_op_stats(module, inputs, outputs): |
|
|
|
rst = { |
|
|
|
"input_shapes": [i.shape for i in inputs], |
|
|
|
"output_shapes": [o.shape for o in outputs], |
|
|
@@ -184,7 +198,7 @@ def get_flops_stats(module, inputs, outputs): |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
def print_flops_stats(flops, bar_length_max=20): |
|
|
|
def print_op_stats(flops, bar_length_max=20): |
|
|
|
max_flops_num = max([i["flops_num"] for i in flops] + [0]) |
|
|
|
total_flops_num = 0 |
|
|
|
for d in flops: |
|
|
@@ -203,13 +217,14 @@ def print_flops_stats(flops, bar_length_max=20): |
|
|
|
"class_name", |
|
|
|
"input_shapes", |
|
|
|
"output_shapes", |
|
|
|
"receptive_field", |
|
|
|
"stride", |
|
|
|
"flops", |
|
|
|
"flops_cum", |
|
|
|
"percentage", |
|
|
|
"bar", |
|
|
|
] |
|
|
|
if _receptive_field_enabled: |
|
|
|
header.insert(4, "receptive_field") |
|
|
|
header.insert(5, "stride") |
|
|
|
|
|
|
|
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") |
|
|
|
total_var_size = sum( |
|
|
@@ -240,7 +255,7 @@ def get_param_stats(param: np.ndarray): |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def print_params_stats(params, bar_length_max=20): |
|
|
|
def print_param_stats(params, bar_length_max=20): |
|
|
|
max_size = max([d["size"] for d in params] + [0]) |
|
|
|
total_param_dims, total_param_size = 0, 0 |
|
|
|
for d in params: |
|
|
@@ -302,11 +317,12 @@ def module_stats( |
|
|
|
:param log_params: whether print and record params size. |
|
|
|
:param log_flops: whether print and record op flops. |
|
|
|
""" |
|
|
|
disable_receptive_field() |
|
|
|
|
|
|
|
def module_stats_hook(module, inputs, outputs, name=""): |
|
|
|
class_name = str(module.__class__).split(".")[-1].split("'")[0] |
|
|
|
|
|
|
|
flops_stats = get_flops_stats(module, inputs, outputs) |
|
|
|
flops_stats = get_op_stats(module, inputs, outputs) |
|
|
|
if flops_stats is not None: |
|
|
|
flops_stats["name"] = name |
|
|
|
flops_stats["class_name"] = class_name |
|
|
@@ -349,11 +365,11 @@ def module_stats( |
|
|
|
} |
|
|
|
total_flops, total_param_dims, total_param_size = 0, 0, 0 |
|
|
|
if log_params: |
|
|
|
total_param_dims, total_param_size = print_params_stats(params, bar_length_max) |
|
|
|
total_param_dims, total_param_size = print_param_stats(params, bar_length_max) |
|
|
|
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) |
|
|
|
extra_info["total_param_size"] = sizeof_fmt(total_param_size) |
|
|
|
if log_flops: |
|
|
|
total_flops = print_flops_stats(flops, bar_length_max) |
|
|
|
total_flops = print_op_stats(flops, bar_length_max) |
|
|
|
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") |
|
|
|
if log_params and log_flops: |
|
|
|
extra_info["flops/param_size"] = "{:3.3f}".format( |
|
|
|