You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profiler.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import os
  11. import re
  12. from contextlib import ContextDecorator, contextmanager
  13. from functools import wraps
  14. from typing import List
  15. from weakref import WeakSet
  16. from .. import _atexit
  17. from ..core._imperative_rt.core2 import (
  18. pop_scope,
  19. push_scope,
  20. start_profile,
  21. stop_profile,
  22. sync,
  23. )
  24. from ..logger import get_logger
  25. _running_profiler = None
  26. _living_profilers = WeakSet()
  27. class Profiler(ContextDecorator):
  28. r"""Profile graph execution in imperative mode.
  29. Args:
  30. path: default path prefix for profiler to dump.
  31. Examples:
  32. .. code-block::
  33. import megengine as mge
  34. import megengine.module as M
  35. from megengine.utils.profiler import Profiler
  36. # With Learnable Parameters
  37. profiler = Profiler()
  38. for iter in range(0, 10):
  39. # Only profile record of last iter would be saved
  40. with profiler:
  41. # your code here
  42. # Then open the profile file in chrome timeline window
  43. """
  44. CHROME_TIMELINE = "chrome_timeline.json"
  45. valid_options = {"sample_rate": 0, "profile_device": 1, "num_tensor_watch": 10}
  46. valid_formats = {"chrome_timeline.json", "memory_flow.svg"}
  47. def __init__(
  48. self,
  49. path: str = "profile",
  50. format: str = "chrome_timeline.json",
  51. formats: List[str] = None,
  52. **kwargs
  53. ) -> None:
  54. if not formats:
  55. formats = [format]
  56. assert not isinstance(formats, str), "formats excepts list, got str"
  57. for format in formats:
  58. assert format in Profiler.valid_formats, "unsupported format {}".format(
  59. format
  60. )
  61. self._path = path
  62. self._formats = formats
  63. self._options = {}
  64. for opt, optval in Profiler.valid_options.items():
  65. self._options[opt] = int(kwargs.pop(opt, optval))
  66. self._pid = "<PID>"
  67. self._dump_callback = None
  68. @property
  69. def path(self):
  70. if len(self._formats) == 0:
  71. format = "<FORMAT>"
  72. elif len(self._formats) == 1:
  73. format = self._formats[0]
  74. else:
  75. format = "{" + ",".join(self._formats) + "}"
  76. return self.format_path(self._path, self._pid, format)
  77. @property
  78. def directory(self):
  79. return self._path
  80. @property
  81. def formats(self):
  82. return list(self._formats)
  83. def start(self):
  84. global _running_profiler
  85. assert _running_profiler is None
  86. _running_profiler = self
  87. self._pid = os.getpid()
  88. start_profile(self._options)
  89. return self
  90. def stop(self):
  91. global _running_profiler
  92. assert _running_profiler is self
  93. _running_profiler = None
  94. sync()
  95. self._dump_callback = stop_profile()
  96. self._pid = os.getpid()
  97. _living_profilers.add(self)
  98. def dump(self):
  99. if self._dump_callback is not None:
  100. if not os.path.exists(self._path):
  101. os.makedirs(self._path)
  102. if not os.path.isdir(self._path):
  103. get_logger().warning(
  104. "{} is not a directory, cannot write profiling results".format(
  105. self._path
  106. )
  107. )
  108. return
  109. for format in self._formats:
  110. path = self.format_path(self._path, self._pid, format)
  111. get_logger().info("process {} generating {}".format(self._pid, format))
  112. self._dump_callback(path, format)
  113. get_logger().info("profiling results written to {}".format(path))
  114. if os.path.getsize(path) > 64 * 1024 * 1024:
  115. get_logger().warning(
  116. "profiling results too large, maybe you are profiling multi iters,"
  117. "consider attach profiler in each iter separately"
  118. )
  119. self._dump_callback = None
  120. _living_profilers.remove(self)
  121. def format_path(self, path, pid, format):
  122. return os.path.join(path, "{}.{}".format(pid, format))
  123. def __enter__(self):
  124. self.start()
  125. def __exit__(self, val, tp, trace):
  126. self.stop()
  127. def __call__(self, func):
  128. func = super().__call__(func)
  129. func.__profiler__ = self
  130. return func
  131. def __del__(self):
  132. self.dump()
  133. @contextmanager
  134. def scope(name):
  135. push_scope(name)
  136. yield
  137. pop_scope(name)
  138. def profile(*args, **kwargs):
  139. if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
  140. return Profiler()(args[0])
  141. return Profiler(*args, **kwargs)
  142. def merge_trace_events(directory: str):
  143. names = filter(
  144. lambda x: re.match(r"\d+\.chrome_timeline\.json", x), os.listdir(directory)
  145. )
  146. def load_trace_events(name):
  147. with open(os.path.join(directory, name), "r", encoding="utf-8") as f:
  148. return json.load(f)
  149. def find_metadata(content):
  150. if isinstance(content, dict):
  151. assert "traceEvents" in content
  152. content = content["traceEvents"]
  153. if len(content) == 0:
  154. return None
  155. assert content[0]["name"] == "Metadata"
  156. return content[0]["args"]
  157. contents = list(map(load_trace_events, names))
  158. metadata_list = list(map(find_metadata, contents))
  159. min_local_time = min(
  160. map(lambda x: x["localTime"], filter(lambda x: x is not None, metadata_list))
  161. )
  162. events = []
  163. for content, metadata in zip(contents, metadata_list):
  164. local_events = content["traceEvents"]
  165. if len(local_events) == 0:
  166. continue
  167. local_time = metadata["localTime"]
  168. time_shift = local_time - min_local_time
  169. for event in local_events:
  170. if "ts" in event:
  171. event["ts"] = int(event["ts"] + time_shift)
  172. events.extend(filter(lambda x: x["name"] != "Metadata", local_events))
  173. result = {
  174. "traceEvents": events,
  175. }
  176. path = os.path.join(directory, "merge.chrome_timeline.json")
  177. with open(path, "w") as f:
  178. json.dump(result, f, ensure_ascii=False, separators=(",", ":"))
  179. get_logger().info("profiling results written to {}".format(path))
  180. def is_profiling():
  181. return _running_profiler is not None
  182. def _stop_current_profiler():
  183. global _running_profiler
  184. if _running_profiler is not None:
  185. _running_profiler.stop()
  186. living_profilers = [*_living_profilers]
  187. for profiler in living_profilers:
  188. profiler.dump()
  189. _atexit(_stop_current_profiler)