forked from mrq/DL-Art-School
Begin a migration to specifying training rate on megasamples instead of arbitrary "steps"
This should help me greatly in tuning models. It's also necessary now that batch size isn't really respected; we simply step once the gradient direction becomes unstable.
This commit is contained in:
parent
93ca619267
commit
a930f2576e
|
@ -39,6 +39,8 @@ class Trainer:
|
|||
self._profile = False
|
||||
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
||||
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
||||
self.current_step = 0
|
||||
self.total_training_data_encountered = 0
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
|
@ -159,6 +161,7 @@ class Trainer:
|
|||
|
||||
self.start_epoch = resume_state['epoch']
|
||||
self.current_step = resume_state['iter']
|
||||
self.total_training_data_encountered = opt_get(resume_state, ['total_data_processed'], 0)
|
||||
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
|
||||
else:
|
||||
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
|
@ -173,7 +176,11 @@ class Trainer:
|
|||
_t = time()
|
||||
|
||||
opt = self.opt
|
||||
batch_size = self.opt['datasets']['train']['batch_size'] # It may seem weird to derive this from opt, rather than train_data. The reason this is done is
|
||||
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
|
||||
self.current_step += 1
|
||||
self.total_training_data_encountered += batch_size
|
||||
|
||||
#### update learning rate
|
||||
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||
|
||||
|
@ -191,7 +198,10 @@ class Trainer:
|
|||
if self.dataset_debugger is not None:
|
||||
self.dataset_debugger.update(train_data)
|
||||
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
||||
logs = self.model.get_current_log(self.current_step)
|
||||
logs = {'step': self.current_step,
|
||||
'samples': self.total_training_data_encountered,
|
||||
'megasamples': self.total_training_data_encountered / 1000000}
|
||||
logs.update(self.model.get_current_log(self.current_step))
|
||||
if self.dataset_debugger is not None:
|
||||
logs.update(self.dataset_debugger.get_debugging_map())
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
||||
|
@ -210,7 +220,10 @@ class Trainer:
|
|||
self.tb_logger.add_scalar(k, v, self.current_step)
|
||||
if opt['wandb'] and self.rank <= 0:
|
||||
import wandb
|
||||
wandb.log(logs, step=int(self.current_step * opt_get(opt, ['wandb_step_factor'], 1)))
|
||||
if opt_get(opt, ['wandb_progress_use_raw_steps'], False):
|
||||
wandb.log(logs, step=self.current_step)
|
||||
else:
|
||||
wandb.log(logs, step=self.total_training_data_encountered)
|
||||
self.logger.info(message)
|
||||
|
||||
#### save models and training states
|
||||
|
@ -219,7 +232,7 @@ class Trainer:
|
|||
if self.rank <= 0:
|
||||
self.logger.info('Saving models and training states.')
|
||||
self.model.save(self.current_step)
|
||||
state = {'epoch': self.epoch, 'iter': self.current_step}
|
||||
state = {'epoch': self.epoch, 'iter': self.current_step, 'total_data_processed': self.total_training_data_encountered}
|
||||
if self.dataset_debugger is not None:
|
||||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||
self.model.save_training_state(state)
|
||||
|
@ -231,7 +244,11 @@ class Trainer:
|
|||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
||||
|
||||
#### validation
|
||||
if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0:
|
||||
if 'val_freq' in opt['train'].keys():
|
||||
val_freq = opt['train']['val_freq'] * batch_size
|
||||
else:
|
||||
val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
|
||||
if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered % val_freq == 0:
|
||||
metrics = []
|
||||
for val_data in tqdm(self.val_loader):
|
||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||
|
|
Loading…
Reference in New Issue
Block a user