gutted bloat loggers, now all my useful metrics update per step

remotes/1719357675652877588/tmp_refs/heads/master
mrq 2023-03-10 22:34:37 +07:00
parent bf94744514
commit b5c6acec9e
1 changed files with 141 additions and 63 deletions

@ -1,9 +1,11 @@
import os import os
import sys
import math import math
import argparse import argparse
import random import random
import logging import logging
import shutil import shutil
import json
from tqdm import tqdm from tqdm import tqdm
@ -21,6 +23,33 @@ from utils.util import opt_get, map_cuda_to_correct_device
import tortoise.utils.torch_intermediary as ml import tortoise.utils.torch_intermediary as ml
def try_json( data ):
reduced = {}
for k, v in data.items():
try:
json.dumps(v)
except Exception as e:
continue
reduced[k] = v
return json.dumps(reduced)
def process_metrics( metrics ):
reduced = {}
for metric in metrics:
d = metric.as_dict() if hasattr(metric, 'as_dict') else metric
for k, v in d.items():
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
if k in reduced.keys():
reduced[k].append(v)
else:
reduced[k] = [v]
logs = {}
for k, v in reduced.items():
logs[k] = torch.stack(v).mean().item()
return logs
def init_dist(backend, **kwargs): def init_dist(backend, **kwargs):
# These packages have globals that screw with Windows, so only import them if needed. # These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist import torch.distributed as dist
@ -32,13 +61,16 @@ def init_dist(backend, **kwargs):
class Trainer: class Trainer:
def init(self, opt_path, opt, launcher): def init(self, opt_path, opt, launcher, mode):
self._profile = False self._profile = False
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False) self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False) self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
self.current_step = 0 self.current_step = 0
self.iteration_rate = 0
self.total_training_data_encountered = 0 self.total_training_data_encountered = 0
self.use_tqdm = False # self.rank <= 0
#### loading resume state if exists #### loading resume state if exists
if opt['path'].get('resume_state', None): if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU # distributed resuming: all load into default GPU
@ -57,17 +89,11 @@ class Trainer:
shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}')) shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
# config loggers. Before it, the log will not work # config loggers. Before it, the log will not work
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True)
screen=True, tofile=True)
self.logger = logging.getLogger('base') self.logger = logging.getLogger('base')
self.logger.info(option.dict2str(opt)) self.logger.info(option.dict2str(opt))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
from torch.utils.tensorboard import SummaryWriter
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
else: else:
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=False)
self.logger = logging.getLogger('base') self.logger = logging.getLogger('base')
if resume_state is not None: if resume_state is not None:
@ -77,14 +103,6 @@ class Trainer:
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
self.opt = opt self.opt = opt
#### wandb init
if opt['wandb'] and self.rank <= 0:
import wandb
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
project_name = opt_get(opt, ['wandb_project_name'], opt['name'])
run_name = opt_get(opt, ['wandb_run_name'], None)
wandb.init(project=project_name, dir=opt['path']['log'], config=opt, name=run_name)
#### random seed #### random seed
seed = opt['train']['manual_seed'] seed = opt['train']['manual_seed']
if seed is None: if seed is None:
@ -197,7 +215,7 @@ class Trainer:
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs. # because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
self.current_step += 1 self.current_step += 1
self.total_training_data_encountered += batch_size self.total_training_data_encountered += batch_size
will_log = self.current_step % opt['logger']['print_freq'] == 0 will_log = False # self.current_step % opt['logger']['print_freq'] == 0
#### update learning rate #### update learning rate
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter']) self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
@ -208,11 +226,15 @@ class Trainer:
_t = time() _t = time()
self.model.feed_data(train_data, self.current_step) self.model.feed_data(train_data, self.current_step)
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log) gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
iteration_rate = (time() - _t) / batch_size self.iteration_rate = (time() - _t) / batch_size
if self._profile: if self._profile:
print("Model feed + step: %f" % (time() - _t)) print("Model feed + step: %f" % (time() - _t))
_t = time() _t = time()
metrics = {}
for s in self.model.steps:
metrics.update(s.get_metrics())
#### log #### log
if self.dataset_debugger is not None: if self.dataset_debugger is not None:
self.dataset_debugger.update(train_data) self.dataset_debugger.update(train_data)
@ -220,6 +242,22 @@ class Trainer:
# Must be run by all instances to gather consensus. # Must be run by all instances to gather consensus.
current_model_logs = self.model.get_current_log(self.current_step) current_model_logs = self.model.get_current_log(self.current_step)
if will_log and self.rank <= 0: if will_log and self.rank <= 0:
logs = {
'step': self.current_step,
'samples': self.total_training_data_encountered,
'megasamples': self.total_training_data_encountered / 1000000,
'iteration_rate': self.iteration_rate,
'lr': self.model.get_current_learning_rate(),
}
logs.update(current_model_logs)
if self.dataset_debugger is not None:
logs.update(self.dataset_debugger.get_debugging_map())
logs.update(gradient_norms_dict)
self.logger.info(f'Training Metrics: {try_json(logs)}')
"""
logs = {'step': self.current_step, logs = {'step': self.current_step,
'samples': self.total_training_data_encountered, 'samples': self.total_training_data_encountered,
'megasamples': self.total_training_data_encountered / 1000000, 'megasamples': self.total_training_data_encountered / 1000000,
@ -255,78 +293,108 @@ class Trainer:
else: else:
wandb.log(wandb_logs, step=self.total_training_data_encountered) wandb.log(wandb_logs, step=self.total_training_data_encountered)
self.logger.info(message) self.logger.info(message)
"""
#### save models and training states #### save models and training states
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0: if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
self.model.consolidate_state() self.model.consolidate_state()
if self.rank <= 0: if self.rank <= 0:
self.logger.info('Saving models and training states.')
self.model.save(self.current_step) self.model.save(self.current_step)
state = {'epoch': self.epoch, 'iter': self.current_step, 'total_data_processed': self.total_training_data_encountered} state = {
'epoch': self.epoch,
'iter': self.current_step,
'total_data_processed': self.total_training_data_encountered
}
if self.dataset_debugger is not None: if self.dataset_debugger is not None:
state['dataset_debugger_state'] = self.dataset_debugger.get_state() state['dataset_debugger_state'] = self.dataset_debugger.get_state()
self.model.save_training_state(state) self.model.save_training_state(state)
self.logger.info('Saving models and training states.')
"""
if 'alt_path' in opt['path'].keys(): if 'alt_path' in opt['path'].keys():
import shutil import shutil
print("Synchronizing tb_logger to alt_path..") print("Synchronizing tb_logger to alt_path..")
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger") alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
shutil.rmtree(alt_tblogger, ignore_errors=True) shutil.rmtree(alt_tblogger, ignore_errors=True)
shutil.copytree(self.tb_logger_path, alt_tblogger) shutil.copytree(self.tb_logger_path, alt_tblogger)
"""
do_eval = self.total_training_data_encountered > self.next_eval_step do_eval = self.total_training_data_encountered > self.next_eval_step
if do_eval: if do_eval:
self.next_eval_step = self.total_training_data_encountered + self.val_freq self.next_eval_step = self.total_training_data_encountered + self.val_freq
if opt_get(opt, ['eval', 'pure'], False) and do_eval:
metrics = []
for val_data in tqdm(self.val_loader):
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
metrics.append(self.model.test())
reduced_metrics = {}
for metric in metrics:
for k, v in metric.as_dict().items():
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
if k in reduced_metrics.keys():
reduced_metrics[k].append(v)
else:
reduced_metrics[k] = [v]
if self.rank <= 0:
for k, v in reduced_metrics.items():
val = torch.stack(v).mean().item()
self.tb_logger.add_scalar(f'val_{k}', val, self.current_step)
print(f">>Eval {k}: {val}")
if opt['wandb']:
import wandb
wandb.log({f'eval_{k}': torch.stack(v).mean().item() for k,v in reduced_metrics.items()})
if len(self.evaluators) != 0 and do_eval: if opt_get(opt, ['eval', 'pure'], False):
self.do_validation()
if len(self.evaluators) != 0:
eval_dict = {} eval_dict = {}
for eval in self.evaluators: for eval in self.evaluators:
if eval.uses_all_ddp or self.rank <= 0: if eval.uses_all_ddp or self.rank <= 0:
eval_dict.update(eval.perform_eval()) eval_dict.update(eval.perform_eval())
if self.rank <= 0: if self.rank <= 0:
print("Evaluator results: ", eval_dict) print("Evaluator results: ", eval_dict)
"""
for ek, ev in eval_dict.items(): for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step) self.tb_logger.add_scalar(ek, ev, self.current_step)
if opt['wandb']: if opt['wandb']:
import wandb import wandb
wandb.log(eval_dict) wandb.log(eval_dict)
"""
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs. # Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
for net in self.model.networks.values(): for net in self.model.networks.values():
net.zero_grad() net.zero_grad()
return metrics
def do_validation(self):
if self.rank <= 0:
self.logger.info('Beginning validation.')
metrics = []
tq_ldr = tqdm(self.train_loader, desc="Validating") if self.use_tqdm else self.train_loader
for val_data in tq_ldr:
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
metric = self.model.test()
metrics.append(metric)
if self.rank <= 0:
logs = process_metrics( metrics )
if self.use_tqdm:
tq_ldr.set_postfix( logs, refresh=True )
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
def do_training(self): def do_training(self):
if self.rank <= 0:
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1): for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch self.epoch = epoch
if self.opt['dist']: if self.opt['dist']:
self.train_sampler.set_epoch(epoch) self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader) if self.rank <= 0 else self.train_loader metrics = []
tq_ldr = tqdm(self.train_loader, desc="Training") if self.use_tqdm else self.train_loader
_t = time() _t = time()
step = 0
for train_data in tq_ldr: for train_data in tq_ldr:
self.do_step(train_data) step = step + 1
metric = self.do_step(train_data)
metrics.append(metric)
if self.rank <= 0:
logs = process_metrics( metrics )
logs['lr'] = self.model.get_current_learning_rate()[0]
if self.use_tqdm:
tq_ldr.set_postfix( logs, refresh=True )
logs['it'] = self.current_step
logs['step'] = step
logs['steps'] = len(self.train_loader)
logs['epoch'] = self.epoch
logs['iteration_rate'] = self.iteration_rate
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
if self.rank <= 0:
self.logger.info('Finished training!')
def create_training_generator(self, index): def create_training_generator(self, index):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
@ -334,26 +402,31 @@ class Trainer:
self.epoch = epoch self.epoch = epoch
if self.opt['dist']: if self.opt['dist']:
self.train_sampler.set_epoch(epoch) self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader, position=index) tq_ldr = tqdm(self.train_loader, position=index)
tq_ldr.set_description('Training')
_t = time() _t = time()
for train_data in tq_ldr: for train_data in tq_ldr:
yield self.model yield self.model
self.do_step(train_data) metric = self.do_step(train_data)
self.logger.info('Finished training')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--mode', type=str, default='', help='Handles printing info')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none': if args.launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode. # export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys(): if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids']) gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list) print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = Trainer() trainer = Trainer()
#### distributed training settings #### distributed training settings
@ -370,5 +443,10 @@ if __name__ == '__main__':
trainer.rank = torch.distributed.get_rank() trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank()) torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(args.opt, opt, args.launcher) if trainer.rank >= 1:
f = open(os.devnull, 'w')
sys.stdout = f
sys.stderr = f
trainer.init(args.opt, opt, args.launcher, args.mode)
trainer.do_training() trainer.do_training()