Browse Source

1.增加torch modules;torch models

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
004c344e4c
43 changed files with 4891 additions and 2 deletions
  1. +1
    -0
      fastNLP/core/collators/padders/get_padder.py
  2. +1
    -1
      fastNLP/embeddings/torch/char_embedding.py
  3. +0
    -0
      fastNLP/models/__init__.py
  4. +21
    -0
      fastNLP/models/torch/__init__.py
  5. +475
    -0
      fastNLP/models/torch/biaffine_parser.py
  6. +92
    -0
      fastNLP/models/torch/cnn_text_classification.py
  7. +81
    -0
      fastNLP/models/torch/seq2seq_generator.py
  8. +196
    -0
      fastNLP/models/torch/seq2seq_model.py
  9. +271
    -0
      fastNLP/models/torch/sequence_labeling.py
  10. +26
    -0
      fastNLP/modules/torch/__init__.py
  11. +321
    -0
      fastNLP/modules/torch/attention.py
  12. +15
    -0
      fastNLP/modules/torch/decoder/__init__.py
  13. +354
    -0
      fastNLP/modules/torch/decoder/crf.py
  14. +416
    -0
      fastNLP/modules/torch/decoder/seq2seq_decoder.py
  15. +145
    -0
      fastNLP/modules/torch/decoder/seq2seq_state.py
  16. +24
    -0
      fastNLP/modules/torch/dropout.py
  17. +17
    -1
      fastNLP/modules/torch/encoder/__init__.py
  18. +87
    -0
      fastNLP/modules/torch/encoder/conv_maxpool.py
  19. +193
    -0
      fastNLP/modules/torch/encoder/seq2seq_encoder.py
  20. +166
    -0
      fastNLP/modules/torch/encoder/star_transformer.py
  21. +43
    -0
      fastNLP/modules/torch/encoder/transformer.py
  22. +303
    -0
      fastNLP/modules/torch/encoder/variational_rnn.py
  23. +6
    -0
      fastNLP/modules/torch/generator/__init__.py
  24. +536
    -0
      fastNLP/modules/torch/generator/seq2seq_generator.py
  25. +1
    -0
      tests/helpers/data/modules/decoder/crf.json
  26. +0
    -0
      tests/models/__init__.py
  27. +0
    -0
      tests/models/torch/__init__.py
  28. +142
    -0
      tests/models/torch/model_runner.py
  29. +91
    -0
      tests/models/torch/test_biaffine_parser.py
  30. +33
    -0
      tests/models/torch/test_cnn_text_classification.py
  31. +73
    -0
      tests/models/torch/test_seq2seq_generator.py
  32. +113
    -0
      tests/models/torch/test_seq2seq_model.py
  33. +47
    -0
      tests/models/torch/test_sequence_labeling.py
  34. +0
    -0
      tests/modules/torch/__init__.py
  35. +0
    -0
      tests/modules/torch/decoder/__init__.py
  36. +327
    -0
      tests/modules/torch/decoder/test_CRF.py
  37. +49
    -0
      tests/modules/torch/decoder/test_seq2seq_decoder.py
  38. +0
    -0
      tests/modules/torch/encoder/__init__.py
  39. +33
    -0
      tests/modules/torch/encoder/test_seq2seq_encoder.py
  40. +18
    -0
      tests/modules/torch/encoder/test_star_transformer.py
  41. +27
    -0
      tests/modules/torch/encoder/test_variational_rnn.py
  42. +1
    -0
      tests/modules/torch/generator/__init__.py
  43. +146
    -0
      tests/modules/torch/generator/test_seq2seq_generator.py

+ 1
- 0
fastNLP/core/collators/padders/get_padder.py View File

@@ -138,6 +138,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \
"information please set logger's level to DEBUG."
if must_pad:
logger.error(msg)
raise type(e)(msg=msg)
logger.debug(msg)
return NullPadder()


+ 1
- 1
fastNLP/embeddings/torch/char_embedding.py View File

@@ -16,6 +16,7 @@ if _NEED_IMPORT_TORCH:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LSTM

from .embedding import TokenEmbedding
from .static_embedding import StaticEmbedding
@@ -23,7 +24,6 @@ from .utils import _construct_char_vocab_from_vocab
from .utils import get_embeddings
from ...core import logger
from ...core.vocabulary import Vocabulary
from ...modules.torch.encoder.lstm import LSTM


class CNNCharEmbedding(TokenEmbedding):


+ 0
- 0
fastNLP/models/__init__.py View File


+ 21
- 0
fastNLP/models/torch/__init__.py View File

@@ -0,0 +1,21 @@
__all__ = [
'BiaffineParser',

"CNNText",

"SequenceGeneratorModel",

"Seq2SeqModel",
'TransformerSeq2SeqModel',
'LSTMSeq2SeqModel',

"SeqLabeling",
"AdvSeqLabel",
"BiLSTMCRF",
]

from .biaffine_parser import BiaffineParser
from .cnn_text_classification import CNNText
from .seq2seq_generator import SequenceGeneratorModel
from .seq2seq_model import *
from .sequence_labeling import *

+ 475
- 0
fastNLP/models/torch/biaffine_parser.py View File

@@ -0,0 +1,475 @@
r"""
Biaffine Dependency Parser 的 Pytorch 实现.
"""
__all__ = [
"BiaffineParser",
"GraphParser"
]

from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ...core.utils import seq_len_to_mask
from ...embeddings.torch.utils import get_embeddings
from ...modules.torch.dropout import TimestepDropout
from ...modules.torch.encoder.transformer import TransformerEncoder
from ...modules.torch.encoder.variational_rnn import VarLSTM


def _mst(scores):
r"""
with some modification to support parser output for MST decoding
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692
"""
length = scores.shape[0]
min_score = scores.min() - 1
eye = np.eye(length)
scores = scores * (1 - eye) + min_score * eye
heads = np.argmax(scores, axis=1)
heads[0] = 0
tokens = np.arange(1, length)
roots = np.where(heads[tokens] == 0)[0] + 1
if len(roots) < 1:
root_scores = scores[tokens, 0]
head_scores = scores[tokens, heads[tokens]]
new_root = tokens[np.argmax(root_scores / head_scores)]
heads[new_root] = 0
elif len(roots) > 1:
root_scores = scores[roots, 0]
scores[roots, 0] = 0
new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1
new_root = roots[np.argmin(
scores[roots, new_heads] / root_scores)]
heads[roots] = new_heads
heads[new_root] = 0
edges = defaultdict(set)
vertices = set((0,))
for dep, head in enumerate(heads[tokens]):
vertices.add(dep + 1)
edges[head].add(dep + 1)
for cycle in _find_cycle(vertices, edges):
dependents = set()
to_visit = set(cycle)
while len(to_visit) > 0:
node = to_visit.pop()
if node not in dependents:
dependents.add(node)
to_visit.update(edges[node])
cycle = np.array(list(cycle))
old_heads = heads[cycle]
old_scores = scores[cycle, old_heads]
non_heads = np.array(list(dependents))
scores[np.repeat(cycle, len(non_heads)),
np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score
new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1
new_scores = scores[cycle, new_heads] / old_scores
change = np.argmax(new_scores)
changed_cycle = cycle[change]
old_head = old_heads[change]
new_head = new_heads[change]
heads[changed_cycle] = new_head
edges[new_head].add(changed_cycle)
edges[old_head].remove(changed_cycle)
return heads


def _find_cycle(vertices, edges):
r"""
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py
"""
_index = 0
_stack = []
_indices = {}
_lowlinks = {}
_onstack = defaultdict(lambda: False)
_SCCs = []
def _strongconnect(v):
nonlocal _index
_indices[v] = _index
_lowlinks[v] = _index
_index += 1
_stack.append(v)
_onstack[v] = True
for w in edges[v]:
if w not in _indices:
_strongconnect(w)
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w])
elif _onstack[w]:
_lowlinks[v] = min(_lowlinks[v], _indices[w])
if _lowlinks[v] == _indices[v]:
SCC = set()
while True:
w = _stack.pop()
_onstack[w] = False
SCC.add(w)
if not (w != v):
break
_SCCs.append(SCC)
for v in vertices:
if v not in _indices:
_strongconnect(v)
return [SCC for SCC in _SCCs if len(SCC) > 1]


class GraphParser(nn.Module):
r"""
基于图的parser base class, 支持贪婪解码和最大生成树解码
"""
def __init__(self):
super(GraphParser, self).__init__()
@staticmethod
def greedy_decoder(arc_matrix, mask=None):
r"""
贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树

:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵
:param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0.
若为 ``None`` 时, 默认为全1向量. Default: ``None``
:return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果
"""
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
flip_mask = mask.eq(False)
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if mask is not None:
heads *= mask.long()
return heads
@staticmethod
def mst_decoder(arc_matrix, mask=None):
r"""
用最大生成树算法, 计算parsing结果, 保证输出合法的树结构

:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵
:param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0.
若为 ``None`` 时, 默认为全1向量. Default: ``None``
:return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果
"""
batch_size, seq_len, _ = arc_matrix.shape
matrix = arc_matrix.clone()
ans = matrix.new_zeros(batch_size, seq_len).long()
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len
for i, graph in enumerate(matrix):
len_i = lens[i]
ans[i, :len_i] = torch.as_tensor(_mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device)
if mask is not None:
ans *= mask.long()
return ans


class ArcBiaffine(nn.Module):
r"""
Biaffine Dependency Parser 的子模块, 用于构建预测边的图

"""
def __init__(self, hidden_size, bias=True):
r"""
:param hidden_size: 输入的特征维度
:param bias: 是否使用bias. Default: ``True``
"""
super(ArcBiaffine, self).__init__()
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True)
self.has_bias = bias
if self.has_bias:
self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True)
else:
self.register_parameter("bias", None)

def forward(self, head, dep):
r"""

:param head: arc-head tensor [batch, length, hidden]
:param dep: arc-dependent tensor [batch, length, hidden]
:return output: tensor [bacth, length, length]
"""
output = dep.matmul(self.U)
output = output.bmm(head.transpose(-1, -2))
if self.has_bias:
output = output + head.matmul(self.bias).unsqueeze(1)
return output


class LabelBilinear(nn.Module):
r"""
Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图

"""
def __init__(self, in1_features, in2_features, num_label, bias=True):
r"""
:param in1_features: 输入的特征1维度
:param in2_features: 输入的特征2维度
:param num_label: 边类别的个数
:param bias: 是否使用bias. Default: ``True``
"""
super(LabelBilinear, self).__init__()
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias)
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False)
def forward(self, x1, x2):
r"""

:param x1: [batch, seq_len, hidden] 输入特征1, 即label-head
:param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep
:return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图
"""
output = self.bilinear(x1, x2)
output = output + self.lin(torch.cat([x1, x2], dim=2))
return output


class BiaffineParser(GraphParser):
r"""
Biaffine Dependency Parser 实现.
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ .

"""
def __init__(self,
embed,
pos_vocab_size,
pos_emb_dim,
num_label,
rnn_layers=1,
rnn_hidden_size=200,
arc_mlp_size=100,
label_mlp_size=100,
dropout=0.3,
encoder='lstm',
use_greedy_infer=False):
r"""
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param pos_vocab_size: part-of-speech 词典大小
:param pos_emb_dim: part-of-speech 向量维度
:param num_label: 边的类别个数
:param rnn_layers: rnn encoder的层数
:param rnn_hidden_size: rnn encoder 的隐状态维度
:param arc_mlp_size: 边预测的MLP维度
:param label_mlp_size: 类别预测的MLP维度
:param dropout: dropout概率.
:param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm
:param use_greedy_infer: 是否在inference时使用贪心算法.
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False``
"""
super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size
word_hid_dim = pos_hid_dim = rnn_hidden_size
self.word_embedding = get_embeddings(embed)
word_emb_dim = self.word_embedding.embedding_dim
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim)
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim)
self.word_norm = nn.LayerNorm(word_hid_dim)
self.pos_norm = nn.LayerNorm(pos_hid_dim)
self.encoder_name = encoder
self.max_len = 512
if encoder == 'var-lstm':
self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
input_dropout=dropout,
hidden_dropout=dropout,
bidirectional=True)
elif encoder == 'lstm':
self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
dropout=dropout,
bidirectional=True)
elif encoder == 'transformer':
n_head = 16
d_k = d_v = int(rnn_out_size / n_head)
if (d_k * n_head) != rnn_out_size:
raise ValueError('Unsupported rnn_out_size: {} for transformer'.format(rnn_out_size))
self.position_emb = nn.Embedding(num_embeddings=self.max_len,
embedding_dim=rnn_out_size, )
self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size,
n_head=n_head, dim_ff=1024, dropout=dropout)
else:
raise ValueError('Unsupported encoder type: {}'.format(encoder))
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2),
nn.ELU(),
TimestepDropout(p=dropout), )
self.arc_mlp_size = arc_mlp_size
self.label_mlp_size = label_mlp_size
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.use_greedy_infer = use_greedy_infer
self.reset_parameters()
self.dropout = dropout
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Embedding):
continue
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 0.1)
nn.init.constant_(m.bias, 0)
else:
for p in m.parameters():
nn.init.normal_(p, 0, 0.1)
def forward(self, words1, words2, seq_len, target1=None):
r"""模型forward阶段

:param words1: [batch_size, seq_len] 输入word序列
:param words2: [batch_size, seq_len] 输入pos序列
:param seq_len: [batch_size, seq_len] 输入序列长度
:param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效,
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器
Default: ``None``
:return dict: parsing
结果::

pred1: [batch_size, seq_len, seq_len] 边预测logits
pred2: [batch_size, seq_len, num_label] label预测logits
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测

"""
# prepare embeddings
batch_size, length = words1.shape
# print('forward {} {}'.format(batch_size, seq_len))
# get sequence mask
mask = seq_len_to_mask(seq_len, max_len=length).long()
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0]
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1]
word, pos = self.word_fc(word), self.pos_fc(pos)
word, pos = self.word_norm(word), self.pos_norm(pos)
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
# encoder, extract features
if self.encoder_name.endswith('lstm'):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
x = x[sort_idx]
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True)
feat, _ = self.encoder(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
feat = feat[unsort_idx]
else:
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None, :]
x = x + self.position_emb(seq_range)
feat = self.encoder(x, mask.float())
# for arc biaffine
# mlp, reduce dim
feat = self.mlp(feat)
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
arc_dep, arc_head = feat[:, :, :arc_sz], feat[:, :, arc_sz:2 * arc_sz]
label_dep, label_head = feat[:, :, 2 * arc_sz:2 * arc_sz + label_sz], feat[:, :, 2 * arc_sz + label_sz:]
# biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
# use gold or predicted arc to predict label
if target1 is None or not self.training:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self.greedy_decoder(arc_pred, mask)
else:
heads = self.mst_decoder(arc_pred, mask)
head_pred = heads
else:
assert self.training # must be training mode
if target1 is None:
heads = self.greedy_decoder(arc_pred, mask)
head_pred = heads
else:
head_pred = None
heads = target1
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'pred1': arc_pred, 'pred2': label_pred}
if head_pred is not None:
res_dict['pred3'] = head_pred
return res_dict

