Finish up single_image_dataset work

Sweet!
This commit is contained in:
James Betker 2020-09-25 16:37:54 -06:00
parent 1cf73c2cce
commit ce4613ecb9
9 changed files with 325 additions and 542 deletions

View File

@ -1,210 +0,0 @@
'''
REDS dataset
support reading images from lmdb, image folder and memcached
'''
import os.path as osp
import random
import pickle
import logging
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
try:
import mc # import memcached
except ImportError:
pass
logger = logging.getLogger('base')
class REDSDataset(data.Dataset):
'''
Reading the training REDS dataset
key example: 000_00000000
GT: Ground-Truth;
LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames
support reading N LQ frames, N = 1, 3, 5, 7
'''
def __init__(self, opt):
super(REDSDataset, self).__init__()
self.opt = opt
# temporal augmentation
self.interval_list = opt['interval_list']
self.random_reverse = opt['random_reverse']
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
self.half_N_frames = opt['N_frames'] // 2
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
self.data_type = self.opt['data_type']
self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs
#### directly load image keys
if self.data_type == 'lmdb':
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
logger.info('Using lmdb meta info for cache keys.')
elif opt['cache_keys']:
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
else:
raise ValueError(
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
# remove the REDS4 for testing
self.paths_GT = [
v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020']
]
assert self.paths_GT, 'Error: GT path is empty.'
if self.data_type == 'lmdb':
self.GT_env, self.LQ_env = None, None
elif self.data_type == 'mc': # memcached
self.mclient = None
elif self.data_type == 'img':
pass
else:
raise ValueError('Wrong data type: {}'.format(self.data_type))
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def _ensure_memcached(self):
if self.mclient is None:
# specify the config files
server_list_config_file = None
client_config_file = None
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
client_config_file)
def _read_img_mc(self, path):
''' Return BGR, HWC, [0, 255], uint8'''
value = mc.pyvector()
self.mclient.Get(path, value)
value_buf = mc.ConvertBuffer(value)
img_array = np.frombuffer(value_buf, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
return img
def _read_img_mc_BGR(self, path, name_a, name_b):
''' Read BGR channels separately and then combine for 1M limits in cluster'''
img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png'))
img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png'))
img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png'))
img = cv2.merge((img_B, img_G, img_R))
return img
def __getitem__(self, index):
if self.data_type == 'mc':
self._ensure_memcached()
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
scale = self.opt['scale']
GT_size = self.opt['target_size']
key = self.paths_GT[index]
name_a, name_b = key.split('_')
center_frame_idx = int(name_b)
#### determine the neighbor frames
interval = random.choice(self.interval_list)
if self.opt['border_mode']:
direction = 1 # 1: forward; 0: backward
N_frames = self.opt['N_frames']
if self.random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (N_frames - 1) > 99:
direction = 0
elif center_frame_idx - interval * (N_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * N_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * N_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + self.half_N_frames * interval >
99) or (center_frame_idx - self.half_N_frames * interval < 0):
center_frame_idx = random.randint(0, 99)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - self.half_N_frames * interval,
center_frame_idx + self.half_N_frames * interval + 1, interval))
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[self.half_N_frames])
assert len(
neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format(
len(neighbor_list))
#### get the GT image (as the center frame)
if self.data_type == 'mc':
img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b)
img_GT = img_GT.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_GT = util.read_img(self.GT_env, key, (3, 720, 1280))
else:
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png'))
#### get LQ images
LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280)
img_LQ_l = []
for v in neighbor_list:
img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v))
if self.data_type == 'mc':
if self.LR_input:
img_LQ = self._read_img_mc(img_LQ_path)
else:
img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v))
img_LQ = img_LQ.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple)
else:
img_LQ = util.read_img(None, img_LQ_path)
img_LQ_l.append(img_LQ)
if self.opt['phase'] == 'train':
C, H, W = LQ_size_tuple # LQ size
# randomly crop
if self.LR_input:
LQ_size = GT_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
else:
rnd_h = random.randint(0, max(0, H - GT_size))
rnd_w = random.randint(0, max(0, W - GT_size))
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
# augmentation - flip, rotate
img_LQ_l.append(img_GT)
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
img_LQ_l = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(img_LQ_l, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
(0, 3, 1, 2)))).float()
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
def __len__(self):
return len(self.paths_GT)

