Enable forced learning rates

This commit is contained in:
James Betker 2020-06-07 16:56:05 -06:00
parent cbedd6340a
commit 299d855b34
2 changed files with 6 additions and 2 deletions

View File

@ -142,7 +142,8 @@ class SRGANModel(BaseModel):
restarts=train_opt['restarts'], restarts=train_opt['restarts'],
weights=train_opt['restart_weights'], weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'], gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state'])) clear_state=train_opt['clear_state'],
force_lr=train_opt['force_lr']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers: for optimizer in self.optimizers:
self.schedulers.append( self.schedulers.append(

View File

@ -7,18 +7,21 @@ from torch.optim.lr_scheduler import _LRScheduler
class MultiStepLR_Restart(_LRScheduler): class MultiStepLR_Restart(_LRScheduler):
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
clear_state=False, last_epoch=-1): clear_state=False, force_lr=False, last_epoch=-1):
self.milestones = Counter(milestones) self.milestones = Counter(milestones)
self.gamma = gamma self.gamma = gamma
self.clear_state = clear_state self.clear_state = clear_state
self.restarts = restarts if restarts else [0] self.restarts = restarts if restarts else [0]
self.restarts = [v + 1 for v in self.restarts] self.restarts = [v + 1 for v in self.restarts]
self.restart_weights = weights if weights else [1] self.restart_weights = weights if weights else [1]
self.force_lr = force_lr
assert len(self.restarts) == len( assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.' self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self): def get_lr(self):
if self.force_lr:
return [group['initial_lr'] for group in self.optimizer.param_groups]
if self.last_epoch in self.restarts: if self.last_epoch in self.restarts:
if self.clear_state: if self.clear_state:
self.optimizer.state = defaultdict(dict) self.optimizer.state = defaultdict(dict)