You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 10 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import copy
  4. import logging
  5. import os
  6. import argparse
  7. import logging
  8. import sys
  9. sys.path.append('..'+ '/' + '..')
  10. from collections import OrderedDict
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from torchvision import datasets, transforms
  15. from model import Net
  16. from pytorch.trainer import Trainer
  17. from pytorch.utils import AverageMeterGroup
  18. from pytorch.utils import mkdirs
  19. from pytorch.mutables import LayerChoice, InputChoice
  20. from mutator import ClassicMutator
  21. import numpy as np
  22. import time
  23. import json
  24. logger = logging.getLogger(__name__)
  25. #logger.setLevel(logging.INFO)
  26. class ClassicnasTrainer(Trainer):
  27. """
  28. Classicnas trainer.
  29. Parameters
  30. ----------
  31. model : nn.Module
  32. PyTorch model to be trained.
  33. loss : callable
  34. Receives logits and ground truth label, return a loss tensor.
  35. metrics : callable
  36. Receives logits and ground truth label, return a dict of metrics.
  37. optimizer : Optimizer
  38. The optimizer used for optimizing the model.
  39. num_epochs : int
  40. Number of epochs planned for training.
  41. dataset_train : Dataset
  42. Dataset for training. Will be split for training weights and architecture weights.
  43. dataset_valid : Dataset
  44. Dataset for testing.
  45. mutator : ClassicMutator
  46. Use in case of customizing your own ClassicMutator. By default will instantiate a ClassicMutator.
  47. batch_size : int
  48. Batch size.
  49. workers : int
  50. Workers for data loading.
  51. device : torch.device
  52. ``torch.device("cpu")`` or ``torch.device("cuda")``.
  53. log_frequency : int
  54. Step count per logging.
  55. callbacks : list of Callback
  56. list of callbacks to trigger at events.
  57. arc_learning_rate : float
  58. Learning rate of architecture parameters.
  59. unrolled : float
  60. ``True`` if using second order optimization, else first order optimization.
  61. """
  62. def __init__(self, model, loss, metrics,
  63. optimizer, epochs, dataset_train, dataset_valid, search_space_path,selected_space_path,trial_id,
  64. mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
  65. callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
  66. self.model = model
  67. self.loss = loss
  68. self.metrics = metrics
  69. self.optimizer = optimizer
  70. self.epochs = epochs
  71. self.device = device
  72. self.batch_size = batch_size
  73. self.train_loader = torch.utils.data.DataLoader(
  74. datasets.MNIST(dataset_train, train=True, download=False,
  75. transform=transforms.Compose([
  76. transforms.ToTensor(),
  77. transforms.Normalize((0.1307,), (0.3081,))
  78. ])),
  79. batch_size=batch_size, shuffle=True, **kwargs)
  80. self.test_loader = torch.utils.data.DataLoader(
  81. datasets.MNIST(dataset_valid, train=False, transform=transforms.Compose([
  82. transforms.ToTensor(),
  83. transforms.Normalize((0.1307,), (0.3081,))
  84. ])),
  85. batch_size=1000, shuffle=True, **kwargs)
  86. self.search_space_path = search_space_path
  87. self.selected_space_path =selected_space_path
  88. self.trial_id = trial_id
  89. self.num_epochs = 10
  90. self.classicmutator=ClassicMutator(self.model,trial_id=self.trial_id,selected_path=self.selected_space_path,search_space_path=self.search_space_path)
  91. self.result = {"accuracy": [],"cost_time": 0.}
  92. def train_one_epoch(self, epoch):
  93. # t1 = time()
  94. # phase 1. architecture step
  95. self.classicmutator.trial_id = epoch
  96. self.classicmutator._chosen_arch=self.classicmutator.random_generate_chosen()
  97. #print('epoch:',epoch,'\n',self.classicmutator._chosen_arch)
  98. # phase 2: child network step
  99. for child_epoch in range(1, self.epochs + 1):
  100. self.model.train()
  101. for batch_idx, (data, target) in enumerate(self.train_loader):
  102. data, target = data.to(self.device), target.to(self.device)
  103. optimizer.zero_grad()
  104. output = self.model(data)
  105. loss = F.nll_loss(output, target)
  106. loss.backward()
  107. optimizer.step()
  108. if batch_idx % args['log_interval'] == 0:
  109. logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  110. child_epoch, batch_idx * len(data), len(self.train_loader.dataset),
  111. 100. * batch_idx / len(self.train_loader), loss.item()))
  112. test_acc = self.validate_one_epoch(epoch)
  113. print({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} )
  114. with open(args['result_path'], "a") as ss_file:
  115. ss_file.write(json.dumps({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + '\n')
  116. self.result['accuracy'].append(test_acc)
  117. def validate_one_epoch(self, epoch):
  118. self.model.eval()
  119. test_loss = 0
  120. correct = 0
  121. with torch.no_grad():
  122. for data, target in self.test_loader:
  123. data, target = data.to(self.device), target.to(self.device)
  124. output = self.model(data)
  125. # sum up batch loss
  126. test_loss += F.nll_loss(output, target, reduction='sum').item()
  127. # get the index of the max log-probability
  128. pred = output.argmax(dim=1, keepdim=True)
  129. correct += pred.eq(target.view_as(pred)).sum().item()
  130. test_loss /= len(self.test_loader.dataset)
  131. accuracy = 100. * correct / len(self.test_loader.dataset)
  132. logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  133. test_loss, correct, len(self.test_loader.dataset), accuracy))
  134. return accuracy
  135. def train(self):
  136. """
  137. Train ``num_epochs``.
  138. Trigger callbacks at the start and the end of each epoch.
  139. Parameters
  140. ----------
  141. validate : bool
  142. If ``true``, will do validation every epoch.
  143. """
  144. for epoch in range(self.num_epochs):
  145. # training
  146. self.train_one_epoch(epoch)
  147. def dump_global_result(args,global_result):
  148. with open(args['result_path'], "w") as ss_file:
  149. json.dump(global_result, ss_file, sort_keys=True, indent=2)
  150. def get_params():
  151. # Training settings
  152. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  153. parser.add_argument("--data_dir", type=str,
  154. default='./data', help="data directory")
  155. parser.add_argument("--model_selected_space_path", type=str,
  156. default='./selected_space.json', help="selected_space_path")
  157. parser.add_argument("--search_space_path", type=str,
  158. default='./selected_space.json', help="search_space_path")
  159. parser.add_argument("--result_path", type=str,
  160. default='./model_result.json', help="result_path")
  161. parser.add_argument('--batch_size', type=int, default=64, metavar='N',
  162. help='input batch size for training (default: 64)')
  163. parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
  164. help='hidden layer size (default: 512)')
  165. parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
  166. help='learning rate (default: 0.01)')
  167. parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
  168. help='SGD momentum (default: 0.5)')
  169. parser.add_argument('--epochs', type=int, default=10, metavar='N',
  170. help='number of epochs to train (default: 10)')
  171. parser.add_argument('--seed', type=int, default=1, metavar='S',
  172. help='random seed (default: 1)')
  173. parser.add_argument('--no_cuda', default=False,
  174. help='disables CUDA training')
  175. parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
  176. help='how many batches to wait before logging training status')
  177. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  178. help='trial_id,start from 0')
  179. args, _ = parser.parse_known_args()
  180. return args
  181. if __name__ == '__main__':
  182. try:
  183. start=time.time()
  184. params = vars(get_params())
  185. args =params
  186. use_cuda = not args['no_cuda'] and torch.cuda.is_available()
  187. torch.manual_seed(args['seed'])
  188. device = torch.device("cuda" if use_cuda else "cpu")
  189. kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  190. data_dir = args['data_dir']
  191. hidden_size = args['hidden_size']
  192. model = Net(hidden_size=hidden_size).to(device)
  193. optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'],
  194. momentum=args['momentum'])
  195. mkdirs(args['search_space_path'])
  196. mkdirs(args['model_selected_space_path'])
  197. mkdirs(args['result_path'])
  198. trainer = ClassicnasTrainer(model,
  199. loss=None,
  200. metrics=None,
  201. optimizer=optimizer,
  202. epochs=args['epochs'],
  203. dataset_train=data_dir,
  204. dataset_valid=data_dir,
  205. search_space_path = args['search_space_path'],
  206. selected_space_path = args['model_selected_space_path'],
  207. trial_id = args['trial_id'],
  208. batch_size=args['batch_size'],
  209. log_frequency=args['log_interval'],
  210. device= device,
  211. unrolled=None,
  212. callbacks=None)
  213. with open(args['result_path'], "w") as ss_file:
  214. ss_file.write('')
  215. trainer.train_one_epoch(args['trial_id'])
  216. #trainer.train()
  217. global_result = trainer.result
  218. #global_result['cost_time'] = str(time.time() - start) +'s'
  219. #dump_global_result(params,global_result)
  220. except Exception as exception:
  221. logger.exception(exception)
  222. raise

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能