Fix force_lr logic

This commit is contained in:
James Betker 2021-10-21 11:51:30 -06:00
parent 0dee15f875
commit 40cb25292a

View File

@ -66,6 +66,8 @@ class MultiStepLR_Restart(_LRScheduler):
self.restarts = [v + 1 for v in self.restarts]
self.restart_weights = weights if weights else [1]
self.force_lr = force_lr
if force_lr:
print(f"!!Forcing the learning rate to: {force_lr}")
self.warmup_steps = warmup_steps
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
@ -73,8 +75,8 @@ class MultiStepLR_Restart(_LRScheduler):
def get_lr(self):
# Note to self: for the purposes of this trainer, "last_epoch" should read "last_step"
if self.force_lr:
return [group['initial_lr'] for group in self.optimizer.param_groups]
if self.force_lr is not None:
return [self.force_lr for group in self.optimizer.param_groups]
if self.last_epoch in self.restarts:
if self.clear_state:
self.optimizer.state = defaultdict(dict)