You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_visualize.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. #! /usr/bin/env python3
  2. import argparse
  3. import logging
  4. import re
  5. from collections import namedtuple
  6. import numpy as np
  7. from tqdm import tqdm
  8. import megengine as mge
  9. from megengine.core.tensor.dtype import is_quantize
  10. from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
  11. from megengine.utils.module_stats import (
  12. enable_receptive_field,
  13. get_activation_stats,
  14. get_op_stats,
  15. get_param_stats,
  16. print_activations_stats,
  17. print_op_stats,
  18. print_param_stats,
  19. print_summary,
  20. sizeof_fmt,
  21. sum_activations_stats,
  22. sum_op_stats,
  23. sum_param_stats,
  24. )
  25. from megengine.utils.network import Network
  26. logger = get_logger(__name__)
  27. def visualize(
  28. model_path: str,
  29. log_path: str,
  30. input: np.ndarray = None,
  31. inp_dict: dict = None,
  32. cal_params: bool = True,
  33. cal_flops: bool = True,
  34. cal_activations: bool = True,
  35. logging_to_stdout: bool = True,
  36. bar_length_max: int = 20,
  37. ):
  38. r"""Load megengine dumped model and visualize graph structure with tensorboard log files.
  39. Can also record and print model's statistics like :func:`~.module_stats`
  40. Args:
  41. model_path: dir path for megengine dumped model.
  42. log_path: dir path for tensorboard graph log.
  43. input: user defined input data for running model and calculating stats,
  44. alternative with inp_dict, used when the model has only one input.
  45. inp_dict: input dict for running model and calculating stats, alternative with
  46. input, used when the model has more than one input.
  47. When both input and inp_dict are None, a random input will be used.
  48. cal_params: whether calculate and record params size.
  49. cal_flops: whether calculate and record op flops.
  50. cal_activations: whether calculate and record op activations.
  51. logging_to_stdout: whether print all calculated statistic details.
  52. bar_length_max: size of bar indicating max flops or parameter size in net stats.
  53. model_path: str:
  54. log_path: str:
  55. input: np.ndarray:
  56. inp_dict: dict:
  57. cal_params: bool:
  58. cal_flops: bool:
  59. cal_activations: bool:
  60. logging_to_stdout: bool:
  61. bar_length_max: int:
  62. """
  63. if log_path:
  64. try:
  65. from tensorboard.compat.proto.attr_value_pb2 import AttrValue
  66. from tensorboard.compat.proto.config_pb2 import RunMetadata
  67. from tensorboard.compat.proto.graph_pb2 import GraphDef
  68. from tensorboard.compat.proto.node_def_pb2 import NodeDef
  69. from tensorboard.compat.proto.step_stats_pb2 import (
  70. AllocatorMemoryUsed,
  71. DeviceStepStats,
  72. NodeExecStats,
  73. StepStats,
  74. )
  75. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  76. from tensorboard.compat.proto.versions_pb2 import VersionDef
  77. from tensorboardX import SummaryWriter
  78. except ImportError:
  79. logger.error(
  80. "TensorBoard and TensorboardX are required for visualize.",
  81. exc_info=True,
  82. )
  83. return
  84. enable_receptive_field()
  85. graph = Network.load(model_path)
  86. graph.reset_batch_size(1)
  87. has_input = False
  88. if input is not None or inp_dict is not None:
  89. has_input = True
  90. repl_dict = {}
  91. inp_vars = graph.input_vars
  92. if inp_dict is not None:
  93. assert len(inp_dict) == len(
  94. inp_vars
  95. ), "Inputs are not sufficient for calculation."
  96. for v in inp_vars:
  97. new_input = graph.make_const(inp_dict[v.name], name=v.name)
  98. repl_dict[v] = new_input
  99. else:
  100. assert len(inp_vars) == 1, "The graph needs more than one input."
  101. inp_var = inp_vars[0]
  102. repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
  103. graph.replace_vars(repl_dict=repl_dict)
  104. graph._compile()
  105. def process_name(name):
  106. # nodes that start with point or contain float const will lead to display bug
  107. if not re.match(r"^[+-]?\d*\.\d*", name):
  108. name = name.replace(".", "/")
  109. return name.encode(encoding="utf-8")
  110. summary = [["item", "value"]]
  111. node_list = []
  112. flops_list = []
  113. params_list = []
  114. activations_list = []
  115. total_stats = namedtuple(
  116. "total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"]
  117. )
  118. stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
  119. disable_stats = False
  120. for node in tqdm(graph.all_oprs):
  121. if hasattr(node, "output_idx"):
  122. node_oup = node.outputs[node.output_idx]
  123. else:
  124. if len(node.outputs) != 1:
  125. logger.warning(
  126. "OpNode {} has more than one output and not has 'output_idx' attr.".format(
  127. node
  128. )
  129. )
  130. node_oup = node.outputs[0]
  131. inp_list = [process_name(var.owner.name) for var in node.inputs]
  132. if log_path:
  133. # detail format see tensorboard/compat/proto/attr_value.proto
  134. attr = {
  135. "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
  136. "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
  137. }
  138. if node_oup.shape:
  139. attr["_output_shapes"] = AttrValue(
  140. list=AttrValue.ListValue(
  141. shape=[
  142. TensorShapeProto(
  143. dim=[
  144. TensorShapeProto.Dim(size=d) for d in node_oup.shape
  145. ]
  146. )
  147. ]
  148. )
  149. )
  150. else:
  151. disable_stats = True
  152. logger.warning(
  153. f"OpNode {node.name} do not has shape attr, would not calculate flops/params/activations for this net."
  154. )
  155. if not disable_stats:
  156. if cal_flops:
  157. flops_stats = get_op_stats(node, node.inputs, node.outputs)
  158. if flops_stats is not None:
  159. # add op flops attr
  160. if log_path and hasattr(flops_stats, "flops_num"):
  161. attr["flops"] = AttrValue(
  162. s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
  163. )
  164. flops_stats["name"] = node.name
  165. flops_stats["class_name"] = node.type
  166. flops_list.append(flops_stats)
  167. if cal_activations:
  168. acts = get_activation_stats(node_oup, has_input=has_input)
  169. acts["name"] = node.name
  170. acts["class_name"] = node.type
  171. activations_list.append(acts)
  172. if cal_params:
  173. if node.type == "ImmutableTensor":
  174. param_stats = get_param_stats(node_oup)
  175. # add tensor size attr
  176. if log_path:
  177. attr["size"] = AttrValue(
  178. s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
  179. )
  180. param_stats["name"] = node.name
  181. params_list.append(param_stats)
  182. if log_path:
  183. node_list.append(
  184. NodeDef(
  185. name=process_name(node.name),
  186. op=node.type,
  187. input=inp_list,
  188. attr=attr,
  189. )
  190. )
  191. # summary
  192. extra_info = {
  193. "#ops": len(graph.all_oprs),
  194. "#params": len(params_list),
  195. }
  196. (
  197. total_flops,
  198. total_param_dims,
  199. total_param_size,
  200. total_act_dims,
  201. total_act_size,
  202. ) = (0, 0, 0, 0, 0)
  203. if not disable_stats:
  204. if cal_params:
  205. total_param_dims, total_param_size, params_list = sum_param_stats(
  206. params_list, bar_length_max
  207. )
  208. extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
  209. extra_info["total_param_size"] = sizeof_fmt(total_param_size)
  210. if logging_to_stdout:
  211. print_param_stats(params_list)
  212. if cal_flops:
  213. total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
  214. extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
  215. if logging_to_stdout:
  216. print_op_stats(flops_list)
  217. if cal_activations:
  218. total_act_dims, total_act_size, activations_list = sum_activations_stats(
  219. activations_list, bar_length_max
  220. )
  221. extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
  222. extra_info["total_act_size"] = sizeof_fmt(total_act_size)
  223. if logging_to_stdout:
  224. print_activations_stats(activations_list, has_input=has_input)
  225. if cal_flops and cal_params:
  226. extra_info["flops/param_size"] = "{:3.3f}".format(
  227. total_flops / total_param_size
  228. )
  229. print_summary(**extra_info)
  230. if log_path:
  231. graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
  232. device = "/device:CPU:0"
  233. stepstats = RunMetadata(
  234. step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
  235. )
  236. writer = SummaryWriter(log_path)
  237. writer._get_file_writer().add_graph((graph_def, stepstats))
  238. return (
  239. total_stats(
  240. param_size=total_param_size,
  241. param_dims=total_param_dims,
  242. flops=total_flops,
  243. act_size=total_act_size,
  244. act_dims=total_act_dims,
  245. ),
  246. stats_details(
  247. params=params_list, flops=flops_list, activations=activations_list
  248. ),
  249. )
  250. def main():
  251. parser = argparse.ArgumentParser(
  252. description="load a megengine dumped model and export log file for tensorboard visualization.",
  253. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  254. )
  255. parser.add_argument("model_path", help="dumped model path.")
  256. parser.add_argument("--log_path", help="tensorboard log path.")
  257. parser.add_argument(
  258. "--load_input_data",
  259. help="load input data from pickle file; it should be a numpy array or a dict of numpy array",
  260. )
  261. parser.add_argument(
  262. "--bar_length_max",
  263. type=int,
  264. default=20,
  265. help="size of bar indicating max flops or parameter size in net stats.",
  266. )
  267. parser.add_argument(
  268. "--cal_params",
  269. action="store_true",
  270. help="whether calculate and record params size.",
  271. )
  272. parser.add_argument(
  273. "--cal_flops",
  274. action="store_true",
  275. help="whether calculate and record op flops.",
  276. )
  277. parser.add_argument(
  278. "--cal_activations",
  279. action="store_true",
  280. help="whether calculate and record op activations.",
  281. )
  282. parser.add_argument(
  283. "--logging_to_stdout",
  284. action="store_true",
  285. help="whether print all calculated statistic details.",
  286. )
  287. parser.add_argument(
  288. "--all",
  289. action="store_true",
  290. help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
  291. )
  292. args = parser.parse_args()
  293. if args.load_input_data:
  294. logger.info("load data from {}".format(args.load_input_data))
  295. data = mge.load(args.load_input_data)
  296. if isinstance(data, dict):
  297. for v in data.values():
  298. assert isinstance(
  299. v, np.ndarray
  300. ), "data should provide ndarray; got {} instead".format(v)
  301. args.inp_dict = data
  302. elif isinstance(data, np.ndarray):
  303. args.input = data
  304. else:
  305. logger.error("input data should be a numpy array or a dict of numpy array")
  306. if args.all:
  307. args.cal_params = True
  308. args.cal_flops = True
  309. args.cal_activations = True
  310. args.logging_to_stdout = True
  311. if not args.log_path:
  312. args.log_path = "./log"
  313. kwargs = vars(args)
  314. kwargs.pop("all")
  315. kwargs.pop("load_input_data")
  316. visualize(**kwargs)
  317. if __name__ == "__main__":
  318. main()