Browse Source

修改element无法与element运算的问题

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
fecd82aadf
4 changed files with 21 additions and 21 deletions
  1. +8
    -5
      fastNLP/core/controllers/trainer.py
  2. +1
    -1
      fastNLP/core/dataset/dataset.py
  3. +4
    -6
      fastNLP/core/metrics/element.py
  4. +8
    -9
      fastNLP/core/metrics/metric.py

+ 8
- 5
fastNLP/core/controllers/trainer.py View File

@@ -919,16 +919,19 @@ class Trainer(TrainerEventTrigger):
_not_called_callback_fns.append(each_callback_fn)

if check_mode:
logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these "
if len(_not_called_callback_fns) != 0:
logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these "
f"callback_fns: {_not_called_callback_fns}, but it seems that"
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.")
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.",
once=True)
# 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass
# 函数;
self.check_batch_step_fn = lambda *args, **kwargs: ...
else:
logger.warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: "
elif len(_not_called_callback_fns)!=0:
logger.rank_zero_warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: "
f"{_not_called_callback_fns}, but it seems that"
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.")
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.",
once=True)

def _check_train_batch_loop_legality(self):
r"""


+ 1
- 1
fastNLP/core/dataset/dataset.py View File

@@ -405,7 +405,7 @@ class DataSet:
if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item]
else:
raise AttributeError
raise AttributeError(f"Dataset has no attribute named:{item}.")

def __setstate__(self, state):
self.__dict__ = state


+ 4
- 6
fastNLP/core/metrics/element.py View File

@@ -136,15 +136,13 @@ class Element:
if self.value is None:
prefix = f'Element:`{self.name}`'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
"element, or use it after it being used by the `Metric.update()` method.")

def __add__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value += other.value
else:
self.value += other
return self
other = other.value
return self.value + other

def __radd__(self, other):
self._check_value_when_call()
@@ -314,7 +312,7 @@ class Element:
if self._value is None:
prefix = f'Element:`{self.name}`'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
"element, or use it after it being used by the `Metric.update()` method.")
return getattr(self._value, item)
except AttributeError as e:
logger.error(f"Element:{self.name} has no `{item}` attribute.")


+ 8
- 9
fastNLP/core/metrics/metric.py View File

@@ -99,8 +99,14 @@ class Metric:

def __setattr__(self, key, value):
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True:
if key in self.elements and value is not self.elements[key]:
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}")
if key in self.elements and isinstance(value, (float, int, bool)):
self.elements[key].fill_value(value)
return
elif key in self.elements:
raise TypeError(f"self.{key} is an Element, only float/int/bool type value can be assigned to it, "
f"instead of {type(value)}.")
if isinstance(value, Element) and key not in self.elements:
raise RuntimeError("Please use register_element() function to add Element.")
object.__setattr__(self, key, value)

def _wrap_update(self, update):
@@ -163,13 +169,6 @@ class Metric:
"""
self.aggregate_when_get_metric = flag

def __getattr__(self, name: str) -> Element:
if 'elements' in self.__dict__:
elements = self.__dict__['elements']
if name in elements:
return elements[name]
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name))

def tensor2numpy(self, tensor) -> np.array:
"""
将tensor向量转为numpy类型变量


Loading…
Cancel
Save