From c45f1eb29835fa95201c0be784ff031da026afb7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Mar 2021 16:27:38 +0800 Subject: [PATCH] fix(mge/tools): improve `module_visualize` result's robustness and beauty GitOrigin-RevId: ef7b57377619fabcf50d4a48235b8d196659f1d4 --- imperative/python/megengine/tools/network_visualize.py | 12 +++++++++++- imperative/python/megengine/utils/module_stats.py | 4 +++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index a77641b0..210343f8 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -7,11 +7,12 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse +import logging import numpy as np from megengine.core.tensor.dtype import is_quantize -from megengine.logger import get_logger +from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.utils.module_stats import ( print_flops_stats, print_params_stats, @@ -58,6 +59,8 @@ def visualize( "TensorBoard and TensorboardX are required for visualize.", exc_info=True ) return + # FIXME: remove this after resolving "span dist too large" warning + old_level = set_mgb_log_level(logging.ERROR) graph = Network.load(model_path) writer = SummaryWriter(log_path) @@ -126,6 +129,9 @@ def visualize( std="{:.2g}".format(node.numpy().std()), ) ) + # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug + if not len(node.name.split(".")) > 2 and not node in graph.input_vars: + continue node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, @@ -145,6 +151,10 @@ def visualize( step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) ) writer._get_file_writer().add_graph((graph_def, stepstats)) + + # FIXME: remove this after resolving "span dist too large" warning + _imperative_rt_logger.set_log_level(old_level) + return total_params, total_flops diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index c091e321..af2a1cb5 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -135,7 +135,9 @@ def print_flops_stats(flops, bar_length_max=20): ] total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") - total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) + total_var_size = sum( + sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops + ) flops.append( dict(name="total", flops=total_flops_str, output_shapes=total_var_size) )