def train_step(self, words1, words2, seq_len, target1, target2):
res = self(words1, words2, seq_len, target1)
arc_pred = res['pred1']
label_pred = res['pred2']
loss = self.loss(pred1=arc_pred, pred2=label_pred, target1=target1, target2=target2, seq_len=seq_len)
return {'loss': loss}
@staticmethod
def loss(pred1, pred2, target1, target2, seq_len):
r"""
计算parser的loss

:param pred1: [batch_size, seq_len, seq_len] 边预测logits
:param pred2: [batch_size, seq_len, num_label] label预测logits
:param target1: [batch_size, seq_len] 真实边的标注
:param target2: [batch_size, seq_len] 真实类别的标注
:param seq_len: [batch_size, seq_len] 真实目标的长度
:return loss: scalar
"""
batch_size, length, _ = pred1.shape
mask = seq_len_to_mask(seq_len, max_len=length)
flip_mask = (mask.eq(False))
_arc_pred = pred1.clone()
_arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf'))
arc_logits = F.log_softmax(_arc_pred, dim=2)
label_logits = F.log_softmax(pred2, dim=2)
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1)
child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, target1]
label_loss = label_logits[batch_index, child_index, target2]
arc_loss = arc_loss.masked_fill(flip_mask, 0)
label_loss = label_loss.masked_fill(flip_mask, 0)
arc_nll = -arc_loss.mean()
label_nll = -label_loss.mean()
return arc_nll + label_nll
def evaluate_step(self, words1, words2, seq_len):
r"""模型预测API

:param words1: [batch_size, seq_len] 输入word序列
:param words2: [batch_size, seq_len] 输入pos序列
:param seq_len: [batch_size, seq_len] 输入序列长度
:return dict: parsing
结果::

pred1: [batch_size, seq_len] heads的预测结果
pred2: [batch_size, seq_len, num_label] label预测logits

"""
res = self(words1, words2, seq_len)
output = {}
output['pred1'] = res.pop('pred3')
_, label_pred = res.pop('pred2').max(2)
output['pred2'] = label_pred
return output


+ 92
- 0
fastNLP/models/torch/cnn_text_classification.py View File

@@ -0,0 +1,92 @@
r"""
.. todo::
doc
"""

__all__ = [
"CNNText"
]

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...core.utils import seq_len_to_mask
from ...embeddings.torch import embedding
from ...modules.torch import encoder


class CNNText(torch.nn.Module):
r"""
使用CNN进行文本分类的模型
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.'
"""

def __init__(self, embed,
num_classes,
kernel_nums=(30, 40, 50),
kernel_sizes=(1, 3, 5),
dropout=0.5):
r"""
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
:param int num_classes: 一共有多少类
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param float dropout: Dropout的大小
"""
super(CNNText, self).__init__()

# no support for pre-trained embedding currently
self.embed = embedding.Embedding(embed)
self.conv_pool = encoder.ConvMaxpool(
in_channels=self.embed.embedding_dim,
out_channels=kernel_nums,
kernel_sizes=kernel_sizes)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(sum(kernel_nums), num_classes)

def forward(self, words, seq_len=None):
r"""

:param torch.LongTensor words: [batch_size, seq_len],句子中word的index
:param torch.LongTensor seq_len: [batch,] 每个句子的长度
:param target: 每个 sample 的目标值。

:return output:
"""
x = self.embed(words) # [N,L] -> [N,L,C]
if seq_len is not None:
mask = seq_len_to_mask(seq_len)
x = self.conv_pool(x, mask)
else:
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]
res = {'pred': x}
return res

def train_step(self, words, target, seq_len=None):
"""

:param words:
:param target:
:param seq_len:
:return:
"""
res = self(words, seq_len)
x = res['pred']
loss = F.cross_entropy(x, target)
return {'loss': loss}

def evaluate_step(self, words, seq_len=None):
r"""
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index
:param torch.LongTensor seq_len: [batch,] 每个句子的长度

:return predict: dict of torch.LongTensor, [batch_size, ]
"""
output = self(words, seq_len)
_, predict = output['pred'].max(dim=1)
return {'pred': predict}

+ 81
- 0
fastNLP/models/torch/seq2seq_generator.py View File

@@ -0,0 +1,81 @@
r"""undocumented"""

import torch
from torch import nn
import torch.nn.functional as F
from fastNLP import seq_len_to_mask
from .seq2seq_model import Seq2SeqModel
from ...modules.torch.generator.seq2seq_generator import SequenceGenerator


__all__ = ['SequenceGeneratorModel']


class SequenceGeneratorModel(nn.Module):
"""
通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict
函数会被调用。

"""

def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0,
num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0,
repetition_penalty=1, length_penalty=1.0, pad_token_id=0):
"""

:param Seq2SeqModel seq2seq_model: 序列到序列模型
:param int,None bos_token_id: 句子开头的token id
:param int,None eos_token_id: 句子结束的token id
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
:param int num_beams: beam search的大小
:param bool do_sample: 是否通过采样的方式生成
:param float temperature: 只有在do_sample为True才有意义
:param int top_k: 只从top_k中采样
:param float top_p: 只从top_p的token中采样,nucles sample
:param float repetition_penalty: 多大程度上惩罚重复的token
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充
"""
super().__init__()
self.seq2seq_model = seq2seq_model
self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id)

def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None):
"""
透传调用seq2seq_model的forward。

:param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor tgt_tokens: bsz x max_len'
:param torch.LongTensor src_seq_len: bsz
:param torch.LongTensor tgt_seq_len: bsz
:return:
"""
return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len)

def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None):
res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len)
pred = res['pred']
if tgt_seq_len is not None:
mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1))
tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100)
loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens)
return {'loss': loss}

def evaluate_step(self, src_tokens, src_seq_len=None):
"""
给定source的内容,输出generate的内容。

:param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor src_seq_len: bsz
:return:
"""
state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len)
result = self.generator.generate(state)
return {'pred': result}

+ 196
- 0
fastNLP/models/torch/seq2seq_model.py View File

@@ -0,0 +1,196 @@
r"""
主要包含组成Sequence-to-Sequence的model

"""

import torch
from torch import nn
import torch.nn.functional as F

from fastNLP import seq_len_to_mask
from ...embeddings.torch.utils import get_embeddings
from ...embeddings.torch.utils import get_sinusoid_encoding_table
from ...modules.torch.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder
from ...modules.torch.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder


__all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel']


class Seq2SeqModel(nn.Module):
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder):
"""
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。如果需要使用该模型
进行生成,需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 中。在本模型中,forward()会把encoder后的
结果传入到decoder中,并将decoder的输出output出来。

:param encoder: Seq2SeqEncoder 对象,需要实现对应的forward()函数,接受两个参数,第一个为bsz x max_len的source tokens, 第二个为
bsz的source的长度;需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len
为1的地方需要被attend。如果encoder的输出或者输入有变化,可以重载本模型的prepare_state()函数或者forward()函数
:param decoder: Seq2SeqDecoder 对象,需要实现init_state()函数,输出为两个参数,第一个为bsz x max_len x hidden_size是
encoder的输出; 第二个为bsz x max_len,为encoder输出的mask,为0的地方为pad。若decoder需要更多输入,请重载当前模型的
prepare_state()或forward()函数
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder

def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None):
"""

:param torch.LongTensor src_tokens: source的token
:param torch.LongTensor tgt_tokens: target的token
:param torch.LongTensor src_seq_len: src的长度
:param torch.LongTensor tgt_seq_len: target的长度,默认用不上
:return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size
"""
state = self.prepare_state(src_tokens, src_seq_len)
decoder_output = self.decoder(tgt_tokens, state)
if isinstance(decoder_output, torch.Tensor):
return {'pred': decoder_output}
elif isinstance(decoder_output, (tuple, list)):
return {'pred': decoder_output[0]}
else:
raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}")

def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None):
res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len)
pred = res['pred']
if tgt_seq_len is not None:
mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1))
tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100)
loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens)
return {'loss': loss}

def prepare_state(self, src_tokens, src_seq_len=None):
"""
调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state

:param src_tokens:
:param src_seq_len:
:return:
"""
encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len)
state = self.decoder.init_state(encoder_output, encoder_mask)
return state

@classmethod
def build_model(cls, *args, **kwargs):
"""
需要实现本方法来进行Seq2SeqModel的初始化

:return:
"""
raise NotImplemented


class TransformerSeq2SeqModel(Seq2SeqModel):
"""
Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化

"""

def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)

@classmethod
def build_model(cls, src_embed, tgt_embed=None,
pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1,
bind_encoder_decoder_embed=False,
bind_decoder_input_output_embed=True):
"""
初始化一个TransformerSeq2SeqModel

:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为
True,则不要输入该值
:param str pos_embed: 支持sin, learned两种
:param int max_position: 最大支持长度
:param int num_layers: encoder和decoder的层数
:param int d_model: encoder和decoder输入输出的大小
:param int n_head: encoder和decoder的head的数量
:param int dim_ff: encoder和decoder中FFN中间映射的维度
:param float dropout: Attention和FFN dropout的大小
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重
:return: TransformerSeq2SeqModel
"""
if bind_encoder_decoder_embed and tgt_embed is not None:
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.")

src_embed = get_embeddings(src_embed)

if bind_encoder_decoder_embed:
tgt_embed = src_embed
else:
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`"
tgt_embed = get_embeddings(tgt_embed)

if pos_embed == 'sin':
encoder_pos_embed = nn.Embedding.from_pretrained(
get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0),
freeze=True) # 这里规定0是padding
deocder_pos_embed = nn.Embedding.from_pretrained(
get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0),
freeze=True) # 这里规定0是padding
elif pos_embed == 'learned':
encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0)
deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1)
else:
raise ValueError("pos_embed only supports sin or learned.")

encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed,
num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff,
dropout=dropout)
decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed,
d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff,
dropout=dropout,
bind_decoder_input_output_embed=bind_decoder_input_output_embed)

return cls(encoder, decoder)


class LSTMSeq2SeqModel(Seq2SeqModel):
"""
使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model

"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)

@classmethod
def build_model(cls, src_embed, tgt_embed=None,
num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True,
attention=True, bind_encoder_decoder_embed=False,
bind_decoder_input_output_embed=True):
"""

:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为
True,则不要输入该值
:param int num_layers: Encoder和Decoder的层数
:param int hidden_size: encoder和decoder的隐藏层大小
:param float dropout: 每层之间的Dropout的大小
:param bool bidirectional: encoder是否使用双向LSTM
:param bool attention: decoder是否使用attention attend encoder在所有时刻的状态
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重
:return: LSTMSeq2SeqModel
"""
if bind_encoder_decoder_embed and tgt_embed is not None:
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.")

src_embed = get_embeddings(src_embed)

