|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from importlib import import_module
- from typing import Dict, Tuple
-
- from ..core._imperative_rt import OpDef
- from ..core.ops import builtin
- from ..tensor import Tensor
- from ..version import __version__
- from .utils import _convert_kwargs_to_args
-
- OPDEF_LOADER = {}
- FUNCTIONAL_LOADER = {}
- TENSORMETHOD_LOADER = {}
- MODULE_LOADER = {}
-
-
- class _ModuleState:
- obj = None
-
- def __init__(self, module: Tuple, state: Dict, version: str):
- self.module = module
- self.state = state
- self.version = version
-
- @classmethod
- def get_module_state(cls, module):
- typem = (type(module).__module__, type(module).__qualname__)
- state = module.__dict__.copy()
- state.pop("_m_dump_modulestate", None)
- if hasattr(module, "_m_dump_modulestate"):
- assert isinstance(module._m_dump_modulestate, cls)
- module._m_dump_modulestate.__init__(typem, state, __version__)
- else:
- module.__dict__["_m_dump_modulestate"] = _ModuleState(
- typem, state, __version__
- )
-
- return module._m_dump_modulestate
-
- def __getstate__(self):
- return {"module": self.module, "state": self.state, "version": self.version}
-
- def to_module(self):
- if self.obj is None:
- typem = getattr(import_module(self.module[0]), self.module[1])
- m_obj = typem.__new__(typem)
- m_obj.__setstate__(self.state)
- self.obj = m_obj
- return self.obj
-
-
- def register_opdef_loader(*opdefs):
- def callback(loader):
- for opdef in opdefs:
- assert opdef not in OPDEF_LOADER
- OPDEF_LOADER[opdef] = loader
- return loader
-
- return callback
-
-
- def register_functional_loader(*funcs):
- def callback(loader):
- for func in funcs:
- assert func not in FUNCTIONAL_LOADER
- FUNCTIONAL_LOADER[func] = loader
- return loader
-
- return callback
-
-
- def register_module_loader(*module_types):
- def callback(loader):
- for module_type in module_types:
- assert module_type not in MODULE_LOADER
- MODULE_LOADER[module_type] = loader
- return loader
-
- return callback
-
-
- def register_tensor_method_loader(*methods):
- def callback(loader):
- for method in methods:
- assert method not in TENSORMETHOD_LOADER
- TENSORMETHOD_LOADER[method] = loader
- return loader
-
- return callback
-
-
- def _replace_args_kwargs(expr, new_args, new_kwargs):
- if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set(
- expr.kwargs.keys()
- ):
- expr.set_args_kwargs(*new_args, **new_kwargs)
-
-
- def load_functional(expr):
- func = (
- (expr.func.__module__, expr.func.__qualname__)
- if callable(expr.func)
- else expr.func
- )
- assert isinstance(func, tuple)
- if func in FUNCTIONAL_LOADER:
- loader = FUNCTIONAL_LOADER[func]
- loader(expr)
- mname, fname = func
- f = import_module(mname)
- for i in fname.split("."):
- f = getattr(f, i)
- expr.func = f
- assert callable(expr.func)
- if not hasattr(expr, "version") or expr.version != __version__:
- args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs)
- _replace_args_kwargs(expr, args, kwargs)
-
-
- def load_call_module_expr(expr):
- m_type = expr.inputs[0].module_type
- if isinstance(m_type, type):
- m_type = (m_type.__module__, m_type.__qualname__)
- if m_type in MODULE_LOADER:
- MODULE_LOADER[m_type](expr)
- if isinstance(expr.inputs[0].module_type, tuple):
- mname, classname = expr.inputs[0].module_type
- expr.inputs[0].module_type = getattr(import_module(mname), classname)
- if not hasattr(expr, "version") or expr.version != __version__:
- fwd_func = getattr(expr.inputs[0].module_type, "forward")
- args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs)
- _replace_args_kwargs(expr, args, kwargs)
-
-
- def load_call_tensor_method_expr(expr):
- if expr.method in TENSORMETHOD_LOADER:
- loader = TENSORMETHOD_LOADER[expr.method]
- loader(expr)
- if not hasattr(expr, "version") or expr.version != __version__:
- tmethod = (
- getattr(expr.args[0], expr.method)
- if isinstance(expr.args[0], type)
- else getattr(Tensor, expr.method)
- )
- args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs)
- _replace_args_kwargs(expr, args, kwargs)
-
-
- def load_apply_expr(expr):
- opdef_type = type(expr.opdef)
- if opdef_type in OPDEF_LOADER:
- OPDEF_LOADER[opdef_type](expr)
- opdef_state = expr.opdef_state
- opdef_obj = opdef_state.pop("opdef_type")()
- opdef_obj.__setstate__(opdef_state)
- expr.opdef = opdef_obj
|