From 0ac642b5d573b41fe71b90ea85e14b50e3b1711d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Jun 2021 16:47:40 +0800 Subject: [PATCH] fix(imperative): persistent cache write through on put GitOrigin-RevId: f9408ae5046a28b06b6de914f430a15c78ab82d9 --- imperative/python/src/utils.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index e43a8ac5..9c09b1f8 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include "./imperative_rt.h" #include "megbrain/common.h" #include "megbrain/comp_node.h" @@ -236,12 +237,16 @@ void init_utils(py::module m) { m.def("_timed_func_exec_cb", [](const std::string& user_data){ mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); }); + using mgb::PersistentCache; class PyPersistentCache: public mgb::PersistentCache { private: using KeyPair = std::pair; using BlobPtr = std::unique_ptr; + std::shared_mutex m_mutex; + std::unordered_map m_local_cache; + static size_t hash_key_pair(const KeyPair& kp) { std::hash hasher; return hasher(kp.first) ^ hasher(kp.second); @@ -275,11 +280,11 @@ void init_utils(py::module m) { } public: mgb::Maybe get(const std::string& category, const Blob& key) override { - thread_local std::unordered_map m_local_cache; auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe { PYBIND11_OVERLOAD_PURE(mgb::Maybe, PersistentCache, get, category, key); }; KeyPair kp = { category, blob_to_str(key) }; + std::shared_lock rlock; auto iter = m_local_cache.find(kp); if (iter == m_local_cache.end()) { auto py_ret = py_get(category, key); @@ -296,6 +301,9 @@ void init_utils(py::module m) { } } void put(const std::string& category, const Blob& key, const Blob& value) override { + KeyPair kp = { category, blob_to_str(key) }; + std::unique_lock wlock; + m_local_cache.insert_or_assign(kp, copy_blob(value)); PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); } };