allow opt states to be reset

This commit is contained in:
James Betker 2022-05-23 10:54:37 -06:00
parent f4a97ca0a7
commit 5d13d38119

View File

@ -158,7 +158,12 @@ class Trainer:
self.start_epoch = resume_state['epoch']
self.current_step = resume_state['iter']
self.total_training_data_encountered = opt_get(resume_state, ['total_data_processed'], 0)
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
if opt_get(opt, ['path', 'optimizer_reset'], False):
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print('!! RESETTING OPTIMIZER STATES')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
else:
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
else:
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
self.total_training_data_encountered = 0 if 'training_data_encountered' not in opt.keys() else opt['training_data_encountered']