Browse Source

fix(mgb/opr-mm): remove PeerDesc from RemoteSend and RemoteRecv

GitOrigin-RevId: b7a7bbd0da
release-0.6
Megvii Engine Team 5 years ago
parent
commit
cf3a55ce17
4 changed files with 58 additions and 71 deletions
  1. +3
    -4
      python_module/src/cpp/opr_defs.cpp
  2. +25
    -25
      src/opr-mm/impl/io_remote.cpp
  3. +14
    -23
      src/opr-mm/include/megbrain/opr/io_remote.h
  4. +16
    -19
      src/opr-mm/test/io_remote.cpp

+ 3
- 4
python_module/src/cpp/opr_defs.cpp View File

@@ -72,10 +72,10 @@ SymbolVar _Opr::remote_send(
const std::string& key, SymbolVar var, const std::string& key, SymbolVar var,
const bool is_grad, const bool is_grad,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
return RemoteSend::make({key, RemoteIOBase::Type::SEND, is_grad}, var,
return RemoteSend::make(key, var,
std::make_shared<GroupClientProxy>(ssprintf( std::make_shared<GroupClientProxy>(ssprintf(
"%s:%d", server_addr.c_str(), port)), "%s:%d", server_addr.c_str(), port)),
config);
is_grad, config);
} }


SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
@@ -85,8 +85,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
const TensorShape ishape = npy::vec2shape(shape); const TensorShape ishape = npy::vec2shape(shape);
const DType idtype = npy::dtype_np2mgb(dtype); const DType idtype = npy::dtype_np2mgb(dtype);


return RemoteRecv::make({key, RemoteIOBase::Type::RECV, false},
graph.get(),
return RemoteRecv::make(key, graph.get(),
std::make_shared<GroupClientProxy>( std::make_shared<GroupClientProxy>(
ssprintf("%s:%d", server_addr.c_str(), port)), ssprintf("%s:%d", server_addr.c_str(), port)),
config, ishape, idtype); config, ishape, idtype);


+ 25
- 25
src/opr-mm/impl/io_remote.cpp View File

@@ -26,27 +26,28 @@ cudaStream_t get_stream(VarNode* var) {


MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);


RemoteSend::RemoteSend(const PeerDesc& peer, VarNode* var,
RemoteSend::RemoteSend(const std::string& key, VarNode* var,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config) :
Super(var->owner_graph(), config, "remote_send", {var}) {
m_peer = peer;
bool is_grad, const OperatorNodeConfig& config) :
Super(var->owner_graph(), config, "remote_send", {var}),
m_is_grad(is_grad) {
m_key = key;
m_group_client = group_client; m_group_client = group_client;


add_input({var}); add_input({var});
auto ovar = add_output(None); auto ovar = add_output(None);
if (!peer.is_grad) {
if (!m_is_grad) {
ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT); .add_flag(VarNode::Flag::VOLATILE_CONTENT);
} }
add_equivalence_component<ScalarHash<void*>>(this); add_equivalence_component<ScalarHash<void*>>(this);
} }


SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var,
SymbolVar RemoteSend::make(const std::string& key, SymbolVar var,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config) {
return var.insert_single_output_opr<RemoteSend>(peer, var.node(),
group_client, config);
bool is_grad, const OperatorNodeConfig& config) {
return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client,
is_grad, config);
} }


void RemoteSend::scn_do_execute() { void RemoteSend::scn_do_execute() {
@@ -54,11 +55,11 @@ void RemoteSend::scn_do_execute() {
auto&& comp_node = output(0)->comp_node(); auto&& comp_node = output(0)->comp_node();


// rank 0 for RemoteSend // rank 0 for RemoteSend
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false,
auto reg_info = m_group_client->opr_register(m_key, 2, 0, false,
comp_node.get_uid()); comp_node.get_uid());


m_megray_comm = MegRayCommBuilder::get_megray_comm( m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);


m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));