View File

@ -1,28 +1,36 @@
import os.path as osp
from data import util
import torch
import numpy as np
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
class ChunkWithReference:
def __init__(self, opt, path):
self.opt = opt
self.path = path
self.path = path.path
self.ref = None # This is loaded on the fly.
self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else True
self.tiles = util.get_image_paths('img', path)
self.tiles, _ = util.get_image_paths('img', path)
self.centers = None
def __getitem__(self, item):
# Load centers on the fly and always cache.
if self.centers is None:
self.centers = torch.load(osp.join(self.path, "centers.pt"))
if self.cache_ref:
if self.ref is None:
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"))
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
self.centers = torch.load(osp.join(self.path, "centers.pt"))
ref = self.ref
centers = self.centers
else:
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"))
self.centers = torch.load(osp.join(self.path, "centers.pt"))
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
tile = util.read_img(None, self.tiles[item], rgb=True)
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
center, tile_width = self.centers[tile_id]
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
return self.tiles[item], ref, centers[item], path
return tile, ref, center, mask, self.tiles[item]
def __len__(self):
return len(self.tiles)
return len(self.tiles)

View File

@ -1,72 +1,92 @@
import random
import cv2
import numpy as np
from data.util import read_img
from PIL import Image
from io import BytesIO
# Performs image corruption on a list of images from a configurable set of corruption
# options.
class ImageCorruptor:
def __init__(self, opt):
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2
self.corruptions_enabled = opt['corruptions']
self.fixed_corruptions = opt['fixed_corruptions']
self.random_corruptions = opt['random_corruptions']
def corrupt_images(self, imgs):
augmentations = random.choice(self.corruptions_enabled, k=self.num_corrupts)
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
# Source of entropy, which should be used across all images.
rand_int = random.randint(1, 999999)
rand_int_f = random.randint(1, 999999)
rand_int_a = random.randint(1, 999999)
corrupted_imgs = []
for img in imgs:
for aug in self.fixed_corruptions:
img = self.apply_corruption(img, aug, rand_int_f)
for aug in augmentations:
if 'color_quantization' in aug:
# Color quantization
quant_div = 2 ** random.randint(1, 4)
augmentation_tensor[AUG_TENSOR_COLOR_QUANT] = float(quant_div) / 5.0
img = self.apply_corruption(img, aug, rand_int_a)
corrupted_imgs.append(img)
pass
elif 'gaussian_blur' in aug:
# Gaussian Blur
kernel = random.randint(1, 3) * 3
image = cv2.GaussianBlur(image, (kernel, kernel), 3)
augmentation_tensor[AUG_TENSOR_BLUR] = float(kernel) / 9
elif 'median_blur' in aug:
# Median Blur
kernel = random.randint(1, 3) * 3
image = cv2.medianBlur(image, kernel)
augmentation_tensor[AUG_TENSOR_BLUR] = float(kernel) / 9
elif 'motion_blur' in aug:
# Motion blur
intensity = random.randrange(1, 9)
image = self.motion_blur(image, intensity, random.randint(0, 360))
augmentation_tensor[AUG_TENSOR_BLUR] = intensity / 9
elif 'smooth_blur' in aug:
# Smooth blur
kernel = random.randint(1, 3) * 3
image = cv2.blur(image, ksize=kernel)
augmentation_tensor[AUG_TENSOR_BLUR] = kernel / 9
elif 'block_noise' in aug:
# Block noise
noise_intensity = random.randint(3, 10)
image += np.random.randn()
pass
elif 'lq_resampling' in aug:
# Bicubic LR->HR
pass
elif 'color_shift' in aug:
# Color shift
pass
elif 'interlacing' in aug:
# Interlacing distortion
pass
elif 'chromatic_aberration' in aug:
# Chromatic aberration
pass
elif 'noise' in aug:
# Noise
pass
elif 'jpeg' in aug:
# JPEG compression
pass
elif 'saturation' in aug:
# Lightening / saturation
pass
return corrupted_imgs
return corrupted_imgs
def apply_corruption(self, img, aug, rand_int):
if 'color_quantization' in aug:
# Color quantization
quant_div = 2 ** ((rand_int % 3) + 2)
img = img * 255
img = (img // quant_div) * quant_div
img = img / 255
elif 'gaussian_blur' in aug:
# Gaussian Blur
kernel = 2 * (rand_int % 3) + 1
img = cv2.GaussianBlur(img, (kernel, kernel), 3)
elif 'motion_blur' in aug:
# Motion blur
intensity = 2 * (rand_int % 3) + 1
angle = (rand_int // 3) % 360
k = np.zeros((intensity, intensity), dtype=np.float32)
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0),
(intensity, intensity))
k = k * (1.0 / np.sum(k))
img = cv2.filter2D(img, -1, k)
elif 'smooth_blur' in aug:
# Smooth blur
kernel = 2 * (rand_int % 3) + 1
img = cv2.blur(img, ksize=(kernel, kernel))
elif 'block_noise' in aug:
# Large distortion blocks in part of an img, such as is used to mask out a face.
pass
elif 'lq_resampling' in aug:
# Bicubic LR->HR
pass
elif 'color_shift' in aug:
# Color shift
pass
elif 'interlacing' in aug:
# Interlacing distortion
pass
elif 'chromatic_aberration' in aug:
# Chromatic aberration
pass
elif 'noise' in aug:
# Block noise
noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4
img += np.random.randn() * noise_intensity
elif 'jpeg' in aug:
# JPEG compression
qf = (rand_int % 20 + 10) # Between 10-30
# cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead.
img = (img * 255).astype(np.uint8)
img = Image.fromarray(img)
buffer = BytesIO()
img.save(buffer, "JPEG", quality=qf, optimice=True)
buffer.seek(0)
jpeg_img_bytes = np.asarray(bytearray(buffer.read()), dtype="uint8")
img = read_img("buffer", jpeg_img_bytes, rgb=True)
elif 'saturation' in aug:
# Lightening / saturation
saturation = float(rand_int % 10) * .03
img = np.clip(img + saturation, a_max=1, a_min=0)
return img

