You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialization.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from importlib import import_module
  9. from typing import Dict, Tuple
  10. from ..core._imperative_rt import OpDef
  11. from ..core.ops import builtin
  12. from ..tensor import Tensor
  13. from ..version import __version__
  14. from .utils import _convert_kwargs_to_args
  15. OPDEF_LOADER = {}
  16. FUNCTIONAL_LOADER = {}
  17. TENSORMETHOD_LOADER = {}
  18. MODULE_LOADER = {}
  19. class _ModuleState:
  20. obj = None
  21. def __init__(self, module: Tuple, state: Dict, version: str):
  22. self.module = module
  23. self.state = state
  24. self.version = version
  25. @classmethod
  26. def get_module_state(cls, module):
  27. typem = (type(module).__module__, type(module).__qualname__)
  28. state = module.__dict__.copy()
  29. state.pop("_m_dump_modulestate", None)
  30. if hasattr(module, "_m_dump_modulestate"):
  31. assert isinstance(module._m_dump_modulestate, cls)
  32. module._m_dump_modulestate.__init__(typem, state, __version__)
  33. else:
  34. module.__dict__["_m_dump_modulestate"] = _ModuleState(
  35. typem, state, __version__
  36. )
  37. return module._m_dump_modulestate
  38. def __getstate__(self):
  39. return {"module": self.module, "state": self.state, "version": self.version}
  40. def to_module(self):
  41. if self.obj is None:
  42. typem = getattr(import_module(self.module[0]), self.module[1])
  43. m_obj = typem.__new__(typem)
  44. m_obj.__dict__.update(self.state)
  45. self.obj = m_obj
  46. return self.obj
  47. def register_opdef_loader(*opdefs):
  48. def callback(loader):
  49. for opdef in opdefs:
  50. assert opdef not in OPDEF_LOADER
  51. OPDEF_LOADER[opdef] = loader
  52. return loader
  53. return callback
  54. def register_functional_loader(*funcs):
  55. def callback(loader):
  56. for func in funcs:
  57. assert func not in FUNCTIONAL_LOADER
  58. FUNCTIONAL_LOADER[func] = loader
  59. return loader
  60. return callback
  61. def register_module_loader(*module_types):
  62. def callback(loader):
  63. for module_type in module_types:
  64. assert module_type not in MODULE_LOADER
  65. MODULE_LOADER[module_type] = loader
  66. return loader
  67. return callback
  68. def register_tensor_method_loader(*methods):
  69. def callback(loader):
  70. for method in methods:
  71. assert method not in TENSORMETHOD_LOADER
  72. TENSORMETHOD_LOADER[method] = loader
  73. return loader
  74. return callback
  75. def _replace_args_kwargs(expr, new_args, new_kwargs):
  76. if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set(
  77. expr.kwargs.keys()
  78. ):
  79. expr.set_args_kwargs(*new_args, **new_kwargs)
  80. def load_functional(expr):
  81. func = (
  82. (expr.func.__module__, expr.func.__qualname__)
  83. if callable(expr.func)
  84. else expr.func
  85. )
  86. assert isinstance(func, tuple)
  87. if func in FUNCTIONAL_LOADER:
  88. loader = FUNCTIONAL_LOADER[func]
  89. loader(expr)
  90. mname, fname = func
  91. f = import_module(mname)
  92. for i in fname.split("."):
  93. f = getattr(f, i)
  94. expr.func = f
  95. assert callable(expr.func)
  96. if not hasattr(expr, "version") or expr.version != __version__:
  97. args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs)
  98. _replace_args_kwargs(expr, args, kwargs)
  99. def load_call_module_expr(expr):
  100. m_type = expr.inputs[0].module_type
  101. if isinstance(m_type, type):
  102. m_type = (m_type.__module__, m_type.__qualname__)
  103. if m_type in MODULE_LOADER:
  104. MODULE_LOADER[m_type](expr)
  105. if isinstance(expr.inputs[0].module_type, tuple):
  106. mname, classname = expr.inputs[0].module_type
  107. expr.inputs[0].module_type = getattr(import_module(mname), classname)
  108. if not hasattr(expr, "version") or expr.version != __version__:
  109. fwd_func = getattr(expr.inputs[0].module_type, "forward")
  110. args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs)
  111. _replace_args_kwargs(expr, args, kwargs)
  112. def load_call_tensor_method_expr(expr):
  113. if expr.method in TENSORMETHOD_LOADER:
  114. loader = TENSORMETHOD_LOADER[expr.method]
  115. loader(expr)
  116. if not hasattr(expr, "version") or expr.version != __version__:
  117. tmethod = (
  118. getattr(expr.args[0], expr.method)
  119. if isinstance(expr.args[0], type)
  120. else getattr(Tensor, expr.method)
  121. )
  122. args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs)
  123. _replace_args_kwargs(expr, args, kwargs)
  124. def load_apply_expr(expr):
  125. opdef_type = type(expr.opdef)
  126. if opdef_type in OPDEF_LOADER:
  127. OPDEF_LOADER[opdef_type](expr)
  128. opdef_state = expr.opdef_state
  129. opdef_obj = opdef_state.pop("opdef_type")()
  130. opdef_obj.__setstate__(opdef_state)
  131. expr.opdef = opdef_obj

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台