Add cross-compare discriminator

This commit is contained in:
James Betker 2020-08-06 08:56:21 -06:00
parent be272248af
commit 1f21c02f8b
2 changed files with 55 additions and 167 deletions

View File

@ -137,6 +137,59 @@ class Discriminator_VGG_128_GN(nn.Module):
out = self.linear2(fea)
return out
class CrossCompareBlock(nn.Module):
def __init__(self, nf_in, nf_out):
self.conv_hr_merge = ConvGnLelu(nf_in * 2, nf_in, kernel_size=1, bias=False, activation=False, norm=True)
self.proc_hr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
self.proc_lr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
self.reduce_hr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
self.reduce_lr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
def forward(self, lr, hr):
hr = self.conv_hr_merge(torch.cat([hr, lr], dim=1))
hr = self.proc_hr(hr)
hr = self.reduce_hr(hr)
lr = self.proc_lr(lr)
lr = self.reduce_lr(lr)
return lr, hr
class CrossCompareDiscriminator(nn.Module):
def __init__(self, in_nc, nf, scale=4):
super(CrossCompareDiscriminator, self).__init__()
assert scale == 2 or scale == 4
self.init_conv_hr = ConvGnLelu(in_nc, nf, stride=2, norm=False, bias=True, activation=True)
self.init_conv_lr = ConvGnLelu(in_nc, nf, stride=1, norm=False, bias=True, activation=True)
if scale == 4:
strd_2 = 2
else:
strd_2 = 1
self.second_conv = ConvGnLelu(nf, nf, stride=strd_2, norm=True, bias=False, activation=True)
self.cross1 = CrossCompareBlock(nf, nf * 2)
self.cross2 = CrossCompareBlock(nf * 2, nf * 4)
self.cross3 = CrossCompareBlock(nf * 4, nf * 8)
self.cross4 = CrossCompareBlock(nf * 8, nf * 8)
self.fproc_conv = ConvGnLelu(nf * 8, nf, norm=True, bias=True, activation=True)
self.out_conv = ConvGnLelu(nf, 1, norm=False, bias=False, activation=False)
def forward(self, lr, hr):
hr = self.init_conv_hr(hr)
hr = self.second_conv(hr)
lr = self.init_conv_lr(lr)
lr, hr = self.cross1(lr, hr)
lr, hr = self.cross2(lr, hr)
lr, hr = self.cross3(lr, hr)
_, hr = self.cross4(lr, hr)
return self.out_conv(self.fproc_conv(hr))
class Discriminator_VGG_PixLoss(nn.Module):
def __init__(self, in_nc, nf):
super(Discriminator_VGG_PixLoss, self).__init__()
@ -297,173 +350,6 @@ class Discriminator_UNet(nn.Module):
return 3, 4
import functools
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConfigurableSwitchComputer, BareConvSwitch
from switched_conv_util import save_attention_to_image
from switched_conv import compute_attention_specificity, AttentionNorm
class ReducingMultiplexer(nn.Module):
def __init__(self, nf, num_channels):
super(ReducingMultiplexer, self).__init__()
self.conv1_0 = ConvGnSilu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnSilu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnSilu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnSilu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnSilu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnSilu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.exp1 = ExpansionBlock(nf * 8, nf * 4)
self.exp2 = ExpansionBlock(nf * 4, nf * 2)
self.exp3 = ExpansionBlock(nf * 2, nf)
self.collapse = ConvGnSilu(nf, num_channels, norm=False, bias=True)
def forward(self, x):
fea1 = self.conv1_0(x)
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
up = self.exp1(fea3, fea2)
up = self.exp2(up, fea1)
up = self.exp3(up, x)
return self.collapse(up)
# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in.
class ConfigurableLinearSwitchComputer(nn.Module):
def __init__(self, out_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
init_temp=20, add_scalable_noise_to_transforms=False):
super(ConfigurableLinearSwitchComputer, self).__init__()
self.multiplexer = multiplexer_net
self.pre_transform = pre_transform_block
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
self.add_noise = add_scalable_noise_to_transforms
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
# And the switch itself, including learned scalars
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None)
self.post_switch_conv = ConvBnLelu(out_filters, out_filters, norm=False, bias=True)
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
# depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
def forward(self, x, output_attention_weights=False, extra_arg=None):
if self.add_noise:
rand_feature = torch.randn_like(x) * self.noise_scale
x = x + rand_feature
if self.pre_transform:
x = self.pre_transform(x)
xformed = [t.forward(x) for t in self.transforms]
m = self.multiplexer(x)
outputs, attention = self.switch(xformed, m, True)
outputs = self.post_switch_conv(outputs)
if output_attention_weights:
return outputs, attention
else:
return outputs
def set_temperature(self, temp):
self.switch.set_attention_temperature(temp)
def create_switched_downsampler(nf, nf_out, num_channels, initial_temp=10):
multiplx = ReducingMultiplexer(nf, num_channels)
pretransform = None
transform_fn = functools.partial(MultiConvBlock, nf, nf, nf_out, kernel_size=3, depth=2)
return ConfigurableLinearSwitchComputer(nf_out, multiplx,
pre_transform_block=pretransform, transform_block=transform_fn,
attention_norm=True,
transform_count=num_channels, init_temp=initial_temp,
add_scalable_noise_to_transforms=False)
class Discriminator_switched(nn.Module):
def __init__(self, in_nc, nf, initial_temp=10, final_temperature_step=50000):
super(Discriminator_switched, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.sw = create_switched_downsampler(nf, nf, 8)
self.switches = [self.sw]
self.conv1_1 = ConvGnLelu(nf, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
# [512, 8, 8]
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.exp2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
self.exp3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.attentions = None
def forward(self, x, flatten=True):
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)
fea1, att = self.sw(fea0, True)
self.attentions = [att]
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
fea4 = self.conv4_0(fea3)
fea4 = self.conv4_1(fea4)
u1 = self.exp1(fea4, fea3)
u2 = self.exp2(u1, fea2)
u3 = self.exp3(u2, fea1)
loss3 = self.collapse3(self.proc3(u3))
return loss3.view(-1, 1)
def pixgan_parameters(self):
return 1, 4
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
for i, sw in enumerate(self.switches):
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
sw.set_temperature(min(self.init_temperature,
max(self.init_temperature - temp_loss_per_step * step, 1)))
if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], 8, step, "disc_a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"disc_switch_temperature": temp}
for i in range(len(means)):
val["disc_switch_%i_specificity" % (i,)] = means[i]
val["disc_switch_%i_histogram" % (i,)] = hists[i]
return val
class Discriminator_UNet_FeaOut(nn.Module):
def __init__(self, in_nc, nf, feature_mode=False):
super(Discriminator_UNet_FeaOut, self).__init__()

View File

@ -150,6 +150,8 @@ def define_D_net(opt_net, img_sz=None):
elif which_model == "discriminator_switched":
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
final_temperature_step=opt_net['final_temperature_step'])
elif which_model == "cross_compare_vgg128":
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], nf=opt_net['nf'], scale=opt_net['scale'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD