diff --git a/codes/train.py b/codes/train.py index 260b3874..9a8746fa 100644 --- a/codes/train.py +++ b/codes/train.py @@ -151,6 +151,7 @@ class Trainer: self.start_epoch = 0 if 'force_start_step' in opt.keys(): self.current_step = opt['force_start_step'] + opt['current_step'] = self.current_step def do_step(self, train_data): if self._profile: diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 7d3d1983..db25b686 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -98,14 +98,8 @@ class ExtensibleTrainer(BaseModel): self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt) # Set the starting step count for the scheduler. - start_step = 0 - if 'force_start_step' in opt.keys(): - start_step = opt['force_start_step'] - elif 'start_step' in opt.keys(): - start_step = opt['start_step'] - if start_step != 0: - for sched in self.schedulers: - sched.last_epoch = start_step + for sched in self.schedulers: + sched.last_epoch = opt['current_step'] else: self.schedulers = []