Make attention norm optional

This commit is contained in:
James Betker 2020-07-18 07:24:02 -06:00
parent ad97a6a18a
commit 47a525241f
3 changed files with 10 additions and 21 deletions

View File

@ -110,8 +110,8 @@ class BackboneMultiplexer(nn.Module):
class ConfigurableSwitchComputer(nn.Module):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20,
add_scalable_noise_to_transforms=False):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
init_temp=20, add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchComputer, self).__init__()
tc = transform_count
@ -123,14 +123,14 @@ class ConfigurableSwitchComputer(nn.Module):
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
# And the switch itself, including learned scalars
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count))
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None)
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
# depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
def forward(self, x, output_attention_weights=False):
def forward(self, x, output_attention_weights=False, fixed_scale=1):
identity = x
if self.add_noise:
rand_feature = torch.randn_like(x) * self.noise_scale
@ -142,8 +142,8 @@ class ConfigurableSwitchComputer(nn.Module):
m = self.multiplexer(identity)
outputs, attention = self.switch(xformed, m, True)
outputs = identity + outputs * self.switch_scale
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale
outputs = identity + outputs * self.switch_scale * fixed_scale
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
if output_attention_weights:
return outputs, attention
else:
@ -155,7 +155,7 @@ class ConfigurableSwitchComputer(nn.Module):
class ConfigurableSwitchedResidualGenerator2(nn.Module):
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
trans_layers, transformation_filters, attention_norm, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1,
add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
@ -171,6 +171,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=attention_norm,
transform_count=trans_counts, init_temp=initial_temp,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
@ -232,18 +233,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
val["switch_%i_histogram" % (i,)] = hists[i]
return val
def load_state_dict(self, state_dict, strict=True):
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
t_state = self.state_dict()
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
for i in range(4):
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
super(ConfigurableSwitchedResidualGenerator2, self).load_state_dict(state_dict, strict)
class Interpolate(nn.Module):
def __init__(self, factor):
super(Interpolate, self).__init__()

View File

@ -63,7 +63,7 @@ def define_G(opt, net_key='network_G'):
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_fdisc.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_pixgan_srg2/train_imgset_pixgan_srg2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)