Browse Source

fix(fastrun): fix persistent_cache in redis

GitOrigin-RevId: ada5862b05
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
1657b8e881
13 changed files with 288 additions and 149 deletions
  1. +9
    -6
      imperative/python/megengine/__init__.py
  2. +86
    -59
      imperative/python/megengine/utils/persistent_cache.py
  3. +2
    -2
      imperative/python/requires.txt
  4. +91
    -34
      imperative/python/src/utils.cpp
  5. +1
    -1
      imperative/src/impl/ops/collective_comm.cpp
  6. +2
    -2
      imperative/src/impl/ops/io_remote.cpp
  7. +81
    -33
      imperative/src/impl/persistent_cache.cpp
  8. +5
    -5
      imperative/src/include/megbrain/imperative/persistent_cache.h
  9. +1
    -1
      imperative/src/test/collective_comm.cpp
  10. +1
    -1
      imperative/src/test/io_remote.cpp
  11. +4
    -1
      src/opr-mm/impl/mm_handler.cpp
  12. +5
    -2
      src/opr-mm/include/megbrain/opr/mm_handler.h
  13. +0
    -2
      src/version.ld

+ 9
- 6
imperative/python/megengine/__init__.py View File

@@ -84,7 +84,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
from .tensor import Parameter, Tensor, tensor from .tensor import Parameter, Tensor, tensor
from .utils import comp_graph_tools as cgtools from .utils import comp_graph_tools as cgtools
from .utils import persistent_cache
from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheOnServer
from .version import __version__ from .version import __version__


_set_fork_exec_path_for_timed_func( _set_fork_exec_path_for_timed_func(
@@ -92,15 +92,13 @@ _set_fork_exec_path_for_timed_func(
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
) )


atexit.register(_close)

del _set_fork_exec_path_for_timed_func del _set_fork_exec_path_for_timed_func


_exit_handlers = [] _exit_handlers = []




def _run_exit_handlers(): def _run_exit_handlers():
for handler in _exit_handlers:
for handler in reversed(_exit_handlers):
handler() handler()
_exit_handlers.clear() _exit_handlers.clear()


@@ -117,6 +115,13 @@ def _atexit(handler):
_exit_handlers.append(handler) _exit_handlers.append(handler)




_atexit(_close)

_persistent_cache = _PersistentCacheOnServer()
_persistent_cache.reg()

_atexit(_persistent_cache.flush)

# subpackages # subpackages
import megengine.amp import megengine.amp
import megengine.autodiff import megengine.autodiff
@@ -132,5 +137,3 @@ import megengine.quantization
import megengine.random import megengine.random
import megengine.utils import megengine.utils
import megengine.traced_module import megengine.traced_module

persistent_cache.get_manager()

+ 86
- 59
imperative/python/megengine/utils/persistent_cache.py View File

@@ -8,87 +8,114 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import argparse import argparse
import contextlib
import getpass import getpass
import os import os
import sys import sys
import urllib.parse import urllib.parse


from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager
import filelock

from ..core._imperative_rt import PersistentCache as _PersistentCache
from ..logger import get_logger from ..logger import get_logger
from ..version import __version__, git_version from ..version import __version__, git_version




class PersistentCacheManager(_PersistentCacheManager):
class PersistentCacheOnServer(_PersistentCache):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY":
get_logger().info("fastrun use in-memory cache")
self.open_memory()
elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE":
self.open_file()
else:
self.open_redis()

def open_memory(self):
pass
cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE")
if cache_type not in ("FILE", "MEMORY"):
try:
redis_config = self.get_redis_config()
except Exception as exc:
get_logger().error(
"failed to connect to cache server {!r}; try fallback to "
"in-file cache".format(exc)
)
else:
self.add_config(
"redis",
redis_config,
"fastrun use redis cache",
"failed to connect to cache server",
)
if cache_type != "MEMORY":
path = self.get_cache_file(self.get_cache_dir())
self.add_config(
"in-file",
{"path": path},
"fastrun use in-file cache in {}".format(path),
"failed to create cache file in {}".format(path),
)
self.add_config(
"in-memory",
{},
"fastrun use in-memory cache",
"failed to create in-memory cache",
)


def open_file(self):
def get_cache_dir(self):
cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
try:
if not cache_dir:
from ..hub.hub import _get_megengine_home
if not cache_dir:
from ..hub.hub import _get_megengine_home


cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache.bin")
)
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "cache")
with open(cache_file, "a"):
pass
assert self.try_open_file(cache_file), "cannot create file"
get_logger().info("fastrun use in-file cache in {}".format(cache_dir))
except Exception as exc:
get_logger().error(
"failed to create cache file in {} {!r}; fallback to "
"in-memory cache".format(cache_dir, exc)
cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache")
) )
self.open_memory()

