You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import os
  2. import sys
  3. import __main__
  4. from functools import wraps
  5. import inspect
  6. from inspect import ismethod
  7. import functools
  8. from copy import deepcopy
  9. from io import StringIO
  10. import time
  11. import numpy as np
  12. from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
  13. from fastNLP.core.drivers.utils import distributed_open_proc
  14. def get_class_that_defined_method(meth):
  15. if isinstance(meth, functools.partial):
  16. return get_class_that_defined_method(meth.func)
  17. if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)):
  18. for cls in inspect.getmro(meth.__self__.__class__):
  19. if meth.__name__ in cls.__dict__:
  20. return cls
  21. meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing
  22. if inspect.isfunction(meth):
  23. cls = getattr(inspect.getmodule(meth),
  24. meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0],
  25. None)
  26. if isinstance(cls, type):
  27. return cls
  28. return getattr(meth, '__objclass__', None) # handle special descriptor objects
  29. def magic_argv_env_context(fn):
  30. @wraps(fn)
  31. def wrapper(*args, **kwargs):
  32. command = deepcopy(sys.argv)
  33. env = deepcopy(os.environ.copy())
  34. used_args = []
  35. for each_arg in sys.argv[1:]:
  36. if "test" not in each_arg:
  37. used_args.append(each_arg)
  38. pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST')
  39. try:
  40. l_index = pytest_current_test.index("[")
  41. r_index = pytest_current_test.index("]")
  42. subtest = pytest_current_test[l_index: r_index + 1]
  43. except:
  44. subtest = ""
  45. if not ismethod(fn) and get_class_that_defined_method(fn) is None:
  46. sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{fn.__name__}{subtest}"] + used_args
  47. else:
  48. sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{get_class_that_defined_method(fn).__name__}::{fn.__name__}{subtest}"] + used_args
  49. res = fn(*args, **kwargs)
  50. sys.argv = deepcopy(command)
  51. os.environ = env
  52. return res
  53. return wrapper
  54. class Capturing(list):
  55. # 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
  56. """
  57. 使用例子
  58. with Capturing() as output:
  59. do_something
  60. assert 'xxx' in output[0]
  61. """
  62. def __init__(self, no_del=False):
  63. # 如果no_del为True,则不会删除_stringio,和_stringioerr
  64. super().__init__()
  65. self.no_del = no_del
  66. def __enter__(self):
  67. self._stdout = sys.stdout
  68. self._stderr = sys.stderr
  69. sys.stdout = self._stringio = StringIO()
  70. sys.stderr = self._stringioerr = StringIO()
  71. return self
  72. def __exit__(self, *args):
  73. self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
  74. if not self.no_del:
  75. del self._stringio, self._stringioerr # free up some memory
  76. sys.stdout = self._stdout
  77. sys.stderr = self._stderr
  78. def re_run_current_cmd_for_torch(num_procs, output_from_new_proc='ignore'):
  79. # Script called as `python a/b/c.py`
  80. if int(os.environ.get('LOCAL_RANK', '0')) == 0:
  81. if __main__.__spec__ is None: # pragma: no-cover
  82. # pull out the commands used to run the script and resolve the abs file path
  83. command = sys.argv
  84. command[0] = os.path.abspath(command[0])
  85. # use the same python interpreter and actually running
  86. command = [sys.executable] + command
  87. # Script called as `python -m a.b.c`
  88. else:
  89. command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:]
  90. for rank in range(1, num_procs+1):
  91. env_copy = os.environ.copy()
  92. env_copy["LOCAL_RANK"] = f"{rank}"
  93. env_copy['WOLRD_SIZE'] = f'{num_procs+1}'
  94. env_copy['RANK'] = f'{rank}'
  95. # 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK;
  96. env_copy[FASTNLP_GLOBAL_RANK] = str(rank)
  97. proc = distributed_open_proc(output_from_new_proc, command, env_copy, None)
  98. delay = np.random.uniform(1, 5, 1)[0]
  99. time.sleep(delay)