Revert "SSG: offer option to use BN-based attention normalization"

Didn't work. Oh well.

This reverts commit 5cd2b37591.
This commit is contained in:
James Betker 2020-10-03 17:54:53 -06:00
parent 43c6c67fd1
commit e3294939b0
3 changed files with 23 additions and 24 deletions

View File

@ -129,7 +129,7 @@ class ReferenceImageBranch(nn.Module):
class SSGr1(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10, use_bn_attention_norm=False):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
super(SSGr1, self).__init__()
n_upscale = int(math.log(upscale, 2))
self.nf = nf
@ -144,16 +144,14 @@ class SSGr1(nn.Module):
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.25),
transformation_filters, kernel_size=3, depth=4,
weight_init_factor=.1)
use_attention_norm = not use_bn_attention_norm
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=use_attention_norm,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
attention_batchnorm=use_bn_attention_norm)
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
self.get_g_nopadding = ImageGradientNoPadding()
@ -161,10 +159,9 @@ class SSGr1(nn.Module):
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=use_attention_norm,
attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
attention_batchnorm=use_bn_attention_norm)
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False)
self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
@ -173,10 +170,9 @@ class SSGr1(nn.Module):
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, kernel_size=1, depth=2)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=use_attention_norm,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
attention_batchnorm=use_bn_attention_norm)
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True)
self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True)

View File

@ -11,6 +11,10 @@ from utils.util import checkpoint
from models.archs.spinenet_arch import SpineNet
# Set to true to relieve memory pressure by using utils.util in several memory-critical locations.
memory_checkpointing_enabled = True
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count.
class HalvingProcessingBlock(nn.Module):
@ -77,8 +81,8 @@ def gather_2d(input, index):
class ConfigurableSwitchComputer(nn.Module):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm=None,
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False, attention_batchnorm=None):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False):
super(ConfigurableSwitchComputer, self).__init__()
tc = transform_count
@ -101,11 +105,6 @@ class ConfigurableSwitchComputer(nn.Module):
# depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
if attention_batchnorm:
self.att_bn = nn.BatchNorm2d(transform_count)
self.att_relu = nn.ReLU()
else:
self.att_bn = None
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
@ -134,16 +133,21 @@ class ConfigurableSwitchComputer(nn.Module):
x = self.pre_transform(*x)
if not isinstance(x, tuple):
x = (x,)
if memory_checkpointing_enabled:
xformed = [checkpoint(t, *x) for t in self.transforms]
else:
xformed = [t(*x) for t in self.transforms]
if not isinstance(att_in, tuple):
att_in = (att_in,)
if self.feed_transforms_into_multiplexer:
att_in = att_in + (torch.stack(xformed, dim=1),)
if memory_checkpointing_enabled:
m = checkpoint(self.multiplexer, *att_in)
if self.att_bn:
m = self.att_relu(self.att_bn(m))
else:
m = self.multiplexer(*att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True)
outputs = identity + outputs * self.switch_scale * fixed_scale
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale

View File

@ -81,8 +81,7 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == "ssgr1":
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = ssg.SSGr1(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10,
use_bn_attention_norm=opt_net['bn_attention_norm'] if 'bn_attention_norm' in opt_net.keys() else False)
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == 'ssg_no_embedding':
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = ssg.SSGNoEmbedding(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],