Do the last commit but in a better way
This commit is contained in:
parent
edf9c38198
commit
ce6524184c
|
@ -151,6 +151,7 @@ class Trainer:
|
||||||
self.start_epoch = 0
|
self.start_epoch = 0
|
||||||
if 'force_start_step' in opt.keys():
|
if 'force_start_step' in opt.keys():
|
||||||
self.current_step = opt['force_start_step']
|
self.current_step = opt['force_start_step']
|
||||||
|
opt['current_step'] = self.current_step
|
||||||
|
|
||||||
def do_step(self, train_data):
|
def do_step(self, train_data):
|
||||||
if self._profile:
|
if self._profile:
|
||||||
|
|
|
@ -98,14 +98,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
|
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
|
||||||
|
|
||||||
# Set the starting step count for the scheduler.
|
# Set the starting step count for the scheduler.
|
||||||
start_step = 0
|
for sched in self.schedulers:
|
||||||
if 'force_start_step' in opt.keys():
|
sched.last_epoch = opt['current_step']
|
||||||
start_step = opt['force_start_step']
|
|
||||||
elif 'start_step' in opt.keys():
|
|
||||||
start_step = opt['start_step']
|
|
||||||
if start_step != 0:
|
|
||||||
for sched in self.schedulers:
|
|
||||||
sched.last_epoch = start_step
|
|
||||||
else:
|
else:
|
||||||
self.schedulers = []
|
self.schedulers = []
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user