You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 1.8 kB

7 years ago
7 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from model import *
  2. from train import *
  3. def evaluate(net, dataset, bactch_size=64, use_cuda=False):
  4. dataloader = DataLoader(dataset, batch_size=bactch_size, collate_fn=collate, num_workers=0)
  5. count = 0
  6. if use_cuda:
  7. net.cuda()
  8. for i, batch_samples in enumerate(dataloader):
  9. x, y = batch_samples
  10. doc_list = []
  11. for sample in x:
  12. doc = []
  13. for sent_vec in sample:
  14. # print(sent_vec.size())
  15. if use_cuda:
  16. sent_vec = sent_vec.cuda()
  17. doc.append(Variable(sent_vec, volatile=True))
  18. doc_list.append(pack_sequence(doc))
  19. if use_cuda:
  20. y = y.cuda()
  21. predicts = net(doc_list)
  22. # idx = []
  23. # for p in predicts.data:
  24. # idx.append(np.random.choice(5, p=torch.exp(p).numpy()))
  25. # idx = torch.LongTensor(idx)
  26. p, idx = torch.max(predicts, dim=1)
  27. idx = idx.data
  28. count += torch.sum(torch.eq(idx, y))
  29. return count
  30. if __name__ == '__main__':
  31. '''
  32. Evaluate the performance of model
  33. '''
  34. from gensim.models import Word2Vec
  35. import gensim
  36. from gensim import models
  37. embed_model = Word2Vec.load('yelp.word2vec')
  38. embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
  39. del embed_model
  40. net = HAN(input_size=200, output_size=5,
  41. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  42. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  43. net.load_state_dict(torch.load('model.dict'))
  44. test_dataset = YelpDocSet('reviews', 199, 4, embedding)
  45. correct = evaluate(net, test_dataset, True)
  46. print('accuracy {}'.format(correct/len(test_dataset)))

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等