You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_visualize.py 7.8 kB


  1. #! /usr/bin/env python3
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import logging
  11. import re
  12. import numpy as np
  13. from megengine.core.tensor.dtype import is_quantize
  14. from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
  15. from megengine.utils.module_stats import (
  16. get_flops_stats,
  17. get_param_stats,
  18. print_flops_stats,
  19. print_params_stats,
  20. print_summary,
  21. sizeof_fmt,
  22. )
  23. from megengine.utils.network import Network
  24. logger = get_logger(__name__)
  25. def visualize(
  26. model_path: str,
  27. log_path: str,
  28. bar_length_max: int = 20,
  29. log_params: bool = True,
  30. log_flops: bool = True,
  31. ):
  32. r"""
  33. Load megengine dumped model and visualize graph structure with tensorboard log files.
  34. Can also record and print model's statistics like :func:`~.module_stats`
  35. :param model_path: dir path for megengine dumped model.
  36. :param log_path: dir path for tensorboard graph log.
  37. :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
  38. :param log_params: whether print and record params size.
  39. :param log_flops: whether print and record op flops.
  40. """
  41. if log_path:
  42. try:
  43. from tensorboard.compat.proto.attr_value_pb2 import AttrValue
  44. from tensorboard.compat.proto.config_pb2 import RunMetadata
  45. from tensorboard.compat.proto.graph_pb2 import GraphDef
  46. from tensorboard.compat.proto.node_def_pb2 import NodeDef
  47. from tensorboard.compat.proto.step_stats_pb2 import (
  48. AllocatorMemoryUsed,
  49. DeviceStepStats,
  50. NodeExecStats,
  51. StepStats,
  52. )
  53. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  54. from tensorboard.compat.proto.versions_pb2 import VersionDef
  55. from tensorboardX import SummaryWriter
  56. except ImportError:
  57. logger.error(
  58. "TensorBoard and TensorboardX are required for visualize.",
  59. exc_info=True,
  60. )
  61. return
  62. # FIXME: remove this after resolving "span dist too large" warning
  63. old_level = set_mgb_log_level(logging.ERROR)
  64. graph = Network.load(model_path)
  65. def process_name(name):
  66. # nodes that start with point or contain float const will lead to display bug
  67. if not re.match(r"^[+-]?\d*\.\d*", name):
  68. name = name.replace(".", "/")
  69. return name.encode(encoding="utf-8")
  70. summary = [["item", "value"]]
  71. node_list = []
  72. flops_list = []
  73. params_list = []
  74. for node in graph.all_oprs:
  75. if hasattr(node, "output_idx"):
  76. node_oup = node.outputs[node.output_idx]
  77. else:
  78. if len(node.outputs) != 1:
  79. logger.warning(
  80. "OpNode {} has more than one output and not has 'output_idx' attr.".format(
  81. node
  82. )
  83. )
  84. node_oup = node.outputs[0]
  85. inp_list = [process_name(var.owner.name) for var in node.inputs]
  86. if log_path:
  87. # detail format see tensorboard/compat/proto/attr_value.proto
  88. attr = {
  89. "_output_shapes": AttrValue(
  90. list=AttrValue.ListValue(
  91. shape=[
  92. TensorShapeProto(
  93. dim=[
  94. TensorShapeProto.Dim(size=d) for d in node_oup.shape
  95. ]
  96. )
  97. ]
  98. )
  99. ),
  100. "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
  101. "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
  102. }
  103. flops_stats = get_flops_stats(node, node.inputs, node.outputs)
  104. if flops_stats is not None:
  105. # add op flops attr
  106. if log_path and hasattr(flops_stats, "flops_num"):
  107. attr["flops"] = AttrValue(
  108. s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
  109. )
  110. flops_stats["name"] = node.name
  111. flops_stats["class_name"] = node.type
  112. flops_list.append(flops_stats)
  113. if node.type == "ImmutableTensor":
  114. param_stats = get_param_stats(node.numpy())
  115. # add tensor size attr
  116. if log_path:
  117. attr["size"] = AttrValue(
  118. s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
  119. )
  120. param_stats["name"] = node.name
  121. params_list.append(param_stats)
  122. if log_path:
  123. node_list.append(
  124. NodeDef(
  125. name=process_name(node.name),
  126. op=node.type,
  127. input=inp_list,
  128. attr=attr,
  129. )
  130. )
  131. # summary
  132. extra_info = {
  133. "#ops": len(graph.all_oprs),
  134. "#params": len(params_list),
  135. }
  136. total_flops, total_param_dims, total_param_size = 0, 0, 0
  137. if log_params:
  138. total_param_dims, total_param_size = print_params_stats(
  139. params_list, bar_length_max
  140. )
  141. extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
  142. extra_info["total_param_size"] = sizeof_fmt(total_param_size)
  143. if log_flops:
  144. total_flops = print_flops_stats(flops_list, bar_length_max)
  145. extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
  146. if log_params and log_flops:
  147. extra_info["flops/param_size"] = "{:3.3f}".format(
  148. total_flops / total_param_size
  149. )
  150. if log_path:
  151. graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
  152. device = "/device:CPU:0"
  153. stepstats = RunMetadata(
  154. step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
  155. )
  156. writer = SummaryWriter(log_path)
  157. writer._get_file_writer().add_graph((graph_def, stepstats))
  158. print_summary(**extra_info)
  159. # FIXME: remove this after resolving "span dist too large" warning
  160. _imperative_rt_logger.set_log_level(old_level)
  161. return total_param_size, total_flops
  162. def main():
  163. parser = argparse.ArgumentParser(
  164. description="load a megengine dumped model and export log file for tensorboard visualization.",
  165. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  166. )
  167. parser.add_argument("model_path", help="dumped model path.")
  168. parser.add_argument("--log_path", help="tensorboard log path.")
  169. parser.add_argument(
  170. "--bar_length_max",
  171. type=int,
  172. default=20,
  173. help="size of bar indicating max flops or parameter size in net stats.",
  174. )
  175. parser.add_argument(
  176. "--log_params",
  177. action="store_true",
  178. help="whether print and record params size.",
  179. )
  180. parser.add_argument(
  181. "--log_flops", action="store_true", help="whether print and record op flops.",
  182. )
  183. parser.add_argument(
  184. "--all",
  185. action="store_true",
  186. help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
  187. )
  188. args = parser.parse_args()
  189. if args.all:
  190. args.log_params = True
  191. args.log_flops = True
  192. if not args.log_path:
  193. args.log_path = "./log"
  194. kwargs = vars(args)
  195. kwargs.pop("all")
  196. visualize(**kwargs)
  197. if __name__ == "__main__":
  198. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台