More work on SSIM/PSNR approximators

- Add a network that accomodates this style of approximator while retaining structure
- Migrate to SSIM approximation
- Add a tool to visualize how these approximators are working
- Fix some issues that came up while doign this work
This commit is contained in:
James Betker 2020-11-03 08:09:58 -07:00
parent 85c545835c
commit 658a267bab
8 changed files with 273 additions and 8 deletions

View File

@ -602,4 +602,57 @@ class PsnrApproximator(nn.Module):
fea = self.lrelu(self.linear2(fea))
fea = self.lrelu(self.linear3(fea))
out = self.linear4(fea)
return out.squeeze()
return out.squeeze()
class SingleImageQualityEstimator(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, nf, input_img_factor=1):
super(SingleImageQualityEstimator, self).__init__()
# [64, 128, 128]
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [512, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.conv3_1 = nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 2, 3, 1, 1, bias=True)
self.conv4_2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
self.conv4_3 = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
self.sigmoid = nn.Sigmoid()
self.lrelu = nn.LeakyReLU(negative_slope=.2, inplace=True)
def compute_body(self, fake):
fea = self.lrelu(self.fake_conv0_0(fake))
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
return fea
def forward(self, fake):
fea = checkpoint(self.compute_body, fake)
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.conv4_0(fea))
fea = self.lrelu(self.conv4_1(fea))
fea = self.lrelu(self.conv4_2(fea))
fea = self.sigmoid(self.conv4_3(fea))
return fea

View File

@ -0,0 +1,80 @@
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True, raw=False):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
elif raw:
return ssim_map
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, raw=False):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.raw = raw
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average, self.raw)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)

View File

@ -4,6 +4,7 @@ import torch.nn
from torch.cuda.amp import autocast
from models.archs.SPSR_arch import ImageGradientNoPadding
from models.archs.pytorch_ssim import SSIM
from utils.weight_scheduler import get_scheduler_for_opt
from models.steps.losses import extract_params_from_state
@ -353,11 +354,16 @@ class RandomShiftInjector(Injector):
class PsnrInjector(Injector):
def __init__(self, opt, env):
super(PsnrInjector, self).__init__(opt, env)
self.ssim = SSIM(size_average=False, raw=True)
self.scale = opt['output_scale_divisor']
self.exp = opt['exponent'] if 'exponent' in opt.keys() else 1
def forward(self, state):
img1, img2 = state[self.input[0]], state[self.input[1]]
mse = torch.mean((img1 - img2) ** 2, dim=[1,2,3])
return {self.output: mse}
ssim = self.ssim(img1, img2)
areal_se = torch.nn.functional.interpolate(ssim, scale_factor=1/self.scale,
mode="area")
return {self.output: areal_se}
class BatchRotateInjector(Injector):

View File

@ -16,6 +16,8 @@ def create_loss(opt_loss, env):
return create_teco_loss(opt_loss, env)
elif type == 'pix':
return PixLoss(opt_loss, env)
elif type == 'direct':
return DirectLoss(opt_loss, env)
elif type == 'feature':
return FeatureLoss(opt_loss, env)
elif type == 'interpreted_feature':
@ -89,9 +91,34 @@ class PixLoss(ConfigurableLoss):
super(PixLoss, self).__init__(opt, env)
self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
self.real_scale = opt['real_scale'] if 'real_scale' in opt.keys() else 1
self.real_offset = opt['real_offset'] if 'real_offset' in opt.keys() else 0
self.report_metrics = opt['report_metrics'] if 'report_metrics' in opt.keys() else False
def forward(self, _, state):
return self.criterion(state[self.opt['fake']].float(), state[self.opt['real']].float())
real = state[self.opt['real']] * self.real_scale + float(self.real_offset)
fake = state[self.opt['fake']]
if self.report_metrics:
self.metrics.append(("real_pix_mean_histogram", torch.mean(real, dim=[1,2,3]).detach()))
self.metrics.append(("fake_pix_mean_histogram", torch.mean(fake, dim=[1,2,3]).detach()))
self.metrics.append(("real_pix_std", torch.std(real).detach()))
self.metrics.append(("fake_pix_std", torch.std(fake).detach()))
return self.criterion(fake.float(), real.float())
# Loss defined by averaging the input tensor across all dimensions an optionally inverting it.
class DirectLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(DirectLoss, self).__init__(opt, env)
self.opt = opt
self.inverted = opt['inverted'] if 'inverted' in opt.keys() else False
self.key = opt['key']
def forward(self, _, state):
if self.inverted:
return -torch.mean(state[self.key])
else:
return torch.mean(state[self.key])
class FeatureLoss(ConfigurableLoss):

