SSG: offer option to use BN-based attention normalization

Not sure how this is going to work, lets try it.
This commit is contained in:
James Betker 2020-10-03 16:16:19 -06:00
parent c896939523
commit 5cd2b37591
3 changed files with 24 additions and 23 deletions

View File

@ -129,7 +129,7 @@ class ReferenceImageBranch(nn.Module):
class SSGr1(nn.Module): class SSGr1(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10, use_bn_attention_norm=False):
super(SSGr1, self).__init__() super(SSGr1, self).__init__()
n_upscale = int(math.log(upscale, 2)) n_upscale = int(math.log(upscale, 2))
self.nf = nf self.nf = nf
@ -144,14 +144,16 @@ class SSGr1(nn.Module):
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.25), transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.25),
transformation_filters, kernel_size=3, depth=4, transformation_filters, kernel_size=3, depth=4,
weight_init_factor=.1) weight_init_factor=.1)
use_attention_norm = not use_bn_attention_norm
# Feature branch # Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn, pre_transform_block=None, transform_block=transform_fn,
attention_norm=True, attention_norm=use_attention_norm,
transform_count=self.transformation_counts, init_temp=init_temperature, transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
attention_batchnorm=use_bn_attention_norm)
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
self.get_g_nopadding = ImageGradientNoPadding() self.get_g_nopadding = ImageGradientNoPadding()
@ -159,9 +161,10 @@ class SSGr1(nn.Module):
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2) self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn, pre_transform_block=None, transform_block=transform_fn,
attention_norm=True, attention_norm=use_attention_norm,
transform_count=self.transformation_counts // 2, init_temp=init_temperature, transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
attention_batchnorm=use_bn_attention_norm)
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False) self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False)
self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True) self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
@ -170,9 +173,10 @@ class SSGr1(nn.Module):
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, kernel_size=1, depth=2) self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, kernel_size=1, depth=2)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn, pre_transform_block=None, transform_block=transform_fn,
attention_norm=True, attention_norm=use_attention_norm,
transform_count=self.transformation_counts, init_temp=init_temperature, transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
attention_batchnorm=use_bn_attention_norm)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True)
self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True)

View File

@ -11,10 +11,6 @@ from utils.util import checkpoint
from models.archs.spinenet_arch import SpineNet from models.archs.spinenet_arch import SpineNet
# Set to true to relieve memory pressure by using utils.util in several memory-critical locations.
memory_checkpointing_enabled = True
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count. # Doubles the input filter count.
class HalvingProcessingBlock(nn.Module): class HalvingProcessingBlock(nn.Module):
@ -81,8 +77,8 @@ def gather_2d(input, index):
class ConfigurableSwitchComputer(nn.Module): class ConfigurableSwitchComputer(nn.Module):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm=None,
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False): init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False, attention_batchnorm=None):
super(ConfigurableSwitchComputer, self).__init__() super(ConfigurableSwitchComputer, self).__init__()
tc = transform_count tc = transform_count
@ -105,6 +101,11 @@ class ConfigurableSwitchComputer(nn.Module):
# depending on its needs. # depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
if attention_batchnorm:
self.att_bn = nn.BatchNorm2d(transform_count)
self.att_relu = nn.ReLU()
else:
self.att_bn = None
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element # Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity. # *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
@ -133,21 +134,16 @@ class ConfigurableSwitchComputer(nn.Module):
x = self.pre_transform(*x) x = self.pre_transform(*x)
if not isinstance(x, tuple): if not isinstance(x, tuple):
x = (x,) x = (x,)
if memory_checkpointing_enabled: xformed = [checkpoint(t, *x) for t in self.transforms]
xformed = [checkpoint(t, *x) for t in self.transforms]
else:
xformed = [t(*x) for t in self.transforms]
if not isinstance(att_in, tuple): if not isinstance(att_in, tuple):
att_in = (att_in,) att_in = (att_in,)
if self.feed_transforms_into_multiplexer: if self.feed_transforms_into_multiplexer:
att_in = att_in + (torch.stack(xformed, dim=1),) att_in = att_in + (torch.stack(xformed, dim=1),)
if memory_checkpointing_enabled: m = checkpoint(self.multiplexer, *att_in)
m = checkpoint(self.multiplexer, *att_in) if self.att_bn:
else: m = self.att_relu(self.att_bn(m))
m = self.multiplexer(*att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True) outputs, attention = self.switch(xformed, m, True)
outputs = identity + outputs * self.switch_scale * fixed_scale outputs = identity + outputs * self.switch_scale * fixed_scale
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale

View File

@ -81,7 +81,8 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == "ssgr1": elif which_model == "ssgr1":
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = ssg.SSGr1(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], netG = ssg.SSGr1(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10,
use_bn_attention_norm=opt_net['bn_attention_norm'] if 'bn_attention_norm' in opt_net.keys() else False)
elif which_model == 'ssg_no_embedding': elif which_model == 'ssg_no_embedding':
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = ssg.SSGNoEmbedding(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], netG = ssg.SSGNoEmbedding(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],