Begin a migration to specifying training rate on megasamples instead of arbitrary "steps"

This should help me greatly in tuning models.  It's also necessary now that batch size isn't really
respected; we simply step once the gradient direction becomes unstable.
This commit is contained in:
James Betker 2022-02-09 17:25:05 -07:00
parent 93ca619267
commit a930f2576e

View File

@ -39,6 +39,8 @@ class Trainer:
self._profile = False self._profile = False
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False) self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False) self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
self.current_step = 0
self.total_training_data_encountered = 0
#### loading resume state if exists #### loading resume state if exists
if opt['path'].get('resume_state', None): if opt['path'].get('resume_state', None):
@ -159,6 +161,7 @@ class Trainer:
self.start_epoch = resume_state['epoch'] self.start_epoch = resume_state['epoch']
self.current_step = resume_state['iter'] self.current_step = resume_state['iter']
self.total_training_data_encountered = opt_get(resume_state, ['total_data_processed'], 0)
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
else: else:
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step'] self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
@ -173,7 +176,11 @@ class Trainer:
_t = time() _t = time()
opt = self.opt opt = self.opt
batch_size = self.opt['datasets']['train']['batch_size'] # It may seem weird to derive this from opt, rather than train_data. The reason this is done is
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
self.current_step += 1 self.current_step += 1
self.total_training_data_encountered += batch_size
#### update learning rate #### update learning rate
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter']) self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
@ -191,7 +198,10 @@ class Trainer:
if self.dataset_debugger is not None: if self.dataset_debugger is not None:
self.dataset_debugger.update(train_data) self.dataset_debugger.update(train_data)
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0: if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
logs = self.model.get_current_log(self.current_step) logs = {'step': self.current_step,
'samples': self.total_training_data_encountered,
'megasamples': self.total_training_data_encountered / 1000000}
logs.update(self.model.get_current_log(self.current_step))
if self.dataset_debugger is not None: if self.dataset_debugger is not None:
logs.update(self.dataset_debugger.get_debugging_map()) logs.update(self.dataset_debugger.get_debugging_map())
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
@ -210,7 +220,10 @@ class Trainer:
self.tb_logger.add_scalar(k, v, self.current_step) self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb'] and self.rank <= 0: if opt['wandb'] and self.rank <= 0:
import wandb import wandb
wandb.log(logs, step=int(self.current_step * opt_get(opt, ['wandb_step_factor'], 1))) if opt_get(opt, ['wandb_progress_use_raw_steps'], False):
wandb.log(logs, step=self.current_step)
else:
wandb.log(logs, step=self.total_training_data_encountered)
self.logger.info(message) self.logger.info(message)
#### save models and training states #### save models and training states
@ -219,7 +232,7 @@ class Trainer:
if self.rank <= 0: if self.rank <= 0:
self.logger.info('Saving models and training states.') self.logger.info('Saving models and training states.')
self.model.save(self.current_step) self.model.save(self.current_step)
state = {'epoch': self.epoch, 'iter': self.current_step} state = {'epoch': self.epoch, 'iter': self.current_step, 'total_data_processed': self.total_training_data_encountered}
if self.dataset_debugger is not None: if self.dataset_debugger is not None:
state['dataset_debugger_state'] = self.dataset_debugger.get_state() state['dataset_debugger_state'] = self.dataset_debugger.get_state()
self.model.save_training_state(state) self.model.save_training_state(state)
@ -231,7 +244,11 @@ class Trainer:
shutil.copytree(self.tb_logger_path, alt_tblogger) shutil.copytree(self.tb_logger_path, alt_tblogger)
#### validation #### validation
if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0: if 'val_freq' in opt['train'].keys():
val_freq = opt['train']['val_freq'] * batch_size
else:
val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered % val_freq == 0:
metrics = [] metrics = []
for val_data in tqdm(self.val_loader): for val_data in tqdm(self.val_loader):
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False) self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)