From bbb1dbb63053f7195e9e660323c580fb02995637 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 24 Mar 2020 15:52:41 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9tutorial=E4=B8=AD=E7=9A=84typ?= =?UTF-8?q?o;=20=E5=A2=9E=E5=8A=A0DataBundle=E7=9A=84add=5Fcollect=5Ffn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/tutorials/序列标注.rst | 2 ++ docs/source/tutorials/文本分类.rst | 2 +- fastNLP/io/data_bundle.py | 39 +++++++++++++++++++++++++++++++++- fastNLP/modules/encoder/lstm.py | 2 +- 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/docs/source/tutorials/序列标注.rst b/docs/source/tutorials/序列标注.rst index 02d3a503..17cbb298 100644 --- a/docs/source/tutorials/序列标注.rst +++ b/docs/source/tutorials/序列标注.rst @@ -147,6 +147,8 @@ fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您 .. code-block:: python from fastNLP.io import WeiboNERPipe + from fastNLP.models import BiLSTMCRF + data_bundle = WeiboNERPipe().process_from_file() data_bundle.rename_field('chars', 'words') diff --git a/docs/source/tutorials/文本分类.rst b/docs/source/tutorials/文本分类.rst index 2f675115..997e35c8 100644 --- a/docs/source/tutorials/文本分类.rst +++ b/docs/source/tutorials/文本分类.rst @@ -301,7 +301,7 @@ fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所 # 这里为了演示一下效果,所以默认Bert不更新权重 bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False) - model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), ) + model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target'))) import torch diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index ba275e61..553d9db8 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -6,7 +6,7 @@ __all__ = [ 'DataBundle', ] -from typing import Union +from typing import Union, List from ..core.dataset import DataSet from ..core.vocabulary import Vocabulary @@ -274,6 +274,22 @@ class DataBundle: for name, dataset in self.datasets.items(): yield name, dataset + def get_dataset_names(self) -> List[str]: + """ + 返回DataBundle中DataSet的名称 + + :return: + """ + return list(self.datasets.keys()) + + def get_vocab_names(self)->List[str]: + """ + 返回DataBundle中Vocabulary的名称 + + :return: + """ + return list(self.vocabs.keys()) + def iter_vocabs(self) -> Union[str, Vocabulary]: """ 迭代data_bundle中的DataSet @@ -332,6 +348,27 @@ class DataBundle: dataset.apply(func, new_field_name=new_field_name, **kwargs) return self + def add_collect_fn(self, fn, name=None): + """ + 向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明. + + :param callable fn: + :param name: + :return: + """ + for _, dataset in self.datasets.items(): + dataset.add_collect_fn(fn=fn, name=name) + + def delete_collect_fn(self, name=None): + """ + 删除DataSet中的collect_fn + + :param name: + :return: + """ + for _, dataset in self.datasets.items(): + dataset.delete_collect_fn(name=name) + def __repr__(self): _str = '' if len(self.datasets): diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 1ae61ec0..e3af8bbd 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -24,7 +24,7 @@ class LSTM(nn.Module): """ :param input_size: 输入 `x` 的特征维度 - :param hidden_size: 隐状态 `h` 的特征维度. + :param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 :param num_layers: rnn的层数. Default: 1 :param dropout: 层间dropout概率. Default: 0 :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``