diff --git a/codes/data/__init__.py b/codes/data/__init__.py
index 60bfc90a..61f478ae 100644
--- a/codes/data/__init__.py
+++ b/codes/data/__init__.py
@@ -3,18 +3,20 @@ import logging
 import torch
 import torch.utils.data
 
+from utils.util import opt_get
+
 
 def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
     phase = dataset_opt['phase']
     if phase == 'train':
-        if opt['dist']:
+        if opt_get(opt, ['dist'], False):
             world_size = torch.distributed.get_world_size()
             num_workers = dataset_opt['n_workers']
             assert dataset_opt['batch_size'] % world_size == 0
             batch_size = dataset_opt['batch_size'] // world_size
             shuffle = False
         else:
-            num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
+            num_workers = dataset_opt['n_workers']
             batch_size = dataset_opt['batch_size']
             shuffle = True
         return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py
index 4c077eed..a02d63b1 100644
--- a/codes/data/image_corruptor.py
+++ b/codes/data/image_corruptor.py
@@ -37,7 +37,10 @@ class ImageCorruptor:
 
     def corrupt_images(self, imgs, return_entropy=False):
         if self.num_corrupts == 0 and not self.fixed_corruptions:
-            return imgs
+            if return_entropy:
+                return imgs, []
+            else:
+                return imgs
 
         if self.num_corrupts == 0:
             augmentations = []
diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py
index c970c7ef..8b7415c9 100644
--- a/codes/data/image_folder_dataset.py
+++ b/codes/data/image_folder_dataset.py
@@ -35,6 +35,8 @@ class ImageFolderDataset:
         self.skip_lq = opt_get(opt, ['skip_lq'], False)
         self.disable_flip = opt_get(opt, ['disable_flip'], False)
         self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
+        self.force_square = opt_get(opt, ['force_square'], True)
+        self.fixed_parameters = {k: torch.tensor(v) for k, v in opt_get(opt, ['fixed_parameters'], {}).items()}
         if 'normalize' in opt.keys():
             if opt['normalize'] == 'stylegan2_norm':
                 self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
@@ -44,7 +46,8 @@ class ImageFolderDataset:
                 raise Exception('Unsupported normalize')
         else:
             self.normalize = None
-        assert (self.target_hq_size // self.scale) % self.multiple == 0  # If we dont throw here, we get some really obscure errors.
+        if self.target_hq_size is not None:
+            assert (self.target_hq_size // self.scale) % self.multiple == 0  # If we dont throw here, we get some really obscure errors.
         if not isinstance(self.paths, list):
             self.paths = [self.paths]
             self.weights = [1]
@@ -129,10 +132,10 @@ class ImageFolderDataset:
         if not self.disable_flip and random.random() < .5:
             hq = hq[:, ::-1, :]
 
-        # We must convert the image into a square.
-        h, w, _ = hq.shape
-        dim = min(h, w)
-        hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
+        if self.force_square:
+            h, w, _ = hq.shape
+            dim = min(h, w)
+            hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
 
         if self.labeler:
             assert hq.shape[0] == hq.shape[1]  # This just has not been accomodated yet.
@@ -211,6 +214,7 @@ class ImageFolderDataset:
                     v = v * 2 - 1
                 out_dict[k] = v
 
+        out_dict.update(self.fixed_parameters)
         return out_dict
 
 if __name__ == '__main__':
diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py
index e15d2321..e843e356 100644
--- a/codes/models/byol/byol_model_wrapper.py
+++ b/codes/models/byol/byol_model_wrapper.py
@@ -241,8 +241,8 @@ class BYOL(nn.Module):
         torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
 
     def forward(self, image_one, image_two):
-        image_one = self.aug(image_one)
-        image_two = self.aug(image_two)
+        image_one = self.aug(image_one.clone())
+        image_two = self.aug(image_two.clone())
 
         # Keep copies on hand for visual_dbg.
         self.im1 = image_one.detach().clone()
diff --git a/codes/scripts/diffusion/diffusion_sampler.py b/codes/scripts/diffusion/diffusion_sampler.py
new file mode 100644
index 00000000..e69de29b
diff --git a/codes/test.py b/codes/test.py
index 6a2529f8..f8f2c0a4 100644
--- a/codes/test.py
+++ b/codes/test.py
@@ -1,5 +1,6 @@
 import os.path as osp
 import logging
+import random
 import time
 import argparse
 from collections import OrderedDict
@@ -11,9 +12,10 @@ from trainer.ExtensibleTrainer import ExtensibleTrainer
 from data import create_dataset, create_dataloader
 from tqdm import tqdm
 import torch
+import numpy as np
 
 
-def forward_pass(model, output_dir, opt):
+def forward_pass(model, data, output_dir, opt):
     alteration_suffix = util.opt_get(opt, ['name'], '')
     denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
     model.feed_data(data, 0, need_GT=need_GT)
@@ -47,11 +49,16 @@ def forward_pass(model, output_dir, opt):
 
 
 if __name__ == "__main__":
+    # Set seeds
+    torch.manual_seed(5555)
+    random.seed(5555)
+    np.random.seed(5555)
+
     #### options
     torch.backends.cudnn.benchmark = True
     want_metrics = False
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_cats_stylegan2_rosinality.yml')
+    parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml')
     opt = option.parse(parser.parse_args().opt, is_train=False)
     opt = option.dict_to_nonedict(opt)
     utils.util.loaded_options = opt
@@ -93,7 +100,7 @@ if __name__ == "__main__":
             need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
             need_GT = need_GT and want_metrics
 
-            fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt)
+            fea_loss, psnr_loss = forward_pass(model, data, dataset_dir, opt)
             fea_loss += fea_loss
             psnr_loss += psnr_loss
 
diff --git a/codes/train.py b/codes/train.py
index a9aa6ca9..58b0ba3c 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -302,7 +302,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_cifar.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()
diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py
index e5e25ae6..263271c0 100644
--- a/codes/trainer/ExtensibleTrainer.py
+++ b/codes/trainer/ExtensibleTrainer.py
@@ -159,7 +159,8 @@ class ExtensibleTrainer(BaseModel):
         self.batch_factor = self.mega_batch_factor
         self.opt['checkpointing_enabled'] = self.checkpointing_cache
         # The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
-        if 'mod_batch_factor' in self.opt['train'].keys() and \
+        if 'train' in self.opt.keys() and \
+                'mod_batch_factor' in self.opt['train'].keys() and \
                 self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
             self.batch_factor = self.opt['train']['mod_batch_factor']
             if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
@@ -350,8 +351,7 @@ class ExtensibleTrainer(BaseModel):
 
     def get_current_visuals(self, need_GT=True):
         # Conforms to an archaic format from MMSR.
-        res = {'lq': self.eval_state['lq'][0].float().cpu(),
-               'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
+        res = {'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
         if 'hq' in self.eval_state.keys():
             res['hq'] = self.eval_state['hq'][0].float().cpu(),
         return res
diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py
index 576dc52b..5705a984 100644
--- a/codes/trainer/injectors/gaussian_diffusion_injector.py
+++ b/codes/trainer/injectors/gaussian_diffusion_injector.py
@@ -40,7 +40,9 @@ class GaussianDiffusionInferenceInjector(Injector):
     def __init__(self, opt, env):
         super().__init__(opt, env)
         self.generator = opt['generator']
-        self.output_shape = opt['output_shape']
+        self.output_batch_size = opt['output_batch_size']
+        self.output_scale_factor = opt['output_scale_factor']
+        self.undo_n1_to_1 = opt_get(opt, ['undo_n1_to_1'], False)  # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1]
         opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
         opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
                                                                  [opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])])
