Mods to support labeled datasets & random augs for those datasets

This commit is contained in:
James Betker 2020-12-15 17:15:56 -07:00
parent e5a3e6b9b5
commit 8e0e883050
4 changed files with 111 additions and 10 deletions

View File

@ -51,6 +51,8 @@ def create_dataset(dataset_opt):
from data.byol_attachment import ByolDatasetWrapper as D
elif mode == 'byol_structured_dataset':
from data.byol_attachment import StructuredCropDatasetWrapper as D
elif mode == 'random_aug_wrapper':
from data.byol_attachment import DatasetRandomAugWrapper as D
elif mode == 'random_dataset':
from data.random_dataset import RandomDataset as D
else:

View File

@ -1,10 +1,11 @@
import random
from time import time
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
from kornia import augmentation as augs, kornia
from kornia import augmentation as augs, kornia, Resample
from kornia import filters
import torch.nn as nn
import torch.nn.functional as F
@ -41,7 +42,8 @@ class ByolDatasetWrapper(Dataset):
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
if opt['normalize']:
# The paper calls for normalization. Recommend setting true if you want exactly like the paper.
# The paper calls for normalization. Most datasets/models in this repo don't use this.
# Recommend setting true if you want to train exactly like the paper.
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
self.aug = nn.Sequential(*augmentations)
@ -54,6 +56,90 @@ class ByolDatasetWrapper(Dataset):
return len(self.wrapped_dataset)
# Basically the same as ByolDatasetWrapper except only produces 1 augmentation and stores in the 'lr' key. Also applies
# crop&resize to 2D tensors in the state dict with the word "label" in them.
class DatasetRandomAugWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size']
self.includes_labels = opt['includes_labels']
augmentations = [ \
RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
self.aug = nn.Sequential(*augmentations)
self.rrc = nn.Sequential(*[
augs.RandomHorizontalFlip(),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))])
def __getitem__(self, item):
item = self.wrapped_dataset[item]
hq = self.aug(item['hq'].unsqueeze(0)).squeeze(0)
labels = []
dtypes = []
for k in item.keys():
if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3:
assert item[k].shape[0] == 1 # Only supports a channel dim of 1.
labels.append(k)
dtypes.append(item[k].dtype)
hq = torch.cat([hq, item[k].type(torch.float)], dim=0)
hq = self.rrc(hq.unsqueeze(0)).squeeze(0)
for i, k in enumerate(labels):
# Strip out any label values that are not whole numbers.
item[k] = hq[3+i:3+i+1,:,:]
whole = (item[k].round() == item[k])
item[k] = item[k] * whole
item[k] = item[k].type(dtypes[i])
item['lq'] = hq[:3,:,:]
item['hq'] = item['lq']
return item
def __len__(self):
return len(self.wrapped_dataset)
def test_dataset_random_aug_wrapper():
opt = {
'dataset': {
'mode': 'imagefolder',
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
'weights': [1],
'target_size': 512,
'force_multiple': 1,
'scale': 1,
'fixed_corruptions': ['jpeg-broad'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': False,
'labeler': {
'type': 'patch_labels',
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json'
}
},
'crop_size': 512,
'includes_labels': True,
}
ds = DatasetRandomAugWrapper(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in tqdm(range(0, len(ds))):
o = ds[random.randint(0, len(ds)-1)]
for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
if k in ['hq']:
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
masked = v * (o['labels_mask'] * .5 + .5)
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
# Pick a random (non-zero) label and spit it out with the textual label.
if len(o['labels'].unique()) > 1:
randlbl = np.random.choice(o['labels'].unique()[1:])
moremask = v * ((1*(o['labels'] == randlbl))*.5+.5)
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl]))
def no_batch_interpolate(i, size, mode):
i = i.unsqueeze(0)
i = F.interpolate(i, size=size, mode=mode)
@ -235,7 +321,7 @@ class StructuredCropDatasetWrapper(Dataset):
# For testing this dataset.
if __name__ == '__main__':
def test_structured_crop_dataset_wrapper():
opt = {
'dataset':
{
@ -270,3 +356,7 @@ if __name__ == '__main__':
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
if __name__ == '__main__':
test_dataset_random_aug_wrapper()

View File

@ -111,10 +111,13 @@ class ImageFolderDataset:
out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
if self.labeler:
base_file = self.image_paths[item].replace(self.paths[0], "")
while base_file.startswith("\\"):
base_file = base_file[1:]
assert dim % hq.shape[1] == 0
lbls, lbl_masks = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1])
lbls, lbl_masks, lblstrings = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1])
out_dict['labels'] = lbls
out_dict['labels_mask'] = lbl_masks
out_dict['label_strings'] = lblstrings
return out_dict
if __name__ == '__main__':
@ -122,7 +125,7 @@ if __name__ == '__main__':
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'],
'weights': [1],
'target_size': 128,
'target_size': 512,
'force_multiple': 32,
'scale': 2,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
@ -144,4 +147,8 @@ if __name__ == '__main__':
masked = (o['labels_mask'] * .5 + .5) * hq
import torchvision
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,))
torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,))
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,))
if len(o['labels'].unique()) > 1:
randlbl = np.random.choice(o['labels'].unique()[1:])
moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))

View File

@ -33,7 +33,7 @@ class VsNetImageLabeler:
# Build the label values, from [1,inf]
label_value_dict = {}
for i, l in enumerate(available_labels):
label_value_dict[l] = i+1
label_value_dict[l] = i
# Insert "labelValue" for each entry.
for v in labeled_images.values():
@ -41,13 +41,15 @@ class VsNetImageLabeler:
l['labelValue'] = label_value_dict[l['label']]
self.labeled_images = labeled_images
self.str_labels = available_labels
def get_labeled_paths(self, base_path):
return [os.path.join(base_path, pth) for pth in self.labeled_images]
def get_labels_as_tensor(self, hq, img_key, resize_factor):
labels = torch.zeros(hq.shape, dtype=torch.long)
mask = torch.zeros_like(hq)
_, h, w = hq.shape
labels = torch.zeros((1,h,w), dtype=torch.long)
mask = torch.zeros((1,h,w), dtype=torch.float)
lbl_list = self.labeled_images[img_key]
for patch_lbl in lbl_list:
t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \
@ -55,4 +57,4 @@ class VsNetImageLabeler:
val = patch_lbl['labelValue']
labels[:,t:t+h,l:l+w] = val
mask[:,t:t+h,l:l+w] = 1.0
return labels, mask
return labels, mask, self.str_labels