From 658a267babec8262939c6b70e2c36b48c24d3e88 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 3 Nov 2020 08:09:58 -0700 Subject: [PATCH] More work on SSIM/PSNR approximators - Add a network that accomodates this style of approximator while retaining structure - Migrate to SSIM approximation - Add a tool to visualize how these approximators are working - Fix some issues that came up while doign this work --- codes/models/archs/discriminator_vgg_arch.py | 55 +++++++++++- codes/models/archs/pytorch_ssim.py | 80 +++++++++++++++++ codes/models/steps/injectors.py | 10 ++- codes/models/steps/losses.py | 29 ++++++- codes/models/steps/steps.py | 3 +- codes/scripts/test_psnr_approximator.py | 90 ++++++++++++++++++++ codes/train2.py | 2 +- codes/utils/loss_accumulator.py | 12 ++- 8 files changed, 273 insertions(+), 8 deletions(-) create mode 100644 codes/models/archs/pytorch_ssim.py create mode 100644 codes/scripts/test_psnr_approximator.py diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 66d64a8e..15539355 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -602,4 +602,57 @@ class PsnrApproximator(nn.Module): fea = self.lrelu(self.linear2(fea)) fea = self.lrelu(self.linear3(fea)) out = self.linear4(fea) - return out.squeeze() \ No newline at end of file + return out.squeeze() + + +class SingleImageQualityEstimator(nn.Module): + # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. + def __init__(self, nf, input_img_factor=1): + super(SingleImageQualityEstimator, self).__init__() + + # [64, 128, 128] + self.fake_conv0_0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) + self.fake_conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.fake_bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.fake_conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.fake_bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.fake_conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.fake_bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 32, 32] + self.fake_conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.fake_bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.fake_conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.fake_bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + + # [512, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 4, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.conv3_1 = nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 2, 3, 1, 1, bias=True) + self.conv4_2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) + self.conv4_3 = nn.Conv2d(nf, 3, 3, 1, 1, bias=True) + self.sigmoid = nn.Sigmoid() + self.lrelu = nn.LeakyReLU(negative_slope=.2, inplace=True) + + def compute_body(self, fake): + fea = self.lrelu(self.fake_conv0_0(fake)) + fea = self.lrelu(self.fake_bn0_1(self.fake_conv0_1(fea))) + fea = self.lrelu(self.fake_bn1_0(self.fake_conv1_0(fea))) + fea = self.lrelu(self.fake_bn1_1(self.fake_conv1_1(fea))) + fea = self.lrelu(self.fake_bn2_0(self.fake_conv2_0(fea))) + fea = self.lrelu(self.fake_bn2_1(self.fake_conv2_1(fea))) + return fea + + def forward(self, fake): + fea = checkpoint(self.compute_body, fake) + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + fea = self.lrelu(self.conv4_0(fea)) + fea = self.lrelu(self.conv4_1(fea)) + fea = self.lrelu(self.conv4_2(fea)) + fea = self.sigmoid(self.conv4_3(fea)) + return fea diff --git a/codes/models/archs/pytorch_ssim.py b/codes/models/archs/pytorch_ssim.py new file mode 100644 index 00000000..5bdadb79 --- /dev/null +++ b/codes/models/archs/pytorch_ssim.py @@ -0,0 +1,80 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True, raw=False): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + elif raw: + return ssim_map + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, raw=False): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.raw = raw + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average, self.raw) + + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) \ No newline at end of file diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 2e277e66..c5e592fa 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -4,6 +4,7 @@ import torch.nn from torch.cuda.amp import autocast from models.archs.SPSR_arch import ImageGradientNoPadding +from models.archs.pytorch_ssim import SSIM from utils.weight_scheduler import get_scheduler_for_opt from models.steps.losses import extract_params_from_state @@ -353,11 +354,16 @@ class RandomShiftInjector(Injector): class PsnrInjector(Injector): def __init__(self, opt, env): super(PsnrInjector, self).__init__(opt, env) + self.ssim = SSIM(size_average=False, raw=True) + self.scale = opt['output_scale_divisor'] + self.exp = opt['exponent'] if 'exponent' in opt.keys() else 1 def forward(self, state): img1, img2 = state[self.input[0]], state[self.input[1]] - mse = torch.mean((img1 - img2) ** 2, dim=[1,2,3]) - return {self.output: mse} + ssim = self.ssim(img1, img2) + areal_se = torch.nn.functional.interpolate(ssim, scale_factor=1/self.scale, + mode="area") + return {self.output: areal_se} class BatchRotateInjector(Injector): diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index d7efdfb6..e8cb5743 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -16,6 +16,8 @@ def create_loss(opt_loss, env): return create_teco_loss(opt_loss, env) elif type == 'pix': return PixLoss(opt_loss, env) + elif type == 'direct': + return DirectLoss(opt_loss, env) elif type == 'feature': return FeatureLoss(opt_loss, env) elif type == 'interpreted_feature': @@ -89,9 +91,34 @@ class PixLoss(ConfigurableLoss): super(PixLoss, self).__init__(opt, env) self.opt = opt self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) + self.real_scale = opt['real_scale'] if 'real_scale' in opt.keys() else 1 + self.real_offset = opt['real_offset'] if 'real_offset' in opt.keys() else 0 + self.report_metrics = opt['report_metrics'] if 'report_metrics' in opt.keys() else False def forward(self, _, state): - return self.criterion(state[self.opt['fake']].float(), state[self.opt['real']].float()) + real = state[self.opt['real']] * self.real_scale + float(self.real_offset) + fake = state[self.opt['fake']] + if self.report_metrics: + self.metrics.append(("real_pix_mean_histogram", torch.mean(real, dim=[1,2,3]).detach())) + self.metrics.append(("fake_pix_mean_histogram", torch.mean(fake, dim=[1,2,3]).detach())) + self.metrics.append(("real_pix_std", torch.std(real).detach())) + self.metrics.append(("fake_pix_std", torch.std(fake).detach())) + return self.criterion(fake.float(), real.float()) + + +# Loss defined by averaging the input tensor across all dimensions an optionally inverting it. +class DirectLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(DirectLoss, self).__init__(opt, env) + self.opt = opt + self.inverted = opt['inverted'] if 'inverted' in opt.keys() else False + self.key = opt['key'] + + def forward(self, _, state): + if self.inverted: + return -torch.mean(state[self.key]) + else: + return torch.mean(state[self.key]) class FeatureLoss(ConfigurableLoss): diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index cb92749a..e3faadbe 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -149,7 +149,8 @@ class ConfigurableStep(Module): for loss_name, loss in self.losses.items(): # Some losses only activate after a set number of steps. For example, proto-discriminator losses can # be very disruptive to a generator. - if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']: + if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \ + 'before' in loss.opt.keys() and self.env['step'] > loss.opt['before']: continue l = loss(self.training_net, local_state) total_loss += l * self.weights[loss_name] diff --git a/codes/scripts/test_psnr_approximator.py b/codes/scripts/test_psnr_approximator.py new file mode 100644 index 00000000..9e274792 --- /dev/null +++ b/codes/scripts/test_psnr_approximator.py @@ -0,0 +1,90 @@ +import os.path as osp +import logging +import shutil +import time +import argparse +from collections import OrderedDict + +import os + +import torchvision + +import utils +import utils.options as option +import utils.util as util +from data.util import bgr2ycbcr +import models.archs.SwitchedResidualGenerator_arch as srg +from models.ExtensibleTrainer import ExtensibleTrainer +from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb +from switched_conv.switched_conv import compute_attention_specificity +from data import create_dataset, create_dataloader +from tqdm import tqdm +import torch +import models.networks as networks + +if __name__ == "__main__": + #### options + torch.backends.cudnn.benchmark = True + srg_analyze = False + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_psnr_approximator.yml') + opt = option.parse(parser.parse_args().opt, is_train=False) + opt = option.dict_to_nonedict(opt) + utils.util.loaded_options = opt + + util.mkdirs( + (path for key, path in opt['path'].items() + if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) + util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + + #### Create test dataset and dataloader + test_loaders = [] + for phase, dataset_opt in sorted(opt['datasets'].items()): + dataset_opt['n_workers'] = 0 + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader(test_set, dataset_opt, opt) + logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + + model = ExtensibleTrainer(opt) + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info('\nTesting [{:s}]...'.format(test_set_name)) + test_start_time = time.time() + dataset_dir = osp.join(opt['path']['results_root'], test_set_name) + util.mkdir(dataset_dir) + + dst_path = "F:\\playground" + [os.makedirs(osp.join(dst_path, str(i)), exist_ok=True) for i in range(10)] + + corruptions = ['none', 'color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise', + 'jpeg-medium', 'jpeg-broad', 'jpeg-normal', 'saturation', 'lq_resampling', + 'lq_resampling4x'] + c_counter = 0 + test_set.corruptor.num_corrupts = 0 + test_set.corruptor.random_corruptions = [] + test_set.corruptor.fixed_corruptions = [corruptions[0]] + corruption_mse = [(0,0) for _ in corruptions] + + tq = tqdm(test_loader) + batch_size = opt['datasets']['train']['batch_size'] + for data in tq: + need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True + model.feed_data(data, need_GT=need_GT) + model.test() + est_psnr = torch.mean(model.eval_state['psnr_approximate'][0], dim=[1,2,3]) + for i in range(est_psnr.shape[0]): + im_path = data['GT_path'][i] + torchvision.utils.save_image(model.eval_state['lq'][0][i], osp.join(dst_path, str(int(est_psnr[i]*10)), osp.basename(im_path))) + #shutil.copy(im_path, osp.join(dst_path, str(int(est_psnr[i]*10)))) + + last_mse, last_ctr = corruption_mse[c_counter % len(corruptions)] + corruption_mse[c_counter % len(corruptions)] = (last_mse + torch.sum(est_psnr).item(), last_ctr + 1) + c_counter += 1 + test_set.corruptor.fixed_corruptions = [corruptions[c_counter % len(corruptions)]] + if c_counter % 100 == 0: + for i, (mse, ctr) in enumerate(corruption_mse): + print("%s: %f" % (corruptions[i], mse / (ctr * batch_size))) \ No newline at end of file diff --git a/codes/train2.py b/codes/train2.py index 2f282f46..c67a20d5 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -278,7 +278,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2classic_4x.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_artificial_quality.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py index bfc7e772..896138e9 100644 --- a/codes/utils/loss_accumulator.py +++ b/codes/utils/loss_accumulator.py @@ -9,10 +9,16 @@ class LossAccumulator: def add_loss(self, name, tensor): if name not in self.buffers.keys(): - self.buffers[name] = (0, torch.zeros(self.buffer_sz), False) + if "_histogram" in name: + tensor = torch.flatten(tensor.detach().cpu()) + self.buffers[name] = (0, torch.zeros((self.buffer_sz, tensor.shape[0])), False) + else: + self.buffers[name] = (0, torch.zeros(self.buffer_sz), False) i, buf, filled = self.buffers[name] # Can take tensors or just plain python numbers. - if isinstance(tensor, torch.Tensor): + if '_histogram' in name: + buf[i] = torch.flatten(tensor.detach().cpu()) + elif isinstance(tensor, torch.Tensor): buf[i] = tensor.detach().cpu() else: buf[i] = tensor @@ -29,6 +35,8 @@ class LossAccumulator: result = {} for k, v in self.buffers.items(): i, buf, filled = v + if '_histogram' in k: + result["loss_" + k] = torch.flatten(buf) if filled: result["loss_" + k] = torch.mean(buf) else: