Implement downsample GAN

This bad boy is for a workflow where you train a model on disjoint image sets to
downsample a "good" set of images like a "bad" set of images looks. You then
use that downsampler to generate a training set of paired images for supersampling.
This commit is contained in:
James Betker 2020-04-24 00:00:46 -06:00
parent ea54c7618a
commit d95808f4ef
8 changed files with 129 additions and 65 deletions

View File

@ -7,14 +7,14 @@ import torch.utils.data as data
import data.util as util
class GTLQDataset(data.Dataset):
class DownsampleDataset(data.Dataset):
"""
Reads unpaired high-resolution and low resolution images. Downsampled, LR images matching the provided high res
images are produced and fed to the downstream model, which can be used in a pixel loss.
Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
downsampled LQ image from the HQ image and feeds that as well.
"""
def __init__(self, opt):
super(GTLQDataset, self).__init__()
super(DownsampleDataset, self).__init__()
self.opt = opt
self.data_type = self.opt['data_type']
self.paths_LQ, self.paths_GT = None, None
@ -23,8 +23,11 @@ class GTLQDataset(data.Dataset):
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT:
assert self.paths_LQ, 'LQ is required for downsampling.'
if not self.data_sz_mismatch_ok:
assert len(self.paths_LQ) == len(
self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
@ -41,9 +44,8 @@ class GTLQDataset(data.Dataset):
def __getitem__(self, index):
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
GT_path, LQ_path = None, None
scale = self.opt['scale']
GT_size = self.opt['target_size']
GT_size = self.opt['target_size'] * scale
# get GT image
GT_path = self.paths_GT[index]
@ -56,43 +58,19 @@ class GTLQDataset(data.Dataset):
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get LQ image
if self.paths_LQ:
LQ_path = self.paths_LQ[index]
lqind = index % len(self.paths_LQ)
LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_GT.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
H, W, _ = img_GT.shape
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
# Create a downsampled version of the HQ image using matlab imresize.
img_Downsampled = util.imresize_np(img_GT, 1 / scale)
assert img_Downsampled.ndim == 3
if self.opt['phase'] == 'train':
# if the image size is too small
H, W, _ = img_GT.shape
if H < GT_size or W < GT_size:
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
assert H >= GT_size and W >= GT_size
H, W, C = img_LQ.shape
LQ_size = GT_size // scale
@ -101,27 +79,35 @@ class GTLQDataset(data.Dataset):
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
# augmentation - flip, rotate
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'],
self.opt['use_rot'])
if self.opt['color']: # change color space if necessary
img_LQ = util.channel_convert(C, self.opt['color'],
[img_LQ])[0] # TODO during val no definition
img_Downsampled = util.channel_convert(C, self.opt['color'],
[img_Downsampled])[0] # TODO during val no definition
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQ = img_LQ[:, :, [2, 1, 0]]
img_Downsampled = img_Downsampled[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float()
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
if LQ_path is None:
LQ_path = GT_path
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
# This may seem really messed up, but let me explain:
# The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
# but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
# we just need to trick its logic into following this rules.
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
def __len__(self):
return len(self.paths_GT)

View File

@ -55,6 +55,9 @@ class SRGANModel(BaseModel):
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
self.l_fea_w = train_opt['feature_weight']
self.l_fea_w_decay = train_opt['feature_weight_decay']
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
else:
logger.info('Remove feature loss.')
self.cri_fea = None
@ -143,13 +146,6 @@ class SRGANModel(BaseModel):
self.pix = data['PIX'].to(self.device)
def optimize_parameters(self, step):
if step % 50 == 0:
for i in range(self.var_L.shape[0]):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
# G
for p in self.netD.parameters():
p.requires_grad = False
@ -157,6 +153,13 @@ class SRGANModel(BaseModel):
self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
if step % 50 == 0:
for i in range(self.var_L.shape[0]):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i)))
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss
@ -168,6 +171,11 @@ class SRGANModel(BaseModel):
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
# in the resultant image.
if step % self.l_fea_w_decay_steps == 0:
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.opt['train']['gan_type'] == 'gan':
pred_g_fake = self.netD(self.fake_H)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
@ -193,7 +201,8 @@ class SRGANModel(BaseModel):
# real
pred_d_real = self.netD(self.var_ref)
l_d_real = self.cri_gan(pred_d_real, True)
l_d_real.backward()
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
# fake
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
l_d_fake = self.cri_gan(pred_d_fake, False)
@ -222,6 +231,7 @@ class SRGANModel(BaseModel):
if self.cri_pix:
self.log_dict['l_g_pix'] = l_g_pix.item()
if self.cri_fea:
self.log_dict['feature_weight'] = self.l_fea_w
self.log_dict['l_g_fea'] = l_g_fea.item()
self.log_dict['l_g_gan'] = l_g_gan.item()
self.log_dict['l_g_total'] = l_g_total.item()

