forked from mrq/DL-Art-School
This will cause larger slices of an image to appear more frequently, increasing the difficulty of the generator.
364 lines
16 KiB
Python
364 lines
16 KiB
Python
import random
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
import torch.utils.data as data
|
|
import data.util as util
|
|
from PIL import Image, ImageOps
|
|
from io import BytesIO
|
|
import torchvision.transforms.functional as F
|
|
|
|
|
|
# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to
|
|
# where those tiles came from.
|
|
class FullImageDataset(data.Dataset):
|
|
"""
|
|
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
|
|
If only GT images are provided, generate LQ images on-the-fly.
|
|
"""
|
|
def get_lq_path(self, i):
|
|
which_lq = random.randint(0, len(self.paths_LQ)-1)
|
|
return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
|
|
|
|
def __init__(self, opt):
|
|
super(FullImageDataset, self).__init__()
|
|
self.opt = opt
|
|
self.data_type = 'img'
|
|
self.paths_LQ, self.paths_GT = None, None
|
|
self.sizes_LQ, self.sizes_GT = None, None
|
|
self.LQ_env, self.GT_env = None, None
|
|
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
|
|
|
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
|
|
if 'dataroot_LQ' in opt.keys():
|
|
self.paths_LQ = []
|
|
if isinstance(opt['dataroot_LQ'], list):
|
|
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
|
|
# we want the model to learn them all.
|
|
for dr_lq in opt['dataroot_LQ']:
|
|
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq)
|
|
self.paths_LQ.append(lq_path)
|
|
else:
|
|
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
|
self.paths_LQ.append(lq_path)
|
|
|
|
assert self.paths_GT, 'Error: GT path is empty.'
|
|
self.random_scale_list = [1]
|
|
|
|
def motion_blur(self, image, size, angle):
|
|
k = np.zeros((size, size), dtype=np.float32)
|
|
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
|
|
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
|
|
k = k * (1.0 / np.sum(k))
|
|
return cv2.filter2D(image, -1, k)
|
|
|
|
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
|
# offset from center is chosen on a normal probability curve.
|
|
def get_square_image(self, image):
|
|
h, w, _ = image.shape
|
|
if h == w:
|
|
return image
|
|
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
|
if h > w:
|
|
diff = h - w
|
|
center = diff // 2
|
|
top = int(center + offset * (center - 2))
|
|
return image[top:top+w, :, :]
|
|
else:
|
|
diff = w - h
|
|
center = diff // 2
|
|
left = int(center + offset * (center - 2))
|
|
return image[:, left:left+h, :]
|
|
|
|
def pick_along_range(self, sz, r, dev):
|
|
margin_sz = sz - r
|
|
margin_center = margin_sz // 2
|
|
return min(max(int(min(np.random.normal(scale=dev), 1.0) * margin_sz + margin_center), 0), margin_sz)
|
|
|
|
def resize_point(self, point, orig_dim, new_dim):
|
|
oh, ow = orig_dim
|
|
nh, nw = new_dim
|
|
dh, dw = float(nh) / float(oh), float(nw) / float(ow)
|
|
point[0] = int(dh * float(point[0]))
|
|
point[1] = int(dw * float(point[1]))
|
|
return point
|
|
|
|
# - Randomly extracts a square from image and resizes it to opt['target_size'].
|
|
# - Fills a mask with zeros, then places 1's where the square was extracted from. Resizes this mask and the source
|
|
# image to the target_size and returns that too.
|
|
# Notes:
|
|
# - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a
|
|
# half-normal distribution, biasing towards the target_size.
|
|
# - A biased normal distribution is also used to bias the tile selection towards the center of the source image.
|
|
def pull_tile(self, image, lq=False):
|
|
if lq:
|
|
target_sz = self.opt['min_tile_size'] // self.opt['scale']
|
|
else:
|
|
target_sz = self.opt['min_tile_size']
|
|
h, w, _ = image.shape
|
|
possible_sizes_above_target = h - target_sz
|
|
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.2)), 1.0))
|
|
# Pick the left,top coords to draw the patch from
|
|
left = self.pick_along_range(w, square_size, .3)
|
|
top = self.pick_along_range(w, square_size, .3)
|
|
|
|
mask = np.zeros((h, w, 1), dtype=image.dtype)
|
|
mask[top:top+square_size, left:left+square_size] = 1
|
|
patch = image[top:top+square_size, left:left+square_size, :]
|
|
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
|
|
|
patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
|
image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
|
mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
|
center = self.resize_point(center, (h, w), image.shape[:2])
|
|
|
|
return patch, image, mask, center
|
|
|
|
def augment_tile(self, img_GT, img_LQ, strength=1):
|
|
scale = self.opt['scale']
|
|
GT_size = self.opt['target_size']
|
|
|
|
H, W, _ = img_GT.shape
|
|
assert H >= GT_size and W >= GT_size
|
|
|
|
LQ_size = GT_size // scale
|
|
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
|
|
|
if self.opt['use_blurring']:
|
|
# Pick randomly between gaussian, motion, or no blur.
|
|
blur_det = random.randint(0, 100)
|
|
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
|
blur_magnitude = max(1, int(blur_magnitude*strength))
|
|
if blur_det < 40:
|
|
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
|
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
|
elif blur_det < 70:
|
|
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
|
|
|
|
return img_GT, img_LQ
|
|
|
|
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
|
|
def pil_augment(self, img_LQ, strength=1):
|
|
img_LQ = (img_LQ * 255).astype(np.uint8)
|
|
img_LQ = Image.fromarray(img_LQ)
|
|
if self.opt['use_compression_artifacts'] and random.random() > .25:
|
|
sub_lo = 90 * strength
|
|
sub_hi = 30 * strength
|
|
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
|
|
corruption_buffer = BytesIO()
|
|
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
|
corruption_buffer.seek(0)
|
|
img_LQ = Image.open(corruption_buffer)
|
|
|
|
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
|
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
|
|
|
return img_LQ
|
|
|
|
def perform_random_hr_augment(self, image, aug_code=None, augmentations=1):
|
|
if aug_code is None:
|
|
aug_code = [random.randint(0, 10) for _ in range(augmentations)]
|
|
else:
|
|
assert augmentations == 1
|
|
aug_code = [aug_code]
|
|
if 0 in aug_code:
|
|
# Color quantization
|
|
pass
|
|
elif 1 in aug_code:
|
|
# Gaussian Blur (point or motion)
|
|
blur_magnitude = 3
|
|
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
|
image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig)
|
|
elif 2 in aug_code:
|
|
# Median Blur
|
|
image = cv2.medianBlur(image, 3)
|
|
elif 3 in aug_code:
|
|
# Motion blur
|
|
image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360))
|
|
elif 4 in aug_code:
|
|
# Smooth blur
|
|
image = cv2.blur(image, ksize=3)
|
|
elif 5 in aug_code:
|
|
# Block noise
|
|
pass
|
|
elif 6 in aug_code:
|
|
# Bicubic LR->HR
|
|
pass
|
|
elif 7 in aug_code:
|
|
# Linear compression distortion
|
|
pass
|
|
elif 8 in aug_code:
|
|
# Interlacing distortion
|
|
pass
|
|
elif 9 in aug_code:
|
|
# Chromatic aberration
|
|
pass
|
|
elif 10 in aug_code:
|
|
# Noise
|
|
pass
|
|
elif 11 in aug_code:
|
|
# JPEG compression
|
|
pass
|
|
elif 12 in aug_code:
|
|
# Lightening / saturation
|
|
pass
|
|
return image
|
|
|
|
def __getitem__(self, index):
|
|
scale = self.opt['scale']
|
|
|
|
# get full size image
|
|
full_path = self.paths_GT[index % len(self.paths_GT)]
|
|
LQ_path = full_path
|
|
img_full = util.read_img(None, full_path, None)
|
|
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
|
|
if self.opt['phase'] == 'train':
|
|
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
|
img_full = self.get_square_image(img_full)
|
|
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full)
|
|
else:
|
|
img_GT, gt_fullsize_ref = img_full, img_full
|
|
gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype)
|
|
gt_center = torch.tensor([img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long)
|
|
orig_gt_dim = gt_fullsize_ref.shape[:2]
|
|
|
|
# get LQ image
|
|
if self.paths_LQ:
|
|
LQ_path = self.get_lq_path(index)
|
|
img_lq_full = util.read_img(None, LQ_path, None)
|
|
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
|
img_lq_full = self.get_square_image(img_lq_full)
|
|
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
|
|
else: # down-sampling on-the-fly
|
|
# randomly scale during training
|
|
if self.opt['phase'] == 'train':
|
|
GT_size = self.opt['target_size']
|
|
random_scale = random.choice(self.random_scale_list)
|
|
if len(img_GT.shape) == 2:
|
|
print("ERRAR:")
|
|
print(img_GT.shape)
|
|
print(full_path)
|
|
H_s, W_s, _ = img_GT.shape
|
|
|
|
def _mod(n, random_scale, scale, thres):
|
|
rlt = int(n * random_scale)
|
|
rlt = (rlt // scale) * scale
|
|
return thres if rlt < thres else rlt
|
|
|
|
H_s = _mod(H_s, random_scale, scale, GT_size)
|
|
W_s = _mod(W_s, random_scale, scale, GT_size)
|
|
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
|
if img_GT.ndim == 2:
|
|
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
|
|
|
H, W, _ = img_GT.shape
|
|
|
|
# using matlab imresize
|
|
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
|
lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True)
|
|
if img_LQ.ndim == 2:
|
|
img_LQ = np.expand_dims(img_LQ, axis=2)
|
|
lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
|
|
orig_lq_dim = lq_fullsize_ref.shape[:2]
|
|
|
|
# Enforce force_resize constraints via clipping.
|
|
h, w, _ = img_LQ.shape
|
|
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
|
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
|
|
img_LQ = img_LQ[:h, :w, :]
|
|
lq_fullsize_ref = lq_fullsize_ref[:h, :w, :]
|
|
h *= scale
|
|
w *= scale
|
|
img_GT = img_GT[:h, :w]
|
|
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
|
|
|
|
if self.opt['phase'] == 'train':
|
|
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
|
|
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
|
|
|
|
# Scale masks.
|
|
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
|
gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# Scale center coords
|
|
lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
|
|
gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
|
|
|
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
|
if img_GT.shape[2] == 3:
|
|
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
|
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
|
lq_fullsize_ref = cv2.cvtColor(lq_fullsize_ref, cv2.COLOR_BGR2RGB)
|
|
gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB)
|
|
|
|
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
|
if self.opt['phase'] == 'train':
|
|
img_LQ = self.pil_augment(img_LQ)
|
|
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
|
|
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
|
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
|
img_LQ = F.to_tensor(img_LQ)
|
|
lq_fullsize_ref = F.to_tensor(lq_fullsize_ref)
|
|
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
|
|
gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
|
|
|
|
if 'lq_noise' in self.opt.keys():
|
|
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
|
|
img_LQ += lq_noise
|
|
lq_fullsize_ref += lq_noise
|
|
|
|
# Apply the masks to the full images.
|
|
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
|
|
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
|
|
|
|
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
|
|
'lq_center': lq_center, 'gt_center': gt_center,
|
|
'LQ_path': LQ_path, 'GT_path': full_path}
|
|
return d
|
|
|
|
def __len__(self):
|
|
return len(self.paths_GT)
|
|
|
|
if __name__ == '__main__':
|
|
'''
|
|
opt = {
|
|
'name': 'amalgam',
|
|
'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'],
|
|
'dataroot_GT_weights': [1],
|
|
'use_flip': True,
|
|
'use_compression_artifacts': True,
|
|
'use_blurring': True,
|
|
'use_rot': True,
|
|
'lq_noise': 5,
|
|
'target_size': 128,
|
|
'min_tile_size': 256,
|
|
'scale': 2,
|
|
'phase': 'train'
|
|
}
|
|
'''
|
|
opt = {
|
|
'name': 'amalgam',
|
|
'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'],
|
|
'dataroot_GT_weights': [1],
|
|
'force_multiple': 32,
|
|
'scale': 2,
|
|
'phase': 'test'
|
|
}
|
|
|
|
ds = FullImageDataset(opt)
|
|
import os
|
|
os.makedirs("debug", exist_ok=True)
|
|
for i in range(300, len(ds)):
|
|
print(i)
|
|
o = ds[i]
|
|
for k, v in o.items():
|
|
if 'path' not in k:
|
|
#if 'full' in k:
|
|
#masked = v[:3, :, :] * v[3]
|
|
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
|
#v = v[:3, :, :]
|
|
#import torchvision
|
|
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
|
pass |