From e3294939b0932dd0a7b1ad984f16f946fc3f4df3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 3 Oct 2020 17:54:53 -0600 Subject: [PATCH] Revert "SSG: offer option to use BN-based attention normalization" Didn't work. Oh well. This reverts commit 5cd2b37591b6b29c8a3acb04f80713b228ab3894. --- .../archs/StructuredSwitchedGenerator.py | 18 +++++-------- .../archs/SwitchedResidualGenerator_arch.py | 26 +++++++++++-------- codes/models/networks.py | 3 +-- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 69991846..d92e6fd7 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -129,7 +129,7 @@ class ReferenceImageBranch(nn.Module): class SSGr1(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10, use_bn_attention_norm=False): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): super(SSGr1, self).__init__() n_upscale = int(math.log(upscale, 2)) self.nf = nf @@ -144,16 +144,14 @@ class SSGr1(nn.Module): transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.25), transformation_filters, kernel_size=3, depth=4, weight_init_factor=.1) - use_attention_norm = not use_bn_attention_norm # Feature branch self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=None, transform_block=transform_fn, - attention_norm=use_attention_norm, + attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, - attention_batchnorm=use_bn_attention_norm) + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. self.get_g_nopadding = ImageGradientNoPadding() @@ -161,10 +159,9 @@ class SSGr1(nn.Module): self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2) self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=None, transform_block=transform_fn, - attention_norm=use_attention_norm, + attention_norm=True, transform_count=self.transformation_counts // 2, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, - attention_batchnorm=use_bn_attention_norm) + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False) self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True) @@ -173,10 +170,9 @@ class SSGr1(nn.Module): self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, kernel_size=1, depth=2) self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=None, transform_block=transform_fn, - attention_norm=use_attention_norm, + attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, - attention_batchnorm=use_bn_attention_norm) + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 545e2604..7b6edfc8 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -11,6 +11,10 @@ from utils.util import checkpoint from models.archs.spinenet_arch import SpineNet +# Set to true to relieve memory pressure by using utils.util in several memory-critical locations. +memory_checkpointing_enabled = True + + # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation # Doubles the input filter count. class HalvingProcessingBlock(nn.Module): @@ -77,8 +81,8 @@ def gather_2d(input, index): class ConfigurableSwitchComputer(nn.Module): - def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm=None, - init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False, attention_batchnorm=None): + def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, + init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False): super(ConfigurableSwitchComputer, self).__init__() tc = transform_count @@ -101,11 +105,6 @@ class ConfigurableSwitchComputer(nn.Module): # depending on its needs. self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) - if attention_batchnorm: - self.att_bn = nn.BatchNorm2d(transform_count) - self.att_relu = nn.ReLU() - else: - self.att_bn = None # Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element # *must* be the actual parameter that gets fed through the network - it is assumed to be the identity. @@ -134,16 +133,21 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(*x) if not isinstance(x, tuple): x = (x,) - xformed = [checkpoint(t, *x) for t in self.transforms] + if memory_checkpointing_enabled: + xformed = [checkpoint(t, *x) for t in self.transforms] + else: + xformed = [t(*x) for t in self.transforms] if not isinstance(att_in, tuple): att_in = (att_in,) if self.feed_transforms_into_multiplexer: att_in = att_in + (torch.stack(xformed, dim=1),) - m = checkpoint(self.multiplexer, *att_in) - if self.att_bn: - m = self.att_relu(self.att_bn(m)) + if memory_checkpointing_enabled: + m = checkpoint(self.multiplexer, *att_in) + else: + m = self.multiplexer(*att_in) + # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True) outputs = identity + outputs * self.switch_scale * fixed_scale outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale diff --git a/codes/models/networks.py b/codes/models/networks.py index c01239db..8e309a6f 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -81,8 +81,7 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == "ssgr1": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = ssg.SSGr1(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, - use_bn_attention_norm=opt_net['bn_attention_norm'] if 'bn_attention_norm' in opt_net.keys() else False) + init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == 'ssg_no_embedding': xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = ssg.SSGNoEmbedding(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],