GitOrigin-RevId: dccdb71553
tags/v0.5.0
@@ -26,25 +26,17 @@ def reduce_sum( | |||
tensor: Tensor, | |||
key: str, | |||
nr_ranks: Optional[int] = None, | |||
rank: Optional[int] = None, | |||
root: Optional[int] = 0, | |||
is_root: Optional[bool] = None, | |||
) -> Tensor: | |||
"""Create reduce_sum operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
:param root: rank of root node, use 0 as default | |||
:param is_root: whether this is a root node | |||
""" | |||
return _collective_comm( | |||
tensor, | |||
key, | |||
CollParam.Mode.REDUCE_SUM, | |||
nr_ranks, | |||
rank, | |||
root, | |||
device=tensor.device, | |||
tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, | |||
) | |||
@@ -52,24 +44,21 @@ def broadcast( | |||
tensor: Tensor, | |||
key: str, | |||
nr_ranks: Optional[int] = None, | |||
rank: Optional[int] = None, | |||
root: Optional[int] = 0, | |||
is_root: Optional[bool] = None, | |||
) -> Tensor: | |||
"""Create broadcast operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
:param root: rank of root node, use 0 as default | |||
:param is_root: whether this is a root node | |||
""" | |||
if key is None: | |||
key = tensor._symvar.name | |||
if is_root is None: | |||
is_root = get_rank() == 0 | |||
if rank is None: | |||
rank = get_rank() | |||
if rank == root: | |||
if is_root: | |||
inp = tensor | |||
else: | |||
inp = tensor._symvar.owner_graph | |||
@@ -79,8 +68,7 @@ def broadcast( | |||
key, | |||
CollParam.Mode.BROADCAST, | |||
nr_ranks, | |||
rank, | |||
root, | |||
is_root, | |||
dtype=tensor.dtype, | |||
device=tensor.device, | |||
) | |||
@@ -94,9 +82,9 @@ def all_gather( | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
:param rank: rank of this node | |||
""" | |||
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank, 0) | |||
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank) | |||
def reduce_scatter_sum( | |||
@@ -107,69 +95,58 @@ def reduce_scatter_sum( | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
:param rank: rank of this node | |||
""" | |||
return _collective_comm( | |||
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank | |||
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank, | |||
) | |||
def all_reduce_sum( | |||
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||
) -> Tensor: | |||
def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||
"""Create all_reduce_sum operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
""" | |||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks, rank) | |||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks) | |||
def all_reduce_max( | |||
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||
) -> Tensor: | |||
def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||
"""Create all_reduce_max operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
""" | |||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks, rank) | |||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks) | |||
def all_reduce_min( | |||
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||
) -> Tensor: | |||
def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||
"""Create all_reduce_min operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
""" | |||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks, rank) | |||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks) | |||
def bcast_param( | |||
inp: Union[Buffer, Parameter], | |||
key: str, | |||
nr_ranks: Optional[int] = None, | |||
rank: Optional[int] = None, | |||
root: Optional[int] = 0, | |||
is_root: Optional[bool] = None, | |||
) -> None: | |||
"""Broadcast parameters among devices | |||
:param inp: input Buffer or Parameter to be synchronized | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
:param root: rank of root node, use 0 as default | |||
:param is_root: whether this is a root node | |||
""" | |||
if not is_distributed(): | |||
return | |||
assert isinstance(inp, (Buffer, Parameter)) | |||
bcast_res = broadcast(inp, key, nr_ranks, rank, root) | |||
bcast_res = broadcast(inp, key, nr_ranks, is_root) | |||
add_update(inp, bcast_res, alpha=0) |
@@ -19,8 +19,8 @@ def collective_comm_symvar( | |||
key: str, | |||
op: CollParam.Mode, | |||
nr_ranks: Optional[int] = None, | |||
is_root: Optional[bool] = None, | |||
rank: Optional[int] = None, | |||
root: Optional[int] = 0, | |||
dtype: Optional[type] = None, | |||
device: Optional[mgb.CompNode] = None, | |||
comp_graph: Optional[mgb.CompGraph] = None, | |||
@@ -31,8 +31,7 @@ def collective_comm_symvar( | |||
:param key: unique identifier for collective communication | |||
:param op: mode of collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of the current process, use util.get_rank() as default | |||
:param root: rank of root node, use 0 as default | |||
:param is_root: whether this node is root node | |||
:param dtype: output data type, use dtype of inp as default | |||
:param device: output comp node, use comp node of inp as default | |||
:param comp_graph: output comp graph, use comp graph of inp as default | |||
@@ -41,8 +40,8 @@ def collective_comm_symvar( | |||
inp, | |||
key=str(key), | |||
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | |||
rank=rank if rank is not None else get_rank(), | |||
root=root, | |||
is_root=is_root if is_root is not None else (get_rank() == 0), | |||
rank=rank if rank is not None else -1, | |||
server_addr=get_master_ip(), | |||
port=get_master_port(), | |||
param=CollParam(mode=op), | |||
@@ -17,7 +17,13 @@ import numpy as np | |||
from .._internal.config import opr_priority_scope | |||
from ..core import Buffer, Parameter, Tensor, TensorDict | |||
from ..core.graph import get_default_graph | |||
from ..distributed import all_reduce_sum, bcast_param, get_world_size, is_distributed | |||
from ..distributed import ( | |||
all_reduce_sum, | |||
bcast_param, | |||
get_rank, | |||
get_world_size, | |||
is_distributed, | |||
) | |||
from ..distributed.util import get_group_id | |||
from ..functional import add_update | |||
from ..functional import grad as grad_func | |||
@@ -222,7 +228,11 @@ class Optimizer(metaclass=ABCMeta): | |||
def bcast_param(self): | |||
for group in self.param_groups: | |||
for param in group["params"]: | |||
bcast_param(param, "bcast_param_" + str(get_group_id())) | |||
bcast_param( | |||
param, | |||
"bcast_param_" + str(get_group_id()), | |||
is_root=(get_rank() == 0), | |||
) | |||
def state_dict(self) -> Dict: | |||
r"""Export the optimizer state. | |||
@@ -93,10 +93,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||
} | |||
SymbolVar _Opr::collective_comm_with_input( | |||
SymbolVar inpvar, const std::string& key, | |||
const size_t nr_devices, const uint32_t rank, const uint32_t root, | |||
const std::string& server_addr, const int port, | |||
PyObject* params, PyObject* dtype, | |||
SymbolVar inpvar, const std::string& key, const size_t nr_devices, | |||
const bool is_root, const int rank, const std::string& server_addr, | |||
const int port, PyObject* params, PyObject* dtype, | |||
const std::string& backend, SharedND* output_buf, | |||
const OperatorNodeConfig& config, const SharedScalar& disable) { | |||
SymbolVarArray inputs(1, inpvar); | |||
@@ -111,15 +110,15 @@ SymbolVar _Opr::collective_comm_with_input( | |||
if (dtype != Py_None) { | |||
_dtype = npy::dtype_np2mgb(dtype); | |||
} | |||
return CollectiveComm::make(inputs, graph, key, nr_devices, rank, root, group_mgr, | |||
dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0]; | |||
return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank, | |||
group_mgr, dev_buffer_arr, param, _dtype, | |||
backend, config, disable.get_val())[0]; | |||
} | |||
SymbolVar _Opr::collective_comm_without_input( | |||
CompGraph& cg, const std::string& key, | |||
const size_t nr_devices, const uint32_t rank, const uint32_t root, | |||
const std::string& server_addr, const int port, | |||
PyObject* params, PyObject* dtype, | |||
CompGraph& cg, const std::string& key, const size_t nr_devices, | |||
const bool is_root, const int rank, const std::string& server_addr, | |||
const int port, PyObject* params, PyObject* dtype, | |||
const std::string& backend, SharedND* output_buf, | |||
const OperatorNodeConfig& config, const SharedScalar& disable) { | |||
SymbolVarArray inputs; | |||
@@ -134,8 +133,9 @@ SymbolVar _Opr::collective_comm_without_input( | |||
if (dtype != Py_None) { | |||
_dtype = npy::dtype_np2mgb(dtype); | |||
} | |||
return CollectiveComm::make(inputs, &graph, key, nr_devices, rank, root, group_mgr, | |||
dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0]; | |||
return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank, | |||
group_mgr, dev_buffer_arr, param, _dtype, | |||
backend, config, disable.get_val())[0]; | |||
} | |||
#else | |||
@@ -172,7 +172,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||
SymbolVar _Opr::collective_comm_with_input( | |||
SymbolVar inpvar, const std::string& key, | |||
const size_t nr_devices, const uint32_t rank, const uint32_t root, | |||
const size_t nr_devices, const bool is_root, const int rank, | |||
const std::string& server_addr, const int port, PyObject* params, | |||
PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||
const OperatorNodeConfig& config, const SharedScalar& disable) { | |||
@@ -181,7 +181,7 @@ SymbolVar _Opr::collective_comm_with_input( | |||
SymbolVar _Opr::collective_comm_without_input( | |||
CompGraph& cg, const std::string& key, | |||
const size_t nr_devices, const uint32_t rank, const uint32_t root, | |||
const size_t nr_devices, const bool is_root, const int rank, | |||
const std::string& server_addr, const int port, PyObject* params, | |||
PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||
const OperatorNodeConfig& config, const SharedScalar& disable) { | |||
@@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port, | |||
static SymbolVar collective_comm_with_input( | |||
SymbolVar inpvar, const std::string& key, const size_t nr_devices, | |||
const uint32_t rank, const uint32_t root, const std::string& server_addr, | |||
const int port, PyObject* params, PyObject* dtype, | |||
const std::string& backend, SharedND* output_buf, | |||
const OperatorNodeConfig& config, const SharedScalar& disable); | |||
const bool is_root, const int rank, const std::string& server_addr, const int port, | |||
PyObject* params, PyObject* dtype, const std::string& backend, | |||
SharedND* output_buf, const OperatorNodeConfig& config, | |||
const SharedScalar& disable); | |||
static SymbolVar collective_comm_without_input( | |||
CompGraph& graph, const std::string& key, const size_t nr_devices, | |||
const uint32_t rank, const uint32_t root, const std::string& server_addr, | |||
const int port, PyObject* params, PyObject* dtype, | |||
const std::string& backend, SharedND* output_buf, | |||
const OperatorNodeConfig& config, const SharedScalar& disable); | |||
const bool is_root, const int rank, const std::string& server_addr, const int port, | |||
PyObject* params, PyObject* dtype, const std::string& backend, | |||
SharedND* output_buf, const OperatorNodeConfig& config, | |||
const SharedScalar& disable); | |||
// misc | |||
static SymbolVarArray extern_c_opr_placeholder( | |||
@@ -102,7 +102,7 @@ def test_all_gather(): | |||
return | |||
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
inp = tensor(data) | |||
output = dist.functional.all_gather(inp, "x") | |||
output = dist.functional.all_gather(inp, "x", rank=rank) | |||
assert np.allclose(output.numpy(), expect) | |||
def check(shape, backend): | |||
@@ -135,7 +135,7 @@ def test_reduce_scatter_sum(): | |||
return | |||
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
inp = tensor(data) | |||
output = dist.functional.reduce_scatter_sum(inp, "x") | |||
output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank) | |||
assert np.allclose(output.numpy(), expect) | |||
def check(shape, backend): | |||
@@ -368,8 +368,8 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { | |||
CollectiveComm::CollectiveComm( | |||
VarNodeArray inputs, ComputingGraph* const graph, | |||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||
const std::string& key, const size_t nr_devices, const bool is_root, | |||
const int rank, std::shared_ptr<GroupClient> group_client, | |||
const Param& param, const DType& dtype, const std::string& backend, | |||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
const OperatorNodeConfig& config, | |||
@@ -380,9 +380,9 @@ CollectiveComm::CollectiveComm( | |||
m_backend(backend), | |||
m_group_client{std::move(group_client)}, | |||
m_nr_devices(nr_devices), | |||
m_is_root(is_root), | |||
m_rank(rank), | |||
m_key(key), | |||
m_root(root), | |||
m_dev_buffers(dev_buffer_arr), | |||
m_disable{disable} { | |||
for (auto i : inputs) { | |||
@@ -422,28 +422,28 @@ CollectiveComm::CollectiveComm( | |||
SymbolVarArray CollectiveComm::make( | |||
const SymbolVarArray& inputs, ComputingGraph* const graph, | |||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||
const Param& param, const DType& dtype, const std::string& backend, | |||
const OperatorNodeConfig& config, | |||
const std::string& key, const size_t nr_devices, const bool is_root, | |||
const int rank, std::shared_ptr<GroupClient> group_client, | |||
const Param& param, const DType& dtype, const std::string& backend, | |||
const OperatorNodeConfig& config, | |||
const std::shared_ptr<DTypeScalar>& disable) { | |||
SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices, | |||
nullptr); | |||
return make(inputs, graph, key, nr_devices, rank, root, group_client, | |||
return make(inputs, graph, key, nr_devices, is_root, rank, group_client, | |||
dev_buffer_arr, param, dtype, backend, config); | |||
} | |||
SymbolVarArray CollectiveComm::make( | |||
const SymbolVarArray& inputs, ComputingGraph* const graph, | |||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||
const std::string& key, const size_t nr_devices, const bool is_root, | |||
const int rank, std::shared_ptr<GroupClient> group_client, | |||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
const Param& param, const DType& dtype, const std::string& backend, | |||
const OperatorNodeConfig& config, | |||
const std::shared_ptr<DTypeScalar>& disable) { | |||
auto inpvars = cg::to_var_node_array(inputs); | |||
auto opr = graph->insert_opr(std::make_unique<CollectiveComm>( | |||
inpvars, graph, key, nr_devices, rank, root, std::move(group_client), | |||
inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client), | |||
param, dtype, backend, dev_buffer_arr, config, disable)); | |||
mgb_assert(!opr->output().empty()); | |||
return cg::to_symbol_var_array(opr->output()); | |||
@@ -452,11 +452,14 @@ SymbolVarArray CollectiveComm::make( | |||
void CollectiveComm::opr_register() { | |||
if (m_init) | |||
return; | |||
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) | |||
.cuda_env(); | |||
auto&& comp_node = output(0)->comp_node(); | |||
auto reg_info = m_group_client->opr_register( | |||
m_key, m_nr_devices, m_is_root, m_rank, | |||
comp_node.get_uid()); | |||
auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank, | |||
reinterpret_cast<uintptr_t>(cuda_env.stream)); | |||
m_rank = reg_info.rank; | |||
m_root = reg_info.root_rank; | |||
MegRayCommunicatorBuilder* builder; | |||
@@ -468,7 +471,7 @@ void CollectiveComm::opr_register() { | |||
} | |||
m_megray_comm = builder->get_megray_comm( | |||
hash, m_key, m_nr_devices, m_rank, | |||
reg_info.hash, m_key, m_nr_devices, m_rank, | |||
get_megray_backend(m_backend), m_group_client); | |||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | |||
@@ -606,8 +609,8 @@ VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const { | |||
} | |||
auto gvar = CollectiveComm::make( | |||
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_rank, m_root, | |||
m_group_client, mode, m_dtype, m_backend, | |||
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root, | |||
m_rank, m_group_client, mode, m_dtype, m_backend, | |||
OperatorNodeConfig{}.comp_node_arr(cn_arr)); | |||
if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) { | |||
@@ -733,11 +736,11 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( | |||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | |||
return opr::CollectiveComm::make(to_symbol_var_array(inputs), | |||
ctx.owner_graph(opr_, inputs), opr.key(), | |||
opr.nr_devices(), opr.rank(), opr.root(), | |||
opr.group_client(), opr.dev_buffers(), | |||
opr.param(), opr.dtype(), opr.backend(), config)[0] | |||
return opr::CollectiveComm::make( | |||
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | |||
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | |||
opr.group_client(), opr.dev_buffers(), opr.param(), | |||
opr.dtype(), opr.backend(), config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
@@ -6,8 +6,8 @@ decl_raw_opr( | |||
'to the same NCCL operation.', 'str'), | |||
Doc('nr_devices', 'Total number of devices involved in the NCCL ' | |||
'operation to which this operator belongs.', 'int'), | |||
Doc('rank', 'Rank of this operator', 'int'), | |||
Doc('root', 'root rank of broadcast or reduce operation'), | |||
Doc('is_root', 'whether this node is root node', 'bool'), | |||
Doc('rank', 'rank of this node, if is -1, generate one', 'int'), | |||
Doc('server_addr', 'rpc server ip address'), | |||
Doc('port', 'server rpc listening port'), | |||
Doc('param', 'The only component of *param* is *mode*, which refers to ' | |||
@@ -28,12 +28,12 @@ decl_raw_opr( | |||
body = [ | |||
'if isinstance(input, _mgb.SymbolVar):', | |||
(' output = _mgb._Opr.collective_comm_with_input(input, key, ' | |||
'nr_devices, rank, root, server_addr, port, ' | |||
'nr_devices, is_root, rank, server_addr, port, ' | |||
'[param.serialize()], dtype, backend, output_buffer, config, disable)'), | |||
'else:', | |||
' assert isinstance(input, _mgb.CompGraph)', | |||
(' output = _mgb._Opr.collective_comm_without_input(input, key, ' | |||
'nr_devices, rank, root, server_addr, port, ' | |||
'nr_devices, is_root, rank, server_addr, port, ' | |||
'[param.serialize()], dtype, backend, output_buffer, config, disable)') | |||
], | |||
desc = ('collective communication between multiple CompNodes on multiple ' | |||
@@ -16,16 +16,60 @@ using namespace opr; | |||
/* ================= GroupInfo ================= */ | |||
void GroupInfo::sort_opr_infos() { | |||
auto cmp = [](const GroupInfo::OprInfo& a, const GroupInfo::OprInfo& b) { | |||
return a.comp_node_hash < b.comp_node_hash; | |||
}; | |||
std::sort(m_opr_infos.begin(), m_opr_infos.end(), cmp); | |||
} | |||
void GroupInfo::gen_infos_from_opr_infos() { | |||
// generate rank | |||
bool rank_assgined = true; | |||
for (auto& opr_info:m_opr_infos) { | |||
if(opr_info.rank < 0) { | |||
rank_assgined = false; | |||
break; | |||
} | |||
} | |||
if (!rank_assgined) { | |||
for (size_t i = 0; i < m_opr_infos.size(); i++) { | |||
m_opr_infos[i].rank = i; | |||
m_rank_map.insert({m_opr_infos[i].comp_node_hash, i}); | |||
} | |||
} else { | |||
for (size_t i = 0; i < m_opr_infos.size(); i++) { | |||
m_rank_map.insert( | |||
{m_opr_infos[i].comp_node_hash, m_opr_infos[i].rank}); | |||
} | |||
} | |||
// generate root rank | |||
for (auto& opr_info:m_opr_infos) { | |||
if (opr_info.is_root) { | |||
m_root_rank = opr_info.rank; | |||
break; | |||
} | |||
} | |||
// generate group hash | |||
auto xxhash = XXHash{}; | |||
for (auto&& opr_info : m_opr_infos) { | |||
xxhash.update(&opr_info.comp_node_hash, sizeof(uint64_t)) | |||
.update(&opr_info.rank, sizeof(int)); | |||
} | |||
m_hash = xxhash.digest(); | |||
} | |||
void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices, | |||
uint32_t rank, uintptr_t stream) { | |||
bool is_root, int rank, uint64_t comp_node_hash) { | |||
std::unique_lock<std::mutex> lk{m_group_mtx}; | |||
if (m_nr_expected_devs == 0) { | |||
m_nr_expected_devs = nr_expected_devices; | |||
} else { | |||
mgb_assert(m_nr_expected_devs == nr_expected_devices); | |||
} | |||
OprInfo opr_info = {rank, stream}; | |||
m_opr_infos.push_back(std::move(opr_info)); | |||
m_opr_infos.push_back({comp_node_hash, is_root, rank}); | |||
m_nr_registered_devs++; | |||
m_count++; | |||
if (m_nr_registered_devs > nr_expected_devices) { | |||
@@ -38,6 +82,8 @@ void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices, | |||
key.c_str(), nr_expected_devices, m_nr_registered_devs); | |||
} | |||
if (m_nr_expected_devs == m_nr_registered_devs) { | |||
sort_opr_infos(); | |||
gen_infos_from_opr_infos(); | |||
m_register_cv.notify_all(); | |||
} else { | |||
m_register_cv.wait(lk, | |||
@@ -66,6 +112,8 @@ void GroupInfo::clear() { | |||
m_count--; | |||
if (m_count == 0) { | |||
m_opr_infos.clear(); | |||
m_rank_map.clear(); | |||
m_root_rank = -1; | |||
m_nr_expected_devs = 0; | |||
m_nr_registered_devs = 0; | |||
m_output_shape.invalidate(); | |||
@@ -77,14 +125,18 @@ void GroupInfo::clear() { | |||
/* ================= GroupManager ================= */ | |||
uint64_t GroupManager::opr_register(const std::string& key, size_t nr_devices, | |||
uint32_t rank, uintptr_t stream) { | |||
GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key, | |||
size_t nr_devices, | |||
bool is_root, int rank, | |||
uint64_t comp_node_hash) { | |||
GroupManager::RegisterInfo ret{0, 0, 0}; | |||
auto&& group = get_group(key); | |||
group.add_opr(key, nr_devices, rank, stream); | |||
auto&& opr_infos = group.opr_infos(); | |||
uint64_t hash = get_hash_key(opr_infos, rank); | |||
group.add_opr(key, nr_devices, is_root, rank, comp_node_hash); | |||
ret.rank = group.get_rank(comp_node_hash); | |||
ret.root_rank = group.get_root_rank(); | |||
ret.hash = group.get_group_hash() + ret.rank; | |||
group.clear(); | |||
return hash; | |||
return ret; | |||
} | |||
std::vector<std::string> GroupManager::gather_uid(const std::string& uid, | |||
@@ -126,22 +178,6 @@ GroupInfo& GroupManager::get_group(const std::string& key) { | |||
return m_key2group_info[key]; | |||
} | |||
uint64_t GroupManager::get_hash_key(const std::vector<GroupInfo::OprInfo>& _infos, | |||
uint32_t rank) { | |||
auto cmp = [](const GroupInfo::OprInfo& lhs, const GroupInfo::OprInfo& rhs) { | |||
return lhs.rank < rhs.rank; | |||
}; | |||
auto infos = _infos; | |||
std::sort(infos.begin(), infos.end(), cmp); | |||
auto xxhash = XXHash{}; | |||
for (auto&& opr_info : infos) { | |||
xxhash.update(&opr_info.rank, sizeof(uint32_t)) | |||
.update(&opr_info.stream, sizeof(uintptr_t)); | |||
} | |||
xxhash.update(&rank, sizeof(uint32_t)); | |||
return xxhash.digest(); | |||
}; | |||
uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) { | |||
std::unique_lock<std::mutex> lk{m_barrier_mtx}; | |||
if (m_barrier_set.empty()) { | |||
@@ -48,12 +48,11 @@ SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var, | |||
void RemoteSend::scn_do_execute() { | |||
if (!m_init) { | |||
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) | |||
.cuda_env(); | |||
auto&& comp_node = output(0)->comp_node(); | |||
// rank 0 for RemoteSend | |||
auto hash = m_group_client->opr_register(m_peer.key, 2, 0, | |||
reinterpret_cast<uintptr_t>(cuda_env.stream)); | |||
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | |||
comp_node.get_uid()); | |||
auto megray_comm_builder = | |||
owner_graph() | |||
@@ -62,7 +61,7 @@ void RemoteSend::scn_do_execute() { | |||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||
m_megray_comm = megray_comm_builder->get_megray_comm( | |||
hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||
m_init = true; | |||
} | |||
@@ -152,12 +151,12 @@ SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
void RemoteRecv::scn_do_execute() { | |||
if (!m_init) { | |||
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) | |||
.cuda_env(); | |||
auto&& comp_node = output(0)->comp_node(); | |||
// rank 1 for RemoteRecv | |||
auto hash = m_group_client->opr_register(m_peer.key, 2, 1, | |||
reinterpret_cast<uintptr_t>(cuda_env.stream)); | |||
auto reg_info = m_group_client->opr_register( | |||
m_peer.key, 2, false, 1, | |||
comp_node.get_uid()); | |||
auto megray_comm_builder = | |||
owner_graph() | |||
@@ -166,7 +165,7 @@ void RemoteRecv::scn_do_execute() { | |||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||
m_megray_comm = megray_comm_builder->get_megray_comm( | |||
hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||
m_init = true; | |||
} | |||
@@ -68,9 +68,11 @@ private: | |||
void GroupServerProxy::opr_register(void* input_ptr, size_t input_len, | |||
std::string *output) { | |||
INFO_INIT(mm_handler, OprRegister); | |||
uint64_t hash = m_mgr.opr_register(req.key(), req.nr_expected_devices(), | |||
req.rank(), req.stream()); | |||
rsp.set_hash(hash); | |||
auto ret = m_mgr.opr_register(req.key(), req.nr_expected_devices(), | |||
req.is_root(), req.rank(), req.comp_node_hash()); | |||
rsp.set_hash(ret.hash); | |||
rsp.set_rank(ret.rank); | |||
rsp.set_root_rank(ret.root_rank); | |||
rsp.SerializeToString(output); | |||
} | |||
@@ -122,11 +124,11 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len, | |||
/* ======================== GroupClientProxy ========================== */ | |||
#define INFO_INIT(space, f_name, name) \ | |||
#define INFO_INIT(space, f_name, name) \ | |||
using Request = space::name##Request; \ | |||
using Response = space::name##Response; \ | |||
std::string func_name = #f_name; \ | |||
Request req; \ | |||
std::string func_name = #f_name; \ | |||
Request req; \ | |||
Response rsp; | |||
#define SOLVE_REQUEST(name, req, rsp) \ | |||
@@ -145,15 +147,18 @@ GroupClientProxy::GroupClientProxy(const std::string& server_addr) | |||
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { | |||
} | |||
uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices, | |||
uint32_t rank, uintptr_t stream) { | |||
GroupManager::RegisterInfo GroupClientProxy::opr_register( | |||
const std::string& key, size_t nr_devices, bool is_root, int rank, | |||
uint64_t comp_node_hash) { | |||
INFO_INIT(mm_handler, opr_register, OprRegister) | |||
req.set_key(key); | |||
req.set_is_root(is_root); | |||
req.set_rank(rank); | |||
req.set_stream(stream); | |||
req.set_comp_node_hash(comp_node_hash); | |||
req.set_nr_expected_devices(nr_devices); | |||
SOLVE_REQUEST(func_name, req, rsp); | |||
return rsp.hash(); | |||
GroupManager::RegisterInfo ret{rsp.hash(), rsp.rank(), rsp.root_rank()}; | |||
return ret; | |||
} | |||
void GroupClientProxy::set_output_shape(const std::string& key, | |||
@@ -26,18 +26,19 @@ public: | |||
using Param = megdnn::param::CollectiveComm; | |||
CollectiveComm(VarNodeArray inputs, ComputingGraph* const graph, | |||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||
const Param& param, const DType& dtype, const std::string& backend, | |||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
const OperatorNodeConfig& config, | |||
const std::shared_ptr<DTypeScalar>& disable); | |||
CollectiveComm( | |||
VarNodeArray inputs, ComputingGraph* const graph, | |||
const std::string& key, const size_t nr_devices, const bool is_root, | |||
const int rank, std::shared_ptr<GroupClient> group_client, | |||
const Param& param, const DType& dtype, const std::string& backend, | |||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
const OperatorNodeConfig& config, | |||
const std::shared_ptr<DTypeScalar>& disable); | |||
static SymbolVarArray make( | |||
const SymbolVarArray& inputs, ComputingGraph* const graph, | |||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||
const std::string& key, const size_t nr_devices, const bool is_root, | |||
const int rank, std::shared_ptr<GroupClient> group_client, | |||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
const Param& param, const DType& dtype = {}, | |||
const std::string& backend = "nccl", | |||
@@ -45,15 +46,16 @@ public: | |||
const std::shared_ptr<DTypeScalar>& disable = | |||
std::make_shared<DTypeScalar>(0)); | |||
static SymbolVarArray make( | |||
const SymbolVarArray& inputs, ComputingGraph* const graph, | |||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||
const Param& param, const DType& dtype = {}, | |||
const std::string& backend = "nccl", | |||
const OperatorNodeConfig& config = {}, | |||
const std::shared_ptr<DTypeScalar>& disable = | |||
std::make_shared<DTypeScalar>(0)); | |||
static SymbolVarArray make(const SymbolVarArray& inputs, | |||
ComputingGraph* const graph, | |||
const std::string& key, const size_t nr_devices, | |||
const bool is_root, const int rank, | |||
std::shared_ptr<GroupClient> group_client, | |||
const Param& param, const DType& dtype = {}, | |||
const std::string& backend = "nccl", | |||
const OperatorNodeConfig& config = {}, | |||
const std::shared_ptr<DTypeScalar>& disable = | |||
std::make_shared<DTypeScalar>(0)); | |||
const Param& param() const { return m_param; } | |||
const DType& dtype() const { return m_dtype; } | |||
@@ -67,9 +69,9 @@ public: | |||
return m_dev_buffers; | |||
} | |||
uint32_t rank() const { return m_rank; } | |||
uint32_t root() const { return m_root; } | |||
bool is_root() const { return m_rank == m_root; } | |||
int rank() const { return m_rank; } | |||
int root() const { return m_root; } | |||
bool is_root() const { return m_is_root; } | |||
//! The key that identifies an NCCL clique. | |||
//! Operators with same keys belong to the same clique. | |||
@@ -108,12 +110,13 @@ private: | |||
std::shared_ptr<GroupClient> m_group_client; | |||
size_t m_nr_devices = 0; | |||
uint32_t m_rank; | |||
bool m_is_root; | |||
int m_rank; | |||
std::string m_key; | |||
//! XXHash generated from m_key | |||
size_t m_hash; | |||
//! root of BROADCAST and REDUCE operation | |||
uint32_t m_root; | |||
int m_root; | |||
//! rank of root of BROADCAST and REDUCE operation | |||
Maybe<TensorShape> m_broadcast_output_shape = None; | |||
// Whether shape infer is enabled. This is only used by BROADCAST operation, | |||
@@ -24,12 +24,13 @@ namespace opr { | |||
class GroupInfo { | |||
public: | |||
struct OprInfo { | |||
uint32_t rank; | |||
uintptr_t stream; | |||
uint64_t comp_node_hash; | |||
bool is_root; | |||
int rank; | |||
}; | |||
void add_opr(const std::string& key, size_t nr_expected_devices, | |||
uint32_t graph_id, uintptr_t stream); | |||
bool is_root, int rank, uint64_t comp_node_hash); | |||
void set_output_shape(const std::string& key, const TensorShape& shape); | |||
@@ -37,15 +38,25 @@ class GroupInfo { | |||
void clear(); | |||
const std::vector<OprInfo>& opr_infos() const {return m_opr_infos; } | |||
const std::vector<OprInfo>& opr_infos() const { return m_opr_infos; } | |||
int get_root_rank() const { return m_root_rank; } | |||
int get_rank(uint64_t hash) const { return m_rank_map.at(hash); } | |||
uint64_t get_group_hash() const { return m_hash; } | |||
private: | |||
void sort_opr_infos(); | |||
void gen_infos_from_opr_infos(); | |||
std::vector<OprInfo> m_opr_infos; | |||
std::unordered_map<uint64_t, int> m_rank_map; | |||
uint64_t m_hash; | |||
uint32_t m_nr_registered_devs; | |||
uint32_t m_nr_expected_devs; | |||
Maybe<TensorShape> m_output_shape; | |||
uint32_t m_count = 0; | |||
int m_root_rank = -1; | |||
std::mutex m_group_mtx; | |||
std::condition_variable m_register_cv; | |||
std::condition_variable m_clear_cv; | |||
@@ -61,10 +72,16 @@ class GroupManager { | |||
public: | |||
~GroupManager() = default; | |||
struct RegisterInfo | |||
{ | |||
uint64_t hash; | |||
int rank, root_rank; | |||
}; | |||
//! register oprs' info to server, return deduplicated hash | |||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||
uintptr_t stream); | |||
RegisterInfo opr_register(const std::string& key, size_t nr_devices, | |||
bool is_root, int rank, uint64_t comp_node_hash); | |||
//! gather uids from all ranks | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank); | |||
@@ -80,9 +97,6 @@ class GroupManager { | |||
private: | |||
GroupInfo& get_group(const std::string& key); | |||
uint64_t get_hash_key(const std::vector<GroupInfo::OprInfo>& _infos, | |||
uint32_t rank); | |||
//! key -> group info. | |||
std::unordered_map<std::string, GroupInfo> m_key2group_info; | |||
@@ -112,9 +126,11 @@ class GroupClient { | |||
virtual ~GroupClient() = default; | |||
public: | |||
virtual uint64_t opr_register(const std::string& key, size_t nr_devices, | |||
uint32_t rank, uintptr_t stream) = 0; | |||
virtual GroupManager::RegisterInfo opr_register(const std::string& key, | |||
size_t nr_devices, | |||
bool is_root, int rank, | |||
uint64_t comp_node_hash) = 0; | |||
virtual std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) = 0; | |||
@@ -14,6 +14,7 @@ | |||
#if MGB_ENABLE_OPR_MM | |||
#include "megbrain/opr/collective_comm.h" | |||
#include "megbrain/opr/group_manager.h" | |||
using namespace mgb; | |||
using namespace opr; | |||
@@ -31,8 +32,10 @@ public: | |||
GroupClientProxy(const std::string& server_addr); | |||
//! graph registration, assign graph_id to worker. | |||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||
uintptr_t stream) override; | |||
GroupManager::RegisterInfo opr_register(const std::string& key, | |||
size_t nr_devices, bool is_root, | |||
int rank, | |||
uint64_t comp_node_hash) override; | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) override; | |||
@@ -4,13 +4,16 @@ package mm_handler; | |||
message OprRegisterRequest { | |||
string key = 1; | |||
uint32 rank = 2; | |||
uint64 stream = 3; | |||
uint32 nr_expected_devices = 4; | |||
bool is_root = 2; | |||
int32 rank = 3; | |||
uint64 comp_node_hash = 4; | |||
uint32 nr_expected_devices = 5; | |||
} | |||
message OprRegisterResponse { | |||
uint64 hash = 1; | |||
uint64 hash = 1; | |||
int32 rank = 2; | |||
int32 root_rank = 3; | |||
} | |||
message GatherUidRequest { | |||
@@ -45,9 +45,11 @@ class MockGroupClient final : public opr::GroupClient { | |||
public: | |||
~MockGroupClient() override = default; | |||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||
uintptr_t stream) { | |||
return m_mgr.opr_register(key, nr_devices, rank, stream); | |||
opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||
size_t nr_devices, | |||
bool is_root, int rank, | |||
uintptr_t stream) { | |||
return m_mgr.opr_register(key, nr_devices, is_root, rank, stream); | |||
} | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
@@ -94,9 +96,9 @@ TEST(TestOprCollectiveComm, AllReduce) { | |||
auto x1c = opr::Copy::make(x1, cn1); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_reduce", | |||
2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||
2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_reduce", | |||
2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||
2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0]; | |||
auto y_expect = make_all_reduce_output(mode, {x0, x1}); | |||
auto func = graph->compile({make_callback_copy(y0, host_y0), | |||
@@ -130,7 +132,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { | |||
auto graph0 = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", | |||
2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||
2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | |||
func0->execute(); | |||
}; | |||
@@ -139,7 +141,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { | |||
auto graph1 = ComputingGraph::make(); | |||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", | |||
2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||
2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0]; | |||
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | |||
func1->execute(); | |||
}; | |||
@@ -192,7 +194,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { | |||
graph0->options().graph_opt_level = 0; | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, 0, 0, client, | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, false, 0, client, | |||
{Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
@@ -211,7 +213,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { | |||
graph1->options().graph_opt_level = 0; | |||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, 1, 0, client, | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, false, 1, client, | |||
{Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
y1.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
@@ -274,9 +276,9 @@ TEST(TestOprCollectiveComm, AllGather) { | |||
auto x1c = opr::Copy::make(x1, cn1); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_gather", | |||
2, 0, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||
2, false, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_gather", | |||
2, 1, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||
2, false, 1, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||
auto y_expect = opr::Concat::make({x0, x1}, 0); | |||
auto func = graph->compile({make_callback_copy(y0, host_y0), | |||
@@ -303,7 +305,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client, | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client, | |||
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | |||
func0->execute(); | |||
@@ -312,7 +314,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { | |||
auto run_1 = [&]() { // rank 1 | |||
auto graph1 = ComputingGraph::make(); | |||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client, | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client, | |||
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | |||
func1->execute(); | |||
@@ -361,7 +363,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { | |||
graph0->options().graph_opt_level = 0; | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client, | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client, | |||
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
@@ -380,7 +382,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { | |||
graph1->options().graph_opt_level = 0; | |||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client, | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client, | |||
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||
y1.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
@@ -444,9 +446,9 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { | |||
auto x1c = opr::Copy::make(x1, cn1); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_scatter_sum", | |||
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_scatter_sum", | |||
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
auto y_expect = make_reduce_scatter_sum_output({x0, x1}); | |||
auto func = graph->compile({make_callback_copy(y0, host_y0), | |||
@@ -475,7 +477,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { | |||
auto graph0 = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", | |||
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | |||
func0->execute(); | |||
}; | |||
@@ -484,7 +486,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { | |||
auto graph1 = ComputingGraph::make(); | |||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", | |||
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | |||
func1->execute(); | |||
}; | |||
@@ -534,7 +536,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", | |||
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); | |||
@@ -553,7 +555,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { | |||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", | |||
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||
y1.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); | |||
@@ -616,9 +618,9 @@ TEST(TestOprCollectiveComm, ReduceSum) { | |||
auto x1c = opr::Copy::make(x1, cn1); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_sum", | |||
2, 0, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
2, true, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_sum", | |||
2, 1, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
2, false, 1, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
auto y_expect = x0 + x1; | |||
auto func = graph->compile({make_callback_copy(y0, host_y0), | |||
@@ -644,7 +646,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client, | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client, | |||
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | |||
func0->execute(); | |||
@@ -653,7 +655,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { | |||
auto run_1 = [&]() { // rank 1 | |||
auto graph1 = ComputingGraph::make(); | |||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client, | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client, | |||
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
auto func1 = graph1->compile({{y1, nullptr}}); | |||
func1->execute(); | |||
@@ -699,7 +701,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { | |||
graph0->options().graph_opt_level = 0; | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client, | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client, | |||
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
@@ -718,7 +720,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { | |||
graph1->options().graph_opt_level = 0; | |||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client, | |||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client, | |||
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||
y1.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
@@ -767,12 +769,12 @@ TEST(TestOprCollectiveComm, Broadcast) { | |||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "broadcast", | |||
2, 0, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | |||
2, true, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | |||
auto y_dev = std::make_shared<DeviceTensorND>(DeviceTensorND() | |||
.comp_node(cn1) | |||
.dtype(dtype::Float32()) | |||
.resize(host_x0->shape())); | |||
auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, 1, 0, | |||
auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, false, 1, | |||
client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | |||
auto func = graph->compile({make_callback_copy(y0, host_y0), | |||
@@ -797,7 +799,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { | |||
auto run_0 = [&]() { // rank 0 | |||
auto graph0 = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client, | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client, | |||
{Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | |||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | |||
func0->execute(); | |||
@@ -809,7 +811,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { | |||
.comp_node(cn1) | |||
.dtype(dtype::Float32()) | |||
.resize(host_x0->shape())); | |||
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client, | |||
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client, | |||
{y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | |||
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | |||
func1->execute(); | |||
@@ -845,7 +847,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { | |||
graph0->options().graph_opt_level = 0; | |||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client, | |||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client, | |||
{Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | |||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | |||
@@ -863,11 +865,11 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { | |||
auto graph1 = ComputingGraph::make(); | |||
graph1->options().graph_opt_level = 0; | |||
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client, | |||
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client, | |||
{Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | |||
auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); | |||
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, 1, 0, client, | |||
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, false, 1, client, | |||
Mode::REDUCE_SUM, dtype::Float32(), "nccl")[0]; | |||
g.node()->owner_opr()->node_prop().attribute().priority = 1; | |||
@@ -26,11 +26,13 @@ class MockGroupClient final : public opr::GroupClient { | |||
public: | |||
~MockGroupClient() override = default; | |||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||
uintptr_t stream) { | |||
return m_mgr.opr_register(key, nr_devices, rank, stream); | |||
opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||
size_t nr_devices, | |||
bool is_root, int rank, | |||
uint64_t comp_node_hash) { | |||
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||
} | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) { | |||
return m_mgr.gather_uid(uid, key, size, rank); | |||