From 17aa205e96703849c3c47c99aa5b4a42b3091cac Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Sep 2020 17:32:57 -0600 Subject: [PATCH] New dataset that reads from lmdb --- codes/data/lmdb_dataset_with_ref.py | 231 ++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 codes/data/lmdb_dataset_with_ref.py diff --git a/codes/data/lmdb_dataset_with_ref.py b/codes/data/lmdb_dataset_with_ref.py new file mode 100644 index 00000000..1bf94796 --- /dev/null +++ b/codes/data/lmdb_dataset_with_ref.py @@ -0,0 +1,231 @@ +import random +import numpy as np +import cv2 +import torch +import torch.utils.data as data +import data.util as util +from PIL import Image, ImageOps +from io import BytesIO +import torchvision.transforms.functional as F +import lmdb +import pyarrow + + +# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to +# where those tiles came from. +class LmdbDatasetWithRef(data.Dataset): + + def __init__(self, opt): + super(LmdbDatasetWithRef, self).__init__() + self.opt = opt + self.db = lmdb.open(self.opt['lmdb_path'], subdir=True, readonly=True, lock=False, readahead=False, meminit=False) + self.data_type = 'img' + self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 + with self.db.begin(write=False) as txn: + self.keys = pyarrow.deserialize(txn.get(b'__keys__')) + self.len = pyarrow.deserialize(txn.get(b'__len__'))\ + + def motion_blur(self, image, size, angle): + k = np.zeros((size, size), dtype=np.float32) + k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32) + k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size)) + k = k * (1.0 / np.sum(k)) + return cv2.filter2D(image, -1, k) + + def resize_point(self, point, orig_dim, new_dim): + oh, ow = orig_dim + nh, nw = new_dim + dh, dw = float(nh) / float(oh), float(nw) / float(ow) + point[0] = int(dh * float(point[0])) + point[1] = int(dw * float(point[1])) + return point + + def augment_tile(self, img_GT, img_LQ, strength=1): + scale = self.opt['scale'] + GT_size = self.opt['target_size'] + + H, W, _ = img_GT.shape + assert H >= GT_size and W >= GT_size + + LQ_size = GT_size // scale + img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) + img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + + if self.opt['use_blurring']: + # Pick randomly between gaussian, motion, or no blur. + blur_det = random.randint(0, 100) + blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] + blur_magnitude = max(1, int(blur_magnitude*strength)) + if blur_det < 40: + blur_sig = int(random.randrange(0, int(blur_magnitude))) + img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig) + elif blur_det < 70: + img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360)) + + return img_GT, img_LQ + + # Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it. + def pil_augment(self, img_LQ, strength=1): + img_LQ = (img_LQ * 255).astype(np.uint8) + img_LQ = Image.fromarray(img_LQ) + if self.opt['use_compression_artifacts'] and random.random() > .25: + sub_lo = 90 * strength + sub_hi = 30 * strength + qf = random.randrange(100 - sub_lo, 100 - sub_hi) + corruption_buffer = BytesIO() + img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) + corruption_buffer.seek(0) + img_LQ = Image.open(corruption_buffer) + + if 'grayscale' in self.opt.keys() and self.opt['grayscale']: + img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') + + return img_LQ + + def __getitem__(self, index): + scale = self.opt['scale'] + + # get the hq image and the ref image + key = self.keys[index] + ref_key = key[:key.index('_')] + with self.db.begin(write=False) as txn: + bytes_ref = txn.get(ref_key.encode()) + bytes_tile = txn.get(key.encode()) + unpacked_ref = pyarrow.deserialize(bytes_ref) + unpacked_tile = pyarrow.deserialize(bytes_tile) + gt_fullsize_ref = unpacked_ref[0] + img_GT, gt_center = unpacked_tile + + # TODO: synthesize gt_mask. + gt_mask = np.ones(img_GT.shape[:2]) + orig_gt_dim = gt_fullsize_ref.shape[:2] + + # Synthesize LQ by downsampling. + if self.opt['phase'] == 'train': + GT_size = self.opt['target_size'] + random_scale = random.choice(self.random_scale_list) + if len(img_GT.shape) == 2: + print("ERRAR:") + print(img_GT.shape) + print(full_path) + H_s, W_s, _ = img_GT.shape + + def _mod(n, random_scale, scale, thres): + rlt = int(n * random_scale) + rlt = (rlt // scale) * scale + return thres if rlt < thres else rlt + + H_s = _mod(H_s, random_scale, scale, GT_size) + W_s = _mod(W_s, random_scale, scale, GT_size) + img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) + if img_GT.ndim == 2: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) + + H, W, _ = img_GT.shape + + # using matlab imresize + img_LQ = util.imresize_np(img_GT, 1 / scale, True) + lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True) + if img_LQ.ndim == 2: + img_LQ = np.expand_dims(img_LQ, axis=2) + lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2]) + orig_lq_dim = lq_fullsize_ref.shape[:2] + + # Enforce force_resize constraints via clipping. + h, w, _ = img_LQ.shape + if h % self.force_multiple != 0 or w % self.force_multiple != 0: + h, w = (h - h % self.force_multiple), (w - w % self.force_multiple) + img_LQ = img_LQ[:h, :w, :] + lq_fullsize_ref = lq_fullsize_ref[:h, :w, :] + h *= scale + w *= scale + img_GT = img_GT[:h, :w] + gt_fullsize_ref = gt_fullsize_ref[:h, :w, :] + + if self.opt['phase'] == 'train': + img_GT, img_LQ = self.augment_tile(img_GT, img_LQ) + gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2) + + # Scale masks. + lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) + gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) + + # Scale center coords + lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2]) + gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2]) + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) + img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) + lq_fullsize_ref = cv2.cvtColor(lq_fullsize_ref, cv2.COLOR_BGR2RGB) + gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB) + + # LQ needs to go to a PIL image to perform the compression-artifact transformation. + if self.opt['phase'] == 'train': + img_LQ = self.pil_augment(img_LQ) + lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) + + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float() + img_LQ = F.to_tensor(img_LQ) + lq_fullsize_ref = F.to_tensor(lq_fullsize_ref) + lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0) + gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0) + + if 'lq_noise' in self.opt.keys(): + lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255 + img_LQ += lq_noise + lq_fullsize_ref += lq_noise + + # Apply the masks to the full images. + gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) + lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) + + d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, + 'lq_center': lq_center, 'gt_center': gt_center, + 'LQ_path': key, 'GT_path': key} + return d + + def __len__(self): + return self.len + +if __name__ == '__main__': + opt = { + 'name': 'amalgam', + 'lmdb_path': 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref', + 'use_flip': True, + 'use_compression_artifacts': True, + 'use_blurring': True, + 'use_rot': True, + 'lq_noise': 5, + 'target_size': 128, + 'min_tile_size': 256, + 'scale': 2, + 'phase': 'train' + } + ''' + opt = { + 'name': 'amalgam', + 'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref'], + 'dataroot_GT_weights': [1], + 'force_multiple': 32, + 'scale': 2, + 'phase': 'test' + } + ''' + + ds = LmdbDatasetWithRef(opt) + import os + os.makedirs("debug", exist_ok=True) + for i in range(300, len(ds)): + print(i) + o = ds[i] + for k, v in o.items(): + if 'path' not in k: + #if 'full' in k: + #masked = v[:3, :, :] * v[3] + #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + #v = v[:3, :, :] + import torchvision + torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) \ No newline at end of file