You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialization.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # -*- coding: utf-8 -*-
  2. import pickle
  3. from .device import _valid_device, get_default_device
  4. from .tensor import Tensor
  5. from .utils.max_recursion_limit import max_recursion_limit
  6. def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL):
  7. r"""Save an object to disk file.
  8. Args:
  9. obj: object to save. Only ``module`` or ``state_dict`` are allowed.
  10. f: a string of file name or a text file object to which ``obj`` is saved to.
  11. pickle_module: Default: ``pickle``.
  12. pickle_protocol: Default: ``pickle.DEFAULT_PROTOCOL``.
  13. """
  14. if isinstance(f, str):
  15. with open(f, "wb") as fout:
  16. save(
  17. obj, fout, pickle_module=pickle_module, pickle_protocol=pickle_protocol
  18. )
  19. return
  20. with max_recursion_limit():
  21. assert hasattr(f, "write"), "{} does not support write".format(f)
  22. pickle_module.dump(obj, f, pickle_protocol)
  23. class dmap:
  24. def __init__(self, map_location):
  25. self.map_location = map_location
  26. def __enter__(self):
  27. Tensor.dmap_callback = staticmethod(self.map_location)
  28. return self
  29. def __exit__(self, type, value, traceback):
  30. Tensor.dmap_callback = None
  31. def _get_callable_map_location(map_location):
  32. if map_location is None:
  33. def callable_map_location(state):
  34. return state
  35. elif isinstance(map_location, str):
  36. def callable_map_location(state):
  37. return map_location
  38. elif isinstance(map_location, dict):
  39. for key, value in map_location.items():
  40. # dict key and values can only be "xpux", "cpux", "gpu0", etc.
  41. assert _valid_device(key), "Invalid locator_map key value {}".format(key)
  42. assert _valid_device(value), "Invalid locator_map key value {}".format(
  43. value
  44. )
  45. def callable_map_location(state):
  46. if state[:4] in map_location.keys():
  47. state = map_location[state[:4]]
  48. return state
  49. else:
  50. assert callable(map_location), "map_location should be str, dict or function"
  51. callable_map_location = map_location
  52. return callable_map_location
  53. def load(f, map_location=None, pickle_module=pickle):
  54. r"""Load an object saved with :func:~.megengine.save` from a file.
  55. Args:
  56. f: a string of file name or a text file object from which to load.
  57. map_location: Default: ``None``.
  58. pickle_module: Default: ``pickle``.
  59. Note:
  60. * ``map_location`` defines device mapping. See examples for usage.
  61. * If you will call :func:`~.megengine.set_default_device()`, please do it
  62. before :func:`~.megengine.load()`.
  63. Examples:
  64. .. code-block::
  65. import megengine as mge
  66. # Load tensors to the same device as defined in model.pkl
  67. mge.load('model.pkl')
  68. # Load all tensors to gpu0.
  69. mge.load('model.pkl', map_location='gpu0')
  70. # Load all tensors originally on gpu0 to cpu0
  71. mge.load('model.pkl', map_location={'gpu0':'cpu0'})
  72. # Load all tensors to cpu0
  73. mge.load('model.pkl', map_location=lambda dev: 'cpu0')
  74. """
  75. if isinstance(f, str):
  76. with open(f, "rb") as fin:
  77. return load(fin, map_location=map_location, pickle_module=pickle_module)
  78. map_location = _get_callable_map_location(map_location) # callable map_location
  79. with dmap(map_location) as dm:
  80. return pickle_module.load(f)