|
|
@@ -7,11 +7,12 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import argparse |
|
|
|
import logging |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from megengine.core.tensor.dtype import is_quantize |
|
|
|
from megengine.logger import get_logger |
|
|
|
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level |
|
|
|
from megengine.utils.module_stats import ( |
|
|
|
print_flops_stats, |
|
|
|
print_params_stats, |
|
|
@@ -58,6 +59,8 @@ def visualize( |
|
|
|
"TensorBoard and TensorboardX are required for visualize.", exc_info=True |
|
|
|
) |
|
|
|
return |
|
|
|
# FIXME: remove this after resolving "span dist too large" warning |
|
|
|
old_level = set_mgb_log_level(logging.ERROR) |
|
|
|
|
|
|
|
graph = Network.load(model_path) |
|
|
|
writer = SummaryWriter(log_path) |
|
|
@@ -126,6 +129,9 @@ def visualize( |
|
|
|
std="{:.2g}".format(node.numpy().std()), |
|
|
|
) |
|
|
|
) |
|
|
|
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug |
|
|
|
if not len(node.name.split(".")) > 2 and not node in graph.input_vars: |
|
|
|
continue |
|
|
|
node_list.append( |
|
|
|
NodeDef( |
|
|
|
name=process_name(node.name), op=node.type, input=inp_list, attr=attr, |
|
|
@@ -145,6 +151,10 @@ def visualize( |
|
|
|
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) |
|
|
|
) |
|
|
|
writer._get_file_writer().add_graph((graph_def, stepstats)) |
|
|
|
|
|
|
|
# FIXME: remove this after resolving "span dist too large" warning |
|
|
|
_imperative_rt_logger.set_log_level(old_level) |
|
|
|
|
|
|
|
return total_params, total_flops |
|
|
|
|
|
|
|
|
|
|
|