diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 4f6c0dc4..7c6bba53 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -919,16 +919,19 @@ class Trainer(TrainerEventTrigger): _not_called_callback_fns.append(each_callback_fn) if check_mode: - logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " + if len(_not_called_callback_fns) != 0: + logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " f"callback_fns: {_not_called_callback_fns}, but it seems that" - "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") + "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.", + once=True) # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass # 函数; self.check_batch_step_fn = lambda *args, **kwargs: ... - else: - logger.warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: " + elif len(_not_called_callback_fns)!=0: + logger.rank_zero_warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: " f"{_not_called_callback_fns}, but it seems that" - "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") + "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.", + once=True) def _check_train_batch_loop_legality(self): r""" diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index d5b45eeb..fff8b5c2 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -405,7 +405,7 @@ class DataSet: if isinstance(item, str) and item in self.field_arrays: return self.field_arrays[item] else: - raise AttributeError + raise AttributeError(f"Dataset has no attribute named:{item}.") def __setstate__(self, state): self.__dict__ = state diff --git a/fastNLP/core/metrics/element.py b/fastNLP/core/metrics/element.py index e20bc90f..f6644602 100644 --- a/fastNLP/core/metrics/element.py +++ b/fastNLP/core/metrics/element.py @@ -136,15 +136,13 @@ class Element: if self.value is None: prefix = f'Element:`{self.name}`' raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " - "element, or use it after it being used by the `Metric.compute()` method.") + "element, or use it after it being used by the `Metric.update()` method.") def __add__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value += other.value - else: - self.value += other - return self + other = other.value + return self.value + other def __radd__(self, other): self._check_value_when_call() @@ -314,7 +312,7 @@ class Element: if self._value is None: prefix = f'Element:`{self.name}`' raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " - "element, or use it after it being used by the `Metric.compute()` method.") + "element, or use it after it being used by the `Metric.update()` method.") return getattr(self._value, item) except AttributeError as e: logger.error(f"Element:{self.name} has no `{item}` attribute.") diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index 178a598f..1a69e80c 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -99,8 +99,14 @@ class Metric: def __setattr__(self, key, value): if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: - if key in self.elements and value is not self.elements[key]: - raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}") + if key in self.elements and isinstance(value, (float, int, bool)): + self.elements[key].fill_value(value) + return + elif key in self.elements: + raise TypeError(f"self.{key} is an Element, only float/int/bool type value can be assigned to it, " + f"instead of {type(value)}.") + if isinstance(value, Element) and key not in self.elements: + raise RuntimeError("Please use register_element() function to add Element.") object.__setattr__(self, key, value) def _wrap_update(self, update): @@ -163,13 +169,6 @@ class Metric: """ self.aggregate_when_get_metric = flag - def __getattr__(self, name: str) -> Element: - if 'elements' in self.__dict__: - elements = self.__dict__['elements'] - if name in elements: - return elements[name] - raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name)) - def tensor2numpy(self, tensor) -> np.array: """ 将tensor向量转为numpy类型变量