You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zmq_rpc.h 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #pragma once
  2. #include "megbrain_build_config.h"
  3. #include <unistd.h>
  4. #include <cassert>
  5. #include <iostream>
  6. #include <memory>
  7. #include <mutex>
  8. #include <queue>
  9. #include <string>
  10. #include <thread>
  11. #include <unordered_map>
  12. #include <vector>
  13. #include <zmq.hpp>
  14. namespace ZmqRpc {
  15. class ZmqRpcServerImpl {
  16. public:
  17. virtual void solve_request(zmq::message_t& request,
  18. zmq::message_t& reply) = 0;
  19. virtual ~ZmqRpcServerImpl() = default;
  20. };
  21. class ZmqRpcWorker {
  22. public:
  23. ZmqRpcWorker() = delete;
  24. ZmqRpcWorker(zmq::context_t* context, ZmqRpcServerImpl* impl);
  25. void run();
  26. void close();
  27. protected:
  28. void work(std::string uid);
  29. void add_worker();
  30. private:
  31. std::vector<std::thread> m_worker_threads;
  32. std::mutex m_mtx;
  33. zmq::context_t* m_ctx;
  34. int m_runable;
  35. ZmqRpcServerImpl* m_impl;
  36. bool m_stop = false;
  37. };
  38. class ZmqRpcServer {
  39. public:
  40. ZmqRpcServer() = delete;
  41. ZmqRpcServer(std::string address, int port,
  42. std::unique_ptr<ZmqRpcServerImpl> impl);
  43. ~ZmqRpcServer() { close(); }
  44. void run();
  45. void close();
  46. int port() { return m_port; }
  47. protected:
  48. void work();
  49. private:
  50. zmq::context_t m_ctx;
  51. std::unique_ptr<ZmqRpcServerImpl> m_impl;
  52. std::string m_address;
  53. int m_port;
  54. zmq::socket_t m_frontend, m_backend;
  55. ZmqRpcWorker m_workers;
  56. std::unique_ptr<std::thread> m_main_thread;
  57. bool m_stop = false;
  58. };
  59. class ZmqRpcClient {
  60. public:
  61. ZmqRpcClient() = delete;
  62. ZmqRpcClient(std::string address);
  63. void request(zmq::message_t& request, zmq::message_t& reply);
  64. static ZmqRpcClient* get_client(std::string addr) {
  65. static std::unordered_map<std::string, std::unique_ptr<ZmqRpcClient>>
  66. addr2handler;
  67. static std::mutex mtx;
  68. std::unique_lock<std::mutex> lk{mtx};
  69. auto it = addr2handler.emplace(addr, nullptr);
  70. if (!it.second) {
  71. assert(it.first->second->m_address == addr);
  72. return it.first->second.get();
  73. } else {
  74. auto handler = std::make_unique<ZmqRpcClient>(addr);
  75. auto handler_ptr = handler.get();
  76. it.first->second = std::move(handler);
  77. return handler_ptr;
  78. }
  79. }
  80. private:
  81. zmq::socket_t* new_socket();
  82. zmq::socket_t* get_socket();
  83. void add_socket(zmq::socket_t* socket);
  84. std::mutex m_queue_mtx;
  85. std::string m_address;
  86. zmq::context_t m_ctx;
  87. std::queue<zmq::socket_t*> m_avaliable_sockets;
  88. std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets;
  89. };
  90. } // namespace ZmqRpc

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台