From 74bb0fad332ae46f3ef7a7617a158883a788b45c Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 23 May 2020 21:04:24 -0600 Subject: [PATCH] Allow dataset classes to add noise internally --- codes/data/Downsample_dataset.py | 38 ++++++++++++++++------------- codes/data/LQGT_dataset.py | 41 ++++++++++++++++++++++---------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/codes/data/Downsample_dataset.py b/codes/data/Downsample_dataset.py index 5e950d05..e6861ef5 100644 --- a/codes/data/Downsample_dataset.py +++ b/codes/data/Downsample_dataset.py @@ -5,6 +5,9 @@ import lmdb import torch import torch.utils.data as data import data.util as util +from PIL import Image +from io import BytesIO +import torchvision.transforms.functional as F class DownsampleDataset(data.Dataset): @@ -64,10 +67,6 @@ class DownsampleDataset(data.Dataset): ] if self.data_type == 'lmdb' else None img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) - # Create a downsampled version of the HQ image using matlab imresize. - img_Downsampled = util.imresize_np(img_GT, 1 / scale) - assert img_Downsampled.ndim == 3 - if self.opt['phase'] == 'train': H, W, _ = img_GT.shape assert H >= GT_size and W >= GT_size @@ -80,29 +79,36 @@ class DownsampleDataset(data.Dataset): rnd_h = random.randint(0, max(0, H - LQ_size)) rnd_w = random.randint(0, max(0, W - LQ_size)) img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] - img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] else: img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - img_Downsampled = cv2.resize(img_Downsampled, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) # augmentation - flip, rotate - img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'], + img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], self.opt['use_rot']) - if self.opt['color']: # change color space if necessary - img_Downsampled = util.channel_convert(C, self.opt['color'], - [img_Downsampled])[0] # TODO during val no definition - # BGR to RGB, HWC to CHW, numpy to tensor if img_GT.shape[2] == 3: - img_GT = img_GT[:, :, [2, 1, 0]] - img_LQ = img_LQ[:, :, [2, 1, 0]] - img_Downsampled = img_Downsampled[:, :, [2, 1, 0]] - img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() - img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float() + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) + img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) + + # HQ needs to go to a PIL image to perform the compression-artifact transformation. + H, W, _ = img_GT.shape + img_GT = (img_GT * 255).astype(np.uint8) + img_GT = Image.fromarray(img_GT) + if self.opt['use_compression_artifacts']: + qf = random.randrange(15, 100) + corruption_buffer = BytesIO() + img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True) + corruption_buffer.seek(0) + img_GT = Image.open(corruption_buffer) + # Generate a downsampled image from HQ for feature and PIX losses. + img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale))) + + img_GT = F.to_tensor(img_GT) + img_Downsampled = F.to_tensor(img_Downsampled) img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() # This may seem really messed up, but let me explain: diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 02bfdd74..2f9820ef 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -5,6 +5,9 @@ import lmdb import torch import torch.utils.data as data import data.util as util +from PIL import Image +from io import BytesIO +import torchvision.transforms.functional as F class LQGTDataset(data.Dataset): @@ -27,16 +30,17 @@ class LQGTDataset(data.Dataset): self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) - self.paths_LQ = [] - if isinstance(opt['dataroot_LQ'], list): - # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and - # we want the model to learn them all. - for dr_lq in opt['dataroot_LQ']: - lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) + if 'dataroot_LQ' in opt.keys(): + self.paths_LQ = [] + if isinstance(opt['dataroot_LQ'], list): + # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and + # we want the model to learn them all. + for dr_lq in opt['dataroot_LQ']: + lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) + self.paths_LQ.append(lq_path) + else: + lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) self.paths_LQ.append(lq_path) - else: - lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) - self.paths_LQ.append(lq_path) self.doCrop = opt['doCrop'] if 'dataroot_PIX' in opt.keys(): self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) @@ -144,12 +148,23 @@ class LQGTDataset(data.Dataset): # BGR to RGB, HWC to CHW, numpy to tensor if img_GT.shape[2] == 3: - img_GT = img_GT[:, :, [2, 1, 0]] - img_LQ = img_LQ[:, :, [2, 1, 0]] - img_PIX = img_PIX[:, :, [2, 1, 0]] + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) + img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) + img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB) + + # LQ needs to go to a PIL image to perform the compression-artifact transformation. + img_LQ = (img_LQ * 255).astype(np.uint8) + img_LQ = Image.fromarray(img_LQ) + if self.opt['use_compression_artifacts']: + qf = random.randrange(15, 100) + corruption_buffer = BytesIO() + img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) + corruption_buffer.seek(0) + img_LQ = Image.open(corruption_buffer) + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float() - img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + img_LQ = F.to_tensor(img_LQ) if LQ_path is None: LQ_path = GT_path