You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_generator.py 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import pytest
  2. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  3. if _NEED_IMPORT_TORCH:
  4. from fastNLP.models.torch import LSTMSeq2SeqModel, TransformerSeq2SeqModel
  5. import torch
  6. from fastNLP.embeddings.torch import StaticEmbedding
  7. from fastNLP import Vocabulary, DataSet
  8. from fastNLP import Trainer, Accuracy
  9. from fastNLP import Callback, TorchDataLoader
  10. def prepare_env():
  11. vocab = Vocabulary().add_word_lst("This is a test .".split())
  12. vocab.add_word_lst("Another test !".split())
  13. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
  14. src_words_idx = [[3, 1, 2], [1, 2]]
  15. # tgt_words_idx = [[1, 2, 3, 4], [2, 3]]
  16. src_seq_len = [3, 2]
  17. # tgt_seq_len = [4, 2]
  18. ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx,
  19. 'tgt_seq_len':src_seq_len})
  20. dl = TorchDataLoader(ds, batch_size=32)
  21. return embed, dl
  22. class ExitCallback(Callback):
  23. def __init__(self):
  24. super().__init__()
  25. def on_valid_end(self, trainer, results):
  26. if results['acc#acc'] == 1:
  27. raise KeyboardInterrupt()
  28. @pytest.mark.torch
  29. class TestSeq2SeqGeneratorModel:
  30. def test_run(self):
  31. # 检测是否能够使用SequenceGeneratorModel训练, 透传预测
  32. embed, dl = prepare_env()
  33. model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  34. pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6,
  35. dim_ff=20, dropout=0.1,
  36. bind_encoder_decoder_embed=True,
  37. bind_decoder_input_output_embed=True)
  38. optimizer = torch.optim.Adam(model1.parameters(), lr=1e-3)
  39. trainer = Trainer(model1, driver='torch', optimizers=optimizer, train_dataloader=dl,
  40. n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()},
  41. evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'],
  42. 'seq_len': x['tgt_seq_len'],
  43. **x},
  44. callbacks=ExitCallback())
  45. trainer.run()
  46. embed, dl = prepare_env()
  47. model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  48. num_layers=1, hidden_size=20, dropout=0.1,
  49. bind_encoder_decoder_embed=True,
  50. bind_decoder_input_output_embed=True, attention=True)
  51. optimizer = torch.optim.Adam(model2.parameters(), lr=0.01)
  52. trainer = Trainer(model2, driver='torch', optimizers=optimizer, train_dataloader=dl,
  53. n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()},
  54. evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'],
  55. 'seq_len': x['tgt_seq_len'],
  56. **x},
  57. callbacks=ExitCallback())
  58. trainer.run()