Corrupted features in srgan

This commit is contained in:
James Betker 2020-08-23 17:32:03 -06:00
parent dffc15184d
commit 7713cb8df5

View File

@ -20,6 +20,48 @@ import os
logger = logging.getLogger('base') logger = logging.getLogger('base')
class GaussianBlur(nn.Module):
def __init__(self):
super(GaussianBlur, self).__init__()
# Set these to whatever you want for your gaussian filter
kernel_size = 3
sigma = 2
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
x_cord = torch.arange(kernel_size)
x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1)
mean = (kernel_size - 1) / 2.
variance = sigma ** 2.
# Calculate the 2-dimensional gaussian kernel which is
# the product of two gaussian distributions for two different
# variables (in this case called x and y)
gaussian_kernel = (1. / (2. * 3.1415926 * variance)) * \
torch.exp(
-torch.sum((xy_grid - mean) ** 2., dim=-1) / \
(2 * variance)
)
# Make sure sum of values in gaussian kernel equals 1.
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
# Reshape to 2d depthwise convolutional weight
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1)
self.gaussian_filter = nn.Conv2d(in_channels=3, out_channels=3,
kernel_size=kernel_size, groups=3, bias=False)
self.gaussian_filter.weight.data = gaussian_kernel
self.gaussian_filter.weight.requires_grad = False
def forward(self, x):
return self.gaussian_filter(x)
class SRGANModel(BaseModel): class SRGANModel(BaseModel):
def __init__(self, opt): def __init__(self, opt):
super(SRGANModel, self).__init__(opt) super(SRGANModel, self).__init__(opt)
@ -132,6 +174,9 @@ class SRGANModel(BaseModel):
logger.info('Remove feature loss.') logger.info('Remove feature loss.')
self.cri_fea = None self.cri_fea = None
if self.cri_fea: # load VGG perceptual loss if self.cri_fea: # load VGG perceptual loss
self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False
if self.use_corrupted_feature_input:
self.feature_corruptor = GaussianBlur()
self.netF = networks.define_F(use_bn=False).to(self.device) self.netF = networks.define_F(use_bn=False).to(self.device)
self.lr_netF = None self.lr_netF = None
if 'lr_fea_path' in train_opt.keys(): if 'lr_fea_path' in train_opt.keys():
@ -385,7 +430,6 @@ class SRGANModel(BaseModel):
print("Misc setup %f" % (time() - _t,)) print("Misc setup %f" % (time() - _t,))
_t = time() _t = time()
if step >= self.init_iters:
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
self.fake_GenOut = [] self.fake_GenOut = []
self.fea_GenOut = [] self.fea_GenOut = []
@ -451,8 +495,14 @@ class SRGANModel(BaseModel):
if self.cri_fea and not using_gan_img and fea_w > 0: # feature loss if self.cri_fea and not using_gan_img and fea_w > 0: # feature loss
if self.lr_netF is not None: if self.lr_netF is not None:
real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale']) real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale'])
elif self.use_corrupted_feature_input:
cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:])
real_fea = self.netF(cor_Pix).detach()
else: else:
real_fea = self.netF(pix).detach() real_fea = self.netF(pix).detach()
if self.use_corrupted_feature_input:
fake_fea = self.netF(F.interpolate(self.feature_corruptor(fea_GenOut), size=var_L.shape[2:]))
else:
fake_fea = self.netF(fea_GenOut) fake_fea = self.netF(fea_GenOut)
l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea) l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea)
l_g_fea_log = l_g_fea / fea_w l_g_fea_log = l_g_fea / fea_w