You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

retrainer.py 11 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import copy
  4. import logging
  5. import os
  6. import argparse
  7. import logging
  8. import sys
  9. sys.path.append('..'+ '/' + '..')
  10. from collections import OrderedDict
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from torchvision import datasets, transforms
  15. from model import Net
  16. from pytorch.trainer import Trainer
  17. from pytorch.utils import AverageMeterGroup
  18. from pytorch.utils import mkdirs
  19. from pytorch.mutables import LayerChoice, InputChoice
  20. from fixed import apply_fixed_architecture
  21. from mutator import ClassicMutator
  22. from abc import ABC, abstractmethod
  23. from pytorch.retrainer import Retrainer
  24. import numpy as np
  25. import time
  26. import json
  27. logger = logging.getLogger(__name__)
  28. #logger.setLevel(logging.INFO)
  29. class ClassicnasRetrainer(Retrainer):
  30. """
  31. Classicnas trainer.
  32. Parameters
  33. ----------
  34. model : nn.Module
  35. PyTorch model to be trained.
  36. loss : callable
  37. Receives logits and ground truth label, return a loss tensor.
  38. metrics : callable
  39. Receives logits and ground truth label, return a dict of metrics.
  40. optimizer : Optimizer
  41. The optimizer used for optimizing the model.
  42. epochs : int
  43. Number of epochs planned for training.
  44. dataset_train : Dataset
  45. Dataset for training. Will be split for training weights and architecture weights.
  46. dataset_valid : Dataset
  47. Dataset for testing.
  48. mutator : ClassicMutator
  49. Use in case of customizing your own ClassicMutator. By default will instantiate a ClassicMutator.
  50. batch_size : int
  51. Batch size.
  52. workers : int
  53. Workers for data loading.
  54. device : torch.device
  55. ``torch.device("cpu")`` or ``torch.device("cuda")``.
  56. log_frequency : int
  57. Step count per logging.
  58. callbacks : list of Callback
  59. list of callbacks to trigger at events.
  60. arc_learning_rate : float
  61. Learning rate of architecture parameters.
  62. unrolled : float
  63. ``True`` if using second order optimization, else first order optimization.
  64. """
  65. def __init__(self, model, loss, metrics,
  66. optimizer, epochs, dataset_train, dataset_valid, search_space_path,selected_space_path,checkpoint_dir,trial_id,
  67. mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
  68. callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
  69. self.model = model
  70. self.loss = loss
  71. self.metrics = metrics
  72. self.optimizer = optimizer
  73. self.epochs = epochs
  74. self.device = device
  75. self.batch_size = batch_size
  76. self.checkpoint_dir =checkpoint_dir
  77. self.train_loader = torch.utils.data.DataLoader(
  78. datasets.MNIST(dataset_train, train=True, download=True,
  79. transform=transforms.Compose([
  80. transforms.ToTensor(),
  81. transforms.Normalize((0.1307,), (0.3081,))
  82. ])),
  83. batch_size=batch_size, shuffle=True, **kwargs)
  84. self.test_loader = torch.utils.data.DataLoader(
  85. datasets.MNIST(dataset_valid, train=False, transform=transforms.Compose([
  86. transforms.ToTensor(),
  87. transforms.Normalize((0.1307,), (0.3081,))
  88. ])),
  89. batch_size=1000, shuffle=True, **kwargs)
  90. self.search_space_path = search_space_path
  91. self.selected_space_path =selected_space_path
  92. self.trial_id = trial_id
  93. self.result = {"accuracy": [],"cost_time": 0.}
  94. def train(self):
  95. # t1 = time()
  96. # phase 1. architecture step
  97. #print(self.model.state_dict)
  98. apply_fixed_architecture(self.model, self.selected_space_path)
  99. #print(self.model.state_dict)
  100. # phase 2: child network step
  101. for child_epoch in range(1, self.epochs + 1):
  102. self.model.train()
  103. for batch_idx, (data, target) in enumerate(self.train_loader):
  104. data, target = data.to(self.device), target.to(self.device)
  105. optimizer.zero_grad()
  106. output = self.model(data)
  107. loss = F.nll_loss(output, target)
  108. loss.backward()
  109. optimizer.step()
  110. if batch_idx % args['log_interval'] == 0:
  111. logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  112. child_epoch, batch_idx * len(data), len(self.train_loader.dataset),
  113. 100. * batch_idx / len(self.train_loader), loss.item()))
  114. test_acc = self.validate()
  115. print({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} )
  116. with open(args['result_path'], "a") as ss_file:
  117. ss_file.write(json.dumps({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + '\n')
  118. self.result['accuracy'].append(test_acc)
  119. def validate(self):
  120. self.model.eval()
  121. test_loss = 0
  122. correct = 0
  123. with torch.no_grad():
  124. for data, target in self.test_loader:
  125. data, target = data.to(self.device), target.to(self.device)
  126. output = self.model(data)
  127. # sum up batch loss
  128. test_loss += F.nll_loss(output, target, reduction='sum').item()
  129. # get the index of the max log-probability
  130. pred = output.argmax(dim=1, keepdim=True)
  131. correct += pred.eq(target.view_as(pred)).sum().item()
  132. test_loss /= len(self.test_loader.dataset)
  133. accuracy = 100. * correct / len(self.test_loader.dataset)
  134. logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  135. test_loss, correct, len(self.test_loader.dataset), accuracy))
  136. return accuracy
  137. #
  138. # def export(self, file):
  139. # """
  140. # Override the method to export to file.
  141. # Parameters
  142. # ----------
  143. # file : str
  144. # File path to export to.
  145. # """
  146. # raise NotImplementedError
  147. def checkpoint(self):
  148. """
  149. Override to dump a checkpoint.
  150. """
  151. if isinstance(self.model, nn.DataParallel):
  152. state_dict = self.model.module.state_dict()
  153. else:
  154. state_dict = self.model.state_dict()
  155. if not os.path.exists(self.checkpoint_dir):
  156. os.makedirs(self.checkpoint_dir)
  157. dest_path = os.path.join(self.checkpoint_dir, f"best_checkpoint_epoch{self.epochs}.pth")
  158. logger.info("Saving model to %s", dest_path)
  159. torch.save(state_dict, dest_path)
  160. def dump_global_result(args,global_result):
  161. with open(args['result_path'], "w") as ss_file:
  162. json.dump(global_result, ss_file, sort_keys=True, indent=2)
  163. def get_params():
  164. # Training settings
  165. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  166. parser.add_argument("--data_dir", type=str,
  167. default='./data', help="data directory")
  168. parser.add_argument("--best_selected_space_path", type=str,
  169. default='./selected_space.json', help="selected_space_path")
  170. parser.add_argument("--search_space_path", type=str,
  171. default='./selected_space.json', help="search_space_path")
  172. parser.add_argument("--result_path", type=str,
  173. default='./result.json', help="result_path")
  174. parser.add_argument('--batch_size', type=int, default=64, metavar='N',
  175. help='input batch size for training (default: 64)')
  176. parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
  177. help='hidden layer size (default: 512)')
  178. parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
  179. help='learning rate (default: 0.01)')
  180. parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
  181. help='SGD momentum (default: 0.5)')
  182. parser.add_argument('--epochs', type=int, default=10, metavar='N',
  183. help='number of epochs to train (default: 10)')
  184. parser.add_argument('--seed', type=int, default=1, metavar='S',
  185. help='random seed (default: 1)')
  186. parser.add_argument('--no_cuda', default=False,
  187. help='disables CUDA training')
  188. parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
  189. help='how many batches to wait before logging training status')
  190. parser.add_argument("--best_checkpoint_dir",type=str,default="path/to/",
  191. help="Path for saved checkpoints. (default: %(default)s)")
  192. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  193. help='trial_id,start from 0')
  194. args, _ = parser.parse_known_args()
  195. return args
  196. if __name__ == '__main__':
  197. try:
  198. start=time.time()
  199. params = vars(get_params())
  200. args =params
  201. use_cuda = not args['no_cuda'] and torch.cuda.is_available()
  202. torch.manual_seed(args['seed'])
  203. device = torch.device("cuda" if use_cuda else "cpu")
  204. kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  205. data_dir = args['data_dir']
  206. hidden_size = args['hidden_size']
  207. model = Net(hidden_size=hidden_size).to(device)
  208. optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'],
  209. momentum=args['momentum'])
  210. #mkdirs(args['search_space_path'])
  211. mkdirs(args['best_selected_space_path'])
  212. mkdirs(args['result_path'])
  213. trainer = ClassicnasRetrainer(model,
  214. loss=None,
  215. metrics=None,
  216. optimizer=optimizer,
  217. epochs=args['epochs'],
  218. dataset_train=data_dir,
  219. dataset_valid=data_dir,
  220. search_space_path = args['search_space_path'],
  221. selected_space_path = args['best_selected_space_path'],
  222. checkpoint_dir = args['best_checkpoint_dir'],
  223. trial_id = args['trial_id'],
  224. batch_size=args['batch_size'],
  225. log_frequency=args['log_interval'],
  226. device= device,
  227. unrolled=None,
  228. callbacks=None)
  229. with open(args['result_path'], "w") as ss_file:
  230. ss_file.write('')
  231. trainer.train()
  232. trainer.checkpoint()
  233. global_result = trainer.result
  234. global_result['cost_time'] = str(time.time() - start) +'s'
  235. #dump_global_result(params,global_result)
  236. except Exception as exception:
  237. logger.exception(exception)
  238. raise

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能