forked from mrq/DL-Art-School
231 lines
9.5 KiB
231 lines
9.5 KiB
import random
import numpy as np
import cv2
import torch
import as data
import data.util as util
from PIL import Image, ImageOps
from io import BytesIO
import torchvision.transforms.functional as F
import lmdb
import pyarrow
# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to
# where those tiles came from.
class LmdbDatasetWithRef(data.Dataset):
def __init__(self, opt):
super(LmdbDatasetWithRef, self).__init__()
self.opt = opt
self.db =['lmdb_path'], subdir=True, readonly=True, lock=False, readahead=False, meminit=False)
self.data_type = 'img'
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
with self.db.begin(write=False) as txn:
self.keys = pyarrow.deserialize(txn.get(b'__keys__'))
self.len = pyarrow.deserialize(txn.get(b'__len__'))\
def motion_blur(self, image, size, angle):
k = np.zeros((size, size), dtype=np.float32)
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
k = k * (1.0 / np.sum(k))
return cv2.filter2D(image, -1, k)
def resize_point(self, point, orig_dim, new_dim):
oh, ow = orig_dim
nh, nw = new_dim
dh, dw = float(nh) / float(oh), float(nw) / float(ow)
point[0] = int(dh * float(point[0]))
point[1] = int(dw * float(point[1]))
return point
def augment_tile(self, img_GT, img_LQ, strength=1):
scale = self.opt['scale']
GT_size = self.opt['target_size']
H, W, _ = img_GT.shape
assert H >= GT_size and W >= GT_size
LQ_size = GT_size // scale
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
if self.opt['use_blurring']:
# Pick randomly between gaussian, motion, or no blur.
blur_det = random.randint(0, 100)
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
blur_magnitude = max(1, int(blur_magnitude*strength))
if blur_det < 40:
blur_sig = int(random.randrange(0, int(blur_magnitude)))
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
elif blur_det < 70:
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
return img_GT, img_LQ
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
def pil_augment(self, img_LQ, strength=1):
img_LQ = (img_LQ * 255).astype(np.uint8)
img_LQ = Image.fromarray(img_LQ)
if self.opt['use_compression_artifacts'] and random.random() > .25:
sub_lo = 90 * strength
sub_hi = 30 * strength
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
corruption_buffer = BytesIO()
|, "JPEG", quality=qf, optimice=True)
img_LQ =
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
return img_LQ
def __getitem__(self, index):
scale = self.opt['scale']
# get the hq image and the ref image
key = self.keys[index]
ref_key = key[:key.index('_')]
with self.db.begin(write=False) as txn:
bytes_ref = txn.get(ref_key.encode())
bytes_tile = txn.get(key.encode())
unpacked_ref = pyarrow.deserialize(bytes_ref)
unpacked_tile = pyarrow.deserialize(bytes_tile)
gt_fullsize_ref = unpacked_ref[0]
img_GT, gt_center = unpacked_tile
# TODO: synthesize gt_mask.
gt_mask = np.ones(img_GT.shape[:2])
orig_gt_dim = gt_fullsize_ref.shape[:2]
# Synthesize LQ by downsampling.
if self.opt['phase'] == 'train':
GT_size = self.opt['target_size']
random_scale = random.choice(self.random_scale_list)
if len(img_GT.shape) == 2:
H_s, W_s, _ = img_GT.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
H, W, _ = img_GT.shape
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
orig_lq_dim = lq_fullsize_ref.shape[:2]
# Enforce force_resize constraints via clipping.
h, w, _ = img_LQ.shape
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
img_LQ = img_LQ[:h, :w, :]
lq_fullsize_ref = lq_fullsize_ref[:h, :w, :]
h *= scale
w *= scale
img_GT = img_GT[:h, :w]
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
if self.opt['phase'] == 'train':
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
# Scale masks.
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
# Scale center coords
lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
lq_fullsize_ref = cv2.cvtColor(lq_fullsize_ref, cv2.COLOR_BGR2RGB)
gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB)
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
if self.opt['phase'] == 'train':
img_LQ = self.pil_augment(img_LQ)
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
img_LQ = F.to_tensor(img_LQ)
lq_fullsize_ref = F.to_tensor(lq_fullsize_ref)
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
if 'lq_noise' in self.opt.keys():
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
img_LQ += lq_noise
lq_fullsize_ref += lq_noise
# Apply the masks to the full images.
gt_fullsize_ref =[gt_fullsize_ref, gt_mask], dim=0)
lq_fullsize_ref =[lq_fullsize_ref, lq_mask], dim=0)
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
'lq_center': lq_center, 'gt_center': gt_center,
'LQ_path': key, 'GT_path': key}
return d
def __len__(self):
return self.len
if __name__ == '__main__':
opt = {
'name': 'amalgam',
'lmdb_path': 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref',
'use_flip': True,
'use_compression_artifacts': True,
'use_blurring': True,
'use_rot': True,
'lq_noise': 5,
'target_size': 128,
'min_tile_size': 256,
'scale': 2,
'phase': 'train'
opt = {
'name': 'amalgam',
'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref'],
'dataroot_GT_weights': [1],
'force_multiple': 32,
'scale': 2,
'phase': 'test'
ds = LmdbDatasetWithRef(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in range(300, len(ds)):
o = ds[i]
for k, v in o.items():
if 'path' not in k:
#if 'full' in k:
#masked = v[:3, :, :] * v[3]
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
#v = v[:3, :, :]
import torchvision
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) |