forked from mrq/DL-Art-School
Corrupted features in srgan
This commit is contained in:
parent
dffc15184d
commit
7713cb8df5
|
@ -20,6 +20,48 @@ import os
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianBlur(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(GaussianBlur, self).__init__()
|
||||||
|
|
||||||
|
# Set these to whatever you want for your gaussian filter
|
||||||
|
kernel_size = 3
|
||||||
|
sigma = 2
|
||||||
|
|
||||||
|
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
|
||||||
|
x_cord = torch.arange(kernel_size)
|
||||||
|
x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
|
||||||
|
y_grid = x_grid.t()
|
||||||
|
xy_grid = torch.stack([x_grid, y_grid], dim=-1)
|
||||||
|
|
||||||
|
mean = (kernel_size - 1) / 2.
|
||||||
|
variance = sigma ** 2.
|
||||||
|
|
||||||
|
# Calculate the 2-dimensional gaussian kernel which is
|
||||||
|
# the product of two gaussian distributions for two different
|
||||||
|
# variables (in this case called x and y)
|
||||||
|
gaussian_kernel = (1. / (2. * 3.1415926 * variance)) * \
|
||||||
|
torch.exp(
|
||||||
|
-torch.sum((xy_grid - mean) ** 2., dim=-1) / \
|
||||||
|
(2 * variance)
|
||||||
|
)
|
||||||
|
# Make sure sum of values in gaussian kernel equals 1.
|
||||||
|
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
|
||||||
|
|
||||||
|
# Reshape to 2d depthwise convolutional weight
|
||||||
|
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
|
||||||
|
gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1)
|
||||||
|
|
||||||
|
self.gaussian_filter = nn.Conv2d(in_channels=3, out_channels=3,
|
||||||
|
kernel_size=kernel_size, groups=3, bias=False)
|
||||||
|
|
||||||
|
self.gaussian_filter.weight.data = gaussian_kernel
|
||||||
|
self.gaussian_filter.weight.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.gaussian_filter(x)
|
||||||
|
|
||||||
|
|
||||||
class SRGANModel(BaseModel):
|
class SRGANModel(BaseModel):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(SRGANModel, self).__init__(opt)
|
super(SRGANModel, self).__init__(opt)
|
||||||
|
@ -132,6 +174,9 @@ class SRGANModel(BaseModel):
|
||||||
logger.info('Remove feature loss.')
|
logger.info('Remove feature loss.')
|
||||||
self.cri_fea = None
|
self.cri_fea = None
|
||||||
if self.cri_fea: # load VGG perceptual loss
|
if self.cri_fea: # load VGG perceptual loss
|
||||||
|
self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False
|
||||||
|
if self.use_corrupted_feature_input:
|
||||||
|
self.feature_corruptor = GaussianBlur()
|
||||||
self.netF = networks.define_F(use_bn=False).to(self.device)
|
self.netF = networks.define_F(use_bn=False).to(self.device)
|
||||||
self.lr_netF = None
|
self.lr_netF = None
|
||||||
if 'lr_fea_path' in train_opt.keys():
|
if 'lr_fea_path' in train_opt.keys():
|
||||||
|
@ -385,7 +430,6 @@ class SRGANModel(BaseModel):
|
||||||
print("Misc setup %f" % (time() - _t,))
|
print("Misc setup %f" % (time() - _t,))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
if step >= self.init_iters:
|
|
||||||
self.optimizer_G.zero_grad()
|
self.optimizer_G.zero_grad()
|
||||||
self.fake_GenOut = []
|
self.fake_GenOut = []
|
||||||
self.fea_GenOut = []
|
self.fea_GenOut = []
|
||||||
|
@ -451,8 +495,14 @@ class SRGANModel(BaseModel):
|
||||||
if self.cri_fea and not using_gan_img and fea_w > 0: # feature loss
|
if self.cri_fea and not using_gan_img and fea_w > 0: # feature loss
|
||||||
if self.lr_netF is not None:
|
if self.lr_netF is not None:
|
||||||
real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale'])
|
real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale'])
|
||||||
|
elif self.use_corrupted_feature_input:
|
||||||
|
cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:])
|
||||||
|
real_fea = self.netF(cor_Pix).detach()
|
||||||
else:
|
else:
|
||||||
real_fea = self.netF(pix).detach()
|
real_fea = self.netF(pix).detach()
|
||||||
|
if self.use_corrupted_feature_input:
|
||||||
|
fake_fea = self.netF(F.interpolate(self.feature_corruptor(fea_GenOut), size=var_L.shape[2:]))
|
||||||
|
else:
|
||||||
fake_fea = self.netF(fea_GenOut)
|
fake_fea = self.netF(fea_GenOut)
|
||||||
l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea)
|
l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea)
|
||||||
l_g_fea_log = l_g_fea / fea_w
|
l_g_fea_log = l_g_fea / fea_w
|
||||||
|
|
Loading…
Reference in New Issue
Block a user