Fix SRG4 & switch disc

"fix". hehe.
This commit is contained in:
James Betker 2020-07-25 17:16:54 -06:00
parent e6e91a1d75
commit b06e1784e1
2 changed files with 64 additions and 33 deletions

View File

@ -139,6 +139,7 @@ class ConfigurableSwitchComputer(nn.Module):
rand_feature = torch.randn_like(x) * self.noise_scale rand_feature = torch.randn_like(x) * self.noise_scale
x = x + rand_feature x = x + rand_feature
if self.pre_transform:
x = self.pre_transform(x) x = self.pre_transform(x)
xformed = [t.forward(x) for t in self.transforms] xformed = [t.forward(x) for t in self.transforms]
m = self.multiplexer(identity) m = self.multiplexer(identity)
@ -255,6 +256,8 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module):
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
switch_processing_layers, trans_counts) switch_processing_layers, trans_counts)
half_multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
switch_processing_layers, trans_counts // 2)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers,
weight_init_factor=.1) weight_init_factor=.1)
@ -265,12 +268,19 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module):
transform_count=trans_counts, init_temp=initial_temp, transform_count=trans_counts, init_temp=initial_temp,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
self.rdb2 = RRDB(transformation_filters) self.rdb2 = RRDB(transformation_filters)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.sw2 = ConfigurableSwitchComputer(transformation_filters, half_multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=attention_norm,
transform_count=trans_counts // 2, init_temp=initial_temp,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
self.rdb3 = RRDB(transformation_filters)
self.sw3 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn, pre_transform_block=None, transform_block=transform_fn,
attention_norm=attention_norm, attention_norm=attention_norm,
transform_count=trans_counts, init_temp=initial_temp, transform_count=trans_counts, init_temp=initial_temp,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
self.rdb3 = RRDB(transformation_filters) self.rdb4 = RRDB(transformation_filters)
self.switches = [self.sw1, self.sw2, self.sw3]
self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
self.transformation_counts = trans_counts self.transformation_counts = trans_counts
@ -290,10 +300,13 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module):
x = self.initial_conv(x) x = self.initial_conv(x)
x = self.rdb1(x) x = self.rdb1(x)
x = self.sw1(x, True) x, a1 = self.sw1(x, True)
x = self.rdb2(x) x = self.rdb2(x)
x = self.sw2(x, True) x, a2 = self.sw2(x, True)
x = self.rdb3(x) x = self.rdb3(x)
x, a3 = self.sw3(x, True)
x = self.rdb4(x)
self.attentions = [a1, a2, a3]
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
if self.upsample_factor > 2: if self.upsample_factor > 2:

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu
import torch.nn.functional as F import torch.nn.functional as F
@ -244,15 +244,33 @@ from switched_conv_util import save_attention_to_image
from switched_conv import compute_attention_specificity, AttentionNorm from switched_conv import compute_attention_specificity, AttentionNorm
class ExpandAndCollapse(nn.Module): class ReducingMultiplexer(nn.Module):
def __init__(self, nf, nf_out, num_channels): def __init__(self, nf, num_channels):
super(ExpandAndCollapse, self).__init__() super(ReducingMultiplexer, self).__init__()
self.expand = ExpansionBlock(nf, nf_out, block=ConvGnLelu) self.conv1_0 = ConvGnSilu(nf, nf * 2, kernel_size=3, bias=False)
self.collapse = ConvGnLelu(nf_out, num_channels, norm=False, bias=False, activation=False) self.conv1_1 = ConvGnSilu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnSilu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnSilu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnSilu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnSilu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.exp1 = ExpansionBlock(nf * 8, nf * 4)
self.exp2 = ExpansionBlock(nf * 4, nf * 2)
self.exp3 = ExpansionBlock(nf * 2, nf)
self.collapse = ConvGnSilu(nf, num_channels, norm=False, bias=True)
def forward(self, x, passthrough): def forward(self, x):
x = self.expand(x, passthrough) fea1 = self.conv1_0(x)
return self.collapse(x) fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
up = self.exp1(fea3, fea2)
up = self.exp2(up, fea1)
up = self.exp3(up, x)
return self.collapse(up)
# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in. # Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in.
@ -274,16 +292,15 @@ class ConfigurableLinearSwitchComputer(nn.Module):
# depending on its needs. # depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
def forward(self, x, passthrough, output_attention_weights=False, extra_arg=None): def forward(self, x, output_attention_weights=False, extra_arg=None):
identity = x
if self.add_noise: if self.add_noise:
rand_feature = torch.randn_like(x) * self.noise_scale rand_feature = torch.randn_like(x) * self.noise_scale
x = x + rand_feature x = x + rand_feature
if self.pre_transform:
x = self.pre_transform(x) x = self.pre_transform(x)
xformed = [t.forward(x, passthrough) for t in self.transforms] xformed = [t.forward(x) for t in self.transforms]
m = self.multiplexer(identity, passthrough) m = self.multiplexer(x)
outputs, attention = self.switch(xformed, m, True) outputs, attention = self.switch(xformed, m, True)
outputs = self.post_switch_conv(outputs) outputs = self.post_switch_conv(outputs)
@ -296,10 +313,10 @@ class ConfigurableLinearSwitchComputer(nn.Module):
self.switch.set_attention_temperature(temp) self.switch.set_attention_temperature(temp)
def create_switched_upsampler(nf, nf_out, num_channels, initial_temp=10): def create_switched_downsampler(nf, nf_out, num_channels, initial_temp=10):
multiplx = ExpandAndCollapse(nf, nf_out, num_channels) multiplx = ReducingMultiplexer(nf, num_channels)
pretransform = ConvGnLelu(nf, nf, norm=True, bias=False) pretransform = None
transform_fn = functools.partial(ExpansionBlock, nf, nf_out, block=ConvGnLelu) transform_fn = functools.partial(MultiConvBlock, nf, nf, nf_out, kernel_size=3, depth=2)
return ConfigurableLinearSwitchComputer(nf_out, multiplx, return ConfigurableLinearSwitchComputer(nf_out, multiplx,
pre_transform_block=pretransform, transform_block=transform_fn, pre_transform_block=pretransform, transform_block=transform_fn,
attention_norm=True, attention_norm=True,
@ -314,8 +331,9 @@ class Discriminator_switched(nn.Module):
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64] # [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) self.sw = create_switched_downsampler(nf, nf, 8)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) self.switches = [self.sw]
self.conv1_1 = ConvGnLelu(nf, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32] # [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
@ -327,9 +345,8 @@ class Discriminator_switched(nn.Module):
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.upsw2 = create_switched_upsampler(nf * 8, nf * 4, 8) self.exp2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
self.upsw3 = create_switched_upsampler(nf * 4, nf * 2, 8) self.exp3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
self.switches = [self.upsw2, self.upsw3]
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
@ -341,7 +358,8 @@ class Discriminator_switched(nn.Module):
fea0 = self.conv0_0(x) fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0) fea0 = self.conv0_1(fea0)
fea1 = self.conv1_0(fea0) fea1, att = self.sw(fea0, True)
self.attentions = [att]
fea1 = self.conv1_1(fea1) fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1) fea2 = self.conv2_0(fea1)
@ -354,9 +372,9 @@ class Discriminator_switched(nn.Module):
fea4 = self.conv4_1(fea4) fea4 = self.conv4_1(fea4)
u1 = self.exp1(fea4, fea3) u1 = self.exp1(fea4, fea3)
u2, a1 = self.upsw2(u1, fea2, output_attention_weights=True) u2 = self.exp2(u1, fea2)
u3, a2 = self.upsw3(u2, fea1, output_attention_weights=True) u3 = self.exp3(u2, fea1)
self.attentions = [a1, a2]
loss3 = self.collapse3(self.proc3(u3)) loss3 = self.collapse3(self.proc3(u3))
return loss3.view(-1, 1) return loss3.view(-1, 1)