Fix eval logic to not run immediately
This commit is contained in:
parent
305dc95e4b
commit
e6387c7613
|
@ -37,7 +37,6 @@ class Trainer:
|
|||
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
||||
self.current_step = 0
|
||||
self.total_training_data_encountered = 0
|
||||
self.next_eval_step = 0
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
|
@ -169,6 +168,14 @@ class Trainer:
|
|||
self.total_training_data_encountered = self.current_step * opt['datasets']['train']['batch_size']
|
||||
opt['current_step'] = self.current_step
|
||||
|
||||
#### validation
|
||||
if 'val_freq' in opt['train'].keys():
|
||||
self.val_freq = opt['train']['val_freq'] * opt['datasets']['train']['batch_size']
|
||||
else:
|
||||
self.val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
|
||||
|
||||
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
||||
|
||||
def do_step(self, train_data):
|
||||
if self._profile:
|
||||
print("Data fetch: %f" % (time() - _t))
|
||||
|
@ -249,14 +256,10 @@ class Trainer:
|
|||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
||||
|
||||
#### validation
|
||||
if 'val_freq' in opt['train'].keys():
|
||||
val_freq = opt['train']['val_freq'] * batch_size
|
||||
else:
|
||||
val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
|
||||
|
||||
if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered > self.next_eval_step:
|
||||
self.next_eval_step = self.total_training_data_encountered + val_freq
|
||||
do_eval = self.total_training_data_encountered > self.next_eval_step
|
||||
if do_eval:
|
||||
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
||||
if opt_get(opt, ['eval', 'pure'], False) and do_eval:
|
||||
metrics = []
|
||||
for val_data in tqdm(self.val_loader):
|
||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||
|
@ -278,7 +281,7 @@ class Trainer:
|
|||
import wandb
|
||||
wandb.log({f'eval_{k}': torch.stack(v).mean().item() for k,v in reduced_metrics.items()})
|
||||
|
||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
||||
if len(self.evaluators) != 0 and do_eval:
|
||||
eval_dict = {}
|
||||
for eval in self.evaluators:
|
||||
if eval.uses_all_ddp or self.rank <= 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user