if bind_encoder_decoder_embed:
tgt_embed = src_embed
else:
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`"
tgt_embed = get_embeddings(tgt_embed)

encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers,
hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional)
decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size,
dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed,
attention=attention)
return cls(encoder, decoder)

+ 271
- 0
fastNLP/models/torch/sequence_labeling.py View File

@@ -0,0 +1,271 @@
r"""
本模块实现了几种序列标注模型
"""
__all__ = [
"SeqLabeling",
"AdvSeqLabel",
"BiLSTMCRF"
]

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...core.utils import seq_len_to_mask
from ...embeddings.torch.utils import get_embeddings
from ...modules.torch.decoder import ConditionalRandomField
from ...modules.torch.encoder import LSTM
from ...modules.torch import decoder, encoder
from ...modules.torch.decoder.crf import allowed_transitions


class BiLSTMCRF(nn.Module):
r"""
结构为embedding + BiLSTM + FC + Dropout + CRF.

"""
def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5,
target_vocab=None):
r"""
:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100)
:param num_classes: 一共多少个类
:param num_layers: BiLSTM的层数
:param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向)
:param dropout: dropout的概率,0为不dropout
:param target_vocab: Vocabulary对象,target与index的对应关系。如果传入该值,将自动避免非法的解码序列。
"""
super().__init__()
self.embed = get_embeddings(embed)

if num_layers>1:
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True,
batch_first=True, dropout=dropout)
else:
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True,
batch_first=True)

self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_size*2, num_classes)

trans = None
if target_vocab is not None:
assert len(target_vocab)==num_classes, "The number of classes should be same with the length of target vocabulary."
trans = allowed_transitions(target_vocab.idx2word, include_start_end=True)

self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans)

def forward(self, words, seq_len=None, target=None):
words = self.embed(words)
feats, _ = self.lstm(words, seq_len=seq_len)
feats = self.fc(feats)
feats = self.dropout(feats)
logits = F.log_softmax(feats, dim=-1)
mask = seq_len_to_mask(seq_len)
if target is None:
pred, _ = self.crf.viterbi_decode(logits, mask)
return {'pred':pred}
else:
loss = self.crf(logits, target, mask).mean()
return {'loss':loss}

def train_step(self, words, seq_len, target):
return self(words, seq_len, target)

def evaluate_step(self, words, seq_len):
return self(words, seq_len)


class SeqLabeling(nn.Module):
r"""
一个基础的Sequence labeling的模型。
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。
"""
def __init__(self, embed, hidden_size, num_classes):
r"""
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding
:param int hidden_size: LSTM隐藏层的大小
:param int num_classes: 一共有多少类
"""
super(SeqLabeling, self).__init__()
self.embedding = get_embeddings(embed)
self.rnn = encoder.LSTM(self.embedding.embedding_dim, hidden_size)
self.fc = nn.Linear(hidden_size, num_classes)
self.crf = decoder.ConditionalRandomField(num_classes)

def forward(self, words, seq_len):
r"""
:param torch.LongTensor words: [batch_size, max_len],序列的index
:param torch.LongTensor seq_len: [batch_size,], 这个序列的长度
:return
"""
x = self.embedding(words)
# [batch_size, max_len, word_emb_dim]
x, _ = self.rnn(x, seq_len)
# [batch_size, max_len, hidden_size * direction]
x = self.fc(x)
return {'pred': x}
# [batch_size, max_len, num_classes]

def train_step(self, words, seq_len, target):
res = self(words, seq_len)
pred = res['pred']
mask = seq_len_to_mask(seq_len, max_len=target.size(1))
return {'loss': self._internal_loss(pred, target, mask)}

def evaluate_step(self, words, seq_len):
r"""
用于在预测时使用

:param torch.LongTensor words: [batch_size, max_len]
:param torch.LongTensor seq_len: [batch_size,]
:return: {'pred': xx}, [batch_size, max_len]
"""
mask = seq_len_to_mask(seq_len, max_len=words.size(1))

res = self(words, seq_len)
pred = res['pred']
# [batch_size, max_len, num_classes]
pred = self._decode(pred, mask)
return {'pred': pred}

def _internal_loss(self, x, y, mask):
r"""
Negative log likelihood loss.
:param x: Tensor, [batch_size, max_len, tag_size]
:param y: Tensor, [batch_size, max_len]
:return loss: a scalar Tensor

"""
x = x.float()
y = y.long()
total_loss = self.crf(x, y, mask)
return torch.mean(total_loss)
def _decode(self, x, mask):
r"""
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
:return prediction: [batch_size, max_len]
"""
tag_seq, _ = self.crf.viterbi_decode(x, mask)
return tag_seq


class AdvSeqLabel(nn.Module):
r"""
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。
"""
def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'):
r"""
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
:param int hidden_size: LSTM的隐层大小
:param int num_classes: 有多少个类
:param float dropout: LSTM中以及DropOut层的drop概率
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S'
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。)
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况有用。
"""
super().__init__()
self.Embedding = get_embeddings(embed)
self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim)
self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2,
dropout=dropout,
bidirectional=True, batch_first=True)
self.Linear1 = nn.Linear(hidden_size * 2, hidden_size * 2 // 3)
self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3)
self.relu = torch.nn.LeakyReLU()
self.drop = torch.nn.Dropout(dropout)
self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes)
if id2words is None:
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False)
else:
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False,
allowed_transitions=allowed_transitions(id2words,
encoding_type=encoding_type))
def _decode(self, x, mask):
r"""
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
:param torch.ByteTensor mask: [batch_size, max_len]
:return torch.LongTensor, [batch_size, max_len]
"""
tag_seq, _ = self.Crf.viterbi_decode(x, mask)
return tag_seq
def _internal_loss(self, x, y, mask):
r"""
Negative log likelihood loss.
:param x: Tensor, [batch_size, max_len, tag_size]
:param y: Tensor, [batch_size, max_len]
:param mask: Tensor, [batch_size, max_len]
:return loss: a scalar Tensor

"""
x = x.float()
y = y.long()
total_loss = self.Crf(x, y, mask)
return torch.mean(total_loss)
def forward(self, words, seq_len, target=None):
r"""
:param torch.LongTensor words: [batch_size, mex_len]
:param torch.LongTensor seq_len:[batch_size, ]
:param torch.LongTensor target: [batch_size, max_len]
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training.
"""
words = words.long()
seq_len = seq_len.long()
mask = seq_len_to_mask(seq_len, max_len=words.size(1))

target = target.long() if target is not None else None
if next(self.parameters()).is_cuda:
words = words.cuda()

x = self.Embedding(words)
x = self.norm1(x)
# [batch_size, max_len, word_emb_dim]
x, _ = self.Rnn(x, seq_len=seq_len)
x = self.Linear1(x)
x = self.norm2(x)
x = self.relu(x)
x = self.drop(x)
x = self.Linear2(x)
if target is not None:
return {"loss": self._internal_loss(x, target, mask)}
else:
return {"pred": self._decode(x, mask)}
def train_step(self, words, seq_len, target):
r"""
:param torch.LongTensor words: [batch_size, mex_len]
:param torch.LongTensor seq_len: [batch_size, ]
:param torch.LongTensor target: [batch_size, max_len], 目标
:return torch.Tensor: a scalar loss
"""
return self(words, seq_len, target)
def evaluate_step(self, words, seq_len):
r"""
:param torch.LongTensor words: [batch_size, mex_len]
:param torch.LongTensor seq_len: [batch_size, ]
:return torch.LongTensor: [batch_size, max_len]
"""
return self(words, seq_len)

+ 26
- 0
fastNLP/modules/torch/__init__.py View File

@@ -0,0 +1,26 @@
__all__ = [
'ConditionalRandomField',
'allowed_transitions',
"State",
"Seq2SeqDecoder",
"LSTMSeq2SeqDecoder",
"TransformerSeq2SeqDecoder",

"LSTM",
"Seq2SeqEncoder",
"TransformerSeq2SeqEncoder",
"LSTMSeq2SeqEncoder",
"StarTransformer",
"VarRNN",
"VarLSTM",
"VarGRU",

'SequenceGenerator',

"TimestepDropout",
]

from .decoder import *
from .encoder import *
from .generator import *
from .dropout import TimestepDropout

+ 321
- 0
fastNLP/modules/torch/attention.py View File

@@ -0,0 +1,321 @@
r"""undocumented"""

__all__ = [
"MultiHeadAttention",
"BiAttention",
"SelfAttention",
]

import math

import torch
import torch.nn.functional as F
from torch import nn

from .decoder.seq2seq_state import TransformerState


class DotAttention(nn.Module):
r"""
Transformer当中的DotAttention
"""

def __init__(self, key_size, value_size, dropout=0.0):
super(DotAttention, self).__init__()
self.key_size = key_size
self.value_size = value_size
self.scale = math.sqrt(key_size)
self.drop = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)

def forward(self, Q, K, V, mask_out=None):
r"""

:param Q: [..., seq_len_q, key_size]
:param K: [..., seq_len_k, key_size]
:param V: [..., seq_len_k, value_size]
:param mask_out: [..., 1, seq_len] or [..., seq_len_q, seq_len_k]
"""
output = torch.matmul(Q, K.transpose(-1, -2)) / self.scale
if mask_out is not None:
output.masked_fill_(mask_out, -1e9)
output = self.softmax(output)
output = self.drop(output)
return torch.matmul(output, V)


class MultiHeadAttention(nn.Module):
"""
Attention is all you need中提到的多头注意力

"""
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.n_head = n_head
self.dropout = dropout
self.head_dim = d_model // n_head
self.layer_idx = layer_idx
assert d_model % n_head == 0, "d_model should be divisible by n_head"
self.scaling = self.head_dim ** -0.5

self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)

self.reset_parameters()

def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None):
"""

:param query: batch x seq x dim
:param key: batch x seq x dim
:param value: batch x seq x dim
:param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的
:param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1
:param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。
:return:
"""
assert key.size() == value.size()
if state is not None:
assert self.layer_idx is not None
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()

q = self.q_proj(query) # batch x seq x dim
q *= self.scaling
k = v = None
prev_k = prev_v = None

# 从state中取kv
if isinstance(state, TransformerState): # 说明此时在inference阶段
if qkv_same: # 此时在decoder self attention
prev_k = state.decoder_prev_key[self.layer_idx]
prev_v = state.decoder_prev_value[self.layer_idx]
else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可
k = state.encoder_key[self.layer_idx]
v = state.encoder_value[self.layer_idx]

if k is None:
k = self.k_proj(key)
v = self.v_proj(value)

if prev_k is not None:
k = torch.cat((prev_k, k), dim=1)
v = torch.cat((prev_v, v), dim=1)

# 更新state
if isinstance(state, TransformerState):
if qkv_same:
state.decoder_prev_key[self.layer_idx] = k
state.decoder_prev_value[self.layer_idx] = v
else:
state.encoder_key[self.layer_idx] = k
state.encoder_value[self.layer_idx] = v

# 开始计算attention
batch_size, q_len, d_model = query.size()
k_len, v_len = k.size(1), v.size(1)
q = q.reshape(batch_size, q_len, self.n_head, self.head_dim)
k = k.reshape(batch_size, k_len, self.n_head, self.head_dim)
v = v.reshape(batch_size, v_len, self.n_head, self.head_dim)

attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head
if key_mask is not None:
_key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1
attn_weights = attn_weights.masked_fill(_key_mask, -float('inf'))

if attn_mask is not None:
_attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head
attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf'))

attn_weights = F.softmax(attn_weights, dim=2)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim
output = output.reshape(batch_size, q_len, -1)
output = self.out_proj(output) # batch,q_len,dim

return output, attn_weights

def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)

def set_layer_idx(self, layer_idx):
self.layer_idx = layer_idx


class AttentionLayer(nn.Module):
def __init__(selfu, input_size, key_dim, value_dim, bias=False):
"""
可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention

:param int input_size: 输入的大小
:param int key_dim: 一般就是encoder_output输出的维度
:param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小
:param bias:
"""
super().__init__()

selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias)
selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias)

def forward(self, input, encode_outputs, encode_mask):
"""

:param input: batch_size x input_size
:param encode_outputs: batch_size x max_len x key_dim
:param encode_mask: batch_size x max_len, 为0的地方为padding
:return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的
"""

# x: bsz x encode_hidden_size
x = self.input_proj(input)

# compute attention
attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len

# don't attend over padding
if encode_mask is not None:
attn_scores = attn_scores.float().masked_fill_(
encode_mask.eq(0),
float('-inf')
).type_as(attn_scores) # FP16 support: cast to float and back

attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz

# sum weighted sources
x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size

x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1)))
return x, attn_scores


def _masked_softmax(tensor, mask):
tensor_shape = tensor.size()
reshaped_tensor = tensor.view(-1, tensor_shape[-1])

# Reshape the mask so it matches the size of the input tensor.
while mask.dim() < tensor.dim():
mask = mask.unsqueeze(1)
mask = mask.expand_as(tensor).contiguous().float()
reshaped_mask = mask.view(-1, mask.size()[-1])
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1)
result = result * reshaped_mask
# 1e-13 is added to avoid divisions by zero.
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result.view(*tensor_shape)


def _weighted_sum(tensor, weights, mask):
w_sum = weights.bmm(tensor)
while mask.dim() < w_sum.dim():
mask = mask.unsqueeze(1)
mask = mask.transpose(-1, -2)
mask = mask.expand_as(w_sum).contiguous().float()
return w_sum * mask


class BiAttention(nn.Module):
r"""
Bi Attention module

对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果

.. math::

\begin{array}{ll} \\
e_{ij} = {a}^{\mathrm{T}}_{i}{b}_{j} \\
{\hat{a}}_{i} = \sum_{j=1}^{\mathcal{l}_{b}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{b}}{\mathrm{exp}(e_{ik})}}}{b}_{j} \\
{\hat{b}}_{j} = \sum_{i=1}^{\mathcal{l}_{a}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{a}}{\mathrm{exp}(e_{ik})}}}{a}_{i} \\
\end{array}

"""

def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask):
r"""
:param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size]
:param torch.Tensor premise_mask: [batch_size, a_seq_len]
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size]
:param torch.Tensor hypothesis_mask: [batch_size, b_seq_len]
:return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size]
"""
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1)
.contiguous())

prem_hyp_attn = _masked_softmax(similarity_matrix, hypothesis_mask)
hyp_prem_attn = _masked_softmax(similarity_matrix.transpose(1, 2)
.contiguous(),
premise_mask)

attended_premises = _weighted_sum(hypothesis_batch,
prem_hyp_attn,
premise_mask)
attended_hypotheses = _weighted_sum(premise_batch,
hyp_prem_attn,
hypothesis_mask)

return attended_premises, attended_hypotheses


class SelfAttention(nn.Module):
r"""
这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_
的Self Attention Module.
"""

def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5):
r"""
:param int input_size: 输入tensor的hidden维度
:param int attention_unit: 输出tensor的hidden维度
:param int attention_hops:
:param float drop: dropout概率,默认值为0.5
"""
super(SelfAttention, self).__init__()

self.attention_hops = attention_hops
self.ws1 = nn.Linear(input_size, attention_unit, bias=False)
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False)
self.I = torch.eye(attention_hops, requires_grad=False)
self.I_origin = self.I
self.drop = nn.Dropout(drop)
self.tanh = nn.Tanh()

def _penalization(self, attention):
r"""
compute the penalization term for attention module
"""
baz = attention.size(0)
size = self.I.size()
if len(size) != 3 or size[0] != baz:
self.I = self.I_origin.expand(baz, -1, -1)
self.I = self.I.to(device=attention.device)
attention_t = torch.transpose(attention, 1, 2).contiguous()
mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)]
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5
return torch.sum(ret) / size[0]

def forward(self, input, input_origin):
r"""
:param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵
:param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容
:return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果
:return torch.Tensor output2: [1] attention惩罚项,是一个标量
"""
input = input.contiguous()
size = input.size() # [bsz, len, nhid]

input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len]
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]

y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1, 2).contiguous()
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]

attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
attention = F.softmax(attention, 2) # [baz ,hop, len]
return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid]

+ 15
- 0
fastNLP/modules/torch/decoder/__init__.py View File

@@ -0,0 +1,15 @@

__all__ = [
'ConditionalRandomField',
'allowed_transitions',

"State",

"Seq2SeqDecoder",
"LSTMSeq2SeqDecoder",
"TransformerSeq2SeqDecoder"
]

from .crf import ConditionalRandomField, allowed_transitions
from .seq2seq_state import State
from .seq2seq_decoder import LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, Seq2SeqDecoder

+ 354
- 0
fastNLP/modules/torch/decoder/crf.py View File

@@ -0,0 +1,354 @@
r"""undocumented"""

__all__ = [
"ConditionalRandomField",
"allowed_transitions"
]

from typing import Union, List

import torch
from torch import nn

from ....core.metrics.span_f1_pre_rec_metric import _get_encoding_type_from_tag_vocab, _check_tag_vocab_and_encoding_type
from ....core.vocabulary import Vocabulary


def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type:str=None, include_start_end:bool=False):
r"""
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。

:param tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN",
tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。
:param encoding_type: 支持``["bio", "bmes", "bmeso", "bioes"]``。默认为None,通过vocab自动推断
:param include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头;
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx);
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。
"""
if encoding_type is None:
encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab)
else:
encoding_type = encoding_type.lower()
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type)

pad_token = '<pad>'
unk_token = '<unk>'

if isinstance(tag_vocab, Vocabulary):
id_label_lst = list(tag_vocab.idx2word.items())
pad_token = tag_vocab.padding
unk_token = tag_vocab.unknown
else:
id_label_lst = list(tag_vocab.items())

num_tags = len(tag_vocab)
start_idx = num_tags
end_idx = num_tags + 1
allowed_trans = []
if include_start_end:
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')]
def split_tag_label(from_label):
from_label = from_label.lower()
if from_label in ['start', 'end']:
from_tag = from_label
from_label = ''
else:
from_tag = from_label[:1]
from_label = from_label[2:]
return from_tag, from_label

for from_id, from_label in id_label_lst:
if from_label in [pad_token, unk_token]:
continue
from_tag, from_label = split_tag_label(from_label)
for to_id, to_label in id_label_lst:
if to_label in [pad_token, unk_token]:
continue
to_tag, to_label = split_tag_label(to_label)
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
allowed_trans.append((from_id, to_id))
return allowed_trans


def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
r"""

:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param str from_label: 比如"PER", "LOC"等label
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param str to_label: 比如"PER", "LOC"等label
:return: bool,能否跃迁
"""
if to_tag == 'start' or from_tag == 'end':
return False
encoding_type = encoding_type.lower()
if encoding_type == 'bio':
r"""
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转
+-------+---+---+---+-------+-----+
| | B | I | O | start | end |
+-------+---+---+---+-------+-----+
| B | y | - | y | n | y |
+-------+---+---+---+-------+-----+
| I | y | - | y | n | y |
+-------+---+---+---+-------+-----+
| O | y | n | y | n | y |
+-------+---+---+---+-------+-----+
| start | y | n | y | n | n |
+-------+---+---+---+-------+-----+
| end | n | n | n | n | n |
+-------+---+---+---+-------+-----+
"""
if from_tag == 'start':
return to_tag in ('b', 'o')
elif from_tag in ['b', 'i']:
return any([to_tag in ['end', 'b', 'o'], to_tag == 'i' and from_label == to_label])
elif from_tag == 'o':
return to_tag in ['end', 'b', 'o']
else:
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag))

elif encoding_type == 'bmes':
r"""
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转
+-------+---+---+---+---+-------+-----+
| | B | M | E | S | start | end |
+-------+---+---+---+---+-------+-----+
| B | n | - | - | n | n | n |
+-------+---+---+---+---+-------+-----+
| M | n | - | - | n | n | n |
+-------+---+---+---+---+-------+-----+
| E | y | n | n | y | n | y |
+-------+---+---+---+---+-------+-----+
| S | y | n | n | y | n | y |
+-------+---+---+---+---+-------+-----+
| start | y | n | n | y | n | n |
+-------+---+---+---+---+-------+-----+
| end | n | n | n | n | n | n |
+-------+---+---+---+---+-------+-----+
"""
if from_tag == 'start':
return to_tag in ['b', 's']
elif from_tag == 'b':
return to_tag in ['m', 'e'] and from_label == to_label
elif from_tag == 'm':
return to_tag in ['m', 'e'] and from_label == to_label
elif from_tag in ['e', 's']:
return to_tag in ['b', 's', 'end']
else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag))
elif encoding_type == 'bmeso':
if from_tag == 'start':
return to_tag in ['b', 's', 'o']
elif from_tag == 'b':
return to_tag in ['m', 'e'] and from_label == to_label
elif from_tag == 'm':
return to_tag in ['m', 'e'] and from_label == to_label
elif from_tag in ['e', 's', 'o']:
return to_tag in ['b', 's', 'end', 'o']
else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag))
elif encoding_type == 'bioes':
if from_tag == 'start':
return to_tag in ['b', 's', 'o']
elif from_tag == 'b':
return to_tag in ['i', 'e'] and from_label == to_label
elif from_tag == 'i':
return to_tag in ['i', 'e'] and from_label == to_label
elif from_tag in ['e', 's', 'o']:
return to_tag in ['b', 's', 'end', 'o']
else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag))
else:
raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type))


class ConditionalRandomField(nn.Module):
r"""
条件随机场。提供 forward() 以及 viterbi_decode() 两个方法,分别用于训练与inference。

"""

def __init__(self, num_tags:int, include_start_end_trans:bool=False, allowed_transitions:List=None):
r"""
:param num_tags: 标签的数量
:param include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。
:param allowed_transitions: 内部的Tuple[from_tag_id(int),
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法
"""
super(ConditionalRandomField, self).__init__()

self.include_start_end_trans = include_start_end_trans
self.num_tags = num_tags

# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags))
if self.include_start_end_trans:
self.start_scores = nn.Parameter(torch.randn(num_tags))
self.end_scores = nn.Parameter(torch.randn(num_tags))

if allowed_transitions is None:
constrain = torch.zeros(num_tags + 2, num_tags + 2)
else:
constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float)
has_start = False
has_end = False
for from_tag_id, to_tag_id in allowed_transitions:
constrain[from_tag_id, to_tag_id] = 0
if from_tag_id==num_tags:
has_start = True
if to_tag_id==num_tags+1:
has_end = True
if not has_start:
constrain[num_tags, :].fill_(0)
if not has_end:
constrain[:, num_tags+1].fill_(0)
self._constrain = nn.Parameter(constrain, requires_grad=False)

def _normalizer_likelihood(self, logits, mask):
r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.

:param logits:FloatTensor, max_len x batch_size x num_tags
:param mask:ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size
"""
seq_len, batch_size, n_tags = logits.size()
alpha = logits[0]
if self.include_start_end_trans:
alpha = alpha + self.start_scores.view(1, -1)

flip_mask = mask.eq(False)

for i in range(1, seq_len):
emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0)

if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1)

return torch.logsumexp(alpha, 1)

def _gold_score(self, logits, tags, mask):
r"""
Compute the score for the gold path.
:param logits: FloatTensor, max_len x batch_size x num_tags
:param tags: LongTensor, max_len x batch_size
:param mask: ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size
"""
seq_len, batch_size, _ = logits.size()
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)

# trans_socre [L-1, B]
mask = mask.eq(True)
flip_mask = mask.eq(False)
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
# emit_score [L, B]
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0)
# score [L-1, B]
score = trans_score + emit_score[:seq_len - 1, :]
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0)
if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = mask.long().sum(0) - 1
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
score = score + st_scores + ed_scores
# return [B,]
return score

def forward(self, feats, tags, mask):
r"""
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。

:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。
:return: torch.FloatTensor, (batch_size,)
"""
feats = feats.transpose(0, 1)
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._gold_score(feats, tags, mask)

return all_path_score - gold_path_score

def viterbi_decode(self, logits, mask, unpad=False):
r"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数

:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这
个sample的有效长度。
:return: 返回 (paths, scores)。
paths: 是解码后的路径, 其值参照unpad参数.
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。

"""
batch_size, max_len, n_tags = logits.size()
seq_len = mask.long().sum(1)
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.eq(True) # L, B
flip_mask = mask.eq(False)

# dp
vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0] # bsz x n_tags
transitions = self._constrain.data.clone()
transitions[:n_tags, :n_tags] += self.trans_m.data
if self.include_start_end_trans:
transitions[n_tags, :n_tags] += self.start_scores.data
transitions[:n_tags, n_tags + 1] += self.end_scores.data

vscore += transitions[n_tags, :n_tags]

trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags

# 针对长度为1的句子
vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \
.masked_fill(seq_len.ne(1).view(-1, 1), 0)
for i in range(1, max_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag
# 需要考虑当前位置是该序列的最后一个
score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0)

best_score, best_dst = score.max(1)
vpath[i] = best_dst
# 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况
vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device)
lens = (seq_len - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len

ans = logits.new_empty((max_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
for i in range(max_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i + 1], batch_idx] = last_tags
ans = ans.transpose(0, 1)
if unpad:
paths = []
for idx, max_len in enumerate(lens):
paths.append(ans[idx, :max_len + 1].tolist())
else:
paths = ans
return paths, ans_score

+ 416
- 0
fastNLP/modules/torch/decoder/seq2seq_decoder.py View File

@@ -0,0 +1,416 @@
r"""undocumented"""
from typing import Union, Tuple
import math

import torch
from torch import nn
import torch.nn.functional as F
from ..attention import AttentionLayer, MultiHeadAttention
from ....embeddings.torch.utils import get_embeddings
from ....embeddings.torch.static_embedding import StaticEmbedding
from .seq2seq_state import State, LSTMState, TransformerState


__all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder']


class Seq2SeqDecoder(nn.Module):
"""
Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。

"""
def __init__(self):
super().__init__()

def forward(self, tokens, state, **kwargs):
"""

:param torch.LongTensor tokens: bsz x max_len
:param State state: state包含了encoder的输出以及decode之前的内容
:return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布
"""
raise NotImplemented

def reorder_states(self, indices, states):
"""
根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。

:param torch.LongTensor indices:
:param State states:
:return:
"""
assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}"
states.reorder_state(indices)

def init_state(self, encoder_output, encoder_mask):
"""
初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。

:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch
维度
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch
维度
:param kwargs:
:return: State, 返回一个State对象,记录了encoder的输出
"""
state = State(encoder_output, encoder_mask)
return state

def decode(self, tokens, state):
"""
根据states中的内容,以及tokens中的内容进行之后的生成。

:param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。
:param State state: 记录了encoder输出与decoder过去状态
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布
"""
outputs = self(state=state, tokens=tokens)
if isinstance(outputs, torch.Tensor):
return outputs[:, -1]
else:
raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.")


class TiedEmbedding(nn.Module):
"""
用于将weight和原始weight绑定

"""
def __init__(self, weight):
super().__init__()
self.weight = weight # vocab_size x embed_size

def forward(self, x):
"""

:param torch.FloatTensor x: bsz x * x embed_size
:return: torch.FloatTensor bsz x * x vocab_size
"""
return torch.matmul(x, self.weight.t())


def get_bind_decoder_output_embed(embed):
"""
给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding

:param embed:
:return:
"""
if isinstance(embed, StaticEmbedding):
for idx, map2idx in enumerate(embed.words_to_words):
assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \
"include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \
"`lower=True` or `min_freq!=1`."
elif not isinstance(embed, nn.Embedding):
raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.")

return TiedEmbedding(embed.weight)


class LSTMSeq2SeqDecoder(Seq2SeqDecoder):
"""
LSTM的Decoder

:param nn.Module,tuple embed: decoder输入的embedding.
:param int num_layers: 多少层LSTM
:param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小
:param dropout: Dropout的大小
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding,
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1.
:param bool attention: 是否使用attention
"""
def __init__(self, embed: Union[nn.Module, Tuple[int, int]], num_layers = 3, hidden_size = 300,
dropout = 0.3, bind_decoder_input_output_embed = True, attention=True):
super().__init__()
self.embed = get_embeddings(init_embed=embed)
self.embed_dim = embed.embedding_dim

if bind_decoder_input_output_embed:
self.output_layer = get_bind_decoder_output_embed(self.embed)
else: # 不需要bind
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim))
self.output_layer = TiedEmbedding(self.output_embed.weight)

self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers,
batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0)

self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None
self.output_proj = nn.Linear(hidden_size, self.embed_dim)
self.dropout_layer = nn.Dropout(dropout)

def forward(self, tokens, state, return_attention=False):
"""

:param torch.LongTensor tokens: batch x max_len
:param LSTMState state: 保存encoder输出和decode状态的State对象
:param bool return_attention: 是否返回attention的的score
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length
"""
src_output = state.encoder_output
encoder_mask = state.encoder_mask

assert tokens.size(1)>state.decode_length, "The state does not match the tokens."
tokens = tokens[:, state.decode_length:]
x = self.embed(tokens)

attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq
input_feed = state.input_feed
decoder_out = []

cur_hidden = state.hidden
cur_cell = state.cell

# 开始计算
for i in range(tokens.size(1)):
input = torch.cat(
(x[:, i:i + 1, :],
input_feed[:, None, :]
),
dim=2
) # batch,1,2*dim
_, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size
if self.attention_layer is not None:
input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask)
attn_weights.append(attn_weight)
else:
input_feed = cur_hidden[-1]

state.input_feed = input_feed # batch, hidden
state.hidden = cur_hidden
state.cell = cur_cell
state.decode_length += 1
decoder_out.append(input_feed)

decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden
decoder_out = self.dropout_layer(decoder_out)
if attn_weights is not None:
attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len

decoder_out = self.output_proj(decoder_out)
feats = self.output_layer(decoder_out)

if return_attention:
return feats, attn_weights
return feats

def init_state(self, encoder_output, encoder_mask) -> LSTMState:
"""

:param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output:
bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的
hidden state和cell state来赋值这两个值
(2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化
:param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend
:return:
"""
if not isinstance(encoder_output, torch.Tensor):
encoder_output, (hidden, cell) = encoder_output
else:
hidden = cell = None
assert encoder_output.ndim==3
assert encoder_mask.size()==encoder_output.size()[:2]
assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \
"the hidden_size."

t = [hidden, cell]
for idx in range(2):
v = t[idx]
if v is None:
v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size)
else:
assert v.dim()==2
assert v.size(-1)==self.hidden_size
v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size
t[idx] = v

state = LSTMState(encoder_output, encoder_mask, t[0], t[1])

return state


class TransformerSeq2SeqDecoderLayer(nn.Module):
"""

:param int d_model: 输入、输出的维度
:param int n_head: 多少个head,需要能被d_model整除
:param int dim_ff:
:param float dropout:
:param int layer_idx: layer的编号
"""
def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None):
super().__init__()
self.d_model = d_model
self.n_head = n_head
self.dim_ff = dim_ff
self.dropout = dropout
self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息

self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx)
self.self_attn_layer_norm = nn.LayerNorm(d_model)

self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx)
self.encoder_attn_layer_norm = nn.LayerNorm(d_model)

self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(self.dim_ff, self.d_model),
nn.Dropout(dropout))

self.final_layer_norm = nn.LayerNorm(self.d_model)

def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None):
"""

:param x: (batch, seq_len, dim), decoder端的输入
:param encoder_output: (batch,src_seq_len,dim), encoder的输出
:param encoder_mask: batch,src_seq_len, 为1的地方需要attend
:param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入
:param TransformerState state: 只在inference阶段传入
:return:
"""

# self attention part
residual = x
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(query=x,
key=x,
value=x,
attn_mask=self_attn_mask,
state=state)

x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x

# encoder attention part
residual = x
x = self.encoder_attn_layer_norm(x)
x, attn_weight = self.encoder_attn(query=x,
key=encoder_output,
value=encoder_output,
key_mask=encoder_mask,
state=state)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x

# ffn
residual = x
x = self.final_layer_norm(x)
x = self.ffn(x)
x = residual + x

return x, attn_weight


class TransformerSeq2SeqDecoder(Seq2SeqDecoder):
"""

:param embed: 输入token的embedding
:param nn.Module pos_embed: 位置embedding
:param int d_model: 输出、输出的大小
:param int num_layers: 多少层
:param int n_head: 多少个head
:param int dim_ff: FFN 的中间大小
:param float dropout: Self-Attention和FFN中的dropout的大小
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding,
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1.
"""
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None,
d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1,
bind_decoder_input_output_embed = True):
super().__init__()

self.embed = get_embeddings(embed)
self.pos_embed = pos_embed

if bind_decoder_input_output_embed:
self.output_layer = get_bind_decoder_output_embed(self.embed)
else: # 不需要bind
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim))
self.output_layer = TiedEmbedding(self.output_embed.weight)

self.num_layers = num_layers
self.d_model = d_model
self.n_head = n_head
self.dim_ff = dim_ff
self.dropout = dropout

self.input_fc = nn.Linear(self.embed.embedding_dim, d_model)
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx)
for layer_idx in range(num_layers)])

self.embed_scale = math.sqrt(d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim)

def forward(self, tokens, state, return_attention=False):
"""

:param torch.LongTensor tokens: batch x tgt_len,decode的词
:param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取
:param bool return_attention: 是否返回对encoder结果的attention score
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length
"""

encoder_output = state.encoder_output
encoder_mask = state.encoder_mask

assert state.decode_length<tokens.size(1), "The decoded tokens in State should be less than tokens."
tokens = tokens[:, state.decode_length:]
device = tokens.device

x = self.embed_scale * self.embed(tokens)
if self.pos_embed is not None:
position = torch.arange(state.decode_length, state.decode_length+tokens.size(1)).long().to(device)[None]
x += self.pos_embed(position)
x = self.input_fc(x)
x = F.dropout(x, p=self.dropout, training=self.training)
batch_size, max_tgt_len = tokens.size()

if max_tgt_len>1:
triangle_mask = self._get_triangle_mask(tokens)
else:
triangle_mask = None

for layer in self.layer_stacks:
x, attn_weight = layer(x=x,
encoder_output=encoder_output,
encoder_mask=encoder_mask,
self_attn_mask=triangle_mask,
state=state
)

x = self.layer_norm(x) # batch, tgt_len, dim
x = self.output_fc(x)
feats = self.output_layer(x)

if return_attention:
return feats, attn_weight
return feats

def init_state(self, encoder_output, encoder_mask):
"""
初始化一个TransformerState用于forward

:param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出
:param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。
:return: TransformerState
"""
if isinstance(encoder_output, torch.Tensor):
encoder_output = encoder_output
elif isinstance(encoder_output, (list, tuple)):
encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果
else:
raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder")
state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers)
return state

@staticmethod
def _get_triangle_mask(tokens):
tensor = tokens.new_ones(tokens.size(1), tokens.size(1))
return torch.tril(tensor).byte()



+ 145
- 0
fastNLP/modules/torch/decoder/seq2seq_state.py View File

@@ -0,0 +1,145 @@
r"""
每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录

"""

__all__ = [
'State',
"LSTMState",
"TransformerState"
]

from typing import Union
import torch


class State:
def __init__(self, encoder_output=None, encoder_mask=None, **kwargs):
"""
每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。

:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch
维度
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch
维度
:param kwargs:
"""
self.encoder_output = encoder_output
self.encoder_mask = encoder_mask
self._decode_length = 0

@property
def num_samples(self):
"""
返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。

:return:
"""
if self.encoder_output is not None:
return self.encoder_output.size(0)
else:
return None

@property
def decode_length(self):
"""
当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。

:return:
"""
return self._decode_length

@decode_length.setter
def decode_length(self, value):
self._decode_length = value

def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0):
if isinstance(state, torch.Tensor):
state = state.index_select(index=indices, dim=dim)
elif isinstance(state, list):
for i in range(len(state)):
assert state[i] is not None
state[i] = self._reorder_state(state[i], indices, dim)
elif isinstance(state, tuple):
tmp_list = []
for i in range(len(state)):
assert state[i] is not None
tmp_list.append(self._reorder_state(state[i], indices, dim))
state = tuple(tmp_list)
else:
raise TypeError(f"Cannot reorder data of type:{type(state)}")

return state

def reorder_state(self, indices: torch.LongTensor):
if self.encoder_mask is not None:
self.encoder_mask = self._reorder_state(self.encoder_mask, indices)
if self.encoder_output is not None:
self.encoder_output = self._reorder_state(self.encoder_output, indices)


class LSTMState(State):
def __init__(self, encoder_output, encoder_mask, hidden, cell):
"""
LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态

:param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出
:param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding
:param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态
:param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态
"""
super().__init__(encoder_output, encoder_mask)
self.hidden = hidden
self.cell = cell
self._input_feed = hidden[0] # 默认是上一个时刻的输出

@property
def input_feed(self):
"""
LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时,
input_feed即上个时刻的hidden state, 否则是attention layer的输出。
:return: torch.FloatTensor, bsz x hidden_size
"""
return self._input_feed

@input_feed.setter
def input_feed(self, value):
self._input_feed = value

def reorder_state(self, indices: torch.LongTensor):
super().reorder_state(indices)
self.hidden = self._reorder_state(self.hidden, indices, dim=1)
self.cell = self._reorder_state(self.cell, indices, dim=1)
if self.input_feed is not None:
self.input_feed = self._reorder_state(self.input_feed, indices, dim=0)


class TransformerState(State):
def __init__(self, encoder_output, encoder_mask, num_decoder_layer):
"""
与TransformerSeq2SeqDecoder对应的State,

:param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出
:param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend
:param int num_decoder_layer: decode有多少层
"""
super().__init__(encoder_output, encoder_mask)
self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim
self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim
self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim
self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim

def reorder_state(self, indices: torch.LongTensor):
super().reorder_state(indices)
self.encoder_key = self._reorder_state(self.encoder_key, indices)
self.encoder_value = self._reorder_state(self.encoder_value, indices)
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices)
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices)

@property
def decode_length(self):
if self.decoder_prev_key[0] is not None:
return self.decoder_prev_key[0].size(1)
return 0



+ 24
- 0
fastNLP/modules/torch/dropout.py View File

@@ -0,0 +1,24 @@
r"""undocumented"""

__all__ = [
"TimestepDropout"
]

import torch


class TimestepDropout(torch.nn.Dropout):
r"""
传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)``
使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。
"""
def forward(self, x):
dropout_mask = x.new_ones(x.shape[0], x.shape[-1])
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True)
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim]
if self.inplace:
x *= dropout_mask
return
else:
return x * dropout_mask

+ 17
- 1
fastNLP/modules/torch/encoder/__init__.py View File

@@ -1,5 +1,21 @@
__all__ = [
"ConvMaxpool",

"LSTM",

"Seq2SeqEncoder",
"TransformerSeq2SeqEncoder",
"LSTMSeq2SeqEncoder",

"StarTransformer",

"VarRNN",
"VarLSTM",
"VarGRU"
]

from .lstm import LSTM
from .conv_maxpool import ConvMaxpool
from .lstm import LSTM
from .seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
from .star_transformer import StarTransformer
from .variational_rnn import VarRNN, VarLSTM, VarGRU

+ 87
- 0
fastNLP/modules/torch/encoder/conv_maxpool.py View File

@@ -0,0 +1,87 @@
r"""undocumented"""

__all__ = [
"ConvMaxpool"
]
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvMaxpool(nn.Module):
r"""
集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x
sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len)
这一维进行max_pooling。最后得到每个sample的一个向量表示。

"""

def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"):
r"""
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh
"""
super(ConvMaxpool, self).__init__()

for kernel_size in kernel_sizes:
assert kernel_size % 2 == 1, "kernel size has to be odd numbers."

# convolution
if isinstance(kernel_sizes, (list, tuple, int)):
if isinstance(kernel_sizes, int) and isinstance(out_channels, int):
out_channels = [out_channels]
kernel_sizes = [kernel_sizes]
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)):
assert len(out_channels) == len(
kernel_sizes), "The number of out_channels should be equal to the number" \
" of kernel_sizes."
else:
raise ValueError("The type of out_channels and kernel_sizes should be the same.")

self.convs = nn.ModuleList([nn.Conv1d(
in_channels=in_channels,
out_channels=oc,
kernel_size=ks,
stride=1,
padding=ks // 2,
dilation=1,
groups=1,
bias=False)
for oc, ks in zip(out_channels, kernel_sizes)])

else:
raise Exception(
'Incorrect kernel sizes: should be list, tuple or int')

# activation function
if activation == 'relu':
self.activation = F.relu
elif activation == 'sigmoid':
self.activation = F.sigmoid
elif activation == 'tanh':
self.activation = F.tanh
else:
raise Exception(
"Undefined activation function: choose from: relu, tanh, sigmoid")

def forward(self, x, mask=None):
r"""

:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置
:return:
"""
# [N,L,C] -> [N,C,L]
x = torch.transpose(x, 1, 2)
# convolution
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...]
if mask is not None:
mask = mask.unsqueeze(1) # B x 1 x L
xs = [x.masked_fill(mask.eq(False), float('-inf')) for x in xs]
# max-pooling
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2)
for i in xs] # [[N, C], ...]
return torch.cat(xs, dim=-1) # [N, C]

+ 193
- 0
fastNLP/modules/torch/encoder/seq2seq_encoder.py View File

@@ -0,0 +1,193 @@
r"""undocumented"""
import torch.nn as nn
import torch
from torch.nn import LayerNorm
import torch.nn.functional as F
from typing import Union, Tuple
from ....core.utils import seq_len_to_mask
import math
from .lstm import LSTM
from ..attention import MultiHeadAttention
from ....embeddings.torch import StaticEmbedding
from ....embeddings.torch.utils import get_embeddings


__all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder']


class Seq2SeqEncoder(nn.Module):
"""
所有Sequence2Sequence Encoder的基类。需要实现forward函数

"""
def __init__(self):
super().__init__()

def forward(self, tokens, seq_len):
"""

:param torch.LongTensor tokens: bsz x max_len, encoder的输入
:param torch.LongTensor seq_len: bsz
:return:
"""
raise NotImplementedError


class TransformerSeq2SeqEncoderLayer(nn.Module):
"""
Self-Attention的Layer,

:param int d_model: input和output的输出维度
:param int n_head: 多少个head,每个head的维度为d_model/n_head
:param int dim_ff: FFN的维度大小
:param float dropout: Self-attention和FFN的dropout大小,0表示不drop
"""
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048,
dropout: float = 0.1):
super(TransformerSeq2SeqEncoderLayer, self).__init__()
self.d_model = d_model
self.n_head = n_head
self.dim_ff = dim_ff
self.dropout = dropout

self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
self.attn_layer_norm = LayerNorm(d_model)
self.ffn_layer_norm = LayerNorm(d_model)

self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(self.dim_ff, self.d_model),
nn.Dropout(dropout))

def forward(self, x, mask):
"""

:param x: batch x src_seq x d_model
:param mask: batch x src_seq,为0的地方为padding
:return:
"""
# attention
residual = x
x = self.attn_layer_norm(x)
x, _ = self.self_attn(query=x,
key=x,
value=x,
key_mask=mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x

# ffn
residual = x
x = self.ffn_layer_norm(x)
x = self.ffn(x)
x = residual + x

return x


class TransformerSeq2SeqEncoder(Seq2SeqEncoder):
"""
基于Transformer的Encoder

:param embed: encoder输入token的embedding
:param nn.Module pos_embed: position embedding
:param int num_layers: 多少层的encoder
:param int d_model: 输入输出的维度
:param int n_head: 多少个head
:param int dim_ff: FFN中间的维度大小
:param float dropout: Attention和FFN的dropout大小
"""
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None,
num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1):
super(TransformerSeq2SeqEncoder, self).__init__()
self.embed = get_embeddings(embed)
self.embed_scale = math.sqrt(d_model)
self.pos_embed = pos_embed
self.num_layers = num_layers
self.d_model = d_model
self.n_head = n_head
self.dim_ff = dim_ff
self.dropout = dropout

self.input_fc = nn.Linear(self.embed.embedding_dim, d_model)
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout)
for _ in range(num_layers)])
self.layer_norm = LayerNorm(d_model)

def forward(self, tokens, seq_len):
"""

:param tokens: batch x max_len
:param seq_len: [batch]
:return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding)
"""
x = self.embed(tokens) * self.embed_scale # batch, seq, dim
batch_size, max_src_len, _ = x.size()
device = x.device
if self.pos_embed is not None:
position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device)
x += self.pos_embed(position)

x = self.input_fc(x)
x = F.dropout(x, p=self.dropout, training=self.training)

encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len)
encoder_mask = encoder_mask.to(device)

for layer in self.layer_stacks:
x = layer(x, encoder_mask)

x = self.layer_norm(x)

return x, encoder_mask


class LSTMSeq2SeqEncoder(Seq2SeqEncoder):
"""
LSTM的Encoder

:param embed: encoder的token embed
:param int num_layers: 多少层
:param int hidden_size: LSTM隐藏层、输出的大小
:param float dropout: LSTM层之间的Dropout是多少
:param bool bidirectional: 是否使用双向
"""
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3,
hidden_size = 400, dropout = 0.3, bidirectional=True):
super().__init__()
self.embed = get_embeddings(embed)
self.num_layers = num_layers
self.dropout = dropout
self.hidden_size = hidden_size
self.bidirectional = bidirectional
hidden_size = hidden_size//2 if bidirectional else hidden_size
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional,
batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers)

def forward(self, tokens, seq_len):
"""

:param torch.LongTensor tokens: bsz x max_len
:param torch.LongTensor seq_len: bsz
:return: (output, (hidden, cell)), encoder_mask
output: bsz x max_len x hidden_size,
hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态
encoder_mask: bsz x max_len, 为0的地方是padding
"""
x = self.embed(tokens)
device = x.device
x, (final_hidden, final_cell) = self.lstm(x, seq_len)
encoder_mask = seq_len_to_mask(seq_len).to(device)

# x: batch,seq_len,dim; h/c: num_layers*2,batch,dim

if self.bidirectional:
final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input
final_cell = self.concat_bidir(final_cell)

return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return

def concat_bidir(self, input):
output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2)
return output.reshape(self.num_layers, input.size(1), -1)

+ 166
- 0
fastNLP/modules/torch/encoder/star_transformer.py View File

@@ -0,0 +1,166 @@
r"""undocumented
Star-Transformer 的encoder部分的 Pytorch 实现
"""

__all__ = [
"StarTransformer"
]

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


class StarTransformer(nn.Module):
r"""
Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码

paper: https://arxiv.org/abs/1902.09113

"""

def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None):
r"""
:param int hidden_size: 输入维度的大小。同时也是输出维度的大小。
:param int num_layers: star-transformer的层数
:param int num_head: head的数量。
:param int head_dim: 每个head的维度大小。
:param float dropout: dropout 概率. Default: 0.1
:param int max_len: int or None, 如果为int,输入序列的最大长度,
模型会为输入序列加上position embedding。
若为`None`,忽略加上position embedding的步骤. Default: `None`
"""
super(StarTransformer, self).__init__()
self.iters = num_layers

self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)])
# self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
self.emb_drop = nn.Dropout(dropout)
self.ring_att = nn.ModuleList(
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)])
self.star_att = nn.ModuleList(
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)])

if max_len is not None:
self.pos_emb = nn.Embedding(max_len, hidden_size)
else:
self.pos_emb = None

def forward(self, data, mask):
r"""
:param FloatTensor data: [batch, length, hidden] 输入的序列
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0,
否则为 1
:return: [batch, length, hidden] 编码后的输出序列

[batch, hidden] 全局 relay 节点, 详见论文
"""

def norm_func(f, x):
# B, H, L, 1
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

B, L, H = data.size()
mask = (mask.eq(False)) # flip the mask for masked_fill_
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)

embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
if self.pos_emb:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
embs = embs + P
embs = norm_func(self.emb_drop, embs)
nodes = embs
relay = embs.mean(2, keepdim=True)
ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
r_embs = embs.view(B, H, 1, L)
for i in range(self.iters):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
# nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))

nodes = nodes.masked_fill_(ex_mask, 0)

nodes = nodes.view(B, H, L).permute(0, 2, 1)

return nodes, relay.view(B, H)


class _MSA1(nn.Module):
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
super(_MSA1, self).__init__()
# Multi-head Self Attention Case 1, doing self-attention for small regions
# Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

self.drop = nn.Dropout(dropout)

# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3

def forward(self, x, ax=None):
# x: B, H, L, 1, ax : B, H, X, L append features
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = x.shape

q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1)

if ax is not None:
aL = ax.shape[2]
ak = self.WK(ax).view(B, nhead, head_dim, aL, L)
av = self.WV(ax).view(B, nhead, head_dim, aL, L)
q = q.view(B, nhead, head_dim, 1, L)
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
.view(B, nhead, head_dim, unfold_size, L)
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
.view(B, nhead, head_dim, unfold_size, L)
if ax is not None:
k = torch.cat([k, ak], 3)
v = torch.cat([v, av], 3)

alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / np.sqrt(head_dim), 3)) # B N L 1 U
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1)

ret = self.WO(att)

return ret


class _MSA2(nn.Module):
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
# Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value
super(_MSA2, self).__init__()
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

self.drop = nn.Dropout(dropout)

# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3

def forward(self, x, y, mask=None):
# x: B, H, 1, 1, 1 y: B H L 1
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = y.shape

q, k, v = self.WQ(x), self.WK(y), self.WV(y)

q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h
pre_a = torch.matmul(q, k) / np.sqrt(head_dim)
if mask is not None:
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1
return self.WO(att)

+ 43
- 0
fastNLP/modules/torch/encoder/transformer.py View File

@@ -0,0 +1,43 @@
r"""undocumented"""

__all__ = [
"TransformerEncoder"
]

from torch import nn

from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer


class TransformerEncoder(nn.Module):
r"""
transformer的encoder模块,不包含embedding层

"""
def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1):
"""

:param int num_layers: 多少层Transformer
:param int d_model: input和output的大小
:param int n_head: 多少个head
:param int dim_ff: FFN中间hidden大小
:param float dropout: 多大概率drop attention和ffn中间的表示
"""
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff,
dropout = dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model, eps=1e-6)

def forward(self, x, seq_mask=None):
r"""
:param x: [batch, seq_len, model_size] 输入序列
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend
Default: ``None``
:return: [batch, seq_len, model_size] 输出序列
"""
output = x
if seq_mask is None:
seq_mask = x.new_ones(x.size(0), x.size(1)).bool()
for layer in self.layers:
output = layer(output, seq_mask)
return self.norm(output)

+ 303
- 0
fastNLP/modules/torch/encoder/variational_rnn.py View File

@@ -0,0 +1,303 @@
r"""undocumented
Variational RNN 及相关模型的 fastNLP实现,相关论文参考:
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_
"""

__all__ = [
"VarRNN",
"VarLSTM",
"VarGRU"
]

import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence

try:
from torch import flip
except ImportError:
def flip(x, dims):
indices = [slice(None)] * x.dim()
for dim in dims:
indices[dim] = torch.arange(
x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)]


class VarRnnCellWrapper(nn.Module):
r"""
Wrapper for normal RNN Cells, make it support variational dropout
"""

def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
self.hidden_size = hidden_size
self.input_p = input_p
self.hidden_p = hidden_p

def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False):
r"""
:param PackedSequence input_x: [seq_len, batch_size, input_size]
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size]
for other RNN, h_0, [batch_size, hidden_size]
:param mask_x: [batch_size, input_size] dropout mask for input
:param mask_h: [batch_size, hidden_size] dropout mask for hidden
:return PackedSequence output: [seq_len, bacth_size, hidden_size]
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
"""

def get_hi(hi, h0, size):
h0_size = size - hi.size(0)
if h0_size > 0:
return torch.cat([hi, h0[:h0_size]], dim=0)
return hi[:size]

is_lstm = isinstance(hidden, tuple)
input, batch_sizes = input_x.data, input_x.batch_sizes
output = []
cell = self.cell
if is_reversed:
batch_iter = flip(batch_sizes, [0])
idx = input.size(0)
else:
batch_iter = batch_sizes
idx = 0

if is_lstm:
hn = (hidden[0].clone(), hidden[1].clone())
else:
hn = hidden.clone()
hi = hidden
for size in batch_iter:
if is_reversed:
input_i = input[idx - size: idx] * mask_x[:size]
idx -= size
else:
input_i = input[idx: idx + size] * mask_x[:size]
idx += size
mask_hi = mask_h[:size]
if is_lstm:
hx, cx = hi
hi = (get_hi(hx, hidden[0], size) *
mask_hi, get_hi(cx, hidden[1], size))
hi = cell(input_i, hi)
hn[0][:size] = hi[0]
hn[1][:size] = hi[1]
output.append(hi[0])
else:
hi = get_hi(hi, hidden, size) * mask_hi
hi = cell(input_i, hi)
hn[:size] = hi
output.append(hi)

if is_reversed:
output = list(reversed(output))
output = torch.cat(output, dim=0)
return PackedSequence(output, batch_sizes), hn


class VarRNNBase(nn.Module):
r"""
Variational Dropout RNN 实现.

论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016)
https://arxiv.org/abs/1512.05287`.