@@ -49,9 +51,12 @@ class GaussianDiffusionInferenceInjector(Injector):
 
     def forward(self, state):
         gen = self.env['generators'][self.opt['generator']]
-        batch_size = self.output_shape[0]
-        model_inputs = {k: state[v][:batch_size] for k, v in self.model_input_keys.items()}
+        model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
         gen.eval()
         with torch.no_grad():
-            gen = self.diffusion.p_sample_loop(gen, self.output_shape, model_kwargs=model_inputs)
+            output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
+                            model_inputs['low_res'].shape[-1] * self.output_scale_factor)
+            gen = self.diffusion.p_sample_loop(gen, output_shape, model_kwargs=model_inputs)
+            if self.undo_n1_to_1:
+                gen = (gen + 1) / 2
             return {self.output: gen}
diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py
index f18b1bab..c4808054 100644
--- a/codes/trainer/steps.py
+++ b/codes/trainer/steps.py
@@ -7,7 +7,7 @@ from trainer.losses import create_loss
 import torch
 from collections import OrderedDict
 from trainer.inject import create_injector
-from utils.util import recursively_detach
+from utils.util import recursively_detach, opt_get
 
 logger = logging.getLogger('base')
 
@@ -53,21 +53,19 @@ class ConfigurableStep(Module):
     #  This default implementation defines a single optimizer for all Generator parameters.
     #  Must be called after networks are initialized and wrapped.
     def define_optimizers(self):
+        opt_configs = opt_get(self.step_opt, ['optimizer_params'], None)
+        self.optimizers = []
+        if opt_configs is None:
+            return
         training = self.step_opt['training']
         training_net = self.get_network_for_name(training)
-        # When only training one network, optimizer params can just embedded in the step params.
-        if 'optimizer_params' not in self.step_opt.keys():
-            opt_configs = [self.step_opt]
-        else:
-            opt_configs = [self.step_opt['optimizer_params']]
         nets = [training_net]
         training = [training]
-        self.optimizers = []
         for net_name, net, opt_config in zip(training, nets, opt_configs):
             # Configs can organize parameters by-group and specify different learning rates for each group. This only
             # works in the model specifically annotates which parameters belong in which group using PARAM_GROUP.
             optim_params = {'default': {'params': [], 'lr': opt_config['lr']}}
-            if 'param_groups' in opt_config.keys():
+            if opt_config is not None and 'param_groups' in opt_config.keys():
                 for k, pg in opt_config['param_groups'].items():
                     optim_params[k] = {'params': [], 'lr': pg['lr']}