Add FlatProcessorNet
After doing some thinking and reading on the subject, it occurred to me that I was treating the generator like a discriminator by focusing the network complexity at the feature levels. It makes far more sense to process each conv level equally for the generator, hence the FlatProcessorNet in this commit. This network borrows some of the residual pass-through logic from RRDB which makes the gradient path exceptionally short for pretty much all model parameters and can be trained in O1 optimization mode without overflows again.
This commit is contained in:
parent
b8f67418d4
commit
8ab595e427
118
codes/models/archs/FlatProcessorNet_arch.py
Normal file
118
codes/models/archs/FlatProcessorNet_arch.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
import functools
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.arch_util as arch_util
|
||||
import torch
|
||||
|
||||
class ReduceAnnealer(nn.Module):
|
||||
'''
|
||||
Reduces an image dimensionality by half and performs a specified number of residual blocks on it before
|
||||
`annealing` the filter count to the same as the input filter count.
|
||||
|
||||
To reduce depth, accepts an interpolated "trunk" input which is summed with the output of the RA block before
|
||||
returning.
|
||||
|
||||
Returns a tuple in the forward pass. The first return is the annealed output. The second is the output before
|
||||
annealing (e.g. number_filters=input*4) which can be be used for upsampling.
|
||||
'''
|
||||
|
||||
def __init__(self, number_filters, residual_blocks):
|
||||
super(ReduceAnnealer, self).__init__()
|
||||
self.reducer = nn.Conv2d(number_filters, number_filters*4, 3, stride=2, padding=1, bias=True)
|
||||
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
|
||||
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
arch_util.initialize_weights([self.reducer, self.annealer], .1)
|
||||
|
||||
def forward(self, x, interpolated_trunk):
|
||||
out = self.lrelu(self.reducer(x))
|
||||
out = self.lrelu(self.res_trunk(out))
|
||||
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
|
||||
return annealed, out
|
||||
|
||||
class Assembler(nn.Module):
|
||||
'''
|
||||
Upsamples a given input using PixelShuffle. Then upsamples this input further and adds in a residual raw input from
|
||||
a corresponding upstream ReduceAnnealer. Finally performs processing using ResNet blocks.
|
||||
'''
|
||||
def __init__(self, number_filters, residual_blocks):
|
||||
super(Assembler, self).__init__()
|
||||
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
|
||||
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, input, skip_raw):
|
||||
out = self.pixel_shuffle(input)
|
||||
out = self.upsampler(out) + skip_raw
|
||||
out = self.lrelu(self.res_trunk(out))
|
||||
return out
|
||||
|
||||
class FlatProcessorNet(nn.Module):
|
||||
'''
|
||||
Specialized network that tries to perform a near-equal amount of processing on each of 5 downsampling steps. Image
|
||||
is then upsampled to a specified size with a similarly flat amount of processing.
|
||||
|
||||
This network automatically applies a noise vector on the inputs to provide entropy for processing.
|
||||
'''
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, reduce_anneal_blocks=4, assembler_blocks=2, downscale=4):
|
||||
super(FlatProcessorNet, self).__init__()
|
||||
|
||||
assert downscale in [1, 2, 4], "Requested downscale not supported; %i" % (downscale, )
|
||||
self.downscale = downscale
|
||||
|
||||
# We will always apply a noise channel to the inputs, account for that here.
|
||||
in_nc += 1
|
||||
|
||||
# We need two layers to move the image into the filter space in which we will perform most of the work.
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, stride=1, padding=1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf*4, out_nc, 3, stride=1, padding=1, bias=True)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
# Torch modules need to have all submodules as explicit class members. So make those, then add them into an
|
||||
# array for easier logic in forward().
|
||||
self.ra1 = ReduceAnnealer(nf, reduce_anneal_blocks)
|
||||
self.ra2 = ReduceAnnealer(nf, reduce_anneal_blocks)
|
||||
self.ra3 = ReduceAnnealer(nf, reduce_anneal_blocks)
|
||||
self.ra4 = ReduceAnnealer(nf, reduce_anneal_blocks)
|
||||
self.ra5 = ReduceAnnealer(nf, reduce_anneal_blocks)
|
||||
self.reducers = [self.ra1, self.ra2, self.ra3, self.ra4, self.ra5]
|
||||
|
||||
# Produce assemblers for all possible downscale variants. Some may not be used.
|
||||
self.assembler1 = Assembler(nf, assembler_blocks)
|
||||
self.assembler2 = Assembler(nf, assembler_blocks)
|
||||
self.assembler3 = Assembler(nf, assembler_blocks)
|
||||
self.assembler4 = Assembler(nf, assembler_blocks)
|
||||
self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
|
||||
|
||||
# Initialization
|
||||
arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
|
||||
|
||||
def forward(self, x):
|
||||
# Noise has the same shape as the input with only one channel.
|
||||
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
|
||||
out = torch.cat([x, rand_feature], dim=1)
|
||||
|
||||
out = self.lrelu(self.conv_first(out))
|
||||
features_trunk = out
|
||||
raw_values = []
|
||||
downsamples = 1
|
||||
for ra in self.reducers:
|
||||
downsamples *= 2
|
||||
interpolated = F.interpolate(features_trunk, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
|
||||
out, raw = ra(out, interpolated)
|
||||
raw_values.append(raw)
|
||||
|
||||
i = -1
|
||||
out = raw_values[-1]
|
||||
while downsamples != self.downscale:
|
||||
out = self.assemblers[i](out, raw_values[i-1])
|
||||
i -= 1
|
||||
downsamples = int(downsamples / 2)
|
||||
|
||||
out = self.conv_last(out)
|
||||
|
||||
basis = x
|
||||
if downsamples != 1:
|
||||
basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
|
||||
return basis + out
|
|
@ -30,6 +30,28 @@ def make_layer(block, n_layers):
|
|||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
'''Residual block with BN
|
||||
---Conv-BN-ReLU-Conv-+-
|
||||
|________________|
|
||||
'''
|
||||
|
||||
def __init__(self, nf=64):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.BN1 = nn.BatchNorm2d(nf)
|
||||
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.BN2 = nn.BatchNorm2d(nf)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = F.relu(self.BN1(self.conv1(x)), inplace=True)
|
||||
out = self.BN2(self.conv2(out))
|
||||
return identity + out
|
||||
|
||||
|
||||
class ResidualBlock_noBN(nn.Module):
|
||||
'''Residual block w/o BN
|
||||
|
|
|
@ -4,6 +4,7 @@ import models.archs.discriminator_vgg_arch as SRGAN_arch
|
|||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||
import models.archs.EDVR_arch as EDVR_arch
|
||||
import models.archs.HighToLowResNet as HighToLowResNet
|
||||
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
|
||||
import math
|
||||
|
||||
# Generator
|
||||
|
@ -25,6 +26,10 @@ def define_G(opt):
|
|||
elif which_model == 'HighToLowResNet':
|
||||
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
|
||||
elif which_model == 'FlatProcessorNet':
|
||||
netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
|
||||
assembler_blocks=opt_net['assembler_blocks'])
|
||||
# video restoration
|
||||
elif which_model == 'EDVR':
|
||||
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
||||
|
|
Loading…
Reference in New Issue
Block a user