"""

def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
r"""
:param mode: rnn 模式, (lstm or not)
:param Cell: rnn cell 类型, (lstm, gru, etc)
:param input_size: 输入 `x` 的特征维度
:param hidden_size: 隐状态 `h` 的特征维度
:param num_layers: rnn的层数. Default: 1
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为
(batch, seq, feature). Default: ``False``
:param input_dropout: 对输入的dropout概率. Default: 0
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
"""
super(VarRNNBase, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.input_dropout = input_dropout
self.hidden_dropout = hidden_dropout
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
self._all_cells = nn.ModuleList()
for layer in range(self.num_layers):
for direction in range(self.num_directions):
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
cell = Cell(input_size, self.hidden_size, bias)
self._all_cells.append(VarRnnCellWrapper(
cell, self.hidden_size, input_dropout, hidden_dropout))
self.is_lstm = (self.mode == "LSTM")

def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h):
is_lstm = self.is_lstm
idx = self.num_directions * n_layer + n_direction
cell = self._all_cells[idx]
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]
output_x, hidden_x = cell(
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1))
return output_x, hidden_x

def forward(self, x, hx=None):
r"""

:param x: [batch, seq_len, input_size] 输入序列
:param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None``
:return (output, ht): [batch, seq_len, hidden_size*num_direction] 输出序列
和 [batch, hidden_size*num_direction] 最后时刻隐状态
"""
is_lstm = self.is_lstm
is_packed = isinstance(x, PackedSequence)
if not is_packed:
seq_len = x.size(1) if self.batch_first else x.size(0)
max_batch_size = x.size(0) if self.batch_first else x.size(1)
seq_lens = torch.LongTensor(
[seq_len for _ in range(max_batch_size)])
x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first)
else:
max_batch_size = int(x.batch_sizes[0])
x, batch_sizes = x.data, x.batch_sizes

if hx is None:
hx = x.new_zeros(self.num_layers * self.num_directions,
max_batch_size, self.hidden_size, requires_grad=True)
if is_lstm:
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True))

mask_x = x.new_ones((max_batch_size, self.input_size))
mask_out = x.new_ones(
(max_batch_size, self.hidden_size * self.num_directions))
mask_h_ones = x.new_ones((max_batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_dropout,
training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout,
training=self.training, inplace=True)

hidden = x.new_zeros(
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
if is_lstm:
cellstate = x.new_zeros(
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
for layer in range(self.num_layers):
output_list = []
input_seq = PackedSequence(x, batch_sizes)
mask_h = nn.functional.dropout(
mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False)
for direction in range(self.num_directions):
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx,
mask_x if layer == 0 else mask_out, mask_h)
output_list.append(output_x.data)
idx = self.num_directions * layer + direction
if is_lstm:
hidden[idx] = hidden_x[0]
cellstate[idx] = hidden_x[1]
else:
hidden[idx] = hidden_x
x = torch.cat(output_list, dim=-1)

if is_lstm:
hidden = (hidden, cellstate)

if is_packed:
output = PackedSequence(x, batch_sizes)
else:
x = PackedSequence(x, batch_sizes)
output, _ = pad_packed_sequence(x, batch_first=self.batch_first)

return output, hidden


class VarLSTM(VarRNNBase):
r"""
Variational Dropout LSTM.
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_

