Merge branch 'gan_lab' of https://github.com/neonbjb/mmsr into gan_lab
This commit is contained in:
commit
93528ff8df
|
@ -142,7 +142,8 @@ class SRGANModel(BaseModel):
|
|||
restarts=train_opt['restarts'],
|
||||
weights=train_opt['restart_weights'],
|
||||
gamma=train_opt['lr_gamma'],
|
||||
clear_state=train_opt['clear_state']))
|
||||
clear_state=train_opt['clear_state'],
|
||||
force_lr=train_opt['force_lr']))
|
||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
|
|
|
@ -7,18 +7,21 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
|
||||
class MultiStepLR_Restart(_LRScheduler):
|
||||
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
||||
clear_state=False, last_epoch=-1):
|
||||
clear_state=False, force_lr=False, last_epoch=-1):
|
||||
self.milestones = Counter(milestones)
|
||||
self.gamma = gamma
|
||||
self.clear_state = clear_state
|
||||
self.restarts = restarts if restarts else [0]
|
||||
self.restarts = [v + 1 for v in self.restarts]
|
||||
self.restart_weights = weights if weights else [1]
|
||||
self.force_lr = force_lr
|
||||
assert len(self.restarts) == len(
|
||||
self.restart_weights), 'restarts and their weights do not match.'
|
||||
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.force_lr:
|
||||
return [group['initial_lr'] for group in self.optimizer.param_groups]
|
||||
if self.last_epoch in self.restarts:
|
||||
if self.clear_state:
|
||||
self.optimizer.state = defaultdict(dict)
|
||||
|
|
Loading…
Reference in New Issue
Block a user