From 299d855b342f01bf3a56c4416fb48102742d69e4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 7 Jun 2020 16:56:05 -0600 Subject: [PATCH] Enable forced learning rates --- codes/models/SRGAN_model.py | 3 ++- codes/models/lr_scheduler.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 58f34aa9..88f705b8 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -142,7 +142,8 @@ class SRGANModel(BaseModel): restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], - clear_state=train_opt['clear_state'])) + clear_state=train_opt['clear_state'], + force_lr=train_opt['force_lr'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( diff --git a/codes/models/lr_scheduler.py b/codes/models/lr_scheduler.py index be7a92f0..6ef0aef3 100644 --- a/codes/models/lr_scheduler.py +++ b/codes/models/lr_scheduler.py @@ -7,18 +7,21 @@ from torch.optim.lr_scheduler import _LRScheduler class MultiStepLR_Restart(_LRScheduler): def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, - clear_state=False, last_epoch=-1): + clear_state=False, force_lr=False, last_epoch=-1): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restarts = [v + 1 for v in self.restarts] self.restart_weights = weights if weights else [1] + self.force_lr = force_lr assert len(self.restarts) == len( self.restart_weights), 'restarts and their weights do not match.' super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) def get_lr(self): + if self.force_lr: + return [group['initial_lr'] for group in self.optimizer.param_groups] if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict)