Implement downsample GAN

This bad boy is for a workflow where you train a model on disjoint image sets to
downsample a "good" set of images like a "bad" set of images looks. You then
use that downsampler to generate a training set of paired images for supersampling.
This commit is contained in:
James Betker 2020-04-24 00:00:46 -06:00
parent ea54c7618a
commit d95808f4ef
8 changed files with 129 additions and 65 deletions

View File

@ -7,14 +7,14 @@ import torch.utils.data as data
import data.util as util import data.util as util
class GTLQDataset(data.Dataset): class DownsampleDataset(data.Dataset):
""" """
Reads unpaired high-resolution and low resolution images. Downsampled, LR images matching the provided high res Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
images are produced and fed to the downstream model, which can be used in a pixel loss. downsampled LQ image from the HQ image and feeds that as well.
""" """
def __init__(self, opt): def __init__(self, opt):
super(GTLQDataset, self).__init__() super(DownsampleDataset, self).__init__()
self.opt = opt self.opt = opt
self.data_type = self.opt['data_type'] self.data_type = self.opt['data_type']
self.paths_LQ, self.paths_GT = None, None self.paths_LQ, self.paths_GT = None, None
@ -23,8 +23,11 @@ class GTLQDataset(data.Dataset):
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
assert self.paths_GT, 'Error: GT path is empty.' assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT: assert self.paths_LQ, 'LQ is required for downsampling.'
if not self.data_sz_mismatch_ok:
assert len(self.paths_LQ) == len( assert len(self.paths_LQ) == len(
self.paths_GT self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format( ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
@ -41,9 +44,8 @@ class GTLQDataset(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb() self._init_lmdb()
GT_path, LQ_path = None, None
scale = self.opt['scale'] scale = self.opt['scale']
GT_size = self.opt['target_size'] GT_size = self.opt['target_size'] * scale
# get GT image # get GT image
GT_path = self.paths_GT[index] GT_path = self.paths_GT[index]
@ -56,43 +58,19 @@ class GTLQDataset(data.Dataset):
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get LQ image # get LQ image
if self.paths_LQ: lqind = index % len(self.paths_LQ)
LQ_path = self.paths_LQ[index] LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
resolution = [int(s) for s in self.sizes_LQ[index].split('_') resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None ] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_GT.shape
def _mod(n, random_scale, scale, thres): # Create a downsampled version of the HQ image using matlab imresize.
rlt = int(n * random_scale) img_Downsampled = util.imresize_np(img_GT, 1 / scale)
rlt = (rlt // scale) * scale assert img_Downsampled.ndim == 3
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
H, W, _ = img_GT.shape
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
if self.opt['phase'] == 'train': if self.opt['phase'] == 'train':
# if the image size is too small
H, W, _ = img_GT.shape H, W, _ = img_GT.shape
if H < GT_size or W < GT_size: assert H >= GT_size and W >= GT_size
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
H, W, C = img_LQ.shape H, W, C = img_LQ.shape
LQ_size = GT_size // scale LQ_size = GT_size // scale
@ -101,27 +79,35 @@ class GTLQDataset(data.Dataset):
rnd_h = random.randint(0, max(0, H - LQ_size)) rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size)) rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
# augmentation - flip, rotate # augmentation - flip, rotate
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'],
self.opt['use_rot']) self.opt['use_rot'])
if self.opt['color']: # change color space if necessary if self.opt['color']: # change color space if necessary
img_LQ = util.channel_convert(C, self.opt['color'], img_Downsampled = util.channel_convert(C, self.opt['color'],
[img_LQ])[0] # TODO during val no definition [img_Downsampled])[0] # TODO during val no definition
# BGR to RGB, HWC to CHW, numpy to tensor # BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3: if img_GT.shape[2] == 3:
img_GT = img_GT[:, :, [2, 1, 0]] img_GT = img_GT[:, :, [2, 1, 0]]
img_LQ = img_LQ[:, :, [2, 1, 0]] img_LQ = img_LQ[:, :, [2, 1, 0]]
img_Downsampled = img_Downsampled[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float()
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
if LQ_path is None: # This may seem really messed up, but let me explain:
LQ_path = GT_path # The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} # but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
# we just need to trick its logic into following this rules.
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
def __len__(self): def __len__(self):
return len(self.paths_GT) return len(self.paths_GT)

View File

@ -55,6 +55,9 @@ class SRGANModel(BaseModel):
else: else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
self.l_fea_w = train_opt['feature_weight'] self.l_fea_w = train_opt['feature_weight']
self.l_fea_w_decay = train_opt['feature_weight_decay']
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
else: else:
logger.info('Remove feature loss.') logger.info('Remove feature loss.')
self.cri_fea = None self.cri_fea = None
@ -143,13 +146,6 @@ class SRGANModel(BaseModel):
self.pix = data['PIX'].to(self.device) self.pix = data['PIX'].to(self.device)
def optimize_parameters(self, step): def optimize_parameters(self, step):
if step % 50 == 0:
for i in range(self.var_L.shape[0]):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
# G # G
for p in self.netD.parameters(): for p in self.netD.parameters():
p.requires_grad = False p.requires_grad = False
@ -157,6 +153,13 @@ class SRGANModel(BaseModel):
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L) self.fake_H = self.netG(self.var_L)
if step % 50 == 0:
for i in range(self.var_L.shape[0]):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i)))
l_g_total = 0 l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss if self.cri_pix: # pixel loss
@ -168,6 +171,11 @@ class SRGANModel(BaseModel):
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea l_g_total += l_g_fea
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
# in the resultant image.
if step % self.l_fea_w_decay_steps == 0:
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan':
pred_g_fake = self.netD(self.fake_H) pred_g_fake = self.netD(self.fake_H)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
@ -193,7 +201,8 @@ class SRGANModel(BaseModel):
# real # real
pred_d_real = self.netD(self.var_ref) pred_d_real = self.netD(self.var_ref)
l_d_real = self.cri_gan(pred_d_real, True) l_d_real = self.cri_gan(pred_d_real, True)
l_d_real.backward() with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
# fake # fake
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake = self.cri_gan(pred_d_fake, False)
@ -222,6 +231,7 @@ class SRGANModel(BaseModel):
if self.cri_pix: if self.cri_pix:
self.log_dict['l_g_pix'] = l_g_pix.item() self.log_dict['l_g_pix'] = l_g_pix.item()
if self.cri_fea: if self.cri_fea:
self.log_dict['feature_weight'] = self.l_fea_w
self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_fea'] = l_g_fea.item()
self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_g_gan'] = l_g_gan.item()
self.log_dict['l_g_total'] = l_g_total.item() self.log_dict['l_g_total'] = l_g_total.item()

