Begin a migration to specifying training rate on megasamples instead of arbitrary "steps"
This should help me greatly in tuning models. It's also necessary now that batch size isn't really respected; we simply step once the gradient direction becomes unstable.
This commit is contained in:
parent
93ca619267
commit
a930f2576e
|
@ -39,6 +39,8 @@ class Trainer:
|
||||||
self._profile = False
|
self._profile = False
|
||||||
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
||||||
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
||||||
|
self.current_step = 0
|
||||||
|
self.total_training_data_encountered = 0
|
||||||
|
|
||||||
#### loading resume state if exists
|
#### loading resume state if exists
|
||||||
if opt['path'].get('resume_state', None):
|
if opt['path'].get('resume_state', None):
|
||||||
|
@ -159,6 +161,7 @@ class Trainer:
|
||||||
|
|
||||||
self.start_epoch = resume_state['epoch']
|
self.start_epoch = resume_state['epoch']
|
||||||
self.current_step = resume_state['iter']
|
self.current_step = resume_state['iter']
|
||||||
|
self.total_training_data_encountered = opt_get(resume_state, ['total_data_processed'], 0)
|
||||||
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
|
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
|
||||||
else:
|
else:
|
||||||
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||||
|
@ -173,7 +176,11 @@ class Trainer:
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
opt = self.opt
|
opt = self.opt
|
||||||
|
batch_size = self.opt['datasets']['train']['batch_size'] # It may seem weird to derive this from opt, rather than train_data. The reason this is done is
|
||||||
|
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
|
self.total_training_data_encountered += batch_size
|
||||||
|
|
||||||
#### update learning rate
|
#### update learning rate
|
||||||
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||||
|
|
||||||
|
@ -191,7 +198,10 @@ class Trainer:
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
self.dataset_debugger.update(train_data)
|
self.dataset_debugger.update(train_data)
|
||||||
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
||||||
logs = self.model.get_current_log(self.current_step)
|
logs = {'step': self.current_step,
|
||||||
|
'samples': self.total_training_data_encountered,
|
||||||
|
'megasamples': self.total_training_data_encountered / 1000000}
|
||||||
|
logs.update(self.model.get_current_log(self.current_step))
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
logs.update(self.dataset_debugger.get_debugging_map())
|
logs.update(self.dataset_debugger.get_debugging_map())
|
||||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
||||||
|
@ -210,7 +220,10 @@ class Trainer:
|
||||||
self.tb_logger.add_scalar(k, v, self.current_step)
|
self.tb_logger.add_scalar(k, v, self.current_step)
|
||||||
if opt['wandb'] and self.rank <= 0:
|
if opt['wandb'] and self.rank <= 0:
|
||||||
import wandb
|
import wandb
|
||||||
wandb.log(logs, step=int(self.current_step * opt_get(opt, ['wandb_step_factor'], 1)))
|
if opt_get(opt, ['wandb_progress_use_raw_steps'], False):
|
||||||
|
wandb.log(logs, step=self.current_step)
|
||||||
|
else:
|
||||||
|
wandb.log(logs, step=self.total_training_data_encountered)
|
||||||
self.logger.info(message)
|
self.logger.info(message)
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
|
@ -219,7 +232,7 @@ class Trainer:
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.logger.info('Saving models and training states.')
|
self.logger.info('Saving models and training states.')
|
||||||
self.model.save(self.current_step)
|
self.model.save(self.current_step)
|
||||||
state = {'epoch': self.epoch, 'iter': self.current_step}
|
state = {'epoch': self.epoch, 'iter': self.current_step, 'total_data_processed': self.total_training_data_encountered}
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||||
self.model.save_training_state(state)
|
self.model.save_training_state(state)
|
||||||
|
@ -231,7 +244,11 @@ class Trainer:
|
||||||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
||||||
|
|
||||||
#### validation
|
#### validation
|
||||||
if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0:
|
if 'val_freq' in opt['train'].keys():
|
||||||
|
val_freq = opt['train']['val_freq'] * batch_size
|
||||||
|
else:
|
||||||
|
val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
|
||||||
|
if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered % val_freq == 0:
|
||||||
metrics = []
|
metrics = []
|
||||||
for val_data in tqdm(self.val_loader):
|
for val_data in tqdm(self.val_loader):
|
||||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user