GitOrigin-RevId: 841a0e45ab
release-1.4
@@ -265,6 +265,7 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
op.key = key | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
op.backend = get_backend() | |||
(dummy,) = apply(_RemoteSend(op), inp) | |||
for g in grad_keys.values(): | |||
@@ -313,6 +314,7 @@ def remote_recv( | |||
op.dtype = dtype | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_from = src_rank | |||
op.backend = get_backend() | |||
(ret,) = apply(_RemoteRecv(op), inp) | |||
if _isscalar: | |||
@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||
OperatorNodeConfig config{send.make_name()}; | |||
cg::OperatorNodeBase* opr = | |||
graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( | |||
send.key, inputs[0], group_client, true, config)); | |||
send.key, inputs[0], group_client, true, send.backend, config)); | |||
return opr; | |||
} | |||
@@ -49,7 +49,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
auto&& graph = inputs[0]->owner_graph(); | |||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | |||
recv.key, inputs[0], *graph, group_client, config, | |||
recv.shape, recv.dtype)); | |||
recv.shape, recv.dtype, recv.backend)); | |||
} | |||
OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | |||
@@ -34,7 +34,7 @@ TEST(TestImperative, IORemote) { | |||
auto run_send = [&](std::shared_ptr<HostTensorND> hnd) { | |||
auto def = imperative::RemoteSend::make( | |||
"io_remote_test", server_addr, port, 1); | |||
"io_remote_test", server_addr, port, 1, "nccl"); | |||
auto inp = Tensor::make(*hnd); | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
}; | |||
@@ -43,7 +43,7 @@ TEST(TestImperative, IORemote) { | |||
auto def = imperative::RemoteRecv::make( | |||
"io_remote_test", server_addr, port, 0, | |||
CompNode::load("gpu1"), TensorShape{vector_size}, | |||
dtype::Float32()); | |||
dtype::Float32(), "nccl"); | |||
auto inp = Tensor::make(*hnd); | |||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
HostTensorND host_v; | |||
@@ -169,7 +169,8 @@ def RemoteSend : MgbHashableOp<"RemoteSend"> { | |||
MgbStringAttr:$key, | |||
MgbStringAttr:$addr, | |||
MgbUI32Attr:$port, | |||
MgbUI32Attr:$rank_to | |||
MgbUI32Attr:$rank_to, | |||
MgbStringAttr:$backend | |||
); | |||
} | |||
@@ -181,7 +182,8 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> { | |||
MgbUI32Attr:$rank_from, | |||
MgbCompNodeAttr:$cn, | |||
MgbTensorShapeAttr:$shape, | |||
MgbDTypeAttr:$dtype | |||
MgbDTypeAttr:$dtype, | |||
MgbStringAttr:$backend | |||
); | |||
} | |||
@@ -24,8 +24,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||
RemoteSend::RemoteSend(const std::string& key, VarNode* var, | |||
std::shared_ptr<GroupClient> group_client, | |||
bool is_grad, const OperatorNodeConfig& config) : | |||
bool is_grad, std::string backend, const OperatorNodeConfig& config) : | |||
Super(var->owner_graph(), config, "remote_send", {var}), | |||
m_backend(backend), | |||
m_is_grad(is_grad) { | |||
m_key = key; | |||
m_group_client = group_client; | |||
@@ -41,9 +42,9 @@ RemoteSend::RemoteSend(const std::string& key, VarNode* var, | |||
SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | |||
std::shared_ptr<GroupClient> group_client, | |||
bool is_grad, const OperatorNodeConfig& config) { | |||
bool is_grad, std::string backend, const OperatorNodeConfig& config) { | |||
return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client, | |||
is_grad, config); | |||
is_grad, backend, config); | |||
} | |||
void RemoteSend::scn_do_execute() { | |||
@@ -64,7 +65,7 @@ void RemoteSend::scn_do_execute() { | |||
} | |||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | |||
reg_info.hash, m_key, 2, 0, get_megray_backend(m_backend), m_group_client); | |||
m_megray_ctx = get_megray_context(output(0)->comp_node()); | |||
@@ -122,7 +123,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { | |||
*opr.owner_graph(), opr.group_client(), | |||
OperatorNodeConfig{opr.comp_node()}.name( | |||
opr.name() + ":grad_recv"), | |||
opr.input(0)->shape(), opr.input(0)->dtype()) | |||
opr.input(0)->shape(), opr.input(0)->dtype(), opr.backend()) | |||
.node(); | |||
} | |||
#endif | |||
@@ -134,9 +135,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||
RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
const TensorShape& shape, DType dtype) : | |||
const TensorShape& shape, DType dtype, std::string backend) : | |||
Super(&graph, config, "remote_recv", {}), | |||
m_shape(shape), m_dtype(dtype) { | |||
m_shape(shape), m_dtype(dtype), m_backend(backend) { | |||
m_key = key; | |||
m_group_client = group_client; | |||
@@ -150,9 +151,9 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
const TensorShape& shape, DType dtype) : | |||
const TensorShape& shape, DType dtype, std::string backend) : | |||
Super(&graph, config, "remote_recv", {}), | |||
m_shape(shape), m_dtype(dtype) { | |||
m_shape(shape), m_dtype(dtype), m_backend(backend) { | |||
m_key = key; | |||
m_group_client = group_client; | |||
@@ -167,18 +168,18 @@ RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& | |||
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
const TensorShape& shape, DType dtype) { | |||
const TensorShape& shape, DType dtype, std::string backend) { | |||
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||
key, graph, group_client, config, shape, dtype)); | |||
key, graph, group_client, config, shape, dtype, backend)); | |||
return opr->output(0); | |||
} | |||
SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
const TensorShape& shape, DType dtype) { | |||
const TensorShape& shape, DType dtype, std::string backend) { | |||
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||
key, var.node(), graph, group_client, config, shape, dtype)); | |||
key, var.node(), graph, group_client, config, shape, dtype, backend)); | |||
return opr->output(0); | |||
} | |||
@@ -201,7 +202,7 @@ void RemoteRecv::scn_do_execute() { | |||
} | |||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | |||
reg_info.hash, m_key, 2, 1, get_megray_backend(m_backend), m_group_client); | |||
m_megray_ctx = get_megray_context(output(0)->comp_node()); | |||
@@ -251,7 +252,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send( | |||
mgb_assert(inputs.size() == 1); | |||
auto&& opr = opr_.cast_final_safe<RemoteSend>(); | |||
return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), | |||
opr.is_grad(), config) | |||
opr.is_grad(), opr.backend(), config) | |||
.node() | |||
->owner_opr(); | |||
} | |||
@@ -265,14 +266,14 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( | |||
if (inputs.size() == 1) { | |||
return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(), | |||
opr.group_client(), config, opr.shape(), | |||
opr.dtype()) | |||
opr.dtype(), opr.backend()) | |||
.node() | |||
->owner_opr(); | |||
} else { | |||
mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input"); | |||
return RemoteRecv::make(opr.key(), *opr.owner_graph(), | |||
opr.group_client(), config, opr.shape(), | |||
opr.dtype()) | |||
opr.dtype(), opr.backend()) | |||
.node() | |||
->owner_opr(); | |||
} | |||
@@ -9,6 +9,8 @@ decl_raw_opr( | |||
Doc('key', 'key to bind send-recv pair', 'str'), | |||
Doc('var', 'variable to be sent', ':class:`.SymbolVar`'), | |||
Doc('is_grad', 'whether the send', 'bool'), | |||
Doc('backend', 'Backend for collective communication, nccl or ucx', | |||
'str', '\'nccl\''), | |||
] | |||
) | |||
@@ -24,7 +26,9 @@ decl_raw_opr( | |||
':class:`.CompGraph`'), | |||
Doc('shape', 'output var shape'), | |||
Doc('dtype', 'data type of the output var; must match dtype at sender', | |||
':class:`numpy.dtype` compatible') | |||
':class:`numpy.dtype` compatible'), | |||
Doc('backend', 'Backend for collective communication, nccl or ucx', | |||
'str', '\'nccl\''), | |||
] | |||
) | |||
@@ -48,17 +48,19 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | |||
public: | |||
RemoteSend(const std::string& key, VarNode* var, | |||
std::shared_ptr<GroupClient> group_client, | |||
bool is_grad, const OperatorNodeConfig& config); | |||
bool is_grad, std::string backend, const OperatorNodeConfig& config); | |||
static SymbolVar make( | |||
const std::string& key, SymbolVar var, | |||
std::shared_ptr<GroupClient> group_client, | |||
bool is_grad, const OperatorNodeConfig& config = {}); | |||
bool is_grad, std::string backend, const OperatorNodeConfig& config = {}); | |||
const std::string& backend() const { return m_backend; } | |||
bool is_grad() const { return m_is_grad; } | |||
private: | |||
HostTensorND m_output_val; | |||
std::string m_backend; | |||
bool m_is_grad; | |||
void scn_do_execute() override; | |||
@@ -75,31 +77,33 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||
RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
DType dtype, std::string backend); | |||
RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
DType dtype, std::string backend); | |||
static SymbolVar make( | |||
const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
DType dtype, std::string backend); | |||
static SymbolVar make( | |||
const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
DType dtype, std::string backend); | |||
const TensorShape& shape() const { return m_shape; } | |||
const DType& dtype() const { return m_dtype; } | |||
const std::string& backend() const { return m_backend; } | |||
private: | |||
const TensorShape m_shape; | |||
const DType m_dtype; | |||
const std::string m_backend; | |||
const CompNode m_comp_node; | |||
DeviceTensorND m_dev_buffer; | |||
@@ -33,10 +33,10 @@ TEST(TestOprIORemote, Identity) { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | |||
auto xr = opr::RemoteSend::make("x", x, client, false); | |||
auto xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||
auto y = opr::RemoteRecv::make("x", *graph.get(), | |||
client, {cn1}, host_x->shape(), | |||
host_x->dtype()); | |||
host_x->dtype(), "nccl"); | |||
auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)}); | |||
@@ -57,7 +57,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
auto graph = ComputingGraph::make(); | |||
sys::set_thread_name("sender"); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
xr = opr::RemoteSend::make("x", x, client, false); | |||
xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||
auto func = graph->compile({{xr, {}}}); | |||
func->execute(); | |||
}; | |||
@@ -67,7 +67,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
client, {cns[0]}, host_x->shape(), | |||
host_x->dtype()); | |||
host_x->dtype(), "nccl"); | |||
auto func = graph->compile({make_callback_copy(x, host_x_get)}); | |||
func->execute(); | |||
}; | |||
@@ -91,7 +91,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
sys::set_thread_name("sender"); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, | |||
xr = opr::RemoteSend::make("x", x, client, false); | |||
xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||
auto func = graph->compile({{xr, {}}}); | |||
func->execute(); | |||
}; | |||
@@ -101,7 +101,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
client, {cns[0]}, host_x->shape(), | |||
host_x->dtype()); | |||
host_x->dtype(), "nccl"); | |||
auto func = | |||
graph->compile({make_callback_copy((x - 1) / 2, host_x_get)}); | |||
func->execute(); | |||
@@ -126,12 +126,12 @@ TEST(TestOprIORemote, APlusB) { | |||
auto graph = ComputingGraph::make(); | |||
auto z = opr::RemoteRecv::make("z", *graph.get(), | |||
client, {cns[0]}, host_x->shape(), | |||
host_x->dtype()); | |||
host_x->dtype(), "nccl"); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), | |||
y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), | |||
xr = opr::RemoteSend::make("x", x, client, false) | |||
xr = opr::RemoteSend::make("x", x, client, false, "nccl") | |||
.rename("xr"), | |||
yr = opr::RemoteSend::make("y", y, client, false) | |||
yr = opr::RemoteSend::make("y", y, client, false, "nccl") | |||
.rename("yr"); | |||
auto func = graph->compile( | |||
{{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); | |||
@@ -144,12 +144,12 @@ TEST(TestOprIORemote, APlusB) { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
client, {cns[1]}, host_x->shape(), | |||
host_x->dtype()), | |||
host_x->dtype(), "nccl"), | |||
y = opr::RemoteRecv::make("y", *graph.get(), | |||
client, {cns[1]}, host_y->shape(), | |||
host_y->dtype()), | |||
host_y->dtype(), "nccl"), | |||
z = x + y, | |||
zr = opr::RemoteSend::make("z", z, client, false); | |||
zr = opr::RemoteSend::make("z", z, client, false, "nccl"); | |||
auto func = graph->compile({{zr, {}}}); | |||
func->execute(); | |||
}; | |||
@@ -178,10 +178,10 @@ TEST(TestOprIORemote, SendGrad) { | |||
sys::set_thread_name("sender"); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
loss = opr::RemoteSend::make("loss", x, client, false); | |||
loss = opr::RemoteSend::make("loss", x, client, false, "nccl"); | |||
ASSERT_TRUE(!loss.shape().ndim && | |||
loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | |||
loss = opr::RemoteSend::make("loss", x, client, true); | |||
loss = opr::RemoteSend::make("loss", x, client, true, "nccl"); | |||
auto gx = cg::grad(loss, x); | |||
set_priority(loss, 0); | |||
set_priority(gx, 1); | |||
@@ -200,8 +200,8 @@ TEST(TestOprIORemote, SendGrad) { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::RemoteRecv::make("loss", *graph.get(), | |||
client, {cns[1]}, host_x->shape(), | |||
host_x->dtype()); | |||
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false); | |||
host_x->dtype(), "nccl"); | |||
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false, "nccl"); | |||
auto func = graph->compile({{y, {}}}); | |||
func->execute(); | |||
}; | |||