From 75f148022deff3969988a9eec43602fff412290a Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 30 Jun 2020 13:52:47 -0600 Subject: [PATCH] Even more NSG improvements (r4) --- codes/models/archs/NestedSwitchGenerator.py | 21 ++++++++++--------- .../archs/SwitchedResidualGenerator_arch.py | 6 +++--- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/codes/models/archs/NestedSwitchGenerator.py b/codes/models/archs/NestedSwitchGenerator.py index 4724e421..056ae513 100644 --- a/codes/models/archs/NestedSwitchGenerator.py +++ b/codes/models/archs/NestedSwitchGenerator.py @@ -75,8 +75,8 @@ class Switch(nn.Module): self.bias = nn.Parameter(torch.zeros(1)) if not self.pass_chain_forward: - self.c_constric = MultiConvBlock(32, 32, 16, 3, 3) - self.c_conjoin = ConvBnLelu(32, 16, kernel_size=1, bn=False) + self.parameterize = ConvBnLelu(16, 16, bn=False, lelu=False) + self.c_constric = MultiConvBlock(48, 32, 16, kernel_size=5, depth=3, bn=False) # x is the input fed to the transform blocks. # m is the output of the multiplexer which will be used to select from those transform blocks. @@ -91,11 +91,9 @@ class Switch(nn.Module): # out in a normal distribution. context = (chain[-1] - 6) / 9.4 context = F.pixel_shuffle(context, 4) - context = self.c_constric(context) - context = F.interpolate(context, size=x.shape[2:], mode='nearest') - context = torch.cat([x, context], dim=1) - context = self.c_conjoin(context) + context = torch.cat([self.parameterize(x), context], dim=1) + context = self.c_constric(context) / 1.6 if self.add_noise: rand_feature = torch.randn_like(x) @@ -224,6 +222,7 @@ class NestedSwitchComputer(nn.Module): nn.init.kaiming_normal_(self.multiplexer_init_conv.weight, nonlinearity="relu") def forward(self, x): + feed_forward = x trunk = [] trunk_input = self.multiplexer_init_conv(x) for m in self.processing_trunk: @@ -232,7 +231,8 @@ class NestedSwitchComputer(nn.Module): self.trunk = (trunk[-1] - 6) / 9.4 x, att = self.switch.forward(x, trunk) - return self.anneal(x), att + x = x + feed_forward + return feed_forward + self.anneal(x) / .86, att def set_temperature(self, temp): self.switch.set_temperature(temp) @@ -244,6 +244,7 @@ class NestedSwitchedGenerator(nn.Module): heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False): super(NestedSwitchedGenerator, self).__init__() self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, lelu=False, bn=False) + self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False) self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, lelu=False, bn=False) switches = [] @@ -271,12 +272,12 @@ class NestedSwitchedGenerator(nn.Module): self.attentions = [] for i, sw in enumerate(self.switches): - sw_out, att = sw.forward(x) + x, att = sw.forward(x) self.attentions.append(att) - x = x + sw_out + x = self.proc_conv(x) / .85 x = self.final_conv(x) - return x, + return x / 4.26, def set_temperature(self, temp): [sw.set_temperature(temp) for sw in self.switches] diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index fbf52dad..63f36aac 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -43,12 +43,12 @@ class ConvBnLelu(nn.Module): class MultiConvBlock(nn.Module): - def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1): + def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False): assert depth >= 2 super(MultiConvBlock, self).__init__() self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01)) - self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=False)] + - [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=False) for i in range(depth-2)] + + self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn)] + + [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn) for i in range(depth-2)] + [ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False)]) self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) self.bias = nn.Parameter(torch.zeros(1))