Add psnr approximator
This commit is contained in:
parent
565517814e
commit
7303d8c932
|
@ -71,8 +71,15 @@ class ImageCorruptor:
|
|||
# Large distortion blocks in part of an img, such as is used to mask out a face.
|
||||
pass
|
||||
elif 'lq_resampling' in aug:
|
||||
# Bicubic LR->HR
|
||||
pass
|
||||
# Random mode interpolation HR->LR->HR
|
||||
scale = 2
|
||||
if 'lq_resampling4x' == aug:
|
||||
scale = 4
|
||||
interpolation_modes = [cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||
mode = rand_int % len(interpolation_modes)
|
||||
# Downsample first, then upsample using the random mode.
|
||||
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST)
|
||||
img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode)
|
||||
elif 'color_shift' in aug:
|
||||
# Color shift
|
||||
pass
|
||||
|
|
|
@ -696,6 +696,16 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
torchvision.utils.save_image(self.lr[:, :3], os.path.join(experiments_path, "attention_maps",
|
||||
"amap_%i_base_image.png" % (step,)))
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"switch_temperature": temp}
|
||||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
def get_debug_values(self, step, net_name):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
|
|
|
@ -513,3 +513,93 @@ class RefDiscriminatorVgg128(nn.Module):
|
|||
|
||||
out = self.output_linears(torch.cat([fea, ref_vector], dim=1))
|
||||
return out
|
||||
|
||||
|
||||
class PsnrApproximator(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, nf, input_img_factor=1):
|
||||
super(PsnrApproximator, self).__init__()
|
||||
|
||||
# [64, 128, 128]
|
||||
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
|
||||
# [64, 128, 128]
|
||||
self.real_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.real_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.real_bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.real_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.real_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
self.real_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.real_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.real_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.real_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.real_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.real_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
|
||||
# [512, 16, 16]
|
||||
self.conv3_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
final_nf = nf * 8
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 1024)
|
||||
self.linear2 = nn.Linear(1024, 512)
|
||||
self.linear3 = nn.Linear(512, 128)
|
||||
self.linear4 = nn.Linear(128, 1)
|
||||
|
||||
def compute_body1(self, real):
|
||||
fea = self.lrelu(self.real_conv0_0(real))
|
||||
fea = self.lrelu(self.real_bn0_1(self.real_conv0_1(fea)))
|
||||
fea = self.lrelu(self.real_bn1_0(self.real_conv1_0(fea)))
|
||||
fea = self.lrelu(self.real_bn1_1(self.real_conv1_1(fea)))
|
||||
fea = self.lrelu(self.real_bn2_0(self.real_conv2_0(fea)))
|
||||
fea = self.lrelu(self.real_bn2_1(self.real_conv2_1(fea)))
|
||||
return fea
|
||||
|
||||
def compute_body2(self, fake):
|
||||
fea = self.lrelu(self.fake_conv0_0(fake))
|
||||
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
|
||||
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
|
||||
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
|
||||
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
|
||||
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
|
||||
return fea
|
||||
|
||||
def forward(self, real, fake):
|
||||
real_fea = checkpoint(self.compute_body1, real)
|
||||
fake_fea = checkpoint(self.compute_body2, fake)
|
||||
fea = torch.cat([real_fea, fake_fea], dim=1)
|
||||
|
||||
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||
|
||||
fea = fea.contiguous().view(fea.size(0), -1)
|
||||
fea = self.lrelu(self.linear1(fea))
|
||||
fea = self.lrelu(self.linear2(fea))
|
||||
fea = self.lrelu(self.linear3(fea))
|
||||
out = self.linear4(fea)
|
||||
return out.squeeze()
|
|
@ -160,6 +160,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
|||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
|
||||
elif which_model == "discriminator_refvgg":
|
||||
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||
elif which_model == "psnr_approximator":
|
||||
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import random
|
||||
|
||||
import torch.nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
@ -45,6 +47,12 @@ def create_injector(opt_inject, env):
|
|||
return ImageFftInjector(opt_inject, env)
|
||||
elif type == 'extract_indices':
|
||||
return IndicesExtractor(opt_inject, env)
|
||||
elif type == 'random_shift':
|
||||
return RandomShiftInjector(opt_inject, env)
|
||||
elif type == 'psnr':
|
||||
return PsnrInjector(opt_inject, env)
|
||||
elif type == 'batch_rotate':
|
||||
return BatchRotateInjector(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -94,12 +102,13 @@ class DiscriminatorInjector(Injector):
|
|||
super(DiscriminatorInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
d = self.env['discriminators'][self.opt['discriminator']]
|
||||
if isinstance(self.input, list):
|
||||
params = [state[i] for i in self.input]
|
||||
results = d(*params)
|
||||
else:
|
||||
results = d(state[self.input])
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
d = self.env['discriminators'][self.opt['discriminator']]
|
||||
if isinstance(self.input, list):
|
||||
params = [state[i] for i in self.input]
|
||||
results = d(*params)
|
||||
else:
|
||||
results = d(state[self.input])
|
||||
new_state = {}
|
||||
if isinstance(self.output, list):
|
||||
# Only dereference tuples or lists, not tensors.
|
||||
|
@ -232,10 +241,25 @@ class MarginRemoval(Injector):
|
|||
def __init__(self, opt, env):
|
||||
super(MarginRemoval, self).__init__(opt, env)
|
||||
self.margin = opt['margin']
|
||||
self.random_shift_max = opt['random_shift_max'] if 'random_shift_max' in opt.keys() else 0
|
||||
|
||||
def forward(self, state):
|
||||
input = state[self.input]
|
||||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
||||
if self.random_shift_max > 0:
|
||||
output = []
|
||||
# This is a really shitty way of doing this. If it works at all, I should reconsider using Resample2D, for example.
|
||||
for b in range(input.shape[0]):
|
||||
shiftleft = random.randint(-self.random_shift_max, self.random_shift_max)
|
||||
shifttop = random.randint(-self.random_shift_max, self.random_shift_max)
|
||||
output.append(input[b, :, self.margin+shiftleft:-(self.margin-shiftleft),
|
||||
self.margin+shifttop:-(self.margin-shifttop)])
|
||||
output = torch.stack(output, dim=0)
|
||||
else:
|
||||
output = input[:, :, self.margin:-self.margin,
|
||||
self.margin:-self.margin]
|
||||
|
||||
return {self.opt['out']: output}
|
||||
|
||||
|
||||
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
|
||||
class ForEachInjector(Injector):
|
||||
|
@ -254,7 +278,7 @@ class ForEachInjector(Injector):
|
|||
for i in range(inputs.shape[1]):
|
||||
st['_in'] = inputs[:, i]
|
||||
injs.append(self.injector(st)['_out'])
|
||||
return {self.output: torch.stack(injs, dim=1)}
|
||||
return {self.output: torch.stack(injs, dim=1)}
|
||||
|
||||
|
||||
class ConstantInjector(Injector):
|
||||
|
@ -316,3 +340,31 @@ class IndicesExtractor(Injector):
|
|||
results[o] = state[self.input][:, i]
|
||||
return results
|
||||
|
||||
|
||||
class RandomShiftInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(RandomShiftInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
img = state[self.input]
|
||||
return {self.output: img}
|
||||
|
||||
|
||||
class PsnrInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(PsnrInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
img1, img2 = state[self.input[0]], state[self.input[1]]
|
||||
mse = torch.mean((img1 - img2) ** 2, dim=[1,2,3])
|
||||
return {self.output: mse}
|
||||
|
||||
|
||||
class BatchRotateInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(BatchRotateInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
img = state[self.input]
|
||||
return {self.output: torch.roll(img, 1, 0)}
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_10bl_bypass.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_justfeature.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user