import pytest from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: from fastNLP.models.torch import LSTMSeq2SeqModel, TransformerSeq2SeqModel import torch from fastNLP.embeddings.torch import StaticEmbedding from fastNLP import Vocabulary, DataSet from fastNLP import Trainer, Accuracy from fastNLP import Callback, TorchDataLoader def prepare_env(): vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) src_words_idx = [[3, 1, 2], [1, 2]] # tgt_words_idx = [[1, 2, 3, 4], [2, 3]] src_seq_len = [3, 2] # tgt_seq_len = [4, 2] ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, 'tgt_seq_len':src_seq_len}) dl = TorchDataLoader(ds, batch_size=32) return embed, dl class ExitCallback(Callback): def __init__(self): super().__init__() def on_valid_end(self, trainer, results): if results['acc#acc'] == 1: raise KeyboardInterrupt() @pytest.mark.torch class TestSeq2SeqGeneratorModel: def test_run(self): # 检测是否能够使用SequenceGeneratorModel训练, 透传预测 embed, dl = prepare_env() model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, bind_encoder_decoder_embed=True, bind_decoder_input_output_embed=True) optimizer = torch.optim.Adam(model1.parameters(), lr=1e-3) trainer = Trainer(model1, driver='torch', optimizers=optimizer, train_dataloader=dl, n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], 'seq_len': x['tgt_seq_len'], **x}, callbacks=ExitCallback()) trainer.run() embed, dl = prepare_env() model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, num_layers=1, hidden_size=20, dropout=0.1, bind_encoder_decoder_embed=True, bind_decoder_input_output_embed=True, attention=True) optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) trainer = Trainer(model2, driver='torch', optimizers=optimizer, train_dataloader=dl, n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], 'seq_len': x['tgt_seq_len'], **x}, callbacks=ExitCallback()) trainer.run()