diff --git a/codes/data_scripts/validate_data.py b/codes/data_scripts/validate_data.py
new file mode 100644
index 00000000..ac7684c0
--- /dev/null
+++ b/codes/data_scripts/validate_data.py
@@ -0,0 +1,66 @@
+# This script iterates through all the data with no worker threads and performs whatever transformations are prescribed.
+# The idea is to find bad/corrupt images.
+
+import math
+import argparse
+import random
+import torch
+import options.options as option
+from utils import util
+from data import create_dataloader, create_dataset
+from time import time
+from tqdm import tqdm
+from skimage import io
+
+def main():
+    #### options
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_mi1_spsr_switched2.yml')
+    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
+                        help='job launcher')
+    parser.add_argument('--local_rank', type=int, default=0)
+    args = parser.parse_args()
+    opt = option.parse(args.opt, is_train=True)
+
+    #### distributed training settings
+    opt['dist'] = False
+    rank = -1
+
+    # convert to NoneDict, which returns None for missing keys
+    opt = option.dict_to_nonedict(opt)
+
+    #### random seed
+    seed = opt['train']['manual_seed']
+    if seed is None:
+        seed = random.randint(1, 10000)
+    util.set_random_seed(seed)
+
+    torch.backends.cudnn.benchmark = True
+    # torch.backends.cudnn.deterministic = True
+
+    #### create train and val dataloader
+    for phase, dataset_opt in opt['datasets'].items():
+        if phase == 'train':
+            train_set = create_dataset(dataset_opt)
+            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
+            total_iters = int(opt['train']['niter'])
+            total_epochs = int(math.ceil(total_iters / train_size))
+            dataset_opt['n_workers'] = 0  # Force num_workers=0 to make dataloader work in process.
+            train_loader = create_dataloader(train_set, dataset_opt, opt, None)
+            if rank <= 0:
+                print('Number of train images: {:,d}, iters: {:,d}'.format(
+                    len(train_set), train_size))
+    assert train_loader is not None
+
+    tq_ldr = tqdm(train_set.paths_GT)
+    for path in tq_ldr:
+        try:
+            _ = io.imread(path)
+            # Do stuff with img
+        except Exception as e:
+            print("Error with %s" % (path,))
+            print(e)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py
index fef8c85a..4e2ec4c0 100644
--- a/codes/models/ExtensibleTrainer.py
+++ b/codes/models/ExtensibleTrainer.py
@@ -1,22 +1,18 @@
 import logging
-from collections import OrderedDict
-import torch
-import torch.nn as nn
-from torch.nn.parallel import DataParallel, DistributedDataParallel
-import models.networks as networks
-from models.steps.steps import create_step
-import models.lr_scheduler as lr_scheduler
-from models.base_model import BaseModel
-from models.loss import GANLoss, FDPLLoss
-from apex import amp
-from data.weight_scheduler import get_scheduler_for_opt
-from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding
-import torch.nn.functional as F
-import glob
-import random
-
-import torchvision.utils as utils
 import os
+import random
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+import torchvision.utils as utils
+from apex import amp
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+import models.lr_scheduler as lr_scheduler
+import models.networks as networks
+from models.base_model import BaseModel
+from models.steps.steps import ConfigurableStep
 
 logger = logging.getLogger('base')
 
@@ -31,15 +27,20 @@ class ExtensibleTrainer(BaseModel):
         train_opt = opt['train']
         self.mega_batch_factor = 1
 
+        # env is used as a global state to store things that subcomponents might need.
+        env = {'device': self.device,
+               'rank': self.rank,
+               'opt': opt}
+
         self.netsG = {}
         self.netsD = {}
         self.networks = []
         for name, net in opt['networks'].items():
             if net['type'] == 'generator':
-                new_net = networks.define_G(net)
+                new_net = networks.define_G(net, None, opt['scale']).to(self.device)
                 self.netsG[name] = new_net
             elif net['type'] == 'discriminator':
-                new_net = networks.define_D(net)
+                new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device)
                 self.netsD[name] = new_net
             else:
                 raise NotImplementedError("Can only handle generators and discriminators")
@@ -51,7 +52,7 @@ class ExtensibleTrainer(BaseModel):
                 self.mega_batch_factor = 1
 
             # Initialize amp.
