From ce6524184c82c2e32576e893ba60155c06aa1c7b Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 2 Jan 2021 22:24:12 -0700 Subject: [PATCH] Do the last commit but in a better way --- codes/train.py | 1 + codes/trainer/ExtensibleTrainer.py | 10 ++-------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/codes/train.py b/codes/train.py index 260b3874..9a8746fa 100644 --- a/codes/train.py +++ b/codes/train.py @@ -151,6 +151,7 @@ class Trainer: self.start_epoch = 0 if 'force_start_step' in opt.keys(): self.current_step = opt['force_start_step'] + opt['current_step'] = self.current_step def do_step(self, train_data): if self._profile: diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 7d3d1983..db25b686 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -98,14 +98,8 @@ class ExtensibleTrainer(BaseModel): self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt) # Set the starting step count for the scheduler. - start_step = 0 - if 'force_start_step' in opt.keys(): - start_step = opt['force_start_step'] - elif 'start_step' in opt.keys(): - start_step = opt['start_step'] - if start_step != 0: - for sched in self.schedulers: - sched.last_epoch = start_step + for sched in self.schedulers: + sched.last_epoch = opt['current_step'] else: self.schedulers = []