Even more NSG improvements (r4)

This commit is contained in:
James Betker 2020-06-30 13:52:47 -06:00
parent 773753073f
commit 75f148022d
2 changed files with 14 additions and 13 deletions

View File

@ -75,8 +75,8 @@ class Switch(nn.Module):
self.bias = nn.Parameter(torch.zeros(1))
if not self.pass_chain_forward:
self.c_constric = MultiConvBlock(32, 32, 16, 3, 3)
self.c_conjoin = ConvBnLelu(32, 16, kernel_size=1, bn=False)
self.parameterize = ConvBnLelu(16, 16, bn=False, lelu=False)
self.c_constric = MultiConvBlock(48, 32, 16, kernel_size=5, depth=3, bn=False)
# x is the input fed to the transform blocks.
# m is the output of the multiplexer which will be used to select from those transform blocks.
@ -91,11 +91,9 @@ class Switch(nn.Module):
# out in a normal distribution.
context = (chain[-1] - 6) / 9.4
context = F.pixel_shuffle(context, 4)
context = self.c_constric(context)
context = F.interpolate(context, size=x.shape[2:], mode='nearest')
context = torch.cat([x, context], dim=1)
context = self.c_conjoin(context)
context = torch.cat([self.parameterize(x), context], dim=1)
context = self.c_constric(context) / 1.6
if self.add_noise:
rand_feature = torch.randn_like(x)
@ -224,6 +222,7 @@ class NestedSwitchComputer(nn.Module):
nn.init.kaiming_normal_(self.multiplexer_init_conv.weight, nonlinearity="relu")
def forward(self, x):
feed_forward = x
trunk = []
trunk_input = self.multiplexer_init_conv(x)
for m in self.processing_trunk:
@ -232,7 +231,8 @@ class NestedSwitchComputer(nn.Module):
self.trunk = (trunk[-1] - 6) / 9.4
x, att = self.switch.forward(x, trunk)
return self.anneal(x), att
x = x + feed_forward
return feed_forward + self.anneal(x) / .86, att
def set_temperature(self, temp):
self.switch.set_temperature(temp)
@ -244,6 +244,7 @@ class NestedSwitchedGenerator(nn.Module):
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
super(NestedSwitchedGenerator, self).__init__()
self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, lelu=False, bn=False)
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False)
self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, lelu=False, bn=False)
switches = []
@ -271,12 +272,12 @@ class NestedSwitchedGenerator(nn.Module):
self.attentions = []
for i, sw in enumerate(self.switches):
sw_out, att = sw.forward(x)
x, att = sw.forward(x)
self.attentions.append(att)
x = x + sw_out
x = self.proc_conv(x) / .85
x = self.final_conv(x)
return x,
return x / 4.26,
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]

View File

@ -43,12 +43,12 @@ class ConvBnLelu(nn.Module):
class MultiConvBlock(nn.Module):
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1):
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False):
assert depth >= 2
super(MultiConvBlock, self).__init__()
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=False)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=False) for i in range(depth-2)] +
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn) for i in range(depth-2)] +
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False)])
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
self.bias = nn.Parameter(torch.zeros(1))