You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_results.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import os
  2. import pytest
  3. import subprocess
  4. from io import StringIO
  5. import sys
  6. from fastNLP.core.utils.cache_results import cache_results
  7. from fastNLP.core import rank_zero_rm
  8. def get_subprocess_results(cmd):
  9. output = subprocess.check_output(cmd, shell=True)
  10. return output.decode('utf8')
  11. class Capturing(list):
  12. # 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
  13. def __enter__(self):
  14. self._stdout = sys.stdout
  15. self._stderr = sys.stderr
  16. sys.stdout = self._stringio = StringIO()
  17. sys.stderr = self._stringioerr = StringIO()
  18. return self
  19. def __exit__(self, *args):
  20. self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
  21. del self._stringio, self._stringioerr # free up some memory
  22. sys.stdout = self._stdout
  23. sys.stderr = self._stderr
  24. class TestCacheResults:
  25. def test_cache_save(self):
  26. cache_fp = 'demo.pkl'
  27. try:
  28. @cache_results(cache_fp)
  29. def demo():
  30. print("¥")
  31. return 1
  32. res = demo()
  33. with Capturing() as output:
  34. res = demo()
  35. assert '¥' not in output[0]
  36. finally:
  37. rank_zero_rm(cache_fp)
  38. def test_cache_save_refresh(self):
  39. cache_fp = 'demo.pkl'
  40. try:
  41. @cache_results(cache_fp, _refresh=True)
  42. def demo():
  43. print("¥")
  44. return 1
  45. res = demo()
  46. with Capturing() as output:
  47. res = demo()
  48. assert '¥' in output[0]
  49. finally:
  50. rank_zero_rm(cache_fp)
  51. def test_cache_no_func_change(self):
  52. cache_fp = os.path.abspath('demo.pkl')
  53. try:
  54. @cache_results(cache_fp)
  55. def demo():
  56. print('¥')
  57. return 1
  58. with Capturing() as output:
  59. res = demo()
  60. assert '¥' in output[0]
  61. @cache_results(cache_fp)
  62. def demo():
  63. print('¥')
  64. return 1
  65. with Capturing() as output:
  66. res = demo()
  67. assert '¥' not in output[0]
  68. finally:
  69. rank_zero_rm('demo.pkl')
  70. def test_cache_func_change(self, capsys):
  71. cache_fp = 'demo.pkl'
  72. try:
  73. @cache_results(cache_fp)
  74. def demo():
  75. print('¥')
  76. return 1
  77. with Capturing() as output:
  78. res = demo()
  79. assert '¥' in output[0]
  80. @cache_results(cache_fp)
  81. def demo():
  82. print('¥¥')
  83. return 1
  84. with Capturing() as output:
  85. res = demo()
  86. assert 'different' in output[0]
  87. assert '¥' not in output[0]
  88. # 关闭check_hash应该不warning的
  89. with Capturing() as output:
  90. res = demo(_check_hash=0)
  91. assert 'different' not in output[0]
  92. assert '¥' not in output[0]
  93. finally:
  94. rank_zero_rm('demo.pkl')
  95. def test_cache_check_hash(self):
  96. cache_fp = 'demo.pkl'
  97. try:
  98. @cache_results(cache_fp, _check_hash=False)
  99. def demo():
  100. print('¥')
  101. return 1
  102. with Capturing() as output:
  103. res = demo(_check_hash=0)
  104. assert '¥' in output[0]
  105. @cache_results(cache_fp, _check_hash=False)
  106. def demo():
  107. print('¥¥')
  108. return 1
  109. # 默认不会check
  110. with Capturing() as output:
  111. res = demo()
  112. assert 'different' not in output[0]
  113. assert '¥' not in output[0]
  114. # check也可以
  115. with Capturing() as output:
  116. res = demo(_check_hash=True)
  117. assert 'different' in output[0]
  118. assert '¥' not in output[0]
  119. finally:
  120. rank_zero_rm('demo.pkl')
  121. # 外部 function 改变也会 导致改变
  122. def test_refer_fun_change(self):
  123. cache_fp = 'demo.pkl'
  124. test_type = 'func_refer_fun_change'
  125. try:
  126. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  127. res = get_subprocess_results(cmd)
  128. assert "¥" in res
  129. # 引用的function没有变化
  130. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  131. res = get_subprocess_results(cmd)
  132. assert "¥" not in res
  133. assert 'Read' in res
  134. assert 'different' not in res
  135. # 引用的function有变化
  136. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  137. res = get_subprocess_results(cmd)
  138. assert "¥" not in res
  139. assert 'different' in res
  140. finally:
  141. rank_zero_rm(cache_fp)
  142. # 外部 method 改变也会 导致改变
  143. def test_refer_class_method_change(self):
  144. cache_fp = 'demo.pkl'
  145. test_type = 'refer_class_method_change'
  146. try:
  147. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  148. res = get_subprocess_results(cmd)
  149. assert "¥" in res
  150. # 引用的class没有变化
  151. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  152. res = get_subprocess_results(cmd)
  153. assert 'Read' in res
  154. assert 'different' not in res
  155. assert "¥" not in res
  156. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  157. res = get_subprocess_results(cmd)
  158. assert 'different' in res
  159. assert "¥" not in res
  160. finally:
  161. rank_zero_rm(cache_fp)
  162. def test_duplicate_keyword(self):
  163. with pytest.raises(RuntimeError):
  164. @cache_results(None)
  165. def func_verbose(a, _verbose):
  166. pass
  167. func_verbose(0, 1)
  168. with pytest.raises(RuntimeError):
  169. @cache_results(None)
  170. def func_cache(a, _cache_fp):
  171. pass
  172. func_cache(1, 2)
  173. with pytest.raises(RuntimeError):
  174. @cache_results(None)
  175. def func_refresh(a, _refresh):
  176. pass
  177. func_refresh(1, 2)
  178. with pytest.raises(RuntimeError):
  179. @cache_results(None)
  180. def func_refresh(a, _check_hash):
  181. pass
  182. func_refresh(1, 2)
  183. def test_create_cache_dir(self):
  184. @cache_results('demo/demo.pkl')
  185. def cache():
  186. return 1, 2
  187. try:
  188. results = cache()
  189. assert (1, 2) == results
  190. finally:
  191. rank_zero_rm('demo/')
  192. def test_result_none_error(self):
  193. @cache_results('demo.pkl')
  194. def cache():
  195. pass
  196. try:
  197. with pytest.raises(RuntimeError):
  198. results = cache()
  199. finally:
  200. rank_zero_rm('demo.pkl')
  201. if __name__ == '__main__':
  202. import argparse
  203. parser = argparse.ArgumentParser()
  204. parser.add_argument('--test_type', type=str, default='refer_class_method_change')
  205. parser.add_argument('--turn', type=int, default=1)
  206. parser.add_argument('--cache_fp', type=str, default='demo.pkl')
  207. args = parser.parse_args()
  208. test_type = args.test_type
  209. cache_fp = args.cache_fp
  210. turn = args.turn
  211. if test_type == 'func_refer_fun_change':
  212. if turn == 0:
  213. def demo():
  214. b = 1
  215. return b
  216. else:
  217. def demo():
  218. b = 2
  219. return b
  220. @cache_results(cache_fp)
  221. def demo_refer_other_func():
  222. b = demo()
  223. print("¥")
  224. return b
  225. res = demo_refer_other_func()
  226. if test_type == 'refer_class_method_change':
  227. print(f"Turn:{turn}")
  228. if turn == 0:
  229. from helper_for_cache_results_1 import Demo
  230. else:
  231. from helper_for_cache_results_2 import Demo
  232. demo = Demo()
  233. # import pdb
  234. # pdb.set_trace()
  235. @cache_results(cache_fp)
  236. def demo_func():
  237. print("¥")
  238. b = demo.demo()
  239. return b
  240. res = demo_func()