Class-ify train.py and workon multi-modal trainer
This commit is contained in:
parent
15e00e9014
commit
76789a456f
|
@ -26,21 +26,19 @@ def create_teco_injector(opt, env):
|
|||
return FlowAdjustment(opt, env)
|
||||
return None
|
||||
|
||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin, fp16):
|
||||
triplet = input_list[:, index:index+3]
|
||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
|
||||
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
||||
with torch.no_grad() and autocast(enabled=fp16):
|
||||
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
|
||||
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
||||
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
|
||||
#last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
|
||||
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
|
||||
triplet[:,1],
|
||||
resampler(triplet[:,2].float(), last_flow.float())]
|
||||
flow_triplet = torch.stack(flow_triplet, dim=1)
|
||||
combined = torch.cat([triplet, flow_triplet], dim=1)
|
||||
b, f, c, h, w = combined.shape
|
||||
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
|
||||
with autocast(enabled=False):
|
||||
triplet = input_list[:, index:index+3].float()
|
||||
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2))
|
||||
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2))
|
||||
flow_triplet = [resampler(triplet[:,0], first_flow),
|
||||
triplet[:,1],
|
||||
resampler(triplet[:,2], last_flow)]
|
||||
flow_triplet = torch.stack(flow_triplet, dim=1)
|
||||
combined = torch.cat([triplet, flow_triplet], dim=1)
|
||||
b, f, c, h, w = combined.shape
|
||||
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
|
||||
# Apply margin
|
||||
return combined[:, :, margin:-margin, margin:-margin]
|
||||
|
||||
|
@ -98,13 +96,11 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
first_step = False
|
||||
else:
|
||||
input = extract_inputs_index(inputs, i)
|
||||
with torch.no_grad():
|
||||
with torch.no_grad() and autocast(enabled=False):
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
# Resample does not work in FP16.
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2).float()
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield)
|
||||
input[self.recurrent_index] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
|
@ -125,13 +121,12 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
for i in it:
|
||||
input = extract_inputs_index(inputs, i)
|
||||
with torch.no_grad():
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
with autocast(enabled=False):
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2).float()
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||
input[self.recurrent_index
|
||||
] = recurrent_input
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield)
|
||||
input[self.recurrent_index] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
debug_index += 1
|
||||
|
@ -167,12 +162,13 @@ class FlowAdjustment(Injector):
|
|||
self.flowed = opt['flowed']
|
||||
|
||||
def forward(self, state):
|
||||
flow = self.env['generators'][self.flow]
|
||||
flow_target = state[self.flow_target]
|
||||
flowed = F.interpolate(state[self.flowed], size=flow_target.shape[2:], mode='bicubic')
|
||||
flow_input = torch.stack([flow_target, flowed], dim=2)
|
||||
flowfield = F.interpolate(flow(flow_input), size=state[self.flowed].shape[2:], mode='bicubic')
|
||||
return {self.output: self.resample(state[self.flowed].float(), flowfield.float())}
|
||||
with autocast(enabled=False):
|
||||
flow = self.env['generators'][self.flow]
|
||||
flow_target = state[self.flow_target]
|
||||
flowed = F.interpolate(state[self.flowed], size=flow_target.shape[2:], mode='bicubic')
|
||||
flow_input = torch.stack([flow_target, flowed], dim=2).float()
|
||||
flowfield = F.interpolate(flow(flow_input), size=state[self.flowed].shape[2:], mode='bicubic')
|
||||
return {self.output: self.resample(state[self.flowed], flowfield)}
|
||||
|
||||
|
||||
# This is the temporal discriminator loss from TecoGAN.
|
||||
|
|
|
@ -18,8 +18,9 @@ def main(master_opt, launcher):
|
|||
shared_networks = []
|
||||
for i, sub_opt in enumerate(master_opt['trainer_options']):
|
||||
sub_opt_parsed = option.parse(sub_opt, is_train=True)
|
||||
# This creates trainers() as a list of generators.
|
||||
train_gen = train.yielding_main(sub_opt_parsed, launcher, i, all_networks)
|
||||
trainer = train.Trainer()
|
||||
trainer.init(sub_opt_parsed, launcher, all_networks)
|
||||
train_gen = trainer.create_training_generator(i)
|
||||
model = next(train_gen)
|
||||
for k, v in model.networks.items():
|
||||
if k in all_networks.keys() and k not in shared_networks:
|
||||
|
|
702
codes/train.py
702
codes/train.py
|
@ -13,491 +13,265 @@ from data import create_dataloader, create_dataset
|
|||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from time import time
|
||||
|
||||
class Trainer:
|
||||
def init_dist(self, backend, **kwargs):
|
||||
# These packages have globals that screw with Windows, so only import them if needed.
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
def init_dist(backend='nccl', **kwargs):
|
||||
# These packages have globals that screw with Windows, so only import them if needed.
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
"""initialization for distributed training"""
|
||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||
mp.set_start_method('spawn')
|
||||
self.rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(self.rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
"""initialization for distributed training"""
|
||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||
mp.set_start_method('spawn')
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
def init(self, opt, launcher, all_networks={}):
|
||||
self._profile = False
|
||||
|
||||
#### distributed training settings
|
||||
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
||||
gpu = input(
|
||||
'I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU'
|
||||
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
|
||||
if gpu:
|
||||
opt['gpu_ids'] = [int(gpu)]
|
||||
if launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
self.rank = -1
|
||||
print('Disabled distributed training.')
|
||||
|
||||
def main(opt, launcher='none'):
|
||||
#### distributed training settings
|
||||
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
||||
gpu = input('I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU'
|
||||
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
|
||||
if gpu:
|
||||
opt['gpu_ids'] = [int(gpu)]
|
||||
if launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
rank = -1
|
||||
print('Disabled distributed training.')
|
||||
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
device_id = torch.cuda.current_device()
|
||||
resume_state = torch.load(opt['path']['resume_state'],
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
option.check_resume(opt, resume_state['iter']) # check resume options
|
||||
else:
|
||||
resume_state = None
|
||||
|
||||
#### mkdir and loggers
|
||||
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
||||
if resume_state is None:
|
||||
util.mkdir_and_rename(
|
||||
opt['path']['experiments_root']) # rename experiment folder if exists
|
||||
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
||||
and 'pretrain_model' not in key and 'resume' not in key))
|
||||
|
||||
# config loggers. Before it, the log will not work
|
||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
||||
version = float(torch.__version__[0:3])
|
||||
if version >= 1.1: # PyTorch 1.1
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
else:
|
||||
logger.info(
|
||||
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
||||
from tensorboardX import SummaryWriter
|
||||
tb_logger = SummaryWriter(log_dir=tb_logger_path)
|
||||
else:
|
||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
# convert to NoneDict, which returns None for missing keys
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
#### random seed
|
||||
seed = opt['train']['manual_seed']
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
if rank <= 0:
|
||||
logger.info('Random seed: {}'.format(seed))
|
||||
util.set_random_seed(seed)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
# Save the compiled opt dict to the global loaded_options variable.
|
||||
util.loaded_options = opt
|
||||
|
||||
#### create train and val dataloader
|
||||
dataset_ratio = 1 # enlarge the size of each epoch
|
||||
for phase, dataset_opt in opt['datasets'].items():
|
||||
if phase == 'train':
|
||||
train_set = create_dataset(dataset_opt)
|
||||
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
||||
total_iters = int(opt['train']['niter'])
|
||||
total_epochs = int(math.ceil(total_iters / train_size))
|
||||
if opt['dist']:
|
||||
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
|
||||
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||
else:
|
||||
train_sampler = None
|
||||
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
||||
if rank <= 0:
|
||||
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(train_set), train_size))
|
||||
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
||||
total_epochs, total_iters))
|
||||
elif phase == 'val':
|
||||
val_set = create_dataset(dataset_opt)
|
||||
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
||||
if rank <= 0:
|
||||
logger.info('Number of val images in [{:s}]: {:d}'.format(
|
||||
dataset_opt['name'], len(val_set)))
|
||||
else:
|
||||
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
||||
assert train_loader is not None
|
||||
opt['dist'] = True
|
||||
self.init_dist()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
self.rank = torch.distributed.get_rank()
|
||||
|
||||
#### create model
|
||||
model = ExtensibleTrainer(opt)
|
||||
|
||||
#### resume training
|
||||
if resume_state:
|
||||
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
||||
resume_state['epoch'], resume_state['iter']))
|
||||
|
||||
start_epoch = resume_state['epoch']
|
||||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
start_epoch = 0
|
||||
if 'force_start_step' in opt.keys():
|
||||
current_step = opt['force_start_step']
|
||||
|
||||
#### training
|
||||
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
||||
for epoch in range(start_epoch, total_epochs + 1):
|
||||
if opt['dist']:
|
||||
train_sampler.set_epoch(epoch)
|
||||
tq_ldr = tqdm(train_loader)
|
||||
|
||||
_t = time()
|
||||
_profile = False
|
||||
for train_data in tq_ldr:
|
||||
if _profile:
|
||||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
current_step += 1
|
||||
if current_step > total_iters:
|
||||
break
|
||||
#### update learning rate
|
||||
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||
|
||||
#### training
|
||||
if _profile:
|
||||
print("Update LR: %f" % (time() - _t))
|
||||
_t = time()
|
||||
model.feed_data(train_data)
|
||||
model.optimize_parameters(current_step)
|
||||
if _profile:
|
||||
print("Model feed + step: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
#### log
|
||||
if current_step % opt['logger']['print_freq'] == 0 and rank <= 0:
|
||||
logs = model.get_current_log(current_step)
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||
for v in model.get_current_learning_rate():
|
||||
message += '{:.3e},'.format(v)
|
||||
message += ')] '
|
||||
for k, v in logs.items():
|
||||
if 'histogram' in k:
|
||||
tb_logger.add_histogram(k, v, current_step)
|
||||
elif isinstance(v, dict):
|
||||
tb_logger.add_scalars(k, v, current_step)
|
||||
else:
|
||||
message += '{:s}: {:.4e} '.format(k, v)
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
tb_logger.add_scalar(k, v, current_step)
|
||||
logger.info(message)
|
||||
|
||||
#### save models and training states
|
||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
if rank <= 0:
|
||||
logger.info('Saving models and training states.')
|
||||
model.save(current_step)
|
||||
model.save_training_state(epoch, current_step)
|
||||
if 'alt_path' in opt['path'].keys():
|
||||
import shutil
|
||||
print("Synchronizing tb_logger to alt_path..")
|
||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||
shutil.copytree(tb_logger_path, alt_tblogger)
|
||||
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||
avg_psnr = 0.
|
||||
avg_fea_loss = 0.
|
||||
idx = 0
|
||||
colab_imgs_to_copy = []
|
||||
val_tqdm = tqdm(val_loader)
|
||||
for val_data in val_tqdm:
|
||||
idx += 1
|
||||
for b in range(len(val_data['LQ_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
util.mkdir(img_dir)
|
||||
|
||||
model.feed_data(val_data)
|
||||
model.test()
|
||||
|
||||
visuals = model.get_current_visuals()
|
||||
if visuals is None:
|
||||
continue
|
||||
|
||||
# calculate PSNR
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
util.save_img(sr_img, save_img_path)
|
||||
|
||||
avg_psnr = avg_psnr / idx
|
||||
avg_fea_loss = avg_fea_loss / idx
|
||||
|
||||
# log
|
||||
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0:
|
||||
tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||
|
||||
if rank <= 0:
|
||||
logger.info('Saving the final model.')
|
||||
model.save('latest')
|
||||
logger.info('End of training.')
|
||||
tb_logger.close()
|
||||
|
||||
# TODO: Integrate with above main by putting this into an object and splitting up business logic.
|
||||
def yielding_main(opt, launcher='none', trainer_id=0, all_networks={}):
|
||||
#### distributed training settings
|
||||
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
||||
gpu = input('I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU'
|
||||
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
|
||||
if gpu:
|
||||
opt['gpu_ids'] = [int(gpu)]
|
||||
if launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
rank = -1
|
||||
print('Disabled distributed training.')
|
||||
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
device_id = torch.cuda.current_device()
|
||||
resume_state = torch.load(opt['path']['resume_state'],
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
option.check_resume(opt, resume_state['iter']) # check resume options
|
||||
else:
|
||||
resume_state = None
|
||||
|
||||
#### mkdir and loggers
|
||||
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
||||
if resume_state is None:
|
||||
util.mkdir_and_rename(
|
||||
opt['path']['experiments_root']) # rename experiment folder if exists
|
||||
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
||||
and 'pretrain_model' not in key and 'resume' not in key))
|
||||
|
||||
# config loggers. Before it, the log will not work
|
||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
||||
version = float(torch.__version__[0:3])
|
||||
if version >= 1.1: # PyTorch 1.1
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
else:
|
||||
logger.info(
|
||||
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
||||
from tensorboardX import SummaryWriter
|
||||
tb_logger = SummaryWriter(log_dir=tb_logger_path)
|
||||
else:
|
||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
# convert to NoneDict, which returns None for missing keys
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
#### random seed
|
||||
seed = opt['train']['manual_seed']
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
if rank <= 0:
|
||||
logger.info('Random seed: {}'.format(seed))
|
||||
util.set_random_seed(seed)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
# Save the compiled opt dict to the global loaded_options variable.
|
||||
util.loaded_options = opt
|
||||
|
||||
#### create train and val dataloader
|
||||
dataset_ratio = 1 # enlarge the size of each epoch
|
||||
for phase, dataset_opt in opt['datasets'].items():
|
||||
if phase == 'train':
|
||||
train_set = create_dataset(dataset_opt)
|
||||
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
||||
total_iters = int(opt['train']['niter'])
|
||||
total_epochs = int(math.ceil(total_iters / train_size))
|
||||
if opt['dist']:
|
||||
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
|
||||
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||
else:
|
||||
train_sampler = None
|
||||
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
||||
if rank <= 0:
|
||||
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(train_set), train_size))
|
||||
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
||||
total_epochs, total_iters))
|
||||
elif phase == 'val':
|
||||
val_set = create_dataset(dataset_opt)
|
||||
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
||||
if rank <= 0:
|
||||
logger.info('Number of val images in [{:s}]: {:d}'.format(
|
||||
dataset_opt['name'], len(val_set)))
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
device_id = torch.cuda.current_device()
|
||||
resume_state = torch.load(opt['path']['resume_state'],
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
option.check_resume(opt, resume_state['iter']) # check resume options
|
||||
else:
|
||||
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
||||
assert train_loader is not None
|
||||
resume_state = None
|
||||
|
||||
#### create model
|
||||
model = ExtensibleTrainer(opt, all_networks)
|
||||
#### mkdir and loggers
|
||||
if self.rank <= 0: # normal training (self.rank -1) OR distributed training (self.rank 0)
|
||||
if resume_state is None:
|
||||
util.mkdir_and_rename(
|
||||
opt['path']['experiments_root']) # rename experiment folder if exists
|
||||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
||||
and 'pretrain_model' not in key and 'resume' not in key))
|
||||
|
||||
#### resume training
|
||||
if resume_state:
|
||||
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
||||
resume_state['epoch'], resume_state['iter']))
|
||||
# config loggers. Before it, the log will not work
|
||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
self.logger = logging.getLogger('base')
|
||||
self.logger.info(option.dict2str(opt))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
||||
version = float(torch.__version__[0:3])
|
||||
if version >= 1.1: # PyTorch 1.1
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
else:
|
||||
self.self.logger.info(
|
||||
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
||||
from tensorboardX import SummaryWriter
|
||||
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
|
||||
else:
|
||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||
self.logger = logging.getLogger('base')
|
||||
|
||||
start_epoch = resume_state['epoch']
|
||||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
start_epoch = 0
|
||||
if 'force_start_step' in opt.keys():
|
||||
current_step = opt['force_start_step']
|
||||
# convert to NoneDict, which returns None for missing keys
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
self.opt = opt
|
||||
|
||||
#### training
|
||||
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
||||
for epoch in range(start_epoch, total_epochs + 1):
|
||||
if opt['dist']:
|
||||
train_sampler.set_epoch(epoch)
|
||||
tq_ldr = tqdm(train_loader, position=trainer_id)
|
||||
#### random seed
|
||||
seed = opt['train']['manual_seed']
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Random seed: {}'.format(seed))
|
||||
util.set_random_seed(seed)
|
||||
|
||||
_t = time()
|
||||
_profile = False
|
||||
for train_data in tq_ldr:
|
||||
# Yielding supports multi-modal trainer which operates multiple train.py instances.
|
||||
yield model
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
if _profile:
|
||||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
# Save the compiled opt dict to the global loaded_options variable.
|
||||
util.loaded_options = opt
|
||||
|
||||
current_step += 1
|
||||
if current_step > total_iters:
|
||||
break
|
||||
#### update learning rate
|
||||
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||
#### create train and val dataloader
|
||||
dataset_ratio = 1 # enlarge the size of each epoch
|
||||
for phase, dataset_opt in opt['datasets'].items():
|
||||
if phase == 'train':
|
||||
self.train_set = create_dataset(dataset_opt)
|
||||
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
||||
total_iters = int(opt['train']['niter'])
|
||||
self.total_epochs = int(math.ceil(total_iters / train_size))
|
||||
if opt['dist']:
|
||||
train_sampler = DistIterSampler(self.train_set, world_size, self.rank, dataset_ratio)
|
||||
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||
else:
|
||||
train_sampler = None
|
||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, train_sampler)
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(self.train_set), train_size))
|
||||
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
||||
self.total_epochs, total_iters))
|
||||
elif phase == 'val':
|
||||
self.val_set = create_dataset(dataset_opt)
|
||||
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None)
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Number of val images in [{:s}]: {:d}'.format(
|
||||
dataset_opt['name'], len(self.val_set)))
|
||||
else:
|
||||
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
||||
assert self.train_loader is not None
|
||||
|
||||
#### training
|
||||
if _profile:
|
||||
print("Update LR: %f" % (time() - _t))
|
||||
_t = time()
|
||||
model.feed_data(train_data)
|
||||
model.optimize_parameters(current_step)
|
||||
if _profile:
|
||||
print("Model feed + step: %f" % (time() - _t))
|
||||
_t = time()
|
||||
#### create model
|
||||
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
|
||||
|
||||
#### log
|
||||
if current_step % opt['logger']['print_freq'] == 0 and rank <= 0:
|
||||
logs = model.get_current_log(current_step)
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||
for v in model.get_current_learning_rate():
|
||||
message += '{:.3e},'.format(v)
|
||||
message += ')] '
|
||||
for k, v in logs.items():
|
||||
if 'histogram' in k:
|
||||
tb_logger.add_histogram(k, v, current_step)
|
||||
elif isinstance(v, dict):
|
||||
tb_logger.add_scalars(k, v, current_step)
|
||||
else:
|
||||
message += '{:s}: {:.4e} '.format(k, v)
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
tb_logger.add_scalar(k, v, current_step)
|
||||
logger.info(message)
|
||||
#### resume training
|
||||
if resume_state:
|
||||
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
||||
resume_state['epoch'], resume_state['iter']))
|
||||
|
||||
#### save models and training states
|
||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
if rank <= 0:
|
||||
logger.info('Saving models and training states.')
|
||||
model.save(current_step)
|
||||
model.save_training_state(epoch, current_step)
|
||||
if 'alt_path' in opt['path'].keys():
|
||||
import shutil
|
||||
print("Synchronizing tb_logger to alt_path..")
|
||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||
shutil.copytree(tb_logger_path, alt_tblogger)
|
||||
self.start_epoch = resume_state['epoch']
|
||||
self.current_step = resume_state['iter']
|
||||
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
|
||||
else:
|
||||
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
self.start_epoch = 0
|
||||
if 'force_start_step' in opt.keys():
|
||||
self.current_step = opt['force_start_step']
|
||||
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||
avg_psnr = 0.
|
||||
avg_fea_loss = 0.
|
||||
idx = 0
|
||||
val_tqdm = tqdm(val_loader)
|
||||
for val_data in val_tqdm:
|
||||
idx += 1
|
||||
for b in range(len(val_data['LQ_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
util.mkdir(img_dir)
|
||||
def do_step(self, train_data):
|
||||
if self._profile:
|
||||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
model.feed_data(val_data)
|
||||
model.test()
|
||||
opt = self.opt
|
||||
self.current_step += 1
|
||||
#### update learning rate
|
||||
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||
|
||||
visuals = model.get_current_visuals()
|
||||
if visuals is None:
|
||||
continue
|
||||
#### training
|
||||
if self._profile:
|
||||
print("Update LR: %f" % (time() - _t))
|
||||
_t = time()
|
||||
self.model.feed_data(train_data)
|
||||
self.model.optimize_parameters(self.current_step)
|
||||
if self._profile:
|
||||
print("Model feed + step: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
# calculate PSNR
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
util.save_img(sr_img, save_img_path)
|
||||
|
||||
avg_psnr = avg_psnr / idx
|
||||
avg_fea_loss = avg_fea_loss / idx
|
||||
|
||||
# log
|
||||
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
#### log
|
||||
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
||||
logs = self.model.get_current_log(self.current_step)
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, self.current_step)
|
||||
for v in self.model.get_current_learning_rate():
|
||||
message += '{:.3e},'.format(v)
|
||||
message += ')] '
|
||||
for k, v in logs.items():
|
||||
if 'histogram' in k:
|
||||
self.tb_logger.add_histogram(k, v, self.current_step)
|
||||
elif isinstance(v, dict):
|
||||
self.tb_logger.add_scalars(k, v, self.current_step)
|
||||
else:
|
||||
message += '{:s}: {:.4e} '.format(k, v)
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0:
|
||||
tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
self.tb_logger.add_scalar(k, v, self.current_step)
|
||||
self.logger.info(message)
|
||||
|
||||
if rank <= 0:
|
||||
logger.info('Saving the final model.')
|
||||
model.save('latest')
|
||||
logger.info('End of training.')
|
||||
tb_logger.close()
|
||||
#### save models and training states
|
||||
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Saving models and training states.')
|
||||
self.model.save(self.current_step)
|
||||
self.model.save_training_state(epoch, self.current_step)
|
||||
if 'alt_path' in opt['path'].keys():
|
||||
import shutil
|
||||
print("Synchronizing tb_logger to alt_path..")
|
||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
||||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
||||
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and self.current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan',
|
||||
'extensibletrainer'] and self.rank <= 0: # image restoration validation
|
||||
avg_psnr = 0.
|
||||
avg_fea_loss = 0.
|
||||
idx = 0
|
||||
val_tqdm = tqdm(self.val_loader)
|
||||
for val_data in val_tqdm:
|
||||
idx += 1
|
||||
for b in range(len(val_data['LQ_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
util.mkdir(img_dir)
|
||||
|
||||
self.model.feed_data(val_data)
|
||||
self.model.test()
|
||||
|
||||
visuals = self.model.get_current_visuals()
|
||||
if visuals is None:
|
||||
continue
|
||||
|
||||
# calculate PSNR
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
util.save_img(sr_img, save_img_path)
|
||||
|
||||
avg_psnr = avg_psnr / idx
|
||||
avg_fea_loss = avg_fea_loss / idx
|
||||
|
||||
# log
|
||||
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
|
||||
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
|
||||
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
|
||||
|
||||
def do_training(self):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
if opt['dist']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
tq_ldr = tqdm(self.train_loader)
|
||||
|
||||
_t = time()
|
||||
for train_data in tq_ldr:
|
||||
self.do_step(train_data)
|
||||
|
||||
def create_training_generator(self, index):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
if self.opt['dist']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
tq_ldr = tqdm(self.train_loader, position=index)
|
||||
|
||||
_t = time()
|
||||
for train_data in tq_ldr:
|
||||
yield self.model
|
||||
self.do_step(train_data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -506,4 +280,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
main(opt, args.launcher)
|
||||
trainer = Trainer()
|
||||
trainer.init(opt, args.launcher)
|
||||
trainer.do_training()
|
||||
|
|
Loading…
Reference in New Issue
Block a user