|
|
|
@ -23,12 +23,10 @@ class ReduceAnnealer(nn.Module):
|
|
|
|
|
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
|
|
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
|
|
|
arch_util.initialize_weights([self.reducer, self.annealer], .1)
|
|
|
|
|
self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True)
|
|
|
|
|
self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, interpolated_trunk):
|
|
|
|
|
out = self.lrelu(self.bn_reduce(self.reducer(x)))
|
|
|
|
|
out = self.lrelu(self.bn_anneal(self.res_trunk(out)))
|
|
|
|
|
out = self.lrelu(self.reducer(x))
|
|
|
|
|
out = self.lrelu(self.res_trunk(out))
|
|
|
|
|
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
|
|
|
|
|
return annealed, out
|
|
|
|
|
|
|
|
|
@ -43,13 +41,11 @@ class Assembler(nn.Module):
|
|
|
|
|
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
|
|
|
|
|
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
|
|
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
|
|
|
self.bn = nn.BatchNorm2d(number_filters*4, affine=True)
|
|
|
|
|
self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True)
|
|
|
|
|
|
|
|
|
|
def forward(self, input, skip_raw):
|
|
|
|
|
out = self.pixel_shuffle(input)
|
|
|
|
|
out = self.bn_up(self.upsampler(out)) + skip_raw
|
|
|
|
|
out = self.lrelu(self.bn(self.res_trunk(out)))
|
|
|
|
|
out = self.upsampler(out) + skip_raw
|
|
|
|
|
out = self.lrelu(self.res_trunk(out))
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
class FlatProcessorNet(nn.Module):
|
|
|
|
@ -84,15 +80,10 @@ class FlatProcessorNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# Produce assemblers for all possible downscale variants. Some may not be used.
|
|
|
|
|
self.assembler1 = Assembler(nf, assembler_blocks)
|
|
|
|
|
self.assemble1_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
|
|
|
|
self.assembler2 = Assembler(nf, assembler_blocks)
|
|
|
|
|
self.assemble2_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
|
|
|
|
self.assembler3 = Assembler(nf, assembler_blocks)
|
|
|
|
|
self.assemble3_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
|
|
|
|
self.assembler4 = Assembler(nf, assembler_blocks)
|
|
|
|
|
self.assemble4_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
|
|
|
|
|
self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
|
|
|
|
|
self.assemble_convs = [self.assemble1_conv, self.assemble2_conv, self.assemble3_conv, self.assemble4_conv]
|
|
|
|
|
|
|
|
|
|
# Initialization
|
|
|
|
|
arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
|
|
|
|
@ -113,10 +104,8 @@ class FlatProcessorNet(nn.Module):
|
|
|
|
|
raw_values.append(raw)
|
|
|
|
|
|
|
|
|
|
i = -1
|
|
|
|
|
scaled_outputs = {}
|
|
|
|
|
out = raw_values[-1]
|
|
|
|
|
while downsamples != self.downscale:
|
|
|
|
|
scaled_outputs[int(x.shape[-1] / downsamples)] = self.assemble_convs[i](out)
|
|
|
|
|
out = self.assemblers[i](out, raw_values[i-1])
|
|
|
|
|
i -= 1
|
|
|
|
|
downsamples = int(downsamples / 2)
|
|
|
|
@ -126,4 +115,4 @@ class FlatProcessorNet(nn.Module):
|
|
|
|
|
basis = x
|
|
|
|
|
if downsamples != 1:
|
|
|
|
|
basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
|
|
|
|
|
return basis + out, scaled_outputs
|
|
|
|
|
return basis + out
|
|
|
|
|