Merge branch 'gan_lab' of https://github.com/neonbjb/mmsr into gan_lab

This commit is contained in:
James Betker 2020-06-07 16:59:31 -06:00
commit 93528ff8df
2 changed files with 6 additions and 2 deletions

View File

@ -142,7 +142,8 @@ class SRGANModel(BaseModel):
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
clear_state=train_opt['clear_state'],
force_lr=train_opt['force_lr']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(

View File

@ -7,18 +7,21 @@ from torch.optim.lr_scheduler import _LRScheduler
class MultiStepLR_Restart(_LRScheduler):
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
clear_state=False, last_epoch=-1):
clear_state=False, force_lr=False, last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.clear_state = clear_state
self.restarts = restarts if restarts else [0]
self.restarts = [v + 1 for v in self.restarts]
self.restart_weights = weights if weights else [1]
self.force_lr = force_lr
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.force_lr:
return [group['initial_lr'] for group in self.optimizer.param_groups]
if self.last_epoch in self.restarts:
if self.clear_state:
self.optimizer.state = defaultdict(dict)