Browse Source

fix(mge/tools): fix node display bug in tensorboard

GitOrigin-RevId: c997d6cccb
release-1.3
Megvii Engine Team 4 years ago
parent
commit
007a2376c3
2 changed files with 7 additions and 6 deletions
  1. +5
    -6
      imperative/python/megengine/tools/network_visualize.py
  2. +2
    -0
      imperative/python/megengine/utils/module_stats.py

+ 5
- 6
imperative/python/megengine/tools/network_visualize.py View File

@@ -7,8 +7,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import json
import logging
import re

import numpy as np

@@ -71,7 +71,10 @@ def visualize(
graph = Network.load(model_path)

def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8")
# nodes that start with point or contain float const will lead to display bug
if not re.match(r"^[+-]?\d*\.\d*", name):
name = name.replace(".", "/")
return name.encode(encoding="utf-8")

summary = [["item", "value"]]
node_list = []
@@ -128,10 +131,6 @@ def visualize(
param_stats["name"] = node.name
params_list.append(param_stats)

# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue

if log_path:
node_list.append(
NodeDef(


+ 2
- 0
imperative/python/megengine/utils/module_stats.py View File

@@ -230,6 +230,7 @@ def get_param_stats(param: np.ndarray):
param_dim = np.prod(param.shape)
param_size = param_dim * nbits // 8
return {
"dtype": param.dtype,
"shape": shape,
"mean": "{:.3g}".format(param.mean()),
"std": "{:.3g}".format(param.std()),
@@ -260,6 +261,7 @@ def print_params_stats(params, bar_length_max=20):

header = [
"name",
"dtype",
"shape",
"mean",
"std",


Loading…
Cancel
Save