You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_embed_loader.py 1.6 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import numpy as np
  2. from fastNLP import Vocabulary
  3. from fastNLP.io import EmbedLoader
  4. class TestEmbedLoader:
  5. def test_load_with_vocab(self):
  6. vocab = Vocabulary()
  7. glove = "tests/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt"
  8. word2vec = "tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt"
  9. vocab.add_word('the')
  10. vocab.add_word('none')
  11. g_m = EmbedLoader.load_with_vocab(glove, vocab)
  12. assert(g_m.shape == (4, 50))
  13. w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True)
  14. assert(w_m.shape ==(4, 50))
  15. assert np.allclose(np.linalg.norm(w_m, axis=1).sum(), 4)
  16. def test_load_without_vocab(self):
  17. words = ['the', 'of', 'in', 'a', 'to', 'and']
  18. glove = "tests/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt"
  19. word2vec = "tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt"
  20. g_m, vocab = EmbedLoader.load_without_vocab(glove)
  21. assert(g_m.shape == (8, 50))
  22. for word in words:
  23. assert(word in vocab)
  24. w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True)
  25. assert(w_m.shape== (8, 50))
  26. assert np.allclose(np.linalg.norm(w_m, axis=1).sum(), 8)
  27. for word in words:
  28. assert(word in vocab)
  29. # no unk
  30. w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None)
  31. assert(w_m.shape == (7, 50))
  32. assert np.allclose(np.linalg.norm(w_m, axis=1).sum(), 7)
  33. for word in words:
  34. assert(word in vocab)