Fix eval logic to not run immediately

This commit is contained in:
James Betker 2022-04-07 11:29:57 -06:00
parent 305dc95e4b
commit e6387c7613

View File

@ -37,7 +37,6 @@ class Trainer:
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
self.current_step = 0
self.total_training_data_encountered = 0
self.next_eval_step = 0
#### loading resume state if exists
if opt['path'].get('resume_state', None):
@ -169,6 +168,14 @@ class Trainer:
self.total_training_data_encountered = self.current_step * opt['datasets']['train']['batch_size']
opt['current_step'] = self.current_step
#### validation
if 'val_freq' in opt['train'].keys():
self.val_freq = opt['train']['val_freq'] * opt['datasets']['train']['batch_size']
else:
self.val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
self.next_eval_step = self.total_training_data_encountered + self.val_freq
def do_step(self, train_data):
if self._profile:
print("Data fetch: %f" % (time() - _t))
@ -249,14 +256,10 @@ class Trainer:
shutil.rmtree(alt_tblogger, ignore_errors=True)
shutil.copytree(self.tb_logger_path, alt_tblogger)
#### validation
if 'val_freq' in opt['train'].keys():
val_freq = opt['train']['val_freq'] * batch_size
else:
val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered > self.next_eval_step:
self.next_eval_step = self.total_training_data_encountered + val_freq
do_eval = self.total_training_data_encountered > self.next_eval_step
if do_eval:
self.next_eval_step = self.total_training_data_encountered + self.val_freq
if opt_get(opt, ['eval', 'pure'], False) and do_eval:
metrics = []
for val_data in tqdm(self.val_loader):
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
@ -278,7 +281,7 @@ class Trainer:
import wandb
wandb.log({f'eval_{k}': torch.stack(v).mean().item() for k,v in reduced_metrics.items()})
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
if len(self.evaluators) != 0 and do_eval:
eval_dict = {}
for eval in self.evaluators:
if eval.uses_all_ddp or self.rank <= 0: