You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import pytest
  2. from tests.helpers.utils import Capturing
  3. from fastNLP.core.callbacks.utils import _get_monitor_value
  4. from fastNLP.core.log.logger import logger
  5. def test_get_monitor_value():
  6. logger.set_stdout(stdout='raw')
  7. # 测试完全匹配
  8. res = {'f1': 0.2, 'acc#rec': 0.3}
  9. with Capturing() as output:
  10. monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res)
  11. assert monitor == 'f1' and value==0.2
  12. # 测试可以匹配,且选择更靠前的
  13. res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
  14. with Capturing() as output:
  15. monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res)
  16. assert monitor=='acc#f1' and value==0.2
  17. # 测试monitor匹配不上,使用real_monitor
  18. res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
  19. with Capturing() as output:
  20. monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res)
  21. assert monitor=='acc#rec' and value==0.3
  22. # 测试monitor/real_monitor匹配不上, 重新选择
  23. res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
  24. with Capturing() as output:
  25. monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res)
  26. assert monitor=='acc#f1' and value==0.2
  27. # 测试partial的位置
  28. res = {"acc#acc": 0.52, "loss#loss": 2}
  29. with Capturing() as output:
  30. monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res)
  31. assert monitor=='loss#loss' and value==2