|
|
@@ -99,8 +99,14 @@ class Metric: |
|
|
|
|
|
|
|
def __setattr__(self, key, value): |
|
|
|
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: |
|
|
|
if key in self.elements and value is not self.elements[key]: |
|
|
|
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}") |
|
|
|
if key in self.elements and isinstance(value, (float, int, bool)): |
|
|
|
self.elements[key].fill_value(value) |
|
|
|
return |
|
|
|
elif key in self.elements: |
|
|
|
raise TypeError(f"self.{key} is an Element, only float/int/bool type value can be assigned to it, " |
|
|
|
f"instead of {type(value)}.") |
|
|
|
if isinstance(value, Element) and key not in self.elements: |
|
|
|
raise RuntimeError("Please use register_element() function to add Element.") |
|
|
|
object.__setattr__(self, key, value) |
|
|
|
|
|
|
|
def _wrap_update(self, update): |
|
|
@@ -163,13 +169,6 @@ class Metric: |
|
|
|
""" |
|
|
|
self.aggregate_when_get_metric = flag |
|
|
|
|
|
|
|
def __getattr__(self, name: str) -> Element: |
|
|
|
if 'elements' in self.__dict__: |
|
|
|
elements = self.__dict__['elements'] |
|
|
|
if name in elements: |
|
|
|
return elements[name] |
|
|
|
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name)) |
|
|
|
|
|
|
|
def tensor2numpy(self, tensor) -> np.array: |
|
|
|
""" |
|
|
|
将tensor向量转为numpy类型变量 |
|
|
|