Mods to support labeled datasets & random augs for those datasets

This commit is contained in:
James Betker 2020-12-15 17:15:56 -07:00
parent e5a3e6b9b5
commit 8e0e883050
4 changed files with 111 additions and 10 deletions

View File

@ -51,6 +51,8 @@ def create_dataset(dataset_opt):
from data.byol_attachment import ByolDatasetWrapper as D from data.byol_attachment import ByolDatasetWrapper as D
elif mode == 'byol_structured_dataset': elif mode == 'byol_structured_dataset':
from data.byol_attachment import StructuredCropDatasetWrapper as D from data.byol_attachment import StructuredCropDatasetWrapper as D
elif mode == 'random_aug_wrapper':
from data.byol_attachment import DatasetRandomAugWrapper as D
elif mode == 'random_dataset': elif mode == 'random_dataset':
from data.random_dataset import RandomDataset as D from data.random_dataset import RandomDataset as D
else: else:

View File

@ -1,10 +1,11 @@
import random import random
from time import time from time import time
import numpy as np
import torch import torch
import torchvision import torchvision
from torch.utils.data import Dataset from torch.utils.data import Dataset
from kornia import augmentation as augs, kornia from kornia import augmentation as augs, kornia, Resample
from kornia import filters from kornia import filters
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -41,7 +42,8 @@ class ByolDatasetWrapper(Dataset):
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
if opt['normalize']: if opt['normalize']:
# The paper calls for normalization. Recommend setting true if you want exactly like the paper. # The paper calls for normalization. Most datasets/models in this repo don't use this.
# Recommend setting true if you want to train exactly like the paper.
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
self.aug = nn.Sequential(*augmentations) self.aug = nn.Sequential(*augmentations)
@ -54,6 +56,90 @@ class ByolDatasetWrapper(Dataset):
return len(self.wrapped_dataset) return len(self.wrapped_dataset)
# Basically the same as ByolDatasetWrapper except only produces 1 augmentation and stores in the 'lr' key. Also applies
# crop&resize to 2D tensors in the state dict with the word "label" in them.
class DatasetRandomAugWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size']
self.includes_labels = opt['includes_labels']
augmentations = [ \
RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
self.aug = nn.Sequential(*augmentations)
self.rrc = nn.Sequential(*[
augs.RandomHorizontalFlip(),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))])
def __getitem__(self, item):
item = self.wrapped_dataset[item]
hq = self.aug(item['hq'].unsqueeze(0)).squeeze(0)
labels = []
dtypes = []
for k in item.keys():
if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3:
assert item[k].shape[0] == 1 # Only supports a channel dim of 1.
labels.append(k)
dtypes.append(item[k].dtype)
hq = torch.cat([hq, item[k].type(torch.float)], dim=0)
hq = self.rrc(hq.unsqueeze(0)).squeeze(0)
for i, k in enumerate(labels):
# Strip out any label values that are not whole numbers.
item[k] = hq[3+i:3+i+1,:,:]
whole = (item[k].round() == item[k])
item[k] = item[k] * whole
item[k] = item[k].type(dtypes[i])
item['lq'] = hq[:3,:,:]
item['hq'] = item['lq']
return item
def __len__(self):
return len(self.wrapped_dataset)
def test_dataset_random_aug_wrapper():
opt = {
'dataset': {
'mode': 'imagefolder',
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
'weights': [1],
'target_size': 512,
'force_multiple': 1,
'scale': 1,
'fixed_corruptions': ['jpeg-broad'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': False,
'labeler': {
'type': 'patch_labels',
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json'
}
},
'crop_size': 512,
'includes_labels': True,
}
ds = DatasetRandomAugWrapper(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in tqdm(range(0, len(ds))):
o = ds[random.randint(0, len(ds)-1)]
for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
if k in ['hq']:
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
masked = v * (o['labels_mask'] * .5 + .5)
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
# Pick a random (non-zero) label and spit it out with the textual label.
if len(o['labels'].unique()) > 1:
randlbl = np.random.choice(o['labels'].unique()[1:])
moremask = v * ((1*(o['labels'] == randlbl))*.5+.5)
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl]))
def no_batch_interpolate(i, size, mode): def no_batch_interpolate(i, size, mode):
i = i.unsqueeze(0) i = i.unsqueeze(0)
i = F.interpolate(i, size=size, mode=mode) i = F.interpolate(i, size=size, mode=mode)
@ -235,7 +321,7 @@ class StructuredCropDatasetWrapper(Dataset):
# For testing this dataset. # For testing this dataset.
if __name__ == '__main__': def test_structured_crop_dataset_wrapper():
opt = { opt = {
'dataset': 'dataset':
{ {
@ -270,3 +356,7 @@ if __name__ == '__main__':
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0)) rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,)) #torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,)) #torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
if __name__ == '__main__':
test_dataset_random_aug_wrapper()