"""

def __init__(self, *args, **kwargs):
r"""
:param input_size: 输入 `x` 的特征维度
:param hidden_size: 隐状态 `h` 的特征维度
:param num_layers: rnn的层数. Default: 1
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为
(batch, seq, feature). Default: ``False``
:param input_dropout: 对输入的dropout概率. Default: 0
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False``
"""
super(VarLSTM, self).__init__(
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)

def forward(self, x, hx=None):
return super(VarLSTM, self).forward(x, hx)


class VarRNN(VarRNNBase):
r"""
Variational Dropout RNN.
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_
"""

def __init__(self, *args, **kwargs):
r"""
:param input_size: 输入 `x` 的特征维度
:param hidden_size: 隐状态 `h` 的特征维度
:param num_layers: rnn的层数. Default: 1
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为
(batch, seq, feature). Default: ``False``
:param input_dropout: 对输入的dropout概率. Default: 0
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
"""
super(VarRNN, self).__init__(
mode="RNN", Cell=nn.RNNCell, *args, **kwargs)

def forward(self, x, hx=None):
return super(VarRNN, self).forward(x, hx)


class VarGRU(VarRNNBase):
r"""
Variational Dropout GRU.
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_
"""

def __init__(self, *args, **kwargs):
r"""
:param input_size: 输入 `x` 的特征维度
:param hidden_size: 隐状态 `h` 的特征维度
:param num_layers: rnn的层数. Default: 1
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为
(batch, seq, feature). Default: ``False``
:param input_dropout: 对输入的dropout概率. Default: 0
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False``
"""
super(VarGRU, self).__init__(
mode="GRU", Cell=nn.GRUCell, *args, **kwargs)

def forward(self, x, hx=None):
return super(VarGRU, self).forward(x, hx)

+ 6
- 0
fastNLP/modules/torch/generator/__init__.py View File

@@ -0,0 +1,6 @@
__all__ = [
'SequenceGenerator'
]


from .seq2seq_generator import SequenceGenerator

+ 536
- 0
fastNLP/modules/torch/generator/seq2seq_generator.py View File

@@ -0,0 +1,536 @@
r"""

