Browse Source

feat(imperative): port persistent cache

GitOrigin-RevId: 8ca24a37cc
release-1.1
Megvii Engine Team 4 years ago
parent
commit
ab6328c563
5 changed files with 172 additions and 0 deletions
  1. +5
    -0
      imperative/python/megengine/__init__.py
  2. +90
    -0
      imperative/python/megengine/utils/persistent_cache.py
  3. +44
    -0
      imperative/python/src/helper.h
  4. +17
    -0
      imperative/python/src/utils.cpp
  5. +16
    -0
      imperative/python/test/unit/test_utils.py

+ 5
- 0
imperative/python/megengine/__init__.py View File

@@ -78,13 +78,18 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
from .tensor import Parameter, Tensor, tensor from .tensor import Parameter, Tensor, tensor
from .version import __version__ from .version import __version__
from .utils import persistent_cache, comp_graph_tools as cgtools


_set_fork_exec_path_for_timed_func( _set_fork_exec_path_for_timed_func(
sys.executable, sys.executable,
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
) )


_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()

atexit.register(sync) atexit.register(sync)


del sync del sync
del _set_fork_exec_path_for_timed_func del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins

+ 90
- 0
imperative/python/megengine/utils/persistent_cache.py View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import argparse
import getpass
import json
import os
import shelve

from ..core._imperative_rt import PersistentCache as _PersistentCache
from ..logger import get_logger
from ..version import __version__


class _FakeRedisConn:
def __init__(self):
try:
from ..hub.hub import _get_megengine_home

cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache")
)
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "cache")
self._dict = shelve.open(cache_file)
self._is_shelve = True
except:
self._dict = {}
self._is_shelve = False

def get(self, key):
if self._is_shelve and isinstance(key, bytes):
key = key.decode("utf-8")

return self._dict.get(key)

def set(self, key, val):
if self._is_shelve and isinstance(key, bytes):
key = key.decode("utf-8")

self._dict[key] = val

def __del__(self):
if self._is_shelve:
self._dict.close()


class PersistentCacheOnServer(_PersistentCache):
_cached_conn = None
_prefix = None
_prev_get_refkeep = None

@property
def _conn(self):
"""get redis connection"""
if self._cached_conn is None:
self._cached_conn = _FakeRedisConn()
self._prefix = self.make_user_prefix()

return self._cached_conn

@classmethod
def make_user_prefix(cls):
return "mgbcache:{}".format(getpass.getuser())


def _make_key(self, category, key):
prefix_with_version = "{}:MGB{}".format(self._prefix, __version__)
return b"@".join(
(prefix_with_version.encode("ascii"), category.encode("ascii"), key)
)

def put(self, category, key, value):
conn = self._conn
key = self._make_key(category, key)
conn.set(key, value)

def get(self, category, key):
conn = self._conn
key = self._make_key(category, key)
self._prev_get_refkeep = conn.get(key)
return self._prev_get_refkeep



+ 44
- 0
imperative/python/src/helper.h View File

@@ -12,6 +12,7 @@
#pragma once #pragma once


#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/utils/persistent_cache.h"


#include <Python.h> #include <Python.h>
#include <string> #include <string>
@@ -328,6 +329,49 @@ namespace detail {


template<> struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {}; template<> struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {};


template <> struct type_caster<mgb::PersistentCache::Blob> {
PYBIND11_TYPE_CASTER(mgb::PersistentCache::Blob, _("Blob"));
public:
bool load(handle src, bool convert) {
if (!isinstance<bytes>(src)) {
return false;
}
value.ptr = PYBIND11_BYTES_AS_STRING(src.ptr());
value.size = PYBIND11_BYTES_SIZE(src.ptr());
return true;
}
static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) {
return bytes((const char*)blob.ptr, blob.size);
}
};

template <typename T> struct type_caster<mgb::Maybe<T>> {
using value_conv = make_caster<T>;
PYBIND11_TYPE_CASTER(mgb::Maybe<T>, _("Optional[") + value_conv::name + _("]"));
public:
bool load(handle src, bool convert) {
if(!src) {
return false;
}
if (src.is_none()) {
return true;
}
value_conv inner_caster;
if (!inner_caster.load(src, convert)) {
return false;
}
value.emplace(cast_op<T&&>(std::move(inner_caster)));
return true;
}

static handle cast(mgb::Maybe<T> src, return_value_policy policy, handle parent) {
if(!src.valid()) {
return none().inc_ref();
}
return pybind11::cast(src.val(), policy, parent);
}
};

} // detail } // detail
} // PYBIND11_NAMESPACE } // PYBIND11_NAMESPACE




+ 17
- 0
imperative/python/src/utils.cpp View File

@@ -25,6 +25,7 @@
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/tensor_sanity_check.h" #include "megbrain/imperative/tensor_sanity_check.h"
#include "megbrain/serialization/helper.h" #include "megbrain/serialization/helper.h"
#include "megbrain/utils/persistent_cache.h"


#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM
#include "megbrain/opr/mm_handler.h" #include "megbrain/opr/mm_handler.h"
@@ -262,4 +263,20 @@ void init_utils(py::module m) {
m.def("_timed_func_exec_cb", [](const std::string& user_data){ m.def("_timed_func_exec_cb", [](const std::string& user_data){
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
}); });
using mgb::PersistentCache;
class PyPersistentCache: public mgb::PersistentCache{
public:
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key);
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
}

};
py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>(m, "PersistentCache")
.def(py::init<>())
.def("get", &PersistentCache::get)
.def("put", &PersistentCache::put)
.def("reg", &PersistentCache::set_impl);
} }

+ 16
- 0
imperative/python/test/unit/test_utils.py View File

@@ -0,0 +1,16 @@
import pytest

import megengine
from megengine.utils.persistent_cache import PersistentCacheOnServer


def test_persistent_cache():
pc = PersistentCacheOnServer()
k0 = b"\x00\x00"
k1 = b"\x00\x01"
cat = "test"
pc.put(cat, k0, k1)
pc.put(cat, k1, k0)
assert k1 == pc.get(cat, k0)
assert k0 == pc.get(cat, k1)
assert pc.get("test1", k0) == None

Loading…
Cancel
Save