View File

@ -111,10 +111,13 @@ class ImageFolderDataset:
out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
if self.labeler: if self.labeler:
base_file = self.image_paths[item].replace(self.paths[0], "") base_file = self.image_paths[item].replace(self.paths[0], "")
while base_file.startswith("\\"):
base_file = base_file[1:]
assert dim % hq.shape[1] == 0 assert dim % hq.shape[1] == 0
lbls, lbl_masks = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1]) lbls, lbl_masks, lblstrings = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1])
out_dict['labels'] = lbls out_dict['labels'] = lbls
out_dict['labels_mask'] = lbl_masks out_dict['labels_mask'] = lbl_masks
out_dict['label_strings'] = lblstrings
return out_dict return out_dict
if __name__ == '__main__': if __name__ == '__main__':
@ -122,7 +125,7 @@ if __name__ == '__main__':
'name': 'amalgam', 'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'], 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'],
'weights': [1], 'weights': [1],
'target_size': 128, 'target_size': 512,
'force_multiple': 32, 'force_multiple': 32,
'scale': 2, 'scale': 2,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
@ -144,4 +147,8 @@ if __name__ == '__main__':
masked = (o['labels_mask'] * .5 + .5) * hq masked = (o['labels_mask'] * .5 + .5) * hq
import torchvision import torchvision
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,)) torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,))
torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,)) #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,))
if len(o['labels'].unique()) > 1:
randlbl = np.random.choice(o['labels'].unique()[1:])
moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))

View File

@ -33,7 +33,7 @@ class VsNetImageLabeler:
# Build the label values, from [1,inf] # Build the label values, from [1,inf]
label_value_dict = {} label_value_dict = {}
for i, l in enumerate(available_labels): for i, l in enumerate(available_labels):
label_value_dict[l] = i+1 label_value_dict[l] = i
# Insert "labelValue" for each entry. # Insert "labelValue" for each entry.
for v in labeled_images.values(): for v in labeled_images.values():
@ -41,13 +41,15 @@ class VsNetImageLabeler:
l['labelValue'] = label_value_dict[l['label']] l['labelValue'] = label_value_dict[l['label']]
self.labeled_images = labeled_images self.labeled_images = labeled_images
self.str_labels = available_labels
def get_labeled_paths(self, base_path): def get_labeled_paths(self, base_path):
return [os.path.join(base_path, pth) for pth in self.labeled_images] return [os.path.join(base_path, pth) for pth in self.labeled_images]
def get_labels_as_tensor(self, hq, img_key, resize_factor): def get_labels_as_tensor(self, hq, img_key, resize_factor):
labels = torch.zeros(hq.shape, dtype=torch.long) _, h, w = hq.shape
mask = torch.zeros_like(hq) labels = torch.zeros((1,h,w), dtype=torch.long)
mask = torch.zeros((1,h,w), dtype=torch.float)
lbl_list = self.labeled_images[img_key] lbl_list = self.labeled_images[img_key]
for patch_lbl in lbl_list: for patch_lbl in lbl_list:
t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \ t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \
@ -55,4 +57,4 @@ class VsNetImageLabeler:
val = patch_lbl['labelValue'] val = patch_lbl['labelValue']
labels[:,t:t+h,l:l+w] = val labels[:,t:t+h,l:l+w] = val
mask[:,t:t+h,l:l+w] = 1.0 mask[:,t:t+h,l:l+w] = 1.0
return labels, mask return labels, mask, self.str_labels