From 8e4b9f42fd335e04e31b081145568d077c8fb7f0 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Dec 2020 09:57:18 -0700 Subject: [PATCH] New BYOL dataset which uses a form of RandomCrop that lends itself to structural guidance to the latents. --- codes/data/byol_attachment.py | 171 ++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/codes/data/byol_attachment.py b/codes/data/byol_attachment.py index 75be36a6..ddf1c6ea 100644 --- a/codes/data/byol_attachment.py +++ b/codes/data/byol_attachment.py @@ -1,14 +1,18 @@ import random import torch +import torchvision from torch.utils.data import Dataset from kornia import augmentation as augs from kornia import filters import torch.nn as nn +import torch.nn.functional as F # Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq' # inputs. These are then outputted as 'aug1' and 'aug2'. from data import create_dataset +from models.archs.arch_util import PixelUnshuffle +from utils.util import opt_get class RandomApply(nn.Module): @@ -45,3 +49,170 @@ class ByolDatasetWrapper(Dataset): def __len__(self): return len(self.wrapped_dataset) + + +def no_batch_interpolate(i, size, mode): + i = i.unsqueeze(0) + i = F.interpolate(i, size=size, mode=mode) + return i.squeeze(0) + + +# Performs a 1-d translation of "other": +# If other (self.multiple*3) + d = d // self.multiple + + # Step 2 + base_w = random.randint(d//2, d-1) + base_l = random.randint(0, d-base_w) + base_h = random.randint(base_w-1, base_w+1) + base_t = random.randint(0, d-base_h) + base_r, base_b = base_l+base_w, base_t+base_h + + # Step 3 + im2_w = random.randint(d//2, d-1) + im2_l = random.randint(0, d-im2_w) + im2_h = random.randint(im2_w-1, im2_w+1) + im2_t = random.randint(0, d-im2_h) + im2_r, im2_b = im2_l+im2_w, im2_t+im2_h + + # Step 6 + m = self.multiple + jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) + p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl] + p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear") + jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) + p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl] + p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear") + + # Step 7 + i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l) + i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l) + ix_h = min(base_b, im2_b) - max(base_t, im2_t) + ix_w = min(base_r, im2_r) - max(base_l, im2_l) + recompute_package = (base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w) + + # Step 8 + mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5) + mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1 + masked1 = p1 * mask1 + mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5) + mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1 + masked2 = p2 * mask2 + mask = torch.full((1, d*m, d*m), fill_value=.33) + mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33 + mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33 + masked_dbg = i1 * mask + + return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg + + +# Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature +# maps. +def reconstructed_shared_regions(fea1, fea2, recompute_package): + f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = recompute_package + # Resize the input features to match + f1s = F.interpolate(fea1, (f1_h, f1_w), mode="bilinear") + f2s = F.interpolate(fea2, (f2_h, f2_w), mode="bilinear") + f1sh = f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w] + f2sh = f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w] + return f1sh, f2sh + + +# Follows the general template of BYOL dataset, with the following changes: +# 1. Flip() is not applied. +# 2. Instead of RandomResizedCrop, a custom Transform, RandomSharedRegionCrop is used. +# 3. The dataset injects two integer tensors alongside the augmentations, which are used to index image regions shared +# by the joint augmentations. +# 4. The dataset injects an aug_shared_view for debugging purposes. +class StructuredCropDatasetWrapper(Dataset): + def __init__(self, opt): + super().__init__() + self.wrapped_dataset = create_dataset(opt['dataset']) + augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), + augs.RandomGrayscale(p=0.2), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)] + self.aug = nn.Sequential(*augmentations) + self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0)) + + def __getitem__(self, item): + item = self.wrapped_dataset[item] + a1 = item['hq'] #self.aug(item['hq']).squeeze(dim=0) + a2 = item['hq'] #self.aug(item['lq']).squeeze(dim=0) + a1, a2, sr_dim, m1, m2, db = self.rrc(a1, a2) + item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim, + 'masked1': m1, 'masked2': m2, 'aug_shared_view': db}) + return item + + def __len__(self): + return len(self.wrapped_dataset) + + +# For testing this dataset. +if __name__ == '__main__': + opt = { + 'dataset': + { + 'mode': 'imagefolder', + 'name': 'amalgam', + 'paths': ['F:\\4k6k\\datasets\\images\\flickr\\flickr-scrape\\filtered\carrot'], + 'weights': [1], + 'target_size': 256, + 'force_multiple': 32, + 'scale': 1, + 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], + 'random_corruptions': ['noise-5', 'none'], + 'num_corrupts_per_image': 1, + 'corrupt_before_downsize': True, + }, + 'latent_multiple': 8, + 'jitter_range': 0, + } + + ds = StructuredCropDatasetWrapper(opt) + import os + os.makedirs("debug", exist_ok=True) + for i in range(0, len(ds)): + o = ds[random.randint(0, len(ds))] + for k, v in o.items(): + # 'lq', 'hq', 'aug1', 'aug2', + if k in [ 'aug_shared_view', 'masked1', 'masked2']: + torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) + rcpkg = o['similar_region_dimensions'] + pixun = PixelUnshuffle(8) + pixsh = nn.PixelShuffle(8) + rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg) + torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,)) + torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))