From af6300e816393a8ba13e0897268f8706014c4906 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Fri, 3 Jun 2022 21:01:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9element,=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E5=80=BC=E8=BD=AC=E4=B8=BA=E9=9D=9Etensor=E5=BD=A2=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics/element.py | 148 ++++++++++++------------- tests/core/metrics/test_element_cal_element.py | 35 ++++++ 2 files changed, 108 insertions(+), 75 deletions(-) create mode 100644 tests/core/metrics/test_element_cal_element.py diff --git a/fastNLP/core/metrics/element.py b/fastNLP/core/metrics/element.py index f6644602..749c5727 100644 --- a/fastNLP/core/metrics/element.py +++ b/fastNLP/core/metrics/element.py @@ -3,6 +3,7 @@ __all__ = [ ] import os +import functools from .backend import Backend, AutoBackend from fastNLP.core.log import logger @@ -10,6 +11,17 @@ from .utils import AggregateMethodError from fastNLP.envs.env import FASTNLP_GLOBAL_RANK +def _wrap_cal_value(func): + @functools.wraps(func) + def _wrap_cal(*args, **kwargs): + self = args[0] + value = func(*args, **kwargs) + value = self.backend.get_scalar(value) + return value + + return _wrap_cal + + class Element: def __init__(self, name, value: float, aggregate_method, backend: Backend): """ @@ -107,6 +119,7 @@ class Element: 对元素进行 fill_value, 会执行队友 backend 的 fill_value 方法 """ + self._check_value_initialized() self._value = self.backend.fill_value(self._value, value) def to(self, device): @@ -138,163 +151,148 @@ class Element: raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " "element, or use it after it being used by the `Metric.update()` method.") + @_wrap_cal_value def __add__(self, other): self._check_value_when_call() if isinstance(other, Element): other = other.value return self.value + other + @_wrap_cal_value def __radd__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value += other.value - else: - self.value += other - return self + other = other.value + return self.value + other + @_wrap_cal_value def __sub__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value -= other.value - else: - self.value -= other - return self + other = other.value + return self.value - other + @_wrap_cal_value def __rsub__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value -= other.value - else: - self.value -= other - return self + other = other.value + return self.value - other + @_wrap_cal_value def __mul__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value *= other.value - else: - self.value *= other - return self + other = other.value + return self.value * other + @_wrap_cal_value def __imul__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value *= other.value - else: - self.value *= other - return self + other = other.value + return self.value * other + @_wrap_cal_value def __floordiv__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value //= other.value - else: - self.value //= other - return self + other = other.value + return self.value // other + @_wrap_cal_value def __rfloordiv__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value //= other.value - else: - self.value //= other - return self + other = other.value + return self.value // other + @_wrap_cal_value def __truediv__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value /= other.value - else: - self.value /= other - return self + other = other.value + return self.value / other + @_wrap_cal_value def __rtruediv__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value /= other.value - else: - self.value /= other - return self + other = other.value + return self.value / other + @_wrap_cal_value def __mod__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value %= other.value - else: - self.value %= other - return self + other = other.value + return self.value % other + @_wrap_cal_value def __rmod__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value /= other.value - else: - self.value /= other - return self + other = other.value + return self.value % other + @_wrap_cal_value def __pow__(self, other, modulo=None): self._check_value_when_call() + if isinstance(other, Element): + other = other.value if modulo is None: - if isinstance(other, Element): - self.value **= other.value - else: - self.value **= other + return self.value ** other else: - if isinstance(other, Element): - self.value = pow(self.value, other.value, modulo) - else: - self.value = pow(self.value, other, modulo) - return self + return pow(self.value, other, modulo) + @_wrap_cal_value def __rpow__(self, other): self._check_value_when_call() if isinstance(other, Element): - self.value **= other.value - else: - self.value **= other - return self + other = other.value + return self.value ** other + @_wrap_cal_value def __lt__(self, other) -> bool: self._check_value_when_call() if isinstance(other, Element): - return self.value < other.value - else: - return self.value < other + other = other.value + return self.value < other + @_wrap_cal_value def __le__(self, other) -> bool: self._check_value_when_call() if isinstance(other, Element): - return self.value <= other.value - else: - return self.value <= other + other = other.value + return self.value <= other + @_wrap_cal_value def __eq__(self, other): self._check_value_when_call() if isinstance(other, Element): - return self.value == other.value - else: - return self.value == other + other = other.value + return self.value == other + @_wrap_cal_value def __ne__(self, other) -> bool: self._check_value_when_call() if isinstance(other, Element): - return self.value != other.value - else: - return self.value != other + other = other.value + return self.value != other + @_wrap_cal_value def __ge__(self, other) -> bool: self._check_value_when_call() if isinstance(other, Element): - return self.value >= other.value - else: - return self.value >= other + other = other.value + return self.value >= other + @_wrap_cal_value def __gt__(self, other) -> bool: self._check_value_when_call() if isinstance(other, Element): - return self.value > other.value - else: - return self.value > other + other = other.value + return self.value > other def __str__(self): return str(self.value) diff --git a/tests/core/metrics/test_element_cal_element.py b/tests/core/metrics/test_element_cal_element.py new file mode 100644 index 00000000..340e2a43 --- /dev/null +++ b/tests/core/metrics/test_element_cal_element.py @@ -0,0 +1,35 @@ +import pytest + +from fastNLP.core.metrics import Metric +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + + +class MyMetric(Metric): + def __init__(self): + super(MyMetric, self).__init__() + self.register_element(name="t1", value=0) + self.register_element(name="t2", value=0) + self.register_element(name="t3", value=0) + + def update(self, pred): + self.t1 = len(pred) + self.t2 = len(pred) + temp = self.t1 + self.t2 + self.t3 = temp + self.t1 = self.t3 / self.t2 + + def get_metric(self) -> dict: + return {"t1": self.t1.get_scalar(), "t2": self.t2.get_scalar(), "t3": self.t3.get_scalar()} + + +class TestElemnt: + + def test_case_v1(self): + pred = torch.tensor([1, 1, 1, 1]) + metric = MyMetric() + metric.update(pred) + res = metric.get_metric() + print(res) \ No newline at end of file