forked from mrq/DL-Art-School
125 lines
5.8 KiB
Python
125 lines
5.8 KiB
Python
import random
|
|
import numpy as np
|
|
import cv2
|
|
import lmdb
|
|
import torch
|
|
import torch.utils.data as data
|
|
import data.util as util
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import torchvision.transforms.functional as F
|
|
|
|
|
|
class DownsampleDataset(data.Dataset):
|
|
"""
|
|
Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
|
|
downsampled LQ image from the HQ image and feeds that as well.
|
|
"""
|
|
|
|
def __init__(self, opt):
|
|
super(DownsampleDataset, self).__init__()
|
|
self.opt = opt
|
|
self.data_type = self.opt['data_type']
|
|
self.paths_LQ, self.paths_GT = None, None
|
|
self.sizes_LQ, self.sizes_GT = None, None
|
|
self.LQ_env, self.GT_env = None, None # environments for lmdb
|
|
self.doCrop = self.opt['doCrop']
|
|
|
|
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
|
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
|
|
|
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
|
|
assert self.paths_GT, 'Error: GT path is empty.'
|
|
assert self.paths_LQ, 'LQ is required for downsampling.'
|
|
if not self.data_sz_mismatch_ok:
|
|
assert len(self.paths_LQ) == len(
|
|
self.paths_GT
|
|
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
|
|
len(self.paths_LQ), len(self.paths_GT))
|
|
self.random_scale_list = [1]
|
|
|
|
def _init_lmdb(self):
|
|
# https://github.com/chainer/chainermn/issues/129
|
|
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
|
meminit=False)
|
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
|
meminit=False)
|
|
|
|
def __getitem__(self, index):
|
|
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
|
self._init_lmdb()
|
|
scale = self.opt['scale']
|
|
GT_size = self.opt['target_size'] * scale
|
|
|
|
# get GT image
|
|
GT_path = self.paths_GT[index % len(self.paths_GT)]
|
|
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
|
] if self.data_type == 'lmdb' else None
|
|
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
|
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
|
|
img_GT = util.modcrop(img_GT, scale)
|
|
if self.opt['color']: # change color space if necessary
|
|
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
|
|
|
# get LQ image
|
|
LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
|
|
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
|
] if self.data_type == 'lmdb' else None
|
|
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
|
|
|
if self.opt['phase'] == 'train':
|
|
H, W, _ = img_GT.shape
|
|
assert H >= GT_size and W >= GT_size
|
|
|
|
H, W, C = img_LQ.shape
|
|
LQ_size = GT_size // scale
|
|
|
|
if self.doCrop:
|
|
# randomly crop
|
|
rnd_h = random.randint(0, max(0, H - LQ_size))
|
|
rnd_w = random.randint(0, max(0, W - LQ_size))
|
|
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
|
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
|
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
|
else:
|
|
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# augmentation - flip, rotate
|
|
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
|
|
self.opt['use_rot'])
|
|
|
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
|
if img_GT.shape[2] == 3:
|
|
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
|
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
|
|
|
# HQ needs to go to a PIL image to perform the compression-artifact transformation.
|
|
H, W, _ = img_GT.shape
|
|
img_GT = (img_GT * 255).astype(np.uint8)
|
|
img_GT = Image.fromarray(img_GT)
|
|
if self.opt['use_compression_artifacts']:
|
|
qf = random.randrange(15, 100)
|
|
corruption_buffer = BytesIO()
|
|
img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
|
corruption_buffer.seek(0)
|
|
img_GT = Image.open(corruption_buffer)
|
|
# Generate a downsampled image from HQ for feature and PIX losses.
|
|
img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale)))
|
|
|
|
img_GT = F.to_tensor(img_GT)
|
|
img_Downsampled = F.to_tensor(img_Downsampled)
|
|
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
|
|
|
# This may seem really messed up, but let me explain:
|
|
# The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
|
|
# but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
|
|
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
|
|
# we just need to trick its logic into following this rules.
|
|
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
|
|
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
|
|
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
|
|
|
def __len__(self):
|
|
return max(len(self.paths_GT), len(self.paths_LQ))
|