gutted bloat loggers, now all my useful metrics update per step
This commit is contained in:
parent
bf94744514
commit
b5c6acec9e
202
codes/train.py
202
codes/train.py
|
@ -1,9 +1,11 @@
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import math
|
import math
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
import json
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -21,6 +23,33 @@ from utils.util import opt_get, map_cuda_to_correct_device
|
||||||
|
|
||||||
import tortoise.utils.torch_intermediary as ml
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
|
def try_json( data ):
|
||||||
|
reduced = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
try:
|
||||||
|
json.dumps(v)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
reduced[k] = v
|
||||||
|
return json.dumps(reduced)
|
||||||
|
|
||||||
|
def process_metrics( metrics ):
|
||||||
|
reduced = {}
|
||||||
|
for metric in metrics:
|
||||||
|
d = metric.as_dict() if hasattr(metric, 'as_dict') else metric
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
|
||||||
|
if k in reduced.keys():
|
||||||
|
reduced[k].append(v)
|
||||||
|
else:
|
||||||
|
reduced[k] = [v]
|
||||||
|
logs = {}
|
||||||
|
|
||||||
|
for k, v in reduced.items():
|
||||||
|
logs[k] = torch.stack(v).mean().item()
|
||||||
|
|
||||||
|
return logs
|
||||||
|
|
||||||
def init_dist(backend, **kwargs):
|
def init_dist(backend, **kwargs):
|
||||||
# These packages have globals that screw with Windows, so only import them if needed.
|
# These packages have globals that screw with Windows, so only import them if needed.
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -32,13 +61,16 @@ def init_dist(backend, **kwargs):
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
||||||
def init(self, opt_path, opt, launcher):
|
def init(self, opt_path, opt, launcher, mode):
|
||||||
self._profile = False
|
self._profile = False
|
||||||
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
||||||
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
|
self.iteration_rate = 0
|
||||||
self.total_training_data_encountered = 0
|
self.total_training_data_encountered = 0
|
||||||
|
|
||||||
|
self.use_tqdm = False # self.rank <= 0
|
||||||
|
|
||||||
#### loading resume state if exists
|
#### loading resume state if exists
|
||||||
if opt['path'].get('resume_state', None):
|
if opt['path'].get('resume_state', None):
|
||||||
# distributed resuming: all load into default GPU
|
# distributed resuming: all load into default GPU
|
||||||
|
@ -57,17 +89,11 @@ class Trainer:
|
||||||
shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
|
shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
|
||||||
|
|
||||||
# config loggers. Before it, the log will not work
|
# config loggers. Before it, the log will not work
|
||||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True)
|
||||||
screen=True, tofile=True)
|
|
||||||
self.logger = logging.getLogger('base')
|
self.logger = logging.getLogger('base')
|
||||||
self.logger.info(option.dict2str(opt))
|
self.logger.info(option.dict2str(opt))
|
||||||
# tensorboard logger
|
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
||||||
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
|
|
||||||
else:
|
else:
|
||||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=False)
|
||||||
self.logger = logging.getLogger('base')
|
self.logger = logging.getLogger('base')
|
||||||
|
|
||||||
if resume_state is not None:
|
if resume_state is not None:
|
||||||
|
@ -77,14 +103,6 @@ class Trainer:
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
|
|
||||||
#### wandb init
|
|
||||||
if opt['wandb'] and self.rank <= 0:
|
|
||||||
import wandb
|
|
||||||
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
|
|
||||||
project_name = opt_get(opt, ['wandb_project_name'], opt['name'])
|
|
||||||
run_name = opt_get(opt, ['wandb_run_name'], None)
|
|
||||||
wandb.init(project=project_name, dir=opt['path']['log'], config=opt, name=run_name)
|
|
||||||
|
|
||||||
#### random seed
|
#### random seed
|
||||||
seed = opt['train']['manual_seed']
|
seed = opt['train']['manual_seed']
|
||||||
if seed is None:
|
if seed is None:
|
||||||
|
@ -197,7 +215,7 @@ class Trainer:
|
||||||
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
|
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
self.total_training_data_encountered += batch_size
|
self.total_training_data_encountered += batch_size
|
||||||
will_log = self.current_step % opt['logger']['print_freq'] == 0
|
will_log = False # self.current_step % opt['logger']['print_freq'] == 0
|
||||||
|
|
||||||
#### update learning rate
|
#### update learning rate
|
||||||
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||||
|
@ -208,11 +226,15 @@ class Trainer:
|
||||||
_t = time()
|
_t = time()
|
||||||
self.model.feed_data(train_data, self.current_step)
|
self.model.feed_data(train_data, self.current_step)
|
||||||
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
|
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
|
||||||
iteration_rate = (time() - _t) / batch_size
|
self.iteration_rate = (time() - _t) / batch_size
|
||||||
if self._profile:
|
if self._profile:
|
||||||
print("Model feed + step: %f" % (time() - _t))
|
print("Model feed + step: %f" % (time() - _t))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
for s in self.model.steps:
|
||||||
|
metrics.update(s.get_metrics())
|
||||||
|
|
||||||
#### log
|
#### log
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
self.dataset_debugger.update(train_data)
|
self.dataset_debugger.update(train_data)
|
||||||
|
@ -220,6 +242,22 @@ class Trainer:
|
||||||
# Must be run by all instances to gather consensus.
|
# Must be run by all instances to gather consensus.
|
||||||
current_model_logs = self.model.get_current_log(self.current_step)
|
current_model_logs = self.model.get_current_log(self.current_step)
|
||||||
if will_log and self.rank <= 0:
|
if will_log and self.rank <= 0:
|
||||||
|
logs = {
|
||||||
|
'step': self.current_step,
|
||||||
|
'samples': self.total_training_data_encountered,
|
||||||
|
'megasamples': self.total_training_data_encountered / 1000000,
|
||||||
|
'iteration_rate': self.iteration_rate,
|
||||||
|
'lr': self.model.get_current_learning_rate(),
|
||||||
|
}
|
||||||
|
logs.update(current_model_logs)
|
||||||
|
|
||||||
|
if self.dataset_debugger is not None:
|
||||||
|
logs.update(self.dataset_debugger.get_debugging_map())
|
||||||
|
|
||||||
|
logs.update(gradient_norms_dict)
|
||||||
|
self.logger.info(f'Training Metrics: {try_json(logs)}')
|
||||||
|
|
||||||
|
"""
|
||||||
logs = {'step': self.current_step,
|
logs = {'step': self.current_step,
|
||||||
'samples': self.total_training_data_encountered,
|
'samples': self.total_training_data_encountered,
|
||||||
'megasamples': self.total_training_data_encountered / 1000000,
|
'megasamples': self.total_training_data_encountered / 1000000,
|
||||||
|
@ -255,78 +293,108 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
wandb.log(wandb_logs, step=self.total_training_data_encountered)
|
wandb.log(wandb_logs, step=self.total_training_data_encountered)
|
||||||
self.logger.info(message)
|
self.logger.info(message)
|
||||||
|
"""
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||||
self.model.consolidate_state()
|
self.model.consolidate_state()
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.logger.info('Saving models and training states.')
|
|
||||||
self.model.save(self.current_step)
|
self.model.save(self.current_step)
|
||||||
state = {'epoch': self.epoch, 'iter': self.current_step, 'total_data_processed': self.total_training_data_encountered}
|
state = {
|
||||||
|
'epoch': self.epoch,
|
||||||
|
'iter': self.current_step,
|
||||||
|
'total_data_processed': self.total_training_data_encountered
|
||||||
|
}
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||||
self.model.save_training_state(state)
|
self.model.save_training_state(state)
|
||||||
|
self.logger.info('Saving models and training states.')
|
||||||
|
"""
|
||||||
if 'alt_path' in opt['path'].keys():
|
if 'alt_path' in opt['path'].keys():
|
||||||
import shutil
|
import shutil
|
||||||
print("Synchronizing tb_logger to alt_path..")
|
print("Synchronizing tb_logger to alt_path..")
|
||||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
||||||
|
"""
|
||||||
|
|
||||||
do_eval = self.total_training_data_encountered > self.next_eval_step
|
do_eval = self.total_training_data_encountered > self.next_eval_step
|
||||||
if do_eval:
|
if do_eval:
|
||||||
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
||||||
if opt_get(opt, ['eval', 'pure'], False) and do_eval:
|
|
||||||
metrics = []
|
|
||||||
for val_data in tqdm(self.val_loader):
|
|
||||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
|
||||||
metrics.append(self.model.test())
|
|
||||||
reduced_metrics = {}
|
|
||||||
for metric in metrics:
|
|
||||||
for k, v in metric.as_dict().items():
|
|
||||||
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
|
|
||||||
if k in reduced_metrics.keys():
|
|
||||||
reduced_metrics[k].append(v)
|
|
||||||
else:
|
|
||||||
reduced_metrics[k] = [v]
|
|
||||||
if self.rank <= 0:
|
|
||||||
for k, v in reduced_metrics.items():
|
|
||||||
val = torch.stack(v).mean().item()
|
|
||||||
self.tb_logger.add_scalar(f'val_{k}', val, self.current_step)
|
|
||||||
print(f">>Eval {k}: {val}")
|
|
||||||
if opt['wandb']:
|
|
||||||
import wandb
|
|
||||||
wandb.log({f'eval_{k}': torch.stack(v).mean().item() for k,v in reduced_metrics.items()})
|
|
||||||
|
|
||||||
if len(self.evaluators) != 0 and do_eval:
|
if opt_get(opt, ['eval', 'pure'], False):
|
||||||
eval_dict = {}
|
self.do_validation()
|
||||||
for eval in self.evaluators:
|
if len(self.evaluators) != 0:
|
||||||
if eval.uses_all_ddp or self.rank <= 0:
|
eval_dict = {}
|
||||||
eval_dict.update(eval.perform_eval())
|
for eval in self.evaluators:
|
||||||
if self.rank <= 0:
|
if eval.uses_all_ddp or self.rank <= 0:
|
||||||
print("Evaluator results: ", eval_dict)
|
eval_dict.update(eval.perform_eval())
|
||||||
for ek, ev in eval_dict.items():
|
if self.rank <= 0:
|
||||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
print("Evaluator results: ", eval_dict)
|
||||||
if opt['wandb']:
|
"""
|
||||||
import wandb
|
for ek, ev in eval_dict.items():
|
||||||
wandb.log(eval_dict)
|
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||||
|
|
||||||
|
if opt['wandb']:
|
||||||
|
import wandb
|
||||||
|
wandb.log(eval_dict)
|
||||||
|
"""
|
||||||
|
|
||||||
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
||||||
for net in self.model.networks.values():
|
for net in self.model.networks.values():
|
||||||
net.zero_grad()
|
net.zero_grad()
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def do_validation(self):
|
||||||
|
if self.rank <= 0:
|
||||||
|
self.logger.info('Beginning validation.')
|
||||||
|
|
||||||
|
metrics = []
|
||||||
|
tq_ldr = tqdm(self.train_loader, desc="Validating") if self.use_tqdm else self.train_loader
|
||||||
|
|
||||||
|
for val_data in tq_ldr:
|
||||||
|
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||||
|
metric = self.model.test()
|
||||||
|
metrics.append(metric)
|
||||||
|
if self.rank <= 0:
|
||||||
|
logs = process_metrics( metrics )
|
||||||
|
if self.use_tqdm:
|
||||||
|
tq_ldr.set_postfix( logs, refresh=True )
|
||||||
|
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
||||||
|
|
||||||
def do_training(self):
|
def do_training(self):
|
||||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
if self.rank <= 0:
|
||||||
|
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||||
|
|
||||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
if self.opt['dist']:
|
if self.opt['dist']:
|
||||||
self.train_sampler.set_epoch(epoch)
|
self.train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
tq_ldr = tqdm(self.train_loader) if self.rank <= 0 else self.train_loader
|
metrics = []
|
||||||
|
tq_ldr = tqdm(self.train_loader, desc="Training") if self.use_tqdm else self.train_loader
|
||||||
|
|
||||||
_t = time()
|
_t = time()
|
||||||
|
step = 0
|
||||||
for train_data in tq_ldr:
|
for train_data in tq_ldr:
|
||||||
self.do_step(train_data)
|
step = step + 1
|
||||||
|
metric = self.do_step(train_data)
|
||||||
|
metrics.append(metric)
|
||||||
|
if self.rank <= 0:
|
||||||
|
logs = process_metrics( metrics )
|
||||||
|
logs['lr'] = self.model.get_current_learning_rate()[0]
|
||||||
|
if self.use_tqdm:
|
||||||
|
tq_ldr.set_postfix( logs, refresh=True )
|
||||||
|
logs['it'] = self.current_step
|
||||||
|
logs['step'] = step
|
||||||
|
logs['steps'] = len(self.train_loader)
|
||||||
|
logs['epoch'] = self.epoch
|
||||||
|
logs['iteration_rate'] = self.iteration_rate
|
||||||
|
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
|
||||||
|
|
||||||
|
if self.rank <= 0:
|
||||||
|
self.logger.info('Finished training!')
|
||||||
|
|
||||||
def create_training_generator(self, index):
|
def create_training_generator(self, index):
|
||||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||||
|
@ -334,26 +402,31 @@ class Trainer:
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
if self.opt['dist']:
|
if self.opt['dist']:
|
||||||
self.train_sampler.set_epoch(epoch)
|
self.train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
tq_ldr = tqdm(self.train_loader, position=index)
|
tq_ldr = tqdm(self.train_loader, position=index)
|
||||||
|
tq_ldr.set_description('Training')
|
||||||
|
|
||||||
_t = time()
|
_t = time()
|
||||||
for train_data in tq_ldr:
|
for train_data in tq_ldr:
|
||||||
yield self.model
|
yield self.model
|
||||||
self.do_step(train_data)
|
metric = self.do_step(train_data)
|
||||||
|
self.logger.info('Finished training')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
|
parser.add_argument('--mode', type=str, default='', help='Handles printing info')
|
||||||
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
||||||
if args.launcher != 'none':
|
if args.launcher != 'none':
|
||||||
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
||||||
if 'gpu_ids' in opt.keys():
|
if 'gpu_ids' in opt.keys():
|
||||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||||
|
|
||||||
trainer = Trainer()
|
trainer = Trainer()
|
||||||
|
|
||||||
#### distributed training settings
|
#### distributed training settings
|
||||||
|
@ -370,5 +443,10 @@ if __name__ == '__main__':
|
||||||
trainer.rank = torch.distributed.get_rank()
|
trainer.rank = torch.distributed.get_rank()
|
||||||
torch.cuda.set_device(torch.distributed.get_rank())
|
torch.cuda.set_device(torch.distributed.get_rank())
|
||||||
|
|
||||||
trainer.init(args.opt, opt, args.launcher)
|
if trainer.rank >= 1:
|
||||||
trainer.do_training()
|
f = open(os.devnull, 'w')
|
||||||
|
sys.stdout = f
|
||||||
|
sys.stderr = f
|
||||||
|
|
||||||
|
trainer.init(args.opt, opt, args.launcher, args.mode)
|
||||||
|
trainer.do_training()
|
Loading…
Reference in New Issue
Block a user