You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialization.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from importlib import import_module
  2. from typing import Dict, Tuple
  3. from ..core._imperative_rt import OpDef
  4. from ..core.ops import builtin
  5. from ..tensor import Tensor
  6. from ..version import __version__
  7. from .utils import _convert_kwargs_to_args
  8. OPDEF_LOADER = {}
  9. FUNCTIONAL_LOADER = {}
  10. TENSORMETHOD_LOADER = {}
  11. MODULE_LOADER = {}
  12. class _ModuleState:
  13. obj = None
  14. def __init__(self, module: Tuple, state: Dict, version: str):
  15. self.module = module
  16. self.state = state
  17. self.version = version
  18. @classmethod
  19. def get_module_state(cls, module):
  20. typem = (type(module).__module__, type(module).__qualname__)
  21. state = module.__dict__.copy()
  22. state.pop("_m_dump_modulestate", None)
  23. if hasattr(module, "_m_dump_modulestate"):
  24. assert isinstance(module._m_dump_modulestate, cls)
  25. module._m_dump_modulestate.__init__(typem, state, __version__)
  26. else:
  27. module.__dict__["_m_dump_modulestate"] = _ModuleState(
  28. typem, state, __version__
  29. )
  30. return module._m_dump_modulestate
  31. def __getstate__(self):
  32. return {"module": self.module, "state": self.state, "version": self.version}
  33. def to_module(self):
  34. if self.obj is None:
  35. typem = getattr(import_module(self.module[0]), self.module[1])
  36. m_obj = typem.__new__(typem)
  37. m_obj.__setstate__(self.state)
  38. self.obj = m_obj
  39. return self.obj
  40. def register_opdef_loader(*opdefs):
  41. def callback(loader):
  42. for opdef in opdefs:
  43. assert opdef not in OPDEF_LOADER
  44. OPDEF_LOADER[opdef] = loader
  45. return loader
  46. return callback
  47. def register_functional_loader(*funcs):
  48. def callback(loader):
  49. for func in funcs:
  50. assert func not in FUNCTIONAL_LOADER
  51. FUNCTIONAL_LOADER[func] = loader
  52. return loader
  53. return callback
  54. def register_module_loader(*module_types):
  55. def callback(loader):
  56. for module_type in module_types:
  57. assert module_type not in MODULE_LOADER
  58. MODULE_LOADER[module_type] = loader
  59. return loader
  60. return callback
  61. def register_tensor_method_loader(*methods):
  62. def callback(loader):
  63. for method in methods:
  64. assert method not in TENSORMETHOD_LOADER
  65. TENSORMETHOD_LOADER[method] = loader
  66. return loader
  67. return callback
  68. def _replace_args_kwargs(expr, new_args, new_kwargs):
  69. if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set(
  70. expr.kwargs.keys()
  71. ):
  72. expr.set_args_kwargs(*new_args, **new_kwargs)
  73. def load_functional(expr):
  74. func = (
  75. (expr.func.__module__, expr.func.__qualname__)
  76. if callable(expr.func)
  77. else expr.func
  78. )
  79. assert isinstance(func, tuple)
  80. if func in FUNCTIONAL_LOADER:
  81. loader = FUNCTIONAL_LOADER[func]
  82. loader(expr)
  83. mname, fname = func
  84. f = import_module(mname)
  85. for i in fname.split("."):
  86. f = getattr(f, i)
  87. expr.func = f
  88. assert callable(expr.func)
  89. if not hasattr(expr, "version") or expr.version != __version__:
  90. args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs)
  91. _replace_args_kwargs(expr, args, kwargs)
  92. def load_call_module_expr(expr):
  93. m_type = expr.inputs[0].module_type
  94. if isinstance(m_type, type):
  95. m_type = (m_type.__module__, m_type.__qualname__)
  96. if m_type in MODULE_LOADER:
  97. MODULE_LOADER[m_type](expr)
  98. if isinstance(expr.inputs[0].module_type, tuple):
  99. mname, classname = expr.inputs[0].module_type
  100. expr.inputs[0].module_type = getattr(import_module(mname), classname)
  101. if not hasattr(expr, "version") or expr.version != __version__:
  102. fwd_func = getattr(expr.inputs[0].module_type, "forward")
  103. args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs)
  104. _replace_args_kwargs(expr, args, kwargs)
  105. def load_call_tensor_method_expr(expr):
  106. if expr.method in TENSORMETHOD_LOADER:
  107. loader = TENSORMETHOD_LOADER[expr.method]
  108. loader(expr)
  109. if not hasattr(expr, "version") or expr.version != __version__:
  110. tmethod = (
  111. getattr(expr.args[0], expr.method)
  112. if isinstance(expr.args[0], type)
  113. else getattr(Tensor, expr.method)
  114. )
  115. args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs)
  116. _replace_args_kwargs(expr, args, kwargs)
  117. def load_apply_expr(expr):
  118. opdef_type = type(expr.opdef)
  119. if opdef_type in OPDEF_LOADER:
  120. OPDEF_LOADER[opdef_type](expr)
  121. opdef_state = expr.opdef_state
  122. opdef_obj = opdef_state.pop("opdef_type")()
  123. opdef_obj.__setstate__(opdef_state)
  124. expr.opdef = opdef_obj