Browse Source

refactor(mgb): move mm_handler from python module into opr-mm

GitOrigin-RevId: f401ce8603
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
d7bb62cfa1
10 changed files with 97 additions and 133 deletions
  1. +1
    -0
      cmake/zmq.cmake
  2. +2
    -20
      python_module/CMakeLists.txt
  3. +37
    -0
      python_module/src/cpp/megbrain_config.cpp
  4. +1
    -1
      python_module/src/cpp/opr_defs.cpp
  5. +9
    -0
      src/CMakeLists.txt
  6. +36
    -82
      src/opr-mm/impl/mm_handler.cpp
  7. +2
    -2
      src/opr-mm/impl/zmq_rpc.cpp
  8. +8
    -27
      src/opr-mm/include/megbrain/opr/mm_handler.h
  9. +1
    -1
      src/opr-mm/include/megbrain/opr/zmq_rpc.h
  10. +0
    -0
      src/opr-mm/proto/mm_handler.proto

+ 1
- 0
cmake/zmq.cmake View File

@@ -14,6 +14,7 @@ ExternalProject_add(
) )


set(ZMQ_INC ${ZMQ_BUILD_DIR}/include) set(ZMQ_INC ${ZMQ_BUILD_DIR}/include)
include_directories(${ZMQ_INC})
file(MAKE_DIRECTORY ${ZMQ_INC}) file(MAKE_DIRECTORY ${ZMQ_INC})


add_library(libzmq STATIC IMPORTED GLOBAL) add_library(libzmq STATIC IMPORTED GLOBAL)


+ 2
- 20
python_module/CMakeLists.txt View File

@@ -12,14 +12,6 @@ set(SWIG_SRC src/swig/mgb.i)
set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64) set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")


if(MGE_WITH_DISTRIBUTED)
file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "src/proto/*.proto")

PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES})
add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE})
endif()


file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl")
file(GLOB_RECURSE PYTHON_SRCS setup.py file(GLOB_RECURSE PYTHON_SRCS setup.py
src/python/*.py src/python/*.py
@@ -55,11 +47,7 @@ add_custom_command(


add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py)


set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)

if(MGE_WITH_DISTRIBUTED)
list(APPEND SRCS src/cpp/zmq_rpc.cpp)
endif()
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)


include(UseSWIG) include(UseSWIG)
set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON)
@@ -70,7 +58,7 @@ set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/


set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal)
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${GRPC_SRCS} ${SRCS})
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${SRCS})


set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld)
add_custom_target(version_ld SOURCES ${VERSION_SCRIPT}) add_custom_target(version_ld SOURCES ${VERSION_SCRIPT})
@@ -81,12 +69,6 @@ target_include_directories(_mgb PRIVATE ${PYTHON_INCLUDE_DIRS} src/cpp ${CMAKE_C
target_link_libraries(_mgb ${PYTHON_LIBRARIES}) target_link_libraries(_mgb ${PYTHON_LIBRARIES})


add_dependencies(_mgb mgb_opr_py version_ld) add_dependencies(_mgb mgb_opr_py version_ld)
if(MGE_WITH_DISTRIBUTED)
add_dependencies(_mgb mgb_proto_target)
target_link_libraries (_mgb libprotobuf libzmq)
set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq)
target_include_directories(_mgb PRIVATE ${CPPZMQ_INC})
endif()


add_custom_command( add_custom_command(
TARGET _mgb POST_BUILD TARGET _mgb POST_BUILD


+ 37
- 0
python_module/src/cpp/megbrain_config.cpp View File

@@ -19,6 +19,10 @@


#include <dlfcn.h> #include <dlfcn.h>


#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/mm_handler.h"
#endif

#if MGB_CUDA #if MGB_CUDA
#include <cuda.h> #include <cuda.h>
#endif #endif
@@ -276,4 +280,37 @@ std::vector<std::pair<uint64_t, std::string>> _config::dump_registered_oprs() {
#endif #endif
} }


#if MGB_ENABLE_OPR_MM
/*! see definition : src/cpp/megbrain_config.h.
* Create mm server. port 0 is permitted, leave zmqrpc to decide which port
* should be used.
*/
int _config::create_mm_server(const std::string& server_addr, int port) {
return create_zmqrpc_server(server_addr, port);
}

void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_assert(rank < size, "invalid rank %d", rank);
auto group_mgr = std::make_shared<GroupClientProxy>(
ssprintf("%s:%d", server_addr.c_str(), port));
uint32_t rsp = group_mgr->group_barrier(size, rank);
mgb_assert(rsp != 0, "rank already registered: %d", rank);
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp);
}

#else

int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time");
return 0;
}

void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time");
}

#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 1
python_module/src/cpp/opr_defs.cpp View File

@@ -12,7 +12,7 @@
#include "./python_helper.h" #include "./python_helper.h"


