Browse Source

feat(mgb): add tensorboard tool python layer interface

GitOrigin-RevId: 065bc4d153
revert-211-master
Megvii Engine Team 3 years ago
parent
commit
8084e4e2ee
1 changed files with 236 additions and 0 deletions
  1. +236
    -0
      imperative/python/megengine/utils/tensorboard.py

+ 236
- 0
imperative/python/megengine/utils/tensorboard.py View File

@@ -0,0 +1,236 @@
#!/usr/bin/env python
# -*-coding=utf-8-*-

from megengine.logger import get_logger

logger = get_logger(__name__)

try:
from tensorboardX import SummaryWriter
from tensorboardX.proto.attr_value_pb2 import AttrValue
from tensorboardX.proto.graph_pb2 import GraphDef
from tensorboardX.proto.node_def_pb2 import NodeDef
from tensorboardX.proto.plugin_text_pb2 import TextPluginData
from tensorboardX.proto.step_stats_pb2 import (
DeviceStepStats,
RunMetadata,
StepStats,
)
from tensorboardX.proto.summary_pb2 import Summary, SummaryMetadata
from tensorboardX.proto.tensor_pb2 import TensorProto
from tensorboardX.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboardX.proto.versions_pb2 import VersionDef
except ImportError:
logger.error(
"TensorBoard and TensorboardX are required for visualize.", exc_info=True,
)


def tensor_shape_proto(shape):
"""Creates an object matching
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto
"""
return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in shape])


def attr_value_proto(shape, dtype, attr):
"""Creates a dict of objects matching
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
specifically designed for a NodeDef. The values have been
reverse engineered from standard TensorBoard logged data.
"""
attr_proto = {}
if shape is not None:
shapeproto = tensor_shape_proto(shape)
attr_proto["_output_shapes"] = AttrValue(
list=AttrValue.ListValue(shape=[shapeproto])
)
if dtype is not None:
attr_proto["dtype"] = AttrValue(s=dtype.encode(encoding="utf-8"))
if attr is not None:
for key in attr.keys():
attr_proto[key] = AttrValue(s=attr[key].encode(encoding="utf-8"))

return attr_proto


def node_proto(
name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={}
):
"""Creates an object matching
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto
"""
if input is None:
input = []
if not isinstance(input, list):
input = [input]
return NodeDef(
name=name.encode(encoding="utf_8"),
op=op,
input=input,
attr=attr_value_proto(outputshape, dtype, attributes),
)


def node(
name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={}
):
return node_proto(name, op, input, outputshape, dtype, attributes)


def graph(node_list):
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])
)
return graph_def, stepstats


def text(tag, text):
plugin_data = SummaryMetadata.PluginData(
plugin_name="text", content=TextPluginData(version=0).SerializeToString()
)
smd = SummaryMetadata(plugin_data=plugin_data)
string_val = []
for item in text:
string_val.append(item.encode(encoding="utf_8"))
tensor = TensorProto(
dtype="DT_STRING",
string_val=string_val,
tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=len(text))]),
)

return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])


class NodeRaw:
def __init__(self, name, op, input, outputshape, dtype, attributes):
self.name = name
self.op = op
self.input = input
self.outputshape = outputshape
self.dtype = dtype
self.attributes = attributes


class SummaryWriterExtend(SummaryWriter):
def __init__(
self,
logdir=None,
comment="",
purge_step=None,
max_queue=10,
flush_secs=120,
filename_suffix="",
write_to_disk=True,
log_dir=None,
**kwargs
):
self.node_raw_dict = {}
super().__init__(
logdir,
comment,
purge_step,
max_queue,
flush_secs,
filename_suffix,
write_to_disk,
log_dir,
**kwargs,
)

def add_text(self, tag, text_string_list, global_step=None, walltime=None):
"""Add text data to summary.
Args:
tag (string): Data identifier
text_string_list (string list): String to save
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
Examples::
# text can be divided into three levels by tag and global_step
from writer import SummaryWriterExtend
writer = SummaryWriterExtend()

writer.add_text('level1.0/level2.0', ['text0'], 0)
writer.add_text('level1.0/level2.0', ['text1'], 1)
writer.add_text('level1.0/level2.1', ['text2'])
writer.add_text('level1.1', ['text3'])
"""

self._get_file_writer().add_summary(
text(tag, text_string_list), global_step, walltime
)

def add_node_raw(
self,
name,
op="UnSpecified",
input=[],
outputshape=None,
dtype=None,
attributes={},
):
"""Add node raw datas that can help build graph.After add all nodes, call
add_graph_by_node_raw_list() to build graph and add graph data to summary.
Args:
name (string): opr name.
op (string): opr class name.
input (string list): input opr name.
outputshape (list): output shape.
dtype (string): output data dtype.
attributes (dict): attributes info.
Examples::
from writer import SummaryWriterExtend
writer = SummaryWriterExtend()

writer.add_node_raw('node1', 'opr1', outputshape=[6, 2, 3], dtype="float32", attributes={
"peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"})
writer.add_node_raw('node2', 'opr2', outputshape=[6, 2, 3], dtype="float32", input="node1", attributes={
"peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"})
writer.add_graph_by_node_raw_list()

"""
# self.node_raw_list.append(
# node(name, op, input, outputshape, dtype, attributes))
self.node_raw_dict[name] = NodeRaw(
name, op, input, outputshape, dtype, dict(attributes)
)

def add_node_raw_name_suffix(self, name, suffix):
"""Give node name suffix in order to finding this node by 'search nodes'
Args:
name (string): opr name.
suffix (string): nam suffix.
"""
old_name = self.node_raw_dict[name].name
new_name = old_name + suffix
# self.node_raw_dict[new_name] = self.node_raw_dict.pop(name)
self.node_raw_dict[name].name = new_name
for node_name, node in self.node_raw_dict.items():
node.input = [new_name if x == old_name else x for x in node.input]

def add_node_raw_attributes(self, name, attributes):
"""
Args:
name (string): opr name.
attributes (dict): attributes info that need to be added.
"""
for key, value in attributes.items():
self.node_raw_dict[name].attributes[key] = value

def add_graph_by_node_raw_list(self):
"""Build graph and add graph data to summary."""
node_raw_list = []
for key, value in self.node_raw_dict.items():
node_raw_list.append(
node(
value.name,
value.op,
value.input,
value.outputshape,
value.dtype,
value.attributes,
)
)
self._get_file_writer().add_graph(graph(node_raw_list))

Loading…
Cancel
Save