|
|
@@ -18,6 +18,7 @@ |
|
|
|
#include <pybind11/operators.h> |
|
|
|
#include <atomic> |
|
|
|
#include <cstdint> |
|
|
|
#include <shared_mutex> |
|
|
|
#include "./imperative_rt.h" |
|
|
|
#include "megbrain/common.h" |
|
|
|
#include "megbrain/comp_node.h" |
|
|
@@ -236,12 +237,16 @@ void init_utils(py::module m) { |
|
|
|
m.def("_timed_func_exec_cb", [](const std::string& user_data){ |
|
|
|
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); |
|
|
|
}); |
|
|
|
|
|
|
|
using mgb::PersistentCache; |
|
|
|
class PyPersistentCache: public mgb::PersistentCache { |
|
|
|
private: |
|
|
|
using KeyPair = std::pair<std::string, std::string>; |
|
|
|
using BlobPtr = std::unique_ptr<Blob, void(*)(Blob*)>; |
|
|
|
|
|
|
|
std::shared_mutex m_mutex; |
|
|
|
std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache; |
|
|
|
|
|
|
|
static size_t hash_key_pair(const KeyPair& kp) { |
|
|
|
std::hash<std::string> hasher; |
|
|
|
return hasher(kp.first) ^ hasher(kp.second); |
|
|
@@ -275,11 +280,11 @@ void init_utils(py::module m) { |
|
|
|
} |
|
|
|
public: |
|
|
|
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { |
|
|
|
thread_local std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache; |
|
|
|
auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe<Blob> { |
|
|
|
PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key); |
|
|
|
}; |
|
|
|
KeyPair kp = { category, blob_to_str(key) }; |
|
|
|
std::shared_lock<decltype(m_mutex)> rlock; |
|
|
|
auto iter = m_local_cache.find(kp); |
|
|
|
if (iter == m_local_cache.end()) { |
|
|
|
auto py_ret = py_get(category, key); |
|
|
@@ -296,6 +301,9 @@ void init_utils(py::module m) { |
|
|
|
} |
|
|
|
} |
|
|
|
void put(const std::string& category, const Blob& key, const Blob& value) override { |
|
|
|
KeyPair kp = { category, blob_to_str(key) }; |
|
|
|
std::unique_lock<decltype(m_mutex)> wlock; |
|
|
|
m_local_cache.insert_or_assign(kp, copy_blob(value)); |
|
|
|
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); |
|
|
|
} |
|
|
|
}; |
|
|
|