Make attention norm optional

This commit is contained in:
James Betker 2020-07-18 07:24:02 -06:00
parent ad97a6a18a
commit 47a525241f
3 changed files with 10 additions and 21 deletions

View File

@ -110,8 +110,8 @@ class BackboneMultiplexer(nn.Module):
class ConfigurableSwitchComputer(nn.Module): class ConfigurableSwitchComputer(nn.Module):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20, def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
add_scalable_noise_to_transforms=False): init_temp=20, add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchComputer, self).__init__() super(ConfigurableSwitchComputer, self).__init__()
tc = transform_count tc = transform_count
@ -123,14 +123,14 @@ class ConfigurableSwitchComputer(nn.Module):
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
# And the switch itself, including learned scalars # And the switch itself, including learned scalars
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count)) self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None)
self.switch_scale = nn.Parameter(torch.full((1,), float(1))) self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True) self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
# depending on its needs. # depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
def forward(self, x, output_attention_weights=False): def forward(self, x, output_attention_weights=False, fixed_scale=1):
identity = x identity = x
if self.add_noise: if self.add_noise:
rand_feature = torch.randn_like(x) * self.noise_scale rand_feature = torch.randn_like(x) * self.noise_scale
@ -142,8 +142,8 @@ class ConfigurableSwitchComputer(nn.Module):
m = self.multiplexer(identity) m = self.multiplexer(identity)
outputs, attention = self.switch(xformed, m, True) outputs, attention = self.switch(xformed, m, True)
outputs = identity + outputs * self.switch_scale outputs = identity + outputs * self.switch_scale * fixed_scale
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
if output_attention_weights: if output_attention_weights:
return outputs, attention return outputs, attention
else: else:
@ -155,7 +155,7 @@ class ConfigurableSwitchComputer(nn.Module):
class ConfigurableSwitchedResidualGenerator2(nn.Module): class ConfigurableSwitchedResidualGenerator2(nn.Module):
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, trans_layers, transformation_filters, attention_norm, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1, heightened_final_step=50000, upsample_factor=1,
add_scalable_noise_to_transforms=False): add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchedResidualGenerator2, self).__init__() super(ConfigurableSwitchedResidualGenerator2, self).__init__()
@ -171,6 +171,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=attention_norm,
transform_count=trans_counts, init_temp=initial_temp, transform_count=trans_counts, init_temp=initial_temp,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
@ -232,18 +233,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
val["switch_%i_histogram" % (i,)] = hists[i] val["switch_%i_histogram" % (i,)] = hists[i]
return val return val
def load_state_dict(self, state_dict, strict=True):
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
t_state = self.state_dict()
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
for i in range(4):
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
super(ConfigurableSwitchedResidualGenerator2, self).load_state_dict(state_dict, strict)
class Interpolate(nn.Module): class Interpolate(nn.Module):
def __init__(self, factor): def __init__(self, factor):
super(Interpolate, self).__init__() super(Interpolate, self).__init__()

View File

@ -63,7 +63,7 @@ def define_G(opt, net_key='network_G'):
switch_reductions=opt_net['switch_reductions'], switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'], transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_fdisc.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_pixgan_srg2/train_imgset_pixgan_srg2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)