New BYOL dataset which uses a form of RandomCrop that lends itself to

structural guidance to the latents.
This commit is contained in:
James Betker 2020-12-10 09:57:18 -07:00
parent 66cbae8731
commit 8e4b9f42fd

View File

@ -1,14 +1,18 @@
import random
import torch
import torchvision
from torch.utils.data import Dataset
from kornia import augmentation as augs
from kornia import filters
import torch.nn as nn
import torch.nn.functional as F
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
# inputs. These are then outputted as 'aug1' and 'aug2'.
from data import create_dataset
from models.archs.arch_util import PixelUnshuffle
from utils.util import opt_get
class RandomApply(nn.Module):
@ -45,3 +49,170 @@ class ByolDatasetWrapper(Dataset):
def __len__(self):
return len(self.wrapped_dataset)
def no_batch_interpolate(i, size, mode):
i = i.unsqueeze(0)
i = F.interpolate(i, size=size, mode=mode)
return i.squeeze(0)
# Performs a 1-d translation of "other":
# If other<ref, returns 0.
# Else: return other-ref
def snap(ref, other):
if other < ref:
return 0
return other - ref
# Variation of RandomResizedCrop, which picks a region of the image that the two augments must share. The augments
# then propagate off random corners of the shared region, using the same scale.
#
# Operates in units of "multiple". The intent is that this multiple is equivalent to the compression multiple of the
# latent space being used so that each structural unit corresponds to a latent unit.
class RandomSharedRegionCrop(nn.Module):
def __init__(self, multiple, jitter_range=0):
super().__init__()
self.multiple = multiple
self.jitter_range = jitter_range # When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range
def forward(self, i1, i2):
assert i1.shape[-1] == i2.shape[-1]
# Outline of the general algorithm:
# 1. Assume the input is a square. Divide it by self.multiple to get working units.
# 2. Pick a random width, height and top corner location for the first patch.
# 3. Pick a random width, height and top corner location for the second patch.
# Note: All dims from (2) and (3) must contain at least half of the image, guaranteeing overlap.
# 6. Build patches from input images. Resize them appropriately. Apply translational jitter.
# 7. Compute the metrics needed to extract overlapping regions from the resized patches: top, left,
# original_height, original_width.
# 8. Compute the "shared_view" from the above data.
# Step 1
c, d, _ = i1.shape
assert d % self.multiple == 0 and d > (self.multiple*3)
d = d // self.multiple
# Step 2
base_w = random.randint(d//2, d-1)
base_l = random.randint(0, d-base_w)
base_h = random.randint(base_w-1, base_w+1)
base_t = random.randint(0, d-base_h)
base_r, base_b = base_l+base_w, base_t+base_h
# Step 3
im2_w = random.randint(d//2, d-1)
im2_l = random.randint(0, d-im2_w)
im2_h = random.randint(im2_w-1, im2_w+1)
im2_t = random.randint(0, d-im2_h)
im2_r, im2_b = im2_l+im2_w, im2_t+im2_h
# Step 6
m = self.multiple
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
# Step 7
i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l)
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
recompute_package = (base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w)
# Step 8
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
masked1 = p1 * mask1
mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
masked2 = p2 * mask2
mask = torch.full((1, d*m, d*m), fill_value=.33)
mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33
masked_dbg = i1 * mask
return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg
# Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature
# maps.
def reconstructed_shared_regions(fea1, fea2, recompute_package):
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = recompute_package
# Resize the input features to match
f1s = F.interpolate(fea1, (f1_h, f1_w), mode="bilinear")
f2s = F.interpolate(fea2, (f2_h, f2_w), mode="bilinear")
f1sh = f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w]
f2sh = f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w]
return f1sh, f2sh
# Follows the general template of BYOL dataset, with the following changes:
# 1. Flip() is not applied.
# 2. Instead of RandomResizedCrop, a custom Transform, RandomSharedRegionCrop is used.
# 3. The dataset injects two integer tensors alongside the augmentations, which are used to index image regions shared
# by the joint augmentations.
# 4. The dataset injects an aug_shared_view for debugging purposes.
class StructuredCropDatasetWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
self.aug = nn.Sequential(*augmentations)
self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
def __getitem__(self, item):
item = self.wrapped_dataset[item]
a1 = item['hq'] #self.aug(item['hq']).squeeze(dim=0)
a2 = item['hq'] #self.aug(item['lq']).squeeze(dim=0)
a1, a2, sr_dim, m1, m2, db = self.rrc(a1, a2)
item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim,
'masked1': m1, 'masked2': m2, 'aug_shared_view': db})
return item
def __len__(self):
return len(self.wrapped_dataset)
# For testing this dataset.
if __name__ == '__main__':
opt = {
'dataset':
{
'mode': 'imagefolder',
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\images\\flickr\\flickr-scrape\\filtered\carrot'],
'weights': [1],
'target_size': 256,
'force_multiple': 32,
'scale': 1,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': True,
},
'latent_multiple': 8,
'jitter_range': 0,
}
ds = StructuredCropDatasetWrapper(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))]
for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
if k in [ 'aug_shared_view', 'masked1', 'masked2']:
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
rcpkg = o['similar_region_dimensions']
pixun = PixelUnshuffle(8)
pixsh = nn.PixelShuffle(8)
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg)
torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))