You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mnist.py 7.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """
  2. A deep MNIST classifier using convolutional layers.
  3. This file is a modification of the official pytorch mnist example:
  4. https://github.com/pytorch/examples/blob/master/mnist/main.py
  5. """
  6. import os
  7. import argparse
  8. import logging
  9. import sys
  10. sys.path.append('..'+ '/' + '..')
  11. from collections import OrderedDict
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. import torch.optim as optim
  16. from torchvision import datasets, transforms
  17. from pytorch.mutables import LayerChoice, InputChoice
  18. from mutator import ClassicMutator
  19. import numpy as np
  20. import time
  21. import json
  22. logger = logging.getLogger('mnist_AutoML')
  23. class Net(nn.Module):
  24. def __init__(self, hidden_size):
  25. super(Net, self).__init__()
  26. # two options of conv1
  27. self.conv1 = LayerChoice(OrderedDict([
  28. ("conv5x5", nn.Conv2d(1, 20, 5, 1)),
  29. ("conv3x3", nn.Conv2d(1, 20, 3, 1))
  30. ]), key='first_conv')
  31. # two options of mid_conv
  32. self.mid_conv = LayerChoice([
  33. nn.Conv2d(20, 20, 3, 1, padding=1),
  34. nn.Conv2d(20, 20, 5, 1, padding=2)
  35. ], key='mid_conv')
  36. self.conv2 = nn.Conv2d(20, 50, 5, 1)
  37. self.fc1 = nn.Linear(4*4*50, hidden_size)
  38. self.fc2 = nn.Linear(hidden_size, 10)
  39. # skip connection over mid_conv
  40. self.input_switch = InputChoice(n_candidates=2,
  41. n_chosen=1,
  42. key='skip')
  43. def forward(self, x):
  44. x = F.relu(self.conv1(x))
  45. x = F.max_pool2d(x, 2, 2)
  46. old_x = x
  47. x = F.relu(self.mid_conv(x))
  48. zero_x = torch.zeros_like(old_x)
  49. skip_x = self.input_switch([zero_x, old_x])
  50. x = torch.add(x, skip_x)
  51. x = F.relu(self.conv2(x))
  52. x = F.max_pool2d(x, 2, 2)
  53. x = x.view(-1, 4*4*50)
  54. x = F.relu(self.fc1(x))
  55. x = self.fc2(x)
  56. return F.log_softmax(x, dim=1)
  57. def train(args, model, device, train_loader, optimizer, epoch):
  58. model.train()
  59. for batch_idx, (data, target) in enumerate(train_loader):
  60. data, target = data.to(device), target.to(device)
  61. optimizer.zero_grad()
  62. output = model(data)
  63. loss = F.nll_loss(output, target)
  64. loss.backward()
  65. optimizer.step()
  66. if batch_idx % args['log_interval'] == 0:
  67. logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  68. epoch, batch_idx * len(data), len(train_loader.dataset),
  69. 100. * batch_idx / len(train_loader), loss.item()))
  70. def test(args, model, device, test_loader):
  71. model.eval()
  72. test_loss = 0
  73. correct = 0
  74. with torch.no_grad():
  75. for data, target in test_loader:
  76. data, target = data.to(device), target.to(device)
  77. output = model(data)
  78. # sum up batch loss
  79. test_loss += F.nll_loss(output, target, reduction='sum').item()
  80. # get the index of the max log-probability
  81. pred = output.argmax(dim=1, keepdim=True)
  82. correct += pred.eq(target.view_as(pred)).sum().item()
  83. test_loss /= len(test_loader.dataset)
  84. accuracy = 100. * correct / len(test_loader.dataset)
  85. logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  86. test_loss, correct, len(test_loader.dataset), accuracy))
  87. return accuracy
  88. def main(args):
  89. global_result={'accuarcy':[]}
  90. use_cuda = not args['no_cuda'] and torch.cuda.is_available()
  91. torch.manual_seed(args['seed'])
  92. device = torch.device("cuda" if use_cuda else "cpu")
  93. kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  94. data_dir = args['data_dir']
  95. train_loader = torch.utils.data.DataLoader(
  96. datasets.MNIST(data_dir, train=True, download=True,
  97. transform=transforms.Compose([
  98. transforms.ToTensor(),
  99. transforms.Normalize((0.1307,), (0.3081,))
  100. ])),
  101. batch_size=args['batch_size'], shuffle=True, **kwargs)
  102. test_loader = torch.utils.data.DataLoader(
  103. datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
  104. transforms.ToTensor(),
  105. transforms.Normalize((0.1307,), (0.3081,))
  106. ])),
  107. batch_size=1000, shuffle=True, **kwargs)
  108. hidden_size = args['hidden_size']
  109. model = Net(hidden_size=hidden_size).to(device)
  110. #np.random.seed(42)
  111. #x = np.random.rand(2,1,28,28).astype(np.float32)
  112. #x= torch.from_numpy(x).to(device)
  113. ClassicMutator(model,trial_id=args['trial_id'],selected_path=args["selected_space_path"],search_space_path=args["search_space_path"])
  114. #y=model(x)
  115. #print(y)
  116. optimizer = optim.SGD(model.parameters(), lr=args['lr'],
  117. momentum=args['momentum'])
  118. for epoch in range(1, args['epochs'] + 1):
  119. train(args, model, device, train_loader, optimizer, epoch)
  120. test_acc = test(args, model, device, test_loader)
  121. print({"type":"accuracy","result":{"sequence":epoch,"category":"epoch","value":test_acc}} )
  122. global_result['accuarcy'].append(test_acc)
  123. return global_result
  124. def dump_global_result(args,global_result):
  125. with open(args['result_path'], "w") as ss_file:
  126. json.dump(global_result, ss_file, sort_keys=True, indent=2)
  127. def get_params():
  128. # Training settings
  129. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  130. parser.add_argument("--data_dir", type=str,
  131. default='./data', help="data directory")
  132. parser.add_argument("--selected_space_path", type=str,
  133. default='./selected_space.json', help="selected_space_path")
  134. parser.add_argument("--search_space_path", type=str,
  135. default='./selected_space.json', help="search_space_path")
  136. parser.add_argument("--result_path", type=str,
  137. default='./result.json', help="result_path")
  138. parser.add_argument('--batch_size', type=int, default=64, metavar='N',
  139. help='input batch size for training (default: 64)')
  140. parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
  141. help='hidden layer size (default: 512)')
  142. parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
  143. help='learning rate (default: 0.01)')
  144. parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
  145. help='SGD momentum (default: 0.5)')
  146. parser.add_argument('--epochs', type=int, default=10, metavar='N',
  147. help='number of epochs to train (default: 10)')
  148. parser.add_argument('--seed', type=int, default=1, metavar='S',
  149. help='random seed (default: 1)')
  150. parser.add_argument('--no_cuda', action='store_true', default=False,
  151. help='disables CUDA training')
  152. parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
  153. help='how many batches to wait before logging training status')
  154. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  155. help='trial_id,start from 0')
  156. args, _ = parser.parse_known_args()
  157. return args
  158. if __name__ == '__main__':
  159. try:
  160. start=time.time()
  161. params = vars(get_params())
  162. global_result = main(params)
  163. global_result['cost_time'] = str(time.time() - start) +'s'
  164. dump_global_result(params,global_result)
  165. except Exception as exception:
  166. logger.exception(exception)
  167. raise

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能