diff --git a/codes/models/archs/NestedSwitchGenerator.py b/codes/models/archs/NestedSwitchGenerator.py index 45adb8a3..72b7f2bb 100644 --- a/codes/models/archs/NestedSwitchGenerator.py +++ b/codes/models/archs/NestedSwitchGenerator.py @@ -74,11 +74,6 @@ class Switch(nn.Module): self.scale = nn.Parameter(torch.ones(1)) self.bias = nn.Parameter(torch.zeros(1)) - if not self.pass_chain_forward: - self.parameterize = ConvBnLelu(64, 64, bn=False, lelu=False) - self.c_constric = MultiConvBlock(576, 256, 64, kernel_size=1, depth=3, bn=False) - self.c_process = ConvBnLelu(64, 64, kernel_size=1, lelu=False, bn=False) - # x is the input fed to the transform blocks. # m is the output of the multiplexer which will be used to select from those transform blocks. # chain is a chain of shared processing outputs used by the individual transforms. @@ -88,20 +83,11 @@ class Switch(nn.Module): xformed = [o[0] for o in pcf] atts = [o[1] for o in pcf] else: - # These adjustments were determined statistically from numeric_stability.py and should start this context - # out in a normal distribution. - context = chain[-1] - context = F.interpolate(context, size=x.shape[2:], mode='nearest') - context = torch.cat([self.parameterize(x), context], dim=1) - context = self.c_constric(context) / 3 - context = self.c_process(context) - context = x * context - if self.add_noise: rand_feature = torch.randn_like(x) - xformed = [t(context, rand_feature) for t in self.transforms] + xformed = [t(x, rand_feature) for t in self.transforms] else: - xformed = [t(context) for t in self.transforms] + xformed = [t(x) for t in self.transforms] # Interpolate the multiplexer across the entire shape of the image. m = F.interpolate(m, size=x.shape[2:], mode='nearest') @@ -128,16 +114,18 @@ class Switch(nn.Module): class Processor(nn.Module): def __init__(self, base_filters, processing_depth, reduce=False): super(Processor, self).__init__() - self.output_filter_count = base_filters * 2 + self.output_filter_count = base_filters * (2 if reduce else 1) # Downsample block used for bottleneck. - downsample = nn.Sequential( - nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2, bias=False), - nn.BatchNorm2d(self.output_filter_count), - ) + if reduce: + downsample = nn.Sequential( + nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2, bias=False), + nn.BatchNorm2d(self.output_filter_count), + ) + else: + downsample = None # Bottleneck block outputs the requested filter sizex4, but we only want x2. - self.initial = FixupBottleneck(base_filters, base_filters // 2, stride=2 if reduce else 1, downsample=downsample) - + self.initial = FixupBottleneck(base_filters, self.output_filter_count // 4, stride=2 if reduce else 1, downsample=downsample) self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)]) def forward(self, x): @@ -200,10 +188,12 @@ class NestedSwitchComputer(nn.Module): processing_trunk = [] filters = [] current_filters = switch_base_filters + reduce = False # Don't reduce the first layer, but reduce after that. for _ in range(nesting_depth): - processing_trunk.append(Processor(current_filters, num_switch_processing_layers, reduce=True)) + processing_trunk.append(Processor(current_filters, num_switch_processing_layers, reduce=reduce)) current_filters = processing_trunk[-1].output_filter_count filters.append(current_filters) + reduce = True self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, lelu=False, bn=False) self.processing_trunk = nn.ModuleList(processing_trunk) @@ -253,7 +243,7 @@ class NestedSwitchedGenerator(nn.Module): for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): switches.append(NestedSwitchComputer(transform_filters=transformation_filters, switch_base_filters=switch_filters, num_switch_processing_layers=sw_proc, nesting_depth=sw_reduce, transforms_at_leaf=trans_count, trans_kernel_size=kernel, trans_num_layers=layers, - trans_scale_init=.2/len(switch_reductions), initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) + trans_scale_init=.2, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) self.switches = nn.ModuleList(switches) self.transformation_counts = trans_counts