|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import collections
-
- from .. import Tensor
- from .. import functional as F
- from ..core.tensor.array_method import ArrayMethodMixin
- from ..module import Module
- from ..module.qat import QATModule
- from .checker import TracedModuleChecker
-
- _active_module_tracer = None
-
- BUILTIN_ARRAY_METHOD = [
- "__lt__",
- "__le__",
- "__gt__",
- "__ge__",
- "__eq__",
- "__ne__",
- "__neg__",
- "__pos__",
- "__abs__",
- "__invert__",
- "__round__",
- "__floor__",
- "__ceil__",
- "__add__",
- "__sub__",
- "__mul__",
- "__matmul__",
- "__truediv__",
- "__floordiv__",
- "__mod__",
- "__pow__",
- "__lshift__",
- "__rshift__",
- "__and__",
- "__or__",
- "__xor__",
- "__radd__",
- "__rsub__",
- "__rmul__",
- "__rmatmul__",
- "__rtruediv__",
- "__rfloordiv__",
- "__rmod__",
- "__rpow__",
- "__rlshift__",
- "__rrshift__",
- "__rand__",
- "__ror__",
- "__rxor__",
- "__iadd__",
- "__isub__",
- "__imul__",
- "__imatmul__",
- "__itruediv__",
- "__ifloordiv__",
- "__imod__",
- "__ipow__",
- "__ilshift__",
- "__irshift__",
- "__iand__",
- "__ior__",
- "__ixor__",
- "transpose",
- "astype",
- "reshape",
- "_broadcast",
- "flatten",
- "sum",
- "prod",
- "min",
- "max",
- "mean",
- "__getitem__",
- "__setitem__",
- ]
-
- BUILTIN_TENSOR_WRAP_METHOD = [
- "T",
- "to",
- "size",
- "shape",
- "detach",
- "device",
- "dtype",
- "grad",
- "item",
- "ndim",
- "numpy",
- "qparams",
- "set_value",
- "reset_zero",
- "requires_grad",
- "_reset",
- "_isscalar",
- "_tuple_shape",
- ]
-
-
- def get_tensor_wrapable_method():
- return BUILTIN_TENSOR_WRAP_METHOD + BUILTIN_ARRAY_METHOD
-
-
- def active_module_tracer():
- return _active_module_tracer
-
-
- def set_active_module_tracer(tracer):
- global _active_module_tracer
- _active_module_tracer = tracer
-
-
- class module_tracer:
-
- # builtin types
- _opaque_types = set()
-
- _active_scopes = None
-
- def __init__(self, wrap_fn):
- self._active_scopes = []
- self.checker = TracedModuleChecker(self)
- self.patcher = Patcher(wrap_fn)
- self._activate_constant_cache = []
-
- @classmethod
- def register_as_builtin(cls, mod):
- assert issubclass(mod, Module)
- cls._opaque_types.add(mod)
- return mod
-
- @classmethod
- def is_builtin(cls, mod):
- return type(mod) in cls._opaque_types
-
- def push_scope(self, scope):
- self._active_scopes.append(scope)
- self.checker.push_scope()
- self._activate_constant_cache.append([])
-
- def pop_scope(self):
- self._active_scopes.pop()
- self.checker.pop_scope()
- cache = self._activate_constant_cache.pop()
- for obj in cache:
- if hasattr(obj, "_NodeMixin__node"):
- delattr(obj, "_NodeMixin__node")
-
- def current_scope(self):
- if self._active_scopes:
- return self._active_scopes[-1]
- return None
-
- def current_constant_cache(self):
- if self._activate_constant_cache:
- return self._activate_constant_cache[-1]
- return None
-
- def top_scope(self):
- if self._active_scopes:
- return self._active_scopes[0]
- return None
-
-
- class NotExist:
- pass
-
-
- class PatchedFn:
- frame_dict = None
- name = None
- origin_fn = None
-
- def __init__(self, frame_dict, name):
- self.frame_dict = frame_dict
- self.name = name
- self.origin_fn = (
- self.frame_dict[name]
- if isinstance(frame_dict, collections.abc.Mapping)
- else getattr(frame_dict, name, NotExist)
- )
-
- def set_func(self, func):
- if isinstance(self.frame_dict, collections.abc.Mapping):
- self.frame_dict[self.name] = func
- else:
- if func is not NotExist:
- setattr(self.frame_dict, self.name, func)
- else:
- delattr(self.frame_dict, self.name)
-
-
- class Patcher:
-
- _builtin_functions = []
- _builtin_modules = [
- F,
- F.distributed,
- F.elemwise,
- F.inplace,
- F.loss,
- F.math,
- F.metric,
- F.nn,
- F.quantized,
- F.tensor,
- F.utils,
- F.vision,
- ]
- _builtin_methods = [
- Tensor,
- ArrayMethodMixin,
- ]
-
- def __init__(self, wrap_fn):
- self.patched_fn_ids = set()
- self.patched_fn = []
- self.visited_frames_ids = set()
- self.wrap_fn = wrap_fn
- for module in self._builtin_modules:
- self.patch_module(module)
- # some functions in F.nn are import from other module, and not in __all__
- self.auto_patch(F.nn.__dict__, False)
- for meth in BUILTIN_ARRAY_METHOD:
- self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
- self.patch_method(Tensor, "detach", self.wrap_fn)
- self.patch_method(Tensor, "__new__", self.wrap_fn)
- self.patch_method(QATModule, "_apply_fakequant_with_observer", self.wrap_fn)
- for i, j in self._builtin_functions:
- if id(i) not in self.visited_frames_ids:
- self.patch_function(i, j, self.wrap_fn)
-
- for m in module_tracer._opaque_types:
- self.auto_patch(getattr(getattr(m, "forward", m), "__globals__", {}))
-
- def patch_function(self, frame_dict, fn, wrap_fn):
- patched_fn = PatchedFn(frame_dict, fn)
- self.patched_fn_ids.add(id(patched_fn.origin_fn))
- patched_fn.set_func(wrap_fn(patched_fn.origin_fn))
- self.patched_fn.append(patched_fn)
-
- def patch_method(self, cls, name, wrap_fn):
- self.patch_function(cls, name, wrap_fn)
-
- def patch_cls(self, cls):
- import inspect
-
- if id(cls) not in self.visited_frames_ids:
- for k, v in cls.__dict__.items():
- if inspect.isfunction(v) and not k.startswith("_"):
- self.patch_function(cls, k, self.wrap_fn)
- self.visited_frames_ids.add(id(cls))
-
- def patch_module(self, module):
- import inspect
-
- if id(module.__dict__) not in self.visited_frames_ids:
- keys = (
- getattr(module, "__all__")
- if hasattr(module, "__all__")
- else module.__dict__.keys()
- )
- for k in keys:
- v = getattr(module, k)
- if inspect.isfunction(v) and not k.startswith("_"):
- self.patch_function(module.__dict__, k, self.wrap_fn)
- self.visited_frames_ids.add(id(module.__dict__))
-
- def auto_patch(self, frame_dict, check_frame_id=True):
- if id(frame_dict) not in self.visited_frames_ids or not check_frame_id:
- for k, v in frame_dict.items():
- if id(v) in self.patched_fn_ids:
- self.patch_function(frame_dict, k, self.wrap_fn)
- self.visited_frames_ids.add(id(frame_dict))
-
- def __enter__(self):
- return self
-
- def __exit__(self, type, vlaue, trace):
- while self.patched_fn:
- pf = self.patched_fn.pop()
- pf.set_func(pf.origin_fn)
- self.visited_frames_ids.clear()
|