NSG improvements (r5)

- Get rid of forwards(), it makes numeric_stability.py not work properly.
- Do stability auditing across layers.
- Upsample last instead of first, work in much higher dimensionality for transforms.
This commit is contained in:
James Betker 2020-06-30 16:59:57 -06:00
parent 75f148022d
commit 3ce1a1878d
2 changed files with 25 additions and 29 deletions

View File

@ -75,31 +75,33 @@ class Switch(nn.Module):
self.bias = nn.Parameter(torch.zeros(1)) self.bias = nn.Parameter(torch.zeros(1))
if not self.pass_chain_forward: if not self.pass_chain_forward:
self.parameterize = ConvBnLelu(16, 16, bn=False, lelu=False) self.parameterize = ConvBnLelu(64, 64, bn=False, lelu=False)
self.c_constric = MultiConvBlock(48, 32, 16, kernel_size=5, depth=3, bn=False) self.c_constric = MultiConvBlock(576, 256, 64, kernel_size=1, depth=3, bn=False)
self.c_process = ConvBnLelu(64, 64, kernel_size=1, lelu=False, bn=False)
# x is the input fed to the transform blocks. # x is the input fed to the transform blocks.
# m is the output of the multiplexer which will be used to select from those transform blocks. # m is the output of the multiplexer which will be used to select from those transform blocks.
# chain is a chain of shared processing outputs used by the individual transforms. # chain is a chain of shared processing outputs used by the individual transforms.
def forward(self, x, m, chain): def forward(self, x, m, chain):
if self.pass_chain_forward: if self.pass_chain_forward:
pcf = [t.forward(x, chain) for t in self.transforms] pcf = [t(x, chain) for t in self.transforms]
xformed = [o[0] for o in pcf] xformed = [o[0] for o in pcf]
atts = [o[1] for o in pcf] atts = [o[1] for o in pcf]
else: else:
# These adjustments were determined statistically from numeric_stability.py and should start this context # These adjustments were determined statistically from numeric_stability.py and should start this context
# out in a normal distribution. # out in a normal distribution.
context = (chain[-1] - 6) / 9.4 context = chain[-1]
context = F.pixel_shuffle(context, 4)
context = F.interpolate(context, size=x.shape[2:], mode='nearest') context = F.interpolate(context, size=x.shape[2:], mode='nearest')
context = torch.cat([self.parameterize(x), context], dim=1) context = torch.cat([self.parameterize(x), context], dim=1)
context = self.c_constric(context) / 1.6 context = self.c_constric(context) / 3
context = self.c_process(context)
context = x * context
if self.add_noise: if self.add_noise:
rand_feature = torch.randn_like(x) rand_feature = torch.randn_like(x)
xformed = [t.forward(context, rand_feature) for t in self.transforms] xformed = [t(context, rand_feature) for t in self.transforms]
else: else:
xformed = [t.forward(context) for t in self.transforms] xformed = [t(context) for t in self.transforms]
# Interpolate the multiplexer across the entire shape of the image. # Interpolate the multiplexer across the entire shape of the image.
m = F.interpolate(m, size=x.shape[2:], mode='nearest') m = F.interpolate(m, size=x.shape[2:], mode='nearest')
@ -139,9 +141,10 @@ class Processor(nn.Module):
self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)]) self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)])
def forward(self, x): def forward(self, x):
x = self.initial(x) x = (self.initial(x) - .4) / .6
for b in self.res_blocks: for b in self.res_blocks:
x = b(x) + x r = (b(x) - .4) / .6
x = r + x
return x return x
@ -160,7 +163,7 @@ class Constrictor(nn.Module):
def forward(self, x): def forward(self, x):
x = self.cbl1(x) x = self.cbl1(x)
x = self.cbl2(x) x = self.cbl2(x)
x = self.cbl3(x) x = self.cbl3(x) / 4
return x return x
@ -202,7 +205,7 @@ class NestedSwitchComputer(nn.Module):
current_filters = processing_trunk[-1].output_filter_count current_filters = processing_trunk[-1].output_filter_count
filters.append(current_filters) filters.append(current_filters)
self.multiplexer_init_conv = nn.Conv2d(transform_filters, switch_base_filters, kernel_size=7, padding=3) self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, lelu=False, bn=False)
self.processing_trunk = nn.ModuleList(processing_trunk) self.processing_trunk = nn.ModuleList(processing_trunk)
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False) self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False)
@ -219,18 +222,17 @@ class NestedSwitchComputer(nn.Module):
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.kaiming_normal_(self.multiplexer_init_conv.weight, nonlinearity="relu")
def forward(self, x): def forward(self, x):
feed_forward = x feed_forward = x
trunk = [] trunk = []
trunk_input = self.multiplexer_init_conv(x) trunk_input = self.multiplexer_init_conv(x)
for m in self.processing_trunk: for m in self.processing_trunk:
trunk_input = m.forward(trunk_input) trunk_input = (m(trunk_input) - 3.3) / 12.5
trunk.append(trunk_input) trunk.append(trunk_input)
self.trunk = (trunk[-1] - 6) / 9.4 self.trunk = trunk[-1]
x, att = self.switch.forward(x, trunk) x, att = self.switch(x, trunk)
x = x + feed_forward x = x + feed_forward
return feed_forward + self.anneal(x) / .86, att return feed_forward + self.anneal(x) / .86, att
@ -263,21 +265,19 @@ class NestedSwitchedGenerator(nn.Module):
self.upsample_factor = upsample_factor self.upsample_factor = upsample_factor
def forward(self, x): def forward(self, x):
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that x = self.initial_conv(x) / .2
# is called for, then repair.
if self.upsample_factor > 1:
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
x = self.initial_conv(x)
self.attentions = [] self.attentions = []
for i, sw in enumerate(self.switches): for i, sw in enumerate(self.switches):
x, att = sw.forward(x) x, att = sw(x)
self.attentions.append(att) self.attentions.append(att)
if self.upsample_factor > 1:
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
x = self.proc_conv(x) / .85 x = self.proc_conv(x) / .85
x = self.final_conv(x) x = self.final_conv(x) / 4.6
return x / 4.26, return x / 16,
def set_temperature(self, temp): def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches] [sw.set_temperature(temp) for sw in self.switches]

View File

@ -336,10 +336,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
self.upsample_factor = upsample_factor self.upsample_factor = upsample_factor
def forward(self, x): def forward(self, x):
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
# is called for, then repair.
if self.upsample_factor > 1:
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
x = self.initial_conv(x) x = self.initial_conv(x)