diff --git a/codes/train.py b/codes/train.py index 70c9d7e1..6ca9cdab 100644 --- a/codes/train.py +++ b/codes/train.py @@ -196,6 +196,8 @@ class Trainer: self.total_training_data_encountered = self.current_step * opt['datasets']['train']['batch_size'] opt['current_step'] = self.current_step + self.epoch = self.start_epoch + #### validation if 'val_freq' in opt['train'].keys(): self.val_freq = opt['train']['val_freq'] * opt['datasets']['train']['batch_size'] @@ -205,6 +207,18 @@ class Trainer: self.next_eval_step = self.total_training_data_encountered + self.val_freq del resume_state # For whatever reason, this relieves a memory burden on the first GPU for some training sessions. + def save(self): + self.model.save(self.current_step) + state = { + 'epoch': self.epoch, + 'iter': self.current_step, + 'total_data_processed': self.total_training_data_encountered + } + if self.dataset_debugger is not None: + state['dataset_debugger_state'] = self.dataset_debugger.get_state() + self.model.save_training_state(state) + self.logger.info('Saving models and training states.') + def do_step(self, train_data): if self._profile: print("Data fetch: %f" % (time() - _t)) @@ -226,7 +240,7 @@ class Trainer: _t = time() self.model.feed_data(train_data, self.current_step) gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log) - self.iteration_rate = (time() - _t) / batch_size + self.iteration_rate = (time() - _t) # / batch_size if self._profile: print("Model feed + step: %f" % (time() - _t)) _t = time() @@ -299,16 +313,7 @@ class Trainer: if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0: self.model.consolidate_state() if self.rank <= 0: - self.model.save(self.current_step) - state = { - 'epoch': self.epoch, - 'iter': self.current_step, - 'total_data_processed': self.total_training_data_encountered - } - if self.dataset_debugger is not None: - state['dataset_debugger_state'] = self.dataset_debugger.get_state() - self.model.save_training_state(state) - self.logger.info('Saving models and training states.') + self.save() do_eval = self.total_training_data_encountered > self.next_eval_step if do_eval: @@ -381,6 +386,7 @@ class Trainer: self.logger.info(f'Training Metrics: {json.dumps(logs)}') if self.rank <= 0: + self.save() self.logger.info('Finished training!') def create_training_generator(self, index): @@ -397,6 +403,7 @@ class Trainer: for train_data in tq_ldr: yield self.model metric = self.do_step(train_data) + self.save() self.logger.info('Finished training') if __name__ == '__main__':