forked from mrq/DL-Art-School
Implement downsample GAN
This bad boy is for a workflow where you train a model on disjoint image sets to downsample a "good" set of images like a "bad" set of images looks. You then use that downsampler to generate a training set of paired images for supersampling.
This commit is contained in:
parent
ea54c7618a
commit
d95808f4ef
|
@ -7,14 +7,14 @@ import torch.utils.data as data
|
|||
import data.util as util
|
||||
|
||||
|
||||
class GTLQDataset(data.Dataset):
|
||||
class DownsampleDataset(data.Dataset):
|
||||
"""
|
||||
Reads unpaired high-resolution and low resolution images. Downsampled, LR images matching the provided high res
|
||||
images are produced and fed to the downstream model, which can be used in a pixel loss.
|
||||
Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
|
||||
downsampled LQ image from the HQ image and feeds that as well.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(GTLQDataset, self).__init__()
|
||||
super(DownsampleDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.data_type = self.opt['data_type']
|
||||
self.paths_LQ, self.paths_GT = None, None
|
||||
|
@ -23,8 +23,11 @@ class GTLQDataset(data.Dataset):
|
|||
|
||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||
|
||||
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
if self.paths_LQ and self.paths_GT:
|
||||
assert self.paths_LQ, 'LQ is required for downsampling.'
|
||||
if not self.data_sz_mismatch_ok:
|
||||
assert len(self.paths_LQ) == len(
|
||||
self.paths_GT
|
||||
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
|
||||
|
@ -41,9 +44,8 @@ class GTLQDataset(data.Dataset):
|
|||
def __getitem__(self, index):
|
||||
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||
self._init_lmdb()
|
||||
GT_path, LQ_path = None, None
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['target_size']
|
||||
GT_size = self.opt['target_size'] * scale
|
||||
|
||||
# get GT image
|
||||
GT_path = self.paths_GT[index]
|
||||
|
@ -56,43 +58,19 @@ class GTLQDataset(data.Dataset):
|
|||
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
||||
|
||||
# get LQ image
|
||||
if self.paths_LQ:
|
||||
LQ_path = self.paths_LQ[index]
|
||||
lqind = index % len(self.paths_LQ)
|
||||
LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
|
||||
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
||||
] if self.data_type == 'lmdb' else None
|
||||
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
||||
else: # down-sampling on-the-fly
|
||||
# randomly scale during training
|
||||
if self.opt['phase'] == 'train':
|
||||
random_scale = random.choice(self.random_scale_list)
|
||||
H_s, W_s, _ = img_GT.shape
|
||||
|
||||
def _mod(n, random_scale, scale, thres):
|
||||
rlt = int(n * random_scale)
|
||||
rlt = (rlt // scale) * scale
|
||||
return thres if rlt < thres else rlt
|
||||
|
||||
H_s = _mod(H_s, random_scale, scale, GT_size)
|
||||
W_s = _mod(W_s, random_scale, scale, GT_size)
|
||||
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
||||
if img_GT.ndim == 2:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
H, W, _ = img_GT.shape
|
||||
# using matlab imresize
|
||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
# Create a downsampled version of the HQ image using matlab imresize.
|
||||
img_Downsampled = util.imresize_np(img_GT, 1 / scale)
|
||||
assert img_Downsampled.ndim == 3
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
# if the image size is too small
|
||||
H, W, _ = img_GT.shape
|
||||
if H < GT_size or W < GT_size:
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
# using matlab imresize
|
||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
assert H >= GT_size and W >= GT_size
|
||||
|
||||
H, W, C = img_LQ.shape
|
||||
LQ_size = GT_size // scale
|
||||
|
@ -101,27 +79,35 @@ class GTLQDataset(data.Dataset):
|
|||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
||||
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
|
||||
img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'],
|
||||
self.opt['use_rot'])
|
||||
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_LQ = util.channel_convert(C, self.opt['color'],
|
||||
[img_LQ])[0] # TODO during val no definition
|
||||
img_Downsampled = util.channel_convert(C, self.opt['color'],
|
||||
[img_Downsampled])[0] # TODO during val no definition
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||
img_LQ = img_LQ[:, :, [2, 1, 0]]
|
||||
img_Downsampled = img_Downsampled[:, :, [2, 1, 0]]
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float()
|
||||
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
||||
|
||||
if LQ_path is None:
|
||||
LQ_path = GT_path
|
||||
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||
# This may seem really messed up, but let me explain:
|
||||
# The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
|
||||
# but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
|
||||
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
|
||||
# we just need to trick its logic into following this rules.
|
||||
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
|
||||
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
|
||||
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths_GT)
|
|
@ -55,6 +55,9 @@ class SRGANModel(BaseModel):
|
|||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
||||
self.l_fea_w = train_opt['feature_weight']
|
||||
self.l_fea_w_decay = train_opt['feature_weight_decay']
|
||||
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
|
||||
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
|
||||
else:
|
||||
logger.info('Remove feature loss.')
|
||||
self.cri_fea = None
|
||||
|
@ -143,13 +146,6 @@ class SRGANModel(BaseModel):
|
|||
self.pix = data['PIX'].to(self.device)
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
|
||||
if step % 50 == 0:
|
||||
for i in range(self.var_L.shape[0]):
|
||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
# G
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = False
|
||||
|
@ -157,6 +153,13 @@ class SRGANModel(BaseModel):
|
|||
self.optimizer_G.zero_grad()
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
|
||||
if step % 50 == 0:
|
||||
for i in range(self.var_L.shape[0]):
|
||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
|
@ -168,6 +171,11 @@ class SRGANModel(BaseModel):
|
|||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
||||
# in the resultant image.
|
||||
if step % self.l_fea_w_decay_steps == 0:
|
||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
|
@ -193,7 +201,8 @@ class SRGANModel(BaseModel):
|
|||
# real
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True)
|
||||
l_d_real.backward()
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
|
@ -222,6 +231,7 @@ class SRGANModel(BaseModel):
|
|||
if self.cri_pix:
|
||||
self.log_dict['l_g_pix'] = l_g_pix.item()
|
||||
if self.cri_fea:
|
||||
self.log_dict['feature_weight'] = self.l_fea_w
|
||||
self.log_dict['l_g_fea'] = l_g_fea.item()
|
||||
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||
self.log_dict['l_g_total'] = l_g_total.item()
|
||||
|
|
|
@ -7,7 +7,7 @@ def create_model(opt):
|
|||
# image restoration
|
||||
if model == 'sr': # PSNR-oriented super resolution
|
||||
from .SR_model import SRModel as M
|
||||
elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN
|
||||
elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
|
||||
from .SRGAN_model import SRGANModel as M
|
||||
# video restoration
|
||||
elif model == 'video_base':
|
||||
|
|
63
codes/models/archs/HighToLowResNet.py
Normal file
63
codes/models/archs/HighToLowResNet.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
import functools
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.arch_util as arch_util
|
||||
import torch
|
||||
|
||||
|
||||
class HighToLowResNet(nn.Module):
|
||||
''' ResNet that applies a noise channel to the input, then downsamples it. Currently only downscale=4 is supported. '''
|
||||
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, downscale=4):
|
||||
super(HighToLowResNet, self).__init__()
|
||||
self.downscale = downscale
|
||||
|
||||
# We will always apply a noise channel to the inputs, account for that here.
|
||||
in_nc += 1
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||||
basic_block2 = functools.partial(arch_util.ResidualBlock_noBN, nf=nf*2)
|
||||
# To keep the total model size down, the residual trunks will be applied across 3 downsampling stages.
|
||||
# The first will be applied against the hi-res inputs and will have only 4 layers.
|
||||
# The second will be applied after half of the downscaling and will also have only 6 layers.
|
||||
# The final will be applied against the final resolution and will have all of the remaining layers.
|
||||
self.trunk_hires = arch_util.make_layer(basic_block, 4)
|
||||
self.trunk_medres = arch_util.make_layer(basic_block, 6)
|
||||
self.trunk_lores = arch_util.make_layer(basic_block2, nb - 10)
|
||||
|
||||
# downsampling
|
||||
if self.downscale == 4:
|
||||
self.downconv1 = nn.Conv2d(nf, nf, 3, stride=2, padding=1, bias=True)
|
||||
self.downconv2 = nn.Conv2d(nf, nf*2, 3, stride=2, padding=1, bias=True)
|
||||
else:
|
||||
raise EnvironmentError("Requested downscale not supported: %i" % (downscale,))
|
||||
|
||||
self.HRconv = nn.Conv2d(nf*2, nf*2, 3, stride=1, padding=1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf*2, out_nc, 3, stride=1, padding=1, bias=True)
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
# initialization
|
||||
arch_util.initialize_weights([self.conv_first, self.HRconv, self.conv_last, self.downconv1, self.downconv2],
|
||||
0.1)
|
||||
|
||||
def forward(self, x):
|
||||
# Noise has the same shape as the input with only one channel.
|
||||
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device)
|
||||
out = torch.cat([x, rand_feature], dim=1)
|
||||
|
||||
out = self.lrelu(self.conv_first(out))
|
||||
out = self.trunk_hires(out)
|
||||
|
||||
if self.downscale == 4:
|
||||
out = self.lrelu(self.downconv1(out))
|
||||
out = self.trunk_medres(out)
|
||||
out = self.lrelu(self.downconv2(out))
|
||||
out = self.trunk_lores(out)
|
||||
|
||||
out = self.conv_last(self.lrelu(self.HRconv(out)))
|
||||
base = F.interpolate(x, scale_factor=1/self.downscale, mode='bilinear', align_corners=False)
|
||||
out += base
|
||||
return out
|
|
@ -32,7 +32,7 @@ class Discriminator_VGG_128(nn.Module):
|
|||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
|
||||
self.linear1 = nn.Linear(int(512 * 4 * input_img_factor * 4 * input_img_factor), 100)
|
||||
self.linear1 = nn.Linear(int(nf * 8 * 4 * input_img_factor * 4 * input_img_factor), 100)
|
||||
self.linear2 = nn.Linear(100, 1)
|
||||
|
||||
# activation function
|
||||
|
|
|
@ -3,6 +3,7 @@ import models.archs.SRResNet_arch as SRResNet_arch
|
|||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||
import models.archs.EDVR_arch as EDVR_arch
|
||||
import models.archs.HighToLowResNet as HighToLowResNet
|
||||
import math
|
||||
|
||||
# Generator
|
||||
|
@ -20,6 +21,10 @@ def define_G(opt):
|
|||
scale_per_step = math.sqrt(scale)
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
|
||||
# video restoration
|
||||
elif which_model == 'EDVR':
|
||||
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
||||
|
|
|
@ -15,14 +15,14 @@ def parse(opt_path, is_train=True):
|
|||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||
|
||||
opt['is_train'] = is_train
|
||||
if opt['distortion'] == 'sr':
|
||||
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
||||
scale = opt['scale']
|
||||
|
||||
# datasets
|
||||
for phase, dataset in opt['datasets'].items():
|
||||
phase = phase.split('_')[0]
|
||||
dataset['phase'] = phase
|
||||
if opt['distortion'] == 'sr':
|
||||
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
||||
dataset['scale'] = scale
|
||||
is_lmdb = False
|
||||
if dataset.get('dataroot_GT', None) is not None:
|
||||
|
@ -62,7 +62,7 @@ def parse(opt_path, is_train=True):
|
|||
opt['path']['log'] = results_root
|
||||
|
||||
# network
|
||||
if opt['distortion'] == 'sr':
|
||||
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
||||
opt['network_G']['scale'] = scale
|
||||
|
||||
return opt
|
||||
|
|
|
@ -29,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_corruptGAN_adrianna.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -176,7 +176,7 @@ def main():
|
|||
logger.info(message)
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
|
||||
# does not support multi-GPU validation
|
||||
pbar = util.ProgressBar(len(val_loader))
|
||||
avg_psnr = 0.
|
||||
|
|
Loading…
Reference in New Issue
Block a user