Browse Source

fix(mge/tools): fix network_visualize for op without out shapes

GitOrigin-RevId: fdde52c214
HuaHua404-patch-4
Megvii Engine Team 2 years ago
parent
commit
4cd4a38a7d
2 changed files with 70 additions and 58 deletions
  1. +3
    -0
      imperative/python/megengine/core/tensor/dtype.py
  2. +67
    -58
      imperative/python/megengine/tools/network_visualize.py

+ 3
- 0
imperative/python/megengine/core/tensor/dtype.py View File

@@ -17,6 +17,9 @@ from .._imperative_rt.common import (




def get_dtype_bit(dtype_name: str): def get_dtype_bit(dtype_name: str):
special_cases = {"bool": 1}
if dtype_name in special_cases:
return special_cases[dtype_name]
numbers = re.findall(r"\d+", dtype_name) numbers = re.findall(r"\d+", dtype_name)
assert len(numbers) == 1, "Unsupport dtype name with more than one number." assert len(numbers) == 1, "Unsupport dtype name with more than one number."
return int(numbers[0]) return int(numbers[0])


+ 67
- 58
imperative/python/megengine/tools/network_visualize.py View File

@@ -129,6 +129,7 @@ def visualize(
) )
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"])


disable_stats = False
for node in tqdm(graph.all_oprs): for node in tqdm(graph.all_oprs):
if hasattr(node, "output_idx"): if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx] node_oup = node.outputs[node.output_idx]
@@ -145,7 +146,11 @@ def visualize(
if log_path: if log_path:
# detail format see tensorboard/compat/proto/attr_value.proto # detail format see tensorboard/compat/proto/attr_value.proto
attr = { attr = {
"_output_shapes": AttrValue(
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
if node_oup.shape:
attr["_output_shapes"] = AttrValue(
list=AttrValue.ListValue( list=AttrValue.ListValue(
shape=[ shape=[
TensorShapeProto( TensorShapeProto(
@@ -155,39 +160,42 @@ def visualize(
) )
] ]
) )
),
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
)
else:
disable_stats = True
logger.warning(
f"OpNode {node.name} do not has shape attr, would not calculate flops/params/activations for this net."
)


if cal_flops:
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
if not disable_stats:
if cal_flops:
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)


if cal_activations:
acts = get_activation_stats(node_oup, has_input=has_input)
acts["name"] = node.name
acts["class_name"] = node.type
activations_list.append(acts)
if cal_activations:
acts = get_activation_stats(node_oup, has_input=has_input)
acts["name"] = node.name
acts["class_name"] = node.type
activations_list.append(acts)


if cal_params:
if node.type == "ImmutableTensor":
param_stats = get_param_stats(node_oup)
# add tensor size attr
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
)
param_stats["name"] = node.name
params_list.append(param_stats)
if cal_params:
if node.type == "ImmutableTensor":
param_stats = get_param_stats(node_oup)
# add tensor size attr
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
)
param_stats["name"] = node.name
params_list.append(param_stats)


if log_path: if log_path:
node_list.append( node_list.append(
@@ -212,34 +220,37 @@ def visualize(
total_act_size, total_act_size,
) = (0, 0, 0, 0, 0) ) = (0, 0, 0, 0, 0)


if cal_params:
total_param_dims, total_param_size, params_list = sum_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if logging_to_stdout:
print_param_stats(params_list)
if not disable_stats:
if cal_params:
total_param_dims, total_param_size, params_list = sum_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if logging_to_stdout:
print_param_stats(params_list)


if cal_flops:
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if logging_to_stdout:
print_op_stats(flops_list)
if cal_flops:
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if logging_to_stdout:
print_op_stats(flops_list)


if cal_activations:
total_act_dims, total_act_size, activations_list = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout:
print_activations_stats(activations_list, has_input=has_input)
if cal_activations:
total_act_dims, total_act_size, activations_list = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout:
print_activations_stats(activations_list, has_input=has_input)


if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)

print_summary(**extra_info)


if log_path: if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
@@ -251,8 +262,6 @@ def visualize(
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats)) writer._get_file_writer().add_graph((graph_def, stepstats))


print_summary(**extra_info)

return ( return (
total_stats( total_stats(
param_size=total_param_size, param_size=total_param_size,


Loading…
Cancel
Save