Add SrPixLoss, which focuses pixel-based losses on high-frequency regions
of the image.
This commit is contained in:
parent
2cdac6bd09
commit
97d895aebe
|
@ -7,6 +7,8 @@ import random
|
||||||
import functools
|
import functools
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
def create_loss(opt_loss, env):
|
def create_loss(opt_loss, env):
|
||||||
type = opt_loss['type']
|
type = opt_loss['type']
|
||||||
|
@ -23,6 +25,8 @@ def create_loss(opt_loss, env):
|
||||||
return CrossEntropy(opt_loss, env)
|
return CrossEntropy(opt_loss, env)
|
||||||
elif type == 'pix':
|
elif type == 'pix':
|
||||||
return PixLoss(opt_loss, env)
|
return PixLoss(opt_loss, env)
|
||||||
|
elif type == 'sr_pix':
|
||||||
|
return SrPixLoss(opt_loss, env)
|
||||||
elif type == 'direct':
|
elif type == 'direct':
|
||||||
return DirectLoss(opt_loss, env)
|
return DirectLoss(opt_loss, env)
|
||||||
elif type == 'feature':
|
elif type == 'feature':
|
||||||
|
@ -143,6 +147,29 @@ class PixLoss(ConfigurableLoss):
|
||||||
return self.criterion(fake.float(), real.float())
|
return self.criterion(fake.float(), real.float())
|
||||||
|
|
||||||
|
|
||||||
|
class SrPixLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.opt = opt
|
||||||
|
self.base_loss = opt_get(opt, ['base_loss'], .2)
|
||||||
|
self.exp = opt_get(opt, ['exp'], 2)
|
||||||
|
self.scale = opt['scale']
|
||||||
|
|
||||||
|
def forward(self, _, state):
|
||||||
|
real = state[self.opt['real']]
|
||||||
|
fake = state[self.opt['fake']]
|
||||||
|
l2 = (fake - real) ** 2
|
||||||
|
self.metrics.append(("l2_loss", l2.mean()))
|
||||||
|
# Adjust loss by prioritizing reconstruction of HF details.
|
||||||
|
no_hf = F.interpolate(F.interpolate(real, scale_factor=1/self.scale, mode="area"), scale_factor=self.scale, mode="nearest")
|
||||||
|
weights = (torch.abs(real - no_hf) + self.base_loss) ** self.exp
|
||||||
|
weights = weights / weights.mean()
|
||||||
|
loss = l2*weights
|
||||||
|
# Preserve the intensity of the loss, just adjust the weighting.
|
||||||
|
loss = loss*l2.mean()/loss.mean()
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
# Loss defined by averaging the input tensor across all dimensions an optionally inverting it.
|
# Loss defined by averaging the input tensor across all dimensions an optionally inverting it.
|
||||||
class DirectLoss(ConfigurableLoss):
|
class DirectLoss(ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user