-            amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_level'], num_losses=len(self.optimizers))
+            amp_nets, amp_opts = amp.initialize(self.networks, self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
             # self.networks is stored unwrapped. It should never be used for forward() or backward() passes, instead use
             # self.netG and self.netD for that.
             self.networks = amp_nets
@@ -76,15 +77,18 @@ class ExtensibleTrainer(BaseModel):
             for dnet in dnets:
                 for net_dict in [self.netsD, self.netsG]:
                     for k, v in net_dict.items():
-                        if v == dnet:
+                        if v == dnet.module:
                             net_dict[k] = dnet
                             found += 1
             assert found == len(self.networks)
 
+            env['generators'] = self.netsG
+            env['discriminators'] = self.netsD
+
             # Initialize the training steps
             self.steps = []
-            for step in opt['steps']:
-                step = create_step(step, self.netsG, self.netsD)
+            for step_name, step in opt['steps'].items():
+                step = ConfigurableStep(step, env)
                 self.steps.append(step)
                 self.optimizers.extend(step.get_optimizers())
 
@@ -113,8 +117,8 @@ class ExtensibleTrainer(BaseModel):
                 net.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
 
         # Iterate through the steps, performing them one at a time.
-        state = {'lr': self.var_L, 'hr': self.var_H, 'ref': self.var_ref}
-        for s in self.steps:
+        state = {'lq': self.var_L, 'hq': self.var_H, 'ref': self.var_ref}
+        for step_num, s in enumerate(self.steps):
             # Only set requires_grad=True for the network being trained.
             nets_to_train = s.get_networks_trained()
             for name, net in self.networks.items():
@@ -126,8 +130,20 @@ class ExtensibleTrainer(BaseModel):
                         p.requires_grad = False
 
             # Now do a forward and backward pass for each gradient accumulation step.
+            new_states = {}
             for m in range(self.mega_batch_factor):
-                state = s.do_forward_backward(state, m)
+                ns = s.do_forward_backward(state, m, step_num)
+                for k, v in ns.items():
+                    if k not in new_states.keys():
+                        new_states[k] = [v.detach()]
+                    else:
+                        new_states[k].append(v.detach())
+
+            # Push the detached new state tensors into the state map for use with the next step.
+            for k, v in new_states.items():
+                # Overwriting existing state keys is not supported.
+                assert k not in state.keys()
+                state[k] = v
 
             # And finally perform optimization.
             s.do_step()
diff --git a/codes/models/__init__.py b/codes/models/__init__.py
index 4cb3264a..26f0e1fa 100644
--- a/codes/models/__init__.py
+++ b/codes/models/__init__.py
@@ -13,6 +13,8 @@ def create_model(opt):
         from .feature_model import FeatureModel as M
     elif model == 'spsr':
         from .SPSR_model import SPSRModel as M
+    elif model == 'extensibletrainer':
+        from .ExtensibleTrainer import ExtensibleTrainer as M
     else:
         raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
     m = M(opt)
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 6cfd2089..dc761243 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -17,10 +17,14 @@ import functools
 from collections import OrderedDict
 
 # Generator
-def define_G(opt, net_key='network_G'):
-    opt_net = opt[net_key]
+def define_G(opt, net_key='network_G', scale=None):
+    if net_key is not None:
+        opt_net = opt[net_key]
+    else:
+        opt_net = opt
+    if scale is None:
+        scale = opt['scale']
     which_model = opt_net['which_model_G']
-    scale = opt['scale']
 
     # image restoration
     if which_model == 'MSRResNet':
diff --git a/codes/models/steps/__init__.py b/codes/models/steps/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py
new file mode 100644
index 00000000..f5954087
--- /dev/null
+++ b/codes/models/steps/injectors.py
@@ -0,0 +1,32 @@
+import torch.nn
+from models.archs.SPSR_arch import ImageGradientNoPadding
+
+# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
+def create_injector(opt_inject, env):
+    type = opt_inject['type']
+    if type == 'img_grad':
+        return ImageGradientInjector(opt_inject, env)
+    else:
+        raise NotImplementedError
+
+
+class Injector(torch.nn.Module):
+    def __init__(self, opt, env):
+        super(self, Injector).__init__()
+        self.opt = opt
+        self.env = env
+        self.input = opt['in']
+        self.output = opt['out']
+
+    # This should return a dict of new state variables.
+    def forward(self, state):
+        raise NotImplementedError
+
+
+class ImageGradientInjector(Injector):
+    def __init__(self, opt, env):
+        super(self, ImageGradientInjector).__init__(opt, env)
+        self.img_grad_fn = ImageGradientNoPadding()
+
+    def forward(self, state):
+        return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}
\ No newline at end of file
diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py
new file mode 100644
index 00000000..3f80978d
--- /dev/null
+++ b/codes/models/steps/losses.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn as nn
+from models.networks import define_F
+from models.loss import GANLoss
+
+
+def create_generator_loss(opt_loss, env):
+    type = opt_loss['type']
+    if type == 'pix':
+        return PixLoss(opt_loss, env)
+    elif type == 'feature':
+        return FeatureLoss(opt_loss, env)
+    elif type == 'generator_gan':
+        return GeneratorGanLoss(opt_loss, env)
+    elif type == 'discriminator_gan':
+        return DiscriminatorGanLoss(opt_loss, env)
+    else:
+        raise NotImplementedError
+
+
+class ConfigurableLoss(nn.Module):
+    def __init__(self, opt, env):
+        super(self, ConfigurableLoss).__init__()
+        self.opt = opt
+        self.env = env
+
+    def forward(self, net, state):
+        raise NotImplementedError
+
+
+def get_basic_criterion_for_name(name, device):
+    if name == 'l1':
+        return nn.L1Loss(device=device)
+    elif name == 'l2':
+        return nn.MSELoss(device=device)
+    else:
+        raise NotImplementedError
+
+
+class PixLoss(ConfigurableLoss):
+    def __init__(self, opt, env):
+        super(self, PixLoss).__init__(opt, env)
+        self.opt = opt
+        self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
+
+    def forward(self, net, state):
+        return self.criterion(state[self.opt['fake']], state[self.opt['real']])
+
+
+class FeatureLoss(ConfigurableLoss):
+    def __init__(self, opt, env):
+        super(self, FeatureLoss).__init__(opt, env)
+        self.opt = opt
+        self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
+        self.netF = define_F(opt).to(self.env['device'])
+
+    def forward(self, net, state):
+        with torch.no_grad():
+            logits_real = self.netF(state[self.opt['real']])
+            logits_fake = self.netF(state[self.opt['fake']])
+        return self.criterion(logits_fake, logits_real)
+
+
+class GeneratorGanLoss(ConfigurableLoss):
+    def __init__(self, opt, env):
+        super(self, GeneratorGanLoss).__init__(opt, env)
+        self.opt = opt
+        self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
+        self.netD = env['discriminators'][opt['discriminator']]
+
+    def forward(self, net, state):
+        if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
+            if self.opt['gan_type'] == 'crossgan':
+                pred_g_fake = self.netD(state[self.opt['fake']], state['lq'])
+            else:
+                pred_g_fake = self.netD(state[self.opt['fake']])
+            return self.criterion(pred_g_fake, True)
+        elif self.opt['gan_type'] == 'ragan':
+            pred_d_real = self.netD(state[self.opt['real']]).detach()
+            pred_g_fake = self.netD(state[self.opt['fake']])
+            return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
+                    self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
+        else:
+            raise NotImplementedError
+
+
+class DiscriminatorGanLoss(ConfigurableLoss):
+    def __init__(self, opt, env):
+        super(self, DiscriminatorGanLoss).__init__(opt, env)
+        self.opt = opt
+        self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
+
+    def forward(self, net, state):
+        if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
+            if self.opt['gan_type'] == 'crossgan':
+                pred_g_fake = net(state[self.opt['fake']].detach(), state['lq'])
+            else:
+                pred_g_fake = net(state[self.opt['fake']].detach())
+            return self.criterion(pred_g_fake, False)
+        elif self.opt['gan_type'] == 'ragan':
+            pred_d_real = self.netD(state[self.opt['real']])
+            pred_g_fake = self.netD(state[self.opt['fake']].detach())
+            return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), True) +
+                    self.cri_gan(pred_g_fake - torch.mean(pred_d_real), False)) / 2
+        else:
+            raise NotImplementedError
diff --git a/codes/models/steps/losses/generator_losses.py b/codes/models/steps/losses/generator_losses.py
deleted file mode 100644
index 5ae088f7..00000000
--- a/codes/models/steps/losses/generator_losses.py
+++ /dev/null
@@ -1,9 +0,0 @@
-def create_generator_loss(opt_loss):
-    pass
-
-
-class GeneratorLoss:
-    def __init__(self, opt):
-        self.opt = opt
-
-    def get_loss(self, var_L, var_H, var_Gen, extras=None):
\ No newline at end of file
diff --git a/codes/models/steps/srgan_generator_step.py b/codes/models/steps/srgan_generator_step.py
deleted file mode 100644
index 4d7b58ca..00000000
--- a/codes/models/steps/srgan_generator_step.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Defines the expected API for a step
-class SrGanGeneratorStep:
-
-    def __init__(self, opt_step, opt, netsG, netsD):
-        self.step_opt = opt_step
-        self.opt = opt
-        self.gen = netsG['base']
-        self.disc = netsD['base']
-        for loss in self.step_opt['losses']:
-
-        # G pixel loss
-        if train_opt['pixel_weight'] > 0:
-            l_pix_type = train_opt['pixel_criterion']
-            if l_pix_type == 'l1':
-                self.cri_pix = nn.L1Loss().to(self.device)
-            elif l_pix_type == 'l2':
-                self.cri_pix = nn.MSELoss().to(self.device)
-            else:
-                raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
-            self.l_pix_w = train_opt['pixel_weight']
-        else:
-            logger.info('Remove pixel loss.')
-            self.cri_pix = None
-
-
-    # Returns all optimizers used in this step.
-    def get_optimizers(self):
-        pass
-
-    # Returns optimizers which are opting in for default LR scheduling.
-    def get_optimizers_with_default_scheduler(self):
-        pass
-
-    # Returns the names of the networks this step will train. Other networks will be frozen.
-    def get_networks_trained(self):
-        pass
-
-    # Performs all forward and backward passes for this step given an input state. All input states are lists or
-    # chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step
-    # exports (which may be used by subsequent steps)
-    def do_forward_backward(self, state, grad_accum_step):
-        return state
-
-    # Performs the optimizer step after all gradient accumulation is completed.
-    def do_step(self):
-        pass
\ No newline at end of file
diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py
index bcd6f2a2..e71d3164 100644
--- a/codes/models/steps/steps.py
+++ b/codes/models/steps/steps.py
@@ -1,29 +1,117 @@
+from utils.loss_accumulator import LossAccumulator
+from torch.nn import Module
+import logging
+from models.steps.losses import create_generator_loss
+import torch
+from apex import amp
+from collections import OrderedDict
+from .injectors import create_injector
+
+logger = logging.getLogger('base')
 
 
-def create_step(opt, opt_step, netsG, netsD):
-    pass
+# Defines the expected API for a single training step
+class ConfigurableStep(Module):
 
