forked from mrq/DL-Art-School
More work on SSIM/PSNR approximators
- Add a network that accomodates this style of approximator while retaining structure - Migrate to SSIM approximation - Add a tool to visualize how these approximators are working - Fix some issues that came up while doign this work
This commit is contained in:
parent
85c545835c
commit
658a267bab
|
@ -602,4 +602,57 @@ class PsnrApproximator(nn.Module):
|
|||
fea = self.lrelu(self.linear2(fea))
|
||||
fea = self.lrelu(self.linear3(fea))
|
||||
out = self.linear4(fea)
|
||||
return out.squeeze()
|
||||
return out.squeeze()
|
||||
|
||||
|
||||
class SingleImageQualityEstimator(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, nf, input_img_factor=1):
|
||||
super(SingleImageQualityEstimator, self).__init__()
|
||||
|
||||
# [64, 128, 128]
|
||||
self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
|
||||
# [512, 16, 16]
|
||||
self.conv3_0 = nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False)
|
||||
self.bn3_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
self.conv3_1 = nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 2, 3, 1, 1, bias=True)
|
||||
self.conv4_2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
|
||||
self.conv4_3 = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=.2, inplace=True)
|
||||
|
||||
def compute_body(self, fake):
|
||||
fea = self.lrelu(self.fake_conv0_0(fake))
|
||||
fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea)))
|
||||
fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea)))
|
||||
fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea)))
|
||||
fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea)))
|
||||
fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea)))
|
||||
return fea
|
||||
|
||||
def forward(self, fake):
|
||||
fea = checkpoint(self.compute_body, fake)
|
||||
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||
fea = self.lrelu(self.conv4_0(fea))
|
||||
fea = self.lrelu(self.conv4_1(fea))
|
||||
fea = self.lrelu(self.conv4_2(fea))
|
||||
fea = self.sigmoid(self.conv4_3(fea))
|
||||
return fea
|
||||
|
|
80
codes/models/archs/pytorch_ssim.py
Normal file
80
codes/models/archs/pytorch_ssim.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
from math import exp
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, size_average=True, raw=False):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
elif raw:
|
||||
return ssim_map
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True, raw=False):
|
||||
super(SSIM, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.raw = raw
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
return _ssim(img1, img2, window, self.window_size, channel, self.size_average, self.raw)
|
||||
|
||||
|
||||
def ssim(img1, img2, window_size=11, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
|
@ -4,6 +4,7 @@ import torch.nn
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from models.archs.pytorch_ssim import SSIM
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
from models.steps.losses import extract_params_from_state
|
||||
|
||||
|
@ -353,11 +354,16 @@ class RandomShiftInjector(Injector):
|
|||
class PsnrInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(PsnrInjector, self).__init__(opt, env)
|
||||
self.ssim = SSIM(size_average=False, raw=True)
|
||||
self.scale = opt['output_scale_divisor']
|
||||
self.exp = opt['exponent'] if 'exponent' in opt.keys() else 1
|
||||
|
||||
def forward(self, state):
|
||||
img1, img2 = state[self.input[0]], state[self.input[1]]
|
||||
mse = torch.mean((img1 - img2) ** 2, dim=[1,2,3])
|
||||
return {self.output: mse}
|
||||
ssim = self.ssim(img1, img2)
|
||||
areal_se = torch.nn.functional.interpolate(ssim, scale_factor=1/self.scale,
|
||||
mode="area")
|
||||
return {self.output: areal_se}
|
||||
|
||||
|
||||
class BatchRotateInjector(Injector):
|
||||
|
|
|
@ -16,6 +16,8 @@ def create_loss(opt_loss, env):
|
|||
return create_teco_loss(opt_loss, env)
|
||||
elif type == 'pix':
|
||||
return PixLoss(opt_loss, env)
|
||||
elif type == 'direct':
|
||||
return DirectLoss(opt_loss, env)
|
||||
elif type == 'feature':
|
||||
return FeatureLoss(opt_loss, env)
|
||||
elif type == 'interpreted_feature':
|
||||
|
@ -89,9 +91,34 @@ class PixLoss(ConfigurableLoss):
|
|||
super(PixLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
self.real_scale = opt['real_scale'] if 'real_scale' in opt.keys() else 1
|
||||
self.real_offset = opt['real_offset'] if 'real_offset' in opt.keys() else 0
|
||||
self.report_metrics = opt['report_metrics'] if 'report_metrics' in opt.keys() else False
|
||||
|
||||
def forward(self, _, state):
|
||||
return self.criterion(state[self.opt['fake']].float(), state[self.opt['real']].float())
|
||||
real = state[self.opt['real']] * self.real_scale + float(self.real_offset)
|
||||
fake = state[self.opt['fake']]
|
||||
if self.report_metrics:
|
||||
self.metrics.append(("real_pix_mean_histogram", torch.mean(real, dim=[1,2,3]).detach()))
|
||||
self.metrics.append(("fake_pix_mean_histogram", torch.mean(fake, dim=[1,2,3]).detach()))
|
||||
self.metrics.append(("real_pix_std", torch.std(real).detach()))
|
||||
self.metrics.append(("fake_pix_std", torch.std(fake).detach()))
|
||||
return self.criterion(fake.float(), real.float())
|
||||
|
||||
|
||||
# Loss defined by averaging the input tensor across all dimensions an optionally inverting it.
|
||||
class DirectLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super(DirectLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.inverted = opt['inverted'] if 'inverted' in opt.keys() else False
|
||||
self.key = opt['key']
|
||||
|
||||
def forward(self, _, state):
|
||||
if self.inverted:
|
||||
return -torch.mean(state[self.key])
|
||||
else:
|
||||
return torch.mean(state[self.key])
|
||||
|
||||
|
||||
class FeatureLoss(ConfigurableLoss):
|
||||
|
|
|
@ -149,7 +149,8 @@ class ConfigurableStep(Module):
|
|||
for loss_name, loss in self.losses.items():
|
||||
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
|
||||
# be very disruptive to a generator.
|
||||
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
|
||||
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \
|
||||
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before']:
|
||||
continue
|
||||
l = loss(self.training_net, local_state)
|
||||
total_loss += l * self.weights[loss_name]
|
||||
|
|
90
codes/scripts/test_psnr_approximator.py
Normal file
90
codes/scripts/test_psnr_approximator.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import os.path as osp
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import os
|
||||
|
||||
import torchvision
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.util import bgr2ycbcr
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import models.networks as networks
|
||||
|
||||
if __name__ == "__main__":
|
||||
#### options
|
||||
torch.backends.cudnn.benchmark = True
|
||||
srg_analyze = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_psnr_approximator.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
||||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items()
|
||||
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
|
||||
#### Create test dataset and dataloader
|
||||
test_loaders = []
|
||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
||||
dataset_opt['n_workers'] = 0
|
||||
test_set = create_dataset(dataset_opt)
|
||||
test_loader = create_dataloader(test_set, dataset_opt, opt)
|
||||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
test_loaders.append(test_loader)
|
||||
|
||||
model = ExtensibleTrainer(opt)
|
||||
for test_loader in test_loaders:
|
||||
test_set_name = test_loader.dataset.opt['name']
|
||||
logger.info('\nTesting [{:s}]...'.format(test_set_name))
|
||||
test_start_time = time.time()
|
||||
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
|
||||
util.mkdir(dataset_dir)
|
||||
|
||||
dst_path = "F:\\playground"
|
||||
[os.makedirs(osp.join(dst_path, str(i)), exist_ok=True) for i in range(10)]
|
||||
|
||||
corruptions = ['none', 'color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise',
|
||||
'jpeg-medium', 'jpeg-broad', 'jpeg-normal', 'saturation', 'lq_resampling',
|
||||
'lq_resampling4x']
|
||||
c_counter = 0
|
||||
test_set.corruptor.num_corrupts = 0
|
||||
test_set.corruptor.random_corruptions = []
|
||||
test_set.corruptor.fixed_corruptions = [corruptions[0]]
|
||||
corruption_mse = [(0,0) for _ in corruptions]
|
||||
|
||||
tq = tqdm(test_loader)
|
||||
batch_size = opt['datasets']['train']['batch_size']
|
||||
for data in tq:
|
||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||
model.feed_data(data, need_GT=need_GT)
|
||||
model.test()
|
||||
est_psnr = torch.mean(model.eval_state['psnr_approximate'][0], dim=[1,2,3])
|
||||
for i in range(est_psnr.shape[0]):
|
||||
im_path = data['GT_path'][i]
|
||||
torchvision.utils.save_image(model.eval_state['lq'][0][i], osp.join(dst_path, str(int(est_psnr[i]*10)), osp.basename(im_path)))
|
||||
#shutil.copy(im_path, osp.join(dst_path, str(int(est_psnr[i]*10))))
|
||||
|
||||
last_mse, last_ctr = corruption_mse[c_counter % len(corruptions)]
|
||||
corruption_mse[c_counter % len(corruptions)] = (last_mse + torch.sum(est_psnr).item(), last_ctr + 1)
|
||||
c_counter += 1
|
||||
test_set.corruptor.fixed_corruptions = [corruptions[c_counter % len(corruptions)]]
|
||||
if c_counter % 100 == 0:
|
||||
for i, (mse, ctr) in enumerate(corruption_mse):
|
||||
print("%s: %f" % (corruptions[i], mse / (ctr * batch_size)))
|
|
@ -278,7 +278,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2classic_4x.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_artificial_quality.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -9,10 +9,16 @@ class LossAccumulator:
|
|||
|
||||
def add_loss(self, name, tensor):
|
||||
if name not in self.buffers.keys():
|
||||
self.buffers[name] = (0, torch.zeros(self.buffer_sz), False)
|
||||
if "_histogram" in name:
|
||||
tensor = torch.flatten(tensor.detach().cpu())
|
||||
self.buffers[name] = (0, torch.zeros((self.buffer_sz, tensor.shape[0])), False)
|
||||
else:
|
||||
self.buffers[name] = (0, torch.zeros(self.buffer_sz), False)
|
||||
i, buf, filled = self.buffers[name]
|
||||
# Can take tensors or just plain python numbers.
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
if '_histogram' in name:
|
||||
buf[i] = torch.flatten(tensor.detach().cpu())
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
buf[i] = tensor.detach().cpu()
|
||||
else:
|
||||
buf[i] = tensor
|
||||
|
@ -29,6 +35,8 @@ class LossAccumulator:
|
|||
result = {}
|
||||
for k, v in self.buffers.items():
|
||||
i, buf, filled = v
|
||||
if '_histogram' in k:
|
||||
result["loss_" + k] = torch.flatten(buf)
|
||||
if filled:
|
||||
result["loss_" + k] = torch.mean(buf)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user