You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 3.9 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import gensim
  2. from gensim import models
  3. import os
  4. import pickle
  5. class SampleIter:
  6. def __init__(self, dirname):
  7. self.dirname = dirname
  8. def __iter__(self):
  9. for f in os.listdir(self.dirname):
  10. for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')):
  11. yield x, y
  12. class SentIter:
  13. def __init__(self, dirname, count):
  14. self.dirname = dirname
  15. self.count = int(count)
  16. def __iter__(self):
  17. for f in os.listdir(self.dirname)[:self.count]:
  18. for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')):
  19. for sent in x:
  20. yield sent
  21. def train_word_vec():
  22. # load data
  23. dirname = 'reviews'
  24. sents = SentIter(dirname, 238)
  25. # define model and train
  26. model = models.Word2Vec(sentences=sents, size=200, sg=0, workers=4, min_count=5)
  27. model.save('yelp.word2vec')
  28. '''
  29. Train process
  30. '''
  31. import math
  32. import os
  33. import copy
  34. import pickle
  35. import matplotlib.pyplot as plt
  36. import matplotlib.ticker as ticker
  37. import numpy as np
  38. import json
  39. import nltk
  40. from gensim.models import Word2Vec
  41. import torch
  42. from torch.utils.data import DataLoader, Dataset
  43. from model import *
  44. net = HAN(input_size=200, output_size=5,
  45. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  46. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  47. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
  48. criterion = nn.NLLLoss()
  49. num_epoch = 1
  50. batch_size = 64
  51. class Embedding_layer:
  52. def __init__(self, wv, vector_size):
  53. self.wv = wv
  54. self.vector_size = vector_size
  55. def get_vec(self, w):
  56. try:
  57. v = self.wv[w]
  58. except KeyError as e:
  59. v = np.zeros(self.vector_size)
  60. return v
  61. embed_model = Word2Vec.load('yelp.word2vec')
  62. embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
  63. del embed_model
  64. class YelpDocSet(Dataset):
  65. def __init__(self, dirname, num_files, embedding):
  66. self.dirname = dirname
  67. self.num_files = num_files
  68. self._len = num_files*5000
  69. self._files = os.listdir(dirname)[:num_files]
  70. self.embedding = embedding
  71. def __len__(self):
  72. return self._len
  73. def __getitem__(self, n):
  74. file_id = n // 5000
  75. sample_list = pickle.load(open(
  76. os.path.join(self.dirname, self._files[file_id]), 'rb'))
  77. y, x = sample_list[n % 5000]
  78. return x, y-1
  79. def collate(iterable):
  80. y_list = []
  81. x_list = []
  82. for x, y in iterable:
  83. y_list.append(y)
  84. x_list.append(x)
  85. return x_list, torch.LongTensor(y_list)
  86. if __name__ == '__main__':
  87. dirname = 'reviews'
  88. dataloader = DataLoader(YelpDocSet(dirname, 238, embedding), batch_size=batch_size, collate_fn=collate)
  89. running_loss = 0.0
  90. print_size = 10
  91. for epoch in range(num_epoch):
  92. for i, batch_samples in enumerate(dataloader):
  93. x, y = batch_samples
  94. doc_list = []
  95. for sample in x:
  96. doc = []
  97. for sent in sample:
  98. sent_vec = []
  99. for word in sent:
  100. vec = embedding.get_vec(word)
  101. sent_vec.append(torch.Tensor(vec.reshape((1, -1))))
  102. sent_vec = torch.cat(sent_vec, dim=0)
  103. # print(sent_vec.size())
  104. doc.append(Variable(sent_vec))
  105. doc_list.append(doc)
  106. y = Variable(y)
  107. predict = net(doc_list)
  108. loss = criterion(predict, y)
  109. optimizer.zero_grad()
  110. loss.backward()
  111. optimizer.step()
  112. running_loss += loss.data[0]
  113. print(loss.data[0])
  114. if i % print_size == print_size-1:
  115. print(running_loss/print_size)
  116. running_loss = 0.0

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等