"""

__all__ = [
'SequenceGenerator'
]

import torch
from torch import nn
import torch.nn.functional as F
from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State
from functools import partial


def _get_model_device(model):
r"""
传入一个nn.Module的模型,获取它所在的device

:param model: nn.Module
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。
"""
assert isinstance(model, nn.Module)

parameters = list(model.parameters())
if len(parameters) == 0:
return None
else:
return parameters[0].device



class SequenceGenerator:
"""
给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出,
第二个参数为state。SequenceGenerator不会对state进行任何操作。

"""
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1,
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0, pad_token_id=0):
"""

:param Seq2SeqDecoder decoder: Decoder对象
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
:param int num_beams: beam search的大小
:param bool do_sample: 是否通过采样的方式生成
:param float temperature: 只有在do_sample为True才有意义
:param int top_k: 只从top_k中采样
:param float top_p: 只从top_p的token中采样,nucles sample
:param int,None bos_token_id: 句子开头的token id
:param int,None eos_token_id: 句子结束的token id
:param float repetition_penalty: 多大程度上惩罚重复的token
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充
"""
if do_sample:
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty,
length_penalty=length_penalty, pad_token_id=pad_token_id)
else:
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty, pad_token_id=pad_token_id)
self.do_sample = do_sample
self.max_length = max_length
self.num_beams = num_beams
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.repetition_penalty = repetition_penalty
self.length_penalty = length_penalty
self.decoder = decoder

@torch.no_grad()
def generate(self, state, tokens=None):
"""

:param State state: encoder结果的State, 是与Decoder配套是用的
:param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token
进行生成。
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id
"""

return self.generate_func(tokens=tokens, state=state)


@torch.no_grad()
def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1,
bos_token_id=None, eos_token_id=None, pad_token_id=0,
repetition_penalty=1, length_penalty=1.0):
"""
贪婪地搜索句子

:param Decoder decoder: Decoder对象
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
:param State state: 应该包含encoder的一些输出。
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
:param int num_beams: 使用多大的beam进行解码。
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。
:param int pad_token_id: pad的token id
:param float repetition_penalty: 对重复出现的token多大的惩罚。
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。
:return:
"""
if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
temperature=1, top_k=50, top_p=1,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id)
else:
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams, temperature=1, top_k=50, top_p=1,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id)

return token_ids


@torch.no_grad()
def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0,
length_penalty=1.0):
"""
使用采样的方法生成句子

:param Decoder decoder: Decoder对象
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
:param State state: 应该包含encoder的一些输出。
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
:param int num_beam: 使用多大的beam进行解码。
:param float temperature: 采样时的退火大小
:param int top_k: 只在top_k的sample里面采样
:param float top_p: 介于0,1的值。
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。
:param int pad_token_id: pad的token id
:param float repetition_penalty: 对重复出现的token多大的惩罚。
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。
:return:
"""
# 每个位置在生成的时候会sample生成
if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
temperature=temperature, top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id)
else:
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id)
return token_ids


def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, max_len_a=0.0, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0):
device = _get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
batch_size = state.num_samples
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `state`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if state.num_samples:
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match."

if eos_token_id is None:
_eos_token_id = -1
else:
_eos_token_id = eos_token_id

scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state
if _eos_token_id!=-1: # 防止第一个位置为结束
scores[:, _eos_token_id] = -1e12
next_tokens = scores.argmax(dim=-1, keepdim=True)
token_ids = torch.cat([tokens, next_tokens], dim=1)
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1)
# tokens = tokens[:, -1:]

if max_len_a!=0:
# (bsz x num_beams, )
if state.encoder_mask is not None:
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
else:
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long)
real_max_length = max_lengths.max().item()
else:
real_max_length = max_length
if state.encoder_mask is not None:
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length
else:
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)

while cur_len < real_max_length:
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size

if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores)

if eos_token_id is not None and length_penalty != 1.0:
token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size
eos_mask = scores.new_ones(scores.size(1))
eos_mask[eos_token_id] = 0
eos_mask = eos_mask.unsqueeze(0).eq(1)
scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小

if do_sample:
if temperature > 0 and temperature != 1:
scores = scores / temperature

scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523
probs = F.softmax(scores, dim=-1) + 1e-12

next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
else:
next_tokens = torch.argmax(scores, dim=-1) # batch_size

# 如果已经达到对应的sequence长度了,就直接填为eos了
if _eos_token_id!=-1:
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id)
next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding
tokens = next_tokens.unsqueeze(1)

token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len

end_mask = next_tokens.eq(_eos_token_id)
dones = dones.__or__(end_mask)
cur_len += 1

if dones.min() == 1:
break

# if eos_token_id is not None:
# tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # 将最大长度位置设置为eos
# if cur_len == max_length:
# token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos
return token_ids


def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, temperature=1.0,
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor:
# 进行beam search
device = _get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
batch_size = state.num_samples
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `state`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if state.num_samples:
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match."

if eos_token_id is None:
_eos_token_id = -1
else:
_eos_token_id = eos_token_id

scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度
if _eos_token_id!=-1: # 防止第一个位置为结束
scores[:, _eos_token_id] = -1e12
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."

if do_sample:
probs = F.softmax(scores, dim=-1) + 1e-12
next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams)
logits = probs.log()
next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams)
else:
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
# 得到(batch_size, num_beams), (batch_size, num_beams)
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)

# 根据index来做顺序的调转
indices = torch.arange(batch_size, dtype=torch.long).to(device)
indices = indices.repeat_interleave(num_beams)
state.reorder_state(indices)

tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
# 记录生成好的token (batch_size', cur_len)
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
dones = [False] * batch_size

beam_scores = next_scores.view(-1) # batch_size * num_beams

# 用来记录已经生成好的token的长度
cur_len = token_ids.size(1)

if max_len_a!=0:
# (bsz x num_beams, )
if state.encoder_mask is not None:
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
else:
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long)
real_max_length = max_lengths.max().item()
else:
real_max_length = max_length
if state.encoder_mask is not None:
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length
else:
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
hypos = [
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
# 0, num_beams, 2*num_beams, ...
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)

while cur_len < real_max_length:
scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size)
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores)

if _eos_token_id!=-1:
max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id]
# 如果已经达到最大长度,就把eos的分数加大
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores)

if do_sample:
if temperature > 0 and temperature != 1:
scores = scores / temperature

# 多召回一个防止eos
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1)
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523
probs = F.softmax(scores, dim=-1) + 1e-12

# 保证至少有一个不是eos的值
_tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1)

logits = probs.log()
# 防止全是这个beam的被选中了,且需要考虑eos被选择的情况
_scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1)
_scores = _scores + beam_scores[:, None] # batch_size' x (num_beams+1)
# 从这里面再选择top的2*num_beam个
_scores = _scores.view(batch_size, num_beams * (num_beams + 1))
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1))
next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams)
from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams)
else:
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size)
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams)
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams)
next_tokens = ids % vocab_size # (batch_size, 2*num_beams)

# 接下来需要组装下一个batch的结果。
# 需要选定哪些留下来
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)

not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留
keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的

_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) # 上面的token是来自哪个beam
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
beam_scores = _next_scores.view(-1)

flag = True
if cur_len+1 == real_max_length:
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的
else:
# 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams
if effective_eos_mask.sum().gt(0):
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
# 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos
else:
flag = False

if flag:
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(),
eos_beam_idx.tolist()):
if not dones[batch_idx]:
score = next_scores[batch_idx, beam_ind].item()
# 之后需要在结尾新增一个eos
if _eos_token_id!=-1:
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
else:
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)

# 更改state状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
state.reorder_state(reorder_inds)
# 重新组织token_ids的状态
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)

for batch_idx in range(batch_size):
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \
max_lengths[batch_idx*num_beams]==cur_len+1

cur_len += 1

if all(dones):
break

# select the best hypotheses
tgt_len = token_ids.new_zeros(batch_size)
best = []

for i, hypotheses in enumerate(hypos):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
# 把上面替换为非eos的词替换回eos
if _eos_token_id!=-1:
best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id])
tgt_len[i] = len(best_hyp)
best.append(best_hyp)

# generate target batch
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best):
decoded[i, :tgt_len[i]] = hypo

return decoded


class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.hyp = []
self.worst_score = 1e9

def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)

def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)

def is_done(self, best_sum_logprobs):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty


def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
"""
根据top_k, top_p的值,将不满足的值置为filter_value的值

