@@ -84,7 +84,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
from .serialization import load, save | |||
from .tensor import Parameter, Tensor, tensor | |||
from .utils import comp_graph_tools as cgtools | |||
from .utils import persistent_cache | |||
from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheOnServer | |||
from .version import __version__ | |||
_set_fork_exec_path_for_timed_func( | |||
@@ -92,15 +92,13 @@ _set_fork_exec_path_for_timed_func( | |||
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
) | |||
atexit.register(_close) | |||
del _set_fork_exec_path_for_timed_func | |||
_exit_handlers = [] | |||
def _run_exit_handlers(): | |||
for handler in _exit_handlers: | |||
for handler in reversed(_exit_handlers): | |||
handler() | |||
_exit_handlers.clear() | |||
@@ -117,6 +115,13 @@ def _atexit(handler): | |||
_exit_handlers.append(handler) | |||
_atexit(_close) | |||
_persistent_cache = _PersistentCacheOnServer() | |||
_persistent_cache.reg() | |||
_atexit(_persistent_cache.flush) | |||
# subpackages | |||
import megengine.amp | |||
import megengine.autodiff | |||
@@ -132,5 +137,3 @@ import megengine.quantization | |||
import megengine.random | |||
import megengine.utils | |||
import megengine.traced_module | |||
persistent_cache.get_manager() |
@@ -8,87 +8,114 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import argparse | |||
import contextlib | |||
import getpass | |||
import os | |||
import sys | |||
import urllib.parse | |||
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||
import filelock | |||
from ..core._imperative_rt import PersistentCache as _PersistentCache | |||
from ..logger import get_logger | |||
from ..version import __version__, git_version | |||
class PersistentCacheManager(_PersistentCacheManager): | |||
class PersistentCacheOnServer(_PersistentCache): | |||
def __init__(self): | |||
super().__init__() | |||
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | |||
get_logger().info("fastrun use in-memory cache") | |||
self.open_memory() | |||
elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": | |||
self.open_file() | |||
else: | |||
self.open_redis() | |||
def open_memory(self): | |||
pass | |||
cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE") | |||
if cache_type not in ("FILE", "MEMORY"): | |||
try: | |||
redis_config = self.get_redis_config() | |||
except Exception as exc: | |||
get_logger().error( | |||
"failed to connect to cache server {!r}; try fallback to " | |||
"in-file cache".format(exc) | |||
) | |||
else: | |||
self.add_config( | |||
"redis", | |||
redis_config, | |||
"fastrun use redis cache", | |||
"failed to connect to cache server", | |||
) | |||
if cache_type != "MEMORY": | |||
path = self.get_cache_file(self.get_cache_dir()) | |||
self.add_config( | |||
"in-file", | |||
{"path": path}, | |||
"fastrun use in-file cache in {}".format(path), | |||
"failed to create cache file in {}".format(path), | |||
) | |||
self.add_config( | |||
"in-memory", | |||
{}, | |||
"fastrun use in-memory cache", | |||
"failed to create in-memory cache", | |||
) | |||
def open_file(self): | |||
def get_cache_dir(self): | |||
cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||
try: | |||
if not cache_dir: | |||
from ..hub.hub import _get_megengine_home | |||
if not cache_dir: | |||
from ..hub.hub import _get_megengine_home | |||
cache_dir = os.path.expanduser( | |||
os.path.join(_get_megengine_home(), "persistent_cache.bin") | |||
) | |||
os.makedirs(cache_dir, exist_ok=True) | |||
cache_file = os.path.join(cache_dir, "cache") | |||
with open(cache_file, "a"): | |||
pass | |||
assert self.try_open_file(cache_file), "cannot create file" | |||
get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) | |||
except Exception as exc: | |||
get_logger().error( | |||
"failed to create cache file in {} {!r}; fallback to " | |||
"in-memory cache".format(cache_dir, exc) | |||
cache_dir = os.path.expanduser( | |||
os.path.join(_get_megengine_home(), "persistent_cache") | |||
) | |||
self.open_memory() | |||
def open_redis(self): | |||
os.makedirs(cache_dir, exist_ok=True) | |||
return cache_dir | |||
def get_cache_file(self, cache_dir): | |||
cache_file = os.path.join(cache_dir, "cache.bin") | |||
with open(cache_file, "a"): | |||
pass | |||
return cache_file | |||
@contextlib.contextmanager | |||
def lock_cache_file(self, cache_dir): | |||
lock_file = os.path.join(cache_dir, "cache.lock") | |||
with filelock.FileLock(lock_file): | |||
yield | |||
def get_redis_config(self): | |||
url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||
if url is None: | |||
return None | |||
assert sys.platform != "win32", "redis cache on windows not tested" | |||
prefix = "mgbcache:{}:MGB{}:GIT:{}".format( | |||
getpass.getuser(), __version__, git_version | |||
) | |||
url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||
if url is None: | |||
self.open_file() | |||
try: | |||
assert sys.platform != "win32", "redis cache on windows not tested" | |||
parse_result = urllib.parse.urlparse(url, scheme="redis") | |||
assert parse_result.scheme == "redis", "unsupported scheme" | |||
assert not parse_result.username, "redis conn with username unsupported" | |||
assert self.try_open_redis( | |||
parse_result.hostname, parse_result.port, parse_result.password, prefix | |||
), "connect failed" | |||
except Exception as exc: | |||
get_logger().error( | |||
"failed to connect to cache server {!r}; try fallback to " | |||
"in-file cache".format(exc) | |||
) | |||
self.open_file() | |||
_manager = None | |||
parse_result = urllib.parse.urlparse(url) | |||
assert not parse_result.username, "redis conn with username unsupported" | |||
if parse_result.scheme == "redis": | |||
assert parse_result.hostname and parse_result.port, "invalid url" | |||
assert not parse_result.path | |||
config = { | |||
"hostname": parse_result.hostname, | |||
"port": str(parse_result.port), | |||
} | |||
elif parse_result.scheme == "redis+socket": | |||
assert not (parse_result.hostname or parse_result.port) | |||
assert parse_result.path | |||
config = { | |||
"unixsocket": parse_result.path, | |||
} | |||
else: | |||
assert False, "unsupported scheme" | |||
if parse_result.password is not None: | |||
config["password"] = parse_result.password | |||
config["prefix"] = prefix | |||
return config | |||
def get_manager(): | |||
global _manager | |||
if _manager is None: | |||
_manager = PersistentCacheManager() | |||
return _manager | |||
def flush(self): | |||
if self.config is not None and self.config.type == "in-file": | |||
with self.lock_cache_file(self.get_cache_dir()): | |||
super().flush() | |||
def _clean(): | |||
nr_del = get_manager().clean() | |||
nr_del = PersistentCacheOnServer().clean() | |||
if nr_del is not None: | |||
print("{} cache entries deleted".format(nr_del)) | |||
@@ -4,8 +4,8 @@ pyarrow | |||
requests | |||
tabulate | |||
tqdm | |||
redispy | |||
deprecated | |||
mprop | |||
wheel | |||
megfile>=0.0.10 | |||
megfile>=0.0.10 | |||
filelock |
@@ -210,7 +210,7 @@ void init_utils(py::module m) { | |||
.def("disable", [](TensorSanityCheck& checker) { checker.disable(); }); | |||
#if MGB_ENABLE_OPR_MM | |||
m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), | |||
m.def("create_mm_server", &mgb::opr::create_zmqrpc_server, py::arg("addr"), | |||
py::arg("port") = 0); | |||
#else | |||
m.def("create_mm_server", []() {}); | |||
@@ -234,51 +234,108 @@ void init_utils(py::module m) { | |||
using ExtendedPersistentCache = | |||
mgb::imperative::persistent_cache::ExtendedPersistentCache; | |||
struct PersistentCacheManager { | |||
std::shared_ptr<ExtendedPersistentCache> instance; | |||
struct ConfigurablePersistentCache : mgb::PersistentCache { | |||
struct Config { | |||
std::string type; | |||
std::unordered_map<std::string, std::string> args; | |||
std::string on_success; | |||
std::string on_fail; | |||
}; | |||
bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) { | |||
if (cache) { | |||
instance = cache; | |||
PersistentCache::set_impl(cache); | |||
return true; | |||
} | |||
return false; | |||
} | |||
bool open_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix) { | |||
return try_reg(mgb::imperative::persistent_cache::make_redis( | |||
ip, port, password, prefix)); | |||
std::shared_ptr<ExtendedPersistentCache> impl; | |||
std::optional<Config> impl_config; | |||
std::vector<Config> configs; | |||
void add_config( | |||
std::string type, std::unordered_map<std::string, std::string> args, | |||
std::string on_success, std::string on_fail) { | |||
configs.push_back({type, args, on_success, on_fail}); | |||
} | |||
bool open_file(std::string path) { | |||
return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||
std::optional<size_t> clean() { return get_impl()->clear(); } | |||
void load_config() { | |||
std::optional<std::string> err_msg; | |||
for (size_t i = 0; i < configs.size(); ++i) { | |||
auto& config = configs[i]; | |||
if (err_msg) { | |||
mgb_log_warn("try fallback to %s cache", config.type.c_str()); | |||
} else { | |||
err_msg.emplace(); | |||
} | |||
auto cache = ExtendedPersistentCache::make_from_config( | |||
config.type, config.args, *err_msg); | |||
if (!cache) { | |||
mgb_log_warn("%s %s", config.on_fail.c_str(), err_msg->c_str()); | |||
} else { | |||
impl = cache; | |||
impl_config = config; | |||
break; | |||
} | |||
} | |||
mgb_assert(impl_config.has_value(), "not valid config"); | |||
} | |||
std::optional<size_t> clean() { | |||
if (instance) { | |||
return instance->clear(); | |||
std::shared_ptr<ExtendedPersistentCache> get_impl() { | |||
if (!impl) { | |||
load_config(); | |||
} | |||
return {}; | |||
return impl; | |||
} | |||
void put(std::string category, std::string key, std::string value) { | |||
PersistentCache::inst().put( | |||
category, {key.data(), key.size()}, {value.data(), value.size()}); | |||
virtual mgb::Maybe<Blob> get(const std::string& category, const Blob& key) { | |||
return get_impl()->get(category, key); | |||
} | |||
virtual void put( | |||
const std::string& category, const Blob& key, const Blob& value) { | |||
return get_impl()->put(category, key, value); | |||
} | |||
py::object get(std::string category, std::string key) { | |||
auto value = | |||
PersistentCache::inst().get(category, {key.data(), key.size()}); | |||
virtual bool support_dump_cache() { return get_impl()->support_dump_cache(); } | |||
py::object py_get(std::string category, std::string key) { | |||
auto value = get_impl()->get(category, {key.data(), key.size()}); | |||
if (value.valid()) { | |||
return py::bytes(std::string((const char*)value->ptr, value->size)); | |||
} else { | |||
return py::none(); | |||
} | |||
} | |||
void py_put(std::string category, std::string key, std::string value) { | |||
get_impl()->put( | |||
category, {key.data(), key.size()}, {value.data(), value.size()}); | |||
} | |||
void flush() { | |||
if (impl) { | |||
impl->flush(); | |||
} | |||
} | |||
}; | |||
py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||
.def(py::init<>()) | |||
.def("try_open_redis", &PersistentCacheManager::open_redis) | |||
.def("try_open_file", &PersistentCacheManager::open_file) | |||
.def("clean", &PersistentCacheManager::clean) | |||
.def("put", &PersistentCacheManager::put) | |||
.def("get", &PersistentCacheManager::get); | |||
auto PyConfigurablePersistentCache = | |||
py::class_< | |||
ConfigurablePersistentCache, | |||
std::shared_ptr<ConfigurablePersistentCache>>(m, "PersistentCache") | |||
.def(py::init<>()) | |||
.def("add_config", &ConfigurablePersistentCache::add_config) | |||
.def("reg", | |||
[](std::shared_ptr<ConfigurablePersistentCache> inst) { | |||
PersistentCache::set_impl(inst); | |||
}) | |||
.def("clean", &ConfigurablePersistentCache::clean) | |||
.def("get", &ConfigurablePersistentCache::py_get) | |||
.def("put", &ConfigurablePersistentCache::py_put) | |||
.def_readonly("config", &ConfigurablePersistentCache::impl_config) | |||
.def("flush", &ConfigurablePersistentCache::flush); | |||
py::class_<ConfigurablePersistentCache::Config>( | |||
PyConfigurablePersistentCache, "Config") | |||
.def_readwrite("type", &ConfigurablePersistentCache::Config::type) | |||
.def_readwrite("args", &ConfigurablePersistentCache::Config::args) | |||
.def_readwrite("on_fail", &ConfigurablePersistentCache::Config::on_fail) | |||
.def_readwrite( | |||
"on_success", &ConfigurablePersistentCache::Config::on_success); | |||
} |
@@ -27,7 +27,7 @@ namespace imperative { | |||
namespace { | |||
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& comm = def.cast_final_safe<CollectiveComm>(); | |||
auto group_client = std::make_shared<GroupClientProxy>( | |||
auto group_client = std::make_shared<opr::GroupClientProxy>( | |||
ssprintf("%s:%d", comm.addr.data(), comm.port)); | |||
SmallVector<std::shared_ptr<mgb::DeviceTensorND>> dev_buffer_arr(1, nullptr); | |||
auto disable = std::make_shared<DTypeScalar>(); | |||
@@ -28,7 +28,7 @@ namespace { | |||
cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||
const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& send = def.cast_final_safe<RemoteSend>(); | |||
auto group_client = std::make_shared<GroupClientProxy>( | |||
auto group_client = std::make_shared<opr::GroupClientProxy>( | |||
ssprintf("%s:%d", send.addr.data(), send.port)); | |||
auto&& graph = inputs[0]->owner_graph(); | |||
@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
auto&& recv = def.cast_final_safe<RemoteRecv>(); | |||
OperatorNodeConfig config{recv.cn}; | |||
config.name(recv.make_name()); | |||
auto group_client = std::make_shared<GroupClientProxy>( | |||
auto group_client = std::make_shared<opr::GroupClientProxy>( | |||
ssprintf("%s:%d", recv.addr.data(), recv.port)); | |||
auto&& graph = inputs[0]->owner_graph(); | |||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | |||
@@ -27,8 +27,10 @@ public: | |||
m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||
} | |||
bool connect(std::string ip, size_t port, std::string password) { | |||
m_client.auth(password); | |||
void connect(std::string ip, size_t port, std::optional<std::string> password) { | |||
if (password) { | |||
m_client.auth(*password); | |||
} | |||
m_client.connect( | |||
ip, port, | |||
[](const std::string& host, std::size_t port, | |||
@@ -40,16 +42,32 @@ public: | |||
} | |||
}, | |||
std::uint32_t(200)); | |||
if (!m_client.is_connected()) { | |||
return false; | |||
} | |||
mgb_assert(m_client.is_connected(), "connect failed"); | |||
auto flag = m_client.get("mgb-cache-flag"); | |||
sync(); | |||
return flag.get().ok(); | |||
auto is_valid = [](const cpp_redis::reply& reply) { | |||
switch (reply.get_type()) { | |||
case cpp_redis::reply::type::error: | |||
case cpp_redis::reply::type::null: | |||
return false; | |||
case cpp_redis::reply::type::integer: | |||
return reply.as_integer() != 0; | |||
case cpp_redis::reply::type::simple_string: | |||
case cpp_redis::reply::type::bulk_string: | |||
return !reply.as_string().empty(); | |||
case cpp_redis::reply::type::array: | |||
return !reply.as_array().empty(); | |||
default: | |||
mgb_assert(false, "unknown reply type %d", (int)reply.get_type()); | |||
} | |||
}; | |||
mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag"); | |||
} | |||
bool valid() const override { return m_client.is_connected(); } | |||
void flush() override {} | |||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto mem_result = m_local->get(category, key); | |||
@@ -75,7 +93,7 @@ public: | |||
MGB_LOCK_GUARD(m_mtx); | |||
std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
std::string redis_key_str; | |||
encode(category + '@' + key_str, redis_key_str); | |||
encode(category + '@' + key_str, redis_key_str, 24); | |||
std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||
std::string redis_value_str; | |||
encode(value_str, redis_value_str); | |||
@@ -118,18 +136,16 @@ private: | |||
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | |||
private: | |||
std::string m_path; | |||
std::optional<std::string> m_path; | |||
std::unique_ptr<mgb::InFilePersistentCache> m_impl; | |||
public: | |||
ExtendedInFilePersistentCache() = default; | |||
bool open(std::string path) { | |||
void open(std::string path) { | |||
std::fstream file; | |||
file.open(path, std::ios::in | std::ios::binary); | |||
if (!file.is_open()) { | |||
return false; | |||
} | |||
mgb_assert(file.is_open(), "can't open file in %s", path.c_str()); | |||
std::vector<char> bytes = { | |||
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; | |||
if (bytes.size()) { | |||
@@ -139,14 +155,11 @@ public: | |||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
} | |||
m_path = path; | |||
return true; | |||
} | |||
~ExtendedInFilePersistentCache() { | |||
if (m_impl) { | |||
m_impl->dump_cache(m_path.c_str()); | |||
} | |||
} | |||
void open() { m_impl = std::make_unique<mgb::InFilePersistentCache>(); } | |||
~ExtendedInFilePersistentCache() { flush(); } | |||
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
return m_impl->get(category, key); | |||
@@ -157,29 +170,64 @@ public: | |||
} | |||
std::optional<size_t> clear() override { | |||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
m_impl->dump_cache(m_path.c_str()); | |||
if (m_impl) { | |||
m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
if (m_path) { | |||
m_impl->dump_cache(m_path->c_str()); | |||
} | |||
} | |||
return {}; | |||
} | |||
bool valid() const override { return m_impl != nullptr; } | |||
}; | |||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix) { | |||
auto cache = std::make_shared<RedisCache>(prefix, 100); | |||
if (!cache->connect(ip, port, password)) { | |||
return nullptr; | |||
void flush() override { | |||
if (m_impl && m_path) { | |||
m_impl->dump_cache(m_path->c_str()); | |||
} | |||
} | |||
return cache; | |||
} | |||
}; | |||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
if (!cache->open(path)) { | |||
return nullptr; | |||
std::shared_ptr<ExtendedPersistentCache> ExtendedPersistentCache::make_from_config( | |||
std::string type, std::unordered_map<std::string, std::string> args, | |||
std::string& err_msg) { | |||
try { | |||
if (type == "redis") { | |||
std::string prefix = args.at("prefix"); | |||
std::optional<std::string> password = args.count("password") | |||
? args.at("password") | |||
: std::optional<std::string>(); | |||
auto cache = std::make_shared<RedisCache>(prefix, 100); | |||
if (args.count("unixsocket")) { | |||
std::string unixsocket = args.at("unixsocket"); | |||
cache->connect(unixsocket, 0, password); | |||
} else { | |||
std::string ip = args.at("hostname"); | |||
int port = atoi(args.at("port").c_str()); | |||
std::optional<std::string> password = | |||
args.count("password") ? args.at("password") | |||
: std::optional<std::string>(); | |||
cache->connect(ip, port, password); | |||
} | |||
return cache; | |||
} else if (type == "in-file") { | |||
std::string path = args.at("path"); | |||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
cache->open(path); | |||
return cache; | |||
} else if (type == "in-memory") { | |||
auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
cache->open(); | |||
return cache; | |||
} else { | |||
mgb_assert(false, "persistent cache type %s unsupported", type.c_str()); | |||
} | |||
} catch (const std::exception& exc) { | |||
err_msg = exc.what(); | |||
} catch (...) { | |||
err_msg = "unknown exception"; | |||
} | |||
return cache; | |||
return nullptr; | |||
} | |||
} // namespace mgb::imperative::persistent_cache | |||
@@ -20,12 +20,12 @@ class ExtendedPersistentCache : public mgb::PersistentCache { | |||
public: | |||
virtual bool valid() const = 0; | |||
virtual std::optional<size_t> clear() = 0; | |||
}; | |||
std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
std::string ip, size_t port, std::string password, std::string prefix); | |||
virtual void flush() = 0; | |||
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||
static std::shared_ptr<ExtendedPersistentCache> make_from_config( | |||
std::string type, std::unordered_map<std::string, std::string> args, | |||
std::string& err_msg); | |||
}; | |||
} // namespace mgb::imperative::persistent_cache | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -20,7 +20,7 @@ TEST(TestImperative, AllReduceBasic) { | |||
REQUIRE_GPU(2); | |||
const char* server_addr = "127.0.0.1"; | |||
uint32_t port = 3456; | |||
mgb_assert(create_zmqrpc_server(server_addr, port) > 0); | |||
mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); | |||
HostTensorGenerator<> gen; | |||
CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); | |||
@@ -20,7 +20,7 @@ TEST(TestImperative, IORemote) { | |||
REQUIRE_GPU(2); | |||
const char* server_addr = "127.0.0.1"; | |||
uint32_t port = 4567; | |||
mgb_assert(create_zmqrpc_server(server_addr, port) > 0); | |||
mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); | |||
HostTensorGenerator<> gen; | |||
CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); | |||
@@ -17,6 +17,9 @@ | |||
#include "megbrain/opr/zmq_rpc.h" | |||
#include "mm_handler.pb.h" | |||
using namespace mgb; | |||
using namespace opr; | |||
/* ======================== GroupServerProxy ========================== */ | |||
/*! | |||
* A proxy that receives zmqrpc call, direct call to NCCL Manager | |||
@@ -213,7 +216,7 @@ struct ServerInfo { | |||
std::unique_ptr<ZmqRpc::ZmqRpcServer> server; | |||
}; | |||
int create_zmqrpc_server(const std::string& server_addr, int port) { | |||
int mgb::opr::create_zmqrpc_server(const std::string& server_addr, int port) { | |||
static std::unordered_map<std::string, ServerInfo> addr2server; | |||
static std::mutex mtx; | |||
MGB_LOCK_GUARD(mtx); | |||
@@ -16,8 +16,8 @@ | |||
#include "megbrain/opr/collective_comm.h" | |||
#include "megbrain/opr/group_manager.h" | |||
using namespace mgb; | |||
using namespace opr; | |||
namespace mgb { | |||
namespace opr { | |||
/*! | |||
* Comm MM Client Proxy. | |||
@@ -56,6 +56,9 @@ private: | |||
int create_zmqrpc_server(const std::string& server_addr, int port); | |||
} // namespace opr | |||
} // namespace mgb | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -13,8 +13,6 @@ global: | |||
base_exceptions*; | |||
}; | |||
megcore*; | |||
*GroupClientProxy*; | |||
*create_zmqrpc_server*; | |||
*custom*; | |||