forked from mrq/DL-Art-School
Enable forced learning rates
This commit is contained in:
parent
cbedd6340a
commit
299d855b34
|
@ -142,7 +142,8 @@ class SRGANModel(BaseModel):
|
||||||
restarts=train_opt['restarts'],
|
restarts=train_opt['restarts'],
|
||||||
weights=train_opt['restart_weights'],
|
weights=train_opt['restart_weights'],
|
||||||
gamma=train_opt['lr_gamma'],
|
gamma=train_opt['lr_gamma'],
|
||||||
clear_state=train_opt['clear_state']))
|
clear_state=train_opt['clear_state'],
|
||||||
|
force_lr=train_opt['force_lr']))
|
||||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||||
for optimizer in self.optimizers:
|
for optimizer in self.optimizers:
|
||||||
self.schedulers.append(
|
self.schedulers.append(
|
||||||
|
|
|
@ -7,18 +7,21 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
|
||||||
class MultiStepLR_Restart(_LRScheduler):
|
class MultiStepLR_Restart(_LRScheduler):
|
||||||
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
||||||
clear_state=False, last_epoch=-1):
|
clear_state=False, force_lr=False, last_epoch=-1):
|
||||||
self.milestones = Counter(milestones)
|
self.milestones = Counter(milestones)
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.clear_state = clear_state
|
self.clear_state = clear_state
|
||||||
self.restarts = restarts if restarts else [0]
|
self.restarts = restarts if restarts else [0]
|
||||||
self.restarts = [v + 1 for v in self.restarts]
|
self.restarts = [v + 1 for v in self.restarts]
|
||||||
self.restart_weights = weights if weights else [1]
|
self.restart_weights = weights if weights else [1]
|
||||||
|
self.force_lr = force_lr
|
||||||
assert len(self.restarts) == len(
|
assert len(self.restarts) == len(
|
||||||
self.restart_weights), 'restarts and their weights do not match.'
|
self.restart_weights), 'restarts and their weights do not match.'
|
||||||
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
|
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
|
if self.force_lr:
|
||||||
|
return [group['initial_lr'] for group in self.optimizer.param_groups]
|
||||||
if self.last_epoch in self.restarts:
|
if self.last_epoch in self.restarts:
|
||||||
if self.clear_state:
|
if self.clear_state:
|
||||||
self.optimizer.state = defaultdict(dict)
|
self.optimizer.state = defaultdict(dict)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user