You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_visualize.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. #! /usr/bin/env python3
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import logging
  11. import re
  12. from collections import namedtuple
  13. import numpy as np
  14. from tqdm import tqdm
  15. import megengine as mge
  16. from megengine.core.tensor.dtype import is_quantize
  17. from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
  18. from megengine.utils.module_stats import (
  19. enable_receptive_field,
  20. get_activation_stats,
  21. get_op_stats,
  22. get_param_stats,
  23. print_activations_stats,
  24. print_op_stats,
  25. print_param_stats,
  26. print_summary,
  27. sizeof_fmt,
  28. sum_activations_stats,
  29. sum_op_stats,
  30. sum_param_stats,
  31. )
  32. from megengine.utils.network import Network
  33. logger = get_logger(__name__)
  34. def visualize(
  35. model_path: str,
  36. log_path: str,
  37. input: np.ndarray = None,
  38. inp_dict: dict = None,
  39. cal_params: bool = True,
  40. cal_flops: bool = True,
  41. cal_activations: bool = True,
  42. logging_to_stdout: bool = True,
  43. bar_length_max: int = 20,
  44. ):
  45. r"""Load megengine dumped model and visualize graph structure with tensorboard log files.
  46. Can also record and print model's statistics like :func:`~.module_stats`
  47. Args:
  48. model_path: dir path for megengine dumped model.
  49. log_path: dir path for tensorboard graph log.
  50. input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input.
  51. inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used.
  52. cal_params: whether calculate and record params size.
  53. cal_flops: whether calculate and record op flops.
  54. cal_activations: whether calculate and record op activations.
  55. logging_to_stdout: whether print all calculated statistic details.
  56. bar_length_max: size of bar indicating max flops or parameter size in net stats.
  57. model_path: str:
  58. log_path: str:
  59. input: np.ndarray:
  60. inp_dict: dict:
  61. cal_params: bool:
  62. cal_flops: bool:
  63. cal_activations: bool:
  64. logging_to_stdout: bool:
  65. bar_length_max: int:
  66. """
  67. if log_path:
  68. try:
  69. from tensorboard.compat.proto.attr_value_pb2 import AttrValue
  70. from tensorboard.compat.proto.config_pb2 import RunMetadata
  71. from tensorboard.compat.proto.graph_pb2 import GraphDef
  72. from tensorboard.compat.proto.node_def_pb2 import NodeDef
  73. from tensorboard.compat.proto.step_stats_pb2 import (
  74. AllocatorMemoryUsed,
  75. DeviceStepStats,
  76. NodeExecStats,
  77. StepStats,
  78. )
  79. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  80. from tensorboard.compat.proto.versions_pb2 import VersionDef
  81. from tensorboardX import SummaryWriter
  82. except ImportError:
  83. logger.error(
  84. "TensorBoard and TensorboardX are required for visualize.",
  85. exc_info=True,
  86. )
  87. return
  88. enable_receptive_field()
  89. graph = Network.load(model_path)
  90. graph.reset_batch_size(1)
  91. has_input = False
  92. if input is not None or inp_dict is not None:
  93. has_input = True
  94. repl_dict = {}
  95. inp_vars = graph.input_vars
  96. if inp_dict is not None:
  97. assert len(inp_dict) == len(
  98. inp_vars
  99. ), "Inputs are not sufficient for calculation."
  100. for v in inp_vars:
  101. new_input = graph.make_const(inp_dict[v.name], name=v.name)
  102. repl_dict[v] = new_input
  103. else:
  104. assert len(inp_vars) == 1, "The graph needs more than one input."
  105. inp_var = inp_vars[0]
  106. repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
  107. graph.replace_vars(repl_dict=repl_dict)
  108. graph._compile()
  109. def process_name(name):
  110. # nodes that start with point or contain float const will lead to display bug
  111. if not re.match(r"^[+-]?\d*\.\d*", name):
  112. name = name.replace(".", "/")
  113. return name.encode(encoding="utf-8")
  114. summary = [["item", "value"]]
  115. node_list = []
  116. flops_list = []
  117. params_list = []
  118. activations_list = []
  119. total_stats = namedtuple(
  120. "total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"]
  121. )
  122. stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
  123. for node in tqdm(graph.all_oprs):
  124. if hasattr(node, "output_idx"):
  125. node_oup = node.outputs[node.output_idx]
  126. else:
  127. if len(node.outputs) != 1:
  128. logger.warning(
  129. "OpNode {} has more than one output and not has 'output_idx' attr.".format(
  130. node
  131. )
  132. )
  133. node_oup = node.outputs[0]
  134. inp_list = [process_name(var.owner.name) for var in node.inputs]
  135. if log_path:
  136. # detail format see tensorboard/compat/proto/attr_value.proto
  137. attr = {
  138. "_output_shapes": AttrValue(
  139. list=AttrValue.ListValue(
  140. shape=[
  141. TensorShapeProto(
  142. dim=[
  143. TensorShapeProto.Dim(size=d) for d in node_oup.shape
  144. ]
  145. )
  146. ]
  147. )
  148. ),
  149. "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
  150. "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
  151. }
  152. if cal_flops:
  153. flops_stats = get_op_stats(node, node.inputs, node.outputs)
  154. if flops_stats is not None:
  155. # add op flops attr
  156. if log_path and hasattr(flops_stats, "flops_num"):
  157. attr["flops"] = AttrValue(
  158. s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
  159. )
  160. flops_stats["name"] = node.name
  161. flops_stats["class_name"] = node.type
  162. flops_list.append(flops_stats)
  163. if cal_activations:
  164. acts = get_activation_stats(node_oup, has_input=has_input)
  165. acts["name"] = node.name
  166. acts["class_name"] = node.type
  167. activations_list.append(acts)
  168. if cal_params:
  169. if node.type == "ImmutableTensor":
  170. param_stats = get_param_stats(node_oup)
  171. # add tensor size attr
  172. if log_path:
  173. attr["size"] = AttrValue(
  174. s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
  175. )
  176. param_stats["name"] = node.name
  177. params_list.append(param_stats)
  178. if log_path:
  179. node_list.append(
  180. NodeDef(
  181. name=process_name(node.name),
  182. op=node.type,
  183. input=inp_list,
  184. attr=attr,
  185. )
  186. )
  187. # summary
  188. extra_info = {
  189. "#ops": len(graph.all_oprs),
  190. "#params": len(params_list),
  191. }
  192. (
  193. total_flops,
  194. total_param_dims,
  195. total_param_size,
  196. total_act_dims,
  197. total_act_size,
  198. ) = (0, 0, 0, 0, 0)
  199. if cal_params:
  200. total_param_dims, total_param_size, params_list = sum_param_stats(
  201. params_list, bar_length_max
  202. )
  203. extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
  204. extra_info["total_param_size"] = sizeof_fmt(total_param_size)
  205. if logging_to_stdout:
  206. print_param_stats(params_list)
  207. if cal_flops:
  208. total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
  209. extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
  210. if logging_to_stdout:
  211. print_op_stats(flops_list)
  212. if cal_activations:
  213. total_act_dims, total_act_size, activations_list = sum_activations_stats(
  214. activations_list, bar_length_max
  215. )
  216. extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
  217. extra_info["total_act_size"] = sizeof_fmt(total_act_size)
  218. if logging_to_stdout:
  219. print_activations_stats(activations_list, has_input=has_input)
  220. if cal_flops and cal_params:
  221. extra_info["flops/param_size"] = "{:3.3f}".format(
  222. total_flops / total_param_size
  223. )
  224. if log_path:
  225. graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
  226. device = "/device:CPU:0"
  227. stepstats = RunMetadata(
  228. step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
  229. )
  230. writer = SummaryWriter(log_path)
  231. writer._get_file_writer().add_graph((graph_def, stepstats))
  232. print_summary(**extra_info)
  233. return (
  234. total_stats(
  235. param_size=total_param_size,
  236. param_dims=total_param_dims,
  237. flops=total_flops,
  238. act_size=total_act_size,
  239. act_dims=total_act_dims,
  240. ),
  241. stats_details(
  242. params=params_list, flops=flops_list, activations=activations_list
  243. ),
  244. )
  245. def main():
  246. parser = argparse.ArgumentParser(
  247. description="load a megengine dumped model and export log file for tensorboard visualization.",
  248. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  249. )
  250. parser.add_argument("model_path", help="dumped model path.")
  251. parser.add_argument("--log_path", help="tensorboard log path.")
  252. parser.add_argument(
  253. "--load_input_data",
  254. help="load input data from pickle file; it should be a numpy array or a dict of numpy array",
  255. )
  256. parser.add_argument(
  257. "--bar_length_max",
  258. type=int,
  259. default=20,
  260. help="size of bar indicating max flops or parameter size in net stats.",
  261. )
  262. parser.add_argument(
  263. "--cal_params",
  264. action="store_true",
  265. help="whether calculate and record params size.",
  266. )
  267. parser.add_argument(
  268. "--cal_flops",
  269. action="store_true",
  270. help="whether calculate and record op flops.",
  271. )
  272. parser.add_argument(
  273. "--cal_activations",
  274. action="store_true",
  275. help="whether calculate and record op activations.",
  276. )
  277. parser.add_argument(
  278. "--logging_to_stdout",
  279. action="store_true",
  280. help="whether print all calculated statistic details.",
  281. )
  282. parser.add_argument(
  283. "--all",
  284. action="store_true",
  285. help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
  286. )
  287. args = parser.parse_args()
  288. if args.load_input_data:
  289. logger.info("load data from {}".format(args.load_input_data))
  290. data = mge.load(args.load_input_data)
  291. if isinstance(data, dict):
  292. for v in data.values():
  293. assert isinstance(
  294. v, np.ndarray
  295. ), "data should provide ndarray; got {} instead".format(v)
  296. args.inp_dict = data
  297. elif isinstance(data, np.ndarray):
  298. args.input = data
  299. else:
  300. logger.error("input data should be a numpy array or a dict of numpy array")
  301. if args.all:
  302. args.cal_params = True
  303. args.cal_flops = True
  304. args.cal_activations = True
  305. args.logging_to_stdout = True
  306. if not args.log_path:
  307. args.log_path = "./log"
  308. kwargs = vars(args)
  309. kwargs.pop("all")
  310. kwargs.pop("load_input_data")
  311. visualize(**kwargs)
  312. if __name__ == "__main__":
  313. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台