Browse Source

fix(imperative): persistent cache write through on put

GitOrigin-RevId: f9408ae504
release-1.5
Megvii Engine Team 4 years ago
parent
commit
0ac642b5d5
1 changed files with 9 additions and 1 deletions
  1. +9
    -1
      imperative/python/src/utils.cpp

+ 9
- 1
imperative/python/src/utils.cpp View File

@@ -18,6 +18,7 @@
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <atomic> #include <atomic>
#include <cstdint> #include <cstdint>
#include <shared_mutex>
#include "./imperative_rt.h" #include "./imperative_rt.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
@@ -236,12 +237,16 @@ void init_utils(py::module m) {
m.def("_timed_func_exec_cb", [](const std::string& user_data){ m.def("_timed_func_exec_cb", [](const std::string& user_data){
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
}); });

using mgb::PersistentCache; using mgb::PersistentCache;
class PyPersistentCache: public mgb::PersistentCache { class PyPersistentCache: public mgb::PersistentCache {
private: private:
using KeyPair = std::pair<std::string, std::string>; using KeyPair = std::pair<std::string, std::string>;
using BlobPtr = std::unique_ptr<Blob, void(*)(Blob*)>; using BlobPtr = std::unique_ptr<Blob, void(*)(Blob*)>;


std::shared_mutex m_mutex;
std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;

static size_t hash_key_pair(const KeyPair& kp) { static size_t hash_key_pair(const KeyPair& kp) {
std::hash<std::string> hasher; std::hash<std::string> hasher;
return hasher(kp.first) ^ hasher(kp.second); return hasher(kp.first) ^ hasher(kp.second);
@@ -275,11 +280,11 @@ void init_utils(py::module m) {
} }
public: public:
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
thread_local std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;
auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe<Blob> { auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe<Blob> {
PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key); PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key);
}; };
KeyPair kp = { category, blob_to_str(key) }; KeyPair kp = { category, blob_to_str(key) };
std::shared_lock<decltype(m_mutex)> rlock;
auto iter = m_local_cache.find(kp); auto iter = m_local_cache.find(kp);
if (iter == m_local_cache.end()) { if (iter == m_local_cache.end()) {
auto py_ret = py_get(category, key); auto py_ret = py_get(category, key);
@@ -296,6 +301,9 @@ void init_utils(py::module m) {
} }
} }
void put(const std::string& category, const Blob& key, const Blob& value) override { void put(const std::string& category, const Blob& key, const Blob& value) override {
KeyPair kp = { category, blob_to_str(key) };
std::unique_lock<decltype(m_mutex)> wlock;
m_local_cache.insert_or_assign(kp, copy_blob(value));
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
} }
}; };


Loading…
Cancel
Save