diff --git a/codes/models/archs/FlatProcessorNet_arch.py b/codes/models/archs/FlatProcessorNet_arch.py
new file mode 100644
index 00000000..2ce1b978
--- /dev/null
+++ b/codes/models/archs/FlatProcessorNet_arch.py
@@ -0,0 +1,118 @@
+import functools
+import torch.nn as nn
+import torch.nn.functional as F
+import models.archs.arch_util as arch_util
+import torch
+
+class ReduceAnnealer(nn.Module):
+    '''
+    Reduces an image dimensionality by half and performs a specified number of residual blocks on it before
+    `annealing` the filter count to the same as the input filter count.
+
+    To reduce depth, accepts an interpolated "trunk" input which is summed with the output of the RA block before
+    returning.
+
+    Returns a tuple in the forward pass. The first return is the annealed output. The second is the output before
+    annealing (e.g. number_filters=input*4) which can be be used for upsampling.
+    '''
+
+    def __init__(self, number_filters, residual_blocks):
+        super(ReduceAnnealer, self).__init__()
+        self.reducer = nn.Conv2d(number_filters, number_filters*4, 3, stride=2, padding=1, bias=True)
+        self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
+        self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+        arch_util.initialize_weights([self.reducer, self.annealer], .1)
+
+    def forward(self, x, interpolated_trunk):
+        out = self.lrelu(self.reducer(x))
+        out = self.lrelu(self.res_trunk(out))
+        annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
+        return annealed, out
+
+class Assembler(nn.Module):
+    '''
+    Upsamples a given input using PixelShuffle. Then upsamples this input further and adds in a residual raw input from
+    a corresponding upstream ReduceAnnealer. Finally performs processing using ResNet blocks.
+    '''
+    def __init__(self, number_filters, residual_blocks):
+        super(Assembler, self).__init__()
+        self.pixel_shuffle = nn.PixelShuffle(2)
+        self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
+        self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+    def forward(self, input, skip_raw):
+        out = self.pixel_shuffle(input)
+        out = self.upsampler(out) + skip_raw
+        out = self.lrelu(self.res_trunk(out))
+        return out
+
+class FlatProcessorNet(nn.Module):
+    '''
+    Specialized network that tries to perform a near-equal amount of processing on each of 5 downsampling steps. Image
+    is then upsampled to a specified size with a similarly flat amount of processing.
+
+    This network automatically applies a noise vector on the inputs to provide entropy for processing.
+     '''
+    def __init__(self, in_nc=3, out_nc=3, nf=64, reduce_anneal_blocks=4, assembler_blocks=2, downscale=4):
+        super(FlatProcessorNet, self).__init__()
+
+        assert downscale in [1, 2, 4], "Requested downscale not supported; %i" % (downscale, )
+        self.downscale = downscale
+
+        # We will always apply a noise channel to the inputs, account for that here.
+        in_nc += 1
+
+        # We need two layers to move the image into the filter space in which we will perform most of the work.
+        self.conv_first = nn.Conv2d(in_nc, nf, 3, stride=1, padding=1, bias=True)
+        self.conv_last = nn.Conv2d(nf*4, out_nc, 3, stride=1, padding=1, bias=True)
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+        # Torch modules need to have all submodules as explicit class members. So make those, then add them into an
+        # array for easier logic in forward().
+        self.ra1 = ReduceAnnealer(nf, reduce_anneal_blocks)
+        self.ra2 = ReduceAnnealer(nf, reduce_anneal_blocks)
+        self.ra3 = ReduceAnnealer(nf, reduce_anneal_blocks)
+        self.ra4 = ReduceAnnealer(nf, reduce_anneal_blocks)
+        self.ra5 = ReduceAnnealer(nf, reduce_anneal_blocks)
+        self.reducers = [self.ra1, self.ra2, self.ra3, self.ra4, self.ra5]
+
+        # Produce assemblers for all possible downscale variants. Some may not be used.
+        self.assembler1 = Assembler(nf, assembler_blocks)
+        self.assembler2 = Assembler(nf, assembler_blocks)
+        self.assembler3 = Assembler(nf, assembler_blocks)
+        self.assembler4 = Assembler(nf, assembler_blocks)
+        self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
+
+        # Initialization
+        arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
+
+    def forward(self, x):
+        # Noise has the same shape as the input with only one channel.
+        rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
+        out = torch.cat([x, rand_feature], dim=1)
+
+        out = self.lrelu(self.conv_first(out))
+        features_trunk = out
+        raw_values = []
+        downsamples = 1
+        for ra in self.reducers:
+            downsamples *= 2
+            interpolated = F.interpolate(features_trunk, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
+            out, raw = ra(out, interpolated)
+            raw_values.append(raw)
+
+        i = -1
+        out = raw_values[-1]
+        while downsamples != self.downscale:
+            out = self.assemblers[i](out, raw_values[i-1])
+            i -= 1
+            downsamples = int(downsamples / 2)
+
+        out = self.conv_last(out)
+
+        basis = x
+        if downsamples != 1:
+            basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
+        return basis + out
diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py
index ca5d7fa9..e2b4a0b9 100644
--- a/codes/models/archs/arch_util.py
+++ b/codes/models/archs/arch_util.py
@@ -30,6 +30,28 @@ def make_layer(block, n_layers):
         layers.append(block())
     return nn.Sequential(*layers)
 
+class ResidualBlock(nn.Module):
+    '''Residual block with BN
+    ---Conv-BN-ReLU-Conv-+-
+     |________________|
+    '''
+
+    def __init__(self, nf=64):
+        super(ResidualBlock, self).__init__()
+        self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+        self.BN1 = nn.BatchNorm2d(nf)
+        self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+        self.BN2 = nn.BatchNorm2d(nf)
+
+        # initialization
+        initialize_weights([self.conv1, self.conv2], 0.1)
+
+    def forward(self, x):
+        identity = x
+        out = F.relu(self.BN1(self.conv1(x)), inplace=True)
+        out = self.BN2(self.conv2(out))
+        return identity + out
+
 
 class ResidualBlock_noBN(nn.Module):
     '''Residual block w/o BN
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 1b7563dc..dfb9ad32 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -4,6 +4,7 @@ import models.archs.discriminator_vgg_arch as SRGAN_arch
 import models.archs.RRDBNet_arch as RRDBNet_arch
 import models.archs.EDVR_arch as EDVR_arch
 import models.archs.HighToLowResNet as HighToLowResNet
+import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
 import math
 
 # Generator
@@ -25,6 +26,10 @@ def define_G(opt):
     elif which_model == 'HighToLowResNet':
         netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                 nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
+    elif which_model == 'FlatProcessorNet':
+        netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
+                                nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
+                                assembler_blocks=opt_net['assembler_blocks'])
     # video restoration
     elif which_model == 'EDVR':
         netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],