Browse Source

refactor(mge/distribute): use is_root (and rank) in stead of rank and root at collective comm

GitOrigin-RevId: dccdb71553
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
c7e6c658fd
17 changed files with 286 additions and 228 deletions
  1. +22
    -45
      python_module/megengine/distributed/functional.py
  2. +4
    -5
      python_module/megengine/distributed/helper.py
  3. +12
    -2
      python_module/megengine/optimizer/optimizer.py
  4. +14
    -14
      python_module/src/cpp/opr_defs.cpp
  5. +8
    -8
      python_module/src/cpp/opr_defs.h
  6. +2
    -2
      python_module/test/unit/distributed/test_functional.py
  7. +26
    -23
      src/opr-mm/impl/collective_comm.cpp
  8. +4
    -4
      src/opr-mm/impl/collective_comm.oprdecl
  9. +61
    -25
      src/opr-mm/impl/group_manager.cpp
  10. +9
    -10
      src/opr-mm/impl/io_remote.cpp
  11. +15
    -10
      src/opr-mm/impl/mm_handler.cpp
  12. +26
    -23
      src/opr-mm/include/megbrain/opr/collective_comm.h
  13. +29
    -13
      src/opr-mm/include/megbrain/opr/group_manager.h
  14. +5
    -2
      src/opr-mm/include/megbrain/opr/mm_handler.h
  15. +7
    -4
      src/opr-mm/proto/mm_handler.proto
  16. +36
    -34
      src/opr-mm/test/collective_comm.cpp
  17. +6
    -4
      src/opr-mm/test/io_remote.cpp

+ 22
- 45
python_module/megengine/distributed/functional.py View File

@@ -26,25 +26,17 @@ def reduce_sum(
tensor: Tensor,
key: str,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
is_root: Optional[bool] = None,
) -> Tensor:
"""Create reduce_sum operator for collective communication

:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param is_root: whether this is a root node
"""
return _collective_comm(
tensor,
key,
CollParam.Mode.REDUCE_SUM,
nr_ranks,
rank,
root,
device=tensor.device,
tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
)


@@ -52,24 +44,21 @@ def broadcast(
tensor: Tensor,
key: str,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
is_root: Optional[bool] = None,
) -> Tensor:
"""Create broadcast operator for collective communication

:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param is_root: whether this is a root node
"""
if key is None:
key = tensor._symvar.name
if is_root is None:
is_root = get_rank() == 0

if rank is None:
rank = get_rank()

if rank == root:
if is_root:
inp = tensor
else:
inp = tensor._symvar.owner_graph
@@ -79,8 +68,7 @@ def broadcast(
key,
CollParam.Mode.BROADCAST,
nr_ranks,
rank,
root,
is_root,
dtype=tensor.dtype,
device=tensor.device,
)
@@ -94,9 +82,9 @@ def all_gather(
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param rank: rank of this node
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank, 0)
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank)


def reduce_scatter_sum(
@@ -107,69 +95,58 @@ def reduce_scatter_sum(
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param rank: rank of this node
"""
return _collective_comm(
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank,
)


def all_reduce_sum(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
) -> Tensor:
def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
"""Create all_reduce_sum operator for collective communication

:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks, rank)
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks)


def all_reduce_max(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
) -> Tensor:
def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
"""Create all_reduce_max operator for collective communication

:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks, rank)
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks)


def all_reduce_min(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
) -> Tensor:
def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
"""Create all_reduce_min operator for collective communication

:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks, rank)
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks)


def bcast_param(
inp: Union[Buffer, Parameter],
key: str,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
is_root: Optional[bool] = None,
) -> None:
"""Broadcast parameters among devices

:param inp: input Buffer or Parameter to be synchronized
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param is_root: whether this is a root node
"""
if not is_distributed():
return
assert isinstance(inp, (Buffer, Parameter))
bcast_res = broadcast(inp, key, nr_ranks, rank, root)
bcast_res = broadcast(inp, key, nr_ranks, is_root)
add_update(inp, bcast_res, alpha=0)

+ 4
- 5
python_module/megengine/distributed/helper.py View File

