From 7713cb8df5fb2ee6137ad83202622063a013bd9b Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 23 Aug 2020 17:32:03 -0600 Subject: [PATCH] Corrupted features in srgan --- codes/models/SRGAN_model.py | 56 +++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 406b8850..58718f98 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -20,6 +20,48 @@ import os logger = logging.getLogger('base') +class GaussianBlur(nn.Module): + def __init__(self): + super(GaussianBlur, self).__init__() + + # Set these to whatever you want for your gaussian filter + kernel_size = 3 + sigma = 2 + + # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) + x_cord = torch.arange(kernel_size) + x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1) + + mean = (kernel_size - 1) / 2. + variance = sigma ** 2. + + # Calculate the 2-dimensional gaussian kernel which is + # the product of two gaussian distributions for two different + # variables (in this case called x and y) + gaussian_kernel = (1. / (2. * 3.1415926 * variance)) * \ + torch.exp( + -torch.sum((xy_grid - mean) ** 2., dim=-1) / \ + (2 * variance) + ) + # Make sure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + + # Reshape to 2d depthwise convolutional weight + gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) + gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1) + + self.gaussian_filter = nn.Conv2d(in_channels=3, out_channels=3, + kernel_size=kernel_size, groups=3, bias=False) + + self.gaussian_filter.weight.data = gaussian_kernel + self.gaussian_filter.weight.requires_grad = False + + def forward(self, x): + return self.gaussian_filter(x) + + class SRGANModel(BaseModel): def __init__(self, opt): super(SRGANModel, self).__init__(opt) @@ -132,6 +174,9 @@ class SRGANModel(BaseModel): logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss + self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False + if self.use_corrupted_feature_input: + self.feature_corruptor = GaussianBlur() self.netF = networks.define_F(use_bn=False).to(self.device) self.lr_netF = None if 'lr_fea_path' in train_opt.keys(): @@ -385,8 +430,7 @@ class SRGANModel(BaseModel): print("Misc setup %f" % (time() - _t,)) _t = time() - if step >= self.init_iters: - self.optimizer_G.zero_grad() + self.optimizer_G.zero_grad() self.fake_GenOut = [] self.fea_GenOut = [] self.fake_H = [] @@ -451,9 +495,15 @@ class SRGANModel(BaseModel): if self.cri_fea and not using_gan_img and fea_w > 0: # feature loss if self.lr_netF is not None: real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale']) + elif self.use_corrupted_feature_input: + cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:]) + real_fea = self.netF(cor_Pix).detach() else: real_fea = self.netF(pix).detach() - fake_fea = self.netF(fea_GenOut) + if self.use_corrupted_feature_input: + fake_fea = self.netF(F.interpolate(self.feature_corruptor(fea_GenOut), size=var_L.shape[2:])) + else: + fake_fea = self.netF(fea_GenOut) l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea) l_g_fea_log = l_g_fea / fea_w l_g_total += l_g_fea