forked from mrq/DL-Art-School
Force LR fix
This commit is contained in:
parent
40cb25292a
commit
9a3e89ec53
|
@ -76,7 +76,7 @@ class MultiStepLR_Restart(_LRScheduler):
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
# Note to self: for the purposes of this trainer, "last_epoch" should read "last_step"
|
# Note to self: for the purposes of this trainer, "last_epoch" should read "last_step"
|
||||||
if self.force_lr is not None:
|
if self.force_lr is not None:
|
||||||
return [self.force_lr for group in self.optimizer.param_groups]
|
return [self.force_lr for _ in self.optimizer.param_groups]
|
||||||
if self.last_epoch in self.restarts:
|
if self.last_epoch in self.restarts:
|
||||||
if self.clear_state:
|
if self.clear_state:
|
||||||
self.optimizer.state = defaultdict(dict)
|
self.optimizer.state = defaultdict(dict)
|
||||||
|
@ -95,8 +95,10 @@ class MultiStepLR_Restart(_LRScheduler):
|
||||||
# Allow this scheduler to use newly appointed milestones partially through a training run..
|
# Allow this scheduler to use newly appointed milestones partially through a training run..
|
||||||
def load_state_dict(self, s):
|
def load_state_dict(self, s):
|
||||||
milestones_cache = self.milestones
|
milestones_cache = self.milestones
|
||||||
|
force_lr_cache = self.force_lr
|
||||||
super(MultiStepLR_Restart, self).load_state_dict(s)
|
super(MultiStepLR_Restart, self).load_state_dict(s)
|
||||||
self.milestones = milestones_cache
|
self.milestones = milestones_cache
|
||||||
|
self.force_lr = force_lr_cache
|
||||||
|
|
||||||
|
|
||||||
class CosineAnnealingLR_Restart(_LRScheduler):
|
class CosineAnnealingLR_Restart(_LRScheduler):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user