forked from mrq/DL-Art-School
fixes
This commit is contained in:
parent
b5c6acec9e
commit
3fdf2a63aa
|
@ -309,14 +309,6 @@ class Trainer:
|
|||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||
self.model.save_training_state(state)
|
||||
self.logger.info('Saving models and training states.')
|
||||
"""
|
||||
if 'alt_path' in opt['path'].keys():
|
||||
import shutil
|
||||
print("Synchronizing tb_logger to alt_path..")
|
||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
||||
"""
|
||||
|
||||
do_eval = self.total_training_data_encountered > self.next_eval_step
|
||||
if do_eval:
|
||||
|
@ -331,14 +323,6 @@ class Trainer:
|
|||
eval_dict.update(eval.perform_eval())
|
||||
if self.rank <= 0:
|
||||
print("Evaluator results: ", eval_dict)
|
||||
"""
|
||||
for ek, ev in eval_dict.items():
|
||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.log(eval_dict)
|
||||
"""
|
||||
|
||||
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
||||
for net in self.model.networks.values():
|
||||
|
@ -351,17 +335,20 @@ class Trainer:
|
|||
self.logger.info('Beginning validation.')
|
||||
|
||||
metrics = []
|
||||
tq_ldr = tqdm(self.train_loader, desc="Validating") if self.use_tqdm else self.train_loader
|
||||
tq_ldr = tqdm(self.val_loader, desc="Validating") if self.use_tqdm else self.val_loader
|
||||
|
||||
for val_data in tq_ldr:
|
||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||
metric = self.model.test()
|
||||
metrics.append(metric)
|
||||
if self.rank <= 0:
|
||||
if self.rank <= 0 and self.use_tqdm:
|
||||
logs = process_metrics( metrics )
|
||||
if self.use_tqdm:
|
||||
tq_ldr.set_postfix( logs, refresh=True )
|
||||
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
||||
tq_ldr.set_postfix( logs, refresh=True )
|
||||
|
||||
if self.rank <= 0:
|
||||
logs = process_metrics( metrics )
|
||||
logs['it'] = self.current_step
|
||||
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
||||
|
||||
def do_training(self):
|
||||
if self.rank <= 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user