From 7303d8c93279c7d74150a89fbb848f51fee72920 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 31 Oct 2020 11:08:55 -0600 Subject: [PATCH] Add psnr approximator --- codes/data/image_corruptor.py | 11 ++- .../archs/SwitchedResidualGenerator_arch.py | 10 +++ codes/models/archs/discriminator_vgg_arch.py | 90 +++++++++++++++++++ codes/models/networks.py | 2 + codes/models/steps/injectors.py | 68 ++++++++++++-- codes/train2.py | 2 +- 6 files changed, 172 insertions(+), 11 deletions(-) diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 03fcda78..dba9f638 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -71,8 +71,15 @@ class ImageCorruptor: # Large distortion blocks in part of an img, such as is used to mask out a face. pass elif 'lq_resampling' in aug: - # Bicubic LR->HR - pass + # Random mode interpolation HR->LR->HR + scale = 2 + if 'lq_resampling4x' == aug: + scale = 4 + interpolation_modes = [cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] + mode = rand_int % len(interpolation_modes) + # Downsample first, then upsample using the random mode. + img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST) + img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode) elif 'color_shift' in aug: # Color shift pass diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 6fb49c13..d09f9a64 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -696,6 +696,16 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): torchvision.utils.save_image(self.lr[:, :3], os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) + def get_debug_values(self, step, net_name): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val def get_debug_values(self, step, net_name): temp = self.switches[0].switch.temperature mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 7eab9987..66d64a8e 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -513,3 +513,93 @@ class RefDiscriminatorVgg128(nn.Module): out = self.output_linears(torch.cat([fea, ref_vector], dim=1)) return out + + +class PsnrApproximator(nn.Module): + # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. + def __init__(self, nf, input_img_factor=1): + super(PsnrApproximator, self).__init__() + + # [64, 128, 128] + self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) + self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 32, 32] + self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + + # [64, 128, 128] + self.real_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) + self.real_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.real_bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.real_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.real_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.real_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.real_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 32, 32] + self.real_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.real_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.real_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.real_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + + # [512, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + final_nf = nf * 8 + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 1024) + self.linear2 = nn.Linear(1024, 512) + self.linear3 = nn.Linear(512, 128) + self.linear4 = nn.Linear(128, 1) + + def compute_body1(self, real): + fea = self.lrelu(self.real_conv0_0(real)) + fea = self.lrelu(self.real_bn0_1(self.real_conv0_1(fea))) + fea = self.lrelu(self.real_bn1_0(self.real_conv1_0(fea))) + fea = self.lrelu(self.real_bn1_1(self.real_conv1_1(fea))) + fea = self.lrelu(self.real_bn2_0(self.real_conv2_0(fea))) + fea = self.lrelu(self.real_bn2_1(self.real_conv2_1(fea))) + return fea + + def compute_body2(self, fake): + fea = self.lrelu(self.fake_conv0_0(fake)) + fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea))) + fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea))) + fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea))) + fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea))) + fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea))) + return fea + + def forward(self, real, fake): + real_fea = checkpoint(self.compute_body1, real) + fake_fea = checkpoint(self.compute_body2, fake) + fea = torch.cat([real_fea, fake_fea], dim=1) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + + fea = fea.contiguous().view(fea.size(0), -1) + fea = self.lrelu(self.linear1(fea)) + fea = self.lrelu(self.linear2(fea)) + fea = self.lrelu(self.linear3(fea)) + out = self.linear4(fea) + return out.squeeze() \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 132d6ed1..a5d2d073 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -160,6 +160,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale']) elif which_model == "discriminator_refvgg": netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) + elif which_model == "psnr_approximator": + netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 51801827..2e277e66 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -1,3 +1,5 @@ +import random + import torch.nn from torch.cuda.amp import autocast @@ -45,6 +47,12 @@ def create_injector(opt_inject, env): return ImageFftInjector(opt_inject, env) elif type == 'extract_indices': return IndicesExtractor(opt_inject, env) + elif type == 'random_shift': + return RandomShiftInjector(opt_inject, env) + elif type == 'psnr': + return PsnrInjector(opt_inject, env) + elif type == 'batch_rotate': + return BatchRotateInjector(opt_inject, env) else: raise NotImplementedError @@ -94,12 +102,13 @@ class DiscriminatorInjector(Injector): super(DiscriminatorInjector, self).__init__(opt, env) def forward(self, state): - d = self.env['discriminators'][self.opt['discriminator']] - if isinstance(self.input, list): - params = [state[i] for i in self.input] - results = d(*params) - else: - results = d(state[self.input]) + with autocast(enabled=self.env['opt']['fp16']): + d = self.env['discriminators'][self.opt['discriminator']] + if isinstance(self.input, list): + params = [state[i] for i in self.input] + results = d(*params) + else: + results = d(state[self.input]) new_state = {} if isinstance(self.output, list): # Only dereference tuples or lists, not tensors. @@ -232,10 +241,25 @@ class MarginRemoval(Injector): def __init__(self, opt, env): super(MarginRemoval, self).__init__(opt, env) self.margin = opt['margin'] + self.random_shift_max = opt['random_shift_max'] if 'random_shift_max' in opt.keys() else 0 def forward(self, state): input = state[self.input] - return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} + if self.random_shift_max > 0: + output = [] + # This is a really shitty way of doing this. If it works at all, I should reconsider using Resample2D, for example. + for b in range(input.shape[0]): + shiftleft = random.randint(-self.random_shift_max, self.random_shift_max) + shifttop = random.randint(-self.random_shift_max, self.random_shift_max) + output.append(input[b, :, self.margin+shiftleft:-(self.margin-shiftleft), + self.margin+shifttop:-(self.margin-shifttop)]) + output = torch.stack(output, dim=0) + else: + output = input[:, :, self.margin:-self.margin, + self.margin:-self.margin] + + return {self.opt['out']: output} + # Produces an injection which is composed of applying a single injector multiple times across a single dimension. class ForEachInjector(Injector): @@ -254,7 +278,7 @@ class ForEachInjector(Injector): for i in range(inputs.shape[1]): st['_in'] = inputs[:, i] injs.append(self.injector(st)['_out']) - return {self.output: torch.stack(injs, dim=1)} + return {self.output: torch.stack(injs, dim=1)} class ConstantInjector(Injector): @@ -316,3 +340,31 @@ class IndicesExtractor(Injector): results[o] = state[self.input][:, i] return results + +class RandomShiftInjector(Injector): + def __init__(self, opt, env): + super(RandomShiftInjector, self).__init__(opt, env) + + def forward(self, state): + img = state[self.input] + return {self.output: img} + + +class PsnrInjector(Injector): + def __init__(self, opt, env): + super(PsnrInjector, self).__init__(opt, env) + + def forward(self, state): + img1, img2 = state[self.input[0]], state[self.input[1]] + mse = torch.mean((img1 - img2) ** 2, dim=[1,2,3]) + return {self.output: mse} + + +class BatchRotateInjector(Injector): + def __init__(self, opt, env): + super(BatchRotateInjector, self).__init__(opt, env) + + def forward(self, state): + img = state[self.input] + return {self.output: torch.roll(img, 1, 0)} + diff --git a/codes/train2.py b/codes/train2.py index 194a6257..c987fb64 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -278,7 +278,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_10bl_bypass.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_justfeature.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)