You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import pytest
  2. from tests.helpers.utils import Capturing
  3. from fastNLP.core.callbacks.utils import _get_monitor_value
  4. from fastNLP.core.log.logger import logger
  5. def test_get_monitor_value():
  6. logger.set_stdout(stdout='raw')
  7. # 测试完全匹配
  8. res = {'f1': 0.2, 'acc#rec': 0.3}
  9. with Capturing() as output:
  10. monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res)
  11. assert monitor == 'f1' and value==0.2
  12. assert 'We can not find' not in output[0]
  13. # 测试可以匹配,且选择更靠前的
  14. res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
  15. with Capturing() as output:
  16. monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res)
  17. assert monitor=='acc#f1' and value==0.2
  18. assert 'We can not find' in output[0]
  19. # 测试monitor匹配不上,使用real_monitor
  20. res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
  21. with Capturing() as output:
  22. monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res)
  23. assert monitor=='acc#rec' and value==0.3
  24. assert 'We can not find' not in output[0]
  25. # 测试monitor/real_monitor匹配不上, 重新选择
  26. res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
  27. with Capturing() as output:
  28. monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res)
  29. assert monitor=='acc#f1' and value==0.2
  30. assert 'We can not find' in output[0]
  31. # 测试partial的位置
  32. res = {"acc#acc": 0.52, "loss#loss": 2}
  33. with Capturing() as output:
  34. monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res)
  35. assert monitor=='loss#loss' and value==2
  36. assert 'We can not find' in output[0]