|
- #!/usr/bin/env python
- # -*-coding=utf-8-*-
-
- from megengine.logger import get_logger
-
- logger = get_logger(__name__)
-
- try:
- from tensorboardX import SummaryWriter
- from tensorboardX.proto.attr_value_pb2 import AttrValue
- from tensorboardX.proto.graph_pb2 import GraphDef
- from tensorboardX.proto.node_def_pb2 import NodeDef
- from tensorboardX.proto.plugin_text_pb2 import TextPluginData
- from tensorboardX.proto.step_stats_pb2 import (
- DeviceStepStats,
- RunMetadata,
- StepStats,
- )
- from tensorboardX.proto.summary_pb2 import Summary, SummaryMetadata
- from tensorboardX.proto.tensor_pb2 import TensorProto
- from tensorboardX.proto.tensor_shape_pb2 import TensorShapeProto
- from tensorboardX.proto.versions_pb2 import VersionDef
- except ImportError:
- logger.error(
- "TensorBoard and TensorboardX are required for visualize.", exc_info=True,
- )
-
-
- def tensor_shape_proto(shape):
- """Creates an object matching
- https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto
- """
- return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in shape])
-
-
- def attr_value_proto(shape, dtype, attr):
- """Creates a dict of objects matching
- https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
- specifically designed for a NodeDef. The values have been
- reverse engineered from standard TensorBoard logged data.
- """
- attr_proto = {}
- if shape is not None:
- shapeproto = tensor_shape_proto(shape)
- attr_proto["_output_shapes"] = AttrValue(
- list=AttrValue.ListValue(shape=[shapeproto])
- )
- if dtype is not None:
- attr_proto["dtype"] = AttrValue(s=dtype.encode(encoding="utf-8"))
- if attr is not None:
- for key in attr.keys():
- attr_proto[key] = AttrValue(s=attr[key].encode(encoding="utf-8"))
-
- return attr_proto
-
-
- def node_proto(
- name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={}
- ):
- """Creates an object matching
- https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto
- """
- if input is None:
- input = []
- if not isinstance(input, list):
- input = [input]
- return NodeDef(
- name=name.encode(encoding="utf_8"),
- op=op,
- input=input,
- attr=attr_value_proto(outputshape, dtype, attributes),
- )
-
-
- def node(
- name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={}
- ):
- return node_proto(name, op, input, outputshape, dtype, attributes)
-
-
- def graph(node_list):
- graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
- stepstats = RunMetadata(
- step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])
- )
- return graph_def, stepstats
-
-
- def text(tag, text):
- plugin_data = SummaryMetadata.PluginData(
- plugin_name="text", content=TextPluginData(version=0).SerializeToString()
- )
- smd = SummaryMetadata(plugin_data=plugin_data)
- string_val = []
- for item in text:
- string_val.append(item.encode(encoding="utf_8"))
- tensor = TensorProto(
- dtype="DT_STRING",
- string_val=string_val,
- tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=len(text))]),
- )
-
- return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
-
-
- class NodeRaw:
- def __init__(self, name, op, input, outputshape, dtype, attributes):
- self.name = name
- self.op = op
- self.input = input
- self.outputshape = outputshape
- self.dtype = dtype
- self.attributes = attributes
-
-
- class SummaryWriterExtend(SummaryWriter):
- def __init__(
- self,
- logdir=None,
- comment="",
- purge_step=None,
- max_queue=10,
- flush_secs=120,
- filename_suffix="",
- write_to_disk=True,
- log_dir=None,
- **kwargs
- ):
- self.node_raw_dict = {}
- super().__init__(
- logdir,
- comment,
- purge_step,
- max_queue,
- flush_secs,
- filename_suffix,
- write_to_disk,
- log_dir,
- **kwargs,
- )
-
- def add_text(self, tag, text_string_list, global_step=None, walltime=None):
- """Add text data to summary.
- Args:
- tag (string): Data identifier
- text_string_list (string list): String to save
- global_step (int): Global step value to record
- walltime (float): Optional override default walltime (time.time())
- seconds after epoch of event
- Examples::
- # text can be divided into three levels by tag and global_step
- from writer import SummaryWriterExtend
- writer = SummaryWriterExtend()
-
- writer.add_text('level1.0/level2.0', ['text0'], 0)
- writer.add_text('level1.0/level2.0', ['text1'], 1)
- writer.add_text('level1.0/level2.1', ['text2'])
- writer.add_text('level1.1', ['text3'])
- """
-
- self._get_file_writer().add_summary(
- text(tag, text_string_list), global_step, walltime
- )
-
- def add_node_raw(
- self,
- name,
- op="UnSpecified",
- input=[],
- outputshape=None,
- dtype=None,
- attributes={},
- ):
- """Add node raw datas that can help build graph.After add all nodes, call
- add_graph_by_node_raw_list() to build graph and add graph data to summary.
- Args:
- name (string): opr name.
- op (string): opr class name.
- input (string list): input opr name.
- outputshape (list): output shape.
- dtype (string): output data dtype.
- attributes (dict): attributes info.
- Examples::
- from writer import SummaryWriterExtend
- writer = SummaryWriterExtend()
-
- writer.add_node_raw('node1', 'opr1', outputshape=[6, 2, 3], dtype="float32", attributes={
- "peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"})
- writer.add_node_raw('node2', 'opr2', outputshape=[6, 2, 3], dtype="float32", input="node1", attributes={
- "peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"})
- writer.add_graph_by_node_raw_list()
-
- """
- # self.node_raw_list.append(
- # node(name, op, input, outputshape, dtype, attributes))
- self.node_raw_dict[name] = NodeRaw(
- name, op, input, outputshape, dtype, dict(attributes)
- )
-
- def add_node_raw_name_suffix(self, name, suffix):
- """Give node name suffix in order to finding this node by 'search nodes'
- Args:
- name (string): opr name.
- suffix (string): nam suffix.
- """
- old_name = self.node_raw_dict[name].name
- new_name = old_name + suffix
- # self.node_raw_dict[new_name] = self.node_raw_dict.pop(name)
- self.node_raw_dict[name].name = new_name
- for node_name, node in self.node_raw_dict.items():
- node.input = [new_name if x == old_name else x for x in node.input]
-
- def add_node_raw_attributes(self, name, attributes):
- """
- Args:
- name (string): opr name.
- attributes (dict): attributes info that need to be added.
- """
- for key, value in attributes.items():
- self.node_raw_dict[name].attributes[key] = value
-
- def add_graph_by_node_raw_list(self):
- """Build graph and add graph data to summary."""
- node_raw_list = []
- for key, value in self.node_raw_dict.items():
- node_raw_list.append(
- node(
- value.name,
- value.op,
- value.input,
- value.outputshape,
- value.dtype,
- value.attributes,
- )
- )
- self._get_file_writer().add_graph(graph(node_raw_list))
|