From ab6328c563a8bfe7c2e9c56fc1e012aaf83381bb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 29 Oct 2020 13:39:08 +0800 Subject: [PATCH] feat(imperative): port persistent cache GitOrigin-RevId: 8ca24a37cc28a0be3f659e0e8863fee1beac3a38 --- imperative/python/megengine/__init__.py | 5 ++ .../python/megengine/utils/persistent_cache.py | 90 ++++++++++++++++++++++ imperative/python/src/helper.h | 44 +++++++++++ imperative/python/src/utils.cpp | 17 ++++ imperative/python/test/unit/test_utils.py | 16 ++++ 5 files changed, 172 insertions(+) create mode 100644 imperative/python/megengine/utils/persistent_cache.py create mode 100644 imperative/python/test/unit/test_utils.py diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 32aee56b..5d86550b 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -78,13 +78,18 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .serialization import load, save from .tensor import Parameter, Tensor, tensor from .version import __version__ +from .utils import persistent_cache, comp_graph_tools as cgtools _set_fork_exec_path_for_timed_func( sys.executable, os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), ) +_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() +_persistent_cache_impl_ins.reg() + atexit.register(sync) del sync del _set_fork_exec_path_for_timed_func +del _persistent_cache_impl_ins diff --git a/imperative/python/megengine/utils/persistent_cache.py b/imperative/python/megengine/utils/persistent_cache.py new file mode 100644 index 00000000..0f142b76 --- /dev/null +++ b/imperative/python/megengine/utils/persistent_cache.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import argparse +import getpass +import json +import os +import shelve + +from ..core._imperative_rt import PersistentCache as _PersistentCache +from ..logger import get_logger +from ..version import __version__ + + +class _FakeRedisConn: + def __init__(self): + try: + from ..hub.hub import _get_megengine_home + + cache_dir = os.path.expanduser( + os.path.join(_get_megengine_home(), "persistent_cache") + ) + os.makedirs(cache_dir, exist_ok=True) + cache_file = os.path.join(cache_dir, "cache") + self._dict = shelve.open(cache_file) + self._is_shelve = True + except: + self._dict = {} + self._is_shelve = False + + def get(self, key): + if self._is_shelve and isinstance(key, bytes): + key = key.decode("utf-8") + + return self._dict.get(key) + + def set(self, key, val): + if self._is_shelve and isinstance(key, bytes): + key = key.decode("utf-8") + + self._dict[key] = val + + def __del__(self): + if self._is_shelve: + self._dict.close() + + +class PersistentCacheOnServer(_PersistentCache): + _cached_conn = None + _prefix = None + _prev_get_refkeep = None + + @property + def _conn(self): + """get redis connection""" + if self._cached_conn is None: + self._cached_conn = _FakeRedisConn() + self._prefix = self.make_user_prefix() + + return self._cached_conn + + @classmethod + def make_user_prefix(cls): + return "mgbcache:{}".format(getpass.getuser()) + + + def _make_key(self, category, key): + prefix_with_version = "{}:MGB{}".format(self._prefix, __version__) + return b"@".join( + (prefix_with_version.encode("ascii"), category.encode("ascii"), key) + ) + + def put(self, category, key, value): + conn = self._conn + key = self._make_key(category, key) + conn.set(key, value) + + def get(self, category, key): + conn = self._conn + key = self._make_key(category, key) + self._prev_get_refkeep = conn.get(key) + return self._prev_get_refkeep + + diff --git a/imperative/python/src/helper.h b/imperative/python/src/helper.h index ec0942c1..05650a3d 100644 --- a/imperative/python/src/helper.h +++ b/imperative/python/src/helper.h @@ -12,6 +12,7 @@ #pragma once #include "megbrain/graph.h" +#include "megbrain/utils/persistent_cache.h" #include #include @@ -328,6 +329,49 @@ namespace detail { template<> struct type_caster : public from_none_caster {}; + template <> struct type_caster { + PYBIND11_TYPE_CASTER(mgb::PersistentCache::Blob, _("Blob")); + public: + bool load(handle src, bool convert) { + if (!isinstance(src)) { + return false; + } + value.ptr = PYBIND11_BYTES_AS_STRING(src.ptr()); + value.size = PYBIND11_BYTES_SIZE(src.ptr()); + return true; + } + static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) { + return bytes((const char*)blob.ptr, blob.size); + } + }; + + template struct type_caster> { + using value_conv = make_caster; + PYBIND11_TYPE_CASTER(mgb::Maybe, _("Optional[") + value_conv::name + _("]")); + public: + bool load(handle src, bool convert) { + if(!src) { + return false; + } + if (src.is_none()) { + return true; + } + value_conv inner_caster; + if (!inner_caster.load(src, convert)) { + return false; + } + value.emplace(cast_op(std::move(inner_caster))); + return true; + } + + static handle cast(mgb::Maybe src, return_value_policy policy, handle parent) { + if(!src.valid()) { + return none().inc_ref(); + } + return pybind11::cast(src.val(), policy, parent); + } + }; + } // detail } // PYBIND11_NAMESPACE diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index 3169f272..0e7274d9 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -25,6 +25,7 @@ #include "megbrain/imperative/profiler.h" #include "megbrain/imperative/tensor_sanity_check.h" #include "megbrain/serialization/helper.h" +#include "megbrain/utils/persistent_cache.h" #if MGB_ENABLE_OPR_MM #include "megbrain/opr/mm_handler.h" @@ -262,4 +263,20 @@ void init_utils(py::module m) { m.def("_timed_func_exec_cb", [](const std::string& user_data){ mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); }); + using mgb::PersistentCache; + class PyPersistentCache: public mgb::PersistentCache{ + public: + mgb::Maybe get(const std::string& category, const Blob& key) override { + PYBIND11_OVERLOAD_PURE(mgb::Maybe, PersistentCache, get, category, key); + } + void put(const std::string& category, const Blob& key, const Blob& value) override { + PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); + } + + }; + py::class_>(m, "PersistentCache") + .def(py::init<>()) + .def("get", &PersistentCache::get) + .def("put", &PersistentCache::put) + .def("reg", &PersistentCache::set_impl); } diff --git a/imperative/python/test/unit/test_utils.py b/imperative/python/test/unit/test_utils.py new file mode 100644 index 00000000..88660b19 --- /dev/null +++ b/imperative/python/test/unit/test_utils.py @@ -0,0 +1,16 @@ +import pytest + +import megengine +from megengine.utils.persistent_cache import PersistentCacheOnServer + + +def test_persistent_cache(): + pc = PersistentCacheOnServer() + k0 = b"\x00\x00" + k1 = b"\x00\x01" + cat = "test" + pc.put(cat, k0, k1) + pc.put(cat, k1, k0) + assert k1 == pc.get(cat, k0) + assert k0 == pc.get(cat, k1) + assert pc.get("test1", k0) == None