You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorboard.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #!/usr/bin/env python
  2. # -*-coding=utf-8-*-
  3. from megengine.logger import get_logger
  4. logger = get_logger(__name__)
  5. try:
  6. from tensorboardX import SummaryWriter
  7. from tensorboardX.proto.attr_value_pb2 import AttrValue
  8. from tensorboardX.proto.graph_pb2 import GraphDef
  9. from tensorboardX.proto.node_def_pb2 import NodeDef
  10. from tensorboardX.proto.plugin_text_pb2 import TextPluginData
  11. from tensorboardX.proto.step_stats_pb2 import (
  12. DeviceStepStats,
  13. RunMetadata,
  14. StepStats,
  15. )
  16. from tensorboardX.proto.summary_pb2 import Summary, SummaryMetadata
  17. from tensorboardX.proto.tensor_pb2 import TensorProto
  18. from tensorboardX.proto.tensor_shape_pb2 import TensorShapeProto
  19. from tensorboardX.proto.versions_pb2 import VersionDef
  20. except ImportError:
  21. logger.error(
  22. "TensorBoard and TensorboardX are required for visualize.", exc_info=True,
  23. )
  24. def tensor_shape_proto(shape):
  25. """Creates an object matching
  26. https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto
  27. """
  28. return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in shape])
  29. def attr_value_proto(shape, dtype, attr):
  30. """Creates a dict of objects matching
  31. https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
  32. specifically designed for a NodeDef. The values have been
  33. reverse engineered from standard TensorBoard logged data.
  34. """
  35. attr_proto = {}
  36. if shape is not None:
  37. shapeproto = tensor_shape_proto(shape)
  38. attr_proto["_output_shapes"] = AttrValue(
  39. list=AttrValue.ListValue(shape=[shapeproto])
  40. )
  41. if dtype is not None:
  42. attr_proto["dtype"] = AttrValue(s=dtype.encode(encoding="utf-8"))
  43. if attr is not None:
  44. for key in attr.keys():
  45. attr_proto[key] = AttrValue(s=attr[key].encode(encoding="utf-8"))
  46. return attr_proto
  47. def node_proto(
  48. name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={}
  49. ):
  50. """Creates an object matching
  51. https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto
  52. """
  53. if input is None:
  54. input = []
  55. if not isinstance(input, list):
  56. input = [input]
  57. return NodeDef(
  58. name=name.encode(encoding="utf_8"),
  59. op=op,
  60. input=input,
  61. attr=attr_value_proto(outputshape, dtype, attributes),
  62. )
  63. def node(
  64. name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={}
  65. ):
  66. return node_proto(name, op, input, outputshape, dtype, attributes)
  67. def graph(node_list):
  68. graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
  69. stepstats = RunMetadata(
  70. step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])
  71. )
  72. return graph_def, stepstats
  73. def text(tag, text):
  74. plugin_data = SummaryMetadata.PluginData(
  75. plugin_name="text", content=TextPluginData(version=0).SerializeToString()
  76. )
  77. smd = SummaryMetadata(plugin_data=plugin_data)
  78. string_val = []
  79. for item in text:
  80. string_val.append(item.encode(encoding="utf_8"))
  81. tensor = TensorProto(
  82. dtype="DT_STRING",
  83. string_val=string_val,
  84. tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=len(text))]),
  85. )
  86. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
  87. class NodeRaw:
  88. def __init__(self, name, op, input, outputshape, dtype, attributes):
  89. self.name = name
  90. self.op = op
  91. self.input = input
  92. self.outputshape = outputshape
  93. self.dtype = dtype
  94. self.attributes = attributes
  95. class SummaryWriterExtend(SummaryWriter):
  96. def __init__(
  97. self,
  98. logdir=None,
  99. comment="",
  100. purge_step=None,
  101. max_queue=10,
  102. flush_secs=120,
  103. filename_suffix="",
  104. write_to_disk=True,
  105. log_dir=None,
  106. **kwargs
  107. ):
  108. self.node_raw_dict = {}
  109. super().__init__(
  110. logdir,
  111. comment,
  112. purge_step,
  113. max_queue,
  114. flush_secs,
  115. filename_suffix,
  116. write_to_disk,
  117. log_dir,
  118. **kwargs,
  119. )
  120. def add_text(self, tag, text_string_list, global_step=None, walltime=None):
  121. """Add text data to summary.
  122. Args:
  123. tag (string): Data identifier
  124. text_string_list (string list): String to save
  125. global_step (int): Global step value to record
  126. walltime (float): Optional override default walltime (time.time())
  127. seconds after epoch of event
  128. Examples::
  129. # text can be divided into three levels by tag and global_step
  130. from writer import SummaryWriterExtend
  131. writer = SummaryWriterExtend()
  132. writer.add_text('level1.0/level2.0', ['text0'], 0)
  133. writer.add_text('level1.0/level2.0', ['text1'], 1)
  134. writer.add_text('level1.0/level2.1', ['text2'])
  135. writer.add_text('level1.1', ['text3'])
  136. """
  137. self._get_file_writer().add_summary(
  138. text(tag, text_string_list), global_step, walltime
  139. )
  140. def add_node_raw(
  141. self,
  142. name,
  143. op="UnSpecified",
  144. input=[],
  145. outputshape=None,
  146. dtype=None,
  147. attributes={},
  148. ):
  149. """Add node raw datas that can help build graph.After add all nodes, call
  150. add_graph_by_node_raw_list() to build graph and add graph data to summary.
  151. Args:
  152. name (string): opr name.
  153. op (string): opr class name.
  154. input (string list): input opr name.
  155. outputshape (list): output shape.
  156. dtype (string): output data dtype.
  157. attributes (dict): attributes info.
  158. Examples::
  159. from writer import SummaryWriterExtend
  160. writer = SummaryWriterExtend()
  161. writer.add_node_raw('node1', 'opr1', outputshape=[6, 2, 3], dtype="float32", attributes={
  162. "peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"})
  163. writer.add_node_raw('node2', 'opr2', outputshape=[6, 2, 3], dtype="float32", input="node1", attributes={
  164. "peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"})
  165. writer.add_graph_by_node_raw_list()
  166. """
  167. # self.node_raw_list.append(
  168. # node(name, op, input, outputshape, dtype, attributes))
  169. self.node_raw_dict[name] = NodeRaw(
  170. name, op, input, outputshape, dtype, dict(attributes)
  171. )
  172. def add_node_raw_name_suffix(self, name, suffix):
  173. """Give node name suffix in order to finding this node by 'search nodes'
  174. Args:
  175. name (string): opr name.
  176. suffix (string): nam suffix.
  177. """
  178. old_name = self.node_raw_dict[name].name
  179. new_name = old_name + suffix
  180. # self.node_raw_dict[new_name] = self.node_raw_dict.pop(name)
  181. self.node_raw_dict[name].name = new_name
  182. for node_name, node in self.node_raw_dict.items():
  183. node.input = [new_name if x == old_name else x for x in node.input]
  184. def add_node_raw_attributes(self, name, attributes):
  185. """
  186. Args:
  187. name (string): opr name.
  188. attributes (dict): attributes info that need to be added.
  189. """
  190. for key, value in attributes.items():
  191. self.node_raw_dict[name].attributes[key] = value
  192. def add_graph_by_node_raw_list(self):
  193. """Build graph and add graph data to summary."""
  194. node_raw_list = []
  195. for key, value in self.node_raw_dict.items():
  196. node_raw_list.append(
  197. node(
  198. value.name,
  199. value.op,
  200. value.input,
  201. value.outputshape,
  202. value.dtype,
  203. value.attributes,
  204. )
  205. )
  206. self._get_file_writer().add_graph(graph(node_raw_list))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台