From cf3a55ce1750822e658c66c9cdb5a17079865f79 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 6 Jul 2020 18:25:32 +0800 Subject: [PATCH] fix(mgb/opr-mm): remove PeerDesc from RemoteSend and RemoteRecv GitOrigin-RevId: b7a7bbd0dad4ab27d9c51c59c8011e518e79e097 --- python_module/src/cpp/opr_defs.cpp | 7 ++-- src/opr-mm/impl/io_remote.cpp | 50 ++++++++++++++--------------- src/opr-mm/include/megbrain/opr/io_remote.h | 37 ++++++++------------- src/opr-mm/test/io_remote.cpp | 35 +++++++++----------- 4 files changed, 58 insertions(+), 71 deletions(-) diff --git a/python_module/src/cpp/opr_defs.cpp b/python_module/src/cpp/opr_defs.cpp index d17ef972..db1e8feb 100644 --- a/python_module/src/cpp/opr_defs.cpp +++ b/python_module/src/cpp/opr_defs.cpp @@ -72,10 +72,10 @@ SymbolVar _Opr::remote_send( const std::string& key, SymbolVar var, const bool is_grad, const OperatorNodeConfig& config) { - return RemoteSend::make({key, RemoteIOBase::Type::SEND, is_grad}, var, + return RemoteSend::make(key, var, std::make_shared(ssprintf( "%s:%d", server_addr.c_str(), port)), - config); + is_grad, config); } SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, @@ -85,8 +85,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, const TensorShape ishape = npy::vec2shape(shape); const DType idtype = npy::dtype_np2mgb(dtype); - return RemoteRecv::make({key, RemoteIOBase::Type::RECV, false}, - graph.get(), + return RemoteRecv::make(key, graph.get(), std::make_shared( ssprintf("%s:%d", server_addr.c_str(), port)), config, ishape, idtype); diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index a4bb714e..938f816a 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -26,27 +26,28 @@ cudaStream_t get_stream(VarNode* var) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); -RemoteSend::RemoteSend(const PeerDesc& peer, VarNode* var, +RemoteSend::RemoteSend(const std::string& key, VarNode* var, std::shared_ptr group_client, - const OperatorNodeConfig& config) : - Super(var->owner_graph(), config, "remote_send", {var}) { - m_peer = peer; + bool is_grad, const OperatorNodeConfig& config) : + Super(var->owner_graph(), config, "remote_send", {var}), + m_is_grad(is_grad) { + m_key = key; m_group_client = group_client; add_input({var}); auto ovar = add_output(None); - if (!peer.is_grad) { + if (!m_is_grad) { ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) .add_flag(VarNode::Flag::VOLATILE_CONTENT); } add_equivalence_component>(this); } -SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var, +SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, std::shared_ptr group_client, - const OperatorNodeConfig& config) { - return var.insert_single_output_opr(peer, var.node(), - group_client, config); + bool is_grad, const OperatorNodeConfig& config) { + return var.insert_single_output_opr(key, var.node(), group_client, + is_grad, config); } void RemoteSend::scn_do_execute() { @@ -54,11 +55,11 @@ void RemoteSend::scn_do_execute() { auto&& comp_node = output(0)->comp_node(); // rank 0 for RemoteSend - auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, + auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, comp_node.get_uid()); m_megray_comm = MegRayCommBuilder::get_megray_comm( - reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); + reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); @@ -76,7 +77,7 @@ void RemoteSend::scn_do_execute() { auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); - if (m_peer.is_grad) { + if (m_is_grad) { auto&& dest = output(0)->dev_tensor(); if (m_output_val.empty()) { m_output_val.comp_node(dest.comp_node()) @@ -92,7 +93,7 @@ void RemoteSend::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); auto do_infer = [this](TensorShape& dest, const InpVal&) { - if (peer_desc().is_grad) { + if (m_is_grad) { dest = {1}; } else { dest = {0}; @@ -109,9 +110,8 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { } MGB_IMPL_OPR_GRAD(RemoteSend) { - mgb_assert(opr.peer_desc().is_grad); - return RemoteRecv::make({opr.peer_desc().key + ":grad", - RemoteIOBase::Type::RECV, false}, + mgb_assert(opr.is_grad()); + return RemoteRecv::make(opr.key() + ":grad", *opr.owner_graph(), opr.group_client(), OperatorNodeConfig{opr.comp_node()}.name( opr.name() + ":grad_recv"), @@ -123,13 +123,13 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); -RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, +RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, DType dtype) : Super(&graph, config, "remote_recv", {}), m_shape(shape), m_dtype(dtype) { - m_peer = peer; + m_key = key; m_group_client = group_client; add_output(None) @@ -139,12 +139,12 @@ RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, add_equivalence_component>(this); } -SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph, +SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, DType dtype) { auto opr = graph.insert_opr(std::make_unique( - peer, graph, group_client, config, shape, dtype)); + key, graph, group_client, config, shape, dtype)); return opr->output(0); } @@ -154,11 +154,11 @@ void RemoteRecv::scn_do_execute() { // rank 1 for RemoteRecv auto reg_info = m_group_client->opr_register( - m_peer.key, 2, false, 1, + m_key, 2, false, 1, comp_node.get_uid()); m_megray_comm = MegRayCommBuilder::get_megray_comm( - reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); + reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); @@ -206,8 +206,8 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send( const OperatorNodeConfig& config) { mgb_assert(inputs.size() == 1); auto&& opr = opr_.cast_final_safe(); - return RemoteSend::make(opr.peer_desc(), inputs[0], opr.group_client(), - config) + return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), + opr.is_grad(), config) .node() ->owner_opr(); } @@ -218,7 +218,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); - return RemoteRecv::make(opr.peer_desc(), *opr.owner_graph(), + return RemoteRecv::make(opr.key(), *opr.owner_graph(), opr.group_client(), config, inputs[0]->shape(), inputs[0]->dtype()) .node() diff --git a/src/opr-mm/include/megbrain/opr/io_remote.h b/src/opr-mm/include/megbrain/opr/io_remote.h index 2788e44e..b365e476 100644 --- a/src/opr-mm/include/megbrain/opr/io_remote.h +++ b/src/opr-mm/include/megbrain/opr/io_remote.h @@ -25,25 +25,14 @@ namespace opr { */ MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { public: - enum Type { - SEND, - RECV - }; - - struct PeerDesc { - std::string key; - Type type; - bool is_grad; - }; - - const PeerDesc& peer_desc() const { return m_peer; } + const std::string& key() const { return m_key; } std::shared_ptr group_client() const { return m_group_client; } protected: - PeerDesc m_peer; + std::string m_key; std::shared_ptr m_group_client; std::shared_ptr m_megray_comm; std::shared_ptr m_megray_ctx; @@ -53,21 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { /*! * \brief send a variable to remote address; a virtual output is produced - * for expressing dependency + * for expressing dependency */ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { public: - RemoteSend(const PeerDesc& peer, VarNode* var, + RemoteSend(const std::string& key, VarNode* var, std::shared_ptr group_client, - const OperatorNodeConfig& config); + bool is_grad, const OperatorNodeConfig& config); static SymbolVar make( - const PeerDesc& peer, SymbolVar var, + const std::string& key, SymbolVar var, std::shared_ptr group_client, - const OperatorNodeConfig& config = {}); + bool is_grad, const OperatorNodeConfig& config = {}); + + bool is_grad() const { return m_is_grad; } private: HostTensorND m_output_val; + bool m_is_grad; void scn_do_execute() override; void init_output_static_infer_desc() override; @@ -75,19 +67,18 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { }; /*! - * \brief receive from multiple remote addresses and write to a var - * - * Target computing node of the var must be specified in config + * \brief receive a variable from remote address; target computing node + * of the var must be specified in config */ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { public: - RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, + RemoteRecv(const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, DType dtype); static SymbolVar make( - const PeerDesc& peer, cg::ComputingGraph& graph, + const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, DType dtype); diff --git a/src/opr-mm/test/io_remote.cpp b/src/opr-mm/test/io_remote.cpp index f4302dd2..a68ffd9a 100644 --- a/src/opr-mm/test/io_remote.cpp +++ b/src/opr-mm/test/io_remote.cpp @@ -20,9 +20,6 @@ using namespace mgb; -const auto send_tag = opr::RemoteIOBase::Type::SEND; -const auto recv_tag = opr::RemoteIOBase::Type::RECV; - TEST(TestOprIORemote, Identity) { REQUIRE_GPU(2); auto cn0 = CompNode::load("gpu0"); @@ -36,8 +33,8 @@ TEST(TestOprIORemote, Identity) { auto graph = ComputingGraph::make(); auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); - auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); - auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), + auto xr = opr::RemoteSend::make("x", x, client, false); + auto y = opr::RemoteRecv::make("x", *graph.get(), client, {cn1}, host_x->shape(), host_x->dtype()); @@ -59,7 +56,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { auto graph = ComputingGraph::make(); sys::set_thread_name("sender"); auto x = opr::Host2DeviceCopy::make(*graph, host_x), - xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); + xr = opr::RemoteSend::make("x", x, client, false); auto func = graph->compile({{xr, {}}}); func->execute(); }; @@ -67,7 +64,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { auto receiver = [&]() { sys::set_thread_name("receiver"); auto graph = ComputingGraph::make(); - auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), + auto x = opr::RemoteRecv::make("x", *graph.get(), client, {cns[0]}, host_x->shape(), host_x->dtype()); auto func = graph->compile({make_callback_copy(x, host_x_get)}); @@ -92,7 +89,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { sys::set_thread_name("sender"); auto graph = ComputingGraph::make(); auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, - xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); + xr = opr::RemoteSend::make("x", x, client, false); auto func = graph->compile({{xr, {}}}); func->execute(); }; @@ -100,7 +97,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { auto receiver = [&]() { sys::set_thread_name("receiver"); auto graph = ComputingGraph::make(); - auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), + auto x = opr::RemoteRecv::make("x", *graph.get(), client, {cns[0]}, host_x->shape(), host_x->dtype()); auto func = @@ -124,14 +121,14 @@ TEST(TestOprIORemote, APlusB) { auto sender = [&]() { auto graph = ComputingGraph::make(); - auto z = opr::RemoteRecv::make({"z", recv_tag, false}, *graph.get(), + auto z = opr::RemoteRecv::make("z", *graph.get(), client, {cns[0]}, host_x->shape(), host_x->dtype()); auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), - xr = opr::RemoteSend::make({"x", send_tag, false}, x, client) + xr = opr::RemoteSend::make("x", x, client, false) .rename("xr"), - yr = opr::RemoteSend::make({"y", send_tag, false}, y, client) + yr = opr::RemoteSend::make("y", y, client, false) .rename("yr"); auto func = graph->compile( {{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); @@ -142,14 +139,14 @@ TEST(TestOprIORemote, APlusB) { auto receiver = [&]() { auto graph = ComputingGraph::make(); - auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), + auto x = opr::RemoteRecv::make("x", *graph.get(), client, {cns[1]}, host_x->shape(), host_x->dtype()), - y = opr::RemoteRecv::make({"y", recv_tag, false}, *graph.get(), + y = opr::RemoteRecv::make("y", *graph.get(), client, {cns[1]}, host_y->shape(), host_y->dtype()), z = x + y, - zr = opr::RemoteSend::make({"z", send_tag, false}, z, client); + zr = opr::RemoteSend::make("z", z, client, false); auto func = graph->compile({{zr, {}}}); func->execute(); }; @@ -177,10 +174,10 @@ TEST(TestOprIORemote, SendGrad) { sys::set_thread_name("sender"); auto graph = ComputingGraph::make(); auto x = opr::Host2DeviceCopy::make(*graph, host_x), - loss = opr::RemoteSend::make({"loss", send_tag, false}, x, client); + loss = opr::RemoteSend::make("loss", x, client, false); ASSERT_TRUE(!loss.shape().ndim && loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); - loss = opr::RemoteSend::make({"loss", send_tag, true}, x, client); + loss = opr::RemoteSend::make("loss", x, client, true); auto gx = cg::grad(loss, x); set_priority(loss, 0); set_priority(gx, 1); @@ -197,10 +194,10 @@ TEST(TestOprIORemote, SendGrad) { auto receiver = [&]() { sys::set_thread_name("receiver"); auto graph = ComputingGraph::make(); - auto x = opr::RemoteRecv::make({"loss", recv_tag, false}, *graph.get(), + auto x = opr::RemoteRecv::make("loss", *graph.get(), client, {cns[1]}, host_x->shape(), host_x->dtype()); - auto y = opr::RemoteSend::make({"loss:grad", send_tag, false}, x + 1, client); + auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false); auto func = graph->compile({{y, {}}}); func->execute(); };