From 9a3e89ec53eac13482b4b14521ea32bce1ecbadf Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 21 Oct 2021 12:01:01 -0600 Subject: [PATCH] Force LR fix --- codes/trainer/lr_scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codes/trainer/lr_scheduler.py b/codes/trainer/lr_scheduler.py index a79124c6..f437ff3c 100644 --- a/codes/trainer/lr_scheduler.py +++ b/codes/trainer/lr_scheduler.py @@ -76,7 +76,7 @@ class MultiStepLR_Restart(_LRScheduler): def get_lr(self): # Note to self: for the purposes of this trainer, "last_epoch" should read "last_step" if self.force_lr is not None: - return [self.force_lr for group in self.optimizer.param_groups] + return [self.force_lr for _ in self.optimizer.param_groups] if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) @@ -95,8 +95,10 @@ class MultiStepLR_Restart(_LRScheduler): # Allow this scheduler to use newly appointed milestones partially through a training run.. def load_state_dict(self, s): milestones_cache = self.milestones + force_lr_cache = self.force_lr super(MultiStepLR_Restart, self).load_state_dict(s) self.milestones = milestones_cache + self.force_lr = force_lr_cache class CosineAnnealingLR_Restart(_LRScheduler):