You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_encoder.py 1.3 kB

3 years ago
3 years ago
1234567891011121314151617181920212223242526272829303132333435
  1. import pytest
  2. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  3. from fastNLP import Vocabulary
  4. if _NEED_IMPORT_TORCH:
  5. import torch
  6. from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
  7. from fastNLP.embeddings.torch import StaticEmbedding
  8. @pytest.mark.torch
  9. class TestTransformerSeq2SeqEncoder:
  10. def test_case(self):
  11. vocab = Vocabulary().add_word_lst("This is a test .".split())
  12. embed = StaticEmbedding(vocab, embedding_dim=5)
  13. encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2)
  14. words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
  15. seq_len = torch.LongTensor([3])
  16. encoder_output, encoder_mask = encoder(words_idx, seq_len)
  17. assert (encoder_output.size() == (1, 3, 10))
  18. @pytest.mark.torch
  19. class TestBiLSTMEncoder:
  20. def test_case(self):
  21. vocab = Vocabulary().add_word_lst("This is a test .".split())
  22. embed = StaticEmbedding(vocab, embedding_dim=5)
  23. encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1)
  24. words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
  25. seq_len = torch.LongTensor([3])
  26. encoder_output, encoder_mask = encoder(words_idx, seq_len)
  27. assert (encoder_mask.size() == (1, 3))