View File

@ -7,7 +7,7 @@ def create_model(opt):
# image restoration # image restoration
if model == 'sr': # PSNR-oriented super resolution if model == 'sr': # PSNR-oriented super resolution
from .SR_model import SRModel as M from .SR_model import SRModel as M
elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
from .SRGAN_model import SRGANModel as M from .SRGAN_model import SRGANModel as M
# video restoration # video restoration
elif model == 'video_base': elif model == 'video_base':

View File

@ -0,0 +1,63 @@
import functools
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
import torch
class HighToLowResNet(nn.Module):
''' ResNet that applies a noise channel to the input, then downsamples it. Currently only downscale=4 is supported. '''
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, downscale=4):
super(HighToLowResNet, self).__init__()
self.downscale = downscale
# We will always apply a noise channel to the inputs, account for that here.
in_nc += 1
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
basic_block2 = functools.partial(arch_util.ResidualBlock_noBN, nf=nf*2)
# To keep the total model size down, the residual trunks will be applied across 3 downsampling stages.
# The first will be applied against the hi-res inputs and will have only 4 layers.
# The second will be applied after half of the downscaling and will also have only 6 layers.
# The final will be applied against the final resolution and will have all of the remaining layers.
self.trunk_hires = arch_util.make_layer(basic_block, 4)
self.trunk_medres = arch_util.make_layer(basic_block, 6)
self.trunk_lores = arch_util.make_layer(basic_block2, nb - 10)
# downsampling
if self.downscale == 4:
self.downconv1 = nn.Conv2d(nf, nf, 3, stride=2, padding=1, bias=True)
self.downconv2 = nn.Conv2d(nf, nf*2, 3, stride=2, padding=1, bias=True)
else:
raise EnvironmentError("Requested downscale not supported: %i" % (downscale,))
self.HRconv = nn.Conv2d(nf*2, nf*2, 3, stride=1, padding=1, bias=True)
self.conv_last = nn.Conv2d(nf*2, out_nc, 3, stride=1, padding=1, bias=True)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# initialization
arch_util.initialize_weights([self.conv_first, self.HRconv, self.conv_last, self.downconv1, self.downconv2],
0.1)
def forward(self, x):
# Noise has the same shape as the input with only one channel.
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device)
out = torch.cat([x, rand_feature], dim=1)
out = self.lrelu(self.conv_first(out))
out = self.trunk_hires(out)
if self.downscale == 4:
out = self.lrelu(self.downconv1(out))
out = self.trunk_medres(out)
out = self.lrelu(self.downconv2(out))
out = self.trunk_lores(out)
out = self.conv_last(self.lrelu(self.HRconv(out)))
base = F.interpolate(x, scale_factor=1/self.downscale, mode='bilinear', align_corners=False)
out += base
return out

