New BYOL dataset which uses a form of RandomCrop that lends itself to
structural guidance to the latents.
This commit is contained in:
parent
66cbae8731
commit
8e4b9f42fd
|
@ -1,14 +1,18 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.utils.data import Dataset
|
||||
from kornia import augmentation as augs
|
||||
from kornia import filters
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
|
||||
# inputs. These are then outputted as 'aug1' and 'aug2'.
|
||||
from data import create_dataset
|
||||
from models.archs.arch_util import PixelUnshuffle
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class RandomApply(nn.Module):
|
||||
|
@ -45,3 +49,170 @@ class ByolDatasetWrapper(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.wrapped_dataset)
|
||||
|
||||
|
||||
def no_batch_interpolate(i, size, mode):
|
||||
i = i.unsqueeze(0)
|
||||
i = F.interpolate(i, size=size, mode=mode)
|
||||
return i.squeeze(0)
|
||||
|
||||
|
||||
# Performs a 1-d translation of "other":
|
||||
# If other<ref, returns 0.
|
||||
# Else: return other-ref
|
||||
def snap(ref, other):
|
||||
if other < ref:
|
||||
return 0
|
||||
return other - ref
|
||||
|
||||
|
||||
# Variation of RandomResizedCrop, which picks a region of the image that the two augments must share. The augments
|
||||
# then propagate off random corners of the shared region, using the same scale.
|
||||
#
|
||||
# Operates in units of "multiple". The intent is that this multiple is equivalent to the compression multiple of the
|
||||
# latent space being used so that each structural unit corresponds to a latent unit.
|
||||
class RandomSharedRegionCrop(nn.Module):
|
||||
def __init__(self, multiple, jitter_range=0):
|
||||
super().__init__()
|
||||
self.multiple = multiple
|
||||
self.jitter_range = jitter_range # When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range
|
||||
|
||||
def forward(self, i1, i2):
|
||||
assert i1.shape[-1] == i2.shape[-1]
|
||||
# Outline of the general algorithm:
|
||||
# 1. Assume the input is a square. Divide it by self.multiple to get working units.
|
||||
# 2. Pick a random width, height and top corner location for the first patch.
|
||||
# 3. Pick a random width, height and top corner location for the second patch.
|
||||
# Note: All dims from (2) and (3) must contain at least half of the image, guaranteeing overlap.
|
||||
# 6. Build patches from input images. Resize them appropriately. Apply translational jitter.
|
||||
# 7. Compute the metrics needed to extract overlapping regions from the resized patches: top, left,
|
||||
# original_height, original_width.
|
||||
# 8. Compute the "shared_view" from the above data.
|
||||
|
||||
# Step 1
|
||||
c, d, _ = i1.shape
|
||||
assert d % self.multiple == 0 and d > (self.multiple*3)
|
||||
d = d // self.multiple
|
||||
|
||||
# Step 2
|
||||
base_w = random.randint(d//2, d-1)
|
||||
base_l = random.randint(0, d-base_w)
|
||||
base_h = random.randint(base_w-1, base_w+1)
|
||||
base_t = random.randint(0, d-base_h)
|
||||
base_r, base_b = base_l+base_w, base_t+base_h
|
||||
|
||||
# Step 3
|
||||
im2_w = random.randint(d//2, d-1)
|
||||
im2_l = random.randint(0, d-im2_w)
|
||||
im2_h = random.randint(im2_w-1, im2_w+1)
|
||||
im2_t = random.randint(0, d-im2_h)
|
||||
im2_r, im2_b = im2_l+im2_w, im2_t+im2_h
|
||||
|
||||
# Step 6
|
||||
m = self.multiple
|
||||
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
|
||||
p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
|
||||
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
|
||||
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
|
||||
|
||||
# Step 7
|
||||
i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l)
|
||||
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
|
||||
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
|
||||
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
|
||||
recompute_package = (base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w)
|
||||
|
||||
# Step 8
|
||||
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
|
||||
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
|
||||
masked1 = p1 * mask1
|
||||
mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
|
||||
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
|
||||
masked2 = p2 * mask2
|
||||
mask = torch.full((1, d*m, d*m), fill_value=.33)
|
||||
mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
|
||||
mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33
|
||||
masked_dbg = i1 * mask
|
||||
|
||||
return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg
|
||||
|
||||
|
||||
# Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature
|
||||
# maps.
|
||||
def reconstructed_shared_regions(fea1, fea2, recompute_package):
|
||||
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = recompute_package
|
||||
# Resize the input features to match
|
||||
f1s = F.interpolate(fea1, (f1_h, f1_w), mode="bilinear")
|
||||
f2s = F.interpolate(fea2, (f2_h, f2_w), mode="bilinear")
|
||||
f1sh = f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w]
|
||||
f2sh = f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w]
|
||||
return f1sh, f2sh
|
||||
|
||||
|
||||
# Follows the general template of BYOL dataset, with the following changes:
|
||||
# 1. Flip() is not applied.
|
||||
# 2. Instead of RandomResizedCrop, a custom Transform, RandomSharedRegionCrop is used.
|
||||
# 3. The dataset injects two integer tensors alongside the augmentations, which are used to index image regions shared
|
||||
# by the joint augmentations.
|
||||
# 4. The dataset injects an aug_shared_view for debugging purposes.
|
||||
class StructuredCropDatasetWrapper(Dataset):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.wrapped_dataset = create_dataset(opt['dataset'])
|
||||
augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
|
||||
self.aug = nn.Sequential(*augmentations)
|
||||
self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
|
||||
|
||||
def __getitem__(self, item):
|
||||
item = self.wrapped_dataset[item]
|
||||
a1 = item['hq'] #self.aug(item['hq']).squeeze(dim=0)
|
||||
a2 = item['hq'] #self.aug(item['lq']).squeeze(dim=0)
|
||||
a1, a2, sr_dim, m1, m2, db = self.rrc(a1, a2)
|
||||
item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim,
|
||||
'masked1': m1, 'masked2': m2, 'aug_shared_view': db})
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.wrapped_dataset)
|
||||
|
||||
|
||||
# For testing this dataset.
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'dataset':
|
||||
{
|
||||
'mode': 'imagefolder',
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\images\\flickr\\flickr-scrape\\filtered\carrot'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
'force_multiple': 32,
|
||||
'scale': 1,
|
||||
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
||||
'random_corruptions': ['noise-5', 'none'],
|
||||
'num_corrupts_per_image': 1,
|
||||
'corrupt_before_downsize': True,
|
||||
},
|
||||
'latent_multiple': 8,
|
||||
'jitter_range': 0,
|
||||
}
|
||||
|
||||
ds = StructuredCropDatasetWrapper(opt)
|
||||
import os
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
for i in range(0, len(ds)):
|
||||
o = ds[random.randint(0, len(ds))]
|
||||
for k, v in o.items():
|
||||
# 'lq', 'hq', 'aug1', 'aug2',
|
||||
if k in [ 'aug_shared_view', 'masked1', 'masked2']:
|
||||
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
rcpkg = o['similar_region_dimensions']
|
||||
pixun = PixelUnshuffle(8)
|
||||
pixsh = nn.PixelShuffle(8)
|
||||
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg)
|
||||
torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
|
||||
torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
|
||||
|
|
Loading…
Reference in New Issue
Block a user