forked from mrq/DL-Art-School
save when training completes
This commit is contained in:
parent
3fdf2a63aa
commit
b253da6e35
|
@ -196,6 +196,8 @@ class Trainer:
|
||||||
self.total_training_data_encountered = self.current_step * opt['datasets']['train']['batch_size']
|
self.total_training_data_encountered = self.current_step * opt['datasets']['train']['batch_size']
|
||||||
opt['current_step'] = self.current_step
|
opt['current_step'] = self.current_step
|
||||||
|
|
||||||
|
self.epoch = self.start_epoch
|
||||||
|
|
||||||
#### validation
|
#### validation
|
||||||
if 'val_freq' in opt['train'].keys():
|
if 'val_freq' in opt['train'].keys():
|
||||||
self.val_freq = opt['train']['val_freq'] * opt['datasets']['train']['batch_size']
|
self.val_freq = opt['train']['val_freq'] * opt['datasets']['train']['batch_size']
|
||||||
|
@ -205,6 +207,18 @@ class Trainer:
|
||||||
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
||||||
del resume_state # For whatever reason, this relieves a memory burden on the first GPU for some training sessions.
|
del resume_state # For whatever reason, this relieves a memory burden on the first GPU for some training sessions.
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.model.save(self.current_step)
|
||||||
|
state = {
|
||||||
|
'epoch': self.epoch,
|
||||||
|
'iter': self.current_step,
|
||||||
|
'total_data_processed': self.total_training_data_encountered
|
||||||
|
}
|
||||||
|
if self.dataset_debugger is not None:
|
||||||
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||||
|
self.model.save_training_state(state)
|
||||||
|
self.logger.info('Saving models and training states.')
|
||||||
|
|
||||||
def do_step(self, train_data):
|
def do_step(self, train_data):
|
||||||
if self._profile:
|
if self._profile:
|
||||||
print("Data fetch: %f" % (time() - _t))
|
print("Data fetch: %f" % (time() - _t))
|
||||||
|
@ -226,7 +240,7 @@ class Trainer:
|
||||||
_t = time()
|
_t = time()
|
||||||
self.model.feed_data(train_data, self.current_step)
|
self.model.feed_data(train_data, self.current_step)
|
||||||
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
|
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
|
||||||
self.iteration_rate = (time() - _t) / batch_size
|
self.iteration_rate = (time() - _t) # / batch_size
|
||||||
if self._profile:
|
if self._profile:
|
||||||
print("Model feed + step: %f" % (time() - _t))
|
print("Model feed + step: %f" % (time() - _t))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
@ -299,16 +313,7 @@ class Trainer:
|
||||||
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||||
self.model.consolidate_state()
|
self.model.consolidate_state()
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.model.save(self.current_step)
|
self.save()
|
||||||
state = {
|
|
||||||
'epoch': self.epoch,
|
|
||||||
'iter': self.current_step,
|
|
||||||
'total_data_processed': self.total_training_data_encountered
|
|
||||||
}
|
|
||||||
if self.dataset_debugger is not None:
|
|
||||||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
|
||||||
self.model.save_training_state(state)
|
|
||||||
self.logger.info('Saving models and training states.')
|
|
||||||
|
|
||||||
do_eval = self.total_training_data_encountered > self.next_eval_step
|
do_eval = self.total_training_data_encountered > self.next_eval_step
|
||||||
if do_eval:
|
if do_eval:
|
||||||
|
@ -381,6 +386,7 @@ class Trainer:
|
||||||
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
|
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
|
||||||
|
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
|
self.save()
|
||||||
self.logger.info('Finished training!')
|
self.logger.info('Finished training!')
|
||||||
|
|
||||||
def create_training_generator(self, index):
|
def create_training_generator(self, index):
|
||||||
|
@ -397,6 +403,7 @@ class Trainer:
|
||||||
for train_data in tq_ldr:
|
for train_data in tq_ldr:
|
||||||
yield self.model
|
yield self.model
|
||||||
metric = self.do_step(train_data)
|
metric = self.do_step(train_data)
|
||||||
|
self.save()
|
||||||
self.logger.info('Finished training')
|
self.logger.info('Finished training')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user