forked from mrq/DL-Art-School
save when training completes
This commit is contained in:
parent
3fdf2a63aa
commit
b253da6e35
|
@ -196,6 +196,8 @@ class Trainer:
|
|||
self.total_training_data_encountered = self.current_step * opt['datasets']['train']['batch_size']
|
||||
opt['current_step'] = self.current_step
|
||||
|
||||
self.epoch = self.start_epoch
|
||||
|
||||
#### validation
|
||||
if 'val_freq' in opt['train'].keys():
|
||||
self.val_freq = opt['train']['val_freq'] * opt['datasets']['train']['batch_size']
|
||||
|
@ -205,6 +207,18 @@ class Trainer:
|
|||
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
||||
del resume_state # For whatever reason, this relieves a memory burden on the first GPU for some training sessions.
|
||||
|
||||
def save(self):
|
||||
self.model.save(self.current_step)
|
||||
state = {
|
||||
'epoch': self.epoch,
|
||||
'iter': self.current_step,
|
||||
'total_data_processed': self.total_training_data_encountered
|
||||
}
|
||||
if self.dataset_debugger is not None:
|
||||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||
self.model.save_training_state(state)
|
||||
self.logger.info('Saving models and training states.')
|
||||
|
||||
def do_step(self, train_data):
|
||||
if self._profile:
|
||||
print("Data fetch: %f" % (time() - _t))
|
||||
|
@ -226,7 +240,7 @@ class Trainer:
|
|||
_t = time()
|
||||
self.model.feed_data(train_data, self.current_step)
|
||||
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
|
||||
self.iteration_rate = (time() - _t) / batch_size
|
||||
self.iteration_rate = (time() - _t) # / batch_size
|
||||
if self._profile:
|
||||
print("Model feed + step: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
@ -299,16 +313,7 @@ class Trainer:
|
|||
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
self.model.consolidate_state()
|
||||
if self.rank <= 0:
|
||||
self.model.save(self.current_step)
|
||||
state = {
|
||||
'epoch': self.epoch,
|
||||
'iter': self.current_step,
|
||||
'total_data_processed': self.total_training_data_encountered
|
||||
}
|
||||
if self.dataset_debugger is not None:
|
||||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||
self.model.save_training_state(state)
|
||||
self.logger.info('Saving models and training states.')
|
||||
self.save()
|
||||
|
||||
do_eval = self.total_training_data_encountered > self.next_eval_step
|
||||
if do_eval:
|
||||
|
@ -381,6 +386,7 @@ class Trainer:
|
|||
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
|
||||
|
||||
if self.rank <= 0:
|
||||
self.save()
|
||||
self.logger.info('Finished training!')
|
||||
|
||||
def create_training_generator(self, index):
|
||||
|
@ -397,6 +403,7 @@ class Trainer:
|
|||
for train_data in tq_ldr:
|
||||
yield self.model
|
||||
metric = self.do_step(train_data)
|
||||
self.save()
|
||||
self.logger.info('Finished training')
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue
Block a user