Finish up single_image_dataset work
Sweet!
This commit is contained in:
parent
1cf73c2cce
commit
ce4613ecb9
|
@ -1,210 +0,0 @@
|
|||
'''
|
||||
REDS dataset
|
||||
support reading images from lmdb, image folder and memcached
|
||||
'''
|
||||
import os.path as osp
|
||||
import random
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
import cv2
|
||||
import lmdb
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
try:
|
||||
import mc # import memcached
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
class REDSDataset(data.Dataset):
|
||||
'''
|
||||
Reading the training REDS dataset
|
||||
key example: 000_00000000
|
||||
GT: Ground-Truth;
|
||||
LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames
|
||||
support reading N LQ frames, N = 1, 3, 5, 7
|
||||
'''
|
||||
|
||||
def __init__(self, opt):
|
||||
super(REDSDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# temporal augmentation
|
||||
self.interval_list = opt['interval_list']
|
||||
self.random_reverse = opt['random_reverse']
|
||||
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
|
||||
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
|
||||
|
||||
self.half_N_frames = opt['N_frames'] // 2
|
||||
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||||
self.data_type = self.opt['data_type']
|
||||
self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs
|
||||
#### directly load image keys
|
||||
if self.data_type == 'lmdb':
|
||||
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
logger.info('Using lmdb meta info for cache keys.')
|
||||
elif opt['cache_keys']:
|
||||
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
|
||||
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
|
||||
else:
|
||||
raise ValueError(
|
||||
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
|
||||
|
||||
# remove the REDS4 for testing
|
||||
self.paths_GT = [
|
||||
v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020']
|
||||
]
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
|
||||
if self.data_type == 'lmdb':
|
||||
self.GT_env, self.LQ_env = None, None
|
||||
elif self.data_type == 'mc': # memcached
|
||||
self.mclient = None
|
||||
elif self.data_type == 'img':
|
||||
pass
|
||||
else:
|
||||
raise ValueError('Wrong data type: {}'.format(self.data_type))
|
||||
|
||||
def _init_lmdb(self):
|
||||
# https://github.com/chainer/chainermn/issues/129
|
||||
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
|
||||
def _ensure_memcached(self):
|
||||
if self.mclient is None:
|
||||
# specify the config files
|
||||
server_list_config_file = None
|
||||
client_config_file = None
|
||||
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
|
||||
client_config_file)
|
||||
|
||||
def _read_img_mc(self, path):
|
||||
''' Return BGR, HWC, [0, 255], uint8'''
|
||||
value = mc.pyvector()
|
||||
self.mclient.Get(path, value)
|
||||
value_buf = mc.ConvertBuffer(value)
|
||||
img_array = np.frombuffer(value_buf, np.uint8)
|
||||
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
|
||||
return img
|
||||
|
||||
def _read_img_mc_BGR(self, path, name_a, name_b):
|
||||
''' Read BGR channels separately and then combine for 1M limits in cluster'''
|
||||
img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png'))
|
||||
img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png'))
|
||||
img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png'))
|
||||
img = cv2.merge((img_B, img_G, img_R))
|
||||
return img
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.data_type == 'mc':
|
||||
self._ensure_memcached()
|
||||
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||
self._init_lmdb()
|
||||
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['target_size']
|
||||
key = self.paths_GT[index]
|
||||
name_a, name_b = key.split('_')
|
||||
center_frame_idx = int(name_b)
|
||||
|
||||
#### determine the neighbor frames
|
||||
interval = random.choice(self.interval_list)
|
||||
if self.opt['border_mode']:
|
||||
direction = 1 # 1: forward; 0: backward
|
||||
N_frames = self.opt['N_frames']
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
direction = random.choice([0, 1])
|
||||
if center_frame_idx + interval * (N_frames - 1) > 99:
|
||||
direction = 0
|
||||
elif center_frame_idx - interval * (N_frames - 1) < 0:
|
||||
direction = 1
|
||||
# get the neighbor list
|
||||
if direction == 1:
|
||||
neighbor_list = list(
|
||||
range(center_frame_idx, center_frame_idx + interval * N_frames, interval))
|
||||
else:
|
||||
neighbor_list = list(
|
||||
range(center_frame_idx, center_frame_idx - interval * N_frames, -interval))
|
||||
name_b = '{:08d}'.format(neighbor_list[0])
|
||||
else:
|
||||
# ensure not exceeding the borders
|
||||
while (center_frame_idx + self.half_N_frames * interval >
|
||||
99) or (center_frame_idx - self.half_N_frames * interval < 0):
|
||||
center_frame_idx = random.randint(0, 99)
|
||||
# get the neighbor list
|
||||
neighbor_list = list(
|
||||
range(center_frame_idx - self.half_N_frames * interval,
|
||||
center_frame_idx + self.half_N_frames * interval + 1, interval))
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
neighbor_list.reverse()
|
||||
name_b = '{:08d}'.format(neighbor_list[self.half_N_frames])
|
||||
|
||||
assert len(
|
||||
neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format(
|
||||
len(neighbor_list))
|
||||
|
||||
#### get the GT image (as the center frame)
|
||||
if self.data_type == 'mc':
|
||||
img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b)
|
||||
img_GT = img_GT.astype(np.float32) / 255.
|
||||
elif self.data_type == 'lmdb':
|
||||
img_GT = util.read_img(self.GT_env, key, (3, 720, 1280))
|
||||
else:
|
||||
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png'))
|
||||
|
||||
#### get LQ images
|
||||
LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280)
|
||||
img_LQ_l = []
|
||||
for v in neighbor_list:
|
||||
img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v))
|
||||
if self.data_type == 'mc':
|
||||
if self.LR_input:
|
||||
img_LQ = self._read_img_mc(img_LQ_path)
|
||||
else:
|
||||
img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v))
|
||||
img_LQ = img_LQ.astype(np.float32) / 255.
|
||||
elif self.data_type == 'lmdb':
|
||||
img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple)
|
||||
else:
|
||||
img_LQ = util.read_img(None, img_LQ_path)
|
||||
img_LQ_l.append(img_LQ)
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
C, H, W = LQ_size_tuple # LQ size
|
||||
# randomly crop
|
||||
if self.LR_input:
|
||||
LQ_size = GT_size // scale
|
||||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
|
||||
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
|
||||
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
|
||||
else:
|
||||
rnd_h = random.randint(0, max(0, H - GT_size))
|
||||
rnd_w = random.randint(0, max(0, W - GT_size))
|
||||
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
|
||||
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_LQ_l.append(img_GT)
|
||||
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
|
||||
img_LQ_l = rlt[0:-1]
|
||||
img_GT = rlt[-1]
|
||||
|
||||
# stack LQ images to NHWC, N is the frame number
|
||||
img_LQs = np.stack(img_LQ_l, axis=0)
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
|
||||
(0, 3, 1, 2)))).float()
|
||||
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths_GT)
|
|
@ -1,28 +1,36 @@
|
|||
import os.path as osp
|
||||
from data import util
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
|
||||
class ChunkWithReference:
|
||||
def __init__(self, opt, path):
|
||||
self.opt = opt
|
||||
self.path = path
|
||||
self.path = path.path
|
||||
self.ref = None # This is loaded on the fly.
|
||||
self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else True
|
||||
self.tiles = util.get_image_paths('img', path)
|
||||
self.tiles, _ = util.get_image_paths('img', path)
|
||||
self.centers = None
|
||||
|
||||
def __getitem__(self, item):
|
||||
# Load centers on the fly and always cache.
|
||||
if self.centers is None:
|
||||
self.centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||
if self.cache_ref:
|
||||
if self.ref is None:
|
||||
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"))
|
||||
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
|
||||
self.centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||
ref = self.ref
|
||||
centers = self.centers
|
||||
else:
|
||||
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"))
|
||||
self.centers = torch.load(osp.join(self.path, "centers.pt"))
|
||||
self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True)
|
||||
tile = util.read_img(None, self.tiles[item], rgb=True)
|
||||
tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0])
|
||||
center, tile_width = self.centers[tile_id]
|
||||
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||
|
||||
return self.tiles[item], ref, centers[item], path
|
||||
return tile, ref, center, mask, self.tiles[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tiles)
|
||||
return len(self.tiles)
|
||||
|
|
|
@ -1,72 +1,92 @@
|
|||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from data.util import read_img
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
# Performs image corruption on a list of images from a configurable set of corruption
|
||||
# options.
|
||||
class ImageCorruptor:
|
||||
def __init__(self, opt):
|
||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2
|
||||
self.corruptions_enabled = opt['corruptions']
|
||||
self.fixed_corruptions = opt['fixed_corruptions']
|
||||
self.random_corruptions = opt['random_corruptions']
|
||||
|
||||
def corrupt_images(self, imgs):
|
||||
augmentations = random.choice(self.corruptions_enabled, k=self.num_corrupts)
|
||||
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
|
||||
# Source of entropy, which should be used across all images.
|
||||
rand_int = random.randint(1, 999999)
|
||||
rand_int_f = random.randint(1, 999999)
|
||||
rand_int_a = random.randint(1, 999999)
|
||||
|
||||
corrupted_imgs = []
|
||||
for img in imgs:
|
||||
for aug in self.fixed_corruptions:
|
||||
img = self.apply_corruption(img, aug, rand_int_f)
|
||||
for aug in augmentations:
|
||||
if 'color_quantization' in aug:
|
||||
# Color quantization
|
||||
quant_div = 2 ** random.randint(1, 4)
|
||||
augmentation_tensor[AUG_TENSOR_COLOR_QUANT] = float(quant_div) / 5.0
|
||||
img = self.apply_corruption(img, aug, rand_int_a)
|
||||
corrupted_imgs.append(img)
|
||||
|
||||
pass
|
||||
elif 'gaussian_blur' in aug:
|
||||
# Gaussian Blur
|
||||
kernel = random.randint(1, 3) * 3
|
||||
image = cv2.GaussianBlur(image, (kernel, kernel), 3)
|
||||
augmentation_tensor[AUG_TENSOR_BLUR] = float(kernel) / 9
|
||||
elif 'median_blur' in aug:
|
||||
# Median Blur
|
||||
kernel = random.randint(1, 3) * 3
|
||||
image = cv2.medianBlur(image, kernel)
|
||||
augmentation_tensor[AUG_TENSOR_BLUR] = float(kernel) / 9
|
||||
elif 'motion_blur' in aug:
|
||||
# Motion blur
|
||||
intensity = random.randrange(1, 9)
|
||||
image = self.motion_blur(image, intensity, random.randint(0, 360))
|
||||
augmentation_tensor[AUG_TENSOR_BLUR] = intensity / 9
|
||||
elif 'smooth_blur' in aug:
|
||||
# Smooth blur
|
||||
kernel = random.randint(1, 3) * 3
|
||||
image = cv2.blur(image, ksize=kernel)
|
||||
augmentation_tensor[AUG_TENSOR_BLUR] = kernel / 9
|
||||
elif 'block_noise' in aug:
|
||||
# Block noise
|
||||
noise_intensity = random.randint(3, 10)
|
||||
image += np.random.randn()
|
||||
pass
|
||||
elif 'lq_resampling' in aug:
|
||||
# Bicubic LR->HR
|
||||
pass
|
||||
elif 'color_shift' in aug:
|
||||
# Color shift
|
||||
pass
|
||||
elif 'interlacing' in aug:
|
||||
# Interlacing distortion
|
||||
pass
|
||||
elif 'chromatic_aberration' in aug:
|
||||
# Chromatic aberration
|
||||
pass
|
||||
elif 'noise' in aug:
|
||||
# Noise
|
||||
pass
|
||||
elif 'jpeg' in aug:
|
||||
# JPEG compression
|
||||
pass
|
||||
elif 'saturation' in aug:
|
||||
# Lightening / saturation
|
||||
pass
|
||||
return corrupted_imgs
|
||||
|
||||
return corrupted_imgs
|
||||
def apply_corruption(self, img, aug, rand_int):
|
||||
if 'color_quantization' in aug:
|
||||
# Color quantization
|
||||
quant_div = 2 ** ((rand_int % 3) + 2)
|
||||
img = img * 255
|
||||
img = (img // quant_div) * quant_div
|
||||
img = img / 255
|
||||
elif 'gaussian_blur' in aug:
|
||||
# Gaussian Blur
|
||||
kernel = 2 * (rand_int % 3) + 1
|
||||
img = cv2.GaussianBlur(img, (kernel, kernel), 3)
|
||||
elif 'motion_blur' in aug:
|
||||
# Motion blur
|
||||
intensity = 2 * (rand_int % 3) + 1
|
||||
angle = (rand_int // 3) % 360
|
||||
k = np.zeros((intensity, intensity), dtype=np.float32)
|
||||
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0),
|
||||
(intensity, intensity))
|
||||
k = k * (1.0 / np.sum(k))
|
||||
img = cv2.filter2D(img, -1, k)
|
||||
elif 'smooth_blur' in aug:
|
||||
# Smooth blur
|
||||
kernel = 2 * (rand_int % 3) + 1
|
||||
img = cv2.blur(img, ksize=(kernel, kernel))
|
||||
elif 'block_noise' in aug:
|
||||
# Large distortion blocks in part of an img, such as is used to mask out a face.
|
||||
pass
|
||||
elif 'lq_resampling' in aug:
|
||||
# Bicubic LR->HR
|
||||
pass
|
||||
elif 'color_shift' in aug:
|
||||
# Color shift
|
||||
pass
|
||||
elif 'interlacing' in aug:
|
||||
# Interlacing distortion
|
||||
pass
|
||||
elif 'chromatic_aberration' in aug:
|
||||
# Chromatic aberration
|
||||
pass
|
||||
elif 'noise' in aug:
|
||||
# Block noise
|
||||
noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4
|
||||
img += np.random.randn() * noise_intensity
|
||||
elif 'jpeg' in aug:
|
||||
# JPEG compression
|
||||
qf = (rand_int % 20 + 10) # Between 10-30
|
||||
# cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead.
|
||||
img = (img * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img)
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, "JPEG", quality=qf, optimice=True)
|
||||
buffer.seek(0)
|
||||
jpeg_img_bytes = np.asarray(bytearray(buffer.read()), dtype="uint8")
|
||||
img = read_img("buffer", jpeg_img_bytes, rgb=True)
|
||||
elif 'saturation' in aug:
|
||||
# Lightening / saturation
|
||||
saturation = float(rand_int % 10) * .03
|
||||
img = np.clip(img + saturation, a_max=1, a_min=0)
|
||||
|
||||
return img
|
||||
|
|
|
@ -1,231 +0,0 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
import torchvision.transforms.functional as F
|
||||
import lmdb
|
||||
import pyarrow
|
||||
|
||||
|
||||
# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to
|
||||
# where those tiles came from.
|
||||
class LmdbDatasetWithRef(data.Dataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
super(LmdbDatasetWithRef, self).__init__()
|
||||
self.opt = opt
|
||||
self.db = lmdb.open(self.opt['lmdb_path'], subdir=True, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
self.data_type = 'img'
|
||||
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
||||
with self.db.begin(write=False) as txn:
|
||||
self.keys = pyarrow.deserialize(txn.get(b'__keys__'))
|
||||
self.len = pyarrow.deserialize(txn.get(b'__len__'))\
|
||||
|
||||
def motion_blur(self, image, size, angle):
|
||||
k = np.zeros((size, size), dtype=np.float32)
|
||||
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
|
||||
k = k * (1.0 / np.sum(k))
|
||||
return cv2.filter2D(image, -1, k)
|
||||
|
||||
def resize_point(self, point, orig_dim, new_dim):
|
||||
oh, ow = orig_dim
|
||||
nh, nw = new_dim
|
||||
dh, dw = float(nh) / float(oh), float(nw) / float(ow)
|
||||
point[0] = int(dh * float(point[0]))
|
||||
point[1] = int(dw * float(point[1]))
|
||||
return point
|
||||
|
||||
def augment_tile(self, img_GT, img_LQ, strength=1):
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['target_size']
|
||||
|
||||
H, W, _ = img_GT.shape
|
||||
assert H >= GT_size and W >= GT_size
|
||||
|
||||
LQ_size = GT_size // scale
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if self.opt['use_blurring']:
|
||||
# Pick randomly between gaussian, motion, or no blur.
|
||||
blur_det = random.randint(0, 100)
|
||||
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
||||
blur_magnitude = max(1, int(blur_magnitude*strength))
|
||||
if blur_det < 40:
|
||||
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
||||
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
||||
elif blur_det < 70:
|
||||
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
|
||||
|
||||
return img_GT, img_LQ
|
||||
|
||||
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
|
||||
def pil_augment(self, img_LQ, strength=1):
|
||||
img_LQ = (img_LQ * 255).astype(np.uint8)
|
||||
img_LQ = Image.fromarray(img_LQ)
|
||||
if self.opt['use_compression_artifacts'] and random.random() > .25:
|
||||
sub_lo = 90 * strength
|
||||
sub_hi = 30 * strength
|
||||
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
|
||||
corruption_buffer = BytesIO()
|
||||
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
||||
corruption_buffer.seek(0)
|
||||
img_LQ = Image.open(corruption_buffer)
|
||||
|
||||
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
||||
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
||||
|
||||
return img_LQ
|
||||
|
||||
def __getitem__(self, index):
|
||||
scale = self.opt['scale']
|
||||
|
||||
# get the hq image and the ref image
|
||||
key = self.keys[index]
|
||||
ref_key = key[:key.index('_')]
|
||||
with self.db.begin(write=False) as txn:
|
||||
bytes_ref = txn.get(ref_key.encode())
|
||||
bytes_tile = txn.get(key.encode())
|
||||
unpacked_ref = pyarrow.deserialize(bytes_ref)
|
||||
unpacked_tile = pyarrow.deserialize(bytes_tile)
|
||||
gt_fullsize_ref = unpacked_ref[0]
|
||||
img_GT, gt_center = unpacked_tile
|
||||
|
||||
# TODO: synthesize gt_mask.
|
||||
gt_mask = np.ones(img_GT.shape[:2])
|
||||
orig_gt_dim = gt_fullsize_ref.shape[:2]
|
||||
|
||||
# Synthesize LQ by downsampling.
|
||||
if self.opt['phase'] == 'train':
|
||||
GT_size = self.opt['target_size']
|
||||
random_scale = random.choice(self.random_scale_list)
|
||||
if len(img_GT.shape) == 2:
|
||||
print("ERRAR:")
|
||||
print(img_GT.shape)
|
||||
print(full_path)
|
||||
H_s, W_s, _ = img_GT.shape
|
||||
|
||||
def _mod(n, random_scale, scale, thres):
|
||||
rlt = int(n * random_scale)
|
||||
rlt = (rlt // scale) * scale
|
||||
return thres if rlt < thres else rlt
|
||||
|
||||
H_s = _mod(H_s, random_scale, scale, GT_size)
|
||||
W_s = _mod(W_s, random_scale, scale, GT_size)
|
||||
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
||||
if img_GT.ndim == 2:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
H, W, _ = img_GT.shape
|
||||
|
||||
# using matlab imresize
|
||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||
lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True)
|
||||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
|
||||
orig_lq_dim = lq_fullsize_ref.shape[:2]
|
||||
|
||||
# Enforce force_resize constraints via clipping.
|
||||
h, w, _ = img_LQ.shape
|
||||
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
||||
h, w = (h - h % self.force_multiple), (w - w % self.force_multiple)
|
||||
img_LQ = img_LQ[:h, :w, :]
|
||||
lq_fullsize_ref = lq_fullsize_ref[:h, :w, :]
|
||||
h *= scale
|
||||
w *= scale
|
||||
img_GT = img_GT[:h, :w]
|
||||
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
|
||||
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
|
||||
|
||||
# Scale masks.
|
||||
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Scale center coords
|
||||
lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
|
||||
gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
||||
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
||||
lq_fullsize_ref = cv2.cvtColor(lq_fullsize_ref, cv2.COLOR_BGR2RGB)
|
||||
gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
if self.opt['phase'] == 'train':
|
||||
img_LQ = self.pil_augment(img_LQ)
|
||||
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
||||
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
||||
img_LQ = F.to_tensor(img_LQ)
|
||||
lq_fullsize_ref = F.to_tensor(lq_fullsize_ref)
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
|
||||
gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
|
||||
|
||||
if 'lq_noise' in self.opt.keys():
|
||||
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
|
||||
img_LQ += lq_noise
|
||||
lq_fullsize_ref += lq_noise
|
||||
|
||||
# Apply the masks to the full images.
|
||||
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
|
||||
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
|
||||
|
||||
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
|
||||
'lq_center': lq_center, 'gt_center': gt_center,
|
||||
'LQ_path': key, 'GT_path': key}
|
||||
return d
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'lmdb_path': 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref',
|
||||
'use_flip': True,
|
||||
'use_compression_artifacts': True,
|
||||
'use_blurring': True,
|
||||
'use_rot': True,
|
||||
'lq_noise': 5,
|
||||
'target_size': 128,
|
||||
'min_tile_size': 256,
|
||||
'scale': 2,
|
||||
'phase': 'train'
|
||||
}
|
||||
'''
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref'],
|
||||
'dataroot_GT_weights': [1],
|
||||
'force_multiple': 32,
|
||||
'scale': 2,
|
||||
'phase': 'test'
|
||||
}
|
||||
'''
|
||||
|
||||
ds = LmdbDatasetWithRef(opt)
|
||||
import os
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
for i in range(300, len(ds)):
|
||||
print(i)
|
||||
o = ds[i]
|
||||
for k, v in o.items():
|
||||
if 'path' not in k:
|
||||
#if 'full' in k:
|
||||
#masked = v[:3, :, :] * v[3]
|
||||
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
#v = v[:3, :, :]
|
||||
import torchvision
|
||||
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
|
@ -5,6 +5,8 @@ import os
|
|||
from bisect import bisect_left
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
# Builds a dataset composed of a set of folders. Each folder represents a single high resolution image that has been
|
||||
|
@ -23,8 +25,17 @@ class SingleImageDataset(data.Dataset):
|
|||
self.weights = [1]
|
||||
else:
|
||||
self.weights = opt['weights']
|
||||
|
||||
# See if there is a cached directory listing and use that rather than re-scanning everything. This will greatly
|
||||
# reduce startup costs.
|
||||
self.chunks = []
|
||||
for path, weight in zip(self.paths, self.weights):
|
||||
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
|
||||
cache_path = os.path.join(path, 'cache.pth')
|
||||
if os.path.exists(cache_path):
|
||||
chunks = torch.load(cache_path)
|
||||
else:
|
||||
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
|
||||
torch.save(chunks, cache_path)
|
||||
for w in range(weight):
|
||||
self.chunks.extend(chunks)
|
||||
|
||||
|
@ -36,25 +47,18 @@ class SingleImageDataset(data.Dataset):
|
|||
start += len(c)
|
||||
self.len = start
|
||||
|
||||
def binary_search(elem, sorted_list):
|
||||
# https://docs.python.org/3/library/bisect.html
|
||||
'Locate the leftmost value exactly equal to x'
|
||||
i = bisect_left(sorted_list, elem)
|
||||
if i != len(sorted_list) and sorted_list[i] == elem:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def resize_point(self, point, orig_dim, new_dim):
|
||||
oh, ow = orig_dim
|
||||
nh, nw = new_dim
|
||||
dh, dw = float(nh) / float(oh), float(nw) / float(ow)
|
||||
point[0] = int(dh * float(point[0]))
|
||||
point[1] = int(dw * float(point[1]))
|
||||
point = int(dh * float(point[0])), int(dw * float(point[1]))
|
||||
return point
|
||||
|
||||
def __getitem__(self, item):
|
||||
chunk_ind = self.binary_search(item, self.starting_indices)
|
||||
hq, hq_ref, hq_center, path = self.chunks[item-self.starting_indices[chunk_ind]]
|
||||
chunk_ind = bisect_left(self.starting_indices, item)
|
||||
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||
hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item-self.starting_indices[chunk_ind]]
|
||||
|
||||
# Enforce size constraints
|
||||
h, w, _ = hq.shape
|
||||
|
@ -63,6 +67,7 @@ class SingleImageDataset(data.Dataset):
|
|||
target_size = (self.target_hq_size, self.target_hq_size)
|
||||
hq = cv2.resize(hq, target_size, interpolation=cv2.INTER_LINEAR)
|
||||
hq_ref = cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_LINEAR)
|
||||
hq_mask = cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_LINEAR)
|
||||
hq_center = self.resize_point(hq_center, (h, w), target_size)
|
||||
h, w = self.target_hq_size, self.target_hq_size
|
||||
hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image.
|
||||
|
@ -71,6 +76,7 @@ class SingleImageDataset(data.Dataset):
|
|||
hq_center = self.resize_point(hq_center, hq.shape[:1], (h, w))
|
||||
hq = hq[:h, :w, :]
|
||||
hq_ref = hq_ref[:h, :w, :]
|
||||
hq_mask = hq_mask[:h, :w, :]
|
||||
|
||||
# Synthesize the LQ image
|
||||
if self.for_eval:
|
||||
|
@ -78,20 +84,70 @@ class SingleImageDataset(data.Dataset):
|
|||
else:
|
||||
lq = cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
lq_ref = cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
lq_center = self.resize_point(hq_center, (h, w), lq.shape[:1])
|
||||
lq_mask = cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
lq_center = self.resize_point(hq_center, (h, w), lq.shape[:2])
|
||||
|
||||
# Corrupt the LQ image
|
||||
lq = self.corruptor.corrupt_images([lq])
|
||||
lq = self.corruptor.corrupt_images([lq])[0]
|
||||
|
||||
# Convert to torch tensor
|
||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hq, (2, 0, 1)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(hq_ref, (2, 0, 1)))).float()
|
||||
lq = F.to_tensor(lq)
|
||||
lq_ref = F.to_tensor(lq_ref)
|
||||
hq_mask = torch.from_numpy(np.ascontiguousarray(hq_mask)).unsqueeze(dim=0)
|
||||
hq_ref = torch.cat([hq_ref, hq_mask], dim=0)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(lq, (2, 0, 1)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(lq_ref, (2, 0, 1)))).float()
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
|
||||
|
||||
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': lq_center, 'gt_center': hq_center,
|
||||
'LQ_path': path, 'GT_path': path}
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
return self.len
|
||||
|
||||
|
||||
self.corruptor = ImageCorruptor(opt)
|
||||
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
|
||||
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
|
||||
self.for_eval = opt['eval'] if 'eval' in opt.keys() else False
|
||||
self.scale = opt['scale'] if not self.for_eval else 1
|
||||
self.paths = opt['paths']
|
||||
if not isinstance(self.paths, list):
|
||||
self.paths = [self.paths]
|
||||
self.weights = [1]
|
||||
else:
|
||||
self.weights = opt['weights']
|
||||
for path, weight in zip(self.paths, self.weights):
|
||||
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
|
||||
for w in range(weight):
|
||||
self.chunks.extend(chunks)
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\images\\flickr\\testbed'],
|
||||
'weights': [1],
|
||||
'target_size': 128,
|
||||
'force_multiple': 32,
|
||||
'scale': 2,
|
||||
'eval': False,
|
||||
'fixed_corruptions': ['jpeg'],
|
||||
'random_corruptions': ['color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise', 'saturation'],
|
||||
'num_corrupts_per_image': 1
|
||||
}
|
||||
|
||||
ds = SingleImageDataset(opt)
|
||||
import os
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
for i in range(0, len(ds)):
|
||||
o = ds[i]
|
||||
for k, v in o.items():
|
||||
if 'path' not in k and 'center' not in k:
|
||||
#if 'full' in k:
|
||||
#masked = v[:3, :, :] * v[3]
|
||||
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
#v = v[:3, :, :]
|
||||
import torchvision
|
||||
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
|
@ -87,13 +87,18 @@ def _read_img_lmdb(env, key, size):
|
|||
return img
|
||||
|
||||
|
||||
def read_img(env, path, size=None):
|
||||
"""read image by cv2 or from lmdb
|
||||
def read_img(env, path, size=None, rgb=False):
|
||||
"""read image by cv2 or from lmdb or from a buffer (in which case path=buffer)
|
||||
return: Numpy float32, HWC, BGR, [0,1]"""
|
||||
if env is None: # img
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
else:
|
||||
elif env is 'lmdb':
|
||||
img = _read_img_lmdb(env, path, size)
|
||||
elif env is 'buffer':
|
||||
img = cv2.imdecode(path, cv2.IMREAD_UNCHANGED)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported env: %s" % (env,))
|
||||
|
||||
if img is None:
|
||||
print("Image error: %s" % (path,))
|
||||
img = img.astype(np.float32) / 255.
|
||||
|
@ -102,6 +107,9 @@ def read_img(env, path, size=None):
|
|||
# some images have 4 channels
|
||||
if img.shape[2] > 3:
|
||||
img = img[:, :, :3]
|
||||
|
||||
if rgb:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
|
|
|
@ -16,17 +16,17 @@ def main():
|
|||
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
|
||||
split_img = False
|
||||
opt = {}
|
||||
opt['n_thread'] = 12
|
||||
opt['n_thread'] = 0
|
||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
|
||||
if mode == 'single':
|
||||
opt['dest'] = 'file'
|
||||
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_segments'
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_with_refs'
|
||||
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\new_tiles'
|
||||
opt['crop_sz'] = [256, 512, 1024] # the size of each sub-image
|
||||
opt['step'] = 256 # step of the sliding crop window
|
||||
opt['step'] = [256, 512, 1024] # step of the sliding crop window
|
||||
opt['thres_sz'] = 128 # size threshold
|
||||
opt['resize_final_img'] = [1, .5, .25]
|
||||
opt['only_resize'] = False
|
||||
|
@ -158,7 +158,7 @@ class FileWriter:
|
|||
|
||||
# Writes the given reference image to the db and returns its ID.
|
||||
def write_reference_image(self, ref_img, path):
|
||||
ref_img, _ = ref_img # Encoded with a center point, which is irrelevant for the reference image.
|
||||
ref_img, _, _ = ref_img # Encoded with a center point, which is irrelevant for the reference image.
|
||||
img_name = osp.basename(path).replace(".jpg", "").replace(".png", "")
|
||||
self.ref_center_points[img_name] = {}
|
||||
self.save_image(img_name, "ref.jpg", ref_img)
|
||||
|
@ -170,14 +170,18 @@ class FileWriter:
|
|||
def write_tile_image(self, ref_id, tile_image):
|
||||
id = self.get_next_unique_id()
|
||||
ref_name = self.ref_ids_to_names[ref_id]
|
||||
img, center = tile_image
|
||||
self.ref_center_points[ref_name][id] = center
|
||||
img, center, tile_sz = tile_image
|
||||
self.ref_center_points[ref_name][id] = center, tile_sz
|
||||
self.save_image(ref_name, "%08i.jpg" % (id,), img)
|
||||
return id
|
||||
|
||||
def close(self):
|
||||
def flush(self):
|
||||
for ref_name, cps in self.ref_center_points.items():
|
||||
torch.save(cps, osp.join(self.folder, ref_name, "centers.pt"))
|
||||
self.ref_center_points = {}
|
||||
|
||||
def close(self):
|
||||
self.flush()
|
||||
|
||||
class TiledDataset(data.Dataset):
|
||||
def __init__(self, opt, split_mode=False):
|
||||
|
@ -192,8 +196,9 @@ class TiledDataset(data.Dataset):
|
|||
else:
|
||||
return self.get(index, False, False)
|
||||
|
||||
def get_for_scale(self, img, split_mode, left_img, crop_sz, resize_factor):
|
||||
step = self.opt['step']
|
||||
def get_for_scale(self, img, split_mode, left_image, crop_sz, step, resize_factor, ref_resize_factor):
|
||||
assert not left_image # Split image not yet supported, False is the default value.
|
||||
|
||||
thres_sz = self.opt['thres_sz']
|
||||
|
||||
h, w, c = img.shape
|
||||
|
@ -215,15 +220,14 @@ class TiledDataset(data.Dataset):
|
|||
for y in w_space:
|
||||
index += 1
|
||||
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
||||
center_point = (x + crop_sz // 2, y + crop_sz // 2)
|
||||
# Center point needs to be resized by ref_resize_factor - since it is relative to the reference image.
|
||||
center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor))
|
||||
crop_img = np.ascontiguousarray(crop_img)
|
||||
if 'resize_final_img' in self.opt.keys():
|
||||
# Resize too.
|
||||
center_point = (int(center_point[0] * resize_factor), int(center_point[1] * resize_factor))
|
||||
crop_img = cv2.resize(crop_img, dsize, interpolation=cv2.INTER_AREA)
|
||||
success, buffer = cv2.imencode(".jpg", crop_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||
assert success
|
||||
results.append((buffer, center_point))
|
||||
results.append((buffer, center_point, int(crop_sz // ref_resize_factor)))
|
||||
return results
|
||||
|
||||
def get(self, index, split_mode, left_img):
|
||||
|
@ -255,15 +259,16 @@ class TiledDataset(data.Dataset):
|
|||
|
||||
tile_dim = int(self.opt['crop_sz'][0] * self.opt['resize_final_img'][0])
|
||||
dsize = (tile_dim, tile_dim)
|
||||
ref_resize_factor = h / tile_dim
|
||||
|
||||
# Reference image should always be first entry in results.
|
||||
ref_img = cv2.resize(img, dsize, interpolation=cv2.INTER_AREA)
|
||||
success, ref_buffer = cv2.imencode(".jpg", ref_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||
assert success
|
||||
results = [(ref_buffer, (-1,-1))]
|
||||
results = [(ref_buffer, (-1,-1), (-1,-1))]
|
||||
|
||||
for crop_sz, resize_factor in zip(self.opt['crop_sz'], self.opt['resize_final_img']):
|
||||
results.extend(self.get_for_scale(img, split_mode, left_img, crop_sz, resize_factor))
|
||||
for crop_sz, resize_factor, step in zip(self.opt['crop_sz'], self.opt['resize_final_img'], self.opt['step']):
|
||||
results.extend(self.get_for_scale(img, split_mode, left_img, crop_sz, step, resize_factor, ref_resize_factor))
|
||||
return results, path
|
||||
|
||||
def __len__(self):
|
||||
|
@ -286,6 +291,7 @@ def extract_single(opt, writer, split_img=False):
|
|||
ref_id = writer.write_reference_image(imgs[0], path)
|
||||
for tile in imgs[1:]:
|
||||
writer.write_tile_image(ref_id, tile)
|
||||
writer.flush()
|
||||
writer.close()
|
||||
|
||||
|
||||
|
|
|
@ -238,6 +238,128 @@ class SPSRNet(nn.Module):
|
|||
return x_out_branch, x_out, x_grad
|
||||
|
||||
|
||||
class SwitchedSpsr(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
|
||||
super(SwitchedSpsr, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
# switch options
|
||||
transformation_filters = nf
|
||||
switch_filters = nf
|
||||
switch_reductions = 3
|
||||
switch_processing_layers = 2
|
||||
self.transformation_counts = xforms
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
|
||||
switch_processing_layers, self.transformation_counts, use_exp2=True)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=3,
|
||||
weight_init_factor=.1)
|
||||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
||||
# Grad branch
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
|
||||
switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
# Upsampling
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
# Conv used to output grad branch shortcut.
|
||||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
|
||||
# Conjoin branch.
|
||||
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
|
||||
transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=4,
|
||||
weight_init_factor=.1)
|
||||
pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1)
|
||||
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw]
|
||||
self.attentions = None
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x):
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x = self.model_fea_conv(x)
|
||||
|
||||
x1, a1 = self.sw1(x, True)
|
||||
x2, a2 = self.sw2(x1, True)
|
||||
x_fea = self.feature_lr_conv(x2)
|
||||
x_fea = self.feature_hr_conv2(x_fea)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_hr_conv(x_grad)
|
||||
x_out_branch = self.upsample_grad(x_grad)
|
||||
x_out_branch = self.grad_branch_output_conv(x_out_branch)
|
||||
|
||||
x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1)
|
||||
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=x_fea, identity=x_fea, output_attention_weights=True)
|
||||
x_out = self.final_lr_conv(x__branch_pretrain_cat)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
x_out = self.final_hr_conv2(x_out)
|
||||
|
||||
self.attentions = [a1, a2, a3, a4]
|
||||
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
temp = max(1, 1 + self.init_temperature *
|
||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||
self.set_temperature(temp)
|
||||
if step % 200 == 0:
|
||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"switch_temperature": temp}
|
||||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
|
||||
class RefJoiner(nn.Module):
|
||||
def __init__(self, nf):
|
||||
super(RefJoiner, self).__init__()
|
||||
|
|
|
@ -49,6 +49,10 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
elif which_model == 'spsr_net_improved':
|
||||
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == "spsr_switched":
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
elif which_model == "spsr_switched_with_ref2" or which_model == "spsr3":
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.SwitchedSpsrWithRef2(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
|
|
Loading…
Reference in New Issue
Block a user