You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_visualize.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. #! /usr/bin/env python3
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import logging
  11. import re
  12. import numpy as np
  13. from megengine.core.tensor.dtype import is_quantize
  14. from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
  15. from megengine.utils.module_stats import (
  16. enable_receptive_field,
  17. get_op_stats,
  18. get_param_stats,
  19. print_op_stats,
  20. print_param_stats,
  21. print_summary,
  22. sizeof_fmt,
  23. )
  24. from megengine.utils.network import Network
  25. logger = get_logger(__name__)
  26. def visualize(
  27. model_path: str,
  28. log_path: str,
  29. bar_length_max: int = 20,
  30. log_params: bool = True,
  31. log_flops: bool = True,
  32. ):
  33. r"""
  34. Load megengine dumped model and visualize graph structure with tensorboard log files.
  35. Can also record and print model's statistics like :func:`~.module_stats`
  36. :param model_path: dir path for megengine dumped model.
  37. :param log_path: dir path for tensorboard graph log.
  38. :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
  39. :param log_params: whether print and record params size.
  40. :param log_flops: whether print and record op flops.
  41. """
  42. if log_path:
  43. try:
  44. from tensorboard.compat.proto.attr_value_pb2 import AttrValue
  45. from tensorboard.compat.proto.config_pb2 import RunMetadata
  46. from tensorboard.compat.proto.graph_pb2 import GraphDef
  47. from tensorboard.compat.proto.node_def_pb2 import NodeDef
  48. from tensorboard.compat.proto.step_stats_pb2 import (
  49. AllocatorMemoryUsed,
  50. DeviceStepStats,
  51. NodeExecStats,
  52. StepStats,
  53. )
  54. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  55. from tensorboard.compat.proto.versions_pb2 import VersionDef
  56. from tensorboardX import SummaryWriter
  57. except ImportError:
  58. logger.error(
  59. "TensorBoard and TensorboardX are required for visualize.",
  60. exc_info=True,
  61. )
  62. return
  63. # FIXME: remove this after resolving "span dist too large" warning
  64. old_level = set_mgb_log_level(logging.ERROR)
  65. enable_receptive_field()
  66. graph = Network.load(model_path)
  67. def process_name(name):
  68. # nodes that start with point or contain float const will lead to display bug
  69. if not re.match(r"^[+-]?\d*\.\d*", name):
  70. name = name.replace(".", "/")
  71. return name.encode(encoding="utf-8")
  72. summary = [["item", "value"]]
  73. node_list = []
  74. flops_list = []
  75. params_list = []
  76. for node in graph.all_oprs:
  77. if hasattr(node, "output_idx"):
  78. node_oup = node.outputs[node.output_idx]
  79. else:
  80. if len(node.outputs) != 1:
  81. logger.warning(
  82. "OpNode {} has more than one output and not has 'output_idx' attr.".format(
  83. node
  84. )
  85. )
  86. node_oup = node.outputs[0]
  87. inp_list = [process_name(var.owner.name) for var in node.inputs]
  88. if log_path:
  89. # detail format see tensorboard/compat/proto/attr_value.proto
  90. attr = {
  91. "_output_shapes": AttrValue(
  92. list=AttrValue.ListValue(
  93. shape=[
  94. TensorShapeProto(
  95. dim=[
  96. TensorShapeProto.Dim(size=d) for d in node_oup.shape
  97. ]
  98. )
  99. ]
  100. )
  101. ),
  102. "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
  103. "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
  104. }
  105. flops_stats = get_op_stats(node, node.inputs, node.outputs)
  106. if flops_stats is not None:
  107. # add op flops attr
  108. if log_path and hasattr(flops_stats, "flops_num"):
  109. attr["flops"] = AttrValue(
  110. s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
  111. )
  112. flops_stats["name"] = node.name
  113. flops_stats["class_name"] = node.type
  114. flops_list.append(flops_stats)
  115. if node.type == "ImmutableTensor":
  116. param_stats = get_param_stats(node.numpy())
  117. # add tensor size attr
  118. if log_path:
  119. attr["size"] = AttrValue(
  120. s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
  121. )
  122. param_stats["name"] = node.name
  123. params_list.append(param_stats)
  124. if log_path:
  125. node_list.append(
  126. NodeDef(
  127. name=process_name(node.name),
  128. op=node.type,
  129. input=inp_list,
  130. attr=attr,
  131. )
  132. )
  133. # summary
  134. extra_info = {
  135. "#ops": len(graph.all_oprs),
  136. "#params": len(params_list),
  137. }
  138. total_flops, total_param_dims, total_param_size = 0, 0, 0
  139. if log_params:
  140. total_param_dims, total_param_size = print_param_stats(
  141. params_list, bar_length_max
  142. )
  143. extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
  144. extra_info["total_param_size"] = sizeof_fmt(total_param_size)
  145. if log_flops:
  146. total_flops = print_op_stats(flops_list, bar_length_max)
  147. extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
  148. if log_params and log_flops:
  149. extra_info["flops/param_size"] = "{:3.3f}".format(
  150. total_flops / total_param_size
  151. )
  152. if log_path:
  153. graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
  154. device = "/device:CPU:0"
  155. stepstats = RunMetadata(
  156. step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
  157. )
  158. writer = SummaryWriter(log_path)
  159. writer._get_file_writer().add_graph((graph_def, stepstats))
  160. print_summary(**extra_info)
  161. # FIXME: remove this after resolving "span dist too large" warning
  162. _imperative_rt_logger.set_log_level(old_level)
  163. return total_param_size, total_flops
  164. def main():
  165. parser = argparse.ArgumentParser(
  166. description="load a megengine dumped model and export log file for tensorboard visualization.",
  167. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  168. )
  169. parser.add_argument("model_path", help="dumped model path.")
  170. parser.add_argument("--log_path", help="tensorboard log path.")
  171. parser.add_argument(
  172. "--bar_length_max",
  173. type=int,
  174. default=20,
  175. help="size of bar indicating max flops or parameter size in net stats.",
  176. )
  177. parser.add_argument(
  178. "--log_params",
  179. action="store_true",
  180. help="whether print and record params size.",
  181. )
  182. parser.add_argument(
  183. "--log_flops", action="store_true", help="whether print and record op flops.",
  184. )
  185. parser.add_argument(
  186. "--all",
  187. action="store_true",
  188. help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
  189. )
  190. args = parser.parse_args()
  191. if args.all:
  192. args.log_params = True
  193. args.log_flops = True
  194. if not args.log_path:
  195. args.log_path = "./log"
  196. kwargs = vars(args)
  197. kwargs.pop("all")
  198. visualize(**kwargs)
  199. if __name__ == "__main__":
  200. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台