@@ -76,7 +77,7 @@ void RemoteSend::scn_do_execute() {
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");


if (m_peer.is_grad) {
if (m_is_grad) {
auto&& dest = output(0)->dev_tensor(); auto&& dest = output(0)->dev_tensor();
if (m_output_val.empty()) { if (m_output_val.empty()) {
m_output_val.comp_node(dest.comp_node()) m_output_val.comp_node(dest.comp_node())
@@ -92,7 +93,7 @@ void RemoteSend::init_output_static_infer_desc() {
using namespace cg::static_infer; using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager(); auto&& mgr = owner_graph()->static_infer_manager();
auto do_infer = [this](TensorShape& dest, const InpVal&) { auto do_infer = [this](TensorShape& dest, const InpVal&) {
if (peer_desc().is_grad) {
if (m_is_grad) {
dest = {1}; dest = {1};
} else { } else {
dest = {0}; dest = {0};
@@ -109,9 +110,8 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
} }


MGB_IMPL_OPR_GRAD(RemoteSend) { MGB_IMPL_OPR_GRAD(RemoteSend) {
mgb_assert(opr.peer_desc().is_grad);
return RemoteRecv::make({opr.peer_desc().key + ":grad",
RemoteIOBase::Type::RECV, false},
mgb_assert(opr.is_grad());
return RemoteRecv::make(opr.key() + ":grad",
*opr.owner_graph(), opr.group_client(), *opr.owner_graph(), opr.group_client(),
OperatorNodeConfig{opr.comp_node()}.name( OperatorNodeConfig{opr.comp_node()}.name(
opr.name() + ":grad_recv"), opr.name() + ":grad_recv"),
@@ -123,13 +123,13 @@ MGB_IMPL_OPR_GRAD(RemoteSend) {


MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);


RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph,
RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) : const TensorShape& shape, DType dtype) :
Super(&graph, config, "remote_recv", {}), Super(&graph, config, "remote_recv", {}),
m_shape(shape), m_dtype(dtype) { m_shape(shape), m_dtype(dtype) {
m_peer = peer;
m_key = key;
m_group_client = group_client; m_group_client = group_client;


add_output(None) add_output(None)
@@ -139,12 +139,12 @@ RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph,
add_equivalence_component<ScalarHash<void*>>(this); add_equivalence_component<ScalarHash<void*>>(this);
} }


SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph,
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) { const TensorShape& shape, DType dtype) {
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
peer, graph, group_client, config, shape, dtype));
key, graph, group_client, config, shape, dtype));
return opr->output(0); return opr->output(0);
} }


@@ -154,11 +154,11 @@ void RemoteRecv::scn_do_execute() {


// rank 1 for RemoteRecv // rank 1 for RemoteRecv
auto reg_info = m_group_client->opr_register( auto reg_info = m_group_client->opr_register(
m_peer.key, 2, false, 1,
m_key, 2, false, 1,
comp_node.get_uid()); comp_node.get_uid());


m_megray_comm = MegRayCommBuilder::get_megray_comm( m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);


m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));


@@ -206,8 +206,8 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send(
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
auto&& opr = opr_.cast_final_safe<RemoteSend>(); auto&& opr = opr_.cast_final_safe<RemoteSend>();
return RemoteSend::make(opr.peer_desc(), inputs[0], opr.group_client(),
config)
return RemoteSend::make(opr.key(), inputs[0], opr.group_client(),
opr.is_grad(), config)
.node() .node()
->owner_opr(); ->owner_opr();
} }
@@ -218,7 +218,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<RemoteRecv>(); auto&& opr = opr_.cast_final_safe<RemoteRecv>();
return RemoteRecv::make(opr.peer_desc(), *opr.owner_graph(),
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, inputs[0]->shape(), opr.group_client(), config, inputs[0]->shape(),
inputs[0]->dtype()) inputs[0]->dtype())
.node() .node()


+ 14
- 23
src/opr-mm/include/megbrain/opr/io_remote.h View File

@@ -25,25 +25,14 @@ namespace opr {
*/ */
MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // {
public: public:
enum Type {
SEND,
RECV
};

struct PeerDesc {
std::string key;
Type type;
bool is_grad;
};

const PeerDesc& peer_desc() const { return m_peer; }
const std::string& key() const { return m_key; }


std::shared_ptr<GroupClient> group_client() const { std::shared_ptr<GroupClient> group_client() const {
return m_group_client; return m_group_client;
} }


protected: protected:
PeerDesc m_peer;
std::string m_key;
std::shared_ptr<GroupClient> m_group_client; std::shared_ptr<GroupClient> m_group_client;
std::shared_ptr<MegRay::Communicator> m_megray_comm; std::shared_ptr<MegRay::Communicator> m_megray_comm;
std::shared_ptr<MegRay::Context> m_megray_ctx; std::shared_ptr<MegRay::Context> m_megray_ctx;
@@ -53,21 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // {


/*! /*!
* \brief send a variable to remote address; a virtual output is produced * \brief send a variable to remote address; a virtual output is produced
* for expressing dependency
* for expressing dependency
*/ */
MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // {
public: public:
RemoteSend(const PeerDesc& peer, VarNode* var,
RemoteSend(const std::string& key, VarNode* var,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config);
bool is_grad, const OperatorNodeConfig& config);


static SymbolVar make( static SymbolVar make(
const PeerDesc& peer, SymbolVar var,
const std::string& key, SymbolVar var,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config = {});
bool is_grad, const OperatorNodeConfig& config = {});

bool is_grad() const { return m_is_grad; }


private: private:
HostTensorND m_output_val; HostTensorND m_output_val;
bool m_is_grad;


void scn_do_execute() override; void scn_do_execute() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
@@ -75,19 +67,18 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // {
}; };


