You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_results.py 13 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import os
  2. import pytest
  3. import subprocess
  4. from io import StringIO
  5. import sys
  6. sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
  7. from fastNLP.core.utils.cache_results import cache_results
  8. from fastNLP.envs.distributed import rank_zero_rm
  9. def get_subprocess_results(cmd):
  10. output = subprocess.check_output(cmd, shell=True)
  11. return output.decode('utf8')
  12. class Capturing(list):
  13. # 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
  14. def __enter__(self):
  15. self._stdout = sys.stdout
  16. self._stderr = sys.stderr
  17. sys.stdout = self._stringio = StringIO()
  18. sys.stderr = self._stringioerr = StringIO()
  19. return self
  20. def __exit__(self, *args):
  21. self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
  22. del self._stringio, self._stringioerr # free up some memory
  23. sys.stdout = self._stdout
  24. sys.stderr = self._stderr
  25. class TestCacheResults:
  26. def test_cache_save(self):
  27. cache_fp = 'demo.pkl'
  28. try:
  29. @cache_results(cache_fp)
  30. def demo():
  31. print("¥")
  32. return 1
  33. res = demo()
  34. with Capturing() as output:
  35. res = demo()
  36. assert '¥' not in output[0]
  37. finally:
  38. rank_zero_rm(cache_fp)
  39. def test_cache_save_refresh(self):
  40. cache_fp = 'demo.pkl'
  41. try:
  42. @cache_results(cache_fp, _refresh=True)
  43. def demo():
  44. print("¥")
  45. return 1
  46. res = demo()
  47. with Capturing() as output:
  48. res = demo()
  49. assert '¥' in output[0]
  50. finally:
  51. rank_zero_rm(cache_fp)
  52. def test_cache_no_func_change(self):
  53. cache_fp = os.path.abspath('demo.pkl')
  54. try:
  55. @cache_results(cache_fp)
  56. def demo():
  57. print('¥')
  58. return 1
  59. with Capturing() as output:
  60. res = demo()
  61. assert '¥' in output[0]
  62. @cache_results(cache_fp)
  63. def demo():
  64. print('¥')
  65. return 1
  66. with Capturing() as output:
  67. res = demo()
  68. assert '¥' not in output[0]
  69. finally:
  70. rank_zero_rm('demo.pkl')
  71. def test_cache_func_change(self, capsys):
  72. cache_fp = 'demo.pkl'
  73. try:
  74. @cache_results(cache_fp)
  75. def demo():
  76. print('¥')
  77. return 1
  78. with Capturing() as output:
  79. res = demo()
  80. assert '¥' in output[0]
  81. @cache_results(cache_fp)
  82. def demo():
  83. print('¥¥')
  84. return 1
  85. with Capturing() as output:
  86. res = demo()
  87. assert 'different' in output[0]
  88. assert '¥' not in output[0]
  89. # 关闭check_hash应该不warning的
  90. with Capturing() as output:
  91. res = demo(_check_hash=0)
  92. assert 'different' not in output[0]
  93. assert '¥' not in output[0]
  94. finally:
  95. rank_zero_rm('demo.pkl')
  96. def test_cache_check_hash(self):
  97. cache_fp = 'demo.pkl'
  98. try:
  99. @cache_results(cache_fp, _check_hash=False)
  100. def demo():
  101. print('¥')
  102. return 1
  103. with Capturing() as output:
  104. res = demo(_check_hash=0)
  105. assert '¥' in output[0]
  106. @cache_results(cache_fp, _check_hash=False)
  107. def demo():
  108. print('¥¥')
  109. return 1
  110. # 默认不会check
  111. with Capturing() as output:
  112. res = demo()
  113. assert 'different' not in output[0]
  114. assert '¥' not in output[0]
  115. # check也可以
  116. with Capturing() as output:
  117. res = demo(_check_hash=True)
  118. assert 'different' in output[0]
  119. assert '¥' not in output[0]
  120. finally:
  121. rank_zero_rm('demo.pkl')
  122. # 外部 function 改变也会 导致改变
  123. def test_refer_fun_change(self):
  124. cache_fp = 'demo.pkl'
  125. test_type = 'func_refer_fun_change'
  126. try:
  127. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  128. res = get_subprocess_results(cmd)
  129. assert "¥" in res
  130. # 引用的function没有变化
  131. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  132. res = get_subprocess_results(cmd)
  133. assert "¥" not in res
  134. assert 'Read' in res
  135. assert 'different' not in res
  136. # 引用的function有变化
  137. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  138. res = get_subprocess_results(cmd)
  139. assert "¥" not in res
  140. assert 'different' in res
  141. finally:
  142. rank_zero_rm(cache_fp)
  143. # 外部 method 改变也会 导致改变
  144. def test_refer_class_method_change(self):
  145. cache_fp = 'demo.pkl'
  146. test_type = 'refer_class_method_change'
  147. try:
  148. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  149. res = get_subprocess_results(cmd)
  150. assert "¥" in res
  151. # 引用的class没有变化
  152. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  153. res = get_subprocess_results(cmd)
  154. assert 'Read' in res
  155. assert 'different' not in res
  156. assert "¥" not in res
  157. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  158. res = get_subprocess_results(cmd)
  159. assert 'different' in res
  160. assert "¥" not in res
  161. finally:
  162. rank_zero_rm(cache_fp)
  163. def test_duplicate_keyword(self):
  164. with pytest.raises(RuntimeError):
  165. @cache_results(None)
  166. def func_verbose(a, _verbose):
  167. pass
  168. func_verbose(0, 1)
  169. with pytest.raises(RuntimeError):
  170. @cache_results(None)
  171. def func_cache(a, _cache_fp):
  172. pass
  173. func_cache(1, 2)
  174. with pytest.raises(RuntimeError):
  175. @cache_results(None)
  176. def func_refresh(a, _refresh):
  177. pass
  178. func_refresh(1, 2)
  179. with pytest.raises(RuntimeError):
  180. @cache_results(None)
  181. def func_refresh(a, _check_hash):
  182. pass
  183. func_refresh(1, 2)
  184. def test_create_cache_dir(self):
  185. @cache_results('demo/demo.pkl')
  186. def cache():
  187. return 1, 2
  188. try:
  189. results = cache()
  190. assert (1, 2) == results
  191. finally:
  192. rank_zero_rm('demo/')
  193. def test_result_none_error(self):
  194. @cache_results('demo.pkl')
  195. def cache():
  196. pass
  197. try:
  198. with pytest.raises(RuntimeError):
  199. results = cache()
  200. finally:
  201. rank_zero_rm('demo.pkl')
  202. def remove_postfix(folder='.', post_fix='.pkl'):
  203. import os
  204. for f in os.listdir(folder):
  205. if os.path.isfile(f) and f.endswith(post_fix):
  206. os.remove(os.path.join(folder, f))
  207. class TestCacheResultsWithParam:
  208. @pytest.mark.parametrize('_refresh', [True, False])
  209. @pytest.mark.parametrize('_hash_param', [True, False])
  210. @pytest.mark.parametrize('_verbose', [0, 1])
  211. @pytest.mark.parametrize('_check_hash', [True, False])
  212. def test_cache_save(self, _refresh, _hash_param, _verbose, _check_hash):
  213. cache_fp = 'demo.pkl'
  214. try:
  215. @cache_results(cache_fp, _refresh=_refresh, _hash_param=_hash_param, _verbose=_verbose,
  216. _check_hash=_check_hash)
  217. def demo(a=1):
  218. print("¥")
  219. return 1
  220. res = demo()
  221. with Capturing() as output:
  222. res = demo(a=1)
  223. if _refresh is False:
  224. assert '¥' not in output[0]
  225. if _verbose is 0:
  226. assert 'read' not in output[0]
  227. with Capturing() as output:
  228. res = demo(1)
  229. if _refresh is False:
  230. assert '¥' not in output[0]
  231. with Capturing() as output:
  232. res = demo(a=2)
  233. if _hash_param is True: # 一定对不上,需要重新生成
  234. assert '¥' in output[0]
  235. finally:
  236. remove_postfix('.')
  237. def test_cache_complex_param(self):
  238. cache_fp = 'demo.pkl'
  239. try:
  240. @cache_results(cache_fp, _refresh=False)
  241. def demo(*args, s=1, **kwargs):
  242. print("¥")
  243. return 1
  244. res = demo(1,2,3, s=4, d=4)
  245. with Capturing() as output:
  246. res = demo(1,2,3,d=4, s=4)
  247. assert '¥' not in output[0]
  248. finally:
  249. remove_postfix('.')
  250. def test_wrapper_change(self):
  251. cache_fp = 'demo.pkl'
  252. test_type = 'wrapper_change'
  253. try:
  254. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  255. res = get_subprocess_results(cmd)
  256. assert "¥" in res
  257. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  258. res = get_subprocess_results(cmd)
  259. assert "¥" not in res
  260. assert 'Read' in res
  261. assert 'different' not in res
  262. finally:
  263. remove_postfix('.')
  264. def test_param_change(self):
  265. cache_fp = 'demo.pkl'
  266. test_type = 'param_change'
  267. try:
  268. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  269. res = get_subprocess_results(cmd)
  270. assert "¥" in res
  271. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  272. res = get_subprocess_results(cmd)
  273. assert "¥" in res
  274. assert 'Read' not in res
  275. finally:
  276. remove_postfix('.')
  277. def test_create_cache_dir(self):
  278. @cache_results('demo/demo.pkl')
  279. def cache(s):
  280. return 1, 2
  281. try:
  282. results = cache(s=1)
  283. assert (1, 2) == results
  284. finally:
  285. import shutil
  286. shutil.rmtree('demo/')
  287. if __name__ == '__main__':
  288. import argparse
  289. parser = argparse.ArgumentParser()
  290. parser.add_argument('--test_type', type=str, default='refer_class_method_change')
  291. parser.add_argument('--turn', type=int, default=1)
  292. parser.add_argument('--cache_fp', type=str, default='demo.pkl')
  293. args = parser.parse_args()
  294. test_type = args.test_type
  295. cache_fp = args.cache_fp
  296. turn = args.turn
  297. if test_type == 'func_refer_fun_change':
  298. if turn == 0:
  299. def demo():
  300. b = 1
  301. return b
  302. else:
  303. def demo():
  304. b = 2
  305. return b
  306. @cache_results(cache_fp)
  307. def demo_refer_other_func():
  308. b = demo()
  309. print("¥")
  310. return b
  311. res = demo_refer_other_func()
  312. if test_type == 'refer_class_method_change':
  313. print(f"Turn:{turn}")
  314. if turn == 0:
  315. from helper_for_cache_results_1 import Demo
  316. else:
  317. from helper_for_cache_results_2 import Demo
  318. demo = Demo()
  319. # import pdb
  320. # pdb.set_trace()
  321. @cache_results(cache_fp)
  322. def demo_func():
  323. print("¥")
  324. b = demo.demo()
  325. return b
  326. res = demo_func()
  327. if test_type == 'wrapper_change':
  328. if turn == 0:
  329. @cache_results(cache_fp, _refresh=True)
  330. def demo_wrapper_change():
  331. print("¥")
  332. return 1
  333. else:
  334. @cache_results(cache_fp, _refresh=False)
  335. def demo_wrapper_change():
  336. print("¥")
  337. return 1
  338. res = demo_wrapper_change()
  339. if test_type == 'param_change':
  340. if turn == 0:
  341. @cache_results(cache_fp, _refresh=False)
  342. def demo_param_change():
  343. print("¥")
  344. return 1
  345. else:
  346. @cache_results(cache_fp, _refresh=False)
  347. def demo_param_change(a=1):
  348. print("¥")
  349. return 1
  350. res = demo_param_change()