GitOrigin-RevId: 841a0e45ab
release-1.4
@@ -265,6 +265,7 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
op.key = key | op.key = key | ||||
op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
op.rank_to = dest_rank | op.rank_to = dest_rank | ||||
op.backend = get_backend() | |||||
(dummy,) = apply(_RemoteSend(op), inp) | (dummy,) = apply(_RemoteSend(op), inp) | ||||
for g in grad_keys.values(): | for g in grad_keys.values(): | ||||
@@ -313,6 +314,7 @@ def remote_recv( | |||||
op.dtype = dtype | op.dtype = dtype | ||||
op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
op.rank_from = src_rank | op.rank_from = src_rank | ||||
op.backend = get_backend() | |||||
(ret,) = apply(_RemoteRecv(op), inp) | (ret,) = apply(_RemoteRecv(op), inp) | ||||
if _isscalar: | if _isscalar: | ||||
@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||||
OperatorNodeConfig config{send.make_name()}; | OperatorNodeConfig config{send.make_name()}; | ||||
cg::OperatorNodeBase* opr = | cg::OperatorNodeBase* opr = | ||||
graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( | graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( | ||||
send.key, inputs[0], group_client, true, config)); | |||||
send.key, inputs[0], group_client, true, send.backend, config)); | |||||
return opr; | return opr; | ||||
} | } | ||||
@@ -49,7 +49,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | ||||
recv.key, inputs[0], *graph, group_client, config, | recv.key, inputs[0], *graph, group_client, config, | ||||
recv.shape, recv.dtype)); | |||||
recv.shape, recv.dtype, recv.backend)); | |||||
} | } | ||||
OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | ||||
@@ -34,7 +34,7 @@ TEST(TestImperative, IORemote) { | |||||
auto run_send = [&](std::shared_ptr<HostTensorND> hnd) { | auto run_send = [&](std::shared_ptr<HostTensorND> hnd) { | ||||
auto def = imperative::RemoteSend::make( | auto def = imperative::RemoteSend::make( | ||||
"io_remote_test", server_addr, port, 1); | |||||
"io_remote_test", server_addr, port, 1, "nccl"); | |||||
auto inp = Tensor::make(*hnd); | auto inp = Tensor::make(*hnd); | ||||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | ||||
}; | }; | ||||
@@ -43,7 +43,7 @@ TEST(TestImperative, IORemote) { | |||||
auto def = imperative::RemoteRecv::make( | auto def = imperative::RemoteRecv::make( | ||||
"io_remote_test", server_addr, port, 0, | "io_remote_test", server_addr, port, 0, | ||||
CompNode::load("gpu1"), TensorShape{vector_size}, | CompNode::load("gpu1"), TensorShape{vector_size}, | ||||
dtype::Float32()); | |||||
dtype::Float32(), "nccl"); | |||||
auto inp = Tensor::make(*hnd); | auto inp = Tensor::make(*hnd); | ||||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | ||||
HostTensorND host_v; | HostTensorND host_v; | ||||
@@ -169,7 +169,8 @@ def RemoteSend : MgbHashableOp<"RemoteSend"> { | |||||
MgbStringAttr:$key, | MgbStringAttr:$key, | ||||
MgbStringAttr:$addr, | MgbStringAttr:$addr, | ||||
MgbUI32Attr:$port, | MgbUI32Attr:$port, | ||||
MgbUI32Attr:$rank_to | |||||
MgbUI32Attr:$rank_to, | |||||
MgbStringAttr:$backend | |||||
); | ); | ||||
} | } | ||||
@@ -181,7 +182,8 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> { | |||||
MgbUI32Attr:$rank_from, | MgbUI32Attr:$rank_from, | ||||
MgbCompNodeAttr:$cn, | MgbCompNodeAttr:$cn, | ||||
MgbTensorShapeAttr:$shape, | MgbTensorShapeAttr:$shape, | ||||
MgbDTypeAttr:$dtype | |||||
MgbDTypeAttr:$dtype, | |||||
MgbStringAttr:$backend | |||||
); | ); | ||||
} | } | ||||
@@ -24,8 +24,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||||
RemoteSend::RemoteSend(const std::string& key, VarNode* var, | RemoteSend::RemoteSend(const std::string& key, VarNode* var, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
bool is_grad, const OperatorNodeConfig& config) : | |||||
bool is_grad, std::string backend, const OperatorNodeConfig& config) : | |||||
Super(var->owner_graph(), config, "remote_send", {var}), | Super(var->owner_graph(), config, "remote_send", {var}), | ||||
m_backend(backend), | |||||
m_is_grad(is_grad) { | m_is_grad(is_grad) { | ||||
m_key = key; | m_key = key; | ||||
m_group_client = group_client; | m_group_client = group_client; | ||||
@@ -41,9 +42,9 @@ RemoteSend::RemoteSend(const std::string& key, VarNode* var, | |||||
SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
bool is_grad, const OperatorNodeConfig& config) { | |||||
bool is_grad, std::string backend, const OperatorNodeConfig& config) { | |||||
return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client, | return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client, | ||||
is_grad, config); | |||||
is_grad, backend, config); | |||||
} | } | ||||
void RemoteSend::scn_do_execute() { | void RemoteSend::scn_do_execute() { | ||||
@@ -64,7 +65,7 @@ void RemoteSend::scn_do_execute() { | |||||
} | } | ||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | |||||
reg_info.hash, m_key, 2, 0, get_megray_backend(m_backend), m_group_client); | |||||
m_megray_ctx = get_megray_context(output(0)->comp_node()); | m_megray_ctx = get_megray_context(output(0)->comp_node()); | ||||
@@ -122,7 +123,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { | |||||
*opr.owner_graph(), opr.group_client(), | *opr.owner_graph(), opr.group_client(), | ||||
OperatorNodeConfig{opr.comp_node()}.name( | OperatorNodeConfig{opr.comp_node()}.name( | ||||
opr.name() + ":grad_recv"), | opr.name() + ":grad_recv"), | ||||
opr.input(0)->shape(), opr.input(0)->dtype()) | |||||
opr.input(0)->shape(), opr.input(0)->dtype(), opr.backend()) | |||||
.node(); | .node(); | ||||
} | } | ||||
#endif | #endif | ||||
@@ -134,9 +135,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||||
RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
const TensorShape& shape, DType dtype) : | |||||
const TensorShape& shape, DType dtype, std::string backend) : | |||||
Super(&graph, config, "remote_recv", {}), | Super(&graph, config, "remote_recv", {}), | ||||
m_shape(shape), m_dtype(dtype) { | |||||
m_shape(shape), m_dtype(dtype), m_backend(backend) { | |||||
m_key = key; | m_key = key; | ||||
m_group_client = group_client; | m_group_client = group_client; | ||||
@@ -150,9 +151,9 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||||
RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
const TensorShape& shape, DType dtype) : | |||||
const TensorShape& shape, DType dtype, std::string backend) : | |||||
Super(&graph, config, "remote_recv", {}), | Super(&graph, config, "remote_recv", {}), | ||||
m_shape(shape), m_dtype(dtype) { | |||||
m_shape(shape), m_dtype(dtype), m_backend(backend) { | |||||
m_key = key; | m_key = key; | ||||
m_group_client = group_client; | m_group_client = group_client; | ||||
@@ -167,18 +168,18 @@ RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& | |||||
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
const TensorShape& shape, DType dtype) { | |||||
const TensorShape& shape, DType dtype, std::string backend) { | |||||
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | ||||
key, graph, group_client, config, shape, dtype)); | |||||
key, graph, group_client, config, shape, dtype, backend)); | |||||
return opr->output(0); | return opr->output(0); | ||||
} | } | ||||
SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
const TensorShape& shape, DType dtype) { | |||||
const TensorShape& shape, DType dtype, std::string backend) { | |||||
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | ||||
key, var.node(), graph, group_client, config, shape, dtype)); | |||||
key, var.node(), graph, group_client, config, shape, dtype, backend)); | |||||
return opr->output(0); | return opr->output(0); | ||||
} | } | ||||
@@ -201,7 +202,7 @@ void RemoteRecv::scn_do_execute() { | |||||
} | } | ||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | |||||
reg_info.hash, m_key, 2, 1, get_megray_backend(m_backend), m_group_client); | |||||
m_megray_ctx = get_megray_context(output(0)->comp_node()); | m_megray_ctx = get_megray_context(output(0)->comp_node()); | ||||
@@ -251,7 +252,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send( | |||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
auto&& opr = opr_.cast_final_safe<RemoteSend>(); | auto&& opr = opr_.cast_final_safe<RemoteSend>(); | ||||
return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), | return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), | ||||
opr.is_grad(), config) | |||||
opr.is_grad(), opr.backend(), config) | |||||
.node() | .node() | ||||
->owner_opr(); | ->owner_opr(); | ||||
} | } | ||||
@@ -265,14 +266,14 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( | |||||
if (inputs.size() == 1) { | if (inputs.size() == 1) { | ||||
return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(), | return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(), | ||||
opr.group_client(), config, opr.shape(), | opr.group_client(), config, opr.shape(), | ||||
opr.dtype()) | |||||
opr.dtype(), opr.backend()) | |||||
.node() | .node() | ||||
->owner_opr(); | ->owner_opr(); | ||||
} else { | } else { | ||||
mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input"); | mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input"); | ||||
return RemoteRecv::make(opr.key(), *opr.owner_graph(), | return RemoteRecv::make(opr.key(), *opr.owner_graph(), | ||||
opr.group_client(), config, opr.shape(), | opr.group_client(), config, opr.shape(), | ||||
opr.dtype()) | |||||
opr.dtype(), opr.backend()) | |||||
.node() | .node() | ||||
->owner_opr(); | ->owner_opr(); | ||||
} | } | ||||
@@ -9,6 +9,8 @@ decl_raw_opr( | |||||
Doc('key', 'key to bind send-recv pair', 'str'), | Doc('key', 'key to bind send-recv pair', 'str'), | ||||
Doc('var', 'variable to be sent', ':class:`.SymbolVar`'), | Doc('var', 'variable to be sent', ':class:`.SymbolVar`'), | ||||
Doc('is_grad', 'whether the send', 'bool'), | Doc('is_grad', 'whether the send', 'bool'), | ||||
Doc('backend', 'Backend for collective communication, nccl or ucx', | |||||
'str', '\'nccl\''), | |||||
] | ] | ||||
) | ) | ||||
@@ -24,7 +26,9 @@ decl_raw_opr( | |||||
':class:`.CompGraph`'), | ':class:`.CompGraph`'), | ||||
Doc('shape', 'output var shape'), | Doc('shape', 'output var shape'), | ||||
Doc('dtype', 'data type of the output var; must match dtype at sender', | Doc('dtype', 'data type of the output var; must match dtype at sender', | ||||
':class:`numpy.dtype` compatible') | |||||
':class:`numpy.dtype` compatible'), | |||||
Doc('backend', 'Backend for collective communication, nccl or ucx', | |||||
'str', '\'nccl\''), | |||||
] | ] | ||||
) | ) | ||||
@@ -48,17 +48,19 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | |||||
public: | public: | ||||
RemoteSend(const std::string& key, VarNode* var, | RemoteSend(const std::string& key, VarNode* var, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
bool is_grad, const OperatorNodeConfig& config); | |||||
bool is_grad, std::string backend, const OperatorNodeConfig& config); | |||||
static SymbolVar make( | static SymbolVar make( | ||||
const std::string& key, SymbolVar var, | const std::string& key, SymbolVar var, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
bool is_grad, const OperatorNodeConfig& config = {}); | |||||
bool is_grad, std::string backend, const OperatorNodeConfig& config = {}); | |||||
const std::string& backend() const { return m_backend; } | |||||
bool is_grad() const { return m_is_grad; } | bool is_grad() const { return m_is_grad; } | ||||
private: | private: | ||||
HostTensorND m_output_val; | HostTensorND m_output_val; | ||||
std::string m_backend; | |||||
bool m_is_grad; | bool m_is_grad; | ||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
@@ -75,31 +77,33 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||||
RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | |||||
DType dtype, std::string backend); | |||||
RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | |||||
DType dtype, std::string backend); | |||||
static SymbolVar make( | static SymbolVar make( | ||||
const std::string& key, cg::ComputingGraph& graph, | const std::string& key, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | |||||
DType dtype, std::string backend); | |||||
static SymbolVar make( | static SymbolVar make( | ||||
const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | |||||
DType dtype, std::string backend); | |||||
const TensorShape& shape() const { return m_shape; } | const TensorShape& shape() const { return m_shape; } | ||||
const DType& dtype() const { return m_dtype; } | const DType& dtype() const { return m_dtype; } | ||||
const std::string& backend() const { return m_backend; } | |||||
private: | private: | ||||
const TensorShape m_shape; | const TensorShape m_shape; | ||||
const DType m_dtype; | const DType m_dtype; | ||||
const std::string m_backend; | |||||
const CompNode m_comp_node; | const CompNode m_comp_node; | ||||
DeviceTensorND m_dev_buffer; | DeviceTensorND m_dev_buffer; | ||||
@@ -33,10 +33,10 @@ TEST(TestOprIORemote, Identity) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | ||||
auto xr = opr::RemoteSend::make("x", x, client, false); | |||||
auto xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||||
auto y = opr::RemoteRecv::make("x", *graph.get(), | auto y = opr::RemoteRecv::make("x", *graph.get(), | ||||
client, {cn1}, host_x->shape(), | client, {cn1}, host_x->shape(), | ||||
host_x->dtype()); | |||||
host_x->dtype(), "nccl"); | |||||
auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)}); | auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)}); | ||||
@@ -57,7 +57,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | auto x = opr::Host2DeviceCopy::make(*graph, host_x), | ||||
xr = opr::RemoteSend::make("x", x, client, false); | |||||
xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||||
auto func = graph->compile({{xr, {}}}); | auto func = graph->compile({{xr, {}}}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||
@@ -67,7 +67,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::RemoteRecv::make("x", *graph.get(), | auto x = opr::RemoteRecv::make("x", *graph.get(), | ||||
client, {cns[0]}, host_x->shape(), | client, {cns[0]}, host_x->shape(), | ||||
host_x->dtype()); | |||||
host_x->dtype(), "nccl"); | |||||
auto func = graph->compile({make_callback_copy(x, host_x_get)}); | auto func = graph->compile({make_callback_copy(x, host_x_get)}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||
@@ -91,7 +91,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||||
sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, | auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, | ||||
xr = opr::RemoteSend::make("x", x, client, false); | |||||
xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||||
auto func = graph->compile({{xr, {}}}); | auto func = graph->compile({{xr, {}}}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||
@@ -101,7 +101,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::RemoteRecv::make("x", *graph.get(), | auto x = opr::RemoteRecv::make("x", *graph.get(), | ||||
client, {cns[0]}, host_x->shape(), | client, {cns[0]}, host_x->shape(), | ||||
host_x->dtype()); | |||||
host_x->dtype(), "nccl"); | |||||
auto func = | auto func = | ||||
graph->compile({make_callback_copy((x - 1) / 2, host_x_get)}); | graph->compile({make_callback_copy((x - 1) / 2, host_x_get)}); | ||||
func->execute(); | func->execute(); | ||||
@@ -126,12 +126,12 @@ TEST(TestOprIORemote, APlusB) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto z = opr::RemoteRecv::make("z", *graph.get(), | auto z = opr::RemoteRecv::make("z", *graph.get(), | ||||
client, {cns[0]}, host_x->shape(), | client, {cns[0]}, host_x->shape(), | ||||
host_x->dtype()); | |||||
host_x->dtype(), "nccl"); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), | auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), | ||||
y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), | y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), | ||||
xr = opr::RemoteSend::make("x", x, client, false) | |||||
xr = opr::RemoteSend::make("x", x, client, false, "nccl") | |||||
.rename("xr"), | .rename("xr"), | ||||
yr = opr::RemoteSend::make("y", y, client, false) | |||||
yr = opr::RemoteSend::make("y", y, client, false, "nccl") | |||||
.rename("yr"); | .rename("yr"); | ||||
auto func = graph->compile( | auto func = graph->compile( | ||||
{{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); | {{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); | ||||
@@ -144,12 +144,12 @@ TEST(TestOprIORemote, APlusB) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::RemoteRecv::make("x", *graph.get(), | auto x = opr::RemoteRecv::make("x", *graph.get(), | ||||
client, {cns[1]}, host_x->shape(), | client, {cns[1]}, host_x->shape(), | ||||
host_x->dtype()), | |||||
host_x->dtype(), "nccl"), | |||||
y = opr::RemoteRecv::make("y", *graph.get(), | y = opr::RemoteRecv::make("y", *graph.get(), | ||||
client, {cns[1]}, host_y->shape(), | client, {cns[1]}, host_y->shape(), | ||||
host_y->dtype()), | |||||
host_y->dtype(), "nccl"), | |||||
z = x + y, | z = x + y, | ||||
zr = opr::RemoteSend::make("z", z, client, false); | |||||
zr = opr::RemoteSend::make("z", z, client, false, "nccl"); | |||||
auto func = graph->compile({{zr, {}}}); | auto func = graph->compile({{zr, {}}}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||
@@ -178,10 +178,10 @@ TEST(TestOprIORemote, SendGrad) { | |||||
sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | auto x = opr::Host2DeviceCopy::make(*graph, host_x), | ||||
loss = opr::RemoteSend::make("loss", x, client, false); | |||||
loss = opr::RemoteSend::make("loss", x, client, false, "nccl"); | |||||
ASSERT_TRUE(!loss.shape().ndim && | ASSERT_TRUE(!loss.shape().ndim && | ||||
loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | ||||
loss = opr::RemoteSend::make("loss", x, client, true); | |||||
loss = opr::RemoteSend::make("loss", x, client, true, "nccl"); | |||||
auto gx = cg::grad(loss, x); | auto gx = cg::grad(loss, x); | ||||
set_priority(loss, 0); | set_priority(loss, 0); | ||||
set_priority(gx, 1); | set_priority(gx, 1); | ||||
@@ -200,8 +200,8 @@ TEST(TestOprIORemote, SendGrad) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::RemoteRecv::make("loss", *graph.get(), | auto x = opr::RemoteRecv::make("loss", *graph.get(), | ||||
client, {cns[1]}, host_x->shape(), | client, {cns[1]}, host_x->shape(), | ||||
host_x->dtype()); | |||||
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false); | |||||
host_x->dtype(), "nccl"); | |||||
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false, "nccl"); | |||||
auto func = graph->compile({{y, {}}}); | auto func = graph->compile({{y, {}}}); | ||||
func->execute(); | func->execute(); | ||||
}; | }; | ||||