gutted bloat loggers, now all my useful metrics update per step
This commit is contained in:
parent
bf94744514
commit
b5c6acec9e
202
codes/train.py
202
codes/train.py
|
@ -1,9 +1,11 @@
|
|||
import os
|
||||
import sys
|
||||
import math
|
||||
import argparse
|
||||
import random
|
||||
import logging
|
||||
import shutil
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -21,6 +23,33 @@ from utils.util import opt_get, map_cuda_to_correct_device
|
|||
|
||||
import tortoise.utils.torch_intermediary as ml
|
||||
|
||||
def try_json( data ):
|
||||
reduced = {}
|
||||
for k, v in data.items():
|
||||
try:
|
||||
json.dumps(v)
|
||||
except Exception as e:
|
||||
continue
|
||||
reduced[k] = v
|
||||
return json.dumps(reduced)
|
||||
|
||||
def process_metrics( metrics ):
|
||||
reduced = {}
|
||||
for metric in metrics:
|
||||
d = metric.as_dict() if hasattr(metric, 'as_dict') else metric
|
||||
for k, v in d.items():
|
||||
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
|
||||
if k in reduced.keys():
|
||||
reduced[k].append(v)
|
||||
else:
|
||||
reduced[k] = [v]
|
||||
logs = {}
|
||||
|
||||
for k, v in reduced.items():
|
||||
logs[k] = torch.stack(v).mean().item()
|
||||
|
||||
return logs
|
||||
|
||||
def init_dist(backend, **kwargs):
|
||||
# These packages have globals that screw with Windows, so only import them if needed.
|
||||
import torch.distributed as dist
|
||||
|
@ -32,13 +61,16 @@ def init_dist(backend, **kwargs):
|
|||
|
||||
class Trainer:
|
||||
|
||||
def init(self, opt_path, opt, launcher):
|
||||
def init(self, opt_path, opt, launcher, mode):
|
||||
self._profile = False
|
||||
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
||||
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
||||
self.current_step = 0
|
||||
self.iteration_rate = 0
|
||||
self.total_training_data_encountered = 0
|
||||
|
||||
self.use_tqdm = False # self.rank <= 0
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
|
@ -57,17 +89,11 @@ class Trainer:
|
|||
shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
|
||||
|
||||
# config loggers. Before it, the log will not work
|
||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True)
|
||||
self.logger = logging.getLogger('base')
|
||||
self.logger.info(option.dict2str(opt))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
|
||||
else:
|
||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=False)
|
||||
self.logger = logging.getLogger('base')
|
||||
|
||||
if resume_state is not None:
|
||||
|
@ -77,14 +103,6 @@ class Trainer:
|
|||
opt = option.dict_to_nonedict(opt)
|
||||
self.opt = opt
|
||||
|
||||
#### wandb init
|
||||
if opt['wandb'] and self.rank <= 0:
|
||||
import wandb
|
||||
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
|
||||
project_name = opt_get(opt, ['wandb_project_name'], opt['name'])
|
||||
run_name = opt_get(opt, ['wandb_run_name'], None)
|
||||
wandb.init(project=project_name, dir=opt['path']['log'], config=opt, name=run_name)
|
||||
|
||||
#### random seed
|
||||
seed = opt['train']['manual_seed']
|
||||
if seed is None:
|
||||
|
@ -197,7 +215,7 @@ class Trainer:
|
|||
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
|
||||
self.current_step += 1
|
||||
self.total_training_data_encountered += batch_size
|
||||
will_log = self.current_step % opt['logger']['print_freq'] == 0
|
||||
will_log = False # self.current_step % opt['logger']['print_freq'] == 0
|
||||
|
||||
#### update learning rate
|
||||
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||
|
@ -208,11 +226,15 @@ class Trainer:
|
|||
_t = time()
|
||||
self.model.feed_data(train_data, self.current_step)
|
||||
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
|
||||
iteration_rate = (time() - _t) / batch_size
|
||||
self.iteration_rate = (time() - _t) / batch_size
|
||||
if self._profile:
|
||||
print("Model feed + step: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
metrics = {}
|
||||
for s in self.model.steps:
|
||||
metrics.update(s.get_metrics())
|
||||
|
||||
#### log
|
||||
if self.dataset_debugger is not None:
|
||||
self.dataset_debugger.update(train_data)
|
||||
|
@ -220,6 +242,22 @@ class Trainer:
|
|||
# Must be run by all instances to gather consensus.
|
||||
current_model_logs = self.model.get_current_log(self.current_step)
|
||||
if will_log and self.rank <= 0:
|
||||
logs = {
|
||||
'step': self.current_step,
|
||||
'samples': self.total_training_data_encountered,
|
||||
'megasamples': self.total_training_data_encountered / 1000000,
|
||||
'iteration_rate': self.iteration_rate,
|
||||
'lr': self.model.get_current_learning_rate(),
|
||||
}
|
||||
logs.update(current_model_logs)
|
||||
|
||||
if self.dataset_debugger is not None:
|
||||
logs.update(self.dataset_debugger.get_debugging_map())
|
||||
|
||||
logs.update(gradient_norms_dict)
|
||||
self.logger.info(f'Training Metrics: {try_json(logs)}')
|
||||
|
||||
"""
|
||||
logs = {'step': self.current_step,
|
||||
'samples': self.total_training_data_encountered,
|
||||
'megasamples': self.total_training_data_encountered / 1000000,
|
||||
|
@ -255,78 +293,108 @@ class Trainer:
|
|||
else:
|
||||
wandb.log(wandb_logs, step=self.total_training_data_encountered)
|
||||
self.logger.info(message)
|
||||
"""
|
||||
|
||||
#### save models and training states
|
||||
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
self.model.consolidate_state()
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Saving models and training states.')
|
||||
self.model.save(self.current_step)
|
||||
state = {'epoch': self.epoch, 'iter': self.current_step, 'total_data_processed': self.total_training_data_encountered}
|
||||
state = {
|
||||
'epoch': self.epoch,
|
||||
'iter': self.current_step,
|
||||
'total_data_processed': self.total_training_data_encountered
|
||||
}
|
||||
if self.dataset_debugger is not None:
|
||||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||
self.model.save_training_state(state)
|
||||
self.logger.info('Saving models and training states.')
|
||||
"""
|
||||
if 'alt_path' in opt['path'].keys():
|
||||
import shutil
|
||||
print("Synchronizing tb_logger to alt_path..")
|
||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
||||
"""
|
||||
|
||||
do_eval = self.total_training_data_encountered > self.next_eval_step
|
||||
if do_eval:
|
||||
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
||||
if opt_get(opt, ['eval', 'pure'], False) and do_eval:
|
||||
metrics = []
|
||||
for val_data in tqdm(self.val_loader):
|
||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||
metrics.append(self.model.test())
|
||||
reduced_metrics = {}
|
||||
for metric in metrics:
|
||||
for k, v in metric.as_dict().items():
|
||||
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
|
||||
if k in reduced_metrics.keys():
|
||||
reduced_metrics[k].append(v)
|
||||
else:
|
||||
reduced_metrics[k] = [v]
|
||||
if self.rank <= 0:
|
||||
for k, v in reduced_metrics.items():
|
||||
val = torch.stack(v).mean().item()
|
||||
self.tb_logger.add_scalar(f'val_{k}', val, self.current_step)
|
||||
print(f">>Eval {k}: {val}")
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.log({f'eval_{k}': torch.stack(v).mean().item() for k,v in reduced_metrics.items()})
|
||||
|
||||
if len(self.evaluators) != 0 and do_eval:
|
||||
eval_dict = {}
|
||||
for eval in self.evaluators:
|
||||
if eval.uses_all_ddp or self.rank <= 0:
|
||||
eval_dict.update(eval.perform_eval())
|
||||
if self.rank <= 0:
|
||||
print("Evaluator results: ", eval_dict)
|
||||
for ek, ev in eval_dict.items():
|
||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.log(eval_dict)
|
||||
if opt_get(opt, ['eval', 'pure'], False):
|
||||
self.do_validation()
|
||||
if len(self.evaluators) != 0:
|
||||
eval_dict = {}
|
||||
for eval in self.evaluators:
|
||||
if eval.uses_all_ddp or self.rank <= 0:
|
||||
eval_dict.update(eval.perform_eval())
|
||||
if self.rank <= 0:
|
||||
print("Evaluator results: ", eval_dict)
|
||||
"""
|
||||
for ek, ev in eval_dict.items():
|
||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.log(eval_dict)
|
||||
"""
|
||||
|
||||
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
||||
for net in self.model.networks.values():
|
||||
net.zero_grad()
|
||||
|
||||
return metrics
|
||||
|
||||
def do_validation(self):
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Beginning validation.')
|
||||
|
||||
metrics = []
|
||||
tq_ldr = tqdm(self.train_loader, desc="Validating") if self.use_tqdm else self.train_loader
|
||||
|
||||
for val_data in tq_ldr:
|
||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||
metric = self.model.test()
|
||||
metrics.append(metric)
|
||||
if self.rank <= 0:
|
||||
logs = process_metrics( metrics )
|
||||
if self.use_tqdm:
|
||||
tq_ldr.set_postfix( logs, refresh=True )
|
||||
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
||||
|
||||
def do_training(self):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
self.epoch = epoch
|
||||
if self.opt['dist']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
|
||||
tq_ldr = tqdm(self.train_loader) if self.rank <= 0 else self.train_loader
|
||||
metrics = []
|
||||
tq_ldr = tqdm(self.train_loader, desc="Training") if self.use_tqdm else self.train_loader
|
||||
|
||||
_t = time()
|
||||
step = 0
|
||||
for train_data in tq_ldr:
|
||||
self.do_step(train_data)
|
||||
step = step + 1
|
||||
metric = self.do_step(train_data)
|
||||
metrics.append(metric)
|
||||
if self.rank <= 0:
|
||||
logs = process_metrics( metrics )
|
||||
logs['lr'] = self.model.get_current_learning_rate()[0]
|
||||
if self.use_tqdm:
|
||||
tq_ldr.set_postfix( logs, refresh=True )
|
||||
logs['it'] = self.current_step
|
||||
logs['step'] = step
|
||||
logs['steps'] = len(self.train_loader)
|
||||
logs['epoch'] = self.epoch
|
||||
logs['iteration_rate'] = self.iteration_rate
|
||||
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
|
||||
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Finished training!')
|
||||
|
||||
def create_training_generator(self, index):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
|
@ -334,26 +402,31 @@ class Trainer:
|
|||
self.epoch = epoch
|
||||
if self.opt['dist']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
|
||||
tq_ldr = tqdm(self.train_loader, position=index)
|
||||
tq_ldr.set_description('Training')
|
||||
|
||||
_t = time()
|
||||
for train_data in tq_ldr:
|
||||
yield self.model
|
||||
self.do_step(train_data)
|
||||
|
||||
metric = self.do_step(train_data)
|
||||
self.logger.info('Finished training')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--mode', type=str, default='', help='Handles printing info')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
||||
if args.launcher != 'none':
|
||||
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
||||
if 'gpu_ids' in opt.keys():
|
||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||
|
||||
trainer = Trainer()
|
||||
|
||||
#### distributed training settings
|
||||
|
@ -370,5 +443,10 @@ if __name__ == '__main__':
|
|||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
||||
trainer.init(args.opt, opt, args.launcher)
|
||||
trainer.do_training()
|
||||
if trainer.rank >= 1:
|
||||
f = open(os.devnull, 'w')
|
||||
sys.stdout = f
|
||||
sys.stderr = f
|
||||
|
||||
trainer.init(args.opt, opt, args.launcher, args.mode)
|
||||
trainer.do_training()
|
Loading…
Reference in New Issue
Block a user