From 3ce1a1878dcf5a47b9739596806e524316ca1817 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 30 Jun 2020 16:59:57 -0600 Subject: [PATCH] NSG improvements (r5) - Get rid of forwards(), it makes numeric_stability.py not work properly. - Do stability auditing across layers. - Upsample last instead of first, work in much higher dimensionality for transforms. --- codes/models/archs/NestedSwitchGenerator.py | 50 +++++++++---------- .../archs/SwitchedResidualGenerator_arch.py | 4 -- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/codes/models/archs/NestedSwitchGenerator.py b/codes/models/archs/NestedSwitchGenerator.py index 056ae513..45adb8a3 100644 --- a/codes/models/archs/NestedSwitchGenerator.py +++ b/codes/models/archs/NestedSwitchGenerator.py @@ -75,31 +75,33 @@ class Switch(nn.Module): self.bias = nn.Parameter(torch.zeros(1)) if not self.pass_chain_forward: - self.parameterize = ConvBnLelu(16, 16, bn=False, lelu=False) - self.c_constric = MultiConvBlock(48, 32, 16, kernel_size=5, depth=3, bn=False) + self.parameterize = ConvBnLelu(64, 64, bn=False, lelu=False) + self.c_constric = MultiConvBlock(576, 256, 64, kernel_size=1, depth=3, bn=False) + self.c_process = ConvBnLelu(64, 64, kernel_size=1, lelu=False, bn=False) # x is the input fed to the transform blocks. # m is the output of the multiplexer which will be used to select from those transform blocks. # chain is a chain of shared processing outputs used by the individual transforms. def forward(self, x, m, chain): if self.pass_chain_forward: - pcf = [t.forward(x, chain) for t in self.transforms] + pcf = [t(x, chain) for t in self.transforms] xformed = [o[0] for o in pcf] atts = [o[1] for o in pcf] else: # These adjustments were determined statistically from numeric_stability.py and should start this context # out in a normal distribution. - context = (chain[-1] - 6) / 9.4 - context = F.pixel_shuffle(context, 4) + context = chain[-1] context = F.interpolate(context, size=x.shape[2:], mode='nearest') context = torch.cat([self.parameterize(x), context], dim=1) - context = self.c_constric(context) / 1.6 + context = self.c_constric(context) / 3 + context = self.c_process(context) + context = x * context if self.add_noise: rand_feature = torch.randn_like(x) - xformed = [t.forward(context, rand_feature) for t in self.transforms] + xformed = [t(context, rand_feature) for t in self.transforms] else: - xformed = [t.forward(context) for t in self.transforms] + xformed = [t(context) for t in self.transforms] # Interpolate the multiplexer across the entire shape of the image. m = F.interpolate(m, size=x.shape[2:], mode='nearest') @@ -139,9 +141,10 @@ class Processor(nn.Module): self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)]) def forward(self, x): - x = self.initial(x) + x = (self.initial(x) - .4) / .6 for b in self.res_blocks: - x = b(x) + x + r = (b(x) - .4) / .6 + x = r + x return x @@ -160,7 +163,7 @@ class Constrictor(nn.Module): def forward(self, x): x = self.cbl1(x) x = self.cbl2(x) - x = self.cbl3(x) + x = self.cbl3(x) / 4 return x @@ -202,7 +205,7 @@ class NestedSwitchComputer(nn.Module): current_filters = processing_trunk[-1].output_filter_count filters.append(current_filters) - self.multiplexer_init_conv = nn.Conv2d(transform_filters, switch_base_filters, kernel_size=7, padding=3) + self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, lelu=False, bn=False) self.processing_trunk = nn.ModuleList(processing_trunk) self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False) @@ -219,18 +222,17 @@ class NestedSwitchComputer(nn.Module): elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) - nn.init.kaiming_normal_(self.multiplexer_init_conv.weight, nonlinearity="relu") def forward(self, x): feed_forward = x trunk = [] trunk_input = self.multiplexer_init_conv(x) for m in self.processing_trunk: - trunk_input = m.forward(trunk_input) + trunk_input = (m(trunk_input) - 3.3) / 12.5 trunk.append(trunk_input) - self.trunk = (trunk[-1] - 6) / 9.4 - x, att = self.switch.forward(x, trunk) + self.trunk = trunk[-1] + x, att = self.switch(x, trunk) x = x + feed_forward return feed_forward + self.anneal(x) / .86, att @@ -263,21 +265,19 @@ class NestedSwitchedGenerator(nn.Module): self.upsample_factor = upsample_factor def forward(self, x): - # This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that - # is called for, then repair. - if self.upsample_factor > 1: - x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest") - - x = self.initial_conv(x) + x = self.initial_conv(x) / .2 self.attentions = [] for i, sw in enumerate(self.switches): - x, att = sw.forward(x) + x, att = sw(x) self.attentions.append(att) + if self.upsample_factor > 1: + x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest") + x = self.proc_conv(x) / .85 - x = self.final_conv(x) - return x / 4.26, + x = self.final_conv(x) / 4.6 + return x / 16, def set_temperature(self, temp): [sw.set_temperature(temp) for sw in self.switches] diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 63f36aac..6214aa5b 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -336,10 +336,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): self.upsample_factor = upsample_factor def forward(self, x): - # This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that - # is called for, then repair. - if self.upsample_factor > 1: - x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest") x = self.initial_conv(x)