import random from time import time import numpy as np import torch import torchvision from torch.utils.data import Dataset from kornia import augmentation as augs, kornia, Resample from kornia import filters import torch.nn as nn import torch.nn.functional as F # Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq' # inputs. These are then outputted as 'aug1' and 'aug2'. from tqdm import tqdm from data import create_dataset from models.arch_util import PixelUnshuffle from utils.util import opt_get class RandomApply(nn.Module): def __init__(self, fn, p): super().__init__() self.fn = fn self.p = p def forward(self, x): if random.random() > self.p: return x return self.fn(x) class ByolDatasetWrapper(Dataset): def __init__(self, opt): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) self.cropped_img_size = opt['crop_size'] self.key1 = opt_get(opt, ['key1'], 'hq') self.key2 = opt_get(opt, ['key2'], 'lq') for_sr = opt_get(opt, ['for_sr'], False) # When set, color alterations and blurs are disabled. augmentations = [ \ augs.RandomHorizontalFlip(), augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] if not for_sr: augmentations.extend([RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]) if opt['normalize']: # The paper calls for normalization. Most datasets/models in this repo don't use this. # Recommend setting true if you want to train exactly like the paper. augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) self.aug = nn.Sequential(*augmentations) def __getitem__(self, item): item = self.wrapped_dataset[item] item.update({'aug1': self.aug(item[self.key1]).squeeze(dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)}) return item def __len__(self): return len(self.wrapped_dataset) # Basically the same as ByolDatasetWrapper except only produces 1 augmentation and stores in the 'lr' key. Also applies # crop&resize to 2D tensors in the state dict with the word "label" in them. class DatasetRandomAugWrapper(Dataset): def __init__(self, opt): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) self.cropped_img_size = opt['crop_size'] self.includes_labels = opt['includes_labels'] augmentations = [ \ RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)] self.aug = nn.Sequential(*augmentations) self.rrc = nn.Sequential(*[ augs.RandomHorizontalFlip(), augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]) def __getitem__(self, item): item = self.wrapped_dataset[item] hq = self.aug(item['hq'].unsqueeze(0)).squeeze(0) labels = [] dtypes = [] for k in item.keys(): if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3: assert item[k].shape[0] == 1 # Only supports a channel dim of 1. labels.append(k) dtypes.append(item[k].dtype) hq = torch.cat([hq, item[k].type(torch.float)], dim=0) hq = self.rrc(hq.unsqueeze(0)).squeeze(0) for i, k in enumerate(labels): # Strip out any label values that are not whole numbers. item[k] = hq[3+i:3+i+1,:,:] whole = (item[k].round() == item[k]) item[k] = item[k] * whole item[k] = item[k].type(dtypes[i]) item['lq'] = hq[:3,:,:] item['hq'] = item['lq'] return item def __len__(self): return len(self.wrapped_dataset) def test_dataset_random_aug_wrapper(): opt = { 'dataset': { 'mode': 'imagefolder', 'name': 'amalgam', 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'], 'weights': [1], 'target_size': 512, 'force_multiple': 1, 'scale': 1, 'fixed_corruptions': ['jpeg-broad'], 'random_corruptions': ['noise-5', 'none'], 'num_corrupts_per_image': 1, 'corrupt_before_downsize': False, 'labeler': { 'type': 'patch_labels', 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json' } }, 'crop_size': 512, 'includes_labels': True, } ds = DatasetRandomAugWrapper(opt) import os os.makedirs("debug", exist_ok=True) for i in tqdm(range(0, len(ds))): o = ds[random.randint(0, len(ds)-1)] for k, v in o.items(): # 'lq', 'hq', 'aug1', 'aug2', if k in ['hq']: torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) masked = v * (o['labels_mask'] * .5 + .5) #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) # Pick a random (non-zero) label and spit it out with the textual label. if len(o['labels'].unique()) > 1: randlbl = np.random.choice(o['labels'].unique()[1:]) moremask = v * ((1*(o['labels'] == randlbl))*.5+.5) torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl])) def no_batch_interpolate(i, size, mode): i = i.unsqueeze(0) i = F.interpolate(i, size=size, mode=mode) return i.squeeze(0) # Performs a 1-d translation of "other": # If other (self.multiple*3) d = d // self.multiple # Step 2 base_w = random.randint(d//2+1, d-1) base_l = random.randint(0, d-base_w) base_h = random.randint(base_w-1, base_w+1) base_t = random.randint(0, d-base_h) base_r, base_b = base_l+base_w, base_t+base_h # Step 3 im2_w = random.randint(d//2+1, d-1) im2_l = random.randint(0, d-im2_w) im2_h = random.randint(im2_w-1, im2_w+1) im2_t = random.randint(0, d-im2_h) im2_r, im2_b = im2_l+im2_w, im2_t+im2_h # Step 4 m = self.multiple jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative. jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds. jl = jl if base_l != 0 else abs(jl) jl = jl if (base_l+base_w)*m != i1.shape[1] else 0 p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl] p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear") jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) jt = jt if im2_t != 0 else abs(jt) jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0 jl = jl if im2_l != 0 else abs(jl) jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0 p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl] p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear") # Step 5 should_flip = random.random() < .5 if should_flip: should_flip = 1 p2_resized = kornia.geometry.transform.hflip(p2_resized) else: should_flip = 0 # Step 6 i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l) i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l) ix_h = min(base_b, im2_b) - max(base_t, im2_t) ix_w = min(base_r, im2_r) - max(base_l, im2_l) recompute_package = torch.tensor([d, base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long) # Step 7 mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5) mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1 masked1 = pad_to(p1 * mask1, d*m) mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5) mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1 masked2 = pad_to(p2 * mask2, d*m) mask = torch.full((1, d*m, d*m), fill_value=.33) mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33 mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33 masked_dbg = i1 * mask # Step 8 - Rebuild shared regions for testing purposes. p1_shuf, p2_shuf = PixelUnshuffle(self.multiple)(p1_resized.unsqueeze(0)), \ PixelUnshuffle(self.multiple)(p2_resized.unsqueeze(0)) i1_shared, i2_shared = reconstructed_shared_regions(p1_shuf, p2_shuf, recompute_package.unsqueeze(0)) i1_shared = pad_to(nn.PixelShuffle(self.multiple)(i1_shared).squeeze(0), d * m) i2_shared = pad_to(nn.PixelShuffle(self.multiple)(i2_shared).squeeze(0), d*m) return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg, i1_shared, i2_shared # Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature # maps. def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor): package = recompute_package.cpu() res1 = [] res2 = [] pad_dim = torch.max(package[:, -2:]).item() # It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside # of conforming the recompute_package across the entire batch. for b in range(package.shape[0]): expected_dim, f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist()) # If you are hitting this assert, you specified `latent_multiple` in your dataset config wrong. assert expected_dim == fea1.shape[2] and expected_dim == fea2.shape[2] # Unflip 2 if needed. f2 = fea2[b] if should_flip == 1: f2 = kornia.geometry.transform.hflip(f2) # Resize the input features to match f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="nearest") f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="nearest") # Outputs must be padded so they can "get along" with each other. res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim)) res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim)) return torch.cat(res1, dim=0), torch.cat(res2, dim=0) # Follows the general template of BYOL dataset, with the following changes: # 1. Flip() is not applied. # 2. Instead of RandomResizedCrop, a custom Transform, RandomSharedRegionCrop is used. # 3. The dataset injects two integer tensors alongside the augmentations, which are used to index image regions shared # by the joint augmentations. # 4. The dataset injects an aug_shared_view for debugging purposes. class StructuredCropDatasetWrapper(Dataset): def __init__(self, opt): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)] self.aug = nn.Sequential(*augmentations) self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0)) def __getitem__(self, item): item = self.wrapped_dataset[item] a1 = self.aug(item['hq']).squeeze(dim=0) a2 = self.aug(item['lq']).squeeze(dim=0) a1, a2, sr_dim, m1, m2, db, i1s, i2s = self.rrc(a1, a2) item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim, 'masked1': m1, 'masked2': m2, 'aug_shared_view': db, 'i1_shared': i1s, 'i2_shared': i2s}) return item def __len__(self): return len(self.wrapped_dataset) # For testing this dataset. def test_structured_crop_dataset_wrapper(): opt = { 'dataset': { 'mode': 'imagefolder', 'name': 'amalgam', 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'], 'weights': [1], 'target_size': 256, 'force_multiple': 32, 'scale': 1, 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], 'random_corruptions': ['noise-5', 'none'], 'num_corrupts_per_image': 1, 'corrupt_before_downsize': True, }, 'latent_multiple': 16, 'jitter_range': 0, } ds = StructuredCropDatasetWrapper(opt) import os os.makedirs("debug", exist_ok=True) for i in tqdm(range(0, len(ds))): o = ds[random.randint(0, len(ds)-1)] #for k, v in o.items(): # 'lq', 'hq', 'aug1', 'aug2', #if k in [ 'aug_shared_view', 'masked1', 'masked2']: #torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) rcpkg = o['similar_region_dimensions'] pixun = PixelUnshuffle(16) pixsh = nn.PixelShuffle(16) rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0)) #torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,)) #torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,)) if __name__ == '__main__': test_dataset_random_aug_wrapper()