@@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||||
set(MODULE_NAME _imperative_rt) | set(MODULE_NAME _imperative_rt) | ||||
set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | ||||
file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | ||||
set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) | |||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | ||||
@@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) | |||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | ||||
target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | ||||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||||
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | ||||
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | ||||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | if(CXX_SUPPORT_WCLASS_MEMACCESS) | ||||
@@ -11,6 +11,7 @@ import argparse | |||||
import getpass | import getpass | ||||
import os | import os | ||||
import sys | import sys | ||||
import urllib.parse | |||||
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
@@ -23,8 +24,10 @@ class PersistentCacheManager(_PersistentCacheManager): | |||||
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | ||||
get_logger().info("fastrun use in-memory cache") | get_logger().info("fastrun use in-memory cache") | ||||
self.open_memory() | self.open_memory() | ||||
else: | |||||
elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": | |||||
self.open_file() | self.open_file() | ||||
else: | |||||
self.open_redis() | |||||
def open_memory(self): | def open_memory(self): | ||||
pass | pass | ||||
@@ -51,6 +54,28 @@ class PersistentCacheManager(_PersistentCacheManager): | |||||
) | ) | ||||
self.open_memory() | self.open_memory() | ||||
def open_redis(self): | |||||
prefix = "mgbcache:{}:MGB{}:GIT:{}".format( | |||||
getpass.getuser(), __version__, git_version | |||||
) | |||||
url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||||
if url is None: | |||||
self.open_file() | |||||
try: | |||||
assert sys.platform != "win32", "redis cache on windows not tested" | |||||
parse_result = urllib.parse.urlparse(url, scheme="redis") | |||||
assert parse_result.scheme == "redis", "unsupported scheme" | |||||
assert not parse_result.username, "redis conn with username unsupported" | |||||
assert self.try_open_redis( | |||||
parse_result.hostname, parse_result.port, parse_result.password, prefix | |||||
), "connect failed" | |||||
except Exception as exc: | |||||
get_logger().error( | |||||
"failed to connect to cache server {!r}; try fallback to " | |||||
"in-file cache".format(exc) | |||||
) | |||||
self.open_file() | |||||
_manager = None | _manager = None | ||||
@@ -60,3 +85,23 @@ def get_manager(): | |||||
if _manager is None: | if _manager is None: | ||||
_manager = PersistentCacheManager() | _manager = PersistentCacheManager() | ||||
return _manager | return _manager | ||||
def _clean(): | |||||
nr_del = get_manager().clean() | |||||
if nr_del is not None: | |||||
print("{} cache entries deleted".format(nr_del)) | |||||
def main(): | |||||
parser = argparse.ArgumentParser(description="manage persistent cache") | |||||
subp = parser.add_subparsers(description="action to be performed", dest="cmd") | |||||
subp.required = True | |||||
subp_clean = subp.add_parser("clean", help="clean all the cache of current user") | |||||
subp_clean.set_defaults(action=_clean) | |||||
args = parser.parse_args() | |||||
args.action() | |||||
if __name__ == "__main__": | |||||
main() |
@@ -245,6 +245,11 @@ void init_utils(py::module m) { | |||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
bool open_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix) { | |||||
return try_reg(mgb::imperative::persistent_cache::make_redis( | |||||
ip, port, password, prefix)); | |||||
} | |||||
bool open_file(std::string path) { | bool open_file(std::string path) { | ||||
return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | ||||
} | } | ||||
@@ -271,6 +276,7 @@ void init_utils(py::module m) { | |||||
py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | ||||
.def(py::init<>()) | .def(py::init<>()) | ||||
.def("try_open_redis", &PersistentCacheManager::open_redis) | |||||
.def("try_open_file", &PersistentCacheManager::open_file) | .def("try_open_file", &PersistentCacheManager::open_file) | ||||
.def("clean", &PersistentCacheManager::clean) | .def("clean", &PersistentCacheManager::clean) | ||||
.def("put", &PersistentCacheManager::put) | .def("put", &PersistentCacheManager::put) | ||||
@@ -13,12 +13,109 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "cpp_redis/cpp_redis" | |||||
#include "megbrain/imperative/persistent_cache.h" | #include "megbrain/imperative/persistent_cache.h" | ||||
#include "megbrain/imperative/utils/base64.h" | #include "megbrain/imperative/utils/base64.h" | ||||
#include "megbrain/utils/infile_persistent_cache.h" | #include "megbrain/utils/infile_persistent_cache.h" | ||||
namespace mgb::imperative::persistent_cache { | namespace mgb::imperative::persistent_cache { | ||||
class RedisCache final : public ExtendedPersistentCache { | |||||
public: | |||||
RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { | |||||
m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||||
} | |||||
bool connect(std::string ip, size_t port, std::string password) { | |||||
m_client.auth(password); | |||||
m_client.connect( | |||||
ip, port, | |||||
[](const std::string& host, std::size_t port, | |||||
cpp_redis::connect_state status) { | |||||
if (status == cpp_redis::connect_state::dropped) { | |||||
mgb_log("client disconnected from %s.", host.c_str()); | |||||
mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), | |||||
port); | |||||
} | |||||
}, | |||||
std::uint32_t(200)); | |||||
if (!m_client.is_connected()) { | |||||
return false; | |||||
} | |||||
auto flag = m_client.get("mgb-cache-flag"); | |||||
sync(); | |||||
return flag.get().ok(); | |||||
} | |||||
bool valid() const override { return m_client.is_connected(); } | |||||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
auto mem_result = m_local->get(category, key); | |||||
if (mem_result.valid()) | |||||
return mem_result; | |||||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||||
std::string redis_key_str; | |||||
encode(category + '@' + key_str, redis_key_str, 24); | |||||
auto result = m_client.get(redis_key_str); | |||||
sync(); | |||||
auto content = result.get(); | |||||
if (content.is_null()) | |||||
return mgb::None; | |||||
std::string decode_content; | |||||
decode(content.as_string(), decode_content); | |||||
m_local->put(category, key, {decode_content.data(), decode_content.length()}); | |||||
return m_local->get(category, key); | |||||
} | |||||
void put(const std::string& category, const Blob& key, const Blob& value) override { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||||
std::string redis_key_str; | |||||
encode(category + '@' + key_str, redis_key_str); | |||||
std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||||
std::string redis_value_str; | |||||
encode(value_str, redis_value_str); | |||||
auto result = m_client.set(redis_key_str, redis_value_str); | |||||
m_local->put(category, key, value); | |||||
sync(); | |||||
} | |||||
std::optional<size_t> clear() override { | |||||
size_t cursor = 0, nr_deleted = 0; | |||||
std::string pattern = m_prefix + "@*"; | |||||
do { | |||||
auto reply = m_client.scan(cursor, pattern).share(); | |||||
sync(); | |||||
auto keys = reply.get().as_array(); | |||||
std::vector<std::string> string_keys; | |||||
for (auto&& key : keys) { | |||||
string_keys.push_back(key.as_string()); | |||||
} | |||||
m_client.del(string_keys); | |||||
nr_deleted += string_keys.size(); | |||||
cursor = reply.get().as_array()[0].as_integer(); | |||||
} while (cursor != 0); | |||||
return nr_deleted; | |||||
} | |||||
private: | |||||
std::shared_ptr<mgb::PersistentCache> m_local; | |||||
std::mutex m_mtx; | |||||
cpp_redis::client m_client; | |||||
std::string m_prefix; | |||||
uint64_t m_timeout; | |||||
void sync() { | |||||
m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout)); | |||||
mgb_assert(valid()); | |||||
} | |||||
}; | |||||
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | ||||
private: | private: | ||||
std::string m_path; | std::string m_path; | ||||
@@ -68,6 +165,15 @@ public: | |||||
bool valid() const override { return m_impl != nullptr; } | bool valid() const override { return m_impl != nullptr; } | ||||
}; | }; | ||||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix) { | |||||
auto cache = std::make_shared<RedisCache>(prefix, 100); | |||||
if (!cache->connect(ip, port, password)) { | |||||
return nullptr; | |||||
} | |||||
return cache; | |||||
} | |||||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | ||||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | ||||
if (!cache->open(path)) { | if (!cache->open(path)) { | ||||
@@ -22,6 +22,9 @@ public: | |||||
virtual std::optional<size_t> clear() = 0; | virtual std::optional<size_t> clear() = 0; | ||||
}; | }; | ||||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix); | |||||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | ||||
} // namespace mgb::imperative::persistent_cache | } // namespace mgb::imperative::persistent_cache | ||||
@@ -12,7 +12,7 @@ endif() | |||||
# TODO: turn python binding into a static/object library | # TODO: turn python binding into a static/object library | ||||
add_executable(imperative_test ${SOURCES} ${SRCS}) | add_executable(imperative_test ${SOURCES} ${SRCS}) | ||||
add_dependencies(imperative_test mgb_opdef) | add_dependencies(imperative_test mgb_opdef) | ||||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||||
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||||
# Python binding | # Python binding | ||||
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | ||||
@@ -35,6 +35,10 @@ configure_file(src/lite_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/l | |||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/genfiles/lite_build_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/lite/include) | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/genfiles/lite_build_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/lite/include) | ||||
# begin config lite | # begin config lite | ||||
if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) | |||||
# FXIME third_party cpp redis do not support build with clang-cl | |||||
list(APPEND SOURCES_LITE ${CPP_REDIS_SRCS}) | |||||
endif() | |||||
add_library(lite_static STATIC ${SOURCES_LITE}) | add_library(lite_static STATIC ${SOURCES_LITE}) | ||||
add_dependencies(lite_static lite_fbs_generate) | add_dependencies(lite_static lite_fbs_generate) | ||||
include_directories($<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/genfiles>) | include_directories($<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/genfiles>) | ||||
@@ -106,6 +110,14 @@ endif() | |||||
if(LITE_BUILD_WITH_MGE) | if(LITE_BUILD_WITH_MGE) | ||||
target_link_libraries(lite_static_all_in_one PRIVATE megbrain megdnn ${MGE_CUDA_LIBS}) | target_link_libraries(lite_static_all_in_one PRIVATE megbrain megdnn ${MGE_CUDA_LIBS}) | ||||
endif() | endif() | ||||
if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) | |||||
# FXIME third_party cpp redis do not support build with clang-cl | |||||
target_include_directories(lite_static PRIVATE ${CPP_REDIS_INCLUDES}) | |||||
target_include_directories(lite_shared PRIVATE ${CPP_REDIS_INCLUDES}) | |||||
target_include_directories(lite_shared_whl PRIVATE ${CPP_REDIS_INCLUDES}) | |||||
target_include_directories(lite_static_all_in_one PRIVATE ${CPP_REDIS_INCLUDES}) | |||||
endif() | |||||
set(LITE_VERSION_SCRIPT ${PROJECT_SOURCE_DIR}/lite/src/version_lite.ld CACHE INTERNAL "Path to linker version script") | set(LITE_VERSION_SCRIPT ${PROJECT_SOURCE_DIR}/lite/src/version_lite.ld CACHE INTERNAL "Path to linker version script") | ||||
add_custom_target(_lite_version_ld SOURCES ${LITE_VERSION_SCRIPT}) | add_custom_target(_lite_version_ld SOURCES ${LITE_VERSION_SCRIPT}) | ||||
if(NOT MSVC AND NOT WIN32) | if(NOT MSVC AND NOT WIN32) | ||||