forked from mrq/DL-Art-School
SSG: offer option to use BN-based attention normalization
Not sure how this is going to work, lets try it.
This commit is contained in:
parent
c896939523
commit
5cd2b37591
|
@ -129,7 +129,7 @@ class ReferenceImageBranch(nn.Module):
|
|||
|
||||
|
||||
class SSGr1(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10, use_bn_attention_norm=False):
|
||||
super(SSGr1, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
self.nf = nf
|
||||
|
@ -144,14 +144,16 @@ class SSGr1(nn.Module):
|
|||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.25),
|
||||
transformation_filters, kernel_size=3, depth=4,
|
||||
weight_init_factor=.1)
|
||||
use_attention_norm = not use_bn_attention_norm
|
||||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
attention_norm=use_attention_norm,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
|
||||
attention_batchnorm=use_bn_attention_norm)
|
||||
|
||||
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
|
@ -159,9 +161,10 @@ class SSGr1(nn.Module):
|
|||
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
attention_norm=use_attention_norm,
|
||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
|
||||
attention_batchnorm=use_bn_attention_norm)
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False)
|
||||
self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
|
||||
|
@ -170,9 +173,10 @@ class SSGr1(nn.Module):
|
|||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, kernel_size=1, depth=2)
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
attention_norm=use_attention_norm,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True,
|
||||
attention_batchnorm=use_bn_attention_norm)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True)
|
||||
self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True)
|
||||
|
|
|
@ -11,10 +11,6 @@ from utils.util import checkpoint
|
|||
from models.archs.spinenet_arch import SpineNet
|
||||
|
||||
|
||||
# Set to true to relieve memory pressure by using utils.util in several memory-critical locations.
|
||||
memory_checkpointing_enabled = True
|
||||
|
||||
|
||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||
# Doubles the input filter count.
|
||||
class HalvingProcessingBlock(nn.Module):
|
||||
|
@ -81,8 +77,8 @@ def gather_2d(input, index):
|
|||
|
||||
|
||||
class ConfigurableSwitchComputer(nn.Module):
|
||||
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
|
||||
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False):
|
||||
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm=None,
|
||||
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False, attention_batchnorm=None):
|
||||
super(ConfigurableSwitchComputer, self).__init__()
|
||||
|
||||
tc = transform_count
|
||||
|
@ -105,6 +101,11 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
# depending on its needs.
|
||||
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||
|
||||
if attention_batchnorm:
|
||||
self.att_bn = nn.BatchNorm2d(transform_count)
|
||||
self.att_relu = nn.ReLU()
|
||||
else:
|
||||
self.att_bn = None
|
||||
|
||||
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
|
||||
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
|
||||
|
@ -133,21 +134,16 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
x = self.pre_transform(*x)
|
||||
if not isinstance(x, tuple):
|
||||
x = (x,)
|
||||
if memory_checkpointing_enabled:
|
||||
xformed = [checkpoint(t, *x) for t in self.transforms]
|
||||
else:
|
||||
xformed = [t(*x) for t in self.transforms]
|
||||
xformed = [checkpoint(t, *x) for t in self.transforms]
|
||||
|
||||
if not isinstance(att_in, tuple):
|
||||
att_in = (att_in,)
|
||||
if self.feed_transforms_into_multiplexer:
|
||||
att_in = att_in + (torch.stack(xformed, dim=1),)
|
||||
if memory_checkpointing_enabled:
|
||||
m = checkpoint(self.multiplexer, *att_in)
|
||||
else:
|
||||
m = self.multiplexer(*att_in)
|
||||
m = checkpoint(self.multiplexer, *att_in)
|
||||
if self.att_bn:
|
||||
m = self.att_relu(self.att_bn(m))
|
||||
|
||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
outputs = identity + outputs * self.switch_scale * fixed_scale
|
||||
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
|
||||
|
|
|
@ -81,7 +81,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
elif which_model == "ssgr1":
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = ssg.SSGr1(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10,
|
||||
use_bn_attention_norm=opt_net['bn_attention_norm'] if 'bn_attention_norm' in opt_net.keys() else False)
|
||||
elif which_model == 'ssg_no_embedding':
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = ssg.SSGNoEmbedding(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
|
|
Loading…
Reference in New Issue
Block a user