You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profile_analyze.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import collections
  11. import json
  12. import re
  13. import textwrap
  14. import numpy as np
  15. from tabulate import tabulate
  16. from megengine.utils.profile_analyzer import (
  17. NonExistNum,
  18. ProfileAnalyzer,
  19. TimeFuncHelper,
  20. )
  21. def _tabulate_ml(tab, **kwargs):
  22. """Tabulate profile output with multi-line support."""
  23. new_tab = []
  24. new_tab_is_row = []
  25. for row in tab:
  26. col_lines = [str(i).split("\n") for i in row]
  27. max_nr_line = max(map(len, col_lines))
  28. new_tab_is_row.append(True)
  29. if max_nr_line > 1:
  30. new_tab_is_row.extend([False] * (max_nr_line - 1))
  31. for i in col_lines:
  32. if len(i) < max_nr_line:
  33. i.extend([""] * (max_nr_line - len(i)))
  34. new_tab.extend(zip(*col_lines))
  35. else:
  36. new_tab.append(row)
  37. assert len(new_tab_is_row) == len(new_tab)
  38. ret = [i + "\n" for i in tabulate(new_tab, **kwargs).split("\n")]
  39. for idx, val in enumerate(new_tab_is_row):
  40. if not val:
  41. ret[idx * 2 + 2] = ""
  42. return "".join(ret)[:-1]
  43. def _tabulate_confluence(tab, **kwargs):
  44. """Tabulate profile output."""
  45. kwargs.pop("tablefmt", None)
  46. s = tabulate(tab, tablefmt="orgtbl", **kwargs)
  47. lines = s.split("\n")
  48. lines[1] = lines[1].replace("+", "|")
  49. return "\n".join(lines)
  50. def main(passed_args=None): # pylint: disable=too-many-statements
  51. """Analyses profile info from :mod:`~.utils.profile_analyzer` .
  52. Run this file with ``--help`` to get more usage.
  53. """
  54. parser = argparse.ArgumentParser(
  55. description="analyze analyzer result",
  56. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  57. )
  58. parser.add_argument("dump")
  59. parser.add_argument(
  60. "-t",
  61. "--top",
  62. type=int,
  63. default=3,
  64. help="number of most time-consuming operators to print",
  65. )
  66. parser.add_argument(
  67. "--type", action="append", help="filter oprs in the top list by type"
  68. )
  69. parser.add_argument(
  70. "--aggregate-by",
  71. default=None,
  72. choices=["type"],
  73. help="aggragate profiling result by",
  74. )
  75. parser.add_argument(
  76. "--opr-name", help="filter oprs in the top list by regex of name"
  77. )
  78. parser.add_argument(
  79. "--input-dtype", type=str, help="filter oprs in the top list by input dtype"
  80. )
  81. parser.add_argument(
  82. "--top-end-key",
  83. default="end",
  84. choices=["end", "kern"],
  85. help="how time in top is calculated; end corresponds "
  86. "to total device time, and kern corresponds to only "
  87. "wait time",
  88. )
  89. parser.add_argument(
  90. "--aggregate",
  91. default=None,
  92. help="aggregate operations",
  93. choices=["max", "min", "sum", "mean"],
  94. )
  95. parser.add_argument(
  96. "--order-by",
  97. default="time",
  98. help="sort result according to given column; the param can be "
  99. "<col_name> or +<col_name>, meaning sorting in descending or "
  100. "ascending order respectively",
  101. )
  102. parser.add_argument(
  103. "--copy-time", action="store_true", help="show copy time related result"
  104. )
  105. parser.add_argument(
  106. "--min-time",
  107. type=float,
  108. default=float("-inf"),
  109. help="minimal time of a result to be printed",
  110. )
  111. parser.add_argument(
  112. "--max-time",
  113. type=float,
  114. default=float("inf"),
  115. help="maximal time of a result to be printed",
  116. )
  117. parser.add_argument(
  118. "--show-host", action="store_true", help="show host profiling info"
  119. )
  120. parser.add_argument(
  121. "--dump-only-opr",
  122. action="store_true",
  123. help="only dump operator info as plaintext; useful "
  124. "for diff between two filtered profile results",
  125. )
  126. parser.add_argument(
  127. "--confluence",
  128. "--wiki",
  129. action="store_true",
  130. help="output confluence-markdown-compatible table",
  131. )
  132. parser.add_argument(
  133. "--print-only",
  134. choices={"summary", "device", "host"},
  135. help="print only chosen info",
  136. )
  137. args = parser.parse_args(passed_args)
  138. opr_filters = []
  139. if args.type:
  140. opr_filters.append(lambda o, a, b: o["type"] in args.type)
  141. if args.opr_name:
  142. opr_filters.append(
  143. lambda o, a, b, r=re.compile(args.opr_name): r.match(o["name"])
  144. )
  145. if args.input_dtype:
  146. opr_filters.append(
  147. lambda o, a, b: any(
  148. [i["mem_plan"]["layout"]["dtype"] == args.input_dtype for i in a]
  149. )
  150. )
  151. if not opr_filters:
  152. def opr_filter(o, a, b): # pylint: disable=unused-argument
  153. return True
  154. else:
  155. def opr_filter(o, a, b):
  156. return all(i(o, a, b) for i in opr_filters)
  157. with open(args.dump) as fin:
  158. dump = json.load(fin)
  159. analyzer = ProfileAnalyzer(dump, opr_filter)
  160. analyzer_tot = ProfileAnalyzer(dump, lambda _, __, ___: True)
  161. def summary():
  162. device_end_func = TimeFuncHelper.eval_time_func("device", "end", np.max)
  163. device_kern_func = TimeFuncHelper.eval_time_func("device", "kern", np.max)
  164. host_end_func = TimeFuncHelper.eval_time_func("host", "end", np.max)
  165. def get_tot_time(func):
  166. rec = analyzer_tot.select(func, aggregate=np.sum)
  167. if not rec:
  168. return "N/A"
  169. rec = rec[0]
  170. return rec.time
  171. tab = []
  172. tot_dev_time = get_tot_time(device_end_func)
  173. tot_host_time = get_tot_time(host_end_func)
  174. tab.append(("total device time", tot_dev_time))
  175. tab.append(("total host time", tot_host_time))
  176. if args.copy_time:
  177. def fmt(a, b):
  178. a = a[0]
  179. b = b[0]
  180. return "tot={:.4f} avg={:.4f}".format(a.time, b.time)
  181. tab.append(
  182. (
  183. "copy time",
  184. fmt(
  185. analyzer.select(
  186. device_end_func,
  187. lambda opr: opr.opr_info["type"] == "Copy",
  188. aggregate=np.sum,
  189. ),
  190. analyzer.select(
  191. device_end_func,
  192. lambda opr: opr.opr_info["type"] == "Copy",
  193. aggregate=np.mean,
  194. ),
  195. ),
  196. )
  197. )
  198. tab.append(
  199. (
  200. "copy wait time",
  201. fmt(
  202. analyzer.select(
  203. device_kern_func,
  204. lambda opr: opr.opr_info["type"] == "Copy",
  205. aggregate=np.sum,
  206. ),
  207. analyzer.select(
  208. device_kern_func,
  209. lambda opr: opr.opr_info["type"] == "Copy",
  210. aggregate=np.mean,
  211. ),
  212. ),
  213. )
  214. )
  215. if args.confluence:
  216. tab_str = _tabulate_confluence(tab, headers=["name", "value"])
  217. else:
  218. tab_str = tabulate(tab)
  219. return tab_str, tot_dev_time, tot_host_time
  220. def prof_details(prof_type, tot_time):
  221. tab = []
  222. def func(
  223. opr,
  224. *,
  225. f0=TimeFuncHelper.eval_time_func(prof_type, args.top_end_key, np.max)
  226. ):
  227. t = f0(opr)
  228. if t is not None and (t < args.min_time or t > args.max_time):
  229. return None
  230. return t
  231. records = analyzer.select(
  232. func,
  233. aggregate=args.aggregate,
  234. aggregate_by=args.aggregate_by,
  235. top_k=args.top,
  236. sort_by=args.order_by,
  237. )
  238. if args.dump_only_opr:
  239. ret = []
  240. for i in records:
  241. ret.append(" ".join(i.info.values()))
  242. return "\n".join(ret)
  243. def format_shapes(shapes, layouts=None, sep="\n"):
  244. if isinstance(shapes, NonExistNum) or shapes is None:
  245. return repr(shapes)
  246. if layouts is None:
  247. layouts = [None] * len(shapes)
  248. comp = []
  249. for i, j in zip(shapes, layouts):
  250. i = "{" + ",".join(map(str, i)) + "}"
  251. if j:
  252. i += "\n -[" + ",".join(map(str, j)) + "]"
  253. comp.append(i)
  254. return sep.join(comp)
  255. def fix_num_and_find_unit(x, base):
  256. if isinstance(x, NonExistNum) or (
  257. isinstance(x, float) and not np.isfinite(x)
  258. ):
  259. return x, ""
  260. unit = iter(["", "K", "M", "G", "T", "P"])
  261. while x >= base:
  262. x /= base
  263. next(unit)
  264. return x, next(unit)
  265. def get_number_with_unit(num, unit, base, sep="\n"):
  266. num, unit_prefix = fix_num_and_find_unit(num, base)
  267. if isinstance(unit, list):
  268. unit = unit[int(unit_prefix != "")]
  269. return ("{:.2f}" + sep + "{}{}").format(num, unit_prefix, unit)
  270. if args.confluence:
  271. rows = []
  272. cum_time = 0
  273. max_time = max([r.time for r in records])
  274. max_bandwidth = max([r.bandwidth for r in records])
  275. max_flops = max(
  276. [r.flops for r in records if not isinstance(r.flops, NonExistNum)]
  277. )
  278. bar_length = 15
  279. for idx, record in enumerate(records):
  280. cum_time += record.time
  281. opr_info = [("opr " + k, v) for k, v in record.info.items()]
  282. row = collections.OrderedDict(
  283. [
  284. ("#", idx),
  285. ("time", "{:.3}".format(record.time)),
  286. ("ratio", "{:.1f}%".format(record.time / tot_time * 100)),
  287. ("time bar", "#" * int(record.time / max_time * bar_length)),
  288. ("cum-time", cum_time),
  289. ("cum-time ratio", cum_time / tot_time),
  290. ]
  291. + opr_info
  292. + [
  293. (
  294. "computation (MFLO)",
  295. "{:.1f}".format(record.computation / 1000 ** 2),
  296. ),
  297. ("MFLOPS", "{:.1f}".format(record.flops / 1000 ** 2)),
  298. (
  299. "MFLOPS-bar",
  300. ""
  301. if isinstance(record.flops, NonExistNum)
  302. else ("#" * int(record.flops / max_flops * bar_length)),
  303. ),
  304. ("memory (MB)", "{:.1f}".format(record.memory / 1024 ** 2)),
  305. (
  306. "bandwidth (MiB/s)",
  307. "{:.1f}".format(record.bandwidth / 1024 ** 2),
  308. ),
  309. (
  310. "bandwidth bar",
  311. "#" * int(record.bandwidth / max_bandwidth * bar_length),
  312. ),
  313. (
  314. "in_shapes",
  315. format_shapes(
  316. record.in_shapes, record.in_layouts, sep=", "
  317. ),
  318. ),
  319. ("out_shapes", format_shapes(record.out_shapes, sep=", ")),
  320. ]
  321. )
  322. rows.append(row)
  323. headers = list(rows[0].keys())
  324. tab = [[row[i] for i in headers] for row in rows]
  325. return _tabulate_confluence(tab, headers=headers)
  326. else:
  327. cum_time = 0
  328. for idx, record in enumerate(records):
  329. cum_time += record.time
  330. tab.append(
  331. (
  332. "#{}\n{:.3}\n{:.1f}%".format(
  333. idx, record.time, record.time / tot_time * 100
  334. ),
  335. "{:.3}\n{:.1f}%".format(cum_time, cum_time / tot_time * 100),
  336. "\n".join(
  337. "\n- ".join(textwrap.wrap(str(i), width=30))
  338. for i in record.info.values()
  339. ),
  340. get_number_with_unit(record.computation, "FLO", 1000),
  341. get_number_with_unit(record.flops, "FLOPS", 1000),
  342. get_number_with_unit(record.memory, ["byte", "iB"], 1024),
  343. get_number_with_unit(
  344. record.bandwidth, ["byte/s", "iB/s"], 1024
  345. ),
  346. format_shapes(record.in_shapes, record.in_layouts),
  347. format_shapes(record.out_shapes),
  348. )
  349. )
  350. return _tabulate_ml(
  351. tab,
  352. headers=[
  353. "{} self time".format(prof_type),
  354. "cumulative",
  355. "operator info",
  356. "computation",
  357. "FLOPS",
  358. "memory",
  359. "bandwidth",
  360. "in_shapes",
  361. "out_shapes",
  362. ],
  363. tablefmt="fancy_grid",
  364. )
  365. summary_tab, tot_dev_time, tot_host_time = summary()
  366. if args.print_only:
  367. print(
  368. {
  369. "summary": lambda: summary_tab,
  370. "device": lambda: prof_details("device", tot_dev_time),
  371. "host": lambda: prof_details("host", tot_host_time),
  372. }[args.print_only]()
  373. )
  374. else:
  375. print(summary_tab)
  376. print()
  377. print(prof_details("device", tot_dev_time))
  378. if args.show_host:
  379. print()
  380. print(prof_details("host", tot_host_time))
  381. if __name__ == "__main__":
  382. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台