View File

@ -7,7 +7,7 @@ def create_model(opt):
# image restoration
if model == 'sr': # PSNR-oriented super resolution
from .SR_model import SRModel as M
elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN
elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
from .SRGAN_model import SRGANModel as M
# video restoration
elif model == 'video_base':

View File

@ -0,0 +1,63 @@
import functools
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
import torch
class HighToLowResNet(nn.Module):
''' ResNet that applies a noise channel to the input, then downsamples it. Currently only downscale=4 is supported. '''
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, downscale=4):
super(HighToLowResNet, self).__init__()
self.downscale = downscale
# We will always apply a noise channel to the inputs, account for that here.
in_nc += 1
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
basic_block2 = functools.partial(arch_util.ResidualBlock_noBN, nf=nf*2)
# To keep the total model size down, the residual trunks will be applied across 3 downsampling stages.
# The first will be applied against the hi-res inputs and will have only 4 layers.
# The second will be applied after half of the downscaling and will also have only 6 layers.
# The final will be applied against the final resolution and will have all of the remaining layers.
self.trunk_hires = arch_util.make_layer(basic_block, 4)
self.trunk_medres = arch_util.make_layer(basic_block, 6)
self.trunk_lores = arch_util.make_layer(basic_block2, nb - 10)
# downsampling
if self.downscale == 4:
self.downconv1 = nn.Conv2d(nf, nf, 3, stride=2, padding=1, bias=True)
self.downconv2 = nn.Conv2d(nf, nf*2, 3, stride=2, padding=1, bias=True)
else:
raise EnvironmentError("Requested downscale not supported: %i" % (downscale,))
self.HRconv = nn.Conv2d(nf*2, nf*2, 3, stride=1, padding=1, bias=True)
self.conv_last = nn.Conv2d(nf*2, out_nc, 3, stride=1, padding=1, bias=True)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# initialization
arch_util.initialize_weights([self.conv_first, self.HRconv, self.conv_last, self.downconv1, self.downconv2],
0.1)
def forward(self, x):
# Noise has the same shape as the input with only one channel.
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device)
out = torch.cat([x, rand_feature], dim=1)
out = self.lrelu(self.conv_first(out))
out = self.trunk_hires(out)
if self.downscale == 4:
out = self.lrelu(self.downconv1(out))
out = self.trunk_medres(out)
out = self.lrelu(self.downconv2(out))
out = self.trunk_lores(out)
out = self.conv_last(self.lrelu(self.HRconv(out)))
base = F.interpolate(x, scale_factor=1/self.downscale, mode='bilinear', align_corners=False)
out += base
return out

View File

@ -32,7 +32,7 @@ class Discriminator_VGG_128(nn.Module):
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
self.linear1 = nn.Linear(int(512 * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear1 = nn.Linear(int(nf * 8 * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = nn.Linear(100, 1)
# activation function

View File

@ -3,6 +3,7 @@ import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.EDVR_arch as EDVR_arch
import models.archs.HighToLowResNet as HighToLowResNet
import math
# Generator
@ -20,6 +21,10 @@ def define_G(opt):
scale_per_step = math.sqrt(scale)
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
# image corruption
elif which_model == 'HighToLowResNet':
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
# video restoration
elif which_model == 'EDVR':
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],

View File

@ -15,14 +15,14 @@ def parse(opt_path, is_train=True):
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
opt['is_train'] = is_train
if opt['distortion'] == 'sr':
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
scale = opt['scale']
# datasets
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0]
dataset['phase'] = phase
if opt['distortion'] == 'sr':
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
dataset['scale'] = scale
is_lmdb = False
if dataset.get('dataroot_GT', None) is not None:
@ -62,7 +62,7 @@ def parse(opt_path, is_train=True):
opt['path']['log'] = results_root
# network
if opt['distortion'] == 'sr':
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
opt['network_G']['scale'] = scale
return opt

View File

@ -29,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_corruptGAN_adrianna.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -176,7 +176,7 @@ def main():
logger.info(message)
#### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
# does not support multi-GPU validation
pbar = util.ProgressBar(len(val_loader))
avg_psnr = 0.