You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import socket
  11. from typing import Callable, List, Optional
  12. import megengine._internal as mgb
  13. from ..core import set_default_device
  14. _master_ip = None
  15. _master_port = 0
  16. _world_size = 0
  17. _rank = 0
  18. _backend = None
  19. _group_id = 0
  20. def init_process_group(
  21. master_ip: str,
  22. master_port: int,
  23. world_size: int,
  24. rank: int,
  25. dev: int,
  26. backend: Optional[str] = "nccl",
  27. ) -> None:
  28. """Initialize the distributed process group, and also specify the device used in the current process.
  29. :param master_ip: IP address of the master node.
  30. :param master_port: Port available for all processes to communicate.
  31. :param world_size: Total number of processes participating in the job.
  32. :param rank: Rank of the current process.
  33. :param dev: The GPU device id to bind this process to.
  34. :param backend: Communicator backend, currently support 'nccl' and 'ucx'
  35. """
  36. global _master_ip # pylint: disable=global-statement
  37. global _master_port # pylint: disable=global-statement
  38. global _world_size # pylint: disable=global-statement
  39. global _rank # pylint: disable=global-statement
  40. global _backend # pylint: disable=global-statement
  41. global _group_id # pylint: disable=global-statement
  42. if not isinstance(master_ip, str):
  43. raise TypeError("Expect type str but got {}".format(type(master_ip)))
  44. if not isinstance(master_port, int):
  45. raise TypeError("Expect type int but got {}".format(type(master_port)))
  46. if not isinstance(world_size, int):
  47. raise TypeError("Expect type int but got {}".format(type(world_size)))
  48. if not isinstance(rank, int):
  49. raise TypeError("Expect type int but got {}".format(type(rank)))
  50. if not isinstance(backend, str):
  51. raise TypeError("Expect type str but got {}".format(type(backend)))
  52. _master_ip = master_ip
  53. _master_port = master_port
  54. _world_size = world_size
  55. _rank = rank
  56. _backend = backend
  57. _group_id = 0
  58. set_default_device(mgb.comp_node("gpu" + str(dev)))
  59. if rank == 0:
  60. _master_port = mgb.config.create_mm_server("0.0.0.0", master_port)
  61. if _master_port == -1:
  62. raise Exception("Failed to start server on port {}".format(master_port))
  63. else:
  64. assert master_port > 0, "master_port must be specified for non-zero rank"
  65. def is_distributed() -> bool:
  66. """Return True if the distributed process group has been initialized"""
  67. return _world_size is not None and _world_size > 1
  68. def get_master_ip() -> str:
  69. """Get the IP address of the master node"""
  70. return str(_master_ip)
  71. def get_master_port() -> int:
  72. """Get the port of the rpc server on the master node"""
  73. return _master_port
  74. def get_world_size() -> int:
  75. """Get the total number of processes participating in the job"""
  76. return _world_size
  77. def get_rank() -> int:
  78. """Get the rank of the current process"""
  79. return _rank
  80. def get_backend() -> str:
  81. """Get the backend str"""
  82. return str(_backend)
  83. def get_group_id() -> int:
  84. """Get group id for collective communication"""
  85. global _group_id
  86. _group_id += 1
  87. return _group_id
  88. def group_barrier() -> None:
  89. """Block until all ranks in the group reach this barrier"""
  90. mgb.config.group_barrier(_master_ip, _master_port, _world_size, _rank)
  91. def synchronized(func: Callable):
  92. """Decorator. Decorated function will synchronize when finished.
  93. Specifically, we use this to prevent data race during hub.load"""
  94. @functools.wraps(func)
  95. def wrapper(*args, **kwargs):
  96. if not is_distributed():
  97. return func(*args, **kwargs)
  98. ret = func(*args, **kwargs)
  99. group_barrier()
  100. return ret
  101. return wrapper
  102. def get_free_ports(num: int) -> List[int]:
  103. """Get one or more free ports.
  104. """
  105. socks, ports = [], []
  106. for i in range(num):
  107. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  108. sock.bind(("", 0))
  109. socks.append(sock)
  110. ports.append(sock.getsockname()[1])
  111. for sock in socks:
  112. sock.close()
  113. return ports

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台