GitOrigin-RevId: f401ce8603
tags/v0.5.0
@@ -14,6 +14,7 @@ ExternalProject_add( | |||||
) | ) | ||||
set(ZMQ_INC ${ZMQ_BUILD_DIR}/include) | set(ZMQ_INC ${ZMQ_BUILD_DIR}/include) | ||||
include_directories(${ZMQ_INC}) | |||||
file(MAKE_DIRECTORY ${ZMQ_INC}) | file(MAKE_DIRECTORY ${ZMQ_INC}) | ||||
add_library(libzmq STATIC IMPORTED GLOBAL) | add_library(libzmq STATIC IMPORTED GLOBAL) | ||||
@@ -12,14 +12,6 @@ set(SWIG_SRC src/swig/mgb.i) | |||||
set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64) | set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64) | ||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | ||||
if(MGE_WITH_DISTRIBUTED) | |||||
file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "src/proto/*.proto") | |||||
PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES}) | |||||
add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE}) | |||||
endif() | |||||
file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | ||||
file(GLOB_RECURSE PYTHON_SRCS setup.py | file(GLOB_RECURSE PYTHON_SRCS setup.py | ||||
src/python/*.py | src/python/*.py | ||||
@@ -55,11 +47,7 @@ add_custom_command( | |||||
add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) | add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) | ||||
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) | |||||
if(MGE_WITH_DISTRIBUTED) | |||||
list(APPEND SRCS src/cpp/zmq_rpc.cpp) | |||||
endif() | |||||
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) | |||||
include(UseSWIG) | include(UseSWIG) | ||||
set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) | set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) | ||||
@@ -70,7 +58,7 @@ set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/ | |||||
set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) | set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) | ||||
set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | ||||
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${GRPC_SRCS} ${SRCS}) | |||||
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${SRCS}) | |||||
set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) | set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) | ||||
add_custom_target(version_ld SOURCES ${VERSION_SCRIPT}) | add_custom_target(version_ld SOURCES ${VERSION_SCRIPT}) | ||||
@@ -81,12 +69,6 @@ target_include_directories(_mgb PRIVATE ${PYTHON_INCLUDE_DIRS} src/cpp ${CMAKE_C | |||||
target_link_libraries(_mgb ${PYTHON_LIBRARIES}) | target_link_libraries(_mgb ${PYTHON_LIBRARIES}) | ||||
add_dependencies(_mgb mgb_opr_py version_ld) | add_dependencies(_mgb mgb_opr_py version_ld) | ||||
if(MGE_WITH_DISTRIBUTED) | |||||
add_dependencies(_mgb mgb_proto_target) | |||||
target_link_libraries (_mgb libprotobuf libzmq) | |||||
set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq) | |||||
target_include_directories(_mgb PRIVATE ${CPPZMQ_INC}) | |||||
endif() | |||||
add_custom_command( | add_custom_command( | ||||
TARGET _mgb POST_BUILD | TARGET _mgb POST_BUILD | ||||
@@ -19,6 +19,10 @@ | |||||
#include <dlfcn.h> | #include <dlfcn.h> | ||||
#if MGB_ENABLE_OPR_MM | |||||
#include "megbrain/opr/mm_handler.h" | |||||
#endif | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
#include <cuda.h> | #include <cuda.h> | ||||
#endif | #endif | ||||
@@ -276,4 +280,37 @@ std::vector<std::pair<uint64_t, std::string>> _config::dump_registered_oprs() { | |||||
#endif | #endif | ||||
} | } | ||||
#if MGB_ENABLE_OPR_MM | |||||
/*! see definition : src/cpp/megbrain_config.h. | |||||
* Create mm server. port 0 is permitted, leave zmqrpc to decide which port | |||||
* should be used. | |||||
*/ | |||||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||||
return create_zmqrpc_server(server_addr, port); | |||||
} | |||||
void _config::group_barrier(const std::string& server_addr, | |||||
int port, uint32_t size, uint32_t rank) { | |||||
mgb_assert(rank < size, "invalid rank %d", rank); | |||||
auto group_mgr = std::make_shared<GroupClientProxy>( | |||||
ssprintf("%s:%d", server_addr.c_str(), port)); | |||||
uint32_t rsp = group_mgr->group_barrier(size, rank); | |||||
mgb_assert(rsp != 0, "rank already registered: %d", rank); | |||||
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); | |||||
} | |||||
#else | |||||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||||
mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time"); | |||||
return 0; | |||||
} | |||||
void _config::group_barrier(const std::string& server_addr, | |||||
int port, uint32_t size, uint32_t rank) { | |||||
mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time"); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -12,7 +12,7 @@ | |||||
#include "./python_helper.h" | #include "./python_helper.h" | ||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
#include "mm_handler.h" | |||||
#include "megbrain/opr/mm_handler.h" | |||||
#endif | #endif | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
@@ -10,6 +10,10 @@ endif() | |||||
if(MGE_WITH_DISTRIBUTED) | if(MGE_WITH_DISTRIBUTED) | ||||
file(GLOB_RECURSE SOURCES_ opr-mm/impl/*.cpp opr-mm/impl/*.inl) | file(GLOB_RECURSE SOURCES_ opr-mm/impl/*.cpp opr-mm/impl/*.inl) | ||||
list(APPEND SOURCES ${SOURCES_}) | list(APPEND SOURCES ${SOURCES_}) | ||||
file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../src/opr-mm/proto/*.proto") | |||||
PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES}) | |||||
add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE}) | |||||
list(APPEND SOURCES ${GRPC_SRCS}) | |||||
endif() | endif() | ||||
set(MGB_INC ${PROJECT_BINARY_DIR}/genfiles core/include gopt/include opr/include plugin/include serialization/include) | set(MGB_INC ${PROJECT_BINARY_DIR}/genfiles core/include gopt/include opr/include plugin/include serialization/include) | ||||
@@ -52,6 +56,11 @@ if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||||
endif() | endif() | ||||
target_link_libraries(megbrain megdnn) | target_link_libraries(megbrain megdnn) | ||||
if(MGE_WITH_DISTRIBUTED) | if(MGE_WITH_DISTRIBUTED) | ||||
add_dependencies(megbrain mgb_proto_target) | |||||
target_link_libraries (megbrain libprotobuf libzmq) | |||||
set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq) | |||||
# FIXME: add CMAKE_CURRENT_BINARY_DIR for including mm_handler.pb.h | |||||
target_include_directories(megbrain PRIVATE ${CPPZMQ_INC} ${CMAKE_CURRENT_BINARY_DIR}) | |||||
target_link_libraries (megbrain megray) | target_link_libraries (megbrain megray) | ||||
endif() | endif() | ||||
target_link_libraries(megbrain ${MGE_CUDA_LIBS}) | target_link_libraries(megbrain ${MGE_CUDA_LIBS}) | ||||
@@ -7,13 +7,14 @@ | |||||
* | * | ||||
*/ | */ | ||||
#include "mm_handler.h" | |||||
#include "megbrain/opr/mm_handler.h" | |||||
#include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
#include "megbrain_config.h" | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
#include "zmq_rpc.h" | |||||
#include "megbrain/opr/zmq_rpc.h" | |||||
#include "mm_handler.pb.h" | |||||
#include <future> | #include <future> | ||||
/* ======================== GroupServerProxy ========================== */ | /* ======================== GroupServerProxy ========================== */ | ||||
@@ -128,17 +129,22 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len, | |||||
Request req; \ | Request req; \ | ||||
Response rsp; | Response rsp; | ||||
#define SOLVE_REQUEST(name, req, rsp) \ | |||||
std::string req_str; \ | |||||
mgb_assert(req.SerializeToString(&req_str)); \ | |||||
zmq::message_t send(req_str.length() + name.length() + 1); \ | |||||
zmq::message_t recv; \ | |||||
memcpy(send.data(), name.data(), name.length() + 1); \ | |||||
memcpy((char*)send.data() + name.length() + 1, req_str.data(), \ | |||||
req_str.length()); \ | |||||
m_stub->request(send, recv); \ | |||||
#define SOLVE_REQUEST(name, req, rsp) \ | |||||
std::string req_str; \ | |||||
mgb_assert(req.SerializeToString(&req_str)); \ | |||||
zmq::message_t send(req_str.length() + name.length() + 1); \ | |||||
zmq::message_t recv; \ | |||||
memcpy(send.data(), name.data(), name.length() + 1); \ | |||||
memcpy((char*)send.data() + name.length() + 1, req_str.data(), \ | |||||
req_str.length()); \ | |||||
static_cast<ZmqRpc::ZmqRpcClient*>(m_stub)->request(send, recv); \ | |||||
mgb_assert(rsp.ParseFromArray(recv.data(), recv.size())); | mgb_assert(rsp.ParseFromArray(recv.data(), recv.size())); | ||||
GroupClientProxy::GroupClientProxy(const std::string& server_addr) | |||||
: m_addr(server_addr), | |||||
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { | |||||
} | |||||
uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices, | uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices, | ||||
uint32_t rank, uintptr_t stream) { | uint32_t rank, uintptr_t stream) { | ||||
INFO_INIT(mm_handler, opr_register, OprRegister) | INFO_INIT(mm_handler, opr_register, OprRegister) | ||||
@@ -199,78 +205,26 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||||
#undef INFO_INIT | #undef INFO_INIT | ||||
#undef SOLVE_REQUEST | #undef SOLVE_REQUEST | ||||
/* ======================== ZmqRpcServerMgr ========================== */ | |||||
class ZmqRpcServerMgr { | |||||
struct ServerInfo { | |||||
std::unique_ptr<ZmqRpc::ZmqRpcServer> server; | |||||
}; | |||||
public: | |||||
int create_zmqrpc_server(const std::string& server_addr, int port, | |||||
std::unique_ptr<ZmqRpc::ZmqRpcServerImpl> service) { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
auto server = | |||||
std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port, | |||||
std::move(service)); | |||||
port = server->port(); | |||||
if (port == -1) { | |||||
return -1; | |||||
} | |||||
auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port); | |||||
server->run(); | |||||
auto ins = m_addr2server.emplace( | |||||
full_srv_addr, ServerInfo{std::move(server)}); | |||||
mgb_assert(ins.second); | |||||
return port; | |||||
} | |||||
static ZmqRpcServerMgr* get_zmqrpc_server_mgr() { | |||||
static ZmqRpcServerMgr mgr; | |||||
return &mgr; | |||||
} | |||||
private: | |||||
std::unordered_map<std::string, ServerInfo> m_addr2server; | |||||
std::mutex m_mtx; | |||||
struct ServerInfo { | |||||
std::unique_ptr<ZmqRpc::ZmqRpcServer> server; | |||||
}; | }; | ||||
/*! see definition : src/cpp/megbrain_config.h. | |||||
* Create mm server. port 0 is permitted, leave zmqrpc to decide which port | |||||
* should be used. | |||||
*/ | |||||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||||
return ZmqRpcServerMgr::get_zmqrpc_server_mgr()->create_zmqrpc_server( | |||||
server_addr, port, std::make_unique<GroupServerProxy>()); | |||||
} | |||||
/* ======================== Group Barrier ========================== */ | |||||
/*! see definition : src/cpp/megbrain_config.h. | |||||
* Block until all ranks in the group reach this barrier | |||||
*/ | |||||
void _config::group_barrier(const std::string& server_addr, | |||||
int port, uint32_t size, uint32_t rank) { | |||||
mgb_assert(rank < size, "invalid rank %d", rank); | |||||
auto group_mgr = std::make_shared<GroupClientProxy>( | |||||
ssprintf("%s:%d", server_addr.c_str(), port)); | |||||
uint32_t rsp = group_mgr->group_barrier(size, rank); | |||||
mgb_assert(rsp != 0, "rank already registered: %d", rank); | |||||
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); | |||||
} | |||||
#else | |||||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||||
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); | |||||
return 0; | |||||
} | |||||
void _config::group_barrier(const std::string& server_addr, | |||||
int port, uint32_t size, uint32_t rank) { | |||||
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); | |||||
int create_zmqrpc_server(const std::string& server_addr, int port) { | |||||
static std::unordered_map<std::string, ServerInfo> addr2server; | |||||
static std::mutex mtx; | |||||
MGB_LOCK_GUARD(mtx); | |||||
auto service = std::make_unique<GroupServerProxy>(); | |||||
auto server = | |||||
std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port, | |||||
std::move(service)); | |||||
port = server->port(); | |||||
auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port); | |||||
server->run(); | |||||
auto ins = addr2server.emplace( | |||||
full_srv_addr, ServerInfo{std::move(server)}); | |||||
mgb_assert(ins.second); | |||||
return port; | |||||
} | } | ||||
#endif | #endif |
@@ -1,6 +1,6 @@ | |||||
#include "zmq_rpc.h" | |||||
#include "megbrain/opr/zmq_rpc.h" | |||||
#include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
#include "megbrain_config.h" | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
#include <unistd.h> | #include <unistd.h> |
@@ -13,10 +13,7 @@ | |||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
#include "zmq_rpc.h" | |||||
#include "megbrain/opr/collective_comm.h" | #include "megbrain/opr/collective_comm.h" | ||||
#include "mm_handler.pb.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | using namespace opr; | ||||
@@ -31,10 +28,7 @@ class GroupClientProxy | |||||
public: | public: | ||||
virtual ~GroupClientProxy() = default; | virtual ~GroupClientProxy() = default; | ||||
GroupClientProxy(const std::string& server_addr) | |||||
: m_addr(server_addr), | |||||
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { | |||||
} | |||||
GroupClientProxy(const std::string& server_addr); | |||||
//! graph registration, assign graph_id to worker. | //! graph registration, assign graph_id to worker. | ||||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | ||||
@@ -50,33 +44,20 @@ public: | |||||
uint32_t group_barrier(uint32_t size, uint32_t rank) override; | uint32_t group_barrier(uint32_t size, uint32_t rank) override; | ||||
//! thread safe to create handler with address | |||||
static GroupClientProxy* get_handler(const std::string& addr) { | |||||
static std::unordered_map<std::string, | |||||
std::unique_ptr<GroupClientProxy>> | |||||
addr2handler; | |||||
static std::mutex mtx; | |||||
MGB_LOCK_GUARD(mtx); | |||||
auto it = addr2handler.emplace(addr, nullptr); | |||||
if (!it.second) { | |||||
mgb_assert(it.first->second->m_addr == addr); | |||||
return it.first->second.get(); | |||||
} else { | |||||
auto handler = std::make_unique<GroupClientProxy>(addr); | |||||
auto handler_ptr = handler.get(); | |||||
it.first->second = std::move(handler); | |||||
return handler_ptr; | |||||
} | |||||
} | |||||
const std::string& get_addr() const { | const std::string& get_addr() const { | ||||
return m_addr; | return m_addr; | ||||
} | } | ||||
private: | private: | ||||
const std::string m_addr; | const std::string m_addr; | ||||
ZmqRpc::ZmqRpcClient* m_stub; | |||||
void* m_stub; | |||||
}; | }; | ||||
/* ======================== ZmqRpcServerMgr ========================== */ | |||||
int create_zmqrpc_server(const std::string& server_addr, int port); | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -101,4 +101,4 @@ private: | |||||
std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets; | std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets; | ||||
}; | }; | ||||
} // namespace ZmqRpc | } // namespace ZmqRpc | ||||
#endif | |||||
#endif |