Browse Source

fix(mge/tools): improve `module_visualize` result's robustness and beauty

GitOrigin-RevId: ef7b573776
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
c45f1eb298
2 changed files with 14 additions and 2 deletions
  1. +11
    -1
      imperative/python/megengine/tools/network_visualize.py
  2. +3
    -1
      imperative/python/megengine/utils/module_stats.py

+ 11
- 1
imperative/python/megengine/tools/network_visualize.py View File

@@ -7,11 +7,12 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import logging

import numpy as np

from megengine.core.tensor.dtype import is_quantize
from megengine.logger import get_logger
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import (
print_flops_stats,
print_params_stats,
@@ -58,6 +59,8 @@ def visualize(
"TensorBoard and TensorboardX are required for visualize.", exc_info=True
)
return
# FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR)

graph = Network.load(model_path)
writer = SummaryWriter(log_path)
@@ -126,6 +129,9 @@ def visualize(
std="{:.2g}".format(node.numpy().std()),
)
)
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
node_list.append(
NodeDef(
name=process_name(node.name), op=node.type, input=inp_list, attr=attr,
@@ -145,6 +151,10 @@ def visualize(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
)
writer._get_file_writer().add_graph((graph_def, stepstats))

# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)

return total_params, total_flops




+ 3
- 1
imperative/python/megengine/utils/module_stats.py View File

@@ -135,7 +135,9 @@ def print_flops_stats(flops, bar_length_max=20):
]

total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops)
total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops
)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)


Loading…
Cancel
Save