@@ -1,5 +1,3 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MultiHeadAttention", | |||
"BiAttention", | |||
@@ -17,7 +15,11 @@ from .decoder.seq2seq_state import TransformerState | |||
class DotAttention(nn.Module): | |||
r""" | |||
Transformer当中的DotAttention | |||
**Transformer** 当中的 **DotAttention** | |||
:param key_size: | |||
:param value_size: | |||
:param dropout: | |||
""" | |||
def __init__(self, key_size, value_size, dropout=0.0): | |||
@@ -31,10 +33,10 @@ class DotAttention(nn.Module): | |||
def forward(self, Q, K, V, mask_out=None): | |||
r""" | |||
:param Q: [..., seq_len_q, key_size] | |||
:param K: [..., seq_len_k, key_size] | |||
:param V: [..., seq_len_k, value_size] | |||
:param mask_out: [..., 1, seq_len] or [..., seq_len_q, seq_len_k] | |||
:param Q: ``[..., seq_len_q, key_size]`` | |||
:param K: ``[..., seq_len_k, key_size]`` | |||
:param V: ``[..., seq_len_k, value_size]`` | |||
:param mask_out: ``[..., 1, seq_len]`` or ``[..., seq_len_q, seq_len_k]`` | |||
""" | |||
output = torch.matmul(Q, K.transpose(-1, -2)) / self.scale | |||
if mask_out is not None: | |||
@@ -46,8 +48,12 @@ class DotAttention(nn.Module): | |||
class MultiHeadAttention(nn.Module): | |||
""" | |||
Attention is all you need中提到的多头注意力 | |||
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ 中提到的多头注意力 | |||
:param d_model: | |||
:param n_head: | |||
:param dropout: | |||
:param layer_idx: | |||
""" | |||
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||
super(MultiHeadAttention, self).__init__() | |||
@@ -69,12 +75,13 @@ class MultiHeadAttention(nn.Module): | |||
def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): | |||
""" | |||
:param query: batch x seq x dim | |||
:param key: batch x seq x dim | |||
:param value: batch x seq x dim | |||
:param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 | |||
:param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 | |||
:param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 | |||
:param query: ``[batch, seq, dim]`` | |||
:param key: ``[batch, seq, dim]`` | |||
:param value: ``[batch, seq, dim]`` | |||
:param key_mask: ``[batch, seq]`` 用于指示哪些 ``key`` 不要 attend 到;注意到 mask 为 **1** 的地方是要attend到的 | |||
:param attn_mask: ``[seq, seq]``, 用于 mask 掉 attention map。 主要是用在训练时 decoder 端的 :class:`SelfAttention` , | |||
下三角为 1。 | |||
:param state: 过去的信息,在 inference 的时候会用到,比如 encoder output、decoder 的 prev kv。这样可以减少计算。 | |||
:return: | |||
""" | |||
assert key.size() == value.size() | |||
@@ -149,15 +156,15 @@ class MultiHeadAttention(nn.Module): | |||
class AttentionLayer(nn.Module): | |||
def __init__(selfu, input_size, key_dim, value_dim, bias=False): | |||
""" | |||
可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention | |||
""" | |||
可用于 LSTM2LSTM 的序列到序列模型的 decode 过程中,该 attention 是在 decode 过程中根据上一个 step 的 hidden 计算对 encoder 结果的 attention | |||
:param int input_size: 输入的大小 | |||
:param int key_dim: 一般就是encoder_output输出的维度 | |||
:param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 | |||
:param bias: | |||
""" | |||
:param int input_size: 输入的大小 | |||
:param int key_dim: 一般就是 encoder_output 输出的维度 | |||
:param int value_dim: 输出的大小维度, 一般就是 decoder hidden 的大小 | |||
:param bias: | |||
""" | |||
def __init__(selfu, input_size, key_dim, value_dim, bias=False): | |||
super().__init__() | |||
selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) | |||
@@ -166,10 +173,10 @@ class AttentionLayer(nn.Module): | |||
def forward(self, input, encode_outputs, encode_mask): | |||
""" | |||
:param input: batch_size x input_size | |||
:param encode_outputs: batch_size x max_len x key_dim | |||
:param encode_mask: batch_size x max_len, 为0的地方为padding | |||
:return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 | |||
:param input: ``[batch_size, input_size]`` | |||
:param encode_outputs: ``[batch_size, max_len, key_dim]`` | |||
:param encode_mask: ``[batch_size, max_len]``, 为0的地方为padding | |||
:return: hidden: ``[batch_size, value_dim]``, scores: ``[batch_size, max_len]``, normalized 过的 | |||
""" | |||
# x: bsz x encode_hidden_size | |||
@@ -221,9 +228,9 @@ def _weighted_sum(tensor, weights, mask): | |||
class BiAttention(nn.Module): | |||
r""" | |||
Bi Attention module | |||
**Bi Attention module** | |||
对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果 | |||
对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , :class:`BiAttention` 模块将通过以下的公式来计算 attention 结果 | |||
.. math:: | |||
@@ -237,11 +244,14 @@ class BiAttention(nn.Module): | |||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||
r""" | |||
:param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] | |||
:param torch.Tensor premise_mask: [batch_size, a_seq_len] | |||
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] | |||
:param torch.Tensor hypothesis_mask: [batch_size, b_seq_len] | |||
:return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] | |||
:param premise_batch: ``[batch_size, a_seq_len, hidden_size]`` | |||
:param premise_mask: ``[batch_size, a_seq_len]`` | |||
:param hypothesis_batch: ``[batch_size, b_seq_len, hidden_size]`` | |||
:param hypothesis_mask: ``[batch_size, b_seq_len]`` | |||
:return: 一个包含两个张量的元组,分别为: | |||
- ``attended_premises`` : ``[batch_size, a_seq_len, hidden_size]`` | |||
- ``attended_hypotheses`` : ``[batch_size, b_seq_len, hidden_size]`` | |||
""" | |||
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) | |||
.contiguous()) | |||
@@ -264,17 +274,15 @@ class BiAttention(nn.Module): | |||
class SelfAttention(nn.Module): | |||
r""" | |||
这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_ | |||
的Self Attention Module. | |||
的 **Self Attention Module** 。 | |||
:param input_size: 输入 tensor 的 hidden 维度 | |||
:param attention_unit: 输出 tensor 的 hidden 维度 | |||
:param attention_hops: | |||
:param drop: dropout 概率 | |||
""" | |||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5): | |||
r""" | |||
:param int input_size: 输入tensor的hidden维度 | |||
:param int attention_unit: 输出tensor的hidden维度 | |||
:param int attention_hops: | |||
:param float drop: dropout概率,默认值为0.5 | |||
""" | |||
super(SelfAttention, self).__init__() | |||
self.attention_hops = attention_hops | |||
@@ -301,10 +309,12 @@ class SelfAttention(nn.Module): | |||
def forward(self, input, input_origin): | |||
r""" | |||
:param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 | |||
:param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 | |||
:return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 | |||
:return torch.Tensor output2: [1] attention惩罚项,是一个标量 | |||
:param input: 要做 **attention** 的矩阵,形状为 ``[batch_size, seq_len, hidden_size]`` | |||
:param input_origin: 原始 token 的 index 组成的矩阵,含有 pad 部分内容,形状为 ``[batch_size, seq_len]`` | |||
:return: 一个元组,分别是: | |||
- 经过 **attention** 操作后输入矩阵的结果,形状为 ``[batch_size, multi-head, hidden_size]`` | |||
- **attention** 惩罚项,是一个标量 | |||
""" | |||
input = input.contiguous() | |||
size = input.size() # [bsz, len, nhid] | |||
@@ -1,11 +1,9 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConditionalRandomField", | |||
"allowed_transitions" | |||
] | |||
from typing import Union, List | |||
from typing import Union, List, Tuple | |||
import torch | |||
from torch import nn | |||
@@ -14,17 +12,19 @@ from ....core.metrics.span_f1_pre_rec_metric import _get_encoding_type_from_tag_ | |||
from ....core.vocabulary import Vocabulary | |||
def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type:str=None, include_start_end:bool=False): | |||
def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type:str=None, include_start_end:bool=False) -> List[Tuple[int, int]]: | |||
r""" | |||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
:param tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | |||
tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。 | |||
:param encoding_type: 支持``["bio", "bmes", "bmeso", "bioes"]``。默认为None,通过vocab自动推断 | |||
:param include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||
给定一个 ``id`` 到 ``label`` 的映射表,返回所有可以跳转的 ``(from_tag_id, to_tag_id)`` 列表。 | |||
:param tag_vocab: 支持类型为 tag 或 tag-label 。只有 tag 的,比如 ``"B"`` 、 ``"M"``,也可以是 ``"B-NN"`` 、 ``"M-NN"``, | |||
tag 和 label之间一定要用 ``"-"`` 隔开。如果传入 :class:`dict` ,格式需要形如 ``{0:"O", 1:"B-tag1"}`` ,即 **index 在前,tag 在后** 。 | |||
:param encoding_type: 支持 ``["bio", "bmes", "bmeso", "bioes", None]``。默认为 ``None``,通过 ``tag_vocab`` 自动推断 | |||
:param include_start_end: 是否包含开始与结尾的转换。比如在 ``bio`` 中, ``b/o`` 可以在开头,但是 ``i`` 不能在开头; | |||
- 为 ``True`` -- 返回的结果中会包含 ``(start_idx, b_idx), (start_idx, o_idx)`` ,但是不包含 ``(start_idx, i_idx)`` ; | |||
其中 ``start_idx=len(id2label)``,``end_idx=len(id2label)+1`` ; | |||
- 为 ``False`` , 返回的结果中不含与开始结尾相关的内容; | |||
:return: 一系列元组构成的列表,内部的 :class:`Tuple` 是可以进行跳转的 ``(from_tag_id, to_tag_id)``。 | |||
""" | |||
if encoding_type is None: | |||
encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
@@ -167,19 +167,15 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||
class ConditionalRandomField(nn.Module): | |||
r""" | |||
条件随机场。提供 forward() 以及 viterbi_decode() 两个方法,分别用于训练与inference。 | |||
条件随机场。提供 :meth:`forward` 以及 :meth:`viterbi_decode` 两个方法,分别用于 **训练** 与 **inference** 。 | |||
:param num_tags: 标签的数量 | |||
:param include_start_end_trans: 是否考虑各个 tag 作为开始以及结尾的分数。 | |||
:param allowed_transitions: 内部的 ``Tuple[from_tag_id(int), to_tag_id(int)]`` 视为允许发生的跃迁,其他没 | |||
有包含的跃迁认为是禁止跃迁,可以通过 :func:`allowed_transitions` 函数得到;如果为 ``None`` ,则所有跃迁均为合法。 | |||
""" | |||
def __init__(self, num_tags:int, include_start_end_trans:bool=False, allowed_transitions:List=None): | |||
r""" | |||
:param num_tags: 标签的数量 | |||
:param include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | |||
:param allowed_transitions: 内部的Tuple[from_tag_id(int), | |||
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||
""" | |||
super(ConditionalRandomField, self).__init__() | |||
self.include_start_end_trans = include_start_end_trans | |||
@@ -213,9 +209,9 @@ class ConditionalRandomField(nn.Module): | |||
r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
sum of the likelihoods across all possible state sequences. | |||
:param logits:FloatTensor, max_len x batch_size x num_tags | |||
:param mask:ByteTensor, max_len x batch_size | |||
:return:FloatTensor, batch_size | |||
:param logits:FloatTensor, ``[max_len, batch_size, num_tags]`` | |||
:param mask:ByteTensor, ``[max_len, batch_size]`` | |||
:return:FloatTensor, ``[batch_size,]`` | |||
""" | |||
seq_len, batch_size, n_tags = logits.size() | |||
alpha = logits[0] | |||
@@ -239,10 +235,10 @@ class ConditionalRandomField(nn.Module): | |||
def _gold_score(self, logits, tags, mask): | |||
r""" | |||
Compute the score for the gold path. | |||
:param logits: FloatTensor, max_len x batch_size x num_tags | |||
:param tags: LongTensor, max_len x batch_size | |||
:param mask: ByteTensor, max_len x batch_size | |||
:return:FloatTensor, batch_size | |||
:param logits: FloatTensor, ``[max_len, batch_size, num_tags]`` | |||
:param tags: LongTensor, ``[max_len, batch_size]`` | |||
:param mask: ByteTensor, ``[max_len, batch_size]`` | |||
:return:FloatTensor, ``[batch_size.]`` | |||
""" | |||
seq_len, batch_size, _ = logits.size() | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||
@@ -265,14 +261,14 @@ class ConditionalRandomField(nn.Module): | |||
# return [B,] | |||
return score | |||
def forward(self, feats, tags, mask): | |||
def forward(self, feats: "torch.FloatTensor", tags: "torch.LongTensor", mask: "torch.ByteTensor") -> "torch.FloatTensor": | |||
r""" | |||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||
用于计算 ``CRF`` 的前向 loss,返回值为一个形状为 ``[batch_size,]`` 的 :class:`torch.FloatTensor` ,可能需要 :func:`mean` 求得 loss 。 | |||
:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | |||
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | |||
:return: torch.FloatTensor, (batch_size,) | |||
:param feats: 特征矩阵,形状为 ``[batch_size, max_len, num_tags]`` | |||
:param tags: 标签矩阵,形状为 ``[batch_size, max_len]`` | |||
:param mask: 形状为 ``[batch_size, max_len]`` ,为 **0** 的位置认为是 padding。 | |||
:return: ``[batch_size,]`` | |||
""" | |||
feats = feats.transpose(0, 1) | |||
tags = tags.transpose(0, 1).long() | |||
@@ -282,17 +278,20 @@ class ConditionalRandomField(nn.Module): | |||
return all_path_score - gold_path_score | |||
def viterbi_decode(self, logits, mask, unpad=False): | |||
r"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
def viterbi_decode(self, logits: "torch.FloatTensor", mask: "torch.ByteTensor", unpad=False): | |||
r"""给定一个 **特征矩阵** 以及 **转移分数矩阵** ,计算出最佳的路径以及对应的分数 | |||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||
个sample的有效长度。 | |||
:param logits: 特征矩阵,形状为 ``[batch_size, max_len, num_tags]`` | |||
:param mask: 标签矩阵,形状为 ``[batch_size, max_len]`` ,为 **0** 的位置认为是 padding。如果为 ``None`` ,则认为没有 padding。 | |||
:param unpad: 是否将结果删去 padding: | |||
- 为 ``False`` 时,返回的是 ``[batch_size, max_len]`` 的张量 | |||
- 为 ``True`` 时,返回的是 :class:`List` [:class:`List` [ :class:`int` ]], 内部的 :class:`List` [:class:`int` ] 为每个 | |||
sequence 的 label ,已经除去 pad 部分,即每个 :class:`List` [ :class:`int` ] 的长度是这个 sample 的有效长度。 | |||
:return: (paths, scores)。 | |||
paths: 是解码后的路径, 其值参照unpad参数. | |||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
- ``paths`` -- 解码后的路径, 其值参照 ``unpad`` 参数. | |||
- ``scores`` -- :class:`torch.FloatTensor` ,形状为 ``[batch_size,]`` ,对应每个最优路径的分数。 | |||
""" | |||
batch_size, max_len, n_tags = logits.size() | |||
@@ -1,23 +1,15 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MLP" | |||
] | |||
from typing import List, Callable, Union | |||
import torch | |||
import torch.nn as nn | |||
class MLP(nn.Module): | |||
r""" | |||
多层感知器 | |||
.. note:: | |||
隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 | |||
如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; | |||
如果传入了一个str/function的list,那么每一个隐藏层的激活函数由这个list中对应的元素定义,其中list的长度为隐藏层数。 | |||
输出层的激活函数由output_activation定义,默认值为None,此时输出层没有激活函数。 | |||
多层感知器。 | |||
Examples:: | |||
@@ -31,18 +23,20 @@ class MLP(nn.Module): | |||
>>> y = net(x) | |||
>>> print(x) | |||
>>> print(y) | |||
:param size_layer: 一个 int 的列表,用来定义 :class:`MLP` 的层数,列表中的数字为每一层是 hidden 数目。 :class:`MLP` 的层数为 ``len(size_layer) - 1`` | |||
:param activation: 隐藏层的激活函数,可以支持多种类型: | |||
- 一个 :class:`str` 或函数 -- 所有隐藏层的激活函数都为 ``activation`` 代表的函数; | |||
- :class:`str` 或函数的列表 -- 每一个隐藏层的激活函数都为列表中对应的函数,其中列表长度为隐藏层数; | |||
对于字符串类型的输入,支持 ``['relu', 'tanh', 'sigmoid']`` 三种激活函数。 | |||
:param output_activation: 输出层的激活函数。默认值为 ``None``,表示输出层没有激活函数 | |||
:param dropout: dropout 概率 | |||
""" | |||
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | |||
r""" | |||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | |||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 | |||
sigmoid,默认值为relu | |||
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | |||
:param str initial_method: 参数初始化方式 | |||
:param float dropout: dropout概率,默认值为0 | |||
""" | |||
def __init__(self, size_layer: List[int], activation: Union[str, Callable, List[str]]='relu', | |||
output_activation: Union[str, Callable]=None, dropout: float=0.0): | |||
super(MLP, self).__init__() | |||
self.hiddens = nn.ModuleList() | |||
self.output = None | |||
@@ -85,8 +79,8 @@ class MLP(nn.Module): | |||
def forward(self, x): | |||
r""" | |||
:param torch.Tensor x: MLP接受的输入 | |||
:return: torch.Tensor : MLP的输出结果 | |||
:param x: | |||
:return: | |||
""" | |||
for layer, func in zip(self.hiddens, self.hidden_active): | |||
x = self.dropout(func(layer(x))) | |||
@@ -1,4 +1,3 @@ | |||
r"""undocumented""" | |||
from typing import Union, Tuple | |||
import math | |||
@@ -16,54 +15,52 @@ __all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder'] | |||
class Seq2SeqDecoder(nn.Module): | |||
""" | |||
Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | |||
**Sequence-to-Sequence Decoder** 的基类。一定需要实现 :meth:`forward` 和 :meth:`decode` 函数,剩下的函数根据需要实现。每个 ``Seq2SeqDecoder`` 都应该有相应的 | |||
:class:`~fastNLP.modules.torch.decoder.State` 对象用来承载该 ``Decoder`` 所需要的 ``Encoder`` 输出、``Decoder`` 需要记录的历史信(例如 :class:`~fastNLP.modules.torch.encoder.LSTM` | |||
的 hidden 信息)。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def forward(self, tokens, state, **kwargs): | |||
def forward(self, tokens: "torch.LongTensor", state: State, **kwargs): | |||
""" | |||
:param torch.LongTensor tokens: bsz x max_len | |||
:param State state: state包含了encoder的输出以及decode之前的内容 | |||
:return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 | |||
:param tokens: ``[batch_size, max_len]`` | |||
:param state: ``state`` 包含了 ``encoder`` 的输出以及 ``decode`` 之前的内容 | |||
:return: 返回值可以为 ``[batch_size, max_len, vocab_size]`` 的张量,也可以是一个 :class:`list`,但是第一个元素必须是词的预测分布 | |||
""" | |||
raise NotImplemented | |||
def reorder_states(self, indices, states): | |||
def reorder_states(self, indices: torch.LongTensor, states): | |||
""" | |||
根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 | |||
根据 ``indices`` 重新排列 ``states`` 中的状态,在 ``beam search`` 进行生成时,会用到该函数。 | |||
:param torch.LongTensor indices: | |||
:param State states: | |||
:return: | |||
:param indices: | |||
:param states: | |||
""" | |||
assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" | |||
states.reorder_state(indices) | |||
def init_state(self, encoder_output, encoder_mask): | |||
def init_state(self, encoder_output: Union[torch.Tensor, list, tuple], encoder_mask: Union[torch.Tensor, list, tuple]): | |||
""" | |||
初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 | |||
初始化一个 :class:`~fastNLP.modules.torch.decoder.State` 对象,用来记录 ``encoder`` 的输出以及 ``decode`` 已经完成的部分。 | |||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
:param encoder_output: 如果不为 ``None`` ,内部元素需要为 :class:`torch.Tensor`,默认其中第一维是 batch_size | |||
维度 | |||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
:param encoder_mask: 如果不为 ``None``,内部元素需要为 :class:`torch.Tensor`,默认其中第一维是 batch_size | |||
维度 | |||
:param kwargs: | |||
:return: State, 返回一个State对象,记录了encoder的输出 | |||
:return: 一个 :class:`~fastNLP.modules.torch.decoder.State` 对象,记录了 ``encoder`` 的输出 | |||
""" | |||
state = State(encoder_output, encoder_mask) | |||
return state | |||
def decode(self, tokens, state): | |||
def decode(self, tokens: torch.LongTensor, state) -> torch.FloatTensor: | |||
""" | |||
根据states中的内容,以及tokens中的内容进行之后的生成。 | |||
根据 ``states`` 中的内容,以及 ``tokens`` 中的内容进行之后的生成。 | |||
:param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。 | |||
:param State state: 记录了encoder输出与decoder过去状态 | |||
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | |||
:param tokens: ``[batch_size, max_len]``,截止到上一个时刻所有的 token 输出。 | |||
:param state: 记录了 ``encoder`` 输出与 ``decoder`` 过去状态 | |||
:return: `下一个时刻的分布,形状为 ``[batch_size, vocab_size]`` | |||
""" | |||
outputs = self(state=state, tokens=tokens) | |||
if isinstance(outputs, torch.Tensor): | |||
@@ -84,8 +81,8 @@ class TiedEmbedding(nn.Module): | |||
def forward(self, x): | |||
""" | |||
:param torch.FloatTensor x: bsz x * x embed_size | |||
:return: torch.FloatTensor bsz x * x vocab_size | |||
:param torch.FloatTensor x: batch_size x * x embed_size | |||
:return: torch.FloatTensor batch_size x * x vocab_size | |||
""" | |||
return torch.matmul(x, self.weight.t()) | |||
@@ -110,18 +107,24 @@ def get_bind_decoder_output_embed(embed): | |||
class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||
""" | |||
LSTM的Decoder | |||
:param nn.Module,tuple embed: decoder输入的embedding. | |||
:param int num_layers: 多少层LSTM | |||
:param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 | |||
:param dropout: Dropout的大小 | |||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
:param bool attention: 是否使用attention | |||
**LSTM** 的 Decoder | |||
:param embed: ``decoder`` 输入的 embedding,支持以下几种输入类型: | |||
- ``tuple(num_embedings, embedding_dim)``,即 embedding 的大小和每个词的维度,此时将随机初始化一个 :class:`torch.nn.Embedding` 实例; | |||
- :class:`torch.nn.Embedding` 或 **fastNLP** 的 ``Embedding`` 对象,此时就以传入的对象作为 embedding; | |||
- :class:`numpy.ndarray` ,将使用传入的 ndarray 作为 Embedding 初始化; | |||
- :class:`torch.Tensor`,此时将使用传入的值作为 Embedding 初始化; | |||
:param num_layers: LSTM 的层数 | |||
:param hidden_size: 隐藏层大小, 该值也被认为是 ``encoder`` 的输出维度大小 | |||
:param dropout: Dropout 的大小 | |||
:param bind_decoder_input_output_embed: ``decoder`` 的输出 embedding 是否与其输入 embedding 是一样的权重(即为同一个),若 ``embed`` 为 | |||
:class:`~fastNLP.embeddings.StaticEmbedding`,则 ``StaticEmbedding`` 的 ``vocab`` 不能包含 ``no_create_entry`` 的 token ,同时 | |||
``StaticEmbedding`` 初始化时 ``lower`` 为 ``False``,``min_freq=1``。 | |||
:param attention: 是否使用attention | |||
""" | |||
def __init__(self, embed: Union[nn.Module, Tuple[int, int]], num_layers = 3, hidden_size = 300, | |||
dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): | |||
def __init__(self, embed: Union[nn.Module, Tuple[int, int]], num_layers: int = 3, hidden_size: int = 300, | |||
dropout: float = 0.3, bind_decoder_input_output_embed: bool = True, attention: bool = True): | |||
super().__init__() | |||
self.embed = get_embeddings(init_embed=embed) | |||
self.embed_dim = embed.embedding_dim | |||
@@ -141,13 +144,14 @@ class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||
self.output_proj = nn.Linear(hidden_size, self.embed_dim) | |||
self.dropout_layer = nn.Dropout(dropout) | |||
def forward(self, tokens, state, return_attention=False): | |||
def forward(self, tokens: torch.LongTensor, state: LSTMState, return_attention: bool=False): | |||
""" | |||
:param torch.LongTensor tokens: batch x max_len | |||
:param LSTMState state: 保存encoder输出和decode状态的State对象 | |||
:param bool return_attention: 是否返回attention的的score | |||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
:param tokens: ``[batch_size, max_len]`` | |||
:param state: 保存 ``encoder`` 输出和 ``decode`` 状态的 :class:`~fastNLP.modules.torch.decoder.LSTMState` 对象 | |||
:param return_attention: 是否返回 attention 的 score | |||
:return: 形状为 ``[batch_size, max_len, vocab_size]`` 的结果。如果 ``return_attention=True`` 则返回一个元组,一个元素为结果,第二个结果为 | |||
注意力权重,形状为 ``[batch_size, max_len, encode_length]`` | |||
""" | |||
src_output = state.encoder_output | |||
encoder_mask = state.encoder_mask | |||
@@ -196,14 +200,18 @@ class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||
return feats, attn_weights | |||
return feats | |||
def init_state(self, encoder_output, encoder_mask) -> LSTMState: | |||
def init_state(self, encoder_output, encoder_mask: torch.ByteTensor) -> LSTMState: | |||
""" | |||
:param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: | |||
bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 | |||
hidden state和cell state来赋值这两个值 | |||
(2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 | |||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend | |||
:param encoder_output: ``encoder`` 的输出,可以有两种情况: | |||
- 输入一个 :class:`tuple`,包含三个内容 ``(encoder_output, (hidden, cell))``,其中 ``encoder_output`` 形状为 | |||
``[batch_size, max_len, hidden_size]``, ``hidden`` 形状为 ``[batch_size, hidden_size]``, ``cell`` 形状为 | |||
``[batch_size, hidden_size]`` ,一般使用 :class:`~fastNLP.modules.torch.encoder.LSTMSeq2SeqEncoder` 最后一层的 | |||
``hidden state`` 和 ``cell state`` 来赋值这两个值。 | |||
- 只有形状为 ``[batch_size, max_len, hidden_size]`` 的 ``encoder_output``, 这种情况下 ``hidden`` 和 ``cell`` | |||
使用 **0** 初始化。 | |||
:param encoder_mask: ``[batch_size, max_len]]``,为 **0** 的位置是 padding, 用来指示输入中哪些不需要 attend 。 | |||
:return: | |||
""" | |||
if not isinstance(encoder_output, torch.Tensor): | |||
@@ -233,14 +241,15 @@ class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||
class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
""" | |||
**Transformer** 的 Decoder 层 | |||
:param int d_model: 输入、输出的维度 | |||
:param int n_head: 多少个head,需要能被d_model整除 | |||
:param int dim_ff: | |||
:param float dropout: | |||
:param int layer_idx: layer的编号 | |||
:param d_model: 输入、输出的维度 | |||
:param n_head: **多头注意力** head 的数目,需要能被 ``d_model`` 整除 | |||
:param dim_ff: FFN 中间映射的维度 | |||
:param dropout: Dropout 的大小 | |||
:param layer_idx: layer的编号 | |||
""" | |||
def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): | |||
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, layer_idx: int = None): | |||
super().__init__() | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
@@ -262,14 +271,14 @@ class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
self.final_layer_norm = nn.LayerNorm(self.d_model) | |||
def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): | |||
def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state: TransformerState=None): | |||
""" | |||
:param x: (batch, seq_len, dim), decoder端的输入 | |||
:param encoder_output: (batch,src_seq_len,dim), encoder的输出 | |||
:param encoder_mask: batch,src_seq_len, 为1的地方需要attend | |||
:param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 | |||
:param TransformerState state: 只在inference阶段传入 | |||
:param x: ``decoder`` 端的输入,形状为 ``[batch_size, seq_len, dim]`` | |||
:param encoder_output: ``encoder`` 的输出,形状为 ``[batch_size, src_seq_len, dim]`` | |||
:param encoder_mask: 掩码,形状为 ``[batch_size, src_seq_len]``,为 **1** 的地方表示需要 attend | |||
:param self_attn_mask: 下三角的mask矩阵,只在训练时传入。形状为 ``[seq_len, seq_len]`` | |||
:param state: 只在 inference 阶段传入,记录了 ``encoder`` 和 ``decoder`` 的状态 | |||
:return: | |||
""" | |||
@@ -307,16 +316,23 @@ class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||
""" | |||
:param embed: 输入token的embedding | |||
:param nn.Module pos_embed: 位置embedding | |||
:param int d_model: 输出、输出的大小 | |||
:param int num_layers: 多少层 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN 的中间大小 | |||
:param float dropout: Self-Attention和FFN中的dropout的大小 | |||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
**Transformer** 的 Decoder | |||
:param embed: ``decoder`` 输入的 embedding,支持以下几种输入类型: | |||
- ``tuple(num_embedings, embedding_dim)``,即 embedding 的大小和每个词的维度,此时将随机初始化一个 :class:`torch.nn.Embedding` 实例; | |||
- :class:`torch.nn.Embedding` 或 **fastNLP** 的 ``Embedding`` 对象,此时就以传入的对象作为 embedding; | |||
- :class:`numpy.ndarray` ,将使用传入的 ndarray 作为 Embedding 初始化; | |||
- :class:`torch.Tensor`,此时将使用传入的值作为 Embedding 初始化; | |||
:param pos_embed: 位置 embedding | |||
:param d_model: 输入、输出的维度 | |||
:param num_layers: :class:`TransformerSeq2SeqDecoderLayer` 的层数 | |||
:param n_head: **多头注意力** head 的数目,需要能被 ``d_model`` 整除 | |||
:param dim_ff: FFN 中间映射的维度 | |||
:param dropout: :class:`~fastNLP.modules.torch.decoder.SelfAttention` 和 FFN 中的 dropout 的大小 | |||
:param bind_decoder_input_output_embed: ``decoder`` 的输出 embedding 是否与其输入 embedding 是一样的权重(即为同一个),若 ``embed`` 为 | |||
:class:`~fastNLP.embeddings.StaticEmbedding`,则 ``StaticEmbedding`` 的 ``vocab`` 不能包含 ``no_create_entry`` 的 token ,同时 | |||
``StaticEmbedding`` 初始化时 ``lower`` 为 ``False``,``min_freq=1``。 | |||
""" | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, | |||
d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, | |||
@@ -346,13 +362,14 @@ class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||
self.layer_norm = nn.LayerNorm(d_model) | |||
self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) | |||
def forward(self, tokens, state, return_attention=False): | |||
def forward(self, tokens: torch.LongTensor, state: TransformerState, return_attention: bool=False): | |||
""" | |||
:param torch.LongTensor tokens: batch x tgt_len,decode的词 | |||
:param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 | |||
:param bool return_attention: 是否返回对encoder结果的attention score | |||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
:param tokens: 用于解码的词,形状为 ``[batch_size, tgt_len]`` | |||
:param state: 用于记录 ``encoder`` 的输出以及 ``decode`` 状态的对象,可以通过 :meth:`init_state` 获取 | |||
:param return_attention: 是否返回对 ``encoder`` 结果的 attention score | |||
:return: 形状为 ``[batch_size, max_len, vocab_size]`` 的结果。如果 ``return_attention=True`` 则返回一个元组,一个元素为结果,第二个结果为 | |||
注意力权重,形状为 ``[batch_size, max_len, encode_length]`` | |||
""" | |||
encoder_output = state.encoder_output | |||
@@ -391,13 +408,13 @@ class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||
return feats, attn_weight | |||
return feats | |||
def init_state(self, encoder_output, encoder_mask): | |||
def init_state(self, encoder_output: torch.FloatTensor, encoder_mask: torch.ByteTensor) -> TransformerState: | |||
""" | |||
初始化一个TransformerState用于forward | |||
初始化一个 :class:`~fastNLP.modules.torch.decoder.TransformerState`` 用于 :meth:`forward` | |||
:param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 | |||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 | |||
:return: TransformerState | |||
:param encoder_output: ``encoder`` 的输出,形状为 ``[batch_size, max_len, d_model]`` | |||
:param encoder_mask: ``[batch_size, max_len]]``,为 **0** 的位置是 padding, 用来指示输入中哪些不需要 attend 。 | |||
:return: | |||
""" | |||
if isinstance(encoder_output, torch.Tensor): | |||
encoder_output = encoder_output | |||
@@ -9,21 +9,22 @@ __all__ = [ | |||
"TransformerState" | |||
] | |||
from typing import Union | |||
from typing import Union, List, Tuple | |||
import torch | |||
class State: | |||
def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): | |||
""" | |||
每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 | |||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param kwargs: | |||
""" | |||
""" | |||
每个 ``Decoder`` 都有对应的 :class:`State` 对象用来承载 ``encoder`` 的输出以及当前时刻之前的 ``decode`` 状态。 | |||
:param encoder_output: 如果不为 ``None`` ,内部元素需要为 :class:`torch.Tensor`,默认其中第一维是 ``batch_size`` | |||
维度 | |||
:param encoder_mask: 如果部位 ``None``,内部元素需要为 :class:`torch.Tensor`,默认其中第一维是 ``batch_size`` | |||
维度 | |||
:param kwargs: | |||
""" | |||
def __init__(self, encoder_output: Union[torch.Tensor, List, Tuple]=None, | |||
encoder_mask: Union[torch.Tensor, List, Tuple]=None, **kwargs): | |||
self.encoder_output = encoder_output | |||
self.encoder_mask = encoder_mask | |||
self._decode_length = 0 | |||
@@ -31,9 +32,7 @@ class State: | |||
@property | |||
def num_samples(self): | |||
""" | |||
返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 | |||
:return: | |||
返回的 State 中包含的是多少个 sample 的 encoder 状态,主要用于 Generate 的时候确定 batch_size 的大小。 | |||
""" | |||
if self.encoder_output is not None: | |||
return self.encoder_output.size(0) | |||
@@ -43,9 +42,7 @@ class State: | |||
@property | |||
def decode_length(self): | |||
""" | |||
当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 | |||
:return: | |||
当前 Decode 到哪个 token 了,decoder 只会从 decode_length 之后的 token 开始 decode, 为 **0** 说明还没开始 decode。 | |||
""" | |||
return self._decode_length | |||
@@ -79,26 +76,27 @@ class State: | |||
class LSTMState(State): | |||
def __init__(self, encoder_output, encoder_mask, hidden, cell): | |||
""" | |||
LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 | |||
:param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 | |||
:param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding | |||
:param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 | |||
:param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 | |||
""" | |||
""" | |||
:class:`~fastNLP.modules.torch.decoder.LSTMSeq2SeqDecoder` 对应的 :class:`State`,保存 ``encoder`` 的输出以及 ``LSTM`` 解码过程中的一些中间状态 | |||
:param encoder_output: ``encoder`` 的输出,形状为 ``[batch_size, src_seq_len, encode_output_size]`` | |||
:param encoder_mask: 掩码,形状为 ``[batch_size, src_seq_len]``,为 **1** 的地方表示需要 attend | |||
:param hidden: 上个时刻的 ``hidden`` 状态,形状为 ``[num_layers, batch_size, hidden_size]`` | |||
:param cell: 上个时刻的 ``cell`` 状态,形状为 ``[num_layers, batch_size, hidden_size]`` | |||
""" | |||
def __init__(self, encoder_output: torch.FloatTensor, encoder_mask: torch.BoolTensor, hidden: torch.FloatTensor, cell: torch.FloatTensor): | |||
super().__init__(encoder_output, encoder_mask) | |||
self.hidden = hidden | |||
self.cell = cell | |||
self._input_feed = hidden[0] # 默认是上一个时刻的输出 | |||
@property | |||
def input_feed(self): | |||
def input_feed(self) -> torch.FloatTensor: | |||
""" | |||
LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, | |||
input_feed即上个时刻的hidden state, 否则是attention layer的输出。 | |||
:return: torch.FloatTensor, bsz x hidden_size | |||
:class:`~fastNLP.modules.torch.decoder.LSTMSeq2SeqDecoder` 中每个时刻的输入会把上个 token 的 embedding 和 ``input_feed`` 拼接起来输入到下个时刻,在 | |||
:class:`~fastNLP.modules.torch.decoder.LSTMSeq2SeqDecoder` 不使用 ``attention`` 时,``input_feed`` 即上个时刻的 ``hidden state``,否则是 ``attention layer`` 的输出。 | |||
:return: ``[batch_size, hidden_size]`` | |||
""" | |||
return self._input_feed | |||
@@ -115,14 +113,14 @@ class LSTMState(State): | |||
class TransformerState(State): | |||
def __init__(self, encoder_output, encoder_mask, num_decoder_layer): | |||
""" | |||
与TransformerSeq2SeqDecoder对应的State, | |||
:param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 | |||
:param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend | |||
:param int num_decoder_layer: decode有多少层 | |||
""" | |||
""" | |||
与 :class:`~fastNLP.modules.torch.decoder.TransformerSeq2SeqDecoder` 对应的 :class:`State`。 | |||
:param encoder_output: ``encoder`` 的输出,形状为 ``[batch_size, encode_max_len, encode_output_size]``, | |||
:param encoder_mask: 掩码,形状为 ``[batch_size, encode_max_len]``,为 **1** 的地方表示需要 attend | |||
:param num_decoder_layer: decoder 层数 | |||
""" | |||
def __init__(self, encoder_output: torch.FloatTensor, encoder_mask: torch.FloatTensor, num_decoder_layer: int): | |||
super().__init__(encoder_output, encoder_mask) | |||
self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim | |||
self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim | |||
@@ -1,5 +1,3 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"TimestepDropout" | |||
] | |||
@@ -9,8 +7,8 @@ import torch | |||
class TimestepDropout(torch.nn.Dropout): | |||
r""" | |||
传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` | |||
使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 | |||
传入参数的 shape 为 ``(batch_size, num_timesteps, embedding_dim)`` | |||
使用同一个 shape 为 ``(batch_size, embedding_dim)`` 的 mask 在每个 timestamp 上做 dropout。 | |||
""" | |||
def forward(self, x): | |||
@@ -1,8 +1,7 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConvMaxpool" | |||
] | |||
from typing import Union, List | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
@@ -10,20 +9,17 @@ import torch.nn.functional as F | |||
class ConvMaxpool(nn.Module): | |||
r""" | |||
集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x | |||
sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) | |||
这一维进行max_pooling。最后得到每个sample的一个向量表示。 | |||
集合了 **Convolution** 和 **Max-Pooling** 于一体的层。给定一个 ``[batch_size, max_len, input_size]`` 的输入,返回 | |||
``[batch_size, sum(output_channels)]`` 大小的 matrix。在内部,是先使用 ``CNN`` 给输入做卷积,然后经过 activation | |||
激活层,在通过在长度(max_len)这一维进行 ``max_pooling`` 。最后得到每个 sample 的一个向量表示。 | |||
:param in_channels: 输入 channel 的大小,一般是 embedding 的维度,或 ``encoder``的 output 维度 | |||
:param out_channels: 输出 channel 的数量。如果为 :class:`list`,则需要与 ``kernel_sizes`` 的数量保持一致 | |||
:param kernel_sizes: 输出 channel 的 kernel 大小。 | |||
:param activation: **卷积** 后的结果将通过该 ``activation`` 后再经过 ``max-pooling``。支持 ``['relu', 'sigmoid', 'tanh']``。 | |||
""" | |||
def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): | |||
r""" | |||
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | |||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | |||
""" | |||
def __init__(self, in_channels: int, out_channels: Union[int, List[int]], kernel_sizes: Union[int, List[int]], activation: str="relu"): | |||
super(ConvMaxpool, self).__init__() | |||
for kernel_size in kernel_sizes: | |||
@@ -67,11 +63,11 @@ class ConvMaxpool(nn.Module): | |||
raise Exception( | |||
"Undefined activation function: choose from: relu, tanh, sigmoid") | |||
def forward(self, x, mask=None): | |||
def forward(self, x: torch.FloatTensor, mask=None): | |||
r""" | |||
:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | |||
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | |||
:param x: ``[batch_size, max_len, input_size]``,一般是经过 ``embedding`` 后的值 | |||
:param mask: ``[batch_size, max_len]``,**0** 的位置表示 padding,不影响卷积运算,``max-pooling`` 一定不会 pool 到 padding 为 0 的位置 | |||
:return: | |||
""" | |||
# [N,L,C] -> [N,C,L] | |||
@@ -1,6 +1,6 @@ | |||
r"""undocumented | |||
轻量封装的 Pytorch LSTM 模块. | |||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | |||
r""" | |||
轻量封装的 **Pytorch LSTM** 模块. | |||
可在 :meth:`forward` 时传入序列的长度, 自动对 padding 做合适的处理. | |||
""" | |||
__all__ = [ | |||
@@ -20,23 +20,21 @@ else: | |||
class LSTM(Module): | |||
r""" | |||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | |||
为1; 且可以应对DataParallel中LSTM的使用问题。 | |||
**LSTM** 模块,轻量封装的 **Pytorch LSTM** 。在提供 ``seq_len`` 的情况下,将自动使用 ``pack_padded_sequence``;同时默认将 ``forget gate`` | |||
的 bias 初始化为 **1**,且可以应对 :class:`DataParallel` 中 LSTM 的使用问题。 | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度. 如果 ``bidirectional`` 为 ``True``,则输出的维度会是 ``hidde_size*2`` | |||
:param num_layers: rnn 的层数 | |||
:param dropout: 层间 dropout 概率 | |||
:param bidirectional: 若为 ``True``, 使用双向的 RNN | |||
:param batch_first: 若为 ``True``, 输入和输出 :class:`torch.Tensor` 形状为 ``[batch_size, seq_len, feature]``,否则为 | |||
``[seq_len, batch_size, features]`` | |||
:param bias: 如果为 ``False``, 模型将不会使用 bias | |||
""" | |||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||
bidirectional=False, bias=True): | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param dropout: 层间dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
:(batch, seq, feature). Default: ``False`` | |||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
""" | |||
super(LSTM, self).__init__() | |||
self.batch_first = batch_first | |||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||
@@ -56,12 +54,12 @@ class LSTM(Module): | |||
def forward(self, x, seq_len=None, h0=None, c0=None): | |||
r""" | |||
:param x: [batch, seq_len, input_size] 输入序列 | |||
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | |||
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||
:return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 | |||
和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. | |||
:param x: 输入序列,形状为 ``[batch_size, seq_len, input_size]`` | |||
:param seq_len: 序列长度,形状为 ``[batch_size, ]``,若为 ``None``,表示所有输入看做一样长 | |||
:param h0: 初始隐状态,形状为 ``[batch_size, hidden_size]``,若为 ``None`` ,设为全 **0** 向量 | |||
:param c0: 初始 ``Cell`` 状态,形状为 ``[batch_size, hidden_size]``,若为 ``None`` ,设为全 **0** 向量 | |||
:return: 返回 ``(output, (ht, ct))`` 格式的结果。``output`` 形状为 ``[batch_size, seq_len, hidden_size*num_direction]``,表示输出序列; | |||
``ht`` 和 ``ct`` 形状为 ``[num_layers*num_direction, batch_size, hidden_size]``,表示最后时刻隐状态。 | |||
""" | |||
batch_size, max_len, _ = x.size() | |||
if h0 is not None and c0 is not None: | |||
@@ -1,4 +1,3 @@ | |||
r"""undocumented""" | |||
import torch.nn as nn | |||
import torch | |||
from torch.nn import LayerNorm | |||
@@ -17,17 +16,17 @@ __all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder'] | |||
class Seq2SeqEncoder(nn.Module): | |||
""" | |||
所有Sequence2Sequence Encoder的基类。需要实现forward函数 | |||
所有 **Sequence2Sequence Encoder** 的基类。需要实现 :meth:`forward` 函数 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def forward(self, tokens, seq_len): | |||
def forward(self, tokens: torch.LongTensor, seq_len: torch.LongTensor): | |||
""" | |||
:param torch.LongTensor tokens: bsz x max_len, encoder的输入 | |||
:param torch.LongTensor seq_len: bsz | |||
:param tokens: ``[batch_size, max_len]]``,encoder 的输入 | |||
:param seq_len: ``[batch_size,]`` | |||
:return: | |||
""" | |||
raise NotImplementedError | |||
@@ -35,7 +34,7 @@ class Seq2SeqEncoder(nn.Module): | |||
class TransformerSeq2SeqEncoderLayer(nn.Module): | |||
""" | |||
Self-Attention的Layer, | |||
**Self-Attention** 的 Layer, | |||
:param int d_model: input和output的输出维度 | |||
:param int n_head: 多少个head,每个head的维度为d_model/n_head | |||
@@ -63,8 +62,8 @@ class TransformerSeq2SeqEncoderLayer(nn.Module): | |||
def forward(self, x, mask): | |||
""" | |||
:param x: batch x src_seq x d_model | |||
:param mask: batch x src_seq,为0的地方为padding | |||
:param x: batch_size, src_seq, d_model | |||
:param mask: batch_size, src_seq,为0的地方为padding | |||
:return: | |||
""" | |||
# attention | |||
@@ -88,18 +87,23 @@ class TransformerSeq2SeqEncoderLayer(nn.Module): | |||
class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||
""" | |||
基于Transformer的Encoder | |||
:param embed: encoder输入token的embedding | |||
:param nn.Module pos_embed: position embedding | |||
:param int num_layers: 多少层的encoder | |||
:param int d_model: 输入输出的维度 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN中间的维度大小 | |||
:param float dropout: Attention和FFN的dropout大小 | |||
基于 **Transformer** 的 :class:`Encoder` | |||
:param embed: ``decoder`` 输入的 embedding,支持以下几种输入类型: | |||
- ``tuple(num_embedings, embedding_dim)``,即 embedding 的大小和每个词的维度,此时将随机初始化一个 :class:`torch.nn.Embedding` 实例; | |||
- :class:`torch.nn.Embedding` 或 **fastNLP** 的 ``Embedding`` 对象,此时就以传入的对象作为 embedding; | |||
- :class:`numpy.ndarray` ,将使用传入的 ndarray 作为 Embedding 初始化; | |||
- :class:`torch.Tensor`,此时将使用传入的值作为 Embedding 初始化; | |||
:param pos_embed: 位置 embedding | |||
:param d_model: 输入、输出的维度 | |||
:param num_layers: :class:`TransformerSeq2SeqDecoderLayer` 的层数 | |||
:param n_head: **多头注意力** head 的数目,需要能被 ``d_model`` 整除 | |||
:param dim_ff: FFN 中间映射的维度 | |||
:param dropout: :class:`~fastNLP.modules.torch.decoder.SelfAttention` 和 FFN 中的 dropout 的大小 | |||
""" | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, | |||
num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, | |||
d_model: int = 512, num_layers: int = 6, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1): | |||
super(TransformerSeq2SeqEncoder, self).__init__() | |||
self.embed = get_embeddings(embed) | |||
self.embed_scale = math.sqrt(d_model) | |||
@@ -118,9 +122,10 @@ class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||
def forward(self, tokens, seq_len): | |||
""" | |||
:param tokens: batch x max_len | |||
:param seq_len: [batch] | |||
:return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) | |||
:param tokens: 输入序列,形状为 ``[batch_size, max_len]`` | |||
:param seq_len: 序列长度,形状为 ``[batch_size, ]``,若为 ``None``,表示所有输入看做一样长 | |||
:return: 一个元组,第一个元素形状为 ``[batch_size, max_len, d_model]`` 表示前向传播的结果,第二个元素形状为 | |||
``[batch_size, max_len]``, 表示产生的掩码 ``encoder_mask``,为 **0** 的地方为 padding | |||
""" | |||
x = self.embed(tokens) * self.embed_scale # batch, seq, dim | |||
batch_size, max_src_len, _ = x.size() | |||
@@ -145,16 +150,21 @@ class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||
class LSTMSeq2SeqEncoder(Seq2SeqEncoder): | |||
""" | |||
LSTM的Encoder | |||
:param embed: encoder的token embed | |||
:param int num_layers: 多少层 | |||
:param int hidden_size: LSTM隐藏层、输出的大小 | |||
:param float dropout: LSTM层之间的Dropout是多少 | |||
:param bool bidirectional: 是否使用双向 | |||
**LSTM** 的 Encoder | |||
:param embed: ``decoder`` 输入的 embedding,支持以下几种输入类型: | |||
- ``tuple(num_embedings, embedding_dim)``,即 embedding 的大小和每个词的维度,此时将随机初始化一个 :class:`torch.nn.Embedding` 实例; | |||
- :class:`torch.nn.Embedding` 或 **fastNLP** 的 ``Embedding`` 对象,此时就以传入的对象作为 embedding; | |||
- :class:`numpy.ndarray` ,将使用传入的 ndarray 作为 Embedding 初始化; | |||
- :class:`torch.Tensor`,此时将使用传入的值作为 Embedding 初始化; | |||
:param num_layers: LSTM 的层数 | |||
:param hidden_size: 隐藏层大小, 该值也被认为是 ``encoder`` 的输出维度大小 | |||
:param dropout: Dropout 的大小 | |||
:param bidirectional: 是否使用双向 | |||
""" | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, | |||
hidden_size = 400, dropout = 0.3, bidirectional=True): | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers: int = 3, | |||
hidden_size: int = 400, dropout: float = 0.3, bidirectional: bool=True): | |||
super().__init__() | |||
self.embed = get_embeddings(embed) | |||
self.num_layers = num_layers | |||
@@ -165,15 +175,17 @@ class LSTMSeq2SeqEncoder(Seq2SeqEncoder): | |||
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, | |||
batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) | |||
def forward(self, tokens, seq_len): | |||
def forward(self, tokens: torch.LongTensor, seq_len: torch.LongTensor): | |||
""" | |||
:param torch.LongTensor tokens: bsz x max_len | |||
:param torch.LongTensor seq_len: bsz | |||
:return: (output, (hidden, cell)), encoder_mask | |||
output: bsz x max_len x hidden_size, | |||
hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 | |||
encoder_mask: bsz x max_len, 为0的地方是padding | |||
:param tokens: 输入序列,形状为 ``[batch_size, max_len]`` | |||
:param seq_len: 序列长度,形状为 ``[batch_size, ]``,若为 ``None``,表示所有输入看做一样长 | |||
:return: 返回 ``((output, (ht, ct)), encoder_mask)`` 格式的结果。 | |||
- ``output`` 形状为 ``[batch_size, seq_len, hidden_size*num_direction]``,表示输出序列; | |||
- ``ht`` 和 ``ct`` 形状为 ``[num_layers*num_direction, batch_size, hidden_size]``,表示最后时刻隐状态; | |||
- ``encoder_mask`` 形状为 ``[batch_size, max_len]``, 表示产生的掩码 ``encoder_mask``,为 **0** 的地方为 padding | |||
""" | |||
x = self.embed(tokens) | |||
device = x.device | |||
@@ -1,5 +1,5 @@ | |||
r"""undocumented | |||
Star-Transformer 的encoder部分的 Pytorch 实现 | |||
r""" | |||
**Star-Transformer** 的 encoder 部分的 Pytorch 实现 | |||
""" | |||
__all__ = [ | |||
@@ -14,24 +14,19 @@ from torch.nn import functional as F | |||
class StarTransformer(nn.Module): | |||
r""" | |||
Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 | |||
paper: https://arxiv.org/abs/1902.09113 | |||
**Star-Transformer** 的 encoder 部分。输入 3d 的文本输入,返回相同长度的文本编码。 | |||
基于论文 `Star-Transformer <https://arxiv.org/abs/1902.09113>`_ | |||
:param hidden_size: 输入维度的大小,同时也是输出维度的大小。 | |||
:param num_layers: **Star-Transformer** 的层数 | |||
:param num_head: **多头注意力** head 的数目,需要能被 ``d_model`` 整除 | |||
:param head_dim: 每个 ``head`` 的维度大小。 | |||
:param dropout: dropout 概率 | |||
:param max_len: 如果为 :class:`int` 表示输入序列的最大长度,模型会为输入序列加上 ``position embedding``; | |||
若为 ``None`` 则会跳过此步骤。 | |||
""" | |||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | |||
r""" | |||
:param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 | |||
:param int num_layers: star-transformer的层数 | |||
:param int num_head: head的数量。 | |||
:param int head_dim: 每个head的维度大小。 | |||
:param float dropout: dropout 概率. Default: 0.1 | |||
:param int max_len: int or None, 如果为int,输入序列的最大长度, | |||
模型会为输入序列加上position embedding。 | |||
若为`None`,忽略加上position embedding的步骤. Default: `None` | |||
""" | |||
def __init__(self, hidden_size: int, num_layers: int, num_head: int, head_dim: int, dropout: float=0.1, max_len: int=None): | |||
super(StarTransformer, self).__init__() | |||
self.iters = num_layers | |||
@@ -50,14 +45,12 @@ class StarTransformer(nn.Module): | |||
else: | |||
self.pos_emb = None | |||
def forward(self, data, mask): | |||
def forward(self, data: torch.FloatTensor, mask: torch.ByteTensor): | |||
r""" | |||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | |||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||
否则为 1 | |||
:return: [batch, length, hidden] 编码后的输出序列 | |||
[batch, hidden] 全局 relay 节点, 详见论文 | |||
:param data: 输入序列,形状为 ``[batch_size, length, hidden]`` | |||
:param mask: 输入序列的 padding mask, 形状为 ``[batch_size, length]`` , 为 **0** 的地方为 padding | |||
:return: 返回一个元组,第一个元素形状为 ``[batch_size, length, hidden]`` ,代表编码后的输出序列; | |||
第二个元素形状为 ``[batch_size, hidden]``,表示全局 relay 节点, 详见论文。 | |||
""" | |||
def norm_func(f, x): | |||
@@ -1,5 +1,3 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"TransformerEncoder" | |||
] | |||
@@ -11,18 +9,15 @@ from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||
class TransformerEncoder(nn.Module): | |||
r""" | |||
transformer的encoder模块,不包含embedding层 | |||
**Transformer** 的 encoder 模块,不包含 embedding 层。 | |||
:param num_layers: **TransformerEncoder** 的层数。 | |||
:param d_model: 输入维度的大小,同时也是输出维度的大小。 | |||
:param n_head: **多头注意力** head 的数目,需要能被 ``d_model`` 整除 | |||
:param dim_ff: FFN 中间映射的维度 | |||
:param dropout: :class:`~fastNLP.modules.torch.decoder.SelfAttention` 和 FFN 中的 dropout 的大小 | |||
""" | |||
def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): | |||
""" | |||
:param int num_layers: 多少层Transformer | |||
:param int d_model: input和output的大小 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN中间hidden大小 | |||
:param float dropout: 多大概率drop attention和ffn中间的表示 | |||
""" | |||
def __init__(self, num_layers: int, d_model: int=512, n_head: int=8, dim_ff: int=2048, dropout: float=0.1): | |||
super(TransformerEncoder, self).__init__() | |||
self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, | |||
dropout = dropout) for _ in range(num_layers)]) | |||
@@ -30,10 +25,10 @@ class TransformerEncoder(nn.Module): | |||
def forward(self, x, seq_mask=None): | |||
r""" | |||
:param x: [batch, seq_len, model_size] 输入序列 | |||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend | |||
Default: ``None`` | |||
:return: [batch, seq_len, model_size] 输出序列 | |||
:param x: 输入序列,形状为 ``[batch_size, seq_len, d_model]`` | |||
:param seq_mask: 输入序列的 padding mask ,形状为 ``[batch, seq_len]``,若为 ``None``,则生成全 **1** 向量;为 **1** | |||
的地方表示需要 attend 。 | |||
:return: 输出序列,形状为 ``[batch, seq_len, d_model]`` | |||
""" | |||
output = x | |||
if seq_mask is None: | |||
@@ -1,5 +1,5 @@ | |||
r"""undocumented | |||
Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | |||
r""" | |||
**Variational RNN** 及相关模型的 **fastNLP** 实现,相关论文参考: | |||
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
@@ -227,77 +227,86 @@ class VarLSTM(VarRNNBase): | |||
Variational Dropout LSTM. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
:param input_size: 输入 `x` 的特征维度。 | |||
:param hidden_size: 隐状态 `h` 的特征维度。 | |||
:param num_layers: rnn的层数,默认为 1。 | |||
:param bias: 如果为 ``False``,模型将不会使用bias。默认为 ``True``。 | |||
:param batch_first: 若为 ``True``,输入和输出 ``Tensor`` 形状为 | |||
``[batch_size, seq, feature]``,否则为 ``[seq, batch_size, feature]``。 | |||
:param input_dropout: 对输入的 dropout 概率。默认为 0。 | |||
:param hidden_dropout: 对每个隐状态的 dropout 概率。默认为 0。 | |||
:param bidirectional: 若为 ``True``,用双向的 LSTM。默认为 ``False``。 | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
(batch, seq, feature). Default: ``False`` | |||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | |||
""" | |||
super(VarLSTM, self).__init__( | |||
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
""" | |||
:param x: 输入序列 ``[batch_size, seq_len, input_size]``。 | |||
:param hx: 初始隐状态 ``[batch_size, hidden_size]``,若为 ``None`` 则初始化为全 **1** 向量。 | |||
:return: ``(output, ht)`` 格式的结果: ``output`` 形状为 ``[batch_size, seq_len, hidden_size*num_direction]``, | |||
表示输出序列,``ht`` 形状为 ``[batch_size, hidden_size*num_direction]``,表示最后时刻隐状态。 | |||
""" | |||
return super(VarLSTM, self).forward(x, hx) | |||
class VarRNN(VarRNNBase): | |||
r""" | |||
Variational Dropout RNN. | |||
**Variational Dropout RNN**。 | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
:param input_size: 输入 `x` 的特征维度。 | |||
:param hidden_size: 隐状态 `h` 的特征维度。 | |||
:param num_layers: rnn的层数,默认为 1。 | |||
:param bias: 如果为 ``False``,模型将不会使用bias。默认为 ``True``。 | |||
:param batch_first: 若为 ``True``,输入和输出 ``Tensor`` 形状为 | |||
``[batch_size, seq, feature]``,否则为 ``[seq, batch_size, feature]``。 | |||
:param input_dropout: 对输入的 dropout 概率。默认为 0。 | |||
:param hidden_dropout: 对每个隐状态的 dropout 概率。默认为 0。 | |||
:param bidirectional: 若为 ``True``,用双向的 RNN。默认为 ``False``。 | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
(batch, seq, feature). Default: ``False`` | |||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
""" | |||
super(VarRNN, self).__init__( | |||
mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
""" | |||
:param x: 输入序列 ``[batch_size, seq_len, input_size]``。 | |||
:param hx: 初始隐状态 ``[batch_size, hidden_size]``,若为 ``None`` 则初始化为全 **1** 向量。 | |||
:return: ``(output, ht)`` 格式的结果: ``output`` 形状为 ``[batch_size, seq_len, hidden_size*num_direction]``, | |||
表示输出序列,``ht`` 形状为 ``[batch_size, hidden_size*num_direction]``,表示最后时刻隐状态。 | |||
""" | |||
return super(VarRNN, self).forward(x, hx) | |||
class VarGRU(VarRNNBase): | |||
r""" | |||
Variational Dropout GRU. | |||
**Variational Dropout GRU**。 | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
:param input_size: 输入 `x` 的特征维度。 | |||
:param hidden_size: 隐状态 `h` 的特征维度。 | |||
:param num_layers: rnn的层数,默认为 1。 | |||
:param bias: 如果为 ``False``,模型将不会使用bias。默认为 ``True``。 | |||
:param batch_first: 若为 ``True``,输入和输出 ``Tensor`` 形状为 | |||
``[batch_size, seq, feature]``,否则为 ``[seq, batch_size, feature]``。 | |||
:param input_dropout: 对输入的 dropout 概率。默认为 0。 | |||
:param hidden_dropout: 对每个隐状态的 dropout 概率。默认为 0。 | |||
:param bidirectional: 若为 ``True``,用双向的 GRU 。默认为 ``False``。 | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
(batch, seq, feature). Default: ``False`` | |||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | |||
""" | |||
super(VarGRU, self).__init__( | |||
mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
""" | |||
:param x: 输入序列 ``[batch_size, seq_len, input_size]``。 | |||
:param hx: 初始隐状态 ``[batch_size, hidden_size]``,若为 ``None`` 则初始化为全 **1** 向量。 | |||
:return: ``(output, ht)`` 格式的结果: ``output`` 形状为 ``[batch_size, seq_len, hidden_size*num_direction]``, | |||
表示输出序列,``ht`` 形状为 ``[batch_size, hidden_size*num_direction]``,表示最后时刻隐状态。 | |||
""" | |||
return super(VarGRU, self).forward(x, hx) |
@@ -32,29 +32,26 @@ def _get_model_device(model): | |||
class SequenceGenerator: | |||
""" | |||
给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出, | |||
第二个参数为state。SequenceGenerator不会对state进行任何操作。 | |||
给定一个 :class:`~fastNLP.modules.torch.decoder.Seq2SeqDecoder` ,decode出句子。输入的 decoder 对象需要有 :meth:`decode` 函数,接受的第一个参数为 decode 的到目前位置的所有输出, | |||
第二个参数为 state 。:class:`SequenceGenerator` 不会对 state 进行任何操作。 | |||
:param decoder: Decoder对象 | |||
:param max_length: 生成句子的最大长度, 每句话的 decode 长度为 ``max_length + max_len_a * src_len`` | |||
:param max_len_a: 每句话的 decode 长度为 ``max_length + max_len_a*src_len``。如果不为 0,需要保证 State 中包含 encoder_mask | |||
:param num_beams: **beam search** 的大小 | |||
:param do_sample: 是否通过采样的方式生成 | |||
:param temperature: 只有在 do_sample 为 ``True`` 才有意义 | |||
:param top_k: 只从 ``top_k`` 中采样 | |||
:param top_p: 只从 ``top_p`` 的 token 中采样( **nucleus sampling** ) | |||
:param bos_token_id: 句子开头的 token id | |||
:param eos_token_id: 句子结束的 token id | |||
:param repetition_penalty: 多大程度上惩罚重复的 token | |||
:param length_penalty: 对长度的惩罚,**小于 1** 鼓励长句,**大于 1** 鼓励短句 | |||
:param pad_token_id: 当某句话生成结束之后,之后生成的内容用 ``pad_token_id`` 补充 | |||
""" | |||
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, | |||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
""" | |||
:param Seq2SeqDecoder decoder: Decoder对象 | |||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
:param int num_beams: beam search的大小 | |||
:param bool do_sample: 是否通过采样的方式生成 | |||
:param float temperature: 只有在do_sample为True才有意义 | |||
:param int top_k: 只从top_k中采样 | |||
:param float top_p: 只从top_p的token中采样,nucles sample | |||
:param int,None bos_token_id: 句子开头的token id | |||
:param int,None eos_token_id: 句子结束的token id | |||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
""" | |||
def __init__(self, decoder: Seq2SeqDecoder, max_length: int=20, max_len_a: float=0.0, num_beams: int=1, | |||
do_sample: bool=True, temperature: float=1.0, top_k: int=50, top_p: float=1.0, bos_token_id: int=None, eos_token_id: int=None, | |||
repetition_penalty: float=1, length_penalty: float=1.0, pad_token_id: int=0): | |||
if do_sample: | |||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, | |||
num_beams=num_beams, | |||
@@ -80,13 +77,13 @@ class SequenceGenerator: | |||
self.decoder = decoder | |||
@torch.no_grad() | |||
def generate(self, state, tokens=None): | |||
def generate(self, state: State, tokens: "torch.LongTensor"=None): | |||
""" | |||
:param State state: encoder结果的State, 是与Decoder配套是用的 | |||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token | |||
:param state: ``encoder`` 结果的 :class:`~fastNLP.modules.torch.decoder.State` ,是与 ``Decoder`` 配套使用的 | |||
:param tokens: 开始的 token,形状为 ``[batch_size, length]``。如果为 ``None`` ,则默认添加 ``bos_token`` 作为开头的 token | |||
进行生成。 | |||
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | |||
:return: 生成的 token 序列,形状为 ``[bsz, max_length]`` 。如果 ``eos_token_id`` 不为 ``None`` ,,每个 sequence 的结尾一定是 ``eos_token_id`` | |||
""" | |||
return self.generate_func(tokens=tokens, state=state) | |||
@@ -100,7 +97,7 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0 | |||
贪婪地搜索句子 | |||
:param Decoder decoder: Decoder对象 | |||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
:param torch.LongTensor tokens: batch_size, len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
:param State state: 应该包含encoder的一些输出。 | |||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
@@ -136,7 +133,7 @@ def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0 | |||
使用采样的方法生成句子 | |||
:param Decoder decoder: Decoder对象 | |||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
:param torch.LongTensor tokens: batch_size, len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
:param State state: 应该包含encoder的一些输出。 | |||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
@@ -504,7 +501,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf") | |||
""" | |||
根据top_k, top_p的值,将不满足的值置为filter_value的值 | |||
:param torch.Tensor logits: bsz x vocab_size | |||
:param torch.Tensor logits: bsz, vocab_size | |||
:param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value | |||
:param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式 | |||
:param float filter_value: | |||