|
|
@@ -40,30 +40,31 @@ def visualize( |
|
|
|
:param log_params: whether print and record params size. |
|
|
|
:param log_flops: whether print and record op flops. |
|
|
|
""" |
|
|
|
try: |
|
|
|
from tensorboard.compat.proto.attr_value_pb2 import AttrValue |
|
|
|
from tensorboard.compat.proto.config_pb2 import RunMetadata |
|
|
|
from tensorboard.compat.proto.graph_pb2 import GraphDef |
|
|
|
from tensorboard.compat.proto.node_def_pb2 import NodeDef |
|
|
|
from tensorboard.compat.proto.step_stats_pb2 import ( |
|
|
|
AllocatorMemoryUsed, |
|
|
|
DeviceStepStats, |
|
|
|
NodeExecStats, |
|
|
|
StepStats, |
|
|
|
) |
|
|
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto |
|
|
|
from tensorboard.compat.proto.versions_pb2 import VersionDef |
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
except ImportError: |
|
|
|
logger.error( |
|
|
|
"TensorBoard and TensorboardX are required for visualize.", exc_info=True |
|
|
|
) |
|
|
|
return |
|
|
|
if log_path: |
|
|
|
try: |
|
|
|
from tensorboard.compat.proto.attr_value_pb2 import AttrValue |
|
|
|
from tensorboard.compat.proto.config_pb2 import RunMetadata |
|
|
|
from tensorboard.compat.proto.graph_pb2 import GraphDef |
|
|
|
from tensorboard.compat.proto.node_def_pb2 import NodeDef |
|
|
|
from tensorboard.compat.proto.step_stats_pb2 import ( |
|
|
|
AllocatorMemoryUsed, |
|
|
|
DeviceStepStats, |
|
|
|
NodeExecStats, |
|
|
|
StepStats, |
|
|
|
) |
|
|
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto |
|
|
|
from tensorboard.compat.proto.versions_pb2 import VersionDef |
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
except ImportError: |
|
|
|
logger.error( |
|
|
|
"TensorBoard and TensorboardX are required for visualize.", |
|
|
|
exc_info=True, |
|
|
|
) |
|
|
|
return |
|
|
|
# FIXME: remove this after resolving "span dist too large" warning |
|
|
|
old_level = set_mgb_log_level(logging.ERROR) |
|
|
|
|
|
|
|
graph = Network.load(model_path) |
|
|
|
writer = SummaryWriter(log_path) |
|
|
|
|
|
|
|
def process_name(name): |
|
|
|
return name.replace(".", "/").encode(encoding="utf-8") |
|
|
@@ -84,21 +85,27 @@ def visualize( |
|
|
|
node_oup = node.outputs[0] |
|
|
|
|
|
|
|
inp_list = [process_name(var.owner.name) for var in node.inputs] |
|
|
|
attr = { |
|
|
|
"_output_shapes": AttrValue( |
|
|
|
list=AttrValue.ListValue( |
|
|
|
shape=[ |
|
|
|
TensorShapeProto( |
|
|
|
dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] |
|
|
|
) |
|
|
|
] |
|
|
|
) |
|
|
|
), |
|
|
|
} |
|
|
|
if log_path: |
|
|
|
attr = { |
|
|
|
"_output_shapes": AttrValue( |
|
|
|
list=AttrValue.ListValue( |
|
|
|
shape=[ |
|
|
|
TensorShapeProto( |
|
|
|
dim=[ |
|
|
|
TensorShapeProto.Dim(size=d) for d in node_oup.shape |
|
|
|
] |
|
|
|
) |
|
|
|
] |
|
|
|
) |
|
|
|
), |
|
|
|
} |
|
|
|
if hasattr(node, "calc_flops"): |
|
|
|
flops_num = node.calc_flops() |
|
|
|
# add op flops attr |
|
|
|
attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) |
|
|
|
if log_path: |
|
|
|
attr["flops"] = AttrValue( |
|
|
|
s=sizeof_fmt(flops_num).encode(encoding="utf-8") |
|
|
|
) |
|
|
|
flops_list.append( |
|
|
|
dict( |
|
|
|
name=node.name, |
|
|
@@ -114,9 +121,10 @@ def visualize( |
|
|
|
# TODO: consider other quantize dtypes |
|
|
|
param_bytes = 1 if is_quantize(node_oup.dtype) else 4 |
|
|
|
# add tensor size attr |
|
|
|
attr["size"] = AttrValue( |
|
|
|
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") |
|
|
|
) |
|
|
|
if log_path: |
|
|
|
attr["size"] = AttrValue( |
|
|
|
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") |
|
|
|
) |
|
|
|
params_list.append( |
|
|
|
dict( |
|
|
|
name=node.name, |
|
|
@@ -132,25 +140,33 @@ def visualize( |
|
|
|
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug |
|
|
|
if not len(node.name.split(".")) > 2 and not node in graph.input_vars: |
|
|
|
continue |
|
|
|
node_list.append( |
|
|
|
NodeDef( |
|
|
|
name=process_name(node.name), op=node.type, input=inp_list, attr=attr, |
|
|
|
if log_path: |
|
|
|
node_list.append( |
|
|
|
NodeDef( |
|
|
|
name=process_name(node.name), |
|
|
|
op=node.type, |
|
|
|
input=inp_list, |
|
|
|
attr=attr, |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
total_flops, total_params = 0, 0 |
|
|
|
total_flops, total_params = None, None |
|
|
|
if log_params: |
|
|
|
total_params = print_params_stats(params_list, bar_length_max) |
|
|
|
if log_flops: |
|
|
|
total_flops = print_flops_stats(flops_list, bar_length_max) |
|
|
|
|
|
|
|
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) |
|
|
|
if log_path: |
|
|
|
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) |
|
|
|
|
|
|
|
device = "/device:CPU:0" |
|
|
|
stepstats = RunMetadata( |
|
|
|
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) |
|
|
|
) |
|
|
|
writer._get_file_writer().add_graph((graph_def, stepstats)) |
|
|
|
device = "/device:CPU:0" |
|
|
|
stepstats = RunMetadata( |
|
|
|
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) |
|
|
|
) |
|
|
|
writer = SummaryWriter(log_path) |
|
|
|
writer._get_file_writer().add_graph((graph_def, stepstats)) |
|
|
|
|
|
|
|
# summary |
|
|
|
|
|
|
|
# FIXME: remove this after resolving "span dist too large" warning |
|
|
|
_imperative_rt_logger.set_log_level(old_level) |
|
|
@@ -164,7 +180,7 @@ def main(): |
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
|
|
) |
|
|
|
parser.add_argument("model_path", help="dumped model path.") |
|
|
|
parser.add_argument("log_path", help="tensorboard log path.") |
|
|
|
parser.add_argument("--log_path", help="tensorboard log path.") |
|
|
|
parser.add_argument( |
|
|
|
"--bar_length_max", |
|
|
|
type=int, |
|
|
@@ -179,7 +195,20 @@ def main(): |
|
|
|
parser.add_argument( |
|
|
|
"--log_flops", action="store_true", help="whether print and record op flops.", |
|
|
|
) |
|
|
|
visualize(**vars(parser.parse_args())) |
|
|
|
parser.add_argument( |
|
|
|
"--all", |
|
|
|
action="store_true", |
|
|
|
help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.", |
|
|
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.all: |
|
|
|
args.log_params = True |
|
|
|
args.log_flops = True |
|
|
|
if not args.log_path: |
|
|
|
args.log_path = "./log" |
|
|
|
kwargs = vars(args) |
|
|
|
kwargs.pop("all") |
|
|
|
visualize(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|