/*! /*!
* \brief receive from multiple remote addresses and write to a var
*
* Target computing node of the var must be specified in config
* \brief receive a variable from remote address; target computing node
* of the var must be specified in config
*/ */
MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
public: public:
RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph,
RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape, const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype); DType dtype);


static SymbolVar make( static SymbolVar make(
const PeerDesc& peer, cg::ComputingGraph& graph,
const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape, const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype); DType dtype);


+ 16
- 19
src/opr-mm/test/io_remote.cpp View File

@@ -20,9 +20,6 @@


using namespace mgb; using namespace mgb;


const auto send_tag = opr::RemoteIOBase::Type::SEND;
const auto recv_tag = opr::RemoteIOBase::Type::RECV;

TEST(TestOprIORemote, Identity) { TEST(TestOprIORemote, Identity) {
REQUIRE_GPU(2); REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0"); auto cn0 = CompNode::load("gpu0");
@@ -36,8 +33,8 @@ TEST(TestOprIORemote, Identity) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();


auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0);
auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
auto xr = opr::RemoteSend::make("x", x, client, false);
auto y = opr::RemoteRecv::make("x", *graph.get(),
client, {cn1}, host_x->shape(), client, {cn1}, host_x->shape(),
host_x->dtype()); host_x->dtype());


