You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_static_embedding.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import pytest
  2. import os
  3. from fastNLP.embeddings.torch import StaticEmbedding
  4. from fastNLP import Vocabulary
  5. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  6. if _NEED_IMPORT_TORCH:
  7. import torch
  8. import numpy as np
  9. tests_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
  10. @pytest.mark.torch
  11. class TestLoad:
  12. def test_norm1(self):
  13. # 测试只对可以找到的norm
  14. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  15. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  16. 'glove.6B.50d_test.txt',
  17. only_norm_found_vector=True)
  18. assert round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4) == 1
  19. assert torch.norm(embed(torch.LongTensor([[4]]))).item() != 1
  20. def test_norm2(self):
  21. # 测试对所有都norm
  22. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  23. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  24. 'glove.6B.50d_test.txt',
  25. normalize=True)
  26. assert round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4) == 1
  27. assert round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4) == 1
  28. def test_dropword(self):
  29. # 测试是否可以通过drop word
  30. vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)])
  31. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4)
  32. for i in range(10):
  33. length = torch.randint(1, 50, (1,)).item()
  34. batch = torch.randint(1, 4, (1,)).item()
  35. words = torch.randint(1, 200, (batch, length)).long()
  36. embed(words)
  37. def test_only_use_pretrain_word(self):
  38. def check_word_unk(words, vocab, embed):
  39. for word in words:
  40. assert embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0] == embed(torch.LongTensor([1])).tolist()[0]
  41. def check_vector_equal(words, vocab, embed, embed_dict, lower=False):
  42. for word in words:
  43. index = vocab.to_index(word)
  44. v1 = embed(torch.LongTensor([index])).tolist()[0]
  45. if lower:
  46. word = word.lower()
  47. v2 = embed_dict[word]
  48. for v1i, v2i in zip(v1, v2):
  49. assert np.allclose(v1i, v2i)
  50. embed_dict = read_static_embed(tests_folder+'/helpers/data/embedding/small_static_embedding/'
  51. 'glove.6B.50d_test.txt')
  52. # 测试是否只使用pretrain的word
  53. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  54. vocab.add_word('of', no_create_entry=True)
  55. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  56. 'glove.6B.50d_test.txt',
  57. only_use_pretrain_word=True)
  58. # notinfile应该被置为unk
  59. check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict)
  60. check_word_unk(['notinfile'], vocab, embed)
  61. # 测试在大小写情况下的使用
  62. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile'])
  63. vocab.add_word('Of', no_create_entry=True)
  64. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  65. 'glove.6B.50d_test.txt',
  66. only_use_pretrain_word=True)
  67. check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到
  68. check_vector_equal(['a'], vocab, embed, embed_dict)
  69. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  70. 'glove.6B.50d_test.txt',
  71. only_use_pretrain_word=True, lower=True)
  72. check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True)
  73. check_word_unk(['notinfile'], vocab, embed)
  74. # 测试min_freq
  75. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2'])
  76. vocab.add_word('Of', no_create_entry=True)
  77. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  78. 'glove.6B.50d_test.txt',
  79. only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True)
  80. check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True)
  81. check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed)
  82. def test_sequential_index(self):
  83. # 当不存在no_create_entry时,words_to_words应该是顺序的
  84. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2'])
  85. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  86. 'glove.6B.50d_test.txt')
  87. for index,i in enumerate(embed.words_to_words):
  88. assert index==i
  89. embed_dict = read_static_embed(tests_folder+'/helpers/data/embedding/small_static_embedding/'
  90. 'glove.6B.50d_test.txt')
  91. for word, index in vocab:
  92. if word in embed_dict:
  93. index = vocab.to_index(word)
  94. v1 = embed(torch.LongTensor([index])).tolist()[0]
  95. v2 = embed_dict[word]
  96. for v1i, v2i in zip(v1, v2):
  97. assert np.allclose(v1i, v2i)
  98. def test_save_load_static_embed(self):
  99. static_test_folder = 'static_save_test'
  100. try:
  101. # 测试包含no_create_entry
  102. os.makedirs(static_test_folder, exist_ok=True)
  103. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A'])
  104. vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True)
  105. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  106. 'glove.6B.50d_test.txt')
  107. embed.save(static_test_folder)
  108. load_embed = StaticEmbedding.load(static_test_folder)
  109. words = torch.randint(len(vocab), size=(2, 20))
  110. assert (embed(words) - load_embed(words)).sum() == 0
  111. # 测试不包含no_create_entry
  112. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A'])
  113. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  114. 'glove.6B.50d_test.txt')
  115. embed.save(static_test_folder)
  116. load_embed = StaticEmbedding.load(static_test_folder)
  117. words = torch.randint(len(vocab), size=(2, 20))
  118. assert (embed(words) - load_embed(words)).sum() == 0
  119. # 测试lower, min_freq
  120. vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B'])
  121. embed = StaticEmbedding(vocab, model_dir_or_name=tests_folder+'/helpers/data/embedding/small_static_embedding/'
  122. 'glove.6B.50d_test.txt', min_freq=2, lower=True)
  123. embed.save(static_test_folder)
  124. load_embed = StaticEmbedding.load(static_test_folder)
  125. words = torch.randint(len(vocab), size=(2, 20))
  126. assert (embed(words) - load_embed(words)).sum() == 0
  127. # 测试random的embedding
  128. vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B'])
  129. vocab = vocab.add_word_lst(['b'], no_create_entry=True)
  130. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=4, min_freq=2, lower=True,
  131. normalize=True)
  132. embed.weight.data += 0.2 # 使得它不是normalize
  133. embed.save(static_test_folder)
  134. load_embed = StaticEmbedding.load(static_test_folder)
  135. words = torch.randint(len(vocab), size=(2, 20))
  136. assert (embed(words) - load_embed(words)).sum()==0
  137. finally:
  138. if os.path.isdir(static_test_folder):
  139. import shutil
  140. shutil.rmtree(static_test_folder)
  141. def read_static_embed(fp):
  142. """
  143. :param str fp: embedding的路径
  144. :return: {}, key是word, value是vector
  145. """
  146. embed = {}
  147. with open(fp, 'r') as f:
  148. for line in f:
  149. line = line.strip()
  150. if line:
  151. parts = line.split()
  152. vector = list(map(float, parts[1:]))
  153. word = parts[0]
  154. embed[word] = vector
  155. return embed
  156. @pytest.mark.torch
  157. class TestRandomSameEntry:
  158. def test_same_vector(self):
  159. vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
  160. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
  161. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
  162. words = embed(words)
  163. embed_0 = words[0, 0]
  164. for i in range(1, 3):
  165. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  166. embed_0 = words[0, 3]
  167. for i in range(3, 5):
  168. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  169. def test_dropout_close(self):
  170. vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
  171. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True,
  172. dropout=0.5, word_dropout=0.9)
  173. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
  174. embed.eval()
  175. words = embed(words)
  176. embed_0 = words[0, 0]
  177. for i in range(1, 3):
  178. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  179. embed_0 = words[0, 3]
  180. for i in range(3, 5):
  181. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))