Integrate RDB into SRG
The last RDB for each cluster is switched.
This commit is contained in:
parent
6ac6c95177
commit
e9ee67ff10
|
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from models.archs.arch_util import initialize_weights
|
from models.archs.arch_util import initialize_weights
|
||||||
|
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image
|
||||||
|
|
||||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||||
|
@ -177,7 +178,7 @@ class SwitchComputer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConfigurableSwitchComputer(nn.Module):
|
class ConfigurableSwitchComputer(nn.Module):
|
||||||
def __init__(self, base_filters, multiplexer_net, transform_block, transform_count, init_temp=20,
|
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20,
|
||||||
enable_negative_transforms=False, add_scalable_noise_to_transforms=False, init_scalar=1):
|
enable_negative_transforms=False, add_scalable_noise_to_transforms=False, init_scalar=1):
|
||||||
super(ConfigurableSwitchComputer, self).__init__()
|
super(ConfigurableSwitchComputer, self).__init__()
|
||||||
self.enable_negative_transforms = enable_negative_transforms
|
self.enable_negative_transforms = enable_negative_transforms
|
||||||
|
@ -187,8 +188,10 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
tc = transform_count * 2
|
tc = transform_count * 2
|
||||||
self.multiplexer = multiplexer_net(tc)
|
self.multiplexer = multiplexer_net(tc)
|
||||||
|
|
||||||
|
self.pre_transform = pre_transform_block()
|
||||||
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
||||||
self.add_noise = add_scalable_noise_to_transforms
|
self.add_noise = add_scalable_noise_to_transforms
|
||||||
|
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
|
||||||
|
|
||||||
# And the switch itself, including learned scalars
|
# And the switch itself, including learned scalars
|
||||||
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
||||||
|
@ -201,14 +204,15 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
def forward(self, x, output_attention_weights=False):
|
def forward(self, x, output_attention_weights=False):
|
||||||
identity = x
|
identity = x
|
||||||
if self.add_noise:
|
if self.add_noise:
|
||||||
rand_feature = torch.randn_like(x)
|
rand_feature = torch.randn_like(x) * self.noise_scale
|
||||||
xformed = [t.forward(x, rand_feature) for t in self.transforms]
|
x = x + rand_feature
|
||||||
else:
|
|
||||||
xformed = [t.forward(x) for t in self.transforms]
|
x = self.pre_transform(x)
|
||||||
|
xformed = [t.forward(x) for t in self.transforms]
|
||||||
if self.enable_negative_transforms:
|
if self.enable_negative_transforms:
|
||||||
xformed.extend([-t for t in xformed])
|
xformed.extend([-t for t in xformed])
|
||||||
|
|
||||||
m = self.multiplexer(x)
|
m = self.multiplexer(identity)
|
||||||
# Interpolate the multiplexer across the entire shape of the image.
|
# Interpolate the multiplexer across the entire shape of the image.
|
||||||
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
||||||
|
|
||||||
|
@ -361,8 +365,10 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
||||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
|
pre_transform_block=functools.partial(nn.Sequential, ResidualDenseBlock_5C(transformation_filters),
|
||||||
trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms,
|
ResidualDenseBlock_5C(transformation_filters)),
|
||||||
|
transform_block=functools.partial(ResidualDenseBlock_5C, transformation_filters),
|
||||||
|
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
|
||||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
|
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
|
||||||
|
|
||||||
self.switches = nn.ModuleList(switches)
|
self.switches = nn.ModuleList(switches)
|
||||||
|
@ -375,7 +381,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
self.upsample_factor = upsample_factor
|
self.upsample_factor = upsample_factor
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
x = self.initial_conv(x)
|
x = self.initial_conv(x)
|
||||||
|
|
||||||
self.attentions = []
|
self.attentions = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user