diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index e76391aa..66d2eee2 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -138,6 +138,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ "information please set logger's level to DEBUG." if must_pad: + logger.error(msg) raise type(e)(msg=msg) logger.debug(msg) return NullPadder() diff --git a/fastNLP/embeddings/torch/char_embedding.py b/fastNLP/embeddings/torch/char_embedding.py index 6af0a7ff..69706281 100644 --- a/fastNLP/embeddings/torch/char_embedding.py +++ b/fastNLP/embeddings/torch/char_embedding.py @@ -16,6 +16,7 @@ if _NEED_IMPORT_TORCH: import torch import torch.nn as nn import torch.nn.functional as F + from torch.nn import LSTM from .embedding import TokenEmbedding from .static_embedding import StaticEmbedding @@ -23,7 +24,6 @@ from .utils import _construct_char_vocab_from_vocab from .utils import get_embeddings from ...core import logger from ...core.vocabulary import Vocabulary -from ...modules.torch.encoder.lstm import LSTM class CNNCharEmbedding(TokenEmbedding): diff --git a/fastNLP/models/__init__.py b/fastNLP/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/models/torch/__init__.py b/fastNLP/models/torch/__init__.py new file mode 100644 index 00000000..25d469ad --- /dev/null +++ b/fastNLP/models/torch/__init__.py @@ -0,0 +1,21 @@ +__all__ = [ + 'BiaffineParser', + + "CNNText", + + "SequenceGeneratorModel", + + "Seq2SeqModel", + 'TransformerSeq2SeqModel', + 'LSTMSeq2SeqModel', + + "SeqLabeling", + "AdvSeqLabel", + "BiLSTMCRF", +] + +from .biaffine_parser import BiaffineParser +from .cnn_text_classification import CNNText +from .seq2seq_generator import SequenceGeneratorModel +from .seq2seq_model import * +from .sequence_labeling import * diff --git a/fastNLP/models/torch/biaffine_parser.py b/fastNLP/models/torch/biaffine_parser.py new file mode 100755 index 00000000..574774bd --- /dev/null +++ b/fastNLP/models/torch/biaffine_parser.py @@ -0,0 +1,475 @@ +r""" +Biaffine Dependency Parser 的 Pytorch 实现. +""" +__all__ = [ + "BiaffineParser", + "GraphParser" +] + +from collections import defaultdict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core.utils import seq_len_to_mask +from ...embeddings.torch.utils import get_embeddings +from ...modules.torch.dropout import TimestepDropout +from ...modules.torch.encoder.transformer import TransformerEncoder +from ...modules.torch.encoder.variational_rnn import VarLSTM + + +def _mst(scores): + r""" + with some modification to support parser output for MST decoding + https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 + """ + length = scores.shape[0] + min_score = scores.min() - 1 + eye = np.eye(length) + scores = scores * (1 - eye) + min_score * eye + heads = np.argmax(scores, axis=1) + heads[0] = 0 + tokens = np.arange(1, length) + roots = np.where(heads[tokens] == 0)[0] + 1 + if len(roots) < 1: + root_scores = scores[tokens, 0] + head_scores = scores[tokens, heads[tokens]] + new_root = tokens[np.argmax(root_scores / head_scores)] + heads[new_root] = 0 + elif len(roots) > 1: + root_scores = scores[roots, 0] + scores[roots, 0] = 0 + new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1 + new_root = roots[np.argmin( + scores[roots, new_heads] / root_scores)] + heads[roots] = new_heads + heads[new_root] = 0 + + edges = defaultdict(set) + vertices = set((0,)) + for dep, head in enumerate(heads[tokens]): + vertices.add(dep + 1) + edges[head].add(dep + 1) + for cycle in _find_cycle(vertices, edges): + dependents = set() + to_visit = set(cycle) + while len(to_visit) > 0: + node = to_visit.pop() + if node not in dependents: + dependents.add(node) + to_visit.update(edges[node]) + cycle = np.array(list(cycle)) + old_heads = heads[cycle] + old_scores = scores[cycle, old_heads] + non_heads = np.array(list(dependents)) + scores[np.repeat(cycle, len(non_heads)), + np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score + new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1 + new_scores = scores[cycle, new_heads] / old_scores + change = np.argmax(new_scores) + changed_cycle = cycle[change] + old_head = old_heads[change] + new_head = new_heads[change] + heads[changed_cycle] = new_head + edges[new_head].add(changed_cycle) + edges[old_head].remove(changed_cycle) + + return heads + + +def _find_cycle(vertices, edges): + r""" + https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm + https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py + """ + _index = 0 + _stack = [] + _indices = {} + _lowlinks = {} + _onstack = defaultdict(lambda: False) + _SCCs = [] + + def _strongconnect(v): + nonlocal _index + _indices[v] = _index + _lowlinks[v] = _index + _index += 1 + _stack.append(v) + _onstack[v] = True + + for w in edges[v]: + if w not in _indices: + _strongconnect(w) + _lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) + elif _onstack[w]: + _lowlinks[v] = min(_lowlinks[v], _indices[w]) + + if _lowlinks[v] == _indices[v]: + SCC = set() + while True: + w = _stack.pop() + _onstack[w] = False + SCC.add(w) + if not (w != v): + break + _SCCs.append(SCC) + + for v in vertices: + if v not in _indices: + _strongconnect(v) + + return [SCC for SCC in _SCCs if len(SCC) > 1] + + +class GraphParser(nn.Module): + r""" + 基于图的parser base class, 支持贪婪解码和最大生成树解码 + """ + + def __init__(self): + super(GraphParser, self).__init__() + + @staticmethod + def greedy_decoder(arc_matrix, mask=None): + r""" + 贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 + + :param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 + :param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. + 若为 ``None`` 时, 默认为全1向量. Default: ``None`` + :return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 + """ + _, seq_len, _ = arc_matrix.shape + matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) + flip_mask = mask.eq(False) + matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) + _, heads = torch.max(matrix, dim=2) + if mask is not None: + heads *= mask.long() + return heads + + @staticmethod + def mst_decoder(arc_matrix, mask=None): + r""" + 用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 + + :param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 + :param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. + 若为 ``None`` 时, 默认为全1向量. Default: ``None`` + :return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 + """ + batch_size, seq_len, _ = arc_matrix.shape + matrix = arc_matrix.clone() + ans = matrix.new_zeros(batch_size, seq_len).long() + lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len + for i, graph in enumerate(matrix): + len_i = lens[i] + ans[i, :len_i] = torch.as_tensor(_mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) + if mask is not None: + ans *= mask.long() + return ans + + +class ArcBiaffine(nn.Module): + r""" + Biaffine Dependency Parser 的子模块, 用于构建预测边的图 + + """ + + def __init__(self, hidden_size, bias=True): + r""" + + :param hidden_size: 输入的特征维度 + :param bias: 是否使用bias. Default: ``True`` + """ + super(ArcBiaffine, self).__init__() + self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) + self.has_bias = bias + if self.has_bias: + self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True) + else: + self.register_parameter("bias", None) + + def forward(self, head, dep): + r""" + + :param head: arc-head tensor [batch, length, hidden] + :param dep: arc-dependent tensor [batch, length, hidden] + :return output: tensor [bacth, length, length] + """ + output = dep.matmul(self.U) + output = output.bmm(head.transpose(-1, -2)) + if self.has_bias: + output = output + head.matmul(self.bias).unsqueeze(1) + return output + + +class LabelBilinear(nn.Module): + r""" + Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 + + """ + + def __init__(self, in1_features, in2_features, num_label, bias=True): + r""" + + :param in1_features: 输入的特征1维度 + :param in2_features: 输入的特征2维度 + :param num_label: 边类别的个数 + :param bias: 是否使用bias. Default: ``True`` + """ + super(LabelBilinear, self).__init__() + self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) + self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) + + def forward(self, x1, x2): + r""" + + :param x1: [batch, seq_len, hidden] 输入特征1, 即label-head + :param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep + :return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 + """ + output = self.bilinear(x1, x2) + output = output + self.lin(torch.cat([x1, x2], dim=2)) + return output + + +class BiaffineParser(GraphParser): + r""" + Biaffine Dependency Parser 实现. + 论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) `_ . + + """ + + def __init__(self, + embed, + pos_vocab_size, + pos_emb_dim, + num_label, + rnn_layers=1, + rnn_hidden_size=200, + arc_mlp_size=100, + label_mlp_size=100, + dropout=0.3, + encoder='lstm', + use_greedy_infer=False): + r""" + + :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 + embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, + 此时就以传入的对象作为embedding + :param pos_vocab_size: part-of-speech 词典大小 + :param pos_emb_dim: part-of-speech 向量维度 + :param num_label: 边的类别个数 + :param rnn_layers: rnn encoder的层数 + :param rnn_hidden_size: rnn encoder 的隐状态维度 + :param arc_mlp_size: 边预测的MLP维度 + :param label_mlp_size: 类别预测的MLP维度 + :param dropout: dropout概率. + :param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm + :param use_greedy_infer: 是否在inference时使用贪心算法. + 若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` + """ + super(BiaffineParser, self).__init__() + rnn_out_size = 2 * rnn_hidden_size + word_hid_dim = pos_hid_dim = rnn_hidden_size + self.word_embedding = get_embeddings(embed) + word_emb_dim = self.word_embedding.embedding_dim + self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) + self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) + self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) + self.word_norm = nn.LayerNorm(word_hid_dim) + self.pos_norm = nn.LayerNorm(pos_hid_dim) + self.encoder_name = encoder + self.max_len = 512 + if encoder == 'var-lstm': + self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, + hidden_size=rnn_hidden_size, + num_layers=rnn_layers, + bias=True, + batch_first=True, + input_dropout=dropout, + hidden_dropout=dropout, + bidirectional=True) + elif encoder == 'lstm': + self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, + hidden_size=rnn_hidden_size, + num_layers=rnn_layers, + bias=True, + batch_first=True, + dropout=dropout, + bidirectional=True) + elif encoder == 'transformer': + n_head = 16 + d_k = d_v = int(rnn_out_size / n_head) + if (d_k * n_head) != rnn_out_size: + raise ValueError('Unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) + self.position_emb = nn.Embedding(num_embeddings=self.max_len, + embedding_dim=rnn_out_size, ) + self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size, + n_head=n_head, dim_ff=1024, dropout=dropout) + else: + raise ValueError('Unsupported encoder type: {}'.format(encoder)) + + self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), + nn.ELU(), + TimestepDropout(p=dropout), ) + self.arc_mlp_size = arc_mlp_size + self.label_mlp_size = label_mlp_size + self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) + self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) + self.use_greedy_infer = use_greedy_infer + self.reset_parameters() + self.dropout = dropout + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Embedding): + continue + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.weight, 0.1) + nn.init.constant_(m.bias, 0) + else: + for p in m.parameters(): + nn.init.normal_(p, 0, 0.1) + + def forward(self, words1, words2, seq_len, target1=None): + r"""模型forward阶段 + + :param words1: [batch_size, seq_len] 输入word序列 + :param words2: [batch_size, seq_len] 输入pos序列 + :param seq_len: [batch_size, seq_len] 输入序列长度 + :param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, + 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 + Default: ``None`` + :return dict: parsing + 结果:: + + pred1: [batch_size, seq_len, seq_len] 边预测logits + pred2: [batch_size, seq_len, num_label] label预测logits + pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 + + """ + # prepare embeddings + batch_size, length = words1.shape + # print('forward {} {}'.format(batch_size, seq_len)) + + # get sequence mask + mask = seq_len_to_mask(seq_len, max_len=length).long() + + word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] + pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] + + word, pos = self.word_fc(word), self.pos_fc(pos) + word, pos = self.word_norm(word), self.pos_norm(pos) + x = torch.cat([word, pos], dim=2) # -> [N,L,C] + + # encoder, extract features + if self.encoder_name.endswith('lstm'): + sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) + x = x[sort_idx] + x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True) + feat, _ = self.encoder(x) # -> [N,L,C] + feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + feat = feat[unsort_idx] + else: + seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None, :] + x = x + self.position_emb(seq_range) + feat = self.encoder(x, mask.float()) + + # for arc biaffine + # mlp, reduce dim + feat = self.mlp(feat) + arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size + arc_dep, arc_head = feat[:, :, :arc_sz], feat[:, :, arc_sz:2 * arc_sz] + label_dep, label_head = feat[:, :, 2 * arc_sz:2 * arc_sz + label_sz], feat[:, :, 2 * arc_sz + label_sz:] + + # biaffine arc classifier + arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] + + # use gold or predicted arc to predict label + if target1 is None or not self.training: + # use greedy decoding in training + if self.training or self.use_greedy_infer: + heads = self.greedy_decoder(arc_pred, mask) + else: + heads = self.mst_decoder(arc_pred, mask) + head_pred = heads + else: + assert self.training # must be training mode + if target1 is None: + heads = self.greedy_decoder(arc_pred, mask) + head_pred = heads + else: + head_pred = None + heads = target1 + + batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) + label_head = label_head[batch_range, heads].contiguous() + label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] + res_dict = {'pred1': arc_pred, 'pred2': label_pred} + if head_pred is not None: + res_dict['pred3'] = head_pred + return res_dict + + def train_step(self, words1, words2, seq_len, target1, target2): + res = self(words1, words2, seq_len, target1) + arc_pred = res['pred1'] + label_pred = res['pred2'] + loss = self.loss(pred1=arc_pred, pred2=label_pred, target1=target1, target2=target2, seq_len=seq_len) + return {'loss': loss} + + @staticmethod + def loss(pred1, pred2, target1, target2, seq_len): + r""" + 计算parser的loss + + :param pred1: [batch_size, seq_len, seq_len] 边预测logits + :param pred2: [batch_size, seq_len, num_label] label预测logits + :param target1: [batch_size, seq_len] 真实边的标注 + :param target2: [batch_size, seq_len] 真实类别的标注 + :param seq_len: [batch_size, seq_len] 真实目标的长度 + :return loss: scalar + """ + + batch_size, length, _ = pred1.shape + mask = seq_len_to_mask(seq_len, max_len=length) + flip_mask = (mask.eq(False)) + _arc_pred = pred1.clone() + _arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) + arc_logits = F.log_softmax(_arc_pred, dim=2) + label_logits = F.log_softmax(pred2, dim=2) + batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) + child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) + arc_loss = arc_logits[batch_index, child_index, target1] + label_loss = label_logits[batch_index, child_index, target2] + + arc_loss = arc_loss.masked_fill(flip_mask, 0) + label_loss = label_loss.masked_fill(flip_mask, 0) + arc_nll = -arc_loss.mean() + label_nll = -label_loss.mean() + return arc_nll + label_nll + + def evaluate_step(self, words1, words2, seq_len): + r"""模型预测API + + :param words1: [batch_size, seq_len] 输入word序列 + :param words2: [batch_size, seq_len] 输入pos序列 + :param seq_len: [batch_size, seq_len] 输入序列长度 + :return dict: parsing + 结果:: + + pred1: [batch_size, seq_len] heads的预测结果 + pred2: [batch_size, seq_len, num_label] label预测logits + + """ + res = self(words1, words2, seq_len) + output = {} + output['pred1'] = res.pop('pred3') + _, label_pred = res.pop('pred2').max(2) + output['pred2'] = label_pred + return output + diff --git a/fastNLP/models/torch/cnn_text_classification.py b/fastNLP/models/torch/cnn_text_classification.py new file mode 100755 index 00000000..34fe7454 --- /dev/null +++ b/fastNLP/models/torch/cnn_text_classification.py @@ -0,0 +1,92 @@ +r""" +.. todo:: + doc +""" + +__all__ = [ + "CNNText" +] + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core.utils import seq_len_to_mask +from ...embeddings.torch import embedding +from ...modules.torch import encoder + + +class CNNText(torch.nn.Module): + r""" + 使用CNN进行文本分类的模型 + 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' + + """ + + def __init__(self, embed, + num_classes, + kernel_nums=(30, 40, 50), + kernel_sizes=(1, 3, 5), + dropout=0.5): + r""" + + :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), + 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding + :param int num_classes: 一共有多少类 + :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 + :param float dropout: Dropout的大小 + """ + super(CNNText, self).__init__() + + # no support for pre-trained embedding currently + self.embed = embedding.Embedding(embed) + self.conv_pool = encoder.ConvMaxpool( + in_channels=self.embed.embedding_dim, + out_channels=kernel_nums, + kernel_sizes=kernel_sizes) + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(sum(kernel_nums), num_classes) + + def forward(self, words, seq_len=None): + r""" + + :param torch.LongTensor words: [batch_size, seq_len],句子中word的index + :param torch.LongTensor seq_len: [batch,] 每个句子的长度 + :param target: 每个 sample 的目标值。 + + :return output: + """ + x = self.embed(words) # [N,L] -> [N,L,C] + if seq_len is not None: + mask = seq_len_to_mask(seq_len) + x = self.conv_pool(x, mask) + else: + x = self.conv_pool(x) # [N,L,C] -> [N,C] + x = self.dropout(x) + x = self.fc(x) # [N,C] -> [N, N_class] + res = {'pred': x} + return res + + def train_step(self, words, target, seq_len=None): + """ + + :param words: + :param target: + :param seq_len: + :return: + """ + res = self(words, seq_len) + x = res['pred'] + loss = F.cross_entropy(x, target) + return {'loss': loss} + + def evaluate_step(self, words, seq_len=None): + r""" + :param torch.LongTensor words: [batch_size, seq_len],句子中word的index + :param torch.LongTensor seq_len: [batch,] 每个句子的长度 + + :return predict: dict of torch.LongTensor, [batch_size, ] + """ + output = self(words, seq_len) + _, predict = output['pred'].max(dim=1) + return {'pred': predict} diff --git a/fastNLP/models/torch/seq2seq_generator.py b/fastNLP/models/torch/seq2seq_generator.py new file mode 100755 index 00000000..9ee723e5 --- /dev/null +++ b/fastNLP/models/torch/seq2seq_generator.py @@ -0,0 +1,81 @@ +r"""undocumented""" + +import torch +from torch import nn +import torch.nn.functional as F +from fastNLP import seq_len_to_mask +from .seq2seq_model import Seq2SeqModel +from ...modules.torch.generator.seq2seq_generator import SequenceGenerator + + +__all__ = ['SequenceGeneratorModel'] + + +class SequenceGeneratorModel(nn.Module): + """ + 通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict + 函数会被调用。 + + """ + + def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, + num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, + repetition_penalty=1, length_penalty=1.0, pad_token_id=0): + """ + + :param Seq2SeqModel seq2seq_model: 序列到序列模型 + :param int,None bos_token_id: 句子开头的token id + :param int,None eos_token_id: 句子结束的token id + :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len + :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask + :param int num_beams: beam search的大小 + :param bool do_sample: 是否通过采样的方式生成 + :param float temperature: 只有在do_sample为True才有意义 + :param int top_k: 只从top_k中采样 + :param float top_p: 只从top_p的token中采样,nucles sample + :param float repetition_penalty: 多大程度上惩罚重复的token + :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 + :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 + """ + super().__init__() + self.seq2seq_model = seq2seq_model + self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, + do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, length_penalty=length_penalty, + pad_token_id=pad_token_id) + + def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): + """ + 透传调用seq2seq_model的forward。 + + :param torch.LongTensor src_tokens: bsz x max_len + :param torch.LongTensor tgt_tokens: bsz x max_len' + :param torch.LongTensor src_seq_len: bsz + :param torch.LongTensor tgt_seq_len: bsz + :return: + """ + return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) + + def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): + res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) + pred = res['pred'] + if tgt_seq_len is not None: + mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) + tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) + loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) + return {'loss': loss} + + def evaluate_step(self, src_tokens, src_seq_len=None): + """ + 给定source的内容,输出generate的内容。 + + :param torch.LongTensor src_tokens: bsz x max_len + :param torch.LongTensor src_seq_len: bsz + :return: + """ + state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) + result = self.generator.generate(state) + return {'pred': result} diff --git a/fastNLP/models/torch/seq2seq_model.py b/fastNLP/models/torch/seq2seq_model.py new file mode 100755 index 00000000..057fb93b --- /dev/null +++ b/fastNLP/models/torch/seq2seq_model.py @@ -0,0 +1,196 @@ +r""" +主要包含组成Sequence-to-Sequence的model + +""" + +import torch +from torch import nn +import torch.nn.functional as F + +from fastNLP import seq_len_to_mask +from ...embeddings.torch.utils import get_embeddings +from ...embeddings.torch.utils import get_sinusoid_encoding_table +from ...modules.torch.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder +from ...modules.torch.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder + + +__all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel'] + + +class Seq2SeqModel(nn.Module): + def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): + """ + 可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。如果需要使用该模型 + 进行生成,需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 中。在本模型中,forward()会把encoder后的 + 结果传入到decoder中,并将decoder的输出output出来。 + + :param encoder: Seq2SeqEncoder 对象,需要实现对应的forward()函数,接受两个参数,第一个为bsz x max_len的source tokens, 第二个为 + bsz的source的长度;需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len + 为1的地方需要被attend。如果encoder的输出或者输入有变化,可以重载本模型的prepare_state()函数或者forward()函数 + :param decoder: Seq2SeqDecoder 对象,需要实现init_state()函数,输出为两个参数,第一个为bsz x max_len x hidden_size是 + encoder的输出; 第二个为bsz x max_len,为encoder输出的mask,为0的地方为pad。若decoder需要更多输入,请重载当前模型的 + prepare_state()或forward()函数 + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): + """ + + :param torch.LongTensor src_tokens: source的token + :param torch.LongTensor tgt_tokens: target的token + :param torch.LongTensor src_seq_len: src的长度 + :param torch.LongTensor tgt_seq_len: target的长度,默认用不上 + :return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size + """ + state = self.prepare_state(src_tokens, src_seq_len) + decoder_output = self.decoder(tgt_tokens, state) + if isinstance(decoder_output, torch.Tensor): + return {'pred': decoder_output} + elif isinstance(decoder_output, (tuple, list)): + return {'pred': decoder_output[0]} + else: + raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}") + + def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): + res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) + pred = res['pred'] + if tgt_seq_len is not None: + mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) + tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) + loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) + return {'loss': loss} + + def prepare_state(self, src_tokens, src_seq_len=None): + """ + 调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state + + :param src_tokens: + :param src_seq_len: + :return: + """ + encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) + state = self.decoder.init_state(encoder_output, encoder_mask) + return state + + @classmethod + def build_model(cls, *args, **kwargs): + """ + 需要实现本方法来进行Seq2SeqModel的初始化 + + :return: + """ + raise NotImplemented + + +class TransformerSeq2SeqModel(Seq2SeqModel): + """ + Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化 + + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def build_model(cls, src_embed, tgt_embed=None, + pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, + bind_encoder_decoder_embed=False, + bind_decoder_input_output_embed=True): + """ + 初始化一个TransformerSeq2SeqModel + + :param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding + :param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 + True,则不要输入该值 + :param str pos_embed: 支持sin, learned两种 + :param int max_position: 最大支持长度 + :param int num_layers: encoder和decoder的层数 + :param int d_model: encoder和decoder输入输出的大小 + :param int n_head: encoder和decoder的head的数量 + :param int dim_ff: encoder和decoder中FFN中间映射的维度 + :param float dropout: Attention和FFN dropout的大小 + :param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding + :param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 + :return: TransformerSeq2SeqModel + """ + if bind_encoder_decoder_embed and tgt_embed is not None: + raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") + + src_embed = get_embeddings(src_embed) + + if bind_encoder_decoder_embed: + tgt_embed = src_embed + else: + assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" + tgt_embed = get_embeddings(tgt_embed) + + if pos_embed == 'sin': + encoder_pos_embed = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), + freeze=True) # 这里规定0是padding + deocder_pos_embed = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), + freeze=True) # 这里规定0是padding + elif pos_embed == 'learned': + encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) + deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) + else: + raise ValueError("pos_embed only supports sin or learned.") + + encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed, + num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, + dropout=dropout) + decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, + d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, + dropout=dropout, + bind_decoder_input_output_embed=bind_decoder_input_output_embed) + + return cls(encoder, decoder) + + +class LSTMSeq2SeqModel(Seq2SeqModel): + """ + 使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model + + """ + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def build_model(cls, src_embed, tgt_embed=None, + num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True, + attention=True, bind_encoder_decoder_embed=False, + bind_decoder_input_output_embed=True): + """ + + :param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding + :param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 + True,则不要输入该值 + :param int num_layers: Encoder和Decoder的层数 + :param int hidden_size: encoder和decoder的隐藏层大小 + :param float dropout: 每层之间的Dropout的大小 + :param bool bidirectional: encoder是否使用双向LSTM + :param bool attention: decoder是否使用attention attend encoder在所有时刻的状态 + :param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding + :param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 + :return: LSTMSeq2SeqModel + """ + if bind_encoder_decoder_embed and tgt_embed is not None: + raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") + + src_embed = get_embeddings(src_embed) + + if bind_encoder_decoder_embed: + tgt_embed = src_embed + else: + assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" + tgt_embed = get_embeddings(tgt_embed) + + encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers, + hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional) + decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size, + dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed, + attention=attention) + return cls(encoder, decoder) diff --git a/fastNLP/models/torch/sequence_labeling.py b/fastNLP/models/torch/sequence_labeling.py new file mode 100755 index 00000000..48c3519b --- /dev/null +++ b/fastNLP/models/torch/sequence_labeling.py @@ -0,0 +1,271 @@ +r""" +本模块实现了几种序列标注模型 +""" +__all__ = [ + "SeqLabeling", + "AdvSeqLabel", + "BiLSTMCRF" +] + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core.utils import seq_len_to_mask +from ...embeddings.torch.utils import get_embeddings +from ...modules.torch.decoder import ConditionalRandomField +from ...modules.torch.encoder import LSTM +from ...modules.torch import decoder, encoder +from ...modules.torch.decoder.crf import allowed_transitions + + +class BiLSTMCRF(nn.Module): + r""" + 结构为embedding + BiLSTM + FC + Dropout + CRF. + + """ + def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, + target_vocab=None): + r""" + + :param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100) + :param num_classes: 一共多少个类 + :param num_layers: BiLSTM的层数 + :param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向) + :param dropout: dropout的概率,0为不dropout + :param target_vocab: Vocabulary对象,target与index的对应关系。如果传入该值,将自动避免非法的解码序列。 + """ + super().__init__() + self.embed = get_embeddings(embed) + + if num_layers>1: + self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, + batch_first=True, dropout=dropout) + else: + self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, + batch_first=True) + + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(hidden_size*2, num_classes) + + trans = None + if target_vocab is not None: + assert len(target_vocab)==num_classes, "The number of classes should be same with the length of target vocabulary." + trans = allowed_transitions(target_vocab.idx2word, include_start_end=True) + + self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans) + + def forward(self, words, seq_len=None, target=None): + words = self.embed(words) + feats, _ = self.lstm(words, seq_len=seq_len) + feats = self.fc(feats) + feats = self.dropout(feats) + logits = F.log_softmax(feats, dim=-1) + mask = seq_len_to_mask(seq_len) + if target is None: + pred, _ = self.crf.viterbi_decode(logits, mask) + return {'pred':pred} + else: + loss = self.crf(logits, target, mask).mean() + return {'loss':loss} + + def train_step(self, words, seq_len, target): + return self(words, seq_len, target) + + def evaluate_step(self, words, seq_len): + return self(words, seq_len) + + +class SeqLabeling(nn.Module): + r""" + 一个基础的Sequence labeling的模型。 + 用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 + + """ + + def __init__(self, embed, hidden_size, num_classes): + r""" + + :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), + 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding + :param int hidden_size: LSTM隐藏层的大小 + :param int num_classes: 一共有多少类 + """ + super(SeqLabeling, self).__init__() + + self.embedding = get_embeddings(embed) + self.rnn = encoder.LSTM(self.embedding.embedding_dim, hidden_size) + self.fc = nn.Linear(hidden_size, num_classes) + self.crf = decoder.ConditionalRandomField(num_classes) + + def forward(self, words, seq_len): + r""" + :param torch.LongTensor words: [batch_size, max_len],序列的index + :param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 + :return + """ + x = self.embedding(words) + # [batch_size, max_len, word_emb_dim] + x, _ = self.rnn(x, seq_len) + # [batch_size, max_len, hidden_size * direction] + x = self.fc(x) + return {'pred': x} + # [batch_size, max_len, num_classes] + + def train_step(self, words, seq_len, target): + res = self(words, seq_len) + pred = res['pred'] + mask = seq_len_to_mask(seq_len, max_len=target.size(1)) + return {'loss': self._internal_loss(pred, target, mask)} + + def evaluate_step(self, words, seq_len): + r""" + 用于在预测时使用 + + :param torch.LongTensor words: [batch_size, max_len] + :param torch.LongTensor seq_len: [batch_size,] + :return: {'pred': xx}, [batch_size, max_len] + """ + mask = seq_len_to_mask(seq_len, max_len=words.size(1)) + + res = self(words, seq_len) + pred = res['pred'] + # [batch_size, max_len, num_classes] + pred = self._decode(pred, mask) + return {'pred': pred} + + def _internal_loss(self, x, y, mask): + r""" + Negative log likelihood loss. + :param x: Tensor, [batch_size, max_len, tag_size] + :param y: Tensor, [batch_size, max_len] + :return loss: a scalar Tensor + + """ + x = x.float() + y = y.long() + total_loss = self.crf(x, y, mask) + return torch.mean(total_loss) + + def _decode(self, x, mask): + r""" + :param torch.FloatTensor x: [batch_size, max_len, tag_size] + :return prediction: [batch_size, max_len] + """ + tag_seq, _ = self.crf.viterbi_decode(x, mask) + return tag_seq + + +class AdvSeqLabel(nn.Module): + r""" + 更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 + """ + + def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'): + r""" + + :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), + 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding + :param int hidden_size: LSTM的隐层大小 + :param int num_classes: 有多少个类 + :param float dropout: LSTM中以及DropOut层的drop概率 + :param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S' + 不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证 + 'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。) + :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况有用。 + """ + super().__init__() + + self.Embedding = get_embeddings(embed) + self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim) + self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, + dropout=dropout, + bidirectional=True, batch_first=True) + self.Linear1 = nn.Linear(hidden_size * 2, hidden_size * 2 // 3) + self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3) + self.relu = torch.nn.LeakyReLU() + self.drop = torch.nn.Dropout(dropout) + self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) + + if id2words is None: + self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False) + else: + self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False, + allowed_transitions=allowed_transitions(id2words, + encoding_type=encoding_type)) + + def _decode(self, x, mask): + r""" + :param torch.FloatTensor x: [batch_size, max_len, tag_size] + :param torch.ByteTensor mask: [batch_size, max_len] + :return torch.LongTensor, [batch_size, max_len] + """ + tag_seq, _ = self.Crf.viterbi_decode(x, mask) + return tag_seq + + def _internal_loss(self, x, y, mask): + r""" + Negative log likelihood loss. + :param x: Tensor, [batch_size, max_len, tag_size] + :param y: Tensor, [batch_size, max_len] + :param mask: Tensor, [batch_size, max_len] + :return loss: a scalar Tensor + + """ + x = x.float() + y = y.long() + total_loss = self.Crf(x, y, mask) + return torch.mean(total_loss) + + def forward(self, words, seq_len, target=None): + r""" + :param torch.LongTensor words: [batch_size, mex_len] + :param torch.LongTensor seq_len:[batch_size, ] + :param torch.LongTensor target: [batch_size, max_len] + :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. + If truth is not None, return loss, a scalar. Used in training. + """ + + words = words.long() + seq_len = seq_len.long() + mask = seq_len_to_mask(seq_len, max_len=words.size(1)) + + target = target.long() if target is not None else None + + if next(self.parameters()).is_cuda: + words = words.cuda() + + x = self.Embedding(words) + x = self.norm1(x) + # [batch_size, max_len, word_emb_dim] + + x, _ = self.Rnn(x, seq_len=seq_len) + + x = self.Linear1(x) + x = self.norm2(x) + x = self.relu(x) + x = self.drop(x) + x = self.Linear2(x) + if target is not None: + return {"loss": self._internal_loss(x, target, mask)} + else: + return {"pred": self._decode(x, mask)} + + def train_step(self, words, seq_len, target): + r""" + + :param torch.LongTensor words: [batch_size, mex_len] + :param torch.LongTensor seq_len: [batch_size, ] + :param torch.LongTensor target: [batch_size, max_len], 目标 + :return torch.Tensor: a scalar loss + """ + return self(words, seq_len, target) + + def evaluate_step(self, words, seq_len): + r""" + + :param torch.LongTensor words: [batch_size, mex_len] + :param torch.LongTensor seq_len: [batch_size, ] + :return torch.LongTensor: [batch_size, max_len] + """ + return self(words, seq_len) diff --git a/fastNLP/modules/torch/__init__.py b/fastNLP/modules/torch/__init__.py new file mode 100755 index 00000000..da92ab9c --- /dev/null +++ b/fastNLP/modules/torch/__init__.py @@ -0,0 +1,26 @@ +__all__ = [ + 'ConditionalRandomField', + 'allowed_transitions', + "State", + "Seq2SeqDecoder", + "LSTMSeq2SeqDecoder", + "TransformerSeq2SeqDecoder", + + "LSTM", + "Seq2SeqEncoder", + "TransformerSeq2SeqEncoder", + "LSTMSeq2SeqEncoder", + "StarTransformer", + "VarRNN", + "VarLSTM", + "VarGRU", + + 'SequenceGenerator', + + "TimestepDropout", +] + +from .decoder import * +from .encoder import * +from .generator import * +from .dropout import TimestepDropout diff --git a/fastNLP/modules/torch/attention.py b/fastNLP/modules/torch/attention.py new file mode 100755 index 00000000..2e194b0b --- /dev/null +++ b/fastNLP/modules/torch/attention.py @@ -0,0 +1,321 @@ +r"""undocumented""" + +__all__ = [ + "MultiHeadAttention", + "BiAttention", + "SelfAttention", +] + +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from .decoder.seq2seq_state import TransformerState + + +class DotAttention(nn.Module): + r""" + Transformer当中的DotAttention + """ + + def __init__(self, key_size, value_size, dropout=0.0): + super(DotAttention, self).__init__() + self.key_size = key_size + self.value_size = value_size + self.scale = math.sqrt(key_size) + self.drop = nn.Dropout(dropout) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, Q, K, V, mask_out=None): + r""" + + :param Q: [..., seq_len_q, key_size] + :param K: [..., seq_len_k, key_size] + :param V: [..., seq_len_k, value_size] + :param mask_out: [..., 1, seq_len] or [..., seq_len_q, seq_len_k] + """ + output = torch.matmul(Q, K.transpose(-1, -2)) / self.scale + if mask_out is not None: + output.masked_fill_(mask_out, -1e9) + output = self.softmax(output) + output = self.drop(output) + return torch.matmul(output, V) + + +class MultiHeadAttention(nn.Module): + """ + Attention is all you need中提到的多头注意力 + + """ + def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): + super(MultiHeadAttention, self).__init__() + self.d_model = d_model + self.n_head = n_head + self.dropout = dropout + self.head_dim = d_model // n_head + self.layer_idx = layer_idx + assert d_model % n_head == 0, "d_model should be divisible by n_head" + self.scaling = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + self.reset_parameters() + + def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): + """ + + :param query: batch x seq x dim + :param key: batch x seq x dim + :param value: batch x seq x dim + :param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 + :param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 + :param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 + :return: + """ + assert key.size() == value.size() + if state is not None: + assert self.layer_idx is not None + qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() + + q = self.q_proj(query) # batch x seq x dim + q *= self.scaling + k = v = None + prev_k = prev_v = None + + # 从state中取kv + if isinstance(state, TransformerState): # 说明此时在inference阶段 + if qkv_same: # 此时在decoder self attention + prev_k = state.decoder_prev_key[self.layer_idx] + prev_v = state.decoder_prev_value[self.layer_idx] + else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 + k = state.encoder_key[self.layer_idx] + v = state.encoder_value[self.layer_idx] + + if k is None: + k = self.k_proj(key) + v = self.v_proj(value) + + if prev_k is not None: + k = torch.cat((prev_k, k), dim=1) + v = torch.cat((prev_v, v), dim=1) + + # 更新state + if isinstance(state, TransformerState): + if qkv_same: + state.decoder_prev_key[self.layer_idx] = k + state.decoder_prev_value[self.layer_idx] = v + else: + state.encoder_key[self.layer_idx] = k + state.encoder_value[self.layer_idx] = v + + # 开始计算attention + batch_size, q_len, d_model = query.size() + k_len, v_len = k.size(1), v.size(1) + q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) + k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) + v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) + + attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head + if key_mask is not None: + _key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1 + attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) + + if attn_mask is not None: + _attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head + attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) + + attn_weights = F.softmax(attn_weights, dim=2) + attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) + + output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim + output = output.reshape(batch_size, q_len, -1) + output = self.out_proj(output) # batch,q_len,dim + + return output, attn_weights + + def reset_parameters(self): + nn.init.xavier_uniform_(self.q_proj.weight) + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.out_proj.weight) + + def set_layer_idx(self, layer_idx): + self.layer_idx = layer_idx + + +class AttentionLayer(nn.Module): + def __init__(selfu, input_size, key_dim, value_dim, bias=False): + """ + 可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention + + :param int input_size: 输入的大小 + :param int key_dim: 一般就是encoder_output输出的维度 + :param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 + :param bias: + """ + super().__init__() + + selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) + selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias) + + def forward(self, input, encode_outputs, encode_mask): + """ + + :param input: batch_size x input_size + :param encode_outputs: batch_size x max_len x key_dim + :param encode_mask: batch_size x max_len, 为0的地方为padding + :return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 + """ + + # x: bsz x encode_hidden_size + x = self.input_proj(input) + + # compute attention + attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len + + # don't attend over padding + if encode_mask is not None: + attn_scores = attn_scores.float().masked_fill_( + encode_mask.eq(0), + float('-inf') + ).type_as(attn_scores) # FP16 support: cast to float and back + + attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz + + # sum weighted sources + x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size + + x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) + return x, attn_scores + + +def _masked_softmax(tensor, mask): + tensor_shape = tensor.size() + reshaped_tensor = tensor.view(-1, tensor_shape[-1]) + + # Reshape the mask so it matches the size of the input tensor. + while mask.dim() < tensor.dim(): + mask = mask.unsqueeze(1) + mask = mask.expand_as(tensor).contiguous().float() + reshaped_mask = mask.view(-1, mask.size()[-1]) + result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) + result = result * reshaped_mask + # 1e-13 is added to avoid divisions by zero. + result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) + return result.view(*tensor_shape) + + +def _weighted_sum(tensor, weights, mask): + w_sum = weights.bmm(tensor) + while mask.dim() < w_sum.dim(): + mask = mask.unsqueeze(1) + mask = mask.transpose(-1, -2) + mask = mask.expand_as(w_sum).contiguous().float() + return w_sum * mask + + +class BiAttention(nn.Module): + r""" + Bi Attention module + + 对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果 + + .. math:: + + \begin{array}{ll} \\ + e_{ij} = {a}^{\mathrm{T}}_{i}{b}_{j} \\ + {\hat{a}}_{i} = \sum_{j=1}^{\mathcal{l}_{b}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{b}}{\mathrm{exp}(e_{ik})}}}{b}_{j} \\ + {\hat{b}}_{j} = \sum_{i=1}^{\mathcal{l}_{a}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{a}}{\mathrm{exp}(e_{ik})}}}{a}_{i} \\ + \end{array} + + """ + + def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): + r""" + :param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] + :param torch.Tensor premise_mask: [batch_size, a_seq_len] + :param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] + :param torch.Tensor hypothesis_mask: [batch_size, b_seq_len] + :return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] + """ + similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) + .contiguous()) + + prem_hyp_attn = _masked_softmax(similarity_matrix, hypothesis_mask) + hyp_prem_attn = _masked_softmax(similarity_matrix.transpose(1, 2) + .contiguous(), + premise_mask) + + attended_premises = _weighted_sum(hypothesis_batch, + prem_hyp_attn, + premise_mask) + attended_hypotheses = _weighted_sum(premise_batch, + hyp_prem_attn, + hypothesis_mask) + + return attended_premises, attended_hypotheses + + +class SelfAttention(nn.Module): + r""" + 这是一个基于论文 `A structured self-attentive sentence embedding `_ + 的Self Attention Module. + """ + + def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5): + r""" + + :param int input_size: 输入tensor的hidden维度 + :param int attention_unit: 输出tensor的hidden维度 + :param int attention_hops: + :param float drop: dropout概率,默认值为0.5 + """ + super(SelfAttention, self).__init__() + + self.attention_hops = attention_hops + self.ws1 = nn.Linear(input_size, attention_unit, bias=False) + self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) + self.I = torch.eye(attention_hops, requires_grad=False) + self.I_origin = self.I + self.drop = nn.Dropout(drop) + self.tanh = nn.Tanh() + + def _penalization(self, attention): + r""" + compute the penalization term for attention module + """ + baz = attention.size(0) + size = self.I.size() + if len(size) != 3 or size[0] != baz: + self.I = self.I_origin.expand(baz, -1, -1) + self.I = self.I.to(device=attention.device) + attention_t = torch.transpose(attention, 1, 2).contiguous() + mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] + ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 + return torch.sum(ret) / size[0] + + def forward(self, input, input_origin): + r""" + :param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 + :param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 + :return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 + :return torch.Tensor output2: [1] attention惩罚项,是一个标量 + """ + input = input.contiguous() + size = input.size() # [bsz, len, nhid] + + input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] + input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] + + y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] + attention = self.ws2(y1).transpose(1, 2).contiguous() + # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] + + attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. + attention = F.softmax(attention, 2) # [baz ,hop, len] + return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] diff --git a/fastNLP/modules/torch/decoder/__init__.py b/fastNLP/modules/torch/decoder/__init__.py new file mode 100755 index 00000000..8181d271 --- /dev/null +++ b/fastNLP/modules/torch/decoder/__init__.py @@ -0,0 +1,15 @@ + +__all__ = [ + 'ConditionalRandomField', + 'allowed_transitions', + + "State", + + "Seq2SeqDecoder", + "LSTMSeq2SeqDecoder", + "TransformerSeq2SeqDecoder" +] + +from .crf import ConditionalRandomField, allowed_transitions +from .seq2seq_state import State +from .seq2seq_decoder import LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, Seq2SeqDecoder \ No newline at end of file diff --git a/fastNLP/modules/torch/decoder/crf.py b/fastNLP/modules/torch/decoder/crf.py new file mode 100755 index 00000000..8c6b8858 --- /dev/null +++ b/fastNLP/modules/torch/decoder/crf.py @@ -0,0 +1,354 @@ +r"""undocumented""" + +__all__ = [ + "ConditionalRandomField", + "allowed_transitions" +] + +from typing import Union, List + +import torch +from torch import nn + +from ....core.metrics.span_f1_pre_rec_metric import _get_encoding_type_from_tag_vocab, _check_tag_vocab_and_encoding_type +from ....core.vocabulary import Vocabulary + + +def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type:str=None, include_start_end:bool=False): + r""" + 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 + + :param tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", + tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。 + :param encoding_type: 支持``["bio", "bmes", "bmeso", "bioes"]``。默认为None,通过vocab自动推断 + :param include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; + 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); + start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 + :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 + """ + if encoding_type is None: + encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) + else: + encoding_type = encoding_type.lower() + _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) + + pad_token = '' + unk_token = '' + + if isinstance(tag_vocab, Vocabulary): + id_label_lst = list(tag_vocab.idx2word.items()) + pad_token = tag_vocab.padding + unk_token = tag_vocab.unknown + else: + id_label_lst = list(tag_vocab.items()) + + num_tags = len(tag_vocab) + start_idx = num_tags + end_idx = num_tags + 1 + allowed_trans = [] + if include_start_end: + id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] + def split_tag_label(from_label): + from_label = from_label.lower() + if from_label in ['start', 'end']: + from_tag = from_label + from_label = '' + else: + from_tag = from_label[:1] + from_label = from_label[2:] + return from_tag, from_label + + for from_id, from_label in id_label_lst: + if from_label in [pad_token, unk_token]: + continue + from_tag, from_label = split_tag_label(from_label) + for to_id, to_label in id_label_lst: + if to_label in [pad_token, unk_token]: + continue + to_tag, to_label = split_tag_label(to_label) + if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): + allowed_trans.append((from_id, to_id)) + return allowed_trans + + +def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): + r""" + + :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 + :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag + :param str from_label: 比如"PER", "LOC"等label + :param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag + :param str to_label: 比如"PER", "LOC"等label + :return: bool,能否跃迁 + """ + if to_tag == 'start' or from_tag == 'end': + return False + encoding_type = encoding_type.lower() + if encoding_type == 'bio': + r""" + 第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 + +-------+---+---+---+-------+-----+ + | | B | I | O | start | end | + +-------+---+---+---+-------+-----+ + | B | y | - | y | n | y | + +-------+---+---+---+-------+-----+ + | I | y | - | y | n | y | + +-------+---+---+---+-------+-----+ + | O | y | n | y | n | y | + +-------+---+---+---+-------+-----+ + | start | y | n | y | n | n | + +-------+---+---+---+-------+-----+ + | end | n | n | n | n | n | + +-------+---+---+---+-------+-----+ + """ + if from_tag == 'start': + return to_tag in ('b', 'o') + elif from_tag in ['b', 'i']: + return any([to_tag in ['end', 'b', 'o'], to_tag == 'i' and from_label == to_label]) + elif from_tag == 'o': + return to_tag in ['end', 'b', 'o'] + else: + raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) + + elif encoding_type == 'bmes': + r""" + 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 + +-------+---+---+---+---+-------+-----+ + | | B | M | E | S | start | end | + +-------+---+---+---+---+-------+-----+ + | B | n | - | - | n | n | n | + +-------+---+---+---+---+-------+-----+ + | M | n | - | - | n | n | n | + +-------+---+---+---+---+-------+-----+ + | E | y | n | n | y | n | y | + +-------+---+---+---+---+-------+-----+ + | S | y | n | n | y | n | y | + +-------+---+---+---+---+-------+-----+ + | start | y | n | n | y | n | n | + +-------+---+---+---+---+-------+-----+ + | end | n | n | n | n | n | n | + +-------+---+---+---+---+-------+-----+ + """ + if from_tag == 'start': + return to_tag in ['b', 's'] + elif from_tag == 'b': + return to_tag in ['m', 'e'] and from_label == to_label + elif from_tag == 'm': + return to_tag in ['m', 'e'] and from_label == to_label + elif from_tag in ['e', 's']: + return to_tag in ['b', 's', 'end'] + else: + raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) + elif encoding_type == 'bmeso': + if from_tag == 'start': + return to_tag in ['b', 's', 'o'] + elif from_tag == 'b': + return to_tag in ['m', 'e'] and from_label == to_label + elif from_tag == 'm': + return to_tag in ['m', 'e'] and from_label == to_label + elif from_tag in ['e', 's', 'o']: + return to_tag in ['b', 's', 'end', 'o'] + else: + raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) + elif encoding_type == 'bioes': + if from_tag == 'start': + return to_tag in ['b', 's', 'o'] + elif from_tag == 'b': + return to_tag in ['i', 'e'] and from_label == to_label + elif from_tag == 'i': + return to_tag in ['i', 'e'] and from_label == to_label + elif from_tag in ['e', 's', 'o']: + return to_tag in ['b', 's', 'end', 'o'] + else: + raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) + else: + raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) + + +class ConditionalRandomField(nn.Module): + r""" + 条件随机场。提供 forward() 以及 viterbi_decode() 两个方法,分别用于训练与inference。 + + """ + + def __init__(self, num_tags:int, include_start_end_trans:bool=False, allowed_transitions:List=None): + r""" + + :param num_tags: 标签的数量 + :param include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 + :param allowed_transitions: 内部的Tuple[from_tag_id(int), + to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 + allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 + """ + super(ConditionalRandomField, self).__init__() + + self.include_start_end_trans = include_start_end_trans + self.num_tags = num_tags + + # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score + self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) + if self.include_start_end_trans: + self.start_scores = nn.Parameter(torch.randn(num_tags)) + self.end_scores = nn.Parameter(torch.randn(num_tags)) + + if allowed_transitions is None: + constrain = torch.zeros(num_tags + 2, num_tags + 2) + else: + constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float) + has_start = False + has_end = False + for from_tag_id, to_tag_id in allowed_transitions: + constrain[from_tag_id, to_tag_id] = 0 + if from_tag_id==num_tags: + has_start = True + if to_tag_id==num_tags+1: + has_end = True + if not has_start: + constrain[num_tags, :].fill_(0) + if not has_end: + constrain[:, num_tags+1].fill_(0) + self._constrain = nn.Parameter(constrain, requires_grad=False) + + def _normalizer_likelihood(self, logits, mask): + r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the + sum of the likelihoods across all possible state sequences. + + :param logits:FloatTensor, max_len x batch_size x num_tags + :param mask:ByteTensor, max_len x batch_size + :return:FloatTensor, batch_size + """ + seq_len, batch_size, n_tags = logits.size() + alpha = logits[0] + if self.include_start_end_trans: + alpha = alpha + self.start_scores.view(1, -1) + + flip_mask = mask.eq(False) + + for i in range(1, seq_len): + emit_score = logits[i].view(batch_size, 1, n_tags) + trans_score = self.trans_m.view(1, n_tags, n_tags) + tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score + alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ + alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) + + if self.include_start_end_trans: + alpha = alpha + self.end_scores.view(1, -1) + + return torch.logsumexp(alpha, 1) + + def _gold_score(self, logits, tags, mask): + r""" + Compute the score for the gold path. + :param logits: FloatTensor, max_len x batch_size x num_tags + :param tags: LongTensor, max_len x batch_size + :param mask: ByteTensor, max_len x batch_size + :return:FloatTensor, batch_size + """ + seq_len, batch_size, _ = logits.size() + batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) + seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) + + # trans_socre [L-1, B] + mask = mask.eq(True) + flip_mask = mask.eq(False) + trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) + # emit_score [L, B] + emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) + # score [L-1, B] + score = trans_score + emit_score[:seq_len - 1, :] + score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) + if self.include_start_end_trans: + st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] + last_idx = mask.long().sum(0) - 1 + ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] + score = score + st_scores + ed_scores + # return [B,] + return score + + def forward(self, feats, tags, mask): + r""" + 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 + + :param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 + :param torch.LongTensor tags: batch_size x max_len,标签矩阵。 + :param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 + :return: torch.FloatTensor, (batch_size,) + """ + feats = feats.transpose(0, 1) + tags = tags.transpose(0, 1).long() + mask = mask.transpose(0, 1).float() + all_path_score = self._normalizer_likelihood(feats, mask) + gold_path_score = self._gold_score(feats, tags, mask) + + return all_path_score - gold_path_score + + def viterbi_decode(self, logits, mask, unpad=False): + r"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 + + :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 + :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 + :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 + List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 + 个sample的有效长度。 + :return: 返回 (paths, scores)。 + paths: 是解码后的路径, 其值参照unpad参数. + scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 + + """ + batch_size, max_len, n_tags = logits.size() + seq_len = mask.long().sum(1) + logits = logits.transpose(0, 1).data # L, B, H + mask = mask.transpose(0, 1).data.eq(True) # L, B + flip_mask = mask.eq(False) + + # dp + vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long) + vscore = logits[0] # bsz x n_tags + transitions = self._constrain.data.clone() + transitions[:n_tags, :n_tags] += self.trans_m.data + if self.include_start_end_trans: + transitions[n_tags, :n_tags] += self.start_scores.data + transitions[:n_tags, n_tags + 1] += self.end_scores.data + + vscore += transitions[n_tags, :n_tags] + + trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data + end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags + + # 针对长度为1的句子 + vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \ + .masked_fill(seq_len.ne(1).view(-1, 1), 0) + for i in range(1, max_len): + prev_score = vscore.view(batch_size, n_tags, 1) + cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score + score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag + # 需要考虑当前位置是该序列的最后一个 + score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0) + + best_score, best_dst = score.max(1) + vpath[i] = best_dst + # 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况 + vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ + vscore.masked_fill(mask[i].view(batch_size, 1), 0) + + # backtrace + batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) + seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device) + lens = (seq_len - 1) + # idxes [L, B], batched idx from seq_len-1 to 0 + idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len + + ans = logits.new_empty((max_len, batch_size), dtype=torch.long) + ans_score, last_tags = vscore.max(1) + ans[idxes[0], batch_idx] = last_tags + for i in range(max_len - 1): + last_tags = vpath[idxes[i], batch_idx, last_tags] + ans[idxes[i + 1], batch_idx] = last_tags + ans = ans.transpose(0, 1) + if unpad: + paths = [] + for idx, max_len in enumerate(lens): + paths.append(ans[idx, :max_len + 1].tolist()) + else: + paths = ans + return paths, ans_score diff --git a/fastNLP/modules/torch/decoder/seq2seq_decoder.py b/fastNLP/modules/torch/decoder/seq2seq_decoder.py new file mode 100755 index 00000000..2c1dfe36 --- /dev/null +++ b/fastNLP/modules/torch/decoder/seq2seq_decoder.py @@ -0,0 +1,416 @@ +r"""undocumented""" +from typing import Union, Tuple +import math + +import torch +from torch import nn +import torch.nn.functional as F +from ..attention import AttentionLayer, MultiHeadAttention +from ....embeddings.torch.utils import get_embeddings +from ....embeddings.torch.static_embedding import StaticEmbedding +from .seq2seq_state import State, LSTMState, TransformerState + + +__all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder'] + + +class Seq2SeqDecoder(nn.Module): + """ + Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 + 用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 + + """ + def __init__(self): + super().__init__() + + def forward(self, tokens, state, **kwargs): + """ + + :param torch.LongTensor tokens: bsz x max_len + :param State state: state包含了encoder的输出以及decode之前的内容 + :return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 + """ + raise NotImplemented + + def reorder_states(self, indices, states): + """ + 根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 + + :param torch.LongTensor indices: + :param State states: + :return: + """ + assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" + states.reorder_state(indices) + + def init_state(self, encoder_output, encoder_mask): + """ + 初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 + + :param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch + 维度 + :param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch + 维度 + :param kwargs: + :return: State, 返回一个State对象,记录了encoder的输出 + """ + state = State(encoder_output, encoder_mask) + return state + + def decode(self, tokens, state): + """ + 根据states中的内容,以及tokens中的内容进行之后的生成。 + + :param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。 + :param State state: 记录了encoder输出与decoder过去状态 + :return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 + """ + outputs = self(state=state, tokens=tokens) + if isinstance(outputs, torch.Tensor): + return outputs[:, -1] + else: + raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.") + + +class TiedEmbedding(nn.Module): + """ + 用于将weight和原始weight绑定 + + """ + def __init__(self, weight): + super().__init__() + self.weight = weight # vocab_size x embed_size + + def forward(self, x): + """ + + :param torch.FloatTensor x: bsz x * x embed_size + :return: torch.FloatTensor bsz x * x vocab_size + """ + return torch.matmul(x, self.weight.t()) + + +def get_bind_decoder_output_embed(embed): + """ + 给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding + + :param embed: + :return: + """ + if isinstance(embed, StaticEmbedding): + for idx, map2idx in enumerate(embed.words_to_words): + assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ + "include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ + "`lower=True` or `min_freq!=1`." + elif not isinstance(embed, nn.Embedding): + raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") + + return TiedEmbedding(embed.weight) + + +class LSTMSeq2SeqDecoder(Seq2SeqDecoder): + """ + LSTM的Decoder + + :param nn.Module,tuple embed: decoder输入的embedding. + :param int num_layers: 多少层LSTM + :param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 + :param dropout: Dropout的大小 + :param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, + 则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. + :param bool attention: 是否使用attention + """ + def __init__(self, embed: Union[nn.Module, Tuple[int, int]], num_layers = 3, hidden_size = 300, + dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): + super().__init__() + self.embed = get_embeddings(init_embed=embed) + self.embed_dim = embed.embedding_dim + + if bind_decoder_input_output_embed: + self.output_layer = get_bind_decoder_output_embed(self.embed) + else: # 不需要bind + self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) + self.output_layer = TiedEmbedding(self.output_embed.weight) + + self.hidden_size = hidden_size + self.num_layers = num_layers + self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, + batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0) + + self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None + self.output_proj = nn.Linear(hidden_size, self.embed_dim) + self.dropout_layer = nn.Dropout(dropout) + + def forward(self, tokens, state, return_attention=False): + """ + + :param torch.LongTensor tokens: batch x max_len + :param LSTMState state: 保存encoder输出和decode状态的State对象 + :param bool return_attention: 是否返回attention的的score + :return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length + """ + src_output = state.encoder_output + encoder_mask = state.encoder_mask + + assert tokens.size(1)>state.decode_length, "The state does not match the tokens." + tokens = tokens[:, state.decode_length:] + x = self.embed(tokens) + + attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq + input_feed = state.input_feed + decoder_out = [] + + cur_hidden = state.hidden + cur_cell = state.cell + + # 开始计算 + for i in range(tokens.size(1)): + input = torch.cat( + (x[:, i:i + 1, :], + input_feed[:, None, :] + ), + dim=2 + ) # batch,1,2*dim + _, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size + if self.attention_layer is not None: + input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) + attn_weights.append(attn_weight) + else: + input_feed = cur_hidden[-1] + + state.input_feed = input_feed # batch, hidden + state.hidden = cur_hidden + state.cell = cur_cell + state.decode_length += 1 + decoder_out.append(input_feed) + + decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden + decoder_out = self.dropout_layer(decoder_out) + if attn_weights is not None: + attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len + + decoder_out = self.output_proj(decoder_out) + feats = self.output_layer(decoder_out) + + if return_attention: + return feats, attn_weights + return feats + + def init_state(self, encoder_output, encoder_mask) -> LSTMState: + """ + + :param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: + bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 + hidden state和cell state来赋值这两个值 + (2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 + :param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend + :return: + """ + if not isinstance(encoder_output, torch.Tensor): + encoder_output, (hidden, cell) = encoder_output + else: + hidden = cell = None + assert encoder_output.ndim==3 + assert encoder_mask.size()==encoder_output.size()[:2] + assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \ + "the hidden_size." + + t = [hidden, cell] + for idx in range(2): + v = t[idx] + if v is None: + v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size) + else: + assert v.dim()==2 + assert v.size(-1)==self.hidden_size + v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size + t[idx] = v + + state = LSTMState(encoder_output, encoder_mask, t[0], t[1]) + + return state + + +class TransformerSeq2SeqDecoderLayer(nn.Module): + """ + + :param int d_model: 输入、输出的维度 + :param int n_head: 多少个head,需要能被d_model整除 + :param int dim_ff: + :param float dropout: + :param int layer_idx: layer的编号 + """ + def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): + super().__init__() + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息 + + self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) + self.self_attn_layer_norm = nn.LayerNorm(d_model) + + self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) + self.encoder_attn_layer_norm = nn.LayerNorm(d_model) + + self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(self.dim_ff, self.d_model), + nn.Dropout(dropout)) + + self.final_layer_norm = nn.LayerNorm(self.d_model) + + def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): + """ + + :param x: (batch, seq_len, dim), decoder端的输入 + :param encoder_output: (batch,src_seq_len,dim), encoder的输出 + :param encoder_mask: batch,src_seq_len, 为1的地方需要attend + :param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 + :param TransformerState state: 只在inference阶段传入 + :return: + """ + + # self attention part + residual = x + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn(query=x, + key=x, + value=x, + attn_mask=self_attn_mask, + state=state) + + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # encoder attention part + residual = x + x = self.encoder_attn_layer_norm(x) + x, attn_weight = self.encoder_attn(query=x, + key=encoder_output, + value=encoder_output, + key_mask=encoder_mask, + state=state) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # ffn + residual = x + x = self.final_layer_norm(x) + x = self.ffn(x) + x = residual + x + + return x, attn_weight + + +class TransformerSeq2SeqDecoder(Seq2SeqDecoder): + """ + + :param embed: 输入token的embedding + :param nn.Module pos_embed: 位置embedding + :param int d_model: 输出、输出的大小 + :param int num_layers: 多少层 + :param int n_head: 多少个head + :param int dim_ff: FFN 的中间大小 + :param float dropout: Self-Attention和FFN中的dropout的大小 + :param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, + 则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. + """ + def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, + d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, + bind_decoder_input_output_embed = True): + super().__init__() + + self.embed = get_embeddings(embed) + self.pos_embed = pos_embed + + if bind_decoder_input_output_embed: + self.output_layer = get_bind_decoder_output_embed(self.embed) + else: # 不需要bind + self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) + self.output_layer = TiedEmbedding(self.output_embed.weight) + + self.num_layers = num_layers + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) + self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) + for layer_idx in range(num_layers)]) + + self.embed_scale = math.sqrt(d_model) + self.layer_norm = nn.LayerNorm(d_model) + self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) + + def forward(self, tokens, state, return_attention=False): + """ + + :param torch.LongTensor tokens: batch x tgt_len,decode的词 + :param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 + :param bool return_attention: 是否返回对encoder结果的attention score + :return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length + """ + + encoder_output = state.encoder_output + encoder_mask = state.encoder_mask + + assert state.decode_length1: + triangle_mask = self._get_triangle_mask(tokens) + else: + triangle_mask = None + + for layer in self.layer_stacks: + x, attn_weight = layer(x=x, + encoder_output=encoder_output, + encoder_mask=encoder_mask, + self_attn_mask=triangle_mask, + state=state + ) + + x = self.layer_norm(x) # batch, tgt_len, dim + x = self.output_fc(x) + feats = self.output_layer(x) + + if return_attention: + return feats, attn_weight + return feats + + def init_state(self, encoder_output, encoder_mask): + """ + 初始化一个TransformerState用于forward + + :param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 + :param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 + :return: TransformerState + """ + if isinstance(encoder_output, torch.Tensor): + encoder_output = encoder_output + elif isinstance(encoder_output, (list, tuple)): + encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 + else: + raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") + state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) + return state + + @staticmethod + def _get_triangle_mask(tokens): + tensor = tokens.new_ones(tokens.size(1), tokens.size(1)) + return torch.tril(tensor).byte() + + diff --git a/fastNLP/modules/torch/decoder/seq2seq_state.py b/fastNLP/modules/torch/decoder/seq2seq_state.py new file mode 100755 index 00000000..de200f86 --- /dev/null +++ b/fastNLP/modules/torch/decoder/seq2seq_state.py @@ -0,0 +1,145 @@ +r""" +每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录 + +""" + +__all__ = [ + 'State', + "LSTMState", + "TransformerState" +] + +from typing import Union +import torch + + +class State: + def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): + """ + 每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 + + :param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch + 维度 + :param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch + 维度 + :param kwargs: + """ + self.encoder_output = encoder_output + self.encoder_mask = encoder_mask + self._decode_length = 0 + + @property + def num_samples(self): + """ + 返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 + + :return: + """ + if self.encoder_output is not None: + return self.encoder_output.size(0) + else: + return None + + @property + def decode_length(self): + """ + 当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 + + :return: + """ + return self._decode_length + + @decode_length.setter + def decode_length(self, value): + self._decode_length = value + + def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): + if isinstance(state, torch.Tensor): + state = state.index_select(index=indices, dim=dim) + elif isinstance(state, list): + for i in range(len(state)): + assert state[i] is not None + state[i] = self._reorder_state(state[i], indices, dim) + elif isinstance(state, tuple): + tmp_list = [] + for i in range(len(state)): + assert state[i] is not None + tmp_list.append(self._reorder_state(state[i], indices, dim)) + state = tuple(tmp_list) + else: + raise TypeError(f"Cannot reorder data of type:{type(state)}") + + return state + + def reorder_state(self, indices: torch.LongTensor): + if self.encoder_mask is not None: + self.encoder_mask = self._reorder_state(self.encoder_mask, indices) + if self.encoder_output is not None: + self.encoder_output = self._reorder_state(self.encoder_output, indices) + + +class LSTMState(State): + def __init__(self, encoder_output, encoder_mask, hidden, cell): + """ + LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 + + :param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 + :param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding + :param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 + :param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 + """ + super().__init__(encoder_output, encoder_mask) + self.hidden = hidden + self.cell = cell + self._input_feed = hidden[0] # 默认是上一个时刻的输出 + + @property + def input_feed(self): + """ + LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, + input_feed即上个时刻的hidden state, 否则是attention layer的输出。 + :return: torch.FloatTensor, bsz x hidden_size + """ + return self._input_feed + + @input_feed.setter + def input_feed(self, value): + self._input_feed = value + + def reorder_state(self, indices: torch.LongTensor): + super().reorder_state(indices) + self.hidden = self._reorder_state(self.hidden, indices, dim=1) + self.cell = self._reorder_state(self.cell, indices, dim=1) + if self.input_feed is not None: + self.input_feed = self._reorder_state(self.input_feed, indices, dim=0) + + +class TransformerState(State): + def __init__(self, encoder_output, encoder_mask, num_decoder_layer): + """ + 与TransformerSeq2SeqDecoder对应的State, + + :param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 + :param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend + :param int num_decoder_layer: decode有多少层 + """ + super().__init__(encoder_output, encoder_mask) + self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim + self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim + self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim + self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim + + def reorder_state(self, indices: torch.LongTensor): + super().reorder_state(indices) + self.encoder_key = self._reorder_state(self.encoder_key, indices) + self.encoder_value = self._reorder_state(self.encoder_value, indices) + self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) + self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) + + @property + def decode_length(self): + if self.decoder_prev_key[0] is not None: + return self.decoder_prev_key[0].size(1) + return 0 + + diff --git a/fastNLP/modules/torch/dropout.py b/fastNLP/modules/torch/dropout.py new file mode 100755 index 00000000..62b039b4 --- /dev/null +++ b/fastNLP/modules/torch/dropout.py @@ -0,0 +1,24 @@ +r"""undocumented""" + +__all__ = [ + "TimestepDropout" +] + +import torch + + +class TimestepDropout(torch.nn.Dropout): + r""" + 传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` + 使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 + """ + + def forward(self, x): + dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) + torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) + dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] + if self.inplace: + x *= dropout_mask + return + else: + return x * dropout_mask diff --git a/fastNLP/modules/torch/encoder/__init__.py b/fastNLP/modules/torch/encoder/__init__.py index d893305f..fc986f0b 100644 --- a/fastNLP/modules/torch/encoder/__init__.py +++ b/fastNLP/modules/torch/encoder/__init__.py @@ -1,5 +1,21 @@ __all__ = [ + "ConvMaxpool", + "LSTM", + + "Seq2SeqEncoder", + "TransformerSeq2SeqEncoder", + "LSTMSeq2SeqEncoder", + + "StarTransformer", + + "VarRNN", + "VarLSTM", + "VarGRU" ] -from .lstm import LSTM \ No newline at end of file +from .conv_maxpool import ConvMaxpool +from .lstm import LSTM +from .seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder +from .star_transformer import StarTransformer +from .variational_rnn import VarRNN, VarLSTM, VarGRU \ No newline at end of file diff --git a/fastNLP/modules/torch/encoder/conv_maxpool.py b/fastNLP/modules/torch/encoder/conv_maxpool.py new file mode 100755 index 00000000..7373005a --- /dev/null +++ b/fastNLP/modules/torch/encoder/conv_maxpool.py @@ -0,0 +1,87 @@ +r"""undocumented""" + +__all__ = [ + "ConvMaxpool" +] +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvMaxpool(nn.Module): + r""" + 集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x + sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) + 这一维进行max_pooling。最后得到每个sample的一个向量表示。 + + """ + + def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): + r""" + + :param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 + :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 + :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 + :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh + """ + super(ConvMaxpool, self).__init__() + + for kernel_size in kernel_sizes: + assert kernel_size % 2 == 1, "kernel size has to be odd numbers." + + # convolution + if isinstance(kernel_sizes, (list, tuple, int)): + if isinstance(kernel_sizes, int) and isinstance(out_channels, int): + out_channels = [out_channels] + kernel_sizes = [kernel_sizes] + elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): + assert len(out_channels) == len( + kernel_sizes), "The number of out_channels should be equal to the number" \ + " of kernel_sizes." + else: + raise ValueError("The type of out_channels and kernel_sizes should be the same.") + + self.convs = nn.ModuleList([nn.Conv1d( + in_channels=in_channels, + out_channels=oc, + kernel_size=ks, + stride=1, + padding=ks // 2, + dilation=1, + groups=1, + bias=False) + for oc, ks in zip(out_channels, kernel_sizes)]) + + else: + raise Exception( + 'Incorrect kernel sizes: should be list, tuple or int') + + # activation function + if activation == 'relu': + self.activation = F.relu + elif activation == 'sigmoid': + self.activation = F.sigmoid + elif activation == 'tanh': + self.activation = F.tanh + else: + raise Exception( + "Undefined activation function: choose from: relu, tanh, sigmoid") + + def forward(self, x, mask=None): + r""" + + :param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 + :param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 + :return: + """ + # [N,L,C] -> [N,C,L] + x = torch.transpose(x, 1, 2) + # convolution + xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] + if mask is not None: + mask = mask.unsqueeze(1) # B x 1 x L + xs = [x.masked_fill(mask.eq(False), float('-inf')) for x in xs] + # max-pooling + xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) + for i in xs] # [[N, C], ...] + return torch.cat(xs, dim=-1) # [N, C] diff --git a/fastNLP/modules/torch/encoder/seq2seq_encoder.py b/fastNLP/modules/torch/encoder/seq2seq_encoder.py new file mode 100755 index 00000000..6a42c9d6 --- /dev/null +++ b/fastNLP/modules/torch/encoder/seq2seq_encoder.py @@ -0,0 +1,193 @@ +r"""undocumented""" +import torch.nn as nn +import torch +from torch.nn import LayerNorm +import torch.nn.functional as F +from typing import Union, Tuple +from ....core.utils import seq_len_to_mask +import math +from .lstm import LSTM +from ..attention import MultiHeadAttention +from ....embeddings.torch import StaticEmbedding +from ....embeddings.torch.utils import get_embeddings + + +__all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder'] + + +class Seq2SeqEncoder(nn.Module): + """ + 所有Sequence2Sequence Encoder的基类。需要实现forward函数 + + """ + def __init__(self): + super().__init__() + + def forward(self, tokens, seq_len): + """ + + :param torch.LongTensor tokens: bsz x max_len, encoder的输入 + :param torch.LongTensor seq_len: bsz + :return: + """ + raise NotImplementedError + + +class TransformerSeq2SeqEncoderLayer(nn.Module): + """ + Self-Attention的Layer, + + :param int d_model: input和output的输出维度 + :param int n_head: 多少个head,每个head的维度为d_model/n_head + :param int dim_ff: FFN的维度大小 + :param float dropout: Self-attention和FFN的dropout大小,0表示不drop + """ + def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, + dropout: float = 0.1): + super(TransformerSeq2SeqEncoderLayer, self).__init__() + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.self_attn = MultiHeadAttention(d_model, n_head, dropout) + self.attn_layer_norm = LayerNorm(d_model) + self.ffn_layer_norm = LayerNorm(d_model) + + self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(self.dim_ff, self.d_model), + nn.Dropout(dropout)) + + def forward(self, x, mask): + """ + + :param x: batch x src_seq x d_model + :param mask: batch x src_seq,为0的地方为padding + :return: + """ + # attention + residual = x + x = self.attn_layer_norm(x) + x, _ = self.self_attn(query=x, + key=x, + value=x, + key_mask=mask) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # ffn + residual = x + x = self.ffn_layer_norm(x) + x = self.ffn(x) + x = residual + x + + return x + + +class TransformerSeq2SeqEncoder(Seq2SeqEncoder): + """ + 基于Transformer的Encoder + + :param embed: encoder输入token的embedding + :param nn.Module pos_embed: position embedding + :param int num_layers: 多少层的encoder + :param int d_model: 输入输出的维度 + :param int n_head: 多少个head + :param int dim_ff: FFN中间的维度大小 + :param float dropout: Attention和FFN的dropout大小 + """ + def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, + num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): + super(TransformerSeq2SeqEncoder, self).__init__() + self.embed = get_embeddings(embed) + self.embed_scale = math.sqrt(d_model) + self.pos_embed = pos_embed + self.num_layers = num_layers + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) + self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) + for _ in range(num_layers)]) + self.layer_norm = LayerNorm(d_model) + + def forward(self, tokens, seq_len): + """ + + :param tokens: batch x max_len + :param seq_len: [batch] + :return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) + """ + x = self.embed(tokens) * self.embed_scale # batch, seq, dim + batch_size, max_src_len, _ = x.size() + device = x.device + if self.pos_embed is not None: + position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) + x += self.pos_embed(position) + + x = self.input_fc(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len) + encoder_mask = encoder_mask.to(device) + + for layer in self.layer_stacks: + x = layer(x, encoder_mask) + + x = self.layer_norm(x) + + return x, encoder_mask + + +class LSTMSeq2SeqEncoder(Seq2SeqEncoder): + """ + LSTM的Encoder + + :param embed: encoder的token embed + :param int num_layers: 多少层 + :param int hidden_size: LSTM隐藏层、输出的大小 + :param float dropout: LSTM层之间的Dropout是多少 + :param bool bidirectional: 是否使用双向 + """ + def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, + hidden_size = 400, dropout = 0.3, bidirectional=True): + super().__init__() + self.embed = get_embeddings(embed) + self.num_layers = num_layers + self.dropout = dropout + self.hidden_size = hidden_size + self.bidirectional = bidirectional + hidden_size = hidden_size//2 if bidirectional else hidden_size + self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, + batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) + + def forward(self, tokens, seq_len): + """ + + :param torch.LongTensor tokens: bsz x max_len + :param torch.LongTensor seq_len: bsz + :return: (output, (hidden, cell)), encoder_mask + output: bsz x max_len x hidden_size, + hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 + encoder_mask: bsz x max_len, 为0的地方是padding + """ + x = self.embed(tokens) + device = x.device + x, (final_hidden, final_cell) = self.lstm(x, seq_len) + encoder_mask = seq_len_to_mask(seq_len).to(device) + + # x: batch,seq_len,dim; h/c: num_layers*2,batch,dim + + if self.bidirectional: + final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input + final_cell = self.concat_bidir(final_cell) + + return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return + + def concat_bidir(self, input): + output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2) + return output.reshape(self.num_layers, input.size(1), -1) diff --git a/fastNLP/modules/torch/encoder/star_transformer.py b/fastNLP/modules/torch/encoder/star_transformer.py new file mode 100755 index 00000000..d7e2a1c1 --- /dev/null +++ b/fastNLP/modules/torch/encoder/star_transformer.py @@ -0,0 +1,166 @@ +r"""undocumented +Star-Transformer 的encoder部分的 Pytorch 实现 +""" + +__all__ = [ + "StarTransformer" +] + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +class StarTransformer(nn.Module): + r""" + Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 + + paper: https://arxiv.org/abs/1902.09113 + + """ + + def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): + r""" + + :param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 + :param int num_layers: star-transformer的层数 + :param int num_head: head的数量。 + :param int head_dim: 每个head的维度大小。 + :param float dropout: dropout 概率. Default: 0.1 + :param int max_len: int or None, 如果为int,输入序列的最大长度, + 模型会为输入序列加上position embedding。 + 若为`None`,忽略加上position embedding的步骤. Default: `None` + """ + super(StarTransformer, self).__init__() + self.iters = num_layers + + self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)]) + # self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1) + self.emb_drop = nn.Dropout(dropout) + self.ring_att = nn.ModuleList( + [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) + for _ in range(self.iters)]) + self.star_att = nn.ModuleList( + [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) + for _ in range(self.iters)]) + + if max_len is not None: + self.pos_emb = nn.Embedding(max_len, hidden_size) + else: + self.pos_emb = None + + def forward(self, data, mask): + r""" + :param FloatTensor data: [batch, length, hidden] 输入的序列 + :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, + 否则为 1 + :return: [batch, length, hidden] 编码后的输出序列 + + [batch, hidden] 全局 relay 节点, 详见论文 + """ + + def norm_func(f, x): + # B, H, L, 1 + return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + B, L, H = data.size() + mask = (mask.eq(False)) # flip the mask for masked_fill_ + smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) + + embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 + if self.pos_emb: + P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ + .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 + embs = embs + P + embs = norm_func(self.emb_drop, embs) + nodes = embs + relay = embs.mean(2, keepdim=True) + ex_mask = mask[:, None, :, None].expand(B, H, L, 1) + r_embs = embs.view(B, H, 1, L) + for i in range(self.iters): + ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) + nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) + # nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax)) + relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) + + nodes = nodes.masked_fill_(ex_mask, 0) + + nodes = nodes.view(B, H, L).permute(0, 2, 1) + + return nodes, relay.view(B, H) + + +class _MSA1(nn.Module): + def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): + super(_MSA1, self).__init__() + # Multi-head Self Attention Case 1, doing self-attention for small regions + # Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small + self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) + + self.drop = nn.Dropout(dropout) + + # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) + self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 + + def forward(self, x, ax=None): + # x: B, H, L, 1, ax : B, H, X, L append features + nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size + B, H, L, _ = x.shape + + q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) + + if ax is not None: + aL = ax.shape[2] + ak = self.WK(ax).view(B, nhead, head_dim, aL, L) + av = self.WV(ax).view(B, nhead, head_dim, aL, L) + q = q.view(B, nhead, head_dim, 1, L) + k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ + .view(B, nhead, head_dim, unfold_size, L) + v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ + .view(B, nhead, head_dim, unfold_size, L) + if ax is not None: + k = torch.cat([k, ak], 3) + v = torch.cat([v, av], 3) + + alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / np.sqrt(head_dim), 3)) # B N L 1 U + att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) + + ret = self.WO(att) + + return ret + + +class _MSA2(nn.Module): + def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): + # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value + super(_MSA2, self).__init__() + self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) + + self.drop = nn.Dropout(dropout) + + # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) + self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 + + def forward(self, x, y, mask=None): + # x: B, H, 1, 1, 1 y: B H L 1 + nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size + B, H, L, _ = y.shape + + q, k, v = self.WQ(x), self.WK(y), self.WV(y) + + q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h + k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L + v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h + pre_a = torch.matmul(q, k) / np.sqrt(head_dim) + if mask is not None: + pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) + alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L + att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 + return self.WO(att) diff --git a/fastNLP/modules/torch/encoder/transformer.py b/fastNLP/modules/torch/encoder/transformer.py new file mode 100755 index 00000000..54884ff1 --- /dev/null +++ b/fastNLP/modules/torch/encoder/transformer.py @@ -0,0 +1,43 @@ +r"""undocumented""" + +__all__ = [ + "TransformerEncoder" +] + +from torch import nn + +from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer + + +class TransformerEncoder(nn.Module): + r""" + transformer的encoder模块,不包含embedding层 + + """ + def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): + """ + + :param int num_layers: 多少层Transformer + :param int d_model: input和output的大小 + :param int n_head: 多少个head + :param int dim_ff: FFN中间hidden大小 + :param float dropout: 多大概率drop attention和ffn中间的表示 + """ + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, + dropout = dropout) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward(self, x, seq_mask=None): + r""" + :param x: [batch, seq_len, model_size] 输入序列 + :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend + Default: ``None`` + :return: [batch, seq_len, model_size] 输出序列 + """ + output = x + if seq_mask is None: + seq_mask = x.new_ones(x.size(0), x.size(1)).bool() + for layer in self.layers: + output = layer(output, seq_mask) + return self.norm(output) diff --git a/fastNLP/modules/torch/encoder/variational_rnn.py b/fastNLP/modules/torch/encoder/variational_rnn.py new file mode 100755 index 00000000..32c48b13 --- /dev/null +++ b/fastNLP/modules/torch/encoder/variational_rnn.py @@ -0,0 +1,303 @@ +r"""undocumented +Variational RNN 及相关模型的 fastNLP实现,相关论文参考: +`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) `_ +""" + +__all__ = [ + "VarRNN", + "VarLSTM", + "VarGRU" +] + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence + +try: + from torch import flip +except ImportError: + def flip(x, dims): + indices = [slice(None)] * x.dim() + for dim in dims: + indices[dim] = torch.arange( + x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) + return x[tuple(indices)] + + +class VarRnnCellWrapper(nn.Module): + r""" + Wrapper for normal RNN Cells, make it support variational dropout + """ + + def __init__(self, cell, hidden_size, input_p, hidden_p): + super(VarRnnCellWrapper, self).__init__() + self.cell = cell + self.hidden_size = hidden_size + self.input_p = input_p + self.hidden_p = hidden_p + + def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): + r""" + :param PackedSequence input_x: [seq_len, batch_size, input_size] + :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] + for other RNN, h_0, [batch_size, hidden_size] + :param mask_x: [batch_size, input_size] dropout mask for input + :param mask_h: [batch_size, hidden_size] dropout mask for hidden + :return PackedSequence output: [seq_len, bacth_size, hidden_size] + hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] + for other RNN, h_n, [batch_size, hidden_size] + """ + + def get_hi(hi, h0, size): + h0_size = size - hi.size(0) + if h0_size > 0: + return torch.cat([hi, h0[:h0_size]], dim=0) + return hi[:size] + + is_lstm = isinstance(hidden, tuple) + input, batch_sizes = input_x.data, input_x.batch_sizes + output = [] + cell = self.cell + if is_reversed: + batch_iter = flip(batch_sizes, [0]) + idx = input.size(0) + else: + batch_iter = batch_sizes + idx = 0 + + if is_lstm: + hn = (hidden[0].clone(), hidden[1].clone()) + else: + hn = hidden.clone() + hi = hidden + for size in batch_iter: + if is_reversed: + input_i = input[idx - size: idx] * mask_x[:size] + idx -= size + else: + input_i = input[idx: idx + size] * mask_x[:size] + idx += size + mask_hi = mask_h[:size] + if is_lstm: + hx, cx = hi + hi = (get_hi(hx, hidden[0], size) * + mask_hi, get_hi(cx, hidden[1], size)) + hi = cell(input_i, hi) + hn[0][:size] = hi[0] + hn[1][:size] = hi[1] + output.append(hi[0]) + else: + hi = get_hi(hi, hidden, size) * mask_hi + hi = cell(input_i, hi) + hn[:size] = hi + output.append(hi) + + if is_reversed: + output = list(reversed(output)) + output = torch.cat(output, dim=0) + return PackedSequence(output, batch_sizes), hn + + +class VarRNNBase(nn.Module): + r""" + Variational Dropout RNN 实现. + + 论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) + https://arxiv.org/abs/1512.05287`. + + """ + + def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, + bias=True, batch_first=False, + input_dropout=0, hidden_dropout=0, bidirectional=False): + r""" + + :param mode: rnn 模式, (lstm or not) + :param Cell: rnn cell 类型, (lstm, gru, etc) + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + (batch, seq, feature). Default: ``False`` + :param input_dropout: 对输入的dropout概率. Default: 0 + :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` + """ + super(VarRNNBase, self).__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.input_dropout = input_dropout + self.hidden_dropout = hidden_dropout + self.bidirectional = bidirectional + self.num_directions = 2 if bidirectional else 1 + self._all_cells = nn.ModuleList() + for layer in range(self.num_layers): + for direction in range(self.num_directions): + input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions + cell = Cell(input_size, self.hidden_size, bias) + self._all_cells.append(VarRnnCellWrapper( + cell, self.hidden_size, input_dropout, hidden_dropout)) + self.is_lstm = (self.mode == "LSTM") + + def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): + is_lstm = self.is_lstm + idx = self.num_directions * n_layer + n_direction + cell = self._all_cells[idx] + hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] + output_x, hidden_x = cell( + input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) + return output_x, hidden_x + + def forward(self, x, hx=None): + r""" + + :param x: [batch, seq_len, input_size] 输入序列 + :param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` + :return (output, ht): [batch, seq_len, hidden_size*num_direction] 输出序列 + 和 [batch, hidden_size*num_direction] 最后时刻隐状态 + """ + is_lstm = self.is_lstm + is_packed = isinstance(x, PackedSequence) + if not is_packed: + seq_len = x.size(1) if self.batch_first else x.size(0) + max_batch_size = x.size(0) if self.batch_first else x.size(1) + seq_lens = torch.LongTensor( + [seq_len for _ in range(max_batch_size)]) + x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) + else: + max_batch_size = int(x.batch_sizes[0]) + x, batch_sizes = x.data, x.batch_sizes + + if hx is None: + hx = x.new_zeros(self.num_layers * self.num_directions, + max_batch_size, self.hidden_size, requires_grad=True) + if is_lstm: + hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) + + mask_x = x.new_ones((max_batch_size, self.input_size)) + mask_out = x.new_ones( + (max_batch_size, self.hidden_size * self.num_directions)) + mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) + nn.functional.dropout(mask_x, p=self.input_dropout, + training=self.training, inplace=True) + nn.functional.dropout(mask_out, p=self.hidden_dropout, + training=self.training, inplace=True) + + hidden = x.new_zeros( + (self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) + if is_lstm: + cellstate = x.new_zeros( + (self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) + for layer in range(self.num_layers): + output_list = [] + input_seq = PackedSequence(x, batch_sizes) + mask_h = nn.functional.dropout( + mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) + for direction in range(self.num_directions): + output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, + mask_x if layer == 0 else mask_out, mask_h) + output_list.append(output_x.data) + idx = self.num_directions * layer + direction + if is_lstm: + hidden[idx] = hidden_x[0] + cellstate[idx] = hidden_x[1] + else: + hidden[idx] = hidden_x + x = torch.cat(output_list, dim=-1) + + if is_lstm: + hidden = (hidden, cellstate) + + if is_packed: + output = PackedSequence(x, batch_sizes) + else: + x = PackedSequence(x, batch_sizes) + output, _ = pad_packed_sequence(x, batch_first=self.batch_first) + + return output, hidden + + +class VarLSTM(VarRNNBase): + r""" + Variational Dropout LSTM. + 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) `_ + + """ + + def __init__(self, *args, **kwargs): + r""" + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + (batch, seq, feature). Default: ``False`` + :param input_dropout: 对输入的dropout概率. Default: 0 + :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` + """ + super(VarLSTM, self).__init__( + mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) + + def forward(self, x, hx=None): + return super(VarLSTM, self).forward(x, hx) + + +class VarRNN(VarRNNBase): + r""" + Variational Dropout RNN. + 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) `_ + + """ + + def __init__(self, *args, **kwargs): + r""" + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + (batch, seq, feature). Default: ``False`` + :param input_dropout: 对输入的dropout概率. Default: 0 + :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` + """ + super(VarRNN, self).__init__( + mode="RNN", Cell=nn.RNNCell, *args, **kwargs) + + def forward(self, x, hx=None): + return super(VarRNN, self).forward(x, hx) + + +class VarGRU(VarRNNBase): + r""" + Variational Dropout GRU. + 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) `_ + + """ + + def __init__(self, *args, **kwargs): + r""" + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + (batch, seq, feature). Default: ``False`` + :param input_dropout: 对输入的dropout概率. Default: 0 + :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` + """ + super(VarGRU, self).__init__( + mode="GRU", Cell=nn.GRUCell, *args, **kwargs) + + def forward(self, x, hx=None): + return super(VarGRU, self).forward(x, hx) diff --git a/fastNLP/modules/torch/generator/__init__.py b/fastNLP/modules/torch/generator/__init__.py new file mode 100755 index 00000000..2dfc4000 --- /dev/null +++ b/fastNLP/modules/torch/generator/__init__.py @@ -0,0 +1,6 @@ +__all__ = [ + 'SequenceGenerator' +] + + +from .seq2seq_generator import SequenceGenerator \ No newline at end of file diff --git a/fastNLP/modules/torch/generator/seq2seq_generator.py b/fastNLP/modules/torch/generator/seq2seq_generator.py new file mode 100755 index 00000000..cf9c5306 --- /dev/null +++ b/fastNLP/modules/torch/generator/seq2seq_generator.py @@ -0,0 +1,536 @@ +r""" + +""" + +__all__ = [ + 'SequenceGenerator' +] + +import torch +from torch import nn +import torch.nn.functional as F +from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State +from functools import partial + + +def _get_model_device(model): + r""" + 传入一个nn.Module的模型,获取它所在的device + + :param model: nn.Module + :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 + """ + assert isinstance(model, nn.Module) + + parameters = list(model.parameters()) + if len(parameters) == 0: + return None + else: + return parameters[0].device + + + +class SequenceGenerator: + """ + 给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出, + 第二个参数为state。SequenceGenerator不会对state进行任何操作。 + + """ + def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, + do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, + repetition_penalty=1, length_penalty=1.0, pad_token_id=0): + """ + + :param Seq2SeqDecoder decoder: Decoder对象 + :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len + :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask + :param int num_beams: beam search的大小 + :param bool do_sample: 是否通过采样的方式生成 + :param float temperature: 只有在do_sample为True才有意义 + :param int top_k: 只从top_k中采样 + :param float top_p: 只从top_p的token中采样,nucles sample + :param int,None bos_token_id: 句子开头的token id + :param int,None eos_token_id: 句子结束的token id + :param float repetition_penalty: 多大程度上惩罚重复的token + :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 + :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 + """ + if do_sample: + self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, + temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, + eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, + length_penalty=length_penalty, pad_token_id=pad_token_id) + else: + self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, + bos_token_id=bos_token_id, eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, pad_token_id=pad_token_id) + self.do_sample = do_sample + self.max_length = max_length + self.num_beams = num_beams + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.repetition_penalty = repetition_penalty + self.length_penalty = length_penalty + self.decoder = decoder + + @torch.no_grad() + def generate(self, state, tokens=None): + """ + + :param State state: encoder结果的State, 是与Decoder配套是用的 + :param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token + 进行生成。 + :return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id + """ + + return self.generate_func(tokens=tokens, state=state) + + +@torch.no_grad() +def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, + bos_token_id=None, eos_token_id=None, pad_token_id=0, + repetition_penalty=1, length_penalty=1.0): + """ + 贪婪地搜索句子 + + :param Decoder decoder: Decoder对象 + :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 + :param State state: 应该包含encoder的一些输出。 + :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len + :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask + :param int num_beams: 使用多大的beam进行解码。 + :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 + :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 + :param int pad_token_id: pad的token id + :param float repetition_penalty: 对重复出现的token多大的惩罚。 + :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 + :return: + """ + if num_beams == 1: + token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, + temperature=1, top_k=50, top_p=1, + bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, + repetition_penalty=repetition_penalty, length_penalty=length_penalty, + pad_token_id=pad_token_id) + else: + token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, temperature=1, top_k=50, top_p=1, + bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, + repetition_penalty=repetition_penalty, length_penalty=length_penalty, + pad_token_id=pad_token_id) + + return token_ids + + +@torch.no_grad() +def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, temperature=1.0, top_k=50, + top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, + length_penalty=1.0): + """ + 使用采样的方法生成句子 + + :param Decoder decoder: Decoder对象 + :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 + :param State state: 应该包含encoder的一些输出。 + :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len + :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask + :param int num_beam: 使用多大的beam进行解码。 + :param float temperature: 采样时的退火大小 + :param int top_k: 只在top_k的sample里面采样 + :param float top_p: 介于0,1的值。 + :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 + :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 + :param int pad_token_id: pad的token id + :param float repetition_penalty: 对重复出现的token多大的惩罚。 + :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 + :return: + """ + # 每个位置在生成的时候会sample生成 + if num_beams == 1: + token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, + temperature=temperature, top_k=top_k, top_p=top_p, + bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, + repetition_penalty=repetition_penalty, length_penalty=length_penalty, + pad_token_id=pad_token_id) + else: + token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, + bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, + repetition_penalty=repetition_penalty, length_penalty=length_penalty, + pad_token_id=pad_token_id) + return token_ids + + +def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, max_len_a=0.0, temperature=1.0, top_k=50, + top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, + repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): + device = _get_model_device(decoder) + if tokens is None: + if bos_token_id is None: + raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") + batch_size = state.num_samples + if batch_size is None: + raise RuntimeError("Cannot infer the number of samples from `state`.") + tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) + batch_size = tokens.size(0) + if state.num_samples: + assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." + + if eos_token_id is None: + _eos_token_id = -1 + else: + _eos_token_id = eos_token_id + + scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state + if _eos_token_id!=-1: # 防止第一个位置为结束 + scores[:, _eos_token_id] = -1e12 + next_tokens = scores.argmax(dim=-1, keepdim=True) + token_ids = torch.cat([tokens, next_tokens], dim=1) + cur_len = token_ids.size(1) + dones = token_ids.new_zeros(batch_size).eq(1) + # tokens = tokens[:, -1:] + + if max_len_a!=0: + # (bsz x num_beams, ) + if state.encoder_mask is not None: + max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length + else: + max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) + real_max_length = max_lengths.max().item() + else: + real_max_length = max_length + if state.encoder_mask is not None: + max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length + else: + max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) + + while cur_len < real_max_length: + scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size + + if repetition_penalty != 1.0: + token_scores = scores.gather(dim=1, index=token_ids) + lt_zero_mask = token_scores.lt(0).float() + ge_zero_mask = lt_zero_mask.eq(0).float() + token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores + scores.scatter_(dim=1, index=token_ids, src=token_scores) + + if eos_token_id is not None and length_penalty != 1.0: + token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size + eos_mask = scores.new_ones(scores.size(1)) + eos_mask[eos_token_id] = 0 + eos_mask = eos_mask.unsqueeze(0).eq(1) + scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 + + if do_sample: + if temperature > 0 and temperature != 1: + scores = scores / temperature + + scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) + # 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 + probs = F.softmax(scores, dim=-1) + 1e-12 + + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size + else: + next_tokens = torch.argmax(scores, dim=-1) # batch_size + + # 如果已经达到对应的sequence长度了,就直接填为eos了 + if _eos_token_id!=-1: + next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) + next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding + tokens = next_tokens.unsqueeze(1) + + token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len + + end_mask = next_tokens.eq(_eos_token_id) + dones = dones.__or__(end_mask) + cur_len += 1 + + if dones.min() == 1: + break + + # if eos_token_id is not None: + # tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # 将最大长度位置设置为eos + # if cur_len == max_length: + # token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos + return token_ids + + +def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, temperature=1.0, + top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, + repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: + # 进行beam search + device = _get_model_device(decoder) + if tokens is None: + if bos_token_id is None: + raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") + batch_size = state.num_samples + if batch_size is None: + raise RuntimeError("Cannot infer the number of samples from `state`.") + tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) + batch_size = tokens.size(0) + if state.num_samples: + assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." + + if eos_token_id is None: + _eos_token_id = -1 + else: + _eos_token_id = eos_token_id + + scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 + if _eos_token_id!=-1: # 防止第一个位置为结束 + scores[:, _eos_token_id] = -1e12 + vocab_size = scores.size(1) + assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." + + if do_sample: + probs = F.softmax(scores, dim=-1) + 1e-12 + next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) + logits = probs.log() + next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) + else: + scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) + # 得到(batch_size, num_beams), (batch_size, num_beams) + next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) + + # 根据index来做顺序的调转 + indices = torch.arange(batch_size, dtype=torch.long).to(device) + indices = indices.repeat_interleave(num_beams) + state.reorder_state(indices) + + tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length + # 记录生成好的token (batch_size', cur_len) + token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) + dones = [False] * batch_size + + beam_scores = next_scores.view(-1) # batch_size * num_beams + + # 用来记录已经生成好的token的长度 + cur_len = token_ids.size(1) + + if max_len_a!=0: + # (bsz x num_beams, ) + if state.encoder_mask is not None: + max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length + else: + max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) + real_max_length = max_lengths.max().item() + else: + real_max_length = max_length + if state.encoder_mask is not None: + max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length + else: + max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) + hypos = [ + BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) + ] + # 0, num_beams, 2*num_beams, ... + batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) + + while cur_len < real_max_length: + scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size) + if repetition_penalty != 1.0: + token_scores = scores.gather(dim=1, index=token_ids) + lt_zero_mask = token_scores.lt(0).float() + ge_zero_mask = lt_zero_mask.eq(0).float() + token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores + scores.scatter_(dim=1, index=token_ids, src=token_scores) + + if _eos_token_id!=-1: + max_len_eos_mask = max_lengths.eq(cur_len+1) + eos_scores = scores[:, _eos_token_id] + # 如果已经达到最大长度,就把eos的分数加大 + scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores) + + if do_sample: + if temperature > 0 and temperature != 1: + scores = scores / temperature + + # 多召回一个防止eos + scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) + # 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 + probs = F.softmax(scores, dim=-1) + 1e-12 + + # 保证至少有一个不是eos的值 + _tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) + + logits = probs.log() + # 防止全是这个beam的被选中了,且需要考虑eos被选择的情况 + _scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1) + _scores = _scores + beam_scores[:, None] # batch_size' x (num_beams+1) + # 从这里面再选择top的2*num_beam个 + _scores = _scores.view(batch_size, num_beams * (num_beams + 1)) + next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) + _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) + next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) + from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams) + else: + scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) + _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) + _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) + next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) + from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) + next_tokens = ids % vocab_size # (batch_size, 2*num_beams) + + # 接下来需要组装下一个batch的结果。 + # 需要选定哪些留下来 + # next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) + # next_tokens = next_tokens.gather(dim=1, index=sorted_inds) + # from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) + + not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos + keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 + keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 + + _next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1) + _from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) # 上面的token是来自哪个beam + _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) + beam_scores = _next_scores.view(-1) + + flag = True + if cur_len+1 == real_max_length: + eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) + eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice + eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 + else: + # 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 + effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams + if effective_eos_mask.sum().gt(0): + eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) + # 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams + eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind + eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos + else: + flag = False + + if flag: + _token_ids = torch.cat([token_ids, _next_tokens], dim=-1) + for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), + eos_beam_idx.tolist()): + if not dones[batch_idx]: + score = next_scores[batch_idx, beam_ind].item() + # 之后需要在结尾新增一个eos + if _eos_token_id!=-1: + hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) + else: + hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score) + + # 更改state状态, 重组token_ids + reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 + state.reorder_state(reorder_inds) + # 重新组织token_ids的状态 + token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1) + + for batch_idx in range(batch_size): + dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ + max_lengths[batch_idx*num_beams]==cur_len+1 + + cur_len += 1 + + if all(dones): + break + + # select the best hypotheses + tgt_len = token_ids.new_zeros(batch_size) + best = [] + + for i, hypotheses in enumerate(hypos): + best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] + # 把上面替换为非eos的词替换回eos + if _eos_token_id!=-1: + best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id]) + tgt_len[i] = len(best_hyp) + best.append(best_hyp) + + # generate target batch + decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id) + for i, hypo in enumerate(best): + decoded[i, :tgt_len[i]] = hypo + + return decoded + + +class BeamHypotheses(object): + def __init__(self, num_beams, max_length, length_penalty, early_stopping): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.hyp = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.hyp) + + def add(self, hyp, sum_logprobs): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / len(hyp) ** self.length_penalty + if len(self) < self.num_beams or score > self.worst_score: + self.hyp.append((score, hyp)) + if len(self) > self.num_beams: + sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) + del self.hyp[sorted_scores[0][1]] + self.worst_score = sorted_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs): + """ + If there are enough hypotheses and that none of the hypotheses being generated + can become better than the worst one in the heap, then we are done with this sentence. + """ + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty + + +def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): + """ + 根据top_k, top_p的值,将不满足的值置为filter_value的值 + + :param torch.Tensor logits: bsz x vocab_size + :param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value + :param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式 + :param float filter_value: + :param int min_tokens_to_keep: 每个sample返回的分布中有概率的词不会低于这个值 + :return: + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits diff --git a/tests/helpers/data/modules/decoder/crf.json b/tests/helpers/data/modules/decoder/crf.json new file mode 100755 index 00000000..ff2d6689 --- /dev/null +++ b/tests/helpers/data/modules/decoder/crf.json @@ -0,0 +1 @@ +{"bio_logits": [[[-1.8154915571212769, -1.3753865957260132, -10001.513671875, -1.619813084602356, -10001.79296875], [-1.742034673690796, -1.5048011541366577, -2.042131185531616, -1.2594754695892334, -1.6648437976837158], [-1.5522804260253906, -1.2926381826400757, -1.8607124090194702, -1.6692707538604736, -1.7734650373458862], [-1.6101375818252563, -1.3285458087921143, -1.7735439538955688, -1.5734118223190308, -1.8438279628753662], [-1.6522153615951538, -1.2640260457992554, -1.9092718362808228, -1.6192445755004883, -1.7168875932693481], [-1.4932769536972046, -1.4628725051879883, -1.9623159170150757, -1.497014045715332, -1.7177777290344238], [-1.8419824838638306, -2.1428799629211426, -1.4285861253738403, -1.2972710132598877, -1.5546820163726807], [-1.671349048614502, -1.4115079641342163, -1.624293565750122, -1.537371277809143, -1.8563929796218872], [-1.5080815553665161, -1.3281997442245483, -1.7912147045135498, -1.5656323432922363, -1.980512022972107], [-2.0562098026275635, -1.4711416959762573, -1.5297126770019531, -1.7554184198379517, -1.3744999170303345]], [[-1.3193378448486328, -1.997290849685669, -10002.0751953125, -1.3334847688674927, -10001.5712890625], [-1.229069471359253, -1.2702847719192505, -2.0717740058898926, -1.9828989505767822, -1.8136863708496094], [-1.8161871433258057, -1.4339262247085571, -1.4476666450500488, -1.8693819046020508, -1.562330722808838], [-1.897119402885437, -1.5767627954483032, -1.54145348072052, -1.6185026168823242, -1.4649395942687988], [-1.8498220443725586, -1.264282464981079, -1.7192784547805786, -1.8041315078735352, -1.530255913734436], [-1.1517643928527832, -1.6473538875579834, -1.5833101272583008, -1.9973593950271606, -1.894622802734375], [-1.7796387672424316, -1.8036197423934937, -1.2666513919830322, -1.4641741514205933, -1.8736846446990967], [-1.555580496788025, -1.5448863506317139, -1.609066128730774, -1.5487936735153198, -1.8138916492462158], [-1.8701002597808838, -2.0567376613616943, -1.6318782567977905, -1.2336504459381104, -1.4643338918685913], [-1.6615228652954102, -1.9764257669448853, -1.277781367301941, -1.3614437580108643, -1.990394949913025]], [[-1.74202299118042, -1.659791111946106, -10001.9951171875, -1.0417697429656982, -10001.9248046875], [-1.2423228025436401, -1.7404581308364868, -1.7569608688354492, -1.5077661275863647, -1.9528108835220337], [-1.7840592861175537, -1.50230872631073, -1.4460601806640625, -1.9473626613616943, -1.4641118049621582], [-1.6109998226165771, -2.0336639881134033, -1.3807575702667236, -1.221280574798584, -2.0938124656677246], [-1.8956525325775146, -1.6966334581375122, -1.8089725971221924, -1.9510140419006348, -1.020185947418213], [-1.7131900787353516, -1.7260419130325317, -2.161870241165161, -1.2767468690872192, -1.3956587314605713], [-1.7567639350891113, -1.1352611780166626, -1.7109652757644653, -1.8825695514678955, -1.7534843683242798], [-1.826012372970581, -1.9964908361434937, -1.7898284196853638, -1.2279980182647705, -1.413594365119934], [-1.522060513496399, -1.56121826171875, -1.5711766481399536, -1.4620665311813354, -2.0226776599884033], [-1.3122025728225708, -2.0931777954101562, -1.8858696222305298, -1.831908106803894, -1.2184979915618896]], [[-1.3956559896469116, -1.8315693140029907, -10001.48046875, -1.844576358795166, -10001.5771484375], [-1.562046766281128, -1.7216087579727173, -1.5044764280319214, -1.4362742900848389, -1.8867106437683105], [-1.5304349660873413, -1.5527287721633911, -1.5590341091156006, -1.6369349956512451, -1.7899152040481567], [-1.6007282733917236, -2.054649829864502, -1.9757367372512817, -1.4219664335250854, -1.2371348142623901], [-1.841418981552124, -1.8178046941757202, -1.5939710140228271, -1.2179311513900757, -1.7144266366958618], [-1.6715152263641357, -1.5060933828353882, -1.6629694700241089, -1.633326530456543, -1.5827515125274658], [-1.9413940906524658, -1.853175163269043, -1.6390701532363892, -1.2217824459075928, -1.5564061403274536], [-1.746218204498291, -1.7089520692825317, -1.6738371849060059, -1.627657175064087, -1.344780445098877], [-1.1776174306869507, -1.629957675933838, -1.79096519947052, -1.7566864490509033, -1.853833556175232], [-1.4880272150039673, -1.4722591638565063, -1.631064534187317, -1.9562634229660034, -1.5718109607696533]]], "bio_scores": [-1.3754, -4.5403, -8.7047, -12.8693], "bio_path": [[1], [3, 0, 1, 1], [3, 0, 1, 3, 4, 3, 1, 3], [0, 1, 1, 0, 3, 0, 3, 0, 1, 0]], "bio_trans_m": [[-0.095858134329319, 0.01011368352919817, -0.33539193868637085, -0.20200660824775696, 0.136741504073143], [0.5436117649078369, 0.37222158908843994, -0.15174923837184906, 0.10455792397260666, -0.35702475905418396], [0.3681447505950928, -0.6996435523033142, -0.002348324516788125, 0.5087339282035828, -0.08750446885824203], [0.6505969762802124, 0.0064192176796495914, -0.10901711881160736, -0.24849674105644226, -0.1375938355922699], [-0.019853945821523666, -0.9098508954048157, 0.06740495562553406, 0.2244909256696701, -0.29204151034355164]], "bio_seq_lens": [1, 4, 8, 10], "bmes_logits": [[[-10002.5830078125, -20002.54296875, -10001.9765625, -2.033155679702759, -10001.712890625, -20001.68359375, -10002.4130859375, -2.1159744262695312], [-1.870416283607483, -2.2075278759002686, -1.9922529458999634, -2.1696650981903076, -2.4956214427948, -2.1040704250335693, -2.065218925476074, -1.869700312614441], [-1.8947919607162476, -2.398089647293091, -2.1316606998443604, -1.6458176374435425, -2.001098871231079, -2.362668514251709, -2.513232707977295, -1.9884836673736572], [-1.5058399438858032, -2.3359181880950928, -2.382275342941284, -2.4573683738708496, -1.7870502471923828, -2.342841148376465, -2.1982951164245605, -2.0483522415161133], [-2.0845396518707275, -2.0447516441345215, -1.7635326385498047, -1.9375617504119873, -2.530120611190796, -1.8380637168884277, -2.099860906600952, -2.666682481765747], [-2.299673557281494, -2.3165550231933594, -1.9403637647628784, -1.8729832172393799, -1.8798956871032715, -1.8799573183059692, -2.2314014434814453, -2.39471173286438], [-1.9613308906555176, -2.136000633239746, -2.1178860664367676, -2.1553683280944824, -1.7840471267700195, -2.4148807525634766, -2.4621479511260986, -1.817263126373291], [-2.056917428970337, -2.5026133060455322, -1.9233015775680542, -2.0078444480895996, -2.064028024673462, -1.776533842086792, -2.3748488426208496, -2.114560127258301], [-2.3671767711639404, -1.7896978855133057, -2.416537284851074, -2.26574444770813, -2.2460145950317383, -1.7739624977111816, -1.9555294513702393, -2.045677661895752], [-2.3571174144744873, -1.820650577545166, -2.2781612873077393, -1.9325084686279297, -1.863953948020935, -2.2260994911193848, -2.5020244121551514, -1.8891260623931885]], [[-2.0461926460266113, -10002.0625, -10001.712890625, -2.251368761062622, -2.2985825538635254, -10002.146484375, -10002.0185546875, -2.225799560546875], [-1.9879356622695923, -2.4706358909606934, -2.3151662349700928, -1.5818747282028198, -2.329188346862793, -2.1170380115509033, -2.159011125564575, -1.9593485593795776], [-2.2397706508636475, -2.2388737201690674, -1.826286792755127, -2.444268226623535, -1.7793290615081787, -2.402519941329956, -1.8540253639221191, -2.09319806098938], [-1.7938345670700073, -2.525993585586548, -1.9962739944458008, -1.9414381980895996, -2.5183513164520264, -2.5057737827301025, -1.7933388948440552, -1.925837755203247], [-2.2330663204193115, -2.098536491394043, -1.9872602224349976, -1.7660422325134277, -2.5269722938537598, -1.9648237228393555, -1.80750572681427, -2.551790475845337], [-1.802718162536621, -2.4936702251434326, -1.846991777420044, -2.6299049854278564, -1.8180453777313232, -2.010246992111206, -1.9285591840744019, -2.5121750831604004], [-1.7665618658065796, -2.2445054054260254, -1.822519063949585, -2.5471863746643066, -2.719733715057373, -1.9708809852600098, -1.7871110439300537, -2.2026400566101074], [-2.2046854496002197, -2.375577926635742, -1.9162014722824097, -2.397550344467163, -1.9547137022018433, -1.759222149848938, -1.818831443786621, -2.4931435585021973], [-1.9187703132629395, -2.5046753883361816, -1.871201515197754, -2.3421711921691895, -2.372368335723877, -1.883248209953308, -1.8868682384490967, -2.0830271244049072], [-2.406679630279541, -1.7564219236373901, -2.340674877166748, -1.8392919301986694, -2.3711328506469727, -1.913435935974121, -2.221808433532715, -2.019878625869751]], [[-10001.7607421875, -20002.30078125, -10001.9677734375, -1.7931804656982422, -10002.2451171875, -20002.15234375, -10002.208984375, -2.4127495288848877], [-2.162931442260742, -2.121459484100342, -2.4020097255706787, -2.5620131492614746, -1.7713403701782227, -2.1945695877075195, -1.8392865657806396, -1.8513271808624268], [-2.2151875495910645, -1.9279260635375977, -2.24403977394104, -2.1955597400665283, -2.2283377647399902, -1.7366830110549927, -2.634793519973755, -1.757084608078003], [-1.813708782196045, -1.93169105052948, -2.2419192790985107, -2.307635545730591, -2.19914174079895, -2.070988178253174, -2.0030927658081055, -2.1678688526153564], [-2.118651866912842, -1.867727518081665, -2.312565326690674, -2.274792194366455, -1.9973562955856323, -2.000102996826172, -1.8425841331481934, -2.3635623455047607], [-2.435579538345337, -1.7167878150939941, -2.3040761947631836, -1.657408595085144, -2.462364912033081, -2.2767324447631836, -1.7957141399383545, -2.425132989883423], [-1.806656837463379, -1.7759110927581787, -2.5295629501342773, -1.9216285943984985, -2.2615668773651123, -1.8556532859802246, -2.4842538833618164, -2.3384106159210205], [-1.9859262704849243, -1.6575560569763184, -2.2854154109954834, -1.9267034530639648, -2.5214226245880127, -2.0166244506835938, -2.479127883911133, -2.0595011711120605], [-2.0371243953704834, -2.2420313358306885, -2.0946967601776123, -2.2463889122009277, -1.8954271078109741, -1.942257285118103, -2.0445871353149414, -2.1946396827697754], [-2.0210611820220947, -2.362877130508423, -1.9862446784973145, -1.8275481462478638, -2.140009880065918, -1.869648814201355, -2.6818318367004395, -2.0021097660064697]], [[-1.986312985420227, -10002.50390625, -10002.0361328125, -1.908732295036316, -2.21740984916687, -10002.1318359375, -10002.1044921875, -1.87873113155365], [-1.9292036294937134, -2.163956880569458, -2.3703503608703613, -1.939669132232666, -1.8776776790618896, -2.4469380378723145, -2.423905611038208, -1.7453217506408691], [-2.0289347171783447, -2.520860195159912, -2.5013701915740967, -2.078547477722168, -1.9699862003326416, -1.8206181526184082, -1.7796630859375, -2.1984922885894775], [-1.8523262739181519, -1.978093147277832, -2.558772087097168, -2.498471260070801, -1.9756053686141968, -1.8080697059631348, -1.9115748405456543, -2.357147216796875], [-2.314960479736328, -2.2433876991271973, -1.6113512516021729, -2.19716477394104, -1.78402578830719, -2.343987226486206, -2.3425848484039307, -2.084155797958374], [-2.002289056777954, -2.2630276679992676, -1.887984275817871, -2.044983386993408, -2.217646360397339, -1.9103771448135376, -2.154231548309326, -2.2321436405181885], [-2.199540853500366, -2.063075065612793, -1.813851237297058, -2.3199379444122314, -1.7984188795089722, -2.4952447414398193, -2.4516515731811523, -1.7922154664993286], [-2.509786367416382, -1.79443359375, -1.8561275005340576, -2.2977330684661865, -2.2080044746398926, -1.7294546365737915, -2.4617154598236084, -2.0944302082061768], [-2.491340160369873, -2.403804063796997, -1.8452543020248413, -1.6882175207138062, -2.5513625144958496, -2.294516086578369, -1.9522627592086792, -1.8124374151229858], [-2.1524035930633545, -2.2049806118011475, -2.3353655338287354, -2.317572832107544, -2.2914233207702637, -1.8211665153503418, -1.69517982006073, -2.0270023345947266]]], "bmes_scores": [-2.0332, -6.1623, -1.7932, -16.7561], "bmes_path": [[3], [7, 3, 4, 6], [3], [3, 4, 5, 6, 7, 3, 4, 5, 6, 7]], "bmes_trans_m": [[0.47934335470199585, -0.2151593416929245, -0.12467780709266663, -0.44244644045829773, 0.16480575501918793, -0.006573359947651625, -1.187401294708252, -0.17424514889717102], [-0.03494556248188019, -0.8173441290855408, -0.2682552933692932, 0.18933893740177155, 0.2203899323940277, 0.3905894160270691, -0.007638207171112299, 0.19527725875377655], [-0.2779119908809662, -0.37053248286247253, 0.34394705295562744, -0.26433902978897095, -0.0001995275670196861, -0.39156094193458557, -0.035449881106615067, 0.02454843744635582], [-0.01391045656055212, 0.3419516384601593, -0.48559853434562683, -0.5893992781639099, 0.9119477272033691, 0.1731061041355133, -0.15039317309856415, 0.1523006409406662], [0.4866299033164978, 0.28264448046684265, -0.25895795226097107, 0.0404033362865448, -0.060920555144548416, 0.12364576756954193, 0.1294233351945877, 0.2434755265712738], [-0.04159824922680855, 0.25353407859802246, 0.12913571298122406, -0.036356933414936066, -0.18522876501083374, -0.5329958200454712, 0.2505933344364166, 0.26512718200683594], [-0.2509276270866394, 0.3572998046875, 0.01873799040913582, -0.30620086193084717, -0.09893298894166946, -0.37399813532829285, -0.6530448198318481, -0.17514197528362274], [-0.29702028632164, 0.680363118648529, -0.6010262370109558, 0.17669369280338287, 0.45010149478912354, -0.1026386097073555, 0.34120017290115356, -0.04910941794514656]], "bmes_seq_lens": [1, 4, 1, 10]} \ No newline at end of file diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/torch/__init__.py b/tests/models/torch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/torch/model_runner.py b/tests/models/torch/model_runner.py new file mode 100755 index 00000000..e6bb56d2 --- /dev/null +++ b/tests/models/torch/model_runner.py @@ -0,0 +1,142 @@ +""" +此模块可以非常方便的测试模型。 +若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试 +若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试 + +此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能 + +Example:: + + # import 全大写变量... + from model_runner import * + + # 测试一个文本分类模型 + init_emb = (VOCAB_SIZE, 50) + model = SomeModel(init_emb, num_cls=NUM_CLS) + RUNNER.run_model_with_task(TEXT_CLS, model) + + # 序列标注模型 + RUNNER.run_model_with_task(POS_TAGGING, model) + + # NLI模型 + RUNNER.run_model_with_task(NLI, model) + + # 自定义模型 + RUNNER.run_model(model, data=get_mydata(), + loss=Myloss(), metrics=Mymetric()) +""" +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch import optim +from fastNLP import Trainer, Evaluator, DataSet, Callback +from fastNLP import Accuracy +from random import randrange +from fastNLP import TorchDataLoader + +VOCAB_SIZE = 100 +NUM_CLS = 100 +MAX_LEN = 10 +N_SAMPLES = 100 +N_EPOCHS = 1 +BATCH_SIZE = 5 + +TEXT_CLS = 'text_cls' +POS_TAGGING = 'pos_tagging' +NLI = 'nli' + +class ModelRunner(): + class Checker(Callback): + def on_backward_begin(self, trainer, outputs): + assert outputs['loss'].to('cpu').numpy().isfinate() + + def gen_seq(self, length, vocab_size): + """generate fake sequence indexes with given length""" + # reserve 0 for padding + return [randrange(1, vocab_size) for _ in range(length)] + + def gen_var_seq(self, max_len, vocab_size): + """generate fake sequence indexes in variant length""" + length = randrange(3, max_len) # at least 3 words in a seq + return self.gen_seq(length, vocab_size) + + def prepare_text_classification_data(self): + index = 'index' + ds = DataSet({index: list(range(N_SAMPLES))}) + ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name='words') + ds.apply_field(lambda x: randrange(NUM_CLS), + field_name=index, new_field_name='target') + ds.apply_field(len, 'words', 'seq_len') + dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) + return dl + + def prepare_pos_tagging_data(self): + index = 'index' + ds = DataSet({index: list(range(N_SAMPLES))}) + ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name='words') + ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS), + field_name='words', new_field_name='target') + ds.apply_field(len, 'words', 'seq_len') + dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) + return dl + + def prepare_nli_data(self): + index = 'index' + ds = DataSet({index: list(range(N_SAMPLES))}) + ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name='words1') + ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name='words2') + ds.apply_field(lambda x: randrange(NUM_CLS), + field_name=index, new_field_name='target') + ds.apply_field(len, 'words1', 'seq_len1') + ds.apply_field(len, 'words2', 'seq_len2') + dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) + return dl + + def run_text_classification(self, model, data=None): + if data is None: + data = self.prepare_text_classification_data() + metric = Accuracy() + self.run_model(model, data, metric) + + def run_pos_tagging(self, model, data=None): + if data is None: + data = self.prepare_pos_tagging_data() + metric = Accuracy() + self.run_model(model, data, metric) + + def run_nli(self, model, data=None): + if data is None: + data = self.prepare_nli_data() + metric = Accuracy() + self.run_model(model, data, metric) + + def run_model(self, model, data, metrics): + """run a model, test if it can run with fastNLP""" + print('testing model:', model.__class__.__name__) + tester = Evaluator(model, data, metrics={'metric': metrics}, driver='torch') + before_train = tester.run() + optimizer = optim.SGD(model.parameters(), lr=1e-3) + trainer = Trainer(model, driver='torch', train_dataloader=data, + n_epochs=N_EPOCHS, optimizers=optimizer) + trainer.run() + after_train = tester.run() + for metric_name, v1 in before_train.items(): + assert metric_name in after_train + # # at least we can sure model params changed, even if we don't know performance + # v2 = after_train[metric_name] + # assert v1 != v2 + + def run_model_with_task(self, task, model): + """run a model with certain task""" + TASKS = { + TEXT_CLS: self.run_text_classification, + POS_TAGGING: self.run_pos_tagging, + NLI: self.run_nli, + } + assert task in TASKS + TASKS[task](model) + +RUNNER = ModelRunner() diff --git a/tests/models/torch/test_biaffine_parser.py b/tests/models/torch/test_biaffine_parser.py new file mode 100755 index 00000000..dc72360d --- /dev/null +++ b/tests/models/torch/test_biaffine_parser.py @@ -0,0 +1,91 @@ +import pytest +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from fastNLP.models.torch.biaffine_parser import BiaffineParser +from fastNLP import Metric, seq_len_to_mask +from .model_runner import * + + +class ParserMetric(Metric): + r""" + 评估parser的性能 + + """ + + def __init__(self): + super().__init__() + self.num_arc = 0 + self.num_label = 0 + self.num_sample = 0 + + def get_metric(self, reset=True): + res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample} + if reset: + self.num_sample = self.num_label = self.num_arc = 0 + return res + + def update(self, pred1, pred2, target1, target2, seq_len=None): + r""" + + :param pred1: 边预测logits + :param pred2: label预测logits + :param target1: 真实边的标注 + :param target2: 真实类别的标注 + :param seq_len: 序列长度 + :return dict: 评估结果:: + + UAS: 不带label时, 边预测的准确率 + LAS: 同时预测边和label的准确率 + """ + if seq_len is None: + seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) + else: + seq_mask = seq_len_to_mask(seq_len.long()).long() + # mask out tag + seq_mask[:, 0] = 0 + head_pred_correct = (pred1 == target1).long() * seq_mask + label_pred_correct = (pred2 == target2).long() * head_pred_correct + self.num_arc += head_pred_correct.sum().item() + self.num_label += label_pred_correct.sum().item() + self.num_sample += seq_mask.sum().item() + + +def prepare_parser_data(): + index = 'index' + ds = DataSet({index: list(range(N_SAMPLES))}) + ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name='words1') + ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), + field_name='words1', new_field_name='words2') + # target1 is heads, should in range(0, len(words)) + ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)), + field_name='words1', new_field_name='target1') + ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), + field_name='words1', new_field_name='target2') + ds.apply_field(len, field_name='words1', new_field_name='seq_len') + dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) + return dl + + +@pytest.mark.torch +class TestBiaffineParser: + def test_train(self): + model = BiaffineParser(embed=(VOCAB_SIZE, 10), + pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, + rnn_hidden_size=10, + arc_mlp_size=10, + label_mlp_size=10, + num_label=NUM_CLS, encoder='var-lstm') + ds = prepare_parser_data() + RUNNER.run_model(model, ds, metrics=ParserMetric()) + + def test_train2(self): + model = BiaffineParser(embed=(VOCAB_SIZE, 10), + pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, + rnn_hidden_size=16, + arc_mlp_size=10, + label_mlp_size=10, + num_label=NUM_CLS, encoder='transformer') + ds = prepare_parser_data() + RUNNER.run_model(model, ds, metrics=ParserMetric()) diff --git a/tests/models/torch/test_cnn_text_classification.py b/tests/models/torch/test_cnn_text_classification.py new file mode 100755 index 00000000..bbe1b718 --- /dev/null +++ b/tests/models/torch/test_cnn_text_classification.py @@ -0,0 +1,33 @@ +import pytest + +from .model_runner import * +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from fastNLP.models.torch.cnn_text_classification import CNNText + + +@pytest.mark.torch +class TestCNNText: + def init_model(self, kernel_sizes, kernel_nums=(1,3,5)): + model = CNNText((VOCAB_SIZE, 30), + NUM_CLS, + kernel_nums=kernel_nums, + kernel_sizes=kernel_sizes) + return model + + def test_case1(self): + # 测试能否正常运行CNN + model = self.init_model((1,3,5)) + RUNNER.run_model_with_task(TEXT_CLS, model) + + def test_init_model(self): + with pytest.raises(Exception): + self.init_model(2, 4) + with pytest.raises(Exception): + self.init_model((2,)) + + def test_output(self): + model = self.init_model((3,), (1,)) + global MAX_LEN + MAX_LEN = 2 + RUNNER.run_model_with_task(TEXT_CLS, model) diff --git a/tests/models/torch/test_seq2seq_generator.py b/tests/models/torch/test_seq2seq_generator.py new file mode 100755 index 00000000..33c3f85a --- /dev/null +++ b/tests/models/torch/test_seq2seq_generator.py @@ -0,0 +1,73 @@ +import pytest + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + from fastNLP.models.torch import LSTMSeq2SeqModel, TransformerSeq2SeqModel + import torch + from fastNLP.embeddings.torch import StaticEmbedding + +from fastNLP import Vocabulary, DataSet +from fastNLP import Trainer, Accuracy +from fastNLP import Callback, TorchDataLoader + + +def prepare_env(): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) + + src_words_idx = [[3, 1, 2], [1, 2]] + # tgt_words_idx = [[1, 2, 3, 4], [2, 3]] + src_seq_len = [3, 2] + # tgt_seq_len = [4, 2] + + ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, + 'tgt_seq_len':src_seq_len}) + + dl = TorchDataLoader(ds, batch_size=32) + return embed, dl + + +class ExitCallback(Callback): + def __init__(self): + super().__init__() + + def on_valid_end(self, trainer, results): + if results['acc#acc'] == 1: + raise KeyboardInterrupt() + + +@pytest.mark.torch +class TestSeq2SeqGeneratorModel: + def test_run(self): + # 检测是否能够使用SequenceGeneratorModel训练, 透传预测 + embed, dl = prepare_env() + model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, + dim_ff=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True) + optimizer = torch.optim.Adam(model1.parameters(), lr=1e-3) + trainer = Trainer(model1, driver='torch', optimizers=optimizer, train_dataloader=dl, + n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, + evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], + 'seq_len': x['tgt_seq_len'], + **x}, + callbacks=ExitCallback()) + + trainer.run() + + embed, dl = prepare_env() + model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + num_layers=1, hidden_size=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True, attention=True) + optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) + trainer = Trainer(model2, driver='torch', optimizers=optimizer, train_dataloader=dl, + n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, + evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], + 'seq_len': x['tgt_seq_len'], + **x}, + callbacks=ExitCallback()) + trainer.run() diff --git a/tests/models/torch/test_seq2seq_model.py b/tests/models/torch/test_seq2seq_model.py new file mode 100755 index 00000000..ee1775d3 --- /dev/null +++ b/tests/models/torch/test_seq2seq_model.py @@ -0,0 +1,113 @@ + +import pytest +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from fastNLP.models.torch.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel + from fastNLP import Vocabulary + from fastNLP.embeddings.torch import StaticEmbedding + import torch + from torch import optim + import torch.nn.functional as F +from fastNLP import seq_len_to_mask + + +def prepare_env(): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) + + src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) + tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) + src_seq_len = torch.LongTensor([3, 2]) + tgt_seq_len = torch.LongTensor([4, 2]) + + return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len + + +def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): + optimizer = optim.Adam(model.parameters(), lr=1e-2) + mask = seq_len_to_mask(tgt_seq_len).eq(0) + target = tgt_words_idx.masked_fill(mask, -100) + + for i in range(100): + optimizer.zero_grad() + pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size + loss = F.cross_entropy(pred.transpose(1, 2), target) + loss.backward() + optimizer.step() + + right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() + return right_count + + +@pytest.mark.torch +class TestTransformerSeq2SeqModel: + def test_run(self): + # 测试能否跑通 + embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() + for pos_embed in ['learned', 'sin']: + model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True) + + output = model(src_words_idx, tgt_words_idx, src_seq_len) + assert (output['pred'].size() == (2, 4, len(embed))) + + for bind_encoder_decoder_embed in [True, False]: + tgt_embed = None + for bind_decoder_input_output_embed in [True, False]: + if bind_encoder_decoder_embed == False: + tgt_embed = embed + + model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, + pos_embed='sin', max_position=20, num_layers=2, + d_model=30, n_head=6, dim_ff=20, dropout=0.1, + bind_encoder_decoder_embed=bind_encoder_decoder_embed, + bind_decoder_input_output_embed=bind_decoder_input_output_embed) + + output = model(src_words_idx, tgt_words_idx, src_seq_len) + assert (output['pred'].size() == (2, 4, len(embed))) + + def test_train(self): + # 测试能否train到overfit + embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() + + model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True) + + right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) + assert(right_count == tgt_words_idx.nelement()) + + +@pytest.mark.torch +class TestLSTMSeq2SeqModel: + def test_run(self): + # 测试能否跑通 + embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() + + for bind_encoder_decoder_embed in [True, False]: + tgt_embed = None + for bind_decoder_input_output_embed in [True, False]: + if bind_encoder_decoder_embed == False: + tgt_embed = embed + model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, + num_layers=2, hidden_size=20, dropout=0.1, + bind_encoder_decoder_embed=bind_encoder_decoder_embed, + bind_decoder_input_output_embed=bind_decoder_input_output_embed) + output = model(src_words_idx, tgt_words_idx, src_seq_len) + assert (output['pred'].size() == (2, 4, len(embed))) + + def test_train(self): + embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() + + model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + num_layers=1, hidden_size=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True) + + right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) + assert (right_count == tgt_words_idx.nelement()) + diff --git a/tests/models/torch/test_sequence_labeling.py b/tests/models/torch/test_sequence_labeling.py new file mode 100755 index 00000000..b05b07fc --- /dev/null +++ b/tests/models/torch/test_sequence_labeling.py @@ -0,0 +1,47 @@ +import pytest +from .model_runner import * +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from fastNLP.models.torch.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF + + +@pytest.mark.torch +class TestBiLSTM: + def test_case1(self): + # 测试能否正常运行CNN + init_emb = (VOCAB_SIZE, 30) + model = BiLSTMCRF(init_emb, + hidden_size=30, + num_classes=NUM_CLS) + + dl = RUNNER.prepare_pos_tagging_data() + metric = Accuracy() + RUNNER.run_model(model, dl, metric) + + +@pytest.mark.torch +class TestSeqLabel: + def test_case1(self): + # 测试能否正常运行CNN + init_emb = (VOCAB_SIZE, 30) + model = SeqLabeling(init_emb, + hidden_size=30, + num_classes=NUM_CLS) + + dl = RUNNER.prepare_pos_tagging_data() + metric = Accuracy() + RUNNER.run_model(model, dl, metric) + + +@pytest.mark.torch +class TestAdvSeqLabel: + def test_case1(self): + # 测试能否正常运行CNN + init_emb = (VOCAB_SIZE, 30) + model = AdvSeqLabel(init_emb, + hidden_size=30, + num_classes=NUM_CLS) + + dl = RUNNER.prepare_pos_tagging_data() + metric = Accuracy() + RUNNER.run_model(model, dl, metric) \ No newline at end of file diff --git a/tests/modules/torch/__init__.py b/tests/modules/torch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/torch/decoder/__init__.py b/tests/modules/torch/decoder/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/tests/modules/torch/decoder/test_CRF.py b/tests/modules/torch/decoder/test_CRF.py new file mode 100755 index 00000000..9cdf240a --- /dev/null +++ b/tests/modules/torch/decoder/test_CRF.py @@ -0,0 +1,327 @@ +import pytest +import os +from fastNLP import Vocabulary + + +@pytest.mark.torch +class TestCRF: + def test_case1(self): + from fastNLP.modules.torch.decoder.crf import allowed_transitions + # 检查allowed_transitions()能否正确使用 + + id2label = {0: 'B', 1: 'I', 2:'O'} + expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), + (2, 4), (3, 0), (3, 2)} + assert expected_res == set(allowed_transitions(id2label, include_start_end=True)) + + id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} + assert (expected_res == set( + allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) + + id2label = {0: 'B', 1: 'I', 2:'O', 3: '', 4:""} + allowed_transitions(id2label, include_start_end=True) + + labels = ['O'] + for label in ['X', 'Y']: + for tag in 'BI': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx:label for idx, label in enumerate(labels)} + expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), + (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), + (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} + assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) + + labels = [] + for label in ['X', 'Y']: + for tag in 'BMES': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), + (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), + (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} + assert (expected_res == set( + allowed_transitions(id2label, include_start_end=True))) + + def test_case11(self): + # 测试自动推断encoding类型 + from fastNLP.modules.torch.decoder.crf import allowed_transitions + + id2label = {0: 'B', 1: 'I', 2: 'O'} + expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), + (2, 4), (3, 0), (3, 2)} + assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) + + id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} + assert (expected_res == set( + allowed_transitions(id2label, include_start_end=True))) + + id2label = {0: 'B', 1: 'I', 2: 'O', 3: '', 4: ""} + allowed_transitions(id2label, include_start_end=True) + + labels = ['O'] + for label in ['X', 'Y']: + for tag in 'BI': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), + (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), + (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} + assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) + + labels = [] + for label in ['X', 'Y']: + for tag in 'BMES': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), + (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), + (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} + assert (expected_res == set( + allowed_transitions(id2label, include_start_end=True))) + + def test_case12(self): + # 测试能否通过vocab生成转移矩阵 + from fastNLP.modules.torch.decoder.crf import allowed_transitions + + id2label = {0: 'B', 1: 'I', 2: 'O'} + vocab = Vocabulary(unknown=None, padding=None) + for idx, tag in id2label.items(): + vocab.add_word(tag) + expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), + (2, 4), (3, 0), (3, 2)} + assert (expected_res == set(allowed_transitions(vocab, include_start_end=True))) + + id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} + vocab = Vocabulary(unknown=None, padding=None) + for idx, tag in id2label.items(): + vocab.add_word(tag) + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} + assert (expected_res == set( + allowed_transitions(vocab, include_start_end=True))) + + id2label = {0: 'B', 1: 'I', 2: 'O', 3: '', 4: ""} + vocab = Vocabulary() + for idx, tag in id2label.items(): + vocab.add_word(tag) + allowed_transitions(vocab, include_start_end=True) + + labels = ['O'] + for label in ['X', 'Y']: + for tag in 'BI': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), + (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), + (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} + vocab = Vocabulary(unknown=None, padding=None) + for idx, tag in id2label.items(): + vocab.add_word(tag) + assert (expected_res == set(allowed_transitions(vocab, include_start_end=True))) + + labels = [] + for label in ['X', 'Y']: + for tag in 'BMES': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + vocab = Vocabulary(unknown=None, padding=None) + for idx, tag in id2label.items(): + vocab.add_word(tag) + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), + (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), + (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} + assert (expected_res == set( + allowed_transitions(vocab, include_start_end=True))) + + # def test_case2(self): + # # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 + # pass + # import torch + # from fastNLP import seq_len_to_mask + # + # labels = ['O'] + # for label in ['X', 'Y']: + # for tag in 'BI': + # labels.append('{}-{}'.format(tag, label)) + # id2label = {idx: label for idx, label in enumerate(labels)} + # num_tags = len(id2label) + # max_len = 10 + # batch_size = 4 + # bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() + # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions + # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), + # include_start_end_transitions=False) + # bio_trans_m = allen_CRF.transitions + # bio_seq_lens = torch.randint(1, max_len, size=(batch_size,)) + # bio_seq_lens[0] = 1 + # bio_seq_lens[-1] = max_len + # mask = seq_len_to_mask(bio_seq_lens) + # allen_res = allen_CRF.viterbi_tags(bio_logits, mask) + # + # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions + # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + # include_start_end=True)) + # fast_CRF.trans_m = bio_trans_m + # fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) + # bio_scores = [round(score, 4) for _, score in allen_res] + # # score equal + # self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) + # # seq equal + # bio_path = [_ for _, score in allen_res] + # self.assertListEqual(bio_path, fast_res[0]) + # + # labels = [] + # for label in ['X', 'Y']: + # for tag in 'BMES': + # labels.append('{}-{}'.format(tag, label)) + # id2label = {idx: label for idx, label in enumerate(labels)} + # num_tags = len(id2label) + # + # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions + # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), + # include_start_end_transitions=False) + # bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() + # bmes_trans_m = allen_CRF.transitions + # bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,)) + # bmes_seq_lens[0] = 1 + # bmes_seq_lens[-1] = max_len + # mask = seq_len_to_mask(bmes_seq_lens) + # allen_res = allen_CRF.viterbi_tags(bmes_logits, mask) + # + # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions + # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + # encoding_type='BMES', + # include_start_end=True)) + # fast_CRF.trans_m = bmes_trans_m + # fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) + # # score equal + # bmes_scores = [round(score, 4) for _, score in allen_res] + # self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) + # # seq equal + # bmes_path = [_ for _, score in allen_res] + # self.assertListEqual(bmes_path, fast_res[0]) + # + # data = { + # 'bio_logits': bio_logits.tolist(), + # 'bio_scores': bio_scores, + # 'bio_path': bio_path, + # 'bio_trans_m': bio_trans_m.tolist(), + # 'bio_seq_lens': bio_seq_lens.tolist(), + # 'bmes_logits': bmes_logits.tolist(), + # 'bmes_scores': bmes_scores, + # 'bmes_path': bmes_path, + # 'bmes_trans_m': bmes_trans_m.tolist(), + # 'bmes_seq_lens': bmes_seq_lens.tolist(), + # } + # + # with open('weights.json', 'w') as f: + # import json + # json.dump(data, f) + + def test_case2(self): + # 测试CRF是否正常work。 + import json + import torch + from fastNLP import seq_len_to_mask + folder = os.path.dirname(os.path.abspath(__file__)) + path = os.path.join(folder, '../../../', 'helpers/data/modules/decoder/crf.json') + + with open(os.path.abspath(path), 'r') as f: + data = json.load(f) + + bio_logits = torch.FloatTensor(data['bio_logits']) + bio_scores = data['bio_scores'] + bio_path = data['bio_path'] + bio_trans_m = torch.FloatTensor(data['bio_trans_m']) + bio_seq_lens = torch.LongTensor(data['bio_seq_lens']) + + bmes_logits = torch.FloatTensor(data['bmes_logits']) + bmes_scores = data['bmes_scores'] + bmes_path = data['bmes_path'] + bmes_trans_m = torch.FloatTensor(data['bmes_trans_m']) + bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens']) + + labels = ['O'] + for label in ['X', 'Y']: + for tag in 'BI': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + num_tags = len(id2label) + + mask = seq_len_to_mask(bio_seq_lens) + + from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions + fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + include_start_end=True)) + fast_CRF.trans_m.data = bio_trans_m + fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) + # score equal + assert (bio_scores == [round(s, 4) for s in fast_res[1].tolist()]) + # seq equal + assert (bio_path == fast_res[0]) + + labels = [] + for label in ['X', 'Y']: + for tag in 'BMES': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + num_tags = len(id2label) + + mask = seq_len_to_mask(bmes_seq_lens) + + from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions + fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + encoding_type='BMES', + include_start_end=True)) + fast_CRF.trans_m.data = bmes_trans_m + fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) + # score equal + assert (bmes_scores == [round(s, 4) for s in fast_res[1].tolist()]) + # seq equal + assert (bmes_path == fast_res[0]) + + def test_case3(self): + # 测试crf的loss不会出现负数 + import torch + from fastNLP.modules.torch.decoder.crf import ConditionalRandomField + from fastNLP.core.utils import seq_len_to_mask + from torch import optim + from torch import nn + + num_tags, include_start_end_trans = 4, True + num_samples = 4 + lengths = torch.randint(3, 50, size=(num_samples, )).long() + max_len = lengths.max() + tags = torch.randint(num_tags, size=(num_samples, max_len)) + masks = seq_len_to_mask(lengths) + feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) + crf = ConditionalRandomField(num_tags, include_start_end_trans) + optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) + for _ in range(10): + loss = crf(feats, tags, masks).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + if _%1000==0: + print(loss) + assert (loss.item()> 0) + + def test_masking(self): + # 测试crf的pad masking正常运行 + import torch + from fastNLP.modules.torch.decoder.crf import ConditionalRandomField + max_len = 5 + n_tags = 5 + pad_len = 5 + + torch.manual_seed(4) + logit = torch.rand(1, max_len+pad_len, n_tags) + # logit[0, -1, :] = 0.0 + mask = torch.ones(1, max_len+pad_len) + mask[0,-pad_len] = 0 + model = ConditionalRandomField(n_tags) + pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len]) + mask_pred, mask_score = model.viterbi_decode(logit, mask) + assert (pred[0].tolist() == mask_pred[0,:-pad_len].tolist()) + diff --git a/tests/modules/torch/decoder/test_seq2seq_decoder.py b/tests/modules/torch/decoder/test_seq2seq_decoder.py new file mode 100755 index 00000000..10af0bfb --- /dev/null +++ b/tests/modules/torch/decoder/test_seq2seq_decoder.py @@ -0,0 +1,49 @@ +import pytest +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + + from fastNLP import Vocabulary + from fastNLP.embeddings.torch import StaticEmbedding + from fastNLP.modules.torch import TransformerSeq2SeqDecoder + from fastNLP.modules.torch import LSTMSeq2SeqDecoder + from fastNLP import seq_len_to_mask + +@pytest.mark.torch +class TestTransformerSeq2SeqDecoder: + @pytest.mark.parametrize("flag", [True, False]) + def test_case(self, flag): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, embedding_dim=10) + + encoder_output = torch.randn(2, 3, 10) + src_seq_len = torch.LongTensor([3, 2]) + encoder_mask = seq_len_to_mask(src_seq_len) + decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, + d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, + bind_decoder_input_output_embed = True) + state = decoder.init_state(encoder_output, encoder_mask) + output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) + assert (output.size() == (2, 4, len(vocab))) + + +@pytest.mark.torch +class TestLSTMDecoder: + @pytest.mark.parametrize("flag", [True, False]) + @pytest.mark.parametrize("attention", [True, False]) + def test_case(self, flag, attention): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) + + encoder_output = torch.randn(2, 3, 10) + tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) + src_seq_len = torch.LongTensor([3, 2]) + encoder_mask = seq_len_to_mask(src_seq_len) + decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, + dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) + state = decoder.init_state(encoder_output, encoder_mask) + output = decoder(tgt_words_idx, state) + assert tuple(output.size()) == (2, 4, len(vocab)) diff --git a/tests/modules/torch/encoder/__init__.py b/tests/modules/torch/encoder/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/tests/modules/torch/encoder/test_seq2seq_encoder.py b/tests/modules/torch/encoder/test_seq2seq_encoder.py new file mode 100755 index 00000000..4c68de8d --- /dev/null +++ b/tests/modules/torch/encoder/test_seq2seq_encoder.py @@ -0,0 +1,33 @@ +import pytest + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + + from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder + from fastNLP import Vocabulary + from fastNLP.embeddings.torch import StaticEmbedding + + +class TestTransformerSeq2SeqEncoder: + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + embed = StaticEmbedding(vocab, embedding_dim=5) + encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2) + words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) + seq_len = torch.LongTensor([3]) + encoder_output, encoder_mask = encoder(words_idx, seq_len) + assert (encoder_output.size() == (1, 3, 10)) + + +class TestBiLSTMEncoder: + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + embed = StaticEmbedding(vocab, embedding_dim=5) + encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1) + words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) + seq_len = torch.LongTensor([3]) + + encoder_output, encoder_mask = encoder(words_idx, seq_len) + assert (encoder_mask.size() == (1, 3)) diff --git a/tests/modules/torch/encoder/test_star_transformer.py b/tests/modules/torch/encoder/test_star_transformer.py new file mode 100755 index 00000000..fcb07ce3 --- /dev/null +++ b/tests/modules/torch/encoder/test_star_transformer.py @@ -0,0 +1,18 @@ +import pytest + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + from fastNLP.modules.torch.encoder.star_transformer import StarTransformer + + +@pytest.mark.torch +class TestStarTransformer: + def test_1(self): + model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100) + x = torch.rand(16, 45, 100) + mask = torch.ones(16, 45).byte() + y, yn = model(x, mask) + assert (tuple(y.size()) == (16, 45, 100)) + assert (tuple(yn.size()) == (16, 100)) diff --git a/tests/modules/torch/encoder/test_variational_rnn.py b/tests/modules/torch/encoder/test_variational_rnn.py new file mode 100755 index 00000000..0b70f3ae --- /dev/null +++ b/tests/modules/torch/encoder/test_variational_rnn.py @@ -0,0 +1,27 @@ +import pytest + +import numpy as np +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + from fastNLP.modules.torch.encoder.variational_rnn import VarLSTM + + +@pytest.mark.torch +class TestMaskedRnn: + def test_case_1(self): + masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) + x = torch.tensor([[[1.0], [2.0]]]) + print(x.size()) + y = masked_rnn(x) + + def test_case_2(self): + input_size = 12 + batch = 16 + hidden = 10 + masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) + + xx = torch.randn((batch, 32, input_size)) + y, _ = masked_rnn(xx) + assert(tuple(y.shape) == (batch, 32, hidden)) diff --git a/tests/modules/torch/generator/__init__.py b/tests/modules/torch/generator/__init__.py new file mode 100755 index 00000000..8b137891 --- /dev/null +++ b/tests/modules/torch/generator/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/modules/torch/generator/test_seq2seq_generator.py b/tests/modules/torch/generator/test_seq2seq_generator.py new file mode 100755 index 00000000..bb4d27a3 --- /dev/null +++ b/tests/modules/torch/generator/test_seq2seq_generator.py @@ -0,0 +1,146 @@ +import pytest + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from fastNLP.modules.torch.generator import SequenceGenerator + from fastNLP.modules.torch import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State + from fastNLP import Vocabulary + from fastNLP.embeddings.torch import StaticEmbedding + from torch import nn + from fastNLP import seq_len_to_mask +else: + from fastNLP.core.utils.dummy_class import DummyClass as State + from fastNLP.core.utils.dummy_class import DummyClass as Seq2SeqDecoder + + +def prepare_env(): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) + + encoder_output = torch.randn(2, 3, 10) + src_seq_len = torch.LongTensor([3, 2]) + encoder_mask = seq_len_to_mask(src_seq_len) + + return embed, encoder_output, encoder_mask + + +class GreedyDummyDecoder(Seq2SeqDecoder): + def __init__(self, decoder_output): + super().__init__() + self.cur_length = 0 + self.decoder_output = decoder_output + + def decode(self, tokens, state): + self.cur_length += 1 + scores = self.decoder_output[:, self.cur_length] + return scores + + +class DummyState(State): + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + + def reorder_state(self, indices: torch.LongTensor): + self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) + + +@pytest.mark.torch +class TestSequenceGenerator: + def test_run(self): + # 测试能否运行 (1) 初始化decoder,(2) decode一发 + embed, encoder_output, encoder_mask = prepare_env() + + for do_sample in [True, False]: + for num_beams in [1, 3, 5]: + decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, + dropout=0.3, bind_decoder_input_output_embed=True, attention=True) + state = decoder.init_state(encoder_output, encoder_mask) + generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, + do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, + repetition_penalty=1, length_penalty=1.0, pad_token_id=0) + generator.generate(state=state, tokens=None) + + decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), + d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, + bind_decoder_input_output_embed=True) + state = decoder.init_state(encoder_output, encoder_mask) + generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, + do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, + repetition_penalty=1, length_penalty=1.0, pad_token_id=0) + generator.generate(state=state, tokens=None) + + # 测试一下其它值 + decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), + d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, + dropout=0.1, + bind_decoder_input_output_embed=True) + state = decoder.init_state(encoder_output, encoder_mask) + generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, + do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, + eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) + generator.generate(state=state, tokens=None) + + def test_greedy_decode(self): + # 测试能否正确的generate + # greedy + for beam_search in [1, 3]: + decoder_output = torch.randn(2, 10, 5) + path = decoder_output.argmax(dim=-1) # 2 x 10 + decoder = GreedyDummyDecoder(decoder_output) + generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, + do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, + eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) + decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) + + assert (decode_path.eq(path).sum() == path.numel()) + + # greedy check eos_token_id + for beam_search in [1, 3]: + decoder_output = torch.randn(2, 10, 5) + decoder_output[:, :7, 4].fill_(-100) + decoder_output[0, 7, 4] = 1000 # 在第8个结束 + decoder_output[1, 5, 4] = 1000 + path = decoder_output.argmax(dim=-1) # 2 x 4 + decoder = GreedyDummyDecoder(decoder_output) + generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, + do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, + eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) + decode_path = generator.generate(DummyState(decoder), + tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) + assert (decode_path.size(1) == 8) # 长度为8 + assert (decode_path[0].eq(path[0, :8]).sum() == 8) + assert (decode_path[1, :6].eq(path[1, :6]).sum() == 6) + + def test_sample_decoder(self): + # greedy check eos_token_id + for beam_search in [1, 3]: + decode_paths = [] + # 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大 + num_tests = 10 + for i in range(num_tests): + decoder_output = torch.randn(2, 10, 5) * 10 + decoder_output[:, :7, 4].fill_(-100) + decoder_output[0, 7, 4] = 10000 # 在第8个结束 + decoder_output[1, 5, 4] = 10000 + path = decoder_output.argmax(dim=-1) # 2 x 4 + decoder = GreedyDummyDecoder(decoder_output) + generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, + do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, + eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) + decode_path = generator.generate(DummyState(decoder), + tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) + decode_paths.append([decode_path, path]) + sizes = [] + eqs = [] + eq2s = [] + for i in range(num_tests): + decode_path, path = decode_paths[i] + sizes.append(decode_path.size(1)==8) + eqs.append(decode_path[0].eq(path[0, :8]).sum()==8) + eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6) + assert any(sizes) + assert any(eqs) + assert any(eq2s) \ No newline at end of file