Browse Source

修改element,计算值转为非tensor形式

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
af6300e816
2 changed files with 108 additions and 75 deletions
  1. +73
    -75
      fastNLP/core/metrics/element.py
  2. +35
    -0
      tests/core/metrics/test_element_cal_element.py

+ 73
- 75
fastNLP/core/metrics/element.py View File

@@ -3,6 +3,7 @@ __all__ = [
]

import os
import functools

from .backend import Backend, AutoBackend
from fastNLP.core.log import logger
@@ -10,6 +11,17 @@ from .utils import AggregateMethodError
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK


def _wrap_cal_value(func):
@functools.wraps(func)
def _wrap_cal(*args, **kwargs):
self = args[0]
value = func(*args, **kwargs)
value = self.backend.get_scalar(value)
return value

return _wrap_cal


class Element:
def __init__(self, name, value: float, aggregate_method, backend: Backend):
"""
@@ -107,6 +119,7 @@ class Element:
对元素进行 fill_value, 会执行队友 backend 的 fill_value 方法

"""
self._check_value_initialized()
self._value = self.backend.fill_value(self._value, value)

def to(self, device):
@@ -138,163 +151,148 @@ class Element:
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.update()` method.")

@_wrap_cal_value
def __add__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
other = other.value
return self.value + other

@_wrap_cal_value
def __radd__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value += other.value
else:
self.value += other
return self
other = other.value
return self.value + other

@_wrap_cal_value
def __sub__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value -= other.value
else:
self.value -= other
return self
other = other.value
return self.value - other

@_wrap_cal_value
def __rsub__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value -= other.value
else:
self.value -= other
return self
other = other.value
return self.value - other

@_wrap_cal_value
def __mul__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value *= other.value
else:
self.value *= other
return self
other = other.value
return self.value * other

@_wrap_cal_value
def __imul__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value *= other.value
else:
self.value *= other
return self
other = other.value
return self.value * other

@_wrap_cal_value
def __floordiv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value //= other.value
else:
self.value //= other
return self
other = other.value
return self.value // other

@_wrap_cal_value
def __rfloordiv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value //= other.value
else:
self.value //= other
return self
other = other.value
return self.value // other

@_wrap_cal_value
def __truediv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self
other = other.value
return self.value / other

@_wrap_cal_value
def __rtruediv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self
other = other.value
return self.value / other

@_wrap_cal_value
def __mod__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value %= other.value
else:
self.value %= other
return self
other = other.value
return self.value % other

@_wrap_cal_value
def __rmod__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self
other = other.value
return self.value % other

@_wrap_cal_value
def __pow__(self, other, modulo=None):
self._check_value_when_call()
if isinstance(other, Element):
other = other.value
if modulo is None:
if isinstance(other, Element):
self.value **= other.value
else:
self.value **= other
return self.value ** other
else:
if isinstance(other, Element):
self.value = pow(self.value, other.value, modulo)
else:
self.value = pow(self.value, other, modulo)
return self
return pow(self.value, other, modulo)

@_wrap_cal_value
def __rpow__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value **= other.value
else:
self.value **= other
return self
other = other.value
return self.value ** other

@_wrap_cal_value
def __lt__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value < other.value
else:
return self.value < other
other = other.value
return self.value < other

@_wrap_cal_value
def __le__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value <= other.value
else:
return self.value <= other
other = other.value
return self.value <= other

@_wrap_cal_value
def __eq__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
return self.value == other.value
else:
return self.value == other
other = other.value
return self.value == other

@_wrap_cal_value
def __ne__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value != other.value
else:
return self.value != other
other = other.value
return self.value != other

@_wrap_cal_value
def __ge__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value >= other.value
else:
return self.value >= other
other = other.value
return self.value >= other

@_wrap_cal_value
def __gt__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value > other.value
else:
return self.value > other
other = other.value
return self.value > other

def __str__(self):
return str(self.value)


+ 35
- 0
tests/core/metrics/test_element_cal_element.py View File

@@ -0,0 +1,35 @@
import pytest

from fastNLP.core.metrics import Metric
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch


class MyMetric(Metric):
def __init__(self):
super(MyMetric, self).__init__()
self.register_element(name="t1", value=0)
self.register_element(name="t2", value=0)
self.register_element(name="t3", value=0)

def update(self, pred):
self.t1 = len(pred)
self.t2 = len(pred)
temp = self.t1 + self.t2
self.t3 = temp
self.t1 = self.t3 / self.t2

def get_metric(self) -> dict:
return {"t1": self.t1.get_scalar(), "t2": self.t2.get_scalar(), "t3": self.t3.get_scalar()}


class TestElemnt:

def test_case_v1(self):
pred = torch.tensor([1, 1, 1, 1])
metric = MyMetric()
metric.update(pred)
res = metric.get_metric()
print(res)

Loading…
Cancel
Save