Add SrPixLoss, which focuses pixel-based losses on high-frequency regions

of the image.
This commit is contained in:
James Betker 2021-01-25 08:26:14 -07:00
parent 2cdac6bd09
commit 97d895aebe

View File

@ -7,6 +7,8 @@ import random
import functools import functools
import torch.nn.functional as F import torch.nn.functional as F
from utils.util import opt_get
def create_loss(opt_loss, env): def create_loss(opt_loss, env):
type = opt_loss['type'] type = opt_loss['type']
@ -23,6 +25,8 @@ def create_loss(opt_loss, env):
return CrossEntropy(opt_loss, env) return CrossEntropy(opt_loss, env)
elif type == 'pix': elif type == 'pix':
return PixLoss(opt_loss, env) return PixLoss(opt_loss, env)
elif type == 'sr_pix':
return SrPixLoss(opt_loss, env)
elif type == 'direct': elif type == 'direct':
return DirectLoss(opt_loss, env) return DirectLoss(opt_loss, env)
elif type == 'feature': elif type == 'feature':
@ -143,6 +147,29 @@ class PixLoss(ConfigurableLoss):
return self.criterion(fake.float(), real.float()) return self.criterion(fake.float(), real.float())
class SrPixLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.opt = opt
self.base_loss = opt_get(opt, ['base_loss'], .2)
self.exp = opt_get(opt, ['exp'], 2)
self.scale = opt['scale']
def forward(self, _, state):
real = state[self.opt['real']]
fake = state[self.opt['fake']]
l2 = (fake - real) ** 2
self.metrics.append(("l2_loss", l2.mean()))
# Adjust loss by prioritizing reconstruction of HF details.
no_hf = F.interpolate(F.interpolate(real, scale_factor=1/self.scale, mode="area"), scale_factor=self.scale, mode="nearest")
weights = (torch.abs(real - no_hf) + self.base_loss) ** self.exp
weights = weights / weights.mean()
loss = l2*weights
# Preserve the intensity of the loss, just adjust the weighting.
loss = loss*l2.mean()/loss.mean()
return loss.mean()
# Loss defined by averaging the input tensor across all dimensions an optionally inverting it. # Loss defined by averaging the input tensor across all dimensions an optionally inverting it.
class DirectLoss(ConfigurableLoss): class DirectLoss(ConfigurableLoss):
def __init__(self, opt, env): def __init__(self, opt, env):