@@ -78,13 +78,18 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||||
from .serialization import load, save | from .serialization import load, save | ||||
from .tensor import Parameter, Tensor, tensor | from .tensor import Parameter, Tensor, tensor | ||||
from .version import __version__ | from .version import __version__ | ||||
from .utils import persistent_cache, comp_graph_tools as cgtools | |||||
_set_fork_exec_path_for_timed_func( | _set_fork_exec_path_for_timed_func( | ||||
sys.executable, | sys.executable, | ||||
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | ||||
) | ) | ||||
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() | |||||
_persistent_cache_impl_ins.reg() | |||||
atexit.register(sync) | atexit.register(sync) | ||||
del sync | del sync | ||||
del _set_fork_exec_path_for_timed_func | del _set_fork_exec_path_for_timed_func | ||||
del _persistent_cache_impl_ins |
@@ -0,0 +1,90 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import argparse | |||||
import getpass | |||||
import json | |||||
import os | |||||
import shelve | |||||
from ..core._imperative_rt import PersistentCache as _PersistentCache | |||||
from ..logger import get_logger | |||||
from ..version import __version__ | |||||
class _FakeRedisConn: | |||||
def __init__(self): | |||||
try: | |||||
from ..hub.hub import _get_megengine_home | |||||
cache_dir = os.path.expanduser( | |||||
os.path.join(_get_megengine_home(), "persistent_cache") | |||||
) | |||||
os.makedirs(cache_dir, exist_ok=True) | |||||
cache_file = os.path.join(cache_dir, "cache") | |||||
self._dict = shelve.open(cache_file) | |||||
self._is_shelve = True | |||||
except: | |||||
self._dict = {} | |||||
self._is_shelve = False | |||||
def get(self, key): | |||||
if self._is_shelve and isinstance(key, bytes): | |||||
key = key.decode("utf-8") | |||||
return self._dict.get(key) | |||||
def set(self, key, val): | |||||
if self._is_shelve and isinstance(key, bytes): | |||||
key = key.decode("utf-8") | |||||
self._dict[key] = val | |||||
def __del__(self): | |||||
if self._is_shelve: | |||||
self._dict.close() | |||||
class PersistentCacheOnServer(_PersistentCache): | |||||
_cached_conn = None | |||||
_prefix = None | |||||
_prev_get_refkeep = None | |||||
@property | |||||
def _conn(self): | |||||
"""get redis connection""" | |||||
if self._cached_conn is None: | |||||
self._cached_conn = _FakeRedisConn() | |||||
self._prefix = self.make_user_prefix() | |||||
return self._cached_conn | |||||
@classmethod | |||||
def make_user_prefix(cls): | |||||
return "mgbcache:{}".format(getpass.getuser()) | |||||
def _make_key(self, category, key): | |||||
prefix_with_version = "{}:MGB{}".format(self._prefix, __version__) | |||||
return b"@".join( | |||||
(prefix_with_version.encode("ascii"), category.encode("ascii"), key) | |||||
) | |||||
def put(self, category, key, value): | |||||
conn = self._conn | |||||
key = self._make_key(category, key) | |||||
conn.set(key, value) | |||||
def get(self, category, key): | |||||
conn = self._conn | |||||
key = self._make_key(category, key) | |||||
self._prev_get_refkeep = conn.get(key) | |||||
return self._prev_get_refkeep | |||||
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/utils/persistent_cache.h" | |||||
#include <Python.h> | #include <Python.h> | ||||
#include <string> | #include <string> | ||||
@@ -328,6 +329,49 @@ namespace detail { | |||||
template<> struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {}; | template<> struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {}; | ||||
template <> struct type_caster<mgb::PersistentCache::Blob> { | |||||
PYBIND11_TYPE_CASTER(mgb::PersistentCache::Blob, _("Blob")); | |||||
public: | |||||
bool load(handle src, bool convert) { | |||||
if (!isinstance<bytes>(src)) { | |||||
return false; | |||||
} | |||||
value.ptr = PYBIND11_BYTES_AS_STRING(src.ptr()); | |||||
value.size = PYBIND11_BYTES_SIZE(src.ptr()); | |||||
return true; | |||||
} | |||||
static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) { | |||||
return bytes((const char*)blob.ptr, blob.size); | |||||
} | |||||
}; | |||||
template <typename T> struct type_caster<mgb::Maybe<T>> { | |||||
using value_conv = make_caster<T>; | |||||
PYBIND11_TYPE_CASTER(mgb::Maybe<T>, _("Optional[") + value_conv::name + _("]")); | |||||
public: | |||||
bool load(handle src, bool convert) { | |||||
if(!src) { | |||||
return false; | |||||
} | |||||
if (src.is_none()) { | |||||
return true; | |||||
} | |||||
value_conv inner_caster; | |||||
if (!inner_caster.load(src, convert)) { | |||||
return false; | |||||
} | |||||
value.emplace(cast_op<T&&>(std::move(inner_caster))); | |||||
return true; | |||||
} | |||||
static handle cast(mgb::Maybe<T> src, return_value_policy policy, handle parent) { | |||||
if(!src.valid()) { | |||||
return none().inc_ref(); | |||||
} | |||||
return pybind11::cast(src.val(), policy, parent); | |||||
} | |||||
}; | |||||
} // detail | } // detail | ||||
} // PYBIND11_NAMESPACE | } // PYBIND11_NAMESPACE | ||||
@@ -25,6 +25,7 @@ | |||||
#include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
#include "megbrain/imperative/tensor_sanity_check.h" | #include "megbrain/imperative/tensor_sanity_check.h" | ||||
#include "megbrain/serialization/helper.h" | #include "megbrain/serialization/helper.h" | ||||
#include "megbrain/utils/persistent_cache.h" | |||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
#include "megbrain/opr/mm_handler.h" | #include "megbrain/opr/mm_handler.h" | ||||
@@ -262,4 +263,20 @@ void init_utils(py::module m) { | |||||
m.def("_timed_func_exec_cb", [](const std::string& user_data){ | m.def("_timed_func_exec_cb", [](const std::string& user_data){ | ||||
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | ||||
}); | }); | ||||
using mgb::PersistentCache; | |||||
class PyPersistentCache: public mgb::PersistentCache{ | |||||
public: | |||||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key); | |||||
} | |||||
void put(const std::string& category, const Blob& key, const Blob& value) override { | |||||
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); | |||||
} | |||||
}; | |||||
py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>(m, "PersistentCache") | |||||
.def(py::init<>()) | |||||
.def("get", &PersistentCache::get) | |||||
.def("put", &PersistentCache::put) | |||||
.def("reg", &PersistentCache::set_impl); | |||||
} | } |
@@ -0,0 +1,16 @@ | |||||
import pytest | |||||
import megengine | |||||
from megengine.utils.persistent_cache import PersistentCacheOnServer | |||||
def test_persistent_cache(): | |||||
pc = PersistentCacheOnServer() | |||||
k0 = b"\x00\x00" | |||||
k1 = b"\x00\x01" | |||||
cat = "test" | |||||
pc.put(cat, k0, k1) | |||||
pc.put(cat, k1, k0) | |||||
assert k1 == pc.get(cat, k0) | |||||
assert k0 == pc.get(cat, k1) | |||||
assert pc.get("test1", k0) == None |