import pytest from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: from fastNLP.models.torch.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel from fastNLP import Vocabulary from fastNLP.embeddings.torch import StaticEmbedding import torch from torch import optim import torch.nn.functional as F from fastNLP import seq_len_to_mask def prepare_env(): vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) src_seq_len = torch.LongTensor([3, 2]) tgt_seq_len = torch.LongTensor([4, 2]) return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): optimizer = optim.Adam(model.parameters(), lr=5e-3) mask = seq_len_to_mask(tgt_seq_len).eq(0) target = tgt_words_idx.masked_fill(mask, -100) for i in range(50): optimizer.zero_grad() pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size loss = F.cross_entropy(pred.transpose(1, 2), target) loss.backward() optimizer.step() right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() return right_count @pytest.mark.torch class TestTransformerSeq2SeqModel: def test_run(self): # 测试能否跑通 embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() for pos_embed in ['learned', 'sin']: model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, bind_encoder_decoder_embed=True, bind_decoder_input_output_embed=True) output = model(src_words_idx, tgt_words_idx, src_seq_len) assert (output['pred'].size() == (2, 4, len(embed))) for bind_encoder_decoder_embed in [True, False]: tgt_embed = None for bind_decoder_input_output_embed in [True, False]: if bind_encoder_decoder_embed == False: tgt_embed = embed model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, bind_encoder_decoder_embed=bind_encoder_decoder_embed, bind_decoder_input_output_embed=bind_decoder_input_output_embed) output = model(src_words_idx, tgt_words_idx, src_seq_len) assert (output['pred'].size() == (2, 4, len(embed))) def test_train(self): # 测试能否train到overfit embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, bind_encoder_decoder_embed=True, bind_decoder_input_output_embed=True) right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) assert(right_count == tgt_words_idx.nelement()) @pytest.mark.torch class TestLSTMSeq2SeqModel: def test_run(self): # 测试能否跑通 embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() for bind_encoder_decoder_embed in [True, False]: tgt_embed = None for bind_decoder_input_output_embed in [True, False]: if bind_encoder_decoder_embed == False: tgt_embed = embed model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, num_layers=2, hidden_size=20, dropout=0.1, bind_encoder_decoder_embed=bind_encoder_decoder_embed, bind_decoder_input_output_embed=bind_decoder_input_output_embed) output = model(src_words_idx, tgt_words_idx, src_seq_len) assert (output['pred'].size() == (2, 4, len(embed))) def test_train(self): embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, num_layers=1, hidden_size=20, dropout=0.1, bind_encoder_decoder_embed=True, bind_decoder_input_output_embed=True) right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) assert (right_count == tgt_words_idx.nelement())