fixes
This commit is contained in:
parent
b5c6acec9e
commit
3fdf2a63aa
|
@ -309,14 +309,6 @@ class Trainer:
|
||||||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||||
self.model.save_training_state(state)
|
self.model.save_training_state(state)
|
||||||
self.logger.info('Saving models and training states.')
|
self.logger.info('Saving models and training states.')
|
||||||
"""
|
|
||||||
if 'alt_path' in opt['path'].keys():
|
|
||||||
import shutil
|
|
||||||
print("Synchronizing tb_logger to alt_path..")
|
|
||||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
|
||||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
|
||||||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
|
||||||
"""
|
|
||||||
|
|
||||||
do_eval = self.total_training_data_encountered > self.next_eval_step
|
do_eval = self.total_training_data_encountered > self.next_eval_step
|
||||||
if do_eval:
|
if do_eval:
|
||||||
|
@ -331,14 +323,6 @@ class Trainer:
|
||||||
eval_dict.update(eval.perform_eval())
|
eval_dict.update(eval.perform_eval())
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
print("Evaluator results: ", eval_dict)
|
print("Evaluator results: ", eval_dict)
|
||||||
"""
|
|
||||||
for ek, ev in eval_dict.items():
|
|
||||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
|
||||||
|
|
||||||
if opt['wandb']:
|
|
||||||
import wandb
|
|
||||||
wandb.log(eval_dict)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
||||||
for net in self.model.networks.values():
|
for net in self.model.networks.values():
|
||||||
|
@ -351,16 +335,19 @@ class Trainer:
|
||||||
self.logger.info('Beginning validation.')
|
self.logger.info('Beginning validation.')
|
||||||
|
|
||||||
metrics = []
|
metrics = []
|
||||||
tq_ldr = tqdm(self.train_loader, desc="Validating") if self.use_tqdm else self.train_loader
|
tq_ldr = tqdm(self.val_loader, desc="Validating") if self.use_tqdm else self.val_loader
|
||||||
|
|
||||||
for val_data in tq_ldr:
|
for val_data in tq_ldr:
|
||||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||||
metric = self.model.test()
|
metric = self.model.test()
|
||||||
metrics.append(metric)
|
metrics.append(metric)
|
||||||
|
if self.rank <= 0 and self.use_tqdm:
|
||||||
|
logs = process_metrics( metrics )
|
||||||
|
tq_ldr.set_postfix( logs, refresh=True )
|
||||||
|
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
logs = process_metrics( metrics )
|
logs = process_metrics( metrics )
|
||||||
if self.use_tqdm:
|
logs['it'] = self.current_step
|
||||||
tq_ldr.set_postfix( logs, refresh=True )
|
|
||||||
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
||||||
|
|
||||||
def do_training(self):
|
def do_training(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user