You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_tracer.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from ... import Tensor
  11. from ... import functional as F
  12. from ...core.tensor.array_method import ArrayMethodMixin
  13. from ...module import Module
  14. _active_module_tracer = None
  15. def active_module_tracer():
  16. return _active_module_tracer
  17. def set_active_module_tracer(tracer):
  18. global _active_module_tracer
  19. _active_module_tracer = tracer
  20. class module_tracer:
  21. # builtin types
  22. _opaque_types = set()
  23. _active_scopes = None
  24. def __init__(self, wrap_fn):
  25. self._active_scopes = []
  26. self.patcher = Patcher(wrap_fn)
  27. @classmethod
  28. def register_as_builtin(cls, mod):
  29. assert issubclass(mod, Module)
  30. cls._opaque_types.add(mod)
  31. return mod
  32. @classmethod
  33. def is_builtin(cls, mod):
  34. return type(mod) in cls._opaque_types
  35. def push_scope(self, scope):
  36. self._active_scopes.append(scope)
  37. def pop_scope(self):
  38. self._active_scopes.pop()
  39. def current_scope(self):
  40. if self._active_scopes:
  41. return self._active_scopes[-1]
  42. return None
  43. class PatchedFn:
  44. frame_dict = None
  45. name = None
  46. origin_fn = None
  47. def __init__(self, frame_dict, name):
  48. self.frame_dict = frame_dict
  49. self.name = name
  50. self.origin_fn = (
  51. self.frame_dict[name]
  52. if isinstance(frame_dict, collections.abc.Mapping)
  53. else getattr(frame_dict, name)
  54. )
  55. def set_func(self, func):
  56. if isinstance(self.frame_dict, collections.abc.Mapping):
  57. self.frame_dict[self.name] = func
  58. else:
  59. setattr(self.frame_dict, self.name, func)
  60. class Patcher:
  61. patched_fn_ids = set()
  62. _builtin_functions = []
  63. _builtin_modules = [
  64. F,
  65. F.distributed,
  66. F.elemwise,
  67. F.inplace,
  68. F.loss,
  69. F.math,
  70. F.metric,
  71. F.nn,
  72. F.quantized,
  73. F.tensor,
  74. F.utils,
  75. F.vision,
  76. ]
  77. _builtin_methods = [
  78. Tensor,
  79. ArrayMethodMixin,
  80. ]
  81. def __init__(self, wrap_fn):
  82. self.patched_fn = []
  83. self.visited_frames_ids = set()
  84. self.wrap_fn = wrap_fn
  85. for module in self._builtin_modules:
  86. self.patch_module(module)
  87. for cls in self._builtin_methods:
  88. self.patch_cls(cls)
  89. for i, j in self._builtin_functions:
  90. if id(i) not in self.visited_frames_ids:
  91. self.patch_function(i, j, self.wrap_fn)
  92. def patch_function(self, frame_dict, fn, wrap_fn):
  93. patched_fn = PatchedFn(frame_dict, fn)
  94. self.patched_fn_ids.add(id(patched_fn.origin_fn))
  95. patched_fn.set_func(wrap_fn(patched_fn.origin_fn))
  96. self.patched_fn.append(patched_fn)
  97. def patch_method(self, cls, name, wrap_fn):
  98. self.patch_function(cls, name, wrap_fn)
  99. def patch_cls(self, cls):
  100. import inspect
  101. if id(cls) not in self.visited_frames_ids:
  102. for k, v in cls.__dict__.items():
  103. if inspect.isfunction(v) and not k.startswith("_"):
  104. self.patch_function(cls, k, self.wrap_fn)
  105. self.visited_frames_ids.add(id(cls))
  106. def patch_module(self, module):
  107. import inspect
  108. if id(module.__dict__) not in self.visited_frames_ids:
  109. for k, v in module.__dict__.items():
  110. if inspect.isfunction(v) and not k.startswith("_"):
  111. self.patch_function(module.__dict__, k, self.wrap_fn)
  112. self.visited_frames_ids.add(id(module.__dict__))
  113. def auto_patch(self, frame_dict):
  114. if id(frame_dict) not in self.visited_frames_ids:
  115. for k, v in frame_dict.items():
  116. if id(v) in self.patched_fn_ids:
  117. self.patch_function(frame_dict, k, self.wrap_fn)
  118. self.visited_frames_ids.add(id(frame_dict))
  119. def __enter__(self):
  120. return self
  121. def __exit__(self, type, vlaue, trace):
  122. while self.patched_fn:
  123. pf = self.patched_fn.pop()
  124. pf.set_func(pf.origin_fn)
  125. self.visited_frames_ids.clear()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台