|
- r"""undocumented"""
- import torch.nn as nn
- import torch
- from torch.nn import LayerNorm
- import torch.nn.functional as F
- from typing import Union, Tuple
- from ....core.utils import seq_len_to_mask
- import math
- from .lstm import LSTM
- from ..attention import MultiHeadAttention
- from ....embeddings.torch import StaticEmbedding
- from ....embeddings.torch.utils import get_embeddings
-
-
- __all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder']
-
-
- class Seq2SeqEncoder(nn.Module):
- """
- 所有Sequence2Sequence Encoder的基类。需要实现forward函数
-
- """
- def __init__(self):
- super().__init__()
-
- def forward(self, tokens, seq_len):
- """
-
- :param torch.LongTensor tokens: bsz x max_len, encoder的输入
- :param torch.LongTensor seq_len: bsz
- :return:
- """
- raise NotImplementedError
-
-
- class TransformerSeq2SeqEncoderLayer(nn.Module):
- """
- Self-Attention的Layer,
-
- :param int d_model: input和output的输出维度
- :param int n_head: 多少个head,每个head的维度为d_model/n_head
- :param int dim_ff: FFN的维度大小
- :param float dropout: Self-attention和FFN的dropout大小,0表示不drop
- """
- def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048,
- dropout: float = 0.1):
- super(TransformerSeq2SeqEncoderLayer, self).__init__()
- self.d_model = d_model
- self.n_head = n_head
- self.dim_ff = dim_ff
- self.dropout = dropout
-
- self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
- self.attn_layer_norm = LayerNorm(d_model)
- self.ffn_layer_norm = LayerNorm(d_model)
-
- self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff),
- nn.ReLU(),
- nn.Dropout(dropout),
- nn.Linear(self.dim_ff, self.d_model),
- nn.Dropout(dropout))
-
- def forward(self, x, mask):
- """
-
- :param x: batch x src_seq x d_model
- :param mask: batch x src_seq,为0的地方为padding
- :return:
- """
- # attention
- residual = x
- x = self.attn_layer_norm(x)
- x, _ = self.self_attn(query=x,
- key=x,
- value=x,
- key_mask=mask)
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
-
- # ffn
- residual = x
- x = self.ffn_layer_norm(x)
- x = self.ffn(x)
- x = residual + x
-
- return x
-
-
- class TransformerSeq2SeqEncoder(Seq2SeqEncoder):
- """
- 基于Transformer的Encoder
-
- :param embed: encoder输入token的embedding
- :param nn.Module pos_embed: position embedding
- :param int num_layers: 多少层的encoder
- :param int d_model: 输入输出的维度
- :param int n_head: 多少个head
- :param int dim_ff: FFN中间的维度大小
- :param float dropout: Attention和FFN的dropout大小
- """
- def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None,
- num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1):
- super(TransformerSeq2SeqEncoder, self).__init__()
- self.embed = get_embeddings(embed)
- self.embed_scale = math.sqrt(d_model)
- self.pos_embed = pos_embed
- self.num_layers = num_layers
- self.d_model = d_model
- self.n_head = n_head
- self.dim_ff = dim_ff
- self.dropout = dropout
-
- self.input_fc = nn.Linear(self.embed.embedding_dim, d_model)
- self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout)
- for _ in range(num_layers)])
- self.layer_norm = LayerNorm(d_model)
-
- def forward(self, tokens, seq_len):
- """
-
- :param tokens: batch x max_len
- :param seq_len: [batch]
- :return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding)
- """
- x = self.embed(tokens) * self.embed_scale # batch, seq, dim
- batch_size, max_src_len, _ = x.size()
- device = x.device
- if self.pos_embed is not None:
- position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device)
- x += self.pos_embed(position)
-
- x = self.input_fc(x)
- x = F.dropout(x, p=self.dropout, training=self.training)
-
- encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len)
- encoder_mask = encoder_mask.to(device)
-
- for layer in self.layer_stacks:
- x = layer(x, encoder_mask)
-
- x = self.layer_norm(x)
-
- return x, encoder_mask
-
-
- class LSTMSeq2SeqEncoder(Seq2SeqEncoder):
- """
- LSTM的Encoder
-
- :param embed: encoder的token embed
- :param int num_layers: 多少层
- :param int hidden_size: LSTM隐藏层、输出的大小
- :param float dropout: LSTM层之间的Dropout是多少
- :param bool bidirectional: 是否使用双向
- """
- def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3,
- hidden_size = 400, dropout = 0.3, bidirectional=True):
- super().__init__()
- self.embed = get_embeddings(embed)
- self.num_layers = num_layers
- self.dropout = dropout
- self.hidden_size = hidden_size
- self.bidirectional = bidirectional
- hidden_size = hidden_size//2 if bidirectional else hidden_size
- self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional,
- batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers)
-
- def forward(self, tokens, seq_len):
- """
-
- :param torch.LongTensor tokens: bsz x max_len
- :param torch.LongTensor seq_len: bsz
- :return: (output, (hidden, cell)), encoder_mask
- output: bsz x max_len x hidden_size,
- hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态
- encoder_mask: bsz x max_len, 为0的地方是padding
- """
- x = self.embed(tokens)
- device = x.device
- x, (final_hidden, final_cell) = self.lstm(x, seq_len)
- encoder_mask = seq_len_to_mask(seq_len).to(device)
-
- # x: batch,seq_len,dim; h/c: num_layers*2,batch,dim
-
- if self.bidirectional:
- final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input
- final_cell = self.concat_bidir(final_cell)
-
- return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return
-
- def concat_bidir(self, input):
- output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2)
- return output.reshape(self.num_layers, input.size(1), -1)
|