Mods to support labeled datasets & random augs for those datasets
This commit is contained in:
parent
e5a3e6b9b5
commit
8e0e883050
|
@ -51,6 +51,8 @@ def create_dataset(dataset_opt):
|
||||||
from data.byol_attachment import ByolDatasetWrapper as D
|
from data.byol_attachment import ByolDatasetWrapper as D
|
||||||
elif mode == 'byol_structured_dataset':
|
elif mode == 'byol_structured_dataset':
|
||||||
from data.byol_attachment import StructuredCropDatasetWrapper as D
|
from data.byol_attachment import StructuredCropDatasetWrapper as D
|
||||||
|
elif mode == 'random_aug_wrapper':
|
||||||
|
from data.byol_attachment import DatasetRandomAugWrapper as D
|
||||||
elif mode == 'random_dataset':
|
elif mode == 'random_dataset':
|
||||||
from data.random_dataset import RandomDataset as D
|
from data.random_dataset import RandomDataset as D
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import random
|
import random
|
||||||
from time import time
|
from time import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from kornia import augmentation as augs, kornia
|
from kornia import augmentation as augs, kornia, Resample
|
||||||
from kornia import filters
|
from kornia import filters
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -41,7 +42,8 @@ class ByolDatasetWrapper(Dataset):
|
||||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||||
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
||||||
if opt['normalize']:
|
if opt['normalize']:
|
||||||
# The paper calls for normalization. Recommend setting true if you want exactly like the paper.
|
# The paper calls for normalization. Most datasets/models in this repo don't use this.
|
||||||
|
# Recommend setting true if you want to train exactly like the paper.
|
||||||
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
|
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
|
||||||
self.aug = nn.Sequential(*augmentations)
|
self.aug = nn.Sequential(*augmentations)
|
||||||
|
|
||||||
|
@ -54,6 +56,90 @@ class ByolDatasetWrapper(Dataset):
|
||||||
return len(self.wrapped_dataset)
|
return len(self.wrapped_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
# Basically the same as ByolDatasetWrapper except only produces 1 augmentation and stores in the 'lr' key. Also applies
|
||||||
|
# crop&resize to 2D tensors in the state dict with the word "label" in them.
|
||||||
|
class DatasetRandomAugWrapper(Dataset):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super().__init__()
|
||||||
|
self.wrapped_dataset = create_dataset(opt['dataset'])
|
||||||
|
self.cropped_img_size = opt['crop_size']
|
||||||
|
self.includes_labels = opt['includes_labels']
|
||||||
|
augmentations = [ \
|
||||||
|
RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
|
||||||
|
augs.RandomGrayscale(p=0.2),
|
||||||
|
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
|
||||||
|
self.aug = nn.Sequential(*augmentations)
|
||||||
|
self.rrc = nn.Sequential(*[
|
||||||
|
augs.RandomHorizontalFlip(),
|
||||||
|
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))])
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
item = self.wrapped_dataset[item]
|
||||||
|
hq = self.aug(item['hq'].unsqueeze(0)).squeeze(0)
|
||||||
|
labels = []
|
||||||
|
dtypes = []
|
||||||
|
for k in item.keys():
|
||||||
|
if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3:
|
||||||
|
assert item[k].shape[0] == 1 # Only supports a channel dim of 1.
|
||||||
|
labels.append(k)
|
||||||
|
dtypes.append(item[k].dtype)
|
||||||
|
hq = torch.cat([hq, item[k].type(torch.float)], dim=0)
|
||||||
|
hq = self.rrc(hq.unsqueeze(0)).squeeze(0)
|
||||||
|
for i, k in enumerate(labels):
|
||||||
|
# Strip out any label values that are not whole numbers.
|
||||||
|
item[k] = hq[3+i:3+i+1,:,:]
|
||||||
|
whole = (item[k].round() == item[k])
|
||||||
|
item[k] = item[k] * whole
|
||||||
|
item[k] = item[k].type(dtypes[i])
|
||||||
|
item['lq'] = hq[:3,:,:]
|
||||||
|
item['hq'] = item['lq']
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.wrapped_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_random_aug_wrapper():
|
||||||
|
opt = {
|
||||||
|
'dataset': {
|
||||||
|
'mode': 'imagefolder',
|
||||||
|
'name': 'amalgam',
|
||||||
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
|
||||||
|
'weights': [1],
|
||||||
|
'target_size': 512,
|
||||||
|
'force_multiple': 1,
|
||||||
|
'scale': 1,
|
||||||
|
'fixed_corruptions': ['jpeg-broad'],
|
||||||
|
'random_corruptions': ['noise-5', 'none'],
|
||||||
|
'num_corrupts_per_image': 1,
|
||||||
|
'corrupt_before_downsize': False,
|
||||||
|
'labeler': {
|
||||||
|
'type': 'patch_labels',
|
||||||
|
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'crop_size': 512,
|
||||||
|
'includes_labels': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
ds = DatasetRandomAugWrapper(opt)
|
||||||
|
import os
|
||||||
|
os.makedirs("debug", exist_ok=True)
|
||||||
|
for i in tqdm(range(0, len(ds))):
|
||||||
|
o = ds[random.randint(0, len(ds)-1)]
|
||||||
|
for k, v in o.items():
|
||||||
|
# 'lq', 'hq', 'aug1', 'aug2',
|
||||||
|
if k in ['hq']:
|
||||||
|
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||||
|
masked = v * (o['labels_mask'] * .5 + .5)
|
||||||
|
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||||
|
# Pick a random (non-zero) label and spit it out with the textual label.
|
||||||
|
if len(o['labels'].unique()) > 1:
|
||||||
|
randlbl = np.random.choice(o['labels'].unique()[1:])
|
||||||
|
moremask = v * ((1*(o['labels'] == randlbl))*.5+.5)
|
||||||
|
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl]))
|
||||||
|
|
||||||
|
|
||||||
def no_batch_interpolate(i, size, mode):
|
def no_batch_interpolate(i, size, mode):
|
||||||
i = i.unsqueeze(0)
|
i = i.unsqueeze(0)
|
||||||
i = F.interpolate(i, size=size, mode=mode)
|
i = F.interpolate(i, size=size, mode=mode)
|
||||||
|
@ -235,7 +321,7 @@ class StructuredCropDatasetWrapper(Dataset):
|
||||||
|
|
||||||
|
|
||||||
# For testing this dataset.
|
# For testing this dataset.
|
||||||
if __name__ == '__main__':
|
def test_structured_crop_dataset_wrapper():
|
||||||
opt = {
|
opt = {
|
||||||
'dataset':
|
'dataset':
|
||||||
{
|
{
|
||||||
|
@ -270,3 +356,7 @@ if __name__ == '__main__':
|
||||||
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
|
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
|
||||||
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
|
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
|
||||||
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
|
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_dataset_random_aug_wrapper()
|
||||||
|
|
|
@ -111,10 +111,13 @@ class ImageFolderDataset:
|
||||||
out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
||||||
if self.labeler:
|
if self.labeler:
|
||||||
base_file = self.image_paths[item].replace(self.paths[0], "")
|
base_file = self.image_paths[item].replace(self.paths[0], "")
|
||||||
|
while base_file.startswith("\\"):
|
||||||
|
base_file = base_file[1:]
|
||||||
assert dim % hq.shape[1] == 0
|
assert dim % hq.shape[1] == 0
|
||||||
lbls, lbl_masks = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1])
|
lbls, lbl_masks, lblstrings = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1])
|
||||||
out_dict['labels'] = lbls
|
out_dict['labels'] = lbls
|
||||||
out_dict['labels_mask'] = lbl_masks
|
out_dict['labels_mask'] = lbl_masks
|
||||||
|
out_dict['label_strings'] = lblstrings
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -122,7 +125,7 @@ if __name__ == '__main__':
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'],
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 128,
|
'target_size': 512,
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
'scale': 2,
|
'scale': 2,
|
||||||
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
||||||
|
@ -144,4 +147,8 @@ if __name__ == '__main__':
|
||||||
masked = (o['labels_mask'] * .5 + .5) * hq
|
masked = (o['labels_mask'] * .5 + .5) * hq
|
||||||
import torchvision
|
import torchvision
|
||||||
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,))
|
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,))
|
||||||
torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,))
|
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,))
|
||||||
|
if len(o['labels'].unique()) > 1:
|
||||||
|
randlbl = np.random.choice(o['labels'].unique()[1:])
|
||||||
|
moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
|
||||||
|
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))
|
|
@ -33,7 +33,7 @@ class VsNetImageLabeler:
|
||||||
# Build the label values, from [1,inf]
|
# Build the label values, from [1,inf]
|
||||||
label_value_dict = {}
|
label_value_dict = {}
|
||||||
for i, l in enumerate(available_labels):
|
for i, l in enumerate(available_labels):
|
||||||
label_value_dict[l] = i+1
|
label_value_dict[l] = i
|
||||||
|
|
||||||
# Insert "labelValue" for each entry.
|
# Insert "labelValue" for each entry.
|
||||||
for v in labeled_images.values():
|
for v in labeled_images.values():
|
||||||
|
@ -41,13 +41,15 @@ class VsNetImageLabeler:
|
||||||
l['labelValue'] = label_value_dict[l['label']]
|
l['labelValue'] = label_value_dict[l['label']]
|
||||||
|
|
||||||
self.labeled_images = labeled_images
|
self.labeled_images = labeled_images
|
||||||
|
self.str_labels = available_labels
|
||||||
|
|
||||||
def get_labeled_paths(self, base_path):
|
def get_labeled_paths(self, base_path):
|
||||||
return [os.path.join(base_path, pth) for pth in self.labeled_images]
|
return [os.path.join(base_path, pth) for pth in self.labeled_images]
|
||||||
|
|
||||||
def get_labels_as_tensor(self, hq, img_key, resize_factor):
|
def get_labels_as_tensor(self, hq, img_key, resize_factor):
|
||||||
labels = torch.zeros(hq.shape, dtype=torch.long)
|
_, h, w = hq.shape
|
||||||
mask = torch.zeros_like(hq)
|
labels = torch.zeros((1,h,w), dtype=torch.long)
|
||||||
|
mask = torch.zeros((1,h,w), dtype=torch.float)
|
||||||
lbl_list = self.labeled_images[img_key]
|
lbl_list = self.labeled_images[img_key]
|
||||||
for patch_lbl in lbl_list:
|
for patch_lbl in lbl_list:
|
||||||
t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \
|
t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \
|
||||||
|
@ -55,4 +57,4 @@ class VsNetImageLabeler:
|
||||||
val = patch_lbl['labelValue']
|
val = patch_lbl['labelValue']
|
||||||
labels[:,t:t+h,l:l+w] = val
|
labels[:,t:t+h,l:l+w] = val
|
||||||
mask[:,t:t+h,l:l+w] = 1.0
|
mask[:,t:t+h,l:l+w] = 1.0
|
||||||
return labels, mask
|
return labels, mask, self.str_labels
|
Loading…
Reference in New Issue
Block a user