You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_visualize.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #! /usr/bin/env python3
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import logging
  11. import numpy as np
  12. from megengine.core.tensor.dtype import is_quantize
  13. from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
  14. from megengine.utils.module_stats import (
  15. print_flops_stats,
  16. print_params_stats,
  17. sizeof_fmt,
  18. )
  19. from megengine.utils.network import Network
  20. logger = get_logger(__name__)
  21. def visualize(
  22. model_path: str,
  23. log_path: str,
  24. bar_length_max: int = 20,
  25. log_params: bool = True,
  26. log_flops: bool = True,
  27. ):
  28. r"""
  29. Load megengine dumped model and visualize graph structure with tensorboard log files.
  30. Can also record and print model's statistics like :func:`~.module_stats`
  31. :param model_path: dir path for megengine dumped model.
  32. :param log_path: dir path for tensorboard graph log.
  33. :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
  34. :param log_params: whether print and record params size.
  35. :param log_flops: whether print and record op flops.
  36. """
  37. if log_path:
  38. try:
  39. from tensorboard.compat.proto.attr_value_pb2 import AttrValue
  40. from tensorboard.compat.proto.config_pb2 import RunMetadata
  41. from tensorboard.compat.proto.graph_pb2 import GraphDef
  42. from tensorboard.compat.proto.node_def_pb2 import NodeDef
  43. from tensorboard.compat.proto.step_stats_pb2 import (
  44. AllocatorMemoryUsed,
  45. DeviceStepStats,
  46. NodeExecStats,
  47. StepStats,
  48. )
  49. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  50. from tensorboard.compat.proto.versions_pb2 import VersionDef
  51. from tensorboardX import SummaryWriter
  52. except ImportError:
  53. logger.error(
  54. "TensorBoard and TensorboardX are required for visualize.",
  55. exc_info=True,
  56. )
  57. return
  58. # FIXME: remove this after resolving "span dist too large" warning
  59. old_level = set_mgb_log_level(logging.ERROR)
  60. graph = Network.load(model_path)
  61. def process_name(name):
  62. return name.replace(".", "/").encode(encoding="utf-8")
  63. node_list = []
  64. flops_list = []
  65. params_list = []
  66. for node in graph.all_oprs:
  67. if hasattr(node, "output_idx"):
  68. node_oup = node.outputs[node.output_idx]
  69. else:
  70. if len(node.outputs) != 1:
  71. logger.warning(
  72. "OpNode {} has more than one output and not has 'output_idx' attr.".format(
  73. node
  74. )
  75. )
  76. node_oup = node.outputs[0]
  77. inp_list = [process_name(var.owner.name) for var in node.inputs]
  78. if log_path:
  79. attr = {
  80. "_output_shapes": AttrValue(
  81. list=AttrValue.ListValue(
  82. shape=[
  83. TensorShapeProto(
  84. dim=[
  85. TensorShapeProto.Dim(size=d) for d in node_oup.shape
  86. ]
  87. )
  88. ]
  89. )
  90. ),
  91. }
  92. if hasattr(node, "calc_flops"):
  93. flops_num = node.calc_flops()
  94. # add op flops attr
  95. if log_path:
  96. attr["flops"] = AttrValue(
  97. s=sizeof_fmt(flops_num).encode(encoding="utf-8")
  98. )
  99. flops_list.append(
  100. dict(
  101. name=node.name,
  102. class_name=node.type,
  103. input_shapes=[i.shape for i in node.inputs],
  104. output_shapes=[o.shape for o in node.outputs],
  105. flops_num=flops_num,
  106. flops_cum=0,
  107. )
  108. )
  109. if node.type == "ImmutableTensor":
  110. param_dim = np.prod(node_oup.shape)
  111. # TODO: consider other quantize dtypes
  112. param_bytes = 1 if is_quantize(node_oup.dtype) else 4
  113. # add tensor size attr
  114. if log_path:
  115. attr["size"] = AttrValue(
  116. s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8")
  117. )
  118. params_list.append(
  119. dict(
  120. name=node.name,
  121. shape=node_oup.shape,
  122. param_dim=param_dim,
  123. bits=param_bytes * 8,
  124. size=param_dim * param_bytes,
  125. size_cum=0,
  126. mean="{:.2g}".format(node.numpy().mean()),
  127. std="{:.2g}".format(node.numpy().std()),
  128. )
  129. )
  130. # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
  131. if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
  132. continue
  133. if log_path:
  134. node_list.append(
  135. NodeDef(
  136. name=process_name(node.name),
  137. op=node.type,
  138. input=inp_list,
  139. attr=attr,
  140. )
  141. )
  142. total_flops, total_params = None, None
  143. if log_params:
  144. total_params = print_params_stats(params_list, bar_length_max)
  145. if log_flops:
  146. total_flops = print_flops_stats(flops_list, bar_length_max)
  147. if log_path:
  148. graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
  149. device = "/device:CPU:0"
  150. stepstats = RunMetadata(
  151. step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
  152. )
  153. writer = SummaryWriter(log_path)
  154. writer._get_file_writer().add_graph((graph_def, stepstats))
  155. # summary
  156. # FIXME: remove this after resolving "span dist too large" warning
  157. _imperative_rt_logger.set_log_level(old_level)
  158. return total_params, total_flops
  159. def main():
  160. parser = argparse.ArgumentParser(
  161. description="load a megengine dumped model and export log file for tensorboard visualization.",
  162. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  163. )
  164. parser.add_argument("model_path", help="dumped model path.")
  165. parser.add_argument("--log_path", help="tensorboard log path.")
  166. parser.add_argument(
  167. "--bar_length_max",
  168. type=int,
  169. default=20,
  170. help="size of bar indicating max flops or parameter size in net stats.",
  171. )
  172. parser.add_argument(
  173. "--log_params",
  174. action="store_true",
  175. help="whether print and record params size.",
  176. )
  177. parser.add_argument(
  178. "--log_flops", action="store_true", help="whether print and record op flops.",
  179. )
  180. parser.add_argument(
  181. "--all",
  182. action="store_true",
  183. help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
  184. )
  185. args = parser.parse_args()
  186. if args.all:
  187. args.log_params = True
  188. args.log_flops = True
  189. if not args.log_path:
  190. args.log_path = "./log"
  191. kwargs = vars(args)
  192. kwargs.pop("all")
  193. visualize(**kwargs)
  194. if __name__ == "__main__":
  195. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台