def open_redis(self):
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

def get_cache_file(self, cache_dir):
cache_file = os.path.join(cache_dir, "cache.bin")
with open(cache_file, "a"):
pass
return cache_file

@contextlib.contextmanager
def lock_cache_file(self, cache_dir):
lock_file = os.path.join(cache_dir, "cache.lock")
with filelock.FileLock(lock_file):
yield

def get_redis_config(self):
url = os.getenv("MGE_FASTRUN_CACHE_URL")
if url is None:
return None
assert sys.platform != "win32", "redis cache on windows not tested"
prefix = "mgbcache:{}:MGB{}:GIT:{}".format( prefix = "mgbcache:{}:MGB{}:GIT:{}".format(
getpass.getuser(), __version__, git_version getpass.getuser(), __version__, git_version
) )
url = os.getenv("MGE_FASTRUN_CACHE_URL")
if url is None:
self.open_file()
try:
assert sys.platform != "win32", "redis cache on windows not tested"
parse_result = urllib.parse.urlparse(url, scheme="redis")
assert parse_result.scheme == "redis", "unsupported scheme"
assert not parse_result.username, "redis conn with username unsupported"
assert self.try_open_redis(
parse_result.hostname, parse_result.port, parse_result.password, prefix
), "connect failed"
except Exception as exc:
get_logger().error(
"failed to connect to cache server {!r}; try fallback to "
"in-file cache".format(exc)
)
self.open_file()
_manager = None
parse_result = urllib.parse.urlparse(url)
assert not parse_result.username, "redis conn with username unsupported"
if parse_result.scheme == "redis":
assert parse_result.hostname and parse_result.port, "invalid url"
assert not parse_result.path
config = {
"hostname": parse_result.hostname,
"port": str(parse_result.port),
}
elif parse_result.scheme == "redis+socket":
assert not (parse_result.hostname or parse_result.port)
assert parse_result.path
config = {
"unixsocket": parse_result.path,
}
else:
assert False, "unsupported scheme"
if parse_result.password is not None:
config["password"] = parse_result.password
config["prefix"] = prefix
return config


def get_manager():
global _manager
if _manager is None:
_manager = PersistentCacheManager()
return _manager
def flush(self):
if self.config is not None and self.config.type == "in-file":
with self.lock_cache_file(self.get_cache_dir()):
super().flush()




def _clean(): def _clean():
nr_del = get_manager().clean()
nr_del = PersistentCacheOnServer().clean()
if nr_del is not None: if nr_del is not None:
print("{} cache entries deleted".format(nr_del)) print("{} cache entries deleted".format(nr_del))




+ 2
- 2
imperative/python/requires.txt View File

@@ -4,8 +4,8 @@ pyarrow
requests requests
tabulate tabulate
tqdm tqdm
redispy
deprecated deprecated
mprop mprop
wheel wheel
megfile>=0.0.10
megfile>=0.0.10
filelock

+ 91
- 34
imperative/python/src/utils.cpp View File

