diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index c9d13715..34106d1a 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -110,8 +110,8 @@ class BackboneMultiplexer(nn.Module): class ConfigurableSwitchComputer(nn.Module): - def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20, - add_scalable_noise_to_transforms=False): + def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, + init_temp=20, add_scalable_noise_to_transforms=False): super(ConfigurableSwitchComputer, self).__init__() tc = transform_count @@ -123,14 +123,14 @@ class ConfigurableSwitchComputer(nn.Module): self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) # And the switch itself, including learned scalars - self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count)) + self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None) self.switch_scale = nn.Parameter(torch.full((1,), float(1))) self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True) # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) # depending on its needs. self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) - def forward(self, x, output_attention_weights=False): + def forward(self, x, output_attention_weights=False, fixed_scale=1): identity = x if self.add_noise: rand_feature = torch.randn_like(x) * self.noise_scale @@ -142,8 +142,8 @@ class ConfigurableSwitchComputer(nn.Module): m = self.multiplexer(identity) outputs, attention = self.switch(xformed, m, True) - outputs = identity + outputs * self.switch_scale - outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale + outputs = identity + outputs * self.switch_scale * fixed_scale + outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale if output_attention_weights: return outputs, attention else: @@ -155,7 +155,7 @@ class ConfigurableSwitchComputer(nn.Module): class ConfigurableSwitchedResidualGenerator2(nn.Module): def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, - trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, + trans_layers, transformation_filters, attention_norm, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False): super(ConfigurableSwitchedResidualGenerator2, self).__init__() @@ -171,6 +171,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=attention_norm, transform_count=trans_counts, init_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) @@ -232,18 +233,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): val["switch_%i_histogram" % (i,)] = hists[i] return val - - def load_state_dict(self, state_dict, strict=True): - # Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict - t_state = self.state_dict() - if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys(): - for i in range(4): - state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)] - state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)] - state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] - super(ConfigurableSwitchedResidualGenerator2, self).load_state_dict(state_dict, strict) - - class Interpolate(nn.Module): def __init__(self, factor): super(Interpolate, self).__init__() diff --git a/codes/models/networks.py b/codes/models/networks.py index c1732584..432cbe1c 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -63,7 +63,7 @@ def define_G(opt, net_key='network_G'): switch_reductions=opt_net['switch_reductions'], switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], - transformation_filters=opt_net['transformation_filters'], + transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'], initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) diff --git a/codes/train.py b/codes/train.py index 89b36d91..99fc7c14 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_fdisc.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_pixgan_srg2/train_imgset_pixgan_srg2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)