Enable skip-through connections from disc to gen
This commit is contained in:
parent
bf634fc9fa
commit
b7857f35c3
|
@ -150,9 +150,16 @@ class SRGANModel(BaseModel):
|
|||
for p in self.netD.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
disc_passthrough = None
|
||||
if step > self.D_init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
genOut = self.netG(self.var_L)
|
||||
if type(genOut) is tuple:
|
||||
self.fake_H = genOut[0]
|
||||
disc_passthrough = genOut[1]
|
||||
else:
|
||||
self.fake_H = genOut
|
||||
disc_passthrough = None
|
||||
else:
|
||||
self.fake_H = self.pix
|
||||
|
||||
|
@ -179,12 +186,14 @@ class SRGANModel(BaseModel):
|
|||
if step % self.l_fea_w_decay_steps == 0:
|
||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
if disc_passthrough is not None:
|
||||
pred_g_fake = self.netD(self.fake_H, disc_passthrough)
|
||||
else:
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(self.var_ref).detach()
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
|
@ -199,15 +208,21 @@ class SRGANModel(BaseModel):
|
|||
p.requires_grad = True
|
||||
|
||||
self.optimizer_D.zero_grad()
|
||||
if disc_passthrough is not None:
|
||||
dp = {}
|
||||
for k, v in disc_passthrough.items():
|
||||
dp[k] = v.detach()
|
||||
pred_d_fake = self.netD(self.fake_H.detach(), dp)
|
||||
else:
|
||||
pred_d_fake = self.netD(self.fake_H.detach())
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
# reald
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True)
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
@ -218,12 +233,10 @@ class SRGANModel(BaseModel):
|
|||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(self.fake_H.detach()).detach()
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake.detach()), True) * 0.5
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
pred_d_fake = self.netD(self.fake_H.detach())
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
@ -245,7 +258,11 @@ class SRGANModel(BaseModel):
|
|||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
genOut = self.netG(self.var_L)
|
||||
if type(genOut) is tuple:
|
||||
self.fake_H = genOut[0]
|
||||
else:
|
||||
self.fake_H = genOut
|
||||
self.netG.train()
|
||||
|
||||
def get_current_log(self):
|
||||
|
|
|
@ -23,10 +23,12 @@ class ReduceAnnealer(nn.Module):
|
|||
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
arch_util.initialize_weights([self.reducer, self.annealer], .1)
|
||||
self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True)
|
||||
self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True)
|
||||
|
||||
def forward(self, x, interpolated_trunk):
|
||||
out = self.lrelu(self.reducer(x))
|
||||
out = self.lrelu(self.res_trunk(out))
|
||||
out = self.lrelu(self.bn_reduce(self.reducer(x)))
|
||||
out = self.lrelu(self.bn_anneal(self.res_trunk(out)))
|
||||
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
|
||||
return annealed, out
|
||||
|
||||
|
@ -41,11 +43,13 @@ class Assembler(nn.Module):
|
|||
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
|
||||
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.bn = nn.BatchNorm2d(number_filters*4, affine=True)
|
||||
self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True)
|
||||
|
||||
def forward(self, input, skip_raw):
|
||||
out = self.pixel_shuffle(input)
|
||||
out = self.upsampler(out) + skip_raw
|
||||
out = self.lrelu(self.res_trunk(out))
|
||||
out = self.bn_up(self.upsampler(out)) + skip_raw
|
||||
out = self.lrelu(self.bn(self.res_trunk(out)))
|
||||
return out
|
||||
|
||||
class FlatProcessorNet(nn.Module):
|
||||
|
@ -80,10 +84,15 @@ class FlatProcessorNet(nn.Module):
|
|||
|
||||
# Produce assemblers for all possible downscale variants. Some may not be used.
|
||||
self.assembler1 = Assembler(nf, assembler_blocks)
|
||||
self.assemble1_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
||||
self.assembler2 = Assembler(nf, assembler_blocks)
|
||||
self.assemble2_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
||||
self.assembler3 = Assembler(nf, assembler_blocks)
|
||||
self.assemble3_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
||||
self.assembler4 = Assembler(nf, assembler_blocks)
|
||||
self.assemble4_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
||||
self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
|
||||
self.assemble_convs = [self.assemble1_conv, self.assemble2_conv, self.assemble3_conv, self.assemble4_conv]
|
||||
|
||||
# Initialization
|
||||
arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
|
||||
|
@ -104,8 +113,10 @@ class FlatProcessorNet(nn.Module):
|
|||
raw_values.append(raw)
|
||||
|
||||
i = -1
|
||||
scaled_outputs = {}
|
||||
out = raw_values[-1]
|
||||
while downsamples != self.downscale:
|
||||
scaled_outputs[int(x.shape[-1] / downsamples)] = self.assemble_convs[i](out)
|
||||
out = self.assemblers[i](out, raw_values[i-1])
|
||||
i -= 1
|
||||
downsamples = int(downsamples / 2)
|
||||
|
@ -115,4 +126,4 @@ class FlatProcessorNet(nn.Module):
|
|||
basis = x
|
||||
if downsamples != 1:
|
||||
basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
|
||||
return basis + out
|
||||
return basis + out, scaled_outputs
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Discriminator_VGG_128(nn.Module):
|
||||
|
@ -11,11 +12,17 @@ class Discriminator_VGG_128(nn.Module):
|
|||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||
|
||||
self.skipconv0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||
|
||||
self.skipconv1 = nn.Conv2d(3, nf*2, 3, 1, 1, bias=True)
|
||||
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||
|
@ -38,13 +45,22 @@ class Discriminator_VGG_128(nn.Module):
|
|||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, gen_skips=None):
|
||||
x_dim = x.size(-1)
|
||||
if gen_skips is None:
|
||||
gen_skips = {
|
||||
int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False),
|
||||
int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False),
|
||||
}
|
||||
|
||||
fea = self.lrelu(self.conv0_0(x))
|
||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||
|
||||
fea = (fea + self.skipconv0(gen_skips[x_dim/2])) / 2
|
||||
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
||||
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||
|
||||
fea = (fea + self.skipconv1(gen_skips[x_dim/4])) / 2
|
||||
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ datasets:
|
|||
dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training
|
||||
mismatched_Data_OK: true
|
||||
use_shuffle: true
|
||||
n_workers: 0 # per GPU
|
||||
batch_size: 16
|
||||
n_workers: 8 # per GPU
|
||||
batch_size: 32
|
||||
target_size: 64
|
||||
use_flip: false
|
||||
use_rot: false
|
||||
|
@ -34,14 +34,17 @@ network_G:
|
|||
which_model_G: FlatProcessorNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 48
|
||||
ra_blocks: 4
|
||||
assembler_blocks: 3
|
||||
nf: 32
|
||||
ra_blocks: 6
|
||||
assembler_blocks: 4
|
||||
|
||||
network_D:
|
||||
which_model_D: discriminator_resnet
|
||||
which_model_D: discriminator_vgg_128
|
||||
in_nc: 3
|
||||
nf: 64
|
||||
#which_model_D: discriminator_resnet
|
||||
#in_nc: 3
|
||||
#nf: 32
|
||||
|
||||
#### path
|
||||
path:
|
||||
|
@ -56,7 +59,7 @@ train:
|
|||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 1e-4
|
||||
lr_D: !!float 2e-4
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
|
@ -71,11 +74,11 @@ train:
|
|||
pixel_weight: !!float 1e-2
|
||||
feature_criterion: l1
|
||||
feature_weight: 0
|
||||
gan_type: ragan # gan | ragan
|
||||
gan_type: gan # gan | ragan
|
||||
gan_weight: !!float 1e-1
|
||||
|
||||
D_update_ratio: 2
|
||||
D_init_iters: 1200
|
||||
D_update_ratio: 1
|
||||
D_init_iters: -1
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e2
|
||||
|
|
Loading…
Reference in New Issue
Block a user