@@ -210,7 +210,7 @@ void init_utils(py::module m) {
.def("disable", [](TensorSanityCheck& checker) { checker.disable(); }); .def("disable", [](TensorSanityCheck& checker) { checker.disable(); });


#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM
m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"),
m.def("create_mm_server", &mgb::opr::create_zmqrpc_server, py::arg("addr"),
py::arg("port") = 0); py::arg("port") = 0);
#else #else
m.def("create_mm_server", []() {}); m.def("create_mm_server", []() {});
@@ -234,51 +234,108 @@ void init_utils(py::module m) {
using ExtendedPersistentCache = using ExtendedPersistentCache =
mgb::imperative::persistent_cache::ExtendedPersistentCache; mgb::imperative::persistent_cache::ExtendedPersistentCache;


struct PersistentCacheManager {
std::shared_ptr<ExtendedPersistentCache> instance;
struct ConfigurablePersistentCache : mgb::PersistentCache {
struct Config {
std::string type;
std::unordered_map<std::string, std::string> args;
std::string on_success;
std::string on_fail;
};


bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) {
if (cache) {
instance = cache;
PersistentCache::set_impl(cache);
return true;
}
return false;
}
bool open_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
return try_reg(mgb::imperative::persistent_cache::make_redis(
ip, port, password, prefix));
std::shared_ptr<ExtendedPersistentCache> impl;
std::optional<Config> impl_config;
std::vector<Config> configs;

void add_config(
std::string type, std::unordered_map<std::string, std::string> args,
std::string on_success, std::string on_fail) {
configs.push_back({type, args, on_success, on_fail});
} }
bool open_file(std::string path) {
return try_reg(mgb::imperative::persistent_cache::make_in_file(path));

std::optional<size_t> clean() { return get_impl()->clear(); }

void load_config() {
std::optional<std::string> err_msg;
for (size_t i = 0; i < configs.size(); ++i) {
auto& config = configs[i];
if (err_msg) {
mgb_log_warn("try fallback to %s cache", config.type.c_str());
} else {
err_msg.emplace();
}
auto cache = ExtendedPersistentCache::make_from_config(
config.type, config.args, *err_msg);
if (!cache) {
mgb_log_warn("%s %s", config.on_fail.c_str(), err_msg->c_str());
} else {
impl = cache;
impl_config = config;
break;
}
}
mgb_assert(impl_config.has_value(), "not valid config");
} }
std::optional<size_t> clean() {
if (instance) {
return instance->clear();

std::shared_ptr<ExtendedPersistentCache> get_impl() {
if (!impl) {
load_config();
} }
return {};
return impl;
} }
void put(std::string category, std::string key, std::string value) {
PersistentCache::inst().put(
category, {key.data(), key.size()}, {value.data(), value.size()});

virtual mgb::Maybe<Blob> get(const std::string& category, const Blob& key) {
return get_impl()->get(category, key);
}

virtual void put(
const std::string& category, const Blob& key, const Blob& value) {
return get_impl()->put(category, key, value);
} }
py::object get(std::string category, std::string key) {
auto value =
PersistentCache::inst().get(category, {key.data(), key.size()});

virtual bool support_dump_cache() { return get_impl()->support_dump_cache(); }

py::object py_get(std::string category, std::string key) {
auto value = get_impl()->get(category, {key.data(), key.size()});
if (value.valid()) { if (value.valid()) {
return py::bytes(std::string((const char*)value->ptr, value->size)); return py::bytes(std::string((const char*)value->ptr, value->size));
} else { } else {
return py::none(); return py::none();
} }
} }

void py_put(std::string category, std::string key, std::string value) {
get_impl()->put(
category, {key.data(), key.size()}, {value.data(), value.size()});
}

void flush() {
if (impl) {
impl->flush();
}
}
}; };


py::class_<PersistentCacheManager>(m, "PersistentCacheManager")
.def(py::init<>())
.def("try_open_redis", &PersistentCacheManager::open_redis)
.def("try_open_file", &PersistentCacheManager::open_file)
.def("clean", &PersistentCacheManager::clean)
.def("put", &PersistentCacheManager::put)
.def("get", &PersistentCacheManager::get);
auto PyConfigurablePersistentCache =
py::class_<
ConfigurablePersistentCache,
std::shared_ptr<ConfigurablePersistentCache>>(m, "PersistentCache")
.def(py::init<>())
.def("add_config", &ConfigurablePersistentCache::add_config)
.def("reg",
[](std::shared_ptr<ConfigurablePersistentCache> inst) {
PersistentCache::set_impl(inst);
})
.def("clean", &ConfigurablePersistentCache::clean)
.def("get", &ConfigurablePersistentCache::py_get)
.def("put", &ConfigurablePersistentCache::py_put)
.def_readonly("config", &ConfigurablePersistentCache::impl_config)
.def("flush", &ConfigurablePersistentCache::flush);

py::class_<ConfigurablePersistentCache::Config>(
PyConfigurablePersistentCache, "Config")
.def_readwrite("type", &ConfigurablePersistentCache::Config::type)
.def_readwrite("args", &ConfigurablePersistentCache::Config::args)
.def_readwrite("on_fail", &ConfigurablePersistentCache::Config::on_fail)
.def_readwrite(
"on_success", &ConfigurablePersistentCache::Config::on_success);
} }

+ 1
- 1
imperative/src/impl/ops/collective_comm.cpp View File

@@ -27,7 +27,7 @@ namespace imperative {
namespace { namespace {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& comm = def.cast_final_safe<CollectiveComm>(); auto&& comm = def.cast_final_safe<CollectiveComm>();
auto group_client = std::make_shared<GroupClientProxy>(
auto group_client = std::make_shared<opr::GroupClientProxy>(
ssprintf("%s:%d", comm.addr.data(), comm.port)); ssprintf("%s:%d", comm.addr.data(), comm.port));
SmallVector<std::shared_ptr<mgb::DeviceTensorND>> dev_buffer_arr(1, nullptr); SmallVector<std::shared_ptr<mgb::DeviceTensorND>> dev_buffer_arr(1, nullptr);
auto disable = std::make_shared<DTypeScalar>(); auto disable = std::make_shared<DTypeScalar>();


+ 2
- 2
imperative/src/impl/ops/io_remote.cpp View File

@@ -28,7 +28,7 @@ namespace {
cg::OperatorNodeBase* apply_on_var_node_remote_send( cg::OperatorNodeBase* apply_on_var_node_remote_send(
const OpDef& def, const VarNodeArray& inputs) { const OpDef& def, const VarNodeArray& inputs) {
auto&& send = def.cast_final_safe<RemoteSend>(); auto&& send = def.cast_final_safe<RemoteSend>();
auto group_client = std::make_shared<GroupClientProxy>(
auto group_client = std::make_shared<opr::GroupClientProxy>(
ssprintf("%s:%d", send.addr.data(), send.port)); ssprintf("%s:%d", send.addr.data(), send.port));
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();


@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
auto&& recv = def.cast_final_safe<RemoteRecv>(); auto&& recv = def.cast_final_safe<RemoteRecv>();
OperatorNodeConfig config{recv.cn}; OperatorNodeConfig config{recv.cn};
config.name(recv.make_name()); config.name(recv.make_name());
auto group_client = std::make_shared<GroupClientProxy>(
auto group_client = std::make_shared<opr::GroupClientProxy>(
ssprintf("%s:%d", recv.addr.data(), recv.port)); ssprintf("%s:%d", recv.addr.data(), recv.port));
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(


+ 81
- 33
imperative/src/impl/persistent_cache.cpp View File

@@ -27,8 +27,10 @@ public:
m_local = std::make_shared<mgb::InMemoryPersistentCache>(); m_local = std::make_shared<mgb::InMemoryPersistentCache>();
} }


bool connect(std::string ip, size_t port, std::string password) {
m_client.auth(password);
void connect(std::string ip, size_t port, std::optional<std::string> password) {
if (password) {
m_client.auth(*password);
}
m_client.connect( m_client.connect(
ip, port, ip, port,
[](const std::string& host, std::size_t port, [](const std::string& host, std::size_t port,
@@ -40,16 +42,32 @@ public:
} }
}, },
std::uint32_t(200)); std::uint32_t(200));
if (!m_client.is_connected()) {
return false;
}
mgb_assert(m_client.is_connected(), "connect failed");
auto flag = m_client.get("mgb-cache-flag"); auto flag = m_client.get("mgb-cache-flag");
sync(); sync();
return flag.get().ok();
auto is_valid = [](const cpp_redis::reply& reply) {
switch (reply.get_type()) {
case cpp_redis::reply::type::error:
case cpp_redis::reply::type::null:
return false;
case cpp_redis::reply::type::integer:
return reply.as_integer() != 0;
case cpp_redis::reply::type::simple_string:
case cpp_redis::reply::type::bulk_string:
return !reply.as_string().empty();
case cpp_redis::reply::type::array:
return !reply.as_array().empty();
default:
mgb_assert(false, "unknown reply type %d", (int)reply.get_type());
}
};
mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag");
} }


bool valid() const override { return m_client.is_connected(); } bool valid() const override { return m_client.is_connected(); }


void flush() override {}

mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
auto mem_result = m_local->get(category, key); auto mem_result = m_local->get(category, key);
@@ -75,7 +93,7 @@ public:
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
std::string key_str(static_cast<const char*>(key.ptr), key.size); std::string key_str(static_cast<const char*>(key.ptr), key.size);
std::string redis_key_str; std::string redis_key_str;
encode(category + '@' + key_str, redis_key_str);
encode(category + '@' + key_str, redis_key_str, 24);
std::string value_str(static_cast<const char*>(value.ptr), value.size); std::string value_str(static_cast<const char*>(value.ptr), value.size);
std::string redis_value_str; std::string redis_value_str;
encode(value_str, redis_value_str); encode(value_str, redis_value_str);
@@ -118,18 +136,16 @@ private:


class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
private: private:
std::string m_path;
std::optional<std::string> m_path;
std::unique_ptr<mgb::InFilePersistentCache> m_impl; std::unique_ptr<mgb::InFilePersistentCache> m_impl;


public: public:
ExtendedInFilePersistentCache() = default; ExtendedInFilePersistentCache() = default;


bool open(std::string path) {
void open(std::string path) {
std::fstream file; std::fstream file;
file.open(path, std::ios::in | std::ios::binary); file.open(path, std::ios::in | std::ios::binary);
if (!file.is_open()) {
return false;
}
mgb_assert(file.is_open(), "can't open file in %s", path.c_str());
std::vector<char> bytes = { std::vector<char> bytes = {
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()};
if (bytes.size()) { if (bytes.size()) {
@@ -139,14 +155,11 @@ public:
m_impl = std::make_unique<mgb::InFilePersistentCache>(); m_impl = std::make_unique<mgb::InFilePersistentCache>();
} }
m_path = path; m_path = path;
return true;
} }


~ExtendedInFilePersistentCache() {
if (m_impl) {
m_impl->dump_cache(m_path.c_str());
}
}
void open() { m_impl = std::make_unique<mgb::InFilePersistentCache>(); }

~ExtendedInFilePersistentCache() { flush(); }


mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
return m_impl->get(category, key); return m_impl->get(category, key);
@@ -157,29 +170,64 @@ public:
} }


std::optional<size_t> clear() override { std::optional<size_t> clear() override {
m_impl = std::make_unique<mgb::InFilePersistentCache>();
m_impl->dump_cache(m_path.c_str());
if (m_impl) {
m_impl = std::make_unique<mgb::InFilePersistentCache>();
if (m_path) {
m_impl->dump_cache(m_path->c_str());
}
}
return {}; return {};
} }


bool valid() const override { return m_impl != nullptr; } bool valid() const override { return m_impl != nullptr; }
};


std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
auto cache = std::make_shared<RedisCache>(prefix, 100);
if (!cache->connect(ip, port, password)) {
return nullptr;
void flush() override {
if (m_impl && m_path) {
m_impl->dump_cache(m_path->c_str());
}
} }
return cache;
}
};


