|
|
@@ -1,4 +1,150 @@ |
|
|
|
r""" |
|
|
|
:class:`~fastNLP.core.dataset.DataSet` 是 fastNLP 中用于承载数据的容器。可以将 DataSet 看做是一个表格, |
|
|
|
每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.instance` ), |
|
|
|
每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.field` )。 |
|
|
|
|
|
|
|
.. csv-table:: Following is a demo layout of DataSet |
|
|
|
:header: "sentence", "words", "seq_len" |
|
|
|
|
|
|
|
"This is the first instance .", "[This, is, the, first, instance, .]", 6 |
|
|
|
"Second instance .", "[Second, instance, .]", 3 |
|
|
|
"Third instance .", "[Third, instance, .]", 3 |
|
|
|
"...", "[...]", "..." |
|
|
|
|
|
|
|
在 fastNLP 内部每一行是一个 :class:`~fastNLP.Instance` 对象; 每一列是一个 :class:`~fastNLP.FieldArray` 对象。 |
|
|
|
|
|
|
|
---------------------------- |
|
|
|
1.DataSet的创建 |
|
|
|
---------------------------- |
|
|
|
|
|
|
|
创建DataSet主要有以下的3种方式 |
|
|
|
|
|
|
|
1.1 传入dict |
|
|
|
---------------------------- |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
from fastNLP import DataSet |
|
|
|
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."], |
|
|
|
'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.'], |
|
|
|
'seq_len': [6, 3, 3]} |
|
|
|
dataset = DataSet(data) |
|
|
|
# 传入的 dict 的每个 key 的 value 应该为具有相同长度的l ist |
|
|
|
|
|
|
|
1.2 通过 Instance 构建 |
|
|
|
---------------------------- |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
from fastNLP import DataSet |
|
|
|
from fastNLP import Instance |
|
|
|
dataset = DataSet() |
|
|
|
instance = Instance(sentence="This is the first instance", |
|
|
|
words=['this', 'is', 'the', 'first', 'instance', '.'], |
|
|
|
seq_len=6) |
|
|
|
dataset.append(instance) |
|
|
|
# 可以继续 append 更多内容,但是 append 的 instance 应该和第一个 instance 拥有完全相同的 field |
|
|
|
|
|
|
|
1.3 通过 List[Instance] 构建 |
|
|
|
-------------------------------------- |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
from fastNLP import DataSet |
|
|
|
from fastNLP import Instance |
|
|
|
instances = [] |
|
|
|
winstances.append(Instance(sentence="This is the first instance", |
|
|
|
ords=['this', 'is', 'the', 'first', 'instance', '.'], |
|
|
|
seq_len=6)) |
|
|
|
instances.append(Instance(sentence="Second instance .", |
|
|
|
words=['Second', 'instance', '.'], |
|
|
|
seq_len=3)) |
|
|
|
dataset = DataSet(instances) |
|
|
|
|
|
|
|
-------------------------------------- |
|
|
|
2.DataSet 与预处理 |
|
|
|
-------------------------------------- |
|
|
|
|
|
|
|
常见的预处理有如下几种 |
|
|
|
|
|
|
|
2.1 从某个文本文件读取内容 |
|
|
|
-------------------------------------- |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
from fastNLP import DataSet |
|
|
|
from fastNLP import Instance |
|
|
|
dataset = DataSet() |
|
|
|
filepath = 'some/text/file' |
|
|
|
# 假设文件中每行内容如下(sentence label): |
|
|
|
# This is a fantastic day positive |
|
|
|
# The bad weather negative |
|
|
|
# ..... |
|
|
|
with open(filepath, 'r') as f: |
|
|
|
for line in f: |
|
|
|
sent, label = line.strip().split('\t') |
|
|
|
dataset.append(Instance(sentence=sent, label=label)) |
|
|
|
|
|
|
|
|
|
|
|
2.2 对 DataSet 中的内容处理 |
|
|
|
-------------------------------------- |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
from fastNLP import DataSet |
|
|
|
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."]} |
|
|
|
dataset = DataSet(data) |
|
|
|
# 将句子分成单词形式, 详见DataSet.apply()方法, 可以开启多进程来加快处理, 也可以更改展示的bar,目前支持 ``['rich', 'tqdm', None]``, |
|
|
|
# 详细内容可以见 :class: `~fastNLP.core.dataset.DataSet`, 需要注意的时匿名函数不支持多进程 |
|
|
|
dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', |
|
|
|
progress_des='Main',progress_bar='rich') |
|
|
|
# 或使用DataSet.apply_field() |
|
|
|
dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', |
|
|
|
progress_des='Main',progress_bar='rich') |
|
|
|
# 除了匿名函数,也可以定义函数传递进去 |
|
|
|
def get_words(instance): |
|
|
|
sentence = instance['sentence'] |
|
|
|
words = sentence.split() |
|
|
|
return words |
|
|
|
dataset.apply(get_words, new_field_name='words', num_proc=2, progress_des='Main',progress_bar='rich') |
|
|
|
|
|
|
|
2.3 删除DataSet的内容 |
|
|
|
-------------------------------------- |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
from fastNLP import DataSet |
|
|
|
dataset = DataSet({'a': list(range(-5, 5))}) |
|
|
|
# 返回满足条件的 instance,并放入 DataSet 中 |
|
|
|
dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False) |
|
|
|
# 在 dataset 中删除满足条件的i nstance |
|
|
|
dataset.drop(lambda ins:ins['a']<0) # dataset 的 instance数量减少 |
|
|
|
# 删除第 3 个 instance |
|
|
|
dataset.delete_instance(2) |
|
|
|
# 删除名为 'a' 的 field |
|
|
|
dataset.delete_field('a') |
|
|
|
|
|
|
|
|
|
|
|
2.4 遍历DataSet的内容 |
|
|
|
-------------------------------------- |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
for instance in dataset: |
|
|
|
# do something |
|
|
|
|
|
|
|
2.5 一些其它操作 |
|
|
|
-------------------------------------- |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
# 检查是否存在名为 'a' 的 field |
|
|
|
dataset.has_field('a') # 或 ('a' in dataset) |
|
|
|
# 将名为 'a' 的 field 改名为 'b' |
|
|
|
dataset.rename_field('a', 'b') |
|
|
|
# DataSet 的长度 |
|
|
|
len(dataset) |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
__all__ = [ |
|
|
@@ -42,9 +188,9 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, p |
|
|
|
""" |
|
|
|
对数据集进行处理封装函数,以便多进程使用 |
|
|
|
|
|
|
|
:param ds: 数据集 |
|
|
|
:param _apply_field: 需要处理数据集的field_name |
|
|
|
:param func: 用户自定义的func |
|
|
|
:param ds: 实现了 __getitem__() 和 __len__() 的对象 |
|
|
|
:param _apply_field: 需要处理数据集的 field_name |
|
|
|
:param func: 用户自定义的 func |
|
|
|
:param desc: 进度条的描述字符 |
|
|
|
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 |
|
|
|
:return: |
|
|
@@ -76,9 +222,9 @@ def _multi_proc(ds, _apply_field, func, counter, queue): |
|
|
|
""" |
|
|
|
对数据集进行处理封装函数,以便多进程使用 |
|
|
|
|
|
|
|
:param ds: 数据集 |
|
|
|
:param _apply_field: 需要处理数据集的field_name |
|
|
|
:param func: 用户自定义的func |
|
|
|
:param ds: 实现了 __getitem__() 和 __len__() 的对象 |
|
|
|
:param _apply_field: 需要处理数据集的 field_name |
|
|
|
:param func: 用户自定义的 func |
|
|
|
:param counter: 计数器 |
|
|
|
:param queue: 多进程时,将结果输入到这个 queue 中 |
|
|
|
:return: |
|
|
@@ -111,9 +257,28 @@ class DataSet: |
|
|
|
|
|
|
|
def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): |
|
|
|
r""" |
|
|
|
初始化 ``DataSet``, fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class: `~fastNLP.core.dataset.Instance`]`` 和 |
|
|
|
``Dict[str, List[Any]]``。 |
|
|
|
|
|
|
|
* 当 data 为 ``List[:class: `~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 |
|
|
|
Instance 详见 :class: `~fastNLP.core.dataset.Instance` 。 |
|
|
|
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 |
|
|
|
|
|
|
|
:param data: 初始化的内容, 其只能为两种类型,分别为 ``List[:class: `~fastNLP.core.dataset.Instance`]`` 和 |
|
|
|
``Dict[str, List[Any]]``。 |
|
|
|
|
|
|
|
* 当 data 为 ``List[:class: `~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。 |
|
|
|
Instance 详见 :class: `~fastNLP.core.dataset.Instance` 。 |
|
|
|
* 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。 |
|
|
|
|
|
|
|
Example:: |
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet, Instance |
|
|
|
data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]} |
|
|
|
data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)] |
|
|
|
ds = DataSet(data) |
|
|
|
ds = DataSet(data1) |
|
|
|
|
|
|
|
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, |
|
|
|
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 |
|
|
|
""" |
|
|
|
self.field_arrays = {} |
|
|
|
self._collator = Collator() |
|
|
@@ -168,11 +333,27 @@ class DataSet: |
|
|
|
return inner_iter_func() |
|
|
|
|
|
|
|
def __getitem__(self, idx: Union[int, slice, str, list]): |
|
|
|
r"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 |
|
|
|
r""" |
|
|
|
去 DataSet 的内容, 根据 idx 类型不同有不同的返回值。 包括四种类型 ``[int, slice, str, list]`` |
|
|
|
|
|
|
|
:param idx: can be int or slice. |
|
|
|
:return: If `idx` is int, return an Instance object. |
|
|
|
If `idx` is slice, return a DataSet object. |
|
|
|
* 当 idx 为 ``int`` 时, idx 的值不能超过 ``DataSet`` 的长度, 会返回一个 ``Instance``, 详见 |
|
|
|
:class: `~fastNLP.core.dataset.Instance` |
|
|
|
* 当 idx 为 ``slice`` 时, 会根据 slice 的内容创建一个新的 DataSet,其包含 slice 所有内容并返回。 |
|
|
|
* 当 idx 为 ``str`` 时, 该 idx 为 DataSet 的 field_name, 其会返回该 field_name 的所有内容, 为 list 类型。 |
|
|
|
* 当 idx 为 ``list`` 时, 该 idx 的 list 内全为 int 数字, 其会取出所有内容组成一个新的 DataSet 返回。 |
|
|
|
|
|
|
|
Example:: |
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
|
|
ds = DataSet({'x': [[1, 0, 1], [0, 1, 1] * 100, 'y': [0, 1] * 100}) |
|
|
|
ins = ds[0] |
|
|
|
sub_ds = ds[0:100] |
|
|
|
sub_ds= ds[[1, 0, 3, 2, 1, 4]] |
|
|
|
field = ds['x'] |
|
|
|
|
|
|
|
:param idx: 用户传入参数 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if isinstance(idx, int): |
|
|
|
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) |
|
|
@@ -230,9 +411,10 @@ class DataSet: |
|
|
|
return self.__dict__ |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
r"""Fetch the length of the dataset. |
|
|
|
r""" |
|
|
|
获取 DataSet 的长度 |
|
|
|
|
|
|
|
:return length: |
|
|
|
:return |
|
|
|
""" |
|
|
|
if len(self.field_arrays) == 0: |
|
|
|
return 0 |
|
|
@@ -244,9 +426,9 @@ class DataSet: |
|
|
|
|
|
|
|
def append(self, instance: Instance) -> None: |
|
|
|
r""" |
|
|
|
将一个instance对象append到DataSet后面。 |
|
|
|
将一个 instance 对象 append 到 DataSet 后面。详见 :class: `~fastNLP.Instance` |
|
|
|
|
|
|
|
:param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。 |
|
|
|
:param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field。 |
|
|
|
|
|
|
|
""" |
|
|
|
if len(self.field_arrays) == 0: |
|
|
@@ -269,10 +451,10 @@ class DataSet: |
|
|
|
|
|
|
|
def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: |
|
|
|
r""" |
|
|
|
将fieldarray添加到DataSet中. |
|
|
|
将 fieldarray 添加到 DataSet 中. |
|
|
|
|
|
|
|
:param str field_name: 新加入的field的名称 |
|
|
|
:param ~fastNLP.core.FieldArray fieldarray: 需要加入DataSet的field的内容 |
|
|
|
:param field_name: 新加入的 field 的名称 |
|
|
|
:param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class: `~fastNLP.core.dataset.FieldArray` |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if not isinstance(fieldarray, FieldArray): |
|
|
@@ -285,10 +467,10 @@ class DataSet: |
|
|
|
|
|
|
|
def add_field(self, field_name: str, fields: list) -> None: |
|
|
|
r""" |
|
|
|
新增一个field, 需要注意的是fields的长度跟dataset长度一致 |
|
|
|
新增一个 field, 需要注意的是 fields 的长度跟 DataSet 长度一致 |
|
|
|
|
|
|
|
:param str field_name: 新增的field的名称 |
|
|
|
:param list fields: 需要新增的field的内容 |
|
|
|
:param field_name: 新增的 field 的名称 |
|
|
|
:param fields: 需要新增的 field 的内容 |
|
|
|
""" |
|
|
|
|
|
|
|
if len(self.field_arrays) != 0: |
|
|
@@ -299,9 +481,9 @@ class DataSet: |
|
|
|
|
|
|
|
def delete_instance(self, index: int): |
|
|
|
r""" |
|
|
|
删除第index个instance |
|
|
|
删除第 ``index `` 个 Instance |
|
|
|
|
|
|
|
:param int index: 需要删除的instance的index,序号从0开始。 |
|
|
|
:param index: 需要删除的 instanc e的 index,序号从 `0` 开始。 |
|
|
|
""" |
|
|
|
assert isinstance(index, int), "Only integer supported." |
|
|
|
if len(self) <= index: |
|
|
@@ -315,9 +497,9 @@ class DataSet: |
|
|
|
|
|
|
|
def delete_field(self, field_name: str): |
|
|
|
r""" |
|
|
|
删除名为field_name的field |
|
|
|
删除名为 field_name 的 field |
|
|
|
|
|
|
|
:param str field_name: 需要删除的field的名称. |
|
|
|
:param field_name: 需要删除的 field 的名称. |
|
|
|
""" |
|
|
|
if self.has_field(field_name): |
|
|
|
self.field_arrays.pop(field_name) |
|
|
@@ -327,10 +509,10 @@ class DataSet: |
|
|
|
|
|
|
|
def copy_field(self, field_name: str, new_field_name: str): |
|
|
|
r""" |
|
|
|
深度copy名为field_name的field到new_field_name |
|
|
|
深度 copy 名为 field_name 的 field 到 new_field_name |
|
|
|
|
|
|
|
:param str field_name: 需要copy的field。 |
|
|
|
:param str new_field_name: copy生成的field名称 |
|
|
|
:param field_name: 需要 copy 的 field。 |
|
|
|
:param new_field_name: copy 生成的 field 名称 |
|
|
|
:return: self |
|
|
|
""" |
|
|
|
if not self.has_field(field_name): |
|
|
@@ -342,10 +524,10 @@ class DataSet: |
|
|
|
|
|
|
|
def has_field(self, field_name: str) -> bool: |
|
|
|
r""" |
|
|
|
判断DataSet中是否有名为field_name这个field |
|
|
|
判断 DataSet 中是否有名为 field_name 这个 field |
|
|
|
|
|
|
|
:param str field_name: field的名称 |
|
|
|
:return bool: 表示是否有名为field_name这个field |
|
|
|
:param field_name: field 的名称 |
|
|
|
:return: 表示是否有名为 field_name 这个 field |
|
|
|
""" |
|
|
|
if isinstance(field_name, str): |
|
|
|
return field_name in self.field_arrays |
|
|
@@ -353,9 +535,9 @@ class DataSet: |
|
|
|
|
|
|
|
def get_field(self, field_name: str) -> FieldArray: |
|
|
|
r""" |
|
|
|
获取field_name这个field |
|
|
|
获取 field_name 这个 field |
|
|
|
|
|
|
|
:param str field_name: field的名称 |
|
|
|
:param field_name: field 的名称 |
|
|
|
:return: :class:`~fastNLP.FieldArray` |
|
|
|
""" |
|
|
|
if field_name not in self.field_arrays: |
|
|
@@ -364,34 +546,34 @@ class DataSet: |
|
|
|
|
|
|
|
def get_all_fields(self) -> dict: |
|
|
|
r""" |
|
|
|
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` |
|
|
|
返回一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.FieldArray` |
|
|
|
|
|
|
|
:return dict: 返回如上所述的字典 |
|
|
|
:return: 返回如上所述的字典 |
|
|
|
""" |
|
|
|
return self.field_arrays |
|
|
|
|
|
|
|
def get_field_names(self) -> list: |
|
|
|
r""" |
|
|
|
返回一个list,包含所有 field 的名字 |
|
|
|
返回一个 list,包含所有 field 的名字 |
|
|
|
|
|
|
|
:return list: 返回如上所述的列表 |
|
|
|
:return: 返回如上所述的列表 |
|
|
|
""" |
|
|
|
return sorted(self.field_arrays.keys()) |
|
|
|
|
|
|
|
def get_length(self) -> int: |
|
|
|
r""" |
|
|
|
获取DataSet的元素数量 |
|
|
|
获取 DataSet 的元素数量 |
|
|
|
|
|
|
|
:return: int: DataSet中Instance的个数。 |
|
|
|
:return: DataSet 中 Instance 的个数。 |
|
|
|
""" |
|
|
|
return len(self) |
|
|
|
|
|
|
|
def rename_field(self, field_name: str, new_field_name: str): |
|
|
|
r""" |
|
|
|
将某个field重新命名. |
|
|
|
将某个 field 重新命名. |
|
|
|
|
|
|
|
:param str field_name: 原来的field名称。 |
|
|
|
:param str new_field_name: 修改为new_name。 |
|
|
|
:param field_name: 原来的 field 名称。 |
|
|
|
:param new_field_name: 修改为 new_name。 |
|
|
|
""" |
|
|
|
if field_name in self.field_arrays: |
|
|
|
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) |
|
|
@@ -627,10 +809,10 @@ class DataSet: |
|
|
|
|
|
|
|
def add_seq_len(self, field_name: str, new_field_name='seq_len'): |
|
|
|
r""" |
|
|
|
将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。 |
|
|
|
将使用 len() 直接对 field_name 中每个元素作用,将其结果作为 sequence length, 并放入 seq_len 这个 field。 |
|
|
|
|
|
|
|
:param field_name: str. |
|
|
|
:param new_field_name: str. 新的field_name |
|
|
|
:param field_name: 需要处理的 field_name |
|
|
|
:param new_field_name: str. 新的 field_name |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if self.has_field(field_name=field_name): |
|
|
@@ -641,10 +823,11 @@ class DataSet: |
|
|
|
|
|
|
|
def drop(self, func: Callable, inplace=True): |
|
|
|
r""" |
|
|
|
func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者不会包含在返回的DataSet中。 |
|
|
|
删除某些 Instance。 需要注意的时func 接受一个 Instance ,返回 bool 值。返回值为 True 时, |
|
|
|
该 Instance 会被移除或者不会包含在返回的 DataSet 中。 |
|
|
|
|
|
|
|
:param callable func: 接受一个Instance作为参数,返回bool值。为True时删除该instance |
|
|
|
:param bool inplace: 是否在当前DataSet中直接删除instance;如果为False,将返回一个新的DataSet。 |
|
|
|
:param func: 接受一个 Instance 作为参数,返回 bool 值。为 True 时删除该 instance |
|
|
|
:param inplace: 是否在当前 DataSet 中直接删除 instance;如果为 False,将返回一个新的 DataSet。 |
|
|
|
|
|
|
|
:return: DataSet |
|
|
|
""" |
|
|
@@ -663,10 +846,10 @@ class DataSet: |
|
|
|
|
|
|
|
def split(self, ratio: float, shuffle=True): |
|
|
|
r""" |
|
|
|
将DataSet按照ratio的比例拆分,返回两个DataSet |
|
|
|
将 DataSet 按照 ratio 的比例拆分,返回两个 DataSet |
|
|
|
|
|
|
|
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `ratio` 这么多数据,第二个DataSet拥有`(1-ratio)`这么多数据 |
|
|
|
:param bool shuffle: 在split前是否shuffle一下。为False,返回的第一个dataset就是当前dataset中前`ratio`比例的数据, |
|
|
|
:param ratio: 0<ratio<1, 返回的第一个 DataSet 拥有 `ratio` 这么多数据,第二个 DataSet 拥有 `(1-ratio)` 这么多数据 |
|
|
|
:param shuffle: 在 split 前是否 shuffle 一下。为 False,返回的第一个 dataset 就是当前 dataset 中前 `ratio` 比例的数据, |
|
|
|
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] |
|
|
|
""" |
|
|
|
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' |
|
|
@@ -696,7 +879,7 @@ class DataSet: |
|
|
|
r""" |
|
|
|
保存DataSet. |
|
|
|
|
|
|
|
:param str path: 将DataSet存在哪个路径 |
|
|
|
:param path: 将DataSet存在哪个路径 |
|
|
|
""" |
|
|
|
with open(path, 'wb') as f: |
|
|
|
pickle.dump(self, f) |
|
|
@@ -704,9 +887,9 @@ class DataSet: |
|
|
|
@staticmethod |
|
|
|
def load(path: str): |
|
|
|
r""" |
|
|
|
从保存的DataSet pickle文件的路径中读取DataSet |
|
|
|
从保存的 DataSet pickle文件的路径中读取DataSet |
|
|
|
|
|
|
|
:param str path: 从哪里读取DataSet |
|
|
|
:param path: 从哪里读取 DataSet |
|
|
|
:return: 读取后的 :class:`~fastNLP.读取后的DataSet`。 |
|
|
|
""" |
|
|
|
with open(path, 'rb') as f: |
|
|
@@ -716,16 +899,16 @@ class DataSet: |
|
|
|
|
|
|
|
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': |
|
|
|
""" |
|
|
|
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target |
|
|
|
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 |
|
|
|
当前dataset含有field,则会报错。 |
|
|
|
将当前 dataset 与输入的 dataset 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset |
|
|
|
的 field_name 和 _collator 以当前 dataset 为准。当 dataset 中包含的 field 多于当前的 dataset,则多余的 field 会被忽略; |
|
|
|
若 dataset 中未包含所有当前 dataset 含有 field,则会报错。 |
|
|
|
|
|
|
|
:param DataSet, dataset: 需要和当前dataset concat的dataset |
|
|
|
:param bool, inplace: 是否直接将dataset组合到当前dataset中 |
|
|
|
:param dict, field_mapping: 当传入的dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的 |
|
|
|
field名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称 |
|
|
|
:param dataset: 需要和当前 dataset concat的 dataset |
|
|
|
:param inplace: 是否直接将 dataset 组合到当前 dataset 中 |
|
|
|
:param field_mapping: 当传入的 dataset 中的 field 名称和当前 dataset 不一致时,需要通过 field_mapping 把输入的 dataset 中的 |
|
|
|
field 名称映射到当前 field. field_mapping 为 dict 类型,key 为 dataset 中的 field 名称,value 是需要映射成的名称 |
|
|
|
|
|
|
|
:return: DataSet |
|
|
|
:return: :class: `~fastNLP.core.dataset.DataSet`` |
|
|
|
""" |
|
|
|
assert isinstance(dataset, DataSet), "Can only concat two datasets." |
|
|
|
|
|
|
@@ -754,8 +937,8 @@ class DataSet: |
|
|
|
@classmethod |
|
|
|
def from_pandas(cls, df): |
|
|
|
""" |
|
|
|
从pandas.DataFrame中读取数据转为Dataset |
|
|
|
:param df: |
|
|
|
从 ``pandas.DataFrame`` 中读取数据转为 DataSet |
|
|
|
:param df: 使用 pandas 读取的数据 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
df_dict = df.to_dict(orient='list') |
|
|
@@ -763,7 +946,7 @@ class DataSet: |
|
|
|
|
|
|
|
def to_pandas(self): |
|
|
|
""" |
|
|
|
将dataset转为pandas.DataFrame类型的数据 |
|
|
|
将 DataSet 数据转为 ``pandas.DataFrame`` 类型的数据 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
@@ -773,9 +956,9 @@ class DataSet: |
|
|
|
|
|
|
|
def to_csv(self, path: str): |
|
|
|
""" |
|
|
|
将dataset保存为csv文件 |
|
|
|
将 DataSet 保存为 csv 文件 |
|
|
|
|
|
|
|
:param path: |
|
|
|
:param path: 保存到路径 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
|
|
|
|