|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from typing import Optional, Union
-
- import megengine._internal as mgb
- from megengine._internal.opr_param_defs import CollectiveComm as CollParam
-
- from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size
-
-
- def collective_comm_symvar(
- inp: Union[mgb.SymbolVar, mgb.CompGraph],
- key: str,
- op: CollParam.Mode,
- nr_ranks: Optional[int] = None,
- is_root: Optional[bool] = None,
- rank: Optional[int] = None,
- dtype: Optional[type] = None,
- device: Optional[mgb.CompNode] = None,
- comp_graph: Optional[mgb.CompGraph] = None,
- ) -> mgb.SymbolVar:
- """Helper function for creating collective_comm operators
-
- :param inp: tensor or comp_graph
- :param key: unique identifier for collective communication
- :param op: mode of collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param is_root: whether this node is root node
- :param dtype: output data type, use dtype of inp as default
- :param device: output comp node, use comp node of inp as default
- :param comp_graph: output comp graph, use comp graph of inp as default
- """
- return mgb.opr.collective_comm(
- inp,
- key=str(key),
- nr_devices=nr_ranks if nr_ranks is not None else get_world_size(),
- is_root=is_root if is_root is not None else (get_rank() == 0),
- rank=rank if rank is not None else -1,
- server_addr=get_master_ip(),
- port=get_master_port(),
- param=CollParam(mode=op),
- dtype=dtype,
- backend=get_backend(),
- comp_node=device,
- comp_graph=comp_graph,
- )
|