NSG rev 6
- Disable style passthrough - Process multiplexers starting at base resolution
This commit is contained in:
parent
3ce1a1878d
commit
e07d8abafb
|
@ -74,11 +74,6 @@ class Switch(nn.Module):
|
|||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
if not self.pass_chain_forward:
|
||||
self.parameterize = ConvBnLelu(64, 64, bn=False, lelu=False)
|
||||
self.c_constric = MultiConvBlock(576, 256, 64, kernel_size=1, depth=3, bn=False)
|
||||
self.c_process = ConvBnLelu(64, 64, kernel_size=1, lelu=False, bn=False)
|
||||
|
||||
# x is the input fed to the transform blocks.
|
||||
# m is the output of the multiplexer which will be used to select from those transform blocks.
|
||||
# chain is a chain of shared processing outputs used by the individual transforms.
|
||||
|
@ -88,20 +83,11 @@ class Switch(nn.Module):
|
|||
xformed = [o[0] for o in pcf]
|
||||
atts = [o[1] for o in pcf]
|
||||
else:
|
||||
# These adjustments were determined statistically from numeric_stability.py and should start this context
|
||||
# out in a normal distribution.
|
||||
context = chain[-1]
|
||||
context = F.interpolate(context, size=x.shape[2:], mode='nearest')
|
||||
context = torch.cat([self.parameterize(x), context], dim=1)
|
||||
context = self.c_constric(context) / 3
|
||||
context = self.c_process(context)
|
||||
context = x * context
|
||||
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x)
|
||||
xformed = [t(context, rand_feature) for t in self.transforms]
|
||||
xformed = [t(x, rand_feature) for t in self.transforms]
|
||||
else:
|
||||
xformed = [t(context) for t in self.transforms]
|
||||
xformed = [t(x) for t in self.transforms]
|
||||
|
||||
# Interpolate the multiplexer across the entire shape of the image.
|
||||
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
||||
|
@ -128,16 +114,18 @@ class Switch(nn.Module):
|
|||
class Processor(nn.Module):
|
||||
def __init__(self, base_filters, processing_depth, reduce=False):
|
||||
super(Processor, self).__init__()
|
||||
self.output_filter_count = base_filters * 2
|
||||
self.output_filter_count = base_filters * (2 if reduce else 1)
|
||||
|
||||
# Downsample block used for bottleneck.
|
||||
if reduce:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(self.output_filter_count),
|
||||
)
|
||||
else:
|
||||
downsample = None
|
||||
# Bottleneck block outputs the requested filter sizex4, but we only want x2.
|
||||
self.initial = FixupBottleneck(base_filters, base_filters // 2, stride=2 if reduce else 1, downsample=downsample)
|
||||
|
||||
self.initial = FixupBottleneck(base_filters, self.output_filter_count // 4, stride=2 if reduce else 1, downsample=downsample)
|
||||
self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)])
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -200,10 +188,12 @@ class NestedSwitchComputer(nn.Module):
|
|||
processing_trunk = []
|
||||
filters = []
|
||||
current_filters = switch_base_filters
|
||||
reduce = False # Don't reduce the first layer, but reduce after that.
|
||||
for _ in range(nesting_depth):
|
||||
processing_trunk.append(Processor(current_filters, num_switch_processing_layers, reduce=True))
|
||||
processing_trunk.append(Processor(current_filters, num_switch_processing_layers, reduce=reduce))
|
||||
current_filters = processing_trunk[-1].output_filter_count
|
||||
filters.append(current_filters)
|
||||
reduce = True
|
||||
|
||||
self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, lelu=False, bn=False)
|
||||
self.processing_trunk = nn.ModuleList(processing_trunk)
|
||||
|
@ -253,7 +243,7 @@ class NestedSwitchedGenerator(nn.Module):
|
|||
for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||
switches.append(NestedSwitchComputer(transform_filters=transformation_filters, switch_base_filters=switch_filters, num_switch_processing_layers=sw_proc,
|
||||
nesting_depth=sw_reduce, transforms_at_leaf=trans_count, trans_kernel_size=kernel, trans_num_layers=layers,
|
||||
trans_scale_init=.2/len(switch_reductions), initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||
trans_scale_init=.2, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||
self.switches = nn.ModuleList(switches)
|
||||
|
||||
self.transformation_counts = trans_counts
|
||||
|
|
Loading…
Reference in New Issue
Block a user