import torch from torch.optim import SGD, lr_scheduler import numpy as np class _LRMomentumScheduler(lr_scheduler._LRScheduler): def __init__(self, optimizer, last_epoch=-1): if last_epoch == -1: for group in optimizer.param_groups: group.setdefault('initial_momentum', group['momentum']) else: for i, group in enumerate(optimizer.param_groups): if 'initial_momentum' not in group: raise KeyError("param 'initial_momentum' is not specified " "in param_groups[{}] when resuming an optimizer".format(i)) self.base_momentums = list(map(lambda group: group['initial_momentum'], optimizer.param_groups)) super().__init__(optimizer, last_epoch) def get_lr(self): raise NotImplementedError def get_momentum(self): raise NotImplementedError def step(self, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.last_epoch = epoch for param_group, lr, momentum in zip(self.optimizer.param_groups, self.get_lr(), self.get_momentum()): param_group['lr'] = lr param_group['momentum'] = momentum class ParameterUpdate(object): """A callable class used to define an arbitrary schedule defined by a list. This object is designed to be passed to the LambdaLR or LambdaScheduler scheduler to apply the given schedule. Arguments: params {list or numpy.array} -- List or numpy array defining parameter schedule. base_param {float} -- Parameter value used to initialize the optimizer. """ def __init__(self, params, base_param): self.params = np.hstack([params, 0]) self.base_param = base_param def __call__(self, epoch): return self.params[epoch] / self.base_param def apply_lambda(last_epoch, bases, lambdas): return [base * lmbda(last_epoch) for lmbda, base in zip(lambdas, bases)] class LambdaScheduler(_LRMomentumScheduler): """Sets the learning rate and momentum of each parameter group to the initial lr and momentum times a given function. When last_epoch=-1, sets initial lr and momentum to the optimizer values. Args: optimizer (Optimizer): Wrapped optimizer. lr_lambda (function or list): A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. Default: lambda x:x. momentum_lambda (function or list): As for lr_lambda but applied to momentum. Default: lambda x:x. last_epoch (int): The index of last epoch. Default: -1. Example: >>> # Assuming optimizer has two groups. >>> lr_lambda = [ ... lambda epoch: epoch // 30, ... lambda epoch: 0.95 ** epoch ... ] >>> mom_lambda = [ ... lambda epoch: max(0, (50 - epoch) // 50), ... lambda epoch: 0.99 ** epoch ... ] >>> scheduler = LambdaScheduler(optimizer, lr_lambda, mom_lambda) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, optimizer, lr_lambda=lambda x: x, momentum_lambda=lambda x: x, last_epoch=-1): self.optimizer = optimizer if not isinstance(lr_lambda, (list, tuple)): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: if len(lr_lambda) != len(optimizer.param_groups): raise ValueError("Expected {} lr_lambdas, but got {}".format( len(optimizer.param_groups), len(lr_lambda))) self.lr_lambdas = list(lr_lambda) if not isinstance(momentum_lambda, (list, tuple)): self.momentum_lambdas = [momentum_lambda] * len(optimizer.param_groups) else: if len(momentum_lambda) != len(optimizer.param_groups): raise ValueError("Expected {} momentum_lambdas, but got {}".format( len(optimizer.param_groups), len(momentum_lambda))) self.momentum_lambdas = list(momentum_lambda) self.last_epoch = last_epoch super().__init__(optimizer, last_epoch) def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. The learning rate and momentum lambda functions will only be saved if they are callable objects and not if they are functions or lambdas. """ state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas', 'momentum_lambdas')} state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) state_dict['momentum_lambdas'] = [None] * len(self.momentum_lambdas) for idx, (lr_fn, mom_fn) in enumerate(zip(self.lr_lambdas, self.momentum_lambdas)): if not isinstance(lr_fn, types.FunctionType): state_dict['lr_lambdas'][idx] = lr_fn.__dict__.copy() if not isinstance(mom_fn, types.FunctionType): state_dict['momentum_lambdas'][idx] = mom_fn.__dict__.copy() return state_dict def load_state_dict(self, state_dict): """Loads the schedulers state. Arguments: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ lr_lambdas = state_dict.pop('lr_lambdas') momentum_lambdas = state_dict.pop('momentum_lambdas') self.__dict__.update(state_dict) for idx, fn in enumerate(lr_lambdas): if fn is not None: self.lr_lambdas[idx].__dict__.update(fn) for idx, fn in enumerate(momentum_lambdas): if fn is not None: self.momentum_lambdas[idx].__dict__.update(fn) def get_lr(self): return apply_lambda(self.last_epoch, self.base_lrs, self.lr_lambdas) def get_momentum(self): return apply_lambda(self.last_epoch, self.base_momentums, self.momentum_lambdas) class ParameterUpdate(object): """A callable class used to define an arbitrary schedule defined by a list. This object is designed to be passed to the LambdaLR or LambdaScheduler scheduler to apply the given schedule. If a base_param is zero, no updates are applied. Arguments: params {list or numpy.array} -- List or numpy array defining parameter schedule. base_param {float} -- Parameter value used to initialize the optimizer. """ def __init__(self, params, base_param): self.params = np.hstack([params, 0]) self.base_param = base_param if base_param < 1e-12: self.base_param = 1 self.params = self.params * 0.0 + 1.0 def __call__(self, epoch): return self.params[epoch] / self.base_param class ListScheduler(LambdaScheduler): """Sets the learning rate and momentum of each parameter group to values defined by lists. When last_epoch=-1, sets initial lr and momentum to the optimizer values. One of both of lr and momentum schedules may be specified. Note that the parameters used to initialize the optimizer are overriden by those defined by this scheduler. Args: optimizer (Optimizer): Wrapped optimizer. lrs (list or numpy.ndarray): A list of learning rates, or a list of lists, one for each parameter group. One- or two-dimensional numpy arrays may also be passed. momentum (list or numpy.ndarray): A list of momentums, or a list of lists, one for each parameter group. One- or two-dimensional numpy arrays may also be passed. last_epoch (int): The index of last epoch. Default: -1. Example: >>> # Assuming optimizer has two groups. >>> lrs = [ ... np.linspace(0.01, 0.1, 100), ... np.logspace(-2, 0, 100) ... ] >>> momentums = [ ... np.linspace(0.85, 0.95, 100), ... np.linspace(0.8, 0.99, 100) ... ] >>> scheduler = ListScheduler(optimizer, lrs, momentums) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, optimizer, lrs=None, momentums=None, last_epoch=-1): groups = optimizer.param_groups if lrs is None: lr_lambda = lambda x: x else: lrs = np.array(lrs) if isinstance(lrs, (list, tuple)) else lrs if len(lrs.shape) == 1: lr_lambda = [ParameterUpdate(lrs, g['lr']) for g in groups] else: lr_lambda = [ParameterUpdate(l, g['lr']) for l, g in zip(lrs, groups)] if momentums is None: momentum_lambda = lambda x: x else: momentums = np.array(momentums) if isinstance(momentums, (list, tuple)) else momentums if len(momentums.shape) == 1: momentum_lambda = [ParameterUpdate(momentums, g['momentum']) for g in groups] else: momentum_lambda = [ParameterUpdate(l, g['momentum']) for l, g in zip(momentums, groups)] super().__init__(optimizer, lr_lambda, momentum_lambda) class RangeFinder(ListScheduler): """Scheduler class that implements the LR range search specified in: A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay. Leslie N. Smith, 2018, arXiv:1803.09820. Logarithmically spaced learning rates from 1e-7 to 1 are searched. The number of increments in that range is determined by 'epochs'. Note that the parameters used to initialize the optimizer are overriden by those defined by this scheduler. Args: optimizer (Optimizer): Wrapped optimizer. epochs (int): Number of epochs over which to run test. Example: >>> scheduler = RangeFinder(optimizer, 100) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, optimizer, epochs): lrs = np.logspace(-7, 0, epochs) super().__init__(optimizer, lrs) class OneCyclePolicy(ListScheduler): """Scheduler class that implements the 1cycle policy search specified in: A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay. Leslie N. Smith, 2018, arXiv:1803.09820. Args: optimizer (Optimizer): Wrapped optimizer. lr (float or list). Maximum learning rate in range. If a list of values is passed, they should correspond to parameter groups. epochs (int): The number of epochs to use during search. momentum_rng (list). Optional upper and lower momentum values (may be both equal). Set to None to run without momentum. Default: [0.85, 0.95]. If a list of lists is passed, they should correspond to parameter groups. phase_ratio (float): Fraction of epochs used for the increasing and decreasing phase of the schedule. For example, if phase_ratio=0.45 and epochs=100, the learning rate will increase from lr/10 to lr over 45 epochs, then decrease back to lr/10 over 45 epochs, then decrease to lr/100 over the remaining 10 epochs. Default: 0.45. """ def __init__(self, optimizer, lr, epochs, momentum_rng=[0.85, 0.95], phase_ratio=0.45): phase_epochs = int(phase_ratio * epochs) if isinstance(lr, (list, tuple)): lrs = [ np.hstack([ np.linspace(l * 1e-1, l, phase_epochs), np.linspace(l, l * 1e-1, phase_epochs), np.linspace(l * 1e-1, l * 1e-2, epochs - 2 * phase_epochs), ]) for l in lr ] else: lrs = np.hstack([ np.linspace(lr * 1e-1, lr, phase_epochs), np.linspace(lr, lr * 1e-1, phase_epochs), np.linspace(lr * 1e-1, lr * 1e-2, epochs - 2 * phase_epochs), ]) if momentum_rng is not None: momentum_rng = np.array(momentum_rng) if len(momentum_rng.shape) == 2: for i, g in enumerate(optimizer.param_groups): g['momentum'] = momentum_rng[i][1] momentums = [ np.hstack([ np.linspace(m[1], m[0], phase_epochs), np.linspace(m[0], m[1], phase_epochs), np.linspace(m[1], m[1], epochs - 2 * phase_epochs), ]) for m in momentum_rng ] else: for i, g in enumerate(optimizer.param_groups): g['momentum'] = momentum_rng[1] momentums = np.hstack([ np.linspace(momentum_rng[1], momentum_rng[0], phase_epochs), np.linspace(momentum_rng[0], momentum_rng[1], phase_epochs), np.linspace(momentum_rng[1], momentum_rng[1], epochs - 2 * phase_epochs), ]) else: momentums = None super().__init__(optimizer, lrs, momentums)