forked from mrq/DL-Art-School
231 lines
9.5 KiB
Python
231 lines
9.5 KiB
Python
|
import random
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
import torch
|
||
|
import torch.utils.data as data
|
||
|
import data.util as util
|
||
|
from PIL import Image, ImageOps
|
||
|
from io import BytesIO
|
||
|
import torchvision.transforms.functional as F
|
||
|
import lmdb
|
||
|
import pyarrow
|
||
|
|
||
|
|
||
|
# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to
|
||
|
# where those tiles came from.
|
||
|
class LmdbDatasetWithRef(data.Dataset):
|
||
|
|
||
|
def __init__(self, opt):
|
||
|
super(LmdbDatasetWithRef, self).__init__()
|
||
|
self.opt = opt
|
||
|
self.db = lmdb.open(self.opt['lmdb_path'], subdir=True, readonly=True, lock=False, readahead=False, meminit=False)
|
||
|
self.data_type = 'img'
|
||
|
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
||
|
with self.db.begin(write=False) as txn:
|
||
|
self.keys = pyarrow.deserialize(txn.get(b'__keys__'))
|
||
|
self.len = pyarrow.deserialize(txn.get(b'__len__'))\
|
||
|
|
||
|
def motion_blur(self, image, size, angle):
|
||
|
k = np.zeros((size, size), dtype=np.float32)
|
||
|
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
|
||
|
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
|
||
|
k = k * (1.0 / np.sum(k))
|
||
|
return cv2.filter2D(image, -1, k)
|
||
|
|
||
|
def resize_point(self, point, orig_dim, new_dim):
|
||
|
oh, ow = orig_dim
|
||
|
nh, nw = new_dim
|
||
|
dh, dw = float(nh) / float(oh), float(nw) / float(ow)
|
||
|
point[0] = int(dh * float(point[0]))
|
||
|
point[1] = int(dw * float(point[1]))
|
||
|
return point
|
||
|
|
||
|
def augment_tile(self, img_GT, img_LQ, strength=1):
|
||
|
scale = self.opt['scale']
|
||
|
GT_size = self.opt['target_size']
|
||
|
|
||
|
H, W, _ = img_GT.shape
|
||
|
assert H >= GT_size and W >= GT_size
|
||
|
|
||
|
LQ_size = GT_size // scale
|
||
|
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||
|
|
||
|
if self.opt['use_blurring']:
|
||
|
# Pick randomly between gaussian, motion, or no blur.
|
||
|
blur_det = random.randint(0, 100)
|
||
|
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
||
|
blur_magnitude = max(1, int(blur_magnitude*strength))
|
||
|
if blur_det < 40:
|
||
|
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
||
|
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
||
|
elif blur_det < 70:
|
||
|
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
|
||
|
|
||
|
return img_GT, img_LQ
|
||
|
|
||
|
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
|
||
|
def pil_augment(self, img_LQ, strength=1):
|
||
|
img_LQ = (img_LQ * 255).astype(np.uint8)
|
||
|
img_LQ = Image.fromarray(img_LQ)
|
||
|
if self.opt['use_compression_artifacts'] and random.random() > .25:
|
||
|
sub_lo = 90 * strength
|
||
|
sub_hi = 30 * strength
|
||
|
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
|
||
|
corruption_buffer = BytesIO()
|
||
|
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
||
|
corruption_buffer.seek(0)
|
||
|
img_LQ = Image.open(corruption_buffer)
|
||
|
|
||
|
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
||
|
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
||
|
|
||
|
return img_LQ
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
scale = self.opt['scale']
|
||
|
|
||
|
# get the hq image and the ref image
|
||
|
key = self.keys[index]
|
||
|
ref_key = key[:key.index('_')]
|
||
|
with self.db.begin(write=False) as txn:
|
||
|
bytes_ref = txn.get(ref_key.encode())
|
||
|
bytes_tile = txn.get(key.encode())
|
||
|
unpacked_ref = pyarrow.deserialize(bytes_ref)
|
||
|
unpacked_tile = pyarrow.deserialize(bytes_tile)
|
||
|
gt_fullsize_ref = unpacked_ref[0]
|
||
|
img_GT, gt_center = unpacked_tile
|
||
|
|
||
|
# TODO: synthesize gt_mask.
|
||
|
gt_mask = np.ones(img_GT.shape[:2])
|
||
|
orig_gt_dim = gt_fullsize_ref.shape[:2]
|
||
|
|
||
|
# Synthesize LQ by downsampling.
|
||
|
if self.opt['phase'] == 'train':
|
||
|
GT_size = self.opt['target_size']
|
||
|
random_scale = random.choice(self.random_scale_list)
|
||
|
if len(img_GT.shape) == 2:
|
||
|
print("ERRAR:")
|
||
|
print(img_GT.shape)
|
||
|
print(full_path)
|
||
|
H_s, W_s, _ = img_GT.shape
|
||
|
|
||
|
def _mod(n, random_scale, scale, thres):
|
||
|
rlt = int(n * random_scale)
|
||
|
rlt = (rlt // scale) * scale
|
||
|
return thres if rlt < thres else rlt
|
||
|
|
||
|
H_s = _mod(H_s, random_scale, scale, GT_size)
|
||
|
W_s = _mod(W_s, random_scale, scale, GT_size)
|
||
|
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
||
|
if img_GT.ndim == 2:
|
||
|
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
||
|
|
||
|
H, W, _ = img_GT.shape
|
||
|
|
||
|
# using matlab imresize
|
||
|
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||
|
lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True)
|
||
|
if img_LQ.ndim == 2:
|
||
|
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||
|
lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
|
||
|
orig_lq_dim = lq_fullsize_ref.shape[:2]
|
||
|
|
||
|
# Enforce force_resize constraints via clipping.
|
||
|
h, w, _ = img_LQ.shape
|
||
|
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
||
|
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
|
||
|
img_LQ = img_LQ[:h, :w, :]
|
||
|
lq_fullsize_ref = lq_fullsize_ref[:h, :w, :]
|
||
|
h *= scale
|
||
|
w *= scale
|
||
|
img_GT = img_GT[:h, :w]
|
||
|
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
|
||
|
|
||
|
if self.opt['phase'] == 'train':
|
||
|
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
|
||
|
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
|
||
|
|
||
|
# Scale masks.
|
||
|
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||
|
gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||
|
|
||
|
# Scale center coords
|
||
|
lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
|
||
|
gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
|
||
|
|
||
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||
|
if img_GT.shape[2] == 3:
|
||
|
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
||
|
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
||
|
lq_fullsize_ref = cv2.cvtColor(lq_fullsize_ref, cv2.COLOR_BGR2RGB)
|
||
|
gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB)
|
||
|
|
||
|
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||
|
if self.opt['phase'] == 'train':
|
||
|
img_LQ = self.pil_augment(img_LQ)
|
||
|
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
||
|
|
||
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||
|
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
||
|
img_LQ = F.to_tensor(img_LQ)
|
||
|
lq_fullsize_ref = F.to_tensor(lq_fullsize_ref)
|
||
|
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
|
||
|
gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
|
||
|
|
||
|
if 'lq_noise' in self.opt.keys():
|
||
|
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
|
||
|
img_LQ += lq_noise
|
||
|
lq_fullsize_ref += lq_noise
|
||
|
|
||
|
# Apply the masks to the full images.
|
||
|
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
|
||
|
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
|
||
|
|
||
|
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
|
||
|
'lq_center': lq_center, 'gt_center': gt_center,
|
||
|
'LQ_path': key, 'GT_path': key}
|
||
|
return d
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.len
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
opt = {
|
||
|
'name': 'amalgam',
|
||
|
'lmdb_path': 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref',
|
||
|
'use_flip': True,
|
||
|
'use_compression_artifacts': True,
|
||
|
'use_blurring': True,
|
||
|
'use_rot': True,
|
||
|
'lq_noise': 5,
|
||
|
'target_size': 128,
|
||
|
'min_tile_size': 256,
|
||
|
'scale': 2,
|
||
|
'phase': 'train'
|
||
|
}
|
||
|
'''
|
||
|
opt = {
|
||
|
'name': 'amalgam',
|
||
|
'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref'],
|
||
|
'dataroot_GT_weights': [1],
|
||
|
'force_multiple': 32,
|
||
|
'scale': 2,
|
||
|
'phase': 'test'
|
||
|
}
|
||
|
'''
|
||
|
|
||
|
ds = LmdbDatasetWithRef(opt)
|
||
|
import os
|
||
|
os.makedirs("debug", exist_ok=True)
|
||
|
for i in range(300, len(ds)):
|
||
|
print(i)
|
||
|
o = ds[i]
|
||
|
for k, v in o.items():
|
||
|
if 'path' not in k:
|
||
|
#if 'full' in k:
|
||
|
#masked = v[:3, :, :] * v[3]
|
||
|
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||
|
#v = v[:3, :, :]
|
||
|
import torchvision
|
||
|
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|