NSG rev 6

- Disable style passthrough
- Process multiplexers starting at base resolution
This commit is contained in:
James Betker 2020-06-30 20:47:26 -06:00
parent 3ce1a1878d
commit e07d8abafb

View File

@ -74,11 +74,6 @@ class Switch(nn.Module):
self.scale = nn.Parameter(torch.ones(1))
self.bias = nn.Parameter(torch.zeros(1))
if not self.pass_chain_forward:
self.parameterize = ConvBnLelu(64, 64, bn=False, lelu=False)
self.c_constric = MultiConvBlock(576, 256, 64, kernel_size=1, depth=3, bn=False)
self.c_process = ConvBnLelu(64, 64, kernel_size=1, lelu=False, bn=False)
# x is the input fed to the transform blocks.
# m is the output of the multiplexer which will be used to select from those transform blocks.
# chain is a chain of shared processing outputs used by the individual transforms.
@ -88,20 +83,11 @@ class Switch(nn.Module):
xformed = [o[0] for o in pcf]
atts = [o[1] for o in pcf]
else:
# These adjustments were determined statistically from numeric_stability.py and should start this context
# out in a normal distribution.
context = chain[-1]
context = F.interpolate(context, size=x.shape[2:], mode='nearest')
context = torch.cat([self.parameterize(x), context], dim=1)
context = self.c_constric(context) / 3
context = self.c_process(context)
context = x * context
if self.add_noise:
rand_feature = torch.randn_like(x)
xformed = [t(context, rand_feature) for t in self.transforms]
xformed = [t(x, rand_feature) for t in self.transforms]
else:
xformed = [t(context) for t in self.transforms]
xformed = [t(x) for t in self.transforms]
# Interpolate the multiplexer across the entire shape of the image.
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
@ -128,16 +114,18 @@ class Switch(nn.Module):
class Processor(nn.Module):
def __init__(self, base_filters, processing_depth, reduce=False):
super(Processor, self).__init__()
self.output_filter_count = base_filters * 2
self.output_filter_count = base_filters * (2 if reduce else 1)
# Downsample block used for bottleneck.
if reduce:
downsample = nn.Sequential(
nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(self.output_filter_count),
)
else:
downsample = None
# Bottleneck block outputs the requested filter sizex4, but we only want x2.
self.initial = FixupBottleneck(base_filters, base_filters // 2, stride=2 if reduce else 1, downsample=downsample)
self.initial = FixupBottleneck(base_filters, self.output_filter_count // 4, stride=2 if reduce else 1, downsample=downsample)
self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)])
def forward(self, x):
@ -200,10 +188,12 @@ class NestedSwitchComputer(nn.Module):
processing_trunk = []
filters = []
current_filters = switch_base_filters
reduce = False # Don't reduce the first layer, but reduce after that.
for _ in range(nesting_depth):
processing_trunk.append(Processor(current_filters, num_switch_processing_layers, reduce=True))
processing_trunk.append(Processor(current_filters, num_switch_processing_layers, reduce=reduce))
current_filters = processing_trunk[-1].output_filter_count
filters.append(current_filters)
reduce = True
self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, lelu=False, bn=False)
self.processing_trunk = nn.ModuleList(processing_trunk)
@ -253,7 +243,7 @@ class NestedSwitchedGenerator(nn.Module):
for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
switches.append(NestedSwitchComputer(transform_filters=transformation_filters, switch_base_filters=switch_filters, num_switch_processing_layers=sw_proc,
nesting_depth=sw_reduce, transforms_at_leaf=trans_count, trans_kernel_size=kernel, trans_num_layers=layers,
trans_scale_init=.2/len(switch_reductions), initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
trans_scale_init=.2, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
self.switches = nn.ModuleList(switches)
self.transformation_counts = trans_counts