|
|
@@ -6,7 +6,7 @@ __all__ = [ |
|
|
|
'DataBundle', |
|
|
|
] |
|
|
|
|
|
|
|
from typing import Union |
|
|
|
from typing import Union, List |
|
|
|
|
|
|
|
from ..core.dataset import DataSet |
|
|
|
from ..core.vocabulary import Vocabulary |
|
|
@@ -274,6 +274,22 @@ class DataBundle: |
|
|
|
for name, dataset in self.datasets.items(): |
|
|
|
yield name, dataset |
|
|
|
|
|
|
|
def get_dataset_names(self) -> List[str]: |
|
|
|
""" |
|
|
|
返回DataBundle中DataSet的名称 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
return list(self.datasets.keys()) |
|
|
|
|
|
|
|
def get_vocab_names(self)->List[str]: |
|
|
|
""" |
|
|
|
返回DataBundle中Vocabulary的名称 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
return list(self.vocabs.keys()) |
|
|
|
|
|
|
|
def iter_vocabs(self) -> Union[str, Vocabulary]: |
|
|
|
""" |
|
|
|
迭代data_bundle中的DataSet |
|
|
@@ -332,6 +348,27 @@ class DataBundle: |
|
|
|
dataset.apply(func, new_field_name=new_field_name, **kwargs) |
|
|
|
return self |
|
|
|
|
|
|
|
def add_collect_fn(self, fn, name=None): |
|
|
|
""" |
|
|
|
向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明. |
|
|
|
|
|
|
|
:param callable fn: |
|
|
|
:param name: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
for _, dataset in self.datasets.items(): |
|
|
|
dataset.add_collect_fn(fn=fn, name=name) |
|
|
|
|
|
|
|
def delete_collect_fn(self, name=None): |
|
|
|
""" |
|
|
|
删除DataSet中的collect_fn |
|
|
|
|
|
|
|
:param name: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
for _, dataset in self.datasets.items(): |
|
|
|
dataset.delete_collect_fn(name=name) |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
_str = '' |
|
|
|
if len(self.datasets): |
|
|
|