Add SrPixLoss, which focuses pixel-based losses on high-frequency regions
of the image.
This commit is contained in:
parent
2cdac6bd09
commit
97d895aebe
|
@ -7,6 +7,8 @@ import random
|
|||
import functools
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def create_loss(opt_loss, env):
|
||||
type = opt_loss['type']
|
||||
|
@ -23,6 +25,8 @@ def create_loss(opt_loss, env):
|
|||
return CrossEntropy(opt_loss, env)
|
||||
elif type == 'pix':
|
||||
return PixLoss(opt_loss, env)
|
||||
elif type == 'sr_pix':
|
||||
return SrPixLoss(opt_loss, env)
|
||||
elif type == 'direct':
|
||||
return DirectLoss(opt_loss, env)
|
||||
elif type == 'feature':
|
||||
|
@ -143,6 +147,29 @@ class PixLoss(ConfigurableLoss):
|
|||
return self.criterion(fake.float(), real.float())
|
||||
|
||||
|
||||
class SrPixLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.base_loss = opt_get(opt, ['base_loss'], .2)
|
||||
self.exp = opt_get(opt, ['exp'], 2)
|
||||
self.scale = opt['scale']
|
||||
|
||||
def forward(self, _, state):
|
||||
real = state[self.opt['real']]
|
||||
fake = state[self.opt['fake']]
|
||||
l2 = (fake - real) ** 2
|
||||
self.metrics.append(("l2_loss", l2.mean()))
|
||||
# Adjust loss by prioritizing reconstruction of HF details.
|
||||
no_hf = F.interpolate(F.interpolate(real, scale_factor=1/self.scale, mode="area"), scale_factor=self.scale, mode="nearest")
|
||||
weights = (torch.abs(real - no_hf) + self.base_loss) ** self.exp
|
||||
weights = weights / weights.mean()
|
||||
loss = l2*weights
|
||||
# Preserve the intensity of the loss, just adjust the weighting.
|
||||
loss = loss*l2.mean()/loss.mean()
|
||||
return loss.mean()
|
||||
|
||||
|
||||
# Loss defined by averaging the input tensor across all dimensions an optionally inverting it.
|
||||
class DirectLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
|
|
Loading…
Reference in New Issue
Block a user