|
@@ -129,6 +129,7 @@ def visualize( |
|
|
) |
|
|
) |
|
|
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) |
|
|
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) |
|
|
|
|
|
|
|
|
|
|
|
disable_stats = False |
|
|
for node in tqdm(graph.all_oprs): |
|
|
for node in tqdm(graph.all_oprs): |
|
|
if hasattr(node, "output_idx"): |
|
|
if hasattr(node, "output_idx"): |
|
|
node_oup = node.outputs[node.output_idx] |
|
|
node_oup = node.outputs[node.output_idx] |
|
@@ -145,7 +146,11 @@ def visualize( |
|
|
if log_path: |
|
|
if log_path: |
|
|
# detail format see tensorboard/compat/proto/attr_value.proto |
|
|
# detail format see tensorboard/compat/proto/attr_value.proto |
|
|
attr = { |
|
|
attr = { |
|
|
"_output_shapes": AttrValue( |
|
|
|
|
|
|
|
|
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), |
|
|
|
|
|
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), |
|
|
|
|
|
} |
|
|
|
|
|
if node_oup.shape: |
|
|
|
|
|
attr["_output_shapes"] = AttrValue( |
|
|
list=AttrValue.ListValue( |
|
|
list=AttrValue.ListValue( |
|
|
shape=[ |
|
|
shape=[ |
|
|
TensorShapeProto( |
|
|
TensorShapeProto( |
|
@@ -155,39 +160,42 @@ def visualize( |
|
|
) |
|
|
) |
|
|
] |
|
|
] |
|
|
) |
|
|
) |
|
|
), |
|
|
|
|
|
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), |
|
|
|
|
|
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
disable_stats = True |
|
|
|
|
|
logger.warning( |
|
|
|
|
|
f"OpNode {node.name} do not has shape attr, would not calculate flops/params/activations for this net." |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
if cal_flops: |
|
|
|
|
|
flops_stats = get_op_stats(node, node.inputs, node.outputs) |
|
|
|
|
|
if flops_stats is not None: |
|
|
|
|
|
# add op flops attr |
|
|
|
|
|
if log_path and hasattr(flops_stats, "flops_num"): |
|
|
|
|
|
attr["flops"] = AttrValue( |
|
|
|
|
|
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") |
|
|
|
|
|
) |
|
|
|
|
|
flops_stats["name"] = node.name |
|
|
|
|
|
flops_stats["class_name"] = node.type |
|
|
|
|
|
flops_list.append(flops_stats) |
|
|
|
|
|
|
|
|
if not disable_stats: |
|
|
|
|
|
if cal_flops: |
|
|
|
|
|
flops_stats = get_op_stats(node, node.inputs, node.outputs) |
|
|
|
|
|
if flops_stats is not None: |
|
|
|
|
|
# add op flops attr |
|
|
|
|
|
if log_path and hasattr(flops_stats, "flops_num"): |
|
|
|
|
|
attr["flops"] = AttrValue( |
|
|
|
|
|
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") |
|
|
|
|
|
) |
|
|
|
|
|
flops_stats["name"] = node.name |
|
|
|
|
|
flops_stats["class_name"] = node.type |
|
|
|
|
|
flops_list.append(flops_stats) |
|
|
|
|
|
|
|
|
if cal_activations: |
|
|
|
|
|
acts = get_activation_stats(node_oup, has_input=has_input) |
|
|
|
|
|
acts["name"] = node.name |
|
|
|
|
|
acts["class_name"] = node.type |
|
|
|
|
|
activations_list.append(acts) |
|
|
|
|
|
|
|
|
if cal_activations: |
|
|
|
|
|
acts = get_activation_stats(node_oup, has_input=has_input) |
|
|
|
|
|
acts["name"] = node.name |
|
|
|
|
|
acts["class_name"] = node.type |
|
|
|
|
|
activations_list.append(acts) |
|
|
|
|
|
|
|
|
if cal_params: |
|
|
|
|
|
if node.type == "ImmutableTensor": |
|
|
|
|
|
param_stats = get_param_stats(node_oup) |
|
|
|
|
|
# add tensor size attr |
|
|
|
|
|
if log_path: |
|
|
|
|
|
attr["size"] = AttrValue( |
|
|
|
|
|
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") |
|
|
|
|
|
) |
|
|
|
|
|
param_stats["name"] = node.name |
|
|
|
|
|
params_list.append(param_stats) |
|
|
|
|
|
|
|
|
if cal_params: |
|
|
|
|
|
if node.type == "ImmutableTensor": |
|
|
|
|
|
param_stats = get_param_stats(node_oup) |
|
|
|
|
|
# add tensor size attr |
|
|
|
|
|
if log_path: |
|
|
|
|
|
attr["size"] = AttrValue( |
|
|
|
|
|
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") |
|
|
|
|
|
) |
|
|
|
|
|
param_stats["name"] = node.name |
|
|
|
|
|
params_list.append(param_stats) |
|
|
|
|
|
|
|
|
if log_path: |
|
|
if log_path: |
|
|
node_list.append( |
|
|
node_list.append( |
|
@@ -212,34 +220,37 @@ def visualize( |
|
|
total_act_size, |
|
|
total_act_size, |
|
|
) = (0, 0, 0, 0, 0) |
|
|
) = (0, 0, 0, 0, 0) |
|
|
|
|
|
|
|
|
if cal_params: |
|
|
|
|
|
total_param_dims, total_param_size, params_list = sum_param_stats( |
|
|
|
|
|
params_list, bar_length_max |
|
|
|
|
|
) |
|
|
|
|
|
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") |
|
|
|
|
|
extra_info["total_param_size"] = sizeof_fmt(total_param_size) |
|
|
|
|
|
if logging_to_stdout: |
|
|
|
|
|
print_param_stats(params_list) |
|
|
|
|
|
|
|
|
if not disable_stats: |
|
|
|
|
|
if cal_params: |
|
|
|
|
|
total_param_dims, total_param_size, params_list = sum_param_stats( |
|
|
|
|
|
params_list, bar_length_max |
|
|
|
|
|
) |
|
|
|
|
|
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") |
|
|
|
|
|
extra_info["total_param_size"] = sizeof_fmt(total_param_size) |
|
|
|
|
|
if logging_to_stdout: |
|
|
|
|
|
print_param_stats(params_list) |
|
|
|
|
|
|
|
|
if cal_flops: |
|
|
|
|
|
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) |
|
|
|
|
|
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") |
|
|
|
|
|
if logging_to_stdout: |
|
|
|
|
|
print_op_stats(flops_list) |
|
|
|
|
|
|
|
|
if cal_flops: |
|
|
|
|
|
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) |
|
|
|
|
|
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") |
|
|
|
|
|
if logging_to_stdout: |
|
|
|
|
|
print_op_stats(flops_list) |
|
|
|
|
|
|
|
|
if cal_activations: |
|
|
|
|
|
total_act_dims, total_act_size, activations_list = sum_activations_stats( |
|
|
|
|
|
activations_list, bar_length_max |
|
|
|
|
|
) |
|
|
|
|
|
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") |
|
|
|
|
|
extra_info["total_act_size"] = sizeof_fmt(total_act_size) |
|
|
|
|
|
if logging_to_stdout: |
|
|
|
|
|
print_activations_stats(activations_list, has_input=has_input) |
|
|
|
|
|
|
|
|
if cal_activations: |
|
|
|
|
|
total_act_dims, total_act_size, activations_list = sum_activations_stats( |
|
|
|
|
|
activations_list, bar_length_max |
|
|
|
|
|
) |
|
|
|
|
|
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") |
|
|
|
|
|
extra_info["total_act_size"] = sizeof_fmt(total_act_size) |
|
|
|
|
|
if logging_to_stdout: |
|
|
|
|
|
print_activations_stats(activations_list, has_input=has_input) |
|
|
|
|
|
|
|
|
if cal_flops and cal_params: |
|
|
|
|
|
extra_info["flops/param_size"] = "{:3.3f}".format( |
|
|
|
|
|
total_flops / total_param_size |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
if cal_flops and cal_params: |
|
|
|
|
|
extra_info["flops/param_size"] = "{:3.3f}".format( |
|
|
|
|
|
total_flops / total_param_size |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
print_summary(**extra_info) |
|
|
|
|
|
|
|
|
if log_path: |
|
|
if log_path: |
|
|
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) |
|
|
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) |
|
@@ -251,8 +262,6 @@ def visualize( |
|
|
writer = SummaryWriter(log_path) |
|
|
writer = SummaryWriter(log_path) |
|
|
writer._get_file_writer().add_graph((graph_def, stepstats)) |
|
|
writer._get_file_writer().add_graph((graph_def, stepstats)) |
|
|
|
|
|
|
|
|
print_summary(**extra_info) |
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
|
return ( |
|
|
total_stats( |
|
|
total_stats( |
|
|
param_size=total_param_size, |
|
|
param_size=total_param_size, |
|
|