Browse Source

优化了monitor相关callback的信息打印

tags/v1.0.0alpha
yh 3 years ago
parent
commit
0be2357f46
4 changed files with 37 additions and 9 deletions
  1. +1
    -0
      fastNLP/core/callbacks/checkpoint_callback.py
  2. +1
    -1
      fastNLP/core/callbacks/fitlog_callback.py
  3. +19
    -4
      fastNLP/core/callbacks/has_monitor_callback.py
  4. +16
    -4
      fastNLP/core/collators/collator.py

+ 1
- 0
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -89,6 +89,7 @@ class CheckpointCallback(Callback):
self.topk_saver = TopkSaver(topk=topk, monitor=monitor, larger_better=larger_better, folder=folder,
save_object=save_object, only_state_dict=only_state_dict, model_save_fn=model_save_fn,
save_evaluate_results=save_evaluate_results, **kwargs)
self.topk_saver.log_name = self.__class__.__name__

self.topk = topk
self.save_object = save_object


+ 1
- 1
fastNLP/core/callbacks/fitlog_callback.py View File

@@ -49,7 +49,7 @@ class FitlogCallback(HasMonitorCallback):
def on_sanity_check_end(self, trainer, sanity_check_res):
super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res)
if self.monitor is None:
logger.rank_zero_warning(f"No monitor set for {self.__class__.__name__}. Therefore, no best metric will "
logger.rank_zero_warning(f"No monitor set for {self.log_name}. Therefore, no best metric will "
f"be logged.")

def on_evaluate_end(self, trainer, results):


+ 19
- 4
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -42,6 +42,7 @@ class ResultsMonitor:
"""
def __init__(self, monitor:Union[Callback, str], larger_better:bool=True):
self.set_monitor(monitor, larger_better)
self._log_name = self.__class__.__name__

def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
@@ -84,11 +85,12 @@ class ResultsMonitor:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.rank_zero_warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as "
logger.rank_zero_warning(f"We can not find monitor:`{self.monitor}` for `{self.log_name}` in the "
f"evaluation result (with keys as "
f"{list(results.keys())}), we use the `{use_monitor}` as the monitor.", once=True)
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.rank_zero_warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
logger.rank_zero_warning(f"Change of monitor detected for `{self.log_name}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")
@@ -166,6 +168,19 @@ class ResultsMonitor:
monitor_name = str(self.monitor)
return monitor_name

@property
def log_name(self) -> str:
"""
内部用于打印信息使用

:return:
"""
return self._log_name

@log_name.setter
def log_name(self, value):
self._log_name = value


class HasMonitorCallback(ResultsMonitor, Callback):
"""
@@ -201,10 +216,10 @@ class HasMonitorCallback(ResultsMonitor, Callback):
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_monitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
raise RuntimeError(f"No `monitor` is set for {self.log_name}. "
f"You can set it in the initialization or through Trainer.")
if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None:
raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.__class__.__name__}"
raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.log_name}"
f" need to watch the monitor:`{self.monitor_name}`.")

def on_sanity_check_end(self, trainer, sanity_check_res):


+ 16
- 4
fastNLP/core/collators/collator.py View File

@@ -11,6 +11,7 @@ import re
from fastNLP.core.log import logger
from .padders.get_padder import get_padder
from ...envs import SUPPORT_BACKENDS
from .padders import Padder


from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, MappingPackerUnpacker, \
@@ -89,6 +90,11 @@ class Collator:
数据转为 batch 类型后为 [1, [1,2]], 会被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同)(3)当前这个 field 的类
型是否是可以 pad (例如 str 类型的数据)。可以通过设置 logger.setLevel('debug') 来打印是判定不可 pad 的原因。

.. note::

``Collator`` 的原理是使用第一个 ``batch`` 的数据尝试推断每个``field``应该使用哪种类型的 ``Padder``,如果第一个 ``batch``
的数据刚好比较特殊,可能导致在之后的 pad 中遭遇失败,这种情况请通过 ``set_pad()`` 函数手动设置一下。

todo 补充 code example 。

如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 pad_val 设置为 None 实现。
@@ -168,10 +174,16 @@ class Collator:

if self.batch_data_type == 'l':
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序

for key, padder in self.padders.items():
batch = unpack_batch.get(key)
pad_batch[key] = padder(batch)
try:
for key, padder in self.padders.items():
batch = unpack_batch.get(key)
pad_batch[key] = padder(batch)
except BaseException as e:
try:
logger.error(f"The following exception happens when try to pad the `{key}` field with padder:{padder}:")
except:
pass
raise e

return self.packer_unpacker.pack_batch(pad_batch) # 根据情况恢复成与输入一致的类型



Loading…
Cancel
Save