2019-08-23 13:42:47 +00:00
|
|
|
import math
|
|
|
|
from collections import Counter
|
|
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
|
2021-08-05 11:57:04 +00:00
|
|
|
from utils.util import opt_get
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
|
2020-08-12 14:45:23 +00:00
|
|
|
def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
|
|
|
schedulers = []
|
|
|
|
for o in optimizers:
|
2020-12-24 03:33:43 +00:00
|
|
|
# Hack to support LARC, which wraps an underlying optimizer.
|
|
|
|
if hasattr(o, 'optim'):
|
|
|
|
o = o.optim
|
|
|
|
|
2020-08-12 14:45:23 +00:00
|
|
|
if name == 'MultiStepLR':
|
|
|
|
sched = MultiStepLR_Restart(o, scheduler_opt['gen_lr_steps'],
|
|
|
|
restarts=scheduler_opt['restarts'],
|
|
|
|
weights=scheduler_opt['restart_weights'],
|
|
|
|
gamma=scheduler_opt['lr_gamma'],
|
|
|
|
clear_state=scheduler_opt['clear_state'],
|
2021-08-05 02:07:45 +00:00
|
|
|
force_lr=scheduler_opt['force_lr'],
|
2021-08-05 11:57:04 +00:00
|
|
|
warmup_steps=opt_get(scheduler_opt, ['warmup_steps'], 0))
|
2020-08-12 14:45:23 +00:00
|
|
|
elif name == 'ProgressiveMultiStepLR':
|
|
|
|
sched = ProgressiveMultiStepLR(o, scheduler_opt['gen_lr_steps'],
|
|
|
|
scheduler_opt['progressive_starts'],
|
|
|
|
scheduler_opt['lr_gamma'])
|
|
|
|
elif name == 'CosineAnnealingLR_Restart':
|
|
|
|
sched = CosineAnnealingLR_Restart(
|
2020-12-24 03:33:43 +00:00
|
|
|
o, scheduler_opt['T_period'], scheduler_opt['warmup'], eta_min=scheduler_opt['eta_min'],
|
2020-08-12 14:45:23 +00:00
|
|
|
restarts=scheduler_opt['restarts'], weights=scheduler_opt['restart_weights'])
|
|
|
|
else:
|
|
|
|
raise NotImplementedError('Scheduler not available')
|
|
|
|
schedulers.append(sched)
|
|
|
|
return schedulers
|
|
|
|
|
|
|
|
|
2020-07-18 20:18:48 +00:00
|
|
|
# This scheduler is specifically designed to modulate the learning rate of several different param groups configured
|
|
|
|
# by a generator or discriminator that slowly adds new stages one at a time, e.g. like progressive growing of GANs.
|
|
|
|
class ProgressiveMultiStepLR(_LRScheduler):
|
|
|
|
def __init__(self, optimizer, milestones, group_starts, gamma=0.1):
|
|
|
|
self.milestones = Counter(milestones)
|
|
|
|
self.gamma = gamma
|
|
|
|
self.group_starts = group_starts
|
|
|
|
super(ProgressiveMultiStepLR, self).__init__(optimizer)
|
|
|
|
|
|
|
|
def get_lr(self):
|
|
|
|
group_lrs = []
|
|
|
|
assert len(self.optimizer.param_groups) == len(self.group_starts)
|
|
|
|
for group, group_start in zip(self.optimizer.param_groups, self.group_starts):
|
|
|
|
if self.last_epoch - group_start not in self.milestones:
|
|
|
|
group_lrs.append(group['lr'])
|
|
|
|
else:
|
|
|
|
group_lrs.append(group['lr'] * self.gamma)
|
|
|
|
return group_lrs
|
|
|
|
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
class MultiStepLR_Restart(_LRScheduler):
|
|
|
|
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
2021-08-05 02:07:45 +00:00
|
|
|
clear_state=False, force_lr=False, last_epoch=-1, warmup_steps=0):
|
2019-08-23 13:42:47 +00:00
|
|
|
self.milestones = Counter(milestones)
|
|
|
|
self.gamma = gamma
|
|
|
|
self.clear_state = clear_state
|
|
|
|
self.restarts = restarts if restarts else [0]
|
|
|
|
self.restarts = [v + 1 for v in self.restarts]
|
|
|
|
self.restart_weights = weights if weights else [1]
|
2020-06-07 22:56:05 +00:00
|
|
|
self.force_lr = force_lr
|
2021-08-05 02:07:45 +00:00
|
|
|
self.warmup_steps = warmup_steps
|
2019-08-23 13:42:47 +00:00
|
|
|
assert len(self.restarts) == len(
|
|
|
|
self.restart_weights), 'restarts and their weights do not match.'
|
|
|
|
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
|
|
|
|
|
|
|
|
def get_lr(self):
|
2021-08-05 02:07:45 +00:00
|
|
|
# Note to self: for the purposes of this trainer, "last_epoch" should read "last_step"
|
2020-06-07 22:56:05 +00:00
|
|
|
if self.force_lr:
|
|
|
|
return [group['initial_lr'] for group in self.optimizer.param_groups]
|
2019-08-23 13:42:47 +00:00
|
|
|
if self.last_epoch in self.restarts:
|
|
|
|
if self.clear_state:
|
|
|
|
self.optimizer.state = defaultdict(dict)
|
|
|
|
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
|
|
|
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
2021-08-05 02:07:45 +00:00
|
|
|
if self.last_epoch < self.warmup_steps:
|
|
|
|
factor = 1 - (self.warmup_steps - self.last_epoch) / self.warmup_steps
|
|
|
|
return [group['initial_lr'] * factor for group in self.optimizer.param_groups]
|
2019-08-23 13:42:47 +00:00
|
|
|
if self.last_epoch not in self.milestones:
|
|
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
return [
|
|
|
|
group['lr'] * self.gamma**self.milestones[self.last_epoch]
|
|
|
|
for group in self.optimizer.param_groups
|
|
|
|
]
|
|
|
|
|
2020-07-31 17:21:11 +00:00
|
|
|
# Allow this scheduler to use newly appointed milestones partially through a training run..
|
|
|
|
def load_state_dict(self, s):
|
|
|
|
milestones_cache = self.milestones
|
|
|
|
super(MultiStepLR_Restart, self).load_state_dict(s)
|
|
|
|
self.milestones = milestones_cache
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
|
|
|
|
class CosineAnnealingLR_Restart(_LRScheduler):
|
2020-12-24 03:33:43 +00:00
|
|
|
def __init__(self, optimizer, T_period, warmup=0, restarts=None, weights=None, eta_min=0, last_epoch=-1):
|
|
|
|
self.warmup = warmup
|
2019-08-23 13:42:47 +00:00
|
|
|
self.T_period = T_period
|
|
|
|
self.T_max = self.T_period[0] # current T period
|
|
|
|
self.eta_min = eta_min
|
|
|
|
self.restarts = restarts if restarts else [0]
|
|
|
|
self.restarts = [v + 1 for v in self.restarts]
|
|
|
|
self.restart_weights = weights if weights else [1]
|
|
|
|
self.last_restart = 0
|
|
|
|
assert len(self.restarts) == len(
|
|
|
|
self.restart_weights), 'restarts and their weights do not match.'
|
|
|
|
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
|
|
|
|
|
|
|
|
def get_lr(self):
|
2020-12-24 03:33:43 +00:00
|
|
|
step = self.last_epoch - self.warmup
|
|
|
|
if step <= 0:
|
2019-08-23 13:42:47 +00:00
|
|
|
return self.base_lrs
|
2020-12-24 03:33:43 +00:00
|
|
|
elif step in self.restarts:
|
|
|
|
self.last_restart = step
|
|
|
|
self.T_max = self.T_period[self.restarts.index(step) + 1]
|
|
|
|
weight = self.restart_weights[self.restarts.index(step)]
|
2019-08-23 13:42:47 +00:00
|
|
|
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
2020-12-24 03:33:43 +00:00
|
|
|
elif (step - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
|
2019-08-23 13:42:47 +00:00
|
|
|
return [
|
|
|
|
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
|
|
|
|
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
|
|
|
]
|
2020-12-24 03:33:43 +00:00
|
|
|
return [(1 + math.cos(math.pi * (step - self.last_restart) / self.T_max)) /
|
|
|
|
(1 + math.cos(math.pi * ((step - self.last_restart) - 1) / self.T_max)) *
|
2019-08-23 13:42:47 +00:00
|
|
|
(group['lr'] - self.eta_min) + self.eta_min
|
|
|
|
for group in self.optimizer.param_groups]
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2021-01-15 21:51:03 +00:00
|
|
|
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
|
2019-08-23 13:42:47 +00:00
|
|
|
betas=(0.9, 0.99))
|
|
|
|
##############################
|
|
|
|
# MultiStepLR_Restart
|
|
|
|
##############################
|
|
|
|
## Original
|
|
|
|
lr_steps = [200000, 400000, 600000, 800000]
|
|
|
|
restarts = None
|
|
|
|
restart_weights = None
|
|
|
|
|
|
|
|
## two
|
|
|
|
lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
|
|
|
|
restarts = [500000]
|
|
|
|
restart_weights = [1]
|
|
|
|
|
|
|
|
## four
|
|
|
|
lr_steps = [
|
|
|
|
50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
|
|
|
|
600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
|
|
|
|
]
|
|
|
|
restarts = [250000, 500000, 750000]
|
|
|
|
restart_weights = [1, 1, 1]
|
|
|
|
|
|
|
|
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
|
2021-08-05 02:07:45 +00:00
|
|
|
clear_state=False, warmup_steps=20000)
|
|
|
|
'''
|
2019-08-23 13:42:47 +00:00
|
|
|
##############################
|
|
|
|
# Cosine Annealing Restart
|
|
|
|
##############################
|
|
|
|
## two
|
|
|
|
T_period = [500000, 500000]
|
|
|
|
restarts = [500000]
|
|
|
|
restart_weights = [1]
|
|
|
|
|
|
|
|
## four
|
2021-01-15 21:51:03 +00:00
|
|
|
T_period = [200000, 100000, 200000]
|
|
|
|
restarts = [200000, 300000]
|
|
|
|
restart_weights = [.5, .25]
|
2019-08-23 13:42:47 +00:00
|
|
|
|
2021-01-15 21:51:03 +00:00
|
|
|
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=10000, eta_min=1e-8, restarts=restarts,
|
2019-08-23 13:42:47 +00:00
|
|
|
weights=restart_weights)
|
2021-08-05 02:07:45 +00:00
|
|
|
'''
|
2019-08-23 13:42:47 +00:00
|
|
|
|
|
|
|
##############################
|
|
|
|
# Draw figure
|
|
|
|
##############################
|
2021-08-05 02:07:45 +00:00
|
|
|
N_iter = 100000
|
2019-08-23 13:42:47 +00:00
|
|
|
lr_l = list(range(N_iter))
|
|
|
|
for i in range(N_iter):
|
|
|
|
scheduler.step()
|
|
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
|
|
lr_l[i] = current_lr
|
|
|
|
|
|
|
|
import matplotlib as mpl
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
import matplotlib.ticker as mtick
|
|
|
|
mpl.style.use('default')
|
|
|
|
import seaborn
|
|
|
|
seaborn.set(style='whitegrid')
|
|
|
|
seaborn.set_context('paper')
|
|
|
|
|
|
|
|
plt.figure(1)
|
|
|
|
plt.subplot(111)
|
|
|
|
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
|
|
|
|
plt.title('Title', fontsize=16, color='k')
|
|
|
|
plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
|
|
|
|
legend = plt.legend(loc='upper right', shadow=False)
|
|
|
|
ax = plt.gca()
|
|
|
|
labels = ax.get_xticks().tolist()
|
|
|
|
for k, v in enumerate(labels):
|
|
|
|
labels[k] = str(int(v / 1000)) + 'K'
|
|
|
|
ax.set_xticklabels(labels)
|
|
|
|
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
|
|
|
|
|
|
|
|
ax.set_ylabel('Learning rate')
|
|
|
|
ax.set_xlabel('Iteration')
|
|
|
|
fig = plt.gcf()
|
|
|
|
plt.show()
|