Change GT_size to target_size
This commit is contained in:
parent
cc834bd5a3
commit
af5dfaa90d
127
codes/data/GTLQ_dataset.py
Normal file
127
codes/data/GTLQ_dataset.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
import lmdb
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
|
||||
|
||||
class LQGTDataset(data.Dataset):
|
||||
"""
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
|
||||
If only GT images are provided, generate LQ images on-the-fly.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(LQGTDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.data_type = self.opt['data_type']
|
||||
self.paths_LQ, self.paths_GT = None, None
|
||||
self.sizes_LQ, self.sizes_GT = None, None
|
||||
self.LQ_env, self.GT_env = None, None # environments for lmdb
|
||||
|
||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
if self.paths_LQ and self.paths_GT:
|
||||
assert len(self.paths_LQ) == len(
|
||||
self.paths_GT
|
||||
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
|
||||
len(self.paths_LQ), len(self.paths_GT))
|
||||
self.random_scale_list = [1]
|
||||
|
||||
def _init_lmdb(self):
|
||||
# https://github.com/chainer/chainermn/issues/129
|
||||
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||
self._init_lmdb()
|
||||
GT_path, LQ_path = None, None
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['GT_size']
|
||||
|
||||
# get GT image
|
||||
GT_path = self.paths_GT[index]
|
||||
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
||||
] if self.data_type == 'lmdb' else None
|
||||
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
||||
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
|
||||
img_GT = util.modcrop(img_GT, scale)
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
||||
|
||||
# get LQ image
|
||||
if self.paths_LQ:
|
||||
LQ_path = self.paths_LQ[index]
|
||||
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
||||
] if self.data_type == 'lmdb' else None
|
||||
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
||||
else: # down-sampling on-the-fly
|
||||
# randomly scale during training
|
||||
if self.opt['phase'] == 'train':
|
||||
random_scale = random.choice(self.random_scale_list)
|
||||
H_s, W_s, _ = img_GT.shape
|
||||
|
||||
def _mod(n, random_scale, scale, thres):
|
||||
rlt = int(n * random_scale)
|
||||
rlt = (rlt // scale) * scale
|
||||
return thres if rlt < thres else rlt
|
||||
|
||||
H_s = _mod(H_s, random_scale, scale, GT_size)
|
||||
W_s = _mod(W_s, random_scale, scale, GT_size)
|
||||
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
||||
if img_GT.ndim == 2:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
H, W, _ = img_GT.shape
|
||||
# using matlab imresize
|
||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
# if the image size is too small
|
||||
H, W, _ = img_GT.shape
|
||||
if H < GT_size or W < GT_size:
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
# using matlab imresize
|
||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
|
||||
H, W, C = img_LQ.shape
|
||||
LQ_size = GT_size // scale
|
||||
|
||||
# randomly crop
|
||||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
||||
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
|
||||
self.opt['use_rot'])
|
||||
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_LQ = util.channel_convert(C, self.opt['color'],
|
||||
[img_LQ])[0] # TODO during val no definition
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||
img_LQ = img_LQ[:, :, [2, 1, 0]]
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
||||
|
||||
if LQ_path is None:
|
||||
LQ_path = GT_path
|
||||
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths_GT)
|
|
@ -43,7 +43,7 @@ class LQGTDataset(data.Dataset):
|
|||
self._init_lmdb()
|
||||
GT_path, LQ_path = None, None
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['GT_size']
|
||||
GT_size = self.opt['target_size']
|
||||
|
||||
# get GT image
|
||||
GT_path = self.paths_GT[index]
|
||||
|
|
|
@ -41,7 +41,7 @@ class REDSDataset(data.Dataset):
|
|||
self.half_N_frames = opt['N_frames'] // 2
|
||||
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||||
self.data_type = self.opt['data_type']
|
||||
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
|
||||
self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs
|
||||
#### directly load image keys
|
||||
if self.data_type == 'lmdb':
|
||||
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
|
@ -107,7 +107,7 @@ class REDSDataset(data.Dataset):
|
|||
self._init_lmdb()
|
||||
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['GT_size']
|
||||
GT_size = self.opt['target_size']
|
||||
key = self.paths_GT[index]
|
||||
name_a, name_b = key.split('_')
|
||||
center_frame_idx = int(name_b)
|
||||
|
|
|
@ -38,7 +38,7 @@ class Vimeo90KDataset(data.Dataset):
|
|||
|
||||
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||||
self.data_type = self.opt['data_type']
|
||||
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
|
||||
self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs
|
||||
|
||||
#### determine the LQ frame list
|
||||
'''
|
||||
|
@ -104,7 +104,7 @@ class Vimeo90KDataset(data.Dataset):
|
|||
self._init_lmdb()
|
||||
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['GT_size']
|
||||
GT_size = self.opt['target_size']
|
||||
key = self.paths_GT[index]
|
||||
name_a, name_b = key.split('_')
|
||||
#### get the GT image (as the center frame)
|
||||
|
|
|
@ -23,7 +23,7 @@ def main():
|
|||
opt['use_shuffle'] = True
|
||||
opt['n_workers'] = 8
|
||||
opt['batch_size'] = 16
|
||||
opt['GT_size'] = 256
|
||||
opt['target_size'] = 256
|
||||
opt['LQ_size'] = 64
|
||||
opt['scale'] = 4
|
||||
opt['use_flip'] = True
|
||||
|
@ -43,7 +43,7 @@ def main():
|
|||
opt['use_shuffle'] = True
|
||||
opt['n_workers'] = 8
|
||||
opt['batch_size'] = 16
|
||||
opt['GT_size'] = 256
|
||||
opt['target_size'] = 256
|
||||
opt['LQ_size'] = 64
|
||||
opt['scale'] = 4
|
||||
opt['use_flip'] = True
|
||||
|
@ -62,7 +62,7 @@ def main():
|
|||
opt['use_shuffle'] = True
|
||||
opt['n_workers'] = 8
|
||||
opt['batch_size'] = 16
|
||||
opt['GT_size'] = 128
|
||||
opt['target_size'] = 128
|
||||
opt['scale'] = 4
|
||||
opt['use_flip'] = True
|
||||
opt['use_rot'] = True
|
||||
|
|
|
@ -3,20 +3,23 @@ import models.archs.SRResNet_arch as SRResNet_arch
|
|||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||
import models.archs.EDVR_arch as EDVR_arch
|
||||
|
||||
import math
|
||||
|
||||
# Generator
|
||||
def define_G(opt):
|
||||
opt_net = opt['network_G']
|
||||
which_model = opt_net['which_model_G']
|
||||
scale = opt['scale']
|
||||
|
||||
# image restoration
|
||||
if which_model == 'MSRResNet':
|
||||
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == 'RRDBNet':
|
||||
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
|
||||
scale_per_step = math.sqrt(scale)
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'])
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
|
||||
# video restoration
|
||||
elif which_model == 'EDVR':
|
||||
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
||||
|
@ -24,6 +27,7 @@ def define_G(opt):
|
|||
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
|
||||
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
|
||||
w_TSA=opt_net['w_TSA'])
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
|
@ -32,7 +36,7 @@ def define_G(opt):
|
|||
|
||||
# Discriminator
|
||||
def define_D(opt):
|
||||
img_sz = opt['datasets']['train']['GT_size']
|
||||
img_sz = opt['datasets']['train']['target_size']
|
||||
opt_net = opt['network_D']
|
||||
which_model = opt_net['which_model_D']
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ datasets:
|
|||
use_shuffle: true
|
||||
n_workers: 3 # per GPU
|
||||
batch_size: 32
|
||||
GT_size: 256
|
||||
target_size: 256
|
||||
LQ_size: 64
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
|
|
|
@ -22,7 +22,7 @@ datasets:
|
|||
use_shuffle: true
|
||||
n_workers: 3 # per GPU
|
||||
batch_size: 32
|
||||
GT_size: 256
|
||||
target_size: 256
|
||||
LQ_size: 64
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
|
|
|
@ -4,28 +4,28 @@ use_tb_logger: true
|
|||
model: srgan
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [2]
|
||||
gpu_ids: [0]
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: DIV2K
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
|
||||
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
|
||||
dataroot_GT: ../datasets/div2k/DIV2K800_sub
|
||||
dataroot_LQ: ../datasets/div2k/DIV2K800_sub_bicLRx4
|
||||
|
||||
use_shuffle: true
|
||||
n_workers: 6 # per GPU
|
||||
n_workers: 16 # per GPU
|
||||
batch_size: 16
|
||||
GT_size: 128
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
val:
|
||||
name: val_set14
|
||||
name: div2kval
|
||||
mode: LQGT
|
||||
dataroot_GT: ../datasets/val_set14/Set14
|
||||
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||
dataroot_GT: ../datasets/div2k/div2k_valid_hr
|
||||
dataroot_LQ: ../datasets/div2k/div2k_valid_lr_bicubic
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
|
@ -41,7 +41,7 @@ network_D:
|
|||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/pretrained_models/RRDB_PSNR_x4.pth
|
||||
pretrain_model_G: ../experiments/RRDB_PSNR_x4.pth
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
|
@ -73,9 +73,9 @@ train:
|
|||
D_init_iters: 0
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e3
|
||||
val_freq: !!float 5e2
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
||||
print_freq: 50
|
||||
save_checkpoint_freq: !!float 5e2
|
||||
|
|
|
@ -20,7 +20,7 @@ datasets:
|
|||
use_shuffle: true
|
||||
n_workers: 6 # per GPU
|
||||
batch_size: 16
|
||||
GT_size: 128
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
|
|
|
@ -20,7 +20,7 @@ datasets:
|
|||
use_shuffle: true
|
||||
n_workers: 6 # per GPU
|
||||
batch_size: 16
|
||||
GT_size: 128
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
|
|
Loading…
Reference in New Issue
Block a user