Force LR fix

This commit is contained in:
James Betker 2021-10-21 12:01:01 -06:00
parent 40cb25292a
commit 9a3e89ec53

View File

@ -76,7 +76,7 @@ class MultiStepLR_Restart(_LRScheduler):
def get_lr(self):
# Note to self: for the purposes of this trainer, "last_epoch" should read "last_step"
if self.force_lr is not None:
return [self.force_lr for group in self.optimizer.param_groups]
return [self.force_lr for _ in self.optimizer.param_groups]
if self.last_epoch in self.restarts:
if self.clear_state:
self.optimizer.state = defaultdict(dict)
@ -95,8 +95,10 @@ class MultiStepLR_Restart(_LRScheduler):
# Allow this scheduler to use newly appointed milestones partially through a training run..
def load_state_dict(self, s):
milestones_cache = self.milestones
force_lr_cache = self.force_lr
super(MultiStepLR_Restart, self).load_state_dict(s)
self.milestones = milestones_cache
self.force_lr = force_lr_cache
class CosineAnnealingLR_Restart(_LRScheduler):