You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pdartstrainer.py 7.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import os
  2. import logging
  3. import torch
  4. import torch.nn as nn
  5. import numpy as np
  6. from collections import OrderedDict
  7. import json
  8. from pytorch.callbacks import LRSchedulerCallback
  9. from pytorch.trainer import BaseTrainer, TorchTensorEncoder
  10. from pytorch.utils import dump_global_result
  11. from model import CNN
  12. from pdartsmutator import PdartsMutator
  13. from pytorch.darts.utils import accuracy
  14. from pytorch.darts import datasets
  15. from pytorch.darts.dartstrainer import DartsTrainer
  16. logger = logging.getLogger(__name__)
  17. class PdartsTrainer(BaseTrainer):
  18. """
  19. This trainer implements the PDARTS algorithm.
  20. PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
  21. This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
  22. pdarts_num_layers means how many layers more than first epoch.
  23. pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
  24. So that the grew network can in similar size.
  25. """
  26. def __init__(self, init_layers, pdarts_num_layers, pdarts_num_to_drop, pdarts_dropout_rates, num_epochs, num_pre_epochs, model_lr, class_num,
  27. arch_lr, channels, batch_size, result_path, log_frequency, unrolled, data_dir, search_space_path,
  28. best_selected_space_path, device=None, workers=4):
  29. super(PdartsTrainer, self).__init__()
  30. self.init_layers = init_layers
  31. self.class_num = class_num
  32. self.channels = channels
  33. self.model_lr = model_lr
  34. self.num_epochs = num_epochs
  35. self.class_num = class_num
  36. self.pdarts_num_layers = pdarts_num_layers
  37. self.pdarts_num_to_drop = pdarts_num_to_drop
  38. self.pdarts_dropout_rates = pdarts_dropout_rates
  39. self.pdarts_epoches = len(pdarts_num_to_drop)
  40. self.search_space_path = search_space_path
  41. self.best_selected_space_path = best_selected_space_path
  42. logger.info("loading data")
  43. dataset_train, dataset_valid = datasets.get_dataset(
  44. "cifar10", root=data_dir)
  45. self.darts_parameters = {
  46. "metrics": lambda output, target: accuracy(output, target, topk=(1,)),
  47. "arch_lr": arch_lr,
  48. "num_epochs": num_epochs,
  49. "num_pre_epochs": num_pre_epochs,
  50. "dataset_train": dataset_train,
  51. "dataset_valid": dataset_valid,
  52. "batch_size": batch_size,
  53. "result_path": result_path,
  54. "workers": workers,
  55. "device": device,
  56. "log_frequency": log_frequency,
  57. "unrolled": unrolled,
  58. "search_space_path": None
  59. }
  60. def train(self, validate=False):
  61. switches = None
  62. last = False
  63. for epoch in range(self.pdarts_epoches):
  64. if epoch == self.pdarts_epoches - 1:
  65. last = True
  66. # create network for each stage
  67. layers = self.init_layers + self.pdarts_num_layers[epoch]
  68. init_dropout_rate = float(self.pdarts_dropout_rates[epoch])
  69. model = CNN(32, 3, self.channels, self.class_num, layers,
  70. init_dropout_rate, n_nodes=4, search=True)
  71. criterion = nn.CrossEntropyLoss()
  72. optim = torch.optim.SGD(
  73. model.parameters(), self.model_lr, momentum=0.9, weight_decay=3.0E-4)
  74. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
  75. optim, self.num_epochs, eta_min=0.001)
  76. logger.info(
  77. "############Start PDARTS training epoch %s############", epoch)
  78. self.mutator = PdartsMutator(
  79. model, epoch, self.pdarts_num_to_drop, switches)
  80. if epoch == 0:
  81. # only write original search space in first stage
  82. search_space = self.mutator._generate_search_space()
  83. dump_global_result(self.search_space_path,
  84. search_space)
  85. darts_callbacks = []
  86. if lr_scheduler is not None:
  87. darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
  88. # darts_callbacks.append(ArchitectureCheckpoint(
  89. # os.path.join(self.selected_space_path, "stage_{}".format(epoch))))
  90. self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion,
  91. optimizer=optim, callbacks=darts_callbacks, **self.darts_parameters)
  92. for train_epoch in range(self.darts_parameters["num_epochs"]):
  93. for callback in darts_callbacks:
  94. callback.on_epoch_begin(train_epoch)
  95. # training
  96. logger.info("Epoch %d Training", train_epoch)
  97. if train_epoch < self.darts_parameters["num_pre_epochs"]:
  98. dropout_rate = init_dropout_rate * \
  99. (self.darts_parameters["num_epochs"] - train_epoch -
  100. 1) / self.darts_parameters["num_epochs"]
  101. else:
  102. # scale_factor = 0.2
  103. dropout_rate = init_dropout_rate * \
  104. np.exp(-(epoch -
  105. self.darts_parameters["num_pre_epochs"]) * 0.2)
  106. model.drop_path_prob(search=True, p=dropout_rate)
  107. self.trainer.train_one_epoch(train_epoch)
  108. if validate:
  109. # validation
  110. logger.info("Epoch %d Validating", train_epoch + 1)
  111. self.trainer.validate_one_epoch(
  112. train_epoch, log_print=True if last else False)
  113. for callback in darts_callbacks:
  114. callback.on_epoch_end(train_epoch)
  115. switches = self.mutator.drop_paths()
  116. # In last pdarts_epoches, need to restrict skipconnection and save best structure
  117. if last:
  118. res = OrderedDict()
  119. op_value = [value for value in search_space["op_list"]["_value"] if value != 'none']
  120. res["op_list"] = search_space["op_list"]
  121. res["op_list"]["_value"] = op_value
  122. res["best_selected_space"] = self.mutator.export(last, switches)
  123. logger.info(res)
  124. dump_global_result(self.best_selected_space_path, res)
  125. def validate(self):
  126. self.trainer.validate()
  127. def export(self, file, last, switches):
  128. self.mutator.export(last, switches)
  129. mutator_export = self.mutator.export()
  130. with open(file, "w") as f:
  131. json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
  132. def checkpoint(self, file_path, epoch):
  133. if isinstance(self.model, nn.DataParallel):
  134. child_model_state_dict = self.model.module.state_dict()
  135. else:
  136. child_model_state_dict = self.model.state_dict()
  137. save_state = {'child_model_state_dict': child_model_state_dict,
  138. 'optimizer_state_dict': self.optimizer.state_dict(),
  139. 'epoch': epoch}
  140. dest_path = os.path.join(
  141. file_path, "best_checkpoint_epoch_{}.pth.tar".format(epoch))
  142. logger.info("Saving model to %s", dest_path)
  143. torch.save(save_state, dest_path)
  144. raise NotImplementedError("Not implemented yet")

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能