:param torch.Tensor logits: bsz x vocab_size
:param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value
:param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式
:param float filter_value:
:param int min_tokens_to_keep: 每个sample返回的分布中有概率的词不会低于这个值
:return:
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value

if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits

+ 1
- 0
tests/helpers/data/modules/decoder/crf.json
File diff suppressed because it is too large
View File


+ 0
- 0
tests/models/__init__.py View File


+ 0
- 0
tests/models/torch/__init__.py View File


+ 142
- 0
tests/models/torch/model_runner.py View File

@@ -0,0 +1,142 @@
"""
此模块可以非常方便的测试模型。
若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试
若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试

此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能

Example::

# import 全大写变量...
from model_runner import *

# 测试一个文本分类模型
init_emb = (VOCAB_SIZE, 50)
model = SomeModel(init_emb, num_cls=NUM_CLS)
RUNNER.run_model_with_task(TEXT_CLS, model)

# 序列标注模型
RUNNER.run_model_with_task(POS_TAGGING, model)

# NLI模型
RUNNER.run_model_with_task(NLI, model)

# 自定义模型
RUNNER.run_model(model, data=get_mydata(),
loss=Myloss(), metrics=Mymetric())
"""
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch import optim
from fastNLP import Trainer, Evaluator, DataSet, Callback
from fastNLP import Accuracy
from random import randrange
from fastNLP import TorchDataLoader

VOCAB_SIZE = 100
NUM_CLS = 100
MAX_LEN = 10
N_SAMPLES = 100
N_EPOCHS = 1
BATCH_SIZE = 5

TEXT_CLS = 'text_cls'
POS_TAGGING = 'pos_tagging'
NLI = 'nli'

class ModelRunner():
class Checker(Callback):
def on_backward_begin(self, trainer, outputs):
assert outputs['loss'].to('cpu').numpy().isfinate()

def gen_seq(self, length, vocab_size):
"""generate fake sequence indexes with given length"""
# reserve 0 for padding
return [randrange(1, vocab_size) for _ in range(length)]

def gen_var_seq(self, max_len, vocab_size):
"""generate fake sequence indexes in variant length"""
length = randrange(3, max_len) # at least 3 words in a seq
return self.gen_seq(length, vocab_size)

def prepare_text_classification_data(self):
index = 'index'
ds = DataSet({index: list(range(N_SAMPLES))})
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name='words')
ds.apply_field(lambda x: randrange(NUM_CLS),
field_name=index, new_field_name='target')
ds.apply_field(len, 'words', 'seq_len')
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE)
return dl

def prepare_pos_tagging_data(self):
index = 'index'
ds = DataSet({index: list(range(N_SAMPLES))})
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name='words')
ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS),
field_name='words', new_field_name='target')
ds.apply_field(len, 'words', 'seq_len')
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE)
return dl

def prepare_nli_data(self):
index = 'index'
ds = DataSet({index: list(range(N_SAMPLES))})
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name='words1')
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name='words2')
ds.apply_field(lambda x: randrange(NUM_CLS),
field_name=index, new_field_name='target')
ds.apply_field(len, 'words1', 'seq_len1')
ds.apply_field(len, 'words2', 'seq_len2')
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE)
return dl

def run_text_classification(self, model, data=None):
if data is None:
data = self.prepare_text_classification_data()
metric = Accuracy()
self.run_model(model, data, metric)

def run_pos_tagging(self, model, data=None):
if data is None:
data = self.prepare_pos_tagging_data()
metric = Accuracy()
self.run_model(model, data, metric)

def run_nli(self, model, data=None):
if data is None:
data = self.prepare_nli_data()
metric = Accuracy()
self.run_model(model, data, metric)

def run_model(self, model, data, metrics):
"""run a model, test if it can run with fastNLP"""
print('testing model:', model.__class__.__name__)
tester = Evaluator(model, data, metrics={'metric': metrics}, driver='torch')
before_train = tester.run()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
trainer = Trainer(model, driver='torch', train_dataloader=data,
n_epochs=N_EPOCHS, optimizers=optimizer)
trainer.run()
after_train = tester.run()
for metric_name, v1 in before_train.items():
assert metric_name in after_train
# # at least we can sure model params changed, even if we don't know performance
# v2 = after_train[metric_name]
# assert v1 != v2

def run_model_with_task(self, task, model):
"""run a model with certain task"""
TASKS = {
TEXT_CLS: self.run_text_classification,
POS_TAGGING: self.run_pos_tagging,
NLI: self.run_nli,
}
assert task in TASKS
TASKS[task](model)

RUNNER = ModelRunner()

+ 91
- 0
tests/models/torch/test_biaffine_parser.py View File

@@ -0,0 +1,91 @@
import pytest
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from fastNLP.models.torch.biaffine_parser import BiaffineParser
from fastNLP import Metric, seq_len_to_mask
from .model_runner import *


class ParserMetric(Metric):
r"""
评估parser的性能

"""

def __init__(self):
super().__init__()
self.num_arc = 0
self.num_label = 0
self.num_sample = 0

def get_metric(self, reset=True):
res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample}
if reset:
self.num_sample = self.num_label = self.num_arc = 0
return res

def update(self, pred1, pred2, target1, target2, seq_len=None):
r"""

:param pred1: 边预测logits
:param pred2: label预测logits
:param target1: 真实边的标注
:param target2: 真实类别的标注
:param seq_len: 序列长度
:return dict: 评估结果::

UAS: 不带label时, 边预测的准确率
LAS: 同时预测边和label的准确率
"""
if seq_len is None:
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long)
else:
seq_mask = seq_len_to_mask(seq_len.long()).long()
# mask out <root> tag
seq_mask[:, 0] = 0
head_pred_correct = (pred1 == target1).long() * seq_mask
label_pred_correct = (pred2 == target2).long() * head_pred_correct
self.num_arc += head_pred_correct.sum().item()
self.num_label += label_pred_correct.sum().item()
self.num_sample += seq_mask.sum().item()


def prepare_parser_data():
index = 'index'
ds = DataSet({index: list(range(N_SAMPLES))})
ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name='words1')
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS),
field_name='words1', new_field_name='words2')
# target1 is heads, should in range(0, len(words))
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)),
field_name='words1', new_field_name='target1')
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS),
field_name='words1', new_field_name='target2')
ds.apply_field(len, field_name='words1', new_field_name='seq_len')
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE)
return dl


@pytest.mark.torch
class TestBiaffineParser:
def test_train(self):
model = BiaffineParser(embed=(VOCAB_SIZE, 10),
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
rnn_hidden_size=10,
arc_mlp_size=10,
label_mlp_size=10,
num_label=NUM_CLS, encoder='var-lstm')
ds = prepare_parser_data()
RUNNER.run_model(model, ds, metrics=ParserMetric())

def test_train2(self):
model = BiaffineParser(embed=(VOCAB_SIZE, 10),
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
rnn_hidden_size=16,
arc_mlp_size=10,
label_mlp_size=10,
num_label=NUM_CLS, encoder='transformer')
ds = prepare_parser_data()
RUNNER.run_model(model, ds, metrics=ParserMetric())

+ 33
- 0
tests/models/torch/test_cnn_text_classification.py View File

@@ -0,0 +1,33 @@
import pytest

from .model_runner import *
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from fastNLP.models.torch.cnn_text_classification import CNNText


@pytest.mark.torch
class TestCNNText:
def init_model(self, kernel_sizes, kernel_nums=(1,3,5)):
model = CNNText((VOCAB_SIZE, 30),
NUM_CLS,
kernel_nums=kernel_nums,
kernel_sizes=kernel_sizes)
return model

def test_case1(self):
# 测试能否正常运行CNN
model = self.init_model((1,3,5))
RUNNER.run_model_with_task(TEXT_CLS, model)

def test_init_model(self):
with pytest.raises(Exception):
self.init_model(2, 4)
with pytest.raises(Exception):
self.init_model((2,))

def test_output(self):
model = self.init_model((3,), (1,))
global MAX_LEN
MAX_LEN = 2
RUNNER.run_model_with_task(TEXT_CLS, model)

+ 73
- 0
tests/models/torch/test_seq2seq_generator.py View File

@@ -0,0 +1,73 @@
import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
from fastNLP.models.torch import LSTMSeq2SeqModel, TransformerSeq2SeqModel
import torch
from fastNLP.embeddings.torch import StaticEmbedding

from fastNLP import Vocabulary, DataSet
from fastNLP import Trainer, Accuracy
from fastNLP import Callback, TorchDataLoader


def prepare_env():
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)

src_words_idx = [[3, 1, 2], [1, 2]]
# tgt_words_idx = [[1, 2, 3, 4], [2, 3]]
src_seq_len = [3, 2]
# tgt_seq_len = [4, 2]

ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx,
'tgt_seq_len':src_seq_len})

dl = TorchDataLoader(ds, batch_size=32)
return embed, dl


class ExitCallback(Callback):
def __init__(self):
super().__init__()

def on_valid_end(self, trainer, results):
if results['acc#acc'] == 1:
raise KeyboardInterrupt()


@pytest.mark.torch
class TestSeq2SeqGeneratorModel:
def test_run(self):
# 检测是否能够使用SequenceGeneratorModel训练, 透传预测
embed, dl = prepare_env()
model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6,
dim_ff=20, dropout=0.1,
bind_encoder_decoder_embed=True,
bind_decoder_input_output_embed=True)
optimizer = torch.optim.Adam(model1.parameters(), lr=1e-3)
trainer = Trainer(model1, driver='torch', optimizers=optimizer, train_dataloader=dl,
n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()},
evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'],
'seq_len': x['tgt_seq_len'],
**x},
callbacks=ExitCallback())

trainer.run()

embed, dl = prepare_env()
model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
num_layers=1, hidden_size=20, dropout=0.1,
bind_encoder_decoder_embed=True,
bind_decoder_input_output_embed=True, attention=True)
optimizer = torch.optim.Adam(model2.parameters(), lr=0.01)
trainer = Trainer(model2, driver='torch', optimizers=optimizer, train_dataloader=dl,
n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()},
evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'],
'seq_len': x['tgt_seq_len'],
**x},
callbacks=ExitCallback())
trainer.run()

+ 113
- 0
tests/models/torch/test_seq2seq_model.py View File

@@ -0,0 +1,113 @@

import pytest
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from fastNLP.models.torch.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
from fastNLP import Vocabulary
from fastNLP.embeddings.torch import StaticEmbedding
import torch
from torch import optim
import torch.nn.functional as F
from fastNLP import seq_len_to_mask


def prepare_env():
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)

src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])
tgt_seq_len = torch.LongTensor([4, 2])

return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len


def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len):
optimizer = optim.Adam(model.parameters(), lr=1e-2)
mask = seq_len_to_mask(tgt_seq_len).eq(0)
target = tgt_words_idx.masked_fill(mask, -100)

for i in range(100):
optimizer.zero_grad()
pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size
loss = F.cross_entropy(pred.transpose(1, 2), target)
loss.backward()
optimizer.step()

right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum()
return right_count


@pytest.mark.torch
class TestTransformerSeq2SeqModel:
def test_run(self):
# 测试能否跑通
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
for pos_embed in ['learned', 'sin']:
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1,
bind_encoder_decoder_embed=True,
bind_decoder_input_output_embed=True)

output = model(src_words_idx, tgt_words_idx, src_seq_len)
assert (output['pred'].size() == (2, 4, len(embed)))

for bind_encoder_decoder_embed in [True, False]:
tgt_embed = None
for bind_decoder_input_output_embed in [True, False]:
if bind_encoder_decoder_embed == False:
tgt_embed = embed

model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed,
pos_embed='sin', max_position=20, num_layers=2,
d_model=30, n_head=6, dim_ff=20, dropout=0.1,
bind_encoder_decoder_embed=bind_encoder_decoder_embed,
bind_decoder_input_output_embed=bind_decoder_input_output_embed)

output = model(src_words_idx, tgt_words_idx, src_seq_len)
assert (output['pred'].size() == (2, 4, len(embed)))

def test_train(self):
# 测试能否train到overfit
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()

model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1,
bind_encoder_decoder_embed=True,
bind_decoder_input_output_embed=True)

right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len)
assert(right_count == tgt_words_idx.nelement())


@pytest.mark.torch
class TestLSTMSeq2SeqModel:
def test_run(self):
# 测试能否跑通
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()

for bind_encoder_decoder_embed in [True, False]:
tgt_embed = None
for bind_decoder_input_output_embed in [True, False]:
if bind_encoder_decoder_embed == False:
tgt_embed = embed
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed,
num_layers=2, hidden_size=20, dropout=0.1,
bind_encoder_decoder_embed=bind_encoder_decoder_embed,
bind_decoder_input_output_embed=bind_decoder_input_output_embed)
output = model(src_words_idx, tgt_words_idx, src_seq_len)
assert (output['pred'].size() == (2, 4, len(embed)))

def test_train(self):
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()

model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
num_layers=1, hidden_size=20, dropout=0.1,
bind_encoder_decoder_embed=True,
bind_decoder_input_output_embed=True)

right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len)
assert (right_count == tgt_words_idx.nelement())


+ 47
- 0
tests/models/torch/test_sequence_labeling.py View File

@@ -0,0 +1,47 @@
import pytest
from .model_runner import *
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from fastNLP.models.torch.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF


@pytest.mark.torch
class TestBiLSTM:
def test_case1(self):
# 测试能否正常运行CNN
init_emb = (VOCAB_SIZE, 30)
model = BiLSTMCRF(init_emb,
hidden_size=30,
num_classes=NUM_CLS)

dl = RUNNER.prepare_pos_tagging_data()
metric = Accuracy()
RUNNER.run_model(model, dl, metric)


@pytest.mark.torch
class TestSeqLabel:
def test_case1(self):
# 测试能否正常运行CNN
init_emb = (VOCAB_SIZE, 30)
model = SeqLabeling(init_emb,
hidden_size=30,
num_classes=NUM_CLS)

dl = RUNNER.prepare_pos_tagging_data()
metric = Accuracy()
RUNNER.run_model(model, dl, metric)


@pytest.mark.torch
class TestAdvSeqLabel:
def test_case1(self):
# 测试能否正常运行CNN
init_emb = (VOCAB_SIZE, 30)
model = AdvSeqLabel(init_emb,
hidden_size=30,
num_classes=NUM_CLS)

dl = RUNNER.prepare_pos_tagging_data()
metric = Accuracy()
RUNNER.run_model(model, dl, metric)

+ 0
- 0
tests/modules/torch/__init__.py View File


+ 0
- 0
tests/modules/torch/decoder/__init__.py View File


+ 327
- 0
tests/modules/torch/decoder/test_CRF.py View File

@@ -0,0 +1,327 @@
import pytest
import os
from fastNLP import Vocabulary


@pytest.mark.torch
class TestCRF:
def test_case1(self):
from fastNLP.modules.torch.decoder.crf import allowed_transitions
# 检查allowed_transitions()能否正确使用

id2label = {0: 'B', 1: 'I', 2:'O'}
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
(2, 4), (3, 0), (3, 2)}
assert expected_res == set(allowed_transitions(id2label, include_start_end=True))

id2label = {0: 'B', 1:'M', 2:'E', 3:'S'}
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
assert (expected_res == set(
allowed_transitions(id2label, encoding_type='BMES', include_start_end=True)))

id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"}
allowed_transitions(id2label, include_start_end=True)

labels = ['O']
for label in ['X', 'Y']:
for tag in 'BI':
labels.append('{}-{}'.format(tag, label))
id2label = {idx:label for idx, label in enumerate(labels)}
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True)))

