You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 2.6 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. import torchvision.datasets as dsets
  5. import torchvision.transforms as transforms
  6. import dataset as dst
  7. from model import CNN_text
  8. from torch.autograd import Variable
  9. from sklearn import cross_validation
  10. from sklearn import datasets
  11. # Hyper Parameters
  12. batch_size = 50
  13. learning_rate = 0.0001
  14. num_epochs = 20
  15. cuda = True
  16. #split Dataset
  17. dataset = dst.MRDataset()
  18. length = len(dataset)
  19. train_dataset = dataset[:int(0.9*length)]
  20. test_dataset = dataset[int(0.9*length):]
  21. train_dataset = dst.train_set(train_dataset)
  22. test_dataset = dst.test_set(test_dataset)
  23. # Data Loader
  24. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  25. batch_size=batch_size,
  26. shuffle=True)
  27. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  28. batch_size=batch_size,
  29. shuffle=False)
  30. cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings())
  31. if cuda:
  32. cnn.cuda()
  33. # Loss and Optimizer
  34. criterion = nn.CrossEntropyLoss()
  35. optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
  36. best_acc = None
  37. for epoch in range(num_epochs):
  38. # Train the Model
  39. cnn.train()
  40. for i, (sents,labels) in enumerate(train_loader):
  41. sents = Variable(sents)
  42. labels = Variable(labels)
  43. if cuda:
  44. sents = sents.cuda()
  45. labels = labels.cuda()
  46. optimizer.zero_grad()
  47. outputs = cnn(sents)
  48. loss = criterion(outputs, labels)
  49. loss.backward()
  50. optimizer.step()
  51. if (i+1) % 100 == 0:
  52. print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
  53. %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data[0]))
  54. # Test the Model
  55. cnn.eval()
  56. correct = 0
  57. total = 0
  58. for sents, labels in test_loader:
  59. sents = Variable(sents)
  60. if cuda:
  61. sents = sents.cuda()
  62. labels = labels.cuda()
  63. outputs = cnn(sents)
  64. _, predicted = torch.max(outputs.data, 1)
  65. total += labels.size(0)
  66. correct += (predicted == labels).sum()
  67. acc = 100. * correct / total
  68. print('Test Accuracy: %f %%' % (acc))
  69. if best_acc is None or acc > best_acc:
  70. best_acc = acc
  71. if os.path.exists("models") is False:
  72. os.makedirs("models")
  73. torch.save(cnn.state_dict(), 'models/cnn.pkl')
  74. else:
  75. learning_rate = learning_rate * 0.8
  76. print("Best Accuracy: %f %%" % best_acc)
  77. print("Best Model: models/cnn.pkl")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等