@@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||
set(MODULE_NAME _imperative_rt) | |||
set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | |||
file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | |||
set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||
@@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) | |||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | |||
target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | |||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
@@ -11,6 +11,7 @@ import argparse | |||
import getpass | |||
import os | |||
import sys | |||
import urllib.parse | |||
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||
from ..logger import get_logger | |||
@@ -23,8 +24,10 @@ class PersistentCacheManager(_PersistentCacheManager): | |||
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | |||
get_logger().info("fastrun use in-memory cache") | |||
self.open_memory() | |||
else: | |||
elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": | |||
self.open_file() | |||
else: | |||
self.open_redis() | |||
def open_memory(self): | |||
pass | |||
@@ -51,6 +54,28 @@ class PersistentCacheManager(_PersistentCacheManager): | |||
) | |||
self.open_memory() | |||
def open_redis(self): | |||
prefix = "mgbcache:{}:MGB{}:GIT:{}".format( | |||
getpass.getuser(), __version__, git_version | |||
) | |||
url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||
if url is None: | |||
self.open_file() | |||
try: | |||
assert sys.platform != "win32", "redis cache on windows not tested" | |||
parse_result = urllib.parse.urlparse(url, scheme="redis") | |||
assert parse_result.scheme == "redis", "unsupported scheme" | |||
assert not parse_result.username, "redis conn with username unsupported" | |||
assert self.try_open_redis( | |||
parse_result.hostname, parse_result.port, parse_result.password, prefix | |||
), "connect failed" | |||
except Exception as exc: | |||
get_logger().error( | |||
"failed to connect to cache server {!r}; try fallback to " | |||
"in-file cache".format(exc) | |||
) | |||
self.open_file() | |||
_manager = None | |||
@@ -60,3 +85,23 @@ def get_manager(): | |||
if _manager is None: | |||
_manager = PersistentCacheManager() | |||
return _manager | |||
def _clean(): | |||
nr_del = get_manager().clean() | |||
if nr_del is not None: | |||
print("{} cache entries deleted".format(nr_del)) | |||
def main(): | |||
parser = argparse.ArgumentParser(description="manage persistent cache") | |||
subp = parser.add_subparsers(description="action to be performed", dest="cmd") | |||
subp.required = True | |||
subp_clean = subp.add_parser("clean", help="clean all the cache of current user") | |||
subp_clean.set_defaults(action=_clean) | |||
args = parser.parse_args() | |||
args.action() | |||
if __name__ == "__main__": | |||
main() |
@@ -245,6 +245,11 @@ void init_utils(py::module m) { | |||
} | |||
return false; | |||
} | |||
bool open_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix) { | |||
return try_reg(mgb::imperative::persistent_cache::make_redis( | |||
ip, port, password, prefix)); | |||
} | |||
bool open_file(std::string path) { | |||
return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||
} | |||
@@ -271,6 +276,7 @@ void init_utils(py::module m) { | |||
py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||
.def(py::init<>()) | |||
.def("try_open_redis", &PersistentCacheManager::open_redis) | |||
.def("try_open_file", &PersistentCacheManager::open_file) | |||
.def("clean", &PersistentCacheManager::clean) | |||
.def("put", &PersistentCacheManager::put) | |||
@@ -13,12 +13,109 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "cpp_redis/cpp_redis" | |||
#include "megbrain/imperative/persistent_cache.h" | |||
#include "megbrain/imperative/utils/base64.h" | |||
#include "megbrain/utils/infile_persistent_cache.h" | |||
namespace mgb::imperative::persistent_cache { | |||
class RedisCache final : public ExtendedPersistentCache { | |||
public: | |||
RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { | |||
m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||
} | |||
bool connect(std::string ip, size_t port, std::string password) { | |||
m_client.auth(password); | |||
m_client.connect( | |||
ip, port, | |||
[](const std::string& host, std::size_t port, | |||
cpp_redis::connect_state status) { | |||
if (status == cpp_redis::connect_state::dropped) { | |||
mgb_log("client disconnected from %s.", host.c_str()); | |||
mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), | |||
port); | |||
} | |||
}, | |||
std::uint32_t(200)); | |||
if (!m_client.is_connected()) { | |||
return false; | |||
} | |||
auto flag = m_client.get("mgb-cache-flag"); | |||
sync(); | |||
return flag.get().ok(); | |||
} | |||
bool valid() const override { return m_client.is_connected(); } | |||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto mem_result = m_local->get(category, key); | |||
if (mem_result.valid()) | |||
return mem_result; | |||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
std::string redis_key_str; | |||
encode(category + '@' + key_str, redis_key_str, 24); | |||
auto result = m_client.get(redis_key_str); | |||
sync(); | |||
auto content = result.get(); | |||
if (content.is_null()) | |||
return mgb::None; | |||
std::string decode_content; | |||
decode(content.as_string(), decode_content); | |||
m_local->put(category, key, {decode_content.data(), decode_content.length()}); | |||
return m_local->get(category, key); | |||
} | |||
void put(const std::string& category, const Blob& key, const Blob& value) override { | |||
MGB_LOCK_GUARD(m_mtx); | |||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
std::string redis_key_str; | |||
encode(category + '@' + key_str, redis_key_str); | |||
std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||
std::string redis_value_str; | |||
encode(value_str, redis_value_str); | |||
auto result = m_client.set(redis_key_str, redis_value_str); | |||
m_local->put(category, key, value); | |||
sync(); | |||
} | |||
std::optional<size_t> clear() override { | |||
size_t cursor = 0, nr_deleted = 0; | |||
std::string pattern = m_prefix + "@*"; | |||
do { | |||
auto reply = m_client.scan(cursor, pattern).share(); | |||
sync(); | |||
auto keys = reply.get().as_array(); | |||
std::vector<std::string> string_keys; | |||
for (auto&& key : keys) { | |||
string_keys.push_back(key.as_string()); | |||
} | |||
m_client.del(string_keys); | |||
nr_deleted += string_keys.size(); | |||
cursor = reply.get().as_array()[0].as_integer(); | |||
} while (cursor != 0); | |||
return nr_deleted; | |||
} | |||
private: | |||
std::shared_ptr<mgb::PersistentCache> m_local; | |||
std::mutex m_mtx; | |||
cpp_redis::client m_client; | |||
std::string m_prefix; | |||
uint64_t m_timeout; | |||
void sync() { | |||
m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout)); | |||
mgb_assert(valid()); | |||
} | |||
}; | |||
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | |||
private: | |||
std::string m_path; | |||
@@ -68,6 +165,15 @@ public: | |||
bool valid() const override { return m_impl != nullptr; } | |||
}; | |||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix) { | |||
auto cache = std::make_shared<RedisCache>(prefix, 100); | |||
if (!cache->connect(ip, port, password)) { | |||
return nullptr; | |||
} | |||
return cache; | |||
} | |||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
if (!cache->open(path)) { | |||
@@ -22,6 +22,9 @@ public: | |||
virtual std::optional<size_t> clear() = 0; | |||
}; | |||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix); | |||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||
} // namespace mgb::imperative::persistent_cache | |||
@@ -12,7 +12,7 @@ endif() | |||
# TODO: turn python binding into a static/object library | |||
add_executable(imperative_test ${SOURCES} ${SRCS}) | |||
add_dependencies(imperative_test mgb_opdef) | |||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
# Python binding | |||
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
@@ -35,6 +35,10 @@ configure_file(src/lite_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/l | |||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/genfiles/lite_build_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/lite/include) | |||
# begin config lite | |||
if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) | |||
# FXIME third_party cpp redis do not support build with clang-cl | |||
list(APPEND SOURCES_LITE ${CPP_REDIS_SRCS}) | |||
endif() | |||
add_library(lite_static STATIC ${SOURCES_LITE}) | |||
add_dependencies(lite_static lite_fbs_generate) | |||
include_directories($<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/genfiles>) | |||
@@ -106,6 +110,14 @@ endif() | |||
if(LITE_BUILD_WITH_MGE) | |||
target_link_libraries(lite_static_all_in_one PRIVATE megbrain megdnn ${MGE_CUDA_LIBS}) | |||
endif() | |||
if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) | |||
# FXIME third_party cpp redis do not support build with clang-cl | |||
target_include_directories(lite_static PRIVATE ${CPP_REDIS_INCLUDES}) | |||
target_include_directories(lite_shared PRIVATE ${CPP_REDIS_INCLUDES}) | |||
target_include_directories(lite_shared_whl PRIVATE ${CPP_REDIS_INCLUDES}) | |||
target_include_directories(lite_static_all_in_one PRIVATE ${CPP_REDIS_INCLUDES}) | |||
endif() | |||
set(LITE_VERSION_SCRIPT ${PROJECT_SOURCE_DIR}/lite/src/version_lite.ld CACHE INTERNAL "Path to linker version script") | |||
add_custom_target(_lite_version_ld SOURCES ${LITE_VERSION_SCRIPT}) | |||
if(NOT MSVC AND NOT WIN32) | |||