diff --git a/codes/data/__init__.py b/codes/data/__init__.py
index c62c3700..ea883a8a 100644
--- a/codes/data/__init__.py
+++ b/codes/data/__init__.py
@@ -49,6 +49,8 @@ def create_dataset(dataset_opt):
         from data.torch_dataset import TorchDataset as D
     elif mode == 'byol_dataset':
         from data.byol_attachment import ByolDatasetWrapper as D
+    elif mode == 'byol_structured_dataset':
+        from data.byol_attachment import StructuredCropDatasetWrapper as D
     elif mode == 'random_dataset':
         from data.random_dataset import RandomDataset as D
     else:
diff --git a/codes/data/byol_attachment.py b/codes/data/byol_attachment.py
index ddf1c6ea..0015b41a 100644
--- a/codes/data/byol_attachment.py
+++ b/codes/data/byol_attachment.py
@@ -1,4 +1,5 @@
 import random
+from time import time
 
 import torch
 import torchvision
@@ -10,6 +11,8 @@ import torch.nn.functional as F
 
 # Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
 # inputs. These are then outputted as 'aug1' and 'aug2'.
+from tqdm import tqdm
+
 from data import create_dataset
 from models.archs.arch_util import PixelUnshuffle
 from utils.util import opt_get
@@ -66,6 +69,17 @@ def snap(ref, other):
     return other - ref
 
 
+# Pads a tensor with zeros so that it fits in a dxd square.
+def pad_to(im, d):
+    if len(im.shape) == 3:
+        pd = torch.zeros((im.shape[0],d,d))
+        pd[:, :im.shape[1], :im.shape[2]] = im
+    else:
+        pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device)
+        pd[:, :, :im.shape[2], :im.shape[3]] = im
+    return pd
+
+
 # Variation of RandomResizedCrop, which picks a region of the image that the two augments must share. The augments
 # then propagate off random corners of the shared region, using the same scale.
 #
@@ -111,9 +125,17 @@ class RandomSharedRegionCrop(nn.Module):
         # Step 6
         m = self.multiple
         jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
+        jt = jt if base_t != 0 else abs(jt)  # If the top of a patch is zero, a negative jitter will cause it to go negative.
+        jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds.
+        jl = jl if base_l != 0 else abs(jl)
+        jl = jl if (base_l+base_w)*m != i1.shape[1] else 0
         p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
         p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
         jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
+        jt = jt if im2_t != 0 else abs(jt)
+        jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0
+        jl = jl if im2_l != 0 else abs(jl)
+        jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0
         p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
         p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
 
@@ -122,15 +144,15 @@ class RandomSharedRegionCrop(nn.Module):
         i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
         ix_h = min(base_b, im2_b) - max(base_t, im2_t)
         ix_w = min(base_r, im2_r) - max(base_l, im2_l)
-        recompute_package = (base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w)
+        recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w], dtype=torch.long)
 
         # Step 8
         mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
         mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
-        masked1 = p1 * mask1
+        masked1 = pad_to(p1 * mask1, d*m)
         mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
         mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
-        masked2 = p2 * mask2
+        masked2 = pad_to(p2 * mask2, d*m)
         mask = torch.full((1, d*m, d*m), fill_value=.33)
         mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
         mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33
@@ -141,14 +163,22 @@ class RandomSharedRegionCrop(nn.Module):
 
 # Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature
 # maps.
-def reconstructed_shared_regions(fea1, fea2, recompute_package):
-    f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = recompute_package
-    # Resize the input features to match
-    f1s = F.interpolate(fea1, (f1_h, f1_w), mode="bilinear")
-    f2s = F.interpolate(fea2, (f2_h, f2_w), mode="bilinear")
-    f1sh = f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w]
-    f2sh = f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w]
-    return f1sh, f2sh
+def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
+    package = recompute_package.cpu()
+    res1 = []
+    res2 = []
+    pad_dim = torch.max(package[:, -2:]).item()
+    # It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
+    # of conforming the recompute_package across the entire batch.
+    for b in range(package.shape[0]):
+        f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = tuple(package[b].tolist())
+        # Resize the input features to match
+        f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear")
+        f2s = F.interpolate(fea2[b].unsqueeze(0), (f2_h, f2_w), mode="bilinear")
+        # Outputs must be padded so they can "get along" with each other.
+        res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
+        res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
+    return torch.cat(res1, dim=0), torch.cat(res2, dim=0)
 
 
 # Follows the general template of BYOL dataset, with the following changes:
@@ -169,8 +199,8 @@ class StructuredCropDatasetWrapper(Dataset):
 
     def __getitem__(self, item):
         item = self.wrapped_dataset[item]
-        a1 = item['hq'] #self.aug(item['hq']).squeeze(dim=0)
-        a2 = item['hq'] #self.aug(item['lq']).squeeze(dim=0)
+        a1 = self.aug(item['hq']).squeeze(dim=0)
+        a2 = self.aug(item['lq']).squeeze(dim=0)
         a1, a2, sr_dim, m1, m2, db = self.rrc(a1, a2)
         item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim,
                      'masked1': m1, 'masked2': m2, 'aug_shared_view': db})
@@ -187,7 +217,7 @@ if __name__ == '__main__':
             {
             'mode': 'imagefolder',
             'name': 'amalgam',
-            'paths': ['F:\\4k6k\\datasets\\images\\flickr\\flickr-scrape\\filtered\carrot'],
+            'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
             'weights': [1],
             'target_size': 256,
             'force_multiple': 32,
@@ -204,15 +234,15 @@ if __name__ == '__main__':
     ds = StructuredCropDatasetWrapper(opt)
     import os
     os.makedirs("debug", exist_ok=True)
-    for i in range(0, len(ds)):
-        o = ds[random.randint(0, len(ds))]
-        for k, v in o.items():
+    for i in tqdm(range(0, len(ds))):
+        o = ds[random.randint(0, len(ds)-1)]
+        #for k, v in o.items():
             # 'lq', 'hq', 'aug1', 'aug2',
-            if k in [ 'aug_shared_view', 'masked1', 'masked2']:
-                torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
+            #if k in [ 'aug_shared_view', 'masked1', 'masked2']:
+                #torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
         rcpkg = o['similar_region_dimensions']
         pixun = PixelUnshuffle(8)
         pixsh = nn.PixelShuffle(8)
         rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg)
