46f550e42b
I'm preprocessing the images myself now. There's no need to have the dataset do this processing as well.
120 lines
5.9 KiB
Python
120 lines
5.9 KiB
Python
import random
|
|
import numpy as np
|
|
import cv2
|
|
import lmdb
|
|
import torch
|
|
import torch.utils.data as data
|
|
import data.util as util
|
|
|
|
|
|
class DownsampleDataset(data.Dataset):
|
|
"""
|
|
Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
|
|
downsampled LQ image from the HQ image and feeds that as well.
|
|
"""
|
|
|
|
def __init__(self, opt):
|
|
super(DownsampleDataset, self).__init__()
|
|
self.opt = opt
|
|
self.data_type = self.opt['data_type']
|
|
self.paths_LQ, self.paths_GT = None, None
|
|
self.sizes_LQ, self.sizes_GT = None, None
|
|
self.LQ_env, self.GT_env = None, None # environments for lmdb
|
|
self.doCrop = self.opt['doCrop']
|
|
|
|
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
|
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
|
|
|
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
|
|
assert self.paths_GT, 'Error: GT path is empty.'
|
|
assert self.paths_LQ, 'LQ is required for downsampling.'
|
|
if not self.data_sz_mismatch_ok:
|
|
assert len(self.paths_LQ) == len(
|
|
self.paths_GT
|
|
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
|
|
len(self.paths_LQ), len(self.paths_GT))
|
|
self.random_scale_list = [1]
|
|
|
|
def _init_lmdb(self):
|
|
# https://github.com/chainer/chainermn/issues/129
|
|
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
|
meminit=False)
|
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
|
meminit=False)
|
|
|
|
def __getitem__(self, index):
|
|
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
|
self._init_lmdb()
|
|
scale = self.opt['scale']
|
|
GT_size = self.opt['target_size'] * scale
|
|
|
|
# get GT image
|
|
GT_path = self.paths_GT[index]
|
|
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
|
] if self.data_type == 'lmdb' else None
|
|
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
|
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
|
|
img_GT = util.modcrop(img_GT, scale)
|
|
if self.opt['color']: # change color space if necessary
|
|
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
|
|
|
# get LQ image
|
|
lqind = index % len(self.paths_LQ)
|
|
LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
|
|
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
|
] if self.data_type == 'lmdb' else None
|
|
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
|
|
|
# Create a downsampled version of the HQ image using matlab imresize.
|
|
img_Downsampled = util.imresize_np(img_GT, 1 / scale)
|
|
assert img_Downsampled.ndim == 3
|
|
|
|
if self.opt['phase'] == 'train':
|
|
H, W, _ = img_GT.shape
|
|
assert H >= GT_size and W >= GT_size
|
|
|
|
H, W, C = img_LQ.shape
|
|
LQ_size = GT_size // scale
|
|
|
|
if self.doCrop:
|
|
# randomly crop
|
|
rnd_h = random.randint(0, max(0, H - LQ_size))
|
|
rnd_w = random.randint(0, max(0, W - LQ_size))
|
|
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
|
img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
|
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
|
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
|
else:
|
|
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
|
img_Downsampled = cv2.resize(img_Downsampled, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# augmentation - flip, rotate
|
|
img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'],
|
|
self.opt['use_rot'])
|
|
|
|
if self.opt['color']: # change color space if necessary
|
|
img_Downsampled = util.channel_convert(C, self.opt['color'],
|
|
[img_Downsampled])[0] # TODO during val no definition
|
|
|
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
|
if img_GT.shape[2] == 3:
|
|
img_GT = img_GT[:, :, [2, 1, 0]]
|
|
img_LQ = img_LQ[:, :, [2, 1, 0]]
|
|
img_Downsampled = img_Downsampled[:, :, [2, 1, 0]]
|
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
|
img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float()
|
|
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
|
|
|
# This may seem really messed up, but let me explain:
|
|
# The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
|
|
# but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
|
|
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
|
|
# we just need to trick its logic into following this rules.
|
|
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
|
|
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
|
|
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
|
|
|
def __len__(self):
|
|
return len(self.paths_GT)
|