|
|
@@ -3,6 +3,7 @@ __all__ = [ |
|
|
|
] |
|
|
|
|
|
|
|
import os |
|
|
|
import functools |
|
|
|
|
|
|
|
from .backend import Backend, AutoBackend |
|
|
|
from fastNLP.core.log import logger |
|
|
@@ -10,6 +11,17 @@ from .utils import AggregateMethodError |
|
|
|
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK |
|
|
|
|
|
|
|
|
|
|
|
def _wrap_cal_value(func): |
|
|
|
@functools.wraps(func) |
|
|
|
def _wrap_cal(*args, **kwargs): |
|
|
|
self = args[0] |
|
|
|
value = func(*args, **kwargs) |
|
|
|
value = self.backend.get_scalar(value) |
|
|
|
return value |
|
|
|
|
|
|
|
return _wrap_cal |
|
|
|
|
|
|
|
|
|
|
|
class Element: |
|
|
|
def __init__(self, name, value: float, aggregate_method, backend: Backend): |
|
|
|
""" |
|
|
@@ -107,6 +119,7 @@ class Element: |
|
|
|
对元素进行 fill_value, 会执行队友 backend 的 fill_value 方法 |
|
|
|
|
|
|
|
""" |
|
|
|
self._check_value_initialized() |
|
|
|
self._value = self.backend.fill_value(self._value, value) |
|
|
|
|
|
|
|
def to(self, device): |
|
|
@@ -138,163 +151,148 @@ class Element: |
|
|
|
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " |
|
|
|
"element, or use it after it being used by the `Metric.update()` method.") |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __add__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
other = other.value |
|
|
|
return self.value + other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __radd__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value += other.value |
|
|
|
else: |
|
|
|
self.value += other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value + other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __sub__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value -= other.value |
|
|
|
else: |
|
|
|
self.value -= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value - other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __rsub__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value -= other.value |
|
|
|
else: |
|
|
|
self.value -= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value - other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __mul__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value *= other.value |
|
|
|
else: |
|
|
|
self.value *= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value * other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __imul__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value *= other.value |
|
|
|
else: |
|
|
|
self.value *= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value * other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __floordiv__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value //= other.value |
|
|
|
else: |
|
|
|
self.value //= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value // other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __rfloordiv__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value //= other.value |
|
|
|
else: |
|
|
|
self.value //= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value // other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __truediv__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value /= other.value |
|
|
|
else: |
|
|
|
self.value /= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value / other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __rtruediv__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value /= other.value |
|
|
|
else: |
|
|
|
self.value /= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value / other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __mod__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value %= other.value |
|
|
|
else: |
|
|
|
self.value %= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value % other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __rmod__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value /= other.value |
|
|
|
else: |
|
|
|
self.value /= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value % other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __pow__(self, other, modulo=None): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
other = other.value |
|
|
|
if modulo is None: |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value **= other.value |
|
|
|
else: |
|
|
|
self.value **= other |
|
|
|
return self.value ** other |
|
|
|
else: |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value = pow(self.value, other.value, modulo) |
|
|
|
else: |
|
|
|
self.value = pow(self.value, other, modulo) |
|
|
|
return self |
|
|
|
return pow(self.value, other, modulo) |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __rpow__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
self.value **= other.value |
|
|
|
else: |
|
|
|
self.value **= other |
|
|
|
return self |
|
|
|
other = other.value |
|
|
|
return self.value ** other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __lt__(self, other) -> bool: |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
return self.value < other.value |
|
|
|
else: |
|
|
|
return self.value < other |
|
|
|
other = other.value |
|
|
|
return self.value < other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __le__(self, other) -> bool: |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
return self.value <= other.value |
|
|
|
else: |
|
|
|
return self.value <= other |
|
|
|
other = other.value |
|
|
|
return self.value <= other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __eq__(self, other): |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
return self.value == other.value |
|
|
|
else: |
|
|
|
return self.value == other |
|
|
|
other = other.value |
|
|
|
return self.value == other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __ne__(self, other) -> bool: |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
return self.value != other.value |
|
|
|
else: |
|
|
|
return self.value != other |
|
|
|
other = other.value |
|
|
|
return self.value != other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __ge__(self, other) -> bool: |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
return self.value >= other.value |
|
|
|
else: |
|
|
|
return self.value >= other |
|
|
|
other = other.value |
|
|
|
return self.value >= other |
|
|
|
|
|
|
|
@_wrap_cal_value |
|
|
|
def __gt__(self, other) -> bool: |
|
|
|
self._check_value_when_call() |
|
|
|
if isinstance(other, Element): |
|
|
|
return self.value > other.value |
|
|
|
else: |
|
|
|
return self.value > other |
|
|
|
other = other.value |
|
|
|
return self.value > other |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return str(self.value) |
|
|
|