#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM
#include "mm_handler.h"
#include "megbrain/opr/mm_handler.h"
#endif #endif


#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"


+ 9
- 0
src/CMakeLists.txt View File

@@ -10,6 +10,10 @@ endif()
if(MGE_WITH_DISTRIBUTED) if(MGE_WITH_DISTRIBUTED)
file(GLOB_RECURSE SOURCES_ opr-mm/impl/*.cpp opr-mm/impl/*.inl) file(GLOB_RECURSE SOURCES_ opr-mm/impl/*.cpp opr-mm/impl/*.inl)
list(APPEND SOURCES ${SOURCES_}) list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../src/opr-mm/proto/*.proto")
PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES})
add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE})
list(APPEND SOURCES ${GRPC_SRCS})
endif() endif()


set(MGB_INC ${PROJECT_BINARY_DIR}/genfiles core/include gopt/include opr/include plugin/include serialization/include) set(MGB_INC ${PROJECT_BINARY_DIR}/genfiles core/include gopt/include opr/include plugin/include serialization/include)
@@ -52,6 +56,11 @@ if(CXX_SUPPORT_WCLASS_MEMACCESS)
endif() endif()
target_link_libraries(megbrain megdnn) target_link_libraries(megbrain megdnn)
if(MGE_WITH_DISTRIBUTED) if(MGE_WITH_DISTRIBUTED)
add_dependencies(megbrain mgb_proto_target)
target_link_libraries (megbrain libprotobuf libzmq)
set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq)
# FIXME: add CMAKE_CURRENT_BINARY_DIR for including mm_handler.pb.h
target_include_directories(megbrain PRIVATE ${CPPZMQ_INC} ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries (megbrain megray) target_link_libraries (megbrain megray)
endif() endif()
target_link_libraries(megbrain ${MGE_CUDA_LIBS}) target_link_libraries(megbrain ${MGE_CUDA_LIBS})


python_module/src/cpp/mm_handler.cpp → src/opr-mm/impl/mm_handler.cpp View File

@@ -7,13 +7,14 @@
* *
*/ */


#include "mm_handler.h"
#include "megbrain/opr/mm_handler.h"


#include "megbrain/exception.h" #include "megbrain/exception.h"
#include "megbrain_config.h"
#include "megbrain_build_config.h"


#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM
#include "zmq_rpc.h"
#include "megbrain/opr/zmq_rpc.h"
#include "mm_handler.pb.h"
#include <future> #include <future>


/* ======================== GroupServerProxy ========================== */ /* ======================== GroupServerProxy ========================== */
@@ -128,17 +129,22 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len,
Request req; \ Request req; \
Response rsp; Response rsp;


#define SOLVE_REQUEST(name, req, rsp) \
std::string req_str; \
mgb_assert(req.SerializeToString(&req_str)); \
zmq::message_t send(req_str.length() + name.length() + 1); \
zmq::message_t recv; \
memcpy(send.data(), name.data(), name.length() + 1); \
memcpy((char*)send.data() + name.length() + 1, req_str.data(), \
req_str.length()); \
m_stub->request(send, recv); \
#define SOLVE_REQUEST(name, req, rsp) \
std::string req_str; \
mgb_assert(req.SerializeToString(&req_str)); \
zmq::message_t send(req_str.length() + name.length() + 1); \
zmq::message_t recv; \
memcpy(send.data(), name.data(), name.length() + 1); \
memcpy((char*)send.data() + name.length() + 1, req_str.data(), \
req_str.length()); \
static_cast<ZmqRpc::ZmqRpcClient*>(m_stub)->request(send, recv); \
mgb_assert(rsp.ParseFromArray(recv.data(), recv.size())); mgb_assert(rsp.ParseFromArray(recv.data(), recv.size()));


GroupClientProxy::GroupClientProxy(const std::string& server_addr)
: m_addr(server_addr),
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} {
}

uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices, uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices,
uint32_t rank, uintptr_t stream) { uint32_t rank, uintptr_t stream) {
INFO_INIT(mm_handler, opr_register, OprRegister) INFO_INIT(mm_handler, opr_register, OprRegister)
@@ -199,78 +205,26 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
#undef INFO_INIT #undef INFO_INIT
#undef SOLVE_REQUEST #undef SOLVE_REQUEST


/* ======================== ZmqRpcServerMgr ========================== */

class ZmqRpcServerMgr {
struct ServerInfo {
std::unique_ptr<ZmqRpc::ZmqRpcServer> server;
};

public:
int create_zmqrpc_server(const std::string& server_addr, int port,
std::unique_ptr<ZmqRpc::ZmqRpcServerImpl> service) {
MGB_LOCK_GUARD(m_mtx);
auto server =
std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port,
std::move(service));
port = server->port();
if (port == -1) {
return -1;
}

auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port);
server->run();
auto ins = m_addr2server.emplace(
full_srv_addr, ServerInfo{std::move(server)});
mgb_assert(ins.second);

return port;
}

