Add FlatProcessorNet

After doing some thinking and reading on the subject, it occurred to me that
I was treating the generator like a discriminator by focusing the network
complexity at the feature levels. It makes far more sense to process each conv
level equally for the generator, hence the FlatProcessorNet in this commit. This
network borrows some of the residual pass-through logic from RRDB which makes
the gradient path exceptionally short for pretty much all model parameters and
can be trained in O1 optimization mode without overflows again.
pull/9/head
James Betker 2020-04-28 11:48:05 +07:00
parent b8f67418d4
commit 8ab595e427
3 changed files with 145 additions and 0 deletions

@ -0,0 +1,118 @@
import functools
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
import torch
class ReduceAnnealer(nn.Module):
'''
Reduces an image dimensionality by half and performs a specified number of residual blocks on it before
`annealing` the filter count to the same as the input filter count.
To reduce depth, accepts an interpolated "trunk" input which is summed with the output of the RA block before
returning.
Returns a tuple in the forward pass. The first return is the annealed output. The second is the output before
annealing (e.g. number_filters=input*4) which can be be used for upsampling.
'''
def __init__(self, number_filters, residual_blocks):
super(ReduceAnnealer, self).__init__()
self.reducer = nn.Conv2d(number_filters, number_filters*4, 3, stride=2, padding=1, bias=True)
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
arch_util.initialize_weights([self.reducer, self.annealer], .1)
def forward(self, x, interpolated_trunk):
out = self.lrelu(self.reducer(x))
out = self.lrelu(self.res_trunk(out))
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
return annealed, out
class Assembler(nn.Module):
'''
Upsamples a given input using PixelShuffle. Then upsamples this input further and adds in a residual raw input from
a corresponding upstream ReduceAnnealer. Finally performs processing using ResNet blocks.
'''
def __init__(self, number_filters, residual_blocks):
super(Assembler, self).__init__()
self.pixel_shuffle = nn.PixelShuffle(2)
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, input, skip_raw):
out = self.pixel_shuffle(input)
out = self.upsampler(out) + skip_raw
out = self.lrelu(self.res_trunk(out))
return out
class FlatProcessorNet(nn.Module):
'''
Specialized network that tries to perform a near-equal amount of processing on each of 5 downsampling steps. Image
is then upsampled to a specified size with a similarly flat amount of processing.
This network automatically applies a noise vector on the inputs to provide entropy for processing.
'''
def __init__(self, in_nc=3, out_nc=3, nf=64, reduce_anneal_blocks=4, assembler_blocks=2, downscale=4):
super(FlatProcessorNet, self).__init__()
assert downscale in [1, 2, 4], "Requested downscale not supported; %i" % (downscale, )
self.downscale = downscale
# We will always apply a noise channel to the inputs, account for that here.
in_nc += 1
# We need two layers to move the image into the filter space in which we will perform most of the work.
self.conv_first = nn.Conv2d(in_nc, nf, 3, stride=1, padding=1, bias=True)
self.conv_last = nn.Conv2d(nf*4, out_nc, 3, stride=1, padding=1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# Torch modules need to have all submodules as explicit class members. So make those, then add them into an
# array for easier logic in forward().
self.ra1 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.ra2 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.ra3 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.ra4 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.ra5 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.reducers = [self.ra1, self.ra2, self.ra3, self.ra4, self.ra5]
# Produce assemblers for all possible downscale variants. Some may not be used.
self.assembler1 = Assembler(nf, assembler_blocks)
self.assembler2 = Assembler(nf, assembler_blocks)
self.assembler3 = Assembler(nf, assembler_blocks)
self.assembler4 = Assembler(nf, assembler_blocks)
self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
# Initialization
arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
def forward(self, x):
# Noise has the same shape as the input with only one channel.
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
out = torch.cat([x, rand_feature], dim=1)
out = self.lrelu(self.conv_first(out))
features_trunk = out
raw_values = []
downsamples = 1
for ra in self.reducers:
downsamples *= 2
interpolated = F.interpolate(features_trunk, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
out, raw = ra(out, interpolated)
raw_values.append(raw)
i = -1
out = raw_values[-1]
while downsamples != self.downscale:
out = self.assemblers[i](out, raw_values[i-1])
i -= 1
downsamples = int(downsamples / 2)
out = self.conv_last(out)
basis = x
if downsamples != 1:
basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
return basis + out

@ -30,6 +30,28 @@ def make_layer(block, n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualBlock(nn.Module):
'''Residual block with BN
---Conv-BN-ReLU-Conv-+-
|________________|
'''
def __init__(self, nf=64):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.BN1 = nn.BatchNorm2d(nf)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.BN2 = nn.BatchNorm2d(nf)
# initialization
initialize_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = F.relu(self.BN1(self.conv1(x)), inplace=True)
out = self.BN2(self.conv2(out))
return identity + out
class ResidualBlock_noBN(nn.Module):
'''Residual block w/o BN

@ -4,6 +4,7 @@ import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.EDVR_arch as EDVR_arch
import models.archs.HighToLowResNet as HighToLowResNet
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
import math
# Generator
@ -25,6 +26,10 @@ def define_G(opt):
elif which_model == 'HighToLowResNet':
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
elif which_model == 'FlatProcessorNet':
netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
assembler_blocks=opt_net['assembler_blocks'])
# video restoration
elif which_model == 'EDVR':
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],