diff --git a/codes/data/__init__.py b/codes/data/__init__.py index ea883a8a..5d33a206 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -51,6 +51,8 @@ def create_dataset(dataset_opt): from data.byol_attachment import ByolDatasetWrapper as D elif mode == 'byol_structured_dataset': from data.byol_attachment import StructuredCropDatasetWrapper as D + elif mode == 'random_aug_wrapper': + from data.byol_attachment import DatasetRandomAugWrapper as D elif mode == 'random_dataset': from data.random_dataset import RandomDataset as D else: diff --git a/codes/data/byol_attachment.py b/codes/data/byol_attachment.py index bf1dc2a6..ace0b2c1 100644 --- a/codes/data/byol_attachment.py +++ b/codes/data/byol_attachment.py @@ -1,10 +1,11 @@ import random from time import time +import numpy as np import torch import torchvision from torch.utils.data import Dataset -from kornia import augmentation as augs, kornia +from kornia import augmentation as augs, kornia, Resample from kornia import filters import torch.nn as nn import torch.nn.functional as F @@ -41,7 +42,8 @@ class ByolDatasetWrapper(Dataset): RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] if opt['normalize']: - # The paper calls for normalization. Recommend setting true if you want exactly like the paper. + # The paper calls for normalization. Most datasets/models in this repo don't use this. + # Recommend setting true if you want to train exactly like the paper. augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) self.aug = nn.Sequential(*augmentations) @@ -54,6 +56,90 @@ class ByolDatasetWrapper(Dataset): return len(self.wrapped_dataset) +# Basically the same as ByolDatasetWrapper except only produces 1 augmentation and stores in the 'lr' key. Also applies +# crop&resize to 2D tensors in the state dict with the word "label" in them. +class DatasetRandomAugWrapper(Dataset): + def __init__(self, opt): + super().__init__() + self.wrapped_dataset = create_dataset(opt['dataset']) + self.cropped_img_size = opt['crop_size'] + self.includes_labels = opt['includes_labels'] + augmentations = [ \ + RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8), + augs.RandomGrayscale(p=0.2), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)] + self.aug = nn.Sequential(*augmentations) + self.rrc = nn.Sequential(*[ + augs.RandomHorizontalFlip(), + augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]) + + def __getitem__(self, item): + item = self.wrapped_dataset[item] + hq = self.aug(item['hq'].unsqueeze(0)).squeeze(0) + labels = [] + dtypes = [] + for k in item.keys(): + if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3: + assert item[k].shape[0] == 1 # Only supports a channel dim of 1. + labels.append(k) + dtypes.append(item[k].dtype) + hq = torch.cat([hq, item[k].type(torch.float)], dim=0) + hq = self.rrc(hq.unsqueeze(0)).squeeze(0) + for i, k in enumerate(labels): + # Strip out any label values that are not whole numbers. + item[k] = hq[3+i:3+i+1,:,:] + whole = (item[k].round() == item[k]) + item[k] = item[k] * whole + item[k] = item[k].type(dtypes[i]) + item['lq'] = hq[:3,:,:] + item['hq'] = item['lq'] + return item + + def __len__(self): + return len(self.wrapped_dataset) + + +def test_dataset_random_aug_wrapper(): + opt = { + 'dataset': { + 'mode': 'imagefolder', + 'name': 'amalgam', + 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'], + 'weights': [1], + 'target_size': 512, + 'force_multiple': 1, + 'scale': 1, + 'fixed_corruptions': ['jpeg-broad'], + 'random_corruptions': ['noise-5', 'none'], + 'num_corrupts_per_image': 1, + 'corrupt_before_downsize': False, + 'labeler': { + 'type': 'patch_labels', + 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json' + } + }, + 'crop_size': 512, + 'includes_labels': True, + } + + ds = DatasetRandomAugWrapper(opt) + import os + os.makedirs("debug", exist_ok=True) + for i in tqdm(range(0, len(ds))): + o = ds[random.randint(0, len(ds)-1)] + for k, v in o.items(): + # 'lq', 'hq', 'aug1', 'aug2', + if k in ['hq']: + torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) + masked = v * (o['labels_mask'] * .5 + .5) + #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + # Pick a random (non-zero) label and spit it out with the textual label. + if len(o['labels'].unique()) > 1: + randlbl = np.random.choice(o['labels'].unique()[1:]) + moremask = v * ((1*(o['labels'] == randlbl))*.5+.5) + torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl])) + + def no_batch_interpolate(i, size, mode): i = i.unsqueeze(0) i = F.interpolate(i, size=size, mode=mode) @@ -235,7 +321,7 @@ class StructuredCropDatasetWrapper(Dataset): # For testing this dataset. -if __name__ == '__main__': +def test_structured_crop_dataset_wrapper(): opt = { 'dataset': { @@ -270,3 +356,7 @@ if __name__ == '__main__': rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0)) #torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,)) #torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,)) + + +if __name__ == '__main__': + test_dataset_random_aug_wrapper() diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 1c21c920..4279842d 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -111,10 +111,13 @@ class ImageFolderDataset: out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} if self.labeler: base_file = self.image_paths[item].replace(self.paths[0], "") + while base_file.startswith("\\"): + base_file = base_file[1:] assert dim % hq.shape[1] == 0 - lbls, lbl_masks = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1]) + lbls, lbl_masks, lblstrings = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1]) out_dict['labels'] = lbls out_dict['labels_mask'] = lbl_masks + out_dict['label_strings'] = lblstrings return out_dict if __name__ == '__main__': @@ -122,7 +125,7 @@ if __name__ == '__main__': 'name': 'amalgam', 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'], 'weights': [1], - 'target_size': 128, + 'target_size': 512, 'force_multiple': 32, 'scale': 2, 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], @@ -144,4 +147,8 @@ if __name__ == '__main__': masked = (o['labels_mask'] * .5 + .5) * hq import torchvision torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,)) - torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,)) \ No newline at end of file + #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,)) + if len(o['labels'].unique()) > 1: + randlbl = np.random.choice(o['labels'].unique()[1:]) + moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5) + torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl])) \ No newline at end of file diff --git a/codes/data/image_label_parser.py b/codes/data/image_label_parser.py index 241746c6..f8d55c84 100644 --- a/codes/data/image_label_parser.py +++ b/codes/data/image_label_parser.py @@ -33,7 +33,7 @@ class VsNetImageLabeler: # Build the label values, from [1,inf] label_value_dict = {} for i, l in enumerate(available_labels): - label_value_dict[l] = i+1 + label_value_dict[l] = i # Insert "labelValue" for each entry. for v in labeled_images.values(): @@ -41,13 +41,15 @@ class VsNetImageLabeler: l['labelValue'] = label_value_dict[l['label']] self.labeled_images = labeled_images + self.str_labels = available_labels def get_labeled_paths(self, base_path): return [os.path.join(base_path, pth) for pth in self.labeled_images] def get_labels_as_tensor(self, hq, img_key, resize_factor): - labels = torch.zeros(hq.shape, dtype=torch.long) - mask = torch.zeros_like(hq) + _, h, w = hq.shape + labels = torch.zeros((1,h,w), dtype=torch.long) + mask = torch.zeros((1,h,w), dtype=torch.float) lbl_list = self.labeled_images[img_key] for patch_lbl in lbl_list: t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \ @@ -55,4 +57,4 @@ class VsNetImageLabeler: val = patch_lbl['labelValue'] labels[:,t:t+h,l:l+w] = val mask[:,t:t+h,l:l+w] = 1.0 - return labels, mask \ No newline at end of file + return labels, mask, self.str_labels \ No newline at end of file