View File

@ -149,7 +149,8 @@ class ConfigurableStep(Module):
for loss_name, loss in self.losses.items():
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
# be very disruptive to a generator.
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before']:
continue
l = loss(self.training_net, local_state)
total_loss += l * self.weights[loss_name]

View File

@ -0,0 +1,90 @@
import os.path as osp
import logging
import shutil
import time
import argparse
from collections import OrderedDict
import os
import torchvision
import utils
import utils.options as option
import utils.util as util
from data.util import bgr2ycbcr
import models.archs.SwitchedResidualGenerator_arch as srg
from models.ExtensibleTrainer import ExtensibleTrainer
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
from switched_conv.switched_conv import compute_attention_specificity
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import models.networks as networks
if __name__ == "__main__":
#### options
torch.backends.cudnn.benchmark = True
srg_analyze = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_psnr_approximator.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
dataset_opt['n_workers'] = 0
test_set = create_dataset(dataset_opt)
test_loader = create_dataloader(test_set, dataset_opt, opt)
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
model = ExtensibleTrainer(opt)
for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name']
logger.info('\nTesting [{:s}]...'.format(test_set_name))
test_start_time = time.time()
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
util.mkdir(dataset_dir)
dst_path = "F:\\playground"
[os.makedirs(osp.join(dst_path, str(i)), exist_ok=True) for i in range(10)]
corruptions = ['none', 'color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise',
'jpeg-medium', 'jpeg-broad', 'jpeg-normal', 'saturation', 'lq_resampling',
'lq_resampling4x']
c_counter = 0
test_set.corruptor.num_corrupts = 0
test_set.corruptor.random_corruptions = []
test_set.corruptor.fixed_corruptions = [corruptions[0]]
corruption_mse = [(0,0) for _ in corruptions]
tq = tqdm(test_loader)
batch_size = opt['datasets']['train']['batch_size']
for data in tq:
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
model.feed_data(data, need_GT=need_GT)
model.test()
est_psnr = torch.mean(model.eval_state['psnr_approximate'][0], dim=[1,2,3])
for i in range(est_psnr.shape[0]):
im_path = data['GT_path'][i]
torchvision.utils.save_image(model.eval_state['lq'][0][i], osp.join(dst_path, str(int(est_psnr[i]*10)), osp.basename(im_path)))
#shutil.copy(im_path, osp.join(dst_path, str(int(est_psnr[i]*10))))
last_mse, last_ctr = corruption_mse[c_counter % len(corruptions)]
corruption_mse[c_counter % len(corruptions)] = (last_mse + torch.sum(est_psnr).item(), last_ctr + 1)
c_counter += 1
test_set.corruptor.fixed_corruptions = [corruptions[c_counter % len(corruptions)]]
if c_counter % 100 == 0:
for i, (mse, ctr) in enumerate(corruption_mse):
print("%s: %f" % (corruptions[i], mse / (ctr * batch_size)))

View File

@ -278,7 +278,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2classic_4x.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_artificial_quality.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -9,10 +9,16 @@ class LossAccumulator:
def add_loss(self, name, tensor):
if name not in self.buffers.keys():
self.buffers[name] = (0, torch.zeros(self.buffer_sz), False)
if "_histogram" in name:
tensor = torch.flatten(tensor.detach().cpu())
self.buffers[name] = (0, torch.zeros((self.buffer_sz, tensor.shape[0])), False)
else:
self.buffers[name] = (0, torch.zeros(self.buffer_sz), False)
i, buf, filled = self.buffers[name]
# Can take tensors or just plain python numbers.
if isinstance(tensor, torch.Tensor):
if '_histogram' in name:
buf[i] = torch.flatten(tensor.detach().cpu())
elif isinstance(tensor, torch.Tensor):
buf[i] = tensor.detach().cpu()
else:
buf[i] = tensor
@ -29,6 +35,8 @@ class LossAccumulator:
result = {}
for k, v in self.buffers.items():
i, buf, filled = v
if '_histogram' in k:
result["loss_" + k] = torch.flatten(buf)
if filled:
result["loss_" + k] = torch.mean(buf)
else: