GitOrigin-RevId: f401ce8603
tags/v0.5.0
@@ -14,6 +14,7 @@ ExternalProject_add( | |||
) | |||
set(ZMQ_INC ${ZMQ_BUILD_DIR}/include) | |||
include_directories(${ZMQ_INC}) | |||
file(MAKE_DIRECTORY ${ZMQ_INC}) | |||
add_library(libzmq STATIC IMPORTED GLOBAL) | |||
@@ -12,14 +12,6 @@ set(SWIG_SRC src/swig/mgb.i) | |||
set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | |||
if(MGE_WITH_DISTRIBUTED) | |||
file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "src/proto/*.proto") | |||
PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES}) | |||
add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE}) | |||
endif() | |||
file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | |||
file(GLOB_RECURSE PYTHON_SRCS setup.py | |||
src/python/*.py | |||
@@ -55,11 +47,7 @@ add_custom_command( | |||
add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) | |||
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) | |||
if(MGE_WITH_DISTRIBUTED) | |||
list(APPEND SRCS src/cpp/zmq_rpc.cpp) | |||
endif() | |||
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) | |||
include(UseSWIG) | |||
set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) | |||
@@ -70,7 +58,7 @@ set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/ | |||
set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) | |||
set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | |||
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${GRPC_SRCS} ${SRCS}) | |||
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${SRCS}) | |||
set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) | |||
add_custom_target(version_ld SOURCES ${VERSION_SCRIPT}) | |||
@@ -81,12 +69,6 @@ target_include_directories(_mgb PRIVATE ${PYTHON_INCLUDE_DIRS} src/cpp ${CMAKE_C | |||
target_link_libraries(_mgb ${PYTHON_LIBRARIES}) | |||
add_dependencies(_mgb mgb_opr_py version_ld) | |||
if(MGE_WITH_DISTRIBUTED) | |||
add_dependencies(_mgb mgb_proto_target) | |||
target_link_libraries (_mgb libprotobuf libzmq) | |||
set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq) | |||
target_include_directories(_mgb PRIVATE ${CPPZMQ_INC}) | |||
endif() | |||
add_custom_command( | |||
TARGET _mgb POST_BUILD | |||
@@ -19,6 +19,10 @@ | |||
#include <dlfcn.h> | |||
#if MGB_ENABLE_OPR_MM | |||
#include "megbrain/opr/mm_handler.h" | |||
#endif | |||
#if MGB_CUDA | |||
#include <cuda.h> | |||
#endif | |||
@@ -276,4 +280,37 @@ std::vector<std::pair<uint64_t, std::string>> _config::dump_registered_oprs() { | |||
#endif | |||
} | |||
#if MGB_ENABLE_OPR_MM | |||
/*! see definition : src/cpp/megbrain_config.h. | |||
* Create mm server. port 0 is permitted, leave zmqrpc to decide which port | |||
* should be used. | |||
*/ | |||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||
return create_zmqrpc_server(server_addr, port); | |||
} | |||
void _config::group_barrier(const std::string& server_addr, | |||
int port, uint32_t size, uint32_t rank) { | |||
mgb_assert(rank < size, "invalid rank %d", rank); | |||
auto group_mgr = std::make_shared<GroupClientProxy>( | |||
ssprintf("%s:%d", server_addr.c_str(), port)); | |||
uint32_t rsp = group_mgr->group_barrier(size, rank); | |||
mgb_assert(rsp != 0, "rank already registered: %d", rank); | |||
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); | |||
} | |||
#else | |||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||
mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time"); | |||
return 0; | |||
} | |||
void _config::group_barrier(const std::string& server_addr, | |||
int port, uint32_t size, uint32_t rank) { | |||
mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time"); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -12,7 +12,7 @@ | |||
#include "./python_helper.h" | |||
#if MGB_ENABLE_OPR_MM | |||
#include "mm_handler.h" | |||
#include "megbrain/opr/mm_handler.h" | |||
#endif | |||
#include "megbrain/opr/io.h" | |||
@@ -10,6 +10,10 @@ endif() | |||
if(MGE_WITH_DISTRIBUTED) | |||
file(GLOB_RECURSE SOURCES_ opr-mm/impl/*.cpp opr-mm/impl/*.inl) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../src/opr-mm/proto/*.proto") | |||
PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES}) | |||
add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE}) | |||
list(APPEND SOURCES ${GRPC_SRCS}) | |||
endif() | |||
set(MGB_INC ${PROJECT_BINARY_DIR}/genfiles core/include gopt/include opr/include plugin/include serialization/include) | |||
@@ -52,6 +56,11 @@ if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
endif() | |||
target_link_libraries(megbrain megdnn) | |||
if(MGE_WITH_DISTRIBUTED) | |||
add_dependencies(megbrain mgb_proto_target) | |||
target_link_libraries (megbrain libprotobuf libzmq) | |||
set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq) | |||
# FIXME: add CMAKE_CURRENT_BINARY_DIR for including mm_handler.pb.h | |||
target_include_directories(megbrain PRIVATE ${CPPZMQ_INC} ${CMAKE_CURRENT_BINARY_DIR}) | |||
target_link_libraries (megbrain megray) | |||
endif() | |||
target_link_libraries(megbrain ${MGE_CUDA_LIBS}) | |||
@@ -7,13 +7,14 @@ | |||
* | |||
*/ | |||
#include "mm_handler.h" | |||
#include "megbrain/opr/mm_handler.h" | |||
#include "megbrain/exception.h" | |||
#include "megbrain_config.h" | |||
#include "megbrain_build_config.h" | |||
#if MGB_ENABLE_OPR_MM | |||
#include "zmq_rpc.h" | |||
#include "megbrain/opr/zmq_rpc.h" | |||
#include "mm_handler.pb.h" | |||
#include <future> | |||
/* ======================== GroupServerProxy ========================== */ | |||
@@ -128,17 +129,22 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len, | |||
Request req; \ | |||
Response rsp; | |||
#define SOLVE_REQUEST(name, req, rsp) \ | |||
std::string req_str; \ | |||
mgb_assert(req.SerializeToString(&req_str)); \ | |||
zmq::message_t send(req_str.length() + name.length() + 1); \ | |||
zmq::message_t recv; \ | |||
memcpy(send.data(), name.data(), name.length() + 1); \ | |||
memcpy((char*)send.data() + name.length() + 1, req_str.data(), \ | |||
req_str.length()); \ | |||
m_stub->request(send, recv); \ | |||
#define SOLVE_REQUEST(name, req, rsp) \ | |||
std::string req_str; \ | |||
mgb_assert(req.SerializeToString(&req_str)); \ | |||
zmq::message_t send(req_str.length() + name.length() + 1); \ | |||
zmq::message_t recv; \ | |||
memcpy(send.data(), name.data(), name.length() + 1); \ | |||
memcpy((char*)send.data() + name.length() + 1, req_str.data(), \ | |||
req_str.length()); \ | |||
static_cast<ZmqRpc::ZmqRpcClient*>(m_stub)->request(send, recv); \ | |||
mgb_assert(rsp.ParseFromArray(recv.data(), recv.size())); | |||
GroupClientProxy::GroupClientProxy(const std::string& server_addr) | |||
: m_addr(server_addr), | |||
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { | |||
} | |||
uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices, | |||
uint32_t rank, uintptr_t stream) { | |||
INFO_INIT(mm_handler, opr_register, OprRegister) | |||
@@ -199,78 +205,26 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||
#undef INFO_INIT | |||
#undef SOLVE_REQUEST | |||
/* ======================== ZmqRpcServerMgr ========================== */ | |||
class ZmqRpcServerMgr { | |||
struct ServerInfo { | |||
std::unique_ptr<ZmqRpc::ZmqRpcServer> server; | |||
}; | |||
public: | |||
int create_zmqrpc_server(const std::string& server_addr, int port, | |||
std::unique_ptr<ZmqRpc::ZmqRpcServerImpl> service) { | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto server = | |||
std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port, | |||
std::move(service)); | |||
port = server->port(); | |||
if (port == -1) { | |||
return -1; | |||
} | |||
auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port); | |||
server->run(); | |||
auto ins = m_addr2server.emplace( | |||
full_srv_addr, ServerInfo{std::move(server)}); | |||
mgb_assert(ins.second); | |||
return port; | |||
} | |||
static ZmqRpcServerMgr* get_zmqrpc_server_mgr() { | |||
static ZmqRpcServerMgr mgr; | |||
return &mgr; | |||
} | |||
private: | |||
std::unordered_map<std::string, ServerInfo> m_addr2server; | |||
std::mutex m_mtx; | |||
struct ServerInfo { | |||
std::unique_ptr<ZmqRpc::ZmqRpcServer> server; | |||
}; | |||
/*! see definition : src/cpp/megbrain_config.h. | |||
* Create mm server. port 0 is permitted, leave zmqrpc to decide which port | |||
* should be used. | |||
*/ | |||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||
return ZmqRpcServerMgr::get_zmqrpc_server_mgr()->create_zmqrpc_server( | |||
server_addr, port, std::make_unique<GroupServerProxy>()); | |||
} | |||
/* ======================== Group Barrier ========================== */ | |||
/*! see definition : src/cpp/megbrain_config.h. | |||
* Block until all ranks in the group reach this barrier | |||
*/ | |||
void _config::group_barrier(const std::string& server_addr, | |||
int port, uint32_t size, uint32_t rank) { | |||
mgb_assert(rank < size, "invalid rank %d", rank); | |||
auto group_mgr = std::make_shared<GroupClientProxy>( | |||
ssprintf("%s:%d", server_addr.c_str(), port)); | |||
uint32_t rsp = group_mgr->group_barrier(size, rank); | |||
mgb_assert(rsp != 0, "rank already registered: %d", rank); | |||
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); | |||
} | |||
#else | |||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); | |||
return 0; | |||
} | |||
void _config::group_barrier(const std::string& server_addr, | |||
int port, uint32_t size, uint32_t rank) { | |||
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); | |||
int create_zmqrpc_server(const std::string& server_addr, int port) { | |||
static std::unordered_map<std::string, ServerInfo> addr2server; | |||
static std::mutex mtx; | |||
MGB_LOCK_GUARD(mtx); | |||
auto service = std::make_unique<GroupServerProxy>(); | |||
auto server = | |||
std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port, | |||
std::move(service)); | |||
port = server->port(); | |||
auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port); | |||
server->run(); | |||
auto ins = addr2server.emplace( | |||
full_srv_addr, ServerInfo{std::move(server)}); | |||
mgb_assert(ins.second); | |||
return port; | |||
} | |||
#endif |
@@ -1,6 +1,6 @@ | |||
#include "zmq_rpc.h" | |||
#include "megbrain/opr/zmq_rpc.h" | |||
#include "megbrain/exception.h" | |||
#include "megbrain_config.h" | |||
#include "megbrain_build_config.h" | |||
#if MGB_CUDA | |||
#include <unistd.h> |
@@ -13,10 +13,7 @@ | |||
#if MGB_ENABLE_OPR_MM | |||
#include "zmq_rpc.h" | |||
#include "megbrain/opr/collective_comm.h" | |||
#include "mm_handler.pb.h" | |||
using namespace mgb; | |||
using namespace opr; | |||
@@ -31,10 +28,7 @@ class GroupClientProxy | |||
public: | |||
virtual ~GroupClientProxy() = default; | |||
GroupClientProxy(const std::string& server_addr) | |||
: m_addr(server_addr), | |||
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { | |||
} | |||
GroupClientProxy(const std::string& server_addr); | |||
//! graph registration, assign graph_id to worker. | |||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||
@@ -50,33 +44,20 @@ public: | |||
uint32_t group_barrier(uint32_t size, uint32_t rank) override; | |||
//! thread safe to create handler with address | |||
static GroupClientProxy* get_handler(const std::string& addr) { | |||
static std::unordered_map<std::string, | |||
std::unique_ptr<GroupClientProxy>> | |||
addr2handler; | |||
static std::mutex mtx; | |||
MGB_LOCK_GUARD(mtx); | |||
auto it = addr2handler.emplace(addr, nullptr); | |||
if (!it.second) { | |||
mgb_assert(it.first->second->m_addr == addr); | |||
return it.first->second.get(); | |||
} else { | |||
auto handler = std::make_unique<GroupClientProxy>(addr); | |||
auto handler_ptr = handler.get(); | |||
it.first->second = std::move(handler); | |||
return handler_ptr; | |||
} | |||
} | |||
const std::string& get_addr() const { | |||
return m_addr; | |||
} | |||
private: | |||
const std::string m_addr; | |||
ZmqRpc::ZmqRpcClient* m_stub; | |||
void* m_stub; | |||
}; | |||
/* ======================== ZmqRpcServerMgr ========================== */ | |||
int create_zmqrpc_server(const std::string& server_addr, int port); | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -101,4 +101,4 @@ private: | |||
std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets; | |||
}; | |||
} // namespace ZmqRpc | |||
#endif | |||
#endif |