You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.9 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. import os
  2. from collections import namedtuple
  3. import numpy as np
  4. import torch.optim as optim
  5. from .model import charLM
  6. from .test import test
  7. from .utilities import *
  8. def preprocess():
  9. word_dict, char_dict = create_word_char_dict("charlm.txt", "train.txt", "tests.txt")
  10. num_words = len(word_dict)
  11. num_char = len(char_dict)
  12. char_dict["BOW"] = num_char+1
  13. char_dict["EOW"] = num_char+2
  14. char_dict["PAD"] = 0
  15. # dict of (int, string)
  16. reverse_word_dict = {value:key for key, value in word_dict.items()}
  17. max_word_len = max([len(word) for word in word_dict])
  18. objects = {
  19. "word_dict": word_dict,
  20. "char_dict": char_dict,
  21. "reverse_word_dict": reverse_word_dict,
  22. "max_word_len": max_word_len
  23. }
  24. torch.save(objects, "cache/prep.pt")
  25. print("Preprocess done.")
  26. def to_var(x):
  27. if torch.cuda.is_available():
  28. x = x.cuda()
  29. return Variable(x)
  30. def train(net, data, opt):
  31. """
  32. :param net: the pytorch model
  33. :param data: numpy array
  34. :param opt: named tuple
  35. 1. random seed
  36. 2. define local input
  37. 3. training settting: learning rate, loss, etc
  38. 4. main loop epoch
  39. 5. batchify
  40. 6. validation
  41. 7. save model
  42. """
  43. torch.manual_seed(1024)
  44. train_input = torch.from_numpy(data.train_input)
  45. train_label = torch.from_numpy(data.train_label)
  46. valid_input = torch.from_numpy(data.valid_input)
  47. valid_label = torch.from_numpy(data.valid_label)
  48. # [num_seq, seq_len, max_word_len+2]
  49. num_seq = train_input.size()[0] // opt.lstm_seq_len
  50. train_input = train_input[:num_seq*opt.lstm_seq_len, :]
  51. train_input = train_input.view(-1, opt.lstm_seq_len, opt.max_word_len+2)
  52. num_seq = valid_input.size()[0] // opt.lstm_seq_len
  53. valid_input = valid_input[:num_seq*opt.lstm_seq_len, :]
  54. valid_input = valid_input.view(-1, opt.lstm_seq_len, opt.max_word_len+2)
  55. num_epoch = opt.epochs
  56. num_iter_per_epoch = train_input.size()[0] // opt.lstm_batch_size
  57. learning_rate = opt.init_lr
  58. old_PPL = 100000
  59. best_PPL = 100000
  60. # Log-SoftMax
  61. criterion = nn.CrossEntropyLoss()
  62. # word_emb_dim == hidden_size / num of hidden units
  63. hidden = (to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)),
  64. to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)))
  65. for epoch in range(num_epoch):
  66. ################ Validation ####################
  67. net.eval()
  68. loss_batch = []
  69. PPL_batch = []
  70. iterations = valid_input.size()[0] // opt.lstm_batch_size
  71. valid_generator = batch_generator(valid_input, opt.lstm_batch_size)
  72. vlabel_generator = batch_generator(valid_label, opt.lstm_batch_size*opt.lstm_seq_len)
  73. for t in range(iterations):
  74. batch_input = valid_generator.__next__()
  75. batch_label = vlabel_generator.__next__()
  76. hidden = [state.detach() for state in hidden]
  77. valid_output, hidden = net(to_var(batch_input), hidden)
  78. length = valid_output.size()[0]
  79. # [num_sample-1, len(word_dict)] vs [num_sample-1]
  80. valid_loss = criterion(valid_output, to_var(batch_label))
  81. PPL = torch.exp(valid_loss.data)
  82. loss_batch.append(float(valid_loss))
  83. PPL_batch.append(float(PPL))
  84. PPL = np.mean(PPL_batch)
  85. print("[epoch {}] valid PPL={}".format(epoch, PPL))
  86. print("valid loss={}".format(np.mean(loss_batch)))
  87. print("PPL decrease={}".format(float(old_PPL - PPL)))
  88. # Preserve the best model
  89. if best_PPL > PPL:
  90. best_PPL = PPL
  91. torch.save(net.state_dict(), "cache/model.pt")
  92. torch.save(net, "cache/net.pkl")
  93. # Adjust the learning rate
  94. if float(old_PPL - PPL) <= 1.0:
  95. learning_rate /= 2
  96. print("halved lr:{}".format(learning_rate))
  97. old_PPL = PPL
  98. ##################################################
  99. #################### Training ####################
  100. net.train()
  101. optimizer = optim.SGD(net.parameters(),
  102. lr = learning_rate,
  103. momentum=0.85)
  104. # split the first dim
  105. input_generator = batch_generator(train_input, opt.lstm_batch_size)
  106. label_generator = batch_generator(train_label, opt.lstm_batch_size*opt.lstm_seq_len)
  107. for t in range(num_iter_per_epoch):
  108. batch_input = input_generator.__next__()
  109. batch_label = label_generator.__next__()
  110. # detach hidden state of LSTM from last batch
  111. hidden = [state.detach() for state in hidden]
  112. output, hidden = net(to_var(batch_input), hidden)
  113. # [num_word, vocab_size]
  114. loss = criterion(output, to_var(batch_label))
  115. net.zero_grad()
  116. loss.backward()
  117. torch.nn.utils.clip_grad_norm(net.parameters(), 5, norm_type=2)
  118. optimizer.step()
  119. if (t+1) % 100 == 0:
  120. print("[epoch {} step {}] train loss={}, Perplexity={}".format(epoch+1,
  121. t+1, float(loss.data), float(np.exp(loss.data))))
  122. torch.save(net.state_dict(), "cache/model.pt")
  123. print("Training finished.")
  124. ################################################################
  125. if __name__=="__main__":
  126. word_embed_dim = 300
  127. char_embedding_dim = 15
  128. if os.path.exists("cache/prep.pt") is False:
  129. preprocess()
  130. objetcs = torch.load("cache/prep.pt")
  131. word_dict = objetcs["word_dict"]
  132. char_dict = objetcs["char_dict"]
  133. reverse_word_dict = objetcs["reverse_word_dict"]
  134. max_word_len = objetcs["max_word_len"]
  135. num_words = len(word_dict)
  136. print("word/char dictionary built. Start making inputs.")
  137. if os.path.exists("cache/data_sets.pt") is False:
  138. train_text = read_data("./train.txt")
  139. valid_text = read_data("./charlm.txt")
  140. test_text = read_data("./tests.txt")
  141. train_set = np.array(text2vec(train_text, char_dict, max_word_len))
  142. valid_set = np.array(text2vec(valid_text, char_dict, max_word_len))
  143. test_set = np.array(text2vec(test_text, char_dict, max_word_len))
  144. # Labels are next-word index in word_dict with the same length as inputs
  145. train_label = np.array([word_dict[w] for w in train_text[1:]] + [word_dict[train_text[-1]]])
  146. valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]])
  147. test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]])
  148. category = {"tdata": train_set, "vdata": valid_set, "tests": test_set,
  149. "trlabel":train_label, "vlabel":valid_label, "tlabel":test_label}
  150. torch.save(category, "cache/data_sets.pt")
  151. else:
  152. data_sets = torch.load("cache/data_sets.pt")
  153. train_set = data_sets["tdata"]
  154. valid_set = data_sets["vdata"]
  155. test_set = data_sets["tests"]
  156. train_label = data_sets["trlabel"]
  157. valid_label = data_sets["vlabel"]
  158. test_label = data_sets["tlabel"]
  159. DataTuple = namedtuple("DataTuple",
  160. "train_input train_label valid_input valid_label test_input test_label")
  161. data = DataTuple(train_input=train_set,
  162. train_label=train_label,
  163. valid_input=valid_set,
  164. valid_label=valid_label,
  165. test_input=test_set,
  166. test_label=test_label)
  167. print("Loaded data sets. Start building network.")
  168. USE_GPU = True
  169. cnn_batch_size = 700
  170. lstm_seq_len = 35
  171. lstm_batch_size = 20
  172. # cnn_batch_size == lstm_seq_len * lstm_batch_size
  173. net = charLM(char_embedding_dim,
  174. word_embed_dim,
  175. num_words,
  176. len(char_dict),
  177. use_gpu=USE_GPU)
  178. for param in net.parameters():
  179. nn.init.uniform(param.data, -0.05, 0.05)
  180. Options = namedtuple("Options", [
  181. "cnn_batch_size", "init_lr", "lstm_seq_len",
  182. "max_word_len", "lstm_batch_size", "epochs",
  183. "word_embed_dim"])
  184. opt = Options(cnn_batch_size=lstm_seq_len*lstm_batch_size,
  185. init_lr=1.0,
  186. lstm_seq_len=lstm_seq_len,
  187. max_word_len=max_word_len,
  188. lstm_batch_size=lstm_batch_size,
  189. epochs=35,
  190. word_embed_dim=word_embed_dim)
  191. print("Network built. Start training.")
  192. # You can stop training anytime by "ctrl+C"
  193. try:
  194. train(net, data, opt)
  195. except KeyboardInterrupt:
  196. print('-' * 89)
  197. print('Exiting from training early')
  198. torch.save(net, "cache/net.pkl")
  199. print("save net")
  200. test(net, data, opt)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等