This commit is contained in:
mrq 2023-03-11 01:18:25 +00:00
parent b5c6acec9e
commit 3fdf2a63aa

View File

@ -309,14 +309,6 @@ class Trainer:
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
self.model.save_training_state(state)
self.logger.info('Saving models and training states.')
"""
if 'alt_path' in opt['path'].keys():
import shutil
print("Synchronizing tb_logger to alt_path..")
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
shutil.rmtree(alt_tblogger, ignore_errors=True)
shutil.copytree(self.tb_logger_path, alt_tblogger)
"""
do_eval = self.total_training_data_encountered > self.next_eval_step
if do_eval:
@ -331,14 +323,6 @@ class Trainer:
eval_dict.update(eval.perform_eval())
if self.rank <= 0:
print("Evaluator results: ", eval_dict)
"""
for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step)
if opt['wandb']:
import wandb
wandb.log(eval_dict)
"""
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
for net in self.model.networks.values():
@ -351,17 +335,20 @@ class Trainer:
self.logger.info('Beginning validation.')
metrics = []
tq_ldr = tqdm(self.train_loader, desc="Validating") if self.use_tqdm else self.train_loader
tq_ldr = tqdm(self.val_loader, desc="Validating") if self.use_tqdm else self.val_loader
for val_data in tq_ldr:
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
metric = self.model.test()
metrics.append(metric)
if self.rank <= 0:
if self.rank <= 0 and self.use_tqdm:
logs = process_metrics( metrics )
if self.use_tqdm:
tq_ldr.set_postfix( logs, refresh=True )
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
tq_ldr.set_postfix( logs, refresh=True )
if self.rank <= 0:
logs = process_metrics( metrics )
logs['it'] = self.current_step
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
def do_training(self):
if self.rank <= 0: