You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

event_parser.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # -*- coding: UTF-8 -*-
  2. """
  3. Copyright 2021 Tianshu AI Platform. All Rights Reserved.
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7. http://www.apache.org/licenses/LICENSE-2.0
  8. Unless required by applicable law or agreed to in writing, software
  9. distributed under the License is distributed on an "AS IS" BASIS,
  10. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. See the License for the specific language governing permissions and
  12. limitations under the License.
  13. =============================================================
  14. """
  15. import struct
  16. import numpy as np
  17. import json
  18. from oneflow.customized.utils.plugin_hparams_pb2 import HParamsPluginData
  19. from oneflow.customized.utils.graph_pb2 import GraphDef
  20. def get_parser(value, step, wall_time):
  21. """
  22. :param event:
  23. :return:
  24. dict = {tag, step, wall_time, value, type}
  25. """
  26. data = dict(step=step, wall_time=wall_time)
  27. if value.HasField('simple_value'):
  28. value = _get_scalar(value)
  29. elif value.HasField('image'):
  30. value = _get_image(value)
  31. elif value.HasField('audio'):
  32. value = _get_audio(value)
  33. elif value.HasField('histo'):
  34. value = _get_hist(value)
  35. elif value.HasField('projector'):
  36. value = _get_projector(value)
  37. elif value.HasField('metadata'):
  38. if value.metadata.plugin_data.plugin_name == 'hparams':
  39. value = _get_hparams(value)
  40. elif value.metadata.plugin_data.plugin_name == 'text':
  41. value = _get_text(value)
  42. elif value.metadata.plugin_data.plugin_name == 'featuremap':
  43. value = _get_featuremap(value)
  44. elif value.metadata.plugin_data.plugin_name == 'transformer':
  45. if 'transformertext' in value.tag:
  46. value = _get_TransformerText(value)
  47. else:
  48. value = _get_transformer(value)
  49. elif value.metadata.plugin_data.plugin_name == 'hiddenstate':
  50. value = _get_state(value)
  51. else:
  52. raise Exception(f'cannot parse {value.metadata.plugin_data.plugin_name} data.')
  53. else:
  54. raise Exception(f'cannot parse this data: {value}')
  55. data.update(value)
  56. return data
  57. def _decode_byte(tensor):
  58. # 若tensor是float类型
  59. if tensor.dtype == 1:
  60. return struct.unpack('f', tensor.tensor_content)[0]
  61. def _decoder_tensor(tensor):
  62. # tensor 为字节流
  63. tensor_shape = tuple([i.size for i in tensor.tensor_shape.dim])
  64. tensor_content = np.frombuffer(tensor.tensor_content, dtype=tensor.dtype)
  65. return tensor_content.reshape(tensor_shape)
  66. def get_graph(event):
  67. graph = GraphDef()
  68. graph.ParseFromString(event.graph_def)
  69. return dict(wall_time=event.wall_time,
  70. value=graph,
  71. type='graph')
  72. def _get_scalar(value):
  73. """
  74. Decode an scalar event
  75. :param value: A value field of an event
  76. :return: Decoded scalar
  77. """
  78. return dict(tag=value.tag,
  79. value=value.simple_value,
  80. type='scalar')
  81. def _get_image(value):
  82. """
  83. Decode an image event
  84. :param value: A value field of an event
  85. :return: Decoded image
  86. """
  87. dic = {
  88. 'width': value.image.width,
  89. 'height': value.image.height,
  90. 'colorspace': value.image.colorspace,
  91. 'encoded_image_string': value.image.encoded_image_string
  92. }
  93. return dict(tag=value.tag,
  94. value=dic,
  95. type='image')
  96. def _get_text(value):
  97. """
  98. Return text data
  99. :param value: A value field of an event
  100. :return: text data
  101. """
  102. return dict(tag=value.tag,
  103. value=np.array([v.decode() for v in value.tensor.string_val]),
  104. type='text')
  105. def _get_audio(value):
  106. dic = {'sample_rate': value.audio.sample_rate,
  107. 'num_channels': value.audio.num_channels,
  108. 'length_frames': value.audio.length_frames,
  109. 'encoded_audio_string': value.audio.encoded_audio_string}
  110. return dict(tag=value.tag,
  111. value=dic,
  112. type='audio')
  113. def _get_hist(value):
  114. dic = { 'min': value.histo.min,
  115. 'max': value.histo.max,
  116. 'num': value.histo.num,
  117. 'sum': value.histo.sum,
  118. 'sum_squares': value.histo.sum_squares,
  119. 'bucket_limit': np.array(value.histo.bucket_limit),
  120. 'bucket': np.array(value.histo.bucket)}
  121. return dict(tag=value.tag,
  122. value=dic,
  123. type='hist')
  124. def _get_hparams(value):
  125. metadata = value.metadata
  126. plugin_data = HParamsPluginData()
  127. plugin_data.ParseFromString(metadata.plugin_data.content)
  128. return dict(tag=value.tag,
  129. value=plugin_data,
  130. type='hparams')
  131. def _get_embedding(value):
  132. projector = value.projector
  133. if projector.embedding.HasField('sample'):
  134. sample_type = {1:'audio', 2:'text', 3:'image'}
  135. sample = projector.embedding.sample
  136. data =dict(type = sample_type[sample.type],
  137. X = _decoder_tensor(sample.X))
  138. return dict(tag = 'sample_' + value.tag,
  139. value = data,
  140. type = 'embedding'
  141. )
  142. else:
  143. embedding = projector.embedding
  144. return dict(tag = value.tag,
  145. value = _decoder_tensor(embedding.value),
  146. label = _decoder_tensor(embedding.label) if embedding.HasField('label') else np.array([]),
  147. type = 'embedding')
  148. def _get_exception(value):
  149. return dict(tag=value.tag,
  150. value=_decoder_tensor(value.projector.exception.value),
  151. type='exception')
  152. def _get_projector(value):
  153. projector = value.projector
  154. if projector.HasField('embedding'):
  155. return _get_embedding(value)
  156. else:
  157. return _get_exception(value)
  158. def filter_graph(file):
  159. variable_names = {}
  160. graph = json.loads(file)
  161. for sub_graph in graph:
  162. cfg = sub_graph["config"]
  163. # 拷贝一份,用于循环
  164. cfg_copy = cfg["layers"].copy()
  165. for layer in cfg_copy:
  166. if layer["class_name"] == "variable":
  167. _name = layer["name"]
  168. variable_names[_name] = layer
  169. cfg["layers"].remove(layer)
  170. # 第二遍循环,删除`variable_names`出现在`inbound_nodes`中的名字
  171. for sub_graph in graph:
  172. cfg = sub_graph["config"]
  173. for layer in cfg["layers"]:
  174. in_nodes = layer["inbound_nodes"]
  175. in_nodes_copy = in_nodes.copy()
  176. for node in in_nodes_copy:
  177. # 在里面则删除
  178. if node in variable_names.keys():
  179. in_nodes.remove(node)
  180. graph_str = json.dumps(graph)
  181. return graph_str
  182. def _get_featuremap(value):
  183. return dict(tag=value.tag,
  184. value=np.array(_decoder_tensor(value.tensor)),
  185. type='featuremap')
  186. def _get_transformer(value):
  187. return dict(tag=value.tag,
  188. value=np.array(_decoder_tensor(value.tensor)),
  189. type='transformer')
  190. def _get_TransformerText(value):
  191. if "transformertext-sentence" in value.tag:
  192. return dict(tag=value.tag,
  193. value=np.array(_decoder_tensor(value.tensor)),
  194. type='transformer')
  195. else:
  196. return dict(tag=value.tag,
  197. value=_decoder_TransformerText(value.transformer),
  198. type='transformer')
  199. def _get_state(value):
  200. return dict(tag=value.tag,
  201. value=np.array(_decoder_tensor(value.tensor)),
  202. type='hiddenstate')
  203. def _decoder_TransformerText(value):
  204. data = {}
  205. for attentionItem in list(value.attentionItem):
  206. tag = attentionItem.tag
  207. attention_kid_data = {}
  208. attention_kid_data['attn'] = _decoder_tensor(attentionItem.attn)
  209. attention_kid_data['left_text'] = _decoder_tensor(attentionItem.left)
  210. attention_kid_data['right_text'] = _decoder_tensor(attentionItem.right)
  211. data[tag] = attention_kid_data
  212. data["bidirectional"] = value.bidirectional
  213. data["default_filter"] = value.default_filter
  214. data["displayMode"] = value.displayMode
  215. data["head"] = value.head
  216. data["layer"] = value.layer
  217. return data

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能