DL-Art-School/dlas/train.py

486 lines
20 KiB
Python

import argparse
import json
import logging
import math
import os
import random
import shutil
import sys
from datetime import datetime
from time import time
import torch
from tqdm import tqdm
from dlas.data import create_dataloader, create_dataset, get_dataset_debugger
from dlas.data.data_sampler import DistIterSampler
from dlas.trainer.eval.evaluator import create_evaluator
from dlas.trainer.ExtensibleTrainer import ExtensibleTrainer
from dlas.utils import options as option
from dlas.utils import util
from dlas.utils.util import map_cuda_to_correct_device, opt_get
def try_json(data):
reduced = {}
for k, v in data.items():
try:
json.dumps(v)
except Exception as e:
continue
reduced[k] = v
return json.dumps(reduced)
def process_metrics(metrics):
reduced = {}
for metric in metrics:
d = metric.as_dict() if hasattr(metric, 'as_dict') else metric
for k, v in d.items():
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
if k in reduced.keys():
reduced[k].append(v)
else:
reduced[k] = [v]
logs = {}
for k, v in reduced.items():
logs[k] = torch.stack(v).mean().item()
return logs
def init_dist(backend, **kwargs):
# These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist
rank = int(os.environ['LOCAL_RANK'])
assert rank < torch.cuda.device_count()
torch.cuda.set_device(rank)
dist.init_process_group(backend=backend, **kwargs)
class Trainer:
def init(self, opt_path, opt, launcher, mode):
self._profile = False
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
self.current_step = 0
self.iteration_rate = 0
self.total_training_data_encountered = 0
self.use_tqdm = False # self.rank <= 0
# loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
resume_state = torch.load(
opt['path']['resume_state'], map_location=map_cuda_to_correct_device)
else:
resume_state = None
# mkdir and loggers
# normal training (self.rank -1) OR distributed training (self.rank 0)
if self.rank <= 0:
if resume_state is None:
util.mkdir_and_rename(
opt['path']['experiments_root']) # rename experiment folder if exists
util.mkdirs(
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
and 'pretrain_model' not in key and 'resume' not in key))
shutil.copy(opt_path, os.path.join(
opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
# config loggers. Before it, the log will not work
util.setup_logger('base', opt['path']['log'], 'train_' +
opt['name'], level=logging.INFO, screen=True, tofile=True)
self.logger = logging.getLogger('base')
self.logger.info(option.dict2str(opt))
else:
util.setup_logger(
'base', opt['path']['log'], 'train', level=logging.INFO, screen=False)
self.logger = logging.getLogger('base')
if resume_state is not None:
# check resume options
option.check_resume(opt, resume_state['iter'])
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
self.opt = opt
# random seed
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
if self.rank <= 0:
self.logger.info('Random seed: {}'.format(seed))
# Different multiprocessing instances should behave differently.
seed += self.rank
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = opt_get(
opt, ['cuda_benchmarking_enabled'], True)
torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.deterministic = True
if opt_get(opt, ['anomaly_detection'], False):
torch.autograd.set_detect_anomaly(True)
# Save the compiled opt dict to the global loaded_options variable.
util.loaded_options = opt
# create train and val dataloader
dataset_ratio = 1 # enlarge the size of each epoch
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
self.train_set, collate_fn = create_dataset(
dataset_opt, return_collate=True)
self.dataset_debugger = get_dataset_debugger(dataset_opt)
if self.dataset_debugger is not None and resume_state is not None:
self.dataset_debugger.load_state(
opt_get(resume_state, ['dataset_debugger_state'], {}))
train_size = int(
math.ceil(len(self.train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
self.total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']:
self.train_sampler = DistIterSampler(
self.train_set, self.world_size, self.rank, dataset_ratio)
self.total_epochs = int(
math.ceil(total_iters / (train_size * dataset_ratio)))
shuffle = False
else:
self.train_sampler = None
shuffle = True
self.train_loader = create_dataloader(
self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle)
if self.rank <= 0:
self.logger.info('Number of training data elements: {:,d}, iters: {:,d}'.format(
len(self.train_set), train_size))
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
self.total_epochs, total_iters))
elif phase == 'val':
if not opt_get(opt, ['eval', 'pure'], False):
continue
self.val_set, collate_fn = create_dataset(
dataset_opt, return_collate=True)
self.val_loader = create_dataloader(
self.val_set, dataset_opt, opt, None, collate_fn=collate_fn)
if self.rank <= 0:
self.logger.info('Number of val images in [{:s}]: {:d}'.format(
dataset_opt['name'], len(self.val_set)))
else:
raise NotImplementedError(
'Phase [{:s}] is not recognized.'.format(phase))
assert self.train_loader is not None
# create model
self.model = ExtensibleTrainer(opt)
# Evaluators
self.evaluators = []
if 'eval' in opt.keys() and 'evaluators' in opt['eval'].keys():
# In "pure" mode, we propagate through the normal training steps, but use validation data instead and average
# the total loss. A validation dataloader is required.
if opt_get(opt, ['eval', 'pure'], False):
assert hasattr(self, 'val_loader')
for ev_key, ev_opt in opt['eval']['evaluators'].items():
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
ev_opt, self.model.env))
# resume training
if resume_state:
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
resume_state['epoch'], resume_state['iter']))
self.start_epoch = resume_state['epoch']
self.current_step = resume_state['iter']
self.total_training_data_encountered = opt_get(
resume_state, ['total_data_processed'], 0)
if opt_get(opt, ['path', 'optimizer_reset'], False):
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print('!! RESETTING OPTIMIZER STATES')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
else:
# handle optimizers and schedulers
self.model.resume_training(
resume_state, 'amp_opt_level' in opt.keys())
else:
self.current_step = - \
1 if 'start_step' not in opt.keys() else opt['start_step']
self.total_training_data_encountered = 0 if 'training_data_encountered' not in opt.keys(
) else opt['training_data_encountered']
self.start_epoch = 0
if 'force_start_step' in opt.keys():
self.current_step = opt['force_start_step']
self.total_training_data_encountered = self.current_step * \
opt['datasets']['train']['batch_size']
opt['current_step'] = self.current_step
self.epoch = self.start_epoch
# validation
if 'val_freq' in opt['train'].keys():
self.val_freq = opt['train']['val_freq'] * \
opt['datasets']['train']['batch_size']
else:
self.val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
self.next_eval_step = self.total_training_data_encountered + self.val_freq
# For whatever reason, this relieves a memory burden on the first GPU for some training sessions.
del resume_state
def save(self):
self.model.save(self.current_step)
state = {
'epoch': self.epoch,
'iter': self.current_step,
'total_data_processed': self.total_training_data_encountered
}
if self.dataset_debugger is not None:
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
self.model.save_training_state(state)
self.logger.info('Saving models and training states.')
def do_step(self, train_data):
if self._profile:
print("Data fetch: %f" % (time() - _t))
_t = time()
opt = self.opt
# It may seem weird to derive this from opt, rather than train_data. The reason this is done is
batch_size = self.opt['datasets']['train']['batch_size']
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
self.current_step += 1
self.total_training_data_encountered += batch_size
# self.current_step % opt['logger']['print_freq'] == 0
will_log = False
# update learning rate
self.model.update_learning_rate(
self.current_step, warmup_iter=opt['train']['warmup_iter'])
# training
if self._profile:
print("Update LR: %f" % (time() - _t))
_t = time()
self.model.feed_data(train_data, self.current_step)
gradient_norms_dict = self.model.optimize_parameters(
self.current_step, return_grad_norms=will_log)
self.iteration_rate = (time() - _t) # / batch_size
if self._profile:
print("Model feed + step: %f" % (time() - _t))
_t = time()
metrics = {}
for s in self.model.steps:
metrics.update(s.get_metrics())
# log
if self.dataset_debugger is not None:
self.dataset_debugger.update(train_data)
if will_log:
# Must be run by all instances to gather consensus.
current_model_logs = self.model.get_current_log(self.current_step)
if will_log and self.rank <= 0:
logs = {
'step': self.current_step,
'samples': self.total_training_data_encountered,
'megasamples': self.total_training_data_encountered / 1000000,
'iteration_rate': self.iteration_rate,
'lr': self.model.get_current_learning_rate(),
}
logs.update(current_model_logs)
if self.dataset_debugger is not None:
logs.update(self.dataset_debugger.get_debugging_map())
logs.update(gradient_norms_dict)
self.logger.info(f'Training Metrics: {try_json(logs)}')
"""
logs = {'step': self.current_step,
'samples': self.total_training_data_encountered,
'megasamples': self.total_training_data_encountered / 1000000,
'iteration_rate': iteration_rate}
logs.update(current_model_logs)
if self.dataset_debugger is not None:
logs.update(self.dataset_debugger.get_debugging_map())
logs.update(gradient_norms_dict)
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
for v in self.model.get_current_learning_rate():
message += '{:.3e},'.format(v)
message += ')] '
for k, v in logs.items():
if 'histogram' in k:
self.tb_logger.add_histogram(k, v, self.current_step)
elif isinstance(v, dict):
self.tb_logger.add_scalars(k, v, self.current_step)
else:
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb'] and self.rank <= 0:
import wandb
wandb_logs = {}
for k, v in logs.items():
if 'histogram' in k:
wandb_logs[k] = wandb.Histogram(v)
else:
wandb_logs[k] = v
if opt_get(opt, ['wandb_progress_use_raw_steps'], False):
wandb.log(wandb_logs, step=self.current_step)
else:
wandb.log(wandb_logs, step=self.total_training_data_encountered)
self.logger.info(message)
"""
# save models and training states
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
self.model.consolidate_state()
if self.rank <= 0:
self.save()
do_eval = self.total_training_data_encountered > self.next_eval_step
if do_eval:
self.next_eval_step = self.total_training_data_encountered + self.val_freq
if opt_get(opt, ['eval', 'pure'], False):
self.do_validation()
if len(self.evaluators) != 0:
eval_dict = {}
for eval in self.evaluators:
if eval.uses_all_ddp or self.rank <= 0:
eval_dict.update(eval.perform_eval())
if self.rank <= 0:
print("Evaluator results: ", eval_dict)
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
for net in self.model.networks.values():
net.zero_grad()
return metrics
def do_validation(self):
if self.rank <= 0:
self.logger.info('Beginning validation.')
metrics = []
tq_ldr = tqdm(self.val_loader,
desc="Validating") if self.use_tqdm else self.val_loader
for val_data in tq_ldr:
self.model.feed_data(val_data, self.current_step,
perform_micro_batching=False)
metric = self.model.test()
metrics.append(metric)
if self.rank <= 0 and self.use_tqdm:
logs = process_metrics(metrics)
tq_ldr.set_postfix(logs, refresh=True)
if self.rank <= 0:
logs = process_metrics(metrics)
logs['it'] = self.current_step
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
def do_training(self):
if self.rank <= 0:
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if self.opt['dist']:
self.train_sampler.set_epoch(epoch)
metrics = []
tq_ldr = tqdm(
self.train_loader, desc="Training") if self.use_tqdm else self.train_loader
_t = time()
step = 0
for train_data in tq_ldr:
step = step + 1
metric = self.do_step(train_data)
metrics.append(metric)
if self.rank <= 0:
logs = process_metrics(metrics)
logs['lr'] = self.model.get_current_learning_rate()[0]
if self.use_tqdm:
tq_ldr.set_postfix(logs, refresh=True)
logs['it'] = self.current_step
logs['step'] = step
logs['steps'] = len(self.train_loader)
logs['epoch'] = self.epoch
logs['iteration_rate'] = self.iteration_rate
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
if self.rank <= 0:
self.save()
self.logger.info('Finished training!')
def create_training_generator(self, index):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if self.opt['dist']:
self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader, position=index)
tq_ldr.set_description('Training')
_t = time()
for train_data in tq_ldr:
yield self.model
metric = self.do_step(train_data)
self.save()
self.logger.info('Finished training')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--mode', type=str, default='',
help='Handles printing info')
parser.add_argument('-opt', type=str, help='Path to option YAML file.',
default='../options/train_vit_latent.yml')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = Trainer()
# distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
if trainer.rank >= 1:
f = open(os.devnull, 'w')
sys.stdout = f
sys.stderr = f
trainer.init(args.opt, opt, args.launcher, args.mode)
trainer.do_training()