From b06e1784e1ed2384d103aafef3d086940cb1270d Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 25 Jul 2020 17:16:54 -0600 Subject: [PATCH] Fix SRG4 & switch disc "fix". hehe. --- .../archs/SwitchedResidualGenerator_arch.py | 23 ++++-- codes/models/archs/discriminator_vgg_arch.py | 74 ++++++++++++------- 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 9fb8b2c4..3a139dc7 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -139,7 +139,8 @@ class ConfigurableSwitchComputer(nn.Module): rand_feature = torch.randn_like(x) * self.noise_scale x = x + rand_feature - x = self.pre_transform(x) + if self.pre_transform: + x = self.pre_transform(x) xformed = [t.forward(x) for t in self.transforms] m = self.multiplexer(identity) @@ -255,6 +256,8 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module): multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) + half_multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, + switch_processing_layers, trans_counts // 2) transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) @@ -265,12 +268,19 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module): transform_count=trans_counts, init_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) self.rdb2 = RRDB(transformation_filters) - self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + self.sw2 = ConfigurableSwitchComputer(transformation_filters, half_multiplx_fn, + pre_transform_block=None, transform_block=transform_fn, + attention_norm=attention_norm, + transform_count=trans_counts // 2, init_temp=initial_temp, + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) + self.rdb3 = RRDB(transformation_filters) + self.sw3 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=None, transform_block=transform_fn, attention_norm=attention_norm, transform_count=trans_counts, init_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) - self.rdb3 = RRDB(transformation_filters) + self.rdb4 = RRDB(transformation_filters) + self.switches = [self.sw1, self.sw2, self.sw3] self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) self.transformation_counts = trans_counts @@ -290,10 +300,13 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module): x = self.initial_conv(x) x = self.rdb1(x) - x = self.sw1(x, True) + x, a1 = self.sw1(x, True) x = self.rdb2(x) - x = self.sw2(x, True) + x, a2 = self.sw2(x, True) x = self.rdb3(x) + x, a3 = self.sw3(x, True) + x = self.rdb4(x) + self.attentions = [a1, a2, a3] x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) if self.upsample_factor > 2: diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 04c9fc4d..dd6bd88c 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torchvision -from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock +from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu import torch.nn.functional as F @@ -244,15 +244,33 @@ from switched_conv_util import save_attention_to_image from switched_conv import compute_attention_specificity, AttentionNorm -class ExpandAndCollapse(nn.Module): - def __init__(self, nf, nf_out, num_channels): - super(ExpandAndCollapse, self).__init__() - self.expand = ExpansionBlock(nf, nf_out, block=ConvGnLelu) - self.collapse = ConvGnLelu(nf_out, num_channels, norm=False, bias=False, activation=False) +class ReducingMultiplexer(nn.Module): + def __init__(self, nf, num_channels): + super(ReducingMultiplexer, self).__init__() + self.conv1_0 = ConvGnSilu(nf, nf * 2, kernel_size=3, bias=False) + self.conv1_1 = ConvGnSilu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) + # [128, 32, 32] + self.conv2_0 = ConvGnSilu(nf * 2, nf * 4, kernel_size=3, bias=False) + self.conv2_1 = ConvGnSilu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) + # [256, 16, 16] + self.conv3_0 = ConvGnSilu(nf * 4, nf * 8, kernel_size=3, bias=False) + self.conv3_1 = ConvGnSilu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + self.exp1 = ExpansionBlock(nf * 8, nf * 4) + self.exp2 = ExpansionBlock(nf * 4, nf * 2) + self.exp3 = ExpansionBlock(nf * 2, nf) + self.collapse = ConvGnSilu(nf, num_channels, norm=False, bias=True) - def forward(self, x, passthrough): - x = self.expand(x, passthrough) - return self.collapse(x) + def forward(self, x): + fea1 = self.conv1_0(x) + fea1 = self.conv1_1(fea1) + fea2 = self.conv2_0(fea1) + fea2 = self.conv2_1(fea2) + fea3 = self.conv3_0(fea2) + fea3 = self.conv3_1(fea3) + up = self.exp1(fea3, fea2) + up = self.exp2(up, fea1) + up = self.exp3(up, x) + return self.collapse(up) # Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in. @@ -274,16 +292,15 @@ class ConfigurableLinearSwitchComputer(nn.Module): # depending on its needs. self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) - def forward(self, x, passthrough, output_attention_weights=False, extra_arg=None): - identity = x + def forward(self, x, output_attention_weights=False, extra_arg=None): if self.add_noise: rand_feature = torch.randn_like(x) * self.noise_scale x = x + rand_feature - x = self.pre_transform(x) - xformed = [t.forward(x, passthrough) for t in self.transforms] - m = self.multiplexer(identity, passthrough) - + if self.pre_transform: + x = self.pre_transform(x) + xformed = [t.forward(x) for t in self.transforms] + m = self.multiplexer(x) outputs, attention = self.switch(xformed, m, True) outputs = self.post_switch_conv(outputs) @@ -296,10 +313,10 @@ class ConfigurableLinearSwitchComputer(nn.Module): self.switch.set_attention_temperature(temp) -def create_switched_upsampler(nf, nf_out, num_channels, initial_temp=10): - multiplx = ExpandAndCollapse(nf, nf_out, num_channels) - pretransform = ConvGnLelu(nf, nf, norm=True, bias=False) - transform_fn = functools.partial(ExpansionBlock, nf, nf_out, block=ConvGnLelu) +def create_switched_downsampler(nf, nf_out, num_channels, initial_temp=10): + multiplx = ReducingMultiplexer(nf, num_channels) + pretransform = None + transform_fn = functools.partial(MultiConvBlock, nf, nf, nf_out, kernel_size=3, depth=2) return ConfigurableLinearSwitchComputer(nf_out, multiplx, pre_transform_block=pretransform, transform_block=transform_fn, attention_norm=True, @@ -314,8 +331,9 @@ class Discriminator_switched(nn.Module): self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) # [64, 64, 64] - self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) - self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) + self.sw = create_switched_downsampler(nf, nf, 8) + self.switches = [self.sw] + self.conv1_1 = ConvGnLelu(nf, nf * 2, kernel_size=3, stride=2, bias=False) # [128, 32, 32] self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) @@ -327,9 +345,8 @@ class Discriminator_switched(nn.Module): self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) - self.upsw2 = create_switched_upsampler(nf * 8, nf * 4, 8) - self.upsw3 = create_switched_upsampler(nf * 4, nf * 2, 8) - self.switches = [self.upsw2, self.upsw3] + self.exp2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu) + self.exp3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu) self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) @@ -341,7 +358,8 @@ class Discriminator_switched(nn.Module): fea0 = self.conv0_0(x) fea0 = self.conv0_1(fea0) - fea1 = self.conv1_0(fea0) + fea1, att = self.sw(fea0, True) + self.attentions = [att] fea1 = self.conv1_1(fea1) fea2 = self.conv2_0(fea1) @@ -354,9 +372,9 @@ class Discriminator_switched(nn.Module): fea4 = self.conv4_1(fea4) u1 = self.exp1(fea4, fea3) - u2, a1 = self.upsw2(u1, fea2, output_attention_weights=True) - u3, a2 = self.upsw3(u2, fea1, output_attention_weights=True) - self.attentions = [a1, a2] + u2 = self.exp2(u1, fea2) + u3 = self.exp3(u2, fea1) + loss3 = self.collapse3(self.proc3(u3)) return loss3.view(-1, 1)