GitOrigin-RevId: 9af7fa5c97
tags/v1.7.2.m1
@@ -877,6 +877,8 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) | |||||
include(cmake/Halide.cmake) | include(cmake/Halide.cmake) | ||||
endif() | endif() | ||||
include(cmake/cpp_redis.cmake) | |||||
# Thread | # Thread | ||||
IF(APPLE) | IF(APPLE) | ||||
set(CMAKE_THREAD_LIBS_INIT "-lpthread") | set(CMAKE_THREAD_LIBS_INIT "-lpthread") | ||||
@@ -0,0 +1,2 @@ | |||||
file(GLOB_RECURSE CPP_REDIS_SRCS ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/sources/*.cpp ${PROJECT_SOURCE_DIR}/third_party/tacopie/sources/*.cpp) | |||||
set(CPP_REDIS_INCLUDES ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/includes ${PROJECT_SOURCE_DIR}/third_party/tacopie/includes) |
@@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||||
set(MODULE_NAME _imperative_rt) | set(MODULE_NAME _imperative_rt) | ||||
set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | ||||
file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | ||||
set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) | |||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | ||||
@@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) | |||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | ||||
target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | ||||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||||
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | ||||
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | ||||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | if(CXX_SUPPORT_WCLASS_MEMACCESS) | ||||
@@ -92,9 +92,6 @@ _set_fork_exec_path_for_timed_func( | |||||
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | ||||
) | ) | ||||
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() | |||||
_persistent_cache_impl_ins.reg() | |||||
atexit.register(_close) | atexit.register(_close) | ||||
del _set_fork_exec_path_for_timed_func | del _set_fork_exec_path_for_timed_func | ||||
@@ -135,3 +132,5 @@ import megengine.quantization | |||||
import megengine.random | import megengine.random | ||||
import megengine.utils | import megengine.utils | ||||
import megengine.traced_module | import megengine.traced_module | ||||
persistent_cache.get_manager() |
@@ -9,108 +9,54 @@ | |||||
import argparse | import argparse | ||||
import getpass | import getpass | ||||
import json | |||||
import os | import os | ||||
import shelve | |||||
import sys | |||||
from ..core._imperative_rt import PersistentCache as _PersistentCache | |||||
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..version import __version__, git_version | from ..version import __version__, git_version | ||||
class _FakeRedisConn: | |||||
_cache_dir = None | |||||
_is_shelve = False | |||||
_dict = {} | |||||
class PersistentCacheManager(_PersistentCacheManager): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | |||||
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | ||||
self._dict = {} | |||||
self._is_shelve = False | |||||
get_logger().info("fastrun use in-memory cache") | get_logger().info("fastrun use in-memory cache") | ||||
self.open_memory() | |||||
else: | else: | ||||
try: | |||||
self._cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||||
if not self._cache_dir: | |||||
from ..hub.hub import _get_megengine_home | |||||
self._cache_dir = os.path.expanduser( | |||||
os.path.join(_get_megengine_home(), "persistent_cache") | |||||
) | |||||
os.makedirs(self._cache_dir, exist_ok=True) | |||||
cache_file = os.path.join(self._cache_dir, "cache") | |||||
self._dict = shelve.open(cache_file) | |||||
self._is_shelve = True | |||||
get_logger().info( | |||||
"fastrun use in-file cache in {}".format(self._cache_dir) | |||||
) | |||||
except Exception as exc: | |||||
self._dict = {} | |||||
self._is_shelve = False | |||||
get_logger().error( | |||||
"failed to create cache file in {} {!r}; fallback to " | |||||
"in-memory cache".format(self._cache_dir, exc) | |||||
) | |||||
def get(self, key): | |||||
if self._is_shelve and isinstance(key, bytes): | |||||
key = key.decode("utf-8") | |||||
return self._dict.get(key) | |||||
def set(self, key, val): | |||||
if self._is_shelve and isinstance(key, bytes): | |||||
key = key.decode("utf-8") | |||||
self._dict[key] = val | |||||
def clear(self): | |||||
print("{} cache item deleted in {}".format(len(self._dict), self._cache_dir)) | |||||
self._dict.clear() | |||||
self.open_file() | |||||
def __del__(self): | |||||
if self._is_shelve: | |||||
self._dict.close() | |||||
def open_memory(self): | |||||
pass | |||||
def open_file(self): | |||||
cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||||
try: | |||||
if not cache_dir: | |||||
from ..hub.hub import _get_megengine_home | |||||
class PersistentCacheOnServer(_PersistentCache): | |||||
_cached_conn = None | |||||
_prefix = None | |||||
_prev_get_refkeep = None | |||||
@property | |||||
def _conn(self): | |||||
"""get redis connection""" | |||||
if self._cached_conn is None: | |||||
self._cached_conn = _FakeRedisConn() | |||||
self._prefix = self.make_user_prefix() | |||||
return self._cached_conn | |||||
@classmethod | |||||
def make_user_prefix(cls): | |||||
return "mgbcache:{}".format(getpass.getuser()) | |||||
def _make_key(self, category, key): | |||||
prefix_with_version = "{}:MGB{}:GIT:{}".format( | |||||
self._prefix, __version__, git_version | |||||
) | |||||
return b"@".join( | |||||
(prefix_with_version.encode("ascii"), category.encode("ascii"), key) | |||||
) | |||||
def put(self, category, key, value): | |||||
conn = self._conn | |||||
key = self._make_key(category, key) | |||||
conn.set(key, value) | |||||
def get(self, category, key): | |||||
conn = self._conn | |||||
key = self._make_key(category, key) | |||||
self._prev_get_refkeep = conn.get(key) | |||||
return self._prev_get_refkeep | |||||
def clean(self): | |||||
conn = self._conn | |||||
if isinstance(conn, _FakeRedisConn): | |||||
conn.clear() | |||||
cache_dir = os.path.expanduser( | |||||
os.path.join(_get_megengine_home(), "persistent_cache.bin") | |||||
) | |||||
os.makedirs(cache_dir, exist_ok=True) | |||||
cache_file = os.path.join(cache_dir, "cache") | |||||
with open(cache_file, "a"): | |||||
pass | |||||
assert self.try_open_file(cache_file), "cannot create file" | |||||
get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) | |||||
except Exception as exc: | |||||
get_logger().error( | |||||
"failed to create cache file in {} {!r}; fallback to " | |||||
"in-memory cache".format(cache_dir, exc) | |||||
) | |||||
self.open_memory() | |||||
_manager = None | |||||
def get_manager(): | |||||
global _manager | |||||
if _manager is None: | |||||
_manager = PersistentCacheManager() | |||||
return _manager |
@@ -23,6 +23,7 @@ | |||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
#include "megbrain/imperative/blob_manager.h" | #include "megbrain/imperative/blob_manager.h" | ||||
#include "megbrain/imperative/persistent_cache.h" | |||||
#include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
#include "megbrain/imperative/tensor_sanity_check.h" | #include "megbrain/imperative/tensor_sanity_check.h" | ||||
#include "megbrain/serialization/helper.h" | #include "megbrain/serialization/helper.h" | ||||
@@ -229,83 +230,55 @@ void init_utils(py::module m) { | |||||
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | ||||
}); | }); | ||||
using mgb::PersistentCache; | |||||
class PyPersistentCache : public mgb::PersistentCache { | |||||
private: | |||||
using KeyPair = std::pair<std::string, std::string>; | |||||
using BlobPtr = std::unique_ptr<Blob, void (*)(Blob*)>; | |||||
using PersistentCache = mgb::PersistentCache; | |||||
using ExtendedPersistentCache = | |||||
mgb::imperative::persistent_cache::ExtendedPersistentCache; | |||||
std::shared_mutex m_mutex; | |||||
std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache; | |||||
struct PersistentCacheManager { | |||||
std::shared_ptr<ExtendedPersistentCache> instance; | |||||
static size_t hash_key_pair(const KeyPair& kp) { | |||||
std::hash<std::string> hasher; | |||||
return hasher(kp.first) ^ hasher(kp.second); | |||||
bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) { | |||||
if (cache) { | |||||
instance = cache; | |||||
PersistentCache::set_impl(cache); | |||||
return true; | |||||
} | |||||
return false; | |||||
} | } | ||||
std::string blob_to_str(const Blob& key) { | |||||
return std::string(reinterpret_cast<const char*>(key.ptr), key.size); | |||||
bool open_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix) { | |||||
return try_reg(mgb::imperative::persistent_cache::make_redis( | |||||
ip, port, password, prefix)); | |||||
} | } | ||||
BlobPtr copy_blob(const Blob& blob) { | |||||
auto blob_deleter = [](Blob* blob) { | |||||
if (blob) { | |||||
std::free(const_cast<void*>(blob->ptr)); | |||||
delete blob; | |||||
} | |||||
}; | |||||
auto blob_ptr = BlobPtr{new Blob(), blob_deleter}; | |||||
blob_ptr->ptr = std::malloc(blob.size); | |||||
std::memcpy(const_cast<void*>(blob_ptr->ptr), blob.ptr, blob.size); | |||||
blob_ptr->size = blob.size; | |||||
return blob_ptr; | |||||
bool open_file(std::string path) { | |||||
return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||||
} | } | ||||
BlobPtr str_to_blob(const std::string& str) { | |||||
auto blob = Blob{str.data(), str.size()}; | |||||
return copy_blob(blob); | |||||
std::optional<size_t> clean() { | |||||
if (instance) { | |||||
return instance->clear(); | |||||
} | |||||
return {}; | |||||
} | } | ||||
std::unique_ptr<Blob, void (*)(Blob*)> empty_blob() { | |||||
return BlobPtr{nullptr, [](Blob* blob) {}}; | |||||
void put(std::string category, std::string key, std::string value) { | |||||
PersistentCache::inst().put( | |||||
category, {key.data(), key.size()}, {value.data(), value.size()}); | |||||
} | } | ||||
public: | |||||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
auto py_get = [this](const std::string& category, | |||||
const Blob& key) -> mgb::Maybe<Blob> { | |||||
PYBIND11_OVERLOAD_PURE( | |||||
mgb::Maybe<Blob>, PersistentCache, get, category, key); | |||||
}; | |||||
KeyPair kp = {category, blob_to_str(key)}; | |||||
std::shared_lock<decltype(m_mutex)> rlock; | |||||
auto iter = m_local_cache.find(kp); | |||||
if (iter == m_local_cache.end()) { | |||||
auto py_ret = py_get(category, key); | |||||
if (!py_ret.valid()) { | |||||
iter = m_local_cache.insert({kp, empty_blob()}).first; | |||||
} else { | |||||
iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first; | |||||
} | |||||
} | |||||
if (iter->second) { | |||||
return *iter->second; | |||||
py::object get(std::string category, std::string key) { | |||||
auto value = | |||||
PersistentCache::inst().get(category, {key.data(), key.size()}); | |||||
if (value.valid()) { | |||||
return py::bytes(std::string((const char*)value->ptr, value->size)); | |||||
} else { | } else { | ||||
return {}; | |||||
return py::none(); | |||||
} | } | ||||
} | } | ||||
void put(const std::string& category, const Blob& key, const Blob& value) | |||||
override { | |||||
KeyPair kp = {category, blob_to_str(key)}; | |||||
std::unique_lock<decltype(m_mutex)> wlock; | |||||
m_local_cache.insert_or_assign(kp, copy_blob(value)); | |||||
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); | |||||
} | |||||
}; | }; | ||||
py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>( | |||||
m, "PersistentCache") | |||||
py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||||
.def(py::init<>()) | .def(py::init<>()) | ||||
.def("get", &PersistentCache::get) | |||||
.def("put", &PersistentCache::put) | |||||
.def("reg", &PersistentCache::set_impl); | |||||
.def("try_open_redis", &PersistentCacheManager::open_redis) | |||||
.def("try_open_file", &PersistentCacheManager::open_file) | |||||
.def("clean", &PersistentCacheManager::clean) | |||||
.def("put", &PersistentCacheManager::put) | |||||
.def("get", &PersistentCacheManager::get); | |||||
} | } |
@@ -1,12 +1,11 @@ | |||||
import pytest | import pytest | ||||
import megengine | |||||
from megengine.utils.persistent_cache import PersistentCacheOnServer | |||||
from megengine.utils.persistent_cache import _manager | |||||
@pytest.mark.skip(reason="fixme: github ci failed") | @pytest.mark.skip(reason="fixme: github ci failed") | ||||
def test_persistent_cache(): | def test_persistent_cache(): | ||||
pc = PersistentCacheOnServer() | |||||
pc = _manager | |||||
k0 = b"\x00\x00" | k0 = b"\x00\x00" | ||||
k1 = b"\x00\x01" | k1 = b"\x00\x01" | ||||
cat = "test" | cat = "test" | ||||
@@ -0,0 +1,186 @@ | |||||
/** | |||||
* \file imperative/src/impl/persistent_cache.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <fstream> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "cpp_redis/cpp_redis" | |||||
#include "megbrain/imperative/persistent_cache.h" | |||||
#include "megbrain/imperative/utils/base64.h" | |||||
#include "megbrain/utils/infile_persistent_cache.h" | |||||
namespace mgb::imperative::persistent_cache { | |||||
class RedisCache final : public ExtendedPersistentCache { | |||||
public: | |||||
RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { | |||||
m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||||
} | |||||
bool connect(std::string ip, size_t port, std::string password) { | |||||
m_client.auth(password); | |||||
m_client.connect( | |||||
ip, port, | |||||
[](const std::string& host, std::size_t port, | |||||
cpp_redis::connect_state status) { | |||||
if (status == cpp_redis::connect_state::dropped) { | |||||
mgb_log("client disconnected from %s.", host.c_str()); | |||||
mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), | |||||
port); | |||||
} | |||||
}, | |||||
std::uint32_t(200)); | |||||
if (!m_client.is_connected()) { | |||||
return false; | |||||
} | |||||
auto flag = m_client.get("mgb-cache-flag"); | |||||
sync(); | |||||
return flag.get().ok(); | |||||
} | |||||
bool valid() const override { return m_client.is_connected(); } | |||||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
auto mem_result = m_local->get(category, key); | |||||
if (mem_result.valid()) | |||||
return mem_result; | |||||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||||
std::string redis_key_str; | |||||
encode(category + '@' + key_str, redis_key_str, 24); | |||||
auto result = m_client.get(redis_key_str); | |||||
sync(); | |||||
auto content = result.get(); | |||||
if (content.is_null()) | |||||
return mgb::None; | |||||
std::string decode_content; | |||||
decode(content.as_string(), decode_content); | |||||
m_local->put(category, key, {decode_content.data(), decode_content.length()}); | |||||
return m_local->get(category, key); | |||||
} | |||||
void put(const std::string& category, const Blob& key, const Blob& value) override { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||||
std::string redis_key_str; | |||||
encode(category + '@' + key_str, redis_key_str); | |||||
std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||||
std::string redis_value_str; | |||||
encode(value_str, redis_value_str); | |||||
auto result = m_client.set(redis_key_str, redis_value_str); | |||||
m_local->put(category, key, value); | |||||
sync(); | |||||
} | |||||
std::optional<size_t> clear() override { | |||||
size_t cursor = 0, nr_deleted = 0; | |||||
std::string pattern = m_prefix + "@*"; | |||||
do { | |||||
auto reply = m_client.scan(cursor, pattern).share(); | |||||
sync(); | |||||
auto keys = reply.get().as_array(); | |||||
std::vector<std::string> string_keys; | |||||
for (auto&& key : keys) { | |||||
string_keys.push_back(key.as_string()); | |||||
} | |||||
m_client.del(string_keys); | |||||
nr_deleted += string_keys.size(); | |||||
cursor = reply.get().as_array()[0].as_integer(); | |||||
} while (cursor != 0); | |||||
return nr_deleted; | |||||
} | |||||
private: | |||||
std::shared_ptr<mgb::PersistentCache> m_local; | |||||
std::mutex m_mtx; | |||||
cpp_redis::client m_client; | |||||
std::string m_prefix; | |||||
uint64_t m_timeout; | |||||
void sync() { | |||||
m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout)); | |||||
mgb_assert(valid()); | |||||
} | |||||
}; | |||||
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | |||||
private: | |||||
std::string m_path; | |||||
std::unique_ptr<mgb::InFilePersistentCache> m_impl; | |||||
public: | |||||
ExtendedInFilePersistentCache() = default; | |||||
bool open(std::string path) { | |||||
std::fstream file; | |||||
file.open(path, std::ios::in | std::ios::binary); | |||||
if (!file.is_open()) { | |||||
return false; | |||||
} | |||||
std::vector<char> bytes = { | |||||
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; | |||||
if (bytes.size()) { | |||||
m_impl = std::make_unique<mgb::InFilePersistentCache>( | |||||
(const uint8_t*)bytes.data(), bytes.size()); | |||||
} else { | |||||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||||
} | |||||
m_path = path; | |||||
return true; | |||||
} | |||||
~ExtendedInFilePersistentCache() { | |||||
if (m_impl) { | |||||
m_impl->dump_cache(m_path.c_str()); | |||||
} | |||||
} | |||||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
return m_impl->get(category, key); | |||||
} | |||||
void put(const std::string& category, const Blob& key, const Blob& value) override { | |||||
return m_impl->put(category, key, value); | |||||
} | |||||
std::optional<size_t> clear() override { | |||||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||||
m_impl->dump_cache(m_path.c_str()); | |||||
return {}; | |||||
} | |||||
bool valid() const override { return m_impl != nullptr; } | |||||
}; | |||||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix) { | |||||
auto cache = std::make_shared<RedisCache>(prefix, 100); | |||||
if (!cache->connect(ip, port, password)) { | |||||
return nullptr; | |||||
} | |||||
return cache; | |||||
} | |||||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||||
if (!cache->open(path)) { | |||||
return nullptr; | |||||
} | |||||
return cache; | |||||
} | |||||
} // namespace mgb::imperative::persistent_cache | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,172 @@ | |||||
/** | |||||
* \file imperative/src/impl/base64.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/utils/base64.h" | |||||
namespace mgb::imperative { | |||||
namespace { | |||||
/* | |||||
** Translation Table as described in RFC1113 | |||||
*/ | |||||
const char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; | |||||
/* | |||||
** Translation Table to decode: | |||||
*https://github.com/dgiardini/imgcalkap/blob/master/base64.c | |||||
*/ | |||||
const char cd64[] = | |||||
"|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`" | |||||
"abcdefghijklmnopq"; | |||||
/* | |||||
** encodeblock | |||||
** | |||||
** encode 3 8-bit binary bytes as 4 '6-bit' characters | |||||
*/ | |||||
void encodeblock(unsigned char in[3], unsigned char out[4], int len) { | |||||
out[0] = cb64[in[0] >> 2]; | |||||
out[1] = cb64[((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4)]; | |||||
out[2] = | |||||
(unsigned char)(len > 1 ? cb64[((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6)] : '='); | |||||
out[3] = (unsigned char)(len > 2 ? cb64[in[2] & 0x3f] : '='); | |||||
} | |||||
/* | |||||
** decodeblock | |||||
** | |||||
** decode 4 '6-bit' characters into 3 8-bit binary bytes | |||||
*/ | |||||
void decodeblock(unsigned char in[4], unsigned char out[3]) { | |||||
out[0] = (unsigned char)(in[0] << 2 | in[1] >> 4); | |||||
out[1] = (unsigned char)(in[1] << 4 | in[2] >> 2); | |||||
out[2] = (unsigned char)(((in[2] << 6) & 0xc0) | in[3]); | |||||
} | |||||
} // namespace | |||||
/** | |||||
* Encode string to base64 string | |||||
* @param input - source string | |||||
* @param outdata - target base64 string | |||||
* @param linesize - max size of line | |||||
*/ | |||||
void encode( | |||||
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata, | |||||
int linesize) { | |||||
outdata.clear(); | |||||
unsigned char in[3], out[4]; | |||||
int i, len, blocksout = 0; | |||||
size_t j = 0; | |||||
auto* indata = reinterpret_cast<const unsigned char*>(input.data()); | |||||
unsigned int insize = input.size(); | |||||
while (j <= insize) { | |||||
len = 0; | |||||
for (i = 0; i < 3; i++) { | |||||
in[i] = (unsigned char)indata[j]; | |||||
j++; | |||||
if (j <= insize) { | |||||
len++; | |||||
} else { | |||||
in[i] = 0; | |||||
} | |||||
} | |||||
if (len) { | |||||
encodeblock(in, out, len); | |||||
for (i = 0; i < 4; i++) { | |||||
outdata.push_back(out[i]); | |||||
} | |||||
blocksout++; | |||||
} | |||||
if (blocksout >= (linesize / 4) || (j == insize)) { | |||||
if (blocksout) { | |||||
outdata.push_back('\r'); | |||||
outdata.push_back('\n'); | |||||
} | |||||
blocksout = 0; | |||||
} | |||||
} | |||||
} | |||||
/** | |||||
* Decode base64 string ot source | |||||
* @param input - base64 string | |||||
* @param outdata - source string | |||||
*/ | |||||
void decode( | |||||
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata) { | |||||
outdata.clear(); | |||||
unsigned char in[4], out[3], v; | |||||
int i, len; | |||||
size_t j = 0; | |||||
auto* indata = reinterpret_cast<const unsigned char*>(input.data()); | |||||
unsigned int insize = input.size(); | |||||
while (j <= insize) { | |||||
for (len = 0, i = 0; i < 4 && (j <= insize); i++) { | |||||
v = 0; | |||||
while ((j <= insize) && v == 0) { | |||||
v = (unsigned char)indata[j++]; | |||||
v = (unsigned char)((v < 43 || v > 122) ? 0 : cd64[v - 43]); | |||||
if (v) { | |||||
v = (unsigned char)((v == '$') ? 0 : v - 61); | |||||
} | |||||
} | |||||
if (j <= insize) { | |||||
len++; | |||||
if (v) { | |||||
in[i] = (unsigned char)(v - 1); | |||||
} | |||||
} else { | |||||
in[i] = 0; | |||||
} | |||||
} | |||||
if (len) { | |||||
decodeblock(in, out); | |||||
for (i = 0; i < len - 1; i++) { | |||||
outdata.push_back(out[i]); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
/** | |||||
* Encode binary data to base64 buffer | |||||
* @param input - source data | |||||
* @param outdata - target base64 buffer | |||||
* @param linesize | |||||
*/ | |||||
void encode(const std::string& input, std::string& outdata, int linesize) { | |||||
std::vector<std::uint8_t> out; | |||||
std::vector<std::uint8_t> in(input.begin(), input.end()); | |||||
encode(in, out, linesize); | |||||
outdata = std::string(out.begin(), out.end()); | |||||
} | |||||
/** | |||||
* Decode base64 buffer to source binary data | |||||
* @param input - base64 buffer | |||||
* @param outdata - source binary data | |||||
*/ | |||||
void decode(const std::string& input, std::string& outdata) { | |||||
std::vector<std::uint8_t> in(input.begin(), input.end()); | |||||
std::vector<std::uint8_t> out; | |||||
decode(in, out); | |||||
outdata = std::string(out.begin(), out.end()); | |||||
} | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/persistent_cache.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <memory> | |||||
#include "megbrain/utils/persistent_cache.h" | |||||
namespace mgb::imperative::persistent_cache { | |||||
class ExtendedPersistentCache : public mgb::PersistentCache { | |||||
public: | |||||
virtual bool valid() const = 0; | |||||
virtual std::optional<size_t> clear() = 0; | |||||
}; | |||||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix); | |||||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||||
} // namespace mgb::imperative::persistent_cache | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,50 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/base64.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/common.h" | |||||
namespace mgb::imperative { | |||||
/** | |||||
* Encode string to base64 string | |||||
* @param input - source string | |||||
* @param outdata - target base64 string | |||||
* @param linesize - max size of line | |||||
*/ | |||||
void encode( | |||||
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata, | |||||
int linesize = 76); | |||||
/** | |||||
* Decode base64 string ot source | |||||
* @param input - base64 string | |||||
* @param outdata - source string | |||||
*/ | |||||
void decode(const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata); | |||||
/** | |||||
* Encode binary data to base64 buffer | |||||
* @param input - source data | |||||
* @param outdata - target base64 buffer | |||||
* @param linesize | |||||
*/ | |||||
void encode(const std::string& input, std::string& outdata, int linesize = 76); | |||||
/** | |||||
* Decode base64 buffer to source binary data | |||||
* @param input - base64 buffer | |||||
* @param outdata - source binary data | |||||
*/ | |||||
void decode(const std::string& input, std::string& outdata); | |||||
} // namespace mgb::imperative |
@@ -12,7 +12,7 @@ endif() | |||||
# TODO: turn python binding into a static/object library | # TODO: turn python binding into a static/object library | ||||
add_executable(imperative_test ${SOURCES} ${SRCS}) | add_executable(imperative_test ${SOURCES} ${SRCS}) | ||||
add_dependencies(imperative_test mgb_opdef) | add_dependencies(imperative_test mgb_opdef) | ||||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) | |||||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||||
# Python binding | # Python binding | ||||
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | ||||
@@ -74,14 +74,19 @@ class InMemoryPersistentCache final : public PersistentCache { | |||||
}; | }; | ||||
}; | }; | ||||
Maybe<Blob> get(const std::string& category, const Blob& key) override; | |||||
void put(const std::string& category, const Blob& key, const Blob& value) override; | |||||
MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | |||||
const std::string& category, const Blob& key) override; | |||||
MGE_WIN_DECLSPEC_FUC void put( | |||||
const std::string& category, const Blob& key, const Blob& value) override; | |||||
std::unordered_map< | std::unordered_map< | ||||
std::string, | std::string, | ||||
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | ||||
m_cache; | m_cache; | ||||
MGB_MUTEX m_mtx; | MGB_MUTEX m_mtx; | ||||
public: | |||||
MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default; | |||||
}; | }; | ||||
/*! | /*! | ||||