Enable skip-through connections from disc to gen

This commit is contained in:
James Betker 2020-04-30 11:30:11 -06:00
parent bf634fc9fa
commit b7857f35c3
4 changed files with 72 additions and 25 deletions

View File

@ -150,9 +150,16 @@ class SRGANModel(BaseModel):
for p in self.netD.parameters():
p.requires_grad = False
disc_passthrough = None
if step > self.D_init_iters:
self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
genOut = self.netG(self.var_L)
if type(genOut) is tuple:
self.fake_H = genOut[0]
disc_passthrough = genOut[1]
else:
self.fake_H = genOut
disc_passthrough = None
else:
self.fake_H = self.pix
@ -179,12 +186,14 @@ class SRGANModel(BaseModel):
if step % self.l_fea_w_decay_steps == 0:
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.opt['train']['gan_type'] == 'gan':
if disc_passthrough is not None:
pred_g_fake = self.netD(self.fake_H, disc_passthrough)
else:
pred_g_fake = self.netD(self.fake_H)
if self.opt['train']['gan_type'] == 'gan':
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(self.var_ref).detach()
pred_g_fake = self.netD(self.fake_H)
l_g_gan = self.l_gan_w * (
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
@ -199,15 +208,21 @@ class SRGANModel(BaseModel):
p.requires_grad = True
self.optimizer_D.zero_grad()
if disc_passthrough is not None:
dp = {}
for k, v in disc_passthrough.items():
dp[k] = v.detach()
pred_d_fake = self.netD(self.fake_H.detach(), dp)
else:
pred_d_fake = self.netD(self.fake_H.detach())
if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ
# real
# reald
pred_d_real = self.netD(self.var_ref)
l_d_real = self.cri_gan(pred_d_real, True)
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
# fake
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
l_d_fake = self.cri_gan(pred_d_fake, False)
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
@ -218,12 +233,10 @@ class SRGANModel(BaseModel):
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
# l_d_total = (l_d_real + l_d_fake) / 2
# l_d_total.backward()
pred_d_fake = self.netD(self.fake_H.detach()).detach()
pred_d_real = self.netD(self.var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake.detach()), True) * 0.5
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
pred_d_fake = self.netD(self.fake_H.detach())
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
@ -245,7 +258,11 @@ class SRGANModel(BaseModel):
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H = self.netG(self.var_L)
genOut = self.netG(self.var_L)
if type(genOut) is tuple:
self.fake_H = genOut[0]
else:
self.fake_H = genOut
self.netG.train()
def get_current_log(self):

View File

@ -23,10 +23,12 @@ class ReduceAnnealer(nn.Module):
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
arch_util.initialize_weights([self.reducer, self.annealer], .1)
self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True)
self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True)
def forward(self, x, interpolated_trunk):
out = self.lrelu(self.reducer(x))
out = self.lrelu(self.res_trunk(out))
out = self.lrelu(self.bn_reduce(self.reducer(x)))
out = self.lrelu(self.bn_anneal(self.res_trunk(out)))
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
return annealed, out
@ -41,11 +43,13 @@ class Assembler(nn.Module):
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bn = nn.BatchNorm2d(number_filters*4, affine=True)
self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True)
def forward(self, input, skip_raw):
out = self.pixel_shuffle(input)
out = self.upsampler(out) + skip_raw
out = self.lrelu(self.res_trunk(out))
out = self.bn_up(self.upsampler(out)) + skip_raw
out = self.lrelu(self.bn(self.res_trunk(out)))
return out
class FlatProcessorNet(nn.Module):
@ -80,10 +84,15 @@ class FlatProcessorNet(nn.Module):
# Produce assemblers for all possible downscale variants. Some may not be used.
self.assembler1 = Assembler(nf, assembler_blocks)
self.assemble1_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
self.assembler2 = Assembler(nf, assembler_blocks)
self.assemble2_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
self.assembler3 = Assembler(nf, assembler_blocks)
self.assemble3_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
self.assembler4 = Assembler(nf, assembler_blocks)
self.assemble4_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
self.assemble_convs = [self.assemble1_conv, self.assemble2_conv, self.assemble3_conv, self.assemble4_conv]
# Initialization
arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
@ -104,8 +113,10 @@ class FlatProcessorNet(nn.Module):
raw_values.append(raw)
i = -1
scaled_outputs = {}
out = raw_values[-1]
while downsamples != self.downscale:
scaled_outputs[int(x.shape[-1] / downsamples)] = self.assemble_convs[i](out)
out = self.assemblers[i](out, raw_values[i-1])
i -= 1
downsamples = int(downsamples / 2)
@ -115,4 +126,4 @@ class FlatProcessorNet(nn.Module):
basis = x
if downsamples != 1:
basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
return basis + out
return basis + out, scaled_outputs

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
class Discriminator_VGG_128(nn.Module):
@ -11,11 +12,17 @@ class Discriminator_VGG_128(nn.Module):
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
self.skipconv0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
# [64, 64, 64]
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
self.skipconv1 = nn.Conv2d(3, nf*2, 3, 1, 1, bias=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
@ -38,13 +45,22 @@ class Discriminator_VGG_128(nn.Module):
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
def forward(self, x, gen_skips=None):
x_dim = x.size(-1)
if gen_skips is None:
gen_skips = {
int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False),
int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False),
}
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
fea = (fea + self.skipconv0(gen_skips[x_dim/2])) / 2
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
fea = (fea + self.skipconv1(gen_skips[x_dim/4])) / 2
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))

View File

@ -16,8 +16,8 @@ datasets:
dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training
mismatched_Data_OK: true
use_shuffle: true
n_workers: 0 # per GPU
batch_size: 16
n_workers: 8 # per GPU
batch_size: 32
target_size: 64
use_flip: false
use_rot: false
@ -34,14 +34,17 @@ network_G:
which_model_G: FlatProcessorNet
in_nc: 3
out_nc: 3
nf: 48
ra_blocks: 4
assembler_blocks: 3
nf: 32
ra_blocks: 6
assembler_blocks: 4
network_D:
which_model_D: discriminator_resnet
which_model_D: discriminator_vgg_128
in_nc: 3
nf: 64
#which_model_D: discriminator_resnet
#in_nc: 3
#nf: 32
#### path
path:
@ -56,7 +59,7 @@ train:
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
lr_D: !!float 2e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
@ -71,11 +74,11 @@ train:
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 0
gan_type: ragan # gan | ragan
gan_type: gan # gan | ragan
gan_weight: !!float 1e-1
D_update_ratio: 2
D_init_iters: 1200
D_update_ratio: 1
D_init_iters: -1
manual_seed: 10
val_freq: !!float 5e2