GitOrigin-RevId: eb3d712704
HuaHua404-patch-1
@@ -1,6 +1,7 @@ | |||
# -*- coding: utf-8 -*- | |||
from mprop import mproperty | |||
from ..core._imperative_rt.core2 import group_end, group_start | |||
from . import group | |||
from .group import ( | |||
WORLD, | |||
@@ -1,12 +1,11 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Optional, Tuple | |||
from typing import Optional | |||
import numpy as np | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.autodiff.grad import Function, _grad_manager_dict | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.tensor.utils import isscalar | |||
from ..core.ops.builtin import CollectiveComm, RemoteRecv, RemoteSend | |||
from ..device import get_default_device, what_is_xpu | |||
from ..tensor import Tensor | |||
from . import group | |||
@@ -843,16 +842,13 @@ def remote_send(inp: Tensor, dest_rank: int): | |||
""" | |||
group = _SendRecvGroup(get_rank(), dest_rank) | |||
_bcast_shape_dtype(group, inp) | |||
_bcast_tracer_state(group, inp) | |||
op = RemoteSend() | |||
op.key = group.key | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
op.backend = _backend() | |||
out = _RemoteSend(op)(inp) | |||
_save_output_for_autodiff(inp, out) | |||
@@ -900,6 +896,34 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_from = src_rank | |||
op.backend = _backend() | |||
ret = _RemoteRecv(op)(inp) | |||
return ret | |||
def _remote_send_nobackward(inp: Tensor, dest_rank: int): | |||
op = RemoteSend() | |||
op.key = "b{}->{}".format(get_rank(), dest_rank) | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
op.backend = _backend() | |||
apply(op, inp) | |||
def _remote_recv_nobackward( | |||
src_rank: int, device: Optional[str] = None, inp=None, shape=None, dtype=None, | |||
): | |||
op = RemoteRecv() | |||
op.key = "b{}->{}".format(src_rank, get_rank()) | |||
if device is None: | |||
device = get_default_device() | |||
op.cn = device | |||
if inp is None: | |||
inp = Tensor(0, device=device) | |||
assert shape is not None and dtype is not None | |||
op.shape = shape | |||
op.dtype = dtype | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_from = src_rank | |||
op.backend = _backend() | |||
ret = apply(op, inp)[0] | |||
return ret |
@@ -160,6 +160,13 @@ def init_process_group( | |||
set_default_device("{}{}".format(device_type, device)) | |||
seed(int(time.time()) + rank) | |||
if backend == "nccl": | |||
# init nccl env | |||
from ..core._imperative_rt.common import init_nccl_env | |||
group_barrier() | |||
init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0) | |||
def _set_machine_ranks(ranks) -> None: | |||
global _sd | |||
@@ -8,6 +8,9 @@ | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/graph.h" | |||
#include "megbrain/imperative/physical_tensor.h" | |||
#if MGB_ENABLE_OPR_MM | |||
#include "megbrain/opr/mm_handler.h" | |||
#endif | |||
#if MEGDNN_WITH_CUDA | |||
#include "cuda_sm_gen.h" | |||
@@ -46,6 +49,18 @@ void set_default_device(const std::string& device) { | |||
default_device = device; | |||
} | |||
void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root) { | |||
#if MGB_ENABLE_OPR_MM | |||
auto&& help = mgb::opr::BatchSendRecvHelper::getInstance(); | |||
bool res = help->init(nranks, rank, ip, port, root); | |||
auto p = help->get(std::string("init_all_cards")); | |||
#else | |||
mgb_throw( | |||
MegBrainError, | |||
"MegEngine compiled without MM opr, doesn't support init_nccl_env"); | |||
#endif | |||
} | |||
std::string get_default_device() { | |||
return default_device; | |||
} | |||
@@ -252,6 +267,8 @@ void init_common(py::module m) { | |||
m.def("what_is_xpu", | |||
[] { return CompNode::Locator::parse("xpux").to_physical().type; }); | |||
m.def("init_nccl_env", &init_nccl_env); | |||
init_npy_num_bfloat16(m); | |||
init_npy_num_intbx(m); | |||
init_dtypes(m); | |||
@@ -8,3 +8,4 @@ void set_default_device(const std::string& device); | |||
std::string get_default_device(); | |||
extern pybind11::handle py_comp_node_type; | |||
void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root); |
@@ -9,6 +9,7 @@ | |||
#include "megbrain/imperative/transformations/dtype_promote.h" | |||
#include "megbrain/imperative/transformations/eval.h" | |||
#include "megbrain/imperative/transformations/format.h" | |||
#include "megbrain/imperative/transformations/group_comm.h" | |||
#include "megbrain/imperative/transformations/lazy.h" | |||
#include "megbrain/imperative/transformations/scalar.h" | |||
#include "megbrain/imperative/transformations/symbol.h" | |||
@@ -947,6 +948,13 @@ void init_tensor(py::module m) { | |||
m.def("enable_cupti", &cupti::enable); | |||
m.def("disable_cupti", &cupti::disable); | |||
m.def("cupti_available", &cupti::available); | |||
static std::unique_ptr<CleanupGuard<>> group_comm_guard; | |||
m.def("group_start", []() { | |||
auto commtrans = std::make_shared<GroupCommTransformation>(); | |||
group_comm_guard = transformations.register_at<Segment::GroupComm>(commtrans); | |||
}); | |||
m.def("group_end", []() { group_comm_guard.reset(); }); | |||
m.def("sync", [channel]() { | |||
if (channel->check_available()) { | |||
channel->sync(); | |||
@@ -16,6 +16,7 @@ struct TransformationManager { | |||
public: | |||
enum Segment { | |||
ModuleTrace, | |||
GroupComm, | |||
DTypePromote, | |||
DimExpansion, | |||
Format, | |||
@@ -26,7 +27,7 @@ public: | |||
Eval, | |||
}; | |||
std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments; | |||
std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments; | |||
private: | |||
template <Segment segment> | |||
@@ -237,3 +237,32 @@ def test_get_cuda_compute_capability(): | |||
assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0 | |||
worker() | |||
@pytest.mark.require_ngpu(3) | |||
@pytest.mark.isolated_distributed | |||
def test_batch_send_recv(): | |||
import megengine.distributed.functional as DF | |||
@dist.launcher(n_gpus=3) | |||
def worker(): | |||
rank = dist.get_rank() | |||
dist.group_start() | |||
for i in range(3): | |||
tensor = mge.tensor(np.ones(10000)) * rank | |||
if i == 2: | |||
tensor *= i | |||
DF._remote_send_nobackward(tensor, (rank + 1) % 3) | |||
DF._remote_recv_nobackward( | |||
src_rank=(rank + 1) % 3, dtype="float32", shape=(10000,) | |||
) | |||
DF._remote_send_nobackward(tensor, (rank - 1) % 3) | |||
recv = DF._remote_recv_nobackward( | |||
src_rank=(rank - 1) % 3, dtype="float32", shape=(10000,) | |||
) | |||
if i == 2: | |||
recv2 = recv | |||
dist.group_end() | |||
np.testing.assert_equal(recv2.numpy(), (rank - 1) % 3 * 2 * np.ones(10000)) | |||
worker() |
@@ -1,14 +1,19 @@ | |||
#include "megbrain/imperative/ops/io_remote.h" | |||
#include "megbrain_build_config.h" | |||
#if MGB_ENABLE_OPR_MM | |||
#include <algorithm> | |||
#include <functional> | |||
#include <numeric> | |||
#include "../blob_manager_impl.h" | |||
#include "../op_trait.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
#include "megbrain/opr/io_remote.h" | |||
#include "megbrain/opr/megray_helper.h" | |||
#include "megbrain/opr/mm_handler.h" | |||
#endif // MGB_ENABLE_OPR_MM | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -46,15 +51,164 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
recv.backend)); | |||
} | |||
TensorPtr megray_recv_tensor( | |||
std::shared_ptr<MegRay::Communicator> megray_comm, TensorLayout& layout, | |||
CompNode cn, uint32_t rank_from) { | |||
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, layout); | |||
auto megray_ctx = mgb::opr::get_megray_context(cn); | |||
size_t data_size = layout.total_nr_elems(); | |||
auto status = megray_comm->recv( | |||
out.raw_ptr(), data_size, mgb::opr::get_megray_dtype(layout.dtype), | |||
rank_from, megray_ctx); | |||
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); | |||
return Tensor::make(out); | |||
} | |||
void megray_send_tensor( | |||
std::shared_ptr<MegRay::Communicator> megray_comm, const TensorPtr& src, | |||
uint32_t rank_to) { | |||
auto&& tensor = src->dev_tensor(); | |||
auto&& ishp = src->shape(); | |||
size_t data_size = ishp.total_nr_elems(); | |||
auto megray_ctx = mgb::opr::get_megray_context(src->comp_node()); | |||
auto status = megray_comm->send( | |||
src->dev_tensor().raw_ptr(), data_size, | |||
mgb::opr::get_megray_dtype(src->layout().dtype), rank_to, megray_ctx); | |||
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | |||
} | |||
TensorLayout create_layout(const std::vector<int32_t>& shape, DType dtype) { | |||
TensorShape tshape; | |||
tshape.ndim = shape.size(); | |||
mgb_assert(tshape.ndim <= TensorLayout::MAX_NDIM); | |||
std::copy(shape.begin(), shape.end(), tshape.shape); | |||
return TensorLayout(tshape, dtype); | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_send( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||
auto&& dtype = input_descs[0].layout.dtype; | |||
auto&& cn = input_descs[0].comp_node; | |||
return {{{TensorLayout({0}, dtype), cn}}, true}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor_remote_send( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op = def.cast_final_safe<RemoteSend>(); | |||
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||
std::string("init_all_cards")); | |||
if (!megray_comm) { | |||
return proxy_graph_detail::apply_on_physical_tensor( | |||
def, inputs, output_descs, validated); | |||
} | |||
mgb_assert(megray_comm != nullptr); | |||
megray_send_tensor(megray_comm, inputs[0], op.rank_to); | |||
TensorLayout layout({0}, inputs[0]->dtype()); | |||
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( | |||
inputs[0]->comp_node(), layout); | |||
return {Tensor::make(out)}; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_recv( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||
auto& op = def.cast_final_safe<RemoteRecv>(); | |||
return {{{create_layout(op.shape, op.dtype), op.cn}}, true}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor_remote_recv( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op = def.cast_final_safe<RemoteRecv>(); | |||
auto layout = create_layout(op.shape, op.dtype); | |||
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||
std::string("init_all_cards")); | |||
if (!megray_comm) { | |||
return proxy_graph_detail::apply_on_physical_tensor( | |||
def, inputs, output_descs, validated); | |||
} | |||
auto&& out = megray_recv_tensor(megray_comm, layout, op.cn, op.rank_from); | |||
return {out}; | |||
} | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
for (size_t i; i < inputs.size(); i++) { | |||
layout_checker[i] = [](const TensorLayout& layout) { | |||
return layout.is_contiguous(); | |||
}; | |||
} | |||
return layout_checker; | |||
} | |||
OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | |||
.apply_on_var_node(apply_on_var_node_remote_send) | |||
.apply_on_physical_tensor(apply_on_physical_tensor_remote_send) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible_remote_send) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | |||
.apply_on_var_node(apply_on_var_node_remote_recv) | |||
.apply_on_physical_tensor(apply_on_physical_tensor_remote_recv) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible_remote_recv) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // anonymous namespace | |||
SmallVector<TensorPtr> apply_on_physical_tensor_batch_send_recv( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op = def.cast_final_safe<BatchSendRecvOp>(); | |||
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||
std::string("init_all_cards")); | |||
mgb_assert(megray_comm != nullptr); | |||
megray_comm->group_start(); | |||
SmallVector<TensorPtr> outputs; | |||
size_t ind = 0; | |||
for (auto&& op_ : op.op_list) { | |||
if (op_->same_type<RemoteSend>()) { | |||
auto&& send_op = op_->cast_final_safe<RemoteSend>(); | |||
auto&& tensor = inputs[ind]; | |||
megray_send_tensor(megray_comm, tensor, send_op.rank_to); | |||
ind++; | |||
} else { | |||
mgb_assert(op_->same_type<RemoteRecv>()); | |||
auto&& recv_op = op_->cast_final_safe<RemoteRecv>(); | |||
auto layout = create_layout(recv_op.shape, recv_op.dtype); | |||
auto&& out = megray_recv_tensor( | |||
megray_comm, layout, recv_op.cn, recv_op.rank_from); | |||
outputs.push_back(out); | |||
} | |||
} | |||
megray_comm->group_end(); | |||
return outputs; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> | |||
infer_output_attrs_fallible_batch_send_recv( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||
auto& op = def.cast_final_safe<BatchSendRecvOp>(); | |||
SmallVector<LogicalTensorDesc> output_descs; | |||
for (auto&& op_ : op.op_list) { | |||
if (op_->same_type<RemoteRecv>()) { | |||
auto&& recv_op = op_->cast_final_safe<RemoteRecv>(); | |||
output_descs.push_back( | |||
{create_layout(recv_op.shape, recv_op.dtype), recv_op.cn}); | |||
} | |||
} | |||
return {output_descs, true}; | |||
} | |||
OP_TRAIT_REG(BatchSendRecvOp, BatchSendRecvOp) | |||
.apply_on_physical_tensor(apply_on_physical_tensor_batch_send_recv) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible_batch_send_recv) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // namespace | |||
#endif // MGB_ENABLE_OPR_MM | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchSendRecvOp); | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,67 @@ | |||
#include "megbrain/imperative/transformations/group_comm.h" | |||
#include "megbrain/imperative/blob_manager.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/io_remote.h" | |||
namespace mgb { | |||
namespace imperative { | |||
ValueRefList GroupCommTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
for (auto inp : inputs) { | |||
mgb_assert( | |||
!inp.is(m_value_type), "Can not use PlaceholderValue as apply input"); | |||
} | |||
if (auto* apply_op = op.as<ApplyOp>()) { | |||
if (apply_op->op().same_type<RemoteSend>()) { | |||
auto&& send_op = apply_op->op().cast_final_safe<RemoteSend>(); | |||
if (send_op.key[0] == 'b') { | |||
send_inputs.push_back(inputs[0]); | |||
record_ops.push_back(send_op.shared_from_this()); | |||
return {}; | |||
} | |||
} | |||
if (apply_op->op().same_type<RemoteRecv>()) { | |||
auto&& recv_op = apply_op->op().cast_final_safe<RemoteRecv>(); | |||
if (recv_op.key[0] == 'b') { | |||
record_ops.push_back(recv_op.shared_from_this()); | |||
auto rst = m_value_type.make(); | |||
recv_tensors.push_back(rst); | |||
auto outputs = ValueRefList(1); | |||
outputs[0] = rst; | |||
return outputs; | |||
} | |||
} | |||
return imperative::apply(op, inputs); | |||
} else { | |||
return imperative::apply(op, inputs); | |||
} | |||
} | |||
ValueRefList GroupCommTransformation::execute_batch_op() { | |||
auto batch_op = BatchSendRecvOp::make(record_ops); | |||
auto outputs = imperative::apply(*batch_op, send_inputs); | |||
return outputs; | |||
} | |||
void GroupCommTransformation::on_unregister() noexcept { | |||
auto rst = execute_batch_op(); | |||
mgb_assert(rst.size() == recv_tensors.size()); | |||
for (size_t i = 0; i < rst.size(); i++) { | |||
auto v = recv_tensors[i].lock(); | |||
if (v != ValueRef::nil) { | |||
v.reset(rst[i]); | |||
} | |||
} | |||
} | |||
GroupCommTransformation::~GroupCommTransformation() { | |||
for (auto&& recv : recv_tensors) { | |||
mgb_assert( | |||
recv.lock() == ValueRef::nil, | |||
"Some PlaceholderValues are not reset after GroupCommTransformation " | |||
"destroyed!"); | |||
}; | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,11 @@ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
namespace mgb::imperative { | |||
struct BatchSendRecvOp final : OpDefImplBase<BatchSendRecvOp> { | |||
SmallVector<std::shared_ptr<OpDef>> op_list; | |||
BatchSendRecvOp() = default; | |||
BatchSendRecvOp(SmallVector<std::shared_ptr<OpDef>> op_list) : op_list{op_list} {} | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
}; | |||
} // namespace mgb::imperative |
@@ -0,0 +1,44 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/scalar.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/imperative/basic_operators.h" | |||
#include "megbrain/imperative/dispatch.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace mgb::imperative { | |||
class PlaceholderValue final : public ObjectValue<PlaceholderValue> { | |||
public: | |||
std::string to_string() const override { return ssprintf("PlaceholderValue"); } | |||
void clear() override {} | |||
}; | |||
class GroupCommTransformation final : public Transformation { | |||
private: | |||
SmallVector<ValueRef> send_inputs; | |||
std::vector<PlaceholderValue::weak_ref_t> recv_tensors; | |||
SmallVector<std::shared_ptr<OpDef>> record_ops; | |||
ObjectType<PlaceholderValue> m_value_type{"PlaceholderValue"}; | |||
public: | |||
GroupCommTransformation() = default; | |||
ValueRefList apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRefList execute_batch_op(); | |||
ValueRef unwrap(ValueRef value) override { return value; } | |||
std::string name() const override { return "GroupCommTransformation"; } | |||
void on_unregister() noexcept override; | |||
~GroupCommTransformation(); | |||
}; | |||
} // namespace mgb::imperative |
@@ -1,4 +1,5 @@ | |||
#include "./helper.h" | |||
#include "megbrain/comp_node_env.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/mm_handler.h" | |||
@@ -47,7 +48,4 @@ TEST(TestImperative, IORemote) { | |||
t0.join(); | |||
t1.join(); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
// ./imperative_test --gtest_filter TestIORemote |
@@ -151,6 +151,28 @@ void GroupManager::bcast_addr( | |||
} | |||
} | |||
void GroupManager::bcast_nccluniqueid( | |||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
uint32_t root) { | |||
std::unique_lock<std::mutex> lk{m_key2nccl_id_mtx}; | |||
if (rank == root) { | |||
m_key2nccl_id[key] = id; | |||
} | |||
m_key2nccl_id_size[key]++; | |||
if (m_key2nccl_id_size[key] == size) { | |||
m_key2nccl_id_flag[key] = true; | |||
m_bcast_cv.notify_all(); | |||
} else { | |||
m_bcast_cv.wait(lk, [&] { return m_key2nccl_id_flag.count(key) > 0; }); | |||
} | |||
id = m_key2nccl_id[key]; | |||
m_key2nccl_id_size[key]--; | |||
if (m_key2nccl_id_size[key] == 0) { | |||
m_key2nccl_id.erase(key); | |||
m_key2nccl_id_flag.erase(key); | |||
} | |||
} | |||
void GroupManager::set_output_shape(const std::string& key, const TensorShape& shape) { | |||
auto&& group = get_group(key); | |||
group.set_output_shape(key, shape); | |||
@@ -67,6 +67,15 @@ void MegRayCommBuilder::emplace( | |||
m_megray_comms.emplace(hash, comm); | |||
} | |||
void MegRayCommBuilder::remove( | |||
uint64_t hash, std::shared_ptr<MegRay::Communicator> comm) { | |||
std::unique_lock<std::mutex> lk(m_map_mtx); | |||
auto it = m_megray_comms.find(hash); | |||
if (it != m_megray_comms.end()) { | |||
m_megray_comms.erase(hash); | |||
} | |||
} | |||
std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
uint64_t hash, std::string key, uint32_t size, uint32_t rank, | |||
MegRay::Backend backend, std::shared_ptr<mgb::opr::GroupClient> group_client) { | |||
@@ -104,5 +113,3 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | |||
std::mutex MegRayCommBuilder::sm_instance_mtx; | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -45,6 +45,7 @@ public: | |||
RUNSERVER(get_output_shape); | |||
RUNSERVER(bcast_addr); | |||
RUNSERVER(group_barrier); | |||
RUNSERVER(bcast_nccluniqueid); | |||
mgb_assert(false, "invalid rpc request"); | |||
} | |||
@@ -53,6 +54,7 @@ private: | |||
void set_output_shape(void* input_ptr, size_t input_len, std::string* output); | |||
void get_output_shape(void* input_ptr, size_t input_len, std::string* output); | |||
void bcast_addr(void* input_ptr, size_t input_len, std::string* output); | |||
void bcast_nccluniqueid(void* input_ptr, size_t input_len, std::string* output); | |||
void group_barrier(void* input_ptr, size_t input_len, std::string* output); | |||
private: | |||
@@ -116,6 +118,15 @@ void GroupServerProxy::bcast_addr( | |||
rsp.SerializeToString(output); | |||
} | |||
void GroupServerProxy::bcast_nccluniqueid( | |||
void* input_ptr, size_t input_len, std::string* output) { | |||
INFO_INIT(mm_handler, BcastNcclUniqueId); | |||
std::string id = req.id(); | |||
m_mgr.bcast_nccluniqueid(req.key(), id, req.size(), req.rank(), req.root()); | |||
rsp.set_id(id); | |||
rsp.SerializeToString(output); | |||
} | |||
void GroupServerProxy::group_barrier( | |||
void* input_ptr, size_t input_len, std::string* output) { | |||
INFO_INIT(mm_handler, GroupBarrier); | |||
@@ -201,6 +212,19 @@ void GroupClientProxy::bcast_addr( | |||
port = rsp.port(); | |||
} | |||
void GroupClientProxy::bcast_nccluniqueid( | |||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
uint32_t root) { | |||
INFO_INIT(mm_handler, bcast_nccluniqueid, BcastNcclUniqueId); | |||
req.set_id(id.data(), id.size()); | |||
req.set_key(key.data(), key.size()); | |||
req.set_size(size); | |||
req.set_rank(rank); | |||
req.set_root(root); | |||
SOLVE_REQUEST(func_name, req, rsp); | |||
id = rsp.id(); | |||
} | |||
uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||
INFO_INIT(mm_handler, group_barrier, GroupBarrier); | |||
req.set_size(size); | |||
@@ -209,6 +233,40 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||
return rsp.size(); | |||
} | |||
std::shared_ptr<MegRay::Communicator> BatchSendRecvHelper::get(std::string&& key) { | |||
auto ptr = megray_comm_cache.find(key); | |||
if (ptr != megray_comm_cache.end()) { | |||
return megray_comm_cache[key]; | |||
} else { | |||
return nullptr; | |||
} | |||
} | |||
std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>> | |||
BatchSendRecvHelper::megray_comm_cache{}; | |||
bool BatchSendRecvHelper::init( | |||
int nranks, int rank, std::string ip, int port, int root) { | |||
auto megray_comm = | |||
MegRay::get_communicator(nranks, rank, MegRay::Backend::MEGRAY_NCCL); | |||
auto group_client = | |||
std::make_shared<opr::GroupClientProxy>(ssprintf("%s:%d", ip.data(), port)); | |||
auto cb = [=](char* nccl_buffer, size_t len) { | |||
std::string id; | |||
id.resize(128); | |||
if (rank == root) { | |||
memcpy(id.data(), nccl_buffer, len); | |||
} | |||
group_client->bcast_nccluniqueid("init_all_cards", id, nranks, rank, root); | |||
if (rank != root) { | |||
memcpy(nccl_buffer, id.data(), len); | |||
} | |||
}; | |||
megray_comm->init(cb); | |||
return megray_comm_cache.insert({std::string("init_all_cards"), megray_comm}) | |||
.second; | |||
} | |||
#undef INFO_INIT | |||
#undef SOLVE_REQUEST | |||
@@ -77,6 +77,11 @@ public: | |||
std::string& master_ip, int& port, const std::string& key, uint32_t size, | |||
uint32_t rank, uint32_t root); | |||
//! bcast uid | |||
void bcast_nccluniqueid( | |||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
uint32_t root); | |||
//! Set output shape of this key | |||
void set_output_shape(const std::string& key, const TensorShape& shape); | |||
@@ -101,6 +106,12 @@ private: | |||
std::mutex m_key2addr_mtx; | |||
std::condition_variable m_bcast_cv; | |||
//! key -> ncclid | |||
std::unordered_map<std::string, std::string> m_key2nccl_id; | |||
std::unordered_map<std::string, uint32_t> m_key2nccl_id_size; | |||
std::unordered_map<std::string, bool> m_key2nccl_id_flag; | |||
std::mutex m_key2nccl_id_mtx; | |||
//! barrier | |||
uint32_t m_barrier_size; | |||
std::set<uint32_t> m_barrier_set; | |||
@@ -128,6 +139,10 @@ public: | |||
std::string& master_ip, int& port, const std::string& key, uint32_t size, | |||
uint32_t rank, uint32_t root) = 0; | |||
virtual void bcast_nccluniqueid( | |||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
uint32_t root) = 0; | |||
virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; | |||
virtual TensorShape get_output_shape(const std::string& key) = 0; | |||
@@ -23,6 +23,7 @@ class MegRayCommBuilder { | |||
private: | |||
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | |||
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | |||
void remove(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | |||
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | |||
std::mutex m_map_mtx; | |||
@@ -39,6 +39,10 @@ public: | |||
std::string& master_ip, int& port, const std::string& key, uint32_t size, | |||
uint32_t rank, uint32_t root) override; | |||
void bcast_nccluniqueid( | |||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
uint32_t root) override; | |||
void set_output_shape(const std::string& key, const TensorShape& shape) override; | |||
TensorShape get_output_shape(const std::string& key) override; | |||
@@ -52,6 +56,34 @@ private: | |||
void* m_stub; | |||
}; | |||
template <typename T> | |||
class ProcessGlobal { // thread safe | |||
public: | |||
template <class... Args> | |||
static std::shared_ptr<T>& getInstance(Args&&... args) { | |||
static auto instance = std::make_shared<T>(std::forward<Args>(args)...); | |||
return instance; | |||
} | |||
protected: | |||
template <class... Args> | |||
ProcessGlobal(Args&&... args); | |||
ProcessGlobal() = default; | |||
public: | |||
ProcessGlobal(ProcessGlobal const&) = delete; | |||
void operator=(ProcessGlobal const&) = delete; | |||
}; | |||
class BatchSendRecvHelper : public ProcessGlobal<BatchSendRecvHelper> { | |||
static std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>> | |||
megray_comm_cache; | |||
public: | |||
std::shared_ptr<MegRay::Communicator> get(std::string&&); | |||
bool init(int nranks, int rank, std::string ip, int port, int root); | |||
}; | |||
/* ======================== ZmqRpcServerMgr ========================== */ | |||
int create_zmqrpc_server(const std::string& server_addr, int port); | |||
@@ -60,5 +92,3 @@ int create_zmqrpc_server(const std::string& server_addr, int port); | |||
} // namespace mgb | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -30,6 +30,18 @@ message BcastAddrResponse { | |||
int32 port = 2; | |||
} | |||
message BcastNcclUniqueIdRequest{ | |||
string key = 1; | |||
bytes id = 2; | |||
uint32 size =3 ; | |||
uint32 rank = 4; | |||
uint32 root =5; | |||
} | |||
message BcastNcclUniqueIdResponse{ | |||
bytes id = 1; | |||
} | |||
message SetOutputShapeRequest { | |||
string key = 1; | |||
TensorShape shape = 2; | |||
@@ -26,6 +26,12 @@ public: | |||
return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); | |||
} | |||
void bcast_nccluniqueid( | |||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
uint32_t root) override { | |||
return m_mgr.bcast_nccluniqueid(key, id, size, rank, root); | |||
} | |||
void set_output_shape(const std::string& key, const TensorShape& shape) override { | |||
m_mgr.set_output_shape(key, shape); | |||
} | |||