diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index faae11c4..3eb647ed 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -5,6 +5,7 @@ import torch.nn.functional as F import functools from collections import OrderedDict from models.archs.arch_util import initialize_weights +from models.archs.RRDBNet_arch import ResidualDenseBlock_5C from switched_conv_util import save_attention_to_image ''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard @@ -177,7 +178,7 @@ class SwitchComputer(nn.Module): class ConfigurableSwitchComputer(nn.Module): - def __init__(self, base_filters, multiplexer_net, transform_block, transform_count, init_temp=20, + def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20, enable_negative_transforms=False, add_scalable_noise_to_transforms=False, init_scalar=1): super(ConfigurableSwitchComputer, self).__init__() self.enable_negative_transforms = enable_negative_transforms @@ -187,8 +188,10 @@ class ConfigurableSwitchComputer(nn.Module): tc = transform_count * 2 self.multiplexer = multiplexer_net(tc) + self.pre_transform = pre_transform_block() self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) self.add_noise = add_scalable_noise_to_transforms + self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) # And the switch itself, including learned scalars self.switch = BareConvSwitch(initial_temperature=init_temp) @@ -201,14 +204,15 @@ class ConfigurableSwitchComputer(nn.Module): def forward(self, x, output_attention_weights=False): identity = x if self.add_noise: - rand_feature = torch.randn_like(x) - xformed = [t.forward(x, rand_feature) for t in self.transforms] - else: - xformed = [t.forward(x) for t in self.transforms] + rand_feature = torch.randn_like(x) * self.noise_scale + x = x + rand_feature + + x = self.pre_transform(x) + xformed = [t.forward(x) for t in self.transforms] if self.enable_negative_transforms: xformed.extend([-t for t in xformed]) - m = self.multiplexer(x) + m = self.multiplexer(identity) # Interpolate the multiplexer across the entire shape of the image. m = F.interpolate(m, size=x.shape[2:], mode='nearest') @@ -361,8 +365,10 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count) switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), - trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, + pre_transform_block=functools.partial(nn.Sequential, ResidualDenseBlock_5C(transformation_filters), + ResidualDenseBlock_5C(transformation_filters)), + transform_block=functools.partial(ResidualDenseBlock_5C, transformation_filters), + transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01)) self.switches = nn.ModuleList(switches) @@ -375,7 +381,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): self.upsample_factor = upsample_factor def forward(self, x): - x = self.initial_conv(x) self.attentions = []