GitOrigin-RevId: dccdb71553
tags/v0.5.0
@@ -26,25 +26,17 @@ def reduce_sum( | |||||
tensor: Tensor, | tensor: Tensor, | ||||
key: str, | key: str, | ||||
nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
rank: Optional[int] = None, | |||||
root: Optional[int] = 0, | |||||
is_root: Optional[bool] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
"""Create reduce_sum operator for collective communication | """Create reduce_sum operator for collective communication | ||||
:param tensor: input tensor | :param tensor: input tensor | ||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
:param root: rank of root node, use 0 as default | |||||
:param is_root: whether this is a root node | |||||
""" | """ | ||||
return _collective_comm( | return _collective_comm( | ||||
tensor, | |||||
key, | |||||
CollParam.Mode.REDUCE_SUM, | |||||
nr_ranks, | |||||
rank, | |||||
root, | |||||
device=tensor.device, | |||||
tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, | |||||
) | ) | ||||
@@ -52,24 +44,21 @@ def broadcast( | |||||
tensor: Tensor, | tensor: Tensor, | ||||
key: str, | key: str, | ||||
nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
rank: Optional[int] = None, | |||||
root: Optional[int] = 0, | |||||
is_root: Optional[bool] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
"""Create broadcast operator for collective communication | """Create broadcast operator for collective communication | ||||
:param tensor: input tensor | :param tensor: input tensor | ||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
:param root: rank of root node, use 0 as default | |||||
:param is_root: whether this is a root node | |||||
""" | """ | ||||
if key is None: | if key is None: | ||||
key = tensor._symvar.name | key = tensor._symvar.name | ||||
if is_root is None: | |||||
is_root = get_rank() == 0 | |||||
if rank is None: | |||||
rank = get_rank() | |||||
if rank == root: | |||||
if is_root: | |||||
inp = tensor | inp = tensor | ||||
else: | else: | ||||
inp = tensor._symvar.owner_graph | inp = tensor._symvar.owner_graph | ||||
@@ -79,8 +68,7 @@ def broadcast( | |||||
key, | key, | ||||
CollParam.Mode.BROADCAST, | CollParam.Mode.BROADCAST, | ||||
nr_ranks, | nr_ranks, | ||||
rank, | |||||
root, | |||||
is_root, | |||||
dtype=tensor.dtype, | dtype=tensor.dtype, | ||||
device=tensor.device, | device=tensor.device, | ||||
) | ) | ||||
@@ -94,9 +82,9 @@ def all_gather( | |||||
:param tensor: input tensor | :param tensor: input tensor | ||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
:param rank: rank of this node | |||||
""" | """ | ||||
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank, 0) | |||||
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank) | |||||
def reduce_scatter_sum( | def reduce_scatter_sum( | ||||
@@ -107,69 +95,58 @@ def reduce_scatter_sum( | |||||
:param tensor: input tensor | :param tensor: input tensor | ||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
:param rank: rank of this node | |||||
""" | """ | ||||
return _collective_comm( | return _collective_comm( | ||||
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank | |||||
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank, | |||||
) | ) | ||||
def all_reduce_sum( | |||||
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||||
) -> Tensor: | |||||
def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||||
"""Create all_reduce_sum operator for collective communication | """Create all_reduce_sum operator for collective communication | ||||
:param tensor: input tensor | :param tensor: input tensor | ||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
""" | """ | ||||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks, rank) | |||||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks) | |||||
def all_reduce_max( | |||||
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||||
) -> Tensor: | |||||
def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||||
"""Create all_reduce_max operator for collective communication | """Create all_reduce_max operator for collective communication | ||||
:param tensor: input tensor | :param tensor: input tensor | ||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
""" | """ | ||||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks, rank) | |||||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks) | |||||
def all_reduce_min( | |||||
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||||
) -> Tensor: | |||||
def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||||
"""Create all_reduce_min operator for collective communication | """Create all_reduce_min operator for collective communication | ||||
:param tensor: input tensor | :param tensor: input tensor | ||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
""" | """ | ||||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks, rank) | |||||
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks) | |||||
def bcast_param( | def bcast_param( | ||||
inp: Union[Buffer, Parameter], | inp: Union[Buffer, Parameter], | ||||
key: str, | key: str, | ||||
nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
rank: Optional[int] = None, | |||||
root: Optional[int] = 0, | |||||
is_root: Optional[bool] = None, | |||||
) -> None: | ) -> None: | ||||
"""Broadcast parameters among devices | """Broadcast parameters among devices | ||||
:param inp: input Buffer or Parameter to be synchronized | :param inp: input Buffer or Parameter to be synchronized | ||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
:param root: rank of root node, use 0 as default | |||||
:param is_root: whether this is a root node | |||||
""" | """ | ||||
if not is_distributed(): | if not is_distributed(): | ||||
return | return | ||||
assert isinstance(inp, (Buffer, Parameter)) | assert isinstance(inp, (Buffer, Parameter)) | ||||
bcast_res = broadcast(inp, key, nr_ranks, rank, root) | |||||
bcast_res = broadcast(inp, key, nr_ranks, is_root) | |||||
add_update(inp, bcast_res, alpha=0) | add_update(inp, bcast_res, alpha=0) |
@@ -19,8 +19,8 @@ def collective_comm_symvar( | |||||
key: str, | key: str, | ||||
op: CollParam.Mode, | op: CollParam.Mode, | ||||
nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
is_root: Optional[bool] = None, | |||||
rank: Optional[int] = None, | rank: Optional[int] = None, | ||||
root: Optional[int] = 0, | |||||
dtype: Optional[type] = None, | dtype: Optional[type] = None, | ||||
device: Optional[mgb.CompNode] = None, | device: Optional[mgb.CompNode] = None, | ||||
comp_graph: Optional[mgb.CompGraph] = None, | comp_graph: Optional[mgb.CompGraph] = None, | ||||
@@ -31,8 +31,7 @@ def collective_comm_symvar( | |||||
:param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
:param op: mode of collective communication | :param op: mode of collective communication | ||||
:param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
:param rank: rank of the current process, use util.get_rank() as default | |||||
:param root: rank of root node, use 0 as default | |||||
:param is_root: whether this node is root node | |||||
:param dtype: output data type, use dtype of inp as default | :param dtype: output data type, use dtype of inp as default | ||||
:param device: output comp node, use comp node of inp as default | :param device: output comp node, use comp node of inp as default | ||||
:param comp_graph: output comp graph, use comp graph of inp as default | :param comp_graph: output comp graph, use comp graph of inp as default | ||||
@@ -41,8 +40,8 @@ def collective_comm_symvar( | |||||
inp, | inp, | ||||
key=str(key), | key=str(key), | ||||
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | ||||
rank=rank if rank is not None else get_rank(), | |||||
root=root, | |||||
is_root=is_root if is_root is not None else (get_rank() == 0), | |||||
rank=rank if rank is not None else -1, | |||||
server_addr=get_master_ip(), | server_addr=get_master_ip(), | ||||
port=get_master_port(), | port=get_master_port(), | ||||
param=CollParam(mode=op), | param=CollParam(mode=op), | ||||
@@ -17,7 +17,13 @@ import numpy as np | |||||
from .._internal.config import opr_priority_scope | from .._internal.config import opr_priority_scope | ||||
from ..core import Buffer, Parameter, Tensor, TensorDict | from ..core import Buffer, Parameter, Tensor, TensorDict | ||||
from ..core.graph import get_default_graph | from ..core.graph import get_default_graph | ||||
from ..distributed import all_reduce_sum, bcast_param, get_world_size, is_distributed | |||||
from ..distributed import ( | |||||
all_reduce_sum, | |||||
bcast_param, | |||||
get_rank, | |||||
get_world_size, | |||||
is_distributed, | |||||
) | |||||
from ..distributed.util import get_group_id | from ..distributed.util import get_group_id | ||||
from ..functional import add_update | from ..functional import add_update | ||||
from ..functional import grad as grad_func | from ..functional import grad as grad_func | ||||
@@ -222,7 +228,11 @@ class Optimizer(metaclass=ABCMeta): | |||||
def bcast_param(self): | def bcast_param(self): | ||||
for group in self.param_groups: | for group in self.param_groups: | ||||
for param in group["params"]: | for param in group["params"]: | ||||
bcast_param(param, "bcast_param_" + str(get_group_id())) | |||||
bcast_param( | |||||
param, | |||||
"bcast_param_" + str(get_group_id()), | |||||
is_root=(get_rank() == 0), | |||||
) | |||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
r"""Export the optimizer state. | r"""Export the optimizer state. | ||||
@@ -93,10 +93,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||||
} | } | ||||
SymbolVar _Opr::collective_comm_with_input( | SymbolVar _Opr::collective_comm_with_input( | ||||
SymbolVar inpvar, const std::string& key, | |||||
const size_t nr_devices, const uint32_t rank, const uint32_t root, | |||||
const std::string& server_addr, const int port, | |||||
PyObject* params, PyObject* dtype, | |||||
SymbolVar inpvar, const std::string& key, const size_t nr_devices, | |||||
const bool is_root, const int rank, const std::string& server_addr, | |||||
const int port, PyObject* params, PyObject* dtype, | |||||
const std::string& backend, SharedND* output_buf, | const std::string& backend, SharedND* output_buf, | ||||
const OperatorNodeConfig& config, const SharedScalar& disable) { | const OperatorNodeConfig& config, const SharedScalar& disable) { | ||||
SymbolVarArray inputs(1, inpvar); | SymbolVarArray inputs(1, inpvar); | ||||
@@ -111,15 +110,15 @@ SymbolVar _Opr::collective_comm_with_input( | |||||
if (dtype != Py_None) { | if (dtype != Py_None) { | ||||
_dtype = npy::dtype_np2mgb(dtype); | _dtype = npy::dtype_np2mgb(dtype); | ||||
} | } | ||||
return CollectiveComm::make(inputs, graph, key, nr_devices, rank, root, group_mgr, | |||||
dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0]; | |||||
return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank, | |||||
group_mgr, dev_buffer_arr, param, _dtype, | |||||
backend, config, disable.get_val())[0]; | |||||
} | } | ||||
SymbolVar _Opr::collective_comm_without_input( | SymbolVar _Opr::collective_comm_without_input( | ||||
CompGraph& cg, const std::string& key, | |||||
const size_t nr_devices, const uint32_t rank, const uint32_t root, | |||||
const std::string& server_addr, const int port, | |||||
PyObject* params, PyObject* dtype, | |||||
CompGraph& cg, const std::string& key, const size_t nr_devices, | |||||
const bool is_root, const int rank, const std::string& server_addr, | |||||
const int port, PyObject* params, PyObject* dtype, | |||||
const std::string& backend, SharedND* output_buf, | const std::string& backend, SharedND* output_buf, | ||||
const OperatorNodeConfig& config, const SharedScalar& disable) { | const OperatorNodeConfig& config, const SharedScalar& disable) { | ||||
SymbolVarArray inputs; | SymbolVarArray inputs; | ||||
@@ -134,8 +133,9 @@ SymbolVar _Opr::collective_comm_without_input( | |||||
if (dtype != Py_None) { | if (dtype != Py_None) { | ||||
_dtype = npy::dtype_np2mgb(dtype); | _dtype = npy::dtype_np2mgb(dtype); | ||||
} | } | ||||
return CollectiveComm::make(inputs, &graph, key, nr_devices, rank, root, group_mgr, | |||||
dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0]; | |||||
return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank, | |||||
group_mgr, dev_buffer_arr, param, _dtype, | |||||
backend, config, disable.get_val())[0]; | |||||
} | } | ||||
#else | #else | ||||
@@ -172,7 +172,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||||
SymbolVar _Opr::collective_comm_with_input( | SymbolVar _Opr::collective_comm_with_input( | ||||
SymbolVar inpvar, const std::string& key, | SymbolVar inpvar, const std::string& key, | ||||
const size_t nr_devices, const uint32_t rank, const uint32_t root, | |||||
const size_t nr_devices, const bool is_root, const int rank, | |||||
const std::string& server_addr, const int port, PyObject* params, | const std::string& server_addr, const int port, PyObject* params, | ||||
PyObject* dtype, const std::string& backend, SharedND* output_buf, | PyObject* dtype, const std::string& backend, SharedND* output_buf, | ||||
const OperatorNodeConfig& config, const SharedScalar& disable) { | const OperatorNodeConfig& config, const SharedScalar& disable) { | ||||
@@ -181,7 +181,7 @@ SymbolVar _Opr::collective_comm_with_input( | |||||
SymbolVar _Opr::collective_comm_without_input( | SymbolVar _Opr::collective_comm_without_input( | ||||
CompGraph& cg, const std::string& key, | CompGraph& cg, const std::string& key, | ||||
const size_t nr_devices, const uint32_t rank, const uint32_t root, | |||||
const size_t nr_devices, const bool is_root, const int rank, | |||||
const std::string& server_addr, const int port, PyObject* params, | const std::string& server_addr, const int port, PyObject* params, | ||||
PyObject* dtype, const std::string& backend, SharedND* output_buf, | PyObject* dtype, const std::string& backend, SharedND* output_buf, | ||||
const OperatorNodeConfig& config, const SharedScalar& disable) { | const OperatorNodeConfig& config, const SharedScalar& disable) { | ||||
@@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port, | |||||
static SymbolVar collective_comm_with_input( | static SymbolVar collective_comm_with_input( | ||||
SymbolVar inpvar, const std::string& key, const size_t nr_devices, | SymbolVar inpvar, const std::string& key, const size_t nr_devices, | ||||
const uint32_t rank, const uint32_t root, const std::string& server_addr, | |||||
const int port, PyObject* params, PyObject* dtype, | |||||
const std::string& backend, SharedND* output_buf, | |||||
const OperatorNodeConfig& config, const SharedScalar& disable); | |||||
const bool is_root, const int rank, const std::string& server_addr, const int port, | |||||
PyObject* params, PyObject* dtype, const std::string& backend, | |||||
SharedND* output_buf, const OperatorNodeConfig& config, | |||||
const SharedScalar& disable); | |||||
static SymbolVar collective_comm_without_input( | static SymbolVar collective_comm_without_input( | ||||
CompGraph& graph, const std::string& key, const size_t nr_devices, | CompGraph& graph, const std::string& key, const size_t nr_devices, | ||||
const uint32_t rank, const uint32_t root, const std::string& server_addr, | |||||
const int port, PyObject* params, PyObject* dtype, | |||||
const std::string& backend, SharedND* output_buf, | |||||
const OperatorNodeConfig& config, const SharedScalar& disable); | |||||
const bool is_root, const int rank, const std::string& server_addr, const int port, | |||||
PyObject* params, PyObject* dtype, const std::string& backend, | |||||
SharedND* output_buf, const OperatorNodeConfig& config, | |||||
const SharedScalar& disable); | |||||
// misc | // misc | ||||
static SymbolVarArray extern_c_opr_placeholder( | static SymbolVarArray extern_c_opr_placeholder( | ||||
@@ -102,7 +102,7 @@ def test_all_gather(): | |||||
return | return | ||||
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
inp = tensor(data) | inp = tensor(data) | ||||
output = dist.functional.all_gather(inp, "x") | |||||
output = dist.functional.all_gather(inp, "x", rank=rank) | |||||
assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
def check(shape, backend): | def check(shape, backend): | ||||
@@ -135,7 +135,7 @@ def test_reduce_scatter_sum(): | |||||
return | return | ||||
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
inp = tensor(data) | inp = tensor(data) | ||||
output = dist.functional.reduce_scatter_sum(inp, "x") | |||||
output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank) | |||||
assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
def check(shape, backend): | def check(shape, backend): | ||||
@@ -368,8 +368,8 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { | |||||
CollectiveComm::CollectiveComm( | CollectiveComm::CollectiveComm( | ||||
VarNodeArray inputs, ComputingGraph* const graph, | VarNodeArray inputs, ComputingGraph* const graph, | ||||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||||
const std::string& key, const size_t nr_devices, const bool is_root, | |||||
const int rank, std::shared_ptr<GroupClient> group_client, | |||||
const Param& param, const DType& dtype, const std::string& backend, | const Param& param, const DType& dtype, const std::string& backend, | ||||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
@@ -380,9 +380,9 @@ CollectiveComm::CollectiveComm( | |||||
m_backend(backend), | m_backend(backend), | ||||
m_group_client{std::move(group_client)}, | m_group_client{std::move(group_client)}, | ||||
m_nr_devices(nr_devices), | m_nr_devices(nr_devices), | ||||
m_is_root(is_root), | |||||
m_rank(rank), | m_rank(rank), | ||||
m_key(key), | m_key(key), | ||||
m_root(root), | |||||
m_dev_buffers(dev_buffer_arr), | m_dev_buffers(dev_buffer_arr), | ||||
m_disable{disable} { | m_disable{disable} { | ||||
for (auto i : inputs) { | for (auto i : inputs) { | ||||
@@ -422,28 +422,28 @@ CollectiveComm::CollectiveComm( | |||||
SymbolVarArray CollectiveComm::make( | SymbolVarArray CollectiveComm::make( | ||||
const SymbolVarArray& inputs, ComputingGraph* const graph, | const SymbolVarArray& inputs, ComputingGraph* const graph, | ||||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||||
const Param& param, const DType& dtype, const std::string& backend, | |||||
const OperatorNodeConfig& config, | |||||
const std::string& key, const size_t nr_devices, const bool is_root, | |||||
const int rank, std::shared_ptr<GroupClient> group_client, | |||||
const Param& param, const DType& dtype, const std::string& backend, | |||||
const OperatorNodeConfig& config, | |||||
const std::shared_ptr<DTypeScalar>& disable) { | const std::shared_ptr<DTypeScalar>& disable) { | ||||
SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices, | SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices, | ||||
nullptr); | nullptr); | ||||
return make(inputs, graph, key, nr_devices, rank, root, group_client, | |||||
return make(inputs, graph, key, nr_devices, is_root, rank, group_client, | |||||
dev_buffer_arr, param, dtype, backend, config); | dev_buffer_arr, param, dtype, backend, config); | ||||
} | } | ||||
SymbolVarArray CollectiveComm::make( | SymbolVarArray CollectiveComm::make( | ||||
const SymbolVarArray& inputs, ComputingGraph* const graph, | const SymbolVarArray& inputs, ComputingGraph* const graph, | ||||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||||
const std::string& key, const size_t nr_devices, const bool is_root, | |||||
const int rank, std::shared_ptr<GroupClient> group_client, | |||||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | ||||
const Param& param, const DType& dtype, const std::string& backend, | const Param& param, const DType& dtype, const std::string& backend, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
const std::shared_ptr<DTypeScalar>& disable) { | const std::shared_ptr<DTypeScalar>& disable) { | ||||
auto inpvars = cg::to_var_node_array(inputs); | auto inpvars = cg::to_var_node_array(inputs); | ||||
auto opr = graph->insert_opr(std::make_unique<CollectiveComm>( | auto opr = graph->insert_opr(std::make_unique<CollectiveComm>( | ||||
inpvars, graph, key, nr_devices, rank, root, std::move(group_client), | |||||
inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client), | |||||
param, dtype, backend, dev_buffer_arr, config, disable)); | param, dtype, backend, dev_buffer_arr, config, disable)); | ||||
mgb_assert(!opr->output().empty()); | mgb_assert(!opr->output().empty()); | ||||
return cg::to_symbol_var_array(opr->output()); | return cg::to_symbol_var_array(opr->output()); | ||||
@@ -452,11 +452,14 @@ SymbolVarArray CollectiveComm::make( | |||||
void CollectiveComm::opr_register() { | void CollectiveComm::opr_register() { | ||||
if (m_init) | if (m_init) | ||||
return; | return; | ||||
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) | |||||
.cuda_env(); | |||||
auto&& comp_node = output(0)->comp_node(); | |||||
auto reg_info = m_group_client->opr_register( | |||||
m_key, m_nr_devices, m_is_root, m_rank, | |||||
comp_node.get_uid()); | |||||
auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank, | |||||
reinterpret_cast<uintptr_t>(cuda_env.stream)); | |||||
m_rank = reg_info.rank; | |||||
m_root = reg_info.root_rank; | |||||
MegRayCommunicatorBuilder* builder; | MegRayCommunicatorBuilder* builder; | ||||
@@ -468,7 +471,7 @@ void CollectiveComm::opr_register() { | |||||
} | } | ||||
m_megray_comm = builder->get_megray_comm( | m_megray_comm = builder->get_megray_comm( | ||||
hash, m_key, m_nr_devices, m_rank, | |||||
reg_info.hash, m_key, m_nr_devices, m_rank, | |||||
get_megray_backend(m_backend), m_group_client); | get_megray_backend(m_backend), m_group_client); | ||||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | ||||
@@ -606,8 +609,8 @@ VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const { | |||||
} | } | ||||
auto gvar = CollectiveComm::make( | auto gvar = CollectiveComm::make( | ||||
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_rank, m_root, | |||||
m_group_client, mode, m_dtype, m_backend, | |||||
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root, | |||||
m_rank, m_group_client, mode, m_dtype, m_backend, | |||||
OperatorNodeConfig{}.comp_node_arr(cn_arr)); | OperatorNodeConfig{}.comp_node_arr(cn_arr)); | ||||
if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) { | if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) { | ||||
@@ -733,11 +736,11 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | ||||
return opr::CollectiveComm::make(to_symbol_var_array(inputs), | |||||
ctx.owner_graph(opr_, inputs), opr.key(), | |||||
opr.nr_devices(), opr.rank(), opr.root(), | |||||
opr.group_client(), opr.dev_buffers(), | |||||
opr.param(), opr.dtype(), opr.backend(), config)[0] | |||||
return opr::CollectiveComm::make( | |||||
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | |||||
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | |||||
opr.group_client(), opr.dev_buffers(), opr.param(), | |||||
opr.dtype(), opr.backend(), config)[0] | |||||
.node() | .node() | ||||
->owner_opr(); | ->owner_opr(); | ||||
} | } | ||||
@@ -6,8 +6,8 @@ decl_raw_opr( | |||||
'to the same NCCL operation.', 'str'), | 'to the same NCCL operation.', 'str'), | ||||
Doc('nr_devices', 'Total number of devices involved in the NCCL ' | Doc('nr_devices', 'Total number of devices involved in the NCCL ' | ||||
'operation to which this operator belongs.', 'int'), | 'operation to which this operator belongs.', 'int'), | ||||
Doc('rank', 'Rank of this operator', 'int'), | |||||
Doc('root', 'root rank of broadcast or reduce operation'), | |||||
Doc('is_root', 'whether this node is root node', 'bool'), | |||||
Doc('rank', 'rank of this node, if is -1, generate one', 'int'), | |||||
Doc('server_addr', 'rpc server ip address'), | Doc('server_addr', 'rpc server ip address'), | ||||
Doc('port', 'server rpc listening port'), | Doc('port', 'server rpc listening port'), | ||||
Doc('param', 'The only component of *param* is *mode*, which refers to ' | Doc('param', 'The only component of *param* is *mode*, which refers to ' | ||||
@@ -28,12 +28,12 @@ decl_raw_opr( | |||||
body = [ | body = [ | ||||
'if isinstance(input, _mgb.SymbolVar):', | 'if isinstance(input, _mgb.SymbolVar):', | ||||
(' output = _mgb._Opr.collective_comm_with_input(input, key, ' | (' output = _mgb._Opr.collective_comm_with_input(input, key, ' | ||||
'nr_devices, rank, root, server_addr, port, ' | |||||
'nr_devices, is_root, rank, server_addr, port, ' | |||||
'[param.serialize()], dtype, backend, output_buffer, config, disable)'), | '[param.serialize()], dtype, backend, output_buffer, config, disable)'), | ||||
'else:', | 'else:', | ||||
' assert isinstance(input, _mgb.CompGraph)', | ' assert isinstance(input, _mgb.CompGraph)', | ||||
(' output = _mgb._Opr.collective_comm_without_input(input, key, ' | (' output = _mgb._Opr.collective_comm_without_input(input, key, ' | ||||
'nr_devices, rank, root, server_addr, port, ' | |||||
'nr_devices, is_root, rank, server_addr, port, ' | |||||
'[param.serialize()], dtype, backend, output_buffer, config, disable)') | '[param.serialize()], dtype, backend, output_buffer, config, disable)') | ||||
], | ], | ||||
desc = ('collective communication between multiple CompNodes on multiple ' | desc = ('collective communication between multiple CompNodes on multiple ' | ||||
@@ -16,16 +16,60 @@ using namespace opr; | |||||
/* ================= GroupInfo ================= */ | /* ================= GroupInfo ================= */ | ||||
void GroupInfo::sort_opr_infos() { | |||||
auto cmp = [](const GroupInfo::OprInfo& a, const GroupInfo::OprInfo& b) { | |||||
return a.comp_node_hash < b.comp_node_hash; | |||||
}; | |||||
std::sort(m_opr_infos.begin(), m_opr_infos.end(), cmp); | |||||
} | |||||
void GroupInfo::gen_infos_from_opr_infos() { | |||||
// generate rank | |||||
bool rank_assgined = true; | |||||
for (auto& opr_info:m_opr_infos) { | |||||
if(opr_info.rank < 0) { | |||||
rank_assgined = false; | |||||
break; | |||||
} | |||||
} | |||||
if (!rank_assgined) { | |||||
for (size_t i = 0; i < m_opr_infos.size(); i++) { | |||||
m_opr_infos[i].rank = i; | |||||
m_rank_map.insert({m_opr_infos[i].comp_node_hash, i}); | |||||
} | |||||
} else { | |||||
for (size_t i = 0; i < m_opr_infos.size(); i++) { | |||||
m_rank_map.insert( | |||||
{m_opr_infos[i].comp_node_hash, m_opr_infos[i].rank}); | |||||
} | |||||
} | |||||
// generate root rank | |||||
for (auto& opr_info:m_opr_infos) { | |||||
if (opr_info.is_root) { | |||||
m_root_rank = opr_info.rank; | |||||
break; | |||||
} | |||||
} | |||||
// generate group hash | |||||
auto xxhash = XXHash{}; | |||||
for (auto&& opr_info : m_opr_infos) { | |||||
xxhash.update(&opr_info.comp_node_hash, sizeof(uint64_t)) | |||||
.update(&opr_info.rank, sizeof(int)); | |||||
} | |||||
m_hash = xxhash.digest(); | |||||
} | |||||
void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices, | void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices, | ||||
uint32_t rank, uintptr_t stream) { | |||||
bool is_root, int rank, uint64_t comp_node_hash) { | |||||
std::unique_lock<std::mutex> lk{m_group_mtx}; | std::unique_lock<std::mutex> lk{m_group_mtx}; | ||||
if (m_nr_expected_devs == 0) { | if (m_nr_expected_devs == 0) { | ||||
m_nr_expected_devs = nr_expected_devices; | m_nr_expected_devs = nr_expected_devices; | ||||
} else { | } else { | ||||
mgb_assert(m_nr_expected_devs == nr_expected_devices); | mgb_assert(m_nr_expected_devs == nr_expected_devices); | ||||
} | } | ||||
OprInfo opr_info = {rank, stream}; | |||||
m_opr_infos.push_back(std::move(opr_info)); | |||||
m_opr_infos.push_back({comp_node_hash, is_root, rank}); | |||||
m_nr_registered_devs++; | m_nr_registered_devs++; | ||||
m_count++; | m_count++; | ||||
if (m_nr_registered_devs > nr_expected_devices) { | if (m_nr_registered_devs > nr_expected_devices) { | ||||
@@ -38,6 +82,8 @@ void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices, | |||||
key.c_str(), nr_expected_devices, m_nr_registered_devs); | key.c_str(), nr_expected_devices, m_nr_registered_devs); | ||||
} | } | ||||
if (m_nr_expected_devs == m_nr_registered_devs) { | if (m_nr_expected_devs == m_nr_registered_devs) { | ||||
sort_opr_infos(); | |||||
gen_infos_from_opr_infos(); | |||||
m_register_cv.notify_all(); | m_register_cv.notify_all(); | ||||
} else { | } else { | ||||
m_register_cv.wait(lk, | m_register_cv.wait(lk, | ||||
@@ -66,6 +112,8 @@ void GroupInfo::clear() { | |||||
m_count--; | m_count--; | ||||
if (m_count == 0) { | if (m_count == 0) { | ||||
m_opr_infos.clear(); | m_opr_infos.clear(); | ||||
m_rank_map.clear(); | |||||
m_root_rank = -1; | |||||
m_nr_expected_devs = 0; | m_nr_expected_devs = 0; | ||||
m_nr_registered_devs = 0; | m_nr_registered_devs = 0; | ||||
m_output_shape.invalidate(); | m_output_shape.invalidate(); | ||||
@@ -77,14 +125,18 @@ void GroupInfo::clear() { | |||||
/* ================= GroupManager ================= */ | /* ================= GroupManager ================= */ | ||||
uint64_t GroupManager::opr_register(const std::string& key, size_t nr_devices, | |||||
uint32_t rank, uintptr_t stream) { | |||||
GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key, | |||||
size_t nr_devices, | |||||
bool is_root, int rank, | |||||
uint64_t comp_node_hash) { | |||||
GroupManager::RegisterInfo ret{0, 0, 0}; | |||||
auto&& group = get_group(key); | auto&& group = get_group(key); | ||||
group.add_opr(key, nr_devices, rank, stream); | |||||
auto&& opr_infos = group.opr_infos(); | |||||
uint64_t hash = get_hash_key(opr_infos, rank); | |||||
group.add_opr(key, nr_devices, is_root, rank, comp_node_hash); | |||||
ret.rank = group.get_rank(comp_node_hash); | |||||
ret.root_rank = group.get_root_rank(); | |||||
ret.hash = group.get_group_hash() + ret.rank; | |||||
group.clear(); | group.clear(); | ||||
return hash; | |||||
return ret; | |||||
} | } | ||||
std::vector<std::string> GroupManager::gather_uid(const std::string& uid, | std::vector<std::string> GroupManager::gather_uid(const std::string& uid, | ||||
@@ -126,22 +178,6 @@ GroupInfo& GroupManager::get_group(const std::string& key) { | |||||
return m_key2group_info[key]; | return m_key2group_info[key]; | ||||
} | } | ||||
uint64_t GroupManager::get_hash_key(const std::vector<GroupInfo::OprInfo>& _infos, | |||||
uint32_t rank) { | |||||
auto cmp = [](const GroupInfo::OprInfo& lhs, const GroupInfo::OprInfo& rhs) { | |||||
return lhs.rank < rhs.rank; | |||||
}; | |||||
auto infos = _infos; | |||||
std::sort(infos.begin(), infos.end(), cmp); | |||||
auto xxhash = XXHash{}; | |||||
for (auto&& opr_info : infos) { | |||||
xxhash.update(&opr_info.rank, sizeof(uint32_t)) | |||||
.update(&opr_info.stream, sizeof(uintptr_t)); | |||||
} | |||||
xxhash.update(&rank, sizeof(uint32_t)); | |||||
return xxhash.digest(); | |||||
}; | |||||
uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) { | uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) { | ||||
std::unique_lock<std::mutex> lk{m_barrier_mtx}; | std::unique_lock<std::mutex> lk{m_barrier_mtx}; | ||||
if (m_barrier_set.empty()) { | if (m_barrier_set.empty()) { | ||||
@@ -48,12 +48,11 @@ SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var, | |||||
void RemoteSend::scn_do_execute() { | void RemoteSend::scn_do_execute() { | ||||
if (!m_init) { | if (!m_init) { | ||||
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) | |||||
.cuda_env(); | |||||
auto&& comp_node = output(0)->comp_node(); | |||||
// rank 0 for RemoteSend | // rank 0 for RemoteSend | ||||
auto hash = m_group_client->opr_register(m_peer.key, 2, 0, | |||||
reinterpret_cast<uintptr_t>(cuda_env.stream)); | |||||
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | |||||
comp_node.get_uid()); | |||||
auto megray_comm_builder = | auto megray_comm_builder = | ||||
owner_graph() | owner_graph() | ||||
@@ -62,7 +61,7 @@ void RemoteSend::scn_do_execute() { | |||||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | .get_user_data_or_create<MegRayCommunicatorBuilder>(); | ||||
m_megray_comm = megray_comm_builder->get_megray_comm( | m_megray_comm = megray_comm_builder->get_megray_comm( | ||||
hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||||
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||||
m_init = true; | m_init = true; | ||||
} | } | ||||
@@ -152,12 +151,12 @@ SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph, | |||||
void RemoteRecv::scn_do_execute() { | void RemoteRecv::scn_do_execute() { | ||||
if (!m_init) { | if (!m_init) { | ||||
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) | |||||
.cuda_env(); | |||||
auto&& comp_node = output(0)->comp_node(); | |||||
// rank 1 for RemoteRecv | // rank 1 for RemoteRecv | ||||
auto hash = m_group_client->opr_register(m_peer.key, 2, 1, | |||||
reinterpret_cast<uintptr_t>(cuda_env.stream)); | |||||
auto reg_info = m_group_client->opr_register( | |||||
m_peer.key, 2, false, 1, | |||||
comp_node.get_uid()); | |||||
auto megray_comm_builder = | auto megray_comm_builder = | ||||
owner_graph() | owner_graph() | ||||
@@ -166,7 +165,7 @@ void RemoteRecv::scn_do_execute() { | |||||
.get_user_data_or_create<MegRayCommunicatorBuilder>(); | .get_user_data_or_create<MegRayCommunicatorBuilder>(); | ||||
m_megray_comm = megray_comm_builder->get_megray_comm( | m_megray_comm = megray_comm_builder->get_megray_comm( | ||||
hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||||
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||||
m_init = true; | m_init = true; | ||||
} | } | ||||
@@ -68,9 +68,11 @@ private: | |||||
void GroupServerProxy::opr_register(void* input_ptr, size_t input_len, | void GroupServerProxy::opr_register(void* input_ptr, size_t input_len, | ||||
std::string *output) { | std::string *output) { | ||||
INFO_INIT(mm_handler, OprRegister); | INFO_INIT(mm_handler, OprRegister); | ||||
uint64_t hash = m_mgr.opr_register(req.key(), req.nr_expected_devices(), | |||||
req.rank(), req.stream()); | |||||
rsp.set_hash(hash); | |||||
auto ret = m_mgr.opr_register(req.key(), req.nr_expected_devices(), | |||||
req.is_root(), req.rank(), req.comp_node_hash()); | |||||
rsp.set_hash(ret.hash); | |||||
rsp.set_rank(ret.rank); | |||||
rsp.set_root_rank(ret.root_rank); | |||||
rsp.SerializeToString(output); | rsp.SerializeToString(output); | ||||
} | } | ||||
@@ -122,11 +124,11 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len, | |||||
/* ======================== GroupClientProxy ========================== */ | /* ======================== GroupClientProxy ========================== */ | ||||
#define INFO_INIT(space, f_name, name) \ | |||||
#define INFO_INIT(space, f_name, name) \ | |||||
using Request = space::name##Request; \ | using Request = space::name##Request; \ | ||||
using Response = space::name##Response; \ | using Response = space::name##Response; \ | ||||
std::string func_name = #f_name; \ | |||||
Request req; \ | |||||
std::string func_name = #f_name; \ | |||||
Request req; \ | |||||
Response rsp; | Response rsp; | ||||
#define SOLVE_REQUEST(name, req, rsp) \ | #define SOLVE_REQUEST(name, req, rsp) \ | ||||
@@ -145,15 +147,18 @@ GroupClientProxy::GroupClientProxy(const std::string& server_addr) | |||||
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { | m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { | ||||
} | } | ||||
uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices, | |||||
uint32_t rank, uintptr_t stream) { | |||||
GroupManager::RegisterInfo GroupClientProxy::opr_register( | |||||
const std::string& key, size_t nr_devices, bool is_root, int rank, | |||||
uint64_t comp_node_hash) { | |||||
INFO_INIT(mm_handler, opr_register, OprRegister) | INFO_INIT(mm_handler, opr_register, OprRegister) | ||||
req.set_key(key); | req.set_key(key); | ||||
req.set_is_root(is_root); | |||||
req.set_rank(rank); | req.set_rank(rank); | ||||
req.set_stream(stream); | |||||
req.set_comp_node_hash(comp_node_hash); | |||||
req.set_nr_expected_devices(nr_devices); | req.set_nr_expected_devices(nr_devices); | ||||
SOLVE_REQUEST(func_name, req, rsp); | SOLVE_REQUEST(func_name, req, rsp); | ||||
return rsp.hash(); | |||||
GroupManager::RegisterInfo ret{rsp.hash(), rsp.rank(), rsp.root_rank()}; | |||||
return ret; | |||||
} | } | ||||
void GroupClientProxy::set_output_shape(const std::string& key, | void GroupClientProxy::set_output_shape(const std::string& key, | ||||
@@ -26,18 +26,19 @@ public: | |||||
using Param = megdnn::param::CollectiveComm; | using Param = megdnn::param::CollectiveComm; | ||||
CollectiveComm(VarNodeArray inputs, ComputingGraph* const graph, | |||||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||||
const Param& param, const DType& dtype, const std::string& backend, | |||||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||||
const OperatorNodeConfig& config, | |||||
const std::shared_ptr<DTypeScalar>& disable); | |||||
CollectiveComm( | |||||
VarNodeArray inputs, ComputingGraph* const graph, | |||||
const std::string& key, const size_t nr_devices, const bool is_root, | |||||
const int rank, std::shared_ptr<GroupClient> group_client, | |||||
const Param& param, const DType& dtype, const std::string& backend, | |||||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||||
const OperatorNodeConfig& config, | |||||
const std::shared_ptr<DTypeScalar>& disable); | |||||
static SymbolVarArray make( | static SymbolVarArray make( | ||||
const SymbolVarArray& inputs, ComputingGraph* const graph, | const SymbolVarArray& inputs, ComputingGraph* const graph, | ||||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||||
const std::string& key, const size_t nr_devices, const bool is_root, | |||||
const int rank, std::shared_ptr<GroupClient> group_client, | |||||
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | ||||
const Param& param, const DType& dtype = {}, | const Param& param, const DType& dtype = {}, | ||||
const std::string& backend = "nccl", | const std::string& backend = "nccl", | ||||
@@ -45,15 +46,16 @@ public: | |||||
const std::shared_ptr<DTypeScalar>& disable = | const std::shared_ptr<DTypeScalar>& disable = | ||||
std::make_shared<DTypeScalar>(0)); | std::make_shared<DTypeScalar>(0)); | ||||
static SymbolVarArray make( | |||||
const SymbolVarArray& inputs, ComputingGraph* const graph, | |||||
const std::string& key, const size_t nr_devices, const uint32_t rank, | |||||
const uint32_t root, std::shared_ptr<GroupClient> group_client, | |||||
const Param& param, const DType& dtype = {}, | |||||
const std::string& backend = "nccl", | |||||
const OperatorNodeConfig& config = {}, | |||||
const std::shared_ptr<DTypeScalar>& disable = | |||||
std::make_shared<DTypeScalar>(0)); | |||||
static SymbolVarArray make(const SymbolVarArray& inputs, | |||||
ComputingGraph* const graph, | |||||
const std::string& key, const size_t nr_devices, | |||||
const bool is_root, const int rank, | |||||
std::shared_ptr<GroupClient> group_client, | |||||
const Param& param, const DType& dtype = {}, | |||||
const std::string& backend = "nccl", | |||||
const OperatorNodeConfig& config = {}, | |||||
const std::shared_ptr<DTypeScalar>& disable = | |||||
std::make_shared<DTypeScalar>(0)); | |||||
const Param& param() const { return m_param; } | const Param& param() const { return m_param; } | ||||
const DType& dtype() const { return m_dtype; } | const DType& dtype() const { return m_dtype; } | ||||
@@ -67,9 +69,9 @@ public: | |||||
return m_dev_buffers; | return m_dev_buffers; | ||||
} | } | ||||
uint32_t rank() const { return m_rank; } | |||||
uint32_t root() const { return m_root; } | |||||
bool is_root() const { return m_rank == m_root; } | |||||
int rank() const { return m_rank; } | |||||
int root() const { return m_root; } | |||||
bool is_root() const { return m_is_root; } | |||||
//! The key that identifies an NCCL clique. | //! The key that identifies an NCCL clique. | ||||
//! Operators with same keys belong to the same clique. | //! Operators with same keys belong to the same clique. | ||||
@@ -108,12 +110,13 @@ private: | |||||
std::shared_ptr<GroupClient> m_group_client; | std::shared_ptr<GroupClient> m_group_client; | ||||
size_t m_nr_devices = 0; | size_t m_nr_devices = 0; | ||||
uint32_t m_rank; | |||||
bool m_is_root; | |||||
int m_rank; | |||||
std::string m_key; | std::string m_key; | ||||
//! XXHash generated from m_key | //! XXHash generated from m_key | ||||
size_t m_hash; | size_t m_hash; | ||||
//! root of BROADCAST and REDUCE operation | //! root of BROADCAST and REDUCE operation | ||||
uint32_t m_root; | |||||
int m_root; | |||||
//! rank of root of BROADCAST and REDUCE operation | //! rank of root of BROADCAST and REDUCE operation | ||||
Maybe<TensorShape> m_broadcast_output_shape = None; | Maybe<TensorShape> m_broadcast_output_shape = None; | ||||
// Whether shape infer is enabled. This is only used by BROADCAST operation, | // Whether shape infer is enabled. This is only used by BROADCAST operation, | ||||
@@ -24,12 +24,13 @@ namespace opr { | |||||
class GroupInfo { | class GroupInfo { | ||||
public: | public: | ||||
struct OprInfo { | struct OprInfo { | ||||
uint32_t rank; | |||||
uintptr_t stream; | |||||
uint64_t comp_node_hash; | |||||
bool is_root; | |||||
int rank; | |||||
}; | }; | ||||
void add_opr(const std::string& key, size_t nr_expected_devices, | void add_opr(const std::string& key, size_t nr_expected_devices, | ||||
uint32_t graph_id, uintptr_t stream); | |||||
bool is_root, int rank, uint64_t comp_node_hash); | |||||
void set_output_shape(const std::string& key, const TensorShape& shape); | void set_output_shape(const std::string& key, const TensorShape& shape); | ||||
@@ -37,15 +38,25 @@ class GroupInfo { | |||||
void clear(); | void clear(); | ||||
const std::vector<OprInfo>& opr_infos() const {return m_opr_infos; } | |||||
const std::vector<OprInfo>& opr_infos() const { return m_opr_infos; } | |||||
int get_root_rank() const { return m_root_rank; } | |||||
int get_rank(uint64_t hash) const { return m_rank_map.at(hash); } | |||||
uint64_t get_group_hash() const { return m_hash; } | |||||
private: | private: | ||||
void sort_opr_infos(); | |||||
void gen_infos_from_opr_infos(); | |||||
std::vector<OprInfo> m_opr_infos; | std::vector<OprInfo> m_opr_infos; | ||||
std::unordered_map<uint64_t, int> m_rank_map; | |||||
uint64_t m_hash; | |||||
uint32_t m_nr_registered_devs; | uint32_t m_nr_registered_devs; | ||||
uint32_t m_nr_expected_devs; | uint32_t m_nr_expected_devs; | ||||
Maybe<TensorShape> m_output_shape; | Maybe<TensorShape> m_output_shape; | ||||
uint32_t m_count = 0; | uint32_t m_count = 0; | ||||
int m_root_rank = -1; | |||||
std::mutex m_group_mtx; | std::mutex m_group_mtx; | ||||
std::condition_variable m_register_cv; | std::condition_variable m_register_cv; | ||||
std::condition_variable m_clear_cv; | std::condition_variable m_clear_cv; | ||||
@@ -61,10 +72,16 @@ class GroupManager { | |||||
public: | public: | ||||
~GroupManager() = default; | ~GroupManager() = default; | ||||
struct RegisterInfo | |||||
{ | |||||
uint64_t hash; | |||||
int rank, root_rank; | |||||
}; | |||||
//! register oprs' info to server, return deduplicated hash | //! register oprs' info to server, return deduplicated hash | ||||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||||
uintptr_t stream); | |||||
RegisterInfo opr_register(const std::string& key, size_t nr_devices, | |||||
bool is_root, int rank, uint64_t comp_node_hash); | |||||
//! gather uids from all ranks | //! gather uids from all ranks | ||||
std::vector<std::string> gather_uid(const std::string& uid, | std::vector<std::string> gather_uid(const std::string& uid, | ||||
const std::string& key, uint32_t size, uint32_t rank); | const std::string& key, uint32_t size, uint32_t rank); | ||||
@@ -80,9 +97,6 @@ class GroupManager { | |||||
private: | private: | ||||
GroupInfo& get_group(const std::string& key); | GroupInfo& get_group(const std::string& key); | ||||
uint64_t get_hash_key(const std::vector<GroupInfo::OprInfo>& _infos, | |||||
uint32_t rank); | |||||
//! key -> group info. | //! key -> group info. | ||||
std::unordered_map<std::string, GroupInfo> m_key2group_info; | std::unordered_map<std::string, GroupInfo> m_key2group_info; | ||||
@@ -112,9 +126,11 @@ class GroupClient { | |||||
virtual ~GroupClient() = default; | virtual ~GroupClient() = default; | ||||
public: | public: | ||||
virtual uint64_t opr_register(const std::string& key, size_t nr_devices, | |||||
uint32_t rank, uintptr_t stream) = 0; | |||||
virtual GroupManager::RegisterInfo opr_register(const std::string& key, | |||||
size_t nr_devices, | |||||
bool is_root, int rank, | |||||
uint64_t comp_node_hash) = 0; | |||||
virtual std::vector<std::string> gather_uid(const std::string& uid, | virtual std::vector<std::string> gather_uid(const std::string& uid, | ||||
const std::string& key, uint32_t size, uint32_t rank) = 0; | const std::string& key, uint32_t size, uint32_t rank) = 0; | ||||
@@ -14,6 +14,7 @@ | |||||
#if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
#include "megbrain/opr/collective_comm.h" | #include "megbrain/opr/collective_comm.h" | ||||
#include "megbrain/opr/group_manager.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | using namespace opr; | ||||
@@ -31,8 +32,10 @@ public: | |||||
GroupClientProxy(const std::string& server_addr); | GroupClientProxy(const std::string& server_addr); | ||||
//! graph registration, assign graph_id to worker. | //! graph registration, assign graph_id to worker. | ||||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||||
uintptr_t stream) override; | |||||
GroupManager::RegisterInfo opr_register(const std::string& key, | |||||
size_t nr_devices, bool is_root, | |||||
int rank, | |||||
uint64_t comp_node_hash) override; | |||||
std::vector<std::string> gather_uid(const std::string& uid, | std::vector<std::string> gather_uid(const std::string& uid, | ||||
const std::string& key, uint32_t size, uint32_t rank) override; | const std::string& key, uint32_t size, uint32_t rank) override; | ||||
@@ -4,13 +4,16 @@ package mm_handler; | |||||
message OprRegisterRequest { | message OprRegisterRequest { | ||||
string key = 1; | string key = 1; | ||||
uint32 rank = 2; | |||||
uint64 stream = 3; | |||||
uint32 nr_expected_devices = 4; | |||||
bool is_root = 2; | |||||
int32 rank = 3; | |||||
uint64 comp_node_hash = 4; | |||||
uint32 nr_expected_devices = 5; | |||||
} | } | ||||
message OprRegisterResponse { | message OprRegisterResponse { | ||||
uint64 hash = 1; | |||||
uint64 hash = 1; | |||||
int32 rank = 2; | |||||
int32 root_rank = 3; | |||||
} | } | ||||
message GatherUidRequest { | message GatherUidRequest { | ||||
@@ -45,9 +45,11 @@ class MockGroupClient final : public opr::GroupClient { | |||||
public: | public: | ||||
~MockGroupClient() override = default; | ~MockGroupClient() override = default; | ||||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||||
uintptr_t stream) { | |||||
return m_mgr.opr_register(key, nr_devices, rank, stream); | |||||
opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||||
size_t nr_devices, | |||||
bool is_root, int rank, | |||||
uintptr_t stream) { | |||||
return m_mgr.opr_register(key, nr_devices, is_root, rank, stream); | |||||
} | } | ||||
std::vector<std::string> gather_uid(const std::string& uid, | std::vector<std::string> gather_uid(const std::string& uid, | ||||
@@ -94,9 +96,9 @@ TEST(TestOprCollectiveComm, AllReduce) { | |||||
auto x1c = opr::Copy::make(x1, cn1); | auto x1c = opr::Copy::make(x1, cn1); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_reduce", | auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_reduce", | ||||
2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||||
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_reduce", | auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_reduce", | ||||
2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0]; | |||||
auto y_expect = make_all_reduce_output(mode, {x0, x1}); | auto y_expect = make_all_reduce_output(mode, {x0, x1}); | ||||
auto func = graph->compile({make_callback_copy(y0, host_y0), | auto func = graph->compile({make_callback_copy(y0, host_y0), | ||||
@@ -130,7 +132,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { | |||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", | auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", | ||||
2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | ||||
func0->execute(); | func0->execute(); | ||||
}; | }; | ||||
@@ -139,7 +141,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { | |||||
auto graph1 = ComputingGraph::make(); | auto graph1 = ComputingGraph::make(); | ||||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | ||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", | auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", | ||||
2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0]; | |||||
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | ||||
func1->execute(); | func1->execute(); | ||||
}; | }; | ||||
@@ -192,7 +194,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { | |||||
graph0->options().graph_opt_level = 0; | graph0->options().graph_opt_level = 0; | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, 0, 0, client, | |||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, false, 0, client, | |||||
{Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | {Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | ||||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | y0.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
@@ -211,7 +213,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { | |||||
graph1->options().graph_opt_level = 0; | graph1->options().graph_opt_level = 0; | ||||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | ||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, 1, 0, client, | |||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, false, 1, client, | |||||
{Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | {Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | ||||
y1.node()->owner_opr()->node_prop().attribute().priority = -1; | y1.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
@@ -274,9 +276,9 @@ TEST(TestOprCollectiveComm, AllGather) { | |||||
auto x1c = opr::Copy::make(x1, cn1); | auto x1c = opr::Copy::make(x1, cn1); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_gather", | auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_gather", | ||||
2, 0, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||||
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_gather", | auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_gather", | ||||
2, 1, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 1, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | |||||
auto y_expect = opr::Concat::make({x0, x1}, 0); | auto y_expect = opr::Concat::make({x0, x1}, 0); | ||||
auto func = graph->compile({make_callback_copy(y0, host_y0), | auto func = graph->compile({make_callback_copy(y0, host_y0), | ||||
@@ -303,7 +305,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client, | |||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client, | |||||
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | ||||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | ||||
func0->execute(); | func0->execute(); | ||||
@@ -312,7 +314,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { | |||||
auto run_1 = [&]() { // rank 1 | auto run_1 = [&]() { // rank 1 | ||||
auto graph1 = ComputingGraph::make(); | auto graph1 = ComputingGraph::make(); | ||||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | ||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client, | |||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client, | |||||
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | ||||
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | ||||
func1->execute(); | func1->execute(); | ||||
@@ -361,7 +363,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { | |||||
graph0->options().graph_opt_level = 0; | graph0->options().graph_opt_level = 0; | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client, | |||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client, | |||||
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | ||||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | y0.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
@@ -380,7 +382,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { | |||||
graph1->options().graph_opt_level = 0; | graph1->options().graph_opt_level = 0; | ||||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | ||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client, | |||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client, | |||||
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; | ||||
y1.node()->owner_opr()->node_prop().attribute().priority = -1; | y1.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
@@ -444,9 +446,9 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { | |||||
auto x1c = opr::Copy::make(x1, cn1); | auto x1c = opr::Copy::make(x1, cn1); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_scatter_sum", | auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_scatter_sum", | ||||
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_scatter_sum", | auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_scatter_sum", | ||||
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
auto y_expect = make_reduce_scatter_sum_output({x0, x1}); | auto y_expect = make_reduce_scatter_sum_output({x0, x1}); | ||||
auto func = graph->compile({make_callback_copy(y0, host_y0), | auto func = graph->compile({make_callback_copy(y0, host_y0), | ||||
@@ -475,7 +477,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { | |||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", | auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", | ||||
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | ||||
func0->execute(); | func0->execute(); | ||||
}; | }; | ||||
@@ -484,7 +486,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { | |||||
auto graph1 = ComputingGraph::make(); | auto graph1 = ComputingGraph::make(); | ||||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | ||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", | auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", | ||||
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | ||||
func1->execute(); | func1->execute(); | ||||
}; | }; | ||||
@@ -534,7 +536,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { | |||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", | auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", | ||||
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | y0.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); | auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); | ||||
@@ -553,7 +555,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { | |||||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | ||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", | auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", | ||||
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; | |||||
y1.node()->owner_opr()->node_prop().attribute().priority = -1; | y1.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); | auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); | ||||
@@ -616,9 +618,9 @@ TEST(TestOprCollectiveComm, ReduceSum) { | |||||
auto x1c = opr::Copy::make(x1, cn1); | auto x1c = opr::Copy::make(x1, cn1); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_sum", | auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_sum", | ||||
2, 0, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||||
2, true, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||||
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_sum", | auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_sum", | ||||
2, 1, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||||
2, false, 1, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | |||||
auto y_expect = x0 + x1; | auto y_expect = x0 + x1; | ||||
auto func = graph->compile({make_callback_copy(y0, host_y0), | auto func = graph->compile({make_callback_copy(y0, host_y0), | ||||
@@ -644,7 +646,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client, | |||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client, | |||||
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | ||||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | ||||
func0->execute(); | func0->execute(); | ||||
@@ -653,7 +655,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { | |||||
auto run_1 = [&]() { // rank 1 | auto run_1 = [&]() { // rank 1 | ||||
auto graph1 = ComputingGraph::make(); | auto graph1 = ComputingGraph::make(); | ||||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | ||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client, | |||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client, | |||||
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | ||||
auto func1 = graph1->compile({{y1, nullptr}}); | auto func1 = graph1->compile({{y1, nullptr}}); | ||||
func1->execute(); | func1->execute(); | ||||
@@ -699,7 +701,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { | |||||
graph0->options().graph_opt_level = 0; | graph0->options().graph_opt_level = 0; | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client, | |||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client, | |||||
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | ||||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | y0.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
@@ -718,7 +720,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { | |||||
graph1->options().graph_opt_level = 0; | graph1->options().graph_opt_level = 0; | ||||
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); | ||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client, | |||||
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client, | |||||
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; | ||||
y1.node()->owner_opr()->node_prop().attribute().priority = -1; | y1.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
@@ -767,12 +769,12 @@ TEST(TestOprCollectiveComm, Broadcast) { | |||||
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "broadcast", | auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "broadcast", | ||||
2, 0, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | |||||
2, true, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | |||||
auto y_dev = std::make_shared<DeviceTensorND>(DeviceTensorND() | auto y_dev = std::make_shared<DeviceTensorND>(DeviceTensorND() | ||||
.comp_node(cn1) | .comp_node(cn1) | ||||
.dtype(dtype::Float32()) | .dtype(dtype::Float32()) | ||||
.resize(host_x0->shape())); | .resize(host_x0->shape())); | ||||
auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, 1, 0, | |||||
auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, false, 1, | |||||
client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | ||||
auto func = graph->compile({make_callback_copy(y0, host_y0), | auto func = graph->compile({make_callback_copy(y0, host_y0), | ||||
@@ -797,7 +799,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { | |||||
auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client, | |||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client, | |||||
{Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | ||||
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); | ||||
func0->execute(); | func0->execute(); | ||||
@@ -809,7 +811,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { | |||||
.comp_node(cn1) | .comp_node(cn1) | ||||
.dtype(dtype::Float32()) | .dtype(dtype::Float32()) | ||||
.resize(host_x0->shape())); | .resize(host_x0->shape())); | ||||
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client, | |||||
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client, | |||||
{y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | ||||
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); | ||||
func1->execute(); | func1->execute(); | ||||
@@ -845,7 +847,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { | |||||
graph0->options().graph_opt_level = 0; | graph0->options().graph_opt_level = 0; | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); | ||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client, | |||||
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client, | |||||
{Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; | ||||
y0.node()->owner_opr()->node_prop().attribute().priority = -1; | y0.node()->owner_opr()->node_prop().attribute().priority = -1; | ||||
@@ -863,11 +865,11 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { | |||||
auto graph1 = ComputingGraph::make(); | auto graph1 = ComputingGraph::make(); | ||||
graph1->options().graph_opt_level = 0; | graph1->options().graph_opt_level = 0; | ||||
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client, | |||||
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client, | |||||
{Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; | ||||
auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); | auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); | ||||
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, 1, 0, client, | |||||
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, false, 1, client, | |||||
Mode::REDUCE_SUM, dtype::Float32(), "nccl")[0]; | Mode::REDUCE_SUM, dtype::Float32(), "nccl")[0]; | ||||
g.node()->owner_opr()->node_prop().attribute().priority = 1; | g.node()->owner_opr()->node_prop().attribute().priority = 1; | ||||
@@ -26,11 +26,13 @@ class MockGroupClient final : public opr::GroupClient { | |||||
public: | public: | ||||
~MockGroupClient() override = default; | ~MockGroupClient() override = default; | ||||
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, | |||||
uintptr_t stream) { | |||||
return m_mgr.opr_register(key, nr_devices, rank, stream); | |||||
opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||||
size_t nr_devices, | |||||
bool is_root, int rank, | |||||
uint64_t comp_node_hash) { | |||||
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||||
} | } | ||||
std::vector<std::string> gather_uid(const std::string& uid, | std::vector<std::string> gather_uid(const std::string& uid, | ||||
const std::string& key, uint32_t size, uint32_t rank) { | const std::string& key, uint32_t size, uint32_t rank) { | ||||
return m_mgr.gather_uid(uid, key, size, rank); | return m_mgr.gather_uid(uid, key, size, rank); | ||||