You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profile_analyzer.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import copy
  4. import functools
  5. from typing import Callable, List, Optional, Union
  6. import numpy as np
  7. class NonExistNum:
  8. r"""An object that behaves like a number but means a field does not exist; It is
  9. always greater than any real number.
  10. """
  11. def __truediv__(self, _):
  12. return self
  13. def __add__(self, rhs):
  14. return rhs
  15. def __radd__(self, lhs):
  16. return lhs
  17. def __neg__(self):
  18. return self
  19. def __gt__(self, rhs):
  20. if isinstance(rhs) is NonExistNum:
  21. return id(self) > id(rhs)
  22. return True
  23. def __ge__(self, rhs):
  24. return self > rhs or self == rhs
  25. def __lt__(self, rhs):
  26. if isinstance(rhs) is NonExistNum:
  27. return id(self) < id(rhs)
  28. return False
  29. def __le__(self, rhs):
  30. return self < rhs or self == rhs
  31. def __eq__(self, rhs):
  32. return self is rhs
  33. def __format__(self, spec):
  34. return "N/A"
  35. def __repr__(self):
  36. return "N/A"
  37. class OprProfRst:
  38. r"""Opr profiling result dumped from megengine profiler.
  39. Args:
  40. entry: profiling json exec_graph items. Opr profiling initialization,
  41. which sets up name, type and id of opr_info.
  42. """
  43. opr_info = None
  44. r"""A dict containing operator info: name, id and type."""
  45. time_dict = None
  46. r"""
  47. A mapping from ``"host"`` or ``"device"`` to list of profiling
  48. results."""
  49. footprint = None
  50. r"""
  51. A mapping from ``"memory"`` or ``"computation"`` to the actual number
  52. of corresponding operations."""
  53. def __init__(self, entry: dict):
  54. assert isinstance(entry, dict)
  55. self.opr_info = collections.OrderedDict()
  56. for key in ["name", "type", "id"]:
  57. self.opr_info[key] = entry[key]
  58. self.time_dict = collections.defaultdict(list)
  59. self.footprint = collections.defaultdict(NonExistNum)
  60. def update_device_prof_info(self, dev_time: dict):
  61. """Updates device profiling info.
  62. Args:
  63. dev_time: device time for single opr,
  64. is an attribute of profiling result.
  65. """
  66. assert isinstance(dev_time, dict)
  67. self.time_dict["device"].append(copy.deepcopy(dev_time))
  68. def update_host_prof_info(self, host_time: dict):
  69. r"""Updates host profiling info.
  70. Args:
  71. host_time: host time for single opr,
  72. is an attribute of profiling result.
  73. """
  74. assert isinstance(host_time, dict)
  75. self.time_dict["host"].append(copy.deepcopy(host_time))
  76. def update_footprint(self, footprint: dict):
  77. r"""Updates opr footprint.
  78. Args:
  79. footprint: footprint for single opr,
  80. is an attribute of profiling result.
  81. """
  82. assert isinstance(footprint, dict)
  83. self.footprint.update(footprint)
  84. class Record:
  85. r"""A record of analyzing result
  86. Args:
  87. time: opr running time, evaluated by applying users providing
  88. function to OprProfRst.
  89. info: opr information, could be original opr information or
  90. aggregate infomation if aggregating enabled.
  91. footprint: contains footprint information, for now, we have
  92. ``"computation"``, ``"memory"``, ``"in_shapes"``, ``"out_shapes"``.
  93. """
  94. __slot__ = [
  95. "time",
  96. "info",
  97. "computation",
  98. "memory",
  99. "in_shapes",
  100. "in_layouts",
  101. "out_shapes",
  102. "flops",
  103. "bandwidth",
  104. "opr_id",
  105. ]
  106. def __init__(self, time: float, info: dict, footprint: dict):
  107. assert isinstance(footprint, dict)
  108. self.time = time
  109. self.info = collections.OrderedDict(copy.deepcopy(info))
  110. self.computation = footprint["computation"] or NonExistNum()
  111. self.memory = footprint["memory"]
  112. self.in_shapes = footprint["in_shapes"]
  113. self.in_layouts = footprint.get("in_layouts")
  114. self.out_shapes = footprint["out_shapes"]
  115. self.flops = self.computation / self.time
  116. self.bandwidth = self.memory / self.time
  117. self.opr_id = info.get("id")
  118. if isinstance(self.opr_id, str) and self.opr_id != "N/A":
  119. self.opr_id = int(self.opr_id)
  120. def get_column_by_name(self, name: str = None):
  121. r"""Extracts column value by its column name.
  122. Args:
  123. name: column name, None for time.
  124. """
  125. if name is None:
  126. name = "time"
  127. return getattr(self, name)
  128. class ProfileAnalyzer:
  129. r"""Initializes ProfileAnalyzer.
  130. Args:
  131. obj: dict dumped from json str.
  132. opr_filter: function that filter oprs.
  133. """
  134. def __init__(self, obj: dict, opr_filter: Callable = lambda opr, inp, out: True):
  135. self._opr_set = dict() # type: dict
  136. assert isinstance(obj, dict), type(obj)
  137. varz = obj["graph_exec"]["var"]
  138. for opr_id, entry in obj["graph_exec"]["operator"].items():
  139. inp = [varz[i] for i in entry["input"]]
  140. out = [varz[i] for i in entry["output"]]
  141. if opr_filter(entry, inp, out):
  142. self._opr_set[opr_id] = OprProfRst(entry)
  143. for opr_id, entry in obj["profiler"]["device"].items():
  144. if opr_id not in self._opr_set:
  145. continue
  146. opr = self._opr_set[opr_id]
  147. for _, time in entry.items():
  148. opr.update_device_prof_info(time)
  149. for opr_id, entry in obj["profiler"]["host"].items():
  150. if opr_id not in self._opr_set:
  151. continue
  152. opr = self._opr_set[opr_id]
  153. for _, time in entry.items():
  154. opr.update_host_prof_info(time)
  155. for opr_id, entry in obj["profiler"].get("opr_footprint", {}).items():
  156. if opr_id not in self._opr_set:
  157. continue
  158. opr = self._opr_set[opr_id]
  159. opr.update_footprint(entry)
  160. def _aggregate(
  161. self, records: List[Record], aop: Union[str, Callable], atype: Optional[str]
  162. ) -> List[Record]:
  163. r"""Aggregate operation.
  164. Args:
  165. records: selected records.
  166. aop: aggregate operation, if aop is str, we would replace it
  167. with associated numpy function wth aop name".
  168. atype: the type aggregated by, None for aggregating all into single
  169. record.
  170. """
  171. if aop is None:
  172. assert atype is None, "must specify aggregate op"
  173. return records
  174. if isinstance(aop, str):
  175. aop = getattr(np, aop)
  176. type2stat = collections.defaultdict(lambda: [[], [], []]) # type: dict
  177. for item in records:
  178. if atype == "type":
  179. d = type2stat[item.info["type"]]
  180. else:
  181. d = type2stat["all"]
  182. d[0].append(item.time)
  183. d[1].append(item.computation)
  184. d[2].append(item.memory)
  185. rst = []
  186. for opr_type in type2stat.keys():
  187. time, computation, memory = type2stat[opr_type]
  188. nr_oprs = len(time)
  189. time_rst = aop(time)
  190. comp_rst = aop(computation)
  191. mem_rst = aop(memory)
  192. item = Record(
  193. time_rst,
  194. {"type": opr_type, "count": nr_oprs, "id": "N/A"},
  195. {
  196. "computation": comp_rst,
  197. "memory": mem_rst,
  198. "in_shapes": None,
  199. "out_shapes": None,
  200. },
  201. )
  202. rst.append(item)
  203. return rst
  204. def _sort(self, records: List[Record], sort_by: str) -> List[Record]:
  205. r"""Sort operation.
  206. Args:
  207. records: the records after aggregate operation.
  208. sort_by: keyword for sorting the list.
  209. """
  210. if sort_by is None:
  211. return records
  212. if sort_by.startswith("+"):
  213. sort_by = sort_by[1:]
  214. key = lambda record: record.get_column_by_name(sort_by)
  215. else:
  216. key = lambda record: -record.get_column_by_name(sort_by)
  217. records.sort(key=key)
  218. return records
  219. def select(
  220. self,
  221. time_func: Callable,
  222. opr_filter: Callable = lambda opr: True,
  223. aggregate: Callable = None,
  224. aggregate_by: str = None,
  225. sort_by: str = None,
  226. top_k: int = 0,
  227. ) -> List[Record]:
  228. r"""Select operation.
  229. Args:
  230. time_func: time_func provided by user, would apply to every
  231. OprProfRst.
  232. opr_filter: filter satisfied operatiors.
  233. aggregate: function that apply to list of records which are
  234. aggregated by atype.
  235. aggregate_by: the type aggregated by.
  236. sort_by: keyword for sorting all records.
  237. top_k: specify the maximum number of records.
  238. Returns:
  239. the records that go through select, aggregate, sort.
  240. """
  241. records = []
  242. for opr in self._opr_set.values():
  243. if opr_filter(opr):
  244. time = time_func(opr)
  245. if time is None:
  246. continue
  247. item = Record(time, opr.opr_info, opr.footprint)
  248. records.append(item)
  249. records = self._aggregate(records, aggregate, aggregate_by)
  250. if not records:
  251. return records
  252. return self._sort(records, sort_by)[0 : len(records) if top_k == 0 else top_k]
  253. class TimeFuncHelper:
  254. r"""Time Function Helper for users."""
  255. @staticmethod
  256. def _eval_time(prof_type, end_key, func, opr_prof):
  257. r"""Eval time.
  258. Args:
  259. prof_type: host' or 'device'.
  260. end_key: kern' or 'end'.
  261. func: apply to list of all ``thread`` of ``gpu`` time.
  262. opr_prof: operator profiling result.
  263. Returns:
  264. time.
  265. """
  266. if prof_type not in opr_prof.time_dict:
  267. return None
  268. time = [time[end_key] - time["start"] for time in opr_prof.time_dict[prof_type]]
  269. return func(time)
  270. @staticmethod
  271. def eval_time_func(prof_type: str, end_key: str, func: Callable) -> float:
  272. r"""Eval oprerator profile time.
  273. Args:
  274. prof_type: host' or 'device'.
  275. end_key: kern' or 'end'.
  276. func: apply to list of all ``thread`` of ``gpu`` time.
  277. Returns:
  278. eval time results.
  279. """
  280. return functools.partial(TimeFuncHelper._eval_time, prof_type, end_key, func)
  281. @staticmethod
  282. def _min_start(
  283. prof_type, end_key, func, opr_prof
  284. ): # pylint: disable=unused-argument
  285. r"""Eval minimum start time.
  286. Args:
  287. prof_type(str): 'host' or 'device'.
  288. end_key(str): 'kern' or 'end'.
  289. func(function): apply to list of all ``thread`` of ``gpu`` time.
  290. opr_prof(OprProfRst): operator profiling result.
  291. Returns:
  292. time.
  293. """
  294. if prof_type not in opr_prof.time_dict:
  295. return None
  296. time = [time["start"] for time in opr_prof.time_dict[prof_type]]
  297. return np.min(time)
  298. @staticmethod
  299. def min_start_func(
  300. prof_type: str, end_key: str, func: Callable
  301. ) -> float: # pylint: disable=unused-argument
  302. r"""Eval oprerator profile min start time.
  303. Args:
  304. prof_type(str): 'host' or 'device'.
  305. end_key(str): 'kern' or 'end'.
  306. func(function): apply to list of all ``thread`` of ``gpu`` time.
  307. Returns:
  308. eval time results.
  309. """
  310. return functools.partial(TimeFuncHelper._min_start, prof_type, end_key, func)
  311. @staticmethod
  312. def _max_end(prof_type, end_key, func, opr_prof): # pylint: disable=unused-argument
  313. r"""Eval maximum end time
  314. Args:
  315. prof_type(str): 'host' or 'device'.
  316. end_key(str): 'kern' or 'end'.
  317. func(function): apply to list of all ``thread`` of ``gpu`` time.
  318. opr_prof(OprProfRst): operator profiling result.
  319. Returns:
  320. time.
  321. """
  322. if prof_type not in opr_prof.time_dict:
  323. return None
  324. time = [time["end"] for time in opr_prof.time_dict[prof_type]]
  325. return np.max(time)
  326. @staticmethod
  327. def max_end_func(prof_type: str, end_key: str, func: Callable) -> float:
  328. """Eval oprerator profile max end time.
  329. Args:
  330. prof_type(str): 'host' or 'device'.
  331. end_key(str): 'kern' or 'end'.
  332. func(function): apply to list of all ``thread`` of ``gpu`` time.
  333. Returns:
  334. eval time results.
  335. """
  336. return functools.partial(TimeFuncHelper._max_end, prof_type, end_key, func)