DL-Art-School/codes/data/byol_attachment.py

262 lines
11 KiB
Python
Raw Normal View History

import random
2020-12-10 22:07:35 +00:00
from time import time
import torch
import torchvision
from torch.utils.data import Dataset
2020-12-11 19:01:09 +00:00
from kornia import augmentation as augs, kornia
from kornia import filters
import torch.nn as nn
import torch.nn.functional as F
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
# inputs. These are then outputted as 'aug1' and 'aug2'.
2020-12-10 22:07:35 +00:00
from tqdm import tqdm
from data import create_dataset
from models.archs.arch_util import PixelUnshuffle
from utils.util import opt_get
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
class ByolDatasetWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size']
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
if opt['normalize']:
# The paper calls for normalization. Recommend setting true if you want exactly like the paper.
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
self.aug = nn.Sequential(*augmentations)
def __getitem__(self, item):
item = self.wrapped_dataset[item]
item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)})
return item
def __len__(self):
return len(self.wrapped_dataset)
def no_batch_interpolate(i, size, mode):
i = i.unsqueeze(0)
i = F.interpolate(i, size=size, mode=mode)
return i.squeeze(0)
# Performs a 1-d translation of "other":
# If other<ref, returns 0.
# Else: return other-ref
def snap(ref, other):
if other < ref:
return 0
return other - ref
2020-12-10 22:07:35 +00:00
# Pads a tensor with zeros so that it fits in a dxd square.
def pad_to(im, d):
if len(im.shape) == 3:
pd = torch.zeros((im.shape[0],d,d))
pd[:, :im.shape[1], :im.shape[2]] = im
else:
pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device)
pd[:, :, :im.shape[2], :im.shape[3]] = im
return pd
# Variation of RandomResizedCrop, which picks a region of the image that the two augments must share. The augments
# then propagate off random corners of the shared region, using the same scale.
#
# Operates in units of "multiple". The intent is that this multiple is equivalent to the compression multiple of the
# latent space being used so that each structural unit corresponds to a latent unit.
class RandomSharedRegionCrop(nn.Module):
def __init__(self, multiple, jitter_range=0):
super().__init__()
self.multiple = multiple
self.jitter_range = jitter_range # When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range
def forward(self, i1, i2):
assert i1.shape[-1] == i2.shape[-1]
# Outline of the general algorithm:
# 1. Assume the input is a square. Divide it by self.multiple to get working units.
# 2. Pick a random width, height and top corner location for the first patch.
# 3. Pick a random width, height and top corner location for the second patch.
# Note: All dims from (2) and (3) must contain at least half of the image, guaranteeing overlap.
2020-12-11 19:01:09 +00:00
# 4. Build patches from input images. Resize them appropriately. Apply translational jitter.\
# 5. Randomly flip image 2 if needed.
# 5. Compute the metrics needed to extract overlapping regions from the resized patches: top, left,
# original_height, original_width.
2020-12-11 19:01:09 +00:00
# 6. Compute the "shared_view" from the above data.
# Step 1
c, d, _ = i1.shape
assert d % self.multiple == 0 and d > (self.multiple*3)
d = d // self.multiple
# Step 2
base_w = random.randint(d//2, d-1)
base_l = random.randint(0, d-base_w)
base_h = random.randint(base_w-1, base_w+1)
base_t = random.randint(0, d-base_h)
base_r, base_b = base_l+base_w, base_t+base_h
# Step 3
im2_w = random.randint(d//2, d-1)
im2_l = random.randint(0, d-im2_w)
im2_h = random.randint(im2_w-1, im2_w+1)
im2_t = random.randint(0, d-im2_h)
im2_r, im2_b = im2_l+im2_w, im2_t+im2_h
2020-12-11 19:01:09 +00:00
# Step 4
m = self.multiple
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
2020-12-10 22:07:35 +00:00
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds.
jl = jl if base_l != 0 else abs(jl)
jl = jl if (base_l+base_w)*m != i1.shape[1] else 0
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
2020-12-10 22:07:35 +00:00
jt = jt if im2_t != 0 else abs(jt)
jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0
jl = jl if im2_l != 0 else abs(jl)
jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
2020-12-11 19:01:09 +00:00
# Step 5
should_flip = random.random() < .5
if should_flip:
should_flip = 1
p2_resized = kornia.geometry.transform.hflip(p2_resized)
else:
should_flip = 0
# Step 6
i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l)
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
2020-12-11 19:01:09 +00:00
recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long)
2020-12-11 19:01:09 +00:00
# Step 7
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
2020-12-10 22:07:35 +00:00
masked1 = pad_to(p1 * mask1, d*m)
mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
2020-12-10 22:07:35 +00:00
masked2 = pad_to(p2 * mask2, d*m)
mask = torch.full((1, d*m, d*m), fill_value=.33)
mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33
masked_dbg = i1 * mask
return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg
# Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature
# maps.
2020-12-10 22:07:35 +00:00
def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
package = recompute_package.cpu()
res1 = []
res2 = []
pad_dim = torch.max(package[:, -2:]).item()
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
# of conforming the recompute_package across the entire batch.
for b in range(package.shape[0]):
2020-12-11 19:01:09 +00:00
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist())
# Unflip 2 if needed.
f2 = fea2[b]
if should_flip == 1:
f2 = kornia.geometry.transform.hflip(f2)
2020-12-10 22:07:35 +00:00
# Resize the input features to match
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear")
2020-12-11 19:01:09 +00:00
f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="bilinear")
2020-12-10 22:07:35 +00:00
# Outputs must be padded so they can "get along" with each other.
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
return torch.cat(res1, dim=0), torch.cat(res2, dim=0)
# Follows the general template of BYOL dataset, with the following changes:
# 1. Flip() is not applied.
# 2. Instead of RandomResizedCrop, a custom Transform, RandomSharedRegionCrop is used.
# 3. The dataset injects two integer tensors alongside the augmentations, which are used to index image regions shared
# by the joint augmentations.
# 4. The dataset injects an aug_shared_view for debugging purposes.
class StructuredCropDatasetWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
self.aug = nn.Sequential(*augmentations)
self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
def __getitem__(self, item):
item = self.wrapped_dataset[item]
2020-12-10 22:07:35 +00:00
a1 = self.aug(item['hq']).squeeze(dim=0)
a2 = self.aug(item['lq']).squeeze(dim=0)
a1, a2, sr_dim, m1, m2, db = self.rrc(a1, a2)
item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim,
'masked1': m1, 'masked2': m2, 'aug_shared_view': db})
return item
def __len__(self):
return len(self.wrapped_dataset)
# For testing this dataset.
if __name__ == '__main__':
opt = {
'dataset':
{
'mode': 'imagefolder',
'name': 'amalgam',
2020-12-10 22:07:35 +00:00
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
'weights': [1],
'target_size': 256,
'force_multiple': 32,
'scale': 1,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': True,
},
'latent_multiple': 8,
'jitter_range': 0,
}
ds = StructuredCropDatasetWrapper(opt)
import os
os.makedirs("debug", exist_ok=True)
2020-12-10 22:07:35 +00:00
for i in tqdm(range(0, len(ds))):
o = ds[random.randint(0, len(ds)-1)]
#for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
2020-12-10 22:07:35 +00:00
#if k in [ 'aug_shared_view', 'masked1', 'masked2']:
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
rcpkg = o['similar_region_dimensions']
pixun = PixelUnshuffle(8)
pixsh = nn.PixelShuffle(8)
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg)
2020-12-10 22:07:35 +00:00
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))