+    def __init__(self, opt_step, env):
+        super(ConfigurableStep, self).__init__()
+
+        self.step_opt = opt_step
+        self.env = env
+        self.opt = env['opt']
+        self.gen = env['generators'][opt_step['generator']]
+        self.discs = env['discriminators']
+        self.gen_outputs = opt_step['generator_outputs']
+        self.training_net = env['generators'][opt_step['training']] if opt_step['training'] in env['generators'].keys() else env['discriminators'][opt_step['training']]
+        self.loss_accumulator = LossAccumulator()
+
+        self.injectors = []
+        if 'injectors' in self.step_opt.keys():
+            for inj_name, injector in self.step_opt['injectors'].items():
+                self.injectors.append(create_injector(injector, env))
+
+        losses = []
+        self.weights = {}
+        for loss_name, loss in self.step_opt['losses'].items():
+            losses.append((loss_name, create_generator_loss(loss, env)))
+            self.weights[loss_name] = loss['weight']
+        self.losses = OrderedDict(losses)
+
+        # Intentionally abstract so subclasses can have alternative optimizers.
+        self.define_optimizers()
+
+    # Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
+    #  This default implementation defines a single optimizer for all Generator parameters.
+    def define_optimizers(self):
+        optim_params = []
+        for k, v in self.training_net.named_parameters():  # can optimize for a part of the model
+            if v.requires_grad:
+                optim_params.append(v)
+            else:
+                if self.env['rank'] <= 0:
+                    logger.warning('Params [{:s}] will not optimize.'.format(k))
+        opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
+                               weight_decay=self.step_opt['weight_decay'],
+                               betas=(self.step_opt['beta1'], self.step_opt['beta2']))
+        self.optimizers = [opt]
 
