From 66e91a3d9e313036994288ae22c83124d68ac778 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 30 Apr 2020 11:45:07 -0600 Subject: [PATCH] Revert "Enable skip-through connections from disc to gen" This reverts commit b7857f35c34f9f2f7c176c2a1f3666554b5b13f9. --- codes/models/SRGAN_model.py | 35 +++++-------------- codes/models/archs/FlatProcessorNet_arch.py | 21 +++-------- codes/models/archs/discriminator_vgg_arch.py | 18 +--------- .../train/train_GAN_blacked_corrupt.yml | 23 ++++++------ 4 files changed, 25 insertions(+), 72 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 6b6c2855..ed9129c1 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -150,16 +150,9 @@ class SRGANModel(BaseModel): for p in self.netD.parameters(): p.requires_grad = False - disc_passthrough = None if step > self.D_init_iters: self.optimizer_G.zero_grad() - genOut = self.netG(self.var_L) - if type(genOut) is tuple: - self.fake_H = genOut[0] - disc_passthrough = genOut[1] - else: - self.fake_H = genOut - disc_passthrough = None + self.fake_H = self.netG(self.var_L) else: self.fake_H = self.pix @@ -186,14 +179,12 @@ class SRGANModel(BaseModel): if step % self.l_fea_w_decay_steps == 0: self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) - if disc_passthrough is not None: - pred_g_fake = self.netD(self.fake_H, disc_passthrough) - else: - pred_g_fake = self.netD(self.fake_H) if self.opt['train']['gan_type'] == 'gan': + pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() + pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 @@ -208,21 +199,15 @@ class SRGANModel(BaseModel): p.requires_grad = True self.optimizer_D.zero_grad() - if disc_passthrough is not None: - dp = {} - for k, v in disc_passthrough.items(): - dp[k] = v.detach() - pred_d_fake = self.netD(self.fake_H.detach(), dp) - else: - pred_d_fake = self.netD(self.fake_H.detach()) if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ - # reald + # real pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() # fake + pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() @@ -233,10 +218,12 @@ class SRGANModel(BaseModel): # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) # l_d_total = (l_d_real + l_d_fake) / 2 # l_d_total.backward() + pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_real = self.netD(self.var_ref) - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake.detach()), True) * 0.5 + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() + pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() @@ -258,11 +245,7 @@ class SRGANModel(BaseModel): def test(self): self.netG.eval() with torch.no_grad(): - genOut = self.netG(self.var_L) - if type(genOut) is tuple: - self.fake_H = genOut[0] - else: - self.fake_H = genOut + self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): diff --git a/codes/models/archs/FlatProcessorNet_arch.py b/codes/models/archs/FlatProcessorNet_arch.py index 374b221b..2ce1b978 100644 --- a/codes/models/archs/FlatProcessorNet_arch.py +++ b/codes/models/archs/FlatProcessorNet_arch.py @@ -23,12 +23,10 @@ class ReduceAnnealer(nn.Module): self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) arch_util.initialize_weights([self.reducer, self.annealer], .1) - self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True) - self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True) def forward(self, x, interpolated_trunk): - out = self.lrelu(self.bn_reduce(self.reducer(x))) - out = self.lrelu(self.bn_anneal(self.res_trunk(out))) + out = self.lrelu(self.reducer(x)) + out = self.lrelu(self.res_trunk(out)) annealed = self.lrelu(self.annealer(out)) + interpolated_trunk return annealed, out @@ -43,13 +41,11 @@ class Assembler(nn.Module): self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True) self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.bn = nn.BatchNorm2d(number_filters*4, affine=True) - self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True) def forward(self, input, skip_raw): out = self.pixel_shuffle(input) - out = self.bn_up(self.upsampler(out)) + skip_raw - out = self.lrelu(self.bn(self.res_trunk(out))) + out = self.upsampler(out) + skip_raw + out = self.lrelu(self.res_trunk(out)) return out class FlatProcessorNet(nn.Module): @@ -84,15 +80,10 @@ class FlatProcessorNet(nn.Module): # Produce assemblers for all possible downscale variants. Some may not be used. self.assembler1 = Assembler(nf, assembler_blocks) - self.assemble1_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True) self.assembler2 = Assembler(nf, assembler_blocks) - self.assemble2_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True) self.assembler3 = Assembler(nf, assembler_blocks) - self.assemble3_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True) self.assembler4 = Assembler(nf, assembler_blocks) - self.assemble4_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True) self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4] - self.assemble_convs = [self.assemble1_conv, self.assemble2_conv, self.assemble3_conv, self.assemble4_conv] # Initialization arch_util.initialize_weights([self.conv_first, self.conv_last], .1) @@ -113,10 +104,8 @@ class FlatProcessorNet(nn.Module): raw_values.append(raw) i = -1 - scaled_outputs = {} out = raw_values[-1] while downsamples != self.downscale: - scaled_outputs[int(x.shape[-1] / downsamples)] = self.assemble_convs[i](out) out = self.assemblers[i](out, raw_values[i-1]) i -= 1 downsamples = int(downsamples / 2) @@ -126,4 +115,4 @@ class FlatProcessorNet(nn.Module): basis = x if downsamples != 1: basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False) - return basis + out, scaled_outputs + return basis + out diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 42db4f6b..10a3ccdc 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import torchvision -import torch.nn.functional as F class Discriminator_VGG_128(nn.Module): @@ -12,17 +11,11 @@ class Discriminator_VGG_128(nn.Module): self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(nf, affine=True) - - self.skipconv0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) - # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) - - self.skipconv1 = nn.Conv2d(3, nf*2, 3, 1, 1, bias=True) - # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) @@ -45,22 +38,13 @@ class Discriminator_VGG_128(nn.Module): # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - def forward(self, x, gen_skips=None): - x_dim = x.size(-1) - if gen_skips is None: - gen_skips = { - int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False), - int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False), - } - + def forward(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) - fea = (fea + self.skipconv0(gen_skips[x_dim/2])) / 2 fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) - fea = (fea + self.skipconv1(gen_skips[x_dim/4])) / 2 fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) diff --git a/codes/options/train/train_GAN_blacked_corrupt.yml b/codes/options/train/train_GAN_blacked_corrupt.yml index c884ab70..f2f0edd6 100644 --- a/codes/options/train/train_GAN_blacked_corrupt.yml +++ b/codes/options/train/train_GAN_blacked_corrupt.yml @@ -16,8 +16,8 @@ datasets: dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training mismatched_Data_OK: true use_shuffle: true - n_workers: 8 # per GPU - batch_size: 32 + n_workers: 0 # per GPU + batch_size: 16 target_size: 64 use_flip: false use_rot: false @@ -34,17 +34,14 @@ network_G: which_model_G: FlatProcessorNet in_nc: 3 out_nc: 3 - nf: 32 - ra_blocks: 6 - assembler_blocks: 4 + nf: 48 + ra_blocks: 4 + assembler_blocks: 3 network_D: - which_model_D: discriminator_vgg_128 + which_model_D: discriminator_resnet in_nc: 3 nf: 64 - #which_model_D: discriminator_resnet - #in_nc: 3 - #nf: 32 #### path path: @@ -59,7 +56,7 @@ train: weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 2e-4 + lr_D: !!float 1e-4 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -74,11 +71,11 @@ train: pixel_weight: !!float 1e-2 feature_criterion: l1 feature_weight: 0 - gan_type: gan # gan | ragan + gan_type: ragan # gan | ragan gan_weight: !!float 1e-1 - D_update_ratio: 1 - D_init_iters: -1 + D_update_ratio: 2 + D_init_iters: 1200 manual_seed: 10 val_freq: !!float 5e2