Fix force_lr logic
This commit is contained in:
parent
0dee15f875
commit
40cb25292a
|
@ -66,6 +66,8 @@ class MultiStepLR_Restart(_LRScheduler):
|
|||
self.restarts = [v + 1 for v in self.restarts]
|
||||
self.restart_weights = weights if weights else [1]
|
||||
self.force_lr = force_lr
|
||||
if force_lr:
|
||||
print(f"!!Forcing the learning rate to: {force_lr}")
|
||||
self.warmup_steps = warmup_steps
|
||||
assert len(self.restarts) == len(
|
||||
self.restart_weights), 'restarts and their weights do not match.'
|
||||
|
@ -73,8 +75,8 @@ class MultiStepLR_Restart(_LRScheduler):
|
|||
|
||||
def get_lr(self):
|
||||
# Note to self: for the purposes of this trainer, "last_epoch" should read "last_step"
|
||||
if self.force_lr:
|
||||
return [group['initial_lr'] for group in self.optimizer.param_groups]
|
||||
if self.force_lr is not None:
|
||||
return [self.force_lr for group in self.optimizer.param_groups]
|
||||
if self.last_epoch in self.restarts:
|
||||
if self.clear_state:
|
||||
self.optimizer.state = defaultdict(dict)
|
||||
|
|
Loading…
Reference in New Issue
Block a user