-# Defines the expected API for a step
-class base_step:
     # Returns all optimizers used in this step.
     def get_optimizers(self):
-        pass
+        assert self.optimizers is not None
+        return self.optimizers
 
     # Returns optimizers which are opting in for default LR scheduling.
     def get_optimizers_with_default_scheduler(self):
-        pass
+        assert self.optimizers is not None
+        return self.optimizers
 
     # Returns the names of the networks this step will train. Other networks will be frozen.
     def get_networks_trained(self):
-        pass
+        return [self.step_opt['training']]
 
-    # Performs all forward and backward passes for this step given an input state. All input states are lists or
-    # chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step
-    # exports (which may be used by subsequent steps)
-    def do_forward_backward(self, state, grad_accum_step):
-        return state
+    # Performs all forward and backward passes for this step given an input state. All input states are lists of
+    # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
+    # steps might use. These tensors are automatically detached and accumulated into chunks.
+    def do_forward_backward(self, state, grad_accum_step, amp_loss_id):
+        # First, do a forward pass with the generator.
+        results = self.gen(state[self.step_opt['generator_input']][grad_accum_step])
+        # Extract the resultants into a "new_state" dict per the configuration.
+        new_state = {}
+        for i, gen_out in enumerate(self.gen_outputs):
+            new_state[gen_out] = results[i]
 
