From b7857f35c34f9f2f7c176c2a1f3666554b5b13f9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 30 Apr 2020 11:30:11 -0600 Subject: [PATCH] Enable skip-through connections from disc to gen --- codes/models/SRGAN_model.py | 35 ++++++++++++++----- codes/models/archs/FlatProcessorNet_arch.py | 21 ++++++++--- codes/models/archs/discriminator_vgg_arch.py | 18 +++++++++- .../train/train_GAN_blacked_corrupt.yml | 23 ++++++------ 4 files changed, 72 insertions(+), 25 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index ed9129c1..6b6c2855 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -150,9 +150,16 @@ class SRGANModel(BaseModel): for p in self.netD.parameters(): p.requires_grad = False + disc_passthrough = None if step > self.D_init_iters: self.optimizer_G.zero_grad() - self.fake_H = self.netG(self.var_L) + genOut = self.netG(self.var_L) + if type(genOut) is tuple: + self.fake_H = genOut[0] + disc_passthrough = genOut[1] + else: + self.fake_H = genOut + disc_passthrough = None else: self.fake_H = self.pix @@ -179,12 +186,14 @@ class SRGANModel(BaseModel): if step % self.l_fea_w_decay_steps == 0: self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) - if self.opt['train']['gan_type'] == 'gan': + if disc_passthrough is not None: + pred_g_fake = self.netD(self.fake_H, disc_passthrough) + else: pred_g_fake = self.netD(self.fake_H) + if self.opt['train']['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() - pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 @@ -199,15 +208,21 @@ class SRGANModel(BaseModel): p.requires_grad = True self.optimizer_D.zero_grad() + if disc_passthrough is not None: + dp = {} + for k, v in disc_passthrough.items(): + dp[k] = v.detach() + pred_d_fake = self.netD(self.fake_H.detach(), dp) + else: + pred_d_fake = self.netD(self.fake_H.detach()) if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ - # real + # reald pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() # fake - pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() @@ -218,12 +233,10 @@ class SRGANModel(BaseModel): # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) # l_d_total = (l_d_real + l_d_fake) / 2 # l_d_total.backward() - pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_real = self.netD(self.var_ref) - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake.detach()), True) * 0.5 with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() - pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() @@ -245,7 +258,11 @@ class SRGANModel(BaseModel): def test(self): self.netG.eval() with torch.no_grad(): - self.fake_H = self.netG(self.var_L) + genOut = self.netG(self.var_L) + if type(genOut) is tuple: + self.fake_H = genOut[0] + else: + self.fake_H = genOut self.netG.train() def get_current_log(self): diff --git a/codes/models/archs/FlatProcessorNet_arch.py b/codes/models/archs/FlatProcessorNet_arch.py index 2ce1b978..374b221b 100644 --- a/codes/models/archs/FlatProcessorNet_arch.py +++ b/codes/models/archs/FlatProcessorNet_arch.py @@ -23,10 +23,12 @@ class ReduceAnnealer(nn.Module): self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) arch_util.initialize_weights([self.reducer, self.annealer], .1) + self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True) + self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True) def forward(self, x, interpolated_trunk): - out = self.lrelu(self.reducer(x)) - out = self.lrelu(self.res_trunk(out)) + out = self.lrelu(self.bn_reduce(self.reducer(x))) + out = self.lrelu(self.bn_anneal(self.res_trunk(out))) annealed = self.lrelu(self.annealer(out)) + interpolated_trunk return annealed, out @@ -41,11 +43,13 @@ class Assembler(nn.Module): self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True) self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.bn = nn.BatchNorm2d(number_filters*4, affine=True) + self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True) def forward(self, input, skip_raw): out = self.pixel_shuffle(input) - out = self.upsampler(out) + skip_raw - out = self.lrelu(self.res_trunk(out)) + out = self.bn_up(self.upsampler(out)) + skip_raw + out = self.lrelu(self.bn(self.res_trunk(out))) return out class FlatProcessorNet(nn.Module): @@ -80,10 +84,15 @@ class FlatProcessorNet(nn.Module): # Produce assemblers for all possible downscale variants. Some may not be used. self.assembler1 = Assembler(nf, assembler_blocks) + self.assemble1_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True) self.assembler2 = Assembler(nf, assembler_blocks) + self.assemble2_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True) self.assembler3 = Assembler(nf, assembler_blocks) + self.assemble3_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True) self.assembler4 = Assembler(nf, assembler_blocks) + self.assemble4_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True) self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4] + self.assemble_convs = [self.assemble1_conv, self.assemble2_conv, self.assemble3_conv, self.assemble4_conv] # Initialization arch_util.initialize_weights([self.conv_first, self.conv_last], .1) @@ -104,8 +113,10 @@ class FlatProcessorNet(nn.Module): raw_values.append(raw) i = -1 + scaled_outputs = {} out = raw_values[-1] while downsamples != self.downscale: + scaled_outputs[int(x.shape[-1] / downsamples)] = self.assemble_convs[i](out) out = self.assemblers[i](out, raw_values[i-1]) i -= 1 downsamples = int(downsamples / 2) @@ -115,4 +126,4 @@ class FlatProcessorNet(nn.Module): basis = x if downsamples != 1: basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False) - return basis + out + return basis + out, scaled_outputs diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 10a3ccdc..42db4f6b 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torchvision +import torch.nn.functional as F class Discriminator_VGG_128(nn.Module): @@ -11,11 +12,17 @@ class Discriminator_VGG_128(nn.Module): self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + + self.skipconv0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) + # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + + self.skipconv1 = nn.Conv2d(3, nf*2, 3, 1, 1, bias=True) + # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) @@ -38,13 +45,22 @@ class Discriminator_VGG_128(nn.Module): # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - def forward(self, x): + def forward(self, x, gen_skips=None): + x_dim = x.size(-1) + if gen_skips is None: + gen_skips = { + int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False), + int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False), + } + fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + fea = (fea + self.skipconv0(gen_skips[x_dim/2])) / 2 fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + fea = (fea + self.skipconv1(gen_skips[x_dim/4])) / 2 fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) diff --git a/codes/options/train/train_GAN_blacked_corrupt.yml b/codes/options/train/train_GAN_blacked_corrupt.yml index f2f0edd6..c884ab70 100644 --- a/codes/options/train/train_GAN_blacked_corrupt.yml +++ b/codes/options/train/train_GAN_blacked_corrupt.yml @@ -16,8 +16,8 @@ datasets: dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training mismatched_Data_OK: true use_shuffle: true - n_workers: 0 # per GPU - batch_size: 16 + n_workers: 8 # per GPU + batch_size: 32 target_size: 64 use_flip: false use_rot: false @@ -34,14 +34,17 @@ network_G: which_model_G: FlatProcessorNet in_nc: 3 out_nc: 3 - nf: 48 - ra_blocks: 4 - assembler_blocks: 3 + nf: 32 + ra_blocks: 6 + assembler_blocks: 4 network_D: - which_model_D: discriminator_resnet + which_model_D: discriminator_vgg_128 in_nc: 3 nf: 64 + #which_model_D: discriminator_resnet + #in_nc: 3 + #nf: 32 #### path path: @@ -56,7 +59,7 @@ train: weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 1e-4 + lr_D: !!float 2e-4 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -71,11 +74,11 @@ train: pixel_weight: !!float 1e-2 feature_criterion: l1 feature_weight: 0 - gan_type: ragan # gan | ragan + gan_type: gan # gan | ragan gan_weight: !!float 1e-1 - D_update_ratio: 2 - D_init_iters: 1200 + D_update_ratio: 1 + D_init_iters: -1 manual_seed: 10 val_freq: !!float 5e2