diff --git a/codes/data_scripts/validate_data.py b/codes/data_scripts/validate_data.py new file mode 100644 index 00000000..ac7684c0 --- /dev/null +++ b/codes/data_scripts/validate_data.py @@ -0,0 +1,66 @@ +# This script iterates through all the data with no worker threads and performs whatever transformations are prescribed. +# The idea is to find bad/corrupt images. + +import math +import argparse +import random +import torch +import options.options as option +from utils import util +from data import create_dataloader, create_dataset +from time import time +from tqdm import tqdm +from skimage import io + +def main(): + #### options + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_mi1_spsr_switched2.yml') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = option.parse(args.opt, is_train=True) + + #### distributed training settings + opt['dist'] = False + rank = -1 + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + + #### random seed + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + util.set_random_seed(seed) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + #### create train and val dataloader + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + dataset_opt['n_workers'] = 0 # Force num_workers=0 to make dataloader work in process. + train_loader = create_dataloader(train_set, dataset_opt, opt, None) + if rank <= 0: + print('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + assert train_loader is not None + + tq_ldr = tqdm(train_set.paths_GT) + for path in tq_ldr: + try: + _ = io.imread(path) + # Do stuff with img + except Exception as e: + print("Error with %s" % (path,)) + print(e) + + +if __name__ == '__main__': + main() diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index fef8c85a..4e2ec4c0 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -1,22 +1,18 @@ import logging -from collections import OrderedDict -import torch -import torch.nn as nn -from torch.nn.parallel import DataParallel, DistributedDataParallel -import models.networks as networks -from models.steps.steps import create_step -import models.lr_scheduler as lr_scheduler -from models.base_model import BaseModel -from models.loss import GANLoss, FDPLLoss -from apex import amp -from data.weight_scheduler import get_scheduler_for_opt -from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding -import torch.nn.functional as F -import glob -import random - -import torchvision.utils as utils import os +import random +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision.utils as utils +from apex import amp +from torch.nn.parallel import DataParallel, DistributedDataParallel + +import models.lr_scheduler as lr_scheduler +import models.networks as networks +from models.base_model import BaseModel +from models.steps.steps import ConfigurableStep logger = logging.getLogger('base') @@ -31,15 +27,20 @@ class ExtensibleTrainer(BaseModel): train_opt = opt['train'] self.mega_batch_factor = 1 + # env is used as a global state to store things that subcomponents might need. + env = {'device': self.device, + 'rank': self.rank, + 'opt': opt} + self.netsG = {} self.netsD = {} self.networks = [] for name, net in opt['networks'].items(): if net['type'] == 'generator': - new_net = networks.define_G(net) + new_net = networks.define_G(net, None, opt['scale']).to(self.device) self.netsG[name] = new_net elif net['type'] == 'discriminator': - new_net = networks.define_D(net) + new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device) self.netsD[name] = new_net else: raise NotImplementedError("Can only handle generators and discriminators") @@ -51,7 +52,7 @@ class ExtensibleTrainer(BaseModel): self.mega_batch_factor = 1 # Initialize amp. - amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_level'], num_losses=len(self.optimizers)) + amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) # self.networks is stored unwrapped. It should never be used for forward() or backward() passes, instead use # self.netG and self.netD for that. self.networks = amp_nets @@ -76,15 +77,18 @@ class ExtensibleTrainer(BaseModel): for dnet in dnets: for net_dict in [self.netsD, self.netsG]: for k, v in net_dict.items(): - if v == dnet: + if v == dnet.module: net_dict[k] = dnet found += 1 assert found == len(self.networks) + env['generators'] = self.netsG + env['discriminators'] = self.netsD + # Initialize the training steps self.steps = [] - for step in opt['steps']: - step = create_step(step, self.netsG, self.netsD) + for step_name, step in opt['steps'].items(): + step = ConfigurableStep(step, env) self.steps.append(step) self.optimizers.extend(step.get_optimizers()) @@ -113,8 +117,8 @@ class ExtensibleTrainer(BaseModel): net.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) # Iterate through the steps, performing them one at a time. - state = {'lr': self.var_L, 'hr': self.var_H, 'ref': self.var_ref} - for s in self.steps: + state = {'lq': self.var_L, 'hq': self.var_H, 'ref': self.var_ref} + for step_num, s in enumerate(self.steps): # Only set requires_grad=True for the network being trained. nets_to_train = s.get_networks_trained() for name, net in self.networks.items(): @@ -126,8 +130,20 @@ class ExtensibleTrainer(BaseModel): p.requires_grad = False # Now do a forward and backward pass for each gradient accumulation step. + new_states = {} for m in range(self.mega_batch_factor): - state = s.do_forward_backward(state, m) + ns = s.do_forward_backward(state, m, step_num) + for k, v in ns.items(): + if k not in new_states.keys(): + new_states[k] = [v.detach()] + else: + new_states[k].append(v.detach()) + + # Push the detached new state tensors into the state map for use with the next step. + for k, v in new_states.items(): + # Overwriting existing state keys is not supported. + assert k not in state.keys() + state[k] = v # And finally perform optimization. s.do_step() diff --git a/codes/models/__init__.py b/codes/models/__init__.py index 4cb3264a..26f0e1fa 100644 --- a/codes/models/__init__.py +++ b/codes/models/__init__.py @@ -13,6 +13,8 @@ def create_model(opt): from .feature_model import FeatureModel as M elif model == 'spsr': from .SPSR_model import SPSRModel as M + elif model == 'extensibletrainer': + from .ExtensibleTrainer import ExtensibleTrainer as M else: raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) m = M(opt) diff --git a/codes/models/networks.py b/codes/models/networks.py index 6cfd2089..dc761243 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -17,10 +17,14 @@ import functools from collections import OrderedDict # Generator -def define_G(opt, net_key='network_G'): - opt_net = opt[net_key] +def define_G(opt, net_key='network_G', scale=None): + if net_key is not None: + opt_net = opt[net_key] + else: + opt_net = opt + if scale is None: + scale = opt['scale'] which_model = opt_net['which_model_G'] - scale = opt['scale'] # image restoration if which_model == 'MSRResNet': diff --git a/codes/models/steps/__init__.py b/codes/models/steps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py new file mode 100644 index 00000000..f5954087 --- /dev/null +++ b/codes/models/steps/injectors.py @@ -0,0 +1,32 @@ +import torch.nn +from models.archs.SPSR_arch import ImageGradientNoPadding + +# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. +def create_injector(opt_inject, env): + type = opt_inject['type'] + if type == 'img_grad': + return ImageGradientInjector(opt_inject, env) + else: + raise NotImplementedError + + +class Injector(torch.nn.Module): + def __init__(self, opt, env): + super(self, Injector).__init__() + self.opt = opt + self.env = env + self.input = opt['in'] + self.output = opt['out'] + + # This should return a dict of new state variables. + def forward(self, state): + raise NotImplementedError + + +class ImageGradientInjector(Injector): + def __init__(self, opt, env): + super(self, ImageGradientInjector).__init__(opt, env) + self.img_grad_fn = ImageGradientNoPadding() + + def forward(self, state): + return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])} \ No newline at end of file diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py new file mode 100644 index 00000000..3f80978d --- /dev/null +++ b/codes/models/steps/losses.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +from models.networks import define_F +from models.loss import GANLoss + + +def create_generator_loss(opt_loss, env): + type = opt_loss['type'] + if type == 'pix': + return PixLoss(opt_loss, env) + elif type == 'feature': + return FeatureLoss(opt_loss, env) + elif type == 'generator_gan': + return GeneratorGanLoss(opt_loss, env) + elif type == 'discriminator_gan': + return DiscriminatorGanLoss(opt_loss, env) + else: + raise NotImplementedError + + +class ConfigurableLoss(nn.Module): + def __init__(self, opt, env): + super(self, ConfigurableLoss).__init__() + self.opt = opt + self.env = env + + def forward(self, net, state): + raise NotImplementedError + + +def get_basic_criterion_for_name(name, device): + if name == 'l1': + return nn.L1Loss(device=device) + elif name == 'l2': + return nn.MSELoss(device=device) + else: + raise NotImplementedError + + +class PixLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(self, PixLoss).__init__(opt, env) + self.opt = opt + self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) + + def forward(self, net, state): + return self.criterion(state[self.opt['fake']], state[self.opt['real']]) + + +class FeatureLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(self, FeatureLoss).__init__(opt, env) + self.opt = opt + self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) + self.netF = define_F(opt).to(self.env['device']) + + def forward(self, net, state): + with torch.no_grad(): + logits_real = self.netF(state[self.opt['real']]) + logits_fake = self.netF(state[self.opt['fake']]) + return self.criterion(logits_fake, logits_real) + + +class GeneratorGanLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(self, GeneratorGanLoss).__init__(opt, env) + self.opt = opt + self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + self.netD = env['discriminators'][opt['discriminator']] + + def forward(self, net, state): + if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: + if self.opt['gan_type'] == 'crossgan': + pred_g_fake = self.netD(state[self.opt['fake']], state['lq']) + else: + pred_g_fake = self.netD(state[self.opt['fake']]) + return self.criterion(pred_g_fake, True) + elif self.opt['gan_type'] == 'ragan': + pred_d_real = self.netD(state[self.opt['real']]).detach() + pred_g_fake = self.netD(state[self.opt['fake']]) + return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + else: + raise NotImplementedError + + +class DiscriminatorGanLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(self, DiscriminatorGanLoss).__init__(opt, env) + self.opt = opt + self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + + def forward(self, net, state): + if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']: + if self.opt['gan_type'] == 'crossgan': + pred_g_fake = net(state[self.opt['fake']].detach(), state['lq']) + else: + pred_g_fake = net(state[self.opt['fake']].detach()) + return self.criterion(pred_g_fake, False) + elif self.opt['gan_type'] == 'ragan': + pred_d_real = self.netD(state[self.opt['real']]) + pred_g_fake = self.netD(state[self.opt['fake']].detach()) + return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), True) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), False)) / 2 + else: + raise NotImplementedError diff --git a/codes/models/steps/losses/generator_losses.py b/codes/models/steps/losses/generator_losses.py deleted file mode 100644 index 5ae088f7..00000000 --- a/codes/models/steps/losses/generator_losses.py +++ /dev/null @@ -1,9 +0,0 @@ -def create_generator_loss(opt_loss): - pass - - -class GeneratorLoss: - def __init__(self, opt): - self.opt = opt - - def get_loss(self, var_L, var_H, var_Gen, extras=None): \ No newline at end of file diff --git a/codes/models/steps/srgan_generator_step.py b/codes/models/steps/srgan_generator_step.py deleted file mode 100644 index 4d7b58ca..00000000 --- a/codes/models/steps/srgan_generator_step.py +++ /dev/null @@ -1,46 +0,0 @@ -# Defines the expected API for a step -class SrGanGeneratorStep: - - def __init__(self, opt_step, opt, netsG, netsD): - self.step_opt = opt_step - self.opt = opt - self.gen = netsG['base'] - self.disc = netsD['base'] - for loss in self.step_opt['losses']: - - # G pixel loss - if train_opt['pixel_weight'] > 0: - l_pix_type = train_opt['pixel_criterion'] - if l_pix_type == 'l1': - self.cri_pix = nn.L1Loss().to(self.device) - elif l_pix_type == 'l2': - self.cri_pix = nn.MSELoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) - self.l_pix_w = train_opt['pixel_weight'] - else: - logger.info('Remove pixel loss.') - self.cri_pix = None - - - # Returns all optimizers used in this step. - def get_optimizers(self): - pass - - # Returns optimizers which are opting in for default LR scheduling. - def get_optimizers_with_default_scheduler(self): - pass - - # Returns the names of the networks this step will train. Other networks will be frozen. - def get_networks_trained(self): - pass - - # Performs all forward and backward passes for this step given an input state. All input states are lists or - # chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step - # exports (which may be used by subsequent steps) - def do_forward_backward(self, state, grad_accum_step): - return state - - # Performs the optimizer step after all gradient accumulation is completed. - def do_step(self): - pass \ No newline at end of file diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index bcd6f2a2..e71d3164 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -1,29 +1,117 @@ +from utils.loss_accumulator import LossAccumulator +from torch.nn import Module +import logging +from models.steps.losses import create_generator_loss +import torch +from apex import amp +from collections import OrderedDict +from .injectors import create_injector + +logger = logging.getLogger('base') -def create_step(opt, opt_step, netsG, netsD): - pass +# Defines the expected API for a single training step +class ConfigurableStep(Module): + def __init__(self, opt_step, env): + super(ConfigurableStep, self).__init__() + + self.step_opt = opt_step + self.env = env + self.opt = env['opt'] + self.gen = env['generators'][opt_step['generator']] + self.discs = env['discriminators'] + self.gen_outputs = opt_step['generator_outputs'] + self.training_net = env['generators'][opt_step['training']] if opt_step['training'] in env['generators'].keys() else env['discriminators'][opt_step['training']] + self.loss_accumulator = LossAccumulator() + + self.injectors = [] + if 'injectors' in self.step_opt.keys(): + for inj_name, injector in self.step_opt['injectors'].items(): + self.injectors.append(create_injector(injector, env)) + + losses = [] + self.weights = {} + for loss_name, loss in self.step_opt['losses'].items(): + losses.append((loss_name, create_generator_loss(loss, env))) + self.weights[loss_name] = loss['weight'] + self.losses = OrderedDict(losses) + + # Intentionally abstract so subclasses can have alternative optimizers. + self.define_optimizers() + + # Subclasses should override this to define individual optimizers. They should all go into self.optimizers. + # This default implementation defines a single optimizer for all Generator parameters. + def define_optimizers(self): + optim_params = [] + for k, v in self.training_net.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.env['rank'] <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'], + weight_decay=self.step_opt['weight_decay'], + betas=(self.step_opt['beta1'], self.step_opt['beta2'])) + self.optimizers = [opt] -# Defines the expected API for a step -class base_step: # Returns all optimizers used in this step. def get_optimizers(self): - pass + assert self.optimizers is not None + return self.optimizers # Returns optimizers which are opting in for default LR scheduling. def get_optimizers_with_default_scheduler(self): - pass + assert self.optimizers is not None + return self.optimizers # Returns the names of the networks this step will train. Other networks will be frozen. def get_networks_trained(self): - pass + return [self.step_opt['training']] - # Performs all forward and backward passes for this step given an input state. All input states are lists or - # chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step - # exports (which may be used by subsequent steps) - def do_forward_backward(self, state, grad_accum_step): - return state + # Performs all forward and backward passes for this step given an input state. All input states are lists of + # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later + # steps might use. These tensors are automatically detached and accumulated into chunks. + def do_forward_backward(self, state, grad_accum_step, amp_loss_id): + # First, do a forward pass with the generator. + results = self.gen(state[self.step_opt['generator_input']][grad_accum_step]) + # Extract the resultants into a "new_state" dict per the configuration. + new_state = {} + for i, gen_out in enumerate(self.gen_outputs): + new_state[gen_out] = results[i] - # Performs the optimizer step after all gradient accumulation is completed. + # Prepare a de-chunked state dict which will be used for the injectors & losses. + local_state = {} + for k, v in state.items(): + local_state[k] = v[grad_accum_step] + local_state.update(new_state) + + # Inject in any extra dependencies. + for inj in self.injectors: + injected = inj(local_state) + local_state.update(injected) + new_state.update(injected) + + # Finally, compute the losses. + total_loss = 0 + for loss_name, loss in self.losses.items(): + l = loss(self.training_net, local_state) + self.loss_accumulator.add_loss(loss_name, l) + total_loss += l * self.weights[loss_name] + self.loss_accumulator.add_loss("total", total_loss) + + # Get dem grads! + with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: + scaled_loss.backward() + + return new_state + + + # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps() + # all self.optimizers. def do_step(self): - pass \ No newline at end of file + for opt in self.optimizers: + opt.step() + + def get_metrics(self): + return self.loss_accumulator.as_dict() \ No newline at end of file diff --git a/codes/train2.py b/codes/train2.py new file mode 100644 index 00000000..5f24dc15 --- /dev/null +++ b/codes/train2.py @@ -0,0 +1,289 @@ +import os +import math +import argparse +import random +import logging +import shutil +from tqdm import tqdm + +import torch +from data.data_sampler import DistIterSampler + +import options.options as option +from utils import util +from data import create_dataloader, create_dataset +from models import create_model +from time import time + + +def init_dist(backend='nccl', **kwargs): + # These packages have globals that screw with Windows, so only import them if needed. + import torch.distributed as dist + import torch.multiprocessing as mp + + """initialization for distributed training""" + if mp.get_start_method(allow_none=True) != 'spawn': + mp.set_start_method('spawn') + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + +def main(): + #### options + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_nt_spsr_switched.yml') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = option.parse(args.opt, is_train=True) + + colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode'] + if colab_mode: + # Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there. + # Each one should have a TEST file in it. + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], + os.path.join(opt['remote_path'], 'training_state', "TEST")) + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], + os.path.join(opt['remote_path'], 'models', "TEST")) + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], + os.path.join(opt['remote_path'], 'val_images', "TEST")) + # Load the state and models needed from the remote server. + if opt['path']['resume_state']: + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state'])) + if opt['path']['pretrain_model_G']: + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G'])) + if opt['path']['pretrain_model_D']: + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D'])) + + #### distributed training settings + if args.launcher == 'none': # disabled distributed training + opt['dist'] = False + rank = -1 + print('Disabled distributed training.') + else: + opt['dist'] = True + init_dist() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + #### loading resume state if exists + if opt['path'].get('resume_state', None): + # distributed resuming: all load into default GPU + device_id = torch.cuda.current_device() + resume_state = torch.load(opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + option.check_resume(opt, resume_state['iter']) # check resume options + else: + resume_state = None + + #### mkdir and loggers + if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) + if resume_state is None: + util.mkdir_and_rename( + opt['path']['experiments_root']) # rename experiment folder if exists + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key)) + + # config loggers. Before it, the log will not work + util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger') + version = float(torch.__version__[0:3]) + if version >= 1.1: # PyTorch 1.1 + from torch.utils.tensorboard import SummaryWriter + else: + logger.info( + 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(log_dir=tb_logger_path) + else: + util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) + logger = logging.getLogger('base') + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + + #### random seed + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + if rank <= 0: + logger.info('Random seed: {}'.format(seed)) + util.set_random_seed(seed) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + #### create train and val dataloader + dataset_ratio = 200 # enlarge the size of each epoch + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + if opt['dist']: + train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) + total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + else: + train_sampler = None + train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) + if rank <= 0: + logger.info('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + logger.info('Total epochs needed: {:d} for iters {:,d}'.format( + total_epochs, total_iters)) + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader(val_set, dataset_opt, opt, None) + if rank <= 0: + logger.info('Number of val images in [{:s}]: {:d}'.format( + dataset_opt['name'], len(val_set))) + else: + raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) + assert train_loader is not None + + #### create model + model = create_model(opt) + + #### resume training + if resume_state: + logger.info('Resuming training from epoch: {}, iter: {}.'.format( + resume_state['epoch'], resume_state['iter'])) + + start_epoch = resume_state['epoch'] + current_step = resume_state['iter'] + model.resume_training(resume_state) # handle optimizers and schedulers + else: + current_step = -1 if 'start_step' not in opt.keys() else opt['start_step'] + start_epoch = 0 + + #### training + logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) + for epoch in range(start_epoch, total_epochs + 1): + if opt['dist']: + train_sampler.set_epoch(epoch) + tq_ldr = tqdm(train_loader) + + _t = time() + _profile = False + for _, train_data in enumerate(tq_ldr): + if _profile: + print("Data fetch: %f" % (time() - _t)) + _t = time() + + current_step += 1 + if current_step > total_iters: + break + #### update learning rate + model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) + + #### training + if _profile: + print("Update LR: %f" % (time() - _t)) + _t = time() + model.feed_data(train_data) + model.optimize_parameters(current_step) + if _profile: + print("Model feed + step: %f" % (time() - _t)) + _t = time() + + #### log + if current_step % opt['logger']['print_freq'] == 0: + logs = model.get_current_log(current_step) + message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) + for v in model.get_current_learning_rate(): + message += '{:.3e},'.format(v) + message += ')] ' + for k, v in logs.items(): + if 'histogram' in k: + if rank <= 0: + tb_logger.add_histogram(k, v, current_step) + else: + message += '{:s}: {:.4e} '.format(k, v) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + if rank <= 0: + tb_logger.add_scalar(k, v, current_step) + if rank <= 0: + logger.info(message) + #### validation + if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: + if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation + model.force_restore_swapout() + val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] + # does not support multi-GPU validation + pbar = util.ProgressBar(len(val_loader) * val_batch_sz) + avg_psnr = 0. + avg_fea_loss = 0. + idx = 0 + colab_imgs_to_copy = [] + for val_data in val_loader: + idx += 1 + for b in range(len(val_data['LQ_path'])): + img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0] + img_dir = os.path.join(opt['path']['val_images'], img_name) + util.mkdir(img_dir) + + model.feed_data(val_data) + model.test() + + visuals = model.get_current_visuals() + if visuals is None: + continue + + sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 + #gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + + # Save SR images for reference + img_base_name = '{:s}_{:d}.png'.format(img_name, current_step) + save_img_path = os.path.join(img_dir, img_base_name) + util.save_img(sr_img, save_img_path) + if colab_mode: + colab_imgs_to_copy.append(save_img_path) + + # calculate PSNR (Naw - don't do that. PSNR sucks) + #sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) + #avg_psnr += util.calculate_psnr(sr_img, gt_img) + #pbar.update('Test {}'.format(img_name)) + + # calculate fea loss + avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) + + if colab_mode: + util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], + colab_imgs_to_copy, + os.path.join(opt['remote_path'], 'val_images', img_base_name)) + + avg_psnr = avg_psnr / idx + avg_fea_loss = avg_fea_loss / idx + + # log + logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + #tb_logger.add_scalar('val_psnr', avg_psnr, current_step) + tb_logger.add_scalar('val_fea', avg_fea_loss, current_step) + + #### save models and training states + if current_step % opt['logger']['save_checkpoint_freq'] == 0: + if rank <= 0: + logger.info('Saving models and training states.') + model.save(current_step) + model.save_training_state(epoch, current_step) + + if rank <= 0: + logger.info('Saving the final model.') + model.save('latest') + logger.info('End of training.') + tb_logger.close() + + +if __name__ == '__main__': + main() diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py new file mode 100644 index 00000000..1f0e151a --- /dev/null +++ b/codes/utils/loss_accumulator.py @@ -0,0 +1,20 @@ +import torch + +# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting. +class LossAccumulator: + def __init__(self, buffer_sz=10): + self.buffer_sz = buffer_sz + self.buffers = {} + + def add_loss(self, name, tensor): + if name not in self.buffers.keys(): + self.buffers[name] = (0, torch.zeros(self.buffer_sz)) + i, buf = self.buffers[name] + buf[i] = tensor.detach().cpu() + self.buffers[name] = ((i+1) % self.buffer_sz, buf) + + def as_dict(self): + result = {} + for k, v in self.buffers: + result["loss_" + k] = torch.mean(v) + return result \ No newline at end of file