From d95808f4ef5c0bd243707e8528e339211d07c980 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 24 Apr 2020 00:00:46 -0600 Subject: [PATCH] Implement downsample GAN This bad boy is for a workflow where you train a model on disjoint image sets to downsample a "good" set of images like a "bad" set of images looks. You then use that downsampler to generate a training set of paired images for supersampling. --- ...{GTLQ_dataset.py => Downsample_dataset.py} | 78 ++++++++----------- codes/models/SRGAN_model.py | 34 +++++--- codes/models/__init__.py | 2 +- codes/models/archs/HighToLowResNet.py | 63 +++++++++++++++ codes/models/archs/discriminator_vgg_arch.py | 2 +- codes/models/networks.py | 5 ++ codes/options/options.py | 6 +- codes/train.py | 4 +- 8 files changed, 129 insertions(+), 65 deletions(-) rename codes/data/{GTLQ_dataset.py => Downsample_dataset.py} (56%) create mode 100644 codes/models/archs/HighToLowResNet.py diff --git a/codes/data/GTLQ_dataset.py b/codes/data/Downsample_dataset.py similarity index 56% rename from codes/data/GTLQ_dataset.py rename to codes/data/Downsample_dataset.py index c9a47a4c..efe07242 100644 --- a/codes/data/GTLQ_dataset.py +++ b/codes/data/Downsample_dataset.py @@ -7,14 +7,14 @@ import torch.utils.data as data import data.util as util -class GTLQDataset(data.Dataset): +class DownsampleDataset(data.Dataset): """ - Reads unpaired high-resolution and low resolution images. Downsampled, LR images matching the provided high res - images are produced and fed to the downstream model, which can be used in a pixel loss. + Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a + downsampled LQ image from the HQ image and feeds that as well. """ def __init__(self, opt): - super(GTLQDataset, self).__init__() + super(DownsampleDataset, self).__init__() self.opt = opt self.data_type = self.opt['data_type'] self.paths_LQ, self.paths_GT = None, None @@ -23,8 +23,11 @@ class GTLQDataset(data.Dataset): self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + + self.data_sz_mismatch_ok = opt['mismatched_Data_OK'] assert self.paths_GT, 'Error: GT path is empty.' - if self.paths_LQ and self.paths_GT: + assert self.paths_LQ, 'LQ is required for downsampling.' + if not self.data_sz_mismatch_ok: assert len(self.paths_LQ) == len( self.paths_GT ), 'GT and LQ datasets have different number of images - {}, {}.'.format( @@ -41,9 +44,8 @@ class GTLQDataset(data.Dataset): def __getitem__(self, index): if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): self._init_lmdb() - GT_path, LQ_path = None, None scale = self.opt['scale'] - GT_size = self.opt['target_size'] + GT_size = self.opt['target_size'] * scale # get GT image GT_path = self.paths_GT[index] @@ -56,43 +58,19 @@ class GTLQDataset(data.Dataset): img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] # get LQ image - if self.paths_LQ: - LQ_path = self.paths_LQ[index] - resolution = [int(s) for s in self.sizes_LQ[index].split('_') - ] if self.data_type == 'lmdb' else None - img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) - else: # down-sampling on-the-fly - # randomly scale during training - if self.opt['phase'] == 'train': - random_scale = random.choice(self.random_scale_list) - H_s, W_s, _ = img_GT.shape + lqind = index % len(self.paths_LQ) + LQ_path = self.paths_LQ[index % len(self.paths_LQ)] + resolution = [int(s) for s in self.sizes_LQ[index].split('_') + ] if self.data_type == 'lmdb' else None + img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) - def _mod(n, random_scale, scale, thres): - rlt = int(n * random_scale) - rlt = (rlt // scale) * scale - return thres if rlt < thres else rlt - - H_s = _mod(H_s, random_scale, scale, GT_size) - W_s = _mod(W_s, random_scale, scale, GT_size) - img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) - if img_GT.ndim == 2: - img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) - - H, W, _ = img_GT.shape - # using matlab imresize - img_LQ = util.imresize_np(img_GT, 1 / scale, True) - if img_LQ.ndim == 2: - img_LQ = np.expand_dims(img_LQ, axis=2) + # Create a downsampled version of the HQ image using matlab imresize. + img_Downsampled = util.imresize_np(img_GT, 1 / scale) + assert img_Downsampled.ndim == 3 if self.opt['phase'] == 'train': - # if the image size is too small H, W, _ = img_GT.shape - if H < GT_size or W < GT_size: - img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - # using matlab imresize - img_LQ = util.imresize_np(img_GT, 1 / scale, True) - if img_LQ.ndim == 2: - img_LQ = np.expand_dims(img_LQ, axis=2) + assert H >= GT_size and W >= GT_size H, W, C = img_LQ.shape LQ_size = GT_size // scale @@ -101,27 +79,35 @@ class GTLQDataset(data.Dataset): rnd_h = random.randint(0, max(0, H - LQ_size)) rnd_w = random.randint(0, max(0, W - LQ_size)) img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] + img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] # augmentation - flip, rotate - img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], + img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'], self.opt['use_rot']) if self.opt['color']: # change color space if necessary - img_LQ = util.channel_convert(C, self.opt['color'], - [img_LQ])[0] # TODO during val no definition + img_Downsampled = util.channel_convert(C, self.opt['color'], + [img_Downsampled])[0] # TODO during val no definition # BGR to RGB, HWC to CHW, numpy to tensor if img_GT.shape[2] == 3: img_GT = img_GT[:, :, [2, 1, 0]] img_LQ = img_LQ[:, :, [2, 1, 0]] + img_Downsampled = img_Downsampled[:, :, [2, 1, 0]] img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float() img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() - if LQ_path is None: - LQ_path = GT_path - return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} + # This may seem really messed up, but let me explain: + # The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample, + # but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the + # "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this, + # we just need to trick its logic into following this rules. + # Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ. + # PIX is used as a reference for the pixel loss. Use the manually downsampled image for this. + return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path} def __len__(self): return len(self.paths_GT) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index a1419a89..30c4f7aa 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -55,6 +55,9 @@ class SRGANModel(BaseModel): else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] + self.l_fea_w_decay = train_opt['feature_weight_decay'] + self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps'] + self.l_fea_w_minimum = train_opt['feature_weight_minimum'] else: logger.info('Remove feature loss.') self.cri_fea = None @@ -143,13 +146,6 @@ class SRGANModel(BaseModel): self.pix = data['PIX'].to(self.device) def optimize_parameters(self, step): - - if step % 50 == 0: - for i in range(self.var_L.shape[0]): - utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i))) - # G for p in self.netD.parameters(): p.requires_grad = False @@ -157,6 +153,13 @@ class SRGANModel(BaseModel): self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) + if step % 50 == 0: + for i in range(self.var_L.shape[0]): + utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i))) + l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss @@ -168,6 +171,11 @@ class SRGANModel(BaseModel): l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea + # Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role + # in the resultant image. + if step % self.l_fea_w_decay_steps == 0: + self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) + if self.opt['train']['gan_type'] == 'gan': pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) @@ -193,7 +201,8 @@ class SRGANModel(BaseModel): # real pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) - l_d_real.backward() + with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: + l_d_real_scaled.backward() # fake pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) @@ -222,12 +231,13 @@ class SRGANModel(BaseModel): if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: + self.log_dict['feature_weight'] = self.l_fea_w self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() - self.log_dict['l_g_total'] = l_g_total.item() - self.log_dict['l_d_real'] = l_d_real.item() - self.log_dict['l_d_fake'] = l_d_fake.item() - self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + self.log_dict['l_g_total'] = l_g_total.item() + self.log_dict['l_d_real'] = l_d_real.item() + self.log_dict['l_d_fake'] = l_d_fake.item() + self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() diff --git a/codes/models/__init__.py b/codes/models/__init__.py index c95004c9..0767eeb3 100644 --- a/codes/models/__init__.py +++ b/codes/models/__init__.py @@ -7,7 +7,7 @@ def create_model(opt): # image restoration if model == 'sr': # PSNR-oriented super resolution from .SR_model import SRModel as M - elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN + elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic from .SRGAN_model import SRGANModel as M # video restoration elif model == 'video_base': diff --git a/codes/models/archs/HighToLowResNet.py b/codes/models/archs/HighToLowResNet.py new file mode 100644 index 00000000..470359f2 --- /dev/null +++ b/codes/models/archs/HighToLowResNet.py @@ -0,0 +1,63 @@ +import functools +import torch.nn as nn +import torch.nn.functional as F +import models.archs.arch_util as arch_util +import torch + + +class HighToLowResNet(nn.Module): + ''' ResNet that applies a noise channel to the input, then downsamples it. Currently only downscale=4 is supported. ''' + + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, downscale=4): + super(HighToLowResNet, self).__init__() + self.downscale = downscale + + # We will always apply a noise channel to the inputs, account for that here. + in_nc += 1 + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) + basic_block2 = functools.partial(arch_util.ResidualBlock_noBN, nf=nf*2) + # To keep the total model size down, the residual trunks will be applied across 3 downsampling stages. + # The first will be applied against the hi-res inputs and will have only 4 layers. + # The second will be applied after half of the downscaling and will also have only 6 layers. + # The final will be applied against the final resolution and will have all of the remaining layers. + self.trunk_hires = arch_util.make_layer(basic_block, 4) + self.trunk_medres = arch_util.make_layer(basic_block, 6) + self.trunk_lores = arch_util.make_layer(basic_block2, nb - 10) + + # downsampling + if self.downscale == 4: + self.downconv1 = nn.Conv2d(nf, nf, 3, stride=2, padding=1, bias=True) + self.downconv2 = nn.Conv2d(nf, nf*2, 3, stride=2, padding=1, bias=True) + else: + raise EnvironmentError("Requested downscale not supported: %i" % (downscale,)) + + self.HRconv = nn.Conv2d(nf*2, nf*2, 3, stride=1, padding=1, bias=True) + self.conv_last = nn.Conv2d(nf*2, out_nc, 3, stride=1, padding=1, bias=True) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + # initialization + arch_util.initialize_weights([self.conv_first, self.HRconv, self.conv_last, self.downconv1, self.downconv2], + 0.1) + + def forward(self, x): + # Noise has the same shape as the input with only one channel. + rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device) + out = torch.cat([x, rand_feature], dim=1) + + out = self.lrelu(self.conv_first(out)) + out = self.trunk_hires(out) + + if self.downscale == 4: + out = self.lrelu(self.downconv1(out)) + out = self.trunk_medres(out) + out = self.lrelu(self.downconv2(out)) + out = self.trunk_lores(out) + + out = self.conv_last(self.lrelu(self.HRconv(out))) + base = F.interpolate(x, scale_factor=1/self.downscale, mode='bilinear', align_corners=False) + out += base + return out diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index ae51ba16..10a3ccdc 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -32,7 +32,7 @@ class Discriminator_VGG_128(nn.Module): self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) - self.linear1 = nn.Linear(int(512 * 4 * input_img_factor * 4 * input_img_factor), 100) + self.linear1 = nn.Linear(int(nf * 8 * 4 * input_img_factor * 4 * input_img_factor), 100) self.linear2 = nn.Linear(100, 1) # activation function diff --git a/codes/models/networks.py b/codes/models/networks.py index 2b79249b..1b7563dc 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -3,6 +3,7 @@ import models.archs.SRResNet_arch as SRResNet_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.EDVR_arch as EDVR_arch +import models.archs.HighToLowResNet as HighToLowResNet import math # Generator @@ -20,6 +21,10 @@ def define_G(opt): scale_per_step = math.sqrt(scale) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step) + # image corruption + elif which_model == 'HighToLowResNet': + netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale']) # video restoration elif which_model == 'EDVR': netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], diff --git a/codes/options/options.py b/codes/options/options.py index 99181b34..5dc34b11 100644 --- a/codes/options/options.py +++ b/codes/options/options.py @@ -15,14 +15,14 @@ def parse(opt_path, is_train=True): print('export CUDA_VISIBLE_DEVICES=' + gpu_list) opt['is_train'] = is_train - if opt['distortion'] == 'sr': + if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample': scale = opt['scale'] # datasets for phase, dataset in opt['datasets'].items(): phase = phase.split('_')[0] dataset['phase'] = phase - if opt['distortion'] == 'sr': + if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample': dataset['scale'] = scale is_lmdb = False if dataset.get('dataroot_GT', None) is not None: @@ -62,7 +62,7 @@ def parse(opt_path, is_train=True): opt['path']['log'] = results_root # network - if opt['distortion'] == 'sr': + if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample': opt['network_G']['scale'] = scale return opt diff --git a/codes/train.py b/codes/train.py index 8721d33f..2906463a 100644 --- a/codes/train.py +++ b/codes/train.py @@ -29,7 +29,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_corruptGAN_adrianna.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -176,7 +176,7 @@ def main(): logger.info(message) #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: - if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation + if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0.