You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profile_analyze.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. #! /usr/bin/env python3
  2. import argparse
  3. import collections
  4. import json
  5. import re
  6. import textwrap
  7. import numpy as np
  8. from tabulate import tabulate
  9. from megengine.utils.profile_analyzer import (
  10. NonExistNum,
  11. ProfileAnalyzer,
  12. TimeFuncHelper,
  13. )
  14. def _tabulate_ml(tab, **kwargs):
  15. r"""Tabulate profile output with multi-line support."""
  16. new_tab = []
  17. new_tab_is_row = []
  18. for row in tab:
  19. col_lines = [str(i).split("\n") for i in row]
  20. max_nr_line = max(map(len, col_lines))
  21. new_tab_is_row.append(True)
  22. if max_nr_line > 1:
  23. new_tab_is_row.extend([False] * (max_nr_line - 1))
  24. for i in col_lines:
  25. if len(i) < max_nr_line:
  26. i.extend([""] * (max_nr_line - len(i)))
  27. new_tab.extend(zip(*col_lines))
  28. else:
  29. new_tab.append(row)
  30. assert len(new_tab_is_row) == len(new_tab)
  31. ret = [i + "\n" for i in tabulate(new_tab, **kwargs).split("\n")]
  32. for idx, val in enumerate(new_tab_is_row):
  33. if not val:
  34. ret[idx * 2 + 2] = ""
  35. return "".join(ret)[:-1]
  36. def _tabulate_confluence(tab, **kwargs):
  37. r"""Tabulate profile output."""
  38. kwargs.pop("tablefmt", None)
  39. s = tabulate(tab, tablefmt="orgtbl", **kwargs)
  40. lines = s.split("\n")
  41. lines[1] = lines[1].replace("+", "|")
  42. return "\n".join(lines)
  43. def main(passed_args=None): # pylint: disable=too-many-statements
  44. r"""Analyses profile info from :mod:`~.utils.profile_analyzer` .
  45. Run this file with ``--help`` to get more usage.
  46. """
  47. parser = argparse.ArgumentParser(
  48. description="analyze analyzer result",
  49. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  50. )
  51. parser.add_argument("dump")
  52. parser.add_argument(
  53. "-t",
  54. "--top",
  55. type=int,
  56. default=3,
  57. help="number of most time-consuming operators to print",
  58. )
  59. parser.add_argument(
  60. "--type", action="append", help="filter oprs in the top list by type"
  61. )
  62. parser.add_argument(
  63. "--aggregate-by",
  64. default=None,
  65. choices=["type"],
  66. help="aggragate profiling result by",
  67. )
  68. parser.add_argument(
  69. "--opr-name", help="filter oprs in the top list by regex of name"
  70. )
  71. parser.add_argument(
  72. "--input-dtype", type=str, help="filter oprs in the top list by input dtype"
  73. )
  74. parser.add_argument(
  75. "--top-end-key",
  76. default="end",
  77. choices=["end", "kern"],
  78. help="how time in top is calculated; end corresponds "
  79. "to total device time, and kern corresponds to only "
  80. "wait time",
  81. )
  82. parser.add_argument(
  83. "--aggregate",
  84. default=None,
  85. help="aggregate operations",
  86. choices=["max", "min", "sum", "mean"],
  87. )
  88. parser.add_argument(
  89. "--order-by",
  90. default="time",
  91. help="sort result according to given column; the param can be "
  92. "<col_name> or +<col_name>, meaning sorting in descending or "
  93. "ascending order respectively",
  94. )
  95. parser.add_argument(
  96. "--copy-time", action="store_true", help="show copy time related result"
  97. )
  98. parser.add_argument(
  99. "--min-time",
  100. type=float,
  101. default=float("-inf"),
  102. help="minimal time of a result to be printed",
  103. )
  104. parser.add_argument(
  105. "--max-time",
  106. type=float,
  107. default=float("inf"),
  108. help="maximal time of a result to be printed",
  109. )
  110. parser.add_argument(
  111. "--show-host", action="store_true", help="show host profiling info"
  112. )
  113. parser.add_argument(
  114. "--dump-only-opr",
  115. action="store_true",
  116. help="only dump operator info as plaintext; useful "
  117. "for diff between two filtered profile results",
  118. )
  119. parser.add_argument(
  120. "--confluence",
  121. "--wiki",
  122. action="store_true",
  123. help="output confluence-markdown-compatible table",
  124. )
  125. parser.add_argument(
  126. "--print-only",
  127. choices={"summary", "device", "host"},
  128. help="print only chosen info",
  129. )
  130. args = parser.parse_args(passed_args)
  131. opr_filters = []
  132. if args.type:
  133. opr_filters.append(lambda o, a, b: o["type"] in args.type)
  134. if args.opr_name:
  135. opr_filters.append(
  136. lambda o, a, b, r=re.compile(args.opr_name): r.match(o["name"])
  137. )
  138. if args.input_dtype:
  139. opr_filters.append(
  140. lambda o, a, b: any(
  141. [i["mem_plan"]["layout"]["dtype"] == args.input_dtype for i in a]
  142. )
  143. )
  144. if not opr_filters:
  145. def opr_filter(o, a, b): # pylint: disable=unused-argument
  146. return True
  147. else:
  148. def opr_filter(o, a, b):
  149. return all(i(o, a, b) for i in opr_filters)
  150. with open(args.dump) as fin:
  151. dump = json.load(fin)
  152. analyzer = ProfileAnalyzer(dump, opr_filter)
  153. analyzer_tot = ProfileAnalyzer(dump, lambda _, __, ___: True)
  154. def summary():
  155. device_end_func = TimeFuncHelper.eval_time_func("device", "end", np.max)
  156. device_kern_func = TimeFuncHelper.eval_time_func("device", "kern", np.max)
  157. host_end_func = TimeFuncHelper.eval_time_func("host", "end", np.max)
  158. def get_tot_time(func):
  159. rec = analyzer_tot.select(func, aggregate=np.sum)
  160. if not rec:
  161. return "N/A"
  162. rec = rec[0]
  163. return rec.time
  164. tab = []
  165. tot_dev_time = get_tot_time(device_end_func)
  166. tot_host_time = get_tot_time(host_end_func)
  167. tab.append(("total device time", tot_dev_time))
  168. tab.append(("total host time", tot_host_time))
  169. if args.copy_time:
  170. def fmt(a, b):
  171. a = a[0]
  172. b = b[0]
  173. return "tot={:.4f} avg={:.4f}".format(a.time, b.time)
  174. tab.append(
  175. (
  176. "copy time",
  177. fmt(
  178. analyzer.select(
  179. device_end_func,
  180. lambda opr: opr.opr_info["type"] == "Copy",
  181. aggregate=np.sum,
  182. ),
  183. analyzer.select(
  184. device_end_func,
  185. lambda opr: opr.opr_info["type"] == "Copy",
  186. aggregate=np.mean,
  187. ),
  188. ),
  189. )
  190. )
  191. tab.append(
  192. (
  193. "copy wait time",
  194. fmt(
  195. analyzer.select(
  196. device_kern_func,
  197. lambda opr: opr.opr_info["type"] == "Copy",
  198. aggregate=np.sum,
  199. ),
  200. analyzer.select(
  201. device_kern_func,
  202. lambda opr: opr.opr_info["type"] == "Copy",
  203. aggregate=np.mean,
  204. ),
  205. ),
  206. )
  207. )
  208. if args.confluence:
  209. tab_str = _tabulate_confluence(tab, headers=["name", "value"])
  210. else:
  211. tab_str = tabulate(tab)
  212. return tab_str, tot_dev_time, tot_host_time
  213. def prof_details(prof_type, tot_time):
  214. tab = []
  215. def func(
  216. opr,
  217. *,
  218. f0=TimeFuncHelper.eval_time_func(prof_type, args.top_end_key, np.max)
  219. ):
  220. t = f0(opr)
  221. if t is not None and (t < args.min_time or t > args.max_time):
  222. return None
  223. return t
  224. records = analyzer.select(
  225. func,
  226. aggregate=args.aggregate,
  227. aggregate_by=args.aggregate_by,
  228. top_k=args.top,
  229. sort_by=args.order_by,
  230. )
  231. if args.dump_only_opr:
  232. ret = []
  233. for i in records:
  234. ret.append(" ".join(i.info.values()))
  235. return "\n".join(ret)
  236. def format_shapes(shapes, layouts=None, sep="\n"):
  237. if isinstance(shapes, NonExistNum) or shapes is None:
  238. return repr(shapes)
  239. if layouts is None:
  240. layouts = [None] * len(shapes)
  241. comp = []
  242. for i, j in zip(shapes, layouts):
  243. i = "{" + ",".join(map(str, i)) + "}"
  244. if j:
  245. i += "\n -[" + ",".join(map(str, j)) + "]"
  246. comp.append(i)
  247. return sep.join(comp)
  248. def fix_num_and_find_unit(x, base):
  249. if isinstance(x, NonExistNum) or (
  250. isinstance(x, float) and not np.isfinite(x)
  251. ):
  252. return x, ""
  253. unit = iter(["", "K", "M", "G", "T", "P"])
  254. while x >= base:
  255. x /= base
  256. next(unit)
  257. return x, next(unit)
  258. def get_number_with_unit(num, unit, base, sep="\n"):
  259. num, unit_prefix = fix_num_and_find_unit(num, base)
  260. if isinstance(unit, list):
  261. unit = unit[int(unit_prefix != "")]
  262. return ("{:.2f}" + sep + "{}{}").format(num, unit_prefix, unit)
  263. if args.confluence:
  264. rows = []
  265. cum_time = 0
  266. max_time = max([r.time for r in records])
  267. max_bandwidth = max([r.bandwidth for r in records])
  268. max_flops = max(
  269. [r.flops for r in records if not isinstance(r.flops, NonExistNum)]
  270. )
  271. bar_length = 15
  272. for idx, record in enumerate(records):
  273. cum_time += record.time
  274. opr_info = [("opr " + k, v) for k, v in record.info.items()]
  275. row = collections.OrderedDict(
  276. [
  277. ("#", idx),
  278. ("time", "{:.3}".format(record.time)),
  279. ("ratio", "{:.1f}%".format(record.time / tot_time * 100)),
  280. ("time bar", "#" * int(record.time / max_time * bar_length)),
  281. ("cum-time", cum_time),
  282. ("cum-time ratio", cum_time / tot_time),
  283. ]
  284. + opr_info
  285. + [
  286. (
  287. "computation (MFLO)",
  288. "{:.1f}".format(record.computation / 1000 ** 2),
  289. ),
  290. ("MFLOPS", "{:.1f}".format(record.flops / 1000 ** 2)),
  291. (
  292. "MFLOPS-bar",
  293. ""
  294. if isinstance(record.flops, NonExistNum)
  295. else ("#" * int(record.flops / max_flops * bar_length)),
  296. ),
  297. ("memory (MB)", "{:.1f}".format(record.memory / 1024 ** 2)),
  298. (
  299. "bandwidth (MiB/s)",
  300. "{:.1f}".format(record.bandwidth / 1024 ** 2),
  301. ),
  302. (
  303. "bandwidth bar",
  304. "#" * int(record.bandwidth / max_bandwidth * bar_length),
  305. ),
  306. (
  307. "in_shapes",
  308. format_shapes(
  309. record.in_shapes, record.in_layouts, sep=", "
  310. ),
  311. ),
  312. ("out_shapes", format_shapes(record.out_shapes, sep=", ")),
  313. ]
  314. )
  315. rows.append(row)
  316. headers = list(rows[0].keys())
  317. tab = [[row[i] for i in headers] for row in rows]
  318. return _tabulate_confluence(tab, headers=headers)
  319. else:
  320. cum_time = 0
  321. for idx, record in enumerate(records):
  322. cum_time += record.time
  323. tab.append(
  324. (
  325. "#{}\n{:.3}\n{:.1f}%".format(
  326. idx, record.time, record.time / tot_time * 100
  327. ),
  328. "{:.3}\n{:.1f}%".format(cum_time, cum_time / tot_time * 100),
  329. "\n".join(
  330. "\n- ".join(textwrap.wrap(str(i), width=30))
  331. for i in record.info.values()
  332. ),
  333. get_number_with_unit(record.computation, "FLO", 1000),
  334. get_number_with_unit(record.flops, "FLOPS", 1000),
  335. get_number_with_unit(record.memory, ["byte", "iB"], 1024),
  336. get_number_with_unit(
  337. record.bandwidth, ["byte/s", "iB/s"], 1024
  338. ),
  339. format_shapes(record.in_shapes, record.in_layouts),
  340. format_shapes(record.out_shapes),
  341. )
  342. )
  343. return _tabulate_ml(
  344. tab,
  345. headers=[
  346. "{} self time".format(prof_type),
  347. "cumulative",
  348. "operator info",
  349. "computation",
  350. "FLOPS",
  351. "memory",
  352. "bandwidth",
  353. "in_shapes",
  354. "out_shapes",
  355. ],
  356. tablefmt="fancy_grid",
  357. )
  358. summary_tab, tot_dev_time, tot_host_time = summary()
  359. if args.print_only:
  360. print(
  361. {
  362. "summary": lambda: summary_tab,
  363. "device": lambda: prof_details("device", tot_dev_time),
  364. "host": lambda: prof_details("host", tot_host_time),
  365. }[args.print_only]()
  366. )
  367. else:
  368. print(summary_tab)
  369. print()
  370. print(prof_details("device", tot_dev_time))
  371. if args.show_host:
  372. print()
  373. print(prof_details("host", tot_host_time))
  374. if __name__ == "__main__":
  375. main()