DL-Art-School/codes/data/Downsample_dataset.py
2020-05-23 21:04:24 -06:00

125 lines
5.8 KiB
Python

import random
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
from PIL import Image
from io import BytesIO
import torchvision.transforms.functional as F
class DownsampleDataset(data.Dataset):
"""
Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
downsampled LQ image from the HQ image and feeds that as well.
"""
def __init__(self, opt):
super(DownsampleDataset, self).__init__()
self.opt = opt
self.data_type = self.opt['data_type']
self.paths_LQ, self.paths_GT = None, None
self.sizes_LQ, self.sizes_GT = None, None
self.LQ_env, self.GT_env = None, None # environments for lmdb
self.doCrop = self.opt['doCrop']
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
assert self.paths_GT, 'Error: GT path is empty.'
assert self.paths_LQ, 'LQ is required for downsampling.'
if not self.data_sz_mismatch_ok:
assert len(self.paths_LQ) == len(
self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
len(self.paths_LQ), len(self.paths_GT))
self.random_scale_list = [1]
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def __getitem__(self, index):
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
scale = self.opt['scale']
GT_size = self.opt['target_size'] * scale
# get GT image
GT_path = self.paths_GT[index % len(self.paths_GT)]
resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution)
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
img_GT = util.modcrop(img_GT, scale)
if self.opt['color']: # change color space if necessary
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get LQ image
LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
if self.opt['phase'] == 'train':
H, W, _ = img_GT.shape
assert H >= GT_size and W >= GT_size
H, W, C = img_LQ.shape
LQ_size = GT_size // scale
if self.doCrop:
# randomly crop
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
else:
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# augmentation - flip, rotate
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
self.opt['use_rot'])
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
# HQ needs to go to a PIL image to perform the compression-artifact transformation.
H, W, _ = img_GT.shape
img_GT = (img_GT * 255).astype(np.uint8)
img_GT = Image.fromarray(img_GT)
if self.opt['use_compression_artifacts']:
qf = random.randrange(15, 100)
corruption_buffer = BytesIO()
img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_GT = Image.open(corruption_buffer)
# Generate a downsampled image from HQ for feature and PIX losses.
img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale)))
img_GT = F.to_tensor(img_GT)
img_Downsampled = F.to_tensor(img_Downsampled)
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
# This may seem really messed up, but let me explain:
# The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
# but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
# "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
# we just need to trick its logic into following this rules.
# Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
# PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
def __len__(self):
return max(len(self.paths_GT), len(self.paths_LQ))