View File

@ -1,231 +0,0 @@
import random
import numpy as np
import cv2
import torch
import torch.utils.data as data
import data.util as util
from PIL import Image, ImageOps
from io import BytesIO
import torchvision.transforms.functional as F
import lmdb
import pyarrow
# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to
# where those tiles came from.
class LmdbDatasetWithRef(data.Dataset):
def __init__(self, opt):
super(LmdbDatasetWithRef, self).__init__()
self.opt = opt
self.db = lmdb.open(self.opt['lmdb_path'], subdir=True, readonly=True, lock=False, readahead=False, meminit=False)
self.data_type = 'img'
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
with self.db.begin(write=False) as txn:
self.keys = pyarrow.deserialize(txn.get(b'__keys__'))
self.len = pyarrow.deserialize(txn.get(b'__len__'))\
def motion_blur(self, image, size, angle):
k = np.zeros((size, size), dtype=np.float32)
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
k = k * (1.0 / np.sum(k))
return cv2.filter2D(image, -1, k)
def resize_point(self, point, orig_dim, new_dim):
oh, ow = orig_dim
nh, nw = new_dim
dh, dw = float(nh) / float(oh), float(nw) / float(ow)
point[0] = int(dh * float(point[0]))
point[1] = int(dw * float(point[1]))
return point
def augment_tile(self, img_GT, img_LQ, strength=1):
scale = self.opt['scale']
GT_size = self.opt['target_size']
H, W, _ = img_GT.shape
assert H >= GT_size and W >= GT_size
LQ_size = GT_size // scale
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
if self.opt['use_blurring']:
# Pick randomly between gaussian, motion, or no blur.
blur_det = random.randint(0, 100)
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
blur_magnitude = max(1, int(blur_magnitude*strength))
if blur_det < 40:
blur_sig = int(random.randrange(0, int(blur_magnitude)))
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
elif blur_det < 70:
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
return img_GT, img_LQ
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
def pil_augment(self, img_LQ, strength=1):
img_LQ = (img_LQ * 255).astype(np.uint8)
img_LQ = Image.fromarray(img_LQ)
if self.opt['use_compression_artifacts'] and random.random() > .25:
sub_lo = 90 * strength
sub_hi = 30 * strength
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
corruption_buffer = BytesIO()
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
return img_LQ
def __getitem__(self, index):
scale = self.opt['scale']
# get the hq image and the ref image
key = self.keys[index]
ref_key = key[:key.index('_')]
with self.db.begin(write=False) as txn:
bytes_ref = txn.get(ref_key.encode())
bytes_tile = txn.get(key.encode())
unpacked_ref = pyarrow.deserialize(bytes_ref)
unpacked_tile = pyarrow.deserialize(bytes_tile)
gt_fullsize_ref = unpacked_ref[0]
img_GT, gt_center = unpacked_tile
# TODO: synthesize gt_mask.
gt_mask = np.ones(img_GT.shape[:2])
orig_gt_dim = gt_fullsize_ref.shape[:2]
# Synthesize LQ by downsampling.
if self.opt['phase'] == 'train':
GT_size = self.opt['target_size']
random_scale = random.choice(self.random_scale_list)
if len(img_GT.shape) == 2:
print("ERRAR:")
print(img_GT.shape)
print(full_path)
H_s, W_s, _ = img_GT.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
H, W, _ = img_GT.shape
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
orig_lq_dim = lq_fullsize_ref.shape[:2]
# Enforce force_resize constraints via clipping.
h, w, _ = img_LQ.shape
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
img_LQ = img_LQ[:h, :w, :]
lq_fullsize_ref = lq_fullsize_ref[:h, :w, :]
h *= scale
w *= scale
img_GT = img_GT[:h, :w]
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
if self.opt['phase'] == 'train':
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
# Scale masks.
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
# Scale center coords
lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
lq_fullsize_ref = cv2.cvtColor(lq_fullsize_ref, cv2.COLOR_BGR2RGB)
gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB)
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
if self.opt['phase'] == 'train':
img_LQ = self.pil_augment(img_LQ)
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
img_LQ = F.to_tensor(img_LQ)
lq_fullsize_ref = F.to_tensor(lq_fullsize_ref)
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
if 'lq_noise' in self.opt.keys():
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
img_LQ += lq_noise
lq_fullsize_ref += lq_noise
# Apply the masks to the full images.
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
'lq_center': lq_center, 'gt_center': gt_center,
'LQ_path': key, 'GT_path': key}
return d
def __len__(self):
return self.len
if __name__ == '__main__':
opt = {
'name': 'amalgam',
'lmdb_path': 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref',
'use_flip': True,
'use_compression_artifacts': True,
'use_blurring': True,
'use_rot': True,
'lq_noise': 5,
'target_size': 128,
'min_tile_size': 256,
'scale': 2,
'phase': 'train'
}
'''
opt = {
'name': 'amalgam',
'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref'],
'dataroot_GT_weights': [1],
'force_multiple': 32,
'scale': 2,
'phase': 'test'
}
'''
ds = LmdbDatasetWithRef(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in range(300, len(ds)):
print(i)
o = ds[i]
for k, v in o.items():
if 'path' not in k:
#if 'full' in k:
#masked = v[:3, :, :] * v[3]
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
#v = v[:3, :, :]
import torchvision
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))

