You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_model.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import pytest
  2. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  3. if _NEED_IMPORT_TORCH:
  4. from fastNLP.models.torch.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
  5. from fastNLP import Vocabulary
  6. from fastNLP.embeddings.torch import StaticEmbedding
  7. import torch
  8. from torch import optim
  9. import torch.nn.functional as F
  10. from fastNLP import seq_len_to_mask
  11. def prepare_env():
  12. vocab = Vocabulary().add_word_lst("This is a test .".split())
  13. vocab.add_word_lst("Another test !".split())
  14. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
  15. src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
  16. tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
  17. src_seq_len = torch.LongTensor([3, 2])
  18. tgt_seq_len = torch.LongTensor([4, 2])
  19. return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len
  20. def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len):
  21. optimizer = optim.Adam(model.parameters(), lr=5e-3)
  22. mask = seq_len_to_mask(tgt_seq_len).eq(0)
  23. target = tgt_words_idx.masked_fill(mask, -100)
  24. for i in range(50):
  25. optimizer.zero_grad()
  26. pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size
  27. loss = F.cross_entropy(pred.transpose(1, 2), target)
  28. loss.backward()
  29. optimizer.step()
  30. right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum()
  31. return right_count
  32. @pytest.mark.torch
  33. class TestTransformerSeq2SeqModel:
  34. def test_run(self):
  35. # 测试能否跑通
  36. embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
  37. for pos_embed in ['learned', 'sin']:
  38. model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  39. pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1,
  40. bind_encoder_decoder_embed=True,
  41. bind_decoder_input_output_embed=True)
  42. output = model(src_words_idx, tgt_words_idx, src_seq_len)
  43. assert (output['pred'].size() == (2, 4, len(embed)))
  44. for bind_encoder_decoder_embed in [True, False]:
  45. tgt_embed = None
  46. for bind_decoder_input_output_embed in [True, False]:
  47. if bind_encoder_decoder_embed == False:
  48. tgt_embed = embed
  49. model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed,
  50. pos_embed='sin', max_position=20, num_layers=2,
  51. d_model=30, n_head=6, dim_ff=20, dropout=0.1,
  52. bind_encoder_decoder_embed=bind_encoder_decoder_embed,
  53. bind_decoder_input_output_embed=bind_decoder_input_output_embed)
  54. output = model(src_words_idx, tgt_words_idx, src_seq_len)
  55. assert (output['pred'].size() == (2, 4, len(embed)))
  56. def test_train(self):
  57. # 测试能否train到overfit
  58. embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
  59. model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  60. pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1,
  61. bind_encoder_decoder_embed=True,
  62. bind_decoder_input_output_embed=True)
  63. right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len)
  64. assert(right_count == tgt_words_idx.nelement())
  65. @pytest.mark.torch
  66. class TestLSTMSeq2SeqModel:
  67. def test_run(self):
  68. # 测试能否跑通
  69. embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
  70. for bind_encoder_decoder_embed in [True, False]:
  71. tgt_embed = None
  72. for bind_decoder_input_output_embed in [True, False]:
  73. if bind_encoder_decoder_embed == False:
  74. tgt_embed = embed
  75. model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed,
  76. num_layers=2, hidden_size=20, dropout=0.1,
  77. bind_encoder_decoder_embed=bind_encoder_decoder_embed,
  78. bind_decoder_input_output_embed=bind_decoder_input_output_embed)
  79. output = model(src_words_idx, tgt_words_idx, src_seq_len)
  80. assert (output['pred'].size() == (2, 4, len(embed)))
  81. def test_train(self):
  82. embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
  83. model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  84. num_layers=1, hidden_size=20, dropout=0.1,
  85. bind_encoder_decoder_embed=True,
  86. bind_decoder_input_output_embed=True)
  87. right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len)
  88. assert (right_count == tgt_words_idx.nelement())