diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 1e7b65b9..753ae987 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -43,6 +43,8 @@ def create_dataset(dataset_opt): from data.paired_frame_dataset import PairedFrameDataset as D elif mode == 'stylegan2': from data.stylegan2_dataset import Stylegan2Dataset as D + elif mode == 'imagefolder': + from data.image_folder_dataset import ImageFolderDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py new file mode 100644 index 00000000..c2ae96ac --- /dev/null +++ b/codes/data/image_folder_dataset.py @@ -0,0 +1,129 @@ +import glob +import itertools +import random + +import cv2 +import numpy as np +import torch +import os + +from data import util +# Builds a dataset created from a simple folder containing a list of training/test/validation images. +from data.image_corruptor import ImageCorruptor + + +class ImageFolderDataset: + def __init__(self, opt): + self.opt = opt + self.corruptor = ImageCorruptor(opt) + self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None + self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1 + self.scale = opt['scale'] + self.paths = opt['paths'] + self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False + assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. + if not isinstance(self.paths, list): + self.paths = [self.paths] + self.weights = [1] + else: + self.weights = opt['weights'] + + # Just scan the given directory for images of standard types. + supported_types = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG', 'gif', 'GIF'] + self.image_paths = [] + for path, weight in zip(self.paths, self.weights): + cache_path = os.path.join(path, 'cache.pth') + if os.path.exists(cache_path): + imgs = torch.load(cache_path) + else: + print("Building image folder cache, this can take some time for large datasets..") + imgs = [] + for ext in supported_types: + imgs.extend(glob.glob(os.path.join(path, "*." + ext))) + torch.save(imgs, cache_path) + for w in range(weight): + self.image_paths.extend(imgs) + self.len = len(self.image_paths) + + def get_paths(self): + return self.image_paths + + # Given an HQ square of arbitrary size, resizes it to specifications from opt. + def resize_hq(self, imgs_hq): + # Enforce size constraints + h, w, _ = imgs_hq[0].shape + if self.target_hq_size is not None and self.target_hq_size != h: + hqs_adjusted = [] + for hq in imgs_hq: + # It is assumed that the target size is a square. + target_size = (self.target_hq_size, self.target_hq_size) + hqs_adjusted.append(cv2.resize(hq, target_size, interpolation=cv2.INTER_AREA)) + h, w = self.target_hq_size, self.target_hq_size + else: + hqs_adjusted = imgs_hq + hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image. + if h % hq_multiple != 0 or w % hq_multiple != 0: + hqs_conformed = [] + for hq in hqs_adjusted: + h, w = (h - h % hq_multiple), (w - w % hq_multiple) + hqs_conformed.append(hq[:h, :w, :]) + return hqs_conformed + return hqs_adjusted + + def synthesize_lq(self, hs): + h, w, _ = hs[0].shape + ls = [] + if self.corrupt_before_downsize: + hs = self.corruptor.corrupt_images(hs) + for hq in hs: + ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) + # Corrupt the LQ image (only in eval mode) + if not self.corrupt_before_downsize: + ls = self.corruptor.corrupt_images(ls) + return ls + + def __len__(self): + return self.len + + def __getitem__(self, item): + hq = util.read_img(None, self.image_paths[item], rgb=True) + + hs = self.resize_hq([hq]) + ls = self.synthesize_lq(hs) + + # Convert to torch tensor + hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() + lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() + + return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} + + +if __name__ == '__main__': + opt = { + 'name': 'amalgam', + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\random_100_1024px'], + 'weights': [1], + 'target_size': 128, + 'force_multiple': 32, + 'scale': 2, + 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], + 'random_corruptions': ['noise-5', 'none'], + 'num_corrupts_per_image': 1, + 'corrupt_before_downsize': True, + } + + ds = ImageFolderDataset(opt) + import os + os.makedirs("debug", exist_ok=True) + for i in range(0, len(ds)): + o = ds[random.randint(0, len(ds))] + #for k, v in o.items(): + k = 'LQ' + v = o[k] + #if 'LQ' in k and 'path' not in k and 'center' not in k: + #if 'full' in k: + #masked = v[:3, :, :] * v[3] + #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + #v = v[:3, :, :] + import torchvision + torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) \ No newline at end of file