NSG improvements (r5)
- Get rid of forwards(), it makes numeric_stability.py not work properly. - Do stability auditing across layers. - Upsample last instead of first, work in much higher dimensionality for transforms.
This commit is contained in:
parent
75f148022d
commit
3ce1a1878d
|
@ -75,31 +75,33 @@ class Switch(nn.Module):
|
|||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
if not self.pass_chain_forward:
|
||||
self.parameterize = ConvBnLelu(16, 16, bn=False, lelu=False)
|
||||
self.c_constric = MultiConvBlock(48, 32, 16, kernel_size=5, depth=3, bn=False)
|
||||
self.parameterize = ConvBnLelu(64, 64, bn=False, lelu=False)
|
||||
self.c_constric = MultiConvBlock(576, 256, 64, kernel_size=1, depth=3, bn=False)
|
||||
self.c_process = ConvBnLelu(64, 64, kernel_size=1, lelu=False, bn=False)
|
||||
|
||||
# x is the input fed to the transform blocks.
|
||||
# m is the output of the multiplexer which will be used to select from those transform blocks.
|
||||
# chain is a chain of shared processing outputs used by the individual transforms.
|
||||
def forward(self, x, m, chain):
|
||||
if self.pass_chain_forward:
|
||||
pcf = [t.forward(x, chain) for t in self.transforms]
|
||||
pcf = [t(x, chain) for t in self.transforms]
|
||||
xformed = [o[0] for o in pcf]
|
||||
atts = [o[1] for o in pcf]
|
||||
else:
|
||||
# These adjustments were determined statistically from numeric_stability.py and should start this context
|
||||
# out in a normal distribution.
|
||||
context = (chain[-1] - 6) / 9.4
|
||||
context = F.pixel_shuffle(context, 4)
|
||||
context = chain[-1]
|
||||
context = F.interpolate(context, size=x.shape[2:], mode='nearest')
|
||||
context = torch.cat([self.parameterize(x), context], dim=1)
|
||||
context = self.c_constric(context) / 1.6
|
||||
context = self.c_constric(context) / 3
|
||||
context = self.c_process(context)
|
||||
context = x * context
|
||||
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x)
|
||||
xformed = [t.forward(context, rand_feature) for t in self.transforms]
|
||||
xformed = [t(context, rand_feature) for t in self.transforms]
|
||||
else:
|
||||
xformed = [t.forward(context) for t in self.transforms]
|
||||
xformed = [t(context) for t in self.transforms]
|
||||
|
||||
# Interpolate the multiplexer across the entire shape of the image.
|
||||
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
||||
|
@ -139,9 +141,10 @@ class Processor(nn.Module):
|
|||
self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial(x)
|
||||
x = (self.initial(x) - .4) / .6
|
||||
for b in self.res_blocks:
|
||||
x = b(x) + x
|
||||
r = (b(x) - .4) / .6
|
||||
x = r + x
|
||||
return x
|
||||
|
||||
|
||||
|
@ -160,7 +163,7 @@ class Constrictor(nn.Module):
|
|||
def forward(self, x):
|
||||
x = self.cbl1(x)
|
||||
x = self.cbl2(x)
|
||||
x = self.cbl3(x)
|
||||
x = self.cbl3(x) / 4
|
||||
return x
|
||||
|
||||
|
||||
|
@ -202,7 +205,7 @@ class NestedSwitchComputer(nn.Module):
|
|||
current_filters = processing_trunk[-1].output_filter_count
|
||||
filters.append(current_filters)
|
||||
|
||||
self.multiplexer_init_conv = nn.Conv2d(transform_filters, switch_base_filters, kernel_size=7, padding=3)
|
||||
self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, lelu=False, bn=False)
|
||||
self.processing_trunk = nn.ModuleList(processing_trunk)
|
||||
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
||||
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False)
|
||||
|
@ -219,18 +222,17 @@ class NestedSwitchComputer(nn.Module):
|
|||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.kaiming_normal_(self.multiplexer_init_conv.weight, nonlinearity="relu")
|
||||
|
||||
def forward(self, x):
|
||||
feed_forward = x
|
||||
trunk = []
|
||||
trunk_input = self.multiplexer_init_conv(x)
|
||||
for m in self.processing_trunk:
|
||||
trunk_input = m.forward(trunk_input)
|
||||
trunk_input = (m(trunk_input) - 3.3) / 12.5
|
||||
trunk.append(trunk_input)
|
||||
|
||||
self.trunk = (trunk[-1] - 6) / 9.4
|
||||
x, att = self.switch.forward(x, trunk)
|
||||
self.trunk = trunk[-1]
|
||||
x, att = self.switch(x, trunk)
|
||||
x = x + feed_forward
|
||||
return feed_forward + self.anneal(x) / .86, att
|
||||
|
||||
|
@ -263,21 +265,19 @@ class NestedSwitchedGenerator(nn.Module):
|
|||
self.upsample_factor = upsample_factor
|
||||
|
||||
def forward(self, x):
|
||||
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
|
||||
# is called for, then repair.
|
||||
if self.upsample_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||
|
||||
x = self.initial_conv(x)
|
||||
x = self.initial_conv(x) / .2
|
||||
|
||||
self.attentions = []
|
||||
for i, sw in enumerate(self.switches):
|
||||
x, att = sw.forward(x)
|
||||
x, att = sw(x)
|
||||
self.attentions.append(att)
|
||||
|
||||
if self.upsample_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||
|
||||
x = self.proc_conv(x) / .85
|
||||
x = self.final_conv(x)
|
||||
return x / 4.26,
|
||||
x = self.final_conv(x) / 4.6
|
||||
return x / 16,
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
|
|
@ -336,10 +336,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
self.upsample_factor = upsample_factor
|
||||
|
||||
def forward(self, x):
|
||||
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
|
||||
# is called for, then repair.
|
||||
if self.upsample_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||
|
||||
x = self.initial_conv(x)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user