diff --git a/codes/trainer/lr_scheduler.py b/codes/trainer/lr_scheduler.py index 53051a2c..a79124c6 100644 --- a/codes/trainer/lr_scheduler.py +++ b/codes/trainer/lr_scheduler.py @@ -66,6 +66,8 @@ class MultiStepLR_Restart(_LRScheduler): self.restarts = [v + 1 for v in self.restarts] self.restart_weights = weights if weights else [1] self.force_lr = force_lr + if force_lr: + print(f"!!Forcing the learning rate to: {force_lr}") self.warmup_steps = warmup_steps assert len(self.restarts) == len( self.restart_weights), 'restarts and their weights do not match.' @@ -73,8 +75,8 @@ class MultiStepLR_Restart(_LRScheduler): def get_lr(self): # Note to self: for the purposes of this trainer, "last_epoch" should read "last_step" - if self.force_lr: - return [group['initial_lr'] for group in self.optimizer.param_groups] + if self.force_lr is not None: + return [self.force_lr for group in self.optimizer.param_groups] if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict)