Add psnr approximator

This commit is contained in:
James Betker 2020-10-31 11:08:55 -06:00
parent 565517814e
commit 7303d8c932
6 changed files with 172 additions and 11 deletions

View File

@ -71,8 +71,15 @@ class ImageCorruptor:
# Large distortion blocks in part of an img, such as is used to mask out a face.
pass
elif 'lq_resampling' in aug:
# Bicubic LR->HR
pass
# Random mode interpolation HR->LR->HR
scale = 2
if 'lq_resampling4x' == aug:
scale = 4
interpolation_modes = [cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = rand_int % len(interpolation_modes)
# Downsample first, then upsample using the random mode.
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode)
elif 'color_shift' in aug:
# Color shift
pass

View File

@ -696,6 +696,16 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
torchvision.utils.save_image(self.lr[:, :3], os.path.join(experiments_path, "attention_maps",
"amap_%i_base_image.png" % (step,)))
def get_debug_values(self, step, net_name):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
def get_debug_values(self, step, net_name):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]

View File

@ -513,3 +513,93 @@ class RefDiscriminatorVgg128(nn.Module):
out = self.output_linears(torch.cat([fea, ref_vector], dim=1))
return out
class PsnrApproximator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, nf, input_img_factor=1):
super(PsnrApproximator, self).__init__()
# [64, 128, 128]
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [64, 128, 128]
self.real_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.real_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.real_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.real_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.real_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.real_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.real_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.real_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.real_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.real_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.real_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [512, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
final_nf = nf * 8
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 1024)
self.linear2 = nn.Linear(1024, 512)
self.linear3 = nn.Linear(512, 128)
self.linear4 = nn.Linear(128, 1)
def compute_body1(self, real):
fea = self.lrelu(self.real_conv0_0(real))
fea = self.lrelu(self.real_bn0_1(self.real_conv0_1(fea)))
fea = self.lrelu(self.real_bn1_0(self.real_conv1_0(fea)))
fea = self.lrelu(self.real_bn1_1(self.real_conv1_1(fea)))
fea = self.lrelu(self.real_bn2_0(self.real_conv2_0(fea)))
fea = self.lrelu(self.real_bn2_1(self.real_conv2_1(fea)))
return fea
def compute_body2(self, fake):
fea = self.lrelu(self.fake_conv0_0(fake))
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
return fea
def forward(self, real, fake):
real_fea = checkpoint(self.compute_body1, real)
fake_fea = checkpoint(self.compute_body2, fake)
fea = torch.cat([real_fea, fake_fea], dim=1)
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
fea = fea.contiguous().view(fea.size(0), -1)
fea = self.lrelu(self.linear1(fea))
fea = self.lrelu(self.linear2(fea))
fea = self.lrelu(self.linear3(fea))
out = self.linear4(fea)
return out.squeeze()

View File

@ -160,6 +160,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
elif which_model == "discriminator_refvgg":
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "psnr_approximator":
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD

View File

@ -1,3 +1,5 @@
import random
import torch.nn
from torch.cuda.amp import autocast
@ -45,6 +47,12 @@ def create_injector(opt_inject, env):
return ImageFftInjector(opt_inject, env)
elif type == 'extract_indices':
return IndicesExtractor(opt_inject, env)
elif type == 'random_shift':
return RandomShiftInjector(opt_inject, env)
elif type == 'psnr':
return PsnrInjector(opt_inject, env)
elif type == 'batch_rotate':
return BatchRotateInjector(opt_inject, env)
else:
raise NotImplementedError
@ -94,12 +102,13 @@ class DiscriminatorInjector(Injector):
super(DiscriminatorInjector, self).__init__(opt, env)
def forward(self, state):
d = self.env['discriminators'][self.opt['discriminator']]
if isinstance(self.input, list):
params = [state[i] for i in self.input]
results = d(*params)
else:
results = d(state[self.input])
with autocast(enabled=self.env['opt']['fp16']):
d = self.env['discriminators'][self.opt['discriminator']]
if isinstance(self.input, list):
params = [state[i] for i in self.input]
results = d(*params)
else:
results = d(state[self.input])
new_state = {}
if isinstance(self.output, list):
# Only dereference tuples or lists, not tensors.
@ -232,10 +241,25 @@ class MarginRemoval(Injector):
def __init__(self, opt, env):
super(MarginRemoval, self).__init__(opt, env)
self.margin = opt['margin']
self.random_shift_max = opt['random_shift_max'] if 'random_shift_max' in opt.keys() else 0
def forward(self, state):
input = state[self.input]
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
if self.random_shift_max > 0:
output = []
# This is a really shitty way of doing this. If it works at all, I should reconsider using Resample2D, for example.
for b in range(input.shape[0]):
shiftleft = random.randint(-self.random_shift_max, self.random_shift_max)
shifttop = random.randint(-self.random_shift_max, self.random_shift_max)
output.append(input[b, :, self.margin+shiftleft:-(self.margin-shiftleft),
self.margin+shifttop:-(self.margin-shifttop)])
output = torch.stack(output, dim=0)
else:
output = input[:, :, self.margin:-self.margin,
self.margin:-self.margin]
return {self.opt['out']: output}
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
class ForEachInjector(Injector):
@ -254,7 +278,7 @@ class ForEachInjector(Injector):
for i in range(inputs.shape[1]):
st['_in'] = inputs[:, i]
injs.append(self.injector(st)['_out'])
return {self.output: torch.stack(injs, dim=1)}
return {self.output: torch.stack(injs, dim=1)}
class ConstantInjector(Injector):
@ -316,3 +340,31 @@ class IndicesExtractor(Injector):
results[o] = state[self.input][:, i]
return results
class RandomShiftInjector(Injector):
def __init__(self, opt, env):
super(RandomShiftInjector, self).__init__(opt, env)
def forward(self, state):
img = state[self.input]
return {self.output: img}
class PsnrInjector(Injector):
def __init__(self, opt, env):
super(PsnrInjector, self).__init__(opt, env)
def forward(self, state):
img1, img2 = state[self.input[0]], state[self.input[1]]
mse = torch.mean((img1 - img2) ** 2, dim=[1,2,3])
return {self.output: mse}
class BatchRotateInjector(Injector):
def __init__(self, opt, env):
super(BatchRotateInjector, self).__init__(opt, env)
def forward(self, state):
img = state[self.input]
return {self.output: torch.roll(img, 1, 0)}

View File

@ -278,7 +278,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_10bl_bypass.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_justfeature.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)