You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profiler.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import os
  11. import re
  12. from contextlib import ContextDecorator, contextmanager
  13. from functools import wraps
  14. from typing import List
  15. from weakref import WeakSet
  16. from .. import _atexit
  17. from ..core._imperative_rt.core2 import (
  18. pop_scope,
  19. push_scope,
  20. start_profile,
  21. stop_profile,
  22. sync,
  23. )
  24. from ..logger import get_logger
  25. _running_profiler = None
  26. _living_profilers = WeakSet()
  27. class Profiler(ContextDecorator):
  28. r"""Profile graph execution in imperative mode.
  29. Args:
  30. path: default path prefix for profiler to dump.
  31. Examples:
  32. .. code-block::
  33. import megengine as mge
  34. import megengine.module as M
  35. from megengine.utils.profiler import Profiler
  36. # With Learnable Parameters
  37. profiler = Profiler()
  38. for iter in range(0, 10):
  39. # Only profile record of last iter would be saved
  40. with profiler:
  41. # your code here
  42. # Then open the profile file in chrome timeline window
  43. """
  44. CHROME_TIMELINE = "chrome_timeline.json"
  45. valid_options = {"sample_rate": 0, "profile_device": 1, "num_tensor_watch": 10}
  46. valid_formats = {"chrome_timeline.json", "memory_flow.svg"}
  47. def __init__(
  48. self,
  49. path: str = "profile",
  50. format: str = "chrome_timeline.json",
  51. formats: List[str] = None,
  52. **kwargs
  53. ) -> None:
  54. if not formats:
  55. formats = [format]
  56. assert not isinstance(formats, str), "formats excepts list, got str"
  57. for format in formats:
  58. assert format in Profiler.valid_formats, "unsupported format {}".format(
  59. format
  60. )
  61. self._path = path
  62. self._formats = formats
  63. self._options = {}
  64. for opt, optval in Profiler.valid_options.items():
  65. self._options[opt] = int(kwargs.pop(opt, optval))
  66. self._pid = "<PID>"
  67. self._dump_callback = None
  68. @property
  69. def path(self):
  70. if len(self._formats) == 0:
  71. format = "<FORMAT>"
  72. elif len(self._formats) == 1:
  73. format = self._formats[0]
  74. else:
  75. format = "{" + ",".join(self._formats) + "}"
  76. return self.format_path(self._path, self._pid, format)
  77. @property
  78. def directory(self):
  79. return self._path
  80. @property
  81. def formats(self):
  82. return list(self._formats)
  83. def start(self):
  84. global _running_profiler
  85. assert _running_profiler is None
  86. _running_profiler = self
  87. self._pid = os.getpid()
  88. start_profile(self._options)
  89. return self
  90. def stop(self):
  91. global _running_profiler
  92. assert _running_profiler is self
  93. _running_profiler = None
  94. sync()
  95. self._dump_callback = stop_profile()
  96. self._pid = os.getpid()
  97. _living_profilers.add(self)
  98. def dump(self):
  99. if self._dump_callback is not None:
  100. if not os.path.exists(self._path):
  101. os.makedirs(self._path)
  102. if not os.path.isdir(self._path):
  103. get_logger().warning(
  104. "{} is not a directory, cannot write profiling results".format(
  105. self._path
  106. )
  107. )
  108. return
  109. for format in self._formats:
  110. path = self.format_path(self._path, self._pid, format)
  111. get_logger().info("process {} generating {}".format(self._pid, format))
  112. self._dump_callback(path, format)
  113. get_logger().info("profiling results written to {}".format(path))
  114. self._dump_callback = None
  115. _living_profilers.remove(self)
  116. def format_path(self, path, pid, format):
  117. return os.path.join(path, "{}.{}".format(pid, format))
  118. def __enter__(self):
  119. self.start()
  120. def __exit__(self, val, tp, trace):
  121. self.stop()
  122. def __call__(self, func):
  123. func = super().__call__(func)
  124. func.__profiler__ = self
  125. return func
  126. def __del__(self):
  127. self.dump()
  128. @contextmanager
  129. def scope(name):
  130. push_scope(name)
  131. yield
  132. pop_scope(name)
  133. def profile(*args, **kwargs):
  134. if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
  135. return Profiler()(args[0])
  136. return Profiler(*args, **kwargs)
  137. def merge_trace_events(directory: str):
  138. names = filter(
  139. lambda x: re.match(r"\d+\.chrome_timeline\.json", x), os.listdir(directory)
  140. )
  141. def load_trace_events(name):
  142. with open(os.path.join(directory, name), "r", encoding="utf-8") as f:
  143. return json.load(f)
  144. def find_metadata(content):
  145. if isinstance(content, dict):
  146. assert "traceEvents" in content
  147. content = content["traceEvents"]
  148. if len(content) == 0:
  149. return None
  150. assert content[0]["name"] == "Metadata"
  151. return content[0]["args"]
  152. contents = list(map(load_trace_events, names))
  153. metadata_list = list(map(find_metadata, contents))
  154. min_local_time = min(
  155. map(lambda x: x["localTime"], filter(lambda x: x is not None, metadata_list))
  156. )
  157. events = []
  158. for content, metadata in zip(contents, metadata_list):
  159. local_events = content["traceEvents"]
  160. if len(local_events) == 0:
  161. continue
  162. local_time = metadata["localTime"]
  163. time_shift = local_time - min_local_time
  164. for event in local_events:
  165. if "ts" in event:
  166. event["ts"] = int(event["ts"] + time_shift)
  167. events.extend(filter(lambda x: x["name"] != "Metadata", local_events))
  168. result = {
  169. "traceEvents": events,
  170. }
  171. path = os.path.join(directory, "merge.chrome_timeline.json")
  172. with open(path, "w") as f:
  173. json.dump(result, f, ensure_ascii=False, separators=(",", ":"))
  174. get_logger().info("profiling results written to {}".format(path))
  175. def is_profiling():
  176. return _running_profiler is not None
  177. def _stop_current_profiler():
  178. global _running_profiler
  179. if _running_profiler is not None:
  180. _running_profiler.stop()
  181. living_profilers = [*_living_profilers]
  182. for profiler in living_profilers:
  183. profiler.dump()
  184. _atexit(_stop_current_profiler)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台