You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profile_analyzer.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. from typing import Callable, List, Optional, Union
  13. import numpy as np
  14. class NonExistNum:
  15. r"""An object that behaves like a number but means a field does not exist; It is
  16. always greater than any real number.
  17. """
  18. def __truediv__(self, _):
  19. return self
  20. def __add__(self, rhs):
  21. return rhs
  22. def __radd__(self, lhs):
  23. return lhs
  24. def __neg__(self):
  25. return self
  26. def __gt__(self, rhs):
  27. if isinstance(rhs) is NonExistNum:
  28. return id(self) > id(rhs)
  29. return True
  30. def __ge__(self, rhs):
  31. return self > rhs or self == rhs
  32. def __lt__(self, rhs):
  33. if isinstance(rhs) is NonExistNum:
  34. return id(self) < id(rhs)
  35. return False
  36. def __le__(self, rhs):
  37. return self < rhs or self == rhs
  38. def __eq__(self, rhs):
  39. return self is rhs
  40. def __format__(self, spec):
  41. return "N/A"
  42. def __repr__(self):
  43. return "N/A"
  44. class OprProfRst:
  45. r"""Opr profiling result dumped from megengine profiler.
  46. Args:
  47. entry: profiling json exec_graph items. Opr profiling initialization,
  48. which sets up name, type and id of opr_info.
  49. """
  50. opr_info = None
  51. r"""A dict containing operator info: name, id and type."""
  52. time_dict = None
  53. r"""
  54. A mapping from ``"host"`` or ``"device"`` to list of profiling
  55. results."""
  56. footprint = None
  57. r"""
  58. A mapping from ``"memory"`` or ``"computation"`` to the actual number
  59. of corresponding operations."""
  60. def __init__(self, entry: dict):
  61. assert isinstance(entry, dict)
  62. self.opr_info = collections.OrderedDict()
  63. for key in ["name", "type", "id"]:
  64. self.opr_info[key] = entry[key]
  65. self.time_dict = collections.defaultdict(list)
  66. self.footprint = collections.defaultdict(NonExistNum)
  67. def update_device_prof_info(self, dev_time: dict):
  68. """Updates device profiling info.
  69. Args:
  70. dev_time: device time for single opr,
  71. is an attribute of profiling result.
  72. """
  73. assert isinstance(dev_time, dict)
  74. self.time_dict["device"].append(copy.deepcopy(dev_time))
  75. def update_host_prof_info(self, host_time: dict):
  76. r"""Updates host profiling info.
  77. Args:
  78. host_time: host time for single opr,
  79. is an attribute of profiling result.
  80. """
  81. assert isinstance(host_time, dict)
  82. self.time_dict["host"].append(copy.deepcopy(host_time))
  83. def update_footprint(self, footprint: dict):
  84. r"""Updates opr footprint.
  85. Args:
  86. footprint: footprint for single opr,
  87. is an attribute of profiling result.
  88. """
  89. assert isinstance(footprint, dict)
  90. self.footprint.update(footprint)
  91. class Record:
  92. r"""A record of analyzing result
  93. Args:
  94. time: opr running time, evaluated by applying users providing
  95. function to OprProfRst.
  96. info: opr information, could be original opr information or
  97. aggregate infomation if aggregating enabled.
  98. footprint: contains footprint information, for now, we have
  99. ``"computation"``, ``"memory"``, ``"in_shapes"``, ``"out_shapes"``.
  100. """
  101. __slot__ = [
  102. "time",
  103. "info",
  104. "computation",
  105. "memory",
  106. "in_shapes",
  107. "in_layouts",
  108. "out_shapes",
  109. "flops",
  110. "bandwidth",
  111. "opr_id",
  112. ]
  113. def __init__(self, time: float, info: dict, footprint: dict):
  114. assert isinstance(footprint, dict)
  115. self.time = time
  116. self.info = collections.OrderedDict(copy.deepcopy(info))
  117. self.computation = footprint["computation"] or NonExistNum()
  118. self.memory = footprint["memory"]
  119. self.in_shapes = footprint["in_shapes"]
  120. self.in_layouts = footprint.get("in_layouts")
  121. self.out_shapes = footprint["out_shapes"]
  122. self.flops = self.computation / self.time
  123. self.bandwidth = self.memory / self.time
  124. self.opr_id = info.get("id")
  125. if isinstance(self.opr_id, str) and self.opr_id != "N/A":
  126. self.opr_id = int(self.opr_id)
  127. def get_column_by_name(self, name: str = None):
  128. r"""Extracts column value by its column name.
  129. Args:
  130. name: column name, None for time.
  131. """
  132. if name is None:
  133. name = "time"
  134. return getattr(self, name)
  135. class ProfileAnalyzer:
  136. r"""Initializes ProfileAnalyzer.
  137. Args:
  138. obj: dict dumped from json str.
  139. opr_filter: function that filter oprs.
  140. """
  141. def __init__(self, obj: dict, opr_filter: Callable = lambda opr, inp, out: True):
  142. self._opr_set = dict() # type: dict
  143. assert isinstance(obj, dict), type(obj)
  144. varz = obj["graph_exec"]["var"]
  145. for opr_id, entry in obj["graph_exec"]["operator"].items():
  146. inp = [varz[i] for i in entry["input"]]
  147. out = [varz[i] for i in entry["output"]]
  148. if opr_filter(entry, inp, out):
  149. self._opr_set[opr_id] = OprProfRst(entry)
  150. for opr_id, entry in obj["profiler"]["device"].items():
  151. if opr_id not in self._opr_set:
  152. continue
  153. opr = self._opr_set[opr_id]
  154. for _, time in entry.items():
  155. opr.update_device_prof_info(time)
  156. for opr_id, entry in obj["profiler"]["host"].items():
  157. if opr_id not in self._opr_set:
  158. continue
  159. opr = self._opr_set[opr_id]
  160. for _, time in entry.items():
  161. opr.update_host_prof_info(time)
  162. for opr_id, entry in obj["profiler"].get("opr_footprint", {}).items():
  163. if opr_id not in self._opr_set:
  164. continue
  165. opr = self._opr_set[opr_id]
  166. opr.update_footprint(entry)
  167. def _aggregate(
  168. self, records: List[Record], aop: Union[str, Callable], atype: Optional[str]
  169. ) -> List[Record]:
  170. r"""Aggregate operation.
  171. Args:
  172. records: selected records.
  173. aop: aggregate operation, if aop is str, we would replace it
  174. with associated numpy function wth aop name".
  175. atype: the type aggregated by, None for aggregating all into single
  176. record.
  177. """
  178. if aop is None:
  179. assert atype is None, "must specify aggregate op"
  180. return records
  181. if isinstance(aop, str):
  182. aop = getattr(np, aop)
  183. type2stat = collections.defaultdict(lambda: [[], [], []]) # type: dict
  184. for item in records:
  185. if atype == "type":
  186. d = type2stat[item.info["type"]]
  187. else:
  188. d = type2stat["all"]
  189. d[0].append(item.time)
  190. d[1].append(item.computation)
  191. d[2].append(item.memory)
  192. rst = []
  193. for opr_type in type2stat.keys():
  194. time, computation, memory = type2stat[opr_type]
  195. nr_oprs = len(time)
  196. time_rst = aop(time)
  197. comp_rst = aop(computation)
  198. mem_rst = aop(memory)
  199. item = Record(
  200. time_rst,
  201. {"type": opr_type, "count": nr_oprs, "id": "N/A"},
  202. {
  203. "computation": comp_rst,
  204. "memory": mem_rst,
  205. "in_shapes": None,
  206. "out_shapes": None,
  207. },
  208. )
  209. rst.append(item)
  210. return rst
  211. def _sort(self, records: List[Record], sort_by: str) -> List[Record]:
  212. r"""Sort operation.
  213. Args:
  214. records: the records after aggregate operation.
  215. sort_by: keyword for sorting the list.
  216. """
  217. if sort_by is None:
  218. return records
  219. if sort_by.startswith("+"):
  220. sort_by = sort_by[1:]
  221. key = lambda record: record.get_column_by_name(sort_by)
  222. else:
  223. key = lambda record: -record.get_column_by_name(sort_by)
  224. records.sort(key=key)
  225. return records
  226. def select(
  227. self,
  228. time_func: Callable,
  229. opr_filter: Callable = lambda opr: True,
  230. aggregate: Callable = None,
  231. aggregate_by: str = None,
  232. sort_by: str = None,
  233. top_k: int = 0,
  234. ) -> List[Record]:
  235. r"""Select operation.
  236. Args:
  237. time_func: time_func provided by user, would apply to every
  238. OprProfRst.
  239. opr_filter: filter satisfied operatiors.
  240. aggregate: function that apply to list of records which are
  241. aggregated by atype.
  242. aggregate_by: the type aggregated by.
  243. sort_by: keyword for sorting all records.
  244. top_k: specify the maximum number of records.
  245. Returns:
  246. the records that go through select, aggregate, sort.
  247. """
  248. records = []
  249. for opr in self._opr_set.values():
  250. if opr_filter(opr):
  251. time = time_func(opr)
  252. if time is None:
  253. continue
  254. item = Record(time, opr.opr_info, opr.footprint)
  255. records.append(item)
  256. records = self._aggregate(records, aggregate, aggregate_by)
  257. if not records:
  258. return records
  259. return self._sort(records, sort_by)[0 : len(records) if top_k == 0 else top_k]
  260. class TimeFuncHelper:
  261. r"""Time Function Helper for users."""
  262. @staticmethod
  263. def _eval_time(prof_type, end_key, func, opr_prof):
  264. r"""Eval time.
  265. Args:
  266. prof_type: host' or 'device'.
  267. end_key: kern' or 'end'.
  268. func: apply to list of all ``thread`` of ``gpu`` time.
  269. opr_prof: operator profiling result.
  270. Returns:
  271. time.
  272. """
  273. if prof_type not in opr_prof.time_dict:
  274. return None
  275. time = [time[end_key] - time["start"] for time in opr_prof.time_dict[prof_type]]
  276. return func(time)
  277. @staticmethod
  278. def eval_time_func(prof_type: str, end_key: str, func: Callable) -> float:
  279. r"""Eval oprerator profile time.
  280. Args:
  281. prof_type: host' or 'device'.
  282. end_key: kern' or 'end'.
  283. func: apply to list of all ``thread`` of ``gpu`` time.
  284. Returns:
  285. eval time results.
  286. """
  287. return functools.partial(TimeFuncHelper._eval_time, prof_type, end_key, func)
  288. @staticmethod
  289. def _min_start(
  290. prof_type, end_key, func, opr_prof
  291. ): # pylint: disable=unused-argument
  292. r"""Eval minimum start time.
  293. Args:
  294. prof_type(str): 'host' or 'device'.
  295. end_key(str): 'kern' or 'end'.
  296. func(function): apply to list of all ``thread`` of ``gpu`` time.
  297. opr_prof(OprProfRst): operator profiling result.
  298. Returns:
  299. time.
  300. """
  301. if prof_type not in opr_prof.time_dict:
  302. return None
  303. time = [time["start"] for time in opr_prof.time_dict[prof_type]]
  304. return np.min(time)
  305. @staticmethod
  306. def min_start_func(
  307. prof_type: str, end_key: str, func: Callable
  308. ) -> float: # pylint: disable=unused-argument
  309. r"""Eval oprerator profile min start time.
  310. Args:
  311. prof_type(str): 'host' or 'device'.
  312. end_key(str): 'kern' or 'end'.
  313. func(function): apply to list of all ``thread`` of ``gpu`` time.
  314. Returns:
  315. eval time results.
  316. """
  317. return functools.partial(TimeFuncHelper._min_start, prof_type, end_key, func)
  318. @staticmethod
  319. def _max_end(prof_type, end_key, func, opr_prof): # pylint: disable=unused-argument
  320. r"""Eval maximum end time
  321. Args:
  322. prof_type(str): 'host' or 'device'.
  323. end_key(str): 'kern' or 'end'.
  324. func(function): apply to list of all ``thread`` of ``gpu`` time.
  325. opr_prof(OprProfRst): operator profiling result.
  326. Returns:
  327. time.
  328. """
  329. if prof_type not in opr_prof.time_dict:
  330. return None
  331. time = [time["end"] for time in opr_prof.time_dict[prof_type]]
  332. return np.max(time)
  333. @staticmethod
  334. def max_end_func(prof_type: str, end_key: str, func: Callable) -> float:
  335. """Eval oprerator profile max end time.
  336. Args:
  337. prof_type(str): 'host' or 'device'.
  338. end_key(str): 'kern' or 'end'.
  339. func(function): apply to list of all ``thread`` of ``gpu`` time.
  340. Returns:
  341. eval time results.
  342. """
  343. return functools.partial(TimeFuncHelper._max_end, prof_type, end_key, func)