@@ -41,6 +41,11 @@ def is_jittor_tensor(dtype): | |||||
def is_jittor_dtype_str(dtype): | def is_jittor_dtype_str(dtype): | ||||
""" | |||||
判断数据类型是否为 jittor 使用的字符串类型 | |||||
:param: dtype 数据类型 | |||||
""" | |||||
try: | try: | ||||
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', | if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', | ||||
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', | 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', | ||||
@@ -53,6 +58,13 @@ def is_jittor_dtype_str(dtype): | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
""" | |||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | |||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
""" | |||||
if not (ele_dtype is None or ( | if not (ele_dtype is None or ( | ||||
is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))): | is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
@@ -62,13 +74,7 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)): | if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)): | ||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | ||||
f"or jittor.dtype but get `{dtype}`.") | f"or jittor.dtype but get `{dtype}`.") | ||||
# dtype = number_to_jittor_dtype_dict.get(dtype, dtype) | |||||
else: | else: | ||||
# if (is_number(ele_dtype) or is_jittor_tensor(ele_dtype)): | |||||
# # ele_dtype = number_to_jittor_dtype_dict.get(ele_dtype, ele_dtype) | |||||
# dtype = ele_dtype | |||||
# elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
# dtype = numpy_to_jittor_dtype_dict.get(ele_dtype.type) | |||||
if is_numpy_generic_class(ele_dtype): | if is_numpy_generic_class(ele_dtype): | ||||
dtype = numpy_to_jittor_dtype_dict.get(ele_dtype) | dtype = numpy_to_jittor_dtype_dict.get(ele_dtype) | ||||
else: | else: | ||||
@@ -91,6 +97,11 @@ class JittorNumberPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
:param batch_field 输入的某个 field 的 batch 数据。 | |||||
:param pad_val 需要填充的值 | |||||
:dtype 数据的类型 | |||||
""" | |||||
return jittor.Var(np.array(batch_field, dtype=dtype)) | return jittor.Var(np.array(batch_field, dtype=dtype)) | ||||
@@ -108,6 +119,11 @@ class JittorSequencePadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
:param batch_field 输入的某个 field 的 batch 数据。 | |||||
:param pad_val 需要填充的值 | |||||
:dtype 数据的类型 | |||||
""" | |||||
tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | ||||
return tensor | return tensor | ||||
@@ -126,6 +142,13 @@ class JittorTensorPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val=0, dtype=None): | def pad(batch_field, pad_val=0, dtype=None): | ||||
""" | |||||
将 batch_field 数据 转为 jittor.Var 并 pad 到相同长度。 | |||||
:param batch_field 输入的某个 field 的 batch 数据。 | |||||
:param pad_val 需要填充的值 | |||||
:dtype 数据的类型 | |||||
""" | |||||
try: | try: | ||||
if not isinstance(batch_field[0], jittor.Var): | if not isinstance(batch_field[0], jittor.Var): | ||||
batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field] | batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field] | ||||
@@ -139,9 +162,6 @@ class JittorTensorPadder(Padder): | |||||
else: | else: | ||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | ||||
# if dtype is not None: | |||||
# tensor = jittor.full(max_shape, pad_val, dtype=dtype) | |||||
# else: | |||||
tensor = jittor.full(max_shape, pad_val, dtype=dtype) | tensor = jittor.full(max_shape, pad_val, dtype=dtype) | ||||
for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
slices = (i,) + tuple(slice(0, s) for s in shapes[i]) | slices = (i,) + tuple(slice(0, s) for s in shapes[i]) | ||||
@@ -15,6 +15,13 @@ from .exceptions import * | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
""" | |||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | |||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
""" | |||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers but get `{ele_dtype}`.") | f"or numpy numbers but get `{ele_dtype}`.") | ||||
@@ -36,6 +36,11 @@ from .exceptions import * | |||||
def is_paddle_tensor(dtype): | def is_paddle_tensor(dtype): | ||||
""" | |||||
判断 dtype 是否为 paddle 的 tensor | |||||
:param dtype 数据的 dtype 类型 | |||||
""" | |||||
if not isclass(dtype) and isinstance(dtype, paddle.dtype): | if not isclass(dtype) and isinstance(dtype, paddle.dtype): | ||||
return True | return True | ||||
@@ -43,6 +48,12 @@ def is_paddle_tensor(dtype): | |||||
def is_paddle_dtype_str(dtype): | def is_paddle_dtype_str(dtype): | ||||
""" | |||||
判断 dtype 是 str 类型 且属于 paddle 支持的 str 类型 | |||||
:param dtype 数据的 dtype 类型 | |||||
""" | |||||
try: | try: | ||||
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', | if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', | ||||
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', | 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', | ||||
@@ -56,6 +67,13 @@ def is_paddle_dtype_str(dtype): | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
""" | |||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | |||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
""" | |||||
if not (ele_dtype is None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | if not (ele_dtype is None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") | f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") | ||||
@@ -10,6 +10,13 @@ from .exceptions import * | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
""" | |||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | |||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
""" | |||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers but get `{ele_dtype}`.") | f"or numpy numbers but get `{ele_dtype}`.") | ||||
@@ -35,12 +35,24 @@ from .exceptions import * | |||||
def is_torch_tensor(dtype): | def is_torch_tensor(dtype): | ||||
""" | |||||
判断是否为 torch 的 tensor | |||||
:param dtype 数据的 dtype 类型 | |||||
""" | |||||
if not isclass(dtype) and isinstance(dtype, torch.dtype): | if not isclass(dtype) and isinstance(dtype, torch.dtype): | ||||
return True | return True | ||||
return False | return False | ||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
""" | |||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | |||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
""" | |||||
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | ||||
@@ -1,4 +1,150 @@ | |||||
r""" | r""" | ||||
:class:`~fastNLP.core.dataset.DataSet` 是 fastNLP 中用于承载数据的容器。可以将 DataSet 看做是一个表格, | |||||
每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.instance` ), | |||||
每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.field` )。 | |||||
.. csv-table:: Following is a demo layout of DataSet | |||||
:header: "sentence", "words", "seq_len" | |||||
"This is the first instance .", "[This, is, the, first, instance, .]", 6 | |||||
"Second instance .", "[Second, instance, .]", 3 | |||||
"Third instance .", "[Third, instance, .]", 3 | |||||
"...", "[...]", "..." | |||||
在 fastNLP 内部每一行是一个 :class:`~fastNLP.Instance` 对象; 每一列是一个 :class:`~fastNLP.FieldArray` 对象。 | |||||
---------------------------- | |||||
1.DataSet的创建 | |||||
---------------------------- | |||||
创建DataSet主要有以下的3种方式 | |||||
1.1 传入dict | |||||
---------------------------- | |||||
.. code-block:: | |||||
from fastNLP import DataSet | |||||
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."], | |||||
'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.'], | |||||
'seq_len': [6, 3, 3]} | |||||
dataset = DataSet(data) | |||||
# 传入的 dict 的每个 key 的 value 应该为具有相同长度的l ist | |||||
1.2 通过 Instance 构建 | |||||
---------------------------- | |||||
.. code-block:: | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
dataset = DataSet() | |||||
instance = Instance(sentence="This is the first instance", | |||||
words=['this', 'is', 'the', 'first', 'instance', '.'], | |||||
seq_len=6) | |||||
dataset.append(instance) | |||||
# 可以继续 append 更多内容,但是 append 的 instance 应该和第一个 instance 拥有完全相同的 field | |||||
1.3 通过 List[Instance] 构建 | |||||
-------------------------------------- | |||||
.. code-block:: | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
instances = [] | |||||
winstances.append(Instance(sentence="This is the first instance", | |||||
ords=['this', 'is', 'the', 'first', 'instance', '.'], | |||||
seq_len=6)) | |||||
instances.append(Instance(sentence="Second instance .", | |||||
words=['Second', 'instance', '.'], | |||||
seq_len=3)) | |||||
dataset = DataSet(instances) | |||||
-------------------------------------- | |||||
2.DataSet 与预处理 | |||||
-------------------------------------- | |||||
常见的预处理有如下几种 | |||||
2.1 从某个文本文件读取内容 | |||||
-------------------------------------- | |||||
.. code-block:: | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
dataset = DataSet() | |||||
filepath = 'some/text/file' | |||||
# 假设文件中每行内容如下(sentence label): | |||||
# This is a fantastic day positive | |||||
# The bad weather negative | |||||
# ..... | |||||
with open(filepath, 'r') as f: | |||||
for line in f: | |||||
sent, label = line.strip().split('\t') | |||||
dataset.append(Instance(sentence=sent, label=label)) | |||||
2.2 对 DataSet 中的内容处理 | |||||
-------------------------------------- | |||||
.. code-block:: | |||||
from fastNLP import DataSet | |||||
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."]} | |||||
dataset = DataSet(data) | |||||
# 将句子分成单词形式, 详见DataSet.apply()方法, 可以开启多进程来加快处理, 也可以更改展示的bar,目前支持 ``['rich', 'tqdm', None]``, | |||||
# 详细内容可以见 :class: `~fastNLP.core.dataset.DataSet`, 需要注意的时匿名函数不支持多进程 | |||||
dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', | |||||
progress_des='Main',progress_bar='rich') | |||||
# 或使用DataSet.apply_field() | |||||
dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', | |||||
progress_des='Main',progress_bar='rich') | |||||
# 除了匿名函数,也可以定义函数传递进去 | |||||
def get_words(instance): | |||||
sentence = instance['sentence'] | |||||
words = sentence.split() | |||||
return words | |||||
dataset.apply(get_words, new_field_name='words', num_proc=2, progress_des='Main',progress_bar='rich') | |||||
2.3 删除DataSet的内容 | |||||
-------------------------------------- | |||||
.. code-block:: | |||||
from fastNLP import DataSet | |||||
dataset = DataSet({'a': list(range(-5, 5))}) | |||||
# 返回满足条件的 instance,并放入 DataSet 中 | |||||
dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False) | |||||
# 在 dataset 中删除满足条件的i nstance | |||||
dataset.drop(lambda ins:ins['a']<0) # dataset 的 instance数量减少 | |||||
# 删除第 3 个 instance | |||||
dataset.delete_instance(2) | |||||
# 删除名为 'a' 的 field | |||||
dataset.delete_field('a') | |||||
2.4 遍历DataSet的内容 | |||||
-------------------------------------- | |||||
.. code-block:: | |||||
for instance in dataset: | |||||
# do something | |||||
2.5 一些其它操作 | |||||
-------------------------------------- | |||||
.. code-block:: | |||||
# 检查是否存在名为 'a' 的 field | |||||
dataset.has_field('a') # 或 ('a' in dataset) | |||||
# 将名为 'a' 的 field 改名为 'b' | |||||
dataset.rename_field('a', 'b') | |||||
# DataSet 的长度 | |||||
len(dataset) | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -42,9 +188,9 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, p | |||||
""" | """ | ||||
对数据集进行处理封装函数,以便多进程使用 | 对数据集进行处理封装函数,以便多进程使用 | ||||
:param ds: 数据集 | |||||
:param _apply_field: 需要处理数据集的field_name | |||||
:param func: 用户自定义的func | |||||
:param ds: 实现了 __getitem__() 和 __len__() 的对象 | |||||
:param _apply_field: 需要处理数据集的 field_name | |||||
:param func: 用户自定义的 func | |||||
:param desc: 进度条的描述字符 | :param desc: 进度条的描述字符 | ||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | ||||
:return: | :return: | ||||
@@ -76,9 +222,9 @@ def _multi_proc(ds, _apply_field, func, counter, queue): | |||||
""" | """ | ||||
对数据集进行处理封装函数,以便多进程使用 | 对数据集进行处理封装函数,以便多进程使用 | ||||
:param ds: 数据集 | |||||
:param _apply_field: 需要处理数据集的field_name | |||||
:param func: 用户自定义的func | |||||
:param ds: 实现了 __getitem__() 和 __len__() 的对象 | |||||
:param _apply_field: 需要处理数据集的 field_name | |||||
:param func: 用户自定义的 func | |||||
:param counter: 计数器 | :param counter: 计数器 | ||||
:param queue: 多进程时,将结果输入到这个 queue 中 | :param queue: 多进程时,将结果输入到这个 queue 中 | ||||
:return: | :return: | ||||
@@ -111,9 +257,28 @@ class DataSet: | |||||
def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): | def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): | ||||
r""" | r""" | ||||
初始化 ``DataSet``, fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class: `~fastNLP.core.dataset.Instance`]`` 和 | |||||
``Dict[str, List[Any]]``。 | |||||
* 当 data 为 ``List[:class: `~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||||
Instance 详见 :class: `~fastNLP.core.dataset.Instance` 。 | |||||
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||||
:param data: 初始化的内容, 其只能为两种类型,分别为 ``List[:class: `~fastNLP.core.dataset.Instance`]`` 和 | |||||
``Dict[str, List[Any]]``。 | |||||
* 当 data 为 ``List[:class: `~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 | |||||
Instance 详见 :class: `~fastNLP.core.dataset.Instance` 。 | |||||
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 | |||||
Example:: | |||||
from fastNLP.core.dataset import DataSet, Instance | |||||
data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]} | |||||
data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)] | |||||
ds = DataSet(data) | |||||
ds = DataSet(data1) | |||||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | |||||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||||
""" | """ | ||||
self.field_arrays = {} | self.field_arrays = {} | ||||
self._collator = Collator() | self._collator = Collator() | ||||
@@ -168,11 +333,27 @@ class DataSet: | |||||
return inner_iter_func() | return inner_iter_func() | ||||
def __getitem__(self, idx: Union[int, slice, str, list]): | def __getitem__(self, idx: Union[int, slice, str, list]): | ||||
r"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||||
r""" | |||||
去 DataSet 的内容, 根据 idx 类型不同有不同的返回值。 包括四种类型 ``[int, slice, str, list]`` | |||||
:param idx: can be int or slice. | |||||
:return: If `idx` is int, return an Instance object. | |||||
If `idx` is slice, return a DataSet object. | |||||
* 当 idx 为 ``int`` 时, idx 的值不能超过 ``DataSet`` 的长度, 会返回一个 ``Instance``, 详见 | |||||
:class: `~fastNLP.core.dataset.Instance` | |||||
* 当 idx 为 ``slice`` 时, 会根据 slice 的内容创建一个新的 DataSet,其包含 slice 所有内容并返回。 | |||||
* 当 idx 为 ``str`` 时, 该 idx 为 DataSet 的 field_name, 其会返回该 field_name 的所有内容, 为 list 类型。 | |||||
* 当 idx 为 ``list`` 时, 该 idx 的 list 内全为 int 数字, 其会取出所有内容组成一个新的 DataSet 返回。 | |||||
Example:: | |||||
from fastNLP.core.dataset import DataSet | |||||
ds = DataSet({'x': [[1, 0, 1], [0, 1, 1] * 100, 'y': [0, 1] * 100}) | |||||
ins = ds[0] | |||||
sub_ds = ds[0:100] | |||||
sub_ds= ds[[1, 0, 3, 2, 1, 4]] | |||||
field = ds['x'] | |||||
:param idx: 用户传入参数 | |||||
:return: | |||||
""" | """ | ||||
if isinstance(idx, int): | if isinstance(idx, int): | ||||
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) | return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) | ||||
@@ -230,9 +411,10 @@ class DataSet: | |||||
return self.__dict__ | return self.__dict__ | ||||
def __len__(self): | def __len__(self): | ||||
r"""Fetch the length of the dataset. | |||||
r""" | |||||
获取 DataSet 的长度 | |||||
:return length: | |||||
:return | |||||
""" | """ | ||||
if len(self.field_arrays) == 0: | if len(self.field_arrays) == 0: | ||||
return 0 | return 0 | ||||
@@ -244,9 +426,9 @@ class DataSet: | |||||
def append(self, instance: Instance) -> None: | def append(self, instance: Instance) -> None: | ||||
r""" | r""" | ||||
将一个instance对象append到DataSet后面。 | |||||
将一个 instance 对象 append 到 DataSet 后面。详见 :class: `~fastNLP.Instance` | |||||
:param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。 | |||||
:param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field。 | |||||
""" | """ | ||||
if len(self.field_arrays) == 0: | if len(self.field_arrays) == 0: | ||||
@@ -269,10 +451,10 @@ class DataSet: | |||||
def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: | def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: | ||||
r""" | r""" | ||||
将fieldarray添加到DataSet中. | |||||
将 fieldarray 添加到 DataSet 中. | |||||
:param str field_name: 新加入的field的名称 | |||||
:param ~fastNLP.core.FieldArray fieldarray: 需要加入DataSet的field的内容 | |||||
:param field_name: 新加入的 field 的名称 | |||||
:param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class: `~fastNLP.core.dataset.FieldArray` | |||||
:return: | :return: | ||||
""" | """ | ||||
if not isinstance(fieldarray, FieldArray): | if not isinstance(fieldarray, FieldArray): | ||||
@@ -285,10 +467,10 @@ class DataSet: | |||||
def add_field(self, field_name: str, fields: list) -> None: | def add_field(self, field_name: str, fields: list) -> None: | ||||
r""" | r""" | ||||
新增一个field, 需要注意的是fields的长度跟dataset长度一致 | |||||
新增一个 field, 需要注意的是 fields 的长度跟 DataSet 长度一致 | |||||
:param str field_name: 新增的field的名称 | |||||
:param list fields: 需要新增的field的内容 | |||||
:param field_name: 新增的 field 的名称 | |||||
:param fields: 需要新增的 field 的内容 | |||||
""" | """ | ||||
if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
@@ -299,9 +481,9 @@ class DataSet: | |||||
def delete_instance(self, index: int): | def delete_instance(self, index: int): | ||||
r""" | r""" | ||||
删除第index个instance | |||||
删除第 ``index `` 个 Instance | |||||
:param int index: 需要删除的instance的index,序号从0开始。 | |||||
:param index: 需要删除的 instanc e的 index,序号从 `0` 开始。 | |||||
""" | """ | ||||
assert isinstance(index, int), "Only integer supported." | assert isinstance(index, int), "Only integer supported." | ||||
if len(self) <= index: | if len(self) <= index: | ||||
@@ -315,9 +497,9 @@ class DataSet: | |||||
def delete_field(self, field_name: str): | def delete_field(self, field_name: str): | ||||
r""" | r""" | ||||
删除名为field_name的field | |||||
删除名为 field_name 的 field | |||||
:param str field_name: 需要删除的field的名称. | |||||
:param field_name: 需要删除的 field 的名称. | |||||
""" | """ | ||||
if self.has_field(field_name): | if self.has_field(field_name): | ||||
self.field_arrays.pop(field_name) | self.field_arrays.pop(field_name) | ||||
@@ -327,10 +509,10 @@ class DataSet: | |||||
def copy_field(self, field_name: str, new_field_name: str): | def copy_field(self, field_name: str, new_field_name: str): | ||||
r""" | r""" | ||||
深度copy名为field_name的field到new_field_name | |||||
深度 copy 名为 field_name 的 field 到 new_field_name | |||||
:param str field_name: 需要copy的field。 | |||||
:param str new_field_name: copy生成的field名称 | |||||
:param field_name: 需要 copy 的 field。 | |||||
:param new_field_name: copy 生成的 field 名称 | |||||
:return: self | :return: self | ||||
""" | """ | ||||
if not self.has_field(field_name): | if not self.has_field(field_name): | ||||
@@ -342,10 +524,10 @@ class DataSet: | |||||
def has_field(self, field_name: str) -> bool: | def has_field(self, field_name: str) -> bool: | ||||
r""" | r""" | ||||
判断DataSet中是否有名为field_name这个field | |||||
判断 DataSet 中是否有名为 field_name 这个 field | |||||
:param str field_name: field的名称 | |||||
:return bool: 表示是否有名为field_name这个field | |||||
:param field_name: field 的名称 | |||||
:return: 表示是否有名为 field_name 这个 field | |||||
""" | """ | ||||
if isinstance(field_name, str): | if isinstance(field_name, str): | ||||
return field_name in self.field_arrays | return field_name in self.field_arrays | ||||
@@ -353,9 +535,9 @@ class DataSet: | |||||
def get_field(self, field_name: str) -> FieldArray: | def get_field(self, field_name: str) -> FieldArray: | ||||
r""" | r""" | ||||
获取field_name这个field | |||||
获取 field_name 这个 field | |||||
:param str field_name: field的名称 | |||||
:param field_name: field 的名称 | |||||
:return: :class:`~fastNLP.FieldArray` | :return: :class:`~fastNLP.FieldArray` | ||||
""" | """ | ||||
if field_name not in self.field_arrays: | if field_name not in self.field_arrays: | ||||
@@ -364,34 +546,34 @@ class DataSet: | |||||
def get_all_fields(self) -> dict: | def get_all_fields(self) -> dict: | ||||
r""" | r""" | ||||
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||||
返回一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||||
:return dict: 返回如上所述的字典 | |||||
:return: 返回如上所述的字典 | |||||
""" | """ | ||||
return self.field_arrays | return self.field_arrays | ||||
def get_field_names(self) -> list: | def get_field_names(self) -> list: | ||||
r""" | r""" | ||||
返回一个list,包含所有 field 的名字 | |||||
返回一个 list,包含所有 field 的名字 | |||||
:return list: 返回如上所述的列表 | |||||
:return: 返回如上所述的列表 | |||||
""" | """ | ||||
return sorted(self.field_arrays.keys()) | return sorted(self.field_arrays.keys()) | ||||
def get_length(self) -> int: | def get_length(self) -> int: | ||||
r""" | r""" | ||||
获取DataSet的元素数量 | |||||
获取 DataSet 的元素数量 | |||||
:return: int: DataSet中Instance的个数。 | |||||
:return: DataSet 中 Instance 的个数。 | |||||
""" | """ | ||||
return len(self) | return len(self) | ||||
def rename_field(self, field_name: str, new_field_name: str): | def rename_field(self, field_name: str, new_field_name: str): | ||||
r""" | r""" | ||||
将某个field重新命名. | |||||
将某个 field 重新命名. | |||||
:param str field_name: 原来的field名称。 | |||||
:param str new_field_name: 修改为new_name。 | |||||
:param field_name: 原来的 field 名称。 | |||||
:param new_field_name: 修改为 new_name。 | |||||
""" | """ | ||||
if field_name in self.field_arrays: | if field_name in self.field_arrays: | ||||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | ||||
@@ -627,10 +809,10 @@ class DataSet: | |||||
def add_seq_len(self, field_name: str, new_field_name='seq_len'): | def add_seq_len(self, field_name: str, new_field_name='seq_len'): | ||||
r""" | r""" | ||||
将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。 | |||||
将使用 len() 直接对 field_name 中每个元素作用,将其结果作为 sequence length, 并放入 seq_len 这个 field。 | |||||
:param field_name: str. | |||||
:param new_field_name: str. 新的field_name | |||||
:param field_name: 需要处理的 field_name | |||||
:param new_field_name: str. 新的 field_name | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.has_field(field_name=field_name): | if self.has_field(field_name=field_name): | ||||
@@ -641,10 +823,11 @@ class DataSet: | |||||
def drop(self, func: Callable, inplace=True): | def drop(self, func: Callable, inplace=True): | ||||
r""" | r""" | ||||
func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者不会包含在返回的DataSet中。 | |||||
删除某些 Instance。 需要注意的时func 接受一个 Instance ,返回 bool 值。返回值为 True 时, | |||||
该 Instance 会被移除或者不会包含在返回的 DataSet 中。 | |||||
:param callable func: 接受一个Instance作为参数,返回bool值。为True时删除该instance | |||||
:param bool inplace: 是否在当前DataSet中直接删除instance;如果为False,将返回一个新的DataSet。 | |||||
:param func: 接受一个 Instance 作为参数,返回 bool 值。为 True 时删除该 instance | |||||
:param inplace: 是否在当前 DataSet 中直接删除 instance;如果为 False,将返回一个新的 DataSet。 | |||||
:return: DataSet | :return: DataSet | ||||
""" | """ | ||||
@@ -663,10 +846,10 @@ class DataSet: | |||||
def split(self, ratio: float, shuffle=True): | def split(self, ratio: float, shuffle=True): | ||||
r""" | r""" | ||||
将DataSet按照ratio的比例拆分,返回两个DataSet | |||||
将 DataSet 按照 ratio 的比例拆分,返回两个 DataSet | |||||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `ratio` 这么多数据,第二个DataSet拥有`(1-ratio)`这么多数据 | |||||
:param bool shuffle: 在split前是否shuffle一下。为False,返回的第一个dataset就是当前dataset中前`ratio`比例的数据, | |||||
:param ratio: 0<ratio<1, 返回的第一个 DataSet 拥有 `ratio` 这么多数据,第二个 DataSet 拥有 `(1-ratio)` 这么多数据 | |||||
:param shuffle: 在 split 前是否 shuffle 一下。为 False,返回的第一个 dataset 就是当前 dataset 中前 `ratio` 比例的数据, | |||||
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] | :return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] | ||||
""" | """ | ||||
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' | assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' | ||||
@@ -696,7 +879,7 @@ class DataSet: | |||||
r""" | r""" | ||||
保存DataSet. | 保存DataSet. | ||||
:param str path: 将DataSet存在哪个路径 | |||||
:param path: 将DataSet存在哪个路径 | |||||
""" | """ | ||||
with open(path, 'wb') as f: | with open(path, 'wb') as f: | ||||
pickle.dump(self, f) | pickle.dump(self, f) | ||||
@@ -704,9 +887,9 @@ class DataSet: | |||||
@staticmethod | @staticmethod | ||||
def load(path: str): | def load(path: str): | ||||
r""" | r""" | ||||
从保存的DataSet pickle文件的路径中读取DataSet | |||||
从保存的 DataSet pickle文件的路径中读取DataSet | |||||
:param str path: 从哪里读取DataSet | |||||
:param path: 从哪里读取 DataSet | |||||
:return: 读取后的 :class:`~fastNLP.读取后的DataSet`。 | :return: 读取后的 :class:`~fastNLP.读取后的DataSet`。 | ||||
""" | """ | ||||
with open(path, 'rb') as f: | with open(path, 'rb') as f: | ||||
@@ -716,16 +899,16 @@ class DataSet: | |||||
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': | def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': | ||||
""" | """ | ||||
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target | |||||
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 | |||||
当前dataset含有field,则会报错。 | |||||
将当前 dataset 与输入的 dataset 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset | |||||
的 field_name 和 _collator 以当前 dataset 为准。当 dataset 中包含的 field 多于当前的 dataset,则多余的 field 会被忽略; | |||||
若 dataset 中未包含所有当前 dataset 含有 field,则会报错。 | |||||
:param DataSet, dataset: 需要和当前dataset concat的dataset | |||||
:param bool, inplace: 是否直接将dataset组合到当前dataset中 | |||||
:param dict, field_mapping: 当传入的dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的 | |||||
field名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称 | |||||
:param dataset: 需要和当前 dataset concat的 dataset | |||||
:param inplace: 是否直接将 dataset 组合到当前 dataset 中 | |||||
:param field_mapping: 当传入的 dataset 中的 field 名称和当前 dataset 不一致时,需要通过 field_mapping 把输入的 dataset 中的 | |||||
field 名称映射到当前 field. field_mapping 为 dict 类型,key 为 dataset 中的 field 名称,value 是需要映射成的名称 | |||||
:return: DataSet | |||||
:return: :class: `~fastNLP.core.dataset.DataSet`` | |||||
""" | """ | ||||
assert isinstance(dataset, DataSet), "Can only concat two datasets." | assert isinstance(dataset, DataSet), "Can only concat two datasets." | ||||
@@ -754,8 +937,8 @@ class DataSet: | |||||
@classmethod | @classmethod | ||||
def from_pandas(cls, df): | def from_pandas(cls, df): | ||||
""" | """ | ||||
从pandas.DataFrame中读取数据转为Dataset | |||||
:param df: | |||||
从 ``pandas.DataFrame`` 中读取数据转为 DataSet | |||||
:param df: 使用 pandas 读取的数据 | |||||
:return: | :return: | ||||
""" | """ | ||||
df_dict = df.to_dict(orient='list') | df_dict = df.to_dict(orient='list') | ||||
@@ -763,7 +946,7 @@ class DataSet: | |||||
def to_pandas(self): | def to_pandas(self): | ||||
""" | """ | ||||
将dataset转为pandas.DataFrame类型的数据 | |||||
将 DataSet 数据转为 ``pandas.DataFrame`` 类型的数据 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -773,9 +956,9 @@ class DataSet: | |||||
def to_csv(self, path: str): | def to_csv(self, path: str): | ||||
""" | """ | ||||
将dataset保存为csv文件 | |||||
将 DataSet 保存为 csv 文件 | |||||
:param path: | |||||
:param path: 保存到路径 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -16,6 +16,13 @@ import numpy as np | |||||
class FieldArray: | class FieldArray: | ||||
def __init__(self, name: str, content): | def __init__(self, name: str, content): | ||||
""" | |||||
初始化 FieldArray | |||||
:param name: 字符串的名称 | |||||
:param content: 任意类型的数据 | |||||
""" | |||||
if len(content) == 0: | if len(content) == 0: | ||||
raise RuntimeError("Empty fieldarray is not allowed.") | raise RuntimeError("Empty fieldarray is not allowed.") | ||||
_content = content | _content = content | ||||
@@ -29,15 +36,17 @@ class FieldArray: | |||||
def append(self, val: Any) -> None: | def append(self, val: Any) -> None: | ||||
r""" | r""" | ||||
:param val: 把该val append到fieldarray。 | |||||
:param val: 把该 val append 到 fieldarray。 | |||||
:return: | :return: | ||||
""" | """ | ||||
self.content.append(val) | self.content.append(val) | ||||
def pop(self, index: int) -> None: | def pop(self, index: int) -> None: | ||||
r""" | r""" | ||||
删除该field中index处的元素 | |||||
:param int index: 从0开始的数据下标。 | |||||
删除该 field 中 index 处的元素 | |||||
:param index: 从 ``0`` 开始的数据下标。 | |||||
:return: | :return: | ||||
""" | """ | ||||
self.content.pop(index) | self.content.pop(index) | ||||
@@ -51,10 +60,10 @@ class FieldArray: | |||||
def get(self, indices: Union[int, List[int]]): | def get(self, indices: Union[int, List[int]]): | ||||
r""" | r""" | ||||
根据给定的indices返回内容。 | |||||
根据给定的 indices 返回内容。 | |||||
:param int,List[int] indices: 获取indices对应的内容。 | |||||
:return: 根据给定的indices返回的内容,可能是单个值或ndarray | |||||
:param indices: 获取 indices 对应的内容。 | |||||
:return: 根据给定的 indices 返回的内容,可能是单个值 或 ``ndarray`` | |||||
""" | """ | ||||
if isinstance(indices, int): | if isinstance(indices, int): | ||||
if indices == -1: | if indices == -1: | ||||
@@ -69,18 +78,18 @@ class FieldArray: | |||||
def __len__(self): | def __len__(self): | ||||
r""" | r""" | ||||
Returns the size of FieldArray. | |||||
返回长度 | |||||
:return int length: | |||||
:return length: | |||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def split(self, sep: str = None, inplace: bool = True): | def split(self, sep: str = None, inplace: bool = True): | ||||
r""" | r""" | ||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。 | |||||
依次对自身的元素使用 ``.split()`` 方法,应该只有当本 field 的元素为 ``str`` 时,该方法才有用。 | |||||
:param sep: 分割符,如果为None则直接调用str.split()。 | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:param sep: 分割符,如果为 ``None`` 则直接调用 ``str.split()``。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||||
:return: List[List[str]] or self | :return: List[List[str]] or self | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -94,10 +103,11 @@ class FieldArray: | |||||
def int(self, inplace: bool = True): | def int(self, inplace: bool = True): | ||||
r""" | r""" | ||||
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
将本 field 中的值调用 ``int(cell)``. 支持 field 中内容为以下两种情况: | |||||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list`` ,``list`` 中的值会被依次转换。) | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||||
:return: List[int], List[List[int]], self | :return: List[int], List[List[int]], self | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -114,10 +124,12 @@ class FieldArray: | |||||
def float(self, inplace=True): | def float(self, inplace=True): | ||||
r""" | r""" | ||||
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
将本 field 中的值调用 ``float(cell)``. 支持 field 中内容为以下两种情况: | |||||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 ``field``。否则返回 ``list``。 | |||||
:return: | :return: | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -134,10 +146,12 @@ class FieldArray: | |||||
def bool(self, inplace=True): | def bool(self, inplace=True): | ||||
r""" | r""" | ||||
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
将本field中的值调用 ``bool(cell)``. 支持 field 中内容为以下两种情况 | |||||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 ``field``。否则返回 ``list``。 | |||||
:return: | :return: | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -155,10 +169,12 @@ class FieldArray: | |||||
def lower(self, inplace=True): | def lower(self, inplace=True): | ||||
r""" | r""" | ||||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况 | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
* ['1', '2', ...](即 ``field`` 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list``中的值会被依次转换。) | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||||
:return: List[int], List[List[int]], self | :return: List[int], List[List[int]], self | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -175,10 +191,12 @@ class FieldArray: | |||||
def upper(self, inplace=True): | def upper(self, inplace=True): | ||||
r""" | r""" | ||||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况 | |||||
* ['1', '2', ...](即 field 中每个值为 ``str`` 的), | |||||
* [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 ``list``,``list`` 中的值会被依次转换。) | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 ``list``。 | |||||
:return: List[int], List[List[int]], self | :return: List[int], List[List[int]], self | ||||
""" | """ | ||||
new_contents = [] | new_contents = [] | ||||
@@ -195,9 +213,9 @@ class FieldArray: | |||||
def value_count(self): | def value_count(self): | ||||
r""" | r""" | ||||
返回该field下不同value的数量。多用于统计label数量 | |||||
返回该 field 下不同 value的 数量。多用于统计 label 数量 | |||||
:return: Counter, key是label,value是出现次数 | |||||
:return: Counter, key 是 label,value 是出现次数 | |||||
""" | """ | ||||
count = Counter() | count = Counter() | ||||
@@ -214,7 +232,7 @@ class FieldArray: | |||||
def _after_process(self, new_contents: list, inplace: bool): | def _after_process(self, new_contents: list, inplace: bool): | ||||
r""" | r""" | ||||
当调用处理函数之后,决定是否要替换field。 | |||||
当调用处理函数之后,决定是否要替换 field。 | |||||
:param new_contents: | :param new_contents: | ||||
:param inplace: | :param inplace: | ||||
@@ -1,5 +1,5 @@ | |||||
r""" | r""" | ||||
instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 | |||||
instance 模块实现了 Instance 类在 fastNLP 中对应 sample。一个 sample 可以认为是一个 Instance 类型的对象。 | |||||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 。 | 便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 。 | ||||
""" | """ | ||||
@@ -27,16 +27,16 @@ class Instance(Mapping): | |||||
def add_field(self, field_name: str, field: any): | def add_field(self, field_name: str, field: any): | ||||
r""" | r""" | ||||
向Instance中增加一个field | |||||
向 Instance 中增加一个 field | |||||
:param str field_name: 新增field的名称 | |||||
:param Any field: 新增field的内容 | |||||
:param field_name: 新增 field 的名称 | |||||
:param field: 新增 field 的内容 | |||||
""" | """ | ||||
self.fields[field_name] = field | self.fields[field_name] = field | ||||
def items(self): | def items(self): | ||||
r""" | r""" | ||||
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | |||||
返回一个迭代器,迭代器返回两个内容,第一个内容是 field_name, 第二个内容是 field_value | |||||
:return: 一个迭代器 | :return: 一个迭代器 | ||||
""" | """ | ||||
@@ -12,18 +12,35 @@ from .jittor_backend.backend import JittorBackend | |||||
class AutoBackend(Backend): | class AutoBackend(Backend): | ||||
""" | """ | ||||
不需要初始化backend的AutoBackend,能够根据get_metric时候判断输入数据类型来选择backend是什么类型的 | |||||
不需要初始化 backend 的 AutoBackend,能够根据 get_metric 时候判断输入数据类型来选择 backend 是什么类型的 | |||||
""" | """ | ||||
def __init__(self, backend: Union[str, Backend, None]): | def __init__(self, backend: Union[str, Backend, None]): | ||||
""" | |||||
初始化 backend. | |||||
:param backend: 目前支持三种值,为 ``[str, Backend, None]``。 | |||||
* 当 backend 为 `str` 时, 其只能为 'auto' | |||||
* 当 backend 为 ``Backend`` 对象时, 其直接使用该对象方法覆盖 AutoBackend | |||||
* 当 backend 为 ``None`` 时, 根据 get_metric 时候判断输入数据类型来选择 backend 是什么类型的 | |||||
""" | |||||
super(AutoBackend, self).__init__() | super(AutoBackend, self).__init__() | ||||
if backend != 'auto': | if backend != 'auto': | ||||
self._convert_backend(backend) | self._convert_backend(backend) | ||||
def _convert_backend(self, backend): | def _convert_backend(self, backend): | ||||
""" | """ | ||||
将AutoBackend转换为合适的Backend对象 | |||||
将 AutoBackend 转换为合适的 Backend 对象 | |||||
:param backend: 传入的 backend 值。 | |||||
* 当 backend 为 `torch` 时, 选择 :class: `~fastNLP.core.metric.TorchBackend` | |||||
* 当 backend 为 `paddle` 时, 选择 :class: `~fastNLP.core.metric.PaddleBackend` | |||||
* 当 backend 为 `jittor` 时, 选择 :class: `~fastNLP.core.metric.JittorBackend` | |||||
* 当 backend 为 ``None`` 时, 直接初始化 | |||||
""" | """ | ||||
if isinstance(backend, Backend): | if isinstance(backend, Backend): | ||||
@@ -43,6 +60,12 @@ class AutoBackend(Backend): | |||||
self._specified = True | self._specified = True | ||||
def choose_real_backend(self, args): | def choose_real_backend(self, args): | ||||
""" | |||||
根据 args 参数类型来选择需要真正初始化的 backend | |||||
:param args: args 参数, 可能为 ``jittor``, ``torch``, ``paddle``, ``numpy`` 类型, 能够检测并选择真正的 backend。 | |||||
""" | |||||
assert not self.is_specified(), "This method should not be called after backend has been specified. " \ | assert not self.is_specified(), "This method should not be called after backend has been specified. " \ | ||||
"This must be a bug, please report." | "This must be a bug, please report." | ||||
types = [] | types = [] | ||||
@@ -12,7 +12,9 @@ class Backend: | |||||
def aggregate(self, tensor, method: str): | def aggregate(self, tensor, method: str): | ||||
""" | """ | ||||
聚集结果,并根据method计算后,返回结果 | |||||
聚集结果,并根据 method 计算后,返回结果 | |||||
:param tensor: 传入的张量 | |||||
:param method: 聚合的方法 | |||||
""" | """ | ||||
if method is not None: | if method is not None: | ||||
return AggregateMethodError(should_have_aggregate_method=False, only_warn=True) | return AggregateMethodError(should_have_aggregate_method=False, only_warn=True) | ||||
@@ -22,6 +24,8 @@ class Backend: | |||||
def create_tensor(self, value: float): | def create_tensor(self, value: float): | ||||
""" | """ | ||||
创建tensor,并且填入value作为值 | 创建tensor,并且填入value作为值 | ||||
:param value: 需要初始化的 value 值 | |||||
""" | """ | ||||
return value | return value | ||||
@@ -29,6 +33,8 @@ class Backend: | |||||
""" | """ | ||||
将tensor的值设置为value | 将tensor的值设置为value | ||||
:param tensor: 传进来的张量 | |||||
:param value: 需要填充的值 | |||||
""" | """ | ||||
return value | return value | ||||
@@ -36,14 +42,14 @@ class Backend: | |||||
""" | """ | ||||
tensor的saclar值 | tensor的saclar值 | ||||
:param tensor: | |||||
:param tensor: 传入的张量 | |||||
:return: | :return: | ||||
""" | """ | ||||
return tensor | return tensor | ||||
def is_specified(self) -> bool: | def is_specified(self) -> bool: | ||||
""" | """ | ||||
判断是否是某种框架的backend | |||||
判断是否是某种框架的 backend | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -51,15 +57,19 @@ class Backend: | |||||
def tensor2numpy(self, tensor): | def tensor2numpy(self, tensor): | ||||
""" | """ | ||||
将tensor转为numpy | |||||
将 tensor 转为 numpy | |||||
:param tensor: | |||||
:param tensor: 传入的张量 | |||||
:return: | :return: | ||||
""" | """ | ||||
return tensor | return tensor | ||||
def move_tensor_to_device(self, tensor, device): | def move_tensor_to_device(self, tensor, device): | ||||
""" | """ | ||||
将张量移动到某个设备上 | |||||
:param tensor: 传入的张量 | |||||
:param device: 设备号, 一般为 ``'cpu'``, ``'cuda:0'`` 等。 | |||||
""" | """ | ||||
return tensor | return tensor | ||||
@@ -16,20 +16,20 @@ class JittorBackend(Backend): | |||||
def aggregate(self, tensor, method: str): | def aggregate(self, tensor, method: str): | ||||
""" | """ | ||||
聚集结果,并根据method计算后,返回结果 | |||||
聚集结果,并根据 method 计算后,返回结果 | |||||
""" | """ | ||||
return tensor | return tensor | ||||
def create_tensor(self, value: float): | def create_tensor(self, value: float): | ||||
""" | """ | ||||
创建tensor,并且填入value作为值 | |||||
创建 tensor,并且填入 value 作为值 | |||||
""" | """ | ||||
value = jittor.Var(value) | value = jittor.Var(value) | ||||
return value | return value | ||||
def fill_value(self, tensor, value: float): | def fill_value(self, tensor, value: float): | ||||
""" | """ | ||||
将tensor的值设置为value | |||||
将 tensor 的值设置为 value | |||||
""" | """ | ||||
value = jittor.full_like(tensor, value) | value = jittor.full_like(tensor, value) | ||||
@@ -37,7 +37,7 @@ class JittorBackend(Backend): | |||||
def get_scalar(self, tensor) -> float: | def get_scalar(self, tensor) -> float: | ||||
""" | """ | ||||
tensor的saclar值 | |||||
tensor 的 saclar 值 | |||||
:param tensor: | :param tensor: | ||||
:return: | :return: | ||||
@@ -46,7 +46,7 @@ class JittorBackend(Backend): | |||||
def is_specified(self) -> bool: | def is_specified(self) -> bool: | ||||
""" | """ | ||||
判断是否是某种框架的backend | |||||
判断是否是某种框架的 backend | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -54,7 +54,7 @@ class JittorBackend(Backend): | |||||
def tensor2numpy(self, tensor): | def tensor2numpy(self, tensor): | ||||
""" | """ | ||||
将tensor转为numpy | |||||
将 tensor 转为 numpy | |||||
:param tensor: | :param tensor: | ||||
:return: | :return: | ||||
@@ -68,6 +68,6 @@ class JittorBackend(Backend): | |||||
def move_tensor_to_device(self, tensor, device): | def move_tensor_to_device(self, tensor, device): | ||||
""" | """ | ||||
jittor的没有转移设备的函数,因此该函数实际上无效 | |||||
jittor 的没有转移设备的函数,因此该函数实际上无效 | |||||
""" | """ | ||||
return tensor | return tensor |
@@ -23,7 +23,16 @@ class PaddleBackend(Backend): | |||||
def aggregate(self, tensor, method: str): | def aggregate(self, tensor, method: str): | ||||
""" | """ | ||||
聚集结果,并根据method计算后,返回结果 | |||||
聚集结果,并根据 method 计算后,返回结果 | |||||
:param tensor: 需要聚合的张量 | |||||
:param method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'mix']``: | |||||
* method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。 | |||||
* method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。 | |||||
* method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。 | |||||
* method 为 ``'mix'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。 | |||||
""" | """ | ||||
if isinstance(tensor, paddle.Tensor): | if isinstance(tensor, paddle.Tensor): | ||||
if parallel_helper._is_parallel_ctx_initialized(): | if parallel_helper._is_parallel_ctx_initialized(): | ||||
@@ -48,23 +57,37 @@ class PaddleBackend(Backend): | |||||
def create_tensor(self, value: float): | def create_tensor(self, value: float): | ||||
""" | """ | ||||
创建tensor,并且填入value作为值 | |||||
创建 tensor,并且填入 value 作为值 | |||||
:param value: 创建张量的初始值 | |||||
""" | """ | ||||
tensor = paddle.ones((1,)).fill_(value) | tensor = paddle.ones((1,)).fill_(value) | ||||
return tensor | return tensor | ||||
def fill_value(self, tensor, value: float): | def fill_value(self, tensor, value: float): | ||||
""" | """ | ||||
将tensor的值设置为value | |||||
将 tensor 的值设置为 value | |||||
:param tensor: 传入的张量 | |||||
:param value: 需要 fill 的值。 | |||||
""" | """ | ||||
tensor.fill_(value) | tensor.fill_(value) | ||||
return tensor | return tensor | ||||
def get_scalar(self, tensor) -> float: | def get_scalar(self, tensor) -> float: | ||||
""" | |||||
获取 tensor 的 scalar 值 | |||||
:param tensor: 传入的张量 | |||||
""" | |||||
return tensor.item() | return tensor.item() | ||||
def tensor2numpy(self, tensor) -> np.array: | def tensor2numpy(self, tensor) -> np.array: | ||||
""" | |||||
将 tensor 转为 numpy 值, 主要是在 metric 计算中使用 | |||||
:param tensor: 传入的张量 | |||||
""" | |||||
if isinstance(tensor, paddle.Tensor): | if isinstance(tensor, paddle.Tensor): | ||||
return tensor.cpu().detach().numpy() | return tensor.cpu().detach().numpy() | ||||
elif isinstance(tensor, np.array): | elif isinstance(tensor, np.array): | ||||
@@ -77,15 +100,29 @@ class PaddleBackend(Backend): | |||||
@staticmethod | @staticmethod | ||||
def is_distributed() -> bool: | def is_distributed() -> bool: | ||||
""" | """ | ||||
判断是否为 ddp 状态 | |||||
:return: | :return: | ||||
""" | """ | ||||
return is_in_paddle_dist() | return is_in_paddle_dist() | ||||
def move_tensor_to_device(self, tensor, device): | def move_tensor_to_device(self, tensor, device): | ||||
""" | |||||
将张量移到设备上 | |||||
:param tensor: 需要移动的张量 | |||||
:param device: 设备名, 一般为 "cpu", "cuda:0"等字符串 | |||||
""" | |||||
device = _convert_data_device(device) | device = _convert_data_device(device) | ||||
return paddle_to(tensor, device) | return paddle_to(tensor, device) | ||||
def all_gather_object(self, obj, group=None) -> List: | def all_gather_object(self, obj, group=None) -> List: | ||||
""" | |||||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||||
:param obj: | |||||
:param group: | |||||
""" | |||||
if self.is_distributed(): | if self.is_distributed(): | ||||
obj_list = fastnlp_paddle_all_gather(obj, group=group) | obj_list = fastnlp_paddle_all_gather(obj, group=group) | ||||
return obj_list | return obj_list | ||||
@@ -21,7 +21,16 @@ class TorchBackend(Backend): | |||||
def aggregate(self, tensor, method: str): | def aggregate(self, tensor, method: str): | ||||
""" | """ | ||||
聚集结果,并根据method计算后,返回结果。 | |||||
聚集结果,并根据 method 计算后,返回结果 | |||||
:param tensor: 需要聚合的张量 | |||||
:param method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'mix']``: | |||||
* method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。 | |||||
* method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。 | |||||
* method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。 | |||||
* method 为 ``'mix'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。 | |||||
""" | """ | ||||
if isinstance(tensor, torch.Tensor): | if isinstance(tensor, torch.Tensor): | ||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
@@ -46,26 +55,36 @@ class TorchBackend(Backend): | |||||
def create_tensor(self, value: float): | def create_tensor(self, value: float): | ||||
""" | """ | ||||
创建tensor,并且填入value作为值 | |||||
创建 tensor,并且填入 value 作为值 | |||||
:param value: 创建张量的初始值 | |||||
""" | """ | ||||
tensor = torch.ones(1).fill_(value) | tensor = torch.ones(1).fill_(value) | ||||
return tensor | return tensor | ||||
def fill_value(self, tensor, value: float): | def fill_value(self, tensor, value: float): | ||||
""" | """ | ||||
将tensor的值设置为value | |||||
将 tensor 的值设置为 value | |||||
:param tensor: 传入的张量 | |||||
:param value: 需要 fill 的值。 | |||||
""" | """ | ||||
tensor.fill_(value) | tensor.fill_(value) | ||||
return tensor | return tensor | ||||
def get_scalar(self, tensor) -> float: | def get_scalar(self, tensor) -> float: | ||||
""" | |||||
获取 tensor 的 scalar 值 | |||||
:param tensor: 传入的张量 | |||||
""" | |||||
return tensor.item() | return tensor.item() | ||||
def tensor2numpy(self, tensor) -> np.array: | def tensor2numpy(self, tensor) -> np.array: | ||||
""" | """ | ||||
将对应的tensor转为numpy对象 | |||||
将 tensor 转为 numpy 值, 主要是在 metric 计算中使用 | |||||
:param tensor: 传入的张量 | |||||
""" | """ | ||||
if isinstance(tensor, torch.Tensor): | if isinstance(tensor, torch.Tensor): | ||||
@@ -80,14 +99,28 @@ class TorchBackend(Backend): | |||||
@staticmethod | @staticmethod | ||||
def is_distributed() -> bool: | def is_distributed() -> bool: | ||||
""" | """ | ||||
判断是否为 ddp 状态 | |||||
:return: | :return: | ||||
""" | """ | ||||
return dist.is_available() and dist.is_initialized() | return dist.is_available() and dist.is_initialized() | ||||
def move_tensor_to_device(self, tensor, device): | def move_tensor_to_device(self, tensor, device): | ||||
""" | |||||
将张量移到设备上 | |||||
:param tensor: 需要移动的张量 | |||||
:param device: 设备名, 一般为 "cpu", "cuda:0"等字符串 | |||||
""" | |||||
return tensor.to(device) | return tensor.to(device) | ||||
def all_gather_object(self, obj, group=None) -> List: | def all_gather_object(self, obj, group=None) -> List: | ||||
""" | |||||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||||
:param obj: | |||||
:param group: | |||||
""" | |||||
if self.is_distributed(): | if self.is_distributed(): | ||||
obj_list = fastnlp_torch_all_gather(obj, group=group) | obj_list = fastnlp_torch_all_gather(obj, group=group) | ||||
return obj_list | return obj_list | ||||
@@ -20,15 +20,21 @@ class ClassifyFPreRecMetric(Metric): | |||||
aggregate_when_get_metric: bool = None) -> None: | aggregate_when_get_metric: bool = None) -> None: | ||||
""" | """ | ||||
:param tag_vocab: | |||||
:param ignore_labels: | |||||
:param only_gross: | |||||
:param f_type: | |||||
:param beta: | |||||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | |||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` . 默认值为 ``None``。若为 ``None`` 则使用数字来作为标签内容, | |||||
否则使用 vocab 来作为标签内容。 | |||||
:param ignore_labels: ``str`` 组成的 ``list``. 这个 ``list``中的 class 不会被用于计算。例如在 POS tagging 时传入 ``['NN']``, | |||||
则不会计算 'NN' 个 label | |||||
:param only_gross: 是否只计算总的 ``f1``, ``precision``, ``recall``的值;如果为 ``False``,不仅返回总的 ``f1``, ``pre``, | |||||
``rec``, 还会返回每个 label 的 ``f1``, ``pre``, ``rec`` | |||||
:param f_type: `micro` 或 `macro` . | |||||
* `micro` : 通过先计算总体的 TP,FN 和 FP 的数量,再计算 f, precision, recall; | |||||
* `macro` : 分布计算每个类别的 f, precision, recall,然后做平均(各类别 f 的权重相同) | |||||
:param beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||||
:param backend: 目前支持四种类型的 backend, ``[torch, paddle, jittor, 'auto']``。其中 ``'auto'`` 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 ``'auto'`` 即可。 | |||||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 ``None`` ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | |||||
""" | """ | ||||
super(ClassifyFPreRecMetric, self).__init__(backend=backend, | super(ClassifyFPreRecMetric, self).__init__(backend=backend, | ||||
aggregate_when_get_metric=aggregate_when_get_metric) | aggregate_when_get_metric=aggregate_when_get_metric) | ||||
@@ -50,6 +56,10 @@ class ClassifyFPreRecMetric(Metric): | |||||
self._fn = Counter() | self._fn = Counter() | ||||
def reset(self): | def reset(self): | ||||
""" | |||||
重置 tp, fp, fn 的值 | |||||
""" | |||||
# 由于不是 element 了,需要自己手动清零一下 | # 由于不是 element 了,需要自己手动清零一下 | ||||
self._tp.clear() | self._tp.clear() | ||||
self._fp.clear() | self._fp.clear() | ||||
@@ -57,9 +67,9 @@ class ClassifyFPreRecMetric(Metric): | |||||
def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
r""" | r""" | ||||
get_metric函数将根据update函数累计的评价指标统计量来计算最终的评价结果. | |||||
get_metric 函数将根据 update 函数累计的评价指标统计量来计算最终的评价结果. | |||||
:return dict evaluate_result: {"acc": float} | |||||
:return evaluate_result: {"acc": float} | |||||
""" | """ | ||||
evaluate_result = {} | evaluate_result = {} | ||||
@@ -120,12 +130,12 @@ class ClassifyFPreRecMetric(Metric): | |||||
r""" | r""" | ||||
update 函数将针对一个批次的预测结果做评价指标的累计 | update 函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
如果mask也被传进来的话seq_len会被忽略. | |||||
:param pred: 预测的 tensor, tensor 的形状可以是 [B,], [B, n_classes]) | |||||
[B, max_len], 或者 [B, max_len, n_classes] | |||||
:param target: 真实值的 tensor, tensor 的形状可以是 [B,], | |||||
[B,], [B, max_len], 或者 [B, max_len] | |||||
:param seq_len: 序列长度标记, 标记的形状可以是 None, [B]. | |||||
""" | """ | ||||
pred = self.tensor2numpy(pred) | pred = self.tensor2numpy(pred) | ||||
target = self.tensor2numpy(target) | target = self.tensor2numpy(target) | ||||
@@ -12,6 +12,26 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||||
class Element: | class Element: | ||||
def __init__(self, name, value: float, aggregate_method, backend: Backend): | def __init__(self, name, value: float, aggregate_method, backend: Backend): | ||||
""" | |||||
保存 Metric 中计算的元素值的对象 | |||||
:param name: 名称 | |||||
:param value: 元素的值 | |||||
:param aggregate_method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'mix']``: | |||||
* method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。 | |||||
* method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。 | |||||
* method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。 | |||||
* method 为 ``'mix'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。 | |||||
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 | |||||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | |||||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | |||||
的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含 | |||||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 | |||||
到任何一种 tensor ,就默认使用 float 类型作为 element 。 | |||||
""" | |||||
self.name = name | self.name = name | ||||
self.init_value = value | self.init_value = value | ||||
self.aggregate_method = aggregate_method | self.aggregate_method = aggregate_method | ||||
@@ -31,7 +51,7 @@ class Element: | |||||
def aggregate(self): | def aggregate(self): | ||||
""" | """ | ||||
自动aggregate对应的元素 | |||||
自动 aggregate 对应的元素 | |||||
""" | """ | ||||
self._check_value_initialized() | self._check_value_initialized() | ||||
@@ -54,6 +74,9 @@ class Element: | |||||
raise RuntimeError(msg) | raise RuntimeError(msg) | ||||
def reset(self): | def reset(self): | ||||
""" | |||||
重置 value | |||||
""" | |||||
if self.backend.is_specified(): | if self.backend.is_specified(): | ||||
self._value = self.backend.fill_value(self._value, self.init_value) | self._value = self.backend.fill_value(self._value, self.init_value) | ||||
@@ -72,19 +95,36 @@ class Element: | |||||
return self._value | return self._value | ||||
def get_scalar(self) -> float: | def get_scalar(self) -> float: | ||||
""" | |||||
获取元素的 scalar 值 | |||||
""" | |||||
self._check_value_initialized() | self._check_value_initialized() | ||||
return self.backend.get_scalar(self._value) | return self.backend.get_scalar(self._value) | ||||
def fill_value(self, value): | def fill_value(self, value): | ||||
""" | |||||
对元素进行 fill_value, 会执行队友 backend 的 fill_value 方法 | |||||
""" | |||||
self._value = self.backend.fill_value(self._value, value) | self._value = self.backend.fill_value(self._value, value) | ||||
def to(self, device): | def to(self, device): | ||||
""" | |||||
将元素移到某个设备上 | |||||
:param device: 设备名, 一般为 ``"cpu"``, ``"cuda:0"`` 等 | |||||
""" | |||||
# device这里如何处理呢? | # device这里如何处理呢? | ||||
if self._value is not None: | if self._value is not None: | ||||
self._value = self.backend.move_tensor_to_device(self._value, device) | self._value = self.backend.move_tensor_to_device(self._value, device) | ||||
self.device = device | self.device = device | ||||
def _check_value_initialized(self): | def _check_value_initialized(self): | ||||
""" | |||||
检查 Element 的 value 是否初始化了 | |||||
""" | |||||
if self._value is None: | if self._value is None: | ||||
assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \ | assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \ | ||||
f"initialization." | f"initialization." | ||||
@@ -114,6 +114,9 @@ class Metric: | |||||
return _wrap_update | return _wrap_update | ||||
def check_backend(self, *args, **kwargs): | def check_backend(self, *args, **kwargs): | ||||
""" | |||||
根据传入的参数的类型选择当前需要的 backend | |||||
""" | |||||
if not self.backend.is_specified(): | if not self.backend.is_specified(): | ||||
_args = [] | _args = [] | ||||
for arg in args: | for arg in args: | ||||
@@ -45,9 +45,9 @@ def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encod | |||||
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | ||||
r""" | r""" | ||||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||||
给定 Vocabular y自动判断是哪种类型的 encoding, 支持判断 bmes, bioes, bmeso, bio | |||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||||
:param tag_vocab: 支持传入 tag Vocabulary; 或者传入形如 {0:"O", 1:"B-tag1"},即 index 在前,tag 在后的 dict。 | |||||
:return: | :return: | ||||
""" | """ | ||||
tag_set = set() | tag_set = set() | ||||
@@ -81,9 +81,9 @@ def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str | |||||
def _bmes_tag_to_spans(tags, ignore_labels=None): | def _bmes_tag_to_spans(tags, ignore_labels=None): | ||||
r""" | r""" | ||||
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | |||||
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | |||||
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | |||||
给定一个 tags 的 lis,比如 ['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | |||||
返回 [('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | |||||
也可以是单纯的 ['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | |||||
:param tags: List[str], | :param tags: List[str], | ||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | :param ignore_labels: List[str], 在该list中的label将被忽略 | ||||
@@ -111,8 +111,8 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||||
def _bmeso_tag_to_spans(tags, ignore_labels=None): | def _bmeso_tag_to_spans(tags, ignore_labels=None): | ||||
r""" | r""" | ||||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (左闭右开区间) | |||||
给定一个 tag s的 lis,比如 ['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||||
返回 [('singer', (1, 4))] (左闭右开区间) | |||||
:param tags: List[str], | :param tags: List[str], | ||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | :param ignore_labels: List[str], 在该list中的label将被忽略 | ||||
@@ -142,8 +142,8 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
def _bioes_tag_to_spans(tags, ignore_labels=None): | def _bioes_tag_to_spans(tags, ignore_labels=None): | ||||
r""" | r""" | ||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (左闭右开区间) | |||||
给定一个 tags 的 lis,比如 ['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | |||||
返回 [('singer', (1, 4))] (左闭右开区间) | |||||
:param tags: List[str], | :param tags: List[str], | ||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | :param ignore_labels: List[str], 在该list中的label将被忽略 | ||||
@@ -173,8 +173,8 @@ def _bioes_tag_to_spans(tags, ignore_labels=None): | |||||
def _bio_tag_to_spans(tags, ignore_labels=None): | def _bio_tag_to_spans(tags, ignore_labels=None): | ||||
r""" | r""" | ||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (左闭右开区间) | |||||
给定一个 tags 的 lis,比如 ['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||||
返回 [('singer', (1, 4))] (左闭右开区间) | |||||
:param tags: List[str], | :param tags: List[str], | ||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | :param ignore_labels: List[str], 在该list中的label将被忽略 | ||||
@@ -204,9 +204,6 @@ class SpanFPreRecMetric(Metric): | |||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | ||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | ||||
:param pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||||
:param target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||||
:param seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||||
:param encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | :param encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | ||||
:param ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | :param ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | ||||
:param only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | :param only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | ||||
@@ -256,11 +253,17 @@ class SpanFPreRecMetric(Metric): | |||||
self._fn = Counter() | self._fn = Counter() | ||||
def reset(self): | def reset(self): | ||||
""" | |||||
重置所有元素 | |||||
""" | |||||
self._tp.clear() | self._tp.clear() | ||||
self._fp.clear() | self._fp.clear() | ||||
self._fn.clear() | self._fn.clear() | ||||
def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
""" | |||||
get_metric 函数将根据 update 函数累计的评价指标统计量来计算最终的评价结果. | |||||
""" | |||||
evaluate_result = {} | evaluate_result = {} | ||||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | ||||
@@ -314,7 +317,8 @@ class SpanFPreRecMetric(Metric): | |||||
return evaluate_result | return evaluate_result | ||||
def update(self, pred, target, seq_len: Optional[List] = None) -> None: | def update(self, pred, target, seq_len: Optional[List] = None) -> None: | ||||
r"""update函数将针对一个批次的预测结果做评价指标的累计 | |||||
r"""u | |||||
pdate函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | :param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | ||||
:param target: [batch, seq_len], 真实值 | :param target: [batch, seq_len], 真实值 | ||||