diff --git a/codes/data/GTLQ_dataset.py b/codes/data/Downsample_dataset.py
similarity index 56%
rename from codes/data/GTLQ_dataset.py
rename to codes/data/Downsample_dataset.py
index c9a47a4c..efe07242 100644
--- a/codes/data/GTLQ_dataset.py
+++ b/codes/data/Downsample_dataset.py
@@ -7,14 +7,14 @@ import torch.utils.data as data
 import data.util as util
 
 
-class GTLQDataset(data.Dataset):
+class DownsampleDataset(data.Dataset):
     """
-    Reads unpaired high-resolution and low resolution images. Downsampled, LR images matching the provided high res
-     images are produced and fed to the downstream model, which can be used in a pixel loss.
+    Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
+    downsampled LQ image from the HQ image and feeds that as well.
     """
 
     def __init__(self, opt):
-        super(GTLQDataset, self).__init__()
+        super(DownsampleDataset, self).__init__()
         self.opt = opt
         self.data_type = self.opt['data_type']
         self.paths_LQ, self.paths_GT = None, None
@@ -23,8 +23,11 @@ class GTLQDataset(data.Dataset):
 
         self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
         self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
+
+        self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
         assert self.paths_GT, 'Error: GT path is empty.'
-        if self.paths_LQ and self.paths_GT:
+        assert self.paths_LQ, 'LQ is required for downsampling.'
+        if not self.data_sz_mismatch_ok:
             assert len(self.paths_LQ) == len(
                 self.paths_GT
             ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
@@ -41,9 +44,8 @@ class GTLQDataset(data.Dataset):
     def __getitem__(self, index):
         if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
             self._init_lmdb()
-        GT_path, LQ_path = None, None
         scale = self.opt['scale']
-        GT_size = self.opt['target_size']
+        GT_size = self.opt['target_size'] * scale
 
         # get GT image
         GT_path = self.paths_GT[index]
@@ -56,43 +58,19 @@ class GTLQDataset(data.Dataset):
             img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
 
         # get LQ image
-        if self.paths_LQ:
-            LQ_path = self.paths_LQ[index]
-            resolution = [int(s) for s in self.sizes_LQ[index].split('_')
-                          ] if self.data_type == 'lmdb' else None
-            img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
-        else:  # down-sampling on-the-fly
-            # randomly scale during training
-            if self.opt['phase'] == 'train':
-                random_scale = random.choice(self.random_scale_list)
-                H_s, W_s, _ = img_GT.shape
+        lqind = index % len(self.paths_LQ)
+        LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
+        resolution = [int(s) for s in self.sizes_LQ[index].split('_')
+                      ] if self.data_type == 'lmdb' else None
+        img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
 
-                def _mod(n, random_scale, scale, thres):
-                    rlt = int(n * random_scale)
-                    rlt = (rlt // scale) * scale
-                    return thres if rlt < thres else rlt
-
-                H_s = _mod(H_s, random_scale, scale, GT_size)
-                W_s = _mod(W_s, random_scale, scale, GT_size)
-                img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
-                if img_GT.ndim == 2:
-                    img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
-
-            H, W, _ = img_GT.shape
-            # using matlab imresize
-            img_LQ = util.imresize_np(img_GT, 1 / scale, True)
-            if img_LQ.ndim == 2:
-                img_LQ = np.expand_dims(img_LQ, axis=2)
+        # Create a downsampled version of the HQ image using matlab imresize.
+        img_Downsampled = util.imresize_np(img_GT, 1 / scale)
+        assert img_Downsampled.ndim == 3
 
         if self.opt['phase'] == 'train':
-            # if the image size is too small
             H, W, _ = img_GT.shape
-            if H < GT_size or W < GT_size:
-                img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
-                # using matlab imresize
-                img_LQ = util.imresize_np(img_GT, 1 / scale, True)
-                if img_LQ.ndim == 2:
-                    img_LQ = np.expand_dims(img_LQ, axis=2)
+            assert H >= GT_size and W >= GT_size
 
             H, W, C = img_LQ.shape
             LQ_size = GT_size // scale
@@ -101,27 +79,35 @@ class GTLQDataset(data.Dataset):
             rnd_h = random.randint(0, max(0, H - LQ_size))
             rnd_w = random.randint(0, max(0, W - LQ_size))
             img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
+            img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
             rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
             img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
 
             # augmentation - flip, rotate
-            img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
+            img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'],
                                           self.opt['use_rot'])
 
         if self.opt['color']:  # change color space if necessary
-            img_LQ = util.channel_convert(C, self.opt['color'],
-                                          [img_LQ])[0]  # TODO during val no definition
+            img_Downsampled = util.channel_convert(C, self.opt['color'],
+                                          [img_Downsampled])[0]  # TODO during val no definition
 
         # BGR to RGB, HWC to CHW, numpy to tensor
         if img_GT.shape[2] == 3:
             img_GT = img_GT[:, :, [2, 1, 0]]
             img_LQ = img_LQ[:, :, [2, 1, 0]]
+            img_Downsampled = img_Downsampled[:, :, [2, 1, 0]]
         img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
+        img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float()
         img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
 
-        if LQ_path is None:
-            LQ_path = GT_path
-        return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
+        # This may seem really messed up, but let me explain:
+        #  The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
+        #  but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
+        #  "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
+        #  we just need to trick its logic into following this rules.
+        #  Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
+        #  PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
+        return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
 
     def __len__(self):
         return len(self.paths_GT)
diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
index a1419a89..30c4f7aa 100644
--- a/codes/models/SRGAN_model.py
+++ b/codes/models/SRGAN_model.py
@@ -55,6 +55,9 @@ class SRGANModel(BaseModel):
                 else:
                     raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                 self.l_fea_w = train_opt['feature_weight']
+                self.l_fea_w_decay = train_opt['feature_weight_decay']
+                self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
+                self.l_fea_w_minimum = train_opt['feature_weight_minimum']
             else:
                 logger.info('Remove feature loss.')
                 self.cri_fea = None
@@ -143,13 +146,6 @@ class SRGANModel(BaseModel):
             self.pix = data['PIX'].to(self.device)
 
     def optimize_parameters(self, step):
-
-        if step % 50 == 0:
-            for i in range(self.var_L.shape[0]):
-                utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
-                utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
-                utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
-
         # G
         for p in self.netD.parameters():
             p.requires_grad = False
@@ -157,6 +153,13 @@ class SRGANModel(BaseModel):
         self.optimizer_G.zero_grad()
         self.fake_H = self.netG(self.var_L)
 
+        if step % 50 == 0:
+            for i in range(self.var_L.shape[0]):
+                utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
+                utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
+                utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
+                utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i)))
+
         l_g_total = 0
         if step % self.D_update_ratio == 0 and step > self.D_init_iters:
             if self.cri_pix:  # pixel loss
