You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_tracer.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import collections
  2. from .. import Tensor
  3. from .. import functional as F
  4. from ..core.tensor.array_method import ArrayMethodMixin
  5. from ..module import Module
  6. from ..module.qat import QATModule
  7. from .checker import TracedModuleChecker
  8. _active_module_tracer = None
  9. BUILTIN_ARRAY_METHOD = [
  10. "__lt__",
  11. "__le__",
  12. "__gt__",
  13. "__ge__",
  14. "__eq__",
  15. "__ne__",
  16. "__neg__",
  17. "__pos__",
  18. "__abs__",
  19. "__invert__",
  20. "__round__",
  21. "__floor__",
  22. "__ceil__",
  23. "__add__",
  24. "__sub__",
  25. "__mul__",
  26. "__matmul__",
  27. "__truediv__",
  28. "__floordiv__",
  29. "__mod__",
  30. "__pow__",
  31. "__lshift__",
  32. "__rshift__",
  33. "__and__",
  34. "__or__",
  35. "__xor__",
  36. "__radd__",
  37. "__rsub__",
  38. "__rmul__",
  39. "__rmatmul__",
  40. "__rtruediv__",
  41. "__rfloordiv__",
  42. "__rmod__",
  43. "__rpow__",
  44. "__rlshift__",
  45. "__rrshift__",
  46. "__rand__",
  47. "__ror__",
  48. "__rxor__",
  49. "__iadd__",
  50. "__isub__",
  51. "__imul__",
  52. "__imatmul__",
  53. "__itruediv__",
  54. "__ifloordiv__",
  55. "__imod__",
  56. "__ipow__",
  57. "__ilshift__",
  58. "__irshift__",
  59. "__iand__",
  60. "__ior__",
  61. "__ixor__",
  62. "transpose",
  63. "astype",
  64. "reshape",
  65. "_broadcast",
  66. "flatten",
  67. "sum",
  68. "prod",
  69. "min",
  70. "max",
  71. "mean",
  72. "__getitem__",
  73. "__setitem__",
  74. ]
  75. BUILTIN_TENSOR_WRAP_METHOD = [
  76. "T",
  77. "to",
  78. "size",
  79. "shape",
  80. "detach",
  81. "device",
  82. "dtype",
  83. "grad",
  84. "item",
  85. "ndim",
  86. "numpy",
  87. "qparams",
  88. "set_value",
  89. "reset_zero",
  90. "requires_grad",
  91. "_reset",
  92. "_isscalar",
  93. "_tuple_shape",
  94. ]
  95. def get_tensor_wrapable_method():
  96. return BUILTIN_TENSOR_WRAP_METHOD + BUILTIN_ARRAY_METHOD
  97. def active_module_tracer():
  98. return _active_module_tracer
  99. def set_active_module_tracer(tracer):
  100. global _active_module_tracer
  101. _active_module_tracer = tracer
  102. class module_tracer:
  103. # builtin types
  104. _opaque_types = set()
  105. _active_scopes = None
  106. def __init__(self, wrap_fn):
  107. self._active_scopes = []
  108. self.checker = TracedModuleChecker(self)
  109. self.patcher = Patcher(wrap_fn)
  110. self._activate_constant_cache = []
  111. @classmethod
  112. def register_as_builtin(cls, mod):
  113. assert issubclass(mod, Module)
  114. cls._opaque_types.add(mod)
  115. return mod
  116. @classmethod
  117. def is_builtin(cls, mod):
  118. return type(mod) in cls._opaque_types
  119. def push_scope(self, scope):
  120. self._active_scopes.append(scope)
  121. self.checker.push_scope()
  122. self._activate_constant_cache.append([])
  123. def pop_scope(self):
  124. self._active_scopes.pop()
  125. self.checker.pop_scope()
  126. cache = self._activate_constant_cache.pop()
  127. for obj in cache:
  128. if hasattr(obj, "_NodeMixin__node"):
  129. delattr(obj, "_NodeMixin__node")
  130. def current_scope(self):
  131. if self._active_scopes:
  132. return self._active_scopes[-1]
  133. return None
  134. def current_constant_cache(self):
  135. if self._activate_constant_cache:
  136. return self._activate_constant_cache[-1]
  137. return None
  138. def top_scope(self):
  139. if self._active_scopes:
  140. return self._active_scopes[0]
  141. return None
  142. class NotExist:
  143. pass
  144. class PatchedFn:
  145. frame_dict = None
  146. name = None
  147. origin_fn = None
  148. def __init__(self, frame_dict, name):
  149. self.frame_dict = frame_dict
  150. self.name = name
  151. self.origin_fn = (
  152. self.frame_dict[name]
  153. if isinstance(frame_dict, collections.abc.Mapping)
  154. else getattr(frame_dict, name, NotExist)
  155. )
  156. def set_func(self, func):
  157. if isinstance(self.frame_dict, collections.abc.Mapping):
  158. self.frame_dict[self.name] = func
  159. else:
  160. if func is not NotExist:
  161. setattr(self.frame_dict, self.name, func)
  162. else:
  163. delattr(self.frame_dict, self.name)
  164. class Patcher:
  165. _builtin_functions = []
  166. _builtin_modules = [
  167. F,
  168. F.distributed,
  169. F.elemwise,
  170. F.inplace,
  171. F.loss,
  172. F.math,
  173. F.metric,
  174. F.nn,
  175. F.quantized,
  176. F.tensor,
  177. F.utils,
  178. F.vision,
  179. ]
  180. _builtin_methods = [
  181. Tensor,
  182. ArrayMethodMixin,
  183. ]
  184. def __init__(self, wrap_fn):
  185. self.patched_fn_ids = set()
  186. self.patched_fn = []
  187. self.visited_frames_ids = set()
  188. self.wrap_fn = wrap_fn
  189. for module in self._builtin_modules:
  190. self.patch_module(module)
  191. # some functions in F.nn are import from other module, and not in __all__
  192. self.auto_patch(F.nn.__dict__, False)
  193. for meth in BUILTIN_ARRAY_METHOD:
  194. self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
  195. self.patch_method(Tensor, "detach", self.wrap_fn)
  196. self.patch_method(Tensor, "__new__", self.wrap_fn)
  197. self.patch_method(QATModule, "_apply_fakequant_with_observer", self.wrap_fn)
  198. for i, j in self._builtin_functions:
  199. if id(i) not in self.visited_frames_ids:
  200. self.patch_function(i, j, self.wrap_fn)
  201. for m in module_tracer._opaque_types:
  202. self.auto_patch(getattr(getattr(m, "forward", m), "__globals__", {}))
  203. def patch_function(self, frame_dict, fn, wrap_fn):
  204. patched_fn = PatchedFn(frame_dict, fn)
  205. self.patched_fn_ids.add(id(patched_fn.origin_fn))
  206. patched_fn.set_func(wrap_fn(patched_fn.origin_fn))
  207. self.patched_fn.append(patched_fn)
  208. def patch_method(self, cls, name, wrap_fn):
  209. self.patch_function(cls, name, wrap_fn)
  210. def patch_cls(self, cls):
  211. import inspect
  212. if id(cls) not in self.visited_frames_ids:
  213. for k, v in cls.__dict__.items():
  214. if inspect.isfunction(v) and not k.startswith("_"):
  215. self.patch_function(cls, k, self.wrap_fn)
  216. self.visited_frames_ids.add(id(cls))
  217. def patch_module(self, module):
  218. import inspect
  219. if id(module.__dict__) not in self.visited_frames_ids:
  220. keys = (
  221. getattr(module, "__all__")
  222. if hasattr(module, "__all__")
  223. else module.__dict__.keys()
  224. )
  225. for k in keys:
  226. v = getattr(module, k)
  227. if inspect.isfunction(v) and not k.startswith("_"):
  228. self.patch_function(module.__dict__, k, self.wrap_fn)
  229. self.visited_frames_ids.add(id(module.__dict__))
  230. def auto_patch(self, frame_dict, check_frame_id=True):
  231. if id(frame_dict) not in self.visited_frames_ids or not check_frame_id:
  232. for k, v in frame_dict.items():
  233. if id(v) in self.patched_fn_ids:
  234. self.patch_function(frame_dict, k, self.wrap_fn)
  235. self.visited_frames_ids.add(id(frame_dict))
  236. def __enter__(self):
  237. return self
  238. def __exit__(self, type, vlaue, trace):
  239. while self.patched_fn:
  240. pf = self.patched_fn.pop()
  241. pf.set_func(pf.origin_fn)
  242. self.visited_frames_ids.clear()