diff --git a/python_module/megengine/distributed/functional.py b/python_module/megengine/distributed/functional.py index ee404d12..5a2e85f6 100644 --- a/python_module/megengine/distributed/functional.py +++ b/python_module/megengine/distributed/functional.py @@ -40,6 +40,32 @@ def reduce_sum( ) +def gather( + tensor: Tensor, + key: str, + nr_ranks: Optional[int] = None, + is_root: Optional[bool] = None, + rank: Optional[int] = None, +) -> Tensor: + """Create gather operator for collective communication + + :param tensor: input tensor + :param key: unique identifier for collective communication + :param nr_ranks: number of ranks, use util.get_world_size() as default + :param is_root: whether this is a root node + :param rank: rank of this node + """ + return _collective_comm( + tensor, + key, + CollParam.Mode.GATHER, + nr_ranks, + is_root, + rank, + device=tensor.device, + ) + + def broadcast( tensor: Tensor, key: str, @@ -74,6 +100,56 @@ def broadcast( ) +def scatter( + tensor: Tensor, + key: str, + nr_ranks: Optional[int] = None, + is_root: Optional[bool] = None, + rank: Optional[int] = None, +) -> Tensor: + """Create scatter operator for collective communication + + :param tensor: input tensor + :param key: unique identifier for collective communication + :param nr_ranks: number of ranks, use util.get_world_size() as default + :param is_root: whether this is a root node + :param rank: rank of this node + """ + if key is None: + key = tensor._symvar.name + if is_root is None: + is_root = get_rank() == 0 + + if is_root: + inp = tensor + else: + inp = tensor._symvar.owner_graph + + return _collective_comm( + inp, + key, + CollParam.Mode.SCATTER, + nr_ranks, + is_root, + rank, + dtype=tensor.dtype, + device=tensor.device, + ) + + +def all_to_all( + tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None +) -> Tensor: + """Create all_to_all operator for collective communication + + :param tensor: input tensor + :param key: unique identifier for collective communication + :param nr_ranks: number of ranks, use util.get_world_size() as default + :param rank: rank of this node + """ + return _collective_comm(tensor, key, CollParam.Mode.ALL_TO_ALL, nr_ranks, rank=rank) + + def all_gather( tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None ) -> Tensor: diff --git a/python_module/test/unit/distributed/test_functional.py b/python_module/test/unit/distributed/test_functional.py index d2bee6bd..a70484ba 100644 --- a/python_module/test/unit/distributed/test_functional.py +++ b/python_module/test/unit/distributed/test_functional.py @@ -62,6 +62,42 @@ def test_reduce_sum(): @pytest.mark.isolated_distributed +def test_gather(): + world_size = 2 + + def worker(rank, data, backend, expect, port_queue): + if not mge.is_cuda_available(): + return + _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) + inp = tensor(data) + output = dist.functional.gather(inp, "x", is_root=(rank == 0), rank=rank) + if rank == 0: + assert np.allclose(output.numpy(), expect) + else: + assert np.allclose(output.numpy(), 0) + + def check(shape, backend): + port_queue = mp.Queue() + x = np.random.rand(*shape).astype("float32") + y = np.random.rand(*shape).astype("float32") + z = np.concatenate((x, y)) + p0 = mp.Process(target=worker, args=(0, x, backend, z, port_queue)) + p1 = mp.Process(target=worker, args=(1, y, backend, None, port_queue)) + + p0.start() + p1.start() + + p0.join(10) + p1.join(10) + + assert p0.exitcode == 0 and p1.exitcode == 0 + + for shape in [(2, 3), (8, 10), (99, 77)]: + for backend in ["nccl", "ucx"]: + check(shape, backend) + + +@pytest.mark.isolated_distributed def test_broadcast(): world_size = 2 @@ -94,6 +130,76 @@ def test_broadcast(): @pytest.mark.isolated_distributed +def test_scatter(): + world_size = 2 + + def worker(rank, data, backend, expect, port_queue): + if not mge.is_cuda_available(): + return + _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) + inp = tensor(data) + output = dist.functional.scatter(inp, "x", is_root=(rank == 0), rank=rank) + assert np.allclose(output.numpy(), expect) + + def check(shape, backend): + port_queue = mp.Queue() + x = np.random.rand(*shape).astype("float32") + y = x + 1 + p0 = mp.Process( + target=worker, args=(0, x, backend, x[: shape[0] // 2], port_queue) + ) + p1 = mp.Process( + target=worker, args=(1, y, backend, x[shape[0] // 2 :], port_queue) + ) + + p0.start() + p1.start() + + p0.join(10) + p1.join(10) + + assert p0.exitcode == 0 and p1.exitcode == 0 + + for shape in [(2, 3), (8, 10), (100, 77)]: + for backend in ["nccl", "ucx"]: + check(shape, backend) + + +@pytest.mark.isolated_distributed +def test_all_to_all(): + world_size = 2 + + def worker(rank, data, backend, expect, port_queue): + if not mge.is_cuda_available(): + return + _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) + inp = tensor(data) + output = dist.functional.all_to_all(inp, "x", rank=rank) + assert np.allclose(output.numpy(), expect) + + def check(shape, backend): + port_queue = mp.Queue() + x = np.random.rand(*shape).astype("float32") + y = np.random.rand(*shape).astype("float32") + a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2])) + b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :])) + p0 = mp.Process(target=worker, args=(0, x, backend, a, port_queue)) + p1 = mp.Process(target=worker, args=(1, y, backend, b, port_queue)) + + p0.start() + p1.start() + + p0.join(10) + p1.join(10) + + assert p0.exitcode == 0 and p1.exitcode == 0 + + for shape in [(2, 3), (8, 10), (100, 77)]: + for backend in ["nccl", "ucx"]: + check(shape, backend) + + +@pytest.mark.isolated_distributed def test_all_gather(): world_size = 2 diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index d4db9f56..c04c813a 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -25,9 +25,10 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); -#define FOREACH_MODE(cb) \ - cb(ALL_REDUCE_SUM) cb(ALL_REDUCE_MAX) cb(ALL_REDUCE_MIN) cb(BROADCAST) \ - cb(REDUCE_SUM) cb(ALL_GATHER) cb(REDUCE_SCATTER_SUM) +#define FOREACH_MODE(cb) \ + cb(ALL_REDUCE_SUM) cb(ALL_REDUCE_MAX) cb(ALL_REDUCE_MIN) cb(BROADCAST) \ + cb(REDUCE_SUM) cb(ALL_GATHER) cb(REDUCE_SCATTER_SUM) cb(GATHER) \ + cb(SCATTER) cb(ALL_TO_ALL) namespace { @@ -84,6 +85,9 @@ class CollectiveComm::ModeTrait { class ALL_REDUCE_SUM; class ALL_REDUCE_MAX; class ALL_REDUCE_MIN; + class GATHER; + class SCATTER; + class ALL_TO_ALL; class ReducedBasedTrait; class AllReduceBase; @@ -350,6 +354,102 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait { Mode grad_mode() override { return Mode::REDUCE_SUM; } }; +class CollectiveComm::ModeTrait::GATHER : public ModeTrait { + void add_output_var(CollectiveComm* opr, + const CompNode::UnorderedSet&) override { + add_output_var_all2all(opr); + } + + void get_output_var_shape(const CollectiveComm* opr, + const TensorShapeArray& ishp, + TensorShapeArray& oshp) override { + MGB_MARK_USED_VAR(opr); + chk_shape_equal(ishp); + if (opr->is_root()) { + oshp[0] = ishp[0]; + oshp[0][0] *= opr->nr_devices(); + } else { + oshp[0] = TensorShape{1}; + } + } + + void exec(CollectiveComm* opr) override { + auto&& iv = opr->input(0)->dev_tensor(); + void* recvbuf = nullptr; + if (opr->is_root()) { + recvbuf = opr->output(0)->dev_tensor().raw_ptr(); + } + auto status = opr->m_megray_comm->gather( + (void*)iv.raw_ptr(), recvbuf, iv.shape().total_nr_elems(), + get_megray_dtype(iv.dtype()), opr->m_root, opr->megray_ctx()); + mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed"); + } + + Mode grad_mode() override { return Mode::SCATTER; } +}; + +class CollectiveComm::ModeTrait::SCATTER : public ModeTrait { + void add_output_var(CollectiveComm* opr, + const CompNode::UnorderedSet&) override { + if (opr->input().size() > 0) { + add_output_var_all2all(opr); + return; + } + + const auto& cns = opr->config().comp_node(); + mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size()); + auto pname = get_param_name(opr->param()); + opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]); + } + + void get_output_var_shape(const CollectiveComm* opr, + const TensorShapeArray& ishp, + TensorShapeArray& oshp) override { + mgb_throw(MegBrainError, "SCATTER should not use get_output_var_shape"); + } + + void exec(CollectiveComm* opr) override { + auto&& ov = opr->output(0)->dev_tensor(); + void* sendbuf = nullptr; + void* recvbuf = ov.raw_ptr(); + if (opr->is_root()) { + sendbuf = opr->input(0)->dev_tensor().raw_ptr(); + } + auto status = opr->m_megray_comm->scatter( + sendbuf, recvbuf, ov.shape().total_nr_elems(), + get_megray_dtype(ov.dtype()), opr->m_root, opr->megray_ctx()); + mgb_assert(status == MegRay::MEGRAY_OK, "MegRay scatter failed"); + } + + Mode grad_mode() override { return Mode::GATHER; } +}; + +class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait { + void add_output_var(CollectiveComm* opr, + const CompNode::UnorderedSet&) override { + add_output_var_all2all(opr); + } + + void get_output_var_shape(const CollectiveComm* opr, + const TensorShapeArray& ishp, + TensorShapeArray& oshp) override { + chk_shape_equal(ishp); + oshp = ishp; + } + + void exec(CollectiveComm* opr) override { + auto&& iv = opr->input(0)->dev_tensor(); + auto&& ov = opr->output(0)->dev_tensor(); + auto status = opr->m_megray_comm->all_to_all( + (void*)iv.raw_ptr(), (void*)ov.raw_ptr(), + iv.shape().total_nr_elems() / opr->nr_devices(), + get_megray_dtype(iv.dtype()), opr->megray_ctx()); + mgb_assert(status == MegRay::MEGRAY_OK, "MegRay all_to_all failed"); + } + + Mode grad_mode() override { return Mode::ALL_TO_ALL; } +}; + CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { switch (mode) { #define c(_m) \ @@ -651,41 +751,20 @@ void CollectiveComm::init_output_dtype() { } void CollectiveComm::init_output_static_infer_desc() { - if (m_param.mode == Param::Mode::REDUCE_SUM) { - using namespace cg::static_infer; - auto&& mgr = owner_graph()->static_infer_manager(); - - auto infer_shape_from_input = [](TensorShape& dest, const InpVal& inp_val) { - dest = inp_val.val[0].shape(); - return true; - }; - - auto infer_shape_constant = [](TensorShape& dest, const InpVal&) { - dest = TensorShape{1}; - return true; - }; - - mgb_assert(input().size() == 1); - mgb_assert(output().size() == 1); - - if (is_root()) { - mgr.register_shape_infer(output(0), - {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input}); - } else { - mgr.register_shape_infer(output(0), - {SourceType::CONSTANT, {}, infer_shape_constant}); - } - - } else if (m_param.mode == Param::Mode::BROADCAST) { + if (m_param.mode == Param::Mode::BROADCAST || + m_param.mode == Param::Mode::SCATTER) { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); auto infer_shape_from_input = [this](TensorShape& dest, const InpVal& inp_val) { - if (!m_broadcast_output_shape.valid()) { - m_broadcast_output_shape = inp_val.val[0].shape(); - m_group_client->set_output_shape(m_key, m_broadcast_output_shape.val()); - } dest = inp_val.val[0].shape(); + if (m_param.mode == Param::Mode::SCATTER) { + dest[0] /= nr_devices(); + } + if (!m_output_shape.valid()) { + m_output_shape = dest; + m_group_client->set_output_shape(m_key, dest); + } return true; }; @@ -694,10 +773,11 @@ void CollectiveComm::init_output_static_infer_desc() { return false; } - if (!m_broadcast_output_shape.valid()) { - m_broadcast_output_shape = m_group_client->get_output_shape(m_key); + if (!m_output_shape.valid()) { + m_output_shape = m_group_client->get_output_shape(m_key); } - dest = m_broadcast_output_shape.val(); + + dest = m_output_shape.val(); return true; }; diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 5cc58849..a4bb714e 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -18,6 +18,10 @@ using namespace mgb; using namespace opr; +cudaStream_t get_stream(VarNode* var) { + return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; +} + /* ===================== RemoteSend ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); @@ -35,7 +39,6 @@ RemoteSend::RemoteSend(const PeerDesc& peer, VarNode* var, ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) .add_flag(VarNode::Flag::VOLATILE_CONTENT); } - m_megray_ctx = MegRay::Context::make(); add_equivalence_component>(this); } @@ -56,6 +59,9 @@ void RemoteSend::scn_do_execute() { m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); + + m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); + m_init = true; } @@ -130,7 +136,6 @@ RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, ->dtype(dtype) .add_flag(VarNode::Flag::NO_MEM_RECLAIM) .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC); - m_megray_ctx = MegRay::Context::make(); add_equivalence_component>(this); } @@ -154,6 +159,9 @@ void RemoteRecv::scn_do_execute() { m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); + + m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); + m_init = true; } diff --git a/src/opr-mm/include/megbrain/opr/collective_comm.h b/src/opr-mm/include/megbrain/opr/collective_comm.h index 9ec9a5b4..791691ca 100644 --- a/src/opr-mm/include/megbrain/opr/collective_comm.h +++ b/src/opr-mm/include/megbrain/opr/collective_comm.h @@ -122,8 +122,9 @@ private: //! root of BROADCAST and REDUCE operation int m_root; //! rank of root of BROADCAST and REDUCE operation - Maybe m_broadcast_output_shape = None; - // Whether shape infer is enabled. This is only used by BROADCAST operation, + Maybe m_output_shape = None; + // Whether shape infer is enabled. + // This is only used by BROADCAST and SCATTER operation, // whose shape infer should be disabled *during* static infer phase. bool m_enable_shape_infer = false; diff --git a/src/opr-mm/test/collective_comm.cpp b/src/opr-mm/test/collective_comm.cpp index b24744c1..41747d03 100644 --- a/src/opr-mm/test/collective_comm.cpp +++ b/src/opr-mm/test/collective_comm.cpp @@ -719,6 +719,164 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { MGB_ASSERT_TENSOR_EQ(*host_grad, host_out_grad1); } +TEST(TestOprCollectiveComm, Gather) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + auto host_x0 = gen({28, 28}); + auto host_x1 = gen({28, 28}); + HostTensorND host_y0, host_y1, host_y_expect; + + auto client = std::make_shared(); + auto graph = ComputingGraph::make(); + + auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0); + auto x1c = opr::Copy::make(x1, cn1); + + auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "gather", + 2, true, 0, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "gather", + 2, false, 1, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0]; + auto y_expect = opr::Concat::make({x0, x1}, 0); + + auto func = graph->compile({make_callback_copy(y0, host_y0), + make_callback_copy(y1, host_y1), + make_callback_copy(y_expect, host_y_expect)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y0); +} + +TEST(TestOprCollectiveComm, GatherMultiThread) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + auto host_x0 = gen({28, 28}); + auto host_x1 = gen({28, 28}); + HostTensorND host_y0, host_y_expect; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "gather", 2, true, 0, client, + {Mode::GATHER}, dtype::Float32(), "nccl")[0]; + auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "gather", 2, false, 1, client, + {Mode::GATHER}, dtype::Float32(), "nccl")[0]; + auto func1 = graph1->compile({{y1, nullptr}}); + func1->execute(); + }; + + auto run_2 = [&]() { // check + auto graph2 = ComputingGraph::make(); + auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); + auto y_expect = opr::Concat::make({x0, x1}, 0); + auto func2 = graph2->compile({make_callback_copy(y_expect, host_y_expect)}); + func2->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + std::thread t2(run_2); + + t0.join(); + t1.join(); + t2.join(); + + MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y0); +} + +TEST(TestOprCollectiveComm, GatherWithGrad) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + TensorShape shape({28, 28}); + auto host_x0 = gen(shape); + auto host_x1 = gen(shape); + auto host_grad0 = gen(shape); + auto host_grad1 = gen(shape); + + HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + graph0->options().graph_opt_level = 0; + + auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "gather", 2, true, 0, client, + {Mode::GATHER}, dtype::Float32(), "nccl")[0]; + y0.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); + auto grad1 = opr::Host2DeviceCopy::make(*graph0, host_grad1, cn0); + auto grad = opr::Concat::make({grad0, grad1}, 0); + auto loss = opr::Dot::make(y0, grad); + auto g = opr::VirtualGrad::make(loss, x0); + + auto func0 = graph0->compile( + {make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + graph1->options().graph_opt_level = 0; + + auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "gather", 2, false, 1, client, + {Mode::GATHER}, dtype::Float32(), "nccl")[0]; + y1.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad = opr::Host2DeviceCopy::make(*graph1, gen({1}), cn1); + auto loss = opr::Dot::make(y1, grad); + auto g = opr::VirtualGrad::make(loss, x1); + + auto func1 = graph1->compile({{y1, nullptr}, make_callback_copy(g, host_out_grad1)}); + func1->execute(); + }; + + auto run_2 = [&]() { // check + auto graph2 = ComputingGraph::make(); + auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); + auto y0_expect = opr::Concat::make({x0, x1}, 0); + auto func2 = graph2->compile({ + make_callback_copy(y0_expect, host_y0_expect)}); + func2->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + std::thread t2(run_2); + + t0.join(); + t1.join(); + t2.join(); + + MGB_ASSERT_TENSOR_EQ(host_y0_expect, host_y0); + MGB_ASSERT_TENSOR_EQ(*host_grad0, host_out_grad0); + MGB_ASSERT_TENSOR_EQ(*host_grad1, host_out_grad1); +} + TEST(TestOprCollectiveComm, Broadcast) { REQUIRE_GPU(2); auto cn0 = CompNode::load("gpu0"); @@ -863,3 +1021,349 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { MGB_ASSERT_TENSOR_EQ(*host_x0, host_y1); MGB_ASSERT_TENSOR_EQ(host_out_grad_expect, host_out_grad); } + +TEST(TestOprCollectiveComm, Scatter) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + auto host_x0 = gen({28, 28}); + auto host_x1 = gen({28, 28}); + HostTensorND host_y0, host_y1; + + auto client = std::make_shared(); + auto graph = ComputingGraph::make(); + + auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0); + auto x = opr::Concat::make({x0, x1}, 0); + auto y0 = opr::CollectiveComm::make({x}, graph.get(), "scatter", + 2, true, 0, client, {Mode::SCATTER}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make({}, graph.get(), "scatter", 2, false, 1, + client, {Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0]; + + auto func = graph->compile({make_callback_copy(y0, host_y0), + make_callback_copy(y1, host_y1)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0); + MGB_ASSERT_TENSOR_EQ(*host_x1, host_y1); +} + +TEST(TestOprCollectiveComm, ScatterMultiThread) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + auto host_x0 = gen({28, 28}); + auto host_x1 = gen({28, 28}); + HostTensorND host_y0, host_y1; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph0, host_x1, cn0); + auto x = opr::Concat::make({x0, x1}, 0); + auto y0 = opr::CollectiveComm::make({x}, graph0.get(), "scatter", 2, true, 0, client, + {Mode::SCATTER}, dtype::Float32(), "nccl")[0]; + auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + auto y1 = opr::CollectiveComm::make({}, graph1.get(), "scatter", 2, false, 1, client, + {Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0]; + auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); + func1->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + + t0.join(); + t1.join(); + + MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0); + MGB_ASSERT_TENSOR_EQ(*host_x1, host_y1); +} + +TEST(TestOprCollectiveComm, ScatterWithGrad) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + TensorShape shape({28, 28}); + auto host_x0 = gen(shape); + auto host_x1 = gen(shape); + auto host_grad0 = gen(shape); + auto host_grad1 = gen(shape); + + HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + graph0->options().graph_opt_level = 0; + + auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph0, host_x1, cn0); + auto x = opr::Concat::make({x0, x1}, 0); + auto y0 = opr::CollectiveComm::make({x}, graph0.get(), "scatter", 2, true, 0, client, + {Mode::SCATTER}, dtype::Float32(), "nccl")[0]; + y0.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); + auto loss = opr::Dot::make(y0, grad0); + auto g = opr::VirtualGrad::make(loss, x); + + auto func0 = graph0->compile( + {make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + graph1->options().graph_opt_level = 0; + + auto y1 = opr::CollectiveComm::make({}, graph1.get(), "scatter", 2, false, 1, client, + {Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0]; + + auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); + auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "scatter:grad", 2, false, 1, client, + Mode::GATHER, dtype::Float32(), "nccl")[0]; + g.node()->owner_opr()->node_prop().attribute().priority = 1; + + auto func1 = graph1->compile({make_callback_copy(y1, host_y1), {g, nullptr}}); + func1->execute(); + }; + + auto run_2 = [&]() { // check + auto graph2 = ComputingGraph::make(); + auto grad0 = opr::Host2DeviceCopy::make(*graph2, host_grad0, cn0); + auto grad1 = opr::Host2DeviceCopy::make(*graph2, host_grad1, cn0); + auto out_grad_expect = opr::Concat::make({grad0, grad1}, 0); + auto func2 = graph2->compile({ + make_callback_copy(out_grad_expect, host_out_grad_expect)}); + func2->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + std::thread t2(run_2); + + t0.join(); + t1.join(); + t2.join(); + + MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0); + MGB_ASSERT_TENSOR_EQ(*host_x1, host_y1); + MGB_ASSERT_TENSOR_EQ(host_out_grad_expect, host_out_grad); +} + +TEST(TestOprCollectiveComm, AllToAll) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + TensorShape shape({10}); + auto host_x00 = gen(shape); + auto host_x01 = gen(shape); + auto host_x10 = gen(shape); + auto host_x11 = gen(shape); + HostTensorND host_y0, host_y1, host_expect_y0, host_expect_y1; + + auto client = std::make_shared(); + auto graph = ComputingGraph::make(); + + auto x00 = opr::Host2DeviceCopy::make(*graph, host_x00, cn0); + auto x01 = opr::Host2DeviceCopy::make(*graph, host_x01, cn0); + auto x0 = opr::Concat::make({x00, x01}, 0); + auto x10 = opr::Host2DeviceCopy::make(*graph, host_x10, cn1); + auto x11 = opr::Host2DeviceCopy::make(*graph, host_x11, cn1); + auto x1 = opr::Concat::make({x10, x11}, 0); + + auto x01c = opr::Copy::make(x01, {cn1}); + auto x10c = opr::Copy::make(x10, {cn0}); + + auto expect_y0 = opr::Concat::make({x00, x10c}, 0); + auto expect_y1 = opr::Concat::make({x01c, x11}, 0); + + auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "alltoall", + 2, false, 0, client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make({x1}, graph.get(), "alltoall", 2, false, 1, + client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + + auto func = graph->compile({make_callback_copy(y0, host_y0), + make_callback_copy(y1, host_y1), + make_callback_copy(expect_y0, host_expect_y0), + make_callback_copy(expect_y1, host_expect_y1)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_expect_y0, host_y0); + MGB_ASSERT_TENSOR_EQ(host_expect_y1, host_y1); +} + +TEST(TestOprCollectiveComm, AllToAllMultiThread) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + TensorShape shape({10}); + auto host_x00 = gen(shape); + auto host_x01 = gen(shape); + auto host_x10 = gen(shape); + auto host_x11 = gen(shape); + HostTensorND host_y0, host_y1, host_expect_y0, host_expect_y1; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + auto x00 = opr::Host2DeviceCopy::make(*graph0, host_x00, cn0); + auto x01 = opr::Host2DeviceCopy::make(*graph0, host_x01, cn0); + auto x10 = opr::Host2DeviceCopy::make(*graph0, host_x10, cn0); + auto x0 = opr::Concat::make({x00, x01}, 0); + auto expect_y0 = opr::Concat::make({x00, x10}, 0); + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "alltoall", 2, false, 0, client, + {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + auto func0 = graph0->compile( + {make_callback_copy(y0, host_y0), + make_callback_copy(expect_y0, host_expect_y0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + auto x10 = opr::Host2DeviceCopy::make(*graph1, host_x10, cn1); + auto x11 = opr::Host2DeviceCopy::make(*graph1, host_x11, cn1); + auto x01 = opr::Host2DeviceCopy::make(*graph1, host_x01, cn1); + auto x1 = opr::Concat::make({x10, x11}, 0); + auto expect_y1 = opr::Concat::make({x01, x11}, 0); + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "alltoall", 2, false, 1, client, + {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + auto func1 = graph1->compile( + {make_callback_copy(y1, host_y1), + make_callback_copy(expect_y1, host_expect_y1)}); + func1->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + + t0.join(); + t1.join(); + + MGB_ASSERT_TENSOR_EQ(host_expect_y0, host_y0); + MGB_ASSERT_TENSOR_EQ(host_expect_y1, host_y1); +} + +TEST(TestOprCollectiveComm, AllToAllWithGrad) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + TensorShape shape({10}); + auto host_x00 = gen(shape); + auto host_x01 = gen(shape); + auto host_x10 = gen(shape); + auto host_x11 = gen(shape); + auto host_grad00 = gen(shape); + auto host_grad01 = gen(shape); + auto host_grad10 = gen(shape); + auto host_grad11 = gen(shape); + + HostTensorND host_y0, host_y1, host_expect_y0, host_expect_y1, host_grad0, + host_grad1, host_expect_grad0, host_expect_grad1; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + graph0->options().graph_opt_level = 0; + + auto x00 = opr::Host2DeviceCopy::make(*graph0, host_x00, cn0); + auto x01 = opr::Host2DeviceCopy::make(*graph0, host_x01, cn0); + auto x10 = opr::Host2DeviceCopy::make(*graph0, host_x10, cn0); + auto x0 = opr::Concat::make({x00, x01}, 0); + auto expect_y0 = opr::Concat::make({x00, x10}, 0); + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "alltoall", 2, false, 0, client, + {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + y0.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad00 = opr::Host2DeviceCopy::make(*graph0, host_grad00, cn0); + auto grad10 = opr::Host2DeviceCopy::make(*graph0, host_grad10, cn0); + auto grad_y0 = opr::Concat::make({grad00, grad10}, 0); + auto loss = opr::Dot::make(y0, grad_y0); + auto g = opr::VirtualGrad::make(loss, x0); + + auto func0 = graph0->compile( + {make_callback_copy(y0, host_y0), + make_callback_copy(g, host_grad0), + make_callback_copy(expect_y0, host_expect_y0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + graph1->options().graph_opt_level = 0; + + auto x10 = opr::Host2DeviceCopy::make(*graph1, host_x10, cn1); + auto x11 = opr::Host2DeviceCopy::make(*graph1, host_x11, cn1); + auto x01 = opr::Host2DeviceCopy::make(*graph1, host_x01, cn1); + auto x1 = opr::Concat::make({x10, x11}, 0); + auto expect_y1 = opr::Concat::make({x01, x11}, 0); + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "alltoall", 2, false, 1, client, + {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + y1.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad01 = opr::Host2DeviceCopy::make(*graph1, host_grad01, cn1); + auto grad11 = opr::Host2DeviceCopy::make(*graph1, host_grad11, cn1); + auto grad_y1 = opr::Concat::make({grad01, grad11}, 0); + auto loss = opr::Dot::make(y1, grad_y1); + auto g = opr::VirtualGrad::make(loss, x1); + + auto func0 = graph1->compile( + {make_callback_copy(y1, host_y1), + make_callback_copy(g, host_grad1), + make_callback_copy(expect_y1, host_expect_y1)}); + func0->execute(); + }; + + auto run_2 = [&]() { // check + auto graph2 = ComputingGraph::make(); + auto grad00 = opr::Host2DeviceCopy::make(*graph2, host_grad00, cn0); + auto grad01 = opr::Host2DeviceCopy::make(*graph2, host_grad01, cn0); + auto grad10 = opr::Host2DeviceCopy::make(*graph2, host_grad10, cn0); + auto grad11 = opr::Host2DeviceCopy::make(*graph2, host_grad11, cn0); + auto out_grad0_expect = opr::Concat::make({grad00, grad01}, 0); + auto out_grad1_expect = opr::Concat::make({grad10, grad11}, 0); + auto func2 = graph2->compile({ + make_callback_copy(out_grad0_expect, host_expect_grad0), + make_callback_copy(out_grad1_expect, host_expect_grad1)}); + func2->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + std::thread t2(run_2); + + t0.join(); + t1.join(); + t2.join(); + + MGB_ASSERT_TENSOR_EQ(host_expect_y0, host_y0); + MGB_ASSERT_TENSOR_EQ(host_expect_y1, host_y1); + MGB_ASSERT_TENSOR_EQ(host_expect_grad0, host_grad0); + MGB_ASSERT_TENSOR_EQ(host_expect_grad1, host_grad1); +} diff --git a/tools/param_defs/mgb_opr_param_defs.py b/tools/param_defs/mgb_opr_param_defs.py index 21ccb751..2399178b 100644 --- a/tools/param_defs/mgb_opr_param_defs.py +++ b/tools/param_defs/mgb_opr_param_defs.py @@ -56,7 +56,10 @@ pdef('PersistentOutputStorage').add_fields( Doc('ALL_REDUCE_SUM', 'every output gets the sum of all inputs'), Doc('ALL_REDUCE_MAX', 'every output gets the max of all inputs'), Doc('ALL_REDUCE_MIN', 'every output gets the min of all inputs'), - Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'))) + Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'), + Doc('GATHER', 'concat inputs to one node'), + Doc('SCATTER', 'scatter input to each output computing node'), + Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'))) (pdef('FakeSerializedDType', 'HACK: The tag of this param def is actually used for another '