Browse Source

fix(mge/tools): fix node display bug in tensorboard

GitOrigin-RevId: c997d6cccb
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
2df847544d
2 changed files with 7 additions and 6 deletions
  1. +5
    -6
      imperative/python/megengine/tools/network_visualize.py
  2. +2
    -0
      imperative/python/megengine/utils/module_stats.py

+ 5
- 6
imperative/python/megengine/tools/network_visualize.py View File

@@ -7,8 +7,8 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse
import json
import logging import logging
import re


import numpy as np import numpy as np


@@ -71,7 +71,10 @@ def visualize(
graph = Network.load(model_path) graph = Network.load(model_path)


def process_name(name): def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8")
# nodes that start with point or contain float const will lead to display bug
if not re.match(r"^[+-]?\d*\.\d*", name):
name = name.replace(".", "/")
return name.encode(encoding="utf-8")


summary = [["item", "value"]] summary = [["item", "value"]]
node_list = [] node_list = []
@@ -128,10 +131,6 @@ def visualize(
param_stats["name"] = node.name param_stats["name"] = node.name
params_list.append(param_stats) params_list.append(param_stats)


# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue

if log_path: if log_path:
node_list.append( node_list.append(
NodeDef( NodeDef(


+ 2
- 0
imperative/python/megengine/utils/module_stats.py View File

@@ -230,6 +230,7 @@ def get_param_stats(param: np.ndarray):
param_dim = np.prod(param.shape) param_dim = np.prod(param.shape)
param_size = param_dim * nbits // 8 param_size = param_dim * nbits // 8
return { return {
"dtype": param.dtype,
"shape": shape, "shape": shape,
"mean": "{:.3g}".format(param.mean()), "mean": "{:.3g}".format(param.mean()),
"std": "{:.3g}".format(param.std()), "std": "{:.3g}".format(param.std()),
@@ -260,6 +261,7 @@ def print_params_stats(params, bar_length_max=20):


header = [ header = [
"name", "name",
"dtype",
"shape", "shape",
"mean", "mean",
"std", "std",


Loading…
Cancel
Save