forked from mrq/DL-Art-School
Even more NSG improvements (r4)
This commit is contained in:
parent
773753073f
commit
75f148022d
|
@ -75,8 +75,8 @@ class Switch(nn.Module):
|
|||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
if not self.pass_chain_forward:
|
||||
self.c_constric = MultiConvBlock(32, 32, 16, 3, 3)
|
||||
self.c_conjoin = ConvBnLelu(32, 16, kernel_size=1, bn=False)
|
||||
self.parameterize = ConvBnLelu(16, 16, bn=False, lelu=False)
|
||||
self.c_constric = MultiConvBlock(48, 32, 16, kernel_size=5, depth=3, bn=False)
|
||||
|
||||
# x is the input fed to the transform blocks.
|
||||
# m is the output of the multiplexer which will be used to select from those transform blocks.
|
||||
|
@ -91,11 +91,9 @@ class Switch(nn.Module):
|
|||
# out in a normal distribution.
|
||||
context = (chain[-1] - 6) / 9.4
|
||||
context = F.pixel_shuffle(context, 4)
|
||||
context = self.c_constric(context)
|
||||
|
||||
context = F.interpolate(context, size=x.shape[2:], mode='nearest')
|
||||
context = torch.cat([x, context], dim=1)
|
||||
context = self.c_conjoin(context)
|
||||
context = torch.cat([self.parameterize(x), context], dim=1)
|
||||
context = self.c_constric(context) / 1.6
|
||||
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x)
|
||||
|
@ -224,6 +222,7 @@ class NestedSwitchComputer(nn.Module):
|
|||
nn.init.kaiming_normal_(self.multiplexer_init_conv.weight, nonlinearity="relu")
|
||||
|
||||
def forward(self, x):
|
||||
feed_forward = x
|
||||
trunk = []
|
||||
trunk_input = self.multiplexer_init_conv(x)
|
||||
for m in self.processing_trunk:
|
||||
|
@ -232,7 +231,8 @@ class NestedSwitchComputer(nn.Module):
|
|||
|
||||
self.trunk = (trunk[-1] - 6) / 9.4
|
||||
x, att = self.switch.forward(x, trunk)
|
||||
return self.anneal(x), att
|
||||
x = x + feed_forward
|
||||
return feed_forward + self.anneal(x) / .86, att
|
||||
|
||||
def set_temperature(self, temp):
|
||||
self.switch.set_temperature(temp)
|
||||
|
@ -244,6 +244,7 @@ class NestedSwitchedGenerator(nn.Module):
|
|||
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
|
||||
super(NestedSwitchedGenerator, self).__init__()
|
||||
self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, lelu=False, bn=False)
|
||||
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False)
|
||||
self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, lelu=False, bn=False)
|
||||
|
||||
switches = []
|
||||
|
@ -271,12 +272,12 @@ class NestedSwitchedGenerator(nn.Module):
|
|||
|
||||
self.attentions = []
|
||||
for i, sw in enumerate(self.switches):
|
||||
sw_out, att = sw.forward(x)
|
||||
x, att = sw.forward(x)
|
||||
self.attentions.append(att)
|
||||
x = x + sw_out
|
||||
|
||||
x = self.proc_conv(x) / .85
|
||||
x = self.final_conv(x)
|
||||
return x,
|
||||
return x / 4.26,
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
|
|
@ -43,12 +43,12 @@ class ConvBnLelu(nn.Module):
|
|||
|
||||
|
||||
class MultiConvBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1):
|
||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False):
|
||||
assert depth >= 2
|
||||
super(MultiConvBlock, self).__init__()
|
||||
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=False)] +
|
||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=False) for i in range(depth-2)] +
|
||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn)] +
|
||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn) for i in range(depth-2)] +
|
||||
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False)])
|
||||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
||||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
|
Loading…
Reference in New Issue
Block a user