Corrupted features in srgan
This commit is contained in:
parent
dffc15184d
commit
7713cb8df5
|
@ -20,6 +20,48 @@ import os
|
|||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
class GaussianBlur(nn.Module):
|
||||
def __init__(self):
|
||||
super(GaussianBlur, self).__init__()
|
||||
|
||||
# Set these to whatever you want for your gaussian filter
|
||||
kernel_size = 3
|
||||
sigma = 2
|
||||
|
||||
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
|
||||
x_cord = torch.arange(kernel_size)
|
||||
x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
|
||||
y_grid = x_grid.t()
|
||||
xy_grid = torch.stack([x_grid, y_grid], dim=-1)
|
||||
|
||||
mean = (kernel_size - 1) / 2.
|
||||
variance = sigma ** 2.
|
||||
|
||||
# Calculate the 2-dimensional gaussian kernel which is
|
||||
# the product of two gaussian distributions for two different
|
||||
# variables (in this case called x and y)
|
||||
gaussian_kernel = (1. / (2. * 3.1415926 * variance)) * \
|
||||
torch.exp(
|
||||
-torch.sum((xy_grid - mean) ** 2., dim=-1) / \
|
||||
(2 * variance)
|
||||
)
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
|
||||
|
||||
# Reshape to 2d depthwise convolutional weight
|
||||
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
|
||||
gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1)
|
||||
|
||||
self.gaussian_filter = nn.Conv2d(in_channels=3, out_channels=3,
|
||||
kernel_size=kernel_size, groups=3, bias=False)
|
||||
|
||||
self.gaussian_filter.weight.data = gaussian_kernel
|
||||
self.gaussian_filter.weight.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
return self.gaussian_filter(x)
|
||||
|
||||
|
||||
class SRGANModel(BaseModel):
|
||||
def __init__(self, opt):
|
||||
super(SRGANModel, self).__init__(opt)
|
||||
|
@ -132,6 +174,9 @@ class SRGANModel(BaseModel):
|
|||
logger.info('Remove feature loss.')
|
||||
self.cri_fea = None
|
||||
if self.cri_fea: # load VGG perceptual loss
|
||||
self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False
|
||||
if self.use_corrupted_feature_input:
|
||||
self.feature_corruptor = GaussianBlur()
|
||||
self.netF = networks.define_F(use_bn=False).to(self.device)
|
||||
self.lr_netF = None
|
||||
if 'lr_fea_path' in train_opt.keys():
|
||||
|
@ -385,7 +430,6 @@ class SRGANModel(BaseModel):
|
|||
print("Misc setup %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
if step >= self.init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_GenOut = []
|
||||
self.fea_GenOut = []
|
||||
|
@ -451,8 +495,14 @@ class SRGANModel(BaseModel):
|
|||
if self.cri_fea and not using_gan_img and fea_w > 0: # feature loss
|
||||
if self.lr_netF is not None:
|
||||
real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale'])
|
||||
elif self.use_corrupted_feature_input:
|
||||
cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:])
|
||||
real_fea = self.netF(cor_Pix).detach()
|
||||
else:
|
||||
real_fea = self.netF(pix).detach()
|
||||
if self.use_corrupted_feature_input:
|
||||
fake_fea = self.netF(F.interpolate(self.feature_corruptor(fea_GenOut), size=var_L.shape[2:]))
|
||||
else:
|
||||
fake_fea = self.netF(fea_GenOut)
|
||||
l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_fea_log = l_g_fea / fea_w
|
||||
|
|
Loading…
Reference in New Issue
Block a user