forked from mrq/DL-Art-School
Add cross-compare discriminator
This commit is contained in:
parent
be272248af
commit
1f21c02f8b
|
@ -137,6 +137,59 @@ class Discriminator_VGG_128_GN(nn.Module):
|
|||
out = self.linear2(fea)
|
||||
return out
|
||||
|
||||
|
||||
class CrossCompareBlock(nn.Module):
|
||||
def __init__(self, nf_in, nf_out):
|
||||
self.conv_hr_merge = ConvGnLelu(nf_in * 2, nf_in, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.proc_hr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.proc_lr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.reduce_hr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
|
||||
self.reduce_lr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
|
||||
|
||||
def forward(self, lr, hr):
|
||||
hr = self.conv_hr_merge(torch.cat([hr, lr], dim=1))
|
||||
hr = self.proc_hr(hr)
|
||||
hr = self.reduce_hr(hr)
|
||||
|
||||
lr = self.proc_lr(lr)
|
||||
lr = self.reduce_lr(lr)
|
||||
|
||||
return lr, hr
|
||||
|
||||
|
||||
class CrossCompareDiscriminator(nn.Module):
|
||||
def __init__(self, in_nc, nf, scale=4):
|
||||
super(CrossCompareDiscriminator, self).__init__()
|
||||
assert scale == 2 or scale == 4
|
||||
|
||||
self.init_conv_hr = ConvGnLelu(in_nc, nf, stride=2, norm=False, bias=True, activation=True)
|
||||
self.init_conv_lr = ConvGnLelu(in_nc, nf, stride=1, norm=False, bias=True, activation=True)
|
||||
if scale == 4:
|
||||
strd_2 = 2
|
||||
else:
|
||||
strd_2 = 1
|
||||
self.second_conv = ConvGnLelu(nf, nf, stride=strd_2, norm=True, bias=False, activation=True)
|
||||
|
||||
self.cross1 = CrossCompareBlock(nf, nf * 2)
|
||||
self.cross2 = CrossCompareBlock(nf * 2, nf * 4)
|
||||
self.cross3 = CrossCompareBlock(nf * 4, nf * 8)
|
||||
self.cross4 = CrossCompareBlock(nf * 8, nf * 8)
|
||||
self.fproc_conv = ConvGnLelu(nf * 8, nf, norm=True, bias=True, activation=True)
|
||||
self.out_conv = ConvGnLelu(nf, 1, norm=False, bias=False, activation=False)
|
||||
|
||||
def forward(self, lr, hr):
|
||||
hr = self.init_conv_hr(hr)
|
||||
hr = self.second_conv(hr)
|
||||
lr = self.init_conv_lr(lr)
|
||||
|
||||
lr, hr = self.cross1(lr, hr)
|
||||
lr, hr = self.cross2(lr, hr)
|
||||
lr, hr = self.cross3(lr, hr)
|
||||
_, hr = self.cross4(lr, hr)
|
||||
|
||||
return self.out_conv(self.fproc_conv(hr))
|
||||
|
||||
|
||||
class Discriminator_VGG_PixLoss(nn.Module):
|
||||
def __init__(self, in_nc, nf):
|
||||
super(Discriminator_VGG_PixLoss, self).__init__()
|
||||
|
@ -297,173 +350,6 @@ class Discriminator_UNet(nn.Module):
|
|||
return 3, 4
|
||||
|
||||
|
||||
import functools
|
||||
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConfigurableSwitchComputer, BareConvSwitch
|
||||
from switched_conv_util import save_attention_to_image
|
||||
from switched_conv import compute_attention_specificity, AttentionNorm
|
||||
|
||||
|
||||
class ReducingMultiplexer(nn.Module):
|
||||
def __init__(self, nf, num_channels):
|
||||
super(ReducingMultiplexer, self).__init__()
|
||||
self.conv1_0 = ConvGnSilu(nf, nf * 2, kernel_size=3, bias=False)
|
||||
self.conv1_1 = ConvGnSilu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = ConvGnSilu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||
self.conv2_1 = ConvGnSilu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = ConvGnSilu(nf * 4, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv3_1 = ConvGnSilu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
self.exp1 = ExpansionBlock(nf * 8, nf * 4)
|
||||
self.exp2 = ExpansionBlock(nf * 4, nf * 2)
|
||||
self.exp3 = ExpansionBlock(nf * 2, nf)
|
||||
self.collapse = ConvGnSilu(nf, num_channels, norm=False, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea1 = self.conv1_0(x)
|
||||
fea1 = self.conv1_1(fea1)
|
||||
fea2 = self.conv2_0(fea1)
|
||||
fea2 = self.conv2_1(fea2)
|
||||
fea3 = self.conv3_0(fea2)
|
||||
fea3 = self.conv3_1(fea3)
|
||||
up = self.exp1(fea3, fea2)
|
||||
up = self.exp2(up, fea1)
|
||||
up = self.exp3(up, x)
|
||||
return self.collapse(up)
|
||||
|
||||
|
||||
# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in.
|
||||
class ConfigurableLinearSwitchComputer(nn.Module):
|
||||
def __init__(self, out_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
|
||||
init_temp=20, add_scalable_noise_to_transforms=False):
|
||||
super(ConfigurableLinearSwitchComputer, self).__init__()
|
||||
|
||||
self.multiplexer = multiplexer_net
|
||||
self.pre_transform = pre_transform_block
|
||||
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
||||
self.add_noise = add_scalable_noise_to_transforms
|
||||
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
|
||||
|
||||
# And the switch itself, including learned scalars
|
||||
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None)
|
||||
self.post_switch_conv = ConvBnLelu(out_filters, out_filters, norm=False, bias=True)
|
||||
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
||||
# depending on its needs.
|
||||
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||
|
||||
def forward(self, x, output_attention_weights=False, extra_arg=None):
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x) * self.noise_scale
|
||||
x = x + rand_feature
|
||||
|
||||
if self.pre_transform:
|
||||
x = self.pre_transform(x)
|
||||
xformed = [t.forward(x) for t in self.transforms]
|
||||
m = self.multiplexer(x)
|
||||
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
outputs = self.post_switch_conv(outputs)
|
||||
if output_attention_weights:
|
||||
return outputs, attention
|
||||
else:
|
||||
return outputs
|
||||
|
||||
def set_temperature(self, temp):
|
||||
self.switch.set_attention_temperature(temp)
|
||||
|
||||
|
||||
def create_switched_downsampler(nf, nf_out, num_channels, initial_temp=10):
|
||||
multiplx = ReducingMultiplexer(nf, num_channels)
|
||||
pretransform = None
|
||||
transform_fn = functools.partial(MultiConvBlock, nf, nf, nf_out, kernel_size=3, depth=2)
|
||||
return ConfigurableLinearSwitchComputer(nf_out, multiplx,
|
||||
pre_transform_block=pretransform, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=num_channels, init_temp=initial_temp,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
|
||||
|
||||
class Discriminator_switched(nn.Module):
|
||||
def __init__(self, in_nc, nf, initial_temp=10, final_temperature_step=50000):
|
||||
super(Discriminator_switched, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
||||
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
|
||||
# [64, 64, 64]
|
||||
self.sw = create_switched_downsampler(nf, nf, 8)
|
||||
self.switches = [self.sw]
|
||||
self.conv1_1 = ConvGnLelu(nf, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||
self.exp2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
|
||||
self.exp3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
|
||||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.attentions = None
|
||||
|
||||
def forward(self, x, flatten=True):
|
||||
fea0 = self.conv0_0(x)
|
||||
fea0 = self.conv0_1(fea0)
|
||||
|
||||
fea1, att = self.sw(fea0, True)
|
||||
self.attentions = [att]
|
||||
fea1 = self.conv1_1(fea1)
|
||||
|
||||
fea2 = self.conv2_0(fea1)
|
||||
fea2 = self.conv2_1(fea2)
|
||||
|
||||
fea3 = self.conv3_0(fea2)
|
||||
fea3 = self.conv3_1(fea3)
|
||||
|
||||
fea4 = self.conv4_0(fea3)
|
||||
fea4 = self.conv4_1(fea4)
|
||||
|
||||
u1 = self.exp1(fea4, fea3)
|
||||
u2 = self.exp2(u1, fea2)
|
||||
u3 = self.exp3(u2, fea1)
|
||||
|
||||
loss3 = self.collapse3(self.proc3(u3))
|
||||
return loss3.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 1, 4
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
for i, sw in enumerate(self.switches):
|
||||
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
|
||||
sw.set_temperature(min(self.init_temperature,
|
||||
max(self.init_temperature - temp_loss_per_step * step, 1)))
|
||||
if step % 50 == 0:
|
||||
[save_attention_to_image(experiments_path, self.attentions[i], 8, step, "disc_a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"disc_switch_temperature": temp}
|
||||
for i in range(len(means)):
|
||||
val["disc_switch_%i_specificity" % (i,)] = means[i]
|
||||
val["disc_switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
|
||||
class Discriminator_UNet_FeaOut(nn.Module):
|
||||
def __init__(self, in_nc, nf, feature_mode=False):
|
||||
super(Discriminator_UNet_FeaOut, self).__init__()
|
||||
|
|
|
@ -150,6 +150,8 @@ def define_D_net(opt_net, img_sz=None):
|
|||
elif which_model == "discriminator_switched":
|
||||
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
|
||||
final_temperature_step=opt_net['final_temperature_step'])
|
||||
elif which_model == "cross_compare_vgg128":
|
||||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], nf=opt_net['nf'], scale=opt_net['scale'])
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
Loading…
Reference in New Issue
Block a user