You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialization.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # -*- coding: utf-8 -*-
  2. import pickle
  3. from .device import _valid_device, get_default_device
  4. from .tensor import Tensor
  5. from .utils.max_recursion_limit import max_recursion_limit
  6. def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL):
  7. r"""Save an object to disk file.
  8. The saved object must be a :class:`~.module.Module`,
  9. :attr:`.Module.state_dict` or :attr:`.Optimizer.state_dict`.
  10. See :ref:`serialization-guide` for more details.
  11. Args:
  12. obj: object to be saved.
  13. f: a string of file name or a text file object to which ``obj`` is saved to.
  14. pickle_module: the module to use for pickling.
  15. pickle_protocol: the protocol to use for pickling.
  16. .. admonition:: If you are using MegEngine with different Python versions
  17. :class: warning
  18. Different Python version may use different DEFAULT/HIGHEST pickle protocol.
  19. If you want to :func:`~megengine.load` the saved object in another Python version,
  20. please make sure you have used the same protocol.
  21. .. admonition:: You can select to use ``pickle`` module directly
  22. This interface is a wrapper of :func:`pickle.dump`. If you want to use ``pickle``,
  23. See :py:mod:`pickle` for more information about how to set ``pickle_protocol``:
  24. * :py:data:`pickle.HIGHEST_PROTOCOL` - the highest protocol version available.
  25. * :py:data:`pickle.DEFAULT_PROTOCOL` - the default protocol version used for pickling.
  26. Examples:
  27. If you want to save object in a higher protocol version which current version Python
  28. not support, you can install other pickle module instead of the build-in one.
  29. Take ``pickle5`` as an example:
  30. >>> import pickle5 as pickle # doctest: +SKIP
  31. It's a backport of the pickle 5 protocol (PEP 574) and other pickle changes.
  32. So you can use it to save object in pickle 5 protocol and load it in Python 3.8+.
  33. Or you can use ``pickle5`` in this way (only used with this interface):
  34. .. code-block:: python
  35. import pickle5
  36. import megengine
  37. megengine.save(obj, f, pickle_module=pickle5, pickle_protocol=5)
  38. """
  39. if isinstance(f, str):
  40. with open(f, "wb") as fout:
  41. save(
  42. obj, fout, pickle_module=pickle_module, pickle_protocol=pickle_protocol
  43. )
  44. return
  45. with max_recursion_limit():
  46. assert hasattr(f, "write"), "{} does not support write".format(f)
  47. pickle_module.dump(obj, f, pickle_protocol)
  48. class _dmap:
  49. def __init__(self, map_location):
  50. self.map_location = map_location
  51. def __enter__(self):
  52. Tensor.dmap_callback = staticmethod(self.map_location)
  53. return self
  54. def __exit__(self, type, value, traceback):
  55. Tensor.dmap_callback = None
  56. def _get_callable_map_location(map_location):
  57. if map_location is None:
  58. def callable_map_location(state):
  59. return state
  60. elif isinstance(map_location, str):
  61. def callable_map_location(state):
  62. return map_location
  63. elif isinstance(map_location, dict):
  64. for key, value in map_location.items():
  65. # dict key and values can only be "xpux", "cpux", "gpu0", etc.
  66. assert _valid_device(key), "Invalid locator_map key value {}".format(key)
  67. assert _valid_device(value), "Invalid locator_map key value {}".format(
  68. value
  69. )
  70. def callable_map_location(state):
  71. if state[:4] in map_location.keys():
  72. state = map_location[state[:4]]
  73. return state
  74. else:
  75. assert callable(map_location), "map_location should be str, dict or function"
  76. callable_map_location = map_location
  77. return callable_map_location
  78. def load(f, map_location=None, pickle_module=pickle):
  79. r"""Load an object saved with :func:`~.megengine.save` from a file.
  80. Args:
  81. f: a string of file name or a text file object from which to load.
  82. map_location: defines device mapping. See examples for usage.
  83. pickle_module: the module to use for pickling.
  84. Note:
  85. If you will call :func:`~.megengine.set_default_device()`, please do it
  86. before :func:`~.megengine.load()`.
  87. .. admonition:: If you are using MegEngine with different Python versions
  88. :class: warning
  89. Different Python version may use different DEFAULT/HIGHEST pickle protocol.
  90. If you want to :func:`~megengine.load` the saved object in another Python version,
  91. please make sure you have used the same protocol.
  92. .. admonition:: You can select to use ``pickle`` module directly
  93. This interface is a wrapper of :func:`pickle.load`. If you want to use ``pickle``,
  94. See :py:mod:`pickle` for more information about how to set ``pickle_protocol``:
  95. * :py:data:`pickle.HIGHEST_PROTOCOL` - the highest protocol version available.
  96. * :py:data:`pickle.DEFAULT_PROTOCOL` - the default protocol version used for pickling.
  97. Examples:
  98. This example shows how to load tenors to different devices:
  99. .. code-block::
  100. import megengine as mge
  101. # Load tensors to the same device as defined in model.pkl
  102. mge.load('model.pkl')
  103. # Load all tensors to gpu0.
  104. mge.load('model.pkl', map_location='gpu0')
  105. # Load all tensors originally on gpu0 to cpu0
  106. mge.load('model.pkl', map_location={'gpu0':'cpu0'})
  107. # Load all tensors to cpu0
  108. mge.load('model.pkl', map_location=lambda dev: 'cpu0')
  109. If you are using a lower version of Python (<3.8),
  110. you can use other pickle module like ``pickle5`` to load object saved in pickle 5 protocol:
  111. >>> import pickle5 as pickle # doctest: +SKIP
  112. Or you can use ``pickle5`` in this way (only used with this interface):
  113. .. code-block:: python
  114. import pickle5
  115. import megengine
  116. megengine.load(obj, pickle_module=pickle5)
  117. """
  118. if isinstance(f, str):
  119. with open(f, "rb") as fin:
  120. return load(fin, map_location=map_location, pickle_module=pickle_module)
  121. map_location = _get_callable_map_location(map_location) # callable map_location
  122. with _dmap(map_location) as dm:
  123. return pickle_module.load(f)