Revert "Enable skip-through connections from disc to gen"
This reverts commit b7857f35c3
.
This commit is contained in:
parent
f027e888ed
commit
66e91a3d9e
|
@ -150,16 +150,9 @@ class SRGANModel(BaseModel):
|
||||||
for p in self.netD.parameters():
|
for p in self.netD.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
disc_passthrough = None
|
|
||||||
if step > self.D_init_iters:
|
if step > self.D_init_iters:
|
||||||
self.optimizer_G.zero_grad()
|
self.optimizer_G.zero_grad()
|
||||||
genOut = self.netG(self.var_L)
|
self.fake_H = self.netG(self.var_L)
|
||||||
if type(genOut) is tuple:
|
|
||||||
self.fake_H = genOut[0]
|
|
||||||
disc_passthrough = genOut[1]
|
|
||||||
else:
|
|
||||||
self.fake_H = genOut
|
|
||||||
disc_passthrough = None
|
|
||||||
else:
|
else:
|
||||||
self.fake_H = self.pix
|
self.fake_H = self.pix
|
||||||
|
|
||||||
|
@ -186,14 +179,12 @@ class SRGANModel(BaseModel):
|
||||||
if step % self.l_fea_w_decay_steps == 0:
|
if step % self.l_fea_w_decay_steps == 0:
|
||||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||||
|
|
||||||
if disc_passthrough is not None:
|
|
||||||
pred_g_fake = self.netD(self.fake_H, disc_passthrough)
|
|
||||||
else:
|
|
||||||
pred_g_fake = self.netD(self.fake_H)
|
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
|
pred_g_fake = self.netD(self.fake_H)
|
||||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
pred_d_real = self.netD(self.var_ref).detach()
|
pred_d_real = self.netD(self.var_ref).detach()
|
||||||
|
pred_g_fake = self.netD(self.fake_H)
|
||||||
l_g_gan = self.l_gan_w * (
|
l_g_gan = self.l_gan_w * (
|
||||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
|
@ -208,21 +199,15 @@ class SRGANModel(BaseModel):
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
|
|
||||||
self.optimizer_D.zero_grad()
|
self.optimizer_D.zero_grad()
|
||||||
if disc_passthrough is not None:
|
|
||||||
dp = {}
|
|
||||||
for k, v in disc_passthrough.items():
|
|
||||||
dp[k] = v.detach()
|
|
||||||
pred_d_fake = self.netD(self.fake_H.detach(), dp)
|
|
||||||
else:
|
|
||||||
pred_d_fake = self.netD(self.fake_H.detach())
|
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
# need to forward and backward separately, since batch norm statistics differ
|
# need to forward and backward separately, since batch norm statistics differ
|
||||||
# reald
|
# real
|
||||||
pred_d_real = self.netD(self.var_ref)
|
pred_d_real = self.netD(self.var_ref)
|
||||||
l_d_real = self.cri_gan(pred_d_real, True)
|
l_d_real = self.cri_gan(pred_d_real, True)
|
||||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||||
l_d_real_scaled.backward()
|
l_d_real_scaled.backward()
|
||||||
# fake
|
# fake
|
||||||
|
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
|
@ -233,10 +218,12 @@ class SRGANModel(BaseModel):
|
||||||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||||
# l_d_total.backward()
|
# l_d_total.backward()
|
||||||
|
pred_d_fake = self.netD(self.fake_H.detach()).detach()
|
||||||
pred_d_real = self.netD(self.var_ref)
|
pred_d_real = self.netD(self.var_ref)
|
||||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake.detach()), True) * 0.5
|
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||||
l_d_real_scaled.backward()
|
l_d_real_scaled.backward()
|
||||||
|
pred_d_fake = self.netD(self.fake_H.detach())
|
||||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
|
@ -258,11 +245,7 @@ class SRGANModel(BaseModel):
|
||||||
def test(self):
|
def test(self):
|
||||||
self.netG.eval()
|
self.netG.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
genOut = self.netG(self.var_L)
|
self.fake_H = self.netG(self.var_L)
|
||||||
if type(genOut) is tuple:
|
|
||||||
self.fake_H = genOut[0]
|
|
||||||
else:
|
|
||||||
self.fake_H = genOut
|
|
||||||
self.netG.train()
|
self.netG.train()
|
||||||
|
|
||||||
def get_current_log(self):
|
def get_current_log(self):
|
||||||
|
|
|
@ -23,12 +23,10 @@ class ReduceAnnealer(nn.Module):
|
||||||
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
|
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
arch_util.initialize_weights([self.reducer, self.annealer], .1)
|
arch_util.initialize_weights([self.reducer, self.annealer], .1)
|
||||||
self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True)
|
|
||||||
self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True)
|
|
||||||
|
|
||||||
def forward(self, x, interpolated_trunk):
|
def forward(self, x, interpolated_trunk):
|
||||||
out = self.lrelu(self.bn_reduce(self.reducer(x)))
|
out = self.lrelu(self.reducer(x))
|
||||||
out = self.lrelu(self.bn_anneal(self.res_trunk(out)))
|
out = self.lrelu(self.res_trunk(out))
|
||||||
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
|
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
|
||||||
return annealed, out
|
return annealed, out
|
||||||
|
|
||||||
|
@ -43,13 +41,11 @@ class Assembler(nn.Module):
|
||||||
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
|
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
|
||||||
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
|
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
self.bn = nn.BatchNorm2d(number_filters*4, affine=True)
|
|
||||||
self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True)
|
|
||||||
|
|
||||||
def forward(self, input, skip_raw):
|
def forward(self, input, skip_raw):
|
||||||
out = self.pixel_shuffle(input)
|
out = self.pixel_shuffle(input)
|
||||||
out = self.bn_up(self.upsampler(out)) + skip_raw
|
out = self.upsampler(out) + skip_raw
|
||||||
out = self.lrelu(self.bn(self.res_trunk(out)))
|
out = self.lrelu(self.res_trunk(out))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class FlatProcessorNet(nn.Module):
|
class FlatProcessorNet(nn.Module):
|
||||||
|
@ -84,15 +80,10 @@ class FlatProcessorNet(nn.Module):
|
||||||
|
|
||||||
# Produce assemblers for all possible downscale variants. Some may not be used.
|
# Produce assemblers for all possible downscale variants. Some may not be used.
|
||||||
self.assembler1 = Assembler(nf, assembler_blocks)
|
self.assembler1 = Assembler(nf, assembler_blocks)
|
||||||
self.assemble1_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
|
||||||
self.assembler2 = Assembler(nf, assembler_blocks)
|
self.assembler2 = Assembler(nf, assembler_blocks)
|
||||||
self.assemble2_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
|
||||||
self.assembler3 = Assembler(nf, assembler_blocks)
|
self.assembler3 = Assembler(nf, assembler_blocks)
|
||||||
self.assemble3_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
|
||||||
self.assembler4 = Assembler(nf, assembler_blocks)
|
self.assembler4 = Assembler(nf, assembler_blocks)
|
||||||
self.assemble4_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
|
||||||
self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
|
self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
|
||||||
self.assemble_convs = [self.assemble1_conv, self.assemble2_conv, self.assemble3_conv, self.assemble4_conv]
|
|
||||||
|
|
||||||
# Initialization
|
# Initialization
|
||||||
arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
|
arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
|
||||||
|
@ -113,10 +104,8 @@ class FlatProcessorNet(nn.Module):
|
||||||
raw_values.append(raw)
|
raw_values.append(raw)
|
||||||
|
|
||||||
i = -1
|
i = -1
|
||||||
scaled_outputs = {}
|
|
||||||
out = raw_values[-1]
|
out = raw_values[-1]
|
||||||
while downsamples != self.downscale:
|
while downsamples != self.downscale:
|
||||||
scaled_outputs[int(x.shape[-1] / downsamples)] = self.assemble_convs[i](out)
|
|
||||||
out = self.assemblers[i](out, raw_values[i-1])
|
out = self.assemblers[i](out, raw_values[i-1])
|
||||||
i -= 1
|
i -= 1
|
||||||
downsamples = int(downsamples / 2)
|
downsamples = int(downsamples / 2)
|
||||||
|
@ -126,4 +115,4 @@ class FlatProcessorNet(nn.Module):
|
||||||
basis = x
|
basis = x
|
||||||
if downsamples != 1:
|
if downsamples != 1:
|
||||||
basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
|
basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
|
||||||
return basis + out, scaled_outputs
|
return basis + out
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Discriminator_VGG_128(nn.Module):
|
class Discriminator_VGG_128(nn.Module):
|
||||||
|
@ -12,17 +11,11 @@ class Discriminator_VGG_128(nn.Module):
|
||||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||||
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||||
|
|
||||||
self.skipconv0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
# [64, 64, 64]
|
# [64, 64, 64]
|
||||||
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||||
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||||
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||||
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||||
|
|
||||||
self.skipconv1 = nn.Conv2d(3, nf*2, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
# [128, 32, 32]
|
# [128, 32, 32]
|
||||||
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||||
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||||
|
@ -45,22 +38,13 @@ class Discriminator_VGG_128(nn.Module):
|
||||||
# activation function
|
# activation function
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
def forward(self, x, gen_skips=None):
|
def forward(self, x):
|
||||||
x_dim = x.size(-1)
|
|
||||||
if gen_skips is None:
|
|
||||||
gen_skips = {
|
|
||||||
int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False),
|
|
||||||
int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False),
|
|
||||||
}
|
|
||||||
|
|
||||||
fea = self.lrelu(self.conv0_0(x))
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||||
|
|
||||||
fea = (fea + self.skipconv0(gen_skips[x_dim/2])) / 2
|
|
||||||
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
||||||
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||||
|
|
||||||
fea = (fea + self.skipconv1(gen_skips[x_dim/4])) / 2
|
|
||||||
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||||
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ datasets:
|
||||||
dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training
|
dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training
|
||||||
mismatched_Data_OK: true
|
mismatched_Data_OK: true
|
||||||
use_shuffle: true
|
use_shuffle: true
|
||||||
n_workers: 8 # per GPU
|
n_workers: 0 # per GPU
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
target_size: 64
|
target_size: 64
|
||||||
use_flip: false
|
use_flip: false
|
||||||
use_rot: false
|
use_rot: false
|
||||||
|
@ -34,17 +34,14 @@ network_G:
|
||||||
which_model_G: FlatProcessorNet
|
which_model_G: FlatProcessorNet
|
||||||
in_nc: 3
|
in_nc: 3
|
||||||
out_nc: 3
|
out_nc: 3
|
||||||
nf: 32
|
nf: 48
|
||||||
ra_blocks: 6
|
ra_blocks: 4
|
||||||
assembler_blocks: 4
|
assembler_blocks: 3
|
||||||
|
|
||||||
network_D:
|
network_D:
|
||||||
which_model_D: discriminator_vgg_128
|
which_model_D: discriminator_resnet
|
||||||
in_nc: 3
|
in_nc: 3
|
||||||
nf: 64
|
nf: 64
|
||||||
#which_model_D: discriminator_resnet
|
|
||||||
#in_nc: 3
|
|
||||||
#nf: 32
|
|
||||||
|
|
||||||
#### path
|
#### path
|
||||||
path:
|
path:
|
||||||
|
@ -59,7 +56,7 @@ train:
|
||||||
weight_decay_G: 0
|
weight_decay_G: 0
|
||||||
beta1_G: 0.9
|
beta1_G: 0.9
|
||||||
beta2_G: 0.99
|
beta2_G: 0.99
|
||||||
lr_D: !!float 2e-4
|
lr_D: !!float 1e-4
|
||||||
weight_decay_D: 0
|
weight_decay_D: 0
|
||||||
beta1_D: 0.9
|
beta1_D: 0.9
|
||||||
beta2_D: 0.99
|
beta2_D: 0.99
|
||||||
|
@ -74,11 +71,11 @@ train:
|
||||||
pixel_weight: !!float 1e-2
|
pixel_weight: !!float 1e-2
|
||||||
feature_criterion: l1
|
feature_criterion: l1
|
||||||
feature_weight: 0
|
feature_weight: 0
|
||||||
gan_type: gan # gan | ragan
|
gan_type: ragan # gan | ragan
|
||||||
gan_weight: !!float 1e-1
|
gan_weight: !!float 1e-1
|
||||||
|
|
||||||
D_update_ratio: 1
|
D_update_ratio: 2
|
||||||
D_init_iters: -1
|
D_init_iters: 1200
|
||||||
|
|
||||||
manual_seed: 10
|
manual_seed: 10
|
||||||
val_freq: !!float 5e2
|
val_freq: !!float 5e2
|
||||||
|
|
Loading…
Reference in New Issue
Block a user