import pytest from tests.helpers.utils import Capturing from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.log.logger import logger def test_get_monitor_value(): logger.set_stdout(stdout='raw') # 测试完全匹配 res = {'f1': 0.2, 'acc#rec': 0.3} with Capturing() as output: monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) assert monitor == 'f1' and value==0.2 # 测试可以匹配,且选择更靠前的 res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) assert monitor=='acc#f1' and value==0.2 # 测试monitor匹配不上,使用real_monitor res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res) assert monitor=='acc#rec' and value==0.3 # 测试monitor/real_monitor匹配不上, 重新选择 res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) assert monitor=='acc#f1' and value==0.2 # 测试partial的位置 res = {"acc#acc": 0.52, "loss#loss": 2} with Capturing() as output: monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) assert monitor=='loss#loss' and value==2