View File

@ -32,7 +32,7 @@ class Discriminator_VGG_128(nn.Module):
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
self.linear1 = nn.Linear(int(512 * 4 * input_img_factor * 4 * input_img_factor), 100) self.linear1 = nn.Linear(int(nf * 8 * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = nn.Linear(100, 1) self.linear2 = nn.Linear(100, 1)
# activation function # activation function

View File

@ -3,6 +3,7 @@ import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.EDVR_arch as EDVR_arch import models.archs.EDVR_arch as EDVR_arch
import models.archs.HighToLowResNet as HighToLowResNet
import math import math
# Generator # Generator
@ -20,6 +21,10 @@ def define_G(opt):
scale_per_step = math.sqrt(scale) scale_per_step = math.sqrt(scale)
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step) nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
# image corruption
elif which_model == 'HighToLowResNet':
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
# video restoration # video restoration
elif which_model == 'EDVR': elif which_model == 'EDVR':
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],

View File

@ -15,14 +15,14 @@ def parse(opt_path, is_train=True):
print('export CUDA_VISIBLE_DEVICES=' + gpu_list) print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
opt['is_train'] = is_train opt['is_train'] = is_train
if opt['distortion'] == 'sr': if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
scale = opt['scale'] scale = opt['scale']
# datasets # datasets
for phase, dataset in opt['datasets'].items(): for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0] phase = phase.split('_')[0]
dataset['phase'] = phase dataset['phase'] = phase
if opt['distortion'] == 'sr': if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
dataset['scale'] = scale dataset['scale'] = scale
is_lmdb = False is_lmdb = False
if dataset.get('dataroot_GT', None) is not None: if dataset.get('dataroot_GT', None) is not None:
@ -62,7 +62,7 @@ def parse(opt_path, is_train=True):
opt['path']['log'] = results_root opt['path']['log'] = results_root
# network # network
if opt['distortion'] == 'sr': if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
opt['network_G']['scale'] = scale opt['network_G']['scale'] = scale
return opt return opt

View File

@ -29,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_corruptGAN_adrianna.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -176,7 +176,7 @@ def main():
logger.info(message) logger.info(message)
#### validation #### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
# does not support multi-GPU validation # does not support multi-GPU validation
pbar = util.ProgressBar(len(val_loader)) pbar = util.ProgressBar(len(val_loader))
avg_psnr = 0. avg_psnr = 0.