|
|
@ -309,14 +309,6 @@ class Trainer:
|
|
|
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
|
|
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
|
|
|
self.model.save_training_state(state)
|
|
|
|
self.model.save_training_state(state)
|
|
|
|
self.logger.info('Saving models and training states.')
|
|
|
|
self.logger.info('Saving models and training states.')
|
|
|
|
"""
|
|
|
|
|
|
|
|
if 'alt_path' in opt['path'].keys():
|
|
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
print("Synchronizing tb_logger to alt_path..")
|
|
|
|
|
|
|
|
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
|
|
|
|
|
|
|
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
|
|
|
|
|
|
|
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
do_eval = self.total_training_data_encountered > self.next_eval_step
|
|
|
|
do_eval = self.total_training_data_encountered > self.next_eval_step
|
|
|
|
if do_eval:
|
|
|
|
if do_eval:
|
|
|
@ -331,14 +323,6 @@ class Trainer:
|
|
|
|
eval_dict.update(eval.perform_eval())
|
|
|
|
eval_dict.update(eval.perform_eval())
|
|
|
|
if self.rank <= 0:
|
|
|
|
if self.rank <= 0:
|
|
|
|
print("Evaluator results: ", eval_dict)
|
|
|
|
print("Evaluator results: ", eval_dict)
|
|
|
|
"""
|
|
|
|
|
|
|
|
for ek, ev in eval_dict.items():
|
|
|
|
|
|
|
|
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt['wandb']:
|
|
|
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
wandb.log(eval_dict)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
|
|
|
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
|
|
|
for net in self.model.networks.values():
|
|
|
|
for net in self.model.networks.values():
|
|
|
@ -351,17 +335,20 @@ class Trainer:
|
|
|
|
self.logger.info('Beginning validation.')
|
|
|
|
self.logger.info('Beginning validation.')
|
|
|
|
|
|
|
|
|
|
|
|
metrics = []
|
|
|
|
metrics = []
|
|
|
|
tq_ldr = tqdm(self.train_loader, desc="Validating") if self.use_tqdm else self.train_loader
|
|
|
|
tq_ldr = tqdm(self.val_loader, desc="Validating") if self.use_tqdm else self.val_loader
|
|
|
|
|
|
|
|
|
|
|
|
for val_data in tq_ldr:
|
|
|
|
for val_data in tq_ldr:
|
|
|
|
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
|
|
|
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
|
|
|
metric = self.model.test()
|
|
|
|
metric = self.model.test()
|
|
|
|
metrics.append(metric)
|
|
|
|
metrics.append(metric)
|
|
|
|
if self.rank <= 0:
|
|
|
|
if self.rank <= 0 and self.use_tqdm:
|
|
|
|
logs = process_metrics( metrics )
|
|
|
|
logs = process_metrics( metrics )
|
|
|
|
if self.use_tqdm:
|
|
|
|
tq_ldr.set_postfix( logs, refresh=True )
|
|
|
|
tq_ldr.set_postfix( logs, refresh=True )
|
|
|
|
|
|
|
|
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
|
|
|
if self.rank <= 0:
|
|
|
|
|
|
|
|
logs = process_metrics( metrics )
|
|
|
|
|
|
|
|
logs['it'] = self.current_step
|
|
|
|
|
|
|
|
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
|
|
|
|
|
|
|
|
|
|
|
def do_training(self):
|
|
|
|
def do_training(self):
|
|
|
|
if self.rank <= 0:
|
|
|
|
if self.rank <= 0:
|
|
|
|