GitOrigin-RevId: 9af7fa5c97
tags/v1.7.2.m1
@@ -877,6 +877,8 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) | |||
include(cmake/Halide.cmake) | |||
endif() | |||
include(cmake/cpp_redis.cmake) | |||
# Thread | |||
IF(APPLE) | |||
set(CMAKE_THREAD_LIBS_INIT "-lpthread") | |||
@@ -0,0 +1,2 @@ | |||
file(GLOB_RECURSE CPP_REDIS_SRCS ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/sources/*.cpp ${PROJECT_SOURCE_DIR}/third_party/tacopie/sources/*.cpp) | |||
set(CPP_REDIS_INCLUDES ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/includes ${PROJECT_SOURCE_DIR}/third_party/tacopie/includes) |
@@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||
set(MODULE_NAME _imperative_rt) | |||
set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | |||
file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | |||
set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||
@@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) | |||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | |||
target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | |||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
@@ -92,9 +92,6 @@ _set_fork_exec_path_for_timed_func( | |||
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
) | |||
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() | |||
_persistent_cache_impl_ins.reg() | |||
atexit.register(_close) | |||
del _set_fork_exec_path_for_timed_func | |||
@@ -135,3 +132,5 @@ import megengine.quantization | |||
import megengine.random | |||
import megengine.utils | |||
import megengine.traced_module | |||
persistent_cache.get_manager() |
@@ -9,108 +9,54 @@ | |||
import argparse | |||
import getpass | |||
import json | |||
import os | |||
import shelve | |||
import sys | |||
from ..core._imperative_rt import PersistentCache as _PersistentCache | |||
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||
from ..logger import get_logger | |||
from ..version import __version__, git_version | |||
class _FakeRedisConn: | |||
_cache_dir = None | |||
_is_shelve = False | |||
_dict = {} | |||
class PersistentCacheManager(_PersistentCacheManager): | |||
def __init__(self): | |||
super().__init__() | |||
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | |||
self._dict = {} | |||
self._is_shelve = False | |||
get_logger().info("fastrun use in-memory cache") | |||
self.open_memory() | |||
else: | |||
try: | |||
self._cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||
if not self._cache_dir: | |||
from ..hub.hub import _get_megengine_home | |||
self._cache_dir = os.path.expanduser( | |||
os.path.join(_get_megengine_home(), "persistent_cache") | |||
) | |||
os.makedirs(self._cache_dir, exist_ok=True) | |||
cache_file = os.path.join(self._cache_dir, "cache") | |||
self._dict = shelve.open(cache_file) | |||
self._is_shelve = True | |||
get_logger().info( | |||
"fastrun use in-file cache in {}".format(self._cache_dir) | |||
) | |||
except Exception as exc: | |||
self._dict = {} | |||
self._is_shelve = False | |||
get_logger().error( | |||
"failed to create cache file in {} {!r}; fallback to " | |||
"in-memory cache".format(self._cache_dir, exc) | |||
) | |||
def get(self, key): | |||
if self._is_shelve and isinstance(key, bytes): | |||
key = key.decode("utf-8") | |||
return self._dict.get(key) | |||
def set(self, key, val): | |||
if self._is_shelve and isinstance(key, bytes): | |||
key = key.decode("utf-8") | |||
self._dict[key] = val | |||
def clear(self): | |||
print("{} cache item deleted in {}".format(len(self._dict), self._cache_dir)) | |||
self._dict.clear() | |||
self.open_file() | |||
def __del__(self): | |||
if self._is_shelve: | |||
self._dict.close() | |||
def open_memory(self): | |||
pass | |||
def open_file(self): | |||
cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||
try: | |||
if not cache_dir: | |||
from ..hub.hub import _get_megengine_home | |||
class PersistentCacheOnServer(_PersistentCache): | |||
_cached_conn = None | |||
_prefix = None | |||
_prev_get_refkeep = None | |||
@property | |||
def _conn(self): | |||
"""get redis connection""" | |||
if self._cached_conn is None: | |||
self._cached_conn = _FakeRedisConn() | |||
self._prefix = self.make_user_prefix() | |||
return self._cached_conn | |||
@classmethod | |||
def make_user_prefix(cls): | |||
return "mgbcache:{}".format(getpass.getuser()) | |||
def _make_key(self, category, key): | |||
prefix_with_version = "{}:MGB{}:GIT:{}".format( | |||
self._prefix, __version__, git_version | |||
) | |||
return b"@".join( | |||
(prefix_with_version.encode("ascii"), category.encode("ascii"), key) | |||
) | |||
def put(self, category, key, value): | |||
conn = self._conn | |||
key = self._make_key(category, key) | |||
conn.set(key, value) | |||
def get(self, category, key): | |||
conn = self._conn | |||
key = self._make_key(category, key) | |||
self._prev_get_refkeep = conn.get(key) | |||
return self._prev_get_refkeep | |||
def clean(self): | |||
conn = self._conn | |||
if isinstance(conn, _FakeRedisConn): | |||
conn.clear() | |||
cache_dir = os.path.expanduser( | |||
os.path.join(_get_megengine_home(), "persistent_cache.bin") | |||
) | |||
os.makedirs(cache_dir, exist_ok=True) | |||
cache_file = os.path.join(cache_dir, "cache") | |||
with open(cache_file, "a"): | |||
pass | |||
assert self.try_open_file(cache_file), "cannot create file" | |||
get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) | |||
except Exception as exc: | |||
get_logger().error( | |||
"failed to create cache file in {} {!r}; fallback to " | |||
"in-memory cache".format(cache_dir, exc) | |||
) | |||
self.open_memory() | |||
_manager = None | |||
def get_manager(): | |||
global _manager | |||
if _manager is None: | |||
_manager = PersistentCacheManager() | |||
return _manager |
@@ -23,6 +23,7 @@ | |||
#include "megbrain/common.h" | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/imperative/blob_manager.h" | |||
#include "megbrain/imperative/persistent_cache.h" | |||
#include "megbrain/imperative/profiler.h" | |||
#include "megbrain/imperative/tensor_sanity_check.h" | |||
#include "megbrain/serialization/helper.h" | |||
@@ -229,83 +230,55 @@ void init_utils(py::module m) { | |||
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | |||
}); | |||
using mgb::PersistentCache; | |||
class PyPersistentCache : public mgb::PersistentCache { | |||
private: | |||
using KeyPair = std::pair<std::string, std::string>; | |||
using BlobPtr = std::unique_ptr<Blob, void (*)(Blob*)>; | |||
using PersistentCache = mgb::PersistentCache; | |||
using ExtendedPersistentCache = | |||
mgb::imperative::persistent_cache::ExtendedPersistentCache; | |||
std::shared_mutex m_mutex; | |||
std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache; | |||
struct PersistentCacheManager { | |||
std::shared_ptr<ExtendedPersistentCache> instance; | |||
static size_t hash_key_pair(const KeyPair& kp) { | |||
std::hash<std::string> hasher; | |||
return hasher(kp.first) ^ hasher(kp.second); | |||
bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) { | |||
if (cache) { | |||
instance = cache; | |||
PersistentCache::set_impl(cache); | |||
return true; | |||
} | |||
return false; | |||
} | |||
std::string blob_to_str(const Blob& key) { | |||
return std::string(reinterpret_cast<const char*>(key.ptr), key.size); | |||
bool open_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix) { | |||
return try_reg(mgb::imperative::persistent_cache::make_redis( | |||
ip, port, password, prefix)); | |||
} | |||
BlobPtr copy_blob(const Blob& blob) { | |||
auto blob_deleter = [](Blob* blob) { | |||
if (blob) { | |||
std::free(const_cast<void*>(blob->ptr)); | |||
delete blob; | |||
} | |||
}; | |||
auto blob_ptr = BlobPtr{new Blob(), blob_deleter}; | |||
blob_ptr->ptr = std::malloc(blob.size); | |||
std::memcpy(const_cast<void*>(blob_ptr->ptr), blob.ptr, blob.size); | |||
blob_ptr->size = blob.size; | |||
return blob_ptr; | |||
bool open_file(std::string path) { | |||
return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||
} | |||
BlobPtr str_to_blob(const std::string& str) { | |||
auto blob = Blob{str.data(), str.size()}; | |||
return copy_blob(blob); | |||
std::optional<size_t> clean() { | |||
if (instance) { | |||
return instance->clear(); | |||
} | |||
return {}; | |||
} | |||
std::unique_ptr<Blob, void (*)(Blob*)> empty_blob() { | |||
return BlobPtr{nullptr, [](Blob* blob) {}}; | |||
void put(std::string category, std::string key, std::string value) { | |||
PersistentCache::inst().put( | |||
category, {key.data(), key.size()}, {value.data(), value.size()}); | |||
} | |||
public: | |||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
auto py_get = [this](const std::string& category, | |||
const Blob& key) -> mgb::Maybe<Blob> { | |||
PYBIND11_OVERLOAD_PURE( | |||
mgb::Maybe<Blob>, PersistentCache, get, category, key); | |||
}; | |||
KeyPair kp = {category, blob_to_str(key)}; | |||
std::shared_lock<decltype(m_mutex)> rlock; | |||
auto iter = m_local_cache.find(kp); | |||
if (iter == m_local_cache.end()) { | |||
auto py_ret = py_get(category, key); | |||
if (!py_ret.valid()) { | |||
iter = m_local_cache.insert({kp, empty_blob()}).first; | |||
} else { | |||
iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first; | |||
} | |||
} | |||
if (iter->second) { | |||
return *iter->second; | |||
py::object get(std::string category, std::string key) { | |||
auto value = | |||
PersistentCache::inst().get(category, {key.data(), key.size()}); | |||
if (value.valid()) { | |||
return py::bytes(std::string((const char*)value->ptr, value->size)); | |||
} else { | |||
return {}; | |||
return py::none(); | |||
} | |||
} | |||
void put(const std::string& category, const Blob& key, const Blob& value) | |||
override { | |||
KeyPair kp = {category, blob_to_str(key)}; | |||
std::unique_lock<decltype(m_mutex)> wlock; | |||
m_local_cache.insert_or_assign(kp, copy_blob(value)); | |||
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); | |||
} | |||
}; | |||
py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>( | |||
m, "PersistentCache") | |||
py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||
.def(py::init<>()) | |||
.def("get", &PersistentCache::get) | |||
.def("put", &PersistentCache::put) | |||
.def("reg", &PersistentCache::set_impl); | |||
.def("try_open_redis", &PersistentCacheManager::open_redis) | |||
.def("try_open_file", &PersistentCacheManager::open_file) | |||
.def("clean", &PersistentCacheManager::clean) | |||
.def("put", &PersistentCacheManager::put) | |||
.def("get", &PersistentCacheManager::get); | |||
} |
@@ -1,12 +1,11 @@ | |||
import pytest | |||
import megengine | |||
from megengine.utils.persistent_cache import PersistentCacheOnServer | |||
from megengine.utils.persistent_cache import _manager | |||
@pytest.mark.skip(reason="fixme: github ci failed") | |||
def test_persistent_cache(): | |||
pc = PersistentCacheOnServer() | |||
pc = _manager | |||
k0 = b"\x00\x00" | |||
k1 = b"\x00\x01" | |||
cat = "test" | |||
@@ -0,0 +1,186 @@ | |||
/** | |||
* \file imperative/src/impl/persistent_cache.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <fstream> | |||
#include <string> | |||
#include <vector> | |||
#include "cpp_redis/cpp_redis" | |||
#include "megbrain/imperative/persistent_cache.h" | |||
#include "megbrain/imperative/utils/base64.h" | |||
#include "megbrain/utils/infile_persistent_cache.h" | |||
namespace mgb::imperative::persistent_cache { | |||
class RedisCache final : public ExtendedPersistentCache { | |||
public: | |||
RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { | |||
m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||
} | |||
bool connect(std::string ip, size_t port, std::string password) { | |||
m_client.auth(password); | |||
m_client.connect( | |||
ip, port, | |||
[](const std::string& host, std::size_t port, | |||
cpp_redis::connect_state status) { | |||
if (status == cpp_redis::connect_state::dropped) { | |||
mgb_log("client disconnected from %s.", host.c_str()); | |||
mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), | |||
port); | |||
} | |||
}, | |||
std::uint32_t(200)); | |||
if (!m_client.is_connected()) { | |||
return false; | |||
} | |||
auto flag = m_client.get("mgb-cache-flag"); | |||
sync(); | |||
return flag.get().ok(); | |||
} | |||
bool valid() const override { return m_client.is_connected(); } | |||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto mem_result = m_local->get(category, key); | |||
if (mem_result.valid()) | |||
return mem_result; | |||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
std::string redis_key_str; | |||
encode(category + '@' + key_str, redis_key_str, 24); | |||
auto result = m_client.get(redis_key_str); | |||
sync(); | |||
auto content = result.get(); | |||
if (content.is_null()) | |||
return mgb::None; | |||
std::string decode_content; | |||
decode(content.as_string(), decode_content); | |||
m_local->put(category, key, {decode_content.data(), decode_content.length()}); | |||
return m_local->get(category, key); | |||
} | |||
void put(const std::string& category, const Blob& key, const Blob& value) override { | |||
MGB_LOCK_GUARD(m_mtx); | |||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
std::string redis_key_str; | |||
encode(category + '@' + key_str, redis_key_str); | |||
std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||
std::string redis_value_str; | |||
encode(value_str, redis_value_str); | |||
auto result = m_client.set(redis_key_str, redis_value_str); | |||
m_local->put(category, key, value); | |||
sync(); | |||
} | |||
std::optional<size_t> clear() override { | |||
size_t cursor = 0, nr_deleted = 0; | |||
std::string pattern = m_prefix + "@*"; | |||
do { | |||
auto reply = m_client.scan(cursor, pattern).share(); | |||
sync(); | |||
auto keys = reply.get().as_array(); | |||
std::vector<std::string> string_keys; | |||
for (auto&& key : keys) { | |||
string_keys.push_back(key.as_string()); | |||
} | |||
m_client.del(string_keys); | |||
nr_deleted += string_keys.size(); | |||
cursor = reply.get().as_array()[0].as_integer(); | |||
} while (cursor != 0); | |||
return nr_deleted; | |||
} | |||
private: | |||
std::shared_ptr<mgb::PersistentCache> m_local; | |||
std::mutex m_mtx; | |||
cpp_redis::client m_client; | |||
std::string m_prefix; | |||
uint64_t m_timeout; | |||
void sync() { | |||
m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout)); | |||
mgb_assert(valid()); | |||
} | |||
}; | |||
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | |||
private: | |||
std::string m_path; | |||
std::unique_ptr<mgb::InFilePersistentCache> m_impl; | |||
public: | |||
ExtendedInFilePersistentCache() = default; | |||
bool open(std::string path) { | |||
std::fstream file; | |||
file.open(path, std::ios::in | std::ios::binary); | |||
if (!file.is_open()) { | |||
return false; | |||
} | |||
std::vector<char> bytes = { | |||
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; | |||
if (bytes.size()) { | |||
m_impl = std::make_unique<mgb::InFilePersistentCache>( | |||
(const uint8_t*)bytes.data(), bytes.size()); | |||
} else { | |||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
} | |||
m_path = path; | |||
return true; | |||
} | |||
~ExtendedInFilePersistentCache() { | |||
if (m_impl) { | |||
m_impl->dump_cache(m_path.c_str()); | |||
} | |||
} | |||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
return m_impl->get(category, key); | |||
} | |||
void put(const std::string& category, const Blob& key, const Blob& value) override { | |||
return m_impl->put(category, key, value); | |||
} | |||
std::optional<size_t> clear() override { | |||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
m_impl->dump_cache(m_path.c_str()); | |||
return {}; | |||
} | |||
bool valid() const override { return m_impl != nullptr; } | |||
}; | |||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix) { | |||
auto cache = std::make_shared<RedisCache>(prefix, 100); | |||
if (!cache->connect(ip, port, password)) { | |||
return nullptr; | |||
} | |||
return cache; | |||
} | |||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
if (!cache->open(path)) { | |||
return nullptr; | |||
} | |||
return cache; | |||
} | |||
} // namespace mgb::imperative::persistent_cache | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,172 @@ | |||
/** | |||
* \file imperative/src/impl/base64.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/utils/base64.h" | |||
namespace mgb::imperative { | |||
namespace { | |||
/* | |||
** Translation Table as described in RFC1113 | |||
*/ | |||
const char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; | |||
/* | |||
** Translation Table to decode: | |||
*https://github.com/dgiardini/imgcalkap/blob/master/base64.c | |||
*/ | |||
const char cd64[] = | |||
"|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`" | |||
"abcdefghijklmnopq"; | |||
/* | |||
** encodeblock | |||
** | |||
** encode 3 8-bit binary bytes as 4 '6-bit' characters | |||
*/ | |||
void encodeblock(unsigned char in[3], unsigned char out[4], int len) { | |||
out[0] = cb64[in[0] >> 2]; | |||
out[1] = cb64[((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4)]; | |||
out[2] = | |||
(unsigned char)(len > 1 ? cb64[((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6)] : '='); | |||
out[3] = (unsigned char)(len > 2 ? cb64[in[2] & 0x3f] : '='); | |||
} | |||
/* | |||
** decodeblock | |||
** | |||
** decode 4 '6-bit' characters into 3 8-bit binary bytes | |||
*/ | |||
void decodeblock(unsigned char in[4], unsigned char out[3]) { | |||
out[0] = (unsigned char)(in[0] << 2 | in[1] >> 4); | |||
out[1] = (unsigned char)(in[1] << 4 | in[2] >> 2); | |||
out[2] = (unsigned char)(((in[2] << 6) & 0xc0) | in[3]); | |||
} | |||
} // namespace | |||
/** | |||
* Encode string to base64 string | |||
* @param input - source string | |||
* @param outdata - target base64 string | |||
* @param linesize - max size of line | |||
*/ | |||
void encode( | |||
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata, | |||
int linesize) { | |||
outdata.clear(); | |||
unsigned char in[3], out[4]; | |||
int i, len, blocksout = 0; | |||
size_t j = 0; | |||
auto* indata = reinterpret_cast<const unsigned char*>(input.data()); | |||
unsigned int insize = input.size(); | |||
while (j <= insize) { | |||
len = 0; | |||
for (i = 0; i < 3; i++) { | |||
in[i] = (unsigned char)indata[j]; | |||
j++; | |||
if (j <= insize) { | |||
len++; | |||
} else { | |||
in[i] = 0; | |||
} | |||
} | |||
if (len) { | |||
encodeblock(in, out, len); | |||
for (i = 0; i < 4; i++) { | |||
outdata.push_back(out[i]); | |||
} | |||
blocksout++; | |||
} | |||
if (blocksout >= (linesize / 4) || (j == insize)) { | |||
if (blocksout) { | |||
outdata.push_back('\r'); | |||
outdata.push_back('\n'); | |||
} | |||
blocksout = 0; | |||
} | |||
} | |||
} | |||
/** | |||
* Decode base64 string ot source | |||
* @param input - base64 string | |||
* @param outdata - source string | |||
*/ | |||
void decode( | |||
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata) { | |||
outdata.clear(); | |||
unsigned char in[4], out[3], v; | |||
int i, len; | |||
size_t j = 0; | |||
auto* indata = reinterpret_cast<const unsigned char*>(input.data()); | |||
unsigned int insize = input.size(); | |||
while (j <= insize) { | |||
for (len = 0, i = 0; i < 4 && (j <= insize); i++) { | |||
v = 0; | |||
while ((j <= insize) && v == 0) { | |||
v = (unsigned char)indata[j++]; | |||
v = (unsigned char)((v < 43 || v > 122) ? 0 : cd64[v - 43]); | |||
if (v) { | |||
v = (unsigned char)((v == '$') ? 0 : v - 61); | |||
} | |||
} | |||
if (j <= insize) { | |||
len++; | |||
if (v) { | |||
in[i] = (unsigned char)(v - 1); | |||
} | |||
} else { | |||
in[i] = 0; | |||
} | |||
} | |||
if (len) { | |||
decodeblock(in, out); | |||
for (i = 0; i < len - 1; i++) { | |||
outdata.push_back(out[i]); | |||
} | |||
} | |||
} | |||
} | |||
/** | |||
* Encode binary data to base64 buffer | |||
* @param input - source data | |||
* @param outdata - target base64 buffer | |||
* @param linesize | |||
*/ | |||
void encode(const std::string& input, std::string& outdata, int linesize) { | |||
std::vector<std::uint8_t> out; | |||
std::vector<std::uint8_t> in(input.begin(), input.end()); | |||
encode(in, out, linesize); | |||
outdata = std::string(out.begin(), out.end()); | |||
} | |||
/** | |||
* Decode base64 buffer to source binary data | |||
* @param input - base64 buffer | |||
* @param outdata - source binary data | |||
*/ | |||
void decode(const std::string& input, std::string& outdata) { | |||
std::vector<std::uint8_t> in(input.begin(), input.end()); | |||
std::vector<std::uint8_t> out; | |||
decode(in, out); | |||
outdata = std::string(out.begin(), out.end()); | |||
} | |||
} // namespace mgb::imperative |
@@ -0,0 +1,31 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/persistent_cache.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <memory> | |||
#include "megbrain/utils/persistent_cache.h" | |||
namespace mgb::imperative::persistent_cache { | |||
class ExtendedPersistentCache : public mgb::PersistentCache { | |||
public: | |||
virtual bool valid() const = 0; | |||
virtual std::optional<size_t> clear() = 0; | |||
}; | |||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix); | |||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||
} // namespace mgb::imperative::persistent_cache | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,50 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/base64.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/common.h" | |||
namespace mgb::imperative { | |||
/** | |||
* Encode string to base64 string | |||
* @param input - source string | |||
* @param outdata - target base64 string | |||
* @param linesize - max size of line | |||
*/ | |||
void encode( | |||
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata, | |||
int linesize = 76); | |||
/** | |||
* Decode base64 string ot source | |||
* @param input - base64 string | |||
* @param outdata - source string | |||
*/ | |||
void decode(const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata); | |||
/** | |||
* Encode binary data to base64 buffer | |||
* @param input - source data | |||
* @param outdata - target base64 buffer | |||
* @param linesize | |||
*/ | |||
void encode(const std::string& input, std::string& outdata, int linesize = 76); | |||
/** | |||
* Decode base64 buffer to source binary data | |||
* @param input - base64 buffer | |||
* @param outdata - source binary data | |||
*/ | |||
void decode(const std::string& input, std::string& outdata); | |||
} // namespace mgb::imperative |
@@ -12,7 +12,7 @@ endif() | |||
# TODO: turn python binding into a static/object library | |||
add_executable(imperative_test ${SOURCES} ${SRCS}) | |||
add_dependencies(imperative_test mgb_opdef) | |||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) | |||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
# Python binding | |||
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
@@ -74,14 +74,19 @@ class InMemoryPersistentCache final : public PersistentCache { | |||
}; | |||
}; | |||
Maybe<Blob> get(const std::string& category, const Blob& key) override; | |||
void put(const std::string& category, const Blob& key, const Blob& value) override; | |||
MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | |||
const std::string& category, const Blob& key) override; | |||
MGE_WIN_DECLSPEC_FUC void put( | |||
const std::string& category, const Blob& key, const Blob& value) override; | |||
std::unordered_map< | |||
std::string, | |||
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||
m_cache; | |||
MGB_MUTEX m_mtx; | |||
public: | |||
MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default; | |||
}; | |||
/*! | |||