@@ -168,6 +171,11 @@ class SRGANModel(BaseModel):
                 l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                 l_g_total += l_g_fea
 
+                # Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
+                # in the resultant image.
+                if step % self.l_fea_w_decay_steps == 0:
+                    self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
+
             if self.opt['train']['gan_type'] == 'gan':
                 pred_g_fake = self.netD(self.fake_H)
                 l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
@@ -193,7 +201,8 @@ class SRGANModel(BaseModel):
             # real
             pred_d_real = self.netD(self.var_ref)
             l_d_real = self.cri_gan(pred_d_real, True)
-            l_d_real.backward()
+            with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
+                l_d_real_scaled.backward()
             # fake
             pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
             l_d_fake = self.cri_gan(pred_d_fake, False)
@@ -222,12 +231,13 @@ class SRGANModel(BaseModel):
             if self.cri_pix:
                 self.log_dict['l_g_pix'] = l_g_pix.item()
             if self.cri_fea:
+                self.log_dict['feature_weight'] = self.l_fea_w
                 self.log_dict['l_g_fea'] = l_g_fea.item()
             self.log_dict['l_g_gan'] = l_g_gan.item()
-        self.log_dict['l_g_total'] = l_g_total.item()
-        self.log_dict['l_d_real'] = l_d_real.item()
-        self.log_dict['l_d_fake'] = l_d_fake.item()
-        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
+            self.log_dict['l_g_total'] = l_g_total.item()
+            self.log_dict['l_d_real'] = l_d_real.item()
+            self.log_dict['l_d_fake'] = l_d_fake.item()
+            self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
 
     def test(self):
         self.netG.eval()
diff --git a/codes/models/__init__.py b/codes/models/__init__.py
index c95004c9..0767eeb3 100644
--- a/codes/models/__init__.py
+++ b/codes/models/__init__.py
@@ -7,7 +7,7 @@ def create_model(opt):
     # image restoration
     if model == 'sr':  # PSNR-oriented super resolution
         from .SR_model import SRModel as M
-    elif model == 'srgan':  # GAN-based super resolution, SRGAN / ESRGAN
+    elif model == 'srgan' or model == 'corruptgan':  # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
         from .SRGAN_model import SRGANModel as M
     # video restoration
     elif model == 'video_base':
