# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import os import json import numpy as np import torch import logging from copy import deepcopy from pytorch.trainer import Trainer from pytorch.utils import AverageMeterGroup from .utils import accuracy, reduce_metrics logger = logging.getLogger(__name__) class CreamSupernetTrainer(Trainer): """ This trainer trains a supernet and output prioritized architectures that can be used for other tasks. Parameters ---------- model : nn.Module Model with mutables. loss : callable Called with logits and targets. Returns a loss tensor. val_loss : callable Called with logits and targets for validation only. Returns a loss tensor. optimizer : Optimizer Optimizer that optimizes the model. num_epochs : int Number of epochs of training. train_loader : iterablez Data loader of training. Raise ``StopIteration`` when one epoch is exhausted. valid_loader : iterablez Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted. mutator : Mutator A mutator object that has been initialized with the model. batch_size : int Batch size. log_frequency : int Number of mini-batches to log metrics. meta_sta_epoch : int start epoch of using meta matching network to pick teacher architecture update_iter : int interval of updating meta matching networks slices : int batch size of mini training data in the process of training meta matching network pool_size : int board size pick_method : basestring how to pick teacher network choice_num : int number of operations in supernet sta_num : int layer number of each stage in supernet (5 stage in supernet) acc_gap : int maximum accuracy improvement to omit the limitation of flops flops_dict : Dict dictionary of each layer's operations in supernet flops_fixed : int flops of fixed part in supernet local_rank : int index of current rank callbacks : list of Callback Callbacks to plug into the trainer. See Callbacks. """ def __init__(self, selected_space, model, loss, val_loss, optimizer, num_epochs, train_loader, valid_loader, mutator=None, batch_size=64, log_frequency=None, meta_sta_epoch=20, update_iter=200, slices=2, pool_size=10, pick_method='meta', choice_num=6, sta_num=(4, 4, 4, 4, 4), acc_gap=5, flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None, result_path=None): assert torch.cuda.is_available() super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None, optimizer, num_epochs, None, None, batch_size, None, None, log_frequency, callbacks) self.selected_space = selected_space self.model = model self.loss = loss self.val_loss = val_loss self.train_loader = train_loader self.valid_loader = valid_loader self.log_frequency = log_frequency self.batch_size = batch_size self.optimizer = optimizer self.model = model self.loss = loss self.num_epochs = num_epochs self.meta_sta_epoch = meta_sta_epoch self.update_iter = update_iter self.slices = slices self.pick_method = pick_method self.pool_size = pool_size self.local_rank = local_rank self.choice_num = choice_num self.sta_num = sta_num self.acc_gap = acc_gap self.flops_dict = flops_dict self.flops_fixed = flops_fixed self.current_student_arch = None self.current_teacher_arch = None self.main_proc = (local_rank == 0) self.current_epoch = 0 self.prioritized_board = [] self.result_path = result_path # size of prioritized board def _board_size(self): return len(self.prioritized_board) # select teacher architecture according to the logit difference def _select_teacher(self): self._replace_mutator_cand(self.current_student_arch) if self.pick_method == 'top1': meta_value, teacher_cand = 0.5, sorted( self.prioritized_board, reverse=True)[0][3] elif self.pick_method == 'meta': meta_value, cand_idx, teacher_cand = -1000000000, -1, None for now_idx, item in enumerate(self.prioritized_board): inputx = item[4] output = torch.nn.functional.softmax(self.model(inputx), dim=1) weight = self.model.forward_meta(output - item[5]) if weight > meta_value: meta_value = weight cand_idx = now_idx teacher_cand = self.prioritized_board[cand_idx][3] assert teacher_cand is not None meta_value = torch.nn.functional.sigmoid(-weight) else: raise ValueError('Method Not supported') return meta_value, teacher_cand # check whether to update prioritized board def _isUpdateBoard(self, prec1, flops): if self.current_epoch <= self.meta_sta_epoch: return False if len(self.prioritized_board) < self.pool_size: return True if prec1 > self.prioritized_board[-1][1] + self.acc_gap: return True if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]: return True return False # update prioritized board def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops): if self._isUpdateBoard(prec1, flops): val_prec1 = prec1 training_data = deepcopy(inputs[:self.slices].detach()) if len(self.prioritized_board) == 0: features = deepcopy(outputs[:self.slices].detach()) else: features = deepcopy(teacher_output[:self.slices].detach()) self.prioritized_board.append( (val_prec1, prec1, flops, self.current_student_arch, training_data, torch.nn.functional.softmax( features, dim=1))) self.prioritized_board = sorted( self.prioritized_board, reverse=True) if len(self.prioritized_board) > self.pool_size: self.prioritized_board = sorted( self.prioritized_board, reverse=True) del self.prioritized_board[-1] # only update student network weights def _update_student_weights_only(self, grad_1): for weight, grad_item in zip( self.model.module.rand_parameters(self.current_student_arch), grad_1): weight.grad = grad_item torch.nn.utils.clip_grad_norm_( self.model.module.rand_parameters(self.current_student_arch), 1) self.optimizer.step() for weight, grad_item in zip( self.model.module.rand_parameters(self.current_student_arch), grad_1): del weight.grad # only update meta networks weights def _update_meta_weights_only(self, teacher_cand, grad_teacher): for weight, grad_item in zip(self.model.module.rand_parameters( teacher_cand, self.pick_method == 'meta'), grad_teacher): weight.grad = grad_item # clip gradients torch.nn.utils.clip_grad_norm_( self.model.module.rand_parameters( self.current_student_arch, self.pick_method == 'meta'), 1) self.optimizer.step() for weight, grad_item in zip(self.model.module.rand_parameters( teacher_cand, self.pick_method == 'meta'), grad_teacher): del weight.grad # simulate sgd updating def _simulate_sgd_update(self, w, g, optimizer): return g * optimizer.param_groups[-1]['lr'] + w # split training images into several slices def _get_minibatch_input(self, input): slice = self.slices x = deepcopy(input[:slice].clone().detach()) return x # calculate 1st gradient of student architectures def _calculate_1st_gradient(self, kd_loss): self.optimizer.zero_grad() grad = torch.autograd.grad( kd_loss, self.model.module.rand_parameters(self.current_student_arch), create_graph=True) return grad # calculate 2nd gradient of meta networks def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight): self.optimizer.zero_grad() grad_student_val = torch.autograd.grad( validation_loss, self.model.module.rand_parameters(self.current_student_arch), retain_graph=True) grad_teacher = torch.autograd.grad( students_weight[0], self.model.module.rand_parameters( teacher_cand, self.pick_method == 'meta'), grad_outputs=grad_student_val) return grad_teacher # forward training data def _forward_training(self, x, meta_value): self._replace_mutator_cand(self.current_student_arch) output = self.model(x) with torch.no_grad(): self._replace_mutator_cand(self.current_teacher_arch) teacher_output = self.model(x) soft_label = torch.nn.functional.softmax(teacher_output, dim=1) kd_loss = meta_value * \ self._cross_entropy_loss_with_soft_target(output, soft_label) return kd_loss # calculate soft target loss def _cross_entropy_loss_with_soft_target(self, pred, soft_target): logsoftmax = torch.nn.LogSoftmax() return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) # forward validation data def _forward_validation(self, input, target): slice = self.slices x = input[slice:slice * 2].clone() self._replace_mutator_cand(self.current_student_arch) output_2 = self.model(x) validation_loss = self.loss(output_2, target[slice:slice * 2]) return validation_loss def _isUpdateMeta(self, batch_idx): isUpdate = True isUpdate &= (self.current_epoch > self.meta_sta_epoch) isUpdate &= (batch_idx > 0) isUpdate &= (batch_idx % self.update_iter == 0) isUpdate &= (self._board_size() > 0) return isUpdate def _replace_mutator_cand(self, cand): self.mutator._cache = cand # update meta matching networks def _run_update(self, input, target, batch_idx): if self._isUpdateMeta(batch_idx): x = self._get_minibatch_input(input) meta_value, teacher_cand = self._select_teacher() kd_loss = self._forward_training(x, meta_value) # calculate 1st gradient grad_1st = self._calculate_1st_gradient(kd_loss) # simulate updated student weights students_weight = [ self._simulate_sgd_update( p, grad_item, self.optimizer) for p, grad_item in zip( self.model.module.rand_parameters(self.current_student_arch), grad_1st)] # update student weights self._update_student_weights_only(grad_1st) validation_loss = self._forward_validation(input, target) # calculate 2nd gradient grad_teacher = self._calculate_2nd_gradient(validation_loss, teacher_cand, students_weight) # update meta matching networks self._update_meta_weights_only(teacher_cand, grad_teacher) # delete internal variants del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight def _get_cand_flops(self, cand): flops = 0 for block_id, block in enumerate(cand): if block == 'LayerChoice1' or block_id == 'LayerChoice23': continue for idx, choice in enumerate(cand[block]): flops += self.flops_dict[block_id][idx] * (1 if choice else 0) return flops + self.flops_fixed def train_one_epoch(self, epoch): self.current_epoch = epoch meters = AverageMeterGroup() self.steps_per_epoch = len(self.train_loader) for step, (input_data, target) in enumerate(self.train_loader): self.mutator.reset() self.current_student_arch = self.mutator._cache input_data, target = input_data.cuda(), target.cuda() # calculate flops of current architecture cand_flops = self._get_cand_flops(self.mutator._cache) # update meta matching network self._run_update(input_data, target, step) if self._board_size() > 0: # select teacher architecture meta_value, teacher_cand = self._select_teacher() self.current_teacher_arch = teacher_cand # forward supernet if self._board_size() == 0 or epoch <= self.meta_sta_epoch: self._replace_mutator_cand(self.current_student_arch) output = self.model(input_data) loss = self.loss(output, target) kd_loss, teacher_output, teacher_cand = None, None, None else: self._replace_mutator_cand(self.current_student_arch) output = self.model(input_data) gt_loss = self.loss(output, target) with torch.no_grad(): self._replace_mutator_cand(self.current_teacher_arch) teacher_output = self.model(input_data).detach() soft_label = torch.nn.functional.softmax(teacher_output, dim=1) kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label) loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2 # update network self.optimizer.zero_grad() loss.backward() self.optimizer.step() # update metrics prec1, prec5 = accuracy(output, target, topk=(1, 5)) metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = reduce_metrics(metrics) meters.update(metrics) # update prioritized board self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops) if self.main_proc and ( step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters) arch_list = [] # if self.main_proc and self.num_epochs == epoch + 1: for idx, i in enumerate(self.prioritized_board): # logger.info("prioritized_board: No.%s %s", idx, i[:4]) if idx == 0: for arch in list(i[3].values()): _ = arch.numpy() _ = np.where(_)[0].tolist() arch_list.append(_) if len(arch_list) > 0: with open(self.selected_space, "w") as f: print("dump selected space.") json.dump({'selected_space': arch_list}, f) def validate_one_epoch(self, epoch): self.model.eval() meters = AverageMeterGroup() with torch.no_grad(): for step, (x, y) in enumerate(self.valid_loader): self.mutator.reset() logits = self.model(x) loss = self.val_loss(logits, y) prec1, prec5 = accuracy(logits, y, topk=(1, 5)) metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = reduce_metrics(metrics) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.valid_loader), meters) # print({'type': 'Accuracy', 'result': {'sequence': epoch, 'category': 'epoch', # 'value': metrics["prec1"]}}) if self.result_path is not None: with open(self.result_path, "a") as ss_file: ss_file.write(json.dumps( {'type': 'Accuracy', 'result': {'sequence': epoch, 'category': 'epoch', 'value': metrics["prec1"]}}) + '\n')