You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 2.7 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. import torchvision.datasets as dsets
  5. import torchvision.transforms as transforms
  6. import dataset as dst
  7. from model import CNN_text
  8. from torch.autograd import Variable
  9. from sklearn import cross_validation
  10. from sklearn import datasets
  11. # Hyper Parameters
  12. batch_size = 50
  13. learning_rate = 0.0001
  14. num_epochs = 20
  15. cuda = True
  16. #split Dataset
  17. dataset = dst.MRDataset()
  18. length = len(dataset)
  19. train_dataset = dataset[:int(0.9*length)]
  20. test_dataset = dataset[int(0.9*length):]
  21. train_dataset = dst.train_set(train_dataset)
  22. test_dataset = dst.test_set(test_dataset)
  23. # Data Loader
  24. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  25. batch_size=batch_size,
  26. shuffle=True)
  27. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  28. batch_size=batch_size,
  29. shuffle=False)
  30. #cnn
  31. cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings())
  32. if cuda:
  33. cnn.cuda()
  34. # Loss and Optimizer
  35. criterion = nn.CrossEntropyLoss()
  36. optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
  37. #train and test
  38. best_acc = None
  39. for epoch in range(num_epochs):
  40. # Train the Model
  41. cnn.train()
  42. for i, (sents,labels) in enumerate(train_loader):
  43. sents = Variable(sents)
  44. labels = Variable(labels)
  45. if cuda:
  46. sents = sents.cuda()
  47. labels = labels.cuda()
  48. optimizer.zero_grad()
  49. outputs = cnn(sents)
  50. loss = criterion(outputs, labels)
  51. loss.backward()
  52. optimizer.step()
  53. if (i+1) % 100 == 0:
  54. print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
  55. %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data[0]))
  56. # Test the Model
  57. cnn.eval()
  58. correct = 0
  59. total = 0
  60. for sents, labels in test_loader:
  61. sents = Variable(sents)
  62. if cuda:
  63. sents = sents.cuda()
  64. labels = labels.cuda()
  65. outputs = cnn(sents)
  66. _, predicted = torch.max(outputs.data, 1)
  67. total += labels.size(0)
  68. correct += (predicted == labels).sum()
  69. acc = 100. * correct / total
  70. print('Test Accuracy: %f %%' % (acc))
  71. if best_acc is None or acc > best_acc:
  72. best_acc = acc
  73. if os.path.exists("models") is False:
  74. os.makedirs("models")
  75. torch.save(cnn.state_dict(), 'models/cnn.pkl')
  76. else:
  77. learning_rate = learning_rate * 0.8
  78. print("Best Accuracy: %f %%" % best_acc)
  79. print("Best Model: models/cnn.pkl")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等