You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_decoder.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import pytest
  2. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  3. if _NEED_IMPORT_TORCH:
  4. import torch
  5. from fastNLP import Vocabulary
  6. from fastNLP.embeddings.torch import StaticEmbedding
  7. from fastNLP.modules.torch import TransformerSeq2SeqDecoder
  8. from fastNLP.modules.torch import LSTMSeq2SeqDecoder
  9. from fastNLP import seq_len_to_mask
  10. @pytest.mark.torch
  11. class TestTransformerSeq2SeqDecoder:
  12. @pytest.mark.parametrize("flag", [True, False])
  13. def test_case(self, flag):
  14. vocab = Vocabulary().add_word_lst("This is a test .".split())
  15. vocab.add_word_lst("Another test !".split())
  16. embed = StaticEmbedding(vocab, embedding_dim=10)
  17. encoder_output = torch.randn(2, 3, 10)
  18. src_seq_len = torch.LongTensor([3, 2])
  19. encoder_mask = seq_len_to_mask(src_seq_len)
  20. decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None,
  21. d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1,
  22. bind_decoder_input_output_embed = True)
  23. state = decoder.init_state(encoder_output, encoder_mask)
  24. output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state)
  25. assert (output.size() == (2, 4, len(vocab)))
  26. @pytest.mark.torch
  27. class TestLSTMDecoder:
  28. @pytest.mark.parametrize("flag", [True, False])
  29. @pytest.mark.parametrize("attention", [True, False])
  30. def test_case(self, flag, attention):
  31. vocab = Vocabulary().add_word_lst("This is a test .".split())
  32. vocab.add_word_lst("Another test !".split())
  33. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10)
  34. encoder_output = torch.randn(2, 3, 10)
  35. tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
  36. src_seq_len = torch.LongTensor([3, 2])
  37. encoder_mask = seq_len_to_mask(src_seq_len)
  38. decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10,
  39. dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention)
  40. state = decoder.init_state(encoder_output, encoder_mask)
  41. output = decoder(tgt_words_idx, state)
  42. assert tuple(output.size()) == (2, 4, len(vocab))