You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 4.5 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import codecs
  2. import random
  3. import re
  4. import gensim
  5. import numpy as np
  6. from gensim import corpora
  7. from torch.utils.data import Dataset
  8. def clean_str(string):
  9. """
  10. Tokenization/string cleaning for all datasets except for SST.
  11. Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
  12. """
  13. string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
  14. string = re.sub(r"\'s", " \'s", string)
  15. string = re.sub(r"\'ve", " \'ve", string)
  16. string = re.sub(r"n\'t", " n\'t", string)
  17. string = re.sub(r"\'re", " \'re", string)
  18. string = re.sub(r"\'d", " \'d", string)
  19. string = re.sub(r"\'ll", " \'ll", string)
  20. string = re.sub(r",", " , ", string)
  21. string = re.sub(r"!", " ! ", string)
  22. string = re.sub(r"\(", " \( ", string)
  23. string = re.sub(r"\)", " \) ", string)
  24. string = re.sub(r"\?", " \? ", string)
  25. string = re.sub(r"\s{2,}", " ", string)
  26. return string.strip()
  27. def pad_sentences(sentence, padding_word=" <PAD/>"):
  28. sequence_length = 64
  29. sent = sentence.split()
  30. padded_sentence = sentence + padding_word * (sequence_length - len(sent))
  31. return padded_sentence
  32. #data loader
  33. class MRDataset(Dataset):
  34. def __init__(self):
  35. #load positive and negative sentenses from files
  36. with codecs.open("./rt-polaritydata/rt-polarity.pos",encoding ='ISO-8859-1') as f:
  37. positive_examples = list(f.readlines())
  38. with codecs.open("./rt-polaritydata/rt-polarity.neg",encoding ='ISO-8859-1') as f:
  39. negative_examples = list(f.readlines())
  40. #s.strip: clear "\n"; clear_str; pad
  41. positive_examples = [pad_sentences(clean_str(s.strip())) for s in positive_examples]
  42. negative_examples = [pad_sentences(clean_str(s.strip())) for s in negative_examples]
  43. self.examples = positive_examples + negative_examples
  44. self.sentences_texts = [sample.split() for sample in self.examples]
  45. #word dictionary
  46. dictionary = corpora.Dictionary(self.sentences_texts)
  47. self.word2id_dict = dictionary.token2id # transform to dict, like {"human":0, "a":1,...}
  48. #set lables: postive is 1; negative is 0
  49. positive_labels = [1 for _ in positive_examples]
  50. negative_labels = [0 for _ in negative_examples]
  51. self.lables = positive_labels + negative_labels
  52. examples_lables = list(zip(self.examples,self.lables))
  53. random.shuffle(examples_lables)
  54. self.MRDataset_frame = examples_lables
  55. #transform word to id
  56. self.MRDataset_wordid = \
  57. [(
  58. np.array([self.word2id_dict[word] for word in sent[0].split()], dtype=np.int64),
  59. sent[1]
  60. ) for sent in self.MRDataset_frame]
  61. def word_embeddings(self, path="./GoogleNews-vectors-negative300.bin/GoogleNews-vectors-negative300.bin"):
  62. # establish from google
  63. model = gensim.models.KeyedVectors.load_word2vec_format(path, binary=True)
  64. print('Please wait ... (it could take a while to load the file : {})'.format(path))
  65. word_dict = self.word2id_dict
  66. embedding_weights = np.random.uniform(-0.25, 0.25, (len(self.word2id_dict), 300))
  67. for word in word_dict:
  68. word_id = word_dict[word]
  69. if word in model.wv.vocab:
  70. embedding_weights[word_id, :] = model[word]
  71. return embedding_weights
  72. def __len__(self):
  73. return len(self.MRDataset_frame)
  74. def __getitem__(self,idx):
  75. sample = self.MRDataset_wordid[idx]
  76. return sample
  77. def getsent(self, idx):
  78. sample = self.MRDataset_wordid[idx][0]
  79. return sample
  80. def getlabel(self, idx):
  81. label = self.MRDataset_wordid[idx][1]
  82. return label
  83. def word2id(self):
  84. return self.word2id_dict
  85. def id2word(self):
  86. id2word_dict = dict([val,key] for key,val in self.word2id_dict.items())
  87. return id2word_dict
  88. class train_set(Dataset):
  89. def __init__(self, samples):
  90. self.train_frame = samples
  91. def __len__(self):
  92. return len(self.train_frame)
  93. def __getitem__(self, idx):
  94. return self.train_frame[idx]
  95. class test_set(Dataset):
  96. def __init__(self, samples):
  97. self.test_frame = samples
  98. def __len__(self):
  99. return len(self.test_frame)
  100. def __getitem__(self, idx):
  101. return self.test_frame[idx]

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等