You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.8 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import os
  2. import pickle
  3. import matplotlib.pyplot as plt
  4. import matplotlib.ticker as ticker
  5. import nltk
  6. import numpy as np
  7. import torch
  8. from model import *
  9. UNK_token = '/unk'
  10. class SampleIter:
  11. def __init__(self, dirname):
  12. self.dirname = dirname
  13. def __iter__(self):
  14. for f in os.listdir(self.dirname):
  15. with open(os.path.join(self.dirname, f), 'rb') as f:
  16. for y, x in pickle.load(f):
  17. yield x, y
  18. class SentIter:
  19. def __init__(self, dirname, count, vocab=None):
  20. self.dirname = dirname
  21. self.count = int(count)
  22. self.vocab = None
  23. def __iter__(self):
  24. for f in os.listdir(self.dirname)[:self.count]:
  25. with open(os.path.join(self.dirname, f), 'rb') as f:
  26. for y, x in pickle.load(f):
  27. for sent in x:
  28. if self.vocab is not None:
  29. _sent = []
  30. for w in sent:
  31. if w in self.vocab:
  32. _sent.append(w)
  33. else:
  34. _sent.append(UNK_token)
  35. sent = _sent
  36. yield sent
  37. def train_word_vec():
  38. # load data
  39. dirname = 'reviews'
  40. sents = SentIter(dirname, 238)
  41. # define model and train
  42. model = models.Word2Vec(size=200, sg=0, workers=4, min_count=5)
  43. model.build_vocab(sents)
  44. sents.vocab = model.wv.vocab
  45. model.train(sents, total_examples=model.corpus_count, epochs=10)
  46. model.save('yelp.word2vec')
  47. print(model.wv.similarity('woman', 'man'))
  48. print(model.wv.similarity('nice', 'awful'))
  49. class Embedding_layer:
  50. def __init__(self, wv, vector_size):
  51. self.wv = wv
  52. self.vector_size = vector_size
  53. def get_vec(self, w):
  54. try:
  55. v = self.wv[w]
  56. except KeyError as e:
  57. v = np.random.randn(self.vector_size)
  58. return v
  59. from torch.utils.data import DataLoader, Dataset
  60. class YelpDocSet(Dataset):
  61. def __init__(self, dirname, start_file, num_files, embedding):
  62. self.dirname = dirname
  63. self.num_files = num_files
  64. self._files = os.listdir(dirname)[start_file:start_file + num_files]
  65. self.embedding = embedding
  66. self._cache = [(-1, None) for i in range(5)]
  67. def __len__(self):
  68. return len(self._files)*5000
  69. def __getitem__(self, n):
  70. file_id = n // 5000
  71. idx = file_id % 5
  72. if self._cache[idx][0] != file_id:
  73. print('load {} to {}'.format(file_id, idx))
  74. with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
  75. self._cache[idx] = (file_id, pickle.load(f))
  76. y, x = self._cache[idx][1][n % 5000]
  77. doc = []
  78. for sent in x:
  79. if len(sent) == 0:
  80. continue
  81. sent_vec = []
  82. for word in sent:
  83. vec = self.embedding.get_vec(word)
  84. sent_vec.append(vec.tolist())
  85. sent_vec = torch.Tensor(sent_vec)
  86. # print(sent_vec.size())
  87. doc.append(sent_vec)
  88. if len(doc) == 0:
  89. doc = [torch.zeros(1,200)]
  90. return doc, y-1
  91. def collate(iterable):
  92. y_list = []
  93. x_list = []
  94. for x, y in iterable:
  95. y_list.append(y)
  96. x_list.append(x)
  97. return x_list, torch.LongTensor(y_list)
  98. def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False):
  99. optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
  100. criterion = nn.NLLLoss()
  101. dataloader = DataLoader(dataset,
  102. batch_size=batch_size,
  103. collate_fn=collate,
  104. num_workers=0)
  105. running_loss = 0.0
  106. if use_cuda:
  107. net.cuda()
  108. print('start training')
  109. for epoch in range(num_epoch):
  110. for i, batch_samples in enumerate(dataloader):
  111. x, y = batch_samples
  112. doc_list = []
  113. for sample in x:
  114. doc = []
  115. for sent_vec in sample:
  116. # print(sent_vec.size())
  117. if use_cuda:
  118. sent_vec = sent_vec.cuda()
  119. doc.append(Variable(sent_vec))
  120. doc_list.append(pack_sequence(doc))
  121. if use_cuda:
  122. y = y.cuda()
  123. y = Variable(y)
  124. predict = net(doc_list)
  125. loss = criterion(predict, y)
  126. optimizer.zero_grad()
  127. loss.backward()
  128. optimizer.step()
  129. running_loss += loss.data[0]
  130. if i % print_size == print_size-1:
  131. print('{}, {}'.format(i+1, running_loss/print_size))
  132. running_loss = 0.0
  133. torch.save(net.state_dict(), 'model.dict')
  134. torch.save(net.state_dict(), 'model.dict')
  135. if __name__ == '__main__':
  136. '''
  137. Train process
  138. '''
  139. from gensim.models import Word2Vec
  140. import gensim
  141. from gensim import models
  142. # train_word_vec()
  143. embed_model = Word2Vec.load('yelp.word2vec')
  144. embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
  145. del embed_model
  146. # for start_file in range(11, 24):
  147. start_file = 0
  148. dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding)
  149. print('start_file %d'% start_file)
  150. print(len(dataset))
  151. net = HAN(input_size=200, output_size=5,
  152. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  153. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  154. net.load_state_dict(torch.load('model.dict'))
  155. train(net, dataset, num_epoch=1, batch_size=64, use_cuda=True)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等