@@ -19,8 +19,8 @@ def collective_comm_symvar(
key: str,
op: CollParam.Mode,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
dtype: Optional[type] = None,
device: Optional[mgb.CompNode] = None,
comp_graph: Optional[mgb.CompGraph] = None,
@@ -31,8 +31,7 @@ def collective_comm_symvar(
:param key: unique identifier for collective communication
:param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param is_root: whether this node is root node
:param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default
@@ -41,8 +40,8 @@ def collective_comm_symvar(
inp,
key=str(key),
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(),
rank=rank if rank is not None else get_rank(),
root=root,
is_root=is_root if is_root is not None else (get_rank() == 0),
rank=rank if rank is not None else -1,
server_addr=get_master_ip(),
port=get_master_port(),
param=CollParam(mode=op),


+ 12
- 2
python_module/megengine/optimizer/optimizer.py View File

@@ -17,7 +17,13 @@ import numpy as np
from .._internal.config import opr_priority_scope
from ..core import Buffer, Parameter, Tensor, TensorDict
from ..core.graph import get_default_graph
from ..distributed import all_reduce_sum, bcast_param, get_world_size, is_distributed
from ..distributed import (
all_reduce_sum,
bcast_param,
get_rank,
get_world_size,
is_distributed,
)
from ..distributed.util import get_group_id
from ..functional import add_update
from ..functional import grad as grad_func
@@ -222,7 +228,11 @@ class Optimizer(metaclass=ABCMeta):
def bcast_param(self):
for group in self.param_groups:
for param in group["params"]:
bcast_param(param, "bcast_param_" + str(get_group_id()))
bcast_param(
param,
"bcast_param_" + str(get_group_id()),
is_root=(get_rank() == 0),
)

def state_dict(self) -> Dict:
r"""Export the optimizer state.


+ 14
- 14
python_module/src/cpp/opr_defs.cpp View File

@@ -93,10 +93,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
}

SymbolVar _Opr::collective_comm_with_input(
SymbolVar inpvar, const std::string& key,
const size_t nr_devices, const uint32_t rank, const uint32_t root,
const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype,
SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
SymbolVarArray inputs(1, inpvar);
@@ -111,15 +110,15 @@ SymbolVar _Opr::collective_comm_with_input(
if (dtype != Py_None) {
_dtype = npy::dtype_np2mgb(dtype);
}
return CollectiveComm::make(inputs, graph, key, nr_devices, rank, root, group_mgr,
dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0];
return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank,
group_mgr, dev_buffer_arr, param, _dtype,
backend, config, disable.get_val())[0];
}

SymbolVar _Opr::collective_comm_without_input(
CompGraph& cg, const std::string& key,
const size_t nr_devices, const uint32_t rank, const uint32_t root,
const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype,
CompGraph& cg, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
SymbolVarArray inputs;
@@ -134,8 +133,9 @@ SymbolVar _Opr::collective_comm_without_input(
if (dtype != Py_None) {
_dtype = npy::dtype_np2mgb(dtype);
}
return CollectiveComm::make(inputs, &graph, key, nr_devices, rank, root, group_mgr,
dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0];
return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank,
group_mgr, dev_buffer_arr, param, _dtype,
backend, config, disable.get_val())[0];
}

#else
@@ -172,7 +172,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,

SymbolVar _Opr::collective_comm_with_input(
SymbolVar inpvar, const std::string& key,
const size_t nr_devices, const uint32_t rank, const uint32_t root,
const size_t nr_devices, const bool is_root, const int rank,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
@@ -181,7 +181,7 @@ SymbolVar _Opr::collective_comm_with_input(

SymbolVar _Opr::collective_comm_without_input(
CompGraph& cg, const std::string& key,
const size_t nr_devices, const uint32_t rank, const uint32_t root,
const size_t nr_devices, const bool is_root, const int rank,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {


+ 8
- 8
python_module/src/cpp/opr_defs.h View File

@@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port,

static SymbolVar collective_comm_with_input(
SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const uint32_t rank, const uint32_t root, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable);
const bool is_root, const int rank, const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype, const std::string& backend,
SharedND* output_buf, const OperatorNodeConfig& config,
const SharedScalar& disable);

static SymbolVar collective_comm_without_input(
CompGraph& graph, const std::string& key, const size_t nr_devices,
const uint32_t rank, const uint32_t root, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable);
const bool is_root, const int rank, const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype, const std::string& backend,
SharedND* output_buf, const OperatorNodeConfig& config,
const SharedScalar& disable);

// misc
static SymbolVarArray extern_c_opr_placeholder(


+ 2
- 2
python_module/test/unit/distributed/test_functional.py View File

@@ -102,7 +102,7 @@ def test_all_gather():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.all_gather(inp, "x")
output = dist.functional.all_gather(inp, "x", rank=rank)
assert np.allclose(output.numpy(), expect)

def check(shape, backend):
@@ -135,7 +135,7 @@ def test_reduce_scatter_sum():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.reduce_scatter_sum(inp, "x")
output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank)
assert np.allclose(output.numpy(), expect)

def check(shape, backend):


+ 26
- 23
src/opr-mm/impl/collective_comm.cpp View File

@@ -368,8 +368,8 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {

CollectiveComm::CollectiveComm(
VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config,
@@ -380,9 +380,9 @@ CollectiveComm::CollectiveComm(
m_backend(backend),
m_group_client{std::move(group_client)},
m_nr_devices(nr_devices),
m_is_root(is_root),
m_rank(rank),
m_key(key),
m_root(root),
m_dev_buffers(dev_buffer_arr),
m_disable{disable} {
for (auto i : inputs) {
@@ -422,28 +422,28 @@ CollectiveComm::CollectiveComm(

SymbolVarArray CollectiveComm::make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) {
SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices,
nullptr);
return make(inputs, graph, key, nr_devices, rank, root, group_client,
return make(inputs, graph, key, nr_devices, is_root, rank, group_client,
dev_buffer_arr, param, dtype, backend, config);
}

SymbolVarArray CollectiveComm::make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const Param& param, const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) {
auto inpvars = cg::to_var_node_array(inputs);
auto opr = graph->insert_opr(std::make_unique<CollectiveComm>(
inpvars, graph, key, nr_devices, rank, root, std::move(group_client),
inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client),
param, dtype, backend, dev_buffer_arr, config, disable));
mgb_assert(!opr->output().empty());
return cg::to_symbol_var_array(opr->output());
@@ -452,11 +452,14 @@ SymbolVarArray CollectiveComm::make(
void CollectiveComm::opr_register() {
if (m_init)
return;
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node())
.cuda_env();
auto&& comp_node = output(0)->comp_node();

auto reg_info = m_group_client->opr_register(
m_key, m_nr_devices, m_is_root, m_rank,
comp_node.get_uid());

auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank,
reinterpret_cast<uintptr_t>(cuda_env.stream));
m_rank = reg_info.rank;
m_root = reg_info.root_rank;

MegRayCommunicatorBuilder* builder;

@@ -468,7 +471,7 @@ void CollectiveComm::opr_register() {
}

m_megray_comm = builder->get_megray_comm(
hash, m_key, m_nr_devices, m_rank,
reg_info.hash, m_key, m_nr_devices, m_rank,
get_megray_backend(m_backend), m_group_client);

m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
@@ -606,8 +609,8 @@ VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const {
}

auto gvar = CollectiveComm::make(
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_rank, m_root,
m_group_client, mode, m_dtype, m_backend,
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root,
m_rank, m_group_client, mode, m_dtype, m_backend,
OperatorNodeConfig{}.comp_node_arr(cn_arr));

if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) {
@@ -733,11 +736,11 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>();
return opr::CollectiveComm::make(to_symbol_var_array(inputs),
ctx.owner_graph(opr_, inputs), opr.key(),
opr.nr_devices(), opr.rank(), opr.root(),
opr.group_client(), opr.dev_buffers(),
opr.param(), opr.dtype(), opr.backend(), config)[0]
return opr::CollectiveComm::make(
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs),
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(),
opr.group_client(), opr.dev_buffers(), opr.param(),
opr.dtype(), opr.backend(), config)[0]
.node()
->owner_opr();
}


+ 4
- 4
src/opr-mm/impl/collective_comm.oprdecl View File

@@ -6,8 +6,8 @@ decl_raw_opr(
'to the same NCCL operation.', 'str'),
Doc('nr_devices', 'Total number of devices involved in the NCCL '
'operation to which this operator belongs.', 'int'),
Doc('rank', 'Rank of this operator', 'int'),
Doc('root', 'root rank of broadcast or reduce operation'),
Doc('is_root', 'whether this node is root node', 'bool'),
Doc('rank', 'rank of this node, if is -1, generate one', 'int'),
Doc('server_addr', 'rpc server ip address'),
Doc('port', 'server rpc listening port'),
Doc('param', 'The only component of *param* is *mode*, which refers to '
@@ -28,12 +28,12 @@ decl_raw_opr(
body = [
'if isinstance(input, _mgb.SymbolVar):',
(' output = _mgb._Opr.collective_comm_with_input(input, key, '
'nr_devices, rank, root, server_addr, port, '
'nr_devices, is_root, rank, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)'),
'else:',
' assert isinstance(input, _mgb.CompGraph)',
(' output = _mgb._Opr.collective_comm_without_input(input, key, '
'nr_devices, rank, root, server_addr, port, '
'nr_devices, is_root, rank, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)')
],
desc = ('collective communication between multiple CompNodes on multiple '


+ 61
- 25
src/opr-mm/impl/group_manager.cpp View File

@@ -16,16 +16,60 @@ using namespace opr;

/* ================= GroupInfo ================= */

void GroupInfo::sort_opr_infos() {
auto cmp = [](const GroupInfo::OprInfo& a, const GroupInfo::OprInfo& b) {
return a.comp_node_hash < b.comp_node_hash;
};
std::sort(m_opr_infos.begin(), m_opr_infos.end(), cmp);
}

void GroupInfo::gen_infos_from_opr_infos() {
// generate rank
bool rank_assgined = true;
for (auto& opr_info:m_opr_infos) {
if(opr_info.rank < 0) {
rank_assgined = false;
break;
}
}
if (!rank_assgined) {
for (size_t i = 0; i < m_opr_infos.size(); i++) {
m_opr_infos[i].rank = i;
m_rank_map.insert({m_opr_infos[i].comp_node_hash, i});
}
} else {
for (size_t i = 0; i < m_opr_infos.size(); i++) {
m_rank_map.insert(
{m_opr_infos[i].comp_node_hash, m_opr_infos[i].rank});
}
}

// generate root rank
for (auto& opr_info:m_opr_infos) {
if (opr_info.is_root) {
m_root_rank = opr_info.rank;
break;
}
}

// generate group hash
auto xxhash = XXHash{};
for (auto&& opr_info : m_opr_infos) {
xxhash.update(&opr_info.comp_node_hash, sizeof(uint64_t))
.update(&opr_info.rank, sizeof(int));
}
m_hash = xxhash.digest();
}

void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices,
uint32_t rank, uintptr_t stream) {
bool is_root, int rank, uint64_t comp_node_hash) {
std::unique_lock<std::mutex> lk{m_group_mtx};
if (m_nr_expected_devs == 0) {
m_nr_expected_devs = nr_expected_devices;
} else {
mgb_assert(m_nr_expected_devs == nr_expected_devices);
}
OprInfo opr_info = {rank, stream};
m_opr_infos.push_back(std::move(opr_info));
m_opr_infos.push_back({comp_node_hash, is_root, rank});
m_nr_registered_devs++;
m_count++;
if (m_nr_registered_devs > nr_expected_devices) {
@@ -38,6 +82,8 @@ void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices,
key.c_str(), nr_expected_devices, m_nr_registered_devs);
}
if (m_nr_expected_devs == m_nr_registered_devs) {
sort_opr_infos();
gen_infos_from_opr_infos();
m_register_cv.notify_all();
} else {
m_register_cv.wait(lk,
@@ -66,6 +112,8 @@ void GroupInfo::clear() {
m_count--;
if (m_count == 0) {
m_opr_infos.clear();
m_rank_map.clear();
m_root_rank = -1;
m_nr_expected_devs = 0;
m_nr_registered_devs = 0;
m_output_shape.invalidate();
@@ -77,14 +125,18 @@ void GroupInfo::clear() {

/* ================= GroupManager ================= */

uint64_t GroupManager::opr_register(const std::string& key, size_t nr_devices,
uint32_t rank, uintptr_t stream) {
GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uint64_t comp_node_hash) {
GroupManager::RegisterInfo ret{0, 0, 0};
auto&& group = get_group(key);
group.add_opr(key, nr_devices, rank, stream);
auto&& opr_infos = group.opr_infos();
uint64_t hash = get_hash_key(opr_infos, rank);
group.add_opr(key, nr_devices, is_root, rank, comp_node_hash);
ret.rank = group.get_rank(comp_node_hash);
ret.root_rank = group.get_root_rank();
ret.hash = group.get_group_hash() + ret.rank;
group.clear();
return hash;
return ret;
}

std::vector<std::string> GroupManager::gather_uid(const std::string& uid,
@@ -126,22 +178,6 @@ GroupInfo& GroupManager::get_group(const std::string& key) {
return m_key2group_info[key];
}

uint64_t GroupManager::get_hash_key(const std::vector<GroupInfo::OprInfo>& _infos,
uint32_t rank) {
auto cmp = [](const GroupInfo::OprInfo& lhs, const GroupInfo::OprInfo& rhs) {
return lhs.rank < rhs.rank;
};
auto infos = _infos;
std::sort(infos.begin(), infos.end(), cmp);
auto xxhash = XXHash{};
for (auto&& opr_info : infos) {
xxhash.update(&opr_info.rank, sizeof(uint32_t))
.update(&opr_info.stream, sizeof(uintptr_t));
}
xxhash.update(&rank, sizeof(uint32_t));
return xxhash.digest();
};

uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) {
std::unique_lock<std::mutex> lk{m_barrier_mtx};
if (m_barrier_set.empty()) {


+ 9
- 10
src/opr-mm/impl/io_remote.cpp View File

@@ -48,12 +48,11 @@ SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var,

void RemoteSend::scn_do_execute() {
if (!m_init) {
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node())
.cuda_env();
auto&& comp_node = output(0)->comp_node();

// rank 0 for RemoteSend
auto hash = m_group_client->opr_register(m_peer.key, 2, 0,
reinterpret_cast<uintptr_t>(cuda_env.stream));
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false,
comp_node.get_uid());

auto megray_comm_builder =
owner_graph()
@@ -62,7 +61,7 @@ void RemoteSend::scn_do_execute() {
.get_user_data_or_create<MegRayCommunicatorBuilder>();

m_megray_comm = megray_comm_builder->get_megray_comm(
hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
m_init = true;
}

@@ -152,12 +151,12 @@ SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph,

void RemoteRecv::scn_do_execute() {
if (!m_init) {
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node())
.cuda_env();
auto&& comp_node = output(0)->comp_node();

// rank 1 for RemoteRecv
auto hash = m_group_client->opr_register(m_peer.key, 2, 1,
reinterpret_cast<uintptr_t>(cuda_env.stream));
auto reg_info = m_group_client->opr_register(
m_peer.key, 2, false, 1,
comp_node.get_uid());

auto megray_comm_builder =
owner_graph()
@@ -166,7 +165,7 @@ void RemoteRecv::scn_do_execute() {
.get_user_data_or_create<MegRayCommunicatorBuilder>();

m_megray_comm = megray_comm_builder->get_megray_comm(
hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
m_init = true;
}



+ 15
- 10
src/opr-mm/impl/mm_handler.cpp View File

@@ -68,9 +68,11 @@ private:
void GroupServerProxy::opr_register(void* input_ptr, size_t input_len,
std::string *output) {
INFO_INIT(mm_handler, OprRegister);
uint64_t hash = m_mgr.opr_register(req.key(), req.nr_expected_devices(),
req.rank(), req.stream());
rsp.set_hash(hash);
auto ret = m_mgr.opr_register(req.key(), req.nr_expected_devices(),
req.is_root(), req.rank(), req.comp_node_hash());
rsp.set_hash(ret.hash);
rsp.set_rank(ret.rank);
rsp.set_root_rank(ret.root_rank);
rsp.SerializeToString(output);
}

@@ -122,11 +124,11 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len,

/* ======================== GroupClientProxy ========================== */

#define INFO_INIT(space, f_name, name) \
#define INFO_INIT(space, f_name, name) \
using Request = space::name##Request; \
using Response = space::name##Response; \
std::string func_name = #f_name; \
Request req; \
std::string func_name = #f_name; \
Request req; \
Response rsp;

#define SOLVE_REQUEST(name, req, rsp) \
@@ -145,15 +147,18 @@ GroupClientProxy::GroupClientProxy(const std::string& server_addr)
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} {
}

uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices,
uint32_t rank, uintptr_t stream) {
GroupManager::RegisterInfo GroupClientProxy::opr_register(
const std::string& key, size_t nr_devices, bool is_root, int rank,
uint64_t comp_node_hash) {
INFO_INIT(mm_handler, opr_register, OprRegister)
req.set_key(key);
req.set_is_root(is_root);
req.set_rank(rank);
req.set_stream(stream);
req.set_comp_node_hash(comp_node_hash);
req.set_nr_expected_devices(nr_devices);
SOLVE_REQUEST(func_name, req, rsp);
return rsp.hash();
GroupManager::RegisterInfo ret{rsp.hash(), rsp.rank(), rsp.root_rank()};
return ret;
}

void GroupClientProxy::set_output_shape(const std::string& key,


+ 26
- 23
src/opr-mm/include/megbrain/opr/collective_comm.h View File

@@ -26,18 +26,19 @@ public:

using Param = megdnn::param::CollectiveComm;

CollectiveComm(VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable);
CollectiveComm(
VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable);

static SymbolVarArray make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const Param& param, const DType& dtype = {},
const std::string& backend = "nccl",
@@ -45,15 +46,16 @@ public:
const std::shared_ptr<DTypeScalar>& disable =
std::make_shared<DTypeScalar>(0));

static SymbolVarArray make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype = {},
const std::string& backend = "nccl",
const OperatorNodeConfig& config = {},
const std::shared_ptr<DTypeScalar>& disable =
std::make_shared<DTypeScalar>(0));
static SymbolVarArray make(const SymbolVarArray& inputs,
ComputingGraph* const graph,
const std::string& key, const size_t nr_devices,
const bool is_root, const int rank,
std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype = {},
const std::string& backend = "nccl",
const OperatorNodeConfig& config = {},
const std::shared_ptr<DTypeScalar>& disable =
std::make_shared<DTypeScalar>(0));

const Param& param() const { return m_param; }
const DType& dtype() const { return m_dtype; }
@@ -67,9 +69,9 @@ public:
return m_dev_buffers;
}

uint32_t rank() const { return m_rank; }
uint32_t root() const { return m_root; }
bool is_root() const { return m_rank == m_root; }
int rank() const { return m_rank; }
int root() const { return m_root; }
bool is_root() const { return m_is_root; }

//! The key that identifies an NCCL clique.
//! Operators with same keys belong to the same clique.
@@ -108,12 +110,13 @@ private:

std::shared_ptr<GroupClient> m_group_client;
size_t m_nr_devices = 0;
uint32_t m_rank;
bool m_is_root;
int m_rank;
std::string m_key;
//! XXHash generated from m_key
size_t m_hash;
//! root of BROADCAST and REDUCE operation
uint32_t m_root;
int m_root;
//! rank of root of BROADCAST and REDUCE operation
Maybe<TensorShape> m_broadcast_output_shape = None;
// Whether shape infer is enabled. This is only used by BROADCAST operation,


+ 29
- 13
src/opr-mm/include/megbrain/opr/group_manager.h View File

@@ -24,12 +24,13 @@ namespace opr {
class GroupInfo {
public:
struct OprInfo {
uint32_t rank;
uintptr_t stream;
uint64_t comp_node_hash;
bool is_root;
int rank;
};

void add_opr(const std::string& key, size_t nr_expected_devices,
uint32_t graph_id, uintptr_t stream);
bool is_root, int rank, uint64_t comp_node_hash);

void set_output_shape(const std::string& key, const TensorShape& shape);

@@ -37,15 +38,25 @@ class GroupInfo {

void clear();

const std::vector<OprInfo>& opr_infos() const {return m_opr_infos; }
const std::vector<OprInfo>& opr_infos() const { return m_opr_infos; }

int get_root_rank() const { return m_root_rank; }
int get_rank(uint64_t hash) const { return m_rank_map.at(hash); }
uint64_t get_group_hash() const { return m_hash; }

private:
void sort_opr_infos();
void gen_infos_from_opr_infos();

std::vector<OprInfo> m_opr_infos;
std::unordered_map<uint64_t, int> m_rank_map;
uint64_t m_hash;
uint32_t m_nr_registered_devs;
uint32_t m_nr_expected_devs;
Maybe<TensorShape> m_output_shape;

uint32_t m_count = 0;
int m_root_rank = -1;
std::mutex m_group_mtx;
std::condition_variable m_register_cv;
std::condition_variable m_clear_cv;
@@ -61,10 +72,16 @@ class GroupManager {
public:
~GroupManager() = default;

struct RegisterInfo
{
uint64_t hash;
int rank, root_rank;
};

//! register oprs' info to server, return deduplicated hash
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
uintptr_t stream);
RegisterInfo opr_register(const std::string& key, size_t nr_devices,
bool is_root, int rank, uint64_t comp_node_hash);
//! gather uids from all ranks
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank);
@@ -80,9 +97,6 @@ class GroupManager {

private:
GroupInfo& get_group(const std::string& key);

uint64_t get_hash_key(const std::vector<GroupInfo::OprInfo>& _infos,
uint32_t rank);
//! key -> group info.
std::unordered_map<std::string, GroupInfo> m_key2group_info;
@@ -112,9 +126,11 @@ class GroupClient {
virtual ~GroupClient() = default;
public:
virtual uint64_t opr_register(const std::string& key, size_t nr_devices,
uint32_t rank, uintptr_t stream) = 0;
virtual GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uint64_t comp_node_hash) = 0;

virtual std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) = 0;


+ 5
- 2
src/opr-mm/include/megbrain/opr/mm_handler.h View File

@@ -14,6 +14,7 @@
#if MGB_ENABLE_OPR_MM

#include "megbrain/opr/collective_comm.h"
#include "megbrain/opr/group_manager.h"

using namespace mgb;
using namespace opr;
@@ -31,8 +32,10 @@ public:
GroupClientProxy(const std::string& server_addr);

//! graph registration, assign graph_id to worker.
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
uintptr_t stream) override;
GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices, bool is_root,
int rank,
uint64_t comp_node_hash) override;

std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) override;


+ 7
- 4
src/opr-mm/proto/mm_handler.proto View File

@@ -4,13 +4,16 @@ package mm_handler;

message OprRegisterRequest {
string key = 1;
uint32 rank = 2;
uint64 stream = 3;
uint32 nr_expected_devices = 4;
bool is_root = 2;
int32 rank = 3;
uint64 comp_node_hash = 4;
uint32 nr_expected_devices = 5;
}

message OprRegisterResponse {
uint64 hash = 1;
uint64 hash = 1;
int32 rank = 2;
int32 root_rank = 3;
}

message GatherUidRequest {


+ 36
- 34
src/opr-mm/test/collective_comm.cpp View File

@@ -45,9 +45,11 @@ class MockGroupClient final : public opr::GroupClient {
public:
~MockGroupClient() override = default;

uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
uintptr_t stream) {
return m_mgr.opr_register(key, nr_devices, rank, stream);
opr::GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uintptr_t stream) {
return m_mgr.opr_register(key, nr_devices, is_root, rank, stream);
}

std::vector<std::string> gather_uid(const std::string& uid,
@@ -94,9 +96,9 @@ TEST(TestOprCollectiveComm, AllReduce) {
auto x1c = opr::Copy::make(x1, cn1);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_reduce",
2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_reduce",
2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0];
auto y_expect = make_all_reduce_output(mode, {x0, x1});
auto func = graph->compile({make_callback_copy(y0, host_y0),
@@ -130,7 +132,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) {
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce",
2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
};
@@ -139,7 +141,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) {
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce",
2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
};
@@ -192,7 +194,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) {
graph0->options().graph_opt_level = 0;

auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, false, 0, client,
{Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;

@@ -211,7 +213,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) {
graph1->options().graph_opt_level = 0;

auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, false, 1, client,
{Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;

@@ -274,9 +276,9 @@ TEST(TestOprCollectiveComm, AllGather) {
auto x1c = opr::Copy::make(x1, cn1);

auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_gather",
2, 0, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_gather",
2, 1, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
auto y_expect = opr::Concat::make({x0, x1}, 0);

auto func = graph->compile({make_callback_copy(y0, host_y0),
@@ -303,7 +305,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) {
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client,
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
@@ -312,7 +314,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) {
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client,
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
@@ -361,7 +363,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) {
graph0->options().graph_opt_level = 0;

auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client,
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;

@@ -380,7 +382,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) {
graph1->options().graph_opt_level = 0;

auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client,
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;

@@ -444,9 +446,9 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) {
auto x1c = opr::Copy::make(x1, cn1);

auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_scatter_sum",
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_scatter_sum",
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
auto y_expect = make_reduce_scatter_sum_output({x0, x1});

auto func = graph->compile({make_callback_copy(y0, host_y0),
@@ -475,7 +477,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) {
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum",
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
};
@@ -484,7 +486,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) {
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum",
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
};
@@ -534,7 +536,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) {

auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum",
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;

auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0);
@@ -553,7 +555,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) {

auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum",
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;

auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1);
@@ -616,9 +618,9 @@ TEST(TestOprCollectiveComm, ReduceSum) {
auto x1c = opr::Copy::make(x1, cn1);

auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_sum",
2, 0, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
2, true, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_sum",
2, 1, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
auto y_expect = x0 + x1;

auto func = graph->compile({make_callback_copy(y0, host_y0),
@@ -644,7 +646,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) {
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client,
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
@@ -653,7 +655,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) {
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client,
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({{y1, nullptr}});
func1->execute();
@@ -699,7 +701,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) {
graph0->options().graph_opt_level = 0;

auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client,
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;

@@ -718,7 +720,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) {
graph1->options().graph_opt_level = 0;

auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client,
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;

@@ -767,12 +769,12 @@ TEST(TestOprCollectiveComm, Broadcast) {

auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "broadcast",
2, 0, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0];
2, true, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0];
auto y_dev = std::make_shared<DeviceTensorND>(DeviceTensorND()
.comp_node(cn1)
.dtype(dtype::Float32())
.resize(host_x0->shape()));
auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, 1, 0,
auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, false, 1,
client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0];

auto func = graph->compile({make_callback_copy(y0, host_y0),
@@ -797,7 +799,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) {
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client,
{Mode::BROADCAST}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
@@ -809,7 +811,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) {
.comp_node(cn1)
.dtype(dtype::Float32())
.resize(host_x0->shape()));
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client,
{y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
@@ -845,7 +847,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) {
graph0->options().graph_opt_level = 0;

auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client,
{Mode::BROADCAST}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;

@@ -863,11 +865,11 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) {
auto graph1 = ComputingGraph::make();
graph1->options().graph_opt_level = 0;

auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client,
{Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0];

auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1);
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, 1, 0, client,
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, false, 1, client,
Mode::REDUCE_SUM, dtype::Float32(), "nccl")[0];
g.node()->owner_opr()->node_prop().attribute().priority = 1;



+ 6
- 4
src/opr-mm/test/io_remote.cpp View File

@@ -26,11 +26,13 @@ class MockGroupClient final : public opr::GroupClient {
public:
~MockGroupClient() override = default;

uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
uintptr_t stream) {
return m_mgr.opr_register(key, nr_devices, rank, stream);
opr::GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uint64_t comp_node_hash) {
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash);
}
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) {
return m_mgr.gather_uid(uid, key, size, rank);


Loading…
Cancel
Save