You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profiler.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import base64
  10. import json
  11. import os
  12. from typing import List, Optional
  13. from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry
  14. from ..core._imperative_rt import ProfilerImpl as _Profiler
  15. from ..core._imperative_rt.imperative import sync
  16. from ..core._imperative_rt.ops import CollectiveCommMode
  17. from ..core.ops.builtin import GetVarShape
  18. class Profiler:
  19. r"""
  20. Profile graph execution in imperative mode.
  21. :type path: Optional[str]
  22. :param path: default path for profiler to dump.
  23. Examples:
  24. .. testcode::
  25. import megengine as mge
  26. import megengine.module as M
  27. import megengine.utils.profiler.Profiler
  28. # With Learnable Parameters
  29. for iter in range(0, 10):
  30. # Only profile record of last iter would be saved
  31. with Profiler("profile.json"):
  32. # your code here
  33. # Then open the profile file in chrome timeline window
  34. """
  35. # see https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html
  36. GOOD = "good"
  37. BAD = "bad"
  38. TERRIBLE = "terrible"
  39. BLACK = "black"
  40. GREY = "grey"
  41. WHITE = "white"
  42. YELLOW = "yellow"
  43. OLIVE = "olive"
  44. def __init__(self, path: str = "profile.json"):
  45. self._impl = _Profiler()
  46. self._path = path
  47. self._color_map = {}
  48. self._type_map = {
  49. OperatorNodeConfig: lambda x: self.print_opnode_config(x),
  50. bytes: lambda x: base64.encodebytes(x).decode("ascii"),
  51. CollectiveCommMode: lambda x: str(x),
  52. }
  53. def __enter__(self):
  54. sync()
  55. self._impl.start()
  56. return self
  57. def __exit__(self, val, type, trace):
  58. sync()
  59. self._impl.stop()
  60. if self._path is not None:
  61. self.dump()
  62. def recolor(self, target: str, color: str):
  63. self._color_map[target] = color
  64. return self
  65. def print_opnode_config(self, config):
  66. return self.make_dict(
  67. name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr,
  68. )
  69. def fetch_attrs(self, op):
  70. attrs = dir(op)
  71. results = {}
  72. for attr in attrs:
  73. if attr.startswith("_"):
  74. continue
  75. value = op.__getattribute__(attr)
  76. if callable(value):
  77. continue
  78. value_type = type(value)
  79. if value_type in self._type_map:
  80. value = self._type_map[value_type](value)
  81. results[attr] = value
  82. return results
  83. def make_dict(self, **kwargs):
  84. unused_keys = []
  85. for k, v in kwargs.items():
  86. if v is None:
  87. unused_keys.append(k)
  88. for k in unused_keys:
  89. del kwargs[k]
  90. return kwargs
  91. def dump(self, path: Optional[str] = None):
  92. pid = os.getpid()
  93. if path is None:
  94. path = self._path
  95. trace_events = []
  96. def append_event(**kwargs):
  97. trace_events.append(self.make_dict(**kwargs))
  98. entries: List[ProfileEntry] = self._impl.dump()
  99. for id, entry in enumerate(entries):
  100. op = entry.op
  101. name = type(op).__name__
  102. host_begin, host_end = entry.host
  103. device_list = entry.device_list
  104. args = self.fetch_attrs(op)
  105. args["__id__"] = "[{}]".format(id)
  106. cname = self._color_map[name] if name in self._color_map else None
  107. cat = name
  108. for ts, ph in [(host_begin, "B"), (host_end, "E")]:
  109. append_event(
  110. name=name,
  111. ph=ph,
  112. ts=ts * 1000,
  113. pid=pid,
  114. tid="host",
  115. args=args,
  116. cname=cname,
  117. cat=cat,
  118. )
  119. for device, device_begin, device_end in device_list:
  120. for ts, ph in [(device_begin(), "B"), (device_end(), "E")]:
  121. append_event(
  122. name=name,
  123. ph=ph,
  124. ts=ts * 1000,
  125. pid=pid,
  126. tid=str(device),
  127. args=args,
  128. cname=cname,
  129. )
  130. with open(path, "w") as f:
  131. json.dump(trace_events, f, indent=2)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台