You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_visualize.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #! /usr/bin/env python3
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import logging
  11. import numpy as np
  12. from megengine.core.tensor.dtype import is_quantize
  13. from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
  14. from megengine.utils.module_stats import (
  15. get_param_stats,
  16. print_flops_stats,
  17. print_params_stats,
  18. print_summary,
  19. sizeof_fmt,
  20. )
  21. from megengine.utils.network import Network
  22. logger = get_logger(__name__)
  23. def visualize(
  24. model_path: str,
  25. log_path: str,
  26. bar_length_max: int = 20,
  27. log_params: bool = True,
  28. log_flops: bool = True,
  29. ):
  30. r"""
  31. Load megengine dumped model and visualize graph structure with tensorboard log files.
  32. Can also record and print model's statistics like :func:`~.module_stats`
  33. :param model_path: dir path for megengine dumped model.
  34. :param log_path: dir path for tensorboard graph log.
  35. :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
  36. :param log_params: whether print and record params size.
  37. :param log_flops: whether print and record op flops.
  38. """
  39. if log_path:
  40. try:
  41. from tensorboard.compat.proto.attr_value_pb2 import AttrValue
  42. from tensorboard.compat.proto.config_pb2 import RunMetadata
  43. from tensorboard.compat.proto.graph_pb2 import GraphDef
  44. from tensorboard.compat.proto.node_def_pb2 import NodeDef
  45. from tensorboard.compat.proto.step_stats_pb2 import (
  46. AllocatorMemoryUsed,
  47. DeviceStepStats,
  48. NodeExecStats,
  49. StepStats,
  50. )
  51. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  52. from tensorboard.compat.proto.versions_pb2 import VersionDef
  53. from tensorboardX import SummaryWriter
  54. except ImportError:
  55. logger.error(
  56. "TensorBoard and TensorboardX are required for visualize.",
  57. exc_info=True,
  58. )
  59. return
  60. # FIXME: remove this after resolving "span dist too large" warning
  61. old_level = set_mgb_log_level(logging.ERROR)
  62. graph = Network.load(model_path)
  63. def process_name(name):
  64. return name.replace(".", "/").encode(encoding="utf-8")
  65. summary = [["item", "value"]]
  66. node_list = []
  67. flops_list = []
  68. params_list = []
  69. for node in graph.all_oprs:
  70. if hasattr(node, "output_idx"):
  71. node_oup = node.outputs[node.output_idx]
  72. else:
  73. if len(node.outputs) != 1:
  74. logger.warning(
  75. "OpNode {} has more than one output and not has 'output_idx' attr.".format(
  76. node
  77. )
  78. )
  79. node_oup = node.outputs[0]
  80. inp_list = [process_name(var.owner.name) for var in node.inputs]
  81. if log_path:
  82. attr = {
  83. "_output_shapes": AttrValue(
  84. list=AttrValue.ListValue(
  85. shape=[
  86. TensorShapeProto(
  87. dim=[
  88. TensorShapeProto.Dim(size=d) for d in node_oup.shape
  89. ]
  90. )
  91. ]
  92. )
  93. ),
  94. }
  95. if hasattr(node, "calc_flops"):
  96. flops_num = node.calc_flops()
  97. # add op flops attr
  98. if log_path:
  99. attr["flops"] = AttrValue(
  100. s=sizeof_fmt(flops_num).encode(encoding="utf-8")
  101. )
  102. flops_list.append(
  103. dict(
  104. name=node.name,
  105. class_name=node.type,
  106. input_shapes=[i.shape for i in node.inputs],
  107. output_shapes=[o.shape for o in node.outputs],
  108. flops_num=flops_num,
  109. flops_cum=0,
  110. )
  111. )
  112. if node.type == "ImmutableTensor":
  113. param_stats = get_param_stats(node.numpy())
  114. # add tensor size attr
  115. if log_path:
  116. attr["size"] = AttrValue(
  117. s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
  118. )
  119. param_stats["name"] = node.name
  120. params_list.append(param_stats)
  121. # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
  122. if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
  123. continue
  124. if log_path:
  125. node_list.append(
  126. NodeDef(
  127. name=process_name(node.name),
  128. op=node.type,
  129. input=inp_list,
  130. attr=attr,
  131. )
  132. )
  133. total_flops, total_params = None, None
  134. if log_params:
  135. total_param_dims, total_param_size = print_params_stats(
  136. params_list, bar_length_max
  137. )
  138. if log_flops:
  139. total_flops = print_flops_stats(flops_list, bar_length_max)
  140. if log_path:
  141. graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
  142. device = "/device:CPU:0"
  143. stepstats = RunMetadata(
  144. step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
  145. )
  146. writer = SummaryWriter(log_path)
  147. writer._get_file_writer().add_graph((graph_def, stepstats))
  148. # summary
  149. extra_info = {
  150. "#ops": len(graph.all_oprs),
  151. "#params": len(params_list),
  152. "total_param_dims": sizeof_fmt(total_param_dims),
  153. "total_param_size": sizeof_fmt(total_param_size),
  154. "total_flops": sizeof_fmt(total_flops, suffix="OPs"),
  155. "flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
  156. }
  157. print_summary(**extra_info)
  158. # FIXME: remove this after resolving "span dist too large" warning
  159. _imperative_rt_logger.set_log_level(old_level)
  160. return total_params, total_flops
  161. def main():
  162. parser = argparse.ArgumentParser(
  163. description="load a megengine dumped model and export log file for tensorboard visualization.",
  164. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  165. )
  166. parser.add_argument("model_path", help="dumped model path.")
  167. parser.add_argument("--log_path", help="tensorboard log path.")
  168. parser.add_argument(
  169. "--bar_length_max",
  170. type=int,
  171. default=20,
  172. help="size of bar indicating max flops or parameter size in net stats.",
  173. )
  174. parser.add_argument(
  175. "--log_params",
  176. action="store_true",
  177. help="whether print and record params size.",
  178. )
  179. parser.add_argument(
  180. "--log_flops", action="store_true", help="whether print and record op flops.",
  181. )
  182. parser.add_argument(
  183. "--all",
  184. action="store_true",
  185. help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
  186. )
  187. args = parser.parse_args()
  188. if args.all:
  189. args.log_params = True
  190. args.log_flops = True
  191. if not args.log_path:
  192. args.log_path = "./log"
  193. kwargs = vars(args)
  194. kwargs.pop("all")
  195. visualize(**kwargs)
  196. if __name__ == "__main__":
  197. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台