forked from mrq/DL-Art-School
486 lines
20 KiB
Python
486 lines
20 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
import shutil
|
|
import sys
|
|
from datetime import datetime
|
|
from time import time
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from dlas.data import create_dataloader, create_dataset, get_dataset_debugger
|
|
from dlas.data.data_sampler import DistIterSampler
|
|
from dlas.trainer.eval.evaluator import create_evaluator
|
|
from dlas.trainer.ExtensibleTrainer import ExtensibleTrainer
|
|
from dlas.utils import options as option
|
|
from dlas.utils import util
|
|
from dlas.utils.util import map_cuda_to_correct_device, opt_get
|
|
|
|
|
|
def try_json(data):
|
|
reduced = {}
|
|
for k, v in data.items():
|
|
try:
|
|
json.dumps(v)
|
|
except Exception as e:
|
|
continue
|
|
reduced[k] = v
|
|
return json.dumps(reduced)
|
|
|
|
|
|
def process_metrics(metrics):
|
|
reduced = {}
|
|
for metric in metrics:
|
|
d = metric.as_dict() if hasattr(metric, 'as_dict') else metric
|
|
for k, v in d.items():
|
|
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
|
|
if k in reduced.keys():
|
|
reduced[k].append(v)
|
|
else:
|
|
reduced[k] = [v]
|
|
logs = {}
|
|
|
|
for k, v in reduced.items():
|
|
logs[k] = torch.stack(v).mean().item()
|
|
|
|
return logs
|
|
|
|
|
|
def init_dist(backend, **kwargs):
|
|
# These packages have globals that screw with Windows, so only import them if needed.
|
|
import torch.distributed as dist
|
|
|
|
rank = int(os.environ['LOCAL_RANK'])
|
|
assert rank < torch.cuda.device_count()
|
|
torch.cuda.set_device(rank)
|
|
dist.init_process_group(backend=backend, **kwargs)
|
|
|
|
|
|
class Trainer:
|
|
|
|
def init(self, opt_path, opt, launcher, mode):
|
|
self._profile = False
|
|
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
|
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
|
self.current_step = 0
|
|
self.iteration_rate = 0
|
|
self.total_training_data_encountered = 0
|
|
|
|
self.use_tqdm = False # self.rank <= 0
|
|
|
|
# loading resume state if exists
|
|
if opt['path'].get('resume_state', None):
|
|
# distributed resuming: all load into default GPU
|
|
resume_state = torch.load(
|
|
opt['path']['resume_state'], map_location=map_cuda_to_correct_device)
|
|
else:
|
|
resume_state = None
|
|
|
|
# mkdir and loggers
|
|
# normal training (self.rank -1) OR distributed training (self.rank 0)
|
|
if self.rank <= 0:
|
|
if resume_state is None:
|
|
util.mkdir_and_rename(
|
|
opt['path']['experiments_root']) # rename experiment folder if exists
|
|
util.mkdirs(
|
|
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
|
and 'pretrain_model' not in key and 'resume' not in key))
|
|
shutil.copy(opt_path, os.path.join(
|
|
opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
|
|
|
|
# config loggers. Before it, the log will not work
|
|
util.setup_logger('base', opt['path']['log'], 'train_' +
|
|
opt['name'], level=logging.INFO, screen=True, tofile=True)
|
|
self.logger = logging.getLogger('base')
|
|
self.logger.info(option.dict2str(opt))
|
|
else:
|
|
util.setup_logger(
|
|
'base', opt['path']['log'], 'train', level=logging.INFO, screen=False)
|
|
self.logger = logging.getLogger('base')
|
|
|
|
if resume_state is not None:
|
|
# check resume options
|
|
option.check_resume(opt, resume_state['iter'])
|
|
|
|
# convert to NoneDict, which returns None for missing keys
|
|
opt = option.dict_to_nonedict(opt)
|
|
self.opt = opt
|
|
|
|
# random seed
|
|
seed = opt['train']['manual_seed']
|
|
if seed is None:
|
|
seed = random.randint(1, 10000)
|
|
if self.rank <= 0:
|
|
self.logger.info('Random seed: {}'.format(seed))
|
|
# Different multiprocessing instances should behave differently.
|
|
seed += self.rank
|
|
util.set_random_seed(seed)
|
|
|
|
torch.backends.cudnn.benchmark = opt_get(
|
|
opt, ['cuda_benchmarking_enabled'], True)
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
# torch.backends.cudnn.deterministic = True
|
|
if opt_get(opt, ['anomaly_detection'], False):
|
|
torch.autograd.set_detect_anomaly(True)
|
|
|
|
# Save the compiled opt dict to the global loaded_options variable.
|
|
util.loaded_options = opt
|
|
|
|
# create train and val dataloader
|
|
dataset_ratio = 1 # enlarge the size of each epoch
|
|
for phase, dataset_opt in opt['datasets'].items():
|
|
if phase == 'train':
|
|
self.train_set, collate_fn = create_dataset(
|
|
dataset_opt, return_collate=True)
|
|
self.dataset_debugger = get_dataset_debugger(dataset_opt)
|
|
if self.dataset_debugger is not None and resume_state is not None:
|
|
self.dataset_debugger.load_state(
|
|
opt_get(resume_state, ['dataset_debugger_state'], {}))
|
|
train_size = int(
|
|
math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
|
total_iters = int(opt['train']['niter'])
|
|
self.total_epochs = int(math.ceil(total_iters / train_size))
|
|
if opt['dist']:
|
|
self.train_sampler = DistIterSampler(
|
|
self.train_set, self.world_size, self.rank, dataset_ratio)
|
|
self.total_epochs = int(
|
|
math.ceil(total_iters / (train_size * dataset_ratio)))
|
|
shuffle = False
|
|
else:
|
|
self.train_sampler = None
|
|
shuffle = True
|
|
self.train_loader = create_dataloader(
|
|
self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle)
|
|
if self.rank <= 0:
|
|
self.logger.info('Number of training data elements: {:,d}, iters: {:,d}'.format(
|
|
len(self.train_set), train_size))
|
|
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
|
self.total_epochs, total_iters))
|
|
elif phase == 'val':
|
|
if not opt_get(opt, ['eval', 'pure'], False):
|
|
continue
|
|
|
|
self.val_set, collate_fn = create_dataset(
|
|
dataset_opt, return_collate=True)
|
|
self.val_loader = create_dataloader(
|
|
self.val_set, dataset_opt, opt, None, collate_fn=collate_fn)
|
|
if self.rank <= 0:
|
|
self.logger.info('Number of val images in [{:s}]: {:d}'.format(
|
|
dataset_opt['name'], len(self.val_set)))
|
|
else:
|
|
raise NotImplementedError(
|
|
'Phase [{:s}] is not recognized.'.format(phase))
|
|
assert self.train_loader is not None
|
|
|
|
# create model
|
|
self.model = ExtensibleTrainer(opt)
|
|
|
|
# Evaluators
|
|
self.evaluators = []
|
|
if 'eval' in opt.keys() and 'evaluators' in opt['eval'].keys():
|
|
# In "pure" mode, we propagate through the normal training steps, but use validation data instead and average
|
|
# the total loss. A validation dataloader is required.
|
|
if opt_get(opt, ['eval', 'pure'], False):
|
|
assert hasattr(self, 'val_loader')
|
|
|
|
for ev_key, ev_opt in opt['eval']['evaluators'].items():
|
|
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
|
|
ev_opt, self.model.env))
|
|
|
|
# resume training
|
|
if resume_state:
|
|
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
|
resume_state['epoch'], resume_state['iter']))
|
|
|
|
self.start_epoch = resume_state['epoch']
|
|
self.current_step = resume_state['iter']
|
|
self.total_training_data_encountered = opt_get(
|
|
resume_state, ['total_data_processed'], 0)
|
|
if opt_get(opt, ['path', 'optimizer_reset'], False):
|
|
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
|
print('!! RESETTING OPTIMIZER STATES')
|
|
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
|
else:
|
|
# handle optimizers and schedulers
|
|
self.model.resume_training(
|
|
resume_state, 'amp_opt_level' in opt.keys())
|
|
else:
|
|
self.current_step = - \
|
|
1 if 'start_step' not in opt.keys() else opt['start_step']
|
|
self.total_training_data_encountered = 0 if 'training_data_encountered' not in opt.keys(
|
|
) else opt['training_data_encountered']
|
|
self.start_epoch = 0
|
|
if 'force_start_step' in opt.keys():
|
|
self.current_step = opt['force_start_step']
|
|
self.total_training_data_encountered = self.current_step * \
|
|
opt['datasets']['train']['batch_size']
|
|
opt['current_step'] = self.current_step
|
|
|
|
self.epoch = self.start_epoch
|
|
|
|
# validation
|
|
if 'val_freq' in opt['train'].keys():
|
|
self.val_freq = opt['train']['val_freq'] * \
|
|
opt['datasets']['train']['batch_size']
|
|
else:
|
|
self.val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
|
|
|
|
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
|
# For whatever reason, this relieves a memory burden on the first GPU for some training sessions.
|
|
del resume_state
|
|
|
|
def save(self):
|
|
self.model.save(self.current_step)
|
|
state = {
|
|
'epoch': self.epoch,
|
|
'iter': self.current_step,
|
|
'total_data_processed': self.total_training_data_encountered
|
|
}
|
|
if self.dataset_debugger is not None:
|
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
|
self.model.save_training_state(state)
|
|
self.logger.info('Saving models and training states.')
|
|
|
|
def do_step(self, train_data):
|
|
if self._profile:
|
|
print("Data fetch: %f" % (time() - _t))
|
|
_t = time()
|
|
|
|
opt = self.opt
|
|
# It may seem weird to derive this from opt, rather than train_data. The reason this is done is
|
|
batch_size = self.opt['datasets']['train']['batch_size']
|
|
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
|
|
self.current_step += 1
|
|
self.total_training_data_encountered += batch_size
|
|
# self.current_step % opt['logger']['print_freq'] == 0
|
|
will_log = False
|
|
|
|
# update learning rate
|
|
self.model.update_learning_rate(
|
|
self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
|
|
|
# training
|
|
if self._profile:
|
|
print("Update LR: %f" % (time() - _t))
|
|
_t = time()
|
|
self.model.feed_data(train_data, self.current_step)
|
|
gradient_norms_dict = self.model.optimize_parameters(
|
|
self.current_step, return_grad_norms=will_log)
|
|
self.iteration_rate = (time() - _t) # / batch_size
|
|
if self._profile:
|
|
print("Model feed + step: %f" % (time() - _t))
|
|
_t = time()
|
|
|
|
metrics = {}
|
|
for s in self.model.steps:
|
|
metrics.update(s.get_metrics())
|
|
|
|
# log
|
|
if self.dataset_debugger is not None:
|
|
self.dataset_debugger.update(train_data)
|
|
if will_log:
|
|
# Must be run by all instances to gather consensus.
|
|
current_model_logs = self.model.get_current_log(self.current_step)
|
|
if will_log and self.rank <= 0:
|
|
logs = {
|
|
'step': self.current_step,
|
|
'samples': self.total_training_data_encountered,
|
|
'megasamples': self.total_training_data_encountered / 1000000,
|
|
'iteration_rate': self.iteration_rate,
|
|
'lr': self.model.get_current_learning_rate(),
|
|
}
|
|
logs.update(current_model_logs)
|
|
|
|
if self.dataset_debugger is not None:
|
|
logs.update(self.dataset_debugger.get_debugging_map())
|
|
|
|
logs.update(gradient_norms_dict)
|
|
self.logger.info(f'Training Metrics: {try_json(logs)}')
|
|
|
|
"""
|
|
logs = {'step': self.current_step,
|
|
'samples': self.total_training_data_encountered,
|
|
'megasamples': self.total_training_data_encountered / 1000000,
|
|
'iteration_rate': iteration_rate}
|
|
logs.update(current_model_logs)
|
|
if self.dataset_debugger is not None:
|
|
logs.update(self.dataset_debugger.get_debugging_map())
|
|
logs.update(gradient_norms_dict)
|
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
|
for v in self.model.get_current_learning_rate():
|
|
message += '{:.3e},'.format(v)
|
|
message += ')] '
|
|
for k, v in logs.items():
|
|
if 'histogram' in k:
|
|
self.tb_logger.add_histogram(k, v, self.current_step)
|
|
elif isinstance(v, dict):
|
|
self.tb_logger.add_scalars(k, v, self.current_step)
|
|
else:
|
|
message += '{:s}: {:.4e} '.format(k, v)
|
|
# tensorboard logger
|
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
self.tb_logger.add_scalar(k, v, self.current_step)
|
|
if opt['wandb'] and self.rank <= 0:
|
|
import wandb
|
|
wandb_logs = {}
|
|
for k, v in logs.items():
|
|
if 'histogram' in k:
|
|
wandb_logs[k] = wandb.Histogram(v)
|
|
else:
|
|
wandb_logs[k] = v
|
|
if opt_get(opt, ['wandb_progress_use_raw_steps'], False):
|
|
wandb.log(wandb_logs, step=self.current_step)
|
|
else:
|
|
wandb.log(wandb_logs, step=self.total_training_data_encountered)
|
|
self.logger.info(message)
|
|
"""
|
|
|
|
# save models and training states
|
|
if self.current_step > 0 and self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
|
self.model.consolidate_state()
|
|
if self.rank <= 0:
|
|
self.save()
|
|
|
|
do_eval = self.total_training_data_encountered > self.next_eval_step
|
|
if do_eval:
|
|
self.next_eval_step = self.total_training_data_encountered + self.val_freq
|
|
|
|
if opt_get(opt, ['eval', 'pure'], False):
|
|
self.do_validation()
|
|
if len(self.evaluators) != 0:
|
|
eval_dict = {}
|
|
for eval in self.evaluators:
|
|
if eval.uses_all_ddp or self.rank <= 0:
|
|
eval_dict.update(eval.perform_eval())
|
|
if self.rank <= 0:
|
|
print("Evaluator results: ", eval_dict)
|
|
|
|
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
|
for net in self.model.networks.values():
|
|
net.zero_grad()
|
|
|
|
return metrics
|
|
|
|
def do_validation(self):
|
|
if self.rank <= 0:
|
|
self.logger.info('Beginning validation.')
|
|
|
|
metrics = []
|
|
tq_ldr = tqdm(self.val_loader,
|
|
desc="Validating") if self.use_tqdm else self.val_loader
|
|
|
|
for val_data in tq_ldr:
|
|
self.model.feed_data(val_data, self.current_step,
|
|
perform_micro_batching=False)
|
|
metric = self.model.test()
|
|
metrics.append(metric)
|
|
if self.rank <= 0 and self.use_tqdm:
|
|
logs = process_metrics(metrics)
|
|
tq_ldr.set_postfix(logs, refresh=True)
|
|
|
|
if self.rank <= 0:
|
|
logs = process_metrics(metrics)
|
|
logs['it'] = self.current_step
|
|
self.logger.info(f'Validation Metrics: {json.dumps(logs)}')
|
|
|
|
def do_training(self):
|
|
if self.rank <= 0:
|
|
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
|
|
self.start_epoch, self.current_step))
|
|
|
|
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
|
self.epoch = epoch
|
|
if self.opt['dist']:
|
|
self.train_sampler.set_epoch(epoch)
|
|
|
|
metrics = []
|
|
tq_ldr = tqdm(
|
|
self.train_loader, desc="Training") if self.use_tqdm else self.train_loader
|
|
|
|
_t = time()
|
|
step = 0
|
|
for train_data in tq_ldr:
|
|
step = step + 1
|
|
metric = self.do_step(train_data)
|
|
metrics.append(metric)
|
|
if self.rank <= 0:
|
|
logs = process_metrics(metrics)
|
|
logs['lr'] = self.model.get_current_learning_rate()[0]
|
|
if self.use_tqdm:
|
|
tq_ldr.set_postfix(logs, refresh=True)
|
|
logs['it'] = self.current_step
|
|
logs['step'] = step
|
|
logs['steps'] = len(self.train_loader)
|
|
logs['epoch'] = self.epoch
|
|
logs['iteration_rate'] = self.iteration_rate
|
|
self.logger.info(f'Training Metrics: {json.dumps(logs)}')
|
|
|
|
if self.rank <= 0:
|
|
self.save()
|
|
self.logger.info('Finished training!')
|
|
|
|
def create_training_generator(self, index):
|
|
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
|
|
self.start_epoch, self.current_step))
|
|
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
|
self.epoch = epoch
|
|
if self.opt['dist']:
|
|
self.train_sampler.set_epoch(epoch)
|
|
|
|
tq_ldr = tqdm(self.train_loader, position=index)
|
|
tq_ldr.set_description('Training')
|
|
|
|
_t = time()
|
|
for train_data in tq_ldr:
|
|
yield self.model
|
|
metric = self.do_step(train_data)
|
|
self.save()
|
|
self.logger.info('Finished training')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
|
parser.add_argument('--mode', type=str, default='',
|
|
help='Handles printing info')
|
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.',
|
|
default='../options/train_vit_latent.yml')
|
|
args = parser.parse_args()
|
|
opt = option.parse(args.opt, is_train=True)
|
|
|
|
if args.launcher != 'none':
|
|
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
|
if 'gpu_ids' in opt.keys():
|
|
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
|
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
|
|
|
trainer = Trainer()
|
|
|
|
# distributed training settings
|
|
if args.launcher == 'none': # disabled distributed training
|
|
opt['dist'] = False
|
|
trainer.rank = -1
|
|
if len(opt['gpu_ids']) == 1:
|
|
torch.cuda.set_device(opt['gpu_ids'][0])
|
|
print('Disabled distributed training.')
|
|
else:
|
|
opt['dist'] = True
|
|
init_dist('nccl')
|
|
trainer.world_size = torch.distributed.get_world_size()
|
|
trainer.rank = torch.distributed.get_rank()
|
|
torch.cuda.set_device(torch.distributed.get_rank())
|
|
|
|
if trainer.rank >= 1:
|
|
f = open(os.devnull, 'w')
|
|
sys.stdout = f
|
|
sys.stderr = f
|
|
|
|
trainer.init(args.opt, opt, args.launcher, args.mode)
|
|
trainer.do_training()
|