Fix SRG4 & switch disc
"fix". hehe.
This commit is contained in:
parent
e6e91a1d75
commit
b06e1784e1
|
@ -139,7 +139,8 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
rand_feature = torch.randn_like(x) * self.noise_scale
|
||||
x = x + rand_feature
|
||||
|
||||
x = self.pre_transform(x)
|
||||
if self.pre_transform:
|
||||
x = self.pre_transform(x)
|
||||
xformed = [t.forward(x) for t in self.transforms]
|
||||
m = self.multiplexer(identity)
|
||||
|
||||
|
@ -255,6 +256,8 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module):
|
|||
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
|
||||
switch_processing_layers, trans_counts)
|
||||
half_multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
|
||||
switch_processing_layers, trans_counts // 2)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers,
|
||||
weight_init_factor=.1)
|
||||
|
@ -265,12 +268,19 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module):
|
|||
transform_count=trans_counts, init_temp=initial_temp,
|
||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
||||
self.rdb2 = RRDB(transformation_filters)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, half_multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=attention_norm,
|
||||
transform_count=trans_counts // 2, init_temp=initial_temp,
|
||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
||||
self.rdb3 = RRDB(transformation_filters)
|
||||
self.sw3 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=attention_norm,
|
||||
transform_count=trans_counts, init_temp=initial_temp,
|
||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
||||
self.rdb3 = RRDB(transformation_filters)
|
||||
self.rdb4 = RRDB(transformation_filters)
|
||||
self.switches = [self.sw1, self.sw2, self.sw3]
|
||||
|
||||
self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
|
||||
self.transformation_counts = trans_counts
|
||||
|
@ -290,10 +300,13 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module):
|
|||
x = self.initial_conv(x)
|
||||
|
||||
x = self.rdb1(x)
|
||||
x = self.sw1(x, True)
|
||||
x, a1 = self.sw1(x, True)
|
||||
x = self.rdb2(x)
|
||||
x = self.sw2(x, True)
|
||||
x, a2 = self.sw2(x, True)
|
||||
x = self.rdb3(x)
|
||||
x, a3 = self.sw3(x, True)
|
||||
x = self.rdb4(x)
|
||||
self.attentions = [a1, a2, a3]
|
||||
|
||||
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||
if self.upsample_factor > 2:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
@ -244,15 +244,33 @@ from switched_conv_util import save_attention_to_image
|
|||
from switched_conv import compute_attention_specificity, AttentionNorm
|
||||
|
||||
|
||||
class ExpandAndCollapse(nn.Module):
|
||||
def __init__(self, nf, nf_out, num_channels):
|
||||
super(ExpandAndCollapse, self).__init__()
|
||||
self.expand = ExpansionBlock(nf, nf_out, block=ConvGnLelu)
|
||||
self.collapse = ConvGnLelu(nf_out, num_channels, norm=False, bias=False, activation=False)
|
||||
class ReducingMultiplexer(nn.Module):
|
||||
def __init__(self, nf, num_channels):
|
||||
super(ReducingMultiplexer, self).__init__()
|
||||
self.conv1_0 = ConvGnSilu(nf, nf * 2, kernel_size=3, bias=False)
|
||||
self.conv1_1 = ConvGnSilu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = ConvGnSilu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||
self.conv2_1 = ConvGnSilu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = ConvGnSilu(nf * 4, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv3_1 = ConvGnSilu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
self.exp1 = ExpansionBlock(nf * 8, nf * 4)
|
||||
self.exp2 = ExpansionBlock(nf * 4, nf * 2)
|
||||
self.exp3 = ExpansionBlock(nf * 2, nf)
|
||||
self.collapse = ConvGnSilu(nf, num_channels, norm=False, bias=True)
|
||||
|
||||
def forward(self, x, passthrough):
|
||||
x = self.expand(x, passthrough)
|
||||
return self.collapse(x)
|
||||
def forward(self, x):
|
||||
fea1 = self.conv1_0(x)
|
||||
fea1 = self.conv1_1(fea1)
|
||||
fea2 = self.conv2_0(fea1)
|
||||
fea2 = self.conv2_1(fea2)
|
||||
fea3 = self.conv3_0(fea2)
|
||||
fea3 = self.conv3_1(fea3)
|
||||
up = self.exp1(fea3, fea2)
|
||||
up = self.exp2(up, fea1)
|
||||
up = self.exp3(up, x)
|
||||
return self.collapse(up)
|
||||
|
||||
|
||||
# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in.
|
||||
|
@ -274,16 +292,15 @@ class ConfigurableLinearSwitchComputer(nn.Module):
|
|||
# depending on its needs.
|
||||
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||
|
||||
def forward(self, x, passthrough, output_attention_weights=False, extra_arg=None):
|
||||
identity = x
|
||||
def forward(self, x, output_attention_weights=False, extra_arg=None):
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x) * self.noise_scale
|
||||
x = x + rand_feature
|
||||
|
||||
x = self.pre_transform(x)
|
||||
xformed = [t.forward(x, passthrough) for t in self.transforms]
|
||||
m = self.multiplexer(identity, passthrough)
|
||||
|
||||
if self.pre_transform:
|
||||
x = self.pre_transform(x)
|
||||
xformed = [t.forward(x) for t in self.transforms]
|
||||
m = self.multiplexer(x)
|
||||
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
outputs = self.post_switch_conv(outputs)
|
||||
|
@ -296,10 +313,10 @@ class ConfigurableLinearSwitchComputer(nn.Module):
|
|||
self.switch.set_attention_temperature(temp)
|
||||
|
||||
|
||||
def create_switched_upsampler(nf, nf_out, num_channels, initial_temp=10):
|
||||
multiplx = ExpandAndCollapse(nf, nf_out, num_channels)
|
||||
pretransform = ConvGnLelu(nf, nf, norm=True, bias=False)
|
||||
transform_fn = functools.partial(ExpansionBlock, nf, nf_out, block=ConvGnLelu)
|
||||
def create_switched_downsampler(nf, nf_out, num_channels, initial_temp=10):
|
||||
multiplx = ReducingMultiplexer(nf, num_channels)
|
||||
pretransform = None
|
||||
transform_fn = functools.partial(MultiConvBlock, nf, nf, nf_out, kernel_size=3, depth=2)
|
||||
return ConfigurableLinearSwitchComputer(nf_out, multiplx,
|
||||
pre_transform_block=pretransform, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
|
@ -314,8 +331,9 @@ class Discriminator_switched(nn.Module):
|
|||
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
||||
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
|
||||
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
self.sw = create_switched_downsampler(nf, nf, 8)
|
||||
self.switches = [self.sw]
|
||||
self.conv1_1 = ConvGnLelu(nf, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||
|
@ -327,9 +345,8 @@ class Discriminator_switched(nn.Module):
|
|||
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||
self.upsw2 = create_switched_upsampler(nf * 8, nf * 4, 8)
|
||||
self.upsw3 = create_switched_upsampler(nf * 4, nf * 2, 8)
|
||||
self.switches = [self.upsw2, self.upsw3]
|
||||
self.exp2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
|
||||
self.exp3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
|
||||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
|
@ -341,7 +358,8 @@ class Discriminator_switched(nn.Module):
|
|||
fea0 = self.conv0_0(x)
|
||||
fea0 = self.conv0_1(fea0)
|
||||
|
||||
fea1 = self.conv1_0(fea0)
|
||||
fea1, att = self.sw(fea0, True)
|
||||
self.attentions = [att]
|
||||
fea1 = self.conv1_1(fea1)
|
||||
|
||||
fea2 = self.conv2_0(fea1)
|
||||
|
@ -354,9 +372,9 @@ class Discriminator_switched(nn.Module):
|
|||
fea4 = self.conv4_1(fea4)
|
||||
|
||||
u1 = self.exp1(fea4, fea3)
|
||||
u2, a1 = self.upsw2(u1, fea2, output_attention_weights=True)
|
||||
u3, a2 = self.upsw3(u2, fea1, output_attention_weights=True)
|
||||
self.attentions = [a1, a2]
|
||||
u2 = self.exp2(u1, fea2)
|
||||
u3 = self.exp3(u2, fea1)
|
||||
|
||||
loss3 = self.collapse3(self.proc3(u3))
|
||||
return loss3.view(-1, 1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user