@@ -84,7 +84,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||||
from .serialization import load, save | from .serialization import load, save | ||||
from .tensor import Parameter, Tensor, tensor | from .tensor import Parameter, Tensor, tensor | ||||
from .utils import comp_graph_tools as cgtools | from .utils import comp_graph_tools as cgtools | ||||
from .utils import persistent_cache | |||||
from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheOnServer | |||||
from .version import __version__ | from .version import __version__ | ||||
_set_fork_exec_path_for_timed_func( | _set_fork_exec_path_for_timed_func( | ||||
@@ -92,15 +92,13 @@ _set_fork_exec_path_for_timed_func( | |||||
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | ||||
) | ) | ||||
atexit.register(_close) | |||||
del _set_fork_exec_path_for_timed_func | del _set_fork_exec_path_for_timed_func | ||||
_exit_handlers = [] | _exit_handlers = [] | ||||
def _run_exit_handlers(): | def _run_exit_handlers(): | ||||
for handler in _exit_handlers: | |||||
for handler in reversed(_exit_handlers): | |||||
handler() | handler() | ||||
_exit_handlers.clear() | _exit_handlers.clear() | ||||
@@ -117,6 +115,13 @@ def _atexit(handler): | |||||
_exit_handlers.append(handler) | _exit_handlers.append(handler) | ||||
_atexit(_close) | |||||
_persistent_cache = _PersistentCacheOnServer() | |||||
_persistent_cache.reg() | |||||
_atexit(_persistent_cache.flush) | |||||
# subpackages | # subpackages | ||||
import megengine.amp | import megengine.amp | ||||
import megengine.autodiff | import megengine.autodiff | ||||
@@ -132,5 +137,3 @@ import megengine.quantization | |||||
import megengine.random | import megengine.random | ||||
import megengine.utils | import megengine.utils | ||||
import megengine.traced_module | import megengine.traced_module | ||||
persistent_cache.get_manager() |
@@ -8,87 +8,114 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import argparse | import argparse | ||||
import contextlib | |||||
import getpass | import getpass | ||||
import os | import os | ||||
import sys | import sys | ||||
import urllib.parse | import urllib.parse | ||||
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||||
import filelock | |||||
from ..core._imperative_rt import PersistentCache as _PersistentCache | |||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..version import __version__, git_version | from ..version import __version__, git_version | ||||
class PersistentCacheManager(_PersistentCacheManager): | |||||
class PersistentCacheOnServer(_PersistentCache): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | |||||
get_logger().info("fastrun use in-memory cache") | |||||
self.open_memory() | |||||
elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": | |||||
self.open_file() | |||||
else: | |||||
self.open_redis() | |||||
def open_memory(self): | |||||
pass | |||||
cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE") | |||||
if cache_type not in ("FILE", "MEMORY"): | |||||
try: | |||||
redis_config = self.get_redis_config() | |||||
except Exception as exc: | |||||
get_logger().error( | |||||
"failed to connect to cache server {!r}; try fallback to " | |||||
"in-file cache".format(exc) | |||||
) | |||||
else: | |||||
self.add_config( | |||||
"redis", | |||||
redis_config, | |||||
"fastrun use redis cache", | |||||
"failed to connect to cache server", | |||||
) | |||||
if cache_type != "MEMORY": | |||||
path = self.get_cache_file(self.get_cache_dir()) | |||||
self.add_config( | |||||
"in-file", | |||||
{"path": path}, | |||||
"fastrun use in-file cache in {}".format(path), | |||||
"failed to create cache file in {}".format(path), | |||||
) | |||||
self.add_config( | |||||
"in-memory", | |||||
{}, | |||||
"fastrun use in-memory cache", | |||||
"failed to create in-memory cache", | |||||
) | |||||
def open_file(self): | |||||
def get_cache_dir(self): | |||||
cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | ||||
try: | |||||
if not cache_dir: | |||||
from ..hub.hub import _get_megengine_home | |||||
if not cache_dir: | |||||
from ..hub.hub import _get_megengine_home | |||||
cache_dir = os.path.expanduser( | |||||
os.path.join(_get_megengine_home(), "persistent_cache.bin") | |||||
) | |||||
os.makedirs(cache_dir, exist_ok=True) | |||||
cache_file = os.path.join(cache_dir, "cache") | |||||
with open(cache_file, "a"): | |||||
pass | |||||
assert self.try_open_file(cache_file), "cannot create file" | |||||
get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) | |||||
except Exception as exc: | |||||
get_logger().error( | |||||
"failed to create cache file in {} {!r}; fallback to " | |||||
"in-memory cache".format(cache_dir, exc) | |||||
cache_dir = os.path.expanduser( | |||||
os.path.join(_get_megengine_home(), "persistent_cache") | |||||
) | ) | ||||
self.open_memory() | |||||
def open_redis(self): | |||||
os.makedirs(cache_dir, exist_ok=True) | |||||
return cache_dir | |||||
def get_cache_file(self, cache_dir): | |||||
cache_file = os.path.join(cache_dir, "cache.bin") | |||||
with open(cache_file, "a"): | |||||
pass | |||||
return cache_file | |||||
@contextlib.contextmanager | |||||
def lock_cache_file(self, cache_dir): | |||||
lock_file = os.path.join(cache_dir, "cache.lock") | |||||
with filelock.FileLock(lock_file): | |||||
yield | |||||
def get_redis_config(self): | |||||
url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||||
if url is None: | |||||
return None | |||||
assert sys.platform != "win32", "redis cache on windows not tested" | |||||
prefix = "mgbcache:{}:MGB{}:GIT:{}".format( | prefix = "mgbcache:{}:MGB{}:GIT:{}".format( | ||||
getpass.getuser(), __version__, git_version | getpass.getuser(), __version__, git_version | ||||
) | ) | ||||
url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||||
if url is None: | |||||
self.open_file() | |||||
try: | |||||
assert sys.platform != "win32", "redis cache on windows not tested" | |||||
parse_result = urllib.parse.urlparse(url, scheme="redis") | |||||
assert parse_result.scheme == "redis", "unsupported scheme" | |||||
assert not parse_result.username, "redis conn with username unsupported" | |||||
assert self.try_open_redis( | |||||
parse_result.hostname, parse_result.port, parse_result.password, prefix | |||||
), "connect failed" | |||||
except Exception as exc: | |||||
get_logger().error( | |||||
"failed to connect to cache server {!r}; try fallback to " | |||||
"in-file cache".format(exc) | |||||
) | |||||
self.open_file() | |||||
_manager = None | |||||
parse_result = urllib.parse.urlparse(url) | |||||
assert not parse_result.username, "redis conn with username unsupported" | |||||
if parse_result.scheme == "redis": | |||||
assert parse_result.hostname and parse_result.port, "invalid url" | |||||
assert not parse_result.path | |||||
config = { | |||||
"hostname": parse_result.hostname, | |||||
"port": str(parse_result.port), | |||||
} | |||||
elif parse_result.scheme == "redis+socket": | |||||
assert not (parse_result.hostname or parse_result.port) | |||||
assert parse_result.path | |||||
config = { | |||||
"unixsocket": parse_result.path, | |||||
} | |||||
else: | |||||
assert False, "unsupported scheme" | |||||
if parse_result.password is not None: | |||||
config["password"] = parse_result.password | |||||
config["prefix"] = prefix | |||||
return config | |||||
def get_manager(): | |||||
global _manager | |||||
if _manager is None: | |||||
_manager = PersistentCacheManager() | |||||
return _manager | |||||
def flush(self): | |||||
if self.config is not None and self.config.type == "in-file": | |||||
with self.lock_cache_file(self.get_cache_dir()): | |||||
super().flush() | |||||
def _clean(): | def _clean(): | ||||
nr_del = get_manager().clean() | |||||
nr_del = PersistentCacheOnServer().clean() | |||||
if nr_del is not None: | if nr_del is not None: | ||||
print("{} cache entries deleted".format(nr_del)) | print("{} cache entries deleted".format(nr_del)) | ||||
@@ -4,8 +4,8 @@ pyarrow | |||||
requests | requests | ||||
tabulate | tabulate | ||||
tqdm | tqdm | ||||
redispy | |||||
deprecated | deprecated | ||||
mprop | mprop | ||||
wheel | wheel | ||||
megfile>=0.0.10 | |||||
megfile>=0.0.10 | |||||
filelock |
@@ -210,7 +210,7 @@ void init_utils(py::module m) { | |||||
.def("disable", [](TensorSanityCheck& checker) { checker.disable(); }); | .def("disable", [](TensorSanityCheck& checker) { checker.disable(); }); | ||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), | |||||
m.def("create_mm_server", &mgb::opr::create_zmqrpc_server, py::arg("addr"), | |||||
py::arg("port") = 0); | py::arg("port") = 0); | ||||
#else | #else | ||||
m.def("create_mm_server", []() {}); | m.def("create_mm_server", []() {}); | ||||
@@ -234,51 +234,108 @@ void init_utils(py::module m) { | |||||
using ExtendedPersistentCache = | using ExtendedPersistentCache = | ||||
mgb::imperative::persistent_cache::ExtendedPersistentCache; | mgb::imperative::persistent_cache::ExtendedPersistentCache; | ||||
struct PersistentCacheManager { | |||||
std::shared_ptr<ExtendedPersistentCache> instance; | |||||
struct ConfigurablePersistentCache : mgb::PersistentCache { | |||||
struct Config { | |||||
std::string type; | |||||
std::unordered_map<std::string, std::string> args; | |||||
std::string on_success; | |||||
std::string on_fail; | |||||
}; | |||||
bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) { | |||||
if (cache) { | |||||
instance = cache; | |||||
PersistentCache::set_impl(cache); | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
bool open_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix) { | |||||
return try_reg(mgb::imperative::persistent_cache::make_redis( | |||||
ip, port, password, prefix)); | |||||
std::shared_ptr<ExtendedPersistentCache> impl; | |||||
std::optional<Config> impl_config; | |||||
std::vector<Config> configs; | |||||
void add_config( | |||||
std::string type, std::unordered_map<std::string, std::string> args, | |||||
std::string on_success, std::string on_fail) { | |||||
configs.push_back({type, args, on_success, on_fail}); | |||||
} | } | ||||
bool open_file(std::string path) { | |||||
return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||||
std::optional<size_t> clean() { return get_impl()->clear(); } | |||||
void load_config() { | |||||
std::optional<std::string> err_msg; | |||||
for (size_t i = 0; i < configs.size(); ++i) { | |||||
auto& config = configs[i]; | |||||
if (err_msg) { | |||||
mgb_log_warn("try fallback to %s cache", config.type.c_str()); | |||||
} else { | |||||
err_msg.emplace(); | |||||
} | |||||
auto cache = ExtendedPersistentCache::make_from_config( | |||||
config.type, config.args, *err_msg); | |||||
if (!cache) { | |||||
mgb_log_warn("%s %s", config.on_fail.c_str(), err_msg->c_str()); | |||||
} else { | |||||
impl = cache; | |||||
impl_config = config; | |||||
break; | |||||
} | |||||
} | |||||
mgb_assert(impl_config.has_value(), "not valid config"); | |||||
} | } | ||||
std::optional<size_t> clean() { | |||||
if (instance) { | |||||
return instance->clear(); | |||||
std::shared_ptr<ExtendedPersistentCache> get_impl() { | |||||
if (!impl) { | |||||
load_config(); | |||||
} | } | ||||
return {}; | |||||
return impl; | |||||
} | } | ||||
void put(std::string category, std::string key, std::string value) { | |||||
PersistentCache::inst().put( | |||||
category, {key.data(), key.size()}, {value.data(), value.size()}); | |||||
virtual mgb::Maybe<Blob> get(const std::string& category, const Blob& key) { | |||||
return get_impl()->get(category, key); | |||||
} | |||||
virtual void put( | |||||
const std::string& category, const Blob& key, const Blob& value) { | |||||
return get_impl()->put(category, key, value); | |||||
} | } | ||||
py::object get(std::string category, std::string key) { | |||||
auto value = | |||||
PersistentCache::inst().get(category, {key.data(), key.size()}); | |||||
virtual bool support_dump_cache() { return get_impl()->support_dump_cache(); } | |||||
py::object py_get(std::string category, std::string key) { | |||||
auto value = get_impl()->get(category, {key.data(), key.size()}); | |||||
if (value.valid()) { | if (value.valid()) { | ||||
return py::bytes(std::string((const char*)value->ptr, value->size)); | return py::bytes(std::string((const char*)value->ptr, value->size)); | ||||
} else { | } else { | ||||
return py::none(); | return py::none(); | ||||
} | } | ||||
} | } | ||||
void py_put(std::string category, std::string key, std::string value) { | |||||
get_impl()->put( | |||||
category, {key.data(), key.size()}, {value.data(), value.size()}); | |||||
} | |||||
void flush() { | |||||
if (impl) { | |||||
impl->flush(); | |||||
} | |||||
} | |||||
}; | }; | ||||
py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||||
.def(py::init<>()) | |||||
.def("try_open_redis", &PersistentCacheManager::open_redis) | |||||
.def("try_open_file", &PersistentCacheManager::open_file) | |||||
.def("clean", &PersistentCacheManager::clean) | |||||
.def("put", &PersistentCacheManager::put) | |||||
.def("get", &PersistentCacheManager::get); | |||||
auto PyConfigurablePersistentCache = | |||||
py::class_< | |||||
ConfigurablePersistentCache, | |||||
std::shared_ptr<ConfigurablePersistentCache>>(m, "PersistentCache") | |||||
.def(py::init<>()) | |||||
.def("add_config", &ConfigurablePersistentCache::add_config) | |||||
.def("reg", | |||||
[](std::shared_ptr<ConfigurablePersistentCache> inst) { | |||||
PersistentCache::set_impl(inst); | |||||
}) | |||||
.def("clean", &ConfigurablePersistentCache::clean) | |||||
.def("get", &ConfigurablePersistentCache::py_get) | |||||
.def("put", &ConfigurablePersistentCache::py_put) | |||||
.def_readonly("config", &ConfigurablePersistentCache::impl_config) | |||||
.def("flush", &ConfigurablePersistentCache::flush); | |||||
py::class_<ConfigurablePersistentCache::Config>( | |||||
PyConfigurablePersistentCache, "Config") | |||||
.def_readwrite("type", &ConfigurablePersistentCache::Config::type) | |||||
.def_readwrite("args", &ConfigurablePersistentCache::Config::args) | |||||
.def_readwrite("on_fail", &ConfigurablePersistentCache::Config::on_fail) | |||||
.def_readwrite( | |||||
"on_success", &ConfigurablePersistentCache::Config::on_success); | |||||
} | } |
@@ -27,7 +27,7 @@ namespace imperative { | |||||
namespace { | namespace { | ||||
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& comm = def.cast_final_safe<CollectiveComm>(); | auto&& comm = def.cast_final_safe<CollectiveComm>(); | ||||
auto group_client = std::make_shared<GroupClientProxy>( | |||||
auto group_client = std::make_shared<opr::GroupClientProxy>( | |||||
ssprintf("%s:%d", comm.addr.data(), comm.port)); | ssprintf("%s:%d", comm.addr.data(), comm.port)); | ||||
SmallVector<std::shared_ptr<mgb::DeviceTensorND>> dev_buffer_arr(1, nullptr); | SmallVector<std::shared_ptr<mgb::DeviceTensorND>> dev_buffer_arr(1, nullptr); | ||||
auto disable = std::make_shared<DTypeScalar>(); | auto disable = std::make_shared<DTypeScalar>(); | ||||
@@ -28,7 +28,7 @@ namespace { | |||||
cg::OperatorNodeBase* apply_on_var_node_remote_send( | cg::OperatorNodeBase* apply_on_var_node_remote_send( | ||||
const OpDef& def, const VarNodeArray& inputs) { | const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& send = def.cast_final_safe<RemoteSend>(); | auto&& send = def.cast_final_safe<RemoteSend>(); | ||||
auto group_client = std::make_shared<GroupClientProxy>( | |||||
auto group_client = std::make_shared<opr::GroupClientProxy>( | |||||
ssprintf("%s:%d", send.addr.data(), send.port)); | ssprintf("%s:%d", send.addr.data(), send.port)); | ||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||||
auto&& recv = def.cast_final_safe<RemoteRecv>(); | auto&& recv = def.cast_final_safe<RemoteRecv>(); | ||||
OperatorNodeConfig config{recv.cn}; | OperatorNodeConfig config{recv.cn}; | ||||
config.name(recv.make_name()); | config.name(recv.make_name()); | ||||
auto group_client = std::make_shared<GroupClientProxy>( | |||||
auto group_client = std::make_shared<opr::GroupClientProxy>( | |||||
ssprintf("%s:%d", recv.addr.data(), recv.port)); | ssprintf("%s:%d", recv.addr.data(), recv.port)); | ||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | ||||
@@ -27,8 +27,10 @@ public: | |||||
m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | ||||
} | } | ||||
bool connect(std::string ip, size_t port, std::string password) { | |||||
m_client.auth(password); | |||||
void connect(std::string ip, size_t port, std::optional<std::string> password) { | |||||
if (password) { | |||||
m_client.auth(*password); | |||||
} | |||||
m_client.connect( | m_client.connect( | ||||
ip, port, | ip, port, | ||||
[](const std::string& host, std::size_t port, | [](const std::string& host, std::size_t port, | ||||
@@ -40,16 +42,32 @@ public: | |||||
} | } | ||||
}, | }, | ||||
std::uint32_t(200)); | std::uint32_t(200)); | ||||
if (!m_client.is_connected()) { | |||||
return false; | |||||
} | |||||
mgb_assert(m_client.is_connected(), "connect failed"); | |||||
auto flag = m_client.get("mgb-cache-flag"); | auto flag = m_client.get("mgb-cache-flag"); | ||||
sync(); | sync(); | ||||
return flag.get().ok(); | |||||
auto is_valid = [](const cpp_redis::reply& reply) { | |||||
switch (reply.get_type()) { | |||||
case cpp_redis::reply::type::error: | |||||
case cpp_redis::reply::type::null: | |||||
return false; | |||||
case cpp_redis::reply::type::integer: | |||||
return reply.as_integer() != 0; | |||||
case cpp_redis::reply::type::simple_string: | |||||
case cpp_redis::reply::type::bulk_string: | |||||
return !reply.as_string().empty(); | |||||
case cpp_redis::reply::type::array: | |||||
return !reply.as_array().empty(); | |||||
default: | |||||
mgb_assert(false, "unknown reply type %d", (int)reply.get_type()); | |||||
} | |||||
}; | |||||
mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag"); | |||||
} | } | ||||
bool valid() const override { return m_client.is_connected(); } | bool valid() const override { return m_client.is_connected(); } | ||||
void flush() override {} | |||||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | ||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
auto mem_result = m_local->get(category, key); | auto mem_result = m_local->get(category, key); | ||||
@@ -75,7 +93,7 @@ public: | |||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | std::string key_str(static_cast<const char*>(key.ptr), key.size); | ||||
std::string redis_key_str; | std::string redis_key_str; | ||||
encode(category + '@' + key_str, redis_key_str); | |||||
encode(category + '@' + key_str, redis_key_str, 24); | |||||
std::string value_str(static_cast<const char*>(value.ptr), value.size); | std::string value_str(static_cast<const char*>(value.ptr), value.size); | ||||
std::string redis_value_str; | std::string redis_value_str; | ||||
encode(value_str, redis_value_str); | encode(value_str, redis_value_str); | ||||
@@ -118,18 +136,16 @@ private: | |||||
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | ||||
private: | private: | ||||
std::string m_path; | |||||
std::optional<std::string> m_path; | |||||
std::unique_ptr<mgb::InFilePersistentCache> m_impl; | std::unique_ptr<mgb::InFilePersistentCache> m_impl; | ||||
public: | public: | ||||
ExtendedInFilePersistentCache() = default; | ExtendedInFilePersistentCache() = default; | ||||
bool open(std::string path) { | |||||
void open(std::string path) { | |||||
std::fstream file; | std::fstream file; | ||||
file.open(path, std::ios::in | std::ios::binary); | file.open(path, std::ios::in | std::ios::binary); | ||||
if (!file.is_open()) { | |||||
return false; | |||||
} | |||||
mgb_assert(file.is_open(), "can't open file in %s", path.c_str()); | |||||
std::vector<char> bytes = { | std::vector<char> bytes = { | ||||
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; | std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; | ||||
if (bytes.size()) { | if (bytes.size()) { | ||||
@@ -139,14 +155,11 @@ public: | |||||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | m_impl = std::make_unique<mgb::InFilePersistentCache>(); | ||||
} | } | ||||
m_path = path; | m_path = path; | ||||
return true; | |||||
} | } | ||||
~ExtendedInFilePersistentCache() { | |||||
if (m_impl) { | |||||
m_impl->dump_cache(m_path.c_str()); | |||||
} | |||||
} | |||||
void open() { m_impl = std::make_unique<mgb::InFilePersistentCache>(); } | |||||
~ExtendedInFilePersistentCache() { flush(); } | |||||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | ||||
return m_impl->get(category, key); | return m_impl->get(category, key); | ||||
@@ -157,29 +170,64 @@ public: | |||||
} | } | ||||
std::optional<size_t> clear() override { | std::optional<size_t> clear() override { | ||||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||||
m_impl->dump_cache(m_path.c_str()); | |||||
if (m_impl) { | |||||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||||
if (m_path) { | |||||
m_impl->dump_cache(m_path->c_str()); | |||||
} | |||||
} | |||||
return {}; | return {}; | ||||
} | } | ||||
bool valid() const override { return m_impl != nullptr; } | bool valid() const override { return m_impl != nullptr; } | ||||
}; | |||||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix) { | |||||
auto cache = std::make_shared<RedisCache>(prefix, 100); | |||||
if (!cache->connect(ip, port, password)) { | |||||
return nullptr; | |||||
void flush() override { | |||||
if (m_impl && m_path) { | |||||
m_impl->dump_cache(m_path->c_str()); | |||||
} | |||||
} | } | ||||
return cache; | |||||
} | |||||
}; | |||||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||||
if (!cache->open(path)) { | |||||
return nullptr; | |||||
std::shared_ptr<ExtendedPersistentCache> ExtendedPersistentCache::make_from_config( | |||||
std::string type, std::unordered_map<std::string, std::string> args, | |||||
std::string& err_msg) { | |||||
try { | |||||
if (type == "redis") { | |||||
std::string prefix = args.at("prefix"); | |||||
std::optional<std::string> password = args.count("password") | |||||
? args.at("password") | |||||
: std::optional<std::string>(); | |||||
auto cache = std::make_shared<RedisCache>(prefix, 100); | |||||
if (args.count("unixsocket")) { | |||||
std::string unixsocket = args.at("unixsocket"); | |||||
cache->connect(unixsocket, 0, password); | |||||
} else { | |||||
std::string ip = args.at("hostname"); | |||||
int port = atoi(args.at("port").c_str()); | |||||
std::optional<std::string> password = | |||||
args.count("password") ? args.at("password") | |||||
: std::optional<std::string>(); | |||||
cache->connect(ip, port, password); | |||||
} | |||||
return cache; | |||||
} else if (type == "in-file") { | |||||
std::string path = args.at("path"); | |||||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||||
cache->open(path); | |||||
return cache; | |||||
} else if (type == "in-memory") { | |||||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||||
cache->open(); | |||||
return cache; | |||||
} else { | |||||
mgb_assert(false, "persistent cache type %s unsupported", type.c_str()); | |||||
} | |||||
} catch (const std::exception& exc) { | |||||
err_msg = exc.what(); | |||||
} catch (...) { | |||||
err_msg = "unknown exception"; | |||||
} | } | ||||
return cache; | |||||
return nullptr; | |||||
} | } | ||||
} // namespace mgb::imperative::persistent_cache | } // namespace mgb::imperative::persistent_cache | ||||
@@ -20,12 +20,12 @@ class ExtendedPersistentCache : public mgb::PersistentCache { | |||||
public: | public: | ||||
virtual bool valid() const = 0; | virtual bool valid() const = 0; | ||||
virtual std::optional<size_t> clear() = 0; | virtual std::optional<size_t> clear() = 0; | ||||
}; | |||||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||||
std::string ip, size_t port, std::string password, std::string prefix); | |||||
virtual void flush() = 0; | |||||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||||
static std::shared_ptr<ExtendedPersistentCache> make_from_config( | |||||
std::string type, std::unordered_map<std::string, std::string> args, | |||||
std::string& err_msg); | |||||
}; | |||||
} // namespace mgb::imperative::persistent_cache | } // namespace mgb::imperative::persistent_cache | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -20,7 +20,7 @@ TEST(TestImperative, AllReduceBasic) { | |||||
REQUIRE_GPU(2); | REQUIRE_GPU(2); | ||||
const char* server_addr = "127.0.0.1"; | const char* server_addr = "127.0.0.1"; | ||||
uint32_t port = 3456; | uint32_t port = 3456; | ||||
mgb_assert(create_zmqrpc_server(server_addr, port) > 0); | |||||
mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); | |||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); | CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); | ||||
@@ -20,7 +20,7 @@ TEST(TestImperative, IORemote) { | |||||
REQUIRE_GPU(2); | REQUIRE_GPU(2); | ||||
const char* server_addr = "127.0.0.1"; | const char* server_addr = "127.0.0.1"; | ||||
uint32_t port = 4567; | uint32_t port = 4567; | ||||
mgb_assert(create_zmqrpc_server(server_addr, port) > 0); | |||||
mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); | |||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); | CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); | ||||
@@ -17,6 +17,9 @@ | |||||
#include "megbrain/opr/zmq_rpc.h" | #include "megbrain/opr/zmq_rpc.h" | ||||
#include "mm_handler.pb.h" | #include "mm_handler.pb.h" | ||||
using namespace mgb; | |||||
using namespace opr; | |||||
/* ======================== GroupServerProxy ========================== */ | /* ======================== GroupServerProxy ========================== */ | ||||
/*! | /*! | ||||
* A proxy that receives zmqrpc call, direct call to NCCL Manager | * A proxy that receives zmqrpc call, direct call to NCCL Manager | ||||
@@ -213,7 +216,7 @@ struct ServerInfo { | |||||
std::unique_ptr<ZmqRpc::ZmqRpcServer> server; | std::unique_ptr<ZmqRpc::ZmqRpcServer> server; | ||||
}; | }; | ||||
int create_zmqrpc_server(const std::string& server_addr, int port) { | |||||
int mgb::opr::create_zmqrpc_server(const std::string& server_addr, int port) { | |||||
static std::unordered_map<std::string, ServerInfo> addr2server; | static std::unordered_map<std::string, ServerInfo> addr2server; | ||||
static std::mutex mtx; | static std::mutex mtx; | ||||
MGB_LOCK_GUARD(mtx); | MGB_LOCK_GUARD(mtx); | ||||
@@ -16,8 +16,8 @@ | |||||
#include "megbrain/opr/collective_comm.h" | #include "megbrain/opr/collective_comm.h" | ||||
#include "megbrain/opr/group_manager.h" | #include "megbrain/opr/group_manager.h" | ||||
using namespace mgb; | |||||
using namespace opr; | |||||
namespace mgb { | |||||
namespace opr { | |||||
/*! | /*! | ||||
* Comm MM Client Proxy. | * Comm MM Client Proxy. | ||||
@@ -56,6 +56,9 @@ private: | |||||
int create_zmqrpc_server(const std::string& server_addr, int port); | int create_zmqrpc_server(const std::string& server_addr, int port); | ||||
} // namespace opr | |||||
} // namespace mgb | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -13,8 +13,6 @@ global: | |||||
base_exceptions*; | base_exceptions*; | ||||
}; | }; | ||||
megcore*; | megcore*; | ||||
*GroupClientProxy*; | |||||
*create_zmqrpc_server*; | |||||
*custom*; | *custom*; | ||||