Make attention norm optional
This commit is contained in:
parent
ad97a6a18a
commit
47a525241f
|
@ -110,8 +110,8 @@ class BackboneMultiplexer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConfigurableSwitchComputer(nn.Module):
|
class ConfigurableSwitchComputer(nn.Module):
|
||||||
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20,
|
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
|
||||||
add_scalable_noise_to_transforms=False):
|
init_temp=20, add_scalable_noise_to_transforms=False):
|
||||||
super(ConfigurableSwitchComputer, self).__init__()
|
super(ConfigurableSwitchComputer, self).__init__()
|
||||||
|
|
||||||
tc = transform_count
|
tc = transform_count
|
||||||
|
@ -123,14 +123,14 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
|
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
|
||||||
|
|
||||||
# And the switch itself, including learned scalars
|
# And the switch itself, including learned scalars
|
||||||
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count))
|
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None)
|
||||||
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
|
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
|
||||||
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
||||||
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
||||||
# depending on its needs.
|
# depending on its needs.
|
||||||
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||||
|
|
||||||
def forward(self, x, output_attention_weights=False):
|
def forward(self, x, output_attention_weights=False, fixed_scale=1):
|
||||||
identity = x
|
identity = x
|
||||||
if self.add_noise:
|
if self.add_noise:
|
||||||
rand_feature = torch.randn_like(x) * self.noise_scale
|
rand_feature = torch.randn_like(x) * self.noise_scale
|
||||||
|
@ -142,8 +142,8 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
m = self.multiplexer(identity)
|
m = self.multiplexer(identity)
|
||||||
|
|
||||||
outputs, attention = self.switch(xformed, m, True)
|
outputs, attention = self.switch(xformed, m, True)
|
||||||
outputs = identity + outputs * self.switch_scale
|
outputs = identity + outputs * self.switch_scale * fixed_scale
|
||||||
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale
|
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
|
||||||
if output_attention_weights:
|
if output_attention_weights:
|
||||||
return outputs, attention
|
return outputs, attention
|
||||||
else:
|
else:
|
||||||
|
@ -155,7 +155,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
|
|
||||||
class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
trans_layers, transformation_filters, attention_norm, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||||
heightened_final_step=50000, upsample_factor=1,
|
heightened_final_step=50000, upsample_factor=1,
|
||||||
add_scalable_noise_to_transforms=False):
|
add_scalable_noise_to_transforms=False):
|
||||||
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
||||||
|
@ -171,6 +171,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
|
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
|
||||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
|
attention_norm=attention_norm,
|
||||||
transform_count=trans_counts, init_temp=initial_temp,
|
transform_count=trans_counts, init_temp=initial_temp,
|
||||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||||
|
|
||||||
|
@ -232,18 +233,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
|
||||||
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
|
|
||||||
t_state = self.state_dict()
|
|
||||||
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
|
|
||||||
for i in range(4):
|
|
||||||
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
|
|
||||||
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
|
|
||||||
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
|
|
||||||
super(ConfigurableSwitchedResidualGenerator2, self).load_state_dict(state_dict, strict)
|
|
||||||
|
|
||||||
|
|
||||||
class Interpolate(nn.Module):
|
class Interpolate(nn.Module):
|
||||||
def __init__(self, factor):
|
def __init__(self, factor):
|
||||||
super(Interpolate, self).__init__()
|
super(Interpolate, self).__init__()
|
||||||
|
|
|
@ -63,7 +63,7 @@ def define_G(opt, net_key='network_G'):
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
switch_reductions=opt_net['switch_reductions'],
|
||||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||||
transformation_filters=opt_net['transformation_filters'],
|
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_fdisc.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_pixgan_srg2/train_imgset_pixgan_srg2.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user