You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 2.2 kB

7 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from model import *
  2. from train import *
  3. def evaluate(net, dataset, bactch_size=64, use_cuda=False):
  4. dataloader = DataLoader(dataset, batch_size=bactch_size, collate_fn=collate, num_workers=0)
  5. count = 0
  6. if use_cuda:
  7. net.cuda()
  8. for i, batch_samples in enumerate(dataloader):
  9. x, y = batch_samples
  10. doc_list = []
  11. for sample in x:
  12. doc = []
  13. for sent_vec in sample:
  14. # print(sent_vec.size())
  15. if use_cuda:
  16. sent_vec = sent_vec.cuda()
  17. doc.append(Variable(sent_vec, volatile=True))
  18. doc_list.append(pack_sequence(doc))
  19. if use_cuda:
  20. y = y.cuda()
  21. predicts = net(doc_list)
  22. # idx = []
  23. # for p in predicts.data:
  24. # idx.append(np.random.choice(5, p=torch.exp(p).numpy()))
  25. # idx = torch.LongTensor(idx)
  26. p, idx = torch.max(predicts, dim=1)
  27. idx = idx.data
  28. count += torch.sum(torch.eq(idx, y))
  29. return count
  30. def visualize_attention(doc, alpha_vec):
  31. pass
  32. if __name__ == '__main__':
  33. from gensim.models import Word2Vec
  34. import gensim
  35. from gensim import models
  36. embed_model = Word2Vec.load('yelp.word2vec')
  37. embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
  38. del embed_model
  39. net = HAN(input_size=200, output_size=5,
  40. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  41. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  42. net.load_state_dict(torch.load('model.dict'))
  43. test_dataset = YelpDocSet('reviews', 199, 4, embedding)
  44. correct = evaluate(net, test_dataset, True)
  45. print('accuracy {}'.format(correct/len(test_dataset)))
  46. # data_idx = 121
  47. # x, y = test_dataset[data_idx]
  48. # doc = []
  49. # for sent_vec in x:
  50. # doc.append(Variable(sent_vec, volatile=True))
  51. # input_vec = [pack_sequence(doc)]
  52. # predict = net(input_vec)
  53. # p, idx = torch.max(predict, dim=1)
  54. # print(net.word_layer.last_alpha.squeeze())
  55. # print(net.sent_layer.last_alpha)
  56. # print(test_dataset.get_doc(data_idx)[0])
  57. # print('predict: {}, true: {}'.format(int(idx), y))

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等