You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 12 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. import os
  5. import pickle
  6. from collections import Counter
  7. import numpy as np
  8. import torch
  9. from torch.utils import data
  10. logger = logging.getLogger(__name__)
  11. logger.setLevel(logging.INFO)
  12. class PTBTree:
  13. WORD_TO_WORD_MAPPING = {
  14. "{": "-LCB-",
  15. "}": "-RCB-"
  16. }
  17. def __init__(self):
  18. self.subtrees = []
  19. self.word = None
  20. self.label = ""
  21. self.parent = None
  22. self.span = (-1, -1)
  23. self.word_vector = None # HOS, store dx1 RNN word vector
  24. self.prediction = None # HOS, store Kx1 prediction vector
  25. def is_leaf(self):
  26. return len(self.subtrees) == 0
  27. def set_by_text(self, text, pos=0, left=0):
  28. depth = 0
  29. right = left
  30. for i in range(pos + 1, len(text)):
  31. char = text[i]
  32. # update the depth
  33. if char == "(":
  34. depth += 1
  35. if depth == 1:
  36. subtree = PTBTree()
  37. subtree.parent = self
  38. subtree.set_by_text(text, i, right)
  39. right = subtree.span[1]
  40. self.span = (left, right)
  41. self.subtrees.append(subtree)
  42. elif char == ")":
  43. depth -= 1
  44. if len(self.subtrees) == 0:
  45. pos = i
  46. for j in range(i, 0, -1):
  47. if text[j] == " ":
  48. pos = j
  49. break
  50. self.word = text[pos + 1:i]
  51. self.span = (left, left + 1)
  52. # we've reached the end of the category that is the root of this subtree
  53. if depth == 0 and char == " " and self.label == "":
  54. self.label = text[pos + 1:i]
  55. # we've reached the end of the scope for this bracket
  56. if depth < 0:
  57. break
  58. # Fix some issues with variation in output, and one error in the treebank
  59. # for a word with a punctuation POS
  60. self.standardise_node()
  61. def standardise_node(self):
  62. if self.word in self.WORD_TO_WORD_MAPPING:
  63. self.word = self.WORD_TO_WORD_MAPPING[self.word]
  64. def __repr__(self, single_line=True, depth=0):
  65. ans = ""
  66. if not single_line and depth > 0:
  67. ans = "\n" + depth * "\t"
  68. ans += "(" + self.label
  69. if self.word is not None:
  70. ans += " " + self.word
  71. for subtree in self.subtrees:
  72. if single_line:
  73. ans += " "
  74. ans += subtree.__repr__(single_line, depth + 1)
  75. ans += ")"
  76. return ans
  77. def read_tree(source):
  78. cur_text = []
  79. depth = 0
  80. while True:
  81. line = source.readline()
  82. # Check if we are out of input
  83. if line == "":
  84. return None
  85. # strip whitespace and only use if this contains something
  86. line = line.strip()
  87. if line == "":
  88. continue
  89. cur_text.append(line)
  90. # Update depth
  91. for char in line:
  92. if char == "(":
  93. depth += 1
  94. elif char == ")":
  95. depth -= 1
  96. # At depth 0 we have a complete tree
  97. if depth == 0:
  98. tree = PTBTree()
  99. tree.set_by_text(" ".join(cur_text))
  100. return tree
  101. return None
  102. def read_trees(source, max_sents=-1):
  103. with open(source) as fp:
  104. trees = []
  105. while True:
  106. tree = read_tree(fp)
  107. if tree is None:
  108. break
  109. trees.append(tree)
  110. if len(trees) >= max_sents > 0:
  111. break
  112. return trees
  113. class SSTDataset(data.Dataset):
  114. def __init__(self, sents, mask, labels):
  115. self.sents = sents
  116. self.labels = labels
  117. self.mask = mask
  118. def __getitem__(self, index):
  119. return (self.sents[index], self.mask[index]), self.labels[index]
  120. def __len__(self):
  121. return len(self.sents)
  122. def sst_get_id_input(content, word_id_dict, max_input_length):
  123. words = content.split(" ")
  124. sentence = [word_id_dict["<pad>"]] * max_input_length
  125. mask = [0] * max_input_length
  126. unknown = word_id_dict["<unknown>"]
  127. for i, word in enumerate(words[:max_input_length]):
  128. sentence[i] = word_id_dict.get(word, unknown)
  129. mask[i] = 1
  130. return sentence, mask
  131. def sst_get_phrases(trees, sample_ratio=1.0, is_binary=False, only_sentence=False):
  132. all_phrases = []
  133. for tree in trees:
  134. if only_sentence:
  135. sentence = get_sentence_by_tree(tree)
  136. label = int(tree.label)
  137. pair = (sentence, label)
  138. all_phrases.append(pair)
  139. else:
  140. phrases = get_phrases_by_tree(tree)
  141. sentence = get_sentence_by_tree(tree)
  142. pair = (sentence, int(tree.label))
  143. all_phrases.append(pair)
  144. all_phrases += phrases
  145. if sample_ratio < 1.:
  146. np.random.shuffle(all_phrases)
  147. result_phrases = []
  148. for pair in all_phrases:
  149. if is_binary:
  150. phrase, label = pair
  151. if label <= 1:
  152. pair = (phrase, 0)
  153. elif label >= 3:
  154. pair = (phrase, 1)
  155. else:
  156. continue
  157. if sample_ratio == 1.:
  158. result_phrases.append(pair)
  159. else:
  160. rand_portion = np.random.random()
  161. if rand_portion < sample_ratio:
  162. result_phrases.append(pair)
  163. return result_phrases
  164. def get_phrases_by_tree(tree):
  165. phrases = []
  166. if tree is None:
  167. return phrases
  168. if tree.is_leaf():
  169. pair = (tree.word, int(tree.label))
  170. phrases.append(pair)
  171. return phrases
  172. left_child_phrases = get_phrases_by_tree(tree.subtrees[0])
  173. right_child_phrases = get_phrases_by_tree(tree.subtrees[1])
  174. phrases.extend(left_child_phrases)
  175. phrases.extend(right_child_phrases)
  176. sentence = get_sentence_by_tree(tree)
  177. pair = (sentence, int(tree.label))
  178. phrases.append(pair)
  179. return phrases
  180. def get_sentence_by_tree(tree):
  181. if tree is None:
  182. return ""
  183. if tree.is_leaf():
  184. return tree.word
  185. left_sentence = get_sentence_by_tree(tree.subtrees[0])
  186. right_sentence = get_sentence_by_tree(tree.subtrees[1])
  187. sentence = left_sentence + " " + right_sentence
  188. return sentence.strip()
  189. def get_word_id_dict(word_num_dict, word_id_dict, min_count):
  190. z = [k for k in sorted(word_num_dict.keys())]
  191. for word in z:
  192. count = word_num_dict[word]
  193. if count >= min_count:
  194. index = len(word_id_dict)
  195. if word not in word_id_dict:
  196. word_id_dict[word] = index
  197. return word_id_dict
  198. def load_word_num_dict(phrases, word_num_dict):
  199. for sentence, _ in phrases:
  200. words = sentence.split(" ")
  201. for cur_word in words:
  202. word = cur_word.strip()
  203. word_num_dict[word] += 1
  204. return word_num_dict
  205. def init_trainable_embedding(embedding_path, word_id_dict, embed_dim=300):
  206. word_embed_model = load_glove_model(embedding_path, embed_dim)
  207. assert word_embed_model["pool"].shape[1] == embed_dim
  208. embedding = np.random.random([len(word_id_dict), embed_dim]).astype(np.float32) / 2.0 - 0.25
  209. embedding[0] = np.zeros(embed_dim) # PAD
  210. embedding[1] = (np.random.rand(embed_dim) - 0.5) / 2 # UNK
  211. for word in sorted(word_id_dict.keys()):
  212. idx = word_id_dict[word]
  213. if idx == 0 or idx == 1:
  214. continue
  215. if word in word_embed_model["mapping"]:
  216. embedding[idx] = word_embed_model["pool"][word_embed_model["mapping"][word]]
  217. else:
  218. embedding[idx] = np.random.rand(embed_dim) / 2.0 - 0.25
  219. return embedding
  220. def sst_get_trainable_data(phrases, word_id_dict, max_input_length):
  221. texts, labels, mask = [], [], []
  222. for phrase, label in phrases:
  223. if not phrase.split():
  224. continue
  225. phrase_split, mask_split = sst_get_id_input(phrase, word_id_dict, max_input_length)
  226. texts.append(phrase_split)
  227. labels.append(int(label))
  228. mask.append(mask_split) # field_input is mask
  229. labels = np.array(labels, dtype=np.int64)
  230. texts = np.reshape(texts, [-1, max_input_length]).astype(np.int32)
  231. mask = np.reshape(mask, [-1, max_input_length]).astype(np.int32)
  232. return SSTDataset(texts, mask, labels)
  233. def load_glove_model(filename, embed_dim):
  234. if os.path.exists(filename + ".cache"):
  235. logger.info("Found cache. Loading...")
  236. with open(filename + ".cache", "rb") as fp:
  237. return pickle.load(fp)
  238. embedding = {"mapping": dict(), "pool": []}
  239. with open(filename) as f:
  240. for i, line in enumerate(f):
  241. line = line.rstrip("\n")
  242. vocab_word, *vec = line.rsplit(" ", maxsplit=embed_dim)
  243. assert len(vec) == 300, "Unexpected line: '%s'" % line
  244. embedding["pool"].append(np.array(list(map(float, vec)), dtype=np.float32))
  245. embedding["mapping"][vocab_word] = i
  246. embedding["pool"] = np.stack(embedding["pool"])
  247. with open(filename + ".cache", "wb") as fp:
  248. pickle.dump(embedding, fp)
  249. return embedding
  250. def read_data_sst(data_path, max_input_length=64, min_count=1, train_with_valid=False,
  251. train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False):
  252. logger.info("data path: {}".format(os.getcwd()))
  253. word_id_dict = dict()
  254. word_num_dict = Counter()
  255. sst_path = os.path.join(data_path, "sst")
  256. logger.info("Reading SST data...")
  257. train_file_name = os.path.join(sst_path, "trees", "train.txt")
  258. valid_file_name = os.path.join(sst_path, "trees", "dev.txt")
  259. test_file_name = os.path.join(sst_path, "trees", "test.txt")
  260. train_trees = read_trees(train_file_name)
  261. train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence)
  262. logger.info("Finish load train phrases.")
  263. valid_trees = read_trees(valid_file_name)
  264. valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence)
  265. logger.info("Finish load valid phrases.")
  266. if train_with_valid:
  267. train_phrases += valid_phrases
  268. test_trees = read_trees(test_file_name)
  269. test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True)
  270. logger.info("Finish load test phrases.")
  271. # get word_id_dict
  272. word_id_dict["<pad>"] = 0
  273. word_id_dict["<unknown>"] = 1
  274. load_word_num_dict(train_phrases, word_num_dict)
  275. logger.info("Finish load train words: %d.", len(word_num_dict))
  276. load_word_num_dict(valid_phrases, word_num_dict)
  277. load_word_num_dict(test_phrases, word_num_dict)
  278. logger.info("Finish load valid+test words: %d.", len(word_num_dict))
  279. word_id_dict = get_word_id_dict(word_num_dict, word_id_dict, min_count)
  280. logger.info("After trim vocab length: %d.", len(word_id_dict))
  281. logger.info("Loading embedding...")
  282. embedding = init_trainable_embedding(os.path.join(data_path, "glove.840B.300d.txt"), word_id_dict)
  283. logger.info("Finish initialize word embedding.")
  284. dataset_train = sst_get_trainable_data(train_phrases, word_id_dict, max_input_length)
  285. logger.info("Loaded %d training samples.", len(dataset_train))
  286. dataset_valid = sst_get_trainable_data(valid_phrases, word_id_dict, max_input_length)
  287. logger.info("Loaded %d validation samples.", len(dataset_valid))
  288. dataset_test = sst_get_trainable_data(test_phrases, word_id_dict, max_input_length)
  289. logger.info("Loaded %d test samples.", len(dataset_test))
  290. return dataset_train, dataset_valid, dataset_test, torch.from_numpy(embedding)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能