You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 4.4 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import os
  2. import pickle
  3. import matplotlib.pyplot as plt
  4. import matplotlib.ticker as ticker
  5. import nltk
  6. import numpy as np
  7. import torch
  8. from model import *
  9. class SampleIter:
  10. def __init__(self, dirname):
  11. self.dirname = dirname
  12. def __iter__(self):
  13. for f in os.listdir(self.dirname):
  14. for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')):
  15. yield x, y
  16. class SentIter:
  17. def __init__(self, dirname, count):
  18. self.dirname = dirname
  19. self.count = int(count)
  20. def __iter__(self):
  21. for f in os.listdir(self.dirname)[:self.count]:
  22. for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')):
  23. for sent in x:
  24. yield sent
  25. def train_word_vec():
  26. # load data
  27. dirname = 'reviews'
  28. sents = SentIter(dirname, 238)
  29. # define model and train
  30. model = models.Word2Vec(sentences=sents, size=200, sg=0, workers=4, min_count=5)
  31. model.save('yelp.word2vec')
  32. class Embedding_layer:
  33. def __init__(self, wv, vector_size):
  34. self.wv = wv
  35. self.vector_size = vector_size
  36. def get_vec(self, w):
  37. try:
  38. v = self.wv[w]
  39. except KeyError as e:
  40. v = np.zeros(self.vector_size)
  41. return v
  42. from torch.utils.data import DataLoader, Dataset
  43. class YelpDocSet(Dataset):
  44. def __init__(self, dirname, num_files, embedding):
  45. self.dirname = dirname
  46. self.num_files = num_files
  47. self._len = num_files*5000
  48. self._files = os.listdir(dirname)[:num_files]
  49. self.embedding = embedding
  50. def __len__(self):
  51. return self._len
  52. def __getitem__(self, n):
  53. file_id = n // 5000
  54. sample_list = pickle.load(open(
  55. os.path.join(self.dirname, self._files[file_id]), 'rb'))
  56. y, x = sample_list[n % 5000]
  57. return x, y-1
  58. def collate(iterable):
  59. y_list = []
  60. x_list = []
  61. for x, y in iterable:
  62. y_list.append(y)
  63. x_list.append(x)
  64. return x_list, torch.LongTensor(y_list)
  65. def train(net, num_epoch, batch_size, print_size=10, use_cuda=False):
  66. from gensim.models import Word2Vec
  67. import torch
  68. import gensim
  69. from gensim import models
  70. embed_model = Word2Vec.load('yelp.word2vec')
  71. embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
  72. del embed_model
  73. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
  74. criterion = nn.NLLLoss()
  75. dirname = 'reviews'
  76. dataloader = DataLoader(YelpDocSet(dirname, 238, embedding),
  77. batch_size=batch_size,
  78. collate_fn=collate,
  79. num_workers=4)
  80. running_loss = 0.0
  81. if use_cuda:
  82. net.cuda()
  83. for epoch in range(num_epoch):
  84. for i, batch_samples in enumerate(dataloader):
  85. x, y = batch_samples
  86. doc_list = []
  87. for sample in x:
  88. doc = []
  89. for sent in sample:
  90. sent_vec = []
  91. for word in sent:
  92. vec = embedding.get_vec(word)
  93. vec = torch.Tensor(vec.reshape((1, -1)))
  94. if use_cuda:
  95. vec = vec.cuda()
  96. sent_vec.append(vec)
  97. sent_vec = torch.cat(sent_vec, dim=0)
  98. # print(sent_vec.size())
  99. doc.append(Variable(sent_vec))
  100. doc_list.append(doc)
  101. if use_cuda:
  102. y = y.cuda()
  103. y = Variable(y)
  104. predict = net(doc_list)
  105. loss = criterion(predict, y)
  106. optimizer.zero_grad()
  107. loss.backward()
  108. optimizer.step()
  109. running_loss += loss.data[0]
  110. if i % print_size == print_size-1:
  111. print(running_loss/print_size)
  112. running_loss = 0.0
  113. torch.save(net.state_dict(), 'model.dict')
  114. torch.save(net.state_dict(), 'model.dict')
  115. if __name__ == '__main__':
  116. '''
  117. Train process
  118. '''
  119. net = HAN(input_size=200, output_size=5,
  120. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  121. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  122. train(net, num_epoch=1, batch_size=64, use_cuda=True)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等