Browse Source

fix(imperative): persistent cache write through on put

GitOrigin-RevId: f9408ae504
release-1.5
Megvii Engine Team 4 years ago
parent
commit
0ac642b5d5
1 changed files with 9 additions and 1 deletions
  1. +9
    -1
      imperative/python/src/utils.cpp

+ 9
- 1
imperative/python/src/utils.cpp View File

@@ -18,6 +18,7 @@
#include <pybind11/operators.h>
#include <atomic>
#include <cstdint>
#include <shared_mutex>
#include "./imperative_rt.h"
#include "megbrain/common.h"
#include "megbrain/comp_node.h"
@@ -236,12 +237,16 @@ void init_utils(py::module m) {
m.def("_timed_func_exec_cb", [](const std::string& user_data){
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
});

using mgb::PersistentCache;
class PyPersistentCache: public mgb::PersistentCache {
private:
using KeyPair = std::pair<std::string, std::string>;
using BlobPtr = std::unique_ptr<Blob, void(*)(Blob*)>;

std::shared_mutex m_mutex;
std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;

static size_t hash_key_pair(const KeyPair& kp) {
std::hash<std::string> hasher;
return hasher(kp.first) ^ hasher(kp.second);
@@ -275,11 +280,11 @@ void init_utils(py::module m) {
}
public:
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
thread_local std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;
auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe<Blob> {
PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key);
};
KeyPair kp = { category, blob_to_str(key) };
std::shared_lock<decltype(m_mutex)> rlock;
auto iter = m_local_cache.find(kp);
if (iter == m_local_cache.end()) {
auto py_ret = py_get(category, key);
@@ -296,6 +301,9 @@ void init_utils(py::module m) {
}
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
KeyPair kp = { category, blob_to_str(key) };
std::unique_lock<decltype(m_mutex)> wlock;
m_local_cache.insert_or_assign(kp, copy_blob(value));
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
}
};


Loading…
Cancel
Save