From 4cd4a38a7d43bcee46f17ce6e21deaca0be280fa Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 5 Jul 2022 18:17:18 +0800 Subject: [PATCH] fix(mge/tools): fix network_visualize for op without out shapes GitOrigin-RevId: fdde52c214a78531d9939938af3170c564bdcf4e --- imperative/python/megengine/core/tensor/dtype.py | 3 + .../python/megengine/tools/network_visualize.py | 125 +++++++++++---------- 2 files changed, 70 insertions(+), 58 deletions(-) diff --git a/imperative/python/megengine/core/tensor/dtype.py b/imperative/python/megengine/core/tensor/dtype.py index 8d0b8d56..ea9eb358 100644 --- a/imperative/python/megengine/core/tensor/dtype.py +++ b/imperative/python/megengine/core/tensor/dtype.py @@ -17,6 +17,9 @@ from .._imperative_rt.common import ( def get_dtype_bit(dtype_name: str): + special_cases = {"bool": 1} + if dtype_name in special_cases: + return special_cases[dtype_name] numbers = re.findall(r"\d+", dtype_name) assert len(numbers) == 1, "Unsupport dtype name with more than one number." return int(numbers[0]) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index 62eb1ecd..c3d98bf9 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -129,6 +129,7 @@ def visualize( ) stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) + disable_stats = False for node in tqdm(graph.all_oprs): if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] @@ -145,7 +146,11 @@ def visualize( if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { - "_output_shapes": AttrValue( + "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), + "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), + } + if node_oup.shape: + attr["_output_shapes"] = AttrValue( list=AttrValue.ListValue( shape=[ TensorShapeProto( @@ -155,39 +160,42 @@ def visualize( ) ] ) - ), - "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), - "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), - } + ) + else: + disable_stats = True + logger.warning( + f"OpNode {node.name} do not has shape attr, would not calculate flops/params/activations for this net." + ) - if cal_flops: - flops_stats = get_op_stats(node, node.inputs, node.outputs) - if flops_stats is not None: - # add op flops attr - if log_path and hasattr(flops_stats, "flops_num"): - attr["flops"] = AttrValue( - s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") - ) - flops_stats["name"] = node.name - flops_stats["class_name"] = node.type - flops_list.append(flops_stats) + if not disable_stats: + if cal_flops: + flops_stats = get_op_stats(node, node.inputs, node.outputs) + if flops_stats is not None: + # add op flops attr + if log_path and hasattr(flops_stats, "flops_num"): + attr["flops"] = AttrValue( + s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") + ) + flops_stats["name"] = node.name + flops_stats["class_name"] = node.type + flops_list.append(flops_stats) - if cal_activations: - acts = get_activation_stats(node_oup, has_input=has_input) - acts["name"] = node.name - acts["class_name"] = node.type - activations_list.append(acts) + if cal_activations: + acts = get_activation_stats(node_oup, has_input=has_input) + acts["name"] = node.name + acts["class_name"] = node.type + activations_list.append(acts) - if cal_params: - if node.type == "ImmutableTensor": - param_stats = get_param_stats(node_oup) - # add tensor size attr - if log_path: - attr["size"] = AttrValue( - s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") - ) - param_stats["name"] = node.name - params_list.append(param_stats) + if cal_params: + if node.type == "ImmutableTensor": + param_stats = get_param_stats(node_oup) + # add tensor size attr + if log_path: + attr["size"] = AttrValue( + s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") + ) + param_stats["name"] = node.name + params_list.append(param_stats) if log_path: node_list.append( @@ -212,34 +220,37 @@ def visualize( total_act_size, ) = (0, 0, 0, 0, 0) - if cal_params: - total_param_dims, total_param_size, params_list = sum_param_stats( - params_list, bar_length_max - ) - extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") - extra_info["total_param_size"] = sizeof_fmt(total_param_size) - if logging_to_stdout: - print_param_stats(params_list) + if not disable_stats: + if cal_params: + total_param_dims, total_param_size, params_list = sum_param_stats( + params_list, bar_length_max + ) + extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") + extra_info["total_param_size"] = sizeof_fmt(total_param_size) + if logging_to_stdout: + print_param_stats(params_list) - if cal_flops: - total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) - extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") - if logging_to_stdout: - print_op_stats(flops_list) + if cal_flops: + total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) + extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") + if logging_to_stdout: + print_op_stats(flops_list) - if cal_activations: - total_act_dims, total_act_size, activations_list = sum_activations_stats( - activations_list, bar_length_max - ) - extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") - extra_info["total_act_size"] = sizeof_fmt(total_act_size) - if logging_to_stdout: - print_activations_stats(activations_list, has_input=has_input) + if cal_activations: + total_act_dims, total_act_size, activations_list = sum_activations_stats( + activations_list, bar_length_max + ) + extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") + extra_info["total_act_size"] = sizeof_fmt(total_act_size) + if logging_to_stdout: + print_activations_stats(activations_list, has_input=has_input) - if cal_flops and cal_params: - extra_info["flops/param_size"] = "{:3.3f}".format( - total_flops / total_param_size - ) + if cal_flops and cal_params: + extra_info["flops/param_size"] = "{:3.3f}".format( + total_flops / total_param_size + ) + + print_summary(**extra_info) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) @@ -251,8 +262,6 @@ def visualize( writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) - print_summary(**extra_info) - return ( total_stats( param_size=total_param_size,