diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index d92e6fd7..69991846 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -129,7 +129,7 @@ class ReferenceImageBranch(nn.Module): class SSGr1(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10, use_bn_attention_norm=False): super(SSGr1, self).__init__() n_upscale = int(math.log(upscale, 2)) self.nf = nf @@ -144,14 +144,16 @@ class SSGr1(nn.Module): transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.25), transformation_filters, kernel_size=3, depth=4, weight_init_factor=.1) + use_attention_norm = not use_bn_attention_norm # Feature branch self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, + attention_norm=use_attention_norm, transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, + attention_batchnorm=use_bn_attention_norm) # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. self.get_g_nopadding = ImageGradientNoPadding() @@ -159,9 +161,10 @@ class SSGr1(nn.Module): self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2) self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, + attention_norm=use_attention_norm, transform_count=self.transformation_counts // 2, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, + attention_batchnorm=use_bn_attention_norm) self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False) self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True) @@ -170,9 +173,10 @@ class SSGr1(nn.Module): self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, kernel_size=1, depth=2) self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, + attention_norm=use_attention_norm, transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, + attention_batchnorm=use_bn_attention_norm) self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 7b6edfc8..545e2604 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -11,10 +11,6 @@ from utils.util import checkpoint from models.archs.spinenet_arch import SpineNet -# Set to true to relieve memory pressure by using utils.util in several memory-critical locations. -memory_checkpointing_enabled = True - - # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation # Doubles the input filter count. class HalvingProcessingBlock(nn.Module): @@ -81,8 +77,8 @@ def gather_2d(input, index): class ConfigurableSwitchComputer(nn.Module): - def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, - init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False): + def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm=None, + init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False, attention_batchnorm=None): super(ConfigurableSwitchComputer, self).__init__() tc = transform_count @@ -105,6 +101,11 @@ class ConfigurableSwitchComputer(nn.Module): # depending on its needs. self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) + if attention_batchnorm: + self.att_bn = nn.BatchNorm2d(transform_count) + self.att_relu = nn.ReLU() + else: + self.att_bn = None # Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element # *must* be the actual parameter that gets fed through the network - it is assumed to be the identity. @@ -133,21 +134,16 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(*x) if not isinstance(x, tuple): x = (x,) - if memory_checkpointing_enabled: - xformed = [checkpoint(t, *x) for t in self.transforms] - else: - xformed = [t(*x) for t in self.transforms] + xformed = [checkpoint(t, *x) for t in self.transforms] if not isinstance(att_in, tuple): att_in = (att_in,) if self.feed_transforms_into_multiplexer: att_in = att_in + (torch.stack(xformed, dim=1),) - if memory_checkpointing_enabled: - m = checkpoint(self.multiplexer, *att_in) - else: - m = self.multiplexer(*att_in) + m = checkpoint(self.multiplexer, *att_in) + if self.att_bn: + m = self.att_relu(self.att_bn(m)) - # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True) outputs = identity + outputs * self.switch_scale * fixed_scale outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale diff --git a/codes/models/networks.py b/codes/models/networks.py index 8e309a6f..c01239db 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -81,7 +81,8 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == "ssgr1": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = ssg.SSGr1(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) + init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, + use_bn_attention_norm=opt_net['bn_attention_norm'] if 'bn_attention_norm' in opt_net.keys() else False) elif which_model == 'ssg_no_embedding': xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = ssg.SSGNoEmbedding(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],