DL-Art-School/codes/data/LQGT_dataset.py
James Betker 318a604405 Allow weighting of input data
This essentially allows you to give some datasets more
importance than others for the purposes of reaching a more
refined network.
2020-06-04 10:05:21 -06:00

200 lines
9.0 KiB
Python

import random
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
from PIL import Image, ImageOps
from io import BytesIO
import torchvision.transforms.functional as F
class LQGTDataset(data.Dataset):
"""
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
If only GT images are provided, generate LQ images on-the-fly.
"""
def get_lq_path(self, i):
which_lq = random.randint(0, len(self.paths_LQ)-1)
return self.paths_LQ[which_lq][i]
def __init__(self, opt):
super(LQGTDataset, self).__init__()
self.opt = opt
self.data_type = self.opt['data_type']
self.paths_LQ, self.paths_GT = None, None
self.sizes_LQ, self.sizes_GT = None, None
self.paths_PIX, self.sizes_PIX = None, None
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
if 'dataroot_LQ' in opt.keys():
self.paths_LQ = []
if isinstance(opt['dataroot_LQ'], list):
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
# we want the model to learn them all.
for dr_lq in opt['dataroot_LQ']:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq)
self.paths_LQ.append(lq_path)
else:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.paths_LQ.append(lq_path)
self.doCrop = opt['doCrop']
if 'dataroot_PIX' in opt.keys():
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT:
assert len(self.paths_LQ[0]) == len(
self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
len(self.paths_LQ[0]), len(self.paths_GT))
self.random_scale_list = [1]
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
if 'dataroot_PIX' in self.opt.keys():
self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False,
meminit=False)
def motion_blur(self, image, size, angle):
k = np.zeros((size, size), dtype=np.float32)
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
k = k * (1.0 / np.sum(k))
return cv2.filter2D(image, -1, k)
def __getitem__(self, index):
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
GT_path, LQ_path = None, None
scale = self.opt['scale']
GT_size = self.opt['target_size']
# get GT image
GT_path = self.paths_GT[index]
resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution)
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
img_GT = util.modcrop(img_GT, scale)
if self.opt['color']: # change color space if necessary
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get the pix image
if self.paths_PIX is not None:
PIX_path = self.paths_PIX[index]
img_PIX = util.read_img(self.PIX_env, PIX_path, resolution)
if self.opt['color']: # change color space if necessary
img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0]
else:
img_PIX = img_GT
# get LQ image
if self.paths_LQ:
LQ_path = self.get_lq_path(index)
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_GT.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
H, W, _ = img_GT.shape
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
if self.opt['phase'] == 'train':
H, W, _ = img_GT.shape
assert H >= GT_size and W >= GT_size
H, W, C = img_LQ.shape
LQ_size = GT_size // scale
if self.doCrop:
# randomly crop
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
else:
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# augmentation - flip, rotate
img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'],
self.opt['use_rot'])
if self.opt['use_blurring']:
# Pick randomly between gaussian, motion, or no blur.
blur_det = random.randint(0, 100)
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
if blur_det < 40:
blur_sig = int(random.randrange(0, blur_magnitude))
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
elif blur_det < 70:
img_LQ = self.motion_blur(img_LQ, random.randrange(1, blur_magnitude * 3), random.randint(0, 360))
if self.opt['color']: # change color space if necessary
img_LQ = util.channel_convert(C, self.opt['color'],
[img_LQ])[0] # TODO during val no definition
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB)
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
img_LQ = (img_LQ * 255).astype(np.uint8)
img_LQ = Image.fromarray(img_LQ)
if self.opt['use_compression_artifacts']:
qf = random.randrange(10, 70)
corruption_buffer = BytesIO()
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
if self.opt['grayscale']:
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
img_LQ = F.to_tensor(img_LQ)
lq_noise = torch.randn_like(img_LQ) * 5 / 255
img_LQ += lq_noise
if LQ_path is None:
LQ_path = GT_path
return {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
def __len__(self):
return len(self.paths_GT)