std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) {
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
if (!cache->open(path)) {
return nullptr;
std::shared_ptr<ExtendedPersistentCache> ExtendedPersistentCache::make_from_config(
std::string type, std::unordered_map<std::string, std::string> args,
std::string& err_msg) {
try {
if (type == "redis") {
std::string prefix = args.at("prefix");
std::optional<std::string> password = args.count("password")
? args.at("password")
: std::optional<std::string>();
auto cache = std::make_shared<RedisCache>(prefix, 100);
if (args.count("unixsocket")) {
std::string unixsocket = args.at("unixsocket");
cache->connect(unixsocket, 0, password);
} else {
std::string ip = args.at("hostname");
int port = atoi(args.at("port").c_str());
std::optional<std::string> password =
args.count("password") ? args.at("password")
: std::optional<std::string>();
cache->connect(ip, port, password);
}
return cache;
} else if (type == "in-file") {
std::string path = args.at("path");
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
cache->open(path);
return cache;
} else if (type == "in-memory") {
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
cache->open();
return cache;
} else {
mgb_assert(false, "persistent cache type %s unsupported", type.c_str());
}
} catch (const std::exception& exc) {
err_msg = exc.what();
} catch (...) {
err_msg = "unknown exception";
} }
return cache;
return nullptr;
} }


} // namespace mgb::imperative::persistent_cache } // namespace mgb::imperative::persistent_cache


