You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profiler.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import os
  4. import re
  5. from contextlib import ContextDecorator, contextmanager
  6. from functools import wraps
  7. from typing import List
  8. from weakref import WeakSet
  9. from .. import _atexit
  10. from ..core._imperative_rt.core2 import (
  11. cupti_available,
  12. disable_cupti,
  13. enable_cupti,
  14. full_sync,
  15. pop_scope,
  16. push_scope,
  17. start_profile,
  18. stop_profile,
  19. sync,
  20. )
  21. from ..logger import get_logger
  22. _running_profiler = None
  23. _living_profilers = WeakSet()
  24. class Profiler(ContextDecorator):
  25. r"""Profile graph execution in imperative mode.
  26. Args:
  27. path: default path prefix for profiler to dump.
  28. Examples:
  29. .. code-block::
  30. import megengine as mge
  31. import megengine.module as M
  32. from megengine.utils.profiler import Profiler
  33. # With Learnable Parameters
  34. profiler = Profiler()
  35. for iter in range(0, 10):
  36. # Only profile record of last iter would be saved
  37. with profiler:
  38. # your code here
  39. # Then open the profile file in chrome timeline window
  40. """
  41. CHROME_TIMELINE = "chrome_timeline.json"
  42. valid_options = {
  43. "sample_rate": 0,
  44. "profile_device": 1,
  45. "num_tensor_watch": 10,
  46. "enable_cupti": 0,
  47. }
  48. valid_formats = {"chrome_timeline.json", "memory_flow.svg"}
  49. def __init__(
  50. self,
  51. path: str = "profile",
  52. format: str = "chrome_timeline.json",
  53. formats: List[str] = None,
  54. **kwargs
  55. ) -> None:
  56. if not formats:
  57. formats = [format]
  58. assert not isinstance(formats, str), "formats excepts list, got str"
  59. for format in formats:
  60. assert format in Profiler.valid_formats, "unsupported format {}".format(
  61. format
  62. )
  63. self._path = path
  64. self._formats = formats
  65. self._options = {}
  66. for opt, optval in Profiler.valid_options.items():
  67. self._options[opt] = int(kwargs.pop(opt, optval))
  68. self._pid = "<PID>"
  69. self._dump_callback = None
  70. if self._options.get("enable_cupti", 0):
  71. if cupti_available():
  72. enable_cupti()
  73. else:
  74. get_logger().warning("CuPTI unavailable")
  75. @property
  76. def path(self):
  77. if len(self._formats) == 0:
  78. format = "<FORMAT>"
  79. elif len(self._formats) == 1:
  80. format = self._formats[0]
  81. else:
  82. format = "{" + ",".join(self._formats) + "}"
  83. return self.format_path(self._path, self._pid, format)
  84. @property
  85. def directory(self):
  86. return self._path
  87. @property
  88. def formats(self):
  89. return list(self._formats)
  90. def start(self):
  91. global _running_profiler
  92. assert _running_profiler is None
  93. _running_profiler = self
  94. self._pid = os.getpid()
  95. start_profile(self._options)
  96. return self
  97. def stop(self):
  98. global _running_profiler
  99. assert _running_profiler is self
  100. _running_profiler = None
  101. full_sync()
  102. self._dump_callback = stop_profile()
  103. self._pid = os.getpid()
  104. _living_profilers.add(self)
  105. def dump(self):
  106. if self._dump_callback is not None:
  107. if not os.path.exists(self._path):
  108. os.makedirs(self._path)
  109. if not os.path.isdir(self._path):
  110. get_logger().warning(
  111. "{} is not a directory, cannot write profiling results".format(
  112. self._path
  113. )
  114. )
  115. return
  116. for format in self._formats:
  117. path = self.format_path(self._path, self._pid, format)
  118. get_logger().info("process {} generating {}".format(self._pid, format))
  119. self._dump_callback(path, format)
  120. get_logger().info("profiling results written to {}".format(path))
  121. if os.path.getsize(path) > 64 * 1024 * 1024:
  122. get_logger().warning(
  123. "profiling results too large, maybe you are profiling multi iters,"
  124. "consider attach profiler in each iter separately"
  125. )
  126. self._dump_callback = None
  127. _living_profilers.remove(self)
  128. def format_path(self, path, pid, format):
  129. return os.path.join(path, "{}.{}".format(pid, format))
  130. def __enter__(self):
  131. self.start()
  132. def __exit__(self, val, tp, trace):
  133. self.stop()
  134. def __call__(self, func):
  135. func = super().__call__(func)
  136. func.__profiler__ = self
  137. return func
  138. def __del__(self):
  139. if self._options.get("enable_cupti", 0):
  140. if cupti_available():
  141. disable_cupti()
  142. self.dump()
  143. @contextmanager
  144. def scope(name):
  145. push_scope(name)
  146. yield
  147. pop_scope(name)
  148. def profile(*args, **kwargs):
  149. if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
  150. return Profiler()(args[0])
  151. return Profiler(*args, **kwargs)
  152. def merge_trace_events(directory: str):
  153. names = filter(
  154. lambda x: re.match(r"\d+\.chrome_timeline\.json", x), os.listdir(directory)
  155. )
  156. def load_trace_events(name):
  157. with open(os.path.join(directory, name), "r", encoding="utf-8") as f:
  158. return json.load(f)
  159. def find_metadata(content):
  160. if isinstance(content, dict):
  161. assert "traceEvents" in content
  162. content = content["traceEvents"]
  163. if len(content) == 0:
  164. return None
  165. assert content[0]["name"] == "Metadata"
  166. return content[0]["args"]
  167. contents = list(map(load_trace_events, names))
  168. metadata_list = list(map(find_metadata, contents))
  169. min_local_time = min(
  170. map(lambda x: x["localTime"], filter(lambda x: x is not None, metadata_list))
  171. )
  172. events = []
  173. for content, metadata in zip(contents, metadata_list):
  174. local_events = content["traceEvents"]
  175. if len(local_events) == 0:
  176. continue
  177. local_time = metadata["localTime"]
  178. time_shift = local_time - min_local_time
  179. for event in local_events:
  180. if "ts" in event:
  181. event["ts"] = int(event["ts"] + time_shift)
  182. events.extend(filter(lambda x: x["name"] != "Metadata", local_events))
  183. result = {
  184. "traceEvents": events,
  185. }
  186. path = os.path.join(directory, "merge.chrome_timeline.json")
  187. with open(path, "w") as f:
  188. json.dump(result, f, ensure_ascii=False, separators=(",", ":"))
  189. get_logger().info("profiling results written to {}".format(path))
  190. def is_profiling():
  191. return _running_profiler is not None
  192. def _stop_current_profiler():
  193. global _running_profiler
  194. if _running_profiler is not None:
  195. _running_profiler.stop()
  196. living_profilers = [*_living_profilers]
  197. for profiler in living_profilers:
  198. profiler.dump()
  199. _atexit(_stop_current_profiler)