From 97d895aebe38e1c614dfcd2bfed1ab5cd383ac47 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 25 Jan 2021 08:26:14 -0700 Subject: [PATCH] Add SrPixLoss, which focuses pixel-based losses on high-frequency regions of the image. --- codes/trainer/losses.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index e1ea1c1d..c2d81147 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -7,6 +7,8 @@ import random import functools import torch.nn.functional as F +from utils.util import opt_get + def create_loss(opt_loss, env): type = opt_loss['type'] @@ -23,6 +25,8 @@ def create_loss(opt_loss, env): return CrossEntropy(opt_loss, env) elif type == 'pix': return PixLoss(opt_loss, env) + elif type == 'sr_pix': + return SrPixLoss(opt_loss, env) elif type == 'direct': return DirectLoss(opt_loss, env) elif type == 'feature': @@ -143,6 +147,29 @@ class PixLoss(ConfigurableLoss): return self.criterion(fake.float(), real.float()) +class SrPixLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.opt = opt + self.base_loss = opt_get(opt, ['base_loss'], .2) + self.exp = opt_get(opt, ['exp'], 2) + self.scale = opt['scale'] + + def forward(self, _, state): + real = state[self.opt['real']] + fake = state[self.opt['fake']] + l2 = (fake - real) ** 2 + self.metrics.append(("l2_loss", l2.mean())) + # Adjust loss by prioritizing reconstruction of HF details. + no_hf = F.interpolate(F.interpolate(real, scale_factor=1/self.scale, mode="area"), scale_factor=self.scale, mode="nearest") + weights = (torch.abs(real - no_hf) + self.base_loss) ** self.exp + weights = weights / weights.mean() + loss = l2*weights + # Preserve the intensity of the loss, just adjust the weighting. + loss = loss*l2.mean()/loss.mean() + return loss.mean() + + # Loss defined by averaging the input tensor across all dimensions an optionally inverting it. class DirectLoss(ConfigurableLoss): def __init__(self, opt, env):