You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profile_analyzer.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. from typing import Callable, List, Optional, Union
  13. import numpy as np
  14. class NonExistNum:
  15. """
  16. An object that behaves like a number but means a field does not exist; It is
  17. always greater than any real number.
  18. """
  19. def __truediv__(self, _):
  20. return self
  21. def __add__(self, rhs):
  22. return rhs
  23. def __radd__(self, lhs):
  24. return lhs
  25. def __neg__(self):
  26. return self
  27. def __gt__(self, rhs):
  28. if isinstance(rhs) is NonExistNum:
  29. return id(self) > id(rhs)
  30. return True
  31. def __ge__(self, rhs):
  32. return self > rhs or self == rhs
  33. def __lt__(self, rhs):
  34. if isinstance(rhs) is NonExistNum:
  35. return id(self) < id(rhs)
  36. return False
  37. def __le__(self, rhs):
  38. return self < rhs or self == rhs
  39. def __eq__(self, rhs):
  40. return self is rhs
  41. def __format__(self, spec):
  42. return "N/A"
  43. def __repr__(self):
  44. return "N/A"
  45. class OprProfRst:
  46. """Opr profiling result dumped from megengine profiler."""
  47. opr_info = None
  48. """A dict containing operator info: name, id and type."""
  49. time_dict = None
  50. """
  51. A mapping from ``"host"`` or ``"device"`` to list of profiling
  52. results."""
  53. footprint = None
  54. """
  55. A mapping from ``"memory"`` or ``"computation"`` to the actual number
  56. of corresponding operations."""
  57. def __init__(self, entry: dict):
  58. """
  59. Opr profiling initialization, which sets up name, type and id of opr_info.
  60. :param entry: profiling json exec_graph items.
  61. """
  62. assert isinstance(entry, dict)
  63. self.opr_info = collections.OrderedDict()
  64. for key in ["name", "type", "id"]:
  65. self.opr_info[key] = entry[key]
  66. self.time_dict = collections.defaultdict(list)
  67. self.footprint = collections.defaultdict(NonExistNum)
  68. def update_device_prof_info(self, dev_time: dict):
  69. """
  70. Updates device profiling info.
  71. :param dev_time: device time for single opr,
  72. is an attribute of profiling result.
  73. """
  74. assert isinstance(dev_time, dict)
  75. self.time_dict["device"].append(copy.deepcopy(dev_time))
  76. def update_host_prof_info(self, host_time: dict):
  77. """
  78. Updates host profiling info.
  79. :param host_time: host time for single opr,
  80. is an attribute of profiling result.
  81. """
  82. assert isinstance(host_time, dict)
  83. self.time_dict["host"].append(copy.deepcopy(host_time))
  84. def update_footprint(self, footprint: dict):
  85. """
  86. Updates opr footprint.
  87. :param footprint: footprint for single opr,
  88. is an attribute of profiling result.
  89. """
  90. assert isinstance(footprint, dict)
  91. self.footprint.update(footprint)
  92. class Record:
  93. """A record of analyzing result"""
  94. __slot__ = [
  95. "time",
  96. "info",
  97. "computation",
  98. "memory",
  99. "in_shapes",
  100. "in_layouts",
  101. "out_shapes",
  102. "flops",
  103. "bandwidth",
  104. "opr_id",
  105. ]
  106. def __init__(self, time: float, info: dict, footprint: dict):
  107. """
  108. Initializes single record.
  109. :param time: opr running time, evaluated by applying users providing
  110. function to OprProfRst.
  111. :param info: opr information, could be original opr information or
  112. aggregate infomation if aggregating enabled.
  113. :param footprint: contains footprint information, for now, we have
  114. ``"computation"``, ``"memory"``, ``"in_shapes"``, ``"out_shapes"``.
  115. """
  116. assert isinstance(footprint, dict)
  117. self.time = time
  118. self.info = collections.OrderedDict(copy.deepcopy(info))
  119. self.computation = footprint["computation"] or NonExistNum()
  120. self.memory = footprint["memory"]
  121. self.in_shapes = footprint["in_shapes"]
  122. self.in_layouts = footprint.get("in_layouts")
  123. self.out_shapes = footprint["out_shapes"]
  124. self.flops = self.computation / self.time
  125. self.bandwidth = self.memory / self.time
  126. self.opr_id = info.get("id")
  127. if isinstance(self.opr_id, str) and self.opr_id != "N/A":
  128. self.opr_id = int(self.opr_id)
  129. def get_column_by_name(self, name: str = None):
  130. """
  131. Extracts column value by its column name.
  132. :param name: column name, None for time.
  133. """
  134. if name is None:
  135. name = "time"
  136. return getattr(self, name)
  137. class ProfileAnalyzer:
  138. def __init__(self, obj: dict, opr_filter: Callable = lambda opr, inp, out: True):
  139. """
  140. Initializes ProfileAnalyzer.
  141. :param obj: dict dumped from json str.
  142. :param opr_filter: function that filter oprs.
  143. """
  144. self._opr_set = dict() # type: dict
  145. assert isinstance(obj, dict), type(obj)
  146. varz = obj["graph_exec"]["var"]
  147. for opr_id, entry in obj["graph_exec"]["operator"].items():
  148. inp = [varz[i] for i in entry["input"]]
  149. out = [varz[i] for i in entry["output"]]
  150. if opr_filter(entry, inp, out):
  151. self._opr_set[opr_id] = OprProfRst(entry)
  152. for opr_id, entry in obj["profiler"]["device"].items():
  153. if opr_id not in self._opr_set:
  154. continue
  155. opr = self._opr_set[opr_id]
  156. for _, time in entry.items():
  157. opr.update_device_prof_info(time)
  158. for opr_id, entry in obj["profiler"]["host"].items():
  159. if opr_id not in self._opr_set:
  160. continue
  161. opr = self._opr_set[opr_id]
  162. for _, time in entry.items():
  163. opr.update_host_prof_info(time)
  164. for opr_id, entry in obj["profiler"].get("opr_footprint", {}).items():
  165. if opr_id not in self._opr_set:
  166. continue
  167. opr = self._opr_set[opr_id]
  168. opr.update_footprint(entry)
  169. def _aggregate(
  170. self, records: List[Record], aop: Union[str, Callable], atype: Optional[str]
  171. ) -> List[Record]:
  172. """
  173. Aggregate operation.
  174. :param records: selected records.
  175. :param aop: aggregate operation, if aop is str, we would replace it
  176. with associated numpy function wth aop name".
  177. :param atype: the type aggregated by, None for aggregating all into single
  178. record.
  179. """
  180. if aop is None:
  181. assert atype is None, "must specify aggregate op"
  182. return records
  183. if isinstance(aop, str):
  184. aop = getattr(np, aop)
  185. type2stat = collections.defaultdict(lambda: [[], [], []]) # type: dict
  186. for item in records:
  187. if atype == "type":
  188. d = type2stat[item.info["type"]]
  189. else:
  190. d = type2stat["all"]
  191. d[0].append(item.time)
  192. d[1].append(item.computation)
  193. d[2].append(item.memory)
  194. rst = []
  195. for opr_type in type2stat.keys():
  196. time, computation, memory = type2stat[opr_type]
  197. nr_oprs = len(time)
  198. time_rst = aop(time)
  199. comp_rst = aop(computation)
  200. mem_rst = aop(memory)
  201. item = Record(
  202. time_rst,
  203. {"type": opr_type, "count": nr_oprs, "id": "N/A"},
  204. {
  205. "computation": comp_rst,
  206. "memory": mem_rst,
  207. "in_shapes": None,
  208. "out_shapes": None,
  209. },
  210. )
  211. rst.append(item)
  212. return rst
  213. def _sort(self, records: List[Record], sort_by: str) -> List[Record]:
  214. """
  215. Sort operation.
  216. :param records: the records after aggregate operation.
  217. :param sort_by: keyword for sorting the list.
  218. """
  219. if sort_by is None:
  220. return records
  221. if sort_by.startswith("+"):
  222. sort_by = sort_by[1:]
  223. key = lambda record: record.get_column_by_name(sort_by)
  224. else:
  225. key = lambda record: -record.get_column_by_name(sort_by)
  226. records.sort(key=key)
  227. return records
  228. def select(
  229. self,
  230. time_func: Callable,
  231. opr_filter: Callable = lambda opr: True,
  232. aggregate: Callable = None,
  233. aggregate_by: str = None,
  234. sort_by: str = None,
  235. top_k: int = 0,
  236. ) -> List[Record]:
  237. """
  238. Select operation.
  239. :param time_func: time_func provided by user, would apply to every
  240. OprProfRst.
  241. :param opr_filter: filter satisfied operatiors.
  242. :param aggregate: function that apply to list of records which are
  243. aggregated by atype.
  244. :param aggregate_by: the type aggregated by.
  245. :param sort_by: keyword for sorting all records.
  246. :param top_k: specify the maximum number of records.
  247. :return: the records that go through select, aggregate, sort.
  248. """
  249. records = []
  250. for opr in self._opr_set.values():
  251. if opr_filter(opr):
  252. time = time_func(opr)
  253. if time is None:
  254. continue
  255. item = Record(time, opr.opr_info, opr.footprint)
  256. records.append(item)
  257. records = self._aggregate(records, aggregate, aggregate_by)
  258. if not records:
  259. return records
  260. return self._sort(records, sort_by)[0 : len(records) if top_k == 0 else top_k]
  261. class TimeFuncHelper:
  262. """Time Function Helper for users."""
  263. @staticmethod
  264. def _eval_time(prof_type, end_key, func, opr_prof):
  265. """
  266. Eval time.
  267. :type prof_type: str
  268. :param prof_type: 'host' or 'device'.
  269. :type end_key: str
  270. :param end_key: 'kern' or 'end'.
  271. :type func: function
  272. :param func: apply to list of all ``thread`` of ``gpu`` time.
  273. :type opr_prof: `class OprProfRst`
  274. :param opr_prof: operator profiling result.
  275. :rtype: float
  276. :return: time.
  277. """
  278. if prof_type not in opr_prof.time_dict:
  279. return None
  280. time = [time[end_key] - time["start"] for time in opr_prof.time_dict[prof_type]]
  281. return func(time)
  282. @staticmethod
  283. def eval_time_func(prof_type: str, end_key: str, func: Callable) -> float:
  284. """
  285. Eval oprerator profile time.
  286. :param prof_type: 'host' or 'device'.
  287. :param end_key: 'kern' or 'end'.
  288. :param func: apply to list of all ``thread`` of ``gpu`` time.
  289. :return: eval time results.
  290. """
  291. return functools.partial(TimeFuncHelper._eval_time, prof_type, end_key, func)
  292. @staticmethod
  293. def _min_start(
  294. prof_type, end_key, func, opr_prof
  295. ): # pylint: disable=unused-argument
  296. """
  297. Eval minimum start time.
  298. :type prof_type: str
  299. :param prof_type: 'host' or 'device'.
  300. :type end_key: str
  301. :param end_key: 'kern' or 'end'.
  302. :type func: function
  303. :param func: apply to list of all ``thread`` of ``gpu`` time.
  304. :type opr_prof: `class OprProfRst`
  305. :param opr_prof: operator profiling result.
  306. :rtype: float
  307. :return: time.
  308. """
  309. if prof_type not in opr_prof.time_dict:
  310. return None
  311. time = [time["start"] for time in opr_prof.time_dict[prof_type]]
  312. return np.min(time)
  313. @staticmethod
  314. def min_start_func(
  315. prof_type: str, end_key: str, func: Callable
  316. ) -> float: # pylint: disable=unused-argument
  317. """
  318. Eval oprerator profile min start time.
  319. :param prof_type: 'host' or 'device'.
  320. :param end_key: 'kern' or 'end'.
  321. :param func: apply to list of all ``thread`` of ``gpu`` time.
  322. :return: eval time results.
  323. """
  324. return functools.partial(TimeFuncHelper._min_start, prof_type, end_key, func)
  325. @staticmethod
  326. def _max_end(prof_type, end_key, func, opr_prof): # pylint: disable=unused-argument
  327. """
  328. Eval maximum end time
  329. :type prof_type: str
  330. :param prof_type: 'host' or 'device'.
  331. :type end_key: str
  332. :param end_key: 'kern' or 'end'.
  333. :type func: function
  334. :param func: apply to list of all ``thread`` of ``gpu`` time.
  335. :type opr_prof: `class OprProfRst`
  336. :param opr_prof: operator profiling result.
  337. :rtype: float
  338. :return: time.
  339. """
  340. if prof_type not in opr_prof.time_dict:
  341. return None
  342. time = [time["end"] for time in opr_prof.time_dict[prof_type]]
  343. return np.max(time)
  344. @staticmethod
  345. def max_end_func(prof_type: str, end_key: str, func: Callable) -> float:
  346. """
  347. Eval oprerator profile max end time.
  348. :param prof_type: 'host' or 'device'.
  349. :param end_key: 'kern' or 'end'.
  350. :param func: apply to list of all ``thread`` of ``gpu`` time.
  351. :return: eval time results.
  352. """
  353. return functools.partial(TimeFuncHelper._max_end, prof_type, end_key, func)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台