+ 5
- 5
imperative/src/include/megbrain/imperative/persistent_cache.h View File

@@ -20,12 +20,12 @@ class ExtendedPersistentCache : public mgb::PersistentCache {
public: public:
virtual bool valid() const = 0; virtual bool valid() const = 0;
virtual std::optional<size_t> clear() = 0; virtual std::optional<size_t> clear() = 0;
};

std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix);
virtual void flush() = 0;


std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path);
static std::shared_ptr<ExtendedPersistentCache> make_from_config(
std::string type, std::unordered_map<std::string, std::string> args,
std::string& err_msg);
};


} // namespace mgb::imperative::persistent_cache } // namespace mgb::imperative::persistent_cache
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 1
imperative/src/test/collective_comm.cpp View File

@@ -20,7 +20,7 @@ TEST(TestImperative, AllReduceBasic) {
REQUIRE_GPU(2); REQUIRE_GPU(2);
const char* server_addr = "127.0.0.1"; const char* server_addr = "127.0.0.1";
uint32_t port = 3456; uint32_t port = 3456;
mgb_assert(create_zmqrpc_server(server_addr, port) > 0);
mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0);
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1");




+ 1
- 1
imperative/src/test/io_remote.cpp View File

@@ -20,7 +20,7 @@ TEST(TestImperative, IORemote) {
REQUIRE_GPU(2); REQUIRE_GPU(2);
const char* server_addr = "127.0.0.1"; const char* server_addr = "127.0.0.1";
uint32_t port = 4567; uint32_t port = 4567;
mgb_assert(create_zmqrpc_server(server_addr, port) > 0);
mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0);
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1");




+ 4
- 1
src/opr-mm/impl/mm_handler.cpp View File

@@ -17,6 +17,9 @@
#include "megbrain/opr/zmq_rpc.h" #include "megbrain/opr/zmq_rpc.h"
#include "mm_handler.pb.h" #include "mm_handler.pb.h"


using namespace mgb;
using namespace opr;

/* ======================== GroupServerProxy ========================== */ /* ======================== GroupServerProxy ========================== */
/*! /*!
* A proxy that receives zmqrpc call, direct call to NCCL Manager * A proxy that receives zmqrpc call, direct call to NCCL Manager
@@ -213,7 +216,7 @@ struct ServerInfo {
std::unique_ptr<ZmqRpc::ZmqRpcServer> server; std::unique_ptr<ZmqRpc::ZmqRpcServer> server;
}; };


int create_zmqrpc_server(const std::string& server_addr, int port) {
int mgb::opr::create_zmqrpc_server(const std::string& server_addr, int port) {
static std::unordered_map<std::string, ServerInfo> addr2server; static std::unordered_map<std::string, ServerInfo> addr2server;
static std::mutex mtx; static std::mutex mtx;
MGB_LOCK_GUARD(mtx); MGB_LOCK_GUARD(mtx);


+ 5
- 2
src/opr-mm/include/megbrain/opr/mm_handler.h View File

@@ -16,8 +16,8 @@
#include "megbrain/opr/collective_comm.h" #include "megbrain/opr/collective_comm.h"
#include "megbrain/opr/group_manager.h" #include "megbrain/opr/group_manager.h"


using namespace mgb;
using namespace opr;
namespace mgb {
namespace opr {


/*! /*!
* Comm MM Client Proxy. * Comm MM Client Proxy.
@@ -56,6 +56,9 @@ private:


int create_zmqrpc_server(const std::string& server_addr, int port); int create_zmqrpc_server(const std::string& server_addr, int port);


} // namespace opr
} // namespace mgb

#endif #endif


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 0
- 2
src/version.ld View File

@@ -13,8 +13,6 @@ global:
base_exceptions*; base_exceptions*;
}; };
megcore*; megcore*;
*GroupClientProxy*;
*create_zmqrpc_server*;
*custom*; *custom*;






Loading…
Cancel
Save