You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 2.6 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. import
  3. import torch
  4. import torch.nn as nn
  5. .dataset as dst
  6. from .model import CNN_text
  7. from torch.autograd import Variable
  8. # Hyper Parameters
  9. batch_size = 50
  10. learning_rate = 0.0001
  11. num_epochs = 20
  12. cuda = True
  13. #split Dataset
  14. dataset = dst.MRDataset()
  15. length = len(dataset)
  16. train_dataset = dataset[:int(0.9*length)]
  17. test_dataset = dataset[int(0.9*length):]
  18. train_dataset = dst.train_set(train_dataset)
  19. test_dataset = dst.test_set(test_dataset)
  20. # Data Loader
  21. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  22. batch_size=batch_size,
  23. shuffle=True)
  24. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  25. batch_size=batch_size,
  26. shuffle=False)
  27. #cnn
  28. cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings())
  29. if cuda:
  30. cnn.cuda()
  31. # Loss and Optimizer
  32. criterion = nn.CrossEntropyLoss()
  33. optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
  34. # train and tests
  35. best_acc = None
  36. for epoch in range(num_epochs):
  37. # Train the Model
  38. cnn.train()
  39. for i, (sents,labels) in enumerate(train_loader):
  40. sents = Variable(sents)
  41. labels = Variable(labels)
  42. if cuda:
  43. sents = sents.cuda()
  44. labels = labels.cuda()
  45. optimizer.zero_grad()
  46. outputs = cnn(sents)
  47. loss = criterion(outputs, labels)
  48. loss.backward()
  49. optimizer.step()
  50. if (i+1) % 100 == 0:
  51. print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
  52. %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data[0]))
  53. # Test the Model
  54. cnn.eval()
  55. correct = 0
  56. total = 0
  57. for sents, labels in test_loader:
  58. sents = Variable(sents)
  59. if cuda:
  60. sents = sents.cuda()
  61. labels = labels.cuda()
  62. outputs = cnn(sents)
  63. _, predicted = torch.max(outputs.data, 1)
  64. total += labels.size(0)
  65. correct += (predicted == labels).sum()
  66. acc = 100. * correct / total
  67. print('Test Accuracy: %f %%' % (acc))
  68. if best_acc is None or acc > best_acc:
  69. best_acc = acc
  70. if os.path.exists("models") is False:
  71. os.makedirs("models")
  72. torch.save(cnn.state_dict(), 'models/cnn.pkl')
  73. else:
  74. learning_rate = learning_rate * 0.8
  75. print("Best Accuracy: %f %%" % best_acc)
  76. print("Best Model: models/cnn.pkl")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等