You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zmq_rpc.cpp 6.6 kB


  1. #include "zmq_rpc.h"
  2. #include "megbrain/exception.h"
  3. #include "megbrain_config.h"
  4. #if MGB_CUDA
  5. #include <unistd.h>
  6. #include <cassert>
  7. #include <cstdio>
  8. #include <iostream>
  9. #include <mutex>
  10. #include <queue>
  11. #include <string>
  12. #include <thread>
  13. #include <vector>
  14. #include <zmq.hpp>
  15. using namespace std;
  16. using namespace zmq;
  17. using namespace ZmqRpc;
  18. ZmqRpcWorker::ZmqRpcWorker(context_t* context, ZmqRpcServerImpl* impl)
  19. : m_ctx(context), m_runable(0), m_impl(impl) {}
  20. void ZmqRpcWorker::run() {
  21. add_worker();
  22. }
  23. void ZmqRpcWorker::close() {
  24. m_stop = true;
  25. for (auto& thread : m_worker_threads) {
  26. thread.join();
  27. }
  28. }
  29. void ZmqRpcWorker::work(string uid) {
  30. // req work pattern: send recv send recv ...
  31. zmq::socket_t socket(*m_ctx, ZMQ_REQ);
  32. socket.setsockopt(ZMQ_IDENTITY, uid.data(), uid.size());
  33. socket.connect("inproc://workers");
  34. // send READY to notify server that worker is ready
  35. zmq::message_t ready(6);
  36. memcpy(ready.data(), "READY", 6);
  37. socket.send(ready, send_flags::dontwait);
  38. while (!m_stop) {
  39. // Wait for next request from client
  40. // request should be like [address, empty, msg]
  41. message_t address;
  42. recv_result_t ret_code;
  43. while (!m_stop) {
  44. ret_code = socket.recv(address, recv_flags::dontwait);
  45. if (ret_code.has_value() && ret_code.value() > 0)
  46. break;
  47. // retry after 10 usec
  48. usleep(10);
  49. }
  50. if (m_stop)
  51. break;
  52. message_t empty;
  53. socket.recv(empty);
  54. assert(empty.size() == 0);
  55. message_t request;
  56. socket.recv(request);
  57. m_mtx.lock();
  58. if (--m_runable <= 0) {
  59. add_worker();
  60. }
  61. m_mtx.unlock();
  62. // Send reply back to client
  63. // reply should be like [address, empty, msg]
  64. zmq::message_t reply;
  65. m_impl->solve_request(request, reply);
  66. socket.send(address, send_flags::sndmore);
  67. socket.send(empty, send_flags::sndmore);
  68. socket.send(reply, send_flags::dontwait);
  69. m_mtx.lock();
  70. ++m_runable;
  71. m_mtx.unlock();
  72. }
  73. socket.close();
  74. }
  75. void ZmqRpcWorker::add_worker() {
  76. int size = m_worker_threads.size();
  77. m_worker_threads.emplace_back(
  78. [this, size] { this->work(to_string(size)); });
  79. ++m_runable;
  80. }
  81. ZmqRpcServer::ZmqRpcServer(string address, int port,
  82. unique_ptr<ZmqRpcServerImpl> impl)
  83. : m_ctx(1),
  84. m_impl(std::move(impl)),
  85. m_address(address),
  86. m_port(port),
  87. m_frontend(m_ctx, ZMQ_ROUTER),
  88. m_backend(m_ctx, ZMQ_ROUTER),
  89. m_workers(&m_ctx, m_impl.get()) {
  90. try {
  91. char full_addr[100];
  92. size_t size = sizeof(full_addr);
  93. sprintf(full_addr, "%s:%d", m_address.c_str(), m_port);
  94. m_frontend.bind(full_addr);
  95. m_frontend.getsockopt(ZMQ_LAST_ENDPOINT, &full_addr, &size);
  96. m_port = 0;
  97. int pow = 1, len = strlen(full_addr);
  98. for (int i = len - 1; i >= 0; i--) {
  99. if (full_addr[i] == ':') break;
  100. m_port += (full_addr[i] - '0') * pow;
  101. pow *= 10;
  102. }
  103. } catch(...) {
  104. m_port = -1;
  105. }
  106. m_backend.bind("inproc://workers");
  107. }
  108. void ZmqRpcServer::run() {
  109. if(m_port == -1) return;
  110. m_main_thread = make_unique<thread>([this] { this->work(); });
  111. }
  112. void ZmqRpcServer::close() {
  113. if(m_port == -1) return;
  114. m_stop = true;
  115. if (m_main_thread->joinable())
  116. m_main_thread->join();
  117. m_ctx.close();
  118. }
  119. void ZmqRpcServer::work() {
  120. m_workers.run();
  121. queue<string> worker_queue;
  122. while (!m_stop) {
  123. zmq_pollitem_t items[] = {{m_backend, 0, ZMQ_POLLIN, 0},
  124. {m_frontend, 0, ZMQ_POLLIN, 0}};
  125. int ret_code = zmq_poll(items, !worker_queue.empty() ? 2 : 1, 10);
  126. if (ret_code == -1)
  127. continue;
  128. if (items[0].revents & ZMQ_POLLIN) {
  129. message_t address;
  130. m_backend.recv(address);
  131. worker_queue.push({(char*)address.data(), address.size()});
  132. message_t empty;
  133. m_backend.recv(empty);
  134. assert(empty.size() == 0);
  135. // the third frame is READY or a client address
  136. message_t client_address;
  137. m_backend.recv(client_address);
  138. string tmp((char*)client_address.data(), client_address.size());
  139. if (strcmp(tmp.c_str(), "READY") != 0) {
  140. empty.rebuild();
  141. m_backend.recv(empty);
  142. assert(empty.size() == 0);
  143. message_t respones;
  144. m_backend.recv(respones);
  145. m_frontend.send(client_address, send_flags::sndmore);
  146. m_frontend.send(empty, send_flags::sndmore);
  147. m_frontend.send(respones, send_flags::dontwait);
  148. }
  149. }
  150. if (items[1].revents & ZMQ_POLLIN) {
  151. message_t address;
  152. m_frontend.recv(address);
  153. message_t empty;
  154. m_frontend.recv(empty);
  155. assert(empty.size() == 0);
  156. message_t request;
  157. m_frontend.recv(request);
  158. string worker_uid = worker_queue.front();
  159. worker_queue.pop();
  160. message_t uid(worker_uid.data(), worker_uid.length());
  161. m_backend.send(uid, send_flags::sndmore);
  162. m_backend.send(empty, send_flags::sndmore);
  163. m_backend.send(address, send_flags::sndmore);
  164. m_backend.send(empty, send_flags::sndmore);
  165. m_backend.send(request, send_flags::dontwait);
  166. }
  167. }
  168. m_workers.close();
  169. m_frontend.close();
  170. m_backend.close();
  171. }
  172. ZmqRpcClient::ZmqRpcClient(string address) : m_address(address), m_ctx(1) {}
  173. socket_t* ZmqRpcClient::new_socket() {
  174. m_own_sockets.emplace_back(make_unique<socket_t>(m_ctx, ZMQ_REQ));
  175. socket_t* ptr = m_own_sockets.back().get();
  176. ptr->connect(m_address);
  177. return ptr;
  178. }
  179. socket_t* ZmqRpcClient::get_socket() {
  180. unique_lock<mutex> lk{m_queue_mtx};
  181. if (m_avaliable_sockets.empty()) {
  182. return new_socket();
  183. }
  184. socket_t* ptr = m_avaliable_sockets.front();
  185. m_avaliable_sockets.pop();
  186. return ptr;
  187. }
  188. void ZmqRpcClient::add_socket(socket_t* socket) {
  189. unique_lock<mutex> lk{m_queue_mtx};
  190. m_avaliable_sockets.push(socket);
  191. }
  192. void ZmqRpcClient::request(message_t& request, message_t& reply) {
  193. socket_t* client = get_socket();
  194. client->send(request, send_flags::dontwait);
  195. client->recv(reply);
  196. add_socket(client);
  197. }
  198. #endif

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台