#include "zmq_rpc.h" #include "megbrain/exception.h" #include "megbrain_config.h" #if MGB_CUDA #include #include #include #include #include #include #include #include #include #include using namespace std; using namespace zmq; using namespace ZmqRpc; ZmqRpcWorker::ZmqRpcWorker(context_t* context, ZmqRpcServerImpl* impl) : m_ctx(context), m_runable(0), m_impl(impl) {} void ZmqRpcWorker::run() { add_worker(); } void ZmqRpcWorker::close() { m_stop = true; for (auto& thread : m_worker_threads) { thread.join(); } } void ZmqRpcWorker::work(string uid) { // req work pattern: send recv send recv ... zmq::socket_t socket(*m_ctx, ZMQ_REQ); socket.setsockopt(ZMQ_IDENTITY, uid.data(), uid.size()); socket.connect("inproc://workers"); // send READY to notify server that worker is ready zmq::message_t ready(6); memcpy(ready.data(), "READY", 6); socket.send(ready, send_flags::dontwait); while (!m_stop) { // Wait for next request from client // request should be like [address, empty, msg] message_t address; recv_result_t ret_code; while (!m_stop) { ret_code = socket.recv(address, recv_flags::dontwait); if (ret_code.has_value() && ret_code.value() > 0) break; // retry after 10 usec usleep(10); } if (m_stop) break; message_t empty; socket.recv(empty); assert(empty.size() == 0); message_t request; socket.recv(request); m_mtx.lock(); if (--m_runable <= 0) { add_worker(); } m_mtx.unlock(); // Send reply back to client // reply should be like [address, empty, msg] zmq::message_t reply; m_impl->solve_request(request, reply); socket.send(address, send_flags::sndmore); socket.send(empty, send_flags::sndmore); socket.send(reply, send_flags::dontwait); m_mtx.lock(); ++m_runable; m_mtx.unlock(); } socket.close(); } void ZmqRpcWorker::add_worker() { int size = m_worker_threads.size(); m_worker_threads.emplace_back( [this, size] { this->work(to_string(size)); }); ++m_runable; } ZmqRpcServer::ZmqRpcServer(string address, int port, unique_ptr impl) : m_ctx(1), m_impl(std::move(impl)), m_address(address), m_port(port), m_frontend(m_ctx, ZMQ_ROUTER), m_backend(m_ctx, ZMQ_ROUTER), m_workers(&m_ctx, m_impl.get()) { try { char full_addr[100]; size_t size = sizeof(full_addr); sprintf(full_addr, "%s:%d", m_address.c_str(), m_port); m_frontend.bind(full_addr); m_frontend.getsockopt(ZMQ_LAST_ENDPOINT, &full_addr, &size); m_port = 0; int pow = 1, len = strlen(full_addr); for (int i = len - 1; i >= 0; i--) { if (full_addr[i] == ':') break; m_port += (full_addr[i] - '0') * pow; pow *= 10; } } catch(...) { m_port = -1; } m_backend.bind("inproc://workers"); } void ZmqRpcServer::run() { if(m_port == -1) return; m_main_thread = make_unique([this] { this->work(); }); } void ZmqRpcServer::close() { if(m_port == -1) return; m_stop = true; if (m_main_thread->joinable()) m_main_thread->join(); m_ctx.close(); } void ZmqRpcServer::work() { m_workers.run(); queue worker_queue; while (!m_stop) { zmq_pollitem_t items[] = {{m_backend, 0, ZMQ_POLLIN, 0}, {m_frontend, 0, ZMQ_POLLIN, 0}}; int ret_code = zmq_poll(items, !worker_queue.empty() ? 2 : 1, 10); if (ret_code == -1) continue; if (items[0].revents & ZMQ_POLLIN) { message_t address; m_backend.recv(address); worker_queue.push({(char*)address.data(), address.size()}); message_t empty; m_backend.recv(empty); assert(empty.size() == 0); // the third frame is READY or a client address message_t client_address; m_backend.recv(client_address); string tmp((char*)client_address.data(), client_address.size()); if (strcmp(tmp.c_str(), "READY") != 0) { empty.rebuild(); m_backend.recv(empty); assert(empty.size() == 0); message_t respones; m_backend.recv(respones); m_frontend.send(client_address, send_flags::sndmore); m_frontend.send(empty, send_flags::sndmore); m_frontend.send(respones, send_flags::dontwait); } } if (items[1].revents & ZMQ_POLLIN) { message_t address; m_frontend.recv(address); message_t empty; m_frontend.recv(empty); assert(empty.size() == 0); message_t request; m_frontend.recv(request); string worker_uid = worker_queue.front(); worker_queue.pop(); message_t uid(worker_uid.data(), worker_uid.length()); m_backend.send(uid, send_flags::sndmore); m_backend.send(empty, send_flags::sndmore); m_backend.send(address, send_flags::sndmore); m_backend.send(empty, send_flags::sndmore); m_backend.send(request, send_flags::dontwait); } } m_workers.close(); m_frontend.close(); m_backend.close(); } ZmqRpcClient::ZmqRpcClient(string address) : m_address(address), m_ctx(1) {} socket_t* ZmqRpcClient::new_socket() { m_own_sockets.emplace_back(make_unique(m_ctx, ZMQ_REQ)); socket_t* ptr = m_own_sockets.back().get(); ptr->connect(m_address); return ptr; } socket_t* ZmqRpcClient::get_socket() { unique_lock lk{m_queue_mtx}; if (m_avaliable_sockets.empty()) { return new_socket(); } socket_t* ptr = m_avaliable_sockets.front(); m_avaliable_sockets.pop(); return ptr; } void ZmqRpcClient::add_socket(socket_t* socket) { unique_lock lk{m_queue_mtx}; m_avaliable_sockets.push(socket); } void ZmqRpcClient::request(message_t& request, message_t& reply) { socket_t* client = get_socket(); client->send(request, send_flags::dontwait); client->recv(reply); add_socket(client); } #endif