GitOrigin-RevId: b7a7bbd0da
release-0.6
@@ -72,10 +72,10 @@ SymbolVar _Opr::remote_send( | |||
const std::string& key, SymbolVar var, | |||
const bool is_grad, | |||
const OperatorNodeConfig& config) { | |||
return RemoteSend::make({key, RemoteIOBase::Type::SEND, is_grad}, var, | |||
return RemoteSend::make(key, var, | |||
std::make_shared<GroupClientProxy>(ssprintf( | |||
"%s:%d", server_addr.c_str(), port)), | |||
config); | |||
is_grad, config); | |||
} | |||
SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||
@@ -85,8 +85,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||
const TensorShape ishape = npy::vec2shape(shape); | |||
const DType idtype = npy::dtype_np2mgb(dtype); | |||
return RemoteRecv::make({key, RemoteIOBase::Type::RECV, false}, | |||
graph.get(), | |||
return RemoteRecv::make(key, graph.get(), | |||
std::make_shared<GroupClientProxy>( | |||
ssprintf("%s:%d", server_addr.c_str(), port)), | |||
config, ishape, idtype); | |||
@@ -26,27 +26,28 @@ cudaStream_t get_stream(VarNode* var) { | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||
RemoteSend::RemoteSend(const PeerDesc& peer, VarNode* var, | |||
RemoteSend::RemoteSend(const std::string& key, VarNode* var, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config) : | |||
Super(var->owner_graph(), config, "remote_send", {var}) { | |||
m_peer = peer; | |||
bool is_grad, const OperatorNodeConfig& config) : | |||
Super(var->owner_graph(), config, "remote_send", {var}), | |||
m_is_grad(is_grad) { | |||
m_key = key; | |||
m_group_client = group_client; | |||
add_input({var}); | |||
auto ovar = add_output(None); | |||
if (!peer.is_grad) { | |||
if (!m_is_grad) { | |||
ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||
.add_flag(VarNode::Flag::VOLATILE_CONTENT); | |||
} | |||
add_equivalence_component<ScalarHash<void*>>(this); | |||
} | |||
SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var, | |||
SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config) { | |||
return var.insert_single_output_opr<RemoteSend>(peer, var.node(), | |||
group_client, config); | |||
bool is_grad, const OperatorNodeConfig& config) { | |||
return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client, | |||
is_grad, config); | |||
} | |||
void RemoteSend::scn_do_execute() { | |||
@@ -54,11 +55,11 @@ void RemoteSend::scn_do_execute() { | |||
auto&& comp_node = output(0)->comp_node(); | |||
// rank 0 for RemoteSend | |||
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | |||
auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, | |||
comp_node.get_uid()); | |||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | |||
@@ -76,7 +77,7 @@ void RemoteSend::scn_do_execute() { | |||
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); | |||
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | |||
if (m_peer.is_grad) { | |||
if (m_is_grad) { | |||
auto&& dest = output(0)->dev_tensor(); | |||
if (m_output_val.empty()) { | |||
m_output_val.comp_node(dest.comp_node()) | |||
@@ -92,7 +93,7 @@ void RemoteSend::init_output_static_infer_desc() { | |||
using namespace cg::static_infer; | |||
auto&& mgr = owner_graph()->static_infer_manager(); | |||
auto do_infer = [this](TensorShape& dest, const InpVal&) { | |||
if (peer_desc().is_grad) { | |||
if (m_is_grad) { | |||
dest = {1}; | |||
} else { | |||
dest = {0}; | |||
@@ -109,9 +110,8 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { | |||
} | |||
MGB_IMPL_OPR_GRAD(RemoteSend) { | |||
mgb_assert(opr.peer_desc().is_grad); | |||
return RemoteRecv::make({opr.peer_desc().key + ":grad", | |||
RemoteIOBase::Type::RECV, false}, | |||
mgb_assert(opr.is_grad()); | |||
return RemoteRecv::make(opr.key() + ":grad", | |||
*opr.owner_graph(), opr.group_client(), | |||
OperatorNodeConfig{opr.comp_node()}.name( | |||
opr.name() + ":grad_recv"), | |||
@@ -123,13 +123,13 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||
RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
const TensorShape& shape, DType dtype) : | |||
Super(&graph, config, "remote_recv", {}), | |||
m_shape(shape), m_dtype(dtype) { | |||
m_peer = peer; | |||
m_key = key; | |||
m_group_client = group_client; | |||
add_output(None) | |||
@@ -139,12 +139,12 @@ RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
add_equivalence_component<ScalarHash<void*>>(this); | |||
} | |||
SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
const TensorShape& shape, DType dtype) { | |||
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||
peer, graph, group_client, config, shape, dtype)); | |||
key, graph, group_client, config, shape, dtype)); | |||
return opr->output(0); | |||
} | |||
@@ -154,11 +154,11 @@ void RemoteRecv::scn_do_execute() { | |||
// rank 1 for RemoteRecv | |||
auto reg_info = m_group_client->opr_register( | |||
m_peer.key, 2, false, 1, | |||
m_key, 2, false, 1, | |||
comp_node.get_uid()); | |||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | |||
@@ -206,8 +206,8 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send( | |||
const OperatorNodeConfig& config) { | |||
mgb_assert(inputs.size() == 1); | |||
auto&& opr = opr_.cast_final_safe<RemoteSend>(); | |||
return RemoteSend::make(opr.peer_desc(), inputs[0], opr.group_client(), | |||
config) | |||
return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), | |||
opr.is_grad(), config) | |||
.node() | |||
->owner_opr(); | |||
} | |||
@@ -218,7 +218,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( | |||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
auto&& opr = opr_.cast_final_safe<RemoteRecv>(); | |||
return RemoteRecv::make(opr.peer_desc(), *opr.owner_graph(), | |||
return RemoteRecv::make(opr.key(), *opr.owner_graph(), | |||
opr.group_client(), config, inputs[0]->shape(), | |||
inputs[0]->dtype()) | |||
.node() | |||
@@ -25,25 +25,14 @@ namespace opr { | |||
*/ | |||
MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { | |||
public: | |||
enum Type { | |||
SEND, | |||
RECV | |||
}; | |||
struct PeerDesc { | |||
std::string key; | |||
Type type; | |||
bool is_grad; | |||
}; | |||
const PeerDesc& peer_desc() const { return m_peer; } | |||
const std::string& key() const { return m_key; } | |||
std::shared_ptr<GroupClient> group_client() const { | |||
return m_group_client; | |||
} | |||
protected: | |||
PeerDesc m_peer; | |||
std::string m_key; | |||
std::shared_ptr<GroupClient> m_group_client; | |||
std::shared_ptr<MegRay::Communicator> m_megray_comm; | |||
std::shared_ptr<MegRay::Context> m_megray_ctx; | |||
@@ -53,21 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { | |||
/*! | |||
* \brief send a variable to remote address; a virtual output is produced | |||
* for expressing dependency | |||
* for expressing dependency | |||
*/ | |||
MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | |||
public: | |||
RemoteSend(const PeerDesc& peer, VarNode* var, | |||
RemoteSend(const std::string& key, VarNode* var, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config); | |||
bool is_grad, const OperatorNodeConfig& config); | |||
static SymbolVar make( | |||
const PeerDesc& peer, SymbolVar var, | |||
const std::string& key, SymbolVar var, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config = {}); | |||
bool is_grad, const OperatorNodeConfig& config = {}); | |||
bool is_grad() const { return m_is_grad; } | |||
private: | |||
HostTensorND m_output_val; | |||
bool m_is_grad; | |||
void scn_do_execute() override; | |||
void init_output_static_infer_desc() override; | |||
@@ -75,19 +67,18 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | |||
}; | |||
/*! | |||
* \brief receive from multiple remote addresses and write to a var | |||
* | |||
* Target computing node of the var must be specified in config | |||
* \brief receive a variable from remote address; target computing node | |||
* of the var must be specified in config | |||
*/ | |||
MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||
public: | |||
RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
static SymbolVar make( | |||
const PeerDesc& peer, cg::ComputingGraph& graph, | |||
const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
@@ -20,9 +20,6 @@ | |||
using namespace mgb; | |||
const auto send_tag = opr::RemoteIOBase::Type::SEND; | |||
const auto recv_tag = opr::RemoteIOBase::Type::RECV; | |||
TEST(TestOprIORemote, Identity) { | |||
REQUIRE_GPU(2); | |||
auto cn0 = CompNode::load("gpu0"); | |||
@@ -36,8 +33,8 @@ TEST(TestOprIORemote, Identity) { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | |||
auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||
auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||
auto xr = opr::RemoteSend::make("x", x, client, false); | |||
auto y = opr::RemoteRecv::make("x", *graph.get(), | |||
client, {cn1}, host_x->shape(), | |||
host_x->dtype()); | |||
@@ -59,7 +56,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
auto graph = ComputingGraph::make(); | |||
sys::set_thread_name("sender"); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||
xr = opr::RemoteSend::make("x", x, client, false); | |||
auto func = graph->compile({{xr, {}}}); | |||
func->execute(); | |||
}; | |||
@@ -67,7 +64,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
auto receiver = [&]() { | |||
sys::set_thread_name("receiver"); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
client, {cns[0]}, host_x->shape(), | |||
host_x->dtype()); | |||
auto func = graph->compile({make_callback_copy(x, host_x_get)}); | |||
@@ -92,7 +89,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
sys::set_thread_name("sender"); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, | |||
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||
xr = opr::RemoteSend::make("x", x, client, false); | |||
auto func = graph->compile({{xr, {}}}); | |||
func->execute(); | |||
}; | |||
@@ -100,7 +97,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
auto receiver = [&]() { | |||
sys::set_thread_name("receiver"); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
client, {cns[0]}, host_x->shape(), | |||
host_x->dtype()); | |||
auto func = | |||
@@ -124,14 +121,14 @@ TEST(TestOprIORemote, APlusB) { | |||
auto sender = [&]() { | |||
auto graph = ComputingGraph::make(); | |||
auto z = opr::RemoteRecv::make({"z", recv_tag, false}, *graph.get(), | |||
auto z = opr::RemoteRecv::make("z", *graph.get(), | |||
client, {cns[0]}, host_x->shape(), | |||
host_x->dtype()); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), | |||
y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), | |||
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client) | |||
xr = opr::RemoteSend::make("x", x, client, false) | |||
.rename("xr"), | |||
yr = opr::RemoteSend::make({"y", send_tag, false}, y, client) | |||
yr = opr::RemoteSend::make("y", y, client, false) | |||
.rename("yr"); | |||
auto func = graph->compile( | |||
{{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); | |||
@@ -142,14 +139,14 @@ TEST(TestOprIORemote, APlusB) { | |||
auto receiver = [&]() { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
client, {cns[1]}, host_x->shape(), | |||
host_x->dtype()), | |||
y = opr::RemoteRecv::make({"y", recv_tag, false}, *graph.get(), | |||
y = opr::RemoteRecv::make("y", *graph.get(), | |||
client, {cns[1]}, host_y->shape(), | |||
host_y->dtype()), | |||
z = x + y, | |||
zr = opr::RemoteSend::make({"z", send_tag, false}, z, client); | |||
zr = opr::RemoteSend::make("z", z, client, false); | |||
auto func = graph->compile({{zr, {}}}); | |||
func->execute(); | |||
}; | |||
@@ -177,10 +174,10 @@ TEST(TestOprIORemote, SendGrad) { | |||
sys::set_thread_name("sender"); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
loss = opr::RemoteSend::make({"loss", send_tag, false}, x, client); | |||
loss = opr::RemoteSend::make("loss", x, client, false); | |||
ASSERT_TRUE(!loss.shape().ndim && | |||
loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | |||
loss = opr::RemoteSend::make({"loss", send_tag, true}, x, client); | |||
loss = opr::RemoteSend::make("loss", x, client, true); | |||
auto gx = cg::grad(loss, x); | |||
set_priority(loss, 0); | |||
set_priority(gx, 1); | |||
@@ -197,10 +194,10 @@ TEST(TestOprIORemote, SendGrad) { | |||
auto receiver = [&]() { | |||
sys::set_thread_name("receiver"); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::RemoteRecv::make({"loss", recv_tag, false}, *graph.get(), | |||
auto x = opr::RemoteRecv::make("loss", *graph.get(), | |||
client, {cns[1]}, host_x->shape(), | |||
host_x->dtype()); | |||
auto y = opr::RemoteSend::make({"loss:grad", send_tag, false}, x + 1, client); | |||
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false); | |||
auto func = graph->compile({{y, {}}}); | |||
func->execute(); | |||
}; | |||