diff --git a/codes/train.py b/codes/train.py index be9cabac..a92e7b35 100644 --- a/codes/train.py +++ b/codes/train.py @@ -39,6 +39,8 @@ class Trainer: self._profile = False self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False) self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False) + self.current_step = 0 + self.total_training_data_encountered = 0 #### loading resume state if exists if opt['path'].get('resume_state', None): @@ -159,6 +161,7 @@ class Trainer: self.start_epoch = resume_state['epoch'] self.current_step = resume_state['iter'] + self.total_training_data_encountered = opt_get(resume_state, ['total_data_processed'], 0) self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers else: self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step'] @@ -173,7 +176,11 @@ class Trainer: _t = time() opt = self.opt + batch_size = self.opt['datasets']['train']['batch_size'] # It may seem weird to derive this from opt, rather than train_data. The reason this is done is + # because train_data is process-local while the opt variant represents all of the data fed across all GPUs. self.current_step += 1 + self.total_training_data_encountered += batch_size + #### update learning rate self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter']) @@ -191,7 +198,10 @@ class Trainer: if self.dataset_debugger is not None: self.dataset_debugger.update(train_data) if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0: - logs = self.model.get_current_log(self.current_step) + logs = {'step': self.current_step, + 'samples': self.total_training_data_encountered, + 'megasamples': self.total_training_data_encountered / 1000000} + logs.update(self.model.get_current_log(self.current_step)) if self.dataset_debugger is not None: logs.update(self.dataset_debugger.get_debugging_map()) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step) @@ -210,7 +220,10 @@ class Trainer: self.tb_logger.add_scalar(k, v, self.current_step) if opt['wandb'] and self.rank <= 0: import wandb - wandb.log(logs, step=int(self.current_step * opt_get(opt, ['wandb_step_factor'], 1))) + if opt_get(opt, ['wandb_progress_use_raw_steps'], False): + wandb.log(logs, step=self.current_step) + else: + wandb.log(logs, step=self.total_training_data_encountered) self.logger.info(message) #### save models and training states @@ -219,7 +232,7 @@ class Trainer: if self.rank <= 0: self.logger.info('Saving models and training states.') self.model.save(self.current_step) - state = {'epoch': self.epoch, 'iter': self.current_step} + state = {'epoch': self.epoch, 'iter': self.current_step, 'total_data_processed': self.total_training_data_encountered} if self.dataset_debugger is not None: state['dataset_debugger_state'] = self.dataset_debugger.get_state() self.model.save_training_state(state) @@ -231,7 +244,11 @@ class Trainer: shutil.copytree(self.tb_logger_path, alt_tblogger) #### validation - if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0: + if 'val_freq' in opt['train'].keys(): + val_freq = opt['train']['val_freq'] * batch_size + else: + val_freq = int(opt['train']['val_freq_megasamples'] * 1000000) + if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered % val_freq == 0: metrics = [] for val_data in tqdm(self.val_loader): self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)