@@ -59,7 +56,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
sys::set_thread_name("sender"); sys::set_thread_name("sender");
auto x = opr::Host2DeviceCopy::make(*graph, host_x), auto x = opr::Host2DeviceCopy::make(*graph, host_x),
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
xr = opr::RemoteSend::make("x", x, client, false);
auto func = graph->compile({{xr, {}}}); auto func = graph->compile({{xr, {}}});
func->execute(); func->execute();
}; };
@@ -67,7 +64,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
auto receiver = [&]() { auto receiver = [&]() {
sys::set_thread_name("receiver"); sys::set_thread_name("receiver");
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[0]}, host_x->shape(), client, {cns[0]}, host_x->shape(),
host_x->dtype()); host_x->dtype());
auto func = graph->compile({make_callback_copy(x, host_x_get)}); auto func = graph->compile({make_callback_copy(x, host_x_get)});
@@ -92,7 +89,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
sys::set_thread_name("sender"); sys::set_thread_name("sender");
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1,
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
xr = opr::RemoteSend::make("x", x, client, false);
auto func = graph->compile({{xr, {}}}); auto func = graph->compile({{xr, {}}});
func->execute(); func->execute();
}; };
@@ -100,7 +97,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
auto receiver = [&]() { auto receiver = [&]() {
sys::set_thread_name("receiver"); sys::set_thread_name("receiver");
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[0]}, host_x->shape(), client, {cns[0]}, host_x->shape(),
host_x->dtype()); host_x->dtype());
auto func = auto func =
@@ -124,14 +121,14 @@ TEST(TestOprIORemote, APlusB) {


auto sender = [&]() { auto sender = [&]() {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto z = opr::RemoteRecv::make({"z", recv_tag, false}, *graph.get(),
auto z = opr::RemoteRecv::make("z", *graph.get(),
client, {cns[0]}, host_x->shape(), client, {cns[0]}, host_x->shape(),
host_x->dtype()); host_x->dtype());
auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"),
y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"),
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client)
xr = opr::RemoteSend::make("x", x, client, false)
.rename("xr"), .rename("xr"),
yr = opr::RemoteSend::make({"y", send_tag, false}, y, client)
yr = opr::RemoteSend::make("y", y, client, false)
.rename("yr"); .rename("yr");
auto func = graph->compile( auto func = graph->compile(
{{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); {{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)});
@@ -142,14 +139,14 @@ TEST(TestOprIORemote, APlusB) {


auto receiver = [&]() { auto receiver = [&]() {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[1]}, host_x->shape(), client, {cns[1]}, host_x->shape(),
host_x->dtype()), host_x->dtype()),
y = opr::RemoteRecv::make({"y", recv_tag, false}, *graph.get(),
y = opr::RemoteRecv::make("y", *graph.get(),
client, {cns[1]}, host_y->shape(), client, {cns[1]}, host_y->shape(),
host_y->dtype()), host_y->dtype()),
z = x + y, z = x + y,
zr = opr::RemoteSend::make({"z", send_tag, false}, z, client);
zr = opr::RemoteSend::make("z", z, client, false);
auto func = graph->compile({{zr, {}}}); auto func = graph->compile({{zr, {}}});
func->execute(); func->execute();
}; };
@@ -177,10 +174,10 @@ TEST(TestOprIORemote, SendGrad) {
sys::set_thread_name("sender"); sys::set_thread_name("sender");
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x), auto x = opr::Host2DeviceCopy::make(*graph, host_x),
loss = opr::RemoteSend::make({"loss", send_tag, false}, x, client);
loss = opr::RemoteSend::make("loss", x, client, false);
ASSERT_TRUE(!loss.shape().ndim && ASSERT_TRUE(!loss.shape().ndim &&
loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
loss = opr::RemoteSend::make({"loss", send_tag, true}, x, client);
loss = opr::RemoteSend::make("loss", x, client, true);
auto gx = cg::grad(loss, x); auto gx = cg::grad(loss, x);
set_priority(loss, 0); set_priority(loss, 0);
set_priority(gx, 1); set_priority(gx, 1);
@@ -197,10 +194,10 @@ TEST(TestOprIORemote, SendGrad) {
auto receiver = [&]() { auto receiver = [&]() {
sys::set_thread_name("receiver"); sys::set_thread_name("receiver");
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make({"loss", recv_tag, false}, *graph.get(),
auto x = opr::RemoteRecv::make("loss", *graph.get(),
client, {cns[1]}, host_x->shape(), client, {cns[1]}, host_x->shape(),
host_x->dtype()); host_x->dtype());
auto y = opr::RemoteSend::make({"loss:grad", send_tag, false}, x + 1, client);
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false);
auto func = graph->compile({{y, {}}}); auto func = graph->compile({{y, {}}});
func->execute(); func->execute();
}; };


Loading…
Cancel
Save