diff --git a/codes/models/archs/HighToLowResNet.py b/codes/models/archs/HighToLowResNet.py
new file mode 100644
index 00000000..470359f2
--- /dev/null
+++ b/codes/models/archs/HighToLowResNet.py
@@ -0,0 +1,63 @@
+import functools
+import torch.nn as nn
+import torch.nn.functional as F
+import models.archs.arch_util as arch_util
+import torch
+
+
+class HighToLowResNet(nn.Module):
+    ''' ResNet that applies a noise channel to the input, then downsamples it. Currently only downscale=4 is supported. '''
+
+    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, downscale=4):
+        super(HighToLowResNet, self).__init__()
+        self.downscale = downscale
+
+        # We will always apply a noise channel to the inputs, account for that here.
+        in_nc += 1
+
+        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
+        basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
+        basic_block2 = functools.partial(arch_util.ResidualBlock_noBN, nf=nf*2)
+        # To keep the total model size down, the residual trunks will be applied across 3 downsampling stages.
+        # The first will be applied against the hi-res inputs and will have only 4 layers.
+        # The second will be applied after half of the downscaling and will also have only 6 layers.
+        # The final will be applied against the final resolution and will have all of the remaining layers.
+        self.trunk_hires = arch_util.make_layer(basic_block, 4)
+        self.trunk_medres = arch_util.make_layer(basic_block, 6)
+        self.trunk_lores = arch_util.make_layer(basic_block2, nb - 10)
+
+        # downsampling
+        if self.downscale == 4:
+            self.downconv1 = nn.Conv2d(nf, nf, 3, stride=2, padding=1, bias=True)
+            self.downconv2 = nn.Conv2d(nf, nf*2, 3, stride=2, padding=1, bias=True)
+        else:
+            raise EnvironmentError("Requested downscale not supported: %i" % (downscale,))
+
+        self.HRconv = nn.Conv2d(nf*2, nf*2, 3, stride=1, padding=1, bias=True)
+        self.conv_last = nn.Conv2d(nf*2, out_nc, 3, stride=1, padding=1, bias=True)
+
+        # activation function
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+        # initialization
+        arch_util.initialize_weights([self.conv_first, self.HRconv, self.conv_last, self.downconv1, self.downconv2],
+                                     0.1)
+
+    def forward(self, x):
+        # Noise has the same shape as the input with only one channel.
+        rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device)
+        out = torch.cat([x, rand_feature], dim=1)
+
+        out = self.lrelu(self.conv_first(out))
+        out = self.trunk_hires(out)
+
+        if self.downscale == 4:
+            out = self.lrelu(self.downconv1(out))
+            out = self.trunk_medres(out)
+            out = self.lrelu(self.downconv2(out))
+            out = self.trunk_lores(out)
+
+        out = self.conv_last(self.lrelu(self.HRconv(out)))
+        base = F.interpolate(x, scale_factor=1/self.downscale, mode='bilinear', align_corners=False)
+        out += base
+        return out
diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py
index ae51ba16..10a3ccdc 100644
--- a/codes/models/archs/discriminator_vgg_arch.py
+++ b/codes/models/archs/discriminator_vgg_arch.py
@@ -32,7 +32,7 @@ class Discriminator_VGG_128(nn.Module):
         self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
         self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
 
-        self.linear1 = nn.Linear(int(512 * 4 * input_img_factor * 4 * input_img_factor), 100)
+        self.linear1 = nn.Linear(int(nf * 8 * 4 * input_img_factor * 4 * input_img_factor), 100)
         self.linear2 = nn.Linear(100, 1)
 
         # activation function
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 2b79249b..1b7563dc 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -3,6 +3,7 @@ import models.archs.SRResNet_arch as SRResNet_arch
 import models.archs.discriminator_vgg_arch as SRGAN_arch
 import models.archs.RRDBNet_arch as RRDBNet_arch
 import models.archs.EDVR_arch as EDVR_arch
+import models.archs.HighToLowResNet as HighToLowResNet
 import math
 
 # Generator
@@ -20,6 +21,10 @@ def define_G(opt):
         scale_per_step = math.sqrt(scale)
         netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                     nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
+    # image corruption
+    elif which_model == 'HighToLowResNet':
+        netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
+                                nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
     # video restoration
     elif which_model == 'EDVR':
         netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
diff --git a/codes/options/options.py b/codes/options/options.py
index 99181b34..5dc34b11 100644
--- a/codes/options/options.py
+++ b/codes/options/options.py
@@ -15,14 +15,14 @@ def parse(opt_path, is_train=True):
     print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
 
     opt['is_train'] = is_train
-    if opt['distortion'] == 'sr':
+    if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
         scale = opt['scale']
 
     # datasets
     for phase, dataset in opt['datasets'].items():
         phase = phase.split('_')[0]
         dataset['phase'] = phase
-        if opt['distortion'] == 'sr':
+        if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
             dataset['scale'] = scale
         is_lmdb = False
         if dataset.get('dataroot_GT', None) is not None:
@@ -62,7 +62,7 @@ def parse(opt_path, is_train=True):
         opt['path']['log'] = results_root
 
     # network
-    if opt['distortion'] == 'sr':
+    if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
         opt['network_G']['scale'] = scale
 
     return opt
diff --git a/codes/train.py b/codes/train.py
index 8721d33f..2906463a 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -29,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
 def main():
     #### options
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_corruptGAN_adrianna.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                         help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
@@ -176,7 +176,7 @@ def main():
                     logger.info(message)
             #### validation
             if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
-                if opt['model'] in ['sr', 'srgan'] and rank <= 0:  # image restoration validation
+                if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0:  # image restoration validation
                     # does not support multi-GPU validation
                     pbar = util.ProgressBar(len(val_loader))
                     avg_psnr = 0.