You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 7.0 kB

3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import os
  2. import sys
  3. import __main__
  4. from functools import wraps, partial
  5. from inspect import ismethod
  6. from copy import deepcopy
  7. from io import StringIO
  8. import time
  9. import signal
  10. import pytest
  11. import numpy as np
  12. from fastNLP.core.utils.utils import get_class_that_defined_method
  13. from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
  14. from fastNLP.core.drivers.utils import distributed_open_proc
  15. from fastNLP.core.log import logger
  16. def recover_logger(fn):
  17. @wraps(fn)
  18. def wrapper(*args, **kwargs):
  19. # 保存logger的状态
  20. handlers = [handler for handler in logger.handlers]
  21. level = logger.level
  22. res = fn(*args, **kwargs)
  23. logger.handlers = handlers
  24. logger.setLevel(level)
  25. return res
  26. return wrapper
  27. def magic_argv_env_context(fn=None, timeout=300):
  28. """
  29. 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确;
  30. 会丢掉 pytest 中的 arg 参数。
  31. :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 5 分钟,单位为秒;
  32. :return:
  33. """
  34. # 说明是通过 @magic_argv_env_context(timeout=600) 调用;
  35. if fn is None:
  36. return partial(magic_argv_env_context, timeout=timeout)
  37. @wraps(fn)
  38. def wrapper(*args, **kwargs):
  39. command = deepcopy(sys.argv)
  40. env = deepcopy(os.environ.copy())
  41. used_args = []
  42. # for each_arg in sys.argv[1:]:
  43. # # warning,否则 可能导致 pytest -s . 中的点混入其中,导致多卡启动的 collect tests items 不为 1
  44. # if each_arg.startswith('-'):
  45. # used_args.append(each_arg)
  46. pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST')
  47. try:
  48. l_index = pytest_current_test.index("[")
  49. r_index = pytest_current_test.index("]")
  50. subtest = pytest_current_test[l_index: r_index + 1]
  51. except:
  52. subtest = ""
  53. if not ismethod(fn) and get_class_that_defined_method(fn) is None:
  54. sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{fn.__name__}{subtest}"] + used_args
  55. else:
  56. sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{get_class_that_defined_method(fn).__name__}::{fn.__name__}{subtest}"] + used_args
  57. def _handle_timeout(signum, frame):
  58. raise TimeoutError(f"\nYour test fn: {fn.__name__} has timed out.\n")
  59. # 恢复 logger
  60. handlers = [handler for handler in logger.handlers]
  61. formatters = [handler.formatter for handler in handlers]
  62. level = logger.level
  63. signal.signal(signal.SIGALRM, _handle_timeout)
  64. signal.alarm(timeout)
  65. res = fn(*args, **kwargs)
  66. signal.alarm(0)
  67. sys.argv = deepcopy(command)
  68. os.environ = env
  69. for formatter, handler in zip(formatters, handlers):
  70. handler.setFormatter(formatter)
  71. logger.handlers = handlers
  72. logger.setLevel(level)
  73. return res
  74. return wrapper
  75. class Capturing(list):
  76. # 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
  77. """
  78. 使用例子
  79. with Capturing() as output:
  80. do_something
  81. assert 'xxx' in output[0]
  82. """
  83. def __init__(self, no_del=False):
  84. # 如果no_del为True,则不会删除_stringio,和_stringioerr
  85. super().__init__()
  86. self.no_del = no_del
  87. def __enter__(self):
  88. self._stdout = sys.stdout
  89. self._stderr = sys.stderr
  90. sys.stdout = self._stringio = StringIO()
  91. sys.stderr = self._stringioerr = StringIO()
  92. return self
  93. def __exit__(self, *args):
  94. self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
  95. if not self.no_del:
  96. del self._stringio, self._stringioerr # free up some memory
  97. sys.stdout = self._stdout
  98. sys.stderr = self._stderr
  99. def re_run_current_cmd_for_torch(num_procs, output_from_new_proc='ignore'):
  100. # Script called as `python a/b/c.py`
  101. if int(os.environ.get('LOCAL_RANK', '0')) == 0:
  102. if __main__.__spec__ is None: # pragma: no-cover
  103. # pull out the commands used to run the script and resolve the abs file path
  104. command = sys.argv
  105. command[0] = os.path.abspath(command[0])
  106. # use the same python interpreter and actually running
  107. command = [sys.executable] + command
  108. # Script called as `python -m a.b.c`
  109. else:
  110. command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:]
  111. for rank in range(1, num_procs+1):
  112. env_copy = os.environ.copy()
  113. env_copy["LOCAL_RANK"] = f"{rank}"
  114. env_copy['WOLRD_SIZE'] = f'{num_procs+1}'
  115. env_copy['RANK'] = f'{rank}'
  116. # 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK;
  117. env_copy[FASTNLP_GLOBAL_RANK] = str(rank)
  118. proc = distributed_open_proc(output_from_new_proc, command, env_copy, None)
  119. delay = np.random.uniform(1, 5, 1)[0]
  120. time.sleep(delay)
  121. def re_run_current_cmd_for_oneflow(num_procs, output_from_new_proc='ignore'):
  122. # 实际上逻辑和 torch 一样,只是为了区分不同框架所以独立出来
  123. # Script called as `python a/b/c.py`
  124. if int(os.environ.get('LOCAL_RANK', '0')) == 0:
  125. if __main__.__spec__ is None: # pragma: no-cover
  126. # pull out the commands used to run the script and resolve the abs file path
  127. command = sys.argv
  128. command[0] = os.path.abspath(command[0])
  129. # use the same python interpreter and actually running
  130. command = [sys.executable] + command
  131. # Script called as `python -m a.b.c`
  132. else:
  133. command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:]
  134. for rank in range(1, num_procs+1):
  135. env_copy = os.environ.copy()
  136. env_copy["LOCAL_RANK"] = f"{rank}"
  137. env_copy['WOLRD_SIZE'] = f'{num_procs+1}'
  138. env_copy['RANK'] = f'{rank}'
  139. env_copy["GLOG_log_dir"] = os.path.join(
  140. os.getcwd(), f"oneflow_rank_{rank}"
  141. )
  142. os.makedirs(env_copy["GLOG_log_dir"], exist_ok=True)
  143. # 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK;
  144. env_copy[FASTNLP_GLOBAL_RANK] = str(rank)
  145. proc = distributed_open_proc(output_from_new_proc, command, env_copy, rank)
  146. delay = np.random.uniform(1, 5, 1)[0]
  147. time.sleep(delay)
  148. def run_pytest(argv):
  149. cmd = argv[0]
  150. for i in range(1, len(argv)):
  151. cmd += "::" + argv[i]
  152. pytest.main([cmd])