-        torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
-        torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
+        #torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
+        #torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
diff --git a/codes/models/byol/byol_structural.py b/codes/models/byol/byol_structural.py
new file mode 100644
index 00000000..9d423e60
--- /dev/null
+++ b/codes/models/byol/byol_structural.py
@@ -0,0 +1,178 @@
+import copy
+import random
+from functools import wraps
+from time import time
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from data.byol_attachment import reconstructed_shared_regions
+from models.byol.byol_model_wrapper import singleton, EMA, MLP, get_module_device, set_requires_grad, \
+    update_moving_average
+from utils.util import checkpoint
+
+# loss function
+def structural_loss_fn(x, y):
+    # Combine the structural dimensions into the batch dimension, then compute the "normal" BYOL loss.
+    x = x.permute(0,2,3,1).flatten(0,2)
+    y = y.permute(0,2,3,1).flatten(0,2)
+    x = F.normalize(x, dim=-1, p=2)
+    y = F.normalize(y, dim=-1, p=2)
+    return 2 - 2 * (x * y).sum(dim=-1)
+
+
+class StructuralTail(nn.Module):
+    def __init__(self, channels, projection_size, hidden_size=512):
+        super().__init__()
+        self.net = nn.Sequential(
+            nn.Conv2d(channels, hidden_size, kernel_size=1),
+            nn.BatchNorm2d(hidden_size),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(hidden_size, projection_size, kernel_size=1),
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+# a wrapper class for the base neural network
+# will manage the interception of the hidden layer output
+# and pipe it into the projecter and predictor nets
+class NetWrapper(nn.Module):
+    def __init__(self, net, projection_size, projection_hidden_size, layer=-2):
+        super().__init__()
+        self.net = net
+        self.layer = layer
+
+        self.projector = None
+        self.projection_size = projection_size
+        self.projection_hidden_size = projection_hidden_size
+
+        self.hidden = None
+        self.hook_registered = False
+
+    def _find_layer(self):
+        if type(self.layer) == str:
+            modules = dict([*self.net.named_modules()])
+            return modules.get(self.layer, None)
+        elif type(self.layer) == int:
+            children = [*self.net.children()]
+            return children[self.layer]
+        return None
+
+    def _hook(self, _, __, output):
+        self.hidden = output
+
+    def _register_hook(self):
+        layer = self._find_layer()
+        assert layer is not None, f'hidden layer ({self.layer}) not found'
+        handle = layer.register_forward_hook(self._hook)
+        self.hook_registered = True
+
+    @singleton('projector')
+    def _get_projector(self, hidden):
+        projector = StructuralTail(hidden.shape[1], self.projection_size, self.projection_hidden_size)
+        return projector.to(hidden)
+
+    def get_representation(self, x):
+        if self.layer == -1:
+            return self.net(x)
+
+        if not self.hook_registered:
+            self._register_hook()
+
+        unused = self.net(x)
+        hidden = self.hidden
+        self.hidden = None
+        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
+        return hidden
+
+    def forward(self, x):
+        representation = self.get_representation(x)
+        projector = self._get_projector(representation)
+        projection = checkpoint(projector, representation)
+        return projection
+
+
+class StructuralBYOL(nn.Module):
+    def __init__(
+            self,
+            net,
+            image_size,
+            hidden_layer=-2,
+            projection_size=256,
+            projection_hidden_size=512,
+            moving_average_decay=0.99,
+            use_momentum=True,
+            pretrained_state_dict=None,
+            freeze_until=0
+    ):
+        super().__init__()
+
+        if pretrained_state_dict:
+            net.load_state_dict(torch.load(pretrained_state_dict), strict=True)
+        self.freeze_until = freeze_until
+        if self.freeze_until > 0:
+            for p in net.parameters():
+                p.DO_NOT_TRAIN = True
+            self.frozen = True
+        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
+
+        self.use_momentum = use_momentum
+        self.target_encoder = None
+        self.target_ema_updater = EMA(moving_average_decay)
+
+        self.online_predictor = StructuralTail(projection_size, projection_size, projection_hidden_size)
+
+        # get device of network and make wrapper same device
+        device = get_module_device(net)
+        self.to(device)
+
+        # send a mock image tensor to instantiate singleton parameters
+        self.forward(torch.randn(2, 3, image_size, image_size, device=device),
+                     torch.randn(2, 3, image_size, image_size, device=device), None)
+
+    @singleton('target_encoder')
+    def _get_target_encoder(self):
+        target_encoder = copy.deepcopy(self.online_encoder)
+        set_requires_grad(target_encoder, False)
+        return target_encoder
+
+    def reset_moving_average(self):
+        del self.target_encoder
+        self.target_encoder = None
+
+    def update_for_step(self, step, __):
+        assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
+        assert self.target_encoder is not None, 'target encoder has not been created yet'
+        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
+        if self.frozen and self.freeze_until < step:
+            print("Unfreezing model weights. Let the latent training commence..")
+            for p in self.online_encoder.net.parameters():
+                del p.DO_NOT_TRAIN
+            self.frozen = False
+
+    def forward(self, image_one, image_two, similar_region_params):
+        online_proj_one = self.online_encoder(image_one)
+        online_proj_two = self.online_encoder(image_two)
+
+        online_pred_one = self.online_predictor(online_proj_one)
+        online_pred_two = self.online_predictor(online_proj_two)
+
+        with torch.no_grad():
+            target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
+            target_proj_one = target_encoder(image_one).detach()
+            target_proj_two = target_encoder(image_two).detach()
+
+        # In the structural BYOL, only the regions of the source image that are shared between the two augments are
+        # compared. These regions can be extracted from the latents using `reconstruct_shared_regions`.
+        if similar_region_params is not None:
+            online_pred_one, target_proj_two = reconstructed_shared_regions(online_pred_one, target_proj_two, similar_region_params)
+        loss_one = structural_loss_fn(online_pred_one, target_proj_two.detach())
+        if similar_region_params is not None:
+            online_pred_two, target_proj_one = reconstructed_shared_regions(online_pred_two, target_proj_one, similar_region_params)
+        loss_two = structural_loss_fn(online_pred_two, target_proj_one.detach())
+
+        loss = loss_one + loss_two
+        return loss.mean()
diff --git a/codes/models/networks.py b/codes/models/networks.py
index da129007..0f504e23 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -153,6 +153,12 @@ def define_G(opt, opt_net, scale=None):
         subnet = define_G(opt, opt_net['subnet'])
         netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
                     structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
+    elif which_model == 'structural_byol':
+        from models.byol.byol_structural import StructuralBYOL
+        subnet = define_G(opt, opt_net['subnet'])
+        netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
+                              pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
+                              freeze_until=opt_get(opt_net, ['freeze_until'], 0))
     elif which_model == 'spinenet':
         from models.archs.spinenet_arch import SpineNet
         netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])