You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mm_handler.cpp 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * \file python_module/src/cpp/mm_handler.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "megbrain/opr/mm_handler.h"
  10. #include "megbrain/exception.h"
  11. #include "megbrain_build_config.h"
  12. #if MGB_ENABLE_OPR_MM
  13. #include "megbrain/opr/zmq_rpc.h"
  14. #include "mm_handler.pb.h"
  15. #include <future>
  16. /* ======================== GroupServerProxy ========================== */
  17. /*!
  18. * A proxy that receives zmqrpc call, direct call to NCCL Manager
  19. */
  20. #define RUNSERVER(rpc_name) \
  21. if (std::strcmp(describe, #rpc_name) == 0) { \
  22. std::string output; \
  23. rpc_name(input_ptr, input_len, &output); \
  24. reply.rebuild(output.length()); \
  25. memcpy(reply.data(), output.data(), output.length()); \
  26. return; \
  27. }
  28. class GroupServerProxy final : public ZmqRpc::ZmqRpcServerImpl {
  29. public:
  30. void solve_request(zmq::message_t& request,
  31. zmq::message_t& reply) override {
  32. char* describe = (char*)request.data();
  33. void* input_ptr = (char*)request.data() + strlen(describe) + 1;
  34. size_t input_len = request.size() - strlen(describe) - 1;
  35. RUNSERVER(opr_register);
  36. RUNSERVER(set_output_shape);
  37. RUNSERVER(get_output_shape);
  38. RUNSERVER(gather_uid);
  39. RUNSERVER(group_barrier);
  40. mgb_assert(false, "invalid rpc request");
  41. }
  42. private:
  43. void opr_register(void* input_ptr, size_t input_len, std::string *output);
  44. void set_output_shape(void* input_ptr, size_t input_len, std::string *output);
  45. void get_output_shape(void* input_ptr, size_t input_len, std::string *output);
  46. void gather_uid(void* input_ptr, size_t input_len, std::string *output);
  47. void group_barrier(void* input_ptr, size_t input_len, std::string *output);
  48. private:
  49. GroupManager m_mgr;
  50. };
  51. #undef RUNSERVER
  52. #define INFO_INIT(space, name) \
  53. using Request = space::name##Request; \
  54. using Response = space::name##Response; \
  55. Request req; \
  56. Response rsp; \
  57. req.ParseFromArray(input_ptr, input_len);
  58. void GroupServerProxy::opr_register(void* input_ptr, size_t input_len,
  59. std::string *output) {
  60. INFO_INIT(mm_handler, OprRegister);
  61. uint64_t hash = m_mgr.opr_register(req.key(), req.nr_expected_devices(),
  62. req.rank(), req.stream());
  63. rsp.set_hash(hash);
  64. rsp.SerializeToString(output);
  65. }
  66. void GroupServerProxy::set_output_shape(void* input_ptr, size_t input_len,
  67. std::string *output) {
  68. INFO_INIT(mm_handler, SetOutputShape);
  69. auto&& shape_proto = req.shape();
  70. TensorShape shape;
  71. shape.ndim = shape_proto.ndim();
  72. for (size_t i = 0; i < shape.ndim; ++i) {
  73. shape.shape[i] = shape_proto.shape(i);
  74. }
  75. m_mgr.set_output_shape(req.key(), shape);
  76. rsp.SerializeToString(output);
  77. }
  78. void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len,
  79. std::string *output) {
  80. INFO_INIT(mm_handler, GetOutputShape);
  81. auto shape = m_mgr.get_output_shape(req.key());
  82. auto&& shape_proto = *rsp.mutable_shape();
  83. shape_proto.set_ndim(shape.ndim);
  84. for (size_t i = 0; i < shape.ndim; ++i) {
  85. shape_proto.add_shape(shape[i]);
  86. }
  87. rsp.SerializeToString(output);
  88. }
  89. void GroupServerProxy::gather_uid(void* input_ptr, size_t input_len,
  90. std::string *output) {
  91. INFO_INIT(mm_handler, GatherUid);
  92. auto uid = req.uid();
  93. auto uids = m_mgr.gather_uid(uid, req.key(), req.size(), req.rank());
  94. for (size_t i = 0;i < uids.size();i++) {
  95. rsp.add_uids();
  96. rsp.set_uids(i, uids[i].data(), uids[i].size());
  97. }
  98. rsp.SerializeToString(output);
  99. }
  100. void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len,
  101. std::string *output) {
  102. INFO_INIT(mm_handler, GroupBarrier);
  103. uint32_t rsp_size = m_mgr.group_barrier(req.size(), req.rank());
  104. rsp.set_size(rsp_size);
  105. rsp.SerializeToString(output);
  106. }
  107. #undef INFO_INIT
  108. /* ======================== GroupClientProxy ========================== */
  109. #define INFO_INIT(space, f_name, name) \
  110. using Request = space::name##Request; \
  111. using Response = space::name##Response; \
  112. std::string func_name = #f_name; \
  113. Request req; \
  114. Response rsp;
  115. #define SOLVE_REQUEST(name, req, rsp) \
  116. std::string req_str; \
  117. mgb_assert(req.SerializeToString(&req_str)); \
  118. zmq::message_t send(req_str.length() + name.length() + 1); \
  119. zmq::message_t recv; \
  120. memcpy(send.data(), name.data(), name.length() + 1); \
  121. memcpy((char*)send.data() + name.length() + 1, req_str.data(), \
  122. req_str.length()); \
  123. static_cast<ZmqRpc::ZmqRpcClient*>(m_stub)->request(send, recv); \
  124. mgb_assert(rsp.ParseFromArray(recv.data(), recv.size()));
  125. GroupClientProxy::GroupClientProxy(const std::string& server_addr)
  126. : m_addr(server_addr),
  127. m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} {
  128. }
  129. uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices,
  130. uint32_t rank, uintptr_t stream) {
  131. INFO_INIT(mm_handler, opr_register, OprRegister)
  132. req.set_key(key);
  133. req.set_rank(rank);
  134. req.set_stream(stream);
  135. req.set_nr_expected_devices(nr_devices);
  136. SOLVE_REQUEST(func_name, req, rsp);
  137. return rsp.hash();
  138. }
  139. void GroupClientProxy::set_output_shape(const std::string& key,
  140. const TensorShape& shape) {
  141. INFO_INIT(mm_handler, set_output_shape, SetOutputShape)
  142. req.set_key(key);
  143. auto&& shape_proto = *req.mutable_shape();
  144. shape_proto.set_ndim(shape.ndim);
  145. for (size_t i = 0; i < shape.ndim; ++i) {
  146. shape_proto.add_shape(shape[i]);
  147. }
  148. SOLVE_REQUEST(func_name, req, rsp);
  149. }
  150. TensorShape GroupClientProxy::get_output_shape(const std::string& key) {
  151. INFO_INIT(mm_handler, get_output_shape, GetOutputShape)
  152. req.set_key(key);
  153. SOLVE_REQUEST(func_name, req, rsp);
  154. TensorShape shape;
  155. shape.ndim = rsp.shape().ndim();
  156. for (size_t i = 0; i < shape.ndim; ++i) {
  157. shape[i] = rsp.shape().shape(i);
  158. }
  159. return shape;
  160. }
  161. std::vector<std::string> GroupClientProxy::gather_uid(const std::string& uid,
  162. const std::string& key, uint32_t size, uint32_t rank) {
  163. INFO_INIT(mm_handler, gather_uid, GatherUid);
  164. req.set_uid(uid.data(), uid.size());
  165. req.set_key(key.data(), key.size());
  166. req.set_size(size);
  167. req.set_rank(rank);
  168. SOLVE_REQUEST(func_name, req, rsp);
  169. std::vector<std::string> rst;
  170. for (size_t i = 0;i < size;i++) {
  171. rst.push_back(rsp.uids(i));
  172. }
  173. return rst;
  174. }
  175. uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
  176. INFO_INIT(mm_handler, group_barrier, GroupBarrier);
  177. req.set_size(size);
  178. req.set_rank(rank);
  179. SOLVE_REQUEST(func_name, req, rsp);
  180. return rsp.size();
  181. }
  182. #undef INFO_INIT
  183. #undef SOLVE_REQUEST
  184. struct ServerInfo {
  185. std::unique_ptr<ZmqRpc::ZmqRpcServer> server;
  186. };
  187. int create_zmqrpc_server(const std::string& server_addr, int port) {
  188. static std::unordered_map<std::string, ServerInfo> addr2server;
  189. static std::mutex mtx;
  190. MGB_LOCK_GUARD(mtx);
  191. auto service = std::make_unique<GroupServerProxy>();
  192. auto server =
  193. std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port,
  194. std::move(service));
  195. port = server->port();
  196. auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port);
  197. server->run();
  198. auto ins = addr2server.emplace(
  199. full_srv_addr, ServerInfo{std::move(server)});
  200. mgb_assert(ins.second);
  201. return port;
  202. }
  203. #endif
  204. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台