View File

@ -5,6 +5,8 @@ import os
from bisect import bisect_left
import cv2
import torch
import numpy as np
import torchvision.transforms.functional as F
# Builds a dataset composed of a set of folders. Each folder represents a single high resolution image that has been
@ -23,8 +25,17 @@ class SingleImageDataset(data.Dataset):
self.weights = [1]
else:
self.weights = opt['weights']
# See if there is a cached directory listing and use that rather than re-scanning everything. This will greatly
# reduce startup costs.
self.chunks = []
for path, weight in zip(self.paths, self.weights):
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
cache_path = os.path.join(path, 'cache.pth')
if os.path.exists(cache_path):
chunks = torch.load(cache_path)
else:
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
torch.save(chunks, cache_path)
for w in range(weight):
self.chunks.extend(chunks)
@ -36,25 +47,18 @@ class SingleImageDataset(data.Dataset):
start += len(c)
self.len = start
def binary_search(elem, sorted_list):
# https://docs.python.org/3/library/bisect.html
'Locate the leftmost value exactly equal to x'
i = bisect_left(sorted_list, elem)
if i != len(sorted_list) and sorted_list[i] == elem:
return i
return -1
def resize_point(self, point, orig_dim, new_dim):
oh, ow = orig_dim
nh, nw = new_dim
dh, dw = float(nh) / float(oh), float(nw) / float(ow)
point[0] = int(dh * float(point[0]))
point[1] = int(dw * float(point[1]))
point = int(dh * float(point[0])), int(dw * float(point[1]))
return point
def __getitem__(self, item):
chunk_ind = self.binary_search(item, self.starting_indices)
hq, hq_ref, hq_center, path = self.chunks[item-self.starting_indices[chunk_ind]]
chunk_ind = bisect_left(self.starting_indices, item)
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item-self.starting_indices[chunk_ind]]
# Enforce size constraints
h, w, _ = hq.shape
@ -63,6 +67,7 @@ class SingleImageDataset(data.Dataset):
target_size = (self.target_hq_size, self.target_hq_size)
hq = cv2.resize(hq, target_size, interpolation=cv2.INTER_LINEAR)
hq_ref = cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_LINEAR)
hq_mask = cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_LINEAR)
hq_center = self.resize_point(hq_center, (h, w), target_size)
h, w = self.target_hq_size, self.target_hq_size
hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image.
@ -71,6 +76,7 @@ class SingleImageDataset(data.Dataset):
hq_center = self.resize_point(hq_center, hq.shape[:1], (h, w))
hq = hq[:h, :w, :]
hq_ref = hq_ref[:h, :w, :]
hq_mask = hq_mask[:h, :w, :]
# Synthesize the LQ image
if self.for_eval:
@ -78,20 +84,70 @@ class SingleImageDataset(data.Dataset):
else:
lq = cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)
lq_ref = cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)
lq_center = self.resize_point(hq_center, (h, w), lq.shape[:1])
lq_mask = cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)
lq_center = self.resize_point(hq_center, (h, w), lq.shape[:2])
# Corrupt the LQ image
lq = self.corruptor.corrupt_images([lq])
lq = self.corruptor.corrupt_images([lq])[0]
# Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hq, (2, 0, 1)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(hq_ref, (2, 0, 1)))).float()
lq = F.to_tensor(lq)
lq_ref = F.to_tensor(lq_ref)
hq_mask = torch.from_numpy(np.ascontiguousarray(hq_mask)).unsqueeze(dim=0)
hq_ref = torch.cat([hq_ref, hq_mask], dim=0)
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(lq, (2, 0, 1)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(lq_ref, (2, 0, 1)))).float()
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': lq_center, 'gt_center': hq_center,
'LQ_path': path, 'GT_path': path}
def __len__(self):
return self.len
return self.len
self.corruptor = ImageCorruptor(opt)
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
self.for_eval = opt['eval'] if 'eval' in opt.keys() else False
self.scale = opt['scale'] if not self.for_eval else 1
self.paths = opt['paths']
if not isinstance(self.paths, list):
self.paths = [self.paths]
self.weights = [1]
else:
self.weights = opt['weights']
for path, weight in zip(self.paths, self.weights):
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
for w in range(weight):
self.chunks.extend(chunks)
if __name__ == '__main__':
opt = {
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\images\\flickr\\testbed'],
'weights': [1],
'target_size': 128,
'force_multiple': 32,
'scale': 2,
'eval': False,
'fixed_corruptions': ['jpeg'],
'random_corruptions': ['color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise', 'saturation'],
'num_corrupts_per_image': 1
}
ds = SingleImageDataset(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in range(0, len(ds)):
o = ds[i]
for k, v in o.items():
if 'path' not in k and 'center' not in k:
#if 'full' in k:
#masked = v[:3, :, :] * v[3]
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
#v = v[:3, :, :]
import torchvision
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))

