You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_tracer.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from ... import Tensor
  11. from ... import functional as F
  12. from ...core.tensor.array_method import ArrayMethodMixin
  13. from ...module import Module
  14. from ...module.qat import QATModule
  15. _active_module_tracer = None
  16. BUILTIN_ARRAY_METHOD = [
  17. "__lt__",
  18. "__le__",
  19. "__gt__",
  20. "__ge__",
  21. "__eq__",
  22. "__ne__",
  23. "__neg__",
  24. "__pos__",
  25. "__abs__",
  26. "__invert__",
  27. "__round__",
  28. "__floor__",
  29. "__ceil__",
  30. "__add__",
  31. "__sub__",
  32. "__mul__",
  33. "__matmul__",
  34. "__truediv__",
  35. "__floordiv__",
  36. "__mod__",
  37. "__pow__",
  38. "__lshift__",
  39. "__rshift__",
  40. "__and__",
  41. "__or__",
  42. "__xor__",
  43. "__radd__",
  44. "__rsub__",
  45. "__rmul__",
  46. "__rmatmul__",
  47. "__rtruediv__",
  48. "__rfloordiv__",
  49. "__rmod__",
  50. "__rpow__",
  51. "__rlshift__",
  52. "__rrshift__",
  53. "__rand__",
  54. "__ror__",
  55. "__rxor__",
  56. "__iadd__",
  57. "__isub__",
  58. "__imul__",
  59. "__imatmul__",
  60. "__itruediv__",
  61. "__ifloordiv__",
  62. "__imod__",
  63. "__ipow__",
  64. "__ilshift__",
  65. "__irshift__",
  66. "__iand__",
  67. "__ior__",
  68. "__ixor__",
  69. "transpose",
  70. "astype",
  71. "reshape",
  72. "_broadcast",
  73. "transpose",
  74. "flatten",
  75. "sum",
  76. "prod",
  77. "min",
  78. "max",
  79. "mean",
  80. "__getitem__",
  81. "__setitem__",
  82. ]
  83. def active_module_tracer():
  84. return _active_module_tracer
  85. def set_active_module_tracer(tracer):
  86. global _active_module_tracer
  87. _active_module_tracer = tracer
  88. class module_tracer:
  89. # builtin types
  90. _opaque_types = set()
  91. _active_scopes = None
  92. def __init__(self, wrap_fn):
  93. self._active_scopes = []
  94. self.patcher = Patcher(wrap_fn)
  95. @classmethod
  96. def register_as_builtin(cls, mod):
  97. assert issubclass(mod, Module)
  98. cls._opaque_types.add(mod)
  99. return mod
  100. @classmethod
  101. def is_builtin(cls, mod):
  102. return type(mod) in cls._opaque_types
  103. def push_scope(self, scope):
  104. self._active_scopes.append(scope)
  105. def pop_scope(self):
  106. self._active_scopes.pop()
  107. def current_scope(self):
  108. if self._active_scopes:
  109. return self._active_scopes[-1]
  110. return None
  111. class PatchedFn:
  112. frame_dict = None
  113. name = None
  114. origin_fn = None
  115. def __init__(self, frame_dict, name):
  116. self.frame_dict = frame_dict
  117. self.name = name
  118. self.origin_fn = (
  119. self.frame_dict[name]
  120. if isinstance(frame_dict, collections.abc.Mapping)
  121. else getattr(frame_dict, name)
  122. )
  123. def set_func(self, func):
  124. if isinstance(self.frame_dict, collections.abc.Mapping):
  125. self.frame_dict[self.name] = func
  126. else:
  127. setattr(self.frame_dict, self.name, func)
  128. class Patcher:
  129. patched_fn_ids = set()
  130. _builtin_functions = []
  131. _builtin_modules = [
  132. F,
  133. F.distributed,
  134. F.elemwise,
  135. F.inplace,
  136. F.loss,
  137. F.math,
  138. F.metric,
  139. F.nn,
  140. F.quantized,
  141. F.tensor,
  142. F.utils,
  143. F.vision,
  144. ]
  145. _builtin_methods = [
  146. Tensor,
  147. ArrayMethodMixin,
  148. ]
  149. def __init__(self, wrap_fn):
  150. self.patched_fn = []
  151. self.visited_frames_ids = set()
  152. self.wrap_fn = wrap_fn
  153. for module in self._builtin_modules:
  154. self.patch_module(module)
  155. for meth in BUILTIN_ARRAY_METHOD:
  156. self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
  157. self.patch_method(Tensor, "detach", self.wrap_fn)
  158. self.patch_method(Tensor, "__new__", self.wrap_fn)
  159. self.patch_method(QATModule, "_apply_fakequant_with_observer", self.wrap_fn)
  160. for i, j in self._builtin_functions:
  161. if id(i) not in self.visited_frames_ids:
  162. self.patch_function(i, j, self.wrap_fn)
  163. for m in module_tracer._opaque_types:
  164. self.auto_patch(getattr(getattr(m, "forward", m), "__globals__", {}))
  165. def patch_function(self, frame_dict, fn, wrap_fn):
  166. patched_fn = PatchedFn(frame_dict, fn)
  167. self.patched_fn_ids.add(id(patched_fn.origin_fn))
  168. patched_fn.set_func(wrap_fn(patched_fn.origin_fn))
  169. self.patched_fn.append(patched_fn)
  170. def patch_method(self, cls, name, wrap_fn):
  171. self.patch_function(cls, name, wrap_fn)
  172. def patch_cls(self, cls):
  173. import inspect
  174. if id(cls) not in self.visited_frames_ids:
  175. for k, v in cls.__dict__.items():
  176. if inspect.isfunction(v) and not k.startswith("_"):
  177. self.patch_function(cls, k, self.wrap_fn)
  178. self.visited_frames_ids.add(id(cls))
  179. def patch_module(self, module):
  180. import inspect
  181. if id(module.__dict__) not in self.visited_frames_ids:
  182. keys = (
  183. getattr(module, "__all__")
  184. if hasattr(module, "__all__")
  185. else module.__dict__.keys()
  186. )
  187. for k in keys:
  188. v = getattr(module, k)
  189. if inspect.isfunction(v) and not k.startswith("_"):
  190. self.patch_function(module.__dict__, k, self.wrap_fn)
  191. self.visited_frames_ids.add(id(module.__dict__))
  192. def auto_patch(self, frame_dict):
  193. if id(frame_dict) not in self.visited_frames_ids:
  194. for k, v in frame_dict.items():
  195. if id(v) in self.patched_fn_ids:
  196. self.patch_function(frame_dict, k, self.wrap_fn)
  197. self.visited_frames_ids.add(id(frame_dict))
  198. def __enter__(self):
  199. return self
  200. def __exit__(self, type, vlaue, trace):
  201. while self.patched_fn:
  202. pf = self.patched_fn.pop()
  203. pf.set_func(pf.origin_fn)
  204. self.visited_frames_ids.clear()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台