Browse Source

feat(mge/serialization): add map location

GitOrigin-RevId: 4b6d83365b
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
1d7fcecab2
5 changed files with 115 additions and 3 deletions
  1. +10
    -0
      python_module/megengine/_internal/__init__.py
  2. +11
    -0
      python_module/megengine/_internal/config.py
  3. +80
    -3
      python_module/megengine/core/serialization.py
  4. +13
    -0
      python_module/src/swig/comp_node.i
  5. +1
    -0
      python_module/src/swig/mgb.i

+ 10
- 0
python_module/megengine/_internal/__init__.py View File

@@ -291,6 +291,16 @@ def current_grad_target(comp_graph):
return _detail._current_grad_target(comp_graph)


def add_device_map(map_location):
"""add map location while loading models"""
_detail.CompNode.cn_thread_local.__setattr__("map_location", map_location)


def del_device_map():
"""delete map location"""
_detail.CompNode.cn_thread_local.__delattr__("map_location")


def inter_graph_trans_var(dest_graph, src):
"""get the corresponding var of *src* in *dest_graph*; assuming
*dest_graph* is a copy of owner graph of *src*; usually used in callback of


+ 11
- 0
python_module/megengine/_internal/config.py View File

@@ -107,6 +107,17 @@ def get_device_count(device_type="xpu", warn=True):
return _mgb.CompNode._get_device_count(device_type.upper(), warn)


def parse_locator(device_name: str) -> tuple:
"""get the tensor locator expression by device name.

:param device_name: device name, like 'cpu0', 'gpu1' and 'xpux'
:type device_name: str

:return: (device_type, dev_num, stream_num)
"""
return _mgb.CompNode._parse_locator(device_name)


def set_mem_reserve_size(size):
"""set memory reserve size:



+ 80
- 3
python_module/megengine/core/serialization.py View File

@@ -8,7 +8,10 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle

import megengine._internal as mgb

from ..utils.max_recursion_limit import max_recursion_limit
from .device import get_default_device


def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL):
@@ -36,16 +39,90 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL):
pickle_module.dump(obj, f, pickle_protocol)


def load(f, pickle_module=pickle):
class dmap:
def __init__(self, map_location):
self.map_location = map_location

def __enter__(self):
mgb.add_device_map(self.map_location)
return self

def __exit__(self, type, value, traceback):
mgb.del_device_map()


def _get_callable_map_location(map_location):
if map_location is None:

def callable_map_location(state):
return str(get_default_device())

elif isinstance(map_location, str):

def callable_map_location(state):
return map_location

elif isinstance(map_location, dict):
locator_map = {}
for key, value in map_location.items():
locator_key = mgb.config.parse_locator(key)[:2]
locator_map[locator_key] = value

def callable_map_location(state):
orig = mgb.config.parse_locator(state)[:2]
if orig in locator_map.keys():
state = locator_map[orig]
return state

else:
assert callable(map_location), "map_location should be str, dict or function"
callable_map_location = map_location
return callable_map_location


def load(f, map_location=None, pickle_module=pickle):
r"""Load an object saved with save() from a file.

:type f: text file object
:param f: a string of file name or a text file object from which to load.
:type map_location: str, dict or a function specifying the map rules
:param map_location: Default: ``None``.

.. note::

map_location will change the logical locator when loading models,
avoiding tensors be loading on non-existent device. If you want to
add the mapping relationship between logical locator and physical
locator in runtime, please call :func:`mge.set_device_map()`

:type pickle_module:
:param pickle_module: Default: ``pickle``.

.. note::

If you will call :func:`mge.set_default_device()`, please do it
before :func:`mge.load()`.

Examples:

.. testcode:

import megengine as mge
mge.load('model.mge')
# Load all tensors based on logical location.
mge.load('model.mge', map_location='gpu0')
# Load all tensors onto the device: GPU0
mge.load('model.mge', map_location={'gpu0':'cpu0'})
# Load all tensors based on logical location, but 'GPU0' will be renamed to 'CPU0'
mge.load('model.mge', map_location=lambda dev: 'cpu0')
# Load all tensors onto the device" CPU0

"""
if isinstance(f, str):
with open(f, "rb") as fin:
return load(fin, pickle_module=pickle_module)
return pickle_module.load(f)
return load(fin, map_location=map_location, pickle_module=pickle_module)

map_location = _get_callable_map_location(map_location) # callable map_location

with dmap(map_location):
return pickle_module.load(f)

+ 13
- 0
python_module/src/swig/comp_node.i View File

@@ -28,6 +28,12 @@ class CompNode {
static CompNode load(const char* id);

%extend {
static std::vector<int> _parse_locator(const std::string &id) const {
auto logi = CompNode::Locator::parse(id);
return {
static_cast<int>(logi.type), logi.device, logi.stream,
};
}
static void _set_device_map(const std::string &type,
int from, int to) {
CompNode::Locator::set_device_map(
@@ -86,7 +92,14 @@ class CompNode {
2: 'CPU'
}

cn_thread_local = threading.local()
"""used to save map location when calling :func:`mge.load()`"""

def __setstate__(self, state):
""":func:`mge.load()` and :func:`deepcopy()` call this function,
The latter will not produce the map_location attribute"""
if "map_location" in CompNode.cn_thread_local.__dict__.keys():
state = CompNode.cn_thread_local.map_location(state)
self.this = CompNode_load(state).this

def __eq__(self, rhs):


+ 1
- 0
python_module/src/swig/mgb.i View File

@@ -35,6 +35,7 @@ void _init_bfloat16_types(PyObject *m); // implemented in bfloat16.cpp
%pythoncode %{
import numpy as np
import os
import threading
intb1 = _mgb.intb1
intb2 = _mgb.intb2
intb4 = _mgb.intb4


Loading…
Cancel
Save