View File

@ -87,13 +87,18 @@ def _read_img_lmdb(env, key, size):
return img
def read_img(env, path, size=None):
"""read image by cv2 or from lmdb
def read_img(env, path, size=None, rgb=False):
"""read image by cv2 or from lmdb or from a buffer (in which case path=buffer)
return: Numpy float32, HWC, BGR, [0,1]"""
if env is None: # img
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
else:
elif env is 'lmdb':
img = _read_img_lmdb(env, path, size)
elif env is 'buffer':
img = cv2.imdecode(path, cv2.IMREAD_UNCHANGED)
else:
raise NotImplementedError("Unsupported env: %s" % (env,))
if img is None:
print("Image error: %s" % (path,))
img = img.astype(np.float32) / 255.
@ -102,6 +107,9 @@ def read_img(env, path, size=None):
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
if rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

View File

@ -16,17 +16,17 @@ def main():
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
split_img = False
opt = {}
opt['n_thread'] = 12
opt['n_thread'] = 0
opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['dest'] = 'file'
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_segments'
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_with_refs'
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\new_tiles'
opt['crop_sz'] = [256, 512, 1024] # the size of each sub-image
opt['step'] = 256 # step of the sliding crop window
opt['step'] = [256, 512, 1024] # step of the sliding crop window
opt['thres_sz'] = 128 # size threshold
opt['resize_final_img'] = [1, .5, .25]
opt['only_resize'] = False
@ -158,7 +158,7 @@ class FileWriter:
# Writes the given reference image to the db and returns its ID.
def write_reference_image(self, ref_img, path):
ref_img, _ = ref_img # Encoded with a center point, which is irrelevant for the reference image.
ref_img, _, _ = ref_img # Encoded with a center point, which is irrelevant for the reference image.
img_name = osp.basename(path).replace(".jpg", "").replace(".png", "")
self.ref_center_points[img_name] = {}
self.save_image(img_name, "ref.jpg", ref_img)
@ -170,14 +170,18 @@ class FileWriter:
def write_tile_image(self, ref_id, tile_image):
id = self.get_next_unique_id()
ref_name = self.ref_ids_to_names[ref_id]
img, center = tile_image
self.ref_center_points[ref_name][id] = center
img, center, tile_sz = tile_image
self.ref_center_points[ref_name][id] = center, tile_sz
self.save_image(ref_name, "%08i.jpg" % (id,), img)
return id
def close(self):
def flush(self):
for ref_name, cps in self.ref_center_points.items():
torch.save(cps, osp.join(self.folder, ref_name, "centers.pt"))
self.ref_center_points = {}
def close(self):
self.flush()
class TiledDataset(data.Dataset):
def __init__(self, opt, split_mode=False):
@ -192,8 +196,9 @@ class TiledDataset(data.Dataset):
else:
return self.get(index, False, False)
def get_for_scale(self, img, split_mode, left_img, crop_sz, resize_factor):
step = self.opt['step']
def get_for_scale(self, img, split_mode, left_image, crop_sz, step, resize_factor, ref_resize_factor):
assert not left_image # Split image not yet supported, False is the default value.
thres_sz = self.opt['thres_sz']
h, w, c = img.shape
@ -215,15 +220,14 @@ class TiledDataset(data.Dataset):
for y in w_space:
index += 1
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
center_point = (x + crop_sz // 2, y + crop_sz // 2)
# Center point needs to be resized by ref_resize_factor - since it is relative to the reference image.
center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor))
crop_img = np.ascontiguousarray(crop_img)
if 'resize_final_img' in self.opt.keys():
# Resize too.
center_point = (int(center_point[0] * resize_factor), int(center_point[1] * resize_factor))
crop_img = cv2.resize(crop_img, dsize, interpolation=cv2.INTER_AREA)
success, buffer = cv2.imencode(".jpg", crop_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
assert success
results.append((buffer, center_point))
results.append((buffer, center_point, int(crop_sz // ref_resize_factor)))
return results
def get(self, index, split_mode, left_img):
@ -255,15 +259,16 @@ class TiledDataset(data.Dataset):
tile_dim = int(self.opt['crop_sz'][0] * self.opt['resize_final_img'][0])
dsize = (tile_dim, tile_dim)
ref_resize_factor = h / tile_dim
# Reference image should always be first entry in results.
ref_img = cv2.resize(img, dsize, interpolation=cv2.INTER_AREA)
success, ref_buffer = cv2.imencode(".jpg", ref_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
assert success
results = [(ref_buffer, (-1,-1))]
results = [(ref_buffer, (-1,-1), (-1,-1))]
for crop_sz, resize_factor in zip(self.opt['crop_sz'], self.opt['resize_final_img']):
results.extend(self.get_for_scale(img, split_mode, left_img, crop_sz, resize_factor))
for crop_sz, resize_factor, step in zip(self.opt['crop_sz'], self.opt['resize_final_img'], self.opt['step']):
results.extend(self.get_for_scale(img, split_mode, left_img, crop_sz, step, resize_factor, ref_resize_factor))
return results, path
def __len__(self):
@ -286,6 +291,7 @@ def extract_single(opt, writer, split_img=False):
ref_id = writer.write_reference_image(imgs[0], path)
for tile in imgs[1:]:
writer.write_tile_image(ref_id, tile)
writer.flush()
writer.close()

View File

@ -238,6 +238,128 @@ class SPSRNet(nn.Module):
return x_out_branch, x_out, x_grad
class SwitchedSpsr(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
super(SwitchedSpsr, self).__init__()
n_upscale = int(math.log(upscale, 2))
# switch options
transformation_filters = nf
switch_filters = nf
switch_reductions = 3
switch_processing_layers = 2
self.transformation_counts = xforms
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
switch_processing_layers, self.transformation_counts, use_exp2=True)
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
transformation_filters, kernel_size=3, depth=3,
weight_init_factor=.1)
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Grad branch
self.get_g_nopadding = ImageGradientNoPadding()
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
# Upsampling
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Conv used to output grad branch shortcut.
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
# Conjoin branch.
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5),
transformation_filters, kernel_size=3, depth=4,
weight_init_factor=.1)
pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1)
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw]
self.attentions = None
self.init_temperature = init_temperature
self.final_temperature_step = 10000
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model_fea_conv(x)
x1, a1 = self.sw1(x, True)
x2, a2 = self.sw2(x1, True)
x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_hr_conv2(x_fea)
x_b_fea = self.b_fea_conv(x_grad)
x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True)
x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_hr_conv(x_grad)
x_out_branch = self.upsample_grad(x_grad)
x_out_branch = self.grad_branch_output_conv(x_out_branch)
x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1)
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=x_fea, identity=x_fea, output_attention_weights=True)
x_out = self.final_lr_conv(x__branch_pretrain_cat)
x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out)
x_out = self.final_hr_conv2(x_out)
self.attentions = [a1, a2, a3, a4]
return x_out_branch, x_out, x_grad
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp)
if step % 200 == 0:
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
class RefJoiner(nn.Module):
def __init__(self, nf):
super(RefJoiner, self).__init__()

View File

@ -49,6 +49,10 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == 'spsr_net_improved':
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == "spsr_switched":
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == "spsr_switched_with_ref2" or which_model == "spsr3":
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.SwitchedSpsrWithRef2(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],