GitOrigin-RevId: eb3d712704
HuaHua404-patch-1
@@ -1,6 +1,7 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
from mprop import mproperty | from mprop import mproperty | ||||
from ..core._imperative_rt.core2 import group_end, group_start | |||||
from . import group | from . import group | ||||
from .group import ( | from .group import ( | ||||
WORLD, | WORLD, | ||||
@@ -1,12 +1,11 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
from typing import Optional, Tuple | |||||
from typing import Optional | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core.autodiff.grad import Function, _grad_manager_dict | from ..core.autodiff.grad import Function, _grad_manager_dict | ||||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||||
from ..core.tensor.utils import isscalar | |||||
from ..core.ops.builtin import CollectiveComm, RemoteRecv, RemoteSend | |||||
from ..device import get_default_device, what_is_xpu | from ..device import get_default_device, what_is_xpu | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from . import group | from . import group | ||||
@@ -843,16 +842,13 @@ def remote_send(inp: Tensor, dest_rank: int): | |||||
""" | """ | ||||
group = _SendRecvGroup(get_rank(), dest_rank) | group = _SendRecvGroup(get_rank(), dest_rank) | ||||
_bcast_shape_dtype(group, inp) | _bcast_shape_dtype(group, inp) | ||||
_bcast_tracer_state(group, inp) | _bcast_tracer_state(group, inp) | ||||
op = RemoteSend() | op = RemoteSend() | ||||
op.key = group.key | op.key = group.key | ||||
op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
op.rank_to = dest_rank | op.rank_to = dest_rank | ||||
op.backend = _backend() | op.backend = _backend() | ||||
out = _RemoteSend(op)(inp) | out = _RemoteSend(op)(inp) | ||||
_save_output_for_autodiff(inp, out) | _save_output_for_autodiff(inp, out) | ||||
@@ -900,6 +896,34 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||||
op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
op.rank_from = src_rank | op.rank_from = src_rank | ||||
op.backend = _backend() | op.backend = _backend() | ||||
ret = _RemoteRecv(op)(inp) | ret = _RemoteRecv(op)(inp) | ||||
return ret | return ret | ||||
def _remote_send_nobackward(inp: Tensor, dest_rank: int): | |||||
op = RemoteSend() | |||||
op.key = "b{}->{}".format(get_rank(), dest_rank) | |||||
op.addr, op.port = get_mm_server_addr() | |||||
op.rank_to = dest_rank | |||||
op.backend = _backend() | |||||
apply(op, inp) | |||||
def _remote_recv_nobackward( | |||||
src_rank: int, device: Optional[str] = None, inp=None, shape=None, dtype=None, | |||||
): | |||||
op = RemoteRecv() | |||||
op.key = "b{}->{}".format(src_rank, get_rank()) | |||||
if device is None: | |||||
device = get_default_device() | |||||
op.cn = device | |||||
if inp is None: | |||||
inp = Tensor(0, device=device) | |||||
assert shape is not None and dtype is not None | |||||
op.shape = shape | |||||
op.dtype = dtype | |||||
op.addr, op.port = get_mm_server_addr() | |||||
op.rank_from = src_rank | |||||
op.backend = _backend() | |||||
ret = apply(op, inp)[0] | |||||
return ret |
@@ -160,6 +160,13 @@ def init_process_group( | |||||
set_default_device("{}{}".format(device_type, device)) | set_default_device("{}{}".format(device_type, device)) | ||||
seed(int(time.time()) + rank) | seed(int(time.time()) + rank) | ||||
if backend == "nccl": | |||||
# init nccl env | |||||
from ..core._imperative_rt.common import init_nccl_env | |||||
group_barrier() | |||||
init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0) | |||||
def _set_machine_ranks(ranks) -> None: | def _set_machine_ranks(ranks) -> None: | ||||
global _sd | global _sd | ||||
@@ -8,6 +8,9 @@ | |||||
#include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/imperative/physical_tensor.h" | #include "megbrain/imperative/physical_tensor.h" | ||||
#if MGB_ENABLE_OPR_MM | |||||
#include "megbrain/opr/mm_handler.h" | |||||
#endif | |||||
#if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
#include "cuda_sm_gen.h" | #include "cuda_sm_gen.h" | ||||
@@ -46,6 +49,18 @@ void set_default_device(const std::string& device) { | |||||
default_device = device; | default_device = device; | ||||
} | } | ||||
void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root) { | |||||
#if MGB_ENABLE_OPR_MM | |||||
auto&& help = mgb::opr::BatchSendRecvHelper::getInstance(); | |||||
bool res = help->init(nranks, rank, ip, port, root); | |||||
auto p = help->get(std::string("init_all_cards")); | |||||
#else | |||||
mgb_throw( | |||||
MegBrainError, | |||||
"MegEngine compiled without MM opr, doesn't support init_nccl_env"); | |||||
#endif | |||||
} | |||||
std::string get_default_device() { | std::string get_default_device() { | ||||
return default_device; | return default_device; | ||||
} | } | ||||
@@ -252,6 +267,8 @@ void init_common(py::module m) { | |||||
m.def("what_is_xpu", | m.def("what_is_xpu", | ||||
[] { return CompNode::Locator::parse("xpux").to_physical().type; }); | [] { return CompNode::Locator::parse("xpux").to_physical().type; }); | ||||
m.def("init_nccl_env", &init_nccl_env); | |||||
init_npy_num_bfloat16(m); | init_npy_num_bfloat16(m); | ||||
init_npy_num_intbx(m); | init_npy_num_intbx(m); | ||||
init_dtypes(m); | init_dtypes(m); | ||||
@@ -8,3 +8,4 @@ void set_default_device(const std::string& device); | |||||
std::string get_default_device(); | std::string get_default_device(); | ||||
extern pybind11::handle py_comp_node_type; | extern pybind11::handle py_comp_node_type; | ||||
void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root); |
@@ -9,6 +9,7 @@ | |||||
#include "megbrain/imperative/transformations/dtype_promote.h" | #include "megbrain/imperative/transformations/dtype_promote.h" | ||||
#include "megbrain/imperative/transformations/eval.h" | #include "megbrain/imperative/transformations/eval.h" | ||||
#include "megbrain/imperative/transformations/format.h" | #include "megbrain/imperative/transformations/format.h" | ||||
#include "megbrain/imperative/transformations/group_comm.h" | |||||
#include "megbrain/imperative/transformations/lazy.h" | #include "megbrain/imperative/transformations/lazy.h" | ||||
#include "megbrain/imperative/transformations/scalar.h" | #include "megbrain/imperative/transformations/scalar.h" | ||||
#include "megbrain/imperative/transformations/symbol.h" | #include "megbrain/imperative/transformations/symbol.h" | ||||
@@ -947,6 +948,13 @@ void init_tensor(py::module m) { | |||||
m.def("enable_cupti", &cupti::enable); | m.def("enable_cupti", &cupti::enable); | ||||
m.def("disable_cupti", &cupti::disable); | m.def("disable_cupti", &cupti::disable); | ||||
m.def("cupti_available", &cupti::available); | m.def("cupti_available", &cupti::available); | ||||
static std::unique_ptr<CleanupGuard<>> group_comm_guard; | |||||
m.def("group_start", []() { | |||||
auto commtrans = std::make_shared<GroupCommTransformation>(); | |||||
group_comm_guard = transformations.register_at<Segment::GroupComm>(commtrans); | |||||
}); | |||||
m.def("group_end", []() { group_comm_guard.reset(); }); | |||||
m.def("sync", [channel]() { | m.def("sync", [channel]() { | ||||
if (channel->check_available()) { | if (channel->check_available()) { | ||||
channel->sync(); | channel->sync(); | ||||
@@ -16,6 +16,7 @@ struct TransformationManager { | |||||
public: | public: | ||||
enum Segment { | enum Segment { | ||||
ModuleTrace, | ModuleTrace, | ||||
GroupComm, | |||||
DTypePromote, | DTypePromote, | ||||
DimExpansion, | DimExpansion, | ||||
Format, | Format, | ||||
@@ -26,7 +27,7 @@ public: | |||||
Eval, | Eval, | ||||
}; | }; | ||||
std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments; | |||||
std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments; | |||||
private: | private: | ||||
template <Segment segment> | template <Segment segment> | ||||
@@ -237,3 +237,32 @@ def test_get_cuda_compute_capability(): | |||||
assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0 | assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0 | ||||
worker() | worker() | ||||
@pytest.mark.require_ngpu(3) | |||||
@pytest.mark.isolated_distributed | |||||
def test_batch_send_recv(): | |||||
import megengine.distributed.functional as DF | |||||
@dist.launcher(n_gpus=3) | |||||
def worker(): | |||||
rank = dist.get_rank() | |||||
dist.group_start() | |||||
for i in range(3): | |||||
tensor = mge.tensor(np.ones(10000)) * rank | |||||
if i == 2: | |||||
tensor *= i | |||||
DF._remote_send_nobackward(tensor, (rank + 1) % 3) | |||||
DF._remote_recv_nobackward( | |||||
src_rank=(rank + 1) % 3, dtype="float32", shape=(10000,) | |||||
) | |||||
DF._remote_send_nobackward(tensor, (rank - 1) % 3) | |||||
recv = DF._remote_recv_nobackward( | |||||
src_rank=(rank - 1) % 3, dtype="float32", shape=(10000,) | |||||
) | |||||
if i == 2: | |||||
recv2 = recv | |||||
dist.group_end() | |||||
np.testing.assert_equal(recv2.numpy(), (rank - 1) % 3 * 2 * np.ones(10000)) | |||||
worker() |
@@ -1,14 +1,19 @@ | |||||
#include "megbrain/imperative/ops/io_remote.h" | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
#include <algorithm> | |||||
#include <functional> | |||||
#include <numeric> | |||||
#include "../blob_manager_impl.h" | |||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
#include "megbrain/opr/io_remote.h" | #include "megbrain/opr/io_remote.h" | ||||
#include "megbrain/opr/megray_helper.h" | |||||
#include "megbrain/opr/mm_handler.h" | #include "megbrain/opr/mm_handler.h" | ||||
#endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -46,15 +51,164 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||||
recv.backend)); | recv.backend)); | ||||
} | } | ||||
TensorPtr megray_recv_tensor( | |||||
std::shared_ptr<MegRay::Communicator> megray_comm, TensorLayout& layout, | |||||
CompNode cn, uint32_t rank_from) { | |||||
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, layout); | |||||
auto megray_ctx = mgb::opr::get_megray_context(cn); | |||||
size_t data_size = layout.total_nr_elems(); | |||||
auto status = megray_comm->recv( | |||||
out.raw_ptr(), data_size, mgb::opr::get_megray_dtype(layout.dtype), | |||||
rank_from, megray_ctx); | |||||
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); | |||||
return Tensor::make(out); | |||||
} | |||||
void megray_send_tensor( | |||||
std::shared_ptr<MegRay::Communicator> megray_comm, const TensorPtr& src, | |||||
uint32_t rank_to) { | |||||
auto&& tensor = src->dev_tensor(); | |||||
auto&& ishp = src->shape(); | |||||
size_t data_size = ishp.total_nr_elems(); | |||||
auto megray_ctx = mgb::opr::get_megray_context(src->comp_node()); | |||||
auto status = megray_comm->send( | |||||
src->dev_tensor().raw_ptr(), data_size, | |||||
mgb::opr::get_megray_dtype(src->layout().dtype), rank_to, megray_ctx); | |||||
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | |||||
} | |||||
TensorLayout create_layout(const std::vector<int32_t>& shape, DType dtype) { | |||||
TensorShape tshape; | |||||
tshape.ndim = shape.size(); | |||||
mgb_assert(tshape.ndim <= TensorLayout::MAX_NDIM); | |||||
std::copy(shape.begin(), shape.end(), tshape.shape); | |||||
return TensorLayout(tshape, dtype); | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_send( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||||
auto&& dtype = input_descs[0].layout.dtype; | |||||
auto&& cn = input_descs[0].comp_node; | |||||
return {{{TensorLayout({0}, dtype), cn}}, true}; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor_remote_send( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& op = def.cast_final_safe<RemoteSend>(); | |||||
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||||
std::string("init_all_cards")); | |||||
if (!megray_comm) { | |||||
return proxy_graph_detail::apply_on_physical_tensor( | |||||
def, inputs, output_descs, validated); | |||||
} | |||||
mgb_assert(megray_comm != nullptr); | |||||
megray_send_tensor(megray_comm, inputs[0], op.rank_to); | |||||
TensorLayout layout({0}, inputs[0]->dtype()); | |||||
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( | |||||
inputs[0]->comp_node(), layout); | |||||
return {Tensor::make(out)}; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_recv( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||||
auto& op = def.cast_final_safe<RemoteRecv>(); | |||||
return {{{create_layout(op.shape, op.dtype), op.cn}}, true}; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor_remote_recv( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& op = def.cast_final_safe<RemoteRecv>(); | |||||
auto layout = create_layout(op.shape, op.dtype); | |||||
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||||
std::string("init_all_cards")); | |||||
if (!megray_comm) { | |||||
return proxy_graph_detail::apply_on_physical_tensor( | |||||
def, inputs, output_descs, validated); | |||||
} | |||||
auto&& out = megray_recv_tensor(megray_comm, layout, op.cn, op.rank_from); | |||||
return {out}; | |||||
} | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
for (size_t i; i < inputs.size(); i++) { | |||||
layout_checker[i] = [](const TensorLayout& layout) { | |||||
return layout.is_contiguous(); | |||||
}; | |||||
} | |||||
return layout_checker; | |||||
} | |||||
OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | ||||
.apply_on_var_node(apply_on_var_node_remote_send) | .apply_on_var_node(apply_on_var_node_remote_send) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor_remote_send) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible_remote_send) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | .fallback(); | ||||
OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | ||||
.apply_on_var_node(apply_on_var_node_remote_recv) | .apply_on_var_node(apply_on_var_node_remote_recv) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor_remote_recv) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible_remote_recv) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | .fallback(); | ||||
} // anonymous namespace | |||||
SmallVector<TensorPtr> apply_on_physical_tensor_batch_send_recv( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& op = def.cast_final_safe<BatchSendRecvOp>(); | |||||
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||||
std::string("init_all_cards")); | |||||
mgb_assert(megray_comm != nullptr); | |||||
megray_comm->group_start(); | |||||
SmallVector<TensorPtr> outputs; | |||||
size_t ind = 0; | |||||
for (auto&& op_ : op.op_list) { | |||||
if (op_->same_type<RemoteSend>()) { | |||||
auto&& send_op = op_->cast_final_safe<RemoteSend>(); | |||||
auto&& tensor = inputs[ind]; | |||||
megray_send_tensor(megray_comm, tensor, send_op.rank_to); | |||||
ind++; | |||||
} else { | |||||
mgb_assert(op_->same_type<RemoteRecv>()); | |||||
auto&& recv_op = op_->cast_final_safe<RemoteRecv>(); | |||||
auto layout = create_layout(recv_op.shape, recv_op.dtype); | |||||
auto&& out = megray_recv_tensor( | |||||
megray_comm, layout, recv_op.cn, recv_op.rank_from); | |||||
outputs.push_back(out); | |||||
} | |||||
} | |||||
megray_comm->group_end(); | |||||
return outputs; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> | |||||
infer_output_attrs_fallible_batch_send_recv( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||||
auto& op = def.cast_final_safe<BatchSendRecvOp>(); | |||||
SmallVector<LogicalTensorDesc> output_descs; | |||||
for (auto&& op_ : op.op_list) { | |||||
if (op_->same_type<RemoteRecv>()) { | |||||
auto&& recv_op = op_->cast_final_safe<RemoteRecv>(); | |||||
output_descs.push_back( | |||||
{create_layout(recv_op.shape, recv_op.dtype), recv_op.cn}); | |||||
} | |||||
} | |||||
return {output_descs, true}; | |||||
} | |||||
OP_TRAIT_REG(BatchSendRecvOp, BatchSendRecvOp) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor_batch_send_recv) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible_batch_send_recv) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | |||||
} // namespace | |||||
#endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchSendRecvOp); | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |
@@ -0,0 +1,67 @@ | |||||
#include "megbrain/imperative/transformations/group_comm.h" | |||||
#include "megbrain/imperative/blob_manager.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/imperative/ops/io_remote.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
ValueRefList GroupCommTransformation::apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) { | |||||
for (auto inp : inputs) { | |||||
mgb_assert( | |||||
!inp.is(m_value_type), "Can not use PlaceholderValue as apply input"); | |||||
} | |||||
if (auto* apply_op = op.as<ApplyOp>()) { | |||||
if (apply_op->op().same_type<RemoteSend>()) { | |||||
auto&& send_op = apply_op->op().cast_final_safe<RemoteSend>(); | |||||
if (send_op.key[0] == 'b') { | |||||
send_inputs.push_back(inputs[0]); | |||||
record_ops.push_back(send_op.shared_from_this()); | |||||
return {}; | |||||
} | |||||
} | |||||
if (apply_op->op().same_type<RemoteRecv>()) { | |||||
auto&& recv_op = apply_op->op().cast_final_safe<RemoteRecv>(); | |||||
if (recv_op.key[0] == 'b') { | |||||
record_ops.push_back(recv_op.shared_from_this()); | |||||
auto rst = m_value_type.make(); | |||||
recv_tensors.push_back(rst); | |||||
auto outputs = ValueRefList(1); | |||||
outputs[0] = rst; | |||||
return outputs; | |||||
} | |||||
} | |||||
return imperative::apply(op, inputs); | |||||
} else { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
} | |||||
ValueRefList GroupCommTransformation::execute_batch_op() { | |||||
auto batch_op = BatchSendRecvOp::make(record_ops); | |||||
auto outputs = imperative::apply(*batch_op, send_inputs); | |||||
return outputs; | |||||
} | |||||
void GroupCommTransformation::on_unregister() noexcept { | |||||
auto rst = execute_batch_op(); | |||||
mgb_assert(rst.size() == recv_tensors.size()); | |||||
for (size_t i = 0; i < rst.size(); i++) { | |||||
auto v = recv_tensors[i].lock(); | |||||
if (v != ValueRef::nil) { | |||||
v.reset(rst[i]); | |||||
} | |||||
} | |||||
} | |||||
GroupCommTransformation::~GroupCommTransformation() { | |||||
for (auto&& recv : recv_tensors) { | |||||
mgb_assert( | |||||
recv.lock() == ValueRef::nil, | |||||
"Some PlaceholderValues are not reset after GroupCommTransformation " | |||||
"destroyed!"); | |||||
}; | |||||
} | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -0,0 +1,11 @@ | |||||
#pragma once | |||||
#include "megbrain/imperative/op_def.h" | |||||
namespace mgb::imperative { | |||||
struct BatchSendRecvOp final : OpDefImplBase<BatchSendRecvOp> { | |||||
SmallVector<std::shared_ptr<OpDef>> op_list; | |||||
BatchSendRecvOp() = default; | |||||
BatchSendRecvOp(SmallVector<std::shared_ptr<OpDef>> op_list) : op_list{op_list} {} | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,44 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/scalar.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/imperative/basic_operators.h" | |||||
#include "megbrain/imperative/dispatch.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
namespace mgb::imperative { | |||||
class PlaceholderValue final : public ObjectValue<PlaceholderValue> { | |||||
public: | |||||
std::string to_string() const override { return ssprintf("PlaceholderValue"); } | |||||
void clear() override {} | |||||
}; | |||||
class GroupCommTransformation final : public Transformation { | |||||
private: | |||||
SmallVector<ValueRef> send_inputs; | |||||
std::vector<PlaceholderValue::weak_ref_t> recv_tensors; | |||||
SmallVector<std::shared_ptr<OpDef>> record_ops; | |||||
ObjectType<PlaceholderValue> m_value_type{"PlaceholderValue"}; | |||||
public: | |||||
GroupCommTransformation() = default; | |||||
ValueRefList apply_transformation( | |||||
const Operator& op, Span<ValueRef> inputs) override; | |||||
ValueRefList execute_batch_op(); | |||||
ValueRef unwrap(ValueRef value) override { return value; } | |||||
std::string name() const override { return "GroupCommTransformation"; } | |||||
void on_unregister() noexcept override; | |||||
~GroupCommTransformation(); | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -1,4 +1,5 @@ | |||||
#include "./helper.h" | #include "./helper.h" | ||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/opr/mm_handler.h" | #include "megbrain/opr/mm_handler.h" | ||||
@@ -47,7 +48,4 @@ TEST(TestImperative, IORemote) { | |||||
t0.join(); | t0.join(); | ||||
t1.join(); | t1.join(); | ||||
} | } | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
// ./imperative_test --gtest_filter TestIORemote |
@@ -151,6 +151,28 @@ void GroupManager::bcast_addr( | |||||
} | } | ||||
} | } | ||||
void GroupManager::bcast_nccluniqueid( | |||||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
uint32_t root) { | |||||
std::unique_lock<std::mutex> lk{m_key2nccl_id_mtx}; | |||||
if (rank == root) { | |||||
m_key2nccl_id[key] = id; | |||||
} | |||||
m_key2nccl_id_size[key]++; | |||||
if (m_key2nccl_id_size[key] == size) { | |||||
m_key2nccl_id_flag[key] = true; | |||||
m_bcast_cv.notify_all(); | |||||
} else { | |||||
m_bcast_cv.wait(lk, [&] { return m_key2nccl_id_flag.count(key) > 0; }); | |||||
} | |||||
id = m_key2nccl_id[key]; | |||||
m_key2nccl_id_size[key]--; | |||||
if (m_key2nccl_id_size[key] == 0) { | |||||
m_key2nccl_id.erase(key); | |||||
m_key2nccl_id_flag.erase(key); | |||||
} | |||||
} | |||||
void GroupManager::set_output_shape(const std::string& key, const TensorShape& shape) { | void GroupManager::set_output_shape(const std::string& key, const TensorShape& shape) { | ||||
auto&& group = get_group(key); | auto&& group = get_group(key); | ||||
group.set_output_shape(key, shape); | group.set_output_shape(key, shape); | ||||
@@ -67,6 +67,15 @@ void MegRayCommBuilder::emplace( | |||||
m_megray_comms.emplace(hash, comm); | m_megray_comms.emplace(hash, comm); | ||||
} | } | ||||
void MegRayCommBuilder::remove( | |||||
uint64_t hash, std::shared_ptr<MegRay::Communicator> comm) { | |||||
std::unique_lock<std::mutex> lk(m_map_mtx); | |||||
auto it = m_megray_comms.find(hash); | |||||
if (it != m_megray_comms.end()) { | |||||
m_megray_comms.erase(hash); | |||||
} | |||||
} | |||||
std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | ||||
uint64_t hash, std::string key, uint32_t size, uint32_t rank, | uint64_t hash, std::string key, uint32_t size, uint32_t rank, | ||||
MegRay::Backend backend, std::shared_ptr<mgb::opr::GroupClient> group_client) { | MegRay::Backend backend, std::shared_ptr<mgb::opr::GroupClient> group_client) { | ||||
@@ -104,5 +113,3 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||||
MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | ||||
std::mutex MegRayCommBuilder::sm_instance_mtx; | std::mutex MegRayCommBuilder::sm_instance_mtx; | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -45,6 +45,7 @@ public: | |||||
RUNSERVER(get_output_shape); | RUNSERVER(get_output_shape); | ||||
RUNSERVER(bcast_addr); | RUNSERVER(bcast_addr); | ||||
RUNSERVER(group_barrier); | RUNSERVER(group_barrier); | ||||
RUNSERVER(bcast_nccluniqueid); | |||||
mgb_assert(false, "invalid rpc request"); | mgb_assert(false, "invalid rpc request"); | ||||
} | } | ||||
@@ -53,6 +54,7 @@ private: | |||||
void set_output_shape(void* input_ptr, size_t input_len, std::string* output); | void set_output_shape(void* input_ptr, size_t input_len, std::string* output); | ||||
void get_output_shape(void* input_ptr, size_t input_len, std::string* output); | void get_output_shape(void* input_ptr, size_t input_len, std::string* output); | ||||
void bcast_addr(void* input_ptr, size_t input_len, std::string* output); | void bcast_addr(void* input_ptr, size_t input_len, std::string* output); | ||||
void bcast_nccluniqueid(void* input_ptr, size_t input_len, std::string* output); | |||||
void group_barrier(void* input_ptr, size_t input_len, std::string* output); | void group_barrier(void* input_ptr, size_t input_len, std::string* output); | ||||
private: | private: | ||||
@@ -116,6 +118,15 @@ void GroupServerProxy::bcast_addr( | |||||
rsp.SerializeToString(output); | rsp.SerializeToString(output); | ||||
} | } | ||||
void GroupServerProxy::bcast_nccluniqueid( | |||||
void* input_ptr, size_t input_len, std::string* output) { | |||||
INFO_INIT(mm_handler, BcastNcclUniqueId); | |||||
std::string id = req.id(); | |||||
m_mgr.bcast_nccluniqueid(req.key(), id, req.size(), req.rank(), req.root()); | |||||
rsp.set_id(id); | |||||
rsp.SerializeToString(output); | |||||
} | |||||
void GroupServerProxy::group_barrier( | void GroupServerProxy::group_barrier( | ||||
void* input_ptr, size_t input_len, std::string* output) { | void* input_ptr, size_t input_len, std::string* output) { | ||||
INFO_INIT(mm_handler, GroupBarrier); | INFO_INIT(mm_handler, GroupBarrier); | ||||
@@ -201,6 +212,19 @@ void GroupClientProxy::bcast_addr( | |||||
port = rsp.port(); | port = rsp.port(); | ||||
} | } | ||||
void GroupClientProxy::bcast_nccluniqueid( | |||||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
uint32_t root) { | |||||
INFO_INIT(mm_handler, bcast_nccluniqueid, BcastNcclUniqueId); | |||||
req.set_id(id.data(), id.size()); | |||||
req.set_key(key.data(), key.size()); | |||||
req.set_size(size); | |||||
req.set_rank(rank); | |||||
req.set_root(root); | |||||
SOLVE_REQUEST(func_name, req, rsp); | |||||
id = rsp.id(); | |||||
} | |||||
uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | ||||
INFO_INIT(mm_handler, group_barrier, GroupBarrier); | INFO_INIT(mm_handler, group_barrier, GroupBarrier); | ||||
req.set_size(size); | req.set_size(size); | ||||
@@ -209,6 +233,40 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||||
return rsp.size(); | return rsp.size(); | ||||
} | } | ||||
std::shared_ptr<MegRay::Communicator> BatchSendRecvHelper::get(std::string&& key) { | |||||
auto ptr = megray_comm_cache.find(key); | |||||
if (ptr != megray_comm_cache.end()) { | |||||
return megray_comm_cache[key]; | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>> | |||||
BatchSendRecvHelper::megray_comm_cache{}; | |||||
bool BatchSendRecvHelper::init( | |||||
int nranks, int rank, std::string ip, int port, int root) { | |||||
auto megray_comm = | |||||
MegRay::get_communicator(nranks, rank, MegRay::Backend::MEGRAY_NCCL); | |||||
auto group_client = | |||||
std::make_shared<opr::GroupClientProxy>(ssprintf("%s:%d", ip.data(), port)); | |||||
auto cb = [=](char* nccl_buffer, size_t len) { | |||||
std::string id; | |||||
id.resize(128); | |||||
if (rank == root) { | |||||
memcpy(id.data(), nccl_buffer, len); | |||||
} | |||||
group_client->bcast_nccluniqueid("init_all_cards", id, nranks, rank, root); | |||||
if (rank != root) { | |||||
memcpy(nccl_buffer, id.data(), len); | |||||
} | |||||
}; | |||||
megray_comm->init(cb); | |||||
return megray_comm_cache.insert({std::string("init_all_cards"), megray_comm}) | |||||
.second; | |||||
} | |||||
#undef INFO_INIT | #undef INFO_INIT | ||||
#undef SOLVE_REQUEST | #undef SOLVE_REQUEST | ||||
@@ -77,6 +77,11 @@ public: | |||||
std::string& master_ip, int& port, const std::string& key, uint32_t size, | std::string& master_ip, int& port, const std::string& key, uint32_t size, | ||||
uint32_t rank, uint32_t root); | uint32_t rank, uint32_t root); | ||||
//! bcast uid | |||||
void bcast_nccluniqueid( | |||||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
uint32_t root); | |||||
//! Set output shape of this key | //! Set output shape of this key | ||||
void set_output_shape(const std::string& key, const TensorShape& shape); | void set_output_shape(const std::string& key, const TensorShape& shape); | ||||
@@ -101,6 +106,12 @@ private: | |||||
std::mutex m_key2addr_mtx; | std::mutex m_key2addr_mtx; | ||||
std::condition_variable m_bcast_cv; | std::condition_variable m_bcast_cv; | ||||
//! key -> ncclid | |||||
std::unordered_map<std::string, std::string> m_key2nccl_id; | |||||
std::unordered_map<std::string, uint32_t> m_key2nccl_id_size; | |||||
std::unordered_map<std::string, bool> m_key2nccl_id_flag; | |||||
std::mutex m_key2nccl_id_mtx; | |||||
//! barrier | //! barrier | ||||
uint32_t m_barrier_size; | uint32_t m_barrier_size; | ||||
std::set<uint32_t> m_barrier_set; | std::set<uint32_t> m_barrier_set; | ||||
@@ -128,6 +139,10 @@ public: | |||||
std::string& master_ip, int& port, const std::string& key, uint32_t size, | std::string& master_ip, int& port, const std::string& key, uint32_t size, | ||||
uint32_t rank, uint32_t root) = 0; | uint32_t rank, uint32_t root) = 0; | ||||
virtual void bcast_nccluniqueid( | |||||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
uint32_t root) = 0; | |||||
virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; | virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; | ||||
virtual TensorShape get_output_shape(const std::string& key) = 0; | virtual TensorShape get_output_shape(const std::string& key) = 0; | ||||
@@ -23,6 +23,7 @@ class MegRayCommBuilder { | |||||
private: | private: | ||||
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | ||||
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | ||||
void remove(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | |||||
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | ||||
std::mutex m_map_mtx; | std::mutex m_map_mtx; | ||||
@@ -39,6 +39,10 @@ public: | |||||
std::string& master_ip, int& port, const std::string& key, uint32_t size, | std::string& master_ip, int& port, const std::string& key, uint32_t size, | ||||
uint32_t rank, uint32_t root) override; | uint32_t rank, uint32_t root) override; | ||||
void bcast_nccluniqueid( | |||||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
uint32_t root) override; | |||||
void set_output_shape(const std::string& key, const TensorShape& shape) override; | void set_output_shape(const std::string& key, const TensorShape& shape) override; | ||||
TensorShape get_output_shape(const std::string& key) override; | TensorShape get_output_shape(const std::string& key) override; | ||||
@@ -52,6 +56,34 @@ private: | |||||
void* m_stub; | void* m_stub; | ||||
}; | }; | ||||
template <typename T> | |||||
class ProcessGlobal { // thread safe | |||||
public: | |||||
template <class... Args> | |||||
static std::shared_ptr<T>& getInstance(Args&&... args) { | |||||
static auto instance = std::make_shared<T>(std::forward<Args>(args)...); | |||||
return instance; | |||||
} | |||||
protected: | |||||
template <class... Args> | |||||
ProcessGlobal(Args&&... args); | |||||
ProcessGlobal() = default; | |||||
public: | |||||
ProcessGlobal(ProcessGlobal const&) = delete; | |||||
void operator=(ProcessGlobal const&) = delete; | |||||
}; | |||||
class BatchSendRecvHelper : public ProcessGlobal<BatchSendRecvHelper> { | |||||
static std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>> | |||||
megray_comm_cache; | |||||
public: | |||||
std::shared_ptr<MegRay::Communicator> get(std::string&&); | |||||
bool init(int nranks, int rank, std::string ip, int port, int root); | |||||
}; | |||||
/* ======================== ZmqRpcServerMgr ========================== */ | /* ======================== ZmqRpcServerMgr ========================== */ | ||||
int create_zmqrpc_server(const std::string& server_addr, int port); | int create_zmqrpc_server(const std::string& server_addr, int port); | ||||
@@ -60,5 +92,3 @@ int create_zmqrpc_server(const std::string& server_addr, int port); | |||||
} // namespace mgb | } // namespace mgb | ||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -30,6 +30,18 @@ message BcastAddrResponse { | |||||
int32 port = 2; | int32 port = 2; | ||||
} | } | ||||
message BcastNcclUniqueIdRequest{ | |||||
string key = 1; | |||||
bytes id = 2; | |||||
uint32 size =3 ; | |||||
uint32 rank = 4; | |||||
uint32 root =5; | |||||
} | |||||
message BcastNcclUniqueIdResponse{ | |||||
bytes id = 1; | |||||
} | |||||
message SetOutputShapeRequest { | message SetOutputShapeRequest { | ||||
string key = 1; | string key = 1; | ||||
TensorShape shape = 2; | TensorShape shape = 2; | ||||
@@ -26,6 +26,12 @@ public: | |||||
return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); | return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); | ||||
} | } | ||||
void bcast_nccluniqueid( | |||||
const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
uint32_t root) override { | |||||
return m_mgr.bcast_nccluniqueid(key, id, size, rank, root); | |||||
} | |||||
void set_output_shape(const std::string& key, const TensorShape& shape) override { | void set_output_shape(const std::string& key, const TensorShape& shape) override { | ||||
m_mgr.set_output_shape(key, shape); | m_mgr.set_output_shape(key, shape); | ||||
} | } | ||||