|
|
@@ -7,8 +7,8 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import argparse |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import re |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
@@ -71,7 +71,10 @@ def visualize( |
|
|
|
graph = Network.load(model_path) |
|
|
|
|
|
|
|
def process_name(name): |
|
|
|
return name.replace(".", "/").encode(encoding="utf-8") |
|
|
|
# nodes that start with point or contain float const will lead to display bug |
|
|
|
if not re.match(r"^[+-]?\d*\.\d*", name): |
|
|
|
name = name.replace(".", "/") |
|
|
|
return name.encode(encoding="utf-8") |
|
|
|
|
|
|
|
summary = [["item", "value"]] |
|
|
|
node_list = [] |
|
|
@@ -128,10 +131,6 @@ def visualize( |
|
|
|
param_stats["name"] = node.name |
|
|
|
params_list.append(param_stats) |
|
|
|
|
|
|
|
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug |
|
|
|
if not len(node.name.split(".")) > 2 and not node in graph.input_vars: |
|
|
|
continue |
|
|
|
|
|
|
|
if log_path: |
|
|
|
node_list.append( |
|
|
|
NodeDef( |
|
|
|