DL-Art-School/codes/data/byol_attachment.py

363 lines
16 KiB
Python
Raw Normal View History

import random
2020-12-10 22:07:35 +00:00
from time import time
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
from kornia import augmentation as augs, kornia, Resample
from kornia import filters
import torch.nn as nn
import torch.nn.functional as F
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
# inputs. These are then outputted as 'aug1' and 'aug2'.
2020-12-10 22:07:35 +00:00
from tqdm import tqdm
from data import create_dataset
2020-12-18 16:24:31 +00:00
from models.arch_util import PixelUnshuffle
from utils.util import opt_get
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
class ByolDatasetWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size']
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
if opt['normalize']:
# The paper calls for normalization. Most datasets/models in this repo don't use this.
# Recommend setting true if you want to train exactly like the paper.
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
self.aug = nn.Sequential(*augmentations)
def __getitem__(self, item):
item = self.wrapped_dataset[item]
item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)})
return item
def __len__(self):
return len(self.wrapped_dataset)
# Basically the same as ByolDatasetWrapper except only produces 1 augmentation and stores in the 'lr' key. Also applies
# crop&resize to 2D tensors in the state dict with the word "label" in them.
class DatasetRandomAugWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size']
self.includes_labels = opt['includes_labels']
augmentations = [ \
RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
self.aug = nn.Sequential(*augmentations)
self.rrc = nn.Sequential(*[
augs.RandomHorizontalFlip(),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))])
def __getitem__(self, item):
item = self.wrapped_dataset[item]
hq = self.aug(item['hq'].unsqueeze(0)).squeeze(0)
labels = []
dtypes = []
for k in item.keys():
if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3:
assert item[k].shape[0] == 1 # Only supports a channel dim of 1.
labels.append(k)
dtypes.append(item[k].dtype)
hq = torch.cat([hq, item[k].type(torch.float)], dim=0)
hq = self.rrc(hq.unsqueeze(0)).squeeze(0)
for i, k in enumerate(labels):
# Strip out any label values that are not whole numbers.
item[k] = hq[3+i:3+i+1,:,:]
whole = (item[k].round() == item[k])
item[k] = item[k] * whole
item[k] = item[k].type(dtypes[i])
item['lq'] = hq[:3,:,:]
item['hq'] = item['lq']
return item
def __len__(self):
return len(self.wrapped_dataset)
def test_dataset_random_aug_wrapper():
opt = {
'dataset': {
'mode': 'imagefolder',
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
'weights': [1],
'target_size': 512,
'force_multiple': 1,
'scale': 1,
'fixed_corruptions': ['jpeg-broad'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': False,
'labeler': {
'type': 'patch_labels',
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json'
}
},
'crop_size': 512,
'includes_labels': True,
}
ds = DatasetRandomAugWrapper(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in tqdm(range(0, len(ds))):
o = ds[random.randint(0, len(ds)-1)]
for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
if k in ['hq']:
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
masked = v * (o['labels_mask'] * .5 + .5)
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
# Pick a random (non-zero) label and spit it out with the textual label.
if len(o['labels'].unique()) > 1:
randlbl = np.random.choice(o['labels'].unique()[1:])
moremask = v * ((1*(o['labels'] == randlbl))*.5+.5)
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl]))
def no_batch_interpolate(i, size, mode):
i = i.unsqueeze(0)
i = F.interpolate(i, size=size, mode=mode)
return i.squeeze(0)
# Performs a 1-d translation of "other":
# If other<ref, returns 0.
# Else: return other-ref
def snap(ref, other):
if other < ref:
return 0
return other - ref
2020-12-10 22:07:35 +00:00
# Pads a tensor with zeros so that it fits in a dxd square.
def pad_to(im, d):
if len(im.shape) == 3:
pd = torch.zeros((im.shape[0],d,d))
pd[:, :im.shape[1], :im.shape[2]] = im
else:
pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device)
pd[:, :, :im.shape[2], :im.shape[3]] = im
return pd
# Variation of RandomResizedCrop, which picks a region of the image that the two augments must share. The augments
# then propagate off random corners of the shared region, using the same scale.
#
# Operates in units of "multiple". The intent is that this multiple is equivalent to the compression multiple of the
# latent space being used so that each structural unit corresponds to a latent unit.
class RandomSharedRegionCrop(nn.Module):
def __init__(self, multiple, jitter_range=0):
super().__init__()
self.multiple = multiple
self.jitter_range = jitter_range # When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range
def forward(self, i1, i2):
assert i1.shape[-1] == i2.shape[-1]
# Outline of the general algorithm:
# 1. Assume the input is a square. Divide it by self.multiple to get working units.
# 2. Pick a random width, height and top corner location for the first patch.
# 3. Pick a random width, height and top corner location for the second patch.
# Note: All dims from (2) and (3) must contain at least half of the image, guaranteeing overlap.
2020-12-11 19:01:09 +00:00
# 4. Build patches from input images. Resize them appropriately. Apply translational jitter.\
# 5. Randomly flip image 2 if needed.
# 5. Compute the metrics needed to extract overlapping regions from the resized patches: top, left,
# original_height, original_width.
2020-12-11 19:01:09 +00:00
# 6. Compute the "shared_view" from the above data.
# Step 1
c, d, _ = i1.shape
assert d % self.multiple == 0 and d > (self.multiple*3)
d = d // self.multiple
# Step 2
2020-12-15 06:59:11 +00:00
base_w = random.randint(d//2+1, d-1)
base_l = random.randint(0, d-base_w)
base_h = random.randint(base_w-1, base_w+1)
base_t = random.randint(0, d-base_h)
base_r, base_b = base_l+base_w, base_t+base_h
# Step 3
2020-12-15 06:59:11 +00:00
im2_w = random.randint(d//2+1, d-1)
im2_l = random.randint(0, d-im2_w)
im2_h = random.randint(im2_w-1, im2_w+1)
im2_t = random.randint(0, d-im2_h)
im2_r, im2_b = im2_l+im2_w, im2_t+im2_h
2020-12-11 19:01:09 +00:00
# Step 4
m = self.multiple
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
2020-12-10 22:07:35 +00:00
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds.
jl = jl if base_l != 0 else abs(jl)
jl = jl if (base_l+base_w)*m != i1.shape[1] else 0
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
2020-12-10 22:07:35 +00:00
jt = jt if im2_t != 0 else abs(jt)
jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0
jl = jl if im2_l != 0 else abs(jl)
jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
2020-12-11 19:01:09 +00:00
# Step 5
should_flip = random.random() < .5
if should_flip:
should_flip = 1
p2_resized = kornia.geometry.transform.hflip(p2_resized)
else:
should_flip = 0
# Step 6
i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l)
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
2020-12-15 06:59:11 +00:00
recompute_package = torch.tensor([d, base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long)
2020-12-11 19:01:09 +00:00
# Step 7
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
2020-12-10 22:07:35 +00:00
masked1 = pad_to(p1 * mask1, d*m)
mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
2020-12-10 22:07:35 +00:00
masked2 = pad_to(p2 * mask2, d*m)
mask = torch.full((1, d*m, d*m), fill_value=.33)
mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33
masked_dbg = i1 * mask
2020-12-15 06:59:11 +00:00
# Step 8 - Rebuild shared regions for testing purposes.
p1_shuf, p2_shuf = PixelUnshuffle(self.multiple)(p1_resized.unsqueeze(0)), \
PixelUnshuffle(self.multiple)(p2_resized.unsqueeze(0))
i1_shared, i2_shared = reconstructed_shared_regions(p1_shuf, p2_shuf, recompute_package.unsqueeze(0))
i1_shared = pad_to(nn.PixelShuffle(self.multiple)(i1_shared).squeeze(0), d * m)
i2_shared = pad_to(nn.PixelShuffle(self.multiple)(i2_shared).squeeze(0), d*m)
return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg, i1_shared, i2_shared
# Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature
# maps.
2020-12-10 22:07:35 +00:00
def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
package = recompute_package.cpu()
res1 = []
res2 = []
pad_dim = torch.max(package[:, -2:]).item()
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
# of conforming the recompute_package across the entire batch.
for b in range(package.shape[0]):
2020-12-15 06:59:11 +00:00
expected_dim, f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist())
# If you are hitting this assert, you specified `latent_multiple` in your dataset config wrong.
assert expected_dim == fea1.shape[2] and expected_dim == fea2.shape[2]
2020-12-11 19:01:09 +00:00
# Unflip 2 if needed.
f2 = fea2[b]
if should_flip == 1:
f2 = kornia.geometry.transform.hflip(f2)
2020-12-10 22:07:35 +00:00
# Resize the input features to match
2020-12-15 06:59:11 +00:00
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="nearest")
f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="nearest")
2020-12-10 22:07:35 +00:00
# Outputs must be padded so they can "get along" with each other.
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
return torch.cat(res1, dim=0), torch.cat(res2, dim=0)
# Follows the general template of BYOL dataset, with the following changes:
# 1. Flip() is not applied.
# 2. Instead of RandomResizedCrop, a custom Transform, RandomSharedRegionCrop is used.
# 3. The dataset injects two integer tensors alongside the augmentations, which are used to index image regions shared
# by the joint augmentations.
# 4. The dataset injects an aug_shared_view for debugging purposes.
class StructuredCropDatasetWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
self.aug = nn.Sequential(*augmentations)
self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
def __getitem__(self, item):
item = self.wrapped_dataset[item]
2020-12-10 22:07:35 +00:00
a1 = self.aug(item['hq']).squeeze(dim=0)
a2 = self.aug(item['lq']).squeeze(dim=0)
2020-12-15 06:59:11 +00:00
a1, a2, sr_dim, m1, m2, db, i1s, i2s = self.rrc(a1, a2)
item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim,
2020-12-15 06:59:11 +00:00
'masked1': m1, 'masked2': m2, 'aug_shared_view': db,
'i1_shared': i1s, 'i2_shared': i2s})
return item
def __len__(self):
return len(self.wrapped_dataset)
# For testing this dataset.
def test_structured_crop_dataset_wrapper():
opt = {
'dataset':
{
'mode': 'imagefolder',
'name': 'amalgam',
2020-12-10 22:07:35 +00:00
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
'weights': [1],
'target_size': 256,
'force_multiple': 32,
'scale': 1,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': True,
},
2020-12-15 06:59:11 +00:00
'latent_multiple': 16,
'jitter_range': 0,
}
ds = StructuredCropDatasetWrapper(opt)
import os
os.makedirs("debug", exist_ok=True)
2020-12-10 22:07:35 +00:00
for i in tqdm(range(0, len(ds))):
o = ds[random.randint(0, len(ds)-1)]
#for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
2020-12-10 22:07:35 +00:00
#if k in [ 'aug_shared_view', 'masked1', 'masked2']:
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
rcpkg = o['similar_region_dimensions']
2020-12-15 06:59:11 +00:00
pixun = PixelUnshuffle(16)
pixsh = nn.PixelShuffle(16)
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
2020-12-10 22:07:35 +00:00
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
if __name__ == '__main__':
test_dataset_random_aug_wrapper()