You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 3.8 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import os
  2. from collections import namedtuple
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from torch.autograd import Variable
  7. from utilities import *
  8. def to_var(x):
  9. if torch.cuda.is_available():
  10. x = x.cuda()
  11. return Variable(x)
  12. def test(net, data, opt):
  13. net.eval()
  14. test_input = torch.from_numpy(data.test_input)
  15. test_label = torch.from_numpy(data.test_label)
  16. num_seq = test_input.size()[0] // opt.lstm_seq_len
  17. test_input = test_input[:num_seq*opt.lstm_seq_len, :]
  18. # [num_seq, seq_len, max_word_len+2]
  19. test_input = test_input.view(-1, opt.lstm_seq_len, opt.max_word_len+2)
  20. criterion = nn.CrossEntropyLoss()
  21. loss_list = []
  22. num_hits = 0
  23. total = 0
  24. iterations = test_input.size()[0] // opt.lstm_batch_size
  25. test_generator = batch_generator(test_input, opt.lstm_batch_size)
  26. label_generator = batch_generator(test_label, opt.lstm_batch_size*opt.lstm_seq_len)
  27. hidden = (to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)),
  28. to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)))
  29. add_loss = 0.0
  30. for t in range(iterations):
  31. batch_input = test_generator.__next__ ()
  32. batch_label = label_generator.__next__()
  33. net.zero_grad()
  34. hidden = [state.detach() for state in hidden]
  35. test_output, hidden = net(to_var(batch_input), hidden)
  36. test_loss = criterion(test_output, to_var(batch_label)).data
  37. loss_list.append(test_loss)
  38. add_loss += test_loss
  39. print("Test Loss={0:.4f}".format(float(add_loss) / iterations))
  40. print("Test PPL={0:.4f}".format(float(np.exp(add_loss / iterations))))
  41. #############################################################
  42. if __name__ == "__main__":
  43. word_embed_dim = 300
  44. char_embedding_dim = 15
  45. if os.path.exists("cache/prep.pt") is False:
  46. print("Cannot find prep.pt")
  47. objetcs = torch.load("cache/prep.pt")
  48. word_dict = objetcs["word_dict"]
  49. char_dict = objetcs["char_dict"]
  50. reverse_word_dict = objetcs["reverse_word_dict"]
  51. max_word_len = objetcs["max_word_len"]
  52. num_words = len(word_dict)
  53. print("word/char dictionary built. Start making inputs.")
  54. if os.path.exists("cache/data_sets.pt") is False:
  55. test_text = read_data("./tests.txt")
  56. test_set = np.array(text2vec(test_text, char_dict, max_word_len))
  57. # Labels are next-word index in word_dict with the same length as inputs
  58. test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]])
  59. category = {"tests": test_set, "tlabel": test_label}
  60. torch.save(category, "cache/data_sets.pt")
  61. else:
  62. data_sets = torch.load("cache/data_sets.pt")
  63. test_set = data_sets["tests"]
  64. test_label = data_sets["tlabel"]
  65. train_set = data_sets["tdata"]
  66. train_label = data_sets["trlabel"]
  67. DataTuple = namedtuple("DataTuple", "test_input test_label train_input train_label ")
  68. data = DataTuple( test_input=test_set,
  69. test_label=test_label, train_label=train_label, train_input=train_set)
  70. print("Loaded data sets. Start building network.")
  71. USE_GPU = True
  72. cnn_batch_size = 700
  73. lstm_seq_len = 35
  74. lstm_batch_size = 20
  75. net = torch.load("cache/net.pkl")
  76. Options = namedtuple("Options", [ "cnn_batch_size", "lstm_seq_len",
  77. "max_word_len", "lstm_batch_size", "word_embed_dim"])
  78. opt = Options(cnn_batch_size=lstm_seq_len*lstm_batch_size,
  79. lstm_seq_len=lstm_seq_len,
  80. max_word_len=max_word_len,
  81. lstm_batch_size=lstm_batch_size,
  82. word_embed_dim=word_embed_dim)
  83. print("Network built. Start testing.")
  84. test(net, data, opt)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等