You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_logger.py 9.4 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import os
  2. import tempfile
  3. import datetime
  4. from pathlib import Path
  5. import logging
  6. import re
  7. from fastNLP.core.envs.env import FASTNLP_LAUNCH_TIME
  8. from tests.helpers.utils import magic_argv_env_context
  9. from fastNLP.core import synchronize_safe_rm
  10. # 测试 TorchDDPDriver;
  11. @magic_argv_env_context
  12. def test_add_file_ddp_1():
  13. """
  14. 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在;
  15. 多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中;
  16. """
  17. import torch
  18. import torch.distributed as dist
  19. from fastNLP.core.log.logger import logger
  20. from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
  21. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  22. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  23. driver = TorchDDPDriver(
  24. model=model,
  25. parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")],
  26. output_from_new_proc="all"
  27. )
  28. driver.setup()
  29. msg = 'some test log msg'
  30. path = Path.cwd()
  31. filepath = path.joinpath('log.txt')
  32. handler = logger.add_file(filepath, mode="w")
  33. logger.info(msg)
  34. logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n")
  35. for h in logger.handlers:
  36. if isinstance(h, logging.FileHandler):
  37. h.flush()
  38. dist.barrier()
  39. with open(filepath, 'r') as f:
  40. line = ''.join([l for l in f])
  41. assert msg in line
  42. assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line
  43. pattern = re.compile(msg)
  44. assert len(pattern.findall(line)) == 1
  45. synchronize_safe_rm(filepath)
  46. dist.barrier()
  47. dist.destroy_process_group()
  48. logger.removeHandler(handler)
  49. @magic_argv_env_context
  50. def test_add_file_ddp_2():
  51. """
  52. 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在;
  53. """
  54. import torch
  55. import torch.distributed as dist
  56. from fastNLP.core.log.logger import logger
  57. from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
  58. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  59. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  60. driver = TorchDDPDriver(
  61. model=model,
  62. parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")],
  63. output_from_new_proc="all"
  64. )
  65. driver.setup()
  66. msg = 'some test log msg'
  67. origin_path = Path.cwd()
  68. try:
  69. path = origin_path.joinpath("not_existed")
  70. filepath = path.joinpath('log.txt')
  71. handler = logger.add_file(filepath)
  72. logger.info(msg)
  73. logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n")
  74. for h in logger.handlers:
  75. if isinstance(h, logging.FileHandler):
  76. h.flush()
  77. dist.barrier()
  78. with open(filepath, 'r') as f:
  79. line = ''.join([l for l in f])
  80. assert msg in line
  81. assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line
  82. pattern = re.compile(msg)
  83. assert len(pattern.findall(line)) == 1
  84. finally:
  85. synchronize_safe_rm(path)
  86. logger.removeHandler(handler)
  87. dist.barrier()
  88. dist.destroy_process_group()
  89. @magic_argv_env_context
  90. def test_add_file_ddp_3():
  91. """
  92. path = None;
  93. 多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中;
  94. """
  95. import torch
  96. import torch.distributed as dist
  97. from fastNLP.core.log.logger import logger
  98. from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
  99. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  100. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  101. driver = TorchDDPDriver(
  102. model=model,
  103. parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")],
  104. output_from_new_proc="all"
  105. )
  106. driver.setup()
  107. msg = 'some test log msg'
  108. handler = logger.add_file()
  109. logger.info(msg)
  110. logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n")
  111. for h in logger.handlers:
  112. if isinstance(h, logging.FileHandler):
  113. h.flush()
  114. dist.barrier()
  115. file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log")
  116. with open(file, 'r') as f:
  117. line = ''.join([l for l in f])
  118. # print(f"\nrank: {driver.get_local_rank()} line, {line}\n")
  119. assert msg in line
  120. assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line
  121. pattern = re.compile(msg)
  122. assert len(pattern.findall(line)) == 1
  123. synchronize_safe_rm(file)
  124. dist.barrier()
  125. dist.destroy_process_group()
  126. logger.removeHandler(handler)
  127. @magic_argv_env_context
  128. def test_add_file_ddp_4():
  129. """
  130. 测试 path 是文件夹;
  131. """
  132. import torch
  133. import torch.distributed as dist
  134. from fastNLP.core.log.logger import logger
  135. from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
  136. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  137. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  138. driver = TorchDDPDriver(
  139. model=model,
  140. parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")],
  141. output_from_new_proc="all"
  142. )
  143. driver.setup()
  144. msg = 'some test log msg'
  145. path = Path.cwd().joinpath("not_existed")
  146. try:
  147. handler = logger.add_file(path)
  148. logger.info(msg)
  149. logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n")
  150. for h in logger.handlers:
  151. if isinstance(h, logging.FileHandler):
  152. h.flush()
  153. dist.barrier()
  154. file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log")
  155. with open(file, 'r') as f:
  156. line = ''.join([l for l in f])
  157. assert msg in line
  158. assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line
  159. pattern = re.compile(msg)
  160. assert len(pattern.findall(line)) == 1
  161. finally:
  162. synchronize_safe_rm(path)
  163. logger.removeHandler(handler)
  164. dist.barrier()
  165. dist.destroy_process_group()
  166. class TestLogger:
  167. msg = 'some test log msg'
  168. def test_add_file_1(self):
  169. """
  170. 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在;
  171. """
  172. from fastNLP.core.log.logger import logger
  173. path = Path(tempfile.mkdtemp())
  174. try:
  175. filepath = path.joinpath('log.txt')
  176. handler = logger.add_file(filepath)
  177. logger.info(self.msg)
  178. with open(filepath, 'r') as f:
  179. line = ''.join([l for l in f])
  180. assert self.msg in line
  181. finally:
  182. synchronize_safe_rm(path)
  183. logger.removeHandler(handler)
  184. def test_add_file_2(self):
  185. """
  186. 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在;
  187. """
  188. from fastNLP.core.log.logger import logger
  189. origin_path = Path(tempfile.mkdtemp())
  190. try:
  191. path = origin_path.joinpath("not_existed")
  192. path = path.joinpath('log.txt')
  193. handler = logger.add_file(path)
  194. logger.info(self.msg)
  195. with open(path, 'r') as f:
  196. line = ''.join([l for l in f])
  197. assert self.msg in line
  198. finally:
  199. synchronize_safe_rm(origin_path)
  200. logger.removeHandler(handler)
  201. def test_add_file_3(self):
  202. """
  203. 测试 path 是 None;
  204. """
  205. from fastNLP.core.log.logger import logger
  206. handler = logger.add_file()
  207. logger.info(self.msg)
  208. path = Path.cwd()
  209. cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d'))
  210. for file in path.iterdir():
  211. if file.name.startswith(cur_datetime):
  212. with open(file, 'r') as f:
  213. line = ''.join([l for l in f])
  214. assert self.msg in line
  215. file.unlink()
  216. logger.removeHandler(handler)
  217. def test_add_file_4(self):
  218. """
  219. 测试 path 是文件夹;
  220. """
  221. from fastNLP.core.log.logger import logger
  222. path = Path(tempfile.mkdtemp())
  223. try:
  224. handler = logger.add_file(path)
  225. logger.info(self.msg)
  226. cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d'))
  227. for file in path.iterdir():
  228. if file.name.startswith(cur_datetime):
  229. with open(file, 'r') as f:
  230. line = ''.join([l for l in f])
  231. assert self.msg in line
  232. finally:
  233. synchronize_safe_rm(path)
  234. logger.removeHandler(handler)
  235. def test_stdout(self, capsys):
  236. from fastNLP.core.log.logger import logger
  237. handler = logger.set_stdout(stdout="raw")
  238. logger.info(self.msg)
  239. logger.debug('aabbc')
  240. captured = capsys.readouterr()
  241. assert "some test log msg\n" == captured.out
  242. logger.removeHandler(handler)