GitOrigin-RevId: b7a7bbd0da
release-0.6
@@ -72,10 +72,10 @@ SymbolVar _Opr::remote_send( | |||||
const std::string& key, SymbolVar var, | const std::string& key, SymbolVar var, | ||||
const bool is_grad, | const bool is_grad, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
return RemoteSend::make({key, RemoteIOBase::Type::SEND, is_grad}, var, | |||||
return RemoteSend::make(key, var, | |||||
std::make_shared<GroupClientProxy>(ssprintf( | std::make_shared<GroupClientProxy>(ssprintf( | ||||
"%s:%d", server_addr.c_str(), port)), | "%s:%d", server_addr.c_str(), port)), | ||||
config); | |||||
is_grad, config); | |||||
} | } | ||||
SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | ||||
@@ -85,8 +85,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||||
const TensorShape ishape = npy::vec2shape(shape); | const TensorShape ishape = npy::vec2shape(shape); | ||||
const DType idtype = npy::dtype_np2mgb(dtype); | const DType idtype = npy::dtype_np2mgb(dtype); | ||||
return RemoteRecv::make({key, RemoteIOBase::Type::RECV, false}, | |||||
graph.get(), | |||||
return RemoteRecv::make(key, graph.get(), | |||||
std::make_shared<GroupClientProxy>( | std::make_shared<GroupClientProxy>( | ||||
ssprintf("%s:%d", server_addr.c_str(), port)), | ssprintf("%s:%d", server_addr.c_str(), port)), | ||||
config, ishape, idtype); | config, ishape, idtype); | ||||
@@ -26,27 +26,28 @@ cudaStream_t get_stream(VarNode* var) { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | ||||
RemoteSend::RemoteSend(const PeerDesc& peer, VarNode* var, | |||||
RemoteSend::RemoteSend(const std::string& key, VarNode* var, | |||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config) : | |||||
Super(var->owner_graph(), config, "remote_send", {var}) { | |||||
m_peer = peer; | |||||
bool is_grad, const OperatorNodeConfig& config) : | |||||
Super(var->owner_graph(), config, "remote_send", {var}), | |||||
m_is_grad(is_grad) { | |||||
m_key = key; | |||||
m_group_client = group_client; | m_group_client = group_client; | ||||
add_input({var}); | add_input({var}); | ||||
auto ovar = add_output(None); | auto ovar = add_output(None); | ||||
if (!peer.is_grad) { | |||||
if (!m_is_grad) { | |||||
ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | ||||
.add_flag(VarNode::Flag::VOLATILE_CONTENT); | .add_flag(VarNode::Flag::VOLATILE_CONTENT); | ||||
} | } | ||||
add_equivalence_component<ScalarHash<void*>>(this); | add_equivalence_component<ScalarHash<void*>>(this); | ||||
} | } | ||||
SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var, | |||||
SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | |||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config) { | |||||
return var.insert_single_output_opr<RemoteSend>(peer, var.node(), | |||||
group_client, config); | |||||
bool is_grad, const OperatorNodeConfig& config) { | |||||
return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client, | |||||
is_grad, config); | |||||
} | } | ||||
void RemoteSend::scn_do_execute() { | void RemoteSend::scn_do_execute() { | ||||
@@ -54,11 +55,11 @@ void RemoteSend::scn_do_execute() { | |||||
auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
// rank 0 for RemoteSend | // rank 0 for RemoteSend | ||||
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | |||||
auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, | |||||
comp_node.get_uid()); | comp_node.get_uid()); | ||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||||
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | ||||
@@ -76,7 +77,7 @@ void RemoteSend::scn_do_execute() { | |||||
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); | auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); | ||||
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | ||||
if (m_peer.is_grad) { | |||||
if (m_is_grad) { | |||||
auto&& dest = output(0)->dev_tensor(); | auto&& dest = output(0)->dev_tensor(); | ||||
if (m_output_val.empty()) { | if (m_output_val.empty()) { | ||||
m_output_val.comp_node(dest.comp_node()) | m_output_val.comp_node(dest.comp_node()) | ||||
@@ -92,7 +93,7 @@ void RemoteSend::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | using namespace cg::static_infer; | ||||
auto&& mgr = owner_graph()->static_infer_manager(); | auto&& mgr = owner_graph()->static_infer_manager(); | ||||
auto do_infer = [this](TensorShape& dest, const InpVal&) { | auto do_infer = [this](TensorShape& dest, const InpVal&) { | ||||
if (peer_desc().is_grad) { | |||||
if (m_is_grad) { | |||||
dest = {1}; | dest = {1}; | ||||
} else { | } else { | ||||
dest = {0}; | dest = {0}; | ||||
@@ -109,9 +110,8 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { | |||||
} | } | ||||
MGB_IMPL_OPR_GRAD(RemoteSend) { | MGB_IMPL_OPR_GRAD(RemoteSend) { | ||||
mgb_assert(opr.peer_desc().is_grad); | |||||
return RemoteRecv::make({opr.peer_desc().key + ":grad", | |||||
RemoteIOBase::Type::RECV, false}, | |||||
mgb_assert(opr.is_grad()); | |||||
return RemoteRecv::make(opr.key() + ":grad", | |||||
*opr.owner_graph(), opr.group_client(), | *opr.owner_graph(), opr.group_client(), | ||||
OperatorNodeConfig{opr.comp_node()}.name( | OperatorNodeConfig{opr.comp_node()}.name( | ||||
opr.name() + ":grad_recv"), | opr.name() + ":grad_recv"), | ||||
@@ -123,13 +123,13 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | ||||
RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||||
RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
const TensorShape& shape, DType dtype) : | const TensorShape& shape, DType dtype) : | ||||
Super(&graph, config, "remote_recv", {}), | Super(&graph, config, "remote_recv", {}), | ||||
m_shape(shape), m_dtype(dtype) { | m_shape(shape), m_dtype(dtype) { | ||||
m_peer = peer; | |||||
m_key = key; | |||||
m_group_client = group_client; | m_group_client = group_client; | ||||
add_output(None) | add_output(None) | ||||
@@ -139,12 +139,12 @@ RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||||
add_equivalence_component<ScalarHash<void*>>(this); | add_equivalence_component<ScalarHash<void*>>(this); | ||||
} | } | ||||
SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph, | |||||
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
const TensorShape& shape, DType dtype) { | const TensorShape& shape, DType dtype) { | ||||
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | ||||
peer, graph, group_client, config, shape, dtype)); | |||||
key, graph, group_client, config, shape, dtype)); | |||||
return opr->output(0); | return opr->output(0); | ||||
} | } | ||||
@@ -154,11 +154,11 @@ void RemoteRecv::scn_do_execute() { | |||||
// rank 1 for RemoteRecv | // rank 1 for RemoteRecv | ||||
auto reg_info = m_group_client->opr_register( | auto reg_info = m_group_client->opr_register( | ||||
m_peer.key, 2, false, 1, | |||||
m_key, 2, false, 1, | |||||
comp_node.get_uid()); | comp_node.get_uid()); | ||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||||
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | ||||
@@ -206,8 +206,8 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send( | |||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
auto&& opr = opr_.cast_final_safe<RemoteSend>(); | auto&& opr = opr_.cast_final_safe<RemoteSend>(); | ||||
return RemoteSend::make(opr.peer_desc(), inputs[0], opr.group_client(), | |||||
config) | |||||
return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), | |||||
opr.is_grad(), config) | |||||
.node() | .node() | ||||
->owner_opr(); | ->owner_opr(); | ||||
} | } | ||||
@@ -218,7 +218,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
auto&& opr = opr_.cast_final_safe<RemoteRecv>(); | auto&& opr = opr_.cast_final_safe<RemoteRecv>(); | ||||
return RemoteRecv::make(opr.peer_desc(), *opr.owner_graph(), | |||||
return RemoteRecv::make(opr.key(), *opr.owner_graph(), | |||||
opr.group_client(), config, inputs[0]->shape(), | opr.group_client(), config, inputs[0]->shape(), | ||||
inputs[0]->dtype()) | inputs[0]->dtype()) | ||||
.node() | .node() | ||||
@@ -25,25 +25,14 @@ namespace opr { | |||||
*/ | */ | ||||
MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { | MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { | ||||
public: | public: | ||||
enum Type { | |||||
SEND, | |||||
RECV | |||||
}; | |||||
struct PeerDesc { | |||||
std::string key; | |||||
Type type; | |||||
bool is_grad; | |||||
}; | |||||
const PeerDesc& peer_desc() const { return m_peer; } | |||||
const std::string& key() const { return m_key; } | |||||
std::shared_ptr<GroupClient> group_client() const { | std::shared_ptr<GroupClient> group_client() const { | ||||
return m_group_client; | return m_group_client; | ||||
} | } | ||||
protected: | protected: | ||||
PeerDesc m_peer; | |||||
std::string m_key; | |||||
std::shared_ptr<GroupClient> m_group_client; | std::shared_ptr<GroupClient> m_group_client; | ||||
std::shared_ptr<MegRay::Communicator> m_megray_comm; | std::shared_ptr<MegRay::Communicator> m_megray_comm; | ||||
std::shared_ptr<MegRay::Context> m_megray_ctx; | std::shared_ptr<MegRay::Context> m_megray_ctx; | ||||
@@ -53,21 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { | |||||
/*! | /*! | ||||
* \brief send a variable to remote address; a virtual output is produced | * \brief send a variable to remote address; a virtual output is produced | ||||
* for expressing dependency | |||||
* for expressing dependency | |||||
*/ | */ | ||||
MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | ||||
public: | public: | ||||
RemoteSend(const PeerDesc& peer, VarNode* var, | |||||
RemoteSend(const std::string& key, VarNode* var, | |||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config); | |||||
bool is_grad, const OperatorNodeConfig& config); | |||||
static SymbolVar make( | static SymbolVar make( | ||||
const PeerDesc& peer, SymbolVar var, | |||||
const std::string& key, SymbolVar var, | |||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config = {}); | |||||
bool is_grad, const OperatorNodeConfig& config = {}); | |||||
bool is_grad() const { return m_is_grad; } | |||||
private: | private: | ||||
HostTensorND m_output_val; | HostTensorND m_output_val; | ||||
bool m_is_grad; | |||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
@@ -75,19 +67,18 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | |||||
}; | }; | ||||
/*! | /*! | ||||
* \brief receive from multiple remote addresses and write to a var | |||||
* | |||||
* Target computing node of the var must be specified in config | |||||
* \brief receive a variable from remote address; target computing node | |||||
* of the var must be specified in config | |||||
*/ | */ | ||||
MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | ||||
public: | public: | ||||
RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||||
RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | DType dtype); | ||||
static SymbolVar make( | static SymbolVar make( | ||||
const PeerDesc& peer, cg::ComputingGraph& graph, | |||||
const std::string& key, cg::ComputingGraph& graph, | |||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | DType dtype); | ||||
@@ -20,9 +20,6 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
const auto send_tag = opr::RemoteIOBase::Type::SEND; | |||||
const auto recv_tag = opr::RemoteIOBase::Type::RECV; | |||||
TEST(TestOprIORemote, Identity) { | TEST(TestOprIORemote, Identity) { | ||||
REQUIRE_GPU(2); | REQUIRE_GPU(2); | ||||
auto cn0 = CompNode::load("gpu0"); | auto cn0 = CompNode::load("gpu0"); | ||||
@@ -36,8 +33,8 @@ TEST(TestOprIORemote, Identity) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | ||||
auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||||
auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||||
auto xr = opr::RemoteSend::make("x", x, client, false); | |||||
auto y = opr::RemoteRecv::make("x", *graph.get(), | |||||
client, {cn1}, host_x->shape(), | client, {cn1}, host_x->shape(), | ||||
host_x->dtype()); | host_x->dtype()); | ||||
@@ -59,7 +56,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | auto x = opr::Host2DeviceCopy::make(*graph, host_x), | ||||
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||||
xr = opr::RemoteSend::make("x", x, client, false); | |||||
auto func = graph->compile({{xr, {}}}); | auto func = graph->compile({{xr, {}}}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||
@@ -67,7 +64,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||||
auto receiver = [&]() { | auto receiver = [&]() { | ||||
sys::set_thread_name("receiver"); | sys::set_thread_name("receiver"); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||||
client, {cns[0]}, host_x->shape(), | client, {cns[0]}, host_x->shape(), | ||||
host_x->dtype()); | host_x->dtype()); | ||||
auto func = graph->compile({make_callback_copy(x, host_x_get)}); | auto func = graph->compile({make_callback_copy(x, host_x_get)}); | ||||
@@ -92,7 +89,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||||
sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, | auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, | ||||
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||||
xr = opr::RemoteSend::make("x", x, client, false); | |||||
auto func = graph->compile({{xr, {}}}); | auto func = graph->compile({{xr, {}}}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||
@@ -100,7 +97,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||||
auto receiver = [&]() { | auto receiver = [&]() { | ||||
sys::set_thread_name("receiver"); | sys::set_thread_name("receiver"); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||||
client, {cns[0]}, host_x->shape(), | client, {cns[0]}, host_x->shape(), | ||||
host_x->dtype()); | host_x->dtype()); | ||||
auto func = | auto func = | ||||
@@ -124,14 +121,14 @@ TEST(TestOprIORemote, APlusB) { | |||||
auto sender = [&]() { | auto sender = [&]() { | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto z = opr::RemoteRecv::make({"z", recv_tag, false}, *graph.get(), | |||||
auto z = opr::RemoteRecv::make("z", *graph.get(), | |||||
client, {cns[0]}, host_x->shape(), | client, {cns[0]}, host_x->shape(), | ||||
host_x->dtype()); | host_x->dtype()); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), | auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), | ||||
y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), | y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), | ||||
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client) | |||||
xr = opr::RemoteSend::make("x", x, client, false) | |||||
.rename("xr"), | .rename("xr"), | ||||
yr = opr::RemoteSend::make({"y", send_tag, false}, y, client) | |||||
yr = opr::RemoteSend::make("y", y, client, false) | |||||
.rename("yr"); | .rename("yr"); | ||||
auto func = graph->compile( | auto func = graph->compile( | ||||
{{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); | {{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); | ||||
@@ -142,14 +139,14 @@ TEST(TestOprIORemote, APlusB) { | |||||
auto receiver = [&]() { | auto receiver = [&]() { | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||||
client, {cns[1]}, host_x->shape(), | client, {cns[1]}, host_x->shape(), | ||||
host_x->dtype()), | host_x->dtype()), | ||||
y = opr::RemoteRecv::make({"y", recv_tag, false}, *graph.get(), | |||||
y = opr::RemoteRecv::make("y", *graph.get(), | |||||
client, {cns[1]}, host_y->shape(), | client, {cns[1]}, host_y->shape(), | ||||
host_y->dtype()), | host_y->dtype()), | ||||
z = x + y, | z = x + y, | ||||
zr = opr::RemoteSend::make({"z", send_tag, false}, z, client); | |||||
zr = opr::RemoteSend::make("z", z, client, false); | |||||
auto func = graph->compile({{zr, {}}}); | auto func = graph->compile({{zr, {}}}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||
@@ -177,10 +174,10 @@ TEST(TestOprIORemote, SendGrad) { | |||||
sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | auto x = opr::Host2DeviceCopy::make(*graph, host_x), | ||||
loss = opr::RemoteSend::make({"loss", send_tag, false}, x, client); | |||||
loss = opr::RemoteSend::make("loss", x, client, false); | |||||
ASSERT_TRUE(!loss.shape().ndim && | ASSERT_TRUE(!loss.shape().ndim && | ||||
loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | ||||
loss = opr::RemoteSend::make({"loss", send_tag, true}, x, client); | |||||
loss = opr::RemoteSend::make("loss", x, client, true); | |||||
auto gx = cg::grad(loss, x); | auto gx = cg::grad(loss, x); | ||||
set_priority(loss, 0); | set_priority(loss, 0); | ||||
set_priority(gx, 1); | set_priority(gx, 1); | ||||
@@ -197,10 +194,10 @@ TEST(TestOprIORemote, SendGrad) { | |||||
auto receiver = [&]() { | auto receiver = [&]() { | ||||
sys::set_thread_name("receiver"); | sys::set_thread_name("receiver"); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::RemoteRecv::make({"loss", recv_tag, false}, *graph.get(), | |||||
auto x = opr::RemoteRecv::make("loss", *graph.get(), | |||||
client, {cns[1]}, host_x->shape(), | client, {cns[1]}, host_x->shape(), | ||||
host_x->dtype()); | host_x->dtype()); | ||||
auto y = opr::RemoteSend::make({"loss:grad", send_tag, false}, x + 1, client); | |||||
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false); | |||||
auto func = graph->compile({{y, {}}}); | auto func = graph->compile({{y, {}}}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||