diff --git a/codes/data/GTLQ_dataset.py b/codes/data/GTLQ_dataset.py new file mode 100644 index 00000000..a52817eb --- /dev/null +++ b/codes/data/GTLQ_dataset.py @@ -0,0 +1,127 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util + + +class LQGTDataset(data.Dataset): + """ + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. + If only GT images are provided, generate LQ images on-the-fly. + """ + + def __init__(self, opt): + super(LQGTDataset, self).__init__() + self.opt = opt + self.data_type = self.opt['data_type'] + self.paths_LQ, self.paths_GT = None, None + self.sizes_LQ, self.sizes_GT = None, None + self.LQ_env, self.GT_env = None, None # environments for lmdb + + self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) + self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + assert self.paths_GT, 'Error: GT path is empty.' + if self.paths_LQ and self.paths_GT: + assert len(self.paths_LQ) == len( + self.paths_GT + ), 'GT and LQ datasets have different number of images - {}, {}.'.format( + len(self.paths_LQ), len(self.paths_GT)) + self.random_scale_list = [1] + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + + def __getitem__(self, index): + if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): + self._init_lmdb() + GT_path, LQ_path = None, None + scale = self.opt['scale'] + GT_size = self.opt['GT_size'] + + # get GT image + GT_path = self.paths_GT[index] + resolution = [int(s) for s in self.sizes_GT[index].split('_') + ] if self.data_type == 'lmdb' else None + img_GT = util.read_img(self.GT_env, GT_path, resolution) + if self.opt['phase'] != 'train': # modcrop in the validation / test phase + img_GT = util.modcrop(img_GT, scale) + if self.opt['color']: # change color space if necessary + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + + # get LQ image + if self.paths_LQ: + LQ_path = self.paths_LQ[index] + resolution = [int(s) for s in self.sizes_LQ[index].split('_') + ] if self.data_type == 'lmdb' else None + img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) + else: # down-sampling on-the-fly + # randomly scale during training + if self.opt['phase'] == 'train': + random_scale = random.choice(self.random_scale_list) + H_s, W_s, _ = img_GT.shape + + def _mod(n, random_scale, scale, thres): + rlt = int(n * random_scale) + rlt = (rlt // scale) * scale + return thres if rlt < thres else rlt + + H_s = _mod(H_s, random_scale, scale, GT_size) + W_s = _mod(W_s, random_scale, scale, GT_size) + img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) + if img_GT.ndim == 2: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) + + H, W, _ = img_GT.shape + # using matlab imresize + img_LQ = util.imresize_np(img_GT, 1 / scale, True) + if img_LQ.ndim == 2: + img_LQ = np.expand_dims(img_LQ, axis=2) + + if self.opt['phase'] == 'train': + # if the image size is too small + H, W, _ = img_GT.shape + if H < GT_size or W < GT_size: + img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + # using matlab imresize + img_LQ = util.imresize_np(img_GT, 1 / scale, True) + if img_LQ.ndim == 2: + img_LQ = np.expand_dims(img_LQ, axis=2) + + H, W, C = img_LQ.shape + LQ_size = GT_size // scale + + # randomly crop + rnd_h = random.randint(0, max(0, H - LQ_size)) + rnd_w = random.randint(0, max(0, W - LQ_size)) + img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] + rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) + img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] + + # augmentation - flip, rotate + img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], + self.opt['use_rot']) + + if self.opt['color']: # change color space if necessary + img_LQ = util.channel_convert(C, self.opt['color'], + [img_LQ])[0] # TODO during val no definition + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = img_GT[:, :, [2, 1, 0]] + img_LQ = img_LQ[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + + if LQ_path is None: + LQ_path = GT_path + return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.paths_GT) diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index a52817eb..1d7645b4 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -43,7 +43,7 @@ class LQGTDataset(data.Dataset): self._init_lmdb() GT_path, LQ_path = None, None scale = self.opt['scale'] - GT_size = self.opt['GT_size'] + GT_size = self.opt['target_size'] # get GT image GT_path = self.paths_GT[index] diff --git a/codes/data/REDS_dataset.py b/codes/data/REDS_dataset.py index 5643dbb2..36f69dcd 100644 --- a/codes/data/REDS_dataset.py +++ b/codes/data/REDS_dataset.py @@ -41,7 +41,7 @@ class REDSDataset(data.Dataset): self.half_N_frames = opt['N_frames'] // 2 self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] self.data_type = self.opt['data_type'] - self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs + self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs #### directly load image keys if self.data_type == 'lmdb': self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT']) @@ -107,7 +107,7 @@ class REDSDataset(data.Dataset): self._init_lmdb() scale = self.opt['scale'] - GT_size = self.opt['GT_size'] + GT_size = self.opt['target_size'] key = self.paths_GT[index] name_a, name_b = key.split('_') center_frame_idx = int(name_b) diff --git a/codes/data/Vimeo90K_dataset.py b/codes/data/Vimeo90K_dataset.py index 914e1dba..324e3a10 100644 --- a/codes/data/Vimeo90K_dataset.py +++ b/codes/data/Vimeo90K_dataset.py @@ -38,7 +38,7 @@ class Vimeo90KDataset(data.Dataset): self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] self.data_type = self.opt['data_type'] - self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs + self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs #### determine the LQ frame list ''' @@ -104,7 +104,7 @@ class Vimeo90KDataset(data.Dataset): self._init_lmdb() scale = self.opt['scale'] - GT_size = self.opt['GT_size'] + GT_size = self.opt['target_size'] key = self.paths_GT[index] name_a, name_b = key.split('_') #### get the GT image (as the center frame) diff --git a/codes/data_scripts/test_dataloader.py b/codes/data_scripts/test_dataloader.py index 5f580793..73df5242 100644 --- a/codes/data_scripts/test_dataloader.py +++ b/codes/data_scripts/test_dataloader.py @@ -23,7 +23,7 @@ def main(): opt['use_shuffle'] = True opt['n_workers'] = 8 opt['batch_size'] = 16 - opt['GT_size'] = 256 + opt['target_size'] = 256 opt['LQ_size'] = 64 opt['scale'] = 4 opt['use_flip'] = True @@ -43,7 +43,7 @@ def main(): opt['use_shuffle'] = True opt['n_workers'] = 8 opt['batch_size'] = 16 - opt['GT_size'] = 256 + opt['target_size'] = 256 opt['LQ_size'] = 64 opt['scale'] = 4 opt['use_flip'] = True @@ -62,7 +62,7 @@ def main(): opt['use_shuffle'] = True opt['n_workers'] = 8 opt['batch_size'] = 16 - opt['GT_size'] = 128 + opt['target_size'] = 128 opt['scale'] = 4 opt['use_flip'] = True opt['use_rot'] = True diff --git a/codes/models/networks.py b/codes/models/networks.py index ddd666a4..2b79249b 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -3,20 +3,23 @@ import models.archs.SRResNet_arch as SRResNet_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.EDVR_arch as EDVR_arch - +import math # Generator def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] + scale = opt['scale'] # image restoration if which_model == 'MSRResNet': netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': + # RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB. + scale_per_step = math.sqrt(scale) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb']) + nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step) # video restoration elif which_model == 'EDVR': netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], @@ -24,6 +27,7 @@ def define_G(opt): back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) + else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) @@ -32,7 +36,7 @@ def define_G(opt): # Discriminator def define_D(opt): - img_sz = opt['datasets']['train']['GT_size'] + img_sz = opt['datasets']['train']['target_size'] opt_net = opt['network_D'] which_model = opt_net['which_model_D'] diff --git a/codes/options/train/train_EDVR_M.yml b/codes/options/train/train_EDVR_M.yml index ed0916c0..8a1c55df 100644 --- a/codes/options/train/train_EDVR_M.yml +++ b/codes/options/train/train_EDVR_M.yml @@ -22,7 +22,7 @@ datasets: use_shuffle: true n_workers: 3 # per GPU batch_size: 32 - GT_size: 256 + target_size: 256 LQ_size: 64 use_flip: true use_rot: true diff --git a/codes/options/train/train_EDVR_woTSA_M.yml b/codes/options/train/train_EDVR_woTSA_M.yml index 9f48573c..cd30f2ab 100644 --- a/codes/options/train/train_EDVR_woTSA_M.yml +++ b/codes/options/train/train_EDVR_woTSA_M.yml @@ -22,7 +22,7 @@ datasets: use_shuffle: true n_workers: 3 # per GPU batch_size: 32 - GT_size: 256 + target_size: 256 LQ_size: 64 use_flip: true use_rot: true diff --git a/codes/options/train/train_ESRGAN.yml b/codes/options/train/train_ESRGAN.yml index 720f8652..cd6e09fb 100644 --- a/codes/options/train/train_ESRGAN.yml +++ b/codes/options/train/train_ESRGAN.yml @@ -4,28 +4,28 @@ use_tb_logger: true model: srgan distortion: sr scale: 4 -gpu_ids: [2] +gpu_ids: [0] #### datasets datasets: train: name: DIV2K mode: LQGT - dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb - dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb + dataroot_GT: ../datasets/div2k/DIV2K800_sub + dataroot_LQ: ../datasets/div2k/DIV2K800_sub_bicLRx4 use_shuffle: true - n_workers: 6 # per GPU + n_workers: 16 # per GPU batch_size: 16 - GT_size: 128 + target_size: 128 use_flip: true use_rot: true color: RGB val: - name: val_set14 + name: div2kval mode: LQGT - dataroot_GT: ../datasets/val_set14/Set14 - dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 + dataroot_GT: ../datasets/div2k/div2k_valid_hr + dataroot_LQ: ../datasets/div2k/div2k_valid_lr_bicubic #### network structures network_G: @@ -41,7 +41,7 @@ network_D: #### path path: - pretrain_model_G: ../experiments/pretrained_models/RRDB_PSNR_x4.pth + pretrain_model_G: ../experiments/RRDB_PSNR_x4.pth strict_load: true resume_state: ~ @@ -73,9 +73,9 @@ train: D_init_iters: 0 manual_seed: 10 - val_freq: !!float 5e3 + val_freq: !!float 5e2 #### logger logger: - print_freq: 100 - save_checkpoint_freq: !!float 5e3 + print_freq: 50 + save_checkpoint_freq: !!float 5e2 diff --git a/codes/options/train/train_SRGAN.yml b/codes/options/train/train_SRGAN.yml index 8aa48727..6835601c 100644 --- a/codes/options/train/train_SRGAN.yml +++ b/codes/options/train/train_SRGAN.yml @@ -20,7 +20,7 @@ datasets: use_shuffle: true n_workers: 6 # per GPU batch_size: 16 - GT_size: 128 + target_size: 128 use_flip: true use_rot: true color: RGB diff --git a/codes/options/train/train_SRResNet.yml b/codes/options/train/train_SRResNet.yml index be6fd4da..15468dce 100644 --- a/codes/options/train/train_SRResNet.yml +++ b/codes/options/train/train_SRResNet.yml @@ -20,7 +20,7 @@ datasets: use_shuffle: true n_workers: 6 # per GPU batch_size: 16 - GT_size: 128 + target_size: 128 use_flip: true use_rot: true color: RGB