save when training completes

This commit is contained in:
mrq 2023-03-15 02:47:12 +00:00
parent 3fdf2a63aa
commit b253da6e35

View File

@ -196,6 +196,8 @@ class Trainer:
self.total_training_data_encountered = self.current_step * opt['datasets']['train']['batch_size']
opt['current_step'] = self.current_step
self.epoch = self.start_epoch
#### validation
if 'val_freq' in opt['train'].keys():
self.val_freq = opt['train']['val_freq'] * opt['datasets']['train']['batch_size']
@ -205,6 +207,18 @@ class Trainer:
self.next_eval_step = self.total_training_data_encountered + self.val_freq
del resume_state # For whatever reason, this relieves a memory burden on the first GPU for some training sessions.
def save(self):
self.model.save(self.current_step)
state = {
'epoch': self.epoch,
'iter': self.current_step,
'total_data_processed': self.total_training_data_encountered
}
if self.dataset_debugger is not None:
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
self.model.save_training_state(state)
self.logger.info('Saving models and training states.')
def do_step(self, train_data):
if self._profile:
print("Data fetch: %f" % (time() - _t))
@ -226,7 +240,7 @@ class Trainer:
_t = time()
self.model.feed_data(train_data, self.current_step)
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
self.iteration_rate = (time() - _t) / batch_size
self.iteration_rate = (time() - _t) # / batch_size
if self._profile:
print("Model feed + step: %f" % (time() - _t))
_t = time()
@ -299,16 +313,7 @@ class Trainer:
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
self.model.consolidate_state()
if self.rank <= 0:
self.model.save(self.current_step)
state = {
'epoch': self.epoch,
'iter': self.current_step,
'total_data_processed': self.total_training_data_encountered
}
if self.dataset_debugger is not None:
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
self.model.save_training_state(state)
self.logger.info('Saving models and training states.')
self.save()
do_eval = self.total_training_data_encountered > self.next_eval_step
if do_eval:
@ -381,6 +386,7 @@ class Trainer:
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
if self.rank <= 0:
self.save()
self.logger.info('Finished training!')
def create_training_generator(self, index):
@ -397,6 +403,7 @@ class Trainer:
for train_data in tq_ldr:
yield self.model
metric = self.do_step(train_data)
self.save()
self.logger.info('Finished training')
if __name__ == '__main__':