labels = []
for label in ['X', 'Y']:
for tag in 'BMES':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
assert (expected_res == set(
allowed_transitions(id2label, include_start_end=True)))

def test_case11(self):
# 测试自动推断encoding类型
from fastNLP.modules.torch.decoder.crf import allowed_transitions

id2label = {0: 'B', 1: 'I', 2: 'O'}
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
(2, 4), (3, 0), (3, 2)}
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True)))

id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'}
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
assert (expected_res == set(
allowed_transitions(id2label, include_start_end=True)))

id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"}
allowed_transitions(id2label, include_start_end=True)

labels = ['O']
for label in ['X', 'Y']:
for tag in 'BI':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True)))

labels = []
for label in ['X', 'Y']:
for tag in 'BMES':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
assert (expected_res == set(
allowed_transitions(id2label, include_start_end=True)))

def test_case12(self):
# 测试能否通过vocab生成转移矩阵
from fastNLP.modules.torch.decoder.crf import allowed_transitions

id2label = {0: 'B', 1: 'I', 2: 'O'}
vocab = Vocabulary(unknown=None, padding=None)
for idx, tag in id2label.items():
vocab.add_word(tag)
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
(2, 4), (3, 0), (3, 2)}
assert (expected_res == set(allowed_transitions(vocab, include_start_end=True)))

id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'}
vocab = Vocabulary(unknown=None, padding=None)
for idx, tag in id2label.items():
vocab.add_word(tag)
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
assert (expected_res == set(
allowed_transitions(vocab, include_start_end=True)))

id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"}
vocab = Vocabulary()
for idx, tag in id2label.items():
vocab.add_word(tag)
allowed_transitions(vocab, include_start_end=True)

labels = ['O']
for label in ['X', 'Y']:
for tag in 'BI':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
vocab = Vocabulary(unknown=None, padding=None)
for idx, tag in id2label.items():
vocab.add_word(tag)
assert (expected_res == set(allowed_transitions(vocab, include_start_end=True)))

labels = []
for label in ['X', 'Y']:
for tag in 'BMES':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
vocab = Vocabulary(unknown=None, padding=None)
for idx, tag in id2label.items():
vocab.add_word(tag)
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
assert (expected_res == set(
allowed_transitions(vocab, include_start_end=True)))

# def test_case2(self):
# # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。
# pass
# import torch
# from fastNLP import seq_len_to_mask
#
# labels = ['O']
# for label in ['X', 'Y']:
# for tag in 'BI':
# labels.append('{}-{}'.format(tag, label))
# id2label = {idx: label for idx, label in enumerate(labels)}
# num_tags = len(id2label)
# max_len = 10
# batch_size = 4
# bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log()
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label),
# include_start_end_transitions=False)
# bio_trans_m = allen_CRF.transitions
# bio_seq_lens = torch.randint(1, max_len, size=(batch_size,))
# bio_seq_lens[0] = 1
# bio_seq_lens[-1] = max_len
# mask = seq_len_to_mask(bio_seq_lens)
# allen_res = allen_CRF.viterbi_tags(bio_logits, mask)
#
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# include_start_end=True))
# fast_CRF.trans_m = bio_trans_m
# fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True)
# bio_scores = [round(score, 4) for _, score in allen_res]
# # score equal
# self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()])
# # seq equal
# bio_path = [_ for _, score in allen_res]
# self.assertListEqual(bio_path, fast_res[0])
#
# labels = []
# for label in ['X', 'Y']:
# for tag in 'BMES':
# labels.append('{}-{}'.format(tag, label))
# id2label = {idx: label for idx, label in enumerate(labels)}
# num_tags = len(id2label)
#
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label),
# include_start_end_transitions=False)
# bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log()
# bmes_trans_m = allen_CRF.transitions
# bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,))
# bmes_seq_lens[0] = 1
# bmes_seq_lens[-1] = max_len
# mask = seq_len_to_mask(bmes_seq_lens)
# allen_res = allen_CRF.viterbi_tags(bmes_logits, mask)
#
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# encoding_type='BMES',
# include_start_end=True))
# fast_CRF.trans_m = bmes_trans_m
# fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True)
# # score equal
# bmes_scores = [round(score, 4) for _, score in allen_res]
# self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()])
# # seq equal
# bmes_path = [_ for _, score in allen_res]
# self.assertListEqual(bmes_path, fast_res[0])
#
# data = {
# 'bio_logits': bio_logits.tolist(),
# 'bio_scores': bio_scores,
# 'bio_path': bio_path,
# 'bio_trans_m': bio_trans_m.tolist(),
# 'bio_seq_lens': bio_seq_lens.tolist(),
# 'bmes_logits': bmes_logits.tolist(),
# 'bmes_scores': bmes_scores,
# 'bmes_path': bmes_path,
# 'bmes_trans_m': bmes_trans_m.tolist(),
# 'bmes_seq_lens': bmes_seq_lens.tolist(),
# }
#
# with open('weights.json', 'w') as f:
# import json
# json.dump(data, f)

def test_case2(self):
# 测试CRF是否正常work。
import json
import torch
from fastNLP import seq_len_to_mask
folder = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(folder, '../../../', 'helpers/data/modules/decoder/crf.json')

with open(os.path.abspath(path), 'r') as f:
data = json.load(f)

bio_logits = torch.FloatTensor(data['bio_logits'])
bio_scores = data['bio_scores']
bio_path = data['bio_path']
bio_trans_m = torch.FloatTensor(data['bio_trans_m'])
bio_seq_lens = torch.LongTensor(data['bio_seq_lens'])

bmes_logits = torch.FloatTensor(data['bmes_logits'])
bmes_scores = data['bmes_scores']
bmes_path = data['bmes_path']
bmes_trans_m = torch.FloatTensor(data['bmes_trans_m'])
bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens'])

labels = ['O']
for label in ['X', 'Y']:
for tag in 'BI':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
num_tags = len(id2label)

mask = seq_len_to_mask(bio_seq_lens)

from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
include_start_end=True))
fast_CRF.trans_m.data = bio_trans_m
fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True)
# score equal
assert (bio_scores == [round(s, 4) for s in fast_res[1].tolist()])
# seq equal
assert (bio_path == fast_res[0])

labels = []
for label in ['X', 'Y']:
for tag in 'BMES':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
num_tags = len(id2label)

mask = seq_len_to_mask(bmes_seq_lens)

from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
encoding_type='BMES',
include_start_end=True))
fast_CRF.trans_m.data = bmes_trans_m
fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True)
# score equal
assert (bmes_scores == [round(s, 4) for s in fast_res[1].tolist()])
# seq equal
assert (bmes_path == fast_res[0])

def test_case3(self):
# 测试crf的loss不会出现负数
import torch
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField
from fastNLP.core.utils import seq_len_to_mask
from torch import optim
from torch import nn

num_tags, include_start_end_trans = 4, True
num_samples = 4
lengths = torch.randint(3, 50, size=(num_samples, )).long()
max_len = lengths.max()
tags = torch.randint(num_tags, size=(num_samples, max_len))
masks = seq_len_to_mask(lengths)
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags))
crf = ConditionalRandomField(num_tags, include_start_end_trans)
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1)
for _ in range(10):
loss = crf(feats, tags, masks).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if _%1000==0:
print(loss)
assert (loss.item()> 0)

def test_masking(self):
# 测试crf的pad masking正常运行
import torch
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField
max_len = 5
n_tags = 5
pad_len = 5

torch.manual_seed(4)
logit = torch.rand(1, max_len+pad_len, n_tags)
# logit[0, -1, :] = 0.0
mask = torch.ones(1, max_len+pad_len)
mask[0,-pad_len] = 0
model = ConditionalRandomField(n_tags)
pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len])
mask_pred, mask_score = model.viterbi_decode(logit, mask)
assert (pred[0].tolist() == mask_pred[0,:-pad_len].tolist())


+ 49
- 0
tests/modules/torch/decoder/test_seq2seq_decoder.py View File

@@ -0,0 +1,49 @@
import pytest
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch

from fastNLP import Vocabulary
from fastNLP.embeddings.torch import StaticEmbedding
from fastNLP.modules.torch import TransformerSeq2SeqDecoder
from fastNLP.modules.torch import LSTMSeq2SeqDecoder
from fastNLP import seq_len_to_mask

@pytest.mark.torch
class TestTransformerSeq2SeqDecoder:
@pytest.mark.parametrize("flag", [True, False])
def test_case(self, flag):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=10)

encoder_output = torch.randn(2, 3, 10)
src_seq_len = torch.LongTensor([3, 2])
encoder_mask = seq_len_to_mask(src_seq_len)
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None,
d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1,
bind_decoder_input_output_embed = True)
state = decoder.init_state(encoder_output, encoder_mask)
output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state)
assert (output.size() == (2, 4, len(vocab)))


@pytest.mark.torch
class TestLSTMDecoder:
@pytest.mark.parametrize("flag", [True, False])
@pytest.mark.parametrize("attention", [True, False])
def test_case(self, flag, attention):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10)

encoder_output = torch.randn(2, 3, 10)
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])
encoder_mask = seq_len_to_mask(src_seq_len)
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10,
dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention)
state = decoder.init_state(encoder_output, encoder_mask)
output = decoder(tgt_words_idx, state)
assert tuple(output.size()) == (2, 4, len(vocab))

+ 0
- 0
tests/modules/torch/encoder/__init__.py View File


+ 33
- 0
tests/modules/torch/encoder/test_seq2seq_encoder.py View File

@@ -0,0 +1,33 @@
import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch

from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
from fastNLP import Vocabulary
from fastNLP.embeddings.torch import StaticEmbedding


class TestTransformerSeq2SeqEncoder:
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = StaticEmbedding(vocab, embedding_dim=5)
encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2)
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
seq_len = torch.LongTensor([3])
encoder_output, encoder_mask = encoder(words_idx, seq_len)
assert (encoder_output.size() == (1, 3, 10))


class TestBiLSTMEncoder:
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = StaticEmbedding(vocab, embedding_dim=5)
encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1)
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
seq_len = torch.LongTensor([3])

encoder_output, encoder_mask = encoder(words_idx, seq_len)
assert (encoder_mask.size() == (1, 3))

+ 18
- 0
tests/modules/torch/encoder/test_star_transformer.py View File

@@ -0,0 +1,18 @@
import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch
from fastNLP.modules.torch.encoder.star_transformer import StarTransformer


@pytest.mark.torch
class TestStarTransformer:
def test_1(self):
model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100)
x = torch.rand(16, 45, 100)
mask = torch.ones(16, 45).byte()
y, yn = model(x, mask)
assert (tuple(y.size()) == (16, 45, 100))
assert (tuple(yn.size()) == (16, 100))

+ 27
- 0
tests/modules/torch/encoder/test_variational_rnn.py View File

@@ -0,0 +1,27 @@
import pytest

import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch
from fastNLP.modules.torch.encoder.variational_rnn import VarLSTM


@pytest.mark.torch
class TestMaskedRnn:
def test_case_1(self):
masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)

def test_case_2(self):
input_size = 12
batch = 16
hidden = 10
masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)

xx = torch.randn((batch, 32, input_size))
y, _ = masked_rnn(xx)
assert(tuple(y.shape) == (batch, 32, hidden))

+ 1
- 0
tests/modules/torch/generator/__init__.py View File

@@ -0,0 +1 @@


+ 146
- 0
tests/modules/torch/generator/test_seq2seq_generator.py View File

@@ -0,0 +1,146 @@
import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from fastNLP.modules.torch.generator import SequenceGenerator
from fastNLP.modules.torch import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State
from fastNLP import Vocabulary
from fastNLP.embeddings.torch import StaticEmbedding
from torch import nn
from fastNLP import seq_len_to_mask
else:
from fastNLP.core.utils.dummy_class import DummyClass as State
from fastNLP.core.utils.dummy_class import DummyClass as Seq2SeqDecoder


def prepare_env():
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)

encoder_output = torch.randn(2, 3, 10)
src_seq_len = torch.LongTensor([3, 2])
encoder_mask = seq_len_to_mask(src_seq_len)

return embed, encoder_output, encoder_mask


class GreedyDummyDecoder(Seq2SeqDecoder):
def __init__(self, decoder_output):
super().__init__()
self.cur_length = 0
self.decoder_output = decoder_output

def decode(self, tokens, state):
self.cur_length += 1
scores = self.decoder_output[:, self.cur_length]
return scores


class DummyState(State):
def __init__(self, decoder):
super().__init__()
self.decoder = decoder

def reorder_state(self, indices: torch.LongTensor):
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0)


@pytest.mark.torch
class TestSequenceGenerator:
def test_run(self):
# 测试能否运行 (1) 初始化decoder,(2) decode一发
embed, encoder_output, encoder_mask = prepare_env()

for do_sample in [True, False]:
for num_beams in [1, 3, 5]:
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10,
dropout=0.3, bind_decoder_input_output_embed=True, attention=True)
state = decoder.init_state(encoder_output, encoder_mask)
generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams,
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0, pad_token_id=0)
generator.generate(state=state, tokens=None)

decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim),
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1,
bind_decoder_input_output_embed=True)
state = decoder.init_state(encoder_output, encoder_mask)
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams,
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0, pad_token_id=0)
generator.generate(state=state, tokens=None)

# 测试一下其它值
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim),
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10,
dropout=0.1,
bind_decoder_input_output_embed=True)
state = decoder.init_state(encoder_output, encoder_mask)
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams,
do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1,
eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0)
generator.generate(state=state, tokens=None)

def test_greedy_decode(self):
# 测试能否正确的generate
# greedy
for beam_search in [1, 3]:
decoder_output = torch.randn(2, 10, 5)
path = decoder_output.argmax(dim=-1) # 2 x 10
decoder = GreedyDummyDecoder(decoder_output)
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1,
eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0)
decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))

assert (decode_path.eq(path).sum() == path.numel())

# greedy check eos_token_id
for beam_search in [1, 3]:
decoder_output = torch.randn(2, 10, 5)
decoder_output[:, :7, 4].fill_(-100)
decoder_output[0, 7, 4] = 1000 # 在第8个结束
decoder_output[1, 5, 4] = 1000
path = decoder_output.argmax(dim=-1) # 2 x 4
decoder = GreedyDummyDecoder(decoder_output)
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1,
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0)
decode_path = generator.generate(DummyState(decoder),
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
assert (decode_path.size(1) == 8) # 长度为8
assert (decode_path[0].eq(path[0, :8]).sum() == 8)
assert (decode_path[1, :6].eq(path[1, :6]).sum() == 6)

def test_sample_decoder(self):
# greedy check eos_token_id
for beam_search in [1, 3]:
decode_paths = []
# 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大
num_tests = 10
for i in range(num_tests):
decoder_output = torch.randn(2, 10, 5) * 10
decoder_output[:, :7, 4].fill_(-100)
decoder_output[0, 7, 4] = 10000 # 在第8个结束
decoder_output[1, 5, 4] = 10000
path = decoder_output.argmax(dim=-1) # 2 x 4
decoder = GreedyDummyDecoder(decoder_output)
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1,
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0)
decode_path = generator.generate(DummyState(decoder),
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
decode_paths.append([decode_path, path])
sizes = []
eqs = []
eq2s = []
for i in range(num_tests):
decode_path, path = decode_paths[i]
sizes.append(decode_path.size(1)==8)
eqs.append(decode_path[0].eq(path[0, :8]).sum()==8)
eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6)
assert any(sizes)
assert any(eqs)
assert any(eq2s)

Loading…
Cancel
Save