diff --git a/codes/data/__init__.py b/codes/data/__init__.py index c62c3700..ea883a8a 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -49,6 +49,8 @@ def create_dataset(dataset_opt): from data.torch_dataset import TorchDataset as D elif mode == 'byol_dataset': from data.byol_attachment import ByolDatasetWrapper as D + elif mode == 'byol_structured_dataset': + from data.byol_attachment import StructuredCropDatasetWrapper as D elif mode == 'random_dataset': from data.random_dataset import RandomDataset as D else: diff --git a/codes/data/byol_attachment.py b/codes/data/byol_attachment.py index ddf1c6ea..0015b41a 100644 --- a/codes/data/byol_attachment.py +++ b/codes/data/byol_attachment.py @@ -1,4 +1,5 @@ import random +from time import time import torch import torchvision @@ -10,6 +11,8 @@ import torch.nn.functional as F # Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq' # inputs. These are then outputted as 'aug1' and 'aug2'. +from tqdm import tqdm + from data import create_dataset from models.archs.arch_util import PixelUnshuffle from utils.util import opt_get @@ -66,6 +69,17 @@ def snap(ref, other): return other - ref +# Pads a tensor with zeros so that it fits in a dxd square. +def pad_to(im, d): + if len(im.shape) == 3: + pd = torch.zeros((im.shape[0],d,d)) + pd[:, :im.shape[1], :im.shape[2]] = im + else: + pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device) + pd[:, :, :im.shape[2], :im.shape[3]] = im + return pd + + # Variation of RandomResizedCrop, which picks a region of the image that the two augments must share. The augments # then propagate off random corners of the shared region, using the same scale. # @@ -111,9 +125,17 @@ class RandomSharedRegionCrop(nn.Module): # Step 6 m = self.multiple jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) + jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative. + jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds. + jl = jl if base_l != 0 else abs(jl) + jl = jl if (base_l+base_w)*m != i1.shape[1] else 0 p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl] p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear") jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) + jt = jt if im2_t != 0 else abs(jt) + jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0 + jl = jl if im2_l != 0 else abs(jl) + jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0 p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl] p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear") @@ -122,15 +144,15 @@ class RandomSharedRegionCrop(nn.Module): i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l) ix_h = min(base_b, im2_b) - max(base_t, im2_t) ix_w = min(base_r, im2_r) - max(base_l, im2_l) - recompute_package = (base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w) + recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w], dtype=torch.long) # Step 8 mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5) mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1 - masked1 = p1 * mask1 + masked1 = pad_to(p1 * mask1, d*m) mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5) mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1 - masked2 = p2 * mask2 + masked2 = pad_to(p2 * mask2, d*m) mask = torch.full((1, d*m, d*m), fill_value=.33) mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33 mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33 @@ -141,14 +163,22 @@ class RandomSharedRegionCrop(nn.Module): # Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature # maps. -def reconstructed_shared_regions(fea1, fea2, recompute_package): - f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = recompute_package - # Resize the input features to match - f1s = F.interpolate(fea1, (f1_h, f1_w), mode="bilinear") - f2s = F.interpolate(fea2, (f2_h, f2_w), mode="bilinear") - f1sh = f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w] - f2sh = f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w] - return f1sh, f2sh +def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor): + package = recompute_package.cpu() + res1 = [] + res2 = [] + pad_dim = torch.max(package[:, -2:]).item() + # It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside + # of conforming the recompute_package across the entire batch. + for b in range(package.shape[0]): + f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = tuple(package[b].tolist()) + # Resize the input features to match + f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear") + f2s = F.interpolate(fea2[b].unsqueeze(0), (f2_h, f2_w), mode="bilinear") + # Outputs must be padded so they can "get along" with each other. + res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim)) + res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim)) + return torch.cat(res1, dim=0), torch.cat(res2, dim=0) # Follows the general template of BYOL dataset, with the following changes: @@ -169,8 +199,8 @@ class StructuredCropDatasetWrapper(Dataset): def __getitem__(self, item): item = self.wrapped_dataset[item] - a1 = item['hq'] #self.aug(item['hq']).squeeze(dim=0) - a2 = item['hq'] #self.aug(item['lq']).squeeze(dim=0) + a1 = self.aug(item['hq']).squeeze(dim=0) + a2 = self.aug(item['lq']).squeeze(dim=0) a1, a2, sr_dim, m1, m2, db = self.rrc(a1, a2) item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim, 'masked1': m1, 'masked2': m2, 'aug_shared_view': db}) @@ -187,7 +217,7 @@ if __name__ == '__main__': { 'mode': 'imagefolder', 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\images\\flickr\\flickr-scrape\\filtered\carrot'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'], 'weights': [1], 'target_size': 256, 'force_multiple': 32, @@ -204,15 +234,15 @@ if __name__ == '__main__': ds = StructuredCropDatasetWrapper(opt) import os os.makedirs("debug", exist_ok=True) - for i in range(0, len(ds)): - o = ds[random.randint(0, len(ds))] - for k, v in o.items(): + for i in tqdm(range(0, len(ds))): + o = ds[random.randint(0, len(ds)-1)] + #for k, v in o.items(): # 'lq', 'hq', 'aug1', 'aug2', - if k in [ 'aug_shared_view', 'masked1', 'masked2']: - torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) + #if k in [ 'aug_shared_view', 'masked1', 'masked2']: + #torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) rcpkg = o['similar_region_dimensions'] pixun = PixelUnshuffle(8) pixsh = nn.PixelShuffle(8) rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg) - torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,)) - torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,)) + #torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,)) + #torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,)) diff --git a/codes/models/byol/byol_structural.py b/codes/models/byol/byol_structural.py new file mode 100644 index 00000000..9d423e60 --- /dev/null +++ b/codes/models/byol/byol_structural.py @@ -0,0 +1,178 @@ +import copy +import random +from functools import wraps +from time import time + +import torch +import torch.nn.functional as F +from torch import nn + +from data.byol_attachment import reconstructed_shared_regions +from models.byol.byol_model_wrapper import singleton, EMA, MLP, get_module_device, set_requires_grad, \ + update_moving_average +from utils.util import checkpoint + +# loss function +def structural_loss_fn(x, y): + # Combine the structural dimensions into the batch dimension, then compute the "normal" BYOL loss. + x = x.permute(0,2,3,1).flatten(0,2) + y = y.permute(0,2,3,1).flatten(0,2) + x = F.normalize(x, dim=-1, p=2) + y = F.normalize(y, dim=-1, p=2) + return 2 - 2 * (x * y).sum(dim=-1) + + +class StructuralTail(nn.Module): + def __init__(self, channels, projection_size, hidden_size=512): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(channels, hidden_size, kernel_size=1), + nn.BatchNorm2d(hidden_size), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_size, projection_size, kernel_size=1), + ) + + def forward(self, x): + return self.net(x) + + +# a wrapper class for the base neural network +# will manage the interception of the hidden layer output +# and pipe it into the projecter and predictor nets +class NetWrapper(nn.Module): + def __init__(self, net, projection_size, projection_hidden_size, layer=-2): + super().__init__() + self.net = net + self.layer = layer + + self.projector = None + self.projection_size = projection_size + self.projection_hidden_size = projection_hidden_size + + self.hidden = None + self.hook_registered = False + + def _find_layer(self): + if type(self.layer) == str: + modules = dict([*self.net.named_modules()]) + return modules.get(self.layer, None) + elif type(self.layer) == int: + children = [*self.net.children()] + return children[self.layer] + return None + + def _hook(self, _, __, output): + self.hidden = output + + def _register_hook(self): + layer = self._find_layer() + assert layer is not None, f'hidden layer ({self.layer}) not found' + handle = layer.register_forward_hook(self._hook) + self.hook_registered = True + + @singleton('projector') + def _get_projector(self, hidden): + projector = StructuralTail(hidden.shape[1], self.projection_size, self.projection_hidden_size) + return projector.to(hidden) + + def get_representation(self, x): + if self.layer == -1: + return self.net(x) + + if not self.hook_registered: + self._register_hook() + + unused = self.net(x) + hidden = self.hidden + self.hidden = None + assert hidden is not None, f'hidden layer {self.layer} never emitted an output' + return hidden + + def forward(self, x): + representation = self.get_representation(x) + projector = self._get_projector(representation) + projection = checkpoint(projector, representation) + return projection + + +class StructuralBYOL(nn.Module): + def __init__( + self, + net, + image_size, + hidden_layer=-2, + projection_size=256, + projection_hidden_size=512, + moving_average_decay=0.99, + use_momentum=True, + pretrained_state_dict=None, + freeze_until=0 + ): + super().__init__() + + if pretrained_state_dict: + net.load_state_dict(torch.load(pretrained_state_dict), strict=True) + self.freeze_until = freeze_until + if self.freeze_until > 0: + for p in net.parameters(): + p.DO_NOT_TRAIN = True + self.frozen = True + self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) + + self.use_momentum = use_momentum + self.target_encoder = None + self.target_ema_updater = EMA(moving_average_decay) + + self.online_predictor = StructuralTail(projection_size, projection_size, projection_hidden_size) + + # get device of network and make wrapper same device + device = get_module_device(net) + self.to(device) + + # send a mock image tensor to instantiate singleton parameters + self.forward(torch.randn(2, 3, image_size, image_size, device=device), + torch.randn(2, 3, image_size, image_size, device=device), None) + + @singleton('target_encoder') + def _get_target_encoder(self): + target_encoder = copy.deepcopy(self.online_encoder) + set_requires_grad(target_encoder, False) + return target_encoder + + def reset_moving_average(self): + del self.target_encoder + self.target_encoder = None + + def update_for_step(self, step, __): + assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' + assert self.target_encoder is not None, 'target encoder has not been created yet' + update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) + if self.frozen and self.freeze_until < step: + print("Unfreezing model weights. Let the latent training commence..") + for p in self.online_encoder.net.parameters(): + del p.DO_NOT_TRAIN + self.frozen = False + + def forward(self, image_one, image_two, similar_region_params): + online_proj_one = self.online_encoder(image_one) + online_proj_two = self.online_encoder(image_two) + + online_pred_one = self.online_predictor(online_proj_one) + online_pred_two = self.online_predictor(online_proj_two) + + with torch.no_grad(): + target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder + target_proj_one = target_encoder(image_one).detach() + target_proj_two = target_encoder(image_two).detach() + + # In the structural BYOL, only the regions of the source image that are shared between the two augments are + # compared. These regions can be extracted from the latents using `reconstruct_shared_regions`. + if similar_region_params is not None: + online_pred_one, target_proj_two = reconstructed_shared_regions(online_pred_one, target_proj_two, similar_region_params) + loss_one = structural_loss_fn(online_pred_one, target_proj_two.detach()) + if similar_region_params is not None: + online_pred_two, target_proj_one = reconstructed_shared_regions(online_pred_two, target_proj_one, similar_region_params) + loss_two = structural_loss_fn(online_pred_two, target_proj_one.detach()) + + loss = loss_one + loss_two + return loss.mean() diff --git a/codes/models/networks.py b/codes/models/networks.py index da129007..0f504e23 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -153,6 +153,12 @@ def define_G(opt, opt_net, scale=None): subnet = define_G(opt, opt_net['subnet']) netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False)) + elif which_model == 'structural_byol': + from models.byol.byol_structural import StructuralBYOL + subnet = define_G(opt, opt_net['subnet']) + netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], + pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]), + freeze_until=opt_get(opt_net, ['freeze_until'], 0)) elif which_model == 'spinenet': from models.archs.spinenet_arch import SpineNet netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])