# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging import os import pickle from collections import Counter import numpy as np import torch from torch.utils import data logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class PTBTree: WORD_TO_WORD_MAPPING = { "{": "-LCB-", "}": "-RCB-" } def __init__(self): self.subtrees = [] self.word = None self.label = "" self.parent = None self.span = (-1, -1) self.word_vector = None # HOS, store dx1 RNN word vector self.prediction = None # HOS, store Kx1 prediction vector def is_leaf(self): return len(self.subtrees) == 0 def set_by_text(self, text, pos=0, left=0): depth = 0 right = left for i in range(pos + 1, len(text)): char = text[i] # update the depth if char == "(": depth += 1 if depth == 1: subtree = PTBTree() subtree.parent = self subtree.set_by_text(text, i, right) right = subtree.span[1] self.span = (left, right) self.subtrees.append(subtree) elif char == ")": depth -= 1 if len(self.subtrees) == 0: pos = i for j in range(i, 0, -1): if text[j] == " ": pos = j break self.word = text[pos + 1:i] self.span = (left, left + 1) # we've reached the end of the category that is the root of this subtree if depth == 0 and char == " " and self.label == "": self.label = text[pos + 1:i] # we've reached the end of the scope for this bracket if depth < 0: break # Fix some issues with variation in output, and one error in the treebank # for a word with a punctuation POS self.standardise_node() def standardise_node(self): if self.word in self.WORD_TO_WORD_MAPPING: self.word = self.WORD_TO_WORD_MAPPING[self.word] def __repr__(self, single_line=True, depth=0): ans = "" if not single_line and depth > 0: ans = "\n" + depth * "\t" ans += "(" + self.label if self.word is not None: ans += " " + self.word for subtree in self.subtrees: if single_line: ans += " " ans += subtree.__repr__(single_line, depth + 1) ans += ")" return ans def read_tree(source): cur_text = [] depth = 0 while True: line = source.readline() # Check if we are out of input if line == "": return None # strip whitespace and only use if this contains something line = line.strip() if line == "": continue cur_text.append(line) # Update depth for char in line: if char == "(": depth += 1 elif char == ")": depth -= 1 # At depth 0 we have a complete tree if depth == 0: tree = PTBTree() tree.set_by_text(" ".join(cur_text)) return tree return None def read_trees(source, max_sents=-1): with open(source) as fp: trees = [] while True: tree = read_tree(fp) if tree is None: break trees.append(tree) if len(trees) >= max_sents > 0: break return trees class SSTDataset(data.Dataset): def __init__(self, sents, mask, labels): self.sents = sents self.labels = labels self.mask = mask def __getitem__(self, index): return (self.sents[index], self.mask[index]), self.labels[index] def __len__(self): return len(self.sents) def sst_get_id_input(content, word_id_dict, max_input_length): words = content.split(" ") sentence = [word_id_dict[""]] * max_input_length mask = [0] * max_input_length unknown = word_id_dict[""] for i, word in enumerate(words[:max_input_length]): sentence[i] = word_id_dict.get(word, unknown) mask[i] = 1 return sentence, mask def sst_get_phrases(trees, sample_ratio=1.0, is_binary=False, only_sentence=False): all_phrases = [] for tree in trees: if only_sentence: sentence = get_sentence_by_tree(tree) label = int(tree.label) pair = (sentence, label) all_phrases.append(pair) else: phrases = get_phrases_by_tree(tree) sentence = get_sentence_by_tree(tree) pair = (sentence, int(tree.label)) all_phrases.append(pair) all_phrases += phrases if sample_ratio < 1.: np.random.shuffle(all_phrases) result_phrases = [] for pair in all_phrases: if is_binary: phrase, label = pair if label <= 1: pair = (phrase, 0) elif label >= 3: pair = (phrase, 1) else: continue if sample_ratio == 1.: result_phrases.append(pair) else: rand_portion = np.random.random() if rand_portion < sample_ratio: result_phrases.append(pair) return result_phrases def get_phrases_by_tree(tree): phrases = [] if tree is None: return phrases if tree.is_leaf(): pair = (tree.word, int(tree.label)) phrases.append(pair) return phrases left_child_phrases = get_phrases_by_tree(tree.subtrees[0]) right_child_phrases = get_phrases_by_tree(tree.subtrees[1]) phrases.extend(left_child_phrases) phrases.extend(right_child_phrases) sentence = get_sentence_by_tree(tree) pair = (sentence, int(tree.label)) phrases.append(pair) return phrases def get_sentence_by_tree(tree): if tree is None: return "" if tree.is_leaf(): return tree.word left_sentence = get_sentence_by_tree(tree.subtrees[0]) right_sentence = get_sentence_by_tree(tree.subtrees[1]) sentence = left_sentence + " " + right_sentence return sentence.strip() def get_word_id_dict(word_num_dict, word_id_dict, min_count): z = [k for k in sorted(word_num_dict.keys())] for word in z: count = word_num_dict[word] if count >= min_count: index = len(word_id_dict) if word not in word_id_dict: word_id_dict[word] = index return word_id_dict def load_word_num_dict(phrases, word_num_dict): for sentence, _ in phrases: words = sentence.split(" ") for cur_word in words: word = cur_word.strip() word_num_dict[word] += 1 return word_num_dict def init_trainable_embedding(embedding_path, word_id_dict, embed_dim=300): word_embed_model = load_glove_model(embedding_path, embed_dim) assert word_embed_model["pool"].shape[1] == embed_dim embedding = np.random.random([len(word_id_dict), embed_dim]).astype(np.float32) / 2.0 - 0.25 embedding[0] = np.zeros(embed_dim) # PAD embedding[1] = (np.random.rand(embed_dim) - 0.5) / 2 # UNK for word in sorted(word_id_dict.keys()): idx = word_id_dict[word] if idx == 0 or idx == 1: continue if word in word_embed_model["mapping"]: embedding[idx] = word_embed_model["pool"][word_embed_model["mapping"][word]] else: embedding[idx] = np.random.rand(embed_dim) / 2.0 - 0.25 return embedding def sst_get_trainable_data(phrases, word_id_dict, max_input_length): texts, labels, mask = [], [], [] for phrase, label in phrases: if not phrase.split(): continue phrase_split, mask_split = sst_get_id_input(phrase, word_id_dict, max_input_length) texts.append(phrase_split) labels.append(int(label)) mask.append(mask_split) # field_input is mask labels = np.array(labels, dtype=np.int64) texts = np.reshape(texts, [-1, max_input_length]).astype(np.int32) mask = np.reshape(mask, [-1, max_input_length]).astype(np.int32) return SSTDataset(texts, mask, labels) def load_glove_model(filename, embed_dim): if os.path.exists(filename + ".cache"): logger.info("Found cache. Loading...") with open(filename + ".cache", "rb") as fp: return pickle.load(fp) embedding = {"mapping": dict(), "pool": []} with open(filename) as f: for i, line in enumerate(f): line = line.rstrip("\n") vocab_word, *vec = line.rsplit(" ", maxsplit=embed_dim) assert len(vec) == 300, "Unexpected line: '%s'" % line embedding["pool"].append(np.array(list(map(float, vec)), dtype=np.float32)) embedding["mapping"][vocab_word] = i embedding["pool"] = np.stack(embedding["pool"]) with open(filename + ".cache", "wb") as fp: pickle.dump(embedding, fp) return embedding def read_data_sst(data_path, max_input_length=64, min_count=1, train_with_valid=False, train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False): logger.info("data path: {}".format(os.getcwd())) word_id_dict = dict() word_num_dict = Counter() sst_path = os.path.join(data_path, "sst") logger.info("Reading SST data...") train_file_name = os.path.join(sst_path, "trees", "train.txt") valid_file_name = os.path.join(sst_path, "trees", "dev.txt") test_file_name = os.path.join(sst_path, "trees", "test.txt") train_trees = read_trees(train_file_name) train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence) logger.info("Finish load train phrases.") valid_trees = read_trees(valid_file_name) valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence) logger.info("Finish load valid phrases.") if train_with_valid: train_phrases += valid_phrases test_trees = read_trees(test_file_name) test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True) logger.info("Finish load test phrases.") # get word_id_dict word_id_dict[""] = 0 word_id_dict[""] = 1 load_word_num_dict(train_phrases, word_num_dict) logger.info("Finish load train words: %d.", len(word_num_dict)) load_word_num_dict(valid_phrases, word_num_dict) load_word_num_dict(test_phrases, word_num_dict) logger.info("Finish load valid+test words: %d.", len(word_num_dict)) word_id_dict = get_word_id_dict(word_num_dict, word_id_dict, min_count) logger.info("After trim vocab length: %d.", len(word_id_dict)) logger.info("Loading embedding...") embedding = init_trainable_embedding(os.path.join(data_path, "glove.840B.300d.txt"), word_id_dict) logger.info("Finish initialize word embedding.") dataset_train = sst_get_trainable_data(train_phrases, word_id_dict, max_input_length) logger.info("Loaded %d training samples.", len(dataset_train)) dataset_valid = sst_get_trainable_data(valid_phrases, word_id_dict, max_input_length) logger.info("Loaded %d validation samples.", len(dataset_valid)) dataset_test = sst_get_trainable_data(test_phrases, word_id_dict, max_input_length) logger.info("Loaded %d test samples.", len(dataset_test)) return dataset_train, dataset_valid, dataset_test, torch.from_numpy(embedding)