static ZmqRpcServerMgr* get_zmqrpc_server_mgr() {
static ZmqRpcServerMgr mgr;
return &mgr;
}

private:
std::unordered_map<std::string, ServerInfo> m_addr2server;
std::mutex m_mtx;
struct ServerInfo {
std::unique_ptr<ZmqRpc::ZmqRpcServer> server;
}; };


/*! see definition : src/cpp/megbrain_config.h.
* Create mm server. port 0 is permitted, leave zmqrpc to decide which port
* should be used.
*/
int _config::create_mm_server(const std::string& server_addr, int port) {
return ZmqRpcServerMgr::get_zmqrpc_server_mgr()->create_zmqrpc_server(
server_addr, port, std::make_unique<GroupServerProxy>());
}

/* ======================== Group Barrier ========================== */

/*! see definition : src/cpp/megbrain_config.h.
* Block until all ranks in the group reach this barrier
*/
void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_assert(rank < size, "invalid rank %d", rank);
auto group_mgr = std::make_shared<GroupClientProxy>(
ssprintf("%s:%d", server_addr.c_str(), port));
uint32_t rsp = group_mgr->group_barrier(size, rank);
mgb_assert(rsp != 0, "rank already registered: %d", rank);
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp);
}

#else

int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
return 0;
}

void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
int create_zmqrpc_server(const std::string& server_addr, int port) {
static std::unordered_map<std::string, ServerInfo> addr2server;
static std::mutex mtx;
MGB_LOCK_GUARD(mtx);
auto service = std::make_unique<GroupServerProxy>();
auto server =
std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port,
std::move(service));
port = server->port();
auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port);
server->run();
auto ins = addr2server.emplace(
full_srv_addr, ServerInfo{std::move(server)});
mgb_assert(ins.second);

return port;
} }


#endif #endif

python_module/src/cpp/zmq_rpc.cpp → src/opr-mm/impl/zmq_rpc.cpp View File

@@ -1,6 +1,6 @@
#include "zmq_rpc.h"
#include "megbrain/opr/zmq_rpc.h"
#include "megbrain/exception.h" #include "megbrain/exception.h"
#include "megbrain_config.h"
#include "megbrain_build_config.h"


#if MGB_CUDA #if MGB_CUDA
#include <unistd.h> #include <unistd.h>

python_module/src/cpp/mm_handler.h → src/opr-mm/include/megbrain/opr/mm_handler.h View File

@@ -13,10 +13,7 @@


#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM


#include "zmq_rpc.h"

#include "megbrain/opr/collective_comm.h" #include "megbrain/opr/collective_comm.h"
#include "mm_handler.pb.h"


using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
@@ -31,10 +28,7 @@ class GroupClientProxy
public: public:
virtual ~GroupClientProxy() = default; virtual ~GroupClientProxy() = default;


GroupClientProxy(const std::string& server_addr)
: m_addr(server_addr),
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} {
}
GroupClientProxy(const std::string& server_addr);


//! graph registration, assign graph_id to worker. //! graph registration, assign graph_id to worker.
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
@@ -50,33 +44,20 @@ public:


uint32_t group_barrier(uint32_t size, uint32_t rank) override; uint32_t group_barrier(uint32_t size, uint32_t rank) override;


//! thread safe to create handler with address
static GroupClientProxy* get_handler(const std::string& addr) {
static std::unordered_map<std::string,
std::unique_ptr<GroupClientProxy>>
addr2handler;
static std::mutex mtx;
MGB_LOCK_GUARD(mtx);
auto it = addr2handler.emplace(addr, nullptr);
if (!it.second) {
mgb_assert(it.first->second->m_addr == addr);
return it.first->second.get();
} else {
auto handler = std::make_unique<GroupClientProxy>(addr);
auto handler_ptr = handler.get();
it.first->second = std::move(handler);
return handler_ptr;
}
}

const std::string& get_addr() const { const std::string& get_addr() const {
return m_addr; return m_addr;
} }


private: private:
const std::string m_addr; const std::string m_addr;
ZmqRpc::ZmqRpcClient* m_stub;
void* m_stub;
}; };


/* ======================== ZmqRpcServerMgr ========================== */

int create_zmqrpc_server(const std::string& server_addr, int port);

#endif #endif


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

python_module/src/cpp/zmq_rpc.h → src/opr-mm/include/megbrain/opr/zmq_rpc.h View File

@@ -101,4 +101,4 @@ private:
std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets; std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets;
}; };
} // namespace ZmqRpc } // namespace ZmqRpc
#endif
#endif

python_module/src/proto/mm_handler.proto → src/opr-mm/proto/mm_handler.proto View File


Loading…
Cancel
Save