|
|
@@ -8,7 +8,10 @@ |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import pickle |
|
|
|
|
|
|
|
import megengine._internal as mgb |
|
|
|
|
|
|
|
from ..utils.max_recursion_limit import max_recursion_limit |
|
|
|
from .device import get_default_device |
|
|
|
|
|
|
|
|
|
|
|
def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL): |
|
|
@@ -36,16 +39,90 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL): |
|
|
|
pickle_module.dump(obj, f, pickle_protocol) |
|
|
|
|
|
|
|
|
|
|
|
def load(f, pickle_module=pickle): |
|
|
|
class dmap: |
|
|
|
def __init__(self, map_location): |
|
|
|
self.map_location = map_location |
|
|
|
|
|
|
|
def __enter__(self): |
|
|
|
mgb.add_device_map(self.map_location) |
|
|
|
return self |
|
|
|
|
|
|
|
def __exit__(self, type, value, traceback): |
|
|
|
mgb.del_device_map() |
|
|
|
|
|
|
|
|
|
|
|
def _get_callable_map_location(map_location): |
|
|
|
if map_location is None: |
|
|
|
|
|
|
|
def callable_map_location(state): |
|
|
|
return str(get_default_device()) |
|
|
|
|
|
|
|
elif isinstance(map_location, str): |
|
|
|
|
|
|
|
def callable_map_location(state): |
|
|
|
return map_location |
|
|
|
|
|
|
|
elif isinstance(map_location, dict): |
|
|
|
locator_map = {} |
|
|
|
for key, value in map_location.items(): |
|
|
|
locator_key = mgb.config.parse_locator(key)[:2] |
|
|
|
locator_map[locator_key] = value |
|
|
|
|
|
|
|
def callable_map_location(state): |
|
|
|
orig = mgb.config.parse_locator(state)[:2] |
|
|
|
if orig in locator_map.keys(): |
|
|
|
state = locator_map[orig] |
|
|
|
return state |
|
|
|
|
|
|
|
else: |
|
|
|
assert callable(map_location), "map_location should be str, dict or function" |
|
|
|
callable_map_location = map_location |
|
|
|
return callable_map_location |
|
|
|
|
|
|
|
|
|
|
|
def load(f, map_location=None, pickle_module=pickle): |
|
|
|
r"""Load an object saved with save() from a file. |
|
|
|
|
|
|
|
:type f: text file object |
|
|
|
:param f: a string of file name or a text file object from which to load. |
|
|
|
:type map_location: str, dict or a function specifying the map rules |
|
|
|
:param map_location: Default: ``None``. |
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
|
|
map_location will change the logical locator when loading models, |
|
|
|
avoiding tensors be loading on non-existent device. If you want to |
|
|
|
add the mapping relationship between logical locator and physical |
|
|
|
locator in runtime, please call :func:`mge.set_device_map()` |
|
|
|
|
|
|
|
:type pickle_module: |
|
|
|
:param pickle_module: Default: ``pickle``. |
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
|
|
If you will call :func:`mge.set_default_device()`, please do it |
|
|
|
before :func:`mge.load()`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
|
.. testcode: |
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
mge.load('model.mge') |
|
|
|
# Load all tensors based on logical location. |
|
|
|
mge.load('model.mge', map_location='gpu0') |
|
|
|
# Load all tensors onto the device: GPU0 |
|
|
|
mge.load('model.mge', map_location={'gpu0':'cpu0'}) |
|
|
|
# Load all tensors based on logical location, but 'GPU0' will be renamed to 'CPU0' |
|
|
|
mge.load('model.mge', map_location=lambda dev: 'cpu0') |
|
|
|
# Load all tensors onto the device" CPU0 |
|
|
|
|
|
|
|
""" |
|
|
|
if isinstance(f, str): |
|
|
|
with open(f, "rb") as fin: |
|
|
|
return load(fin, pickle_module=pickle_module) |
|
|
|
return pickle_module.load(f) |
|
|
|
return load(fin, map_location=map_location, pickle_module=pickle_module) |
|
|
|
|
|
|
|
map_location = _get_callable_map_location(map_location) # callable map_location |
|
|
|
|
|
|
|
with dmap(map_location): |
|
|
|
return pickle_module.load(f) |