-    # Performs the optimizer step after all gradient accumulation is completed.
+        # Prepare a de-chunked state dict which will be used for the injectors & losses.
+        local_state = {}
+        for k, v in state.items():
+            local_state[k] = v[grad_accum_step]
+        local_state.update(new_state)
+
+        # Inject in any extra dependencies.
+        for inj in self.injectors:
+            injected = inj(local_state)
+            local_state.update(injected)
+            new_state.update(injected)
+
+        # Finally, compute the losses.
+        total_loss = 0
+        for loss_name, loss in self.losses.items():
+            l = loss(self.training_net, local_state)
+            self.loss_accumulator.add_loss(loss_name, l)
+            total_loss += l * self.weights[loss_name]
+        self.loss_accumulator.add_loss("total", total_loss)
+
+        # Get dem grads!
+        with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
+            scaled_loss.backward()
+
+        return new_state
+
+
+    # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
+    # all self.optimizers.
     def do_step(self):
-        pass
\ No newline at end of file
+        for opt in self.optimizers:
+            opt.step()
+
+    def get_metrics(self):
+        return self.loss_accumulator.as_dict()
\ No newline at end of file
diff --git a/codes/train2.py b/codes/train2.py
new file mode 100644
index 00000000..5f24dc15
--- /dev/null
+++ b/codes/train2.py
@@ -0,0 +1,289 @@
+import os
+import math
+import argparse
+import random
+import logging
+import shutil
+from tqdm import tqdm
+
+import torch
+from data.data_sampler import DistIterSampler
+
+import options.options as option
+from utils import util
+from data import create_dataloader, create_dataset
+from models import create_model
+from time import time
+
+
+def init_dist(backend='nccl', **kwargs):
+    # These packages have globals that screw with Windows, so only import them if needed.
+    import torch.distributed as dist
+    import torch.multiprocessing as mp
+
+    """initialization for distributed training"""
+    if mp.get_start_method(allow_none=True) != 'spawn':
+        mp.set_start_method('spawn')
+    rank = int(os.environ['RANK'])
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(rank % num_gpus)
+    dist.init_process_group(backend=backend, **kwargs)
+
+def main():
+    #### options
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_nt_spsr_switched.yml')
+    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
+                        help='job launcher')
+    parser.add_argument('--local_rank', type=int, default=0)
+    args = parser.parse_args()
+    opt = option.parse(args.opt, is_train=True)
+
+    colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode']
+    if colab_mode:
+        # Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there.
+        # Each one should have a TEST file in it.
+        util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
+                                   os.path.join(opt['remote_path'], 'training_state', "TEST"))
+        util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
+                                   os.path.join(opt['remote_path'], 'models', "TEST"))
+        util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
+                                   os.path.join(opt['remote_path'], 'val_images', "TEST"))
+        # Load the state and models needed from the remote server.
+        if opt['path']['resume_state']:
+            util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state']))
+        if opt['path']['pretrain_model_G']:
+            util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G']))
+        if opt['path']['pretrain_model_D']:
+            util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D']))
+
+    #### distributed training settings
+    if args.launcher == 'none':  # disabled distributed training
+        opt['dist'] = False
+        rank = -1
+        print('Disabled distributed training.')
+    else:
+        opt['dist'] = True
+        init_dist()
+        world_size = torch.distributed.get_world_size()
+        rank = torch.distributed.get_rank()
+
+    #### loading resume state if exists
+    if opt['path'].get('resume_state', None):
+        # distributed resuming: all load into default GPU
+        device_id = torch.cuda.current_device()
+        resume_state = torch.load(opt['path']['resume_state'],
+                                  map_location=lambda storage, loc: storage.cuda(device_id))
+        option.check_resume(opt, resume_state['iter'])  # check resume options
+    else:
+        resume_state = None
+
+    #### mkdir and loggers
+    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
+        if resume_state is None:
+            util.mkdir_and_rename(
+                opt['path']['experiments_root'])  # rename experiment folder if exists
+            util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
+                         and 'pretrain_model' not in key and 'resume' not in key))
+
+        # config loggers. Before it, the log will not work
+        util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
+                          screen=True, tofile=True)
+        logger = logging.getLogger('base')
+        logger.info(option.dict2str(opt))
+        # tensorboard logger
+        if opt['use_tb_logger'] and 'debug' not in opt['name']:
+            tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
+            version = float(torch.__version__[0:3])
+            if version >= 1.1:  # PyTorch 1.1
+                from torch.utils.tensorboard import SummaryWriter
+            else:
+                logger.info(
+                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
+                from tensorboardX import SummaryWriter
+            tb_logger = SummaryWriter(log_dir=tb_logger_path)
+    else:
+        util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
+        logger = logging.getLogger('base')
+
+    # convert to NoneDict, which returns None for missing keys
+    opt = option.dict_to_nonedict(opt)
+
+    #### random seed
+    seed = opt['train']['manual_seed']
+    if seed is None:
+        seed = random.randint(1, 10000)
+    if rank <= 0:
+        logger.info('Random seed: {}'.format(seed))
+    util.set_random_seed(seed)
+
+    torch.backends.cudnn.benchmark = True
+    # torch.backends.cudnn.deterministic = True
+
+    #### create train and val dataloader
+    dataset_ratio = 200  # enlarge the size of each epoch
+    for phase, dataset_opt in opt['datasets'].items():
+        if phase == 'train':
+            train_set = create_dataset(dataset_opt)
+            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
+            total_iters = int(opt['train']['niter'])
+            total_epochs = int(math.ceil(total_iters / train_size))
+            if opt['dist']:
+                train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
+                total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
+            else:
+                train_sampler = None
+            train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
+            if rank <= 0:
+                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
+                    len(train_set), train_size))
+                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
+                    total_epochs, total_iters))
+        elif phase == 'val':
+            val_set = create_dataset(dataset_opt)
+            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
+            if rank <= 0:
+                logger.info('Number of val images in [{:s}]: {:d}'.format(
+                    dataset_opt['name'], len(val_set)))
+        else:
+            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
+    assert train_loader is not None
+
+    #### create model
+    model = create_model(opt)
+
+    #### resume training
+    if resume_state:
+        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
+            resume_state['epoch'], resume_state['iter']))
+
+        start_epoch = resume_state['epoch']
+        current_step = resume_state['iter']
+        model.resume_training(resume_state)  # handle optimizers and schedulers
+    else:
+        current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
+        start_epoch = 0
+
+    #### training
+    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
+    for epoch in range(start_epoch, total_epochs + 1):
+        if opt['dist']:
+            train_sampler.set_epoch(epoch)
+        tq_ldr = tqdm(train_loader)
+
+        _t = time()
+        _profile = False
+        for _, train_data in enumerate(tq_ldr):
+            if _profile:
+                print("Data fetch: %f" % (time() - _t))
+                _t = time()
+
+            current_step += 1
+            if current_step > total_iters:
+                break
+            #### update learning rate
+            model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
+
+            #### training
+            if _profile:
+                print("Update LR: %f" % (time() - _t))
+                _t = time()
+            model.feed_data(train_data)
+            model.optimize_parameters(current_step)
+            if _profile:
+                print("Model feed + step: %f" % (time() - _t))
+                _t = time()
+
+            #### log
+            if current_step % opt['logger']['print_freq'] == 0:
+                logs = model.get_current_log(current_step)
+                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
+                for v in model.get_current_learning_rate():
+                    message += '{:.3e},'.format(v)
+                message += ')] '
+                for k, v in logs.items():
+                    if 'histogram' in k:
+                        if rank <= 0:
+                            tb_logger.add_histogram(k, v, current_step)
+                    else:
+                        message += '{:s}: {:.4e} '.format(k, v)
+                        # tensorboard logger
+                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
+                            if rank <= 0:
+                                tb_logger.add_scalar(k, v, current_step)
+                if rank <= 0:
+                    logger.info(message)
+            #### validation
+            if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
+                if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0:  # image restoration validation
+                    model.force_restore_swapout()
+                    val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
+                    # does not support multi-GPU validation
+                    pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
+                    avg_psnr = 0.
+                    avg_fea_loss = 0.
+                    idx = 0
+                    colab_imgs_to_copy = []
+                    for val_data in val_loader:
+                        idx += 1
+                        for b in range(len(val_data['LQ_path'])):
+                            img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
+                            img_dir = os.path.join(opt['path']['val_images'], img_name)
+                            util.mkdir(img_dir)
+
+                            model.feed_data(val_data)
+                            model.test()
+
+                            visuals = model.get_current_visuals()
+                            if visuals is None:
+                                continue
+
+                            sr_img = util.tensor2img(visuals['rlt'][b])  # uint8
+                            #gt_img = util.tensor2img(visuals['GT'][b])  # uint8
+
+                            # Save SR images for reference
+                            img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
+                            save_img_path = os.path.join(img_dir, img_base_name)
+                            util.save_img(sr_img, save_img_path)
+                            if colab_mode:
+                                colab_imgs_to_copy.append(save_img_path)
+
+                            # calculate PSNR (Naw - don't do that. PSNR sucks)
+                            #sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
+                            #avg_psnr += util.calculate_psnr(sr_img, gt_img)
+                            #pbar.update('Test {}'.format(img_name))
+
+                            # calculate fea loss
+                            avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
+
+                    if colab_mode:
+                        util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
+                                                  colab_imgs_to_copy,
+                                                  os.path.join(opt['remote_path'], 'val_images', img_base_name))
+
+                    avg_psnr = avg_psnr / idx
+                    avg_fea_loss = avg_fea_loss / idx
+
+                    # log
+                    logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
+                    # tensorboard logger
+                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
+                        #tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
+                        tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
+
+            #### save models and training states
+            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
+                if rank <= 0:
+                    logger.info('Saving models and training states.')
+                    model.save(current_step)
+                    model.save_training_state(epoch, current_step)
+
+    if rank <= 0:
+        logger.info('Saving the final model.')
+        model.save('latest')
+        logger.info('End of training.')
+        tb_logger.close()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py
new file mode 100644
index 00000000..1f0e151a
--- /dev/null
+++ b/codes/utils/loss_accumulator.py
@@ -0,0 +1,20 @@
+import torch
+
+# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
+class LossAccumulator:
+    def __init__(self, buffer_sz=10):
+        self.buffer_sz = buffer_sz
+        self.buffers = {}
+
+    def add_loss(self, name, tensor):
+        if name not in self.buffers.keys():
+            self.buffers[name] = (0, torch.zeros(self.buffer_sz))
+        i, buf = self.buffers[name]
+        buf[i] = tensor.detach().cpu()
+        self.buffers[name] = ((i+1) % self.buffer_sz, buf)
+
+    def as_dict(self):
+        result = {}
+        for k, v in self.buffers:
+            result["loss_" + k] = torch.mean(v)
+        return result
\ No newline at end of file