@@ -138,6 +138,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | |||
"information please set logger's level to DEBUG." | |||
if must_pad: | |||
logger.error(msg) | |||
raise type(e)(msg=msg) | |||
logger.debug(msg) | |||
return NullPadder() | |||
@@ -16,6 +16,7 @@ if _NEED_IMPORT_TORCH: | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.nn import LSTM | |||
from .embedding import TokenEmbedding | |||
from .static_embedding import StaticEmbedding | |||
@@ -23,7 +24,6 @@ from .utils import _construct_char_vocab_from_vocab | |||
from .utils import get_embeddings | |||
from ...core import logger | |||
from ...core.vocabulary import Vocabulary | |||
from ...modules.torch.encoder.lstm import LSTM | |||
class CNNCharEmbedding(TokenEmbedding): | |||
@@ -0,0 +1,21 @@ | |||
__all__ = [ | |||
'BiaffineParser', | |||
"CNNText", | |||
"SequenceGeneratorModel", | |||
"Seq2SeqModel", | |||
'TransformerSeq2SeqModel', | |||
'LSTMSeq2SeqModel', | |||
"SeqLabeling", | |||
"AdvSeqLabel", | |||
"BiLSTMCRF", | |||
] | |||
from .biaffine_parser import BiaffineParser | |||
from .cnn_text_classification import CNNText | |||
from .seq2seq_generator import SequenceGeneratorModel | |||
from .seq2seq_model import * | |||
from .sequence_labeling import * |
@@ -0,0 +1,475 @@ | |||
r""" | |||
Biaffine Dependency Parser 的 Pytorch 实现. | |||
""" | |||
__all__ = [ | |||
"BiaffineParser", | |||
"GraphParser" | |||
] | |||
from collections import defaultdict | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from ...core.utils import seq_len_to_mask | |||
from ...embeddings.torch.utils import get_embeddings | |||
from ...modules.torch.dropout import TimestepDropout | |||
from ...modules.torch.encoder.transformer import TransformerEncoder | |||
from ...modules.torch.encoder.variational_rnn import VarLSTM | |||
def _mst(scores): | |||
r""" | |||
with some modification to support parser output for MST decoding | |||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | |||
""" | |||
length = scores.shape[0] | |||
min_score = scores.min() - 1 | |||
eye = np.eye(length) | |||
scores = scores * (1 - eye) + min_score * eye | |||
heads = np.argmax(scores, axis=1) | |||
heads[0] = 0 | |||
tokens = np.arange(1, length) | |||
roots = np.where(heads[tokens] == 0)[0] + 1 | |||
if len(roots) < 1: | |||
root_scores = scores[tokens, 0] | |||
head_scores = scores[tokens, heads[tokens]] | |||
new_root = tokens[np.argmax(root_scores / head_scores)] | |||
heads[new_root] = 0 | |||
elif len(roots) > 1: | |||
root_scores = scores[roots, 0] | |||
scores[roots, 0] = 0 | |||
new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1 | |||
new_root = roots[np.argmin( | |||
scores[roots, new_heads] / root_scores)] | |||
heads[roots] = new_heads | |||
heads[new_root] = 0 | |||
edges = defaultdict(set) | |||
vertices = set((0,)) | |||
for dep, head in enumerate(heads[tokens]): | |||
vertices.add(dep + 1) | |||
edges[head].add(dep + 1) | |||
for cycle in _find_cycle(vertices, edges): | |||
dependents = set() | |||
to_visit = set(cycle) | |||
while len(to_visit) > 0: | |||
node = to_visit.pop() | |||
if node not in dependents: | |||
dependents.add(node) | |||
to_visit.update(edges[node]) | |||
cycle = np.array(list(cycle)) | |||
old_heads = heads[cycle] | |||
old_scores = scores[cycle, old_heads] | |||
non_heads = np.array(list(dependents)) | |||
scores[np.repeat(cycle, len(non_heads)), | |||
np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score | |||
new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1 | |||
new_scores = scores[cycle, new_heads] / old_scores | |||
change = np.argmax(new_scores) | |||
changed_cycle = cycle[change] | |||
old_head = old_heads[change] | |||
new_head = new_heads[change] | |||
heads[changed_cycle] = new_head | |||
edges[new_head].add(changed_cycle) | |||
edges[old_head].remove(changed_cycle) | |||
return heads | |||
def _find_cycle(vertices, edges): | |||
r""" | |||
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm | |||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py | |||
""" | |||
_index = 0 | |||
_stack = [] | |||
_indices = {} | |||
_lowlinks = {} | |||
_onstack = defaultdict(lambda: False) | |||
_SCCs = [] | |||
def _strongconnect(v): | |||
nonlocal _index | |||
_indices[v] = _index | |||
_lowlinks[v] = _index | |||
_index += 1 | |||
_stack.append(v) | |||
_onstack[v] = True | |||
for w in edges[v]: | |||
if w not in _indices: | |||
_strongconnect(w) | |||
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) | |||
elif _onstack[w]: | |||
_lowlinks[v] = min(_lowlinks[v], _indices[w]) | |||
if _lowlinks[v] == _indices[v]: | |||
SCC = set() | |||
while True: | |||
w = _stack.pop() | |||
_onstack[w] = False | |||
SCC.add(w) | |||
if not (w != v): | |||
break | |||
_SCCs.append(SCC) | |||
for v in vertices: | |||
if v not in _indices: | |||
_strongconnect(v) | |||
return [SCC for SCC in _SCCs if len(SCC) > 1] | |||
class GraphParser(nn.Module): | |||
r""" | |||
基于图的parser base class, 支持贪婪解码和最大生成树解码 | |||
""" | |||
def __init__(self): | |||
super(GraphParser, self).__init__() | |||
@staticmethod | |||
def greedy_decoder(arc_matrix, mask=None): | |||
r""" | |||
贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 | |||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||
:param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. | |||
若为 ``None`` 时, 默认为全1向量. Default: ``None`` | |||
:return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 | |||
""" | |||
_, seq_len, _ = arc_matrix.shape | |||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | |||
flip_mask = mask.eq(False) | |||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
_, heads = torch.max(matrix, dim=2) | |||
if mask is not None: | |||
heads *= mask.long() | |||
return heads | |||
@staticmethod | |||
def mst_decoder(arc_matrix, mask=None): | |||
r""" | |||
用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 | |||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||
:param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. | |||
若为 ``None`` 时, 默认为全1向量. Default: ``None`` | |||
:return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 | |||
""" | |||
batch_size, seq_len, _ = arc_matrix.shape | |||
matrix = arc_matrix.clone() | |||
ans = matrix.new_zeros(batch_size, seq_len).long() | |||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | |||
for i, graph in enumerate(matrix): | |||
len_i = lens[i] | |||
ans[i, :len_i] = torch.as_tensor(_mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||
if mask is not None: | |||
ans *= mask.long() | |||
return ans | |||
class ArcBiaffine(nn.Module): | |||
r""" | |||
Biaffine Dependency Parser 的子模块, 用于构建预测边的图 | |||
""" | |||
def __init__(self, hidden_size, bias=True): | |||
r""" | |||
:param hidden_size: 输入的特征维度 | |||
:param bias: 是否使用bias. Default: ``True`` | |||
""" | |||
super(ArcBiaffine, self).__init__() | |||
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) | |||
self.has_bias = bias | |||
if self.has_bias: | |||
self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True) | |||
else: | |||
self.register_parameter("bias", None) | |||
def forward(self, head, dep): | |||
r""" | |||
:param head: arc-head tensor [batch, length, hidden] | |||
:param dep: arc-dependent tensor [batch, length, hidden] | |||
:return output: tensor [bacth, length, length] | |||
""" | |||
output = dep.matmul(self.U) | |||
output = output.bmm(head.transpose(-1, -2)) | |||
if self.has_bias: | |||
output = output + head.matmul(self.bias).unsqueeze(1) | |||
return output | |||
class LabelBilinear(nn.Module): | |||
r""" | |||
Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 | |||
""" | |||
def __init__(self, in1_features, in2_features, num_label, bias=True): | |||
r""" | |||
:param in1_features: 输入的特征1维度 | |||
:param in2_features: 输入的特征2维度 | |||
:param num_label: 边类别的个数 | |||
:param bias: 是否使用bias. Default: ``True`` | |||
""" | |||
super(LabelBilinear, self).__init__() | |||
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | |||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||
def forward(self, x1, x2): | |||
r""" | |||
:param x1: [batch, seq_len, hidden] 输入特征1, 即label-head | |||
:param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep | |||
:return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 | |||
""" | |||
output = self.bilinear(x1, x2) | |||
output = output + self.lin(torch.cat([x1, x2], dim=2)) | |||
return output | |||
class BiaffineParser(GraphParser): | |||
r""" | |||
Biaffine Dependency Parser 实现. | |||
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . | |||
""" | |||
def __init__(self, | |||
embed, | |||
pos_vocab_size, | |||
pos_emb_dim, | |||
num_label, | |||
rnn_layers=1, | |||
rnn_hidden_size=200, | |||
arc_mlp_size=100, | |||
label_mlp_size=100, | |||
dropout=0.3, | |||
encoder='lstm', | |||
use_greedy_infer=False): | |||
r""" | |||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
此时就以传入的对象作为embedding | |||
:param pos_vocab_size: part-of-speech 词典大小 | |||
:param pos_emb_dim: part-of-speech 向量维度 | |||
:param num_label: 边的类别个数 | |||
:param rnn_layers: rnn encoder的层数 | |||
:param rnn_hidden_size: rnn encoder 的隐状态维度 | |||
:param arc_mlp_size: 边预测的MLP维度 | |||
:param label_mlp_size: 类别预测的MLP维度 | |||
:param dropout: dropout概率. | |||
:param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm | |||
:param use_greedy_infer: 是否在inference时使用贪心算法. | |||
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` | |||
""" | |||
super(BiaffineParser, self).__init__() | |||
rnn_out_size = 2 * rnn_hidden_size | |||
word_hid_dim = pos_hid_dim = rnn_hidden_size | |||
self.word_embedding = get_embeddings(embed) | |||
word_emb_dim = self.word_embedding.embedding_dim | |||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | |||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||
self.word_norm = nn.LayerNorm(word_hid_dim) | |||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||
self.encoder_name = encoder | |||
self.max_len = 512 | |||
if encoder == 'var-lstm': | |||
self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
batch_first=True, | |||
input_dropout=dropout, | |||
hidden_dropout=dropout, | |||
bidirectional=True) | |||
elif encoder == 'lstm': | |||
self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
batch_first=True, | |||
dropout=dropout, | |||
bidirectional=True) | |||
elif encoder == 'transformer': | |||
n_head = 16 | |||
d_k = d_v = int(rnn_out_size / n_head) | |||
if (d_k * n_head) != rnn_out_size: | |||
raise ValueError('Unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | |||
self.position_emb = nn.Embedding(num_embeddings=self.max_len, | |||
embedding_dim=rnn_out_size, ) | |||
self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size, | |||
n_head=n_head, dim_ff=1024, dropout=dropout) | |||
else: | |||
raise ValueError('Unsupported encoder type: {}'.format(encoder)) | |||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||
nn.ELU(), | |||
TimestepDropout(p=dropout), ) | |||
self.arc_mlp_size = arc_mlp_size | |||
self.label_mlp_size = label_mlp_size | |||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||
self.use_greedy_infer = use_greedy_infer | |||
self.reset_parameters() | |||
self.dropout = dropout | |||
def reset_parameters(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Embedding): | |||
continue | |||
elif isinstance(m, nn.LayerNorm): | |||
nn.init.constant_(m.weight, 0.1) | |||
nn.init.constant_(m.bias, 0) | |||
else: | |||
for p in m.parameters(): | |||
nn.init.normal_(p, 0, 0.1) | |||
def forward(self, words1, words2, seq_len, target1=None): | |||
r"""模型forward阶段 | |||
:param words1: [batch_size, seq_len] 输入word序列 | |||
:param words2: [batch_size, seq_len] 输入pos序列 | |||
:param seq_len: [batch_size, seq_len] 输入序列长度 | |||
:param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | |||
Default: ``None`` | |||
:return dict: parsing | |||
结果:: | |||
pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
pred2: [batch_size, seq_len, num_label] label预测logits | |||
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 | |||
""" | |||
# prepare embeddings | |||
batch_size, length = words1.shape | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
# get sequence mask | |||
mask = seq_len_to_mask(seq_len, max_len=length).long() | |||
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | |||
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | |||
word, pos = self.word_fc(word), self.pos_fc(pos) | |||
word, pos = self.word_norm(word), self.pos_norm(pos) | |||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||
# encoder, extract features | |||
if self.encoder_name.endswith('lstm'): | |||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||
x = x[sort_idx] | |||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True) | |||
feat, _ = self.encoder(x) # -> [N,L,C] | |||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
feat = feat[unsort_idx] | |||
else: | |||
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None, :] | |||
x = x + self.position_emb(seq_range) | |||
feat = self.encoder(x, mask.float()) | |||
# for arc biaffine | |||
# mlp, reduce dim | |||
feat = self.mlp(feat) | |||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||
arc_dep, arc_head = feat[:, :, :arc_sz], feat[:, :, arc_sz:2 * arc_sz] | |||
label_dep, label_head = feat[:, :, 2 * arc_sz:2 * arc_sz + label_sz], feat[:, :, 2 * arc_sz + label_sz:] | |||
# biaffine arc classifier | |||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
# use gold or predicted arc to predict label | |||
if target1 is None or not self.training: | |||
# use greedy decoding in training | |||
if self.training or self.use_greedy_infer: | |||
heads = self.greedy_decoder(arc_pred, mask) | |||
else: | |||
heads = self.mst_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
assert self.training # must be training mode | |||
if target1 is None: | |||
heads = self.greedy_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
head_pred = None | |||
heads = target1 | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | |||
label_head = label_head[batch_range, heads].contiguous() | |||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||
res_dict = {'pred1': arc_pred, 'pred2': label_pred} | |||
if head_pred is not None: | |||
res_dict['pred3'] = head_pred | |||
return res_dict | |||
def train_step(self, words1, words2, seq_len, target1, target2): | |||
res = self(words1, words2, seq_len, target1) | |||
arc_pred = res['pred1'] | |||
label_pred = res['pred2'] | |||
loss = self.loss(pred1=arc_pred, pred2=label_pred, target1=target1, target2=target2, seq_len=seq_len) | |||
return {'loss': loss} | |||
@staticmethod | |||
def loss(pred1, pred2, target1, target2, seq_len): | |||
r""" | |||
计算parser的loss | |||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
:param pred2: [batch_size, seq_len, num_label] label预测logits | |||
:param target1: [batch_size, seq_len] 真实边的标注 | |||
:param target2: [batch_size, seq_len] 真实类别的标注 | |||
:param seq_len: [batch_size, seq_len] 真实目标的长度 | |||
:return loss: scalar | |||
""" | |||
batch_size, length, _ = pred1.shape | |||
mask = seq_len_to_mask(seq_len, max_len=length) | |||
flip_mask = (mask.eq(False)) | |||
_arc_pred = pred1.clone() | |||
_arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | |||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||
label_logits = F.log_softmax(pred2, dim=2) | |||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||
child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||
arc_loss = arc_logits[batch_index, child_index, target1] | |||
label_loss = label_logits[batch_index, child_index, target2] | |||
arc_loss = arc_loss.masked_fill(flip_mask, 0) | |||
label_loss = label_loss.masked_fill(flip_mask, 0) | |||
arc_nll = -arc_loss.mean() | |||
label_nll = -label_loss.mean() | |||
return arc_nll + label_nll | |||
def evaluate_step(self, words1, words2, seq_len): | |||
r"""模型预测API | |||
:param words1: [batch_size, seq_len] 输入word序列 | |||
:param words2: [batch_size, seq_len] 输入pos序列 | |||
:param seq_len: [batch_size, seq_len] 输入序列长度 | |||
:return dict: parsing | |||
结果:: | |||
pred1: [batch_size, seq_len] heads的预测结果 | |||
pred2: [batch_size, seq_len, num_label] label预测logits | |||
""" | |||
res = self(words1, words2, seq_len) | |||
output = {} | |||
output['pred1'] = res.pop('pred3') | |||
_, label_pred = res.pop('pred2').max(2) | |||
output['pred2'] = label_pred | |||
return output | |||
@@ -0,0 +1,92 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"CNNText" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from ...core.utils import seq_len_to_mask | |||
from ...embeddings.torch import embedding | |||
from ...modules.torch import encoder | |||
class CNNText(torch.nn.Module): | |||
r""" | |||
使用CNN进行文本分类的模型 | |||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' | |||
""" | |||
def __init__(self, embed, | |||
num_classes, | |||
kernel_nums=(30, 40, 50), | |||
kernel_sizes=(1, 3, 5), | |||
dropout=0.5): | |||
r""" | |||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||
:param int num_classes: 一共有多少类 | |||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
:param float dropout: Dropout的大小 | |||
""" | |||
super(CNNText, self).__init__() | |||
# no support for pre-trained embedding currently | |||
self.embed = embedding.Embedding(embed) | |||
self.conv_pool = encoder.ConvMaxpool( | |||
in_channels=self.embed.embedding_dim, | |||
out_channels=kernel_nums, | |||
kernel_sizes=kernel_sizes) | |||
self.dropout = nn.Dropout(dropout) | |||
self.fc = nn.Linear(sum(kernel_nums), num_classes) | |||
def forward(self, words, seq_len=None): | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||
:param target: 每个 sample 的目标值。 | |||
:return output: | |||
""" | |||
x = self.embed(words) # [N,L] -> [N,L,C] | |||
if seq_len is not None: | |||
mask = seq_len_to_mask(seq_len) | |||
x = self.conv_pool(x, mask) | |||
else: | |||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||
x = self.dropout(x) | |||
x = self.fc(x) # [N,C] -> [N, N_class] | |||
res = {'pred': x} | |||
return res | |||
def train_step(self, words, target, seq_len=None): | |||
""" | |||
:param words: | |||
:param target: | |||
:param seq_len: | |||
:return: | |||
""" | |||
res = self(words, seq_len) | |||
x = res['pred'] | |||
loss = F.cross_entropy(x, target) | |||
return {'loss': loss} | |||
def evaluate_step(self, words, seq_len=None): | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||
:return predict: dict of torch.LongTensor, [batch_size, ] | |||
""" | |||
output = self(words, seq_len) | |||
_, predict = output['pred'].max(dim=1) | |||
return {'pred': predict} |
@@ -0,0 +1,81 @@ | |||
r"""undocumented""" | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from fastNLP import seq_len_to_mask | |||
from .seq2seq_model import Seq2SeqModel | |||
from ...modules.torch.generator.seq2seq_generator import SequenceGenerator | |||
__all__ = ['SequenceGeneratorModel'] | |||
class SequenceGeneratorModel(nn.Module): | |||
""" | |||
通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict | |||
函数会被调用。 | |||
""" | |||
def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, | |||
num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
""" | |||
:param Seq2SeqModel seq2seq_model: 序列到序列模型 | |||
:param int,None bos_token_id: 句子开头的token id | |||
:param int,None eos_token_id: 句子结束的token id | |||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
:param int num_beams: beam search的大小 | |||
:param bool do_sample: 是否通过采样的方式生成 | |||
:param float temperature: 只有在do_sample为True才有意义 | |||
:param int top_k: 只从top_k中采样 | |||
:param float top_p: 只从top_p的token中采样,nucles sample | |||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
""" | |||
super().__init__() | |||
self.seq2seq_model = seq2seq_model | |||
self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a, | |||
num_beams=num_beams, | |||
do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, | |||
bos_token_id=bos_token_id, | |||
eos_token_id=eos_token_id, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
""" | |||
透传调用seq2seq_model的forward。 | |||
:param torch.LongTensor src_tokens: bsz x max_len | |||
:param torch.LongTensor tgt_tokens: bsz x max_len' | |||
:param torch.LongTensor src_seq_len: bsz | |||
:param torch.LongTensor tgt_seq_len: bsz | |||
:return: | |||
""" | |||
return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||
def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||
pred = res['pred'] | |||
if tgt_seq_len is not None: | |||
mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | |||
tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | |||
loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) | |||
return {'loss': loss} | |||
def evaluate_step(self, src_tokens, src_seq_len=None): | |||
""" | |||
给定source的内容,输出generate的内容。 | |||
:param torch.LongTensor src_tokens: bsz x max_len | |||
:param torch.LongTensor src_seq_len: bsz | |||
:return: | |||
""" | |||
state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) | |||
result = self.generator.generate(state) | |||
return {'pred': result} |
@@ -0,0 +1,196 @@ | |||
r""" | |||
主要包含组成Sequence-to-Sequence的model | |||
""" | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from fastNLP import seq_len_to_mask | |||
from ...embeddings.torch.utils import get_embeddings | |||
from ...embeddings.torch.utils import get_sinusoid_encoding_table | |||
from ...modules.torch.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder | |||
from ...modules.torch.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
__all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel'] | |||
class Seq2SeqModel(nn.Module): | |||
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | |||
""" | |||
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。如果需要使用该模型 | |||
进行生成,需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 中。在本模型中,forward()会把encoder后的 | |||
结果传入到decoder中,并将decoder的输出output出来。 | |||
:param encoder: Seq2SeqEncoder 对象,需要实现对应的forward()函数,接受两个参数,第一个为bsz x max_len的source tokens, 第二个为 | |||
bsz的source的长度;需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len | |||
为1的地方需要被attend。如果encoder的输出或者输入有变化,可以重载本模型的prepare_state()函数或者forward()函数 | |||
:param decoder: Seq2SeqDecoder 对象,需要实现init_state()函数,输出为两个参数,第一个为bsz x max_len x hidden_size是 | |||
encoder的输出; 第二个为bsz x max_len,为encoder输出的mask,为0的地方为pad。若decoder需要更多输入,请重载当前模型的 | |||
prepare_state()或forward()函数 | |||
""" | |||
super().__init__() | |||
self.encoder = encoder | |||
self.decoder = decoder | |||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
""" | |||
:param torch.LongTensor src_tokens: source的token | |||
:param torch.LongTensor tgt_tokens: target的token | |||
:param torch.LongTensor src_seq_len: src的长度 | |||
:param torch.LongTensor tgt_seq_len: target的长度,默认用不上 | |||
:return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size | |||
""" | |||
state = self.prepare_state(src_tokens, src_seq_len) | |||
decoder_output = self.decoder(tgt_tokens, state) | |||
if isinstance(decoder_output, torch.Tensor): | |||
return {'pred': decoder_output} | |||
elif isinstance(decoder_output, (tuple, list)): | |||
return {'pred': decoder_output[0]} | |||
else: | |||
raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}") | |||
def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||
pred = res['pred'] | |||
if tgt_seq_len is not None: | |||
mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | |||
tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | |||
loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) | |||
return {'loss': loss} | |||
def prepare_state(self, src_tokens, src_seq_len=None): | |||
""" | |||
调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state | |||
:param src_tokens: | |||
:param src_seq_len: | |||
:return: | |||
""" | |||
encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) | |||
state = self.decoder.init_state(encoder_output, encoder_mask) | |||
return state | |||
@classmethod | |||
def build_model(cls, *args, **kwargs): | |||
""" | |||
需要实现本方法来进行Seq2SeqModel的初始化 | |||
:return: | |||
""" | |||
raise NotImplemented | |||
class TransformerSeq2SeqModel(Seq2SeqModel): | |||
""" | |||
Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化 | |||
""" | |||
def __init__(self, encoder, decoder): | |||
super().__init__(encoder, decoder) | |||
@classmethod | |||
def build_model(cls, src_embed, tgt_embed=None, | |||
pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, | |||
bind_encoder_decoder_embed=False, | |||
bind_decoder_input_output_embed=True): | |||
""" | |||
初始化一个TransformerSeq2SeqModel | |||
:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||
True,则不要输入该值 | |||
:param str pos_embed: 支持sin, learned两种 | |||
:param int max_position: 最大支持长度 | |||
:param int num_layers: encoder和decoder的层数 | |||
:param int d_model: encoder和decoder输入输出的大小 | |||
:param int n_head: encoder和decoder的head的数量 | |||
:param int dim_ff: encoder和decoder中FFN中间映射的维度 | |||
:param float dropout: Attention和FFN dropout的大小 | |||
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||
:return: TransformerSeq2SeqModel | |||
""" | |||
if bind_encoder_decoder_embed and tgt_embed is not None: | |||
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||
src_embed = get_embeddings(src_embed) | |||
if bind_encoder_decoder_embed: | |||
tgt_embed = src_embed | |||
else: | |||
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||
tgt_embed = get_embeddings(tgt_embed) | |||
if pos_embed == 'sin': | |||
encoder_pos_embed = nn.Embedding.from_pretrained( | |||
get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), | |||
freeze=True) # 这里规定0是padding | |||
deocder_pos_embed = nn.Embedding.from_pretrained( | |||
get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), | |||
freeze=True) # 这里规定0是padding | |||
elif pos_embed == 'learned': | |||
encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) | |||
deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) | |||
else: | |||
raise ValueError("pos_embed only supports sin or learned.") | |||
encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed, | |||
num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, | |||
dropout=dropout) | |||
decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, | |||
d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, | |||
dropout=dropout, | |||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
return cls(encoder, decoder) | |||
class LSTMSeq2SeqModel(Seq2SeqModel): | |||
""" | |||
使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model | |||
""" | |||
def __init__(self, encoder, decoder): | |||
super().__init__(encoder, decoder) | |||
@classmethod | |||
def build_model(cls, src_embed, tgt_embed=None, | |||
num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True, | |||
attention=True, bind_encoder_decoder_embed=False, | |||
bind_decoder_input_output_embed=True): | |||
""" | |||
:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||
True,则不要输入该值 | |||
:param int num_layers: Encoder和Decoder的层数 | |||
:param int hidden_size: encoder和decoder的隐藏层大小 | |||
:param float dropout: 每层之间的Dropout的大小 | |||
:param bool bidirectional: encoder是否使用双向LSTM | |||
:param bool attention: decoder是否使用attention attend encoder在所有时刻的状态 | |||
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||
:return: LSTMSeq2SeqModel | |||
""" | |||
if bind_encoder_decoder_embed and tgt_embed is not None: | |||
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||
src_embed = get_embeddings(src_embed) | |||
if bind_encoder_decoder_embed: | |||
tgt_embed = src_embed | |||
else: | |||
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||
tgt_embed = get_embeddings(tgt_embed) | |||
encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers, | |||
hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional) | |||
decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size, | |||
dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed, | |||
attention=attention) | |||
return cls(encoder, decoder) |
@@ -0,0 +1,271 @@ | |||
r""" | |||
本模块实现了几种序列标注模型 | |||
""" | |||
__all__ = [ | |||
"SeqLabeling", | |||
"AdvSeqLabel", | |||
"BiLSTMCRF" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from ...core.utils import seq_len_to_mask | |||
from ...embeddings.torch.utils import get_embeddings | |||
from ...modules.torch.decoder import ConditionalRandomField | |||
from ...modules.torch.encoder import LSTM | |||
from ...modules.torch import decoder, encoder | |||
from ...modules.torch.decoder.crf import allowed_transitions | |||
class BiLSTMCRF(nn.Module): | |||
r""" | |||
结构为embedding + BiLSTM + FC + Dropout + CRF. | |||
""" | |||
def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, | |||
target_vocab=None): | |||
r""" | |||
:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100) | |||
:param num_classes: 一共多少个类 | |||
:param num_layers: BiLSTM的层数 | |||
:param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向) | |||
:param dropout: dropout的概率,0为不dropout | |||
:param target_vocab: Vocabulary对象,target与index的对应关系。如果传入该值,将自动避免非法的解码序列。 | |||
""" | |||
super().__init__() | |||
self.embed = get_embeddings(embed) | |||
if num_layers>1: | |||
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, | |||
batch_first=True, dropout=dropout) | |||
else: | |||
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, | |||
batch_first=True) | |||
self.dropout = nn.Dropout(dropout) | |||
self.fc = nn.Linear(hidden_size*2, num_classes) | |||
trans = None | |||
if target_vocab is not None: | |||
assert len(target_vocab)==num_classes, "The number of classes should be same with the length of target vocabulary." | |||
trans = allowed_transitions(target_vocab.idx2word, include_start_end=True) | |||
self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans) | |||
def forward(self, words, seq_len=None, target=None): | |||
words = self.embed(words) | |||
feats, _ = self.lstm(words, seq_len=seq_len) | |||
feats = self.fc(feats) | |||
feats = self.dropout(feats) | |||
logits = F.log_softmax(feats, dim=-1) | |||
mask = seq_len_to_mask(seq_len) | |||
if target is None: | |||
pred, _ = self.crf.viterbi_decode(logits, mask) | |||
return {'pred':pred} | |||
else: | |||
loss = self.crf(logits, target, mask).mean() | |||
return {'loss':loss} | |||
def train_step(self, words, seq_len, target): | |||
return self(words, seq_len, target) | |||
def evaluate_step(self, words, seq_len): | |||
return self(words, seq_len) | |||
class SeqLabeling(nn.Module): | |||
r""" | |||
一个基础的Sequence labeling的模型。 | |||
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 | |||
""" | |||
def __init__(self, embed, hidden_size, num_classes): | |||
r""" | |||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding | |||
:param int hidden_size: LSTM隐藏层的大小 | |||
:param int num_classes: 一共有多少类 | |||
""" | |||
super(SeqLabeling, self).__init__() | |||
self.embedding = get_embeddings(embed) | |||
self.rnn = encoder.LSTM(self.embedding.embedding_dim, hidden_size) | |||
self.fc = nn.Linear(hidden_size, num_classes) | |||
self.crf = decoder.ConditionalRandomField(num_classes) | |||
def forward(self, words, seq_len): | |||
r""" | |||
:param torch.LongTensor words: [batch_size, max_len],序列的index | |||
:param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 | |||
:return | |||
""" | |||
x = self.embedding(words) | |||
# [batch_size, max_len, word_emb_dim] | |||
x, _ = self.rnn(x, seq_len) | |||
# [batch_size, max_len, hidden_size * direction] | |||
x = self.fc(x) | |||
return {'pred': x} | |||
# [batch_size, max_len, num_classes] | |||
def train_step(self, words, seq_len, target): | |||
res = self(words, seq_len) | |||
pred = res['pred'] | |||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)) | |||
return {'loss': self._internal_loss(pred, target, mask)} | |||
def evaluate_step(self, words, seq_len): | |||
r""" | |||
用于在预测时使用 | |||
:param torch.LongTensor words: [batch_size, max_len] | |||
:param torch.LongTensor seq_len: [batch_size,] | |||
:return: {'pred': xx}, [batch_size, max_len] | |||
""" | |||
mask = seq_len_to_mask(seq_len, max_len=words.size(1)) | |||
res = self(words, seq_len) | |||
pred = res['pred'] | |||
# [batch_size, max_len, num_classes] | |||
pred = self._decode(pred, mask) | |||
return {'pred': pred} | |||
def _internal_loss(self, x, y, mask): | |||
r""" | |||
Negative log likelihood loss. | |||
:param x: Tensor, [batch_size, max_len, tag_size] | |||
:param y: Tensor, [batch_size, max_len] | |||
:return loss: a scalar Tensor | |||
""" | |||
x = x.float() | |||
y = y.long() | |||
total_loss = self.crf(x, y, mask) | |||
return torch.mean(total_loss) | |||
def _decode(self, x, mask): | |||
r""" | |||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||
:return prediction: [batch_size, max_len] | |||
""" | |||
tag_seq, _ = self.crf.viterbi_decode(x, mask) | |||
return tag_seq | |||
class AdvSeqLabel(nn.Module): | |||
r""" | |||
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 | |||
""" | |||
def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'): | |||
r""" | |||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||
:param int hidden_size: LSTM的隐层大小 | |||
:param int num_classes: 有多少个类 | |||
:param float dropout: LSTM中以及DropOut层的drop概率 | |||
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S' | |||
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证 | |||
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。) | |||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况有用。 | |||
""" | |||
super().__init__() | |||
self.Embedding = get_embeddings(embed) | |||
self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim) | |||
self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, | |||
dropout=dropout, | |||
bidirectional=True, batch_first=True) | |||
self.Linear1 = nn.Linear(hidden_size * 2, hidden_size * 2 // 3) | |||
self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3) | |||
self.relu = torch.nn.LeakyReLU() | |||
self.drop = torch.nn.Dropout(dropout) | |||
self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) | |||
if id2words is None: | |||
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
else: | |||
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||
allowed_transitions=allowed_transitions(id2words, | |||
encoding_type=encoding_type)) | |||
def _decode(self, x, mask): | |||
r""" | |||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||
:param torch.ByteTensor mask: [batch_size, max_len] | |||
:return torch.LongTensor, [batch_size, max_len] | |||
""" | |||
tag_seq, _ = self.Crf.viterbi_decode(x, mask) | |||
return tag_seq | |||
def _internal_loss(self, x, y, mask): | |||
r""" | |||
Negative log likelihood loss. | |||
:param x: Tensor, [batch_size, max_len, tag_size] | |||
:param y: Tensor, [batch_size, max_len] | |||
:param mask: Tensor, [batch_size, max_len] | |||
:return loss: a scalar Tensor | |||
""" | |||
x = x.float() | |||
y = y.long() | |||
total_loss = self.Crf(x, y, mask) | |||
return torch.mean(total_loss) | |||
def forward(self, words, seq_len, target=None): | |||
r""" | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len:[batch_size, ] | |||
:param torch.LongTensor target: [batch_size, max_len] | |||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
If truth is not None, return loss, a scalar. Used in training. | |||
""" | |||
words = words.long() | |||
seq_len = seq_len.long() | |||
mask = seq_len_to_mask(seq_len, max_len=words.size(1)) | |||
target = target.long() if target is not None else None | |||
if next(self.parameters()).is_cuda: | |||
words = words.cuda() | |||
x = self.Embedding(words) | |||
x = self.norm1(x) | |||
# [batch_size, max_len, word_emb_dim] | |||
x, _ = self.Rnn(x, seq_len=seq_len) | |||
x = self.Linear1(x) | |||
x = self.norm2(x) | |||
x = self.relu(x) | |||
x = self.drop(x) | |||
x = self.Linear2(x) | |||
if target is not None: | |||
return {"loss": self._internal_loss(x, target, mask)} | |||
else: | |||
return {"pred": self._decode(x, mask)} | |||
def train_step(self, words, seq_len, target): | |||
r""" | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len: [batch_size, ] | |||
:param torch.LongTensor target: [batch_size, max_len], 目标 | |||
:return torch.Tensor: a scalar loss | |||
""" | |||
return self(words, seq_len, target) | |||
def evaluate_step(self, words, seq_len): | |||
r""" | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len: [batch_size, ] | |||
:return torch.LongTensor: [batch_size, max_len] | |||
""" | |||
return self(words, seq_len) |
@@ -0,0 +1,26 @@ | |||
__all__ = [ | |||
'ConditionalRandomField', | |||
'allowed_transitions', | |||
"State", | |||
"Seq2SeqDecoder", | |||
"LSTMSeq2SeqDecoder", | |||
"TransformerSeq2SeqDecoder", | |||
"LSTM", | |||
"Seq2SeqEncoder", | |||
"TransformerSeq2SeqEncoder", | |||
"LSTMSeq2SeqEncoder", | |||
"StarTransformer", | |||
"VarRNN", | |||
"VarLSTM", | |||
"VarGRU", | |||
'SequenceGenerator', | |||
"TimestepDropout", | |||
] | |||
from .decoder import * | |||
from .encoder import * | |||
from .generator import * | |||
from .dropout import TimestepDropout |
@@ -0,0 +1,321 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MultiHeadAttention", | |||
"BiAttention", | |||
"SelfAttention", | |||
] | |||
import math | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from .decoder.seq2seq_state import TransformerState | |||
class DotAttention(nn.Module): | |||
r""" | |||
Transformer当中的DotAttention | |||
""" | |||
def __init__(self, key_size, value_size, dropout=0.0): | |||
super(DotAttention, self).__init__() | |||
self.key_size = key_size | |||
self.value_size = value_size | |||
self.scale = math.sqrt(key_size) | |||
self.drop = nn.Dropout(dropout) | |||
self.softmax = nn.Softmax(dim=-1) | |||
def forward(self, Q, K, V, mask_out=None): | |||
r""" | |||
:param Q: [..., seq_len_q, key_size] | |||
:param K: [..., seq_len_k, key_size] | |||
:param V: [..., seq_len_k, value_size] | |||
:param mask_out: [..., 1, seq_len] or [..., seq_len_q, seq_len_k] | |||
""" | |||
output = torch.matmul(Q, K.transpose(-1, -2)) / self.scale | |||
if mask_out is not None: | |||
output.masked_fill_(mask_out, -1e9) | |||
output = self.softmax(output) | |||
output = self.drop(output) | |||
return torch.matmul(output, V) | |||
class MultiHeadAttention(nn.Module): | |||
""" | |||
Attention is all you need中提到的多头注意力 | |||
""" | |||
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||
super(MultiHeadAttention, self).__init__() | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dropout = dropout | |||
self.head_dim = d_model // n_head | |||
self.layer_idx = layer_idx | |||
assert d_model % n_head == 0, "d_model should be divisible by n_head" | |||
self.scaling = self.head_dim ** -0.5 | |||
self.q_proj = nn.Linear(d_model, d_model) | |||
self.k_proj = nn.Linear(d_model, d_model) | |||
self.v_proj = nn.Linear(d_model, d_model) | |||
self.out_proj = nn.Linear(d_model, d_model) | |||
self.reset_parameters() | |||
def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): | |||
""" | |||
:param query: batch x seq x dim | |||
:param key: batch x seq x dim | |||
:param value: batch x seq x dim | |||
:param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 | |||
:param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 | |||
:param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 | |||
:return: | |||
""" | |||
assert key.size() == value.size() | |||
if state is not None: | |||
assert self.layer_idx is not None | |||
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() | |||
q = self.q_proj(query) # batch x seq x dim | |||
q *= self.scaling | |||
k = v = None | |||
prev_k = prev_v = None | |||
# 从state中取kv | |||
if isinstance(state, TransformerState): # 说明此时在inference阶段 | |||
if qkv_same: # 此时在decoder self attention | |||
prev_k = state.decoder_prev_key[self.layer_idx] | |||
prev_v = state.decoder_prev_value[self.layer_idx] | |||
else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 | |||
k = state.encoder_key[self.layer_idx] | |||
v = state.encoder_value[self.layer_idx] | |||
if k is None: | |||
k = self.k_proj(key) | |||
v = self.v_proj(value) | |||
if prev_k is not None: | |||
k = torch.cat((prev_k, k), dim=1) | |||
v = torch.cat((prev_v, v), dim=1) | |||
# 更新state | |||
if isinstance(state, TransformerState): | |||
if qkv_same: | |||
state.decoder_prev_key[self.layer_idx] = k | |||
state.decoder_prev_value[self.layer_idx] = v | |||
else: | |||
state.encoder_key[self.layer_idx] = k | |||
state.encoder_value[self.layer_idx] = v | |||
# 开始计算attention | |||
batch_size, q_len, d_model = query.size() | |||
k_len, v_len = k.size(1), v.size(1) | |||
q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) | |||
k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) | |||
v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) | |||
attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head | |||
if key_mask is not None: | |||
_key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1 | |||
attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) | |||
if attn_mask is not None: | |||
_attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head | |||
attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) | |||
attn_weights = F.softmax(attn_weights, dim=2) | |||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |||
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | |||
output = output.reshape(batch_size, q_len, -1) | |||
output = self.out_proj(output) # batch,q_len,dim | |||
return output, attn_weights | |||
def reset_parameters(self): | |||
nn.init.xavier_uniform_(self.q_proj.weight) | |||
nn.init.xavier_uniform_(self.k_proj.weight) | |||
nn.init.xavier_uniform_(self.v_proj.weight) | |||
nn.init.xavier_uniform_(self.out_proj.weight) | |||
def set_layer_idx(self, layer_idx): | |||
self.layer_idx = layer_idx | |||
class AttentionLayer(nn.Module): | |||
def __init__(selfu, input_size, key_dim, value_dim, bias=False): | |||
""" | |||
可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention | |||
:param int input_size: 输入的大小 | |||
:param int key_dim: 一般就是encoder_output输出的维度 | |||
:param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 | |||
:param bias: | |||
""" | |||
super().__init__() | |||
selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) | |||
selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias) | |||
def forward(self, input, encode_outputs, encode_mask): | |||
""" | |||
:param input: batch_size x input_size | |||
:param encode_outputs: batch_size x max_len x key_dim | |||
:param encode_mask: batch_size x max_len, 为0的地方为padding | |||
:return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 | |||
""" | |||
# x: bsz x encode_hidden_size | |||
x = self.input_proj(input) | |||
# compute attention | |||
attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len | |||
# don't attend over padding | |||
if encode_mask is not None: | |||
attn_scores = attn_scores.float().masked_fill_( | |||
encode_mask.eq(0), | |||
float('-inf') | |||
).type_as(attn_scores) # FP16 support: cast to float and back | |||
attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz | |||
# sum weighted sources | |||
x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size | |||
x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) | |||
return x, attn_scores | |||
def _masked_softmax(tensor, mask): | |||
tensor_shape = tensor.size() | |||
reshaped_tensor = tensor.view(-1, tensor_shape[-1]) | |||
# Reshape the mask so it matches the size of the input tensor. | |||
while mask.dim() < tensor.dim(): | |||
mask = mask.unsqueeze(1) | |||
mask = mask.expand_as(tensor).contiguous().float() | |||
reshaped_mask = mask.view(-1, mask.size()[-1]) | |||
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) | |||
result = result * reshaped_mask | |||
# 1e-13 is added to avoid divisions by zero. | |||
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) | |||
return result.view(*tensor_shape) | |||
def _weighted_sum(tensor, weights, mask): | |||
w_sum = weights.bmm(tensor) | |||
while mask.dim() < w_sum.dim(): | |||
mask = mask.unsqueeze(1) | |||
mask = mask.transpose(-1, -2) | |||
mask = mask.expand_as(w_sum).contiguous().float() | |||
return w_sum * mask | |||
class BiAttention(nn.Module): | |||
r""" | |||
Bi Attention module | |||
对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果 | |||
.. math:: | |||
\begin{array}{ll} \\ | |||
e_{ij} = {a}^{\mathrm{T}}_{i}{b}_{j} \\ | |||
{\hat{a}}_{i} = \sum_{j=1}^{\mathcal{l}_{b}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{b}}{\mathrm{exp}(e_{ik})}}}{b}_{j} \\ | |||
{\hat{b}}_{j} = \sum_{i=1}^{\mathcal{l}_{a}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{a}}{\mathrm{exp}(e_{ik})}}}{a}_{i} \\ | |||
\end{array} | |||
""" | |||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||
r""" | |||
:param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] | |||
:param torch.Tensor premise_mask: [batch_size, a_seq_len] | |||
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] | |||
:param torch.Tensor hypothesis_mask: [batch_size, b_seq_len] | |||
:return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] | |||
""" | |||
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) | |||
.contiguous()) | |||
prem_hyp_attn = _masked_softmax(similarity_matrix, hypothesis_mask) | |||
hyp_prem_attn = _masked_softmax(similarity_matrix.transpose(1, 2) | |||
.contiguous(), | |||
premise_mask) | |||
attended_premises = _weighted_sum(hypothesis_batch, | |||
prem_hyp_attn, | |||
premise_mask) | |||
attended_hypotheses = _weighted_sum(premise_batch, | |||
hyp_prem_attn, | |||
hypothesis_mask) | |||
return attended_premises, attended_hypotheses | |||
class SelfAttention(nn.Module): | |||
r""" | |||
这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_ | |||
的Self Attention Module. | |||
""" | |||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5): | |||
r""" | |||
:param int input_size: 输入tensor的hidden维度 | |||
:param int attention_unit: 输出tensor的hidden维度 | |||
:param int attention_hops: | |||
:param float drop: dropout概率,默认值为0.5 | |||
""" | |||
super(SelfAttention, self).__init__() | |||
self.attention_hops = attention_hops | |||
self.ws1 = nn.Linear(input_size, attention_unit, bias=False) | |||
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) | |||
self.I = torch.eye(attention_hops, requires_grad=False) | |||
self.I_origin = self.I | |||
self.drop = nn.Dropout(drop) | |||
self.tanh = nn.Tanh() | |||
def _penalization(self, attention): | |||
r""" | |||
compute the penalization term for attention module | |||
""" | |||
baz = attention.size(0) | |||
size = self.I.size() | |||
if len(size) != 3 or size[0] != baz: | |||
self.I = self.I_origin.expand(baz, -1, -1) | |||
self.I = self.I.to(device=attention.device) | |||
attention_t = torch.transpose(attention, 1, 2).contiguous() | |||
mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] | |||
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 | |||
return torch.sum(ret) / size[0] | |||
def forward(self, input, input_origin): | |||
r""" | |||
:param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 | |||
:param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 | |||
:return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 | |||
:return torch.Tensor output2: [1] attention惩罚项,是一个标量 | |||
""" | |||
input = input.contiguous() | |||
size = input.size() # [bsz, len, nhid] | |||
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | |||
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | |||
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | |||
attention = self.ws2(y1).transpose(1, 2).contiguous() | |||
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | |||
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | |||
attention = F.softmax(attention, 2) # [baz ,hop, len] | |||
return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] |
@@ -0,0 +1,15 @@ | |||
__all__ = [ | |||
'ConditionalRandomField', | |||
'allowed_transitions', | |||
"State", | |||
"Seq2SeqDecoder", | |||
"LSTMSeq2SeqDecoder", | |||
"TransformerSeq2SeqDecoder" | |||
] | |||
from .crf import ConditionalRandomField, allowed_transitions | |||
from .seq2seq_state import State | |||
from .seq2seq_decoder import LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, Seq2SeqDecoder |
@@ -0,0 +1,354 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConditionalRandomField", | |||
"allowed_transitions" | |||
] | |||
from typing import Union, List | |||
import torch | |||
from torch import nn | |||
from ....core.metrics.span_f1_pre_rec_metric import _get_encoding_type_from_tag_vocab, _check_tag_vocab_and_encoding_type | |||
from ....core.vocabulary import Vocabulary | |||
def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type:str=None, include_start_end:bool=False): | |||
r""" | |||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
:param tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | |||
tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。 | |||
:param encoding_type: 支持``["bio", "bmes", "bmeso", "bioes"]``。默认为None,通过vocab自动推断 | |||
:param include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||
""" | |||
if encoding_type is None: | |||
encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
else: | |||
encoding_type = encoding_type.lower() | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
pad_token = '<pad>' | |||
unk_token = '<unk>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
id_label_lst = list(tag_vocab.idx2word.items()) | |||
pad_token = tag_vocab.padding | |||
unk_token = tag_vocab.unknown | |||
else: | |||
id_label_lst = list(tag_vocab.items()) | |||
num_tags = len(tag_vocab) | |||
start_idx = num_tags | |||
end_idx = num_tags + 1 | |||
allowed_trans = [] | |||
if include_start_end: | |||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | |||
def split_tag_label(from_label): | |||
from_label = from_label.lower() | |||
if from_label in ['start', 'end']: | |||
from_tag = from_label | |||
from_label = '' | |||
else: | |||
from_tag = from_label[:1] | |||
from_label = from_label[2:] | |||
return from_tag, from_label | |||
for from_id, from_label in id_label_lst: | |||
if from_label in [pad_token, unk_token]: | |||
continue | |||
from_tag, from_label = split_tag_label(from_label) | |||
for to_id, to_label in id_label_lst: | |||
if to_label in [pad_token, unk_token]: | |||
continue | |||
to_tag, to_label = split_tag_label(to_label) | |||
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
allowed_trans.append((from_id, to_id)) | |||
return allowed_trans | |||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
r""" | |||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | |||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param str from_label: 比如"PER", "LOC"等label | |||
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param str to_label: 比如"PER", "LOC"等label | |||
:return: bool,能否跃迁 | |||
""" | |||
if to_tag == 'start' or from_tag == 'end': | |||
return False | |||
encoding_type = encoding_type.lower() | |||
if encoding_type == 'bio': | |||
r""" | |||
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | |||
+-------+---+---+---+-------+-----+ | |||
| | B | I | O | start | end | | |||
+-------+---+---+---+-------+-----+ | |||
| B | y | - | y | n | y | | |||
+-------+---+---+---+-------+-----+ | |||
| I | y | - | y | n | y | | |||
+-------+---+---+---+-------+-----+ | |||
| O | y | n | y | n | y | | |||
+-------+---+---+---+-------+-----+ | |||
| start | y | n | y | n | n | | |||
+-------+---+---+---+-------+-----+ | |||
| end | n | n | n | n | n | | |||
+-------+---+---+---+-------+-----+ | |||
""" | |||
if from_tag == 'start': | |||
return to_tag in ('b', 'o') | |||
elif from_tag in ['b', 'i']: | |||
return any([to_tag in ['end', 'b', 'o'], to_tag == 'i' and from_label == to_label]) | |||
elif from_tag == 'o': | |||
return to_tag in ['end', 'b', 'o'] | |||
else: | |||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | |||
elif encoding_type == 'bmes': | |||
r""" | |||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | |||
+-------+---+---+---+---+-------+-----+ | |||
| | B | M | E | S | start | end | | |||
+-------+---+---+---+---+-------+-----+ | |||
| B | n | - | - | n | n | n | | |||
+-------+---+---+---+---+-------+-----+ | |||
| M | n | - | - | n | n | n | | |||
+-------+---+---+---+---+-------+-----+ | |||
| E | y | n | n | y | n | y | | |||
+-------+---+---+---+---+-------+-----+ | |||
| S | y | n | n | y | n | y | | |||
+-------+---+---+---+---+-------+-----+ | |||
| start | y | n | n | y | n | n | | |||
+-------+---+---+---+---+-------+-----+ | |||
| end | n | n | n | n | n | n | | |||
+-------+---+---+---+---+-------+-----+ | |||
""" | |||
if from_tag == 'start': | |||
return to_tag in ['b', 's'] | |||
elif from_tag == 'b': | |||
return to_tag in ['m', 'e'] and from_label == to_label | |||
elif from_tag == 'm': | |||
return to_tag in ['m', 'e'] and from_label == to_label | |||
elif from_tag in ['e', 's']: | |||
return to_tag in ['b', 's', 'end'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||
elif encoding_type == 'bmeso': | |||
if from_tag == 'start': | |||
return to_tag in ['b', 's', 'o'] | |||
elif from_tag == 'b': | |||
return to_tag in ['m', 'e'] and from_label == to_label | |||
elif from_tag == 'm': | |||
return to_tag in ['m', 'e'] and from_label == to_label | |||
elif from_tag in ['e', 's', 'o']: | |||
return to_tag in ['b', 's', 'end', 'o'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
elif encoding_type == 'bioes': | |||
if from_tag == 'start': | |||
return to_tag in ['b', 's', 'o'] | |||
elif from_tag == 'b': | |||
return to_tag in ['i', 'e'] and from_label == to_label | |||
elif from_tag == 'i': | |||
return to_tag in ['i', 'e'] and from_label == to_label | |||
elif from_tag in ['e', 's', 'o']: | |||
return to_tag in ['b', 's', 'end', 'o'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) | |||
else: | |||
raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) | |||
class ConditionalRandomField(nn.Module): | |||
r""" | |||
条件随机场。提供 forward() 以及 viterbi_decode() 两个方法,分别用于训练与inference。 | |||
""" | |||
def __init__(self, num_tags:int, include_start_end_trans:bool=False, allowed_transitions:List=None): | |||
r""" | |||
:param num_tags: 标签的数量 | |||
:param include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | |||
:param allowed_transitions: 内部的Tuple[from_tag_id(int), | |||
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||
""" | |||
super(ConditionalRandomField, self).__init__() | |||
self.include_start_end_trans = include_start_end_trans | |||
self.num_tags = num_tags | |||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | |||
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | |||
if self.include_start_end_trans: | |||
self.start_scores = nn.Parameter(torch.randn(num_tags)) | |||
self.end_scores = nn.Parameter(torch.randn(num_tags)) | |||
if allowed_transitions is None: | |||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||
else: | |||
constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float) | |||
has_start = False | |||
has_end = False | |||
for from_tag_id, to_tag_id in allowed_transitions: | |||
constrain[from_tag_id, to_tag_id] = 0 | |||
if from_tag_id==num_tags: | |||
has_start = True | |||
if to_tag_id==num_tags+1: | |||
has_end = True | |||
if not has_start: | |||
constrain[num_tags, :].fill_(0) | |||
if not has_end: | |||
constrain[:, num_tags+1].fill_(0) | |||
self._constrain = nn.Parameter(constrain, requires_grad=False) | |||
def _normalizer_likelihood(self, logits, mask): | |||
r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
sum of the likelihoods across all possible state sequences. | |||
:param logits:FloatTensor, max_len x batch_size x num_tags | |||
:param mask:ByteTensor, max_len x batch_size | |||
:return:FloatTensor, batch_size | |||
""" | |||
seq_len, batch_size, n_tags = logits.size() | |||
alpha = logits[0] | |||
if self.include_start_end_trans: | |||
alpha = alpha + self.start_scores.view(1, -1) | |||
flip_mask = mask.eq(False) | |||
for i in range(1, seq_len): | |||
emit_score = logits[i].view(batch_size, 1, n_tags) | |||
trans_score = self.trans_m.view(1, n_tags, n_tags) | |||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | |||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||
alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) | |||
if self.include_start_end_trans: | |||
alpha = alpha + self.end_scores.view(1, -1) | |||
return torch.logsumexp(alpha, 1) | |||
def _gold_score(self, logits, tags, mask): | |||
r""" | |||
Compute the score for the gold path. | |||
:param logits: FloatTensor, max_len x batch_size x num_tags | |||
:param tags: LongTensor, max_len x batch_size | |||
:param mask: ByteTensor, max_len x batch_size | |||
:return:FloatTensor, batch_size | |||
""" | |||
seq_len, batch_size, _ = logits.size() | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||
# trans_socre [L-1, B] | |||
mask = mask.eq(True) | |||
flip_mask = mask.eq(False) | |||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | |||
# emit_score [L, B] | |||
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) | |||
# score [L-1, B] | |||
score = trans_score + emit_score[:seq_len - 1, :] | |||
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) | |||
if self.include_start_end_trans: | |||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||
last_idx = mask.long().sum(0) - 1 | |||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||
score = score + st_scores + ed_scores | |||
# return [B,] | |||
return score | |||
def forward(self, feats, tags, mask): | |||
r""" | |||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||
:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | |||
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | |||
:return: torch.FloatTensor, (batch_size,) | |||
""" | |||
feats = feats.transpose(0, 1) | |||
tags = tags.transpose(0, 1).long() | |||
mask = mask.transpose(0, 1).float() | |||
all_path_score = self._normalizer_likelihood(feats, mask) | |||
gold_path_score = self._gold_score(feats, tags, mask) | |||
return all_path_score - gold_path_score | |||
def viterbi_decode(self, logits, mask, unpad=False): | |||
r"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||
个sample的有效长度。 | |||
:return: 返回 (paths, scores)。 | |||
paths: 是解码后的路径, 其值参照unpad参数. | |||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
""" | |||
batch_size, max_len, n_tags = logits.size() | |||
seq_len = mask.long().sum(1) | |||
logits = logits.transpose(0, 1).data # L, B, H | |||
mask = mask.transpose(0, 1).data.eq(True) # L, B | |||
flip_mask = mask.eq(False) | |||
# dp | |||
vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = logits[0] # bsz x n_tags | |||
transitions = self._constrain.data.clone() | |||
transitions[:n_tags, :n_tags] += self.trans_m.data | |||
if self.include_start_end_trans: | |||
transitions[n_tags, :n_tags] += self.start_scores.data | |||
transitions[:n_tags, n_tags + 1] += self.end_scores.data | |||
vscore += transitions[n_tags, :n_tags] | |||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags | |||
# 针对长度为1的句子 | |||
vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \ | |||
.masked_fill(seq_len.ne(1).view(-1, 1), 0) | |||
for i in range(1, max_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score | |||
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag | |||
# 需要考虑当前位置是该序列的最后一个 | |||
score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0) | |||
best_score, best_dst = score.max(1) | |||
vpath[i] = best_dst | |||
# 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况 | |||
vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||
# backtrace | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||
seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device) | |||
lens = (seq_len - 1) | |||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len | |||
ans = logits.new_empty((max_len, batch_size), dtype=torch.long) | |||
ans_score, last_tags = vscore.max(1) | |||
ans[idxes[0], batch_idx] = last_tags | |||
for i in range(max_len - 1): | |||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
ans[idxes[i + 1], batch_idx] = last_tags | |||
ans = ans.transpose(0, 1) | |||
if unpad: | |||
paths = [] | |||
for idx, max_len in enumerate(lens): | |||
paths.append(ans[idx, :max_len + 1].tolist()) | |||
else: | |||
paths = ans | |||
return paths, ans_score |
@@ -0,0 +1,416 @@ | |||
r"""undocumented""" | |||
from typing import Union, Tuple | |||
import math | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from ..attention import AttentionLayer, MultiHeadAttention | |||
from ....embeddings.torch.utils import get_embeddings | |||
from ....embeddings.torch.static_embedding import StaticEmbedding | |||
from .seq2seq_state import State, LSTMState, TransformerState | |||
__all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder'] | |||
class Seq2SeqDecoder(nn.Module): | |||
""" | |||
Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def forward(self, tokens, state, **kwargs): | |||
""" | |||
:param torch.LongTensor tokens: bsz x max_len | |||
:param State state: state包含了encoder的输出以及decode之前的内容 | |||
:return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 | |||
""" | |||
raise NotImplemented | |||
def reorder_states(self, indices, states): | |||
""" | |||
根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 | |||
:param torch.LongTensor indices: | |||
:param State states: | |||
:return: | |||
""" | |||
assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" | |||
states.reorder_state(indices) | |||
def init_state(self, encoder_output, encoder_mask): | |||
""" | |||
初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 | |||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param kwargs: | |||
:return: State, 返回一个State对象,记录了encoder的输出 | |||
""" | |||
state = State(encoder_output, encoder_mask) | |||
return state | |||
def decode(self, tokens, state): | |||
""" | |||
根据states中的内容,以及tokens中的内容进行之后的生成。 | |||
:param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。 | |||
:param State state: 记录了encoder输出与decoder过去状态 | |||
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | |||
""" | |||
outputs = self(state=state, tokens=tokens) | |||
if isinstance(outputs, torch.Tensor): | |||
return outputs[:, -1] | |||
else: | |||
raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.") | |||
class TiedEmbedding(nn.Module): | |||
""" | |||
用于将weight和原始weight绑定 | |||
""" | |||
def __init__(self, weight): | |||
super().__init__() | |||
self.weight = weight # vocab_size x embed_size | |||
def forward(self, x): | |||
""" | |||
:param torch.FloatTensor x: bsz x * x embed_size | |||
:return: torch.FloatTensor bsz x * x vocab_size | |||
""" | |||
return torch.matmul(x, self.weight.t()) | |||
def get_bind_decoder_output_embed(embed): | |||
""" | |||
给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding | |||
:param embed: | |||
:return: | |||
""" | |||
if isinstance(embed, StaticEmbedding): | |||
for idx, map2idx in enumerate(embed.words_to_words): | |||
assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ | |||
"include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ | |||
"`lower=True` or `min_freq!=1`." | |||
elif not isinstance(embed, nn.Embedding): | |||
raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") | |||
return TiedEmbedding(embed.weight) | |||
class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||
""" | |||
LSTM的Decoder | |||
:param nn.Module,tuple embed: decoder输入的embedding. | |||
:param int num_layers: 多少层LSTM | |||
:param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 | |||
:param dropout: Dropout的大小 | |||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
:param bool attention: 是否使用attention | |||
""" | |||
def __init__(self, embed: Union[nn.Module, Tuple[int, int]], num_layers = 3, hidden_size = 300, | |||
dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): | |||
super().__init__() | |||
self.embed = get_embeddings(init_embed=embed) | |||
self.embed_dim = embed.embedding_dim | |||
if bind_decoder_input_output_embed: | |||
self.output_layer = get_bind_decoder_output_embed(self.embed) | |||
else: # 不需要bind | |||
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||
self.output_layer = TiedEmbedding(self.output_embed.weight) | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, | |||
batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0) | |||
self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None | |||
self.output_proj = nn.Linear(hidden_size, self.embed_dim) | |||
self.dropout_layer = nn.Dropout(dropout) | |||
def forward(self, tokens, state, return_attention=False): | |||
""" | |||
:param torch.LongTensor tokens: batch x max_len | |||
:param LSTMState state: 保存encoder输出和decode状态的State对象 | |||
:param bool return_attention: 是否返回attention的的score | |||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
""" | |||
src_output = state.encoder_output | |||
encoder_mask = state.encoder_mask | |||
assert tokens.size(1)>state.decode_length, "The state does not match the tokens." | |||
tokens = tokens[:, state.decode_length:] | |||
x = self.embed(tokens) | |||
attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq | |||
input_feed = state.input_feed | |||
decoder_out = [] | |||
cur_hidden = state.hidden | |||
cur_cell = state.cell | |||
# 开始计算 | |||
for i in range(tokens.size(1)): | |||
input = torch.cat( | |||
(x[:, i:i + 1, :], | |||
input_feed[:, None, :] | |||
), | |||
dim=2 | |||
) # batch,1,2*dim | |||
_, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size | |||
if self.attention_layer is not None: | |||
input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) | |||
attn_weights.append(attn_weight) | |||
else: | |||
input_feed = cur_hidden[-1] | |||
state.input_feed = input_feed # batch, hidden | |||
state.hidden = cur_hidden | |||
state.cell = cur_cell | |||
state.decode_length += 1 | |||
decoder_out.append(input_feed) | |||
decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden | |||
decoder_out = self.dropout_layer(decoder_out) | |||
if attn_weights is not None: | |||
attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len | |||
decoder_out = self.output_proj(decoder_out) | |||
feats = self.output_layer(decoder_out) | |||
if return_attention: | |||
return feats, attn_weights | |||
return feats | |||
def init_state(self, encoder_output, encoder_mask) -> LSTMState: | |||
""" | |||
:param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: | |||
bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 | |||
hidden state和cell state来赋值这两个值 | |||
(2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 | |||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend | |||
:return: | |||
""" | |||
if not isinstance(encoder_output, torch.Tensor): | |||
encoder_output, (hidden, cell) = encoder_output | |||
else: | |||
hidden = cell = None | |||
assert encoder_output.ndim==3 | |||
assert encoder_mask.size()==encoder_output.size()[:2] | |||
assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \ | |||
"the hidden_size." | |||
t = [hidden, cell] | |||
for idx in range(2): | |||
v = t[idx] | |||
if v is None: | |||
v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size) | |||
else: | |||
assert v.dim()==2 | |||
assert v.size(-1)==self.hidden_size | |||
v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size | |||
t[idx] = v | |||
state = LSTMState(encoder_output, encoder_mask, t[0], t[1]) | |||
return state | |||
class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
""" | |||
:param int d_model: 输入、输出的维度 | |||
:param int n_head: 多少个head,需要能被d_model整除 | |||
:param int dim_ff: | |||
:param float dropout: | |||
:param int layer_idx: layer的编号 | |||
""" | |||
def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): | |||
super().__init__() | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
self.dropout = dropout | |||
self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息 | |||
self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||
self.self_attn_layer_norm = nn.LayerNorm(d_model) | |||
self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||
self.encoder_attn_layer_norm = nn.LayerNorm(d_model) | |||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||
nn.ReLU(), | |||
nn.Dropout(dropout), | |||
nn.Linear(self.dim_ff, self.d_model), | |||
nn.Dropout(dropout)) | |||
self.final_layer_norm = nn.LayerNorm(self.d_model) | |||
def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): | |||
""" | |||
:param x: (batch, seq_len, dim), decoder端的输入 | |||
:param encoder_output: (batch,src_seq_len,dim), encoder的输出 | |||
:param encoder_mask: batch,src_seq_len, 为1的地方需要attend | |||
:param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 | |||
:param TransformerState state: 只在inference阶段传入 | |||
:return: | |||
""" | |||
# self attention part | |||
residual = x | |||
x = self.self_attn_layer_norm(x) | |||
x, _ = self.self_attn(query=x, | |||
key=x, | |||
value=x, | |||
attn_mask=self_attn_mask, | |||
state=state) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
x = residual + x | |||
# encoder attention part | |||
residual = x | |||
x = self.encoder_attn_layer_norm(x) | |||
x, attn_weight = self.encoder_attn(query=x, | |||
key=encoder_output, | |||
value=encoder_output, | |||
key_mask=encoder_mask, | |||
state=state) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
x = residual + x | |||
# ffn | |||
residual = x | |||
x = self.final_layer_norm(x) | |||
x = self.ffn(x) | |||
x = residual + x | |||
return x, attn_weight | |||
class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||
""" | |||
:param embed: 输入token的embedding | |||
:param nn.Module pos_embed: 位置embedding | |||
:param int d_model: 输出、输出的大小 | |||
:param int num_layers: 多少层 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN 的中间大小 | |||
:param float dropout: Self-Attention和FFN中的dropout的大小 | |||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
""" | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, | |||
d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, | |||
bind_decoder_input_output_embed = True): | |||
super().__init__() | |||
self.embed = get_embeddings(embed) | |||
self.pos_embed = pos_embed | |||
if bind_decoder_input_output_embed: | |||
self.output_layer = get_bind_decoder_output_embed(self.embed) | |||
else: # 不需要bind | |||
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||
self.output_layer = TiedEmbedding(self.output_embed.weight) | |||
self.num_layers = num_layers | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
self.dropout = dropout | |||
self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) | |||
for layer_idx in range(num_layers)]) | |||
self.embed_scale = math.sqrt(d_model) | |||
self.layer_norm = nn.LayerNorm(d_model) | |||
self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) | |||
def forward(self, tokens, state, return_attention=False): | |||
""" | |||
:param torch.LongTensor tokens: batch x tgt_len,decode的词 | |||
:param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 | |||
:param bool return_attention: 是否返回对encoder结果的attention score | |||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
""" | |||
encoder_output = state.encoder_output | |||
encoder_mask = state.encoder_mask | |||
assert state.decode_length<tokens.size(1), "The decoded tokens in State should be less than tokens." | |||
tokens = tokens[:, state.decode_length:] | |||
device = tokens.device | |||
x = self.embed_scale * self.embed(tokens) | |||
if self.pos_embed is not None: | |||
position = torch.arange(state.decode_length, state.decode_length+tokens.size(1)).long().to(device)[None] | |||
x += self.pos_embed(position) | |||
x = self.input_fc(x) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
batch_size, max_tgt_len = tokens.size() | |||
if max_tgt_len>1: | |||
triangle_mask = self._get_triangle_mask(tokens) | |||
else: | |||
triangle_mask = None | |||
for layer in self.layer_stacks: | |||
x, attn_weight = layer(x=x, | |||
encoder_output=encoder_output, | |||
encoder_mask=encoder_mask, | |||
self_attn_mask=triangle_mask, | |||
state=state | |||
) | |||
x = self.layer_norm(x) # batch, tgt_len, dim | |||
x = self.output_fc(x) | |||
feats = self.output_layer(x) | |||
if return_attention: | |||
return feats, attn_weight | |||
return feats | |||
def init_state(self, encoder_output, encoder_mask): | |||
""" | |||
初始化一个TransformerState用于forward | |||
:param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 | |||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 | |||
:return: TransformerState | |||
""" | |||
if isinstance(encoder_output, torch.Tensor): | |||
encoder_output = encoder_output | |||
elif isinstance(encoder_output, (list, tuple)): | |||
encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 | |||
else: | |||
raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") | |||
state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) | |||
return state | |||
@staticmethod | |||
def _get_triangle_mask(tokens): | |||
tensor = tokens.new_ones(tokens.size(1), tokens.size(1)) | |||
return torch.tril(tensor).byte() | |||
@@ -0,0 +1,145 @@ | |||
r""" | |||
每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录 | |||
""" | |||
__all__ = [ | |||
'State', | |||
"LSTMState", | |||
"TransformerState" | |||
] | |||
from typing import Union | |||
import torch | |||
class State: | |||
def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): | |||
""" | |||
每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 | |||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param kwargs: | |||
""" | |||
self.encoder_output = encoder_output | |||
self.encoder_mask = encoder_mask | |||
self._decode_length = 0 | |||
@property | |||
def num_samples(self): | |||
""" | |||
返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 | |||
:return: | |||
""" | |||
if self.encoder_output is not None: | |||
return self.encoder_output.size(0) | |||
else: | |||
return None | |||
@property | |||
def decode_length(self): | |||
""" | |||
当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 | |||
:return: | |||
""" | |||
return self._decode_length | |||
@decode_length.setter | |||
def decode_length(self, value): | |||
self._decode_length = value | |||
def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): | |||
if isinstance(state, torch.Tensor): | |||
state = state.index_select(index=indices, dim=dim) | |||
elif isinstance(state, list): | |||
for i in range(len(state)): | |||
assert state[i] is not None | |||
state[i] = self._reorder_state(state[i], indices, dim) | |||
elif isinstance(state, tuple): | |||
tmp_list = [] | |||
for i in range(len(state)): | |||
assert state[i] is not None | |||
tmp_list.append(self._reorder_state(state[i], indices, dim)) | |||
state = tuple(tmp_list) | |||
else: | |||
raise TypeError(f"Cannot reorder data of type:{type(state)}") | |||
return state | |||
def reorder_state(self, indices: torch.LongTensor): | |||
if self.encoder_mask is not None: | |||
self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||
if self.encoder_output is not None: | |||
self.encoder_output = self._reorder_state(self.encoder_output, indices) | |||
class LSTMState(State): | |||
def __init__(self, encoder_output, encoder_mask, hidden, cell): | |||
""" | |||
LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 | |||
:param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 | |||
:param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding | |||
:param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 | |||
:param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 | |||
""" | |||
super().__init__(encoder_output, encoder_mask) | |||
self.hidden = hidden | |||
self.cell = cell | |||
self._input_feed = hidden[0] # 默认是上一个时刻的输出 | |||
@property | |||
def input_feed(self): | |||
""" | |||
LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, | |||
input_feed即上个时刻的hidden state, 否则是attention layer的输出。 | |||
:return: torch.FloatTensor, bsz x hidden_size | |||
""" | |||
return self._input_feed | |||
@input_feed.setter | |||
def input_feed(self, value): | |||
self._input_feed = value | |||
def reorder_state(self, indices: torch.LongTensor): | |||
super().reorder_state(indices) | |||
self.hidden = self._reorder_state(self.hidden, indices, dim=1) | |||
self.cell = self._reorder_state(self.cell, indices, dim=1) | |||
if self.input_feed is not None: | |||
self.input_feed = self._reorder_state(self.input_feed, indices, dim=0) | |||
class TransformerState(State): | |||
def __init__(self, encoder_output, encoder_mask, num_decoder_layer): | |||
""" | |||
与TransformerSeq2SeqDecoder对应的State, | |||
:param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 | |||
:param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend | |||
:param int num_decoder_layer: decode有多少层 | |||
""" | |||
super().__init__(encoder_output, encoder_mask) | |||
self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim | |||
self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim | |||
self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||
self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||
def reorder_state(self, indices: torch.LongTensor): | |||
super().reorder_state(indices) | |||
self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||
self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||
@property | |||
def decode_length(self): | |||
if self.decoder_prev_key[0] is not None: | |||
return self.decoder_prev_key[0].size(1) | |||
return 0 | |||
@@ -0,0 +1,24 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"TimestepDropout" | |||
] | |||
import torch | |||
class TimestepDropout(torch.nn.Dropout): | |||
r""" | |||
传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` | |||
使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 | |||
""" | |||
def forward(self, x): | |||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | |||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | |||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||
if self.inplace: | |||
x *= dropout_mask | |||
return | |||
else: | |||
return x * dropout_mask |
@@ -1,5 +1,21 @@ | |||
__all__ = [ | |||
"ConvMaxpool", | |||
"LSTM", | |||
"Seq2SeqEncoder", | |||
"TransformerSeq2SeqEncoder", | |||
"LSTMSeq2SeqEncoder", | |||
"StarTransformer", | |||
"VarRNN", | |||
"VarLSTM", | |||
"VarGRU" | |||
] | |||
from .lstm import LSTM | |||
from .conv_maxpool import ConvMaxpool | |||
from .lstm import LSTM | |||
from .seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
from .star_transformer import StarTransformer | |||
from .variational_rnn import VarRNN, VarLSTM, VarGRU |
@@ -0,0 +1,87 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConvMaxpool" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class ConvMaxpool(nn.Module): | |||
r""" | |||
集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x | |||
sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) | |||
这一维进行max_pooling。最后得到每个sample的一个向量表示。 | |||
""" | |||
def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): | |||
r""" | |||
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | |||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | |||
""" | |||
super(ConvMaxpool, self).__init__() | |||
for kernel_size in kernel_sizes: | |||
assert kernel_size % 2 == 1, "kernel size has to be odd numbers." | |||
# convolution | |||
if isinstance(kernel_sizes, (list, tuple, int)): | |||
if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | |||
out_channels = [out_channels] | |||
kernel_sizes = [kernel_sizes] | |||
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | |||
assert len(out_channels) == len( | |||
kernel_sizes), "The number of out_channels should be equal to the number" \ | |||
" of kernel_sizes." | |||
else: | |||
raise ValueError("The type of out_channels and kernel_sizes should be the same.") | |||
self.convs = nn.ModuleList([nn.Conv1d( | |||
in_channels=in_channels, | |||
out_channels=oc, | |||
kernel_size=ks, | |||
stride=1, | |||
padding=ks // 2, | |||
dilation=1, | |||
groups=1, | |||
bias=False) | |||
for oc, ks in zip(out_channels, kernel_sizes)]) | |||
else: | |||
raise Exception( | |||
'Incorrect kernel sizes: should be list, tuple or int') | |||
# activation function | |||
if activation == 'relu': | |||
self.activation = F.relu | |||
elif activation == 'sigmoid': | |||
self.activation = F.sigmoid | |||
elif activation == 'tanh': | |||
self.activation = F.tanh | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: relu, tanh, sigmoid") | |||
def forward(self, x, mask=None): | |||
r""" | |||
:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | |||
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | |||
:return: | |||
""" | |||
# [N,L,C] -> [N,C,L] | |||
x = torch.transpose(x, 1, 2) | |||
# convolution | |||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | |||
if mask is not None: | |||
mask = mask.unsqueeze(1) # B x 1 x L | |||
xs = [x.masked_fill(mask.eq(False), float('-inf')) for x in xs] | |||
# max-pooling | |||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | |||
for i in xs] # [[N, C], ...] | |||
return torch.cat(xs, dim=-1) # [N, C] |
@@ -0,0 +1,193 @@ | |||
r"""undocumented""" | |||
import torch.nn as nn | |||
import torch | |||
from torch.nn import LayerNorm | |||
import torch.nn.functional as F | |||
from typing import Union, Tuple | |||
from ....core.utils import seq_len_to_mask | |||
import math | |||
from .lstm import LSTM | |||
from ..attention import MultiHeadAttention | |||
from ....embeddings.torch import StaticEmbedding | |||
from ....embeddings.torch.utils import get_embeddings | |||
__all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder'] | |||
class Seq2SeqEncoder(nn.Module): | |||
""" | |||
所有Sequence2Sequence Encoder的基类。需要实现forward函数 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def forward(self, tokens, seq_len): | |||
""" | |||
:param torch.LongTensor tokens: bsz x max_len, encoder的输入 | |||
:param torch.LongTensor seq_len: bsz | |||
:return: | |||
""" | |||
raise NotImplementedError | |||
class TransformerSeq2SeqEncoderLayer(nn.Module): | |||
""" | |||
Self-Attention的Layer, | |||
:param int d_model: input和output的输出维度 | |||
:param int n_head: 多少个head,每个head的维度为d_model/n_head | |||
:param int dim_ff: FFN的维度大小 | |||
:param float dropout: Self-attention和FFN的dropout大小,0表示不drop | |||
""" | |||
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, | |||
dropout: float = 0.1): | |||
super(TransformerSeq2SeqEncoderLayer, self).__init__() | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
self.dropout = dropout | |||
self.self_attn = MultiHeadAttention(d_model, n_head, dropout) | |||
self.attn_layer_norm = LayerNorm(d_model) | |||
self.ffn_layer_norm = LayerNorm(d_model) | |||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||
nn.ReLU(), | |||
nn.Dropout(dropout), | |||
nn.Linear(self.dim_ff, self.d_model), | |||
nn.Dropout(dropout)) | |||
def forward(self, x, mask): | |||
""" | |||
:param x: batch x src_seq x d_model | |||
:param mask: batch x src_seq,为0的地方为padding | |||
:return: | |||
""" | |||
# attention | |||
residual = x | |||
x = self.attn_layer_norm(x) | |||
x, _ = self.self_attn(query=x, | |||
key=x, | |||
value=x, | |||
key_mask=mask) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
x = residual + x | |||
# ffn | |||
residual = x | |||
x = self.ffn_layer_norm(x) | |||
x = self.ffn(x) | |||
x = residual + x | |||
return x | |||
class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||
""" | |||
基于Transformer的Encoder | |||
:param embed: encoder输入token的embedding | |||
:param nn.Module pos_embed: position embedding | |||
:param int num_layers: 多少层的encoder | |||
:param int d_model: 输入输出的维度 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN中间的维度大小 | |||
:param float dropout: Attention和FFN的dropout大小 | |||
""" | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, | |||
num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): | |||
super(TransformerSeq2SeqEncoder, self).__init__() | |||
self.embed = get_embeddings(embed) | |||
self.embed_scale = math.sqrt(d_model) | |||
self.pos_embed = pos_embed | |||
self.num_layers = num_layers | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
self.dropout = dropout | |||
self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) | |||
for _ in range(num_layers)]) | |||
self.layer_norm = LayerNorm(d_model) | |||
def forward(self, tokens, seq_len): | |||
""" | |||
:param tokens: batch x max_len | |||
:param seq_len: [batch] | |||
:return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) | |||
""" | |||
x = self.embed(tokens) * self.embed_scale # batch, seq, dim | |||
batch_size, max_src_len, _ = x.size() | |||
device = x.device | |||
if self.pos_embed is not None: | |||
position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) | |||
x += self.pos_embed(position) | |||
x = self.input_fc(x) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len) | |||
encoder_mask = encoder_mask.to(device) | |||
for layer in self.layer_stacks: | |||
x = layer(x, encoder_mask) | |||
x = self.layer_norm(x) | |||
return x, encoder_mask | |||
class LSTMSeq2SeqEncoder(Seq2SeqEncoder): | |||
""" | |||
LSTM的Encoder | |||
:param embed: encoder的token embed | |||
:param int num_layers: 多少层 | |||
:param int hidden_size: LSTM隐藏层、输出的大小 | |||
:param float dropout: LSTM层之间的Dropout是多少 | |||
:param bool bidirectional: 是否使用双向 | |||
""" | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, | |||
hidden_size = 400, dropout = 0.3, bidirectional=True): | |||
super().__init__() | |||
self.embed = get_embeddings(embed) | |||
self.num_layers = num_layers | |||
self.dropout = dropout | |||
self.hidden_size = hidden_size | |||
self.bidirectional = bidirectional | |||
hidden_size = hidden_size//2 if bidirectional else hidden_size | |||
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, | |||
batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) | |||
def forward(self, tokens, seq_len): | |||
""" | |||
:param torch.LongTensor tokens: bsz x max_len | |||
:param torch.LongTensor seq_len: bsz | |||
:return: (output, (hidden, cell)), encoder_mask | |||
output: bsz x max_len x hidden_size, | |||
hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 | |||
encoder_mask: bsz x max_len, 为0的地方是padding | |||
""" | |||
x = self.embed(tokens) | |||
device = x.device | |||
x, (final_hidden, final_cell) = self.lstm(x, seq_len) | |||
encoder_mask = seq_len_to_mask(seq_len).to(device) | |||
# x: batch,seq_len,dim; h/c: num_layers*2,batch,dim | |||
if self.bidirectional: | |||
final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input | |||
final_cell = self.concat_bidir(final_cell) | |||
return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return | |||
def concat_bidir(self, input): | |||
output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2) | |||
return output.reshape(self.num_layers, input.size(1), -1) |
@@ -0,0 +1,166 @@ | |||
r"""undocumented | |||
Star-Transformer 的encoder部分的 Pytorch 实现 | |||
""" | |||
__all__ = [ | |||
"StarTransformer" | |||
] | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
class StarTransformer(nn.Module): | |||
r""" | |||
Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 | |||
paper: https://arxiv.org/abs/1902.09113 | |||
""" | |||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | |||
r""" | |||
:param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 | |||
:param int num_layers: star-transformer的层数 | |||
:param int num_head: head的数量。 | |||
:param int head_dim: 每个head的维度大小。 | |||
:param float dropout: dropout 概率. Default: 0.1 | |||
:param int max_len: int or None, 如果为int,输入序列的最大长度, | |||
模型会为输入序列加上position embedding。 | |||
若为`None`,忽略加上position embedding的步骤. Default: `None` | |||
""" | |||
super(StarTransformer, self).__init__() | |||
self.iters = num_layers | |||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)]) | |||
# self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1) | |||
self.emb_drop = nn.Dropout(dropout) | |||
self.ring_att = nn.ModuleList( | |||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||
for _ in range(self.iters)]) | |||
self.star_att = nn.ModuleList( | |||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||
for _ in range(self.iters)]) | |||
if max_len is not None: | |||
self.pos_emb = nn.Embedding(max_len, hidden_size) | |||
else: | |||
self.pos_emb = None | |||
def forward(self, data, mask): | |||
r""" | |||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | |||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||
否则为 1 | |||
:return: [batch, length, hidden] 编码后的输出序列 | |||
[batch, hidden] 全局 relay 节点, 详见论文 | |||
""" | |||
def norm_func(f, x): | |||
# B, H, L, 1 | |||
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |||
B, L, H = data.size() | |||
mask = (mask.eq(False)) # flip the mask for masked_fill_ | |||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | |||
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | |||
if self.pos_emb: | |||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | |||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||
embs = embs + P | |||
embs = norm_func(self.emb_drop, embs) | |||
nodes = embs | |||
relay = embs.mean(2, keepdim=True) | |||
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | |||
r_embs = embs.view(B, H, 1, L) | |||
for i in range(self.iters): | |||
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | |||
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||
# nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax)) | |||
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | |||
nodes = nodes.masked_fill_(ex_mask, 0) | |||
nodes = nodes.view(B, H, L).permute(0, 2, 1) | |||
return nodes, relay.view(B, H) | |||
class _MSA1(nn.Module): | |||
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | |||
super(_MSA1, self).__init__() | |||
# Multi-head Self Attention Case 1, doing self-attention for small regions | |||
# Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small | |||
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | |||
self.drop = nn.Dropout(dropout) | |||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | |||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | |||
def forward(self, x, ax=None): | |||
# x: B, H, L, 1, ax : B, H, X, L append features | |||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | |||
B, H, L, _ = x.shape | |||
q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) | |||
if ax is not None: | |||
aL = ax.shape[2] | |||
ak = self.WK(ax).view(B, nhead, head_dim, aL, L) | |||
av = self.WV(ax).view(B, nhead, head_dim, aL, L) | |||
q = q.view(B, nhead, head_dim, 1, L) | |||
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||
.view(B, nhead, head_dim, unfold_size, L) | |||
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||
.view(B, nhead, head_dim, unfold_size, L) | |||
if ax is not None: | |||
k = torch.cat([k, ak], 3) | |||
v = torch.cat([v, av], 3) | |||
alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / np.sqrt(head_dim), 3)) # B N L 1 U | |||
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) | |||
ret = self.WO(att) | |||
return ret | |||
class _MSA2(nn.Module): | |||
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | |||
# Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value | |||
super(_MSA2, self).__init__() | |||
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | |||
self.drop = nn.Dropout(dropout) | |||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | |||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | |||
def forward(self, x, y, mask=None): | |||
# x: B, H, 1, 1, 1 y: B H L 1 | |||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | |||
B, H, L, _ = y.shape | |||
q, k, v = self.WQ(x), self.WK(y), self.WV(y) | |||
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h | |||
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L | |||
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h | |||
pre_a = torch.matmul(q, k) / np.sqrt(head_dim) | |||
if mask is not None: | |||
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) | |||
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L | |||
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 | |||
return self.WO(att) |
@@ -0,0 +1,43 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"TransformerEncoder" | |||
] | |||
from torch import nn | |||
from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||
class TransformerEncoder(nn.Module): | |||
r""" | |||
transformer的encoder模块,不包含embedding层 | |||
""" | |||
def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): | |||
""" | |||
:param int num_layers: 多少层Transformer | |||
:param int d_model: input和output的大小 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN中间hidden大小 | |||
:param float dropout: 多大概率drop attention和ffn中间的表示 | |||
""" | |||
super(TransformerEncoder, self).__init__() | |||
self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, | |||
dropout = dropout) for _ in range(num_layers)]) | |||
self.norm = nn.LayerNorm(d_model, eps=1e-6) | |||
def forward(self, x, seq_mask=None): | |||
r""" | |||
:param x: [batch, seq_len, model_size] 输入序列 | |||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend | |||
Default: ``None`` | |||
:return: [batch, seq_len, model_size] 输出序列 | |||
""" | |||
output = x | |||
if seq_mask is None: | |||
seq_mask = x.new_ones(x.size(0), x.size(1)).bool() | |||
for layer in self.layers: | |||
output = layer(output, seq_mask) | |||
return self.norm(output) |
@@ -0,0 +1,303 @@ | |||
r"""undocumented | |||
Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | |||
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
__all__ = [ | |||
"VarRNN", | |||
"VarLSTM", | |||
"VarGRU" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |||
try: | |||
from torch import flip | |||
except ImportError: | |||
def flip(x, dims): | |||
indices = [slice(None)] * x.dim() | |||
for dim in dims: | |||
indices[dim] = torch.arange( | |||
x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | |||
return x[tuple(indices)] | |||
class VarRnnCellWrapper(nn.Module): | |||
r""" | |||
Wrapper for normal RNN Cells, make it support variational dropout | |||
""" | |||
def __init__(self, cell, hidden_size, input_p, hidden_p): | |||
super(VarRnnCellWrapper, self).__init__() | |||
self.cell = cell | |||
self.hidden_size = hidden_size | |||
self.input_p = input_p | |||
self.hidden_p = hidden_p | |||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | |||
r""" | |||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | |||
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | |||
for other RNN, h_0, [batch_size, hidden_size] | |||
:param mask_x: [batch_size, input_size] dropout mask for input | |||
:param mask_h: [batch_size, hidden_size] dropout mask for hidden | |||
:return PackedSequence output: [seq_len, bacth_size, hidden_size] | |||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||
for other RNN, h_n, [batch_size, hidden_size] | |||
""" | |||
def get_hi(hi, h0, size): | |||
h0_size = size - hi.size(0) | |||
if h0_size > 0: | |||
return torch.cat([hi, h0[:h0_size]], dim=0) | |||
return hi[:size] | |||
is_lstm = isinstance(hidden, tuple) | |||
input, batch_sizes = input_x.data, input_x.batch_sizes | |||
output = [] | |||
cell = self.cell | |||
if is_reversed: | |||
batch_iter = flip(batch_sizes, [0]) | |||
idx = input.size(0) | |||
else: | |||
batch_iter = batch_sizes | |||
idx = 0 | |||
if is_lstm: | |||
hn = (hidden[0].clone(), hidden[1].clone()) | |||
else: | |||
hn = hidden.clone() | |||
hi = hidden | |||
for size in batch_iter: | |||
if is_reversed: | |||
input_i = input[idx - size: idx] * mask_x[:size] | |||
idx -= size | |||
else: | |||
input_i = input[idx: idx + size] * mask_x[:size] | |||
idx += size | |||
mask_hi = mask_h[:size] | |||
if is_lstm: | |||
hx, cx = hi | |||
hi = (get_hi(hx, hidden[0], size) * | |||
mask_hi, get_hi(cx, hidden[1], size)) | |||
hi = cell(input_i, hi) | |||
hn[0][:size] = hi[0] | |||
hn[1][:size] = hi[1] | |||
output.append(hi[0]) | |||
else: | |||
hi = get_hi(hi, hidden, size) * mask_hi | |||
hi = cell(input_i, hi) | |||
hn[:size] = hi | |||
output.append(hi) | |||
if is_reversed: | |||
output = list(reversed(output)) | |||
output = torch.cat(output, dim=0) | |||
return PackedSequence(output, batch_sizes), hn | |||
class VarRNNBase(nn.Module): | |||
r""" | |||
Variational Dropout RNN 实现. | |||
论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | |||
https://arxiv.org/abs/1512.05287`. | |||
""" | |||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||
bias=True, batch_first=False, | |||
input_dropout=0, hidden_dropout=0, bidirectional=False): | |||
r""" | |||
:param mode: rnn 模式, (lstm or not) | |||
:param Cell: rnn cell 类型, (lstm, gru, etc) | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
(batch, seq, feature). Default: ``False`` | |||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
""" | |||
super(VarRNNBase, self).__init__() | |||
self.mode = mode | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.bias = bias | |||
self.batch_first = batch_first | |||
self.input_dropout = input_dropout | |||
self.hidden_dropout = hidden_dropout | |||
self.bidirectional = bidirectional | |||
self.num_directions = 2 if bidirectional else 1 | |||
self._all_cells = nn.ModuleList() | |||
for layer in range(self.num_layers): | |||
for direction in range(self.num_directions): | |||
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | |||
cell = Cell(input_size, self.hidden_size, bias) | |||
self._all_cells.append(VarRnnCellWrapper( | |||
cell, self.hidden_size, input_dropout, hidden_dropout)) | |||
self.is_lstm = (self.mode == "LSTM") | |||
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | |||
is_lstm = self.is_lstm | |||
idx = self.num_directions * n_layer + n_direction | |||
cell = self._all_cells[idx] | |||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||
output_x, hidden_x = cell( | |||
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||
return output_x, hidden_x | |||
def forward(self, x, hx=None): | |||
r""" | |||
:param x: [batch, seq_len, input_size] 输入序列 | |||
:param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||
:return (output, ht): [batch, seq_len, hidden_size*num_direction] 输出序列 | |||
和 [batch, hidden_size*num_direction] 最后时刻隐状态 | |||
""" | |||
is_lstm = self.is_lstm | |||
is_packed = isinstance(x, PackedSequence) | |||
if not is_packed: | |||
seq_len = x.size(1) if self.batch_first else x.size(0) | |||
max_batch_size = x.size(0) if self.batch_first else x.size(1) | |||
seq_lens = torch.LongTensor( | |||
[seq_len for _ in range(max_batch_size)]) | |||
x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) | |||
else: | |||
max_batch_size = int(x.batch_sizes[0]) | |||
x, batch_sizes = x.data, x.batch_sizes | |||
if hx is None: | |||
hx = x.new_zeros(self.num_layers * self.num_directions, | |||
max_batch_size, self.hidden_size, requires_grad=True) | |||
if is_lstm: | |||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | |||
mask_x = x.new_ones((max_batch_size, self.input_size)) | |||
mask_out = x.new_ones( | |||
(max_batch_size, self.hidden_size * self.num_directions)) | |||
mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) | |||
nn.functional.dropout(mask_x, p=self.input_dropout, | |||
training=self.training, inplace=True) | |||
nn.functional.dropout(mask_out, p=self.hidden_dropout, | |||
training=self.training, inplace=True) | |||
hidden = x.new_zeros( | |||
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||
if is_lstm: | |||
cellstate = x.new_zeros( | |||
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||
for layer in range(self.num_layers): | |||
output_list = [] | |||
input_seq = PackedSequence(x, batch_sizes) | |||
mask_h = nn.functional.dropout( | |||
mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||
for direction in range(self.num_directions): | |||
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | |||
mask_x if layer == 0 else mask_out, mask_h) | |||
output_list.append(output_x.data) | |||
idx = self.num_directions * layer + direction | |||
if is_lstm: | |||
hidden[idx] = hidden_x[0] | |||
cellstate[idx] = hidden_x[1] | |||
else: | |||
hidden[idx] = hidden_x | |||
x = torch.cat(output_list, dim=-1) | |||
if is_lstm: | |||
hidden = (hidden, cellstate) | |||
if is_packed: | |||
output = PackedSequence(x, batch_sizes) | |||
else: | |||
x = PackedSequence(x, batch_sizes) | |||
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | |||
return output, hidden | |||
class VarLSTM(VarRNNBase): | |||
r""" | |||
Variational Dropout LSTM. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
(batch, seq, feature). Default: ``False`` | |||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | |||
""" | |||
super(VarLSTM, self).__init__( | |||
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
return super(VarLSTM, self).forward(x, hx) | |||
class VarRNN(VarRNNBase): | |||
r""" | |||
Variational Dropout RNN. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
(batch, seq, feature). Default: ``False`` | |||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
""" | |||
super(VarRNN, self).__init__( | |||
mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
return super(VarRNN, self).forward(x, hx) | |||
class VarGRU(VarRNNBase): | |||
r""" | |||
Variational Dropout GRU. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
(batch, seq, feature). Default: ``False`` | |||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | |||
""" | |||
super(VarGRU, self).__init__( | |||
mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
return super(VarGRU, self).forward(x, hx) |
@@ -0,0 +1,6 @@ | |||
__all__ = [ | |||
'SequenceGenerator' | |||
] | |||
from .seq2seq_generator import SequenceGenerator |
@@ -0,0 +1,536 @@ | |||
r""" | |||
""" | |||
__all__ = [ | |||
'SequenceGenerator' | |||
] | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||
from functools import partial | |||
def _get_model_device(model): | |||
r""" | |||
传入一个nn.Module的模型,获取它所在的device | |||
:param model: nn.Module | |||
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | |||
""" | |||
assert isinstance(model, nn.Module) | |||
parameters = list(model.parameters()) | |||
if len(parameters) == 0: | |||
return None | |||
else: | |||
return parameters[0].device | |||
class SequenceGenerator: | |||
""" | |||
给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出, | |||
第二个参数为state。SequenceGenerator不会对state进行任何操作。 | |||
""" | |||
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, | |||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
""" | |||
:param Seq2SeqDecoder decoder: Decoder对象 | |||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
:param int num_beams: beam search的大小 | |||
:param bool do_sample: 是否通过采样的方式生成 | |||
:param float temperature: 只有在do_sample为True才有意义 | |||
:param int top_k: 只从top_k中采样 | |||
:param float top_p: 只从top_p的token中采样,nucles sample | |||
:param int,None bos_token_id: 句子开头的token id | |||
:param int,None eos_token_id: 句子结束的token id | |||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
""" | |||
if do_sample: | |||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, | |||
num_beams=num_beams, | |||
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, | |||
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, | |||
length_penalty=length_penalty, pad_token_id=pad_token_id) | |||
else: | |||
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, | |||
num_beams=num_beams, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, | |||
repetition_penalty=repetition_penalty, | |||
length_penalty=length_penalty, pad_token_id=pad_token_id) | |||
self.do_sample = do_sample | |||
self.max_length = max_length | |||
self.num_beams = num_beams | |||
self.temperature = temperature | |||
self.top_k = top_k | |||
self.top_p = top_p | |||
self.bos_token_id = bos_token_id | |||
self.eos_token_id = eos_token_id | |||
self.repetition_penalty = repetition_penalty | |||
self.length_penalty = length_penalty | |||
self.decoder = decoder | |||
@torch.no_grad() | |||
def generate(self, state, tokens=None): | |||
""" | |||
:param State state: encoder结果的State, 是与Decoder配套是用的 | |||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token | |||
进行生成。 | |||
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | |||
""" | |||
return self.generate_func(tokens=tokens, state=state) | |||
@torch.no_grad() | |||
def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, | |||
bos_token_id=None, eos_token_id=None, pad_token_id=0, | |||
repetition_penalty=1, length_penalty=1.0): | |||
""" | |||
贪婪地搜索句子 | |||
:param Decoder decoder: Decoder对象 | |||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
:param State state: 应该包含encoder的一些输出。 | |||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
:param int num_beams: 使用多大的beam进行解码。 | |||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||
:param int pad_token_id: pad的token id | |||
:param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||
:return: | |||
""" | |||
if num_beams == 1: | |||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||
temperature=1, top_k=50, top_p=1, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
else: | |||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||
num_beams=num_beams, temperature=1, top_k=50, top_p=1, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
return token_ids | |||
@torch.no_grad() | |||
def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, temperature=1.0, top_k=50, | |||
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, | |||
length_penalty=1.0): | |||
""" | |||
使用采样的方法生成句子 | |||
:param Decoder decoder: Decoder对象 | |||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
:param State state: 应该包含encoder的一些输出。 | |||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
:param int num_beam: 使用多大的beam进行解码。 | |||
:param float temperature: 采样时的退火大小 | |||
:param int top_k: 只在top_k的sample里面采样 | |||
:param float top_p: 介于0,1的值。 | |||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||
:param int pad_token_id: pad的token id | |||
:param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||
:return: | |||
""" | |||
# 每个位置在生成的时候会sample生成 | |||
if num_beams == 1: | |||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||
temperature=temperature, top_k=top_k, top_p=top_p, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
else: | |||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||
num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
return token_ids | |||
def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, max_len_a=0.0, temperature=1.0, top_k=50, | |||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): | |||
device = _get_model_device(decoder) | |||
if tokens is None: | |||
if bos_token_id is None: | |||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||
batch_size = state.num_samples | |||
if batch_size is None: | |||
raise RuntimeError("Cannot infer the number of samples from `state`.") | |||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||
batch_size = tokens.size(0) | |||
if state.num_samples: | |||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||
if eos_token_id is None: | |||
_eos_token_id = -1 | |||
else: | |||
_eos_token_id = eos_token_id | |||
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state | |||
if _eos_token_id!=-1: # 防止第一个位置为结束 | |||
scores[:, _eos_token_id] = -1e12 | |||
next_tokens = scores.argmax(dim=-1, keepdim=True) | |||
token_ids = torch.cat([tokens, next_tokens], dim=1) | |||
cur_len = token_ids.size(1) | |||
dones = token_ids.new_zeros(batch_size).eq(1) | |||
# tokens = tokens[:, -1:] | |||
if max_len_a!=0: | |||
# (bsz x num_beams, ) | |||
if state.encoder_mask is not None: | |||
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length | |||
else: | |||
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) | |||
real_max_length = max_lengths.max().item() | |||
else: | |||
real_max_length = max_length | |||
if state.encoder_mask is not None: | |||
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length | |||
else: | |||
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) | |||
while cur_len < real_max_length: | |||
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size | |||
if repetition_penalty != 1.0: | |||
token_scores = scores.gather(dim=1, index=token_ids) | |||
lt_zero_mask = token_scores.lt(0).float() | |||
ge_zero_mask = lt_zero_mask.eq(0).float() | |||
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores | |||
scores.scatter_(dim=1, index=token_ids, src=token_scores) | |||
if eos_token_id is not None and length_penalty != 1.0: | |||
token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size | |||
eos_mask = scores.new_ones(scores.size(1)) | |||
eos_mask[eos_token_id] = 0 | |||
eos_mask = eos_mask.unsqueeze(0).eq(1) | |||
scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 | |||
if do_sample: | |||
if temperature > 0 and temperature != 1: | |||
scores = scores / temperature | |||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) | |||
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size | |||
else: | |||
next_tokens = torch.argmax(scores, dim=-1) # batch_size | |||
# 如果已经达到对应的sequence长度了,就直接填为eos了 | |||
if _eos_token_id!=-1: | |||
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) | |||
next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding | |||
tokens = next_tokens.unsqueeze(1) | |||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len | |||
end_mask = next_tokens.eq(_eos_token_id) | |||
dones = dones.__or__(end_mask) | |||
cur_len += 1 | |||
if dones.min() == 1: | |||
break | |||
# if eos_token_id is not None: | |||
# tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # 将最大长度位置设置为eos | |||
# if cur_len == max_length: | |||
# token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos | |||
return token_ids | |||
def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, temperature=1.0, | |||
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: | |||
# 进行beam search | |||
device = _get_model_device(decoder) | |||
if tokens is None: | |||
if bos_token_id is None: | |||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||
batch_size = state.num_samples | |||
if batch_size is None: | |||
raise RuntimeError("Cannot infer the number of samples from `state`.") | |||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||
batch_size = tokens.size(0) | |||
if state.num_samples: | |||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||
if eos_token_id is None: | |||
_eos_token_id = -1 | |||
else: | |||
_eos_token_id = eos_token_id | |||
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 | |||
if _eos_token_id!=-1: # 防止第一个位置为结束 | |||
scores[:, _eos_token_id] = -1e12 | |||
vocab_size = scores.size(1) | |||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | |||
if do_sample: | |||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||
next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) | |||
logits = probs.log() | |||
next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) | |||
else: | |||
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) | |||
# 得到(batch_size, num_beams), (batch_size, num_beams) | |||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | |||
# 根据index来做顺序的调转 | |||
indices = torch.arange(batch_size, dtype=torch.long).to(device) | |||
indices = indices.repeat_interleave(num_beams) | |||
state.reorder_state(indices) | |||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | |||
# 记录生成好的token (batch_size', cur_len) | |||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) | |||
dones = [False] * batch_size | |||
beam_scores = next_scores.view(-1) # batch_size * num_beams | |||
# 用来记录已经生成好的token的长度 | |||
cur_len = token_ids.size(1) | |||
if max_len_a!=0: | |||
# (bsz x num_beams, ) | |||
if state.encoder_mask is not None: | |||
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length | |||
else: | |||
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) | |||
real_max_length = max_lengths.max().item() | |||
else: | |||
real_max_length = max_length | |||
if state.encoder_mask is not None: | |||
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length | |||
else: | |||
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) | |||
hypos = [ | |||
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) | |||
] | |||
# 0, num_beams, 2*num_beams, ... | |||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | |||
while cur_len < real_max_length: | |||
scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size) | |||
if repetition_penalty != 1.0: | |||
token_scores = scores.gather(dim=1, index=token_ids) | |||
lt_zero_mask = token_scores.lt(0).float() | |||
ge_zero_mask = lt_zero_mask.eq(0).float() | |||
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores | |||
scores.scatter_(dim=1, index=token_ids, src=token_scores) | |||
if _eos_token_id!=-1: | |||
max_len_eos_mask = max_lengths.eq(cur_len+1) | |||
eos_scores = scores[:, _eos_token_id] | |||
# 如果已经达到最大长度,就把eos的分数加大 | |||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores) | |||
if do_sample: | |||
if temperature > 0 and temperature != 1: | |||
scores = scores / temperature | |||
# 多召回一个防止eos | |||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) | |||
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||
# 保证至少有一个不是eos的值 | |||
_tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) | |||
logits = probs.log() | |||
# 防止全是这个beam的被选中了,且需要考虑eos被选择的情况 | |||
_scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1) | |||
_scores = _scores + beam_scores[:, None] # batch_size' x (num_beams+1) | |||
# 从这里面再选择top的2*num_beam个 | |||
_scores = _scores.view(batch_size, num_beams * (num_beams + 1)) | |||
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) | |||
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) | |||
next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) | |||
from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams) | |||
else: | |||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) | |||
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) | |||
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) | |||
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) | |||
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) | |||
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) | |||
# 接下来需要组装下一个batch的结果。 | |||
# 需要选定哪些留下来 | |||
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) | |||
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | |||
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | |||
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos | |||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | |||
keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 | |||
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1) | |||
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) # 上面的token是来自哪个beam | |||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) | |||
beam_scores = _next_scores.view(-1) | |||
flag = True | |||
if cur_len+1 == real_max_length: | |||
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) | |||
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice | |||
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 | |||
else: | |||
# 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 | |||
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams | |||
if effective_eos_mask.sum().gt(0): | |||
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) | |||
# 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams | |||
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind | |||
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos | |||
else: | |||
flag = False | |||
if flag: | |||
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1) | |||
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), | |||
eos_beam_idx.tolist()): | |||
if not dones[batch_idx]: | |||
score = next_scores[batch_idx, beam_ind].item() | |||
# 之后需要在结尾新增一个eos | |||
if _eos_token_id!=-1: | |||
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) | |||
else: | |||
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score) | |||
# 更改state状态, 重组token_ids | |||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||
state.reorder_state(reorder_inds) | |||
# 重新组织token_ids的状态 | |||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1) | |||
for batch_idx in range(batch_size): | |||
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ | |||
max_lengths[batch_idx*num_beams]==cur_len+1 | |||
cur_len += 1 | |||
if all(dones): | |||
break | |||
# select the best hypotheses | |||
tgt_len = token_ids.new_zeros(batch_size) | |||
best = [] | |||
for i, hypotheses in enumerate(hypos): | |||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] | |||
# 把上面替换为非eos的词替换回eos | |||
if _eos_token_id!=-1: | |||
best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id]) | |||
tgt_len[i] = len(best_hyp) | |||
best.append(best_hyp) | |||
# generate target batch | |||
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id) | |||
for i, hypo in enumerate(best): | |||
decoded[i, :tgt_len[i]] = hypo | |||
return decoded | |||
class BeamHypotheses(object): | |||
def __init__(self, num_beams, max_length, length_penalty, early_stopping): | |||
""" | |||
Initialize n-best list of hypotheses. | |||
""" | |||
self.max_length = max_length - 1 # ignoring bos_token | |||
self.length_penalty = length_penalty | |||
self.early_stopping = early_stopping | |||
self.num_beams = num_beams | |||
self.hyp = [] | |||
self.worst_score = 1e9 | |||
def __len__(self): | |||
""" | |||
Number of hypotheses in the list. | |||
""" | |||
return len(self.hyp) | |||
def add(self, hyp, sum_logprobs): | |||
""" | |||
Add a new hypothesis to the list. | |||
""" | |||
score = sum_logprobs / len(hyp) ** self.length_penalty | |||
if len(self) < self.num_beams or score > self.worst_score: | |||
self.hyp.append((score, hyp)) | |||
if len(self) > self.num_beams: | |||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) | |||
del self.hyp[sorted_scores[0][1]] | |||
self.worst_score = sorted_scores[1][0] | |||
else: | |||
self.worst_score = min(score, self.worst_score) | |||
def is_done(self, best_sum_logprobs): | |||
""" | |||
If there are enough hypotheses and that none of the hypotheses being generated | |||
can become better than the worst one in the heap, then we are done with this sentence. | |||
""" | |||
if len(self) < self.num_beams: | |||
return False | |||
elif self.early_stopping: | |||
return True | |||
else: | |||
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty | |||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): | |||
""" | |||
根据top_k, top_p的值,将不满足的值置为filter_value的值 | |||
:param torch.Tensor logits: bsz x vocab_size | |||
:param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value | |||
:param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式 | |||
:param float filter_value: | |||
:param int min_tokens_to_keep: 每个sample返回的分布中有概率的词不会低于这个值 | |||
:return: | |||
""" | |||
if top_k > 0: | |||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check | |||
# Remove all tokens with a probability less than the last token of the top-k | |||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |||
logits[indices_to_remove] = filter_value | |||
if top_p < 1.0: | |||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |||
sorted_indices_to_remove = cumulative_probs > top_p | |||
if min_tokens_to_keep > 1: | |||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |||
# Shift the indices to the right to keep also the first token above the threshold | |||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |||
sorted_indices_to_remove[..., 0] = 0 | |||
# scatter sorted tensors to original indexing | |||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |||
logits[indices_to_remove] = filter_value | |||
return logits |
@@ -0,0 +1,142 @@ | |||
""" | |||
此模块可以非常方便的测试模型。 | |||
若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试 | |||
若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试 | |||
此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能 | |||
Example:: | |||
# import 全大写变量... | |||
from model_runner import * | |||
# 测试一个文本分类模型 | |||
init_emb = (VOCAB_SIZE, 50) | |||
model = SomeModel(init_emb, num_cls=NUM_CLS) | |||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||
# 序列标注模型 | |||
RUNNER.run_model_with_task(POS_TAGGING, model) | |||
# NLI模型 | |||
RUNNER.run_model_with_task(NLI, model) | |||
# 自定义模型 | |||
RUNNER.run_model(model, data=get_mydata(), | |||
loss=Myloss(), metrics=Mymetric()) | |||
""" | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from torch import optim | |||
from fastNLP import Trainer, Evaluator, DataSet, Callback | |||
from fastNLP import Accuracy | |||
from random import randrange | |||
from fastNLP import TorchDataLoader | |||
VOCAB_SIZE = 100 | |||
NUM_CLS = 100 | |||
MAX_LEN = 10 | |||
N_SAMPLES = 100 | |||
N_EPOCHS = 1 | |||
BATCH_SIZE = 5 | |||
TEXT_CLS = 'text_cls' | |||
POS_TAGGING = 'pos_tagging' | |||
NLI = 'nli' | |||
class ModelRunner(): | |||
class Checker(Callback): | |||
def on_backward_begin(self, trainer, outputs): | |||
assert outputs['loss'].to('cpu').numpy().isfinate() | |||
def gen_seq(self, length, vocab_size): | |||
"""generate fake sequence indexes with given length""" | |||
# reserve 0 for padding | |||
return [randrange(1, vocab_size) for _ in range(length)] | |||
def gen_var_seq(self, max_len, vocab_size): | |||
"""generate fake sequence indexes in variant length""" | |||
length = randrange(3, max_len) # at least 3 words in a seq | |||
return self.gen_seq(length, vocab_size) | |||
def prepare_text_classification_data(self): | |||
index = 'index' | |||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name='words') | |||
ds.apply_field(lambda x: randrange(NUM_CLS), | |||
field_name=index, new_field_name='target') | |||
ds.apply_field(len, 'words', 'seq_len') | |||
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||
return dl | |||
def prepare_pos_tagging_data(self): | |||
index = 'index' | |||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name='words') | |||
ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS), | |||
field_name='words', new_field_name='target') | |||
ds.apply_field(len, 'words', 'seq_len') | |||
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||
return dl | |||
def prepare_nli_data(self): | |||
index = 'index' | |||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name='words1') | |||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name='words2') | |||
ds.apply_field(lambda x: randrange(NUM_CLS), | |||
field_name=index, new_field_name='target') | |||
ds.apply_field(len, 'words1', 'seq_len1') | |||
ds.apply_field(len, 'words2', 'seq_len2') | |||
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||
return dl | |||
def run_text_classification(self, model, data=None): | |||
if data is None: | |||
data = self.prepare_text_classification_data() | |||
metric = Accuracy() | |||
self.run_model(model, data, metric) | |||
def run_pos_tagging(self, model, data=None): | |||
if data is None: | |||
data = self.prepare_pos_tagging_data() | |||
metric = Accuracy() | |||
self.run_model(model, data, metric) | |||
def run_nli(self, model, data=None): | |||
if data is None: | |||
data = self.prepare_nli_data() | |||
metric = Accuracy() | |||
self.run_model(model, data, metric) | |||
def run_model(self, model, data, metrics): | |||
"""run a model, test if it can run with fastNLP""" | |||
print('testing model:', model.__class__.__name__) | |||
tester = Evaluator(model, data, metrics={'metric': metrics}, driver='torch') | |||
before_train = tester.run() | |||
optimizer = optim.SGD(model.parameters(), lr=1e-3) | |||
trainer = Trainer(model, driver='torch', train_dataloader=data, | |||
n_epochs=N_EPOCHS, optimizers=optimizer) | |||
trainer.run() | |||
after_train = tester.run() | |||
for metric_name, v1 in before_train.items(): | |||
assert metric_name in after_train | |||
# # at least we can sure model params changed, even if we don't know performance | |||
# v2 = after_train[metric_name] | |||
# assert v1 != v2 | |||
def run_model_with_task(self, task, model): | |||
"""run a model with certain task""" | |||
TASKS = { | |||
TEXT_CLS: self.run_text_classification, | |||
POS_TAGGING: self.run_pos_tagging, | |||
NLI: self.run_nli, | |||
} | |||
assert task in TASKS | |||
TASKS[task](model) | |||
RUNNER = ModelRunner() |
@@ -0,0 +1,91 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from fastNLP.models.torch.biaffine_parser import BiaffineParser | |||
from fastNLP import Metric, seq_len_to_mask | |||
from .model_runner import * | |||
class ParserMetric(Metric): | |||
r""" | |||
评估parser的性能 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.num_arc = 0 | |||
self.num_label = 0 | |||
self.num_sample = 0 | |||
def get_metric(self, reset=True): | |||
res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample} | |||
if reset: | |||
self.num_sample = self.num_label = self.num_arc = 0 | |||
return res | |||
def update(self, pred1, pred2, target1, target2, seq_len=None): | |||
r""" | |||
:param pred1: 边预测logits | |||
:param pred2: label预测logits | |||
:param target1: 真实边的标注 | |||
:param target2: 真实类别的标注 | |||
:param seq_len: 序列长度 | |||
:return dict: 评估结果:: | |||
UAS: 不带label时, 边预测的准确率 | |||
LAS: 同时预测边和label的准确率 | |||
""" | |||
if seq_len is None: | |||
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | |||
else: | |||
seq_mask = seq_len_to_mask(seq_len.long()).long() | |||
# mask out <root> tag | |||
seq_mask[:, 0] = 0 | |||
head_pred_correct = (pred1 == target1).long() * seq_mask | |||
label_pred_correct = (pred2 == target2).long() * head_pred_correct | |||
self.num_arc += head_pred_correct.sum().item() | |||
self.num_label += label_pred_correct.sum().item() | |||
self.num_sample += seq_mask.sum().item() | |||
def prepare_parser_data(): | |||
index = 'index' | |||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||
ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name='words1') | |||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||
field_name='words1', new_field_name='words2') | |||
# target1 is heads, should in range(0, len(words)) | |||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)), | |||
field_name='words1', new_field_name='target1') | |||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||
field_name='words1', new_field_name='target2') | |||
ds.apply_field(len, field_name='words1', new_field_name='seq_len') | |||
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||
return dl | |||
@pytest.mark.torch | |||
class TestBiaffineParser: | |||
def test_train(self): | |||
model = BiaffineParser(embed=(VOCAB_SIZE, 10), | |||
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, | |||
rnn_hidden_size=10, | |||
arc_mlp_size=10, | |||
label_mlp_size=10, | |||
num_label=NUM_CLS, encoder='var-lstm') | |||
ds = prepare_parser_data() | |||
RUNNER.run_model(model, ds, metrics=ParserMetric()) | |||
def test_train2(self): | |||
model = BiaffineParser(embed=(VOCAB_SIZE, 10), | |||
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, | |||
rnn_hidden_size=16, | |||
arc_mlp_size=10, | |||
label_mlp_size=10, | |||
num_label=NUM_CLS, encoder='transformer') | |||
ds = prepare_parser_data() | |||
RUNNER.run_model(model, ds, metrics=ParserMetric()) |
@@ -0,0 +1,33 @@ | |||
import pytest | |||
from .model_runner import * | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from fastNLP.models.torch.cnn_text_classification import CNNText | |||
@pytest.mark.torch | |||
class TestCNNText: | |||
def init_model(self, kernel_sizes, kernel_nums=(1,3,5)): | |||
model = CNNText((VOCAB_SIZE, 30), | |||
NUM_CLS, | |||
kernel_nums=kernel_nums, | |||
kernel_sizes=kernel_sizes) | |||
return model | |||
def test_case1(self): | |||
# 测试能否正常运行CNN | |||
model = self.init_model((1,3,5)) | |||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||
def test_init_model(self): | |||
with pytest.raises(Exception): | |||
self.init_model(2, 4) | |||
with pytest.raises(Exception): | |||
self.init_model((2,)) | |||
def test_output(self): | |||
model = self.init_model((3,), (1,)) | |||
global MAX_LEN | |||
MAX_LEN = 2 | |||
RUNNER.run_model_with_task(TEXT_CLS, model) |
@@ -0,0 +1,73 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from fastNLP.models.torch import LSTMSeq2SeqModel, TransformerSeq2SeqModel | |||
import torch | |||
from fastNLP.embeddings.torch import StaticEmbedding | |||
from fastNLP import Vocabulary, DataSet | |||
from fastNLP import Trainer, Accuracy | |||
from fastNLP import Callback, TorchDataLoader | |||
def prepare_env(): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
src_words_idx = [[3, 1, 2], [1, 2]] | |||
# tgt_words_idx = [[1, 2, 3, 4], [2, 3]] | |||
src_seq_len = [3, 2] | |||
# tgt_seq_len = [4, 2] | |||
ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, | |||
'tgt_seq_len':src_seq_len}) | |||
dl = TorchDataLoader(ds, batch_size=32) | |||
return embed, dl | |||
class ExitCallback(Callback): | |||
def __init__(self): | |||
super().__init__() | |||
def on_valid_end(self, trainer, results): | |||
if results['acc#acc'] == 1: | |||
raise KeyboardInterrupt() | |||
@pytest.mark.torch | |||
class TestSeq2SeqGeneratorModel: | |||
def test_run(self): | |||
# 检测是否能够使用SequenceGeneratorModel训练, 透传预测 | |||
embed, dl = prepare_env() | |||
model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, | |||
dim_ff=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True) | |||
optimizer = torch.optim.Adam(model1.parameters(), lr=1e-3) | |||
trainer = Trainer(model1, driver='torch', optimizers=optimizer, train_dataloader=dl, | |||
n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, | |||
evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], | |||
'seq_len': x['tgt_seq_len'], | |||
**x}, | |||
callbacks=ExitCallback()) | |||
trainer.run() | |||
embed, dl = prepare_env() | |||
model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
num_layers=1, hidden_size=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True, attention=True) | |||
optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) | |||
trainer = Trainer(model2, driver='torch', optimizers=optimizer, train_dataloader=dl, | |||
n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, | |||
evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], | |||
'seq_len': x['tgt_seq_len'], | |||
**x}, | |||
callbacks=ExitCallback()) | |||
trainer.run() |
@@ -0,0 +1,113 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from fastNLP.models.torch.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings.torch import StaticEmbedding | |||
import torch | |||
from torch import optim | |||
import torch.nn.functional as F | |||
from fastNLP import seq_len_to_mask | |||
def prepare_env(): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
tgt_seq_len = torch.LongTensor([4, 2]) | |||
return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len | |||
def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): | |||
optimizer = optim.Adam(model.parameters(), lr=1e-2) | |||
mask = seq_len_to_mask(tgt_seq_len).eq(0) | |||
target = tgt_words_idx.masked_fill(mask, -100) | |||
for i in range(100): | |||
optimizer.zero_grad() | |||
pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size | |||
loss = F.cross_entropy(pred.transpose(1, 2), target) | |||
loss.backward() | |||
optimizer.step() | |||
right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() | |||
return right_count | |||
@pytest.mark.torch | |||
class TestTransformerSeq2SeqModel: | |||
def test_run(self): | |||
# 测试能否跑通 | |||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
for pos_embed in ['learned', 'sin']: | |||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True) | |||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
assert (output['pred'].size() == (2, 4, len(embed))) | |||
for bind_encoder_decoder_embed in [True, False]: | |||
tgt_embed = None | |||
for bind_decoder_input_output_embed in [True, False]: | |||
if bind_encoder_decoder_embed == False: | |||
tgt_embed = embed | |||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||
pos_embed='sin', max_position=20, num_layers=2, | |||
d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
assert (output['pred'].size() == (2, 4, len(embed))) | |||
def test_train(self): | |||
# 测试能否train到overfit | |||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True) | |||
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||
assert(right_count == tgt_words_idx.nelement()) | |||
@pytest.mark.torch | |||
class TestLSTMSeq2SeqModel: | |||
def test_run(self): | |||
# 测试能否跑通 | |||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
for bind_encoder_decoder_embed in [True, False]: | |||
tgt_embed = None | |||
for bind_decoder_input_output_embed in [True, False]: | |||
if bind_encoder_decoder_embed == False: | |||
tgt_embed = embed | |||
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||
num_layers=2, hidden_size=20, dropout=0.1, | |||
bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
assert (output['pred'].size() == (2, 4, len(embed))) | |||
def test_train(self): | |||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
num_layers=1, hidden_size=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True) | |||
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||
assert (right_count == tgt_words_idx.nelement()) | |||
@@ -0,0 +1,47 @@ | |||
import pytest | |||
from .model_runner import * | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from fastNLP.models.torch.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
@pytest.mark.torch | |||
class TestBiLSTM: | |||
def test_case1(self): | |||
# 测试能否正常运行CNN | |||
init_emb = (VOCAB_SIZE, 30) | |||
model = BiLSTMCRF(init_emb, | |||
hidden_size=30, | |||
num_classes=NUM_CLS) | |||
dl = RUNNER.prepare_pos_tagging_data() | |||
metric = Accuracy() | |||
RUNNER.run_model(model, dl, metric) | |||
@pytest.mark.torch | |||
class TestSeqLabel: | |||
def test_case1(self): | |||
# 测试能否正常运行CNN | |||
init_emb = (VOCAB_SIZE, 30) | |||
model = SeqLabeling(init_emb, | |||
hidden_size=30, | |||
num_classes=NUM_CLS) | |||
dl = RUNNER.prepare_pos_tagging_data() | |||
metric = Accuracy() | |||
RUNNER.run_model(model, dl, metric) | |||
@pytest.mark.torch | |||
class TestAdvSeqLabel: | |||
def test_case1(self): | |||
# 测试能否正常运行CNN | |||
init_emb = (VOCAB_SIZE, 30) | |||
model = AdvSeqLabel(init_emb, | |||
hidden_size=30, | |||
num_classes=NUM_CLS) | |||
dl = RUNNER.prepare_pos_tagging_data() | |||
metric = Accuracy() | |||
RUNNER.run_model(model, dl, metric) |
@@ -0,0 +1,327 @@ | |||
import pytest | |||
import os | |||
from fastNLP import Vocabulary | |||
@pytest.mark.torch | |||
class TestCRF: | |||
def test_case1(self): | |||
from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||
# 检查allowed_transitions()能否正确使用 | |||
id2label = {0: 'B', 1: 'I', 2:'O'} | |||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
(2, 4), (3, 0), (3, 2)} | |||
assert expected_res == set(allowed_transitions(id2label, include_start_end=True)) | |||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
assert (expected_res == set( | |||
allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||
allowed_transitions(id2label, include_start_end=True) | |||
labels = ['O'] | |||
for label in ['X', 'Y']: | |||
for tag in 'BI': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx:label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||
labels = [] | |||
for label in ['X', 'Y']: | |||
for tag in 'BMES': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
assert (expected_res == set( | |||
allowed_transitions(id2label, include_start_end=True))) | |||
def test_case11(self): | |||
# 测试自动推断encoding类型 | |||
from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||
id2label = {0: 'B', 1: 'I', 2: 'O'} | |||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
(2, 4), (3, 0), (3, 2)} | |||
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||
id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
assert (expected_res == set( | |||
allowed_transitions(id2label, include_start_end=True))) | |||
id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||
allowed_transitions(id2label, include_start_end=True) | |||
labels = ['O'] | |||
for label in ['X', 'Y']: | |||
for tag in 'BI': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||
labels = [] | |||
for label in ['X', 'Y']: | |||
for tag in 'BMES': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
assert (expected_res == set( | |||
allowed_transitions(id2label, include_start_end=True))) | |||
def test_case12(self): | |||
# 测试能否通过vocab生成转移矩阵 | |||
from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||
id2label = {0: 'B', 1: 'I', 2: 'O'} | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
(2, 4), (3, 0), (3, 2)} | |||
assert (expected_res == set(allowed_transitions(vocab, include_start_end=True))) | |||
id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
assert (expected_res == set( | |||
allowed_transitions(vocab, include_start_end=True))) | |||
id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||
vocab = Vocabulary() | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
allowed_transitions(vocab, include_start_end=True) | |||
labels = ['O'] | |||
for label in ['X', 'Y']: | |||
for tag in 'BI': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
assert (expected_res == set(allowed_transitions(vocab, include_start_end=True))) | |||
labels = [] | |||
for label in ['X', 'Y']: | |||
for tag in 'BMES': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
assert (expected_res == set( | |||
allowed_transitions(vocab, include_start_end=True))) | |||
# def test_case2(self): | |||
# # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||
# pass | |||
# import torch | |||
# from fastNLP import seq_len_to_mask | |||
# | |||
# labels = ['O'] | |||
# for label in ['X', 'Y']: | |||
# for tag in 'BI': | |||
# labels.append('{}-{}'.format(tag, label)) | |||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||
# num_tags = len(id2label) | |||
# max_len = 10 | |||
# batch_size = 4 | |||
# bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() | |||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), | |||
# include_start_end_transitions=False) | |||
# bio_trans_m = allen_CRF.transitions | |||
# bio_seq_lens = torch.randint(1, max_len, size=(batch_size,)) | |||
# bio_seq_lens[0] = 1 | |||
# bio_seq_lens[-1] = max_len | |||
# mask = seq_len_to_mask(bio_seq_lens) | |||
# allen_res = allen_CRF.viterbi_tags(bio_logits, mask) | |||
# | |||
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
# include_start_end=True)) | |||
# fast_CRF.trans_m = bio_trans_m | |||
# fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) | |||
# bio_scores = [round(score, 4) for _, score in allen_res] | |||
# # score equal | |||
# self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) | |||
# # seq equal | |||
# bio_path = [_ for _, score in allen_res] | |||
# self.assertListEqual(bio_path, fast_res[0]) | |||
# | |||
# labels = [] | |||
# for label in ['X', 'Y']: | |||
# for tag in 'BMES': | |||
# labels.append('{}-{}'.format(tag, label)) | |||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||
# num_tags = len(id2label) | |||
# | |||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), | |||
# include_start_end_transitions=False) | |||
# bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() | |||
# bmes_trans_m = allen_CRF.transitions | |||
# bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,)) | |||
# bmes_seq_lens[0] = 1 | |||
# bmes_seq_lens[-1] = max_len | |||
# mask = seq_len_to_mask(bmes_seq_lens) | |||
# allen_res = allen_CRF.viterbi_tags(bmes_logits, mask) | |||
# | |||
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
# encoding_type='BMES', | |||
# include_start_end=True)) | |||
# fast_CRF.trans_m = bmes_trans_m | |||
# fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) | |||
# # score equal | |||
# bmes_scores = [round(score, 4) for _, score in allen_res] | |||
# self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) | |||
# # seq equal | |||
# bmes_path = [_ for _, score in allen_res] | |||
# self.assertListEqual(bmes_path, fast_res[0]) | |||
# | |||
# data = { | |||
# 'bio_logits': bio_logits.tolist(), | |||
# 'bio_scores': bio_scores, | |||
# 'bio_path': bio_path, | |||
# 'bio_trans_m': bio_trans_m.tolist(), | |||
# 'bio_seq_lens': bio_seq_lens.tolist(), | |||
# 'bmes_logits': bmes_logits.tolist(), | |||
# 'bmes_scores': bmes_scores, | |||
# 'bmes_path': bmes_path, | |||
# 'bmes_trans_m': bmes_trans_m.tolist(), | |||
# 'bmes_seq_lens': bmes_seq_lens.tolist(), | |||
# } | |||
# | |||
# with open('weights.json', 'w') as f: | |||
# import json | |||
# json.dump(data, f) | |||
def test_case2(self): | |||
# 测试CRF是否正常work。 | |||
import json | |||
import torch | |||
from fastNLP import seq_len_to_mask | |||
folder = os.path.dirname(os.path.abspath(__file__)) | |||
path = os.path.join(folder, '../../../', 'helpers/data/modules/decoder/crf.json') | |||
with open(os.path.abspath(path), 'r') as f: | |||
data = json.load(f) | |||
bio_logits = torch.FloatTensor(data['bio_logits']) | |||
bio_scores = data['bio_scores'] | |||
bio_path = data['bio_path'] | |||
bio_trans_m = torch.FloatTensor(data['bio_trans_m']) | |||
bio_seq_lens = torch.LongTensor(data['bio_seq_lens']) | |||
bmes_logits = torch.FloatTensor(data['bmes_logits']) | |||
bmes_scores = data['bmes_scores'] | |||
bmes_path = data['bmes_path'] | |||
bmes_trans_m = torch.FloatTensor(data['bmes_trans_m']) | |||
bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens']) | |||
labels = ['O'] | |||
for label in ['X', 'Y']: | |||
for tag in 'BI': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
num_tags = len(id2label) | |||
mask = seq_len_to_mask(bio_seq_lens) | |||
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions | |||
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
include_start_end=True)) | |||
fast_CRF.trans_m.data = bio_trans_m | |||
fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) | |||
# score equal | |||
assert (bio_scores == [round(s, 4) for s in fast_res[1].tolist()]) | |||
# seq equal | |||
assert (bio_path == fast_res[0]) | |||
labels = [] | |||
for label in ['X', 'Y']: | |||
for tag in 'BMES': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
num_tags = len(id2label) | |||
mask = seq_len_to_mask(bmes_seq_lens) | |||
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions | |||
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
encoding_type='BMES', | |||
include_start_end=True)) | |||
fast_CRF.trans_m.data = bmes_trans_m | |||
fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) | |||
# score equal | |||
assert (bmes_scores == [round(s, 4) for s in fast_res[1].tolist()]) | |||
# seq equal | |||
assert (bmes_path == fast_res[0]) | |||
def test_case3(self): | |||
# 测试crf的loss不会出现负数 | |||
import torch | |||
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField | |||
from fastNLP.core.utils import seq_len_to_mask | |||
from torch import optim | |||
from torch import nn | |||
num_tags, include_start_end_trans = 4, True | |||
num_samples = 4 | |||
lengths = torch.randint(3, 50, size=(num_samples, )).long() | |||
max_len = lengths.max() | |||
tags = torch.randint(num_tags, size=(num_samples, max_len)) | |||
masks = seq_len_to_mask(lengths) | |||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | |||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | |||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | |||
for _ in range(10): | |||
loss = crf(feats, tags, masks).mean() | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
if _%1000==0: | |||
print(loss) | |||
assert (loss.item()> 0) | |||
def test_masking(self): | |||
# 测试crf的pad masking正常运行 | |||
import torch | |||
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField | |||
max_len = 5 | |||
n_tags = 5 | |||
pad_len = 5 | |||
torch.manual_seed(4) | |||
logit = torch.rand(1, max_len+pad_len, n_tags) | |||
# logit[0, -1, :] = 0.0 | |||
mask = torch.ones(1, max_len+pad_len) | |||
mask[0,-pad_len] = 0 | |||
model = ConditionalRandomField(n_tags) | |||
pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len]) | |||
mask_pred, mask_score = model.viterbi_decode(logit, mask) | |||
assert (pred[0].tolist() == mask_pred[0,:-pad_len].tolist()) | |||
@@ -0,0 +1,49 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings.torch import StaticEmbedding | |||
from fastNLP.modules.torch import TransformerSeq2SeqDecoder | |||
from fastNLP.modules.torch import LSTMSeq2SeqDecoder | |||
from fastNLP import seq_len_to_mask | |||
@pytest.mark.torch | |||
class TestTransformerSeq2SeqDecoder: | |||
@pytest.mark.parametrize("flag", [True, False]) | |||
def test_case(self, flag): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=10) | |||
encoder_output = torch.randn(2, 3, 10) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
encoder_mask = seq_len_to_mask(src_seq_len) | |||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, | |||
d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, | |||
bind_decoder_input_output_embed = True) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) | |||
assert (output.size() == (2, 4, len(vocab))) | |||
@pytest.mark.torch | |||
class TestLSTMDecoder: | |||
@pytest.mark.parametrize("flag", [True, False]) | |||
@pytest.mark.parametrize("attention", [True, False]) | |||
def test_case(self, flag, attention): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) | |||
encoder_output = torch.randn(2, 3, 10) | |||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
encoder_mask = seq_len_to_mask(src_seq_len) | |||
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, | |||
dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
output = decoder(tgt_words_idx, state) | |||
assert tuple(output.size()) == (2, 4, len(vocab)) |
@@ -0,0 +1,33 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings.torch import StaticEmbedding | |||
class TestTransformerSeq2SeqEncoder: | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=5) | |||
encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2) | |||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
seq_len = torch.LongTensor([3]) | |||
encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||
assert (encoder_output.size() == (1, 3, 10)) | |||
class TestBiLSTMEncoder: | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=5) | |||
encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1) | |||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
seq_len = torch.LongTensor([3]) | |||
encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||
assert (encoder_mask.size() == (1, 3)) |
@@ -0,0 +1,18 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from fastNLP.modules.torch.encoder.star_transformer import StarTransformer | |||
@pytest.mark.torch | |||
class TestStarTransformer: | |||
def test_1(self): | |||
model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100) | |||
x = torch.rand(16, 45, 100) | |||
mask = torch.ones(16, 45).byte() | |||
y, yn = model(x, mask) | |||
assert (tuple(y.size()) == (16, 45, 100)) | |||
assert (tuple(yn.size()) == (16, 100)) |
@@ -0,0 +1,27 @@ | |||
import pytest | |||
import numpy as np | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from fastNLP.modules.torch.encoder.variational_rnn import VarLSTM | |||
@pytest.mark.torch | |||
class TestMaskedRnn: | |||
def test_case_1(self): | |||
masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | |||
x = torch.tensor([[[1.0], [2.0]]]) | |||
print(x.size()) | |||
y = masked_rnn(x) | |||
def test_case_2(self): | |||
input_size = 12 | |||
batch = 16 | |||
hidden = 10 | |||
masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) | |||
xx = torch.randn((batch, 32, input_size)) | |||
y, _ = masked_rnn(xx) | |||
assert(tuple(y.shape) == (batch, 32, hidden)) |
@@ -0,0 +1 @@ | |||
@@ -0,0 +1,146 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from fastNLP.modules.torch.generator import SequenceGenerator | |||
from fastNLP.modules.torch import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings.torch import StaticEmbedding | |||
from torch import nn | |||
from fastNLP import seq_len_to_mask | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as State | |||
from fastNLP.core.utils.dummy_class import DummyClass as Seq2SeqDecoder | |||
def prepare_env(): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
encoder_output = torch.randn(2, 3, 10) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
encoder_mask = seq_len_to_mask(src_seq_len) | |||
return embed, encoder_output, encoder_mask | |||
class GreedyDummyDecoder(Seq2SeqDecoder): | |||
def __init__(self, decoder_output): | |||
super().__init__() | |||
self.cur_length = 0 | |||
self.decoder_output = decoder_output | |||
def decode(self, tokens, state): | |||
self.cur_length += 1 | |||
scores = self.decoder_output[:, self.cur_length] | |||
return scores | |||
class DummyState(State): | |||
def __init__(self, decoder): | |||
super().__init__() | |||
self.decoder = decoder | |||
def reorder_state(self, indices: torch.LongTensor): | |||
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) | |||
@pytest.mark.torch | |||
class TestSequenceGenerator: | |||
def test_run(self): | |||
# 测试能否运行 (1) 初始化decoder,(2) decode一发 | |||
embed, encoder_output, encoder_mask = prepare_env() | |||
for do_sample in [True, False]: | |||
for num_beams in [1, 3, 5]: | |||
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, | |||
dropout=0.3, bind_decoder_input_output_embed=True, attention=True) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, | |||
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||
generator.generate(state=state, tokens=None) | |||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, | |||
bind_decoder_input_output_embed=True) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||
generator.generate(state=state, tokens=None) | |||
# 测试一下其它值 | |||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, | |||
dropout=0.1, | |||
bind_decoder_input_output_embed=True) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||
do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, | |||
eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) | |||
generator.generate(state=state, tokens=None) | |||
def test_greedy_decode(self): | |||
# 测试能否正确的generate | |||
# greedy | |||
for beam_search in [1, 3]: | |||
decoder_output = torch.randn(2, 10, 5) | |||
path = decoder_output.argmax(dim=-1) # 2 x 10 | |||
decoder = GreedyDummyDecoder(decoder_output) | |||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, | |||
eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
assert (decode_path.eq(path).sum() == path.numel()) | |||
# greedy check eos_token_id | |||
for beam_search in [1, 3]: | |||
decoder_output = torch.randn(2, 10, 5) | |||
decoder_output[:, :7, 4].fill_(-100) | |||
decoder_output[0, 7, 4] = 1000 # 在第8个结束 | |||
decoder_output[1, 5, 4] = 1000 | |||
path = decoder_output.argmax(dim=-1) # 2 x 4 | |||
decoder = GreedyDummyDecoder(decoder_output) | |||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
decode_path = generator.generate(DummyState(decoder), | |||
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
assert (decode_path.size(1) == 8) # 长度为8 | |||
assert (decode_path[0].eq(path[0, :8]).sum() == 8) | |||
assert (decode_path[1, :6].eq(path[1, :6]).sum() == 6) | |||
def test_sample_decoder(self): | |||
# greedy check eos_token_id | |||
for beam_search in [1, 3]: | |||
decode_paths = [] | |||
# 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大 | |||
num_tests = 10 | |||
for i in range(num_tests): | |||
decoder_output = torch.randn(2, 10, 5) * 10 | |||
decoder_output[:, :7, 4].fill_(-100) | |||
decoder_output[0, 7, 4] = 10000 # 在第8个结束 | |||
decoder_output[1, 5, 4] = 10000 | |||
path = decoder_output.argmax(dim=-1) # 2 x 4 | |||
decoder = GreedyDummyDecoder(decoder_output) | |||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
decode_path = generator.generate(DummyState(decoder), | |||
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
decode_paths.append([decode_path, path]) | |||
sizes = [] | |||
eqs = [] | |||
eq2s = [] | |||
for i in range(num_tests): | |||
decode_path, path = decode_paths[i] | |||
sizes.append(decode_path.size(1)==8) | |||
eqs.append(decode_path[0].eq(path[0, :8]).sum()==8) | |||
eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6) | |||
assert any(sizes) | |||
assert any(eqs) | |||
assert any(eq2s) |