Browse Source

修改tutorial中的typo; 增加DataBundle的add_collect_fn

tags/v0.5.5
yh_cc 5 years ago
parent
commit
bbb1dbb630
4 changed files with 42 additions and 3 deletions
  1. +2
    -0
      docs/source/tutorials/序列标注.rst
  2. +1
    -1
      docs/source/tutorials/文本分类.rst
  3. +38
    -1
      fastNLP/io/data_bundle.py
  4. +1
    -1
      fastNLP/modules/encoder/lstm.py

+ 2
- 0
docs/source/tutorials/序列标注.rst View File

@@ -147,6 +147,8 @@ fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您
.. code-block:: python

from fastNLP.io import WeiboNERPipe
from fastNLP.models import BiLSTMCRF

data_bundle = WeiboNERPipe().process_from_file()
data_bundle.rename_field('chars', 'words')



+ 1
- 1
docs/source/tutorials/文本分类.rst View File

@@ -301,7 +301,7 @@ fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所
# 这里为了演示一下效果,所以默认Bert不更新权重
bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)
model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )
model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')))
import torch


+ 38
- 1
fastNLP/io/data_bundle.py View File

@@ -6,7 +6,7 @@ __all__ = [
'DataBundle',
]

from typing import Union
from typing import Union, List

from ..core.dataset import DataSet
from ..core.vocabulary import Vocabulary
@@ -274,6 +274,22 @@ class DataBundle:
for name, dataset in self.datasets.items():
yield name, dataset

def get_dataset_names(self) -> List[str]:
"""
返回DataBundle中DataSet的名称

:return:
"""
return list(self.datasets.keys())

def get_vocab_names(self)->List[str]:
"""
返回DataBundle中Vocabulary的名称

:return:
"""
return list(self.vocabs.keys())

def iter_vocabs(self) -> Union[str, Vocabulary]:
"""
迭代data_bundle中的DataSet
@@ -332,6 +348,27 @@ class DataBundle:
dataset.apply(func, new_field_name=new_field_name, **kwargs)
return self

def add_collect_fn(self, fn, name=None):
"""
向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明.

:param callable fn:
:param name:
:return:
"""
for _, dataset in self.datasets.items():
dataset.add_collect_fn(fn=fn, name=name)

def delete_collect_fn(self, name=None):
"""
删除DataSet中的collect_fn

:param name:
:return:
"""
for _, dataset in self.datasets.items():
dataset.delete_collect_fn(name=name)

def __repr__(self):
_str = ''
if len(self.datasets):


+ 1
- 1
fastNLP/modules/encoder/lstm.py View File

@@ -24,7 +24,7 @@ class LSTM(nn.Module):
"""
:param input_size: 输入 `x` 的特征维度
:param hidden_size: 隐状态 `h` 的特征维度.
:param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2
:param num_layers: rnn的层数. Default: 1
:param dropout: 层间dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``


Loading…
Cancel
Save