From 037933ba66054bb69db47d499694e6d249fba0ec Mon Sep 17 00:00:00 2001 From: XintaoWang Date: Fri, 23 Aug 2019 21:42:47 +0800 Subject: [PATCH] mmsr --- .flake8 | 6 + .gitignore | 121 +++ .style.yapf | 4 + README.md | 47 ++ codes/data/LQGT_dataset.py | 127 ++++ codes/data/LQ_dataset.py | 47 ++ codes/data/REDS_dataset.py | 210 ++++++ codes/data/Vimeo90K_dataset.py | 167 +++++ codes/data/__init__.py | 49 ++ codes/data/data_sampler.py | 65 ++ codes/data/util.py | 543 ++++++++++++++ codes/data/video_test_dataset.py | 84 +++ codes/data_scripts/create_lmdb.py | 411 +++++++++++ codes/data_scripts/extract_subimages.py | 141 ++++ codes/data_scripts/generate_LR_Vimeo90K.m | 49 ++ codes/data_scripts/generate_mod_LR_bic.m | 82 +++ codes/data_scripts/generate_mod_LR_bic.py | 81 ++ .../data_scripts/prepare_DIV2K_x4_dataset.sh | 42 ++ codes/data_scripts/regroup_REDS.py | 11 + codes/data_scripts/rename.py | 19 + codes/data_scripts/test_dataloader.py | 104 +++ codes/metrics/calculate_PSNR_SSIM.m | 261 +++++++ codes/metrics/calculate_PSNR_SSIM.py | 147 ++++ codes/models/SRGAN_model.py | 267 +++++++ codes/models/SR_model.py | 170 +++++ codes/models/Video_base_model.py | 166 +++++ codes/models/__init__.py | 19 + codes/models/archs/DUF_arch.py | 368 ++++++++++ codes/models/archs/EDVR_arch.py | 312 ++++++++ codes/models/archs/RRDBNet_arch.py | 73 ++ codes/models/archs/SRResNet_arch.py | 55 ++ codes/models/archs/TOF_arch.py | 137 ++++ codes/models/archs/__init__.py | 0 codes/models/archs/arch_util.py | 79 ++ codes/models/archs/dcn/__init__.py | 7 + codes/models/archs/dcn/deform_conv.py | 291 ++++++++ codes/models/archs/dcn/setup.py | 22 + .../models/archs/dcn/src/deform_conv_cuda.cpp | 695 ++++++++++++++++++ codes/models/archs/discriminator_vgg_arch.py | 88 +++ codes/models/base_model.py | 116 +++ codes/models/loss.py | 74 ++ codes/models/lr_scheduler.py | 144 ++++ codes/models/networks.py | 57 ++ codes/options/__init__.py | 0 codes/options/options.py | 116 +++ codes/options/test/test_ESRGAN.yml | 32 + codes/options/test/test_SRGAN.yml | 32 + codes/options/test/test_SRResNet.yml | 48 ++ codes/options/train/train_EDVR_M.yml | 80 ++ codes/options/train/train_EDVR_woTSA_M.yml | 71 ++ codes/options/train/train_ESRGAN.yml | 81 ++ codes/options/train/train_SRGAN.yml | 85 +++ codes/options/train/train_SRResNet.yml | 70 ++ codes/run_scripts.sh | 10 + .../scripts/back_projection/backprojection.m | 20 + codes/scripts/back_projection/main_bp.m | 22 + .../back_projection/main_reverse_filter.m | 25 + codes/scripts/transfer_params_MSRResNet.py | 27 + codes/test.py | 105 +++ codes/test_Vid4_REDS4_with_GT.py | 208 ++++++ codes/test_Vid4_REDS4_with_GT_DUF.py | 264 +++++++ codes/test_Vid4_REDS4_with_GT_TOF.py | 230 ++++++ codes/train.py | 310 ++++++++ codes/utils/__init__.py | 0 codes/utils/util.py | 327 ++++++++ .../Put pretrained models here. | 0 66 files changed, 8121 insertions(+) create mode 100644 .flake8 create mode 100644 .gitignore create mode 100644 .style.yapf create mode 100644 README.md create mode 100644 codes/data/LQGT_dataset.py create mode 100644 codes/data/LQ_dataset.py create mode 100644 codes/data/REDS_dataset.py create mode 100644 codes/data/Vimeo90K_dataset.py create mode 100644 codes/data/__init__.py create mode 100644 codes/data/data_sampler.py create mode 100644 codes/data/util.py create mode 100644 codes/data/video_test_dataset.py create mode 100644 codes/data_scripts/create_lmdb.py create mode 100644 codes/data_scripts/extract_subimages.py create mode 100644 codes/data_scripts/generate_LR_Vimeo90K.m create mode 100644 codes/data_scripts/generate_mod_LR_bic.m create mode 100644 codes/data_scripts/generate_mod_LR_bic.py create mode 100644 codes/data_scripts/prepare_DIV2K_x4_dataset.sh create mode 100644 codes/data_scripts/regroup_REDS.py create mode 100644 codes/data_scripts/rename.py create mode 100644 codes/data_scripts/test_dataloader.py create mode 100644 codes/metrics/calculate_PSNR_SSIM.m create mode 100644 codes/metrics/calculate_PSNR_SSIM.py create mode 100644 codes/models/SRGAN_model.py create mode 100644 codes/models/SR_model.py create mode 100644 codes/models/Video_base_model.py create mode 100644 codes/models/__init__.py create mode 100644 codes/models/archs/DUF_arch.py create mode 100644 codes/models/archs/EDVR_arch.py create mode 100644 codes/models/archs/RRDBNet_arch.py create mode 100644 codes/models/archs/SRResNet_arch.py create mode 100755 codes/models/archs/TOF_arch.py create mode 100644 codes/models/archs/__init__.py create mode 100644 codes/models/archs/arch_util.py create mode 100644 codes/models/archs/dcn/__init__.py create mode 100644 codes/models/archs/dcn/deform_conv.py create mode 100644 codes/models/archs/dcn/setup.py create mode 100644 codes/models/archs/dcn/src/deform_conv_cuda.cpp create mode 100644 codes/models/archs/discriminator_vgg_arch.py create mode 100644 codes/models/base_model.py create mode 100644 codes/models/loss.py create mode 100644 codes/models/lr_scheduler.py create mode 100644 codes/models/networks.py create mode 100644 codes/options/__init__.py create mode 100644 codes/options/options.py create mode 100644 codes/options/test/test_ESRGAN.yml create mode 100644 codes/options/test/test_SRGAN.yml create mode 100644 codes/options/test/test_SRResNet.yml create mode 100644 codes/options/train/train_EDVR_M.yml create mode 100644 codes/options/train/train_EDVR_woTSA_M.yml create mode 100644 codes/options/train/train_ESRGAN.yml create mode 100644 codes/options/train/train_SRGAN.yml create mode 100644 codes/options/train/train_SRResNet.yml create mode 100644 codes/run_scripts.sh create mode 100644 codes/scripts/back_projection/backprojection.m create mode 100644 codes/scripts/back_projection/main_bp.m create mode 100644 codes/scripts/back_projection/main_reverse_filter.m create mode 100644 codes/scripts/transfer_params_MSRResNet.py create mode 100644 codes/test.py create mode 100644 codes/test_Vid4_REDS4_with_GT.py create mode 100644 codes/test_Vid4_REDS4_with_GT_DUF.py create mode 100644 codes/test_Vid4_REDS4_with_GT_TOF.py create mode 100644 codes/train.py create mode 100644 codes/utils/__init__.py create mode 100644 codes/utils/util.py create mode 100644 experiments/pretrained_models/Put pretrained models here. diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..009995b3 --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] +ignore = + # Too many leading '#' for block comment (E266) + E266 + +max-line-length=100 \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..b0abdf54 --- /dev/null +++ b/.gitignore @@ -0,0 +1,121 @@ +experiments/* +results/* +tb_logger/* +datasets/* +.vscode + +*.html +*.png +*.jpg +*.gif +*.pth +*.pytorch +*.zip +*.cu + +# template + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 00000000..9f0e0622 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,4 @@ +[style] +BASED_ON_STYLE = pep8 +COLUMN_LIMIT = 100 +SPLIT_BEFORE_NAMED_ASSIGNS = false \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 00000000..4b8f4ffc --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +# MMSR + +MMSR is an open source image and video super-resolution toolbox based on PyTorch. It is a part of the [open-mmlab](https://github.com/open-mmlab) project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/). MMSR is based on our previous projects: [BasicSR](https://github.com/xinntao/BasicSR), [ESRGAN](https://github.com/xinntao/ESRGAN)and [EDVR](https://github.com/xinntao/EDVR). + +### Highlights +- **A unified framework** suitable for image and video super-resolution tasks. It is also easy to adapt to other restoration tasks, e.g., deblurring, denoising, etc. +- **State of the art**: It includes several winning methods in competitions: such as ESRGAN (PIRM18), EDVR (NTIRE19). +- **Easy to extend**: It is easy to try new research ideas based on the code base. + + +### Updates +[2019-07-25] MMSR v0.1 is released. + +## Dependencies and Installation + +- Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) +- [PyTorch >= 1.0](https://pytorch.org/) +- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) +- [Deformable Convolution](https://arxiv.org/abs/1703.06211). We use [mmdetection](https://github.com/open-mmlab/mmdetection)'s dcn implementation. Please first compile it. + ``` + cd ./codes/models/archs/dcn + python setup.py develop + ``` +- Python packages: `pip install numpy opencv-python lmdb pyyaml` +- TensorBoard: + - PyTorch >= 1.1: `pip install tb-nightly future` + - PyTorch == 1.0: `pip install tensorboardX` + +## Dataset Preparation +We use datasets in LDMB format for faster IO speed. Please refer to [DATASETS.md](datasets/DATASETS.md) for more details. + +## Training and Testing +Please see [wiki- Training and Testing](https://github.com/open-mmlab/mmsr/wiki/Training-and-Testing) for the basic usage, *i.e.,* training and testing. + +## Model Zoo and Baselines +Results and pre-trained models are available in the [wiki-Model Zoo](https://github.com/open-mmlab/mmsr/wiki/Model-Zoo). + +## Contributing +We appreciate all contributions. Please refer to [mmdetection](https://github.com/open-mmlab/mmdetection/blob/master/CONTRIBUTING.md) for contributing guideline. + +**Python code style**
+We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style. We use [flake8](http://flake8.pycqa.org/en/latest/) as the linter and [yapf](https://github.com/google/yapf) as the formatter. Please upgrade to the latest yapf (>=0.27.0) and refer to the [yapf configuration](.style.yapf) and [flake8 configuration](.flake8). + +> Before you create a PR, make sure that your code lints and is formatted by yapf. + +## License +This project is released under the Apache 2.0 license. diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py new file mode 100644 index 00000000..a52817eb --- /dev/null +++ b/codes/data/LQGT_dataset.py @@ -0,0 +1,127 @@ +import random +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util + + +class LQGTDataset(data.Dataset): + """ + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. + If only GT images are provided, generate LQ images on-the-fly. + """ + + def __init__(self, opt): + super(LQGTDataset, self).__init__() + self.opt = opt + self.data_type = self.opt['data_type'] + self.paths_LQ, self.paths_GT = None, None + self.sizes_LQ, self.sizes_GT = None, None + self.LQ_env, self.GT_env = None, None # environments for lmdb + + self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) + self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + assert self.paths_GT, 'Error: GT path is empty.' + if self.paths_LQ and self.paths_GT: + assert len(self.paths_LQ) == len( + self.paths_GT + ), 'GT and LQ datasets have different number of images - {}, {}.'.format( + len(self.paths_LQ), len(self.paths_GT)) + self.random_scale_list = [1] + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + + def __getitem__(self, index): + if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): + self._init_lmdb() + GT_path, LQ_path = None, None + scale = self.opt['scale'] + GT_size = self.opt['GT_size'] + + # get GT image + GT_path = self.paths_GT[index] + resolution = [int(s) for s in self.sizes_GT[index].split('_') + ] if self.data_type == 'lmdb' else None + img_GT = util.read_img(self.GT_env, GT_path, resolution) + if self.opt['phase'] != 'train': # modcrop in the validation / test phase + img_GT = util.modcrop(img_GT, scale) + if self.opt['color']: # change color space if necessary + img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] + + # get LQ image + if self.paths_LQ: + LQ_path = self.paths_LQ[index] + resolution = [int(s) for s in self.sizes_LQ[index].split('_') + ] if self.data_type == 'lmdb' else None + img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) + else: # down-sampling on-the-fly + # randomly scale during training + if self.opt['phase'] == 'train': + random_scale = random.choice(self.random_scale_list) + H_s, W_s, _ = img_GT.shape + + def _mod(n, random_scale, scale, thres): + rlt = int(n * random_scale) + rlt = (rlt // scale) * scale + return thres if rlt < thres else rlt + + H_s = _mod(H_s, random_scale, scale, GT_size) + W_s = _mod(W_s, random_scale, scale, GT_size) + img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) + if img_GT.ndim == 2: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) + + H, W, _ = img_GT.shape + # using matlab imresize + img_LQ = util.imresize_np(img_GT, 1 / scale, True) + if img_LQ.ndim == 2: + img_LQ = np.expand_dims(img_LQ, axis=2) + + if self.opt['phase'] == 'train': + # if the image size is too small + H, W, _ = img_GT.shape + if H < GT_size or W < GT_size: + img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + # using matlab imresize + img_LQ = util.imresize_np(img_GT, 1 / scale, True) + if img_LQ.ndim == 2: + img_LQ = np.expand_dims(img_LQ, axis=2) + + H, W, C = img_LQ.shape + LQ_size = GT_size // scale + + # randomly crop + rnd_h = random.randint(0, max(0, H - LQ_size)) + rnd_w = random.randint(0, max(0, W - LQ_size)) + img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] + rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) + img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] + + # augmentation - flip, rotate + img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], + self.opt['use_rot']) + + if self.opt['color']: # change color space if necessary + img_LQ = util.channel_convert(C, self.opt['color'], + [img_LQ])[0] # TODO during val no definition + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = img_GT[:, :, [2, 1, 0]] + img_LQ = img_LQ[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + + if LQ_path is None: + LQ_path = GT_path + return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} + + def __len__(self): + return len(self.paths_GT) diff --git a/codes/data/LQ_dataset.py b/codes/data/LQ_dataset.py new file mode 100644 index 00000000..e7ad7dd3 --- /dev/null +++ b/codes/data/LQ_dataset.py @@ -0,0 +1,47 @@ +import numpy as np +import lmdb +import torch +import torch.utils.data as data +import data.util as util + + +class LQDataset(data.Dataset): + '''Read LQ images only in the test phase.''' + + def __init__(self, opt): + super(LQDataset, self).__init__() + self.opt = opt + self.paths_LQ, self.paths_GT = None, None + self.LQ_env = None # environment for lmdb + + self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + assert self.paths_LQ, 'Error: LQ paths are empty.' + + def _init_lmdb(self): + self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + + def __getitem__(self, index): + if self.data_type == 'lmdb' and self.LQ_env is None: + self._init_lmdb() + LQ_path = None + + # get LQ image + LQ_path = self.LQ_path[index] + resolution = [int(s) for s in self.sizes_LQ[index].split('_') + ] if self.data_type == 'lmdb' else None + img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) + H, W, C = img_LQ.shape + + if self.opt['color']: # change color space if necessary + img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0] + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_LQ.shape[2] == 3: + img_LQ = img_LQ[:, :, [2, 1, 0]] + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + + return {'LQ': img_LQ, 'LQ_path': LQ_path} + + def __len__(self): + return len(self.paths_LQ) diff --git a/codes/data/REDS_dataset.py b/codes/data/REDS_dataset.py new file mode 100644 index 00000000..5643dbb2 --- /dev/null +++ b/codes/data/REDS_dataset.py @@ -0,0 +1,210 @@ +''' +REDS dataset +support reading images from lmdb, image folder and memcached +''' +import os.path as osp +import random +import pickle +import logging +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util +try: + import mc # import memcached +except ImportError: + pass + +logger = logging.getLogger('base') + + +class REDSDataset(data.Dataset): + ''' + Reading the training REDS dataset + key example: 000_00000000 + GT: Ground-Truth; + LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames + support reading N LQ frames, N = 1, 3, 5, 7 + ''' + + def __init__(self, opt): + super(REDSDataset, self).__init__() + self.opt = opt + # temporal augmentation + self.interval_list = opt['interval_list'] + self.random_reverse = opt['random_reverse'] + logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( + ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) + + self.half_N_frames = opt['N_frames'] // 2 + self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] + self.data_type = self.opt['data_type'] + self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs + #### directly load image keys + if self.data_type == 'lmdb': + self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT']) + logger.info('Using lmdb meta info for cache keys.') + elif opt['cache_keys']: + logger.info('Using cache keys: {}'.format(opt['cache_keys'])) + self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys'] + else: + raise ValueError( + 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') + + # remove the REDS4 for testing + self.paths_GT = [ + v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020'] + ] + assert self.paths_GT, 'Error: GT path is empty.' + + if self.data_type == 'lmdb': + self.GT_env, self.LQ_env = None, None + elif self.data_type == 'mc': # memcached + self.mclient = None + elif self.data_type == 'img': + pass + else: + raise ValueError('Wrong data type: {}'.format(self.data_type)) + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + + def _ensure_memcached(self): + if self.mclient is None: + # specify the config files + server_list_config_file = None + client_config_file = None + self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, + client_config_file) + + def _read_img_mc(self, path): + ''' Return BGR, HWC, [0, 255], uint8''' + value = mc.pyvector() + self.mclient.Get(path, value) + value_buf = mc.ConvertBuffer(value) + img_array = np.frombuffer(value_buf, np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) + return img + + def _read_img_mc_BGR(self, path, name_a, name_b): + ''' Read BGR channels separately and then combine for 1M limits in cluster''' + img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png')) + img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png')) + img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png')) + img = cv2.merge((img_B, img_G, img_R)) + return img + + def __getitem__(self, index): + if self.data_type == 'mc': + self._ensure_memcached() + elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): + self._init_lmdb() + + scale = self.opt['scale'] + GT_size = self.opt['GT_size'] + key = self.paths_GT[index] + name_a, name_b = key.split('_') + center_frame_idx = int(name_b) + + #### determine the neighbor frames + interval = random.choice(self.interval_list) + if self.opt['border_mode']: + direction = 1 # 1: forward; 0: backward + N_frames = self.opt['N_frames'] + if self.random_reverse and random.random() < 0.5: + direction = random.choice([0, 1]) + if center_frame_idx + interval * (N_frames - 1) > 99: + direction = 0 + elif center_frame_idx - interval * (N_frames - 1) < 0: + direction = 1 + # get the neighbor list + if direction == 1: + neighbor_list = list( + range(center_frame_idx, center_frame_idx + interval * N_frames, interval)) + else: + neighbor_list = list( + range(center_frame_idx, center_frame_idx - interval * N_frames, -interval)) + name_b = '{:08d}'.format(neighbor_list[0]) + else: + # ensure not exceeding the borders + while (center_frame_idx + self.half_N_frames * interval > + 99) or (center_frame_idx - self.half_N_frames * interval < 0): + center_frame_idx = random.randint(0, 99) + # get the neighbor list + neighbor_list = list( + range(center_frame_idx - self.half_N_frames * interval, + center_frame_idx + self.half_N_frames * interval + 1, interval)) + if self.random_reverse and random.random() < 0.5: + neighbor_list.reverse() + name_b = '{:08d}'.format(neighbor_list[self.half_N_frames]) + + assert len( + neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format( + len(neighbor_list)) + + #### get the GT image (as the center frame) + if self.data_type == 'mc': + img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b) + img_GT = img_GT.astype(np.float32) / 255. + elif self.data_type == 'lmdb': + img_GT = util.read_img(self.GT_env, key, (3, 720, 1280)) + else: + img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png')) + + #### get LQ images + LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280) + img_LQ_l = [] + for v in neighbor_list: + img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v)) + if self.data_type == 'mc': + if self.LR_input: + img_LQ = self._read_img_mc(img_LQ_path) + else: + img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v)) + img_LQ = img_LQ.astype(np.float32) / 255. + elif self.data_type == 'lmdb': + img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple) + else: + img_LQ = util.read_img(None, img_LQ_path) + img_LQ_l.append(img_LQ) + + if self.opt['phase'] == 'train': + C, H, W = LQ_size_tuple # LQ size + # randomly crop + if self.LR_input: + LQ_size = GT_size // scale + rnd_h = random.randint(0, max(0, H - LQ_size)) + rnd_w = random.randint(0, max(0, W - LQ_size)) + img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l] + rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) + img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] + else: + rnd_h = random.randint(0, max(0, H - GT_size)) + rnd_w = random.randint(0, max(0, W - GT_size)) + img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l] + img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] + + # augmentation - flip, rotate + img_LQ_l.append(img_GT) + rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) + img_LQ_l = rlt[0:-1] + img_GT = rlt[-1] + + # stack LQ images to NHWC, N is the frame number + img_LQs = np.stack(img_LQ_l, axis=0) + # BGR to RGB, HWC to CHW, numpy to tensor + img_GT = img_GT[:, :, [2, 1, 0]] + img_LQs = img_LQs[:, :, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, + (0, 3, 1, 2)))).float() + return {'LQs': img_LQs, 'GT': img_GT, 'key': key} + + def __len__(self): + return len(self.paths_GT) diff --git a/codes/data/Vimeo90K_dataset.py b/codes/data/Vimeo90K_dataset.py new file mode 100644 index 00000000..914e1dba --- /dev/null +++ b/codes/data/Vimeo90K_dataset.py @@ -0,0 +1,167 @@ +''' +Vimeo90K dataset +support reading images from lmdb, image folder and memcached +''' +import os.path as osp +import random +import pickle +import logging +import numpy as np +import cv2 +import lmdb +import torch +import torch.utils.data as data +import data.util as util +try: + import mc # import memcached +except ImportError: + pass +logger = logging.getLogger('base') + + +class Vimeo90KDataset(data.Dataset): + ''' + Reading the training Vimeo90K dataset + key example: 00001_0001 (_1, ..., _7) + GT (Ground-Truth): 4th frame; + LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame + ''' + + def __init__(self, opt): + super(Vimeo90KDataset, self).__init__() + self.opt = opt + # temporal augmentation + self.interval_list = opt['interval_list'] + self.random_reverse = opt['random_reverse'] + logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( + ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) + + self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] + self.data_type = self.opt['data_type'] + self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs + + #### determine the LQ frame list + ''' + N | frames + 1 | 4 + 3 | 3,4,5 + 5 | 2,3,4,5,6 + 7 | 1,2,3,4,5,6,7 + ''' + self.LQ_frames_list = [] + for i in range(opt['N_frames']): + self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2) + + #### directly load image keys + if self.data_type == 'lmdb': + self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT']) + logger.info('Using lmdb meta info for cache keys.') + elif opt['cache_keys']: + logger.info('Using cache keys: {}'.format(opt['cache_keys'])) + self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys'] + else: + raise ValueError( + 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') + assert self.paths_GT, 'Error: GT path is empty.' + + if self.data_type == 'lmdb': + self.GT_env, self.LQ_env = None, None + elif self.data_type == 'mc': # memcached + self.mclient = None + elif self.data_type == 'img': + pass + else: + raise ValueError('Wrong data type: {}'.format(self.data_type)) + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, + meminit=False) + self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, + meminit=False) + + def _ensure_memcached(self): + if self.mclient is None: + # specify the config files + server_list_config_file = None + client_config_file = None + self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, + client_config_file) + + def _read_img_mc(self, path): + ''' Return BGR, HWC, [0, 255], uint8''' + value = mc.pyvector() + self.mclient.Get(path, value) + value_buf = mc.ConvertBuffer(value) + img_array = np.frombuffer(value_buf, np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) + return img + + def __getitem__(self, index): + if self.data_type == 'mc': + self._ensure_memcached() + elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): + self._init_lmdb() + + scale = self.opt['scale'] + GT_size = self.opt['GT_size'] + key = self.paths_GT[index] + name_a, name_b = key.split('_') + #### get the GT image (as the center frame) + if self.data_type == 'mc': + img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png')) + img_GT = img_GT.astype(np.float32) / 255. + elif self.data_type == 'lmdb': + img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448)) + else: + img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png')) + + #### get LQ images + LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448) + img_LQ_l = [] + for v in self.LQ_frames_list: + if self.data_type == 'mc': + img_LQ = self._read_img_mc( + osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v))) + img_LQ = img_LQ.astype(np.float32) / 255. + elif self.data_type == 'lmdb': + img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple) + else: + img_LQ = util.read_img(None, + osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v))) + img_LQ_l.append(img_LQ) + + if self.opt['phase'] == 'train': + C, H, W = LQ_size_tuple # LQ size + # randomly crop + if self.LR_input: + LQ_size = GT_size // scale + rnd_h = random.randint(0, max(0, H - LQ_size)) + rnd_w = random.randint(0, max(0, W - LQ_size)) + img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l] + rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) + img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] + else: + rnd_h = random.randint(0, max(0, H - GT_size)) + rnd_w = random.randint(0, max(0, W - GT_size)) + img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l] + img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] + + # augmentation - flip, rotate + img_LQ_l.append(img_GT) + rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) + img_LQ_l = rlt[0:-1] + img_GT = rlt[-1] + + # stack LQ images to NHWC, N is the frame number + img_LQs = np.stack(img_LQ_l, axis=0) + # BGR to RGB, HWC to CHW, numpy to tensor + img_GT = img_GT[:, :, [2, 1, 0]] + img_LQs = img_LQs[:, :, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, + (0, 3, 1, 2)))).float() + return {'LQs': img_LQs, 'GT': img_GT, 'key': key} + + def __len__(self): + return len(self.paths_GT) diff --git a/codes/data/__init__.py b/codes/data/__init__.py new file mode 100644 index 00000000..ce10a834 --- /dev/null +++ b/codes/data/__init__.py @@ -0,0 +1,49 @@ +"""create dataset and dataloader""" +import logging +import torch +import torch.utils.data + + +def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): + phase = dataset_opt['phase'] + if phase == 'train': + if opt['dist']: + world_size = torch.distributed.get_world_size() + num_workers = dataset_opt['n_workers'] + assert dataset_opt['batch_size'] % world_size == 0 + batch_size = dataset_opt['batch_size'] // world_size + shuffle = False + else: + num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) + batch_size = dataset_opt['batch_size'] + shuffle = True + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, + num_workers=num_workers, sampler=sampler, drop_last=True, + pin_memory=False) + else: + return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, + pin_memory=False) + + +def create_dataset(dataset_opt): + mode = dataset_opt['mode'] + # datasets for image restoration + if mode == 'LQ': + from data.LQ_dataset import LQDataset as D + elif mode == 'LQGT': + from data.LQGT_dataset import LQGTDataset as D + # datasets for video restoration + elif mode == 'REDS': + from data.REDS_dataset import REDSDataset as D + elif mode == 'Vimeo90K': + from data.Vimeo90K_dataset import Vimeo90KDataset as D + elif mode == 'video_test': + from data.video_test_dataset import VideoTestDataset as D + else: + raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) + dataset = D(dataset_opt) + + logger = logging.getLogger('base') + logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, + dataset_opt['name'])) + return dataset diff --git a/codes/data/data_sampler.py b/codes/data/data_sampler.py new file mode 100644 index 00000000..9c409418 --- /dev/null +++ b/codes/data/data_sampler.py @@ -0,0 +1,65 @@ +""" +Modified from torch.utils.data.distributed.DistributedSampler +Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the +dataloader after each epoch +""" +import math +import torch +from torch.utils.data.sampler import Sampler +import torch.distributed as dist + + +class DistIterSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dsize = len(self.dataset) + indices = [v % dsize for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/codes/data/util.py b/codes/data/util.py new file mode 100644 index 00000000..516c6ae9 --- /dev/null +++ b/codes/data/util.py @@ -0,0 +1,543 @@ +import os +import math +import pickle +import random +import numpy as np +import glob +import torch +import cv2 + +#################### +# Files & IO +#################### + +###################### get image path list ###################### +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def _get_paths_from_images(path): + """get image path list from image folder""" + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +def _get_paths_from_lmdb(dataroot): + """get image path list from lmdb meta info""" + meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) + paths = meta_info['keys'] + sizes = meta_info['resolution'] + if len(sizes) == 1: + sizes = sizes * len(paths) + return paths, sizes + + +def get_image_paths(data_type, dataroot): + """get image path list + support lmdb or image files""" + paths, sizes = None, None + if dataroot is not None: + if data_type == 'lmdb': + paths, sizes = _get_paths_from_lmdb(dataroot) + elif data_type == 'img': + paths = sorted(_get_paths_from_images(dataroot)) + else: + raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) + return paths, sizes + + +def glob_file_list(root): + return sorted(glob.glob(os.path.join(root, '*'))) + + +###################### read images ###################### +def _read_img_lmdb(env, key, size): + """read image from lmdb with key (w/ and w/o fixed size) + size: (C, H, W) tuple""" + with env.begin(write=False) as txn: + buf = txn.get(key.encode('ascii')) + img_flat = np.frombuffer(buf, dtype=np.uint8) + C, H, W = size + img = img_flat.reshape(H, W, C) + return img + + +def read_img(env, path, size=None): + """read image by cv2 or from lmdb + return: Numpy float32, HWC, BGR, [0,1]""" + if env is None: # img + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + else: + img = _read_img_lmdb(env, path, size) + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +def read_img_seq(path): + """Read a sequence of images from a given folder path + Args: + path (list/str): list of image paths/image folder path + + Returns: + imgs (Tensor): size (T, C, H, W), RGB, [0, 1] + """ + if type(path) is list: + img_path_l = path + else: + img_path_l = sorted(glob.glob(os.path.join(path, '*'))) + img_l = [read_img(None, v) for v in img_path_l] + # stack to Torch tensor + imgs = np.stack(img_l, axis=0) + imgs = imgs[:, :, :, [2, 1, 0]] + imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() + return imgs + + +def index_generation(crt_i, max_n, N, padding='reflection'): + """Generate an index list for reading N frames from a sequence of images + Args: + crt_i (int): current center index + max_n (int): max number of the sequence of images (calculated from 1) + N (int): reading N frames + padding (str): padding mode, one of replicate | reflection | new_info | circle + Example: crt_i = 0, N = 5 + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + new_info: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + return_l (list [int]): a list of indexes + """ + max_n = max_n - 1 + n_pad = N // 2 + return_l = [] + + for i in range(crt_i - n_pad, crt_i + n_pad + 1): + if i < 0: + if padding == 'replicate': + add_idx = 0 + elif padding == 'reflection': + add_idx = -i + elif padding == 'new_info': + add_idx = (crt_i + n_pad) + (-i) + elif padding == 'circle': + add_idx = N + i + else: + raise ValueError('Wrong padding mode') + elif i > max_n: + if padding == 'replicate': + add_idx = max_n + elif padding == 'reflection': + add_idx = max_n * 2 - i + elif padding == 'new_info': + add_idx = (crt_i - n_pad) - (i - max_n) + elif padding == 'circle': + add_idx = i - N + else: + raise ValueError('Wrong padding mode') + else: + add_idx = i + return_l.append(add_idx) + return return_l + + +#################### +# image processing +# process on numpy image +#################### + + +def augment(img_list, hflip=True, rot=True): + """horizontal flip OR rotate (0, 90, 180, 270 degrees)""" + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +def augment_flow(img_list, flow_list, hflip=True, rot=True): + """horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows""" + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: + flow = flow[:, ::-1, :] + flow[:, :, 0] *= -1 + if vflip: + flow = flow[::-1, :, :] + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + rlt_img_list = [_augment(img) for img in img_list] + rlt_flow_list = [_augment_flow(flow) for flow in flow_list] + + return rlt_img_list, rlt_flow_list + + +def channel_convert(in_c, tar_type, img_list): + """conversion among BGR, gray and y""" + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +def rgb2ycbcr(img, only_y=True): + """same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + """bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + """same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def modcrop(img_in, scale): + """img_in: Numpy, HWC or HW""" + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +#################### +# Functions +#################### + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (( + (absx > 1) * (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: CHW RGB [0,1] + # output: CHW RGB [0,1] w/o round + + in_C, in_H, in_W = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) + + return out_2 + + +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC BGR [0,1] + # output: HWC BGR [0,1] w/o round + img = torch.from_numpy(img) + + in_H, in_W, in_C = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) + out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) + out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) + + return out_2.numpy() + + +if __name__ == '__main__': + # test imresize function + # read images + img = cv2.imread('test.png') + img = img * 1.0 / 255 + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + # imresize + scale = 1 / 4 + import time + total_time = 0 + for i in range(10): + start_time = time.time() + rlt = imresize(img, scale, antialiasing=True) + use_time = time.time() - start_time + total_time += use_time + print('average time: {}'.format(total_time / 10)) + + import torchvision.utils + torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, + normalize=False) diff --git a/codes/data/video_test_dataset.py b/codes/data/video_test_dataset.py new file mode 100644 index 00000000..ce891958 --- /dev/null +++ b/codes/data/video_test_dataset.py @@ -0,0 +1,84 @@ +import os.path as osp +import torch +import torch.utils.data as data +import data.util as util + + +class VideoTestDataset(data.Dataset): + """ + A video test dataset. Support: + Vid4 + REDS4 + Vimeo90K-Test + + no need to prepare LMDB files + """ + + def __init__(self, opt): + super(VideoTestDataset, self).__init__() + self.opt = opt + self.cache_data = opt['cache_data'] + self.half_N_frames = opt['N_frames'] // 2 + self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] + self.data_type = self.opt['data_type'] + self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []} + if self.data_type == 'lmdb': + raise ValueError('No need to use LMDB during validation/test.') + #### Generate data info and cache data + self.imgs_LQ, self.imgs_GT = {}, {} + if opt['name'].lower() in ['vid4', 'reds4']: + subfolders_LQ = util.glob_file_list(self.LQ_root) + subfolders_GT = util.glob_file_list(self.GT_root) + for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT): + subfolder_name = osp.basename(subfolder_GT) + img_paths_LQ = util.glob_file_list(subfolder_LQ) + img_paths_GT = util.glob_file_list(subfolder_GT) + max_idx = len(img_paths_LQ) + assert max_idx == len( + img_paths_GT), 'Different number of images in LQ and GT folders' + self.data_info['path_LQ'].extend(img_paths_LQ) + self.data_info['path_GT'].extend(img_paths_GT) + self.data_info['folder'].extend([subfolder_name] * max_idx) + for i in range(max_idx): + self.data_info['idx'].append('{}/{}'.format(i, max_idx)) + border_l = [0] * max_idx + for i in range(self.half_N_frames): + border_l[i] = 1 + border_l[max_idx - i - 1] = 1 + self.data_info['border'].extend(border_l) + + if self.cache_data: + self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ) + self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT) + elif opt['name'].lower() in ['vimeo90k-test']: + pass # TODO + else: + raise ValueError( + 'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.') + + def __getitem__(self, index): + # path_LQ = self.data_info['path_LQ'][index] + # path_GT = self.data_info['path_GT'][index] + folder = self.data_info['folder'][index] + idx, max_idx = self.data_info['idx'][index].split('/') + idx, max_idx = int(idx), int(max_idx) + border = self.data_info['border'][index] + + if self.cache_data: + select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'], + padding=self.opt['padding']) + imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx)) + img_GT = self.imgs_GT[folder][idx] + else: + pass # TODO + + return { + 'LQs': imgs_LQ, + 'GT': img_GT, + 'folder': folder, + 'idx': self.data_info['idx'][index], + 'border': border + } + + def __len__(self): + return len(self.data_info['path_GT']) diff --git a/codes/data_scripts/create_lmdb.py b/codes/data_scripts/create_lmdb.py new file mode 100644 index 00000000..c44b5c78 --- /dev/null +++ b/codes/data_scripts/create_lmdb.py @@ -0,0 +1,411 @@ +"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets""" + +import sys +import os.path as osp +import glob +import pickle +from multiprocessing import Pool +import numpy as np +import lmdb +import cv2 + +sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) +import data.util as data_util # noqa: E402 +import utils.util as util # noqa: E402 + + +def main(): + dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test + mode = 'GT' # used for vimeo90k and REDS datasets + # vimeo90k: GT | LR | flow + # REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp + # train_sharp_flowx4 + if dataset == 'vimeo90k': + vimeo90k(mode) + elif dataset == 'REDS': + REDS(mode) + elif dataset == 'general': + opt = {} + opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub' + opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' + opt['name'] = 'DIV2K800_sub_GT' + general_image_folder(opt) + elif dataset == 'DIV2K_demo': + opt = {} + ## GT + opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub' + opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' + opt['name'] = 'DIV2K800_sub_GT' + general_image_folder(opt) + ## LR + opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4' + opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' + opt['name'] = 'DIV2K800_sub_bicLRx4' + general_image_folder(opt) + elif dataset == 'test': + test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS') + + +def read_image_worker(path, key): + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + return (key, img) + + +def general_image_folder(opt): + """Create lmdb for general image folders + Users should define the keys, such as: '0321_s035' for DIV2K sub-images + If all the images have the same resolution, it will only store one copy of resolution info. + Otherwise, it will store every resolution info. + """ + #### configurations + read_all_imgs = False # whether real all images to memory with multiprocessing + # Set False for use limited memory + BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False + n_thread = 40 + ######################################################## + img_folder = opt['img_folder'] + lmdb_save_path = opt['lmdb_save_path'] + meta_info = {'name': opt['name']} + if not lmdb_save_path.endswith('.lmdb'): + raise ValueError("lmdb_save_path must end with \'lmdb\'.") + if osp.exists(lmdb_save_path): + print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) + sys.exit(1) + + #### read all the image paths to a list + print('Reading image path list ...') + all_img_list = sorted(glob.glob(osp.join(img_folder, '*'))) + keys = [] + for img_path in all_img_list: + keys.append(osp.splitext(osp.basename(img_path))[0]) + + if read_all_imgs: + #### read all images to memory (multiprocessing) + dataset = {} # store all image data. list cannot keep the order, use dict + print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) + pbar = util.ProgressBar(len(all_img_list)) + + def mycallback(arg): + '''get the image data and update pbar''' + key = arg[0] + dataset[key] = arg[1] + pbar.update('Reading {}'.format(key)) + + pool = Pool(n_thread) + for path, key in zip(all_img_list, keys): + pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) + pool.close() + pool.join() + print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) + + #### create lmdb environment + data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes + print('data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(all_img_list) + env = lmdb.open(lmdb_save_path, map_size=data_size * 10) + + #### write data to lmdb + pbar = util.ProgressBar(len(all_img_list)) + txn = env.begin(write=True) + resolutions = [] + for idx, (path, key) in enumerate(zip(all_img_list, keys)): + pbar.update('Write {}'.format(key)) + key_byte = key.encode('ascii') + data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) + if data.ndim == 2: + H, W = data.shape + C = 1 + else: + H, W, C = data.shape + txn.put(key_byte, data) + resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W)) + if not read_all_imgs and idx % BATCH == 0: + txn.commit() + txn = env.begin(write=True) + txn.commit() + env.close() + print('Finish writing lmdb.') + + #### create meta information + # check whether all the images are the same size + assert len(keys) == len(resolutions) + if len(set(resolutions)) <= 1: + meta_info['resolution'] = [resolutions[0]] + meta_info['keys'] = keys + print('All images have the same resolution. Simplify the meta info.') + else: + meta_info['resolution'] = resolutions + meta_info['keys'] = keys + print('Not all images have the same resolution. Save meta info for each image.') + + pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) + print('Finish creating lmdb meta info.') + + +def vimeo90k(mode): + """Create lmdb for the Vimeo90K dataset, each image with a fixed size + GT: [3, 256, 448] + Now only need the 4th frame, e.g., 00001_0001_4 + LR: [3, 64, 112] + 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7 + key: + Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001 + + flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3] + Each flow is calculated with GT images by PWCNet and then downsampled by 1/4 + Flow map is quantized by mmcv and saved in png format + """ + #### configurations + read_all_imgs = False # whether real all images to memory with multiprocessing + # Set False for use limited memory + BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False + if mode == 'GT': + img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences' + lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' + txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' + H_dst, W_dst = 256, 448 + elif mode == 'LR': + img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences' + lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' + txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' + H_dst, W_dst = 64, 112 + elif mode == 'flow': + img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4' + lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb' + txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' + H_dst, W_dst = 128, 112 + else: + raise ValueError('Wrong dataset mode: {}'.format(mode)) + n_thread = 40 + ######################################################## + if not lmdb_save_path.endswith('.lmdb'): + raise ValueError("lmdb_save_path must end with \'lmdb\'.") + if osp.exists(lmdb_save_path): + print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) + sys.exit(1) + + #### read all the image paths to a list + print('Reading image path list ...') + with open(txt_file) as f: + train_l = f.readlines() + train_l = [v.strip() for v in train_l] + all_img_list = [] + keys = [] + for line in train_l: + folder = line.split('/')[0] + sub_folder = line.split('/')[1] + all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*'))) + if mode == 'flow': + for j in range(1, 4): + keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j)) + keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j)) + else: + for j in range(7): + keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) + all_img_list = sorted(all_img_list) + keys = sorted(keys) + if mode == 'GT': # only read the 4th frame for the GT mode + print('Only keep the 4th frame.') + all_img_list = [v for v in all_img_list if v.endswith('im4.png')] + keys = [v for v in keys if v.endswith('_4')] + + if read_all_imgs: + #### read all images to memory (multiprocessing) + dataset = {} # store all image data. list cannot keep the order, use dict + print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) + pbar = util.ProgressBar(len(all_img_list)) + + def mycallback(arg): + """get the image data and update pbar""" + key = arg[0] + dataset[key] = arg[1] + pbar.update('Reading {}'.format(key)) + + pool = Pool(n_thread) + for path, key in zip(all_img_list, keys): + pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) + pool.close() + pool.join() + print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) + + #### write data to lmdb + data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes + print('data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(all_img_list) + env = lmdb.open(lmdb_save_path, map_size=data_size * 10) + txn = env.begin(write=True) + pbar = util.ProgressBar(len(all_img_list)) + for idx, (path, key) in enumerate(zip(all_img_list, keys)): + pbar.update('Write {}'.format(key)) + key_byte = key.encode('ascii') + data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) + if 'flow' in mode: + H, W = data.shape + assert H == H_dst and W == W_dst, 'different shape.' + else: + H, W, C = data.shape + assert H == H_dst and W == W_dst and C == 3, 'different shape.' + txn.put(key_byte, data) + if not read_all_imgs and idx % BATCH == 0: + txn.commit() + txn = env.begin(write=True) + txn.commit() + env.close() + print('Finish writing lmdb.') + + #### create meta information + meta_info = {} + if mode == 'GT': + meta_info['name'] = 'Vimeo90K_train_GT' + elif mode == 'LR': + meta_info['name'] = 'Vimeo90K_train_LR' + elif mode == 'flow': + meta_info['name'] = 'Vimeo90K_train_flowx4' + channel = 1 if 'flow' in mode else 3 + meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) + key_set = set() + for key in keys: + if mode == 'flow': + a, b, _, _ = key.split('_') + else: + a, b, _ = key.split('_') + key_set.add('{}_{}'.format(a, b)) + meta_info['keys'] = list(key_set) + pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) + print('Finish creating lmdb meta info.') + + +def REDS(mode): + """Create lmdb for the REDS dataset, each image with a fixed size + GT: [3, 720, 1280], key: 000_00000000 + LR: [3, 180, 320], key: 000_00000000 + key: 000_00000000 + + flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2] + Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4 + Flow map is quantized by mmcv and saved in png format + """ + #### configurations + read_all_imgs = False # whether real all images to memory with multiprocessing + # Set False for use limited memory + BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False + if mode == 'train_sharp': + img_folder = '../../datasets/REDS/train_sharp' + lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb' + H_dst, W_dst = 720, 1280 + elif mode == 'train_sharp_bicubic': + img_folder = '../../datasets/REDS/train_sharp_bicubic' + lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' + H_dst, W_dst = 180, 320 + elif mode == 'train_blur_bicubic': + img_folder = '../../datasets/REDS/train_blur_bicubic' + lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb' + H_dst, W_dst = 180, 320 + elif mode == 'train_blur': + img_folder = '../../datasets/REDS/train_blur' + lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb' + H_dst, W_dst = 720, 1280 + elif mode == 'train_blur_comp': + img_folder = '../../datasets/REDS/train_blur_comp' + lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb' + H_dst, W_dst = 720, 1280 + elif mode == 'train_sharp_flowx4': + img_folder = '../../datasets/REDS/train_sharp_flowx4' + lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb' + H_dst, W_dst = 360, 320 + n_thread = 40 + ######################################################## + if not lmdb_save_path.endswith('.lmdb'): + raise ValueError("lmdb_save_path must end with \'lmdb\'.") + if osp.exists(lmdb_save_path): + print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) + sys.exit(1) + + #### read all the image paths to a list + print('Reading image path list ...') + all_img_list = data_util._get_paths_from_images(img_folder) + keys = [] + for img_path in all_img_list: + split_rlt = img_path.split('/') + folder = split_rlt[-2] + img_name = split_rlt[-1].split('.png')[0] + keys.append(folder + '_' + img_name) + + if read_all_imgs: + #### read all images to memory (multiprocessing) + dataset = {} # store all image data. list cannot keep the order, use dict + print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) + pbar = util.ProgressBar(len(all_img_list)) + + def mycallback(arg): + '''get the image data and update pbar''' + key = arg[0] + dataset[key] = arg[1] + pbar.update('Reading {}'.format(key)) + + pool = Pool(n_thread) + for path, key in zip(all_img_list, keys): + pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) + pool.close() + pool.join() + print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) + + #### create lmdb environment + data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes + print('data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(all_img_list) + env = lmdb.open(lmdb_save_path, map_size=data_size * 10) + + #### write data to lmdb + pbar = util.ProgressBar(len(all_img_list)) + txn = env.begin(write=True) + for idx, (path, key) in enumerate(zip(all_img_list, keys)): + pbar.update('Write {}'.format(key)) + key_byte = key.encode('ascii') + data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) + if 'flow' in mode: + H, W = data.shape + assert H == H_dst and W == W_dst, 'different shape.' + else: + H, W, C = data.shape + assert H == H_dst and W == W_dst and C == 3, 'different shape.' + txn.put(key_byte, data) + if not read_all_imgs and idx % BATCH == 0: + txn.commit() + txn = env.begin(write=True) + txn.commit() + env.close() + print('Finish writing lmdb.') + + #### create meta information + meta_info = {} + meta_info['name'] = 'REDS_{}_wval'.format(mode) + channel = 1 if 'flow' in mode else 3 + meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) + meta_info['keys'] = keys + pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) + print('Finish creating lmdb meta info.') + + +def test_lmdb(dataroot, dataset='REDS'): + env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False) + meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb")) + print('Name: ', meta_info['name']) + print('Resolution: ', meta_info['resolution']) + print('# keys: ', len(meta_info['keys'])) + # read one image + if dataset == 'vimeo90k': + key = '00001_0001_4' + else: + key = '000_00000000' + print('Reading {} for test.'.format(key)) + with env.begin(write=False) as txn: + buf = txn.get(key.encode('ascii')) + img_flat = np.frombuffer(buf, dtype=np.uint8) + C, H, W = [int(s) for s in meta_info['resolution'].split('_')] + img = img_flat.reshape(H, W, C) + cv2.imwrite('test.png', img) + + +if __name__ == "__main__": + main() diff --git a/codes/data_scripts/extract_subimages.py b/codes/data_scripts/extract_subimages.py new file mode 100644 index 00000000..a8185df2 --- /dev/null +++ b/codes/data_scripts/extract_subimages.py @@ -0,0 +1,141 @@ +"""A multi-thread tool to crop large images to sub-images for faster IO.""" +import os +import os.path as osp +import sys +from multiprocessing import Pool +import numpy as np +import cv2 +from PIL import Image +sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) +from utils.util import ProgressBar # noqa: E402 +import data.util as data_util # noqa: E402 + + +def main(): + mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs) + opt = {} + opt['n_thread'] = 20 + opt['compression_level'] = 3 # 3 is the default value in cv2 + # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer + # compression time. If read raw images during training, use 0 for faster IO speed. + if mode == 'single': + opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR' + opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub' + opt['crop_sz'] = 480 # the size of each sub-image + opt['step'] = 240 # step of the sliding crop window + opt['thres_sz'] = 48 # size threshold + extract_signle(opt) + elif mode == 'pair': + GT_folder = '../../datasets/DIV2K/DIV2K_train_HR' + LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' + save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub' + save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4' + scale_ratio = 4 + crop_sz = 480 # the size of each sub-image (GT) + step = 240 # step of the sliding crop window (GT) + thres_sz = 48 # size threshold + ######################################################################## + # check that all the GT and LR images have correct scale ratio + img_GT_list = data_util._get_paths_from_images(GT_folder) + img_LR_list = data_util._get_paths_from_images(LR_folder) + assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.' + for path_GT, path_LR in zip(img_GT_list, img_LR_list): + img_GT = Image.open(path_GT) + img_LR = Image.open(path_LR) + w_GT, h_GT = img_GT.size + w_LR, h_LR = img_LR.size + assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 + w_GT, scale_ratio, w_LR, path_GT) + assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 + w_GT, scale_ratio, w_LR, path_GT) + # check crop size, step and threshold size + assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format( + scale_ratio) + assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio) + assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format( + scale_ratio) + print('process GT...') + opt['input_folder'] = GT_folder + opt['save_folder'] = save_GT_folder + opt['crop_sz'] = crop_sz + opt['step'] = step + opt['thres_sz'] = thres_sz + extract_signle(opt) + print('process LR...') + opt['input_folder'] = LR_folder + opt['save_folder'] = save_LR_folder + opt['crop_sz'] = crop_sz // scale_ratio + opt['step'] = step // scale_ratio + opt['thres_sz'] = thres_sz // scale_ratio + extract_signle(opt) + assert len(data_util._get_paths_from_images(save_GT_folder)) == len( + data_util._get_paths_from_images( + save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.' + else: + raise ValueError('Wrong mode.') + + +def extract_signle(opt): + input_folder = opt['input_folder'] + save_folder = opt['save_folder'] + if not osp.exists(save_folder): + os.makedirs(save_folder) + print('mkdir [{:s}] ...'.format(save_folder)) + else: + print('Folder [{:s}] already exists. Exit...'.format(save_folder)) + sys.exit(1) + img_list = data_util._get_paths_from_images(input_folder) + + def update(arg): + pbar.update(arg) + + pbar = ProgressBar(len(img_list)) + + pool = Pool(opt['n_thread']) + for path in img_list: + pool.apply_async(worker, args=(path, opt), callback=update) + pool.close() + pool.join() + print('All subprocesses done.') + + +def worker(path, opt): + crop_sz = opt['crop_sz'] + step = opt['step'] + thres_sz = opt['thres_sz'] + img_name = osp.basename(path) + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + + n_channels = len(img.shape) + if n_channels == 2: + h, w = img.shape + elif n_channels == 3: + h, w, c = img.shape + else: + raise ValueError('Wrong image shape - {}'.format(n_channels)) + + h_space = np.arange(0, h - crop_sz + 1, step) + if h - (h_space[-1] + crop_sz) > thres_sz: + h_space = np.append(h_space, h - crop_sz) + w_space = np.arange(0, w - crop_sz + 1, step) + if w - (w_space[-1] + crop_sz) > thres_sz: + w_space = np.append(w_space, w - crop_sz) + + index = 0 + for x in h_space: + for y in w_space: + index += 1 + if n_channels == 2: + crop_img = img[x:x + crop_sz, y:y + crop_sz] + else: + crop_img = img[x:x + crop_sz, y:y + crop_sz, :] + crop_img = np.ascontiguousarray(crop_img) + cv2.imwrite( + osp.join(opt['save_folder'], + img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img, + [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) + return 'Processing {:s} ...'.format(img_name) + + +if __name__ == '__main__': + main() diff --git a/codes/data_scripts/generate_LR_Vimeo90K.m b/codes/data_scripts/generate_LR_Vimeo90K.m new file mode 100644 index 00000000..acce7898 --- /dev/null +++ b/codes/data_scripts/generate_LR_Vimeo90K.m @@ -0,0 +1,49 @@ +function generate_LR_Vimeo90K() +%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset + +up_scale = 4; +mod_scale = 4; +idx = 0; +filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png'); +for i = 1 : length(filepaths) + [~,imname,ext] = fileparts(filepaths(i).name); + folder_path = filepaths(i).folder; + save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4'); + if ~exist(save_LR_folder, 'dir') + mkdir(save_LR_folder); + end + if isempty(imname) + disp('Ignore . folder.'); + elseif strcmp(imname, '.') + disp('Ignore .. folder.'); + else + idx = idx + 1; + str_rlt = sprintf('%d\t%s.\n', idx, imname); + fprintf(str_rlt); + % read image + img = imread(fullfile(folder_path, [imname, ext])); + img = im2double(img); + % modcrop + img = modcrop(img, mod_scale); + % LR + im_LR = imresize(img, 1/up_scale, 'bicubic'); + if exist('save_LR_folder', 'var') + imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); + end + end +end +end + +%% modcrop +function img = modcrop(img, modulo) +if size(img,3) == 1 + sz = size(img); + sz = sz - mod(sz, modulo); + img = img(1:sz(1), 1:sz(2)); +else + tmpsz = size(img); + sz = tmpsz(1:2); + sz = sz - mod(sz, modulo); + img = img(1:sz(1), 1:sz(2),:); +end +end diff --git a/codes/data_scripts/generate_mod_LR_bic.m b/codes/data_scripts/generate_mod_LR_bic.m new file mode 100644 index 00000000..05a9c61a --- /dev/null +++ b/codes/data_scripts/generate_mod_LR_bic.m @@ -0,0 +1,82 @@ +function generate_mod_LR_bic() +%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. + +%% set parameters +% comment the unnecessary line +input_folder = '../../datasets/DIV2K/DIV2K800'; +% save_mod_folder = ''; +save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4'; +% save_bic_folder = ''; + +up_scale = 4; +mod_scale = 4; + +if exist('save_mod_folder', 'var') + if exist(save_mod_folder, 'dir') + disp(['It will cover ', save_mod_folder]); + else + mkdir(save_mod_folder); + end +end +if exist('save_LR_folder', 'var') + if exist(save_LR_folder, 'dir') + disp(['It will cover ', save_LR_folder]); + else + mkdir(save_LR_folder); + end +end +if exist('save_bic_folder', 'var') + if exist(save_bic_folder, 'dir') + disp(['It will cover ', save_bic_folder]); + else + mkdir(save_bic_folder); + end +end + +idx = 0; +filepaths = dir(fullfile(input_folder,'*.*')); +for i = 1 : length(filepaths) + [paths,imname,ext] = fileparts(filepaths(i).name); + if isempty(imname) + disp('Ignore . folder.'); + elseif strcmp(imname, '.') + disp('Ignore .. folder.'); + else + idx = idx + 1; + str_rlt = sprintf('%d\t%s.\n', idx, imname); + fprintf(str_rlt); + % read image + img = imread(fullfile(input_folder, [imname, ext])); + img = im2double(img); + % modcrop + img = modcrop(img, mod_scale); + if exist('save_mod_folder', 'var') + imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); + end + % LR + im_LR = imresize(img, 1/up_scale, 'bicubic'); + if exist('save_LR_folder', 'var') + imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); + end + % Bicubic + if exist('save_bic_folder', 'var') + im_B = imresize(im_LR, up_scale, 'bicubic'); + imwrite(im_B, fullfile(save_bic_folder, [imname, '.png'])); + end + end +end +end + +%% modcrop +function img = modcrop(img, modulo) +if size(img,3) == 1 + sz = size(img); + sz = sz - mod(sz, modulo); + img = img(1:sz(1), 1:sz(2)); +else + tmpsz = size(img); + sz = tmpsz(1:2); + sz = sz - mod(sz, modulo); + img = img(1:sz(1), 1:sz(2),:); +end +end diff --git a/codes/data_scripts/generate_mod_LR_bic.py b/codes/data_scripts/generate_mod_LR_bic.py new file mode 100644 index 00000000..59b313a8 --- /dev/null +++ b/codes/data_scripts/generate_mod_LR_bic.py @@ -0,0 +1,81 @@ +import os +import sys +import cv2 +import numpy as np + +try: + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from data.util import imresize_np +except ImportError: + pass + + +def generate_mod_LR_bic(): + # set parameters + up_scale = 4 + mod_scale = 4 + # set data dir + sourcedir = '/data/datasets/img' + savedir = '/data/datasets/mod' + + saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale)) + saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale)) + saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale)) + + if not os.path.isdir(sourcedir): + print('Error: No source data found') + exit(0) + if not os.path.isdir(savedir): + os.mkdir(savedir) + + if not os.path.isdir(os.path.join(savedir, 'HR')): + os.mkdir(os.path.join(savedir, 'HR')) + if not os.path.isdir(os.path.join(savedir, 'LR')): + os.mkdir(os.path.join(savedir, 'LR')) + if not os.path.isdir(os.path.join(savedir, 'Bic')): + os.mkdir(os.path.join(savedir, 'Bic')) + + if not os.path.isdir(saveHRpath): + os.mkdir(saveHRpath) + else: + print('It will cover ' + str(saveHRpath)) + + if not os.path.isdir(saveLRpath): + os.mkdir(saveLRpath) + else: + print('It will cover ' + str(saveLRpath)) + + if not os.path.isdir(saveBicpath): + os.mkdir(saveBicpath) + else: + print('It will cover ' + str(saveBicpath)) + + filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')] + num_files = len(filepaths) + + # prepare data with augementation + for i in range(num_files): + filename = filepaths[i] + print('No.{} -- Processing {}'.format(i, filename)) + # read image + image = cv2.imread(os.path.join(sourcedir, filename)) + + width = int(np.floor(image.shape[1] / mod_scale)) + height = int(np.floor(image.shape[0] / mod_scale)) + # modcrop + if len(image.shape) == 3: + image_HR = image[0:mod_scale * height, 0:mod_scale * width, :] + else: + image_HR = image[0:mod_scale * height, 0:mod_scale * width] + # LR + image_LR = imresize_np(image_HR, 1 / up_scale, True) + # bic + image_Bic = imresize_np(image_LR, up_scale, True) + + cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) + cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) + cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) + + +if __name__ == "__main__": + generate_mod_LR_bic() diff --git a/codes/data_scripts/prepare_DIV2K_x4_dataset.sh b/codes/data_scripts/prepare_DIV2K_x4_dataset.sh new file mode 100644 index 00000000..a53bd1f0 --- /dev/null +++ b/codes/data_scripts/prepare_DIV2K_x4_dataset.sh @@ -0,0 +1,42 @@ + + +echo "Prepare DIV2K X4 datasets..." +cd ../../datasets +mkdir DIV2K +cd DIV2K + +#### Step 1 +echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..." +# GT +FOLDER=DIV2K_train_HR +FILE=DIV2K_train_HR.zip +if [ ! -d "$FOLDER" ]; then + if [ ! -f "$FILE" ]; then + echo "Downloading $FILE..." + wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE + fi + unzip $FILE +fi +# LR +FOLDER=DIV2K_train_LR_bicubic +FILE=DIV2K_train_LR_bicubic_X4.zip +if [ ! -d "$FOLDER" ]; then + if [ ! -f "$FILE" ]; then + echo "Downloading $FILE..." + wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE + fi + unzip $FILE +fi + +#### Step 2 +echo "Step 2: Rename the LR images..." +cd ../../codes/data_scripts +python rename.py + +#### Step 4 +echo "Step 4: Crop to sub-images..." +python extract_subimages.py + +#### Step 5 +echo "Step5: Create LMDB files..." +python create_lmdb.py diff --git a/codes/data_scripts/regroup_REDS.py b/codes/data_scripts/regroup_REDS.py new file mode 100644 index 00000000..7c8fa928 --- /dev/null +++ b/codes/data_scripts/regroup_REDS.py @@ -0,0 +1,11 @@ +import os +import glob + +train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4' +val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4' + +# mv the val set +val_folders = glob.glob(os.path.join(val_path, '*')) +for folder in val_folders: + new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240) + os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx))) diff --git a/codes/data_scripts/rename.py b/codes/data_scripts/rename.py new file mode 100644 index 00000000..f8a19552 --- /dev/null +++ b/codes/data_scripts/rename.py @@ -0,0 +1,19 @@ +import os +import glob + + +def main(): + folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' + DIV2K(folder) + print('Finished.') + + +def DIV2K(path): + img_path_l = glob.glob(os.path.join(path, '*')) + for img_path in img_path_l: + new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') + os.rename(img_path, new_path) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/codes/data_scripts/test_dataloader.py b/codes/data_scripts/test_dataloader.py new file mode 100644 index 00000000..5f580793 --- /dev/null +++ b/codes/data_scripts/test_dataloader.py @@ -0,0 +1,104 @@ +import sys +import os.path as osp +import math +import torchvision.utils + +sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) +from data import create_dataloader, create_dataset # noqa: E402 +from utils import util # noqa: E402 + + +def main(): + dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub + opt = {} + opt['dist'] = False + opt['gpu_ids'] = [0] + if dataset == 'REDS': + opt['name'] = 'test_REDS' + opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb' + opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' + opt['mode'] = 'REDS' + opt['N_frames'] = 5 + opt['phase'] = 'train' + opt['use_shuffle'] = True + opt['n_workers'] = 8 + opt['batch_size'] = 16 + opt['GT_size'] = 256 + opt['LQ_size'] = 64 + opt['scale'] = 4 + opt['use_flip'] = True + opt['use_rot'] = True + opt['interval_list'] = [1] + opt['random_reverse'] = False + opt['border_mode'] = False + opt['cache_keys'] = None + opt['data_type'] = 'lmdb' # img | lmdb | mc + elif dataset == 'Vimeo90K': + opt['name'] = 'test_Vimeo90K' + opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' + opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' + opt['mode'] = 'Vimeo90K' + opt['N_frames'] = 7 + opt['phase'] = 'train' + opt['use_shuffle'] = True + opt['n_workers'] = 8 + opt['batch_size'] = 16 + opt['GT_size'] = 256 + opt['LQ_size'] = 64 + opt['scale'] = 4 + opt['use_flip'] = True + opt['use_rot'] = True + opt['interval_list'] = [1] + opt['random_reverse'] = False + opt['border_mode'] = False + opt['cache_keys'] = None + opt['data_type'] = 'lmdb' # img | lmdb | mc + elif dataset == 'DIV2K800_sub': + opt['name'] = 'DIV2K800' + opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' + opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' + opt['mode'] = 'LQGT' + opt['phase'] = 'train' + opt['use_shuffle'] = True + opt['n_workers'] = 8 + opt['batch_size'] = 16 + opt['GT_size'] = 128 + opt['scale'] = 4 + opt['use_flip'] = True + opt['use_rot'] = True + opt['color'] = 'RGB' + opt['data_type'] = 'lmdb' # img | lmdb + else: + raise ValueError('Please implement by yourself.') + + util.mkdir('tmp') + train_set = create_dataset(opt) + train_loader = create_dataloader(train_set, opt, opt, None) + nrow = int(math.sqrt(opt['batch_size'])) + padding = 2 if opt['phase'] == 'train' else 0 + + print('start...') + for i, data in enumerate(train_loader): + if i > 5: + break + print(i) + if dataset == 'REDS' or dataset == 'Vimeo90K': + LQs = data['LQs'] + else: + LQ = data['LQ'] + GT = data['GT'] + + if dataset == 'REDS' or dataset == 'Vimeo90K': + for j in range(LQs.size(1)): + torchvision.utils.save_image(LQs[:, j, :, :, :], + 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow, + padding=padding, normalize=False) + else: + torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow, + padding=padding, normalize=False) + torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding, + normalize=False) + + +if __name__ == "__main__": + main() diff --git a/codes/metrics/calculate_PSNR_SSIM.m b/codes/metrics/calculate_PSNR_SSIM.m new file mode 100644 index 00000000..471de83f --- /dev/null +++ b/codes/metrics/calculate_PSNR_SSIM.m @@ -0,0 +1,261 @@ +function calculate_PSNR_SSIM() + +% GT and SR folder +folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'; +folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'; +scale = 4; +suffix = ''; % suffix for SR images +test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels +if test_Y + fprintf('Tesing Y channel.\n'); +else + fprintf('Tesing RGB channels.\n'); +end +filepaths = dir(fullfile(folder_GT, '*.png')); +PSNR_all = zeros(1, length(filepaths)); +SSIM_all = zeros(1, length(filepaths)); + +for idx_im = 1:length(filepaths) + im_name = filepaths(idx_im).name; + im_GT = imread(fullfile(folder_GT, im_name)); + im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png'])); + + if test_Y % evaluate on Y channel in YCbCr color space + if size(im_GT, 3) == 3 + im_GT_YCbCr = rgb2ycbcr(im2double(im_GT)); + im_GT_in = im_GT_YCbCr(:,:,1); + im_SR_YCbCr = rgb2ycbcr(im2double(im_SR)); + im_SR_in = im_SR_YCbCr(:,:,1); + else + im_GT_in = im2double(im_GT); + im_SR_in = im2double(im_SR); + end + else % evaluate on RGB channels + im_GT_in = im2double(im_GT); + im_SR_in = im2double(im_SR); + end + + % calculate PSNR and SSIM + PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale); + SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale); + fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im)); +end + +fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all)); +end + +function res = calculate_PSNR(GT, SR, border) +% remove border +GT = GT(border+1:end-border, border+1:end-border, :); +SR = SR(border+1:end-border, border+1:end-border, :); +% calculate PNSR (assume in [0,255]) +error = GT(:) - SR(:); +mse = mean(error.^2); +res = 10 * log10(255^2/mse); +end + +function res = calculate_SSIM(GT, SR, border) +GT = GT(border+1:end-border, border+1:end-border, :); +SR = SR(border+1:end-border, border+1:end-border, :); +% calculate SSIM +mssim = zeros(1, size(SR, 3)); +for i = 1:size(SR,3) + [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i)); +end +res = mean(mssim); +end + +function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L) + +%======================================================================== +%SSIM Index, Version 1.0 +%Copyright(c) 2003 Zhou Wang +%All Rights Reserved. +% +%The author is with Howard Hughes Medical Institute, and Laboratory +%for Computational Vision at Center for Neural Science and Courant +%Institute of Mathematical Sciences, New York University. +% +%---------------------------------------------------------------------- +%Permission to use, copy, or modify this software and its documentation +%for educational and research purposes only and without fee is hereby +%granted, provided that this copyright notice and the original authors' +%names appear on all copies and supporting documentation. This program +%shall not be used, rewritten, or adapted as the basis of a commercial +%software or hardware product without first obtaining permission of the +%authors. The authors make no representations about the suitability of +%this software for any purpose. It is provided "as is" without express +%or implied warranty. +%---------------------------------------------------------------------- +% +%This is an implementation of the algorithm for calculating the +%Structural SIMilarity (SSIM) index between two images. Please refer +%to the following paper: +% +%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image +%quality assessment: From error measurement to structural similarity" +%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. +% +%Kindly report any suggestions or corrections to zhouwang@ieee.org +% +%---------------------------------------------------------------------- +% +%Input : (1) img1: the first image being compared +% (2) img2: the second image being compared +% (3) K: constants in the SSIM index formula (see the above +% reference). defualt value: K = [0.01 0.03] +% (4) window: local window for statistics (see the above +% reference). default widnow is Gaussian given by +% window = fspecial('gaussian', 11, 1.5); +% (5) L: dynamic range of the images. default: L = 255 +% +%Output: (1) mssim: the mean SSIM index value between 2 images. +% If one of the images being compared is regarded as +% perfect quality, then mssim can be considered as the +% quality measure of the other image. +% If img1 = img2, then mssim = 1. +% (2) ssim_map: the SSIM index map of the test image. The map +% has a smaller size than the input images. The actual size: +% size(img1) - size(window) + 1. +% +%Default Usage: +% Given 2 test images img1 and img2, whose dynamic range is 0-255 +% +% [mssim ssim_map] = ssim_index(img1, img2); +% +%Advanced Usage: +% User defined parameters. For example +% +% K = [0.05 0.05]; +% window = ones(8); +% L = 100; +% [mssim ssim_map] = ssim_index(img1, img2, K, window, L); +% +%See the results: +% +% mssim %Gives the mssim value +% imshow(max(0, ssim_map).^4) %Shows the SSIM index map +% +%======================================================================== + + +if (nargin < 2 || nargin > 5) + ssim_index = -Inf; + ssim_map = -Inf; + return; +end + +if (size(img1) ~= size(img2)) + ssim_index = -Inf; + ssim_map = -Inf; + return; +end + +[M, N] = size(img1); + +if (nargin == 2) + if ((M < 11) || (N < 11)) + ssim_index = -Inf; + ssim_map = -Inf; + return + end + window = fspecial('gaussian', 11, 1.5); % + K(1) = 0.01; % default settings + K(2) = 0.03; % + L = 255; % +end + +if (nargin == 3) + if ((M < 11) || (N < 11)) + ssim_index = -Inf; + ssim_map = -Inf; + return + end + window = fspecial('gaussian', 11, 1.5); + L = 255; + if (length(K) == 2) + if (K(1) < 0 || K(2) < 0) + ssim_index = -Inf; + ssim_map = -Inf; + return; + end + else + ssim_index = -Inf; + ssim_map = -Inf; + return; + end +end + +if (nargin == 4) + [H, W] = size(window); + if ((H*W) < 4 || (H > M) || (W > N)) + ssim_index = -Inf; + ssim_map = -Inf; + return + end + L = 255; + if (length(K) == 2) + if (K(1) < 0 || K(2) < 0) + ssim_index = -Inf; + ssim_map = -Inf; + return; + end + else + ssim_index = -Inf; + ssim_map = -Inf; + return; + end +end + +if (nargin == 5) + [H, W] = size(window); + if ((H*W) < 4 || (H > M) || (W > N)) + ssim_index = -Inf; + ssim_map = -Inf; + return + end + if (length(K) == 2) + if (K(1) < 0 || K(2) < 0) + ssim_index = -Inf; + ssim_map = -Inf; + return; + end + else + ssim_index = -Inf; + ssim_map = -Inf; + return; + end +end + +C1 = (K(1)*L)^2; +C2 = (K(2)*L)^2; +window = window/sum(sum(window)); +img1 = double(img1); +img2 = double(img2); + +mu1 = filter2(window, img1, 'valid'); +mu2 = filter2(window, img2, 'valid'); +mu1_sq = mu1.*mu1; +mu2_sq = mu2.*mu2; +mu1_mu2 = mu1.*mu2; +sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; +sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; +sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; + +if (C1 > 0 && C2 > 0) + ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); +else + numerator1 = 2*mu1_mu2 + C1; + numerator2 = 2*sigma12 + C2; + denominator1 = mu1_sq + mu2_sq + C1; + denominator2 = sigma1_sq + sigma2_sq + C2; + ssim_map = ones(size(mu1)); + index = (denominator1.*denominator2 > 0); + ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); + index = (denominator1 ~= 0) & (denominator2 == 0); + ssim_map(index) = numerator1(index)./denominator1(index); +end + +mssim = mean2(ssim_map); + +end diff --git a/codes/metrics/calculate_PSNR_SSIM.py b/codes/metrics/calculate_PSNR_SSIM.py new file mode 100644 index 00000000..fecc0310 --- /dev/null +++ b/codes/metrics/calculate_PSNR_SSIM.py @@ -0,0 +1,147 @@ +''' +calculate the PSNR and SSIM. +same as MATLAB's results +''' +import os +import math +import numpy as np +import cv2 +import glob + + +def main(): + # Configurations + + # GT - Ground-truth; + # Gen: Generated / Restored / Recovered images + folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5' + folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5' + + crop_border = 4 + suffix = '' # suffix for Gen images + test_Y = False # True: test Y channel only; False: test RGB channels + + PSNR_all = [] + SSIM_all = [] + img_list = sorted(glob.glob(folder_GT + '/*')) + + if test_Y: + print('Testing Y channel.') + else: + print('Testing RGB channels.') + + for i, img_path in enumerate(img_list): + base_name = os.path.splitext(os.path.basename(img_path))[0] + im_GT = cv2.imread(img_path) / 255. + im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255. + + if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space + im_GT_in = bgr2ycbcr(im_GT) + im_Gen_in = bgr2ycbcr(im_Gen) + else: + im_GT_in = im_GT + im_Gen_in = im_Gen + + # crop borders + if im_GT_in.ndim == 3: + cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] + cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] + elif im_GT_in.ndim == 2: + cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] + cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] + else: + raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) + + # calculate PSNR and SSIM + PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) + + SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) + print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format( + i + 1, base_name, PSNR, SSIM)) + PSNR_all.append(PSNR) + SSIM_all.append(SSIM) + print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format( + sum(PSNR_all) / len(PSNR_all), + sum(SSIM_all) / len(SSIM_all))) + + +def calculate_psnr(img1, img2): + # img1 and img2 have range [0, 255] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, img2): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1, img2)) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def bgr2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +if __name__ == '__main__': + main() diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py new file mode 100644 index 00000000..77bb0fca --- /dev/null +++ b/codes/models/SRGAN_model.py @@ -0,0 +1,267 @@ +import logging +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.loss import GANLoss + +logger = logging.getLogger('base') + + +class SRGANModel(BaseModel): + def __init__(self, opt): + super(SRGANModel, self).__init__(opt) + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + + # define networks and load pretrained models + self.netG = networks.define_G(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + if self.is_train: + self.netD = networks.define_D(opt).to(self.device) + if opt['dist']: + self.netD = DistributedDataParallel(self.netD, + device_ids=[torch.cuda.current_device()]) + else: + self.netD = DataParallel(self.netD) + + self.netG.train() + self.netD.train() + + # define losses, optimizer and scheduler + if self.is_train: + # G pixel loss + if train_opt['pixel_weight'] > 0: + l_pix_type = train_opt['pixel_criterion'] + if l_pix_type == 'l1': + self.cri_pix = nn.L1Loss().to(self.device) + elif l_pix_type == 'l2': + self.cri_pix = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) + self.l_pix_w = train_opt['pixel_weight'] + else: + logger.info('Remove pixel loss.') + self.cri_pix = None + + # G feature loss + if train_opt['feature_weight'] > 0: + l_fea_type = train_opt['feature_criterion'] + if l_fea_type == 'l1': + self.cri_fea = nn.L1Loss().to(self.device) + elif l_fea_type == 'l2': + self.cri_fea = nn.MSELoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) + self.l_fea_w = train_opt['feature_weight'] + else: + logger.info('Remove feature loss.') + self.cri_fea = None + if self.cri_fea: # load VGG perceptual loss + self.netF = networks.define_F(opt, use_bn=False).to(self.device) + if opt['dist']: + self.netF = DistributedDataParallel(self.netF, + device_ids=[torch.cuda.current_device()]) + else: + self.netF = DataParallel(self.netF) + + # GD gan loss + self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) + self.l_gan_w = train_opt['gan_weight'] + # D_update_ratio and D_init_iters + self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 + self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 + + # optimizers + # G + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + optim_params = [] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1_G'], train_opt['beta2_G'])) + self.optimizers.append(self.optimizer_G) + # D + wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], + weight_decay=wd_D, + betas=(train_opt['beta1_D'], train_opt['beta2_D'])) + self.optimizers.append(self.optimizer_D) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + self.print_network() # print network + self.load() # load G and D if needed + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQ'].to(self.device) # LQ + if need_GT: + self.var_H = data['GT'].to(self.device) # GT + input_ref = data['ref'] if 'ref' in data else data['GT'] + self.var_ref = input_ref.to(self.device) + + def optimize_parameters(self, step): + # G + for p in self.netD.parameters(): + p.requires_grad = False + + self.optimizer_G.zero_grad() + self.fake_H = self.netG(self.var_L) + + l_g_total = 0 + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if self.cri_pix: # pixel loss + l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) + l_g_total += l_g_pix + if self.cri_fea: # feature loss + real_fea = self.netF(self.var_H).detach() + fake_fea = self.netF(self.fake_H) + l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) + l_g_total += l_g_fea + + pred_g_fake = self.netD(self.fake_H) + if self.opt['train']['gan_type'] == 'gan': + l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_d_real = self.netD(self.var_ref).detach() + l_g_gan = self.l_gan_w * ( + self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_total += l_g_gan + + l_g_total.backward() + self.optimizer_G.step() + + # D + for p in self.netD.parameters(): + p.requires_grad = True + + self.optimizer_D.zero_grad() + l_d_total = 0 + pred_d_real = self.netD(self.var_ref) + pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G + if self.opt['train']['gan_type'] == 'gan': + l_d_real = self.cri_gan(pred_d_real, True) + l_d_fake = self.cri_gan(pred_d_fake, False) + l_d_total = l_d_real + l_d_fake + elif self.opt['train']['gan_type'] == 'ragan': + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + l_d_total = (l_d_real + l_d_fake) / 2 + + l_d_total.backward() + self.optimizer_D.step() + + # set log + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if self.cri_pix: + self.log_dict['l_g_pix'] = l_g_pix.item() + if self.cri_fea: + self.log_dict['l_g_fea'] = l_g_fea.item() + self.log_dict['l_g_gan'] = l_g_gan.item() + + self.log_dict['l_d_real'] = l_d_real.item() + self.log_dict['l_d_fake'] = l_d_fake.item() + self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) + self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + + def test(self): + self.netG.eval() + with torch.no_grad(): + self.fake_H = self.netG(self.var_L) + self.netG.train() + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() + if need_GT: + out_dict['GT'] = self.var_H.detach()[0].float().cpu() + return out_dict + + def print_network(self): + # Generator + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + if self.is_train: + # Discriminator + s, n = self.get_network_description(self.netD) + if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, + DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, + self.netD.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netD.__class__.__name__) + if self.rank <= 0: + logger.info('Network D structure: {}, with parameters: {:,d}'.format( + net_struc_str, n)) + logger.info(s) + + if self.cri_fea: # F, Perceptual Network + s, n = self.get_network_description(self.netF) + if isinstance(self.netF, nn.DataParallel) or isinstance( + self.netF, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, + self.netF.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netF.__class__.__name__) + if self.rank <= 0: + logger.info('Network F structure: {}, with parameters: {:,d}'.format( + net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + load_path_D = self.opt['path']['pretrain_model_D'] + if self.opt['is_train'] and load_path_D is not None: + logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) + self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) + + def save(self, iter_step): + self.save_network(self.netG, 'G', iter_step) + self.save_network(self.netD, 'D', iter_step) diff --git a/codes/models/SR_model.py b/codes/models/SR_model.py new file mode 100644 index 00000000..6782762a --- /dev/null +++ b/codes/models/SR_model.py @@ -0,0 +1,170 @@ +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.loss import CharbonnierLoss + +logger = logging.getLogger('base') + + +class SRModel(BaseModel): + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + + # define network and load pretrained models + self.netG = networks.define_G(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + # print network + self.print_network() + self.load() + + if self.is_train: + self.netG.train() + + # loss + loss_type = train_opt['pixel_criterion'] + if loss_type == 'l1': + self.cri_pix = nn.L1Loss().to(self.device) + elif loss_type == 'l2': + self.cri_pix = nn.MSELoss().to(self.device) + elif loss_type == 'cb': + self.cri_pix = CharbonnierLoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_pix_w = train_opt['pixel_weight'] + + # optimizers + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + optim_params = [] + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQ'].to(self.device) # LQ + if need_GT: + self.real_H = data['GT'].to(self.device) # GT + + def optimize_parameters(self, step): + self.optimizer_G.zero_grad() + self.fake_H = self.netG(self.var_L) + l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) + l_pix.backward() + self.optimizer_G.step() + + # set log + self.log_dict['l_pix'] = l_pix.item() + + def test(self): + self.netG.eval() + with torch.no_grad(): + self.fake_H = self.netG(self.var_L) + self.netG.train() + + def test_x8(self): + # from https://github.com/thstkdgus35/EDSR-PyTorch + self.netG.eval() + + def _transform(v, op): + # if self.precision != 'single': v = v.float() + v2np = v.data.cpu().numpy() + if op == 'v': + tfnp = v2np[:, :, :, ::-1].copy() + elif op == 'h': + tfnp = v2np[:, :, ::-1, :].copy() + elif op == 't': + tfnp = v2np.transpose((0, 1, 3, 2)).copy() + + ret = torch.Tensor(tfnp).to(self.device) + # if self.precision == 'half': ret = ret.half() + + return ret + + lr_list = [self.var_L] + for tf in 'v', 'h', 't': + lr_list.extend([_transform(t, tf) for t in lr_list]) + with torch.no_grad(): + sr_list = [self.netG(aug) for aug in lr_list] + for i in range(len(sr_list)): + if i > 3: + sr_list[i] = _transform(sr_list[i], 't') + if i % 4 > 1: + sr_list[i] = _transform(sr_list[i], 'h') + if (i % 4) % 2 == 1: + sr_list[i] = _transform(sr_list[i], 'v') + + output_cat = torch.cat(sr_list, dim=0) + self.fake_H = output_cat.mean(dim=0, keepdim=True) + self.netG.train() + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() + if need_GT: + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) diff --git a/codes/models/Video_base_model.py b/codes/models/Video_base_model.py new file mode 100644 index 00000000..eb85fc5c --- /dev/null +++ b/codes/models/Video_base_model.py @@ -0,0 +1,166 @@ +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.loss import CharbonnierLoss + +logger = logging.getLogger('base') + + +class VideoBaseModel(BaseModel): + def __init__(self, opt): + super(VideoBaseModel, self).__init__(opt) + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt['train'] + + # define network and load pretrained models + self.netG = networks.define_G(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + # print network + self.print_network() + self.load() + + if self.is_train: + self.netG.train() + + #### loss + loss_type = train_opt['pixel_criterion'] + if loss_type == 'l1': + self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) + elif loss_type == 'l2': + self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) + elif loss_type == 'cb': + self.cri_pix = CharbonnierLoss().to(self.device) + else: + raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) + self.l_pix_w = train_opt['pixel_weight'] + + #### optimizers + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + if train_opt['ft_tsa_only']: + normal_params = [] + tsa_fusion_params = [] + for k, v in self.netG.named_parameters(): + if v.requires_grad: + if 'tsa_fusion' in k: + tsa_fusion_params.append(v) + else: + normal_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + optim_params = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['lr_G'] + }, + { + 'params': tsa_fusion_params, + 'lr': train_opt['lr_G'] + }, + ] + else: + optim_params = [] + for k, v in self.netG.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + #### schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError() + + self.log_dict = OrderedDict() + + def feed_data(self, data, need_GT=True): + self.var_L = data['LQs'].to(self.device) + if need_GT: + self.real_H = data['GT'].to(self.device) + + def set_params_lr_zero(self): + # fix normal module + self.optimizers[0].param_groups[0]['lr'] = 0 + + def optimize_parameters(self, step): + if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: + self.set_params_lr_zero() + + self.optimizer_G.zero_grad() + self.fake_H = self.netG(self.var_L) + + l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) + l_pix.backward() + self.optimizer_G.step() + + # set log + self.log_dict['l_pix'] = l_pix.item() + + def test(self): + self.netG.eval() + with torch.no_grad(): + self.fake_H = self.netG(self.var_L) + self.netG.train() + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict['LQ'] = self.var_L.detach()[0].float().cpu() + out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() + if need_GT: + out_dict['GT'] = self.real_H.detach()[0].float().cpu() + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) diff --git a/codes/models/__init__.py b/codes/models/__init__.py new file mode 100644 index 00000000..c95004c9 --- /dev/null +++ b/codes/models/__init__.py @@ -0,0 +1,19 @@ +import logging +logger = logging.getLogger('base') + + +def create_model(opt): + model = opt['model'] + # image restoration + if model == 'sr': # PSNR-oriented super resolution + from .SR_model import SRModel as M + elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN + from .SRGAN_model import SRGANModel as M + # video restoration + elif model == 'video_base': + from .Video_base_model import VideoBaseModel as M + else: + raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) + m = M(opt) + logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) + return m diff --git a/codes/models/archs/DUF_arch.py b/codes/models/archs/DUF_arch.py new file mode 100644 index 00000000..319cea87 --- /dev/null +++ b/codes/models/archs/DUF_arch.py @@ -0,0 +1,368 @@ +'''Network architecture for DUF: +Deep Video Super-Resolution Network Using Dynamic Upsampling Filters +Without Explicit Motion Compensation (CVPR18) +https://github.com/yhjo09/VSR-DUF + +For all the models below, [adapt_official] is only necessary when +loading the weights converted from the official TensorFlow weights. +Please set it to [False] if you are training the model from scratch. +''' + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def adapt_official(Rx, scale=4): + '''Adapt the weights translated from the official tensorflow weights + Not necessary if you are training from scratch''' + x = Rx.clone() + x1 = x[:, ::3, :, :] + x2 = x[:, 1::3, :, :] + x3 = x[:, 2::3, :, :] + + Rx[:, :scale**2, :, :] = x1 + Rx[:, scale**2:2 * (scale**2), :, :] = x2 + Rx[:, 2 * (scale**2):, :, :] = x3 + + return Rx + + +class DenseBlock(nn.Module): + '''Dense block + for the second denseblock, t_reduced = True''' + + def __init__(self, nf=64, ng=32, t_reduce=False): + super(DenseBlock, self).__init__() + self.t_reduce = t_reduce + if self.t_reduce: + pad = (0, 1, 1) + else: + pad = (1, 1, 1) + self.bn3d_1 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3) + self.conv3d_1 = nn.Conv3d(nf, nf, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + self.bn3d_2 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3) + self.conv3d_2 = nn.Conv3d(nf, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True) + self.bn3d_3 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3) + self.conv3d_3 = nn.Conv3d(nf + ng, nf + ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True) + self.bn3d_4 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3) + self.conv3d_4 = nn.Conv3d(nf + ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True) + self.bn3d_5 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3) + self.conv3d_5 = nn.Conv3d(nf + 2 * ng, nf + 2 * ng, (1, 1, 1), stride=(1, 1, 1), + padding=(0, 0, 0), bias=True) + self.bn3d_6 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3) + self.conv3d_6 = nn.Conv3d(nf + 2 * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, + bias=True) + + def forward(self, x): + '''x: [B, C, T, H, W] + C: nf -> nf + 3 * ng + T: 1) 7 -> 7 (t_reduce=False); + 2) 7 -> 7 - 2 * 3 = 1 (t_reduce=True)''' + x1 = self.conv3d_1(F.relu(self.bn3d_1(x), inplace=True)) + x1 = self.conv3d_2(F.relu(self.bn3d_2(x1), inplace=True)) + if self.t_reduce: + x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1) + else: + x1 = torch.cat((x, x1), 1) + + x2 = self.conv3d_3(F.relu(self.bn3d_3(x1), inplace=True)) + x2 = self.conv3d_4(F.relu(self.bn3d_4(x2), inplace=True)) + if self.t_reduce: + x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1) + else: + x2 = torch.cat((x1, x2), 1) + + x3 = self.conv3d_5(F.relu(self.bn3d_5(x2), inplace=True)) + x3 = self.conv3d_6(F.relu(self.bn3d_6(x3), inplace=True)) + if self.t_reduce: + x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1) + else: + x3 = torch.cat((x2, x3), 1) + return x3 + + +class DynamicUpsamplingFilter_3C(nn.Module): + '''dynamic upsampling filter with 3 channels applying the same filters + filter_size: filter size of the generated filters, shape (C, kH, kW)''' + + def __init__(self, filter_size=(1, 5, 5)): + super(DynamicUpsamplingFilter_3C, self).__init__() + # generate a local expansion filter, used similar to im2col + nF = np.prod(filter_size) + expand_filter_np = np.reshape(np.eye(nF, nF), + (nF, filter_size[0], filter_size[1], filter_size[2])) + expand_filter = torch.from_numpy(expand_filter_np).float() + self.expand_filter = torch.cat((expand_filter, expand_filter, expand_filter), + 0) # [75, 1, 5, 5] + + def forward(self, x, filters): + '''x: input image, [B, 3, H, W] + filters: generate dynamic filters, [B, F, R, H, W], e.g., [B, 25, 16, H, W] + F: prod of filter kernel size, e.g., 5*5 = 25 + R: used for upsampling, similar to pixel shuffle, e.g., 4*4 = 16 for x4 + Return: filtered image, [B, 3*R, H, W] + ''' + B, nF, R, H, W = filters.size() + # using group convolution + input_expand = F.conv2d(x, self.expand_filter.type_as(x), padding=2, + groups=3) # [B, 75, H, W] similar to im2col + input_expand = input_expand.view(B, 3, nF, H, W).permute(0, 3, 4, 1, 2) # [B, H, W, 3, 25] + filters = filters.permute(0, 3, 4, 1, 2) # [B, H, W, 25, 16] + out = torch.matmul(input_expand, filters) # [B, H, W, 3, 16] + return out.permute(0, 3, 4, 1, 2).view(B, 3 * R, H, W) # [B, 3*16, H, W] + + +class DUF_16L(nn.Module): + '''Official DUF structure with 16 layers''' + + def __init__(self, scale=4, adapt_official=False): + super(DUF_16L, self).__init__() + self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) + self.dense_block_1 = DenseBlock(64, 64 // 2, t_reduce=False) # 64 + 32 * 3 = 160, T = 7 + self.dense_block_2 = DenseBlock(160, 64 // 2, t_reduce=True) # 160 + 32 * 3 = 256, T = 1 + self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3) + self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), + bias=True) + + self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True) + self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), + padding=(0, 0, 0), bias=True) + + self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True) + self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), + padding=(0, 0, 0), bias=True) + + self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5)) + + self.scale = scale + self.adapt_official = adapt_official + + def forward(self, x): + ''' + x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D + Generate filters and image residual: + Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C + Rx: [B, 3*16, 1, H, W] + ''' + B, T, C, H, W = x.size() + x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] for Conv3D + x_center = x[:, :, T // 2, :, :] + + x = self.conv3d_1(x) + x = self.dense_block_1(x) + x = self.dense_block_2(x) # reduce T to 1 + x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True) + + # image residual + Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W] + + # filter + Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W] + Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1) + + # Adapt to official model weights + if self.adapt_official: + adapt_official(Rx, scale=self.scale) + + # dynamic filter + out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W] + out += Rx.squeeze_(2) + out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W] + + return out + + +class DenseBlock_28L(nn.Module): + '''The first part of the dense blocks used in DUF_28L + Temporal dimension remains the same here''' + + def __init__(self, nf=64, ng=16): + super(DenseBlock_28L, self).__init__() + pad = (1, 1, 1) + + dense_block_l = [] + for i in range(0, 9): + dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3)) + dense_block_l.append(nn.ReLU()) + dense_block_l.append( + nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True)) + + dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3)) + dense_block_l.append(nn.ReLU()) + dense_block_l.append( + nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)) + + self.dense_blocks = nn.ModuleList(dense_block_l) + + def forward(self, x): + '''x: [B, C, T, H, W] + C: 1) 64 -> 208; + T: 1) 7 -> 7; (t_reduce=True)''' + for i in range(0, len(self.dense_blocks), 6): + y = x + for j in range(6): + y = self.dense_blocks[i + j](y) + x = torch.cat((x, y), 1) + return x + + +class DUF_28L(nn.Module): + '''Official DUF structure with 28 layers''' + + def __init__(self, scale=4, adapt_official=False): + super(DUF_28L, self).__init__() + self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) + self.dense_block_1 = DenseBlock_28L(64, 16) # 64 + 16 * 9 = 208, T = 7 + self.dense_block_2 = DenseBlock(208, 16, t_reduce=True) # 208 + 16 * 3 = 256, T = 1 + self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3) + self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), + bias=True) + + self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True) + self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), + padding=(0, 0, 0), bias=True) + + self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True) + self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), + padding=(0, 0, 0), bias=True) + + self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5)) + + self.scale = scale + self.adapt_official = adapt_official + + def forward(self, x): + ''' + x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D + Generate filters and image residual: + Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C + Rx: [B, 3*16, 1, H, W] + ''' + B, T, C, H, W = x.size() + x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D + x_center = x[:, :, T // 2, :, :] + x = self.conv3d_1(x) + x = self.dense_block_1(x) + x = self.dense_block_2(x) # reduce T to 1 + x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True) + + # image residual + Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W] + + # filter + Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W] + Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1) + + # Adapt to official model weights + if self.adapt_official: + adapt_official(Rx, scale=self.scale) + + # dynamic filter + out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W] + out += Rx.squeeze_(2) + out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W] + return out + + +class DenseBlock_52L(nn.Module): + '''The first part of the dense blocks used in DUF_52L + Temporal dimension remains the same here''' + + def __init__(self, nf=64, ng=16): + super(DenseBlock_52L, self).__init__() + pad = (1, 1, 1) + + dense_block_l = [] + for i in range(0, 21): + dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3)) + dense_block_l.append(nn.ReLU()) + dense_block_l.append( + nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True)) + + dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3)) + dense_block_l.append(nn.ReLU()) + dense_block_l.append( + nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)) + + self.dense_blocks = nn.ModuleList(dense_block_l) + + def forward(self, x): + '''x: [B, C, T, H, W] + C: 1) 64 -> 400; + T: 1) 7 -> 7; (t_reduce=True)''' + for i in range(0, len(self.dense_blocks), 6): + y = x + for j in range(6): + y = self.dense_blocks[i + j](y) + x = torch.cat((x, y), 1) + return x + + +class DUF_52L(nn.Module): + '''Official DUF structure with 52 layers''' + + def __init__(self, scale=4, adapt_official=False): + super(DUF_52L, self).__init__() + self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) + self.dense_block_1 = DenseBlock_52L(64, 16) # 64 + 21 * 9 = 400, T = 7 + self.dense_block_2 = DenseBlock(400, 16, t_reduce=True) # 400 + 16 * 3 = 448, T = 1 + + self.bn3d_2 = nn.BatchNorm3d(448, eps=1e-3, momentum=1e-3) + self.conv3d_2 = nn.Conv3d(448, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), + bias=True) + + self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True) + self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), + padding=(0, 0, 0), bias=True) + + self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), + bias=True) + self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), + padding=(0, 0, 0), bias=True) + + self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5)) + + self.scale = scale + self.adapt_official = adapt_official + + def forward(self, x): + ''' + x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D + Generate filters and image residual: + Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C + Rx: [B, 3*16, 1, H, W] + ''' + B, T, C, H, W = x.size() + x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D + x_center = x[:, :, T // 2, :, :] + x = self.conv3d_1(x) + x = self.dense_block_1(x) + x = self.dense_block_2(x) + x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True) + + # image residual + Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W] + + # filter + Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W] + Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1) + + # Adapt to official model weights + if self.adapt_official: + adapt_official(Rx, scale=self.scale) + + # dynamic filter + out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W] + out += Rx.squeeze_(2) + out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W] + return out diff --git a/codes/models/archs/EDVR_arch.py b/codes/models/archs/EDVR_arch.py new file mode 100644 index 00000000..df9c0325 --- /dev/null +++ b/codes/models/archs/EDVR_arch.py @@ -0,0 +1,312 @@ +''' network architecture for EDVR ''' +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.archs.arch_util as arch_util +try: + from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN +except ImportError: + raise ImportError('Failed to import DCNv2 module.') + + +class Predeblur_ResNet_Pyramid(nn.Module): + def __init__(self, nf=128, HR_in=False): + ''' + HR_in: True if the inputs are high spatial size + ''' + + super(Predeblur_ResNet_Pyramid, self).__init__() + self.HR_in = True if HR_in else False + if self.HR_in: + self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) + self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + else: + self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) + basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) + self.RB_L1_1 = basic_block() + self.RB_L1_2 = basic_block() + self.RB_L1_3 = basic_block() + self.RB_L1_4 = basic_block() + self.RB_L1_5 = basic_block() + self.RB_L2_1 = basic_block() + self.RB_L2_2 = basic_block() + self.RB_L3_1 = basic_block() + self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, x): + if self.HR_in: + L1_fea = self.lrelu(self.conv_first_1(x)) + L1_fea = self.lrelu(self.conv_first_2(L1_fea)) + L1_fea = self.lrelu(self.conv_first_3(L1_fea)) + else: + L1_fea = self.lrelu(self.conv_first(x)) + L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea)) + L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea)) + L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear', + align_corners=False) + L2_fea = self.RB_L2_1(L2_fea) + L3_fea + L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear', + align_corners=False) + L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea + out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea))) + return out + + +class PCD_Align(nn.Module): + ''' Alignment module using Pyramid, Cascading and Deformable convolution + with 3 pyramid levels. + ''' + + def __init__(self, nf=64, groups=8): + super(PCD_Align, self).__init__() + # L3: level 3, 1/4 spatial size + self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff + self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, + extra_offset_mask=True) + # L2: level 2, 1/2 spatial size + self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff + self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset + self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, + extra_offset_mask=True) + self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea + # L1: level 1, original spatial size + self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff + self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset + self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, + extra_offset_mask=True) + self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea + # Cascading DCN + self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff + self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, + extra_offset_mask=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, nbr_fea_l, ref_fea_l): + '''align other neighboring frames to the reference frame in the feature level + nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features + ''' + # L3 + L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1) + L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) + L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) + L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset])) + # L2 + L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1) + L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) + L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) + L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1))) + L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) + L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset]) + L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) + L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) + # L1 + L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1) + L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) + L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) + L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) + L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) + L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset]) + L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) + L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) + # Cascading + offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) + offset = self.lrelu(self.cas_offset_conv1(offset)) + offset = self.lrelu(self.cas_offset_conv2(offset)) + L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset])) + + return L1_fea + + +class TSA_Fusion(nn.Module): + ''' Temporal Spatial Attention fusion module + Temporal: correlation; + Spatial: 3 pyramid levels. + ''' + + def __init__(self, nf=64, nframes=5, center=2): + super(TSA_Fusion, self).__init__() + self.center = center + # temporal attention (before fusion conv) + self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + # fusion conv: using 1x1 to save parameters and computation + self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) + + # spatial attention (after fusion conv) + self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.avgpool = nn.AvgPool2d(3, stride=2, padding=1) + self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True) + self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True) + self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True) + self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) + self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True) + self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, aligned_fea): + B, N, C, H, W = aligned_fea.size() # N video frames + #### temporal attention + emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone()) + emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W] + + cor_l = [] + for i in range(N): + emb_nbr = emb[:, i, :, :, :] + cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W + cor_l.append(cor_tmp) + cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W + cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W) + aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob + + #### fusion + fea = self.lrelu(self.fea_fusion(aligned_fea)) + + #### spatial attention + att = self.lrelu(self.sAtt_1(aligned_fea)) + att_max = self.maxpool(att) + att_avg = self.avgpool(att) + att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1))) + # pyramid levels + att_L = self.lrelu(self.sAtt_L1(att)) + att_max = self.maxpool(att_L) + att_avg = self.avgpool(att_L) + att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1))) + att_L = self.lrelu(self.sAtt_L3(att_L)) + att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False) + + att = self.lrelu(self.sAtt_3(att)) + att = att + att_L + att = self.lrelu(self.sAtt_4(att)) + att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False) + att = self.sAtt_5(att) + att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att))) + att = torch.sigmoid(att) + + fea = fea * att * 2 + att_add + return fea + + +class EDVR(nn.Module): + def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None, + predeblur=False, HR_in=False, w_TSA=True): + super(EDVR, self).__init__() + self.nf = nf + self.center = nframes // 2 if center is None else center + self.is_predeblur = True if predeblur else False + self.HR_in = True if HR_in else False + self.w_TSA = w_TSA + ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) + + #### extract features (for each frame) + if self.is_predeblur: + self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in) + self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True) + else: + if self.HR_in: + self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) + self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + else: + self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) + self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs) + self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) + self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + self.pcd_align = PCD_Align(nf=nf, groups=groups) + if self.w_TSA: + self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center) + else: + self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) + + #### reconstruction + self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True) + self.pixel_shuffle = nn.PixelShuffle(2) + self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True) + + #### activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, x): + B, N, C, H, W = x.size() # N video frames + x_center = x[:, self.center, :, :, :].contiguous() + + #### extract LR features + # L1 + if self.is_predeblur: + L1_fea = self.pre_deblur(x.view(-1, C, H, W)) + L1_fea = self.conv_1x1(L1_fea) + if self.HR_in: + H, W = H // 4, W // 4 + else: + if self.HR_in: + L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W))) + L1_fea = self.lrelu(self.conv_first_2(L1_fea)) + L1_fea = self.lrelu(self.conv_first_3(L1_fea)) + H, W = H // 4, W // 4 + else: + L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W))) + L1_fea = self.feature_extraction(L1_fea) + # L2 + L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea)) + L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea)) + # L3 + L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea)) + L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea)) + + L1_fea = L1_fea.view(B, N, -1, H, W) + L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2) + L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4) + + #### pcd align + # ref feature list + ref_fea_l = [ + L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(), + L3_fea[:, self.center, :, :, :].clone() + ] + aligned_fea = [] + for i in range(N): + nbr_fea_l = [ + L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(), + L3_fea[:, i, :, :, :].clone() + ] + aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l)) + aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W] + + if not self.w_TSA: + aligned_fea = aligned_fea.view(B, -1, H, W) + fea = self.tsa_fusion(aligned_fea) + + out = self.recon_trunk(fea) + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + out = self.lrelu(self.HRconv(out)) + out = self.conv_last(out) + if self.HR_in: + base = x_center + else: + base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False) + out += base + return out diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py new file mode 100644 index 00000000..9d61256c --- /dev/null +++ b/codes/models/archs/RRDBNet_arch.py @@ -0,0 +1,73 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.archs.arch_util as arch_util + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], + 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out diff --git a/codes/models/archs/SRResNet_arch.py b/codes/models/archs/SRResNet_arch.py new file mode 100644 index 00000000..6e622ac3 --- /dev/null +++ b/codes/models/archs/SRResNet_arch.py @@ -0,0 +1,55 @@ +import functools +import torch.nn as nn +import torch.nn.functional as F +import models.archs.arch_util as arch_util + + +class MSRResNet(nn.Module): + ''' modified SRResNet''' + + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): + super(MSRResNet, self).__init__() + self.upscale = upscale + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) + self.recon_trunk = arch_util.make_layer(basic_block, nb) + + # upsampling + if self.upscale == 2: + self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) + self.pixel_shuffle = nn.PixelShuffle(2) + elif self.upscale == 3: + self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) + self.pixel_shuffle = nn.PixelShuffle(3) + elif self.upscale == 4: + self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) + self.pixel_shuffle = nn.PixelShuffle(2) + + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + # initialization + arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], + 0.1) + if self.upscale == 4: + arch_util.initialize_weights(self.upconv2, 0.1) + + def forward(self, x): + fea = self.lrelu(self.conv_first(x)) + out = self.recon_trunk(fea) + + if self.upscale == 4: + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + elif self.upscale == 3 or self.upscale == 2: + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + + out = self.conv_last(self.lrelu(self.HRconv(out))) + base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) + out += base + return out diff --git a/codes/models/archs/TOF_arch.py b/codes/models/archs/TOF_arch.py new file mode 100755 index 00000000..02d7a914 --- /dev/null +++ b/codes/models/archs/TOF_arch.py @@ -0,0 +1,137 @@ +'''PyTorch implementation of TOFlow +Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 +Code reference: +1. https://github.com/anchen1011/toflow +2. https://github.com/Coldog2333/pytoflow +''' + +import torch +import torch.nn as nn +from .arch_util import flow_warp + + +def normalize(x): + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x) + return (x - mean) / std + + +def denormalize(x): + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x) + return x * std + mean + + +class SpyNet_Block(nn.Module): + '''A submodule of SpyNet.''' + + def __init__(self): + super(SpyNet_Block, self).__init__() + + self.block = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), + nn.BatchNorm2d(64), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), + nn.BatchNorm2d(16), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, x): + ''' + input: x: [ref im, nbr im, initial flow] - (B, 8, H, W) + output: estimated flow - (B, 2, H, W) + ''' + return self.block(x) + + +class SpyNet(nn.Module): + '''SpyNet for estimating optical flow + Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016''' + + def __init__(self): + super(SpyNet, self).__init__() + + self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)]) + + def forward(self, ref, nbr): + '''Estimating optical flow in coarse level, upsample, and estimate in fine level + input: ref: reference image - [B, 3, H, W] + nbr: the neighboring image to be warped - [B, 3, H, W] + output: estimated optical flow - [B, 2, H, W] + ''' + B, C, H, W = ref.size() + ref = [ref] + nbr = [nbr] + + for _ in range(3): + ref.insert( + 0, + nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2, + count_include_pad=False)) + nbr.insert( + 0, + nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2, + count_include_pad=False)) + + flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0]) + + for i in range(4): + flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear', + align_corners=True) * 2.0 + flow = flow_up + self.blocks[i](torch.cat( + [ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) + return flow + + +class TOFlow(nn.Module): + def __init__(self, adapt_official=False): + super(TOFlow, self).__init__() + + self.SpyNet = SpyNet() + + self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4) + self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4) + self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1) + self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1) + + self.relu = nn.ReLU(inplace=True) + + self.adapt_official = adapt_official # True if using translated official weights else False + + def forward(self, x): + """ + input: x: input frames - [B, 7, 3, H, W] + output: SR reference frame - [B, 3, H, W] + """ + + B, T, C, H, W = x.size() + x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W) + + ref_idx = 3 + x_ref = x[:, ref_idx, :, :, :] + + # In the official torch code, the 0-th frame is the reference frame + if self.adapt_official: + x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] + ref_idx = 0 + + x_warped = [] + for i in range(7): + if i == ref_idx: + x_warped.append(x_ref) + else: + x_nbr = x[:, i, :, :, :] + flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1) + x_warped.append(flow_warp(x_nbr, flow)) + x_warped = torch.stack(x_warped, dim=1) + + x = x_warped.view(B, -1, H, W) + x = self.relu(self.conv_3x7_64_9x9(x)) + x = self.relu(self.conv_64_64_9x9(x)) + x = self.relu(self.conv_64_64_1x1(x)) + x = self.conv_64_3_1x1(x) + x_ref + + return denormalize(x) diff --git a/codes/models/archs/__init__.py b/codes/models/archs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py new file mode 100644 index 00000000..ca5d7fa9 --- /dev/null +++ b/codes/models/archs/arch_util.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualBlock_noBN(nn.Module): + '''Residual block w/o BN + ---Conv-ReLU-Conv-+- + |________________| + ''' + + def __init__(self, nf=64): + super(ResidualBlock_noBN, self).__init__() + self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + # initialization + initialize_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = F.relu(self.conv1(x), inplace=True) + out = self.conv2(out) + return identity + out + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): + """Warp an image or feature map with optical flow + Args: + x (Tensor): size (N, C, H, W) + flow (Tensor): size (N, H, W, 2), normal value + interp_mode (str): 'nearest' or 'bilinear' + padding_mode (str): 'zeros' or 'border' or 'reflection' + + Returns: + Tensor: warped image or feature map + """ + assert x.size()[-2:] == flow.size()[1:3] + B, C, H, W = x.size() + # mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + grid = grid.type_as(x) + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) + return output diff --git a/codes/models/archs/dcn/__init__.py b/codes/models/archs/dcn/__init__.py new file mode 100644 index 00000000..1c85e1f0 --- /dev/null +++ b/codes/models/archs/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, + deform_conv, modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/codes/models/archs/dcn/deform_conv.py b/codes/models/archs/dcn/deform_conv.py new file mode 100644 index 00000000..f97cb1c8 --- /dev/null +++ b/codes/models/archs/dcn/deform_conv.py @@ -0,0 +1,291 @@ +import math +import logging + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from . import deform_conv_cuda + +logger = logging.getLogger('base') + + +class DeformConvFunction(Function): + @staticmethod + def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1, + deformable_groups=1, im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format( + input.dim())) + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], + ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_cuda.deform_conv_backward_input_cuda( + input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0], + weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_cuda.deform_conv_backward_parameters_cuda( + input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1], + weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, 1, cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError("convolution input is too small (output would be {})".format('x'.join( + map(str, output_size)))) + return output_size + + +class ModulatedDeformConvFunction(Function): + @staticmethod + def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1, + groups=1, deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_cuda.modulated_deform_conv_cuda_forward( + input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2], + weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, + ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_cuda.modulated_deform_conv_cuda_backward( + input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight, + grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3], + ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, + None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * + (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * + (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, + groups=1, deformable_groups=1, bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, \ + 'in_channels {} cannot be divisible by groups {}'.format( + in_channels, groups) + assert out_channels % groups == 0, \ + 'out_channels {} cannot be divisible by groups {}'.format( + out_channels, groups) + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class DeformConvPack(DeformConv): + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, + groups=1, deformable_groups=1, bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + def __init__(self, *args, extra_offset_mask=False, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.extra_offset_mask = extra_offset_mask + self.conv_offset_mask = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, x): + if self.extra_offset_mask: + # x = [input, features] + out = self.conv_offset_mask(x[1]) + x = x[0] + else: + out = self.conv_offset_mask(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_mean = torch.mean(torch.abs(offset)) + if offset_mean > 100: + logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean)) + + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, + self.deformable_groups) diff --git a/codes/models/archs/dcn/setup.py b/codes/models/archs/dcn/setup.py new file mode 100644 index 00000000..094d961f --- /dev/null +++ b/codes/models/archs/dcn/setup.py @@ -0,0 +1,22 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +def make_cuda_ext(name, sources): + + return CUDAExtension( + name='{}'.format(name), sources=[p for p in sources], extra_compile_args={ + 'cxx': [], + 'nvcc': [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + }) + + +setup( + name='deform_conv', ext_modules=[ + make_cuda_ext(name='deform_conv_cuda', + sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']) + ], cmdclass={'build_ext': BuildExtension}, zip_safe=False) diff --git a/codes/models/archs/dcn/src/deform_conv_cuda.cpp b/codes/models/archs/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 00000000..c4563ed8 --- /dev/null +++ b/codes/models/archs/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,695 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + AT_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + AT_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + AT_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + AT_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + AT_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + AT_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + AT_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + AT_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + AT_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + AT_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, + "deform forward (CUDA)"); + m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda, + "deform_conv_backward_input (CUDA)"); + m.def("deform_conv_backward_parameters_cuda", + &deform_conv_backward_parameters_cuda, + "deform_conv_backward_parameters (CUDA)"); + m.def("modulated_deform_conv_cuda_forward", + &modulated_deform_conv_cuda_forward, + "modulated deform conv forward (CUDA)"); + m.def("modulated_deform_conv_cuda_backward", + &modulated_deform_conv_cuda_backward, + "modulated deform conv backward (CUDA)"); +} diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py new file mode 100644 index 00000000..27dd6a1f --- /dev/null +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torchvision + + +class Discriminator_VGG_128(nn.Module): + def __init__(self, in_nc, nf): + super(Discriminator_VGG_128, self).__init__() + # [64, 128, 128] + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 32, 32] + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + # [256, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + + self.linear1 = nn.Linear(512 * 4 * 4, 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + + fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + + fea = fea.view(fea.size(0), -1) + fea = self.lrelu(self.linear1(fea)) + out = self.linear2(fea) + return out + + +class VGGFeatureExtractor(nn.Module): + def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, + device=torch.device('cpu')): + super(VGGFeatureExtractor, self).__init__() + self.use_input_norm = use_input_norm + if use_bn: + model = torchvision.models.vgg19_bn(pretrained=True) + else: + model = torchvision.models.vgg19(pretrained=True) + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] + self.register_buffer('mean', mean) + self.register_buffer('std', std) + self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) + # No need to BP to variable + for k, v in self.features.named_parameters(): + v.requires_grad = False + + def forward(self, x): + # Assume input range is [0, 1] + if self.use_input_norm: + x = (x - self.mean) / self.std + output = self.features(x) + return output diff --git a/codes/models/base_model.py b/codes/models/base_model.py new file mode 100644 index 00000000..8a5d2225 --- /dev/null +++ b/codes/models/base_model.py @@ -0,0 +1,116 @@ +import os +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel + + +class BaseModel(): + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def get_current_losses(self): + pass + + def print_network(self): + pass + + def save(self, label): + pass + + def load(self): + pass + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup + lr_groups_l: list for lr_groups. each for a optimizer""" + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler""" + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, cur_iter, warmup_iter=-1): + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if cur_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + def get_network_description(self, network): + """Get the string and total parameters of the network""" + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + return str(network), sum(map(lambda x: x.numel(), network.parameters())) + + def save_network(self, network, network_label, iter_label): + save_filename = '{}_{}.pth'.format(iter_label, network_label) + save_path = os.path.join(self.opt['path']['models'], save_filename) + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + state_dict = network.state_dict() + for key, param in state_dict.items(): + state_dict[key] = param.cpu() + torch.save(state_dict, save_path) + + def load_network(self, load_path, network, strict=True): + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + load_net = torch.load(load_path) + load_net_clean = OrderedDict() # remove unnecessary 'module.' + for k, v in load_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + network.load_state_dict(load_net_clean, strict=strict) + + def save_training_state(self, epoch, iter_step): + """Save training state during training, which will be used for resuming""" + state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + save_filename = '{}.state'.format(iter_step) + save_path = os.path.join(self.opt['path']['training_state'], save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + """Resume the optimizers and schedulers for training""" + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) diff --git a/codes/models/loss.py b/codes/models/loss.py new file mode 100644 index 00000000..c9f75f89 --- /dev/null +++ b/codes/models/loss.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, eps=1e-6): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y): + diff = x - y + loss = torch.sum(torch.sqrt(diff * diff + self.eps)) + return loss + + +# Define GAN loss: [vanilla | lsgan | wgan-gp] +class GANLoss(nn.Module): + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type.lower() + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'gan' or self.gan_type == 'ragan': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan-gp': + + def wgan_loss(input, target): + # target is boolean + return -1 * input.mean() if target else input.mean() + + self.loss = wgan_loss + else: + raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) + + def get_target_label(self, input, target_is_real): + if self.gan_type == 'wgan-gp': + return target_is_real + if target_is_real: + return torch.empty_like(input).fill_(self.real_label_val) + else: + return torch.empty_like(input).fill_(self.fake_label_val) + + def forward(self, input, target_is_real): + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + return loss + + +class GradientPenaltyLoss(nn.Module): + def __init__(self, device=torch.device('cpu')): + super(GradientPenaltyLoss, self).__init__() + self.register_buffer('grad_outputs', torch.Tensor()) + self.grad_outputs = self.grad_outputs.to(device) + + def get_grad_outputs(self, input): + if self.grad_outputs.size() != input.size(): + self.grad_outputs.resize_(input.size()).fill_(1.0) + return self.grad_outputs + + def forward(self, interp, interp_crit): + grad_outputs = self.get_grad_outputs(interp_crit) + grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, + grad_outputs=grad_outputs, create_graph=True, + retain_graph=True, only_inputs=True)[0] + grad_interp = grad_interp.view(grad_interp.size(0), -1) + grad_interp_norm = grad_interp.norm(2, dim=1) + + loss = ((grad_interp_norm - 1)**2).mean() + return loss diff --git a/codes/models/lr_scheduler.py b/codes/models/lr_scheduler.py new file mode 100644 index 00000000..be7a92f0 --- /dev/null +++ b/codes/models/lr_scheduler.py @@ -0,0 +1,144 @@ +import math +from collections import Counter +from collections import defaultdict +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepLR_Restart(_LRScheduler): + def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, + clear_state=False, last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.clear_state = clear_state + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + if self.clear_state: + self.optimizer.state = defaultdict(dict) + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + +class CosineAnnealingLR_Restart(_LRScheduler): + def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): + self.T_period = T_period + self.T_max = self.T_period[0] # current T period + self.eta_min = eta_min + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + self.last_restart = 0 + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lrs + elif self.last_epoch in self.restarts: + self.last_restart = self.last_epoch + self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / + (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + +if __name__ == "__main__": + optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, + betas=(0.9, 0.99)) + ############################## + # MultiStepLR_Restart + ############################## + ## Original + lr_steps = [200000, 400000, 600000, 800000] + restarts = None + restart_weights = None + + ## two + lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] + restarts = [500000] + restart_weights = [1] + + ## four + lr_steps = [ + 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, + 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 + ] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, + clear_state=False) + + ############################## + # Cosine Annealing Restart + ############################## + ## two + T_period = [500000, 500000] + restarts = [500000] + restart_weights = [1] + + ## four + T_period = [250000, 250000, 250000, 250000] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, + weights=restart_weights) + + ############################## + # Draw figure + ############################## + N_iter = 1000000 + lr_l = list(range(N_iter)) + for i in range(N_iter): + scheduler.step() + current_lr = optimizer.param_groups[0]['lr'] + lr_l[i] = current_lr + + import matplotlib as mpl + from matplotlib import pyplot as plt + import matplotlib.ticker as mtick + mpl.style.use('default') + import seaborn + seaborn.set(style='whitegrid') + seaborn.set_context('paper') + + plt.figure(1) + plt.subplot(111) + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + plt.title('Title', fontsize=16, color='k') + plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') + legend = plt.legend(loc='upper right', shadow=False) + ax = plt.gca() + labels = ax.get_xticks().tolist() + for k, v in enumerate(labels): + labels[k] = str(int(v / 1000)) + 'K' + ax.set_xticklabels(labels) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) + + ax.set_ylabel('Learning rate') + ax.set_xlabel('Iteration') + fig = plt.gcf() + plt.show() diff --git a/codes/models/networks.py b/codes/models/networks.py new file mode 100644 index 00000000..2a679134 --- /dev/null +++ b/codes/models/networks.py @@ -0,0 +1,57 @@ +import torch +import models.archs.SRResNet_arch as SRResNet_arch +import models.archs.discriminator_vgg_arch as SRGAN_arch +import models.archs.RRDBNet_arch as RRDBNet_arch +import models.archs.EDVR_arch as EDVR_arch + + +# Generator +def define_G(opt): + opt_net = opt['network_G'] + which_model = opt_net['which_model_G'] + + # image restoration + if which_model == 'MSRResNet': + netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) + elif which_model == 'RRDBNet': + netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb']) + # video restoration + elif which_model == 'EDVR': + netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], + groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], + back_RBs=opt_net['back_RBs'], center=opt_net['center'], + predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], + w_TSA=opt_net['w_TSA']) + else: + raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) + + return netG + + +# Discriminator +def define_D(opt): + opt_net = opt['network_D'] + which_model = opt_net['which_model_D'] + + if which_model == 'discriminator_vgg_128': + netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) + else: + raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) + return netD + + +# Define network used for perceptual loss +def define_F(opt, use_bn=False): + gpu_ids = opt['gpu_ids'] + device = torch.device('cuda' if gpu_ids else 'cpu') + # PyTorch pretrained VGG19-54, before ReLU. + if use_bn: + feature_layer = 49 + else: + feature_layer = 34 + netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, + use_input_norm=True, device=device) + netF.eval() # No need to train + return netF diff --git a/codes/options/__init__.py b/codes/options/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/options/options.py b/codes/options/options.py new file mode 100644 index 00000000..99181b34 --- /dev/null +++ b/codes/options/options.py @@ -0,0 +1,116 @@ +import os +import os.path as osp +import logging +import yaml +from utils.util import OrderedYaml +Loader, Dumper = OrderedYaml() + + +def parse(opt_path, is_train=True): + with open(opt_path, mode='r') as f: + opt = yaml.load(f, Loader=Loader) + # export CUDA_VISIBLE_DEVICES + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) + + opt['is_train'] = is_train + if opt['distortion'] == 'sr': + scale = opt['scale'] + + # datasets + for phase, dataset in opt['datasets'].items(): + phase = phase.split('_')[0] + dataset['phase'] = phase + if opt['distortion'] == 'sr': + dataset['scale'] = scale + is_lmdb = False + if dataset.get('dataroot_GT', None) is not None: + dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) + if dataset['dataroot_GT'].endswith('lmdb'): + is_lmdb = True + if dataset.get('dataroot_LQ', None) is not None: + dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) + if dataset['dataroot_LQ'].endswith('lmdb'): + is_lmdb = True + dataset['data_type'] = 'lmdb' if is_lmdb else 'img' + if dataset['mode'].endswith('mc'): # for memcached + dataset['data_type'] = 'mc' + dataset['mode'] = dataset['mode'].replace('_mc', '') + + # path + for key, path in opt['path'].items(): + if path and key in opt['path'] and key != 'strict_load': + opt['path'][key] = osp.expanduser(path) + opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) + if is_train: + experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_state'] = osp.join(experiments_root, 'training_state') + opt['path']['log'] = experiments_root + opt['path']['val_images'] = osp.join(experiments_root, 'val_images') + + # change some options for debug mode + if 'debug' in opt['name']: + opt['train']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(opt['path']['root'], 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + + # network + if opt['distortion'] == 'sr': + opt['network_G']['scale'] = scale + + return opt + + +def dict2str(opt, indent_l=1): + '''dict to string for logger''' + msg = '' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_l * 2) + k + ':[\n' + msg += dict2str(v, indent_l + 1) + msg += ' ' * (indent_l * 2) + ']\n' + else: + msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' + return msg + + +class NoneDict(dict): + def __missing__(self, key): + return None + + +# convert to NoneDict, which return None for missing key. +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +def check_resume(opt, resume_iter): + '''Check resume states and pretrain_model paths''' + logger = logging.getLogger('base') + if opt['path']['resume_state']: + if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( + 'pretrain_model_D', None) is not None: + logger.warning('pretrain_model path will be ignored when resuming training.') + + opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], + '{}_G.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) + if 'gan' in opt['model']: + opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], + '{}_D.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) diff --git a/codes/options/test/test_ESRGAN.yml b/codes/options/test/test_ESRGAN.yml new file mode 100644 index 00000000..3522f217 --- /dev/null +++ b/codes/options/test/test_ESRGAN.yml @@ -0,0 +1,32 @@ +name: RRDB_ESRGAN_x4 +suffix: ~ # add suffix to saved images +model: sr +distortion: sr +scale: 4 +crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels +gpu_ids: [0] + +datasets: + test_1: # the 1st test dataset + name: set5 + mode: LQGT + dataroot_GT: ../datasets/val_set5/Set5 + dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 + test_2: # the 2st test dataset + name: set14 + mode: LQGT + dataroot_GT: ../datasets/val_set14/Set14 + dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 + +#### network structures +network_G: + which_model_G: RRDBNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 23 + upscale: 4 + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth diff --git a/codes/options/test/test_SRGAN.yml b/codes/options/test/test_SRGAN.yml new file mode 100644 index 00000000..21eea625 --- /dev/null +++ b/codes/options/test/test_SRGAN.yml @@ -0,0 +1,32 @@ +name: MSRGANx4 +suffix: ~ # add suffix to saved images +model: sr +distortion: sr +scale: 4 +crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels +gpu_ids: [0] + +datasets: + test_1: # the 1st test dataset + name: set5 + mode: LQGT + dataroot_GT: ../datasets/val_set5/Set5 + dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 + test_2: # the 2st test dataset + name: set14 + mode: LQGT + dataroot_GT: ../datasets/val_set14/Set14 + dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 + +#### network structures +network_G: + which_model_G: MSRResNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 16 + upscale: 4 + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth diff --git a/codes/options/test/test_SRResNet.yml b/codes/options/test/test_SRResNet.yml new file mode 100644 index 00000000..b30b3b44 --- /dev/null +++ b/codes/options/test/test_SRResNet.yml @@ -0,0 +1,48 @@ +name: MSRResNetx4 +suffix: ~ # add suffix to saved images +model: sr +distortion: sr +scale: 4 +crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels +gpu_ids: [0] + +datasets: + test_1: # the 1st test dataset + name: set5 + mode: LQGT + dataroot_GT: ../datasets/val_set5/Set5 + dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 + test_2: # the 2st test dataset + name: set14 + mode: LQGT + dataroot_GT: ../datasets/val_set14/Set14 + dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 + test_3: + name: bsd100 + mode: LQGT + dataroot_GT: ../datasets/BSD/BSDS100 + dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4 + test_4: + name: urban100 + mode: LQGT + dataroot_GT: ../datasets/urban100 + dataroot_LQ: ../datasets/urban100_bicLRx4 + test_5: + name: div2k100 + mode: LQGT + dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR + dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4 + + +#### network structures +network_G: + which_model_G: MSRResNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 16 + upscale: 4 + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth diff --git a/codes/options/train/train_EDVR_M.yml b/codes/options/train/train_EDVR_M.yml new file mode 100644 index 00000000..ed0916c0 --- /dev/null +++ b/codes/options/train/train_EDVR_M.yml @@ -0,0 +1,80 @@ +#### general settings +name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new +use_tb_logger: true +model: video_base +distortion: sr +scale: 4 +gpu_ids: [0,1,2,3,4,5,6,7] + +#### datasets +datasets: + train: + name: REDS + mode: REDS + interval_list: [1] + random_reverse: false + border_mode: false + dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb + dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb + cache_keys: ~ + + N_frames: 5 + use_shuffle: true + n_workers: 3 # per GPU + batch_size: 32 + GT_size: 256 + LQ_size: 64 + use_flip: true + use_rot: true + color: RGB + val: + name: REDS4 + mode: video_test + dataroot_GT: ../datasets/REDS4/GT + dataroot_LQ: ../datasets/REDS4/sharp_bicubic + cache_data: True + N_frames: 5 + padding: new_info + +#### network structures +network_G: + which_model_G: EDVR + nf: 64 + nframes: 5 + groups: 8 + front_RBs: 5 + back_RBs: 10 + predeblur: false + HR_in: false + w_TSA: true + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth + strict_load: false + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 4e-4 + lr_scheme: CosineAnnealingLR_Restart + beta1: 0.9 + beta2: 0.99 + niter: 600000 + ft_tsa_only: 50000 + warmup_iter: -1 # -1: no warm up + T_period: [50000, 100000, 150000, 150000, 150000] + restarts: [50000, 150000, 300000, 450000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-7 + + pixel_criterion: cb + pixel_weight: 1.0 + val_freq: !!float 5e3 + + manual_seed: 0 + +#### logger +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_EDVR_woTSA_M.yml b/codes/options/train/train_EDVR_woTSA_M.yml new file mode 100644 index 00000000..9f48573c --- /dev/null +++ b/codes/options/train/train_EDVR_woTSA_M.yml @@ -0,0 +1,71 @@ +#### general settings +name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S +use_tb_logger: true +model: video_base +distortion: sr +scale: 4 +gpu_ids: [0,1,2,3,4,5,6,7] + +#### datasets +datasets: + train: + name: REDS + mode: REDS + interval_list: [1] + random_reverse: false + border_mode: false + dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb + dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb + cache_keys: ~ + + N_frames: 5 + use_shuffle: true + n_workers: 3 # per GPU + batch_size: 32 + GT_size: 256 + LQ_size: 64 + use_flip: true + use_rot: true + color: RGB + +#### network structures +network_G: + which_model_G: EDVR + nf: 64 + nframes: 5 + groups: 8 + front_RBs: 5 + back_RBs: 10 + predeblur: false + HR_in: false + w_TSA: false + +#### path +path: + pretrain_model_G: ~ + strict_load: true + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 4e-4 + lr_scheme: CosineAnnealingLR_Restart + beta1: 0.9 + beta2: 0.99 + niter: 600000 + warmup_iter: -1 # -1: no warm up + T_period: [150000, 150000, 150000, 150000] + restarts: [150000, 300000, 450000] + restart_weights: [1, 1, 1] + eta_min: !!float 1e-7 + + pixel_criterion: cb + pixel_weight: 1.0 + val_freq: !!float 5e3 + + manual_seed: 0 + +#### logger +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_ESRGAN.yml b/codes/options/train/train_ESRGAN.yml new file mode 100644 index 00000000..720f8652 --- /dev/null +++ b/codes/options/train/train_ESRGAN.yml @@ -0,0 +1,81 @@ +#### general settings +name: 003_RRDB_ESRGANx4_DIV2K +use_tb_logger: true +model: srgan +distortion: sr +scale: 4 +gpu_ids: [2] + +#### datasets +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb + dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 16 + GT_size: 128 + use_flip: true + use_rot: true + color: RGB + val: + name: val_set14 + mode: LQGT + dataroot_GT: ../datasets/val_set14/Set14 + dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 + +#### network structures +network_G: + which_model_G: RRDBNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 23 +network_D: + which_model_D: discriminator_vgg_128 + in_nc: 3 + nf: 64 + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/RRDB_PSNR_x4.pth + strict_load: true + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1e-4 + weight_decay_G: 0 + beta1_G: 0.9 + beta2_G: 0.99 + lr_D: !!float 1e-4 + weight_decay_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + lr_scheme: MultiStepLR + + niter: 400000 + warmup_iter: -1 # no warm up + lr_steps: [50000, 100000, 200000, 300000] + lr_gamma: 0.5 + + pixel_criterion: l1 + pixel_weight: !!float 1e-2 + feature_criterion: l1 + feature_weight: 1 + gan_type: ragan # gan | ragan + gan_weight: !!float 5e-3 + + D_update_ratio: 1 + D_init_iters: 0 + + manual_seed: 10 + val_freq: !!float 5e3 + +#### logger +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_SRGAN.yml b/codes/options/train/train_SRGAN.yml new file mode 100644 index 00000000..8aa48727 --- /dev/null +++ b/codes/options/train/train_SRGAN.yml @@ -0,0 +1,85 @@ +# Not exactly the same as SRGAN in +# With 16 Residual blocks w/o BN + +#### general settings +name: 002_SRGANx4_MSRResNetx4Ini_DIV2K +use_tb_logger: true +model: srgan +distortion: sr +scale: 4 +gpu_ids: [1] + +#### datasets +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb + dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 16 + GT_size: 128 + use_flip: true + use_rot: true + color: RGB + val: + name: val_set14 + mode: LQGT + dataroot_GT: ../datasets/val_set14/Set14 + dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 + +#### network structures +network_G: + which_model_G: MSRResNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 16 + upscale: 4 +network_D: + which_model_D: discriminator_vgg_128 + in_nc: 3 + nf: 64 + +#### path +path: + pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth + strict_load: true + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1e-4 + weight_decay_G: 0 + beta1_G: 0.9 + beta2_G: 0.99 + lr_D: !!float 1e-4 + weight_decay_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + lr_scheme: MultiStepLR + + niter: 400000 + warmup_iter: -1 # no warm up + lr_steps: [50000, 100000, 200000, 300000] + lr_gamma: 0.5 + + pixel_criterion: l1 + pixel_weight: !!float 1e-2 + feature_criterion: l1 + feature_weight: 1 + gan_type: gan # gan | ragan + gan_weight: !!float 5e-3 + + D_update_ratio: 1 + D_init_iters: 0 + + manual_seed: 10 + val_freq: !!float 5e3 + +#### logger +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/options/train/train_SRResNet.yml b/codes/options/train/train_SRResNet.yml new file mode 100644 index 00000000..be6fd4da --- /dev/null +++ b/codes/options/train/train_SRResNet.yml @@ -0,0 +1,70 @@ +# Not exactly the same as SRResNet in +# With 16 Residual blocks w/o BN + +#### general settings +name: 001_MSRResNetx4_scratch_DIV2K +use_tb_logger: true +model: sr +distortion: sr +scale: 4 +gpu_ids: [0] + +#### datasets +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb + dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb + + use_shuffle: true + n_workers: 6 # per GPU + batch_size: 16 + GT_size: 128 + use_flip: true + use_rot: true + color: RGB + val: + name: val_set5 + mode: LQGT + dataroot_GT: ../datasets/val_set5/Set5 + dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 + +#### network structures +network_G: + which_model_G: MSRResNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 16 + upscale: 4 + +#### path +path: + pretrain_model_G: ~ + strict_load: true + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 2e-4 + lr_scheme: CosineAnnealingLR_Restart + beta1: 0.9 + beta2: 0.99 + niter: 1000000 + warmup_iter: -1 # no warm up + T_period: [250000, 250000, 250000, 250000] + restarts: [250000, 500000, 750000] + restart_weights: [1, 1, 1] + eta_min: !!float 1e-7 + + pixel_criterion: l1 + pixel_weight: 1.0 + + manual_seed: 10 + val_freq: !!float 5e3 + +#### logger +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/codes/run_scripts.sh b/codes/run_scripts.sh new file mode 100644 index 00000000..3e7c4945 --- /dev/null +++ b/codes/run_scripts.sh @@ -0,0 +1,10 @@ +# single GPU training (image SR) +python train.py -opt options/train/train_SRResNet.yml +python train.py -opt options/train/train_SRGAN.yml +python train.py -opt options/train/train_ESRGAN.yml + + +# distributed training (video SR) +# 8 GPUs +python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch +python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch \ No newline at end of file diff --git a/codes/scripts/back_projection/backprojection.m b/codes/scripts/back_projection/backprojection.m new file mode 100644 index 00000000..496d93f4 --- /dev/null +++ b/codes/scripts/back_projection/backprojection.m @@ -0,0 +1,20 @@ +function [im_h] = backprojection(im_h, im_l, maxIter) + +[row_l, col_l,~] = size(im_l); +[row_h, col_h,~] = size(im_h); + +p = fspecial('gaussian', 5, 1); +p = p.^2; +p = p./sum(p(:)); + +im_l = double(im_l); +im_h = double(im_h); + +for ii = 1:maxIter + im_l_s = imresize(im_h, [row_l, col_l], 'bicubic'); + im_diff = im_l - im_l_s; + im_diff = imresize(im_diff, [row_h, col_h], 'bicubic'); + im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same'); + im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same'); + im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same'); +end diff --git a/codes/scripts/back_projection/main_bp.m b/codes/scripts/back_projection/main_bp.m new file mode 100644 index 00000000..40c137ed --- /dev/null +++ b/codes/scripts/back_projection/main_bp.m @@ -0,0 +1,22 @@ +clear; close all; clc; + +LR_folder = './LR'; % LR +preout_folder = './results'; % pre output +save_folder = './results_20bp'; +filepaths = dir(fullfile(preout_folder, '*.png')); +max_iter = 20; + +if ~ exist(save_folder, 'dir') + mkdir(save_folder); +end + +for idx_im = 1:length(filepaths) + fprintf([num2str(idx_im) '\n']); + im_name = filepaths(idx_im).name; + im_LR = im2double(imread(fullfile(LR_folder, im_name))); + im_out = im2double(imread(fullfile(preout_folder, im_name))); + %tic + im_out = backprojection(im_out, im_LR, max_iter); + %toc + imwrite(im_out, fullfile(save_folder, im_name)); +end diff --git a/codes/scripts/back_projection/main_reverse_filter.m b/codes/scripts/back_projection/main_reverse_filter.m new file mode 100644 index 00000000..63f2edcf --- /dev/null +++ b/codes/scripts/back_projection/main_reverse_filter.m @@ -0,0 +1,25 @@ +clear; close all; clc; + +LR_folder = './LR'; % LR +preout_folder = './results'; % pre output +save_folder = './results_20if'; +filepaths = dir(fullfile(preout_folder, '*.png')); +max_iter = 20; + +if ~ exist(save_folder, 'dir') + mkdir(save_folder); +end + +for idx_im = 1:length(filepaths) + fprintf([num2str(idx_im) '\n']); + im_name = filepaths(idx_im).name; + im_LR = im2double(imread(fullfile(LR_folder, im_name))); + im_out = im2double(imread(fullfile(preout_folder, im_name))); + J = imresize(im_LR,4,'bicubic'); + %tic + for m = 1:max_iter + im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic')); + end + %toc + imwrite(im_out, fullfile(save_folder, im_name)); +end diff --git a/codes/scripts/transfer_params_MSRResNet.py b/codes/scripts/transfer_params_MSRResNet.py new file mode 100644 index 00000000..70dafa4d --- /dev/null +++ b/codes/scripts/transfer_params_MSRResNet.py @@ -0,0 +1,27 @@ +import os.path as osp +import sys +import torch +try: + sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) + import models.archs.SRResNet_arch as SRResNet_arch +except ImportError: + pass + +pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth') +crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3) +crt_net = crt_model.state_dict() + +for k, v in crt_net.items(): + if k in pretrained_net and 'upconv1' not in k: + crt_net[k] = pretrained_net[k] + print('replace ... ', k) + +# x4 -> x3 +crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2 +crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2 +crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2 +crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2 +crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2 +crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2 + +torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth') diff --git a/codes/test.py b/codes/test.py new file mode 100644 index 00000000..b07a44b7 --- /dev/null +++ b/codes/test.py @@ -0,0 +1,105 @@ +import os.path as osp +import logging +import time +import argparse +from collections import OrderedDict + +import options.options as option +import utils.util as util +from data.util import bgr2ycbcr +from data import create_dataset, create_dataloader +from models import create_model + +#### options +parser = argparse.ArgumentParser() +parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') +opt = option.parse(parser.parse_args().opt, is_train=False) +opt = option.dict_to_nonedict(opt) + +util.mkdirs( + (path for key, path in opt['path'].items() + if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) +util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) +logger = logging.getLogger('base') +logger.info(option.dict2str(opt)) + +#### Create test dataset and dataloader +test_loaders = [] +for phase, dataset_opt in sorted(opt['datasets'].items()): + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader(test_set, dataset_opt) + logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + +model = create_model(opt) +for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info('\nTesting [{:s}]...'.format(test_set_name)) + test_start_time = time.time() + dataset_dir = osp.join(opt['path']['results_root'], test_set_name) + util.mkdir(dataset_dir) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + + for data in test_loader: + need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True + model.feed_data(data, need_GT=need_GT) + img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] + img_name = osp.splitext(osp.basename(img_path))[0] + + model.test() + visuals = model.get_current_visuals(need_GT=need_GT) + + sr_img = util.tensor2img(visuals['rlt']) # uint8 + + # save images + suffix = opt['suffix'] + if suffix: + save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') + else: + save_img_path = osp.join(dataset_dir, img_name + '.png') + util.save_img(sr_img, save_img_path) + + # calculate PSNR and SSIM + if need_GT: + gt_img = util.tensor2img(visuals['GT']) + sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) + psnr = util.calculate_psnr(sr_img, gt_img) + ssim = util.calculate_ssim(sr_img, gt_img) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + + if gt_img.shape[2] == 3: # RGB image + sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) + gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) + + psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) + ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + logger.info( + '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'. + format(img_name, psnr, ssim, psnr_y, ssim_y)) + else: + logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) + else: + logger.info(img_name) + + if need_GT: # metrics + # Average PSNR/SSIM results + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info( + '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format( + test_set_name, ave_psnr, ave_ssim)) + if test_results['psnr_y'] and test_results['ssim_y']: + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + logger.info( + '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'. + format(ave_psnr_y, ave_ssim_y)) diff --git a/codes/test_Vid4_REDS4_with_GT.py b/codes/test_Vid4_REDS4_with_GT.py new file mode 100644 index 00000000..7c90c493 --- /dev/null +++ b/codes/test_Vid4_REDS4_with_GT.py @@ -0,0 +1,208 @@ +''' +Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets +''' + +import os +import os.path as osp +import glob +import logging +import numpy as np +import cv2 +import torch + +import utils.util as util +import data.util as data_util +import models.archs.EDVR_arch as EDVR_arch + + +def main(): + ################# + # configurations + ################# + device = torch.device('cuda') + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp + # Vid4: SR + # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); + # blur (deblur-clean), blur_comp (deblur-compression). + stage = 1 # 1 or 2, use two stage strategy for REDS dataset. + flip_test = False + ############################################################################ + #### model + if data_mode == 'Vid4': + if stage == 1: + model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' + else: + raise ValueError('Vid4 does not support stage 2.') + elif data_mode == 'sharp_bicubic': + if stage == 1: + model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' + else: + model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' + elif data_mode == 'blur_bicubic': + if stage == 1: + model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' + else: + model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' + elif data_mode == 'blur': + if stage == 1: + model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' + else: + model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' + elif data_mode == 'blur_comp': + if stage == 1: + model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' + else: + model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' + else: + raise NotImplementedError + + if data_mode == 'Vid4': + N_in = 7 # use N_in images to restore one HR image + else: + N_in = 5 + + predeblur, HR_in = False, False + back_RBs = 40 + if data_mode == 'blur_bicubic': + predeblur = True + if data_mode == 'blur' or data_mode == 'blur_comp': + predeblur, HR_in = True, True + if stage == 2: + HR_in = True + back_RBs = 20 + model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) + + #### dataset + if data_mode == 'Vid4': + test_dataset_folder = '../datasets/Vid4/BIx4' + GT_dataset_folder = '../datasets/Vid4/GT' + else: + if stage == 1: + test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) + else: + test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' + print('You should modify the test_dataset_folder path for stage 2') + GT_dataset_folder = '../datasets/REDS4/GT' + + #### evaluation + crop_border = 0 + border_frame = N_in // 2 # border frames when evaluate + # temporal padding mode + if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': + padding = 'new_info' + else: + padding = 'replicate' + save_imgs = True + + save_folder = '../results/{}'.format(data_mode) + util.mkdirs(save_folder) + util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) + logger = logging.getLogger('base') + + #### log info + logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) + logger.info('Padding mode: {}'.format(padding)) + logger.info('Model path: {}'.format(model_path)) + logger.info('Save images: {}'.format(save_imgs)) + logger.info('Flip test: {}'.format(flip_test)) + + #### set up the models + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + model = model.to(device) + + avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] + subfolder_name_l = [] + + subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) + subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) + # for each subfolder + for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): + subfolder_name = osp.basename(subfolder) + subfolder_name_l.append(subfolder_name) + save_subfolder = osp.join(save_folder, subfolder_name) + + img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) + max_idx = len(img_path_l) + if save_imgs: + util.mkdirs(save_subfolder) + + #### read LQ and GT images + imgs_LQ = data_util.read_img_seq(subfolder) + img_GT_l = [] + for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): + img_GT_l.append(data_util.read_img(None, img_GT_path)) + + avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 + + # process each image + for img_idx, img_path in enumerate(img_path_l): + img_name = osp.splitext(osp.basename(img_path))[0] + select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) + imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) + + if flip_test: + output = util.flipx4_forward(model, imgs_in) + else: + output = util.single_forward(model, imgs_in) + output = util.tensor2img(output.squeeze(0)) + + if save_imgs: + cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) + + # calculate PSNR + output = output / 255. + GT = np.copy(img_GT_l[img_idx]) + # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel + if data_mode == 'Vid4': # bgr2y, [0, 1] + GT = data_util.bgr2ycbcr(GT, only_y=True) + output = data_util.bgr2ycbcr(output, only_y=True) + + output, GT = util.crop_border([output, GT], crop_border) + crt_psnr = util.calculate_psnr(output * 255, GT * 255) + logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) + + if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames + avg_psnr_center += crt_psnr + N_center += 1 + else: # border frames + avg_psnr_border += crt_psnr + N_border += 1 + + avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) + avg_psnr_center = avg_psnr_center / N_center + avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border + avg_psnr_l.append(avg_psnr) + avg_psnr_center_l.append(avg_psnr_center) + avg_psnr_border_l.append(avg_psnr_border) + + logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' + 'Center PSNR: {:.6f} dB for {} frames; ' + 'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr, + (N_center + N_border), + avg_psnr_center, N_center, + avg_psnr_border, N_border)) + + logger.info('################ Tidy Outputs ################') + for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, + avg_psnr_center_l, avg_psnr_border_l): + logger.info('Folder {} - Average PSNR: {:.6f} dB. ' + 'Center PSNR: {:.6f} dB. ' + 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, + psnr_border)) + logger.info('################ Final Results ################') + logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) + logger.info('Padding mode: {}'.format(padding)) + logger.info('Model path: {}'.format(model_path)) + logger.info('Save images: {}'.format(save_imgs)) + logger.info('Flip test: {}'.format(flip_test)) + logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' + 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( + sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), + sum(avg_psnr_center_l) / len(avg_psnr_center_l), + sum(avg_psnr_border_l) / len(avg_psnr_border_l))) + + +if __name__ == '__main__': + main() diff --git a/codes/test_Vid4_REDS4_with_GT_DUF.py b/codes/test_Vid4_REDS4_with_GT_DUF.py new file mode 100644 index 00000000..fcec690d --- /dev/null +++ b/codes/test_Vid4_REDS4_with_GT_DUF.py @@ -0,0 +1,264 @@ +""" +DUF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets +write to txt log file +""" + +import os +import os.path as osp +import glob +import logging +import numpy as np +import cv2 +import torch + +import utils.util as util +import data.util as data_util +import models.archs.DUF_arch as DUF_arch + + +def main(): + ################# + # configurations + ################# + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS) + + # Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52) + scale = 4 + layer = 52 + assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28), + (4, 52)], 'Unrecognized (scale, layer) combination' + + # model + N_in = 7 + model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format(scale, layer) + adapt_official = True if 'official' in model_path else False + DUF_downsampling = True # True | False + if layer == 16: + model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official) + elif layer == 28: + model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official) + elif layer == 52: + model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official) + + #### dataset + if data_mode == 'Vid4': + test_dataset_folder = '../datasets/Vid4/BIx4/*' + else: # sharp_bicubic (REDS) + test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode) + + #### evaluation + crop_border = 8 + border_frame = N_in // 2 # border frames when evaluate + # temporal padding mode + padding = 'new_info' # different from the official testing codes, which pads zeros. + save_imgs = True + ############################################################################ + device = torch.device('cuda') + save_folder = '../results/{}'.format(data_mode) + util.mkdirs(save_folder) + util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) + logger = logging.getLogger('base') + + #### log info + logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) + logger.info('Padding mode: {}'.format(padding)) + logger.info('Model path: {}'.format(model_path)) + logger.info('Save images: {}'.format(save_imgs)) + + def read_image(img_path): + '''read one image from img_path + Return img: HWC, BGR, [0,1], numpy + ''' + img_GT = cv2.imread(img_path) + img = img_GT.astype(np.float32) / 255. + return img + + def read_seq_imgs(img_seq_path): + '''read a sequence of images''' + img_path_l = sorted(glob.glob(img_seq_path + '/*')) + img_l = [read_image(v) for v in img_path_l] + # stack to TCHW, RGB, [0,1], torch + imgs = np.stack(img_l, axis=0) + imgs = imgs[:, :, :, [2, 1, 0]] + imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() + return imgs + + def index_generation(crt_i, max_n, N, padding='reflection'): + ''' + padding: replicate | reflection | new_info | circle + ''' + max_n = max_n - 1 + n_pad = N // 2 + return_l = [] + + for i in range(crt_i - n_pad, crt_i + n_pad + 1): + if i < 0: + if padding == 'replicate': + add_idx = 0 + elif padding == 'reflection': + add_idx = -i + elif padding == 'new_info': + add_idx = (crt_i + n_pad) + (-i) + elif padding == 'circle': + add_idx = N + i + else: + raise ValueError('Wrong padding mode') + elif i > max_n: + if padding == 'replicate': + add_idx = max_n + elif padding == 'reflection': + add_idx = max_n * 2 - i + elif padding == 'new_info': + add_idx = (crt_i - n_pad) - (i - max_n) + elif padding == 'circle': + add_idx = i - N + else: + raise ValueError('Wrong padding mode') + else: + add_idx = i + return_l.append(add_idx) + return return_l + + def single_forward(model, imgs_in): + with torch.no_grad(): + model_output = model(imgs_in) + if isinstance(model_output, list) or isinstance(model_output, tuple): + output = model_output[0] + else: + output = model_output + return output + + sub_folder_l = sorted(glob.glob(test_dataset_folder)) + #### set up the models + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + model = model.to(device) + + avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] + sub_folder_name_l = [] + + # for each sub-folder + for sub_folder in sub_folder_l: + sub_folder_name = sub_folder.split('/')[-1] + sub_folder_name_l.append(sub_folder_name) + save_sub_folder = osp.join(save_folder, sub_folder_name) + + img_path_l = sorted(glob.glob(sub_folder + '/*')) + max_idx = len(img_path_l) + + if save_imgs: + util.mkdirs(save_sub_folder) + + #### read LR images + imgs = read_seq_imgs(sub_folder) + #### read GT images + img_GT_l = [] + if data_mode == 'Vid4': + sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*') + else: + sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*') + for img_GT_path in sorted(glob.glob(sub_folder_GT)): + img_GT_l.append(read_image(img_GT_path)) + + # When using the downsampling in DUF official code, we downsample the HR images + if DUF_downsampling: + sub_folder = sub_folder_GT + img_path_l = sorted(glob.glob(sub_folder)) + max_idx = len(img_path_l) + imgs = read_seq_imgs(sub_folder[:-2]) + + avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0 + cal_n_border, cal_n_center = 0, 0 + + # process each image + for img_idx, img_path in enumerate(img_path_l): + c_idx = int(osp.splitext(osp.basename(img_path))[0]) + select_idx = index_generation(c_idx, max_idx, N_in, padding=padding) + # get input images + imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) + + # Downsample the HR images + H, W = imgs_in.size(3), imgs_in.size(4) + if DUF_downsampling: + imgs_in = util.DUF_downsample(imgs_in, scale=scale) + + output = single_forward(model, imgs_in) + + # Crop to the original shape + if scale == 3: + pad_h = 3 - (H % 3) + pad_w = 3 - (W % 3) + if pad_h > 0: + output = output[:, :, :-pad_h, :] + if pad_w > 0: + output = output[:, :, :, :-pad_w] + output_f = output.data.float().cpu().squeeze(0) + + output = util.tensor2img(output_f) + + # save imgs + if save_imgs: + cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output) + + #### calculate PSNR + output = output / 255. + GT = np.copy(img_GT_l[img_idx]) + # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels + if data_mode == 'Vid4': # bgr2y, [0, 1] + GT = data_util.bgr2ycbcr(GT) + output = data_util.bgr2ycbcr(output) + if crop_border == 0: + cropped_output = output + cropped_GT = GT + else: + cropped_output = output[crop_border:-crop_border, crop_border:-crop_border] + cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border] + crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255) + logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr)) + + if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames + avg_psnr_center += crt_psnr + cal_n_center += 1 + else: # border frames + avg_psnr_border += crt_psnr + cal_n_border += 1 + + avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border) + avg_psnr_center = avg_psnr_center / cal_n_center + if cal_n_border == 0: + avg_psnr_border = 0 + else: + avg_psnr_border = avg_psnr_border / cal_n_border + + logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' + 'Center PSNR: {:.6f} dB for {} frames; ' + 'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr, + (cal_n_center + cal_n_border), + avg_psnr_center, cal_n_center, + avg_psnr_border, cal_n_border)) + + avg_psnr_l.append(avg_psnr) + avg_psnr_center_l.append(avg_psnr_center) + avg_psnr_border_l.append(avg_psnr_border) + + logger.info('################ Tidy Outputs ################') + for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l, + avg_psnr_center_l, avg_psnr_border_l): + logger.info('Folder {} - Average PSNR: {:.6f} dB. ' + 'Center PSNR: {:.6f} dB. ' + 'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border)) + logger.info('################ Final Results ################') + logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) + logger.info('Padding mode: {}'.format(padding)) + logger.info('Model path: {}'.format(model_path)) + logger.info('Save images: {}'.format(save_imgs)) + logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' + 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( + sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l), + sum(avg_psnr_center_l) / len(avg_psnr_center_l), + sum(avg_psnr_border_l) / len(avg_psnr_border_l))) + + +if __name__ == '__main__': + main() diff --git a/codes/test_Vid4_REDS4_with_GT_TOF.py b/codes/test_Vid4_REDS4_with_GT_TOF.py new file mode 100644 index 00000000..da8fc300 --- /dev/null +++ b/codes/test_Vid4_REDS4_with_GT_TOF.py @@ -0,0 +1,230 @@ +""" +TOF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets +write to txt log file +""" + +import os +import os.path as osp +import glob +import logging +import numpy as np +import cv2 +import torch + +import utils.util as util +import data.util as data_util +import models.archs.TOF_arch as TOF_arch + + +def main(): + ################# + # configurations + ################# + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS) + + # model + N_in = 7 + model_path = '../experiments/pretrained_models/TOF_official.pth' + adapt_official = True if 'official' in model_path else False + model = TOF_arch.TOFlow(adapt_official=adapt_official) + + #### dataset + if data_mode == 'Vid4': + test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*' + else: + test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode) + + #### evaluation + crop_border = 0 + border_frame = N_in // 2 # border frames when evaluate + # temporal padding mode + padding = 'new_info' # different from the official setting + save_imgs = True + ############################################################################ + device = torch.device('cuda') + save_folder = '../results/{}'.format(data_mode) + util.mkdirs(save_folder) + util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) + logger = logging.getLogger('base') + + #### log info + logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) + logger.info('Padding mode: {}'.format(padding)) + logger.info('Model path: {}'.format(model_path)) + logger.info('Save images: {}'.format(save_imgs)) + + def read_image(img_path): + '''read one image from img_path + Return img: HWC, BGR, [0,1], numpy + ''' + img_GT = cv2.imread(img_path) + img = img_GT.astype(np.float32) / 255. + return img + + def read_seq_imgs(img_seq_path): + '''read a sequence of images''' + img_path_l = sorted(glob.glob(img_seq_path + '/*')) + img_l = [read_image(v) for v in img_path_l] + # stack to TCHW, RGB, [0,1], torch + imgs = np.stack(img_l, axis=0) + imgs = imgs[:, :, :, [2, 1, 0]] + imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() + return imgs + + def index_generation(crt_i, max_n, N, padding='reflection'): + ''' + padding: replicate | reflection | new_info | circle + ''' + max_n = max_n - 1 + n_pad = N // 2 + return_l = [] + + for i in range(crt_i - n_pad, crt_i + n_pad + 1): + if i < 0: + if padding == 'replicate': + add_idx = 0 + elif padding == 'reflection': + add_idx = -i + elif padding == 'new_info': + add_idx = (crt_i + n_pad) + (-i) + elif padding == 'circle': + add_idx = N + i + else: + raise ValueError('Wrong padding mode') + elif i > max_n: + if padding == 'replicate': + add_idx = max_n + elif padding == 'reflection': + add_idx = max_n * 2 - i + elif padding == 'new_info': + add_idx = (crt_i - n_pad) - (i - max_n) + elif padding == 'circle': + add_idx = i - N + else: + raise ValueError('Wrong padding mode') + else: + add_idx = i + return_l.append(add_idx) + return return_l + + def single_forward(model, imgs_in): + with torch.no_grad(): + model_output = model(imgs_in) + if isinstance(model_output, list) or isinstance(model_output, tuple): + output = model_output[0] + else: + output = model_output + return output + + sub_folder_l = sorted(glob.glob(test_dataset_folder)) + #### set up the models + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + model = model.to(device) + + avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] + sub_folder_name_l = [] + + # for each sub-folder + for sub_folder in sub_folder_l: + sub_folder_name = sub_folder.split('/')[-1] + sub_folder_name_l.append(sub_folder_name) + save_sub_folder = osp.join(save_folder, sub_folder_name) + + img_path_l = sorted(glob.glob(sub_folder + '/*')) + max_idx = len(img_path_l) + + if save_imgs: + util.mkdirs(save_sub_folder) + + #### read LR images + imgs = read_seq_imgs(sub_folder) + #### read GT images + img_GT_l = [] + if data_mode == 'Vid4': + sub_folder_GT = osp.join(sub_folder.replace('/BIx4up_direct/', '/GT/'), '*') + else: + sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*') + for img_GT_path in sorted(glob.glob(sub_folder_GT)): + img_GT_l.append(read_image(img_GT_path)) + + avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0 + cal_n_border, cal_n_center = 0, 0 + + # process each image + for img_idx, img_path in enumerate(img_path_l): + c_idx = int(osp.splitext(osp.basename(img_path))[0]) + select_idx = index_generation(c_idx, max_idx, N_in, padding=padding) + # get input images + imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) + output = single_forward(model, imgs_in) + output_f = output.data.float().cpu().squeeze(0) + + output = util.tensor2img(output_f) + + # save imgs + if save_imgs: + cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output) + + #### calculate PSNR + output = output / 255. + GT = np.copy(img_GT_l[img_idx]) + # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels + if data_mode == 'Vid4': # bgr2y, [0, 1] + GT = data_util.bgr2ycbcr(GT) + output = data_util.bgr2ycbcr(output) + if crop_border == 0: + cropped_output = output + cropped_GT = GT + else: + cropped_output = output[crop_border:-crop_border, crop_border:-crop_border] + cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border] + crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255) + logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr)) + + if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames + avg_psnr_center += crt_psnr + cal_n_center += 1 + else: # border frames + avg_psnr_border += crt_psnr + cal_n_border += 1 + + avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border) + avg_psnr_center = avg_psnr_center / cal_n_center + if cal_n_border == 0: + avg_psnr_border = 0 + else: + avg_psnr_border = avg_psnr_border / cal_n_border + + logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' + 'Center PSNR: {:.6f} dB for {} frames; ' + 'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr, + (cal_n_center + cal_n_border), + avg_psnr_center, cal_n_center, + avg_psnr_border, cal_n_border)) + + avg_psnr_l.append(avg_psnr) + avg_psnr_center_l.append(avg_psnr_center) + avg_psnr_border_l.append(avg_psnr_border) + + logger.info('################ Tidy Outputs ################') + for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l, + avg_psnr_center_l, avg_psnr_border_l): + logger.info('Folder {} - Average PSNR: {:.6f} dB. ' + 'Center PSNR: {:.6f} dB. ' + 'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border)) + logger.info('################ Final Results ################') + logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) + logger.info('Padding mode: {}'.format(padding)) + logger.info('Model path: {}'.format(model_path)) + logger.info('Save images: {}'.format(save_imgs)) + logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' + 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( + sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l), + sum(avg_psnr_center_l) / len(avg_psnr_center_l), + sum(avg_psnr_border_l) / len(avg_psnr_border_l))) + + +if __name__ == '__main__': + main() diff --git a/codes/train.py b/codes/train.py new file mode 100644 index 00000000..c8c29bde --- /dev/null +++ b/codes/train.py @@ -0,0 +1,310 @@ +import os +import math +import argparse +import random +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from data.data_sampler import DistIterSampler + +import options.options as option +from utils import util +from data import create_dataloader, create_dataset +from models import create_model + + +def init_dist(backend='nccl', **kwargs): + """initialization for distributed training""" + if mp.get_start_method(allow_none=True) != 'spawn': + mp.set_start_method('spawn') + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def main(): + #### options + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = option.parse(args.opt, is_train=True) + + #### distributed training settings + if args.launcher == 'none': # disabled distributed training + opt['dist'] = False + rank = -1 + print('Disabled distributed training.') + else: + opt['dist'] = True + init_dist() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + #### loading resume state if exists + if opt['path'].get('resume_state', None): + # distributed resuming: all load into default GPU + device_id = torch.cuda.current_device() + resume_state = torch.load(opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + option.check_resume(opt, resume_state['iter']) # check resume options + else: + resume_state = None + + #### mkdir and loggers + if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) + if resume_state is None: + util.mkdir_and_rename( + opt['path']['experiments_root']) # rename experiment folder if exists + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + and 'pretrain_model' not in key and 'resume' not in key)) + + # config loggers. Before it, the log will not work + util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + version = float(torch.__version__[0:3]) + if version >= 1.1: # PyTorch 1.1 + from torch.utils.tensorboard import SummaryWriter + else: + logger.info( + 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) + else: + util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) + logger = logging.getLogger('base') + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + + #### random seed + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + if rank <= 0: + logger.info('Random seed: {}'.format(seed)) + util.set_random_seed(seed) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + #### create train and val dataloader + dataset_ratio = 200 # enlarge the size of each epoch + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + if opt['dist']: + train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) + total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + else: + train_sampler = None + train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) + if rank <= 0: + logger.info('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + logger.info('Total epochs needed: {:d} for iters {:,d}'.format( + total_epochs, total_iters)) + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader(val_set, dataset_opt, opt, None) + if rank <= 0: + logger.info('Number of val images in [{:s}]: {:d}'.format( + dataset_opt['name'], len(val_set))) + else: + raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) + assert train_loader is not None + + #### create model + model = create_model(opt) + + #### resume training + if resume_state: + logger.info('Resuming training from epoch: {}, iter: {}.'.format( + resume_state['epoch'], resume_state['iter'])) + + start_epoch = resume_state['epoch'] + current_step = resume_state['iter'] + model.resume_training(resume_state) # handle optimizers and schedulers + else: + current_step = 0 + start_epoch = 0 + + #### training + logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) + for epoch in range(start_epoch, total_epochs + 1): + if opt['dist']: + train_sampler.set_epoch(epoch) + for _, train_data in enumerate(train_loader): + current_step += 1 + if current_step > total_iters: + break + #### update learning rate + model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) + + #### training + model.feed_data(train_data) + model.optimize_parameters(current_step) + + #### log + if current_step % opt['logger']['print_freq'] == 0: + logs = model.get_current_log() + message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) + for v in model.get_current_learning_rate(): + message += '{:.3e},'.format(v) + message += ')] ' + for k, v in logs.items(): + message += '{:s}: {:.4e} '.format(k, v) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + if rank <= 0: + tb_logger.add_scalar(k, v, current_step) + if rank <= 0: + logger.info(message) + #### validation + if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: + if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation + # does not support multi-GPU validation + pbar = util.ProgressBar(len(val_loader)) + avg_psnr = 0. + idx = 0 + for val_data in val_loader: + idx += 1 + img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] + img_dir = os.path.join(opt['path']['val_images'], img_name) + util.mkdir(img_dir) + + model.feed_data(val_data) + model.test() + + visuals = model.get_current_visuals() + sr_img = util.tensor2img(visuals['rlt']) # uint8 + gt_img = util.tensor2img(visuals['GT']) # uint8 + + # Save SR images for reference + save_img_path = os.path.join(img_dir, + '{:s}_{:d}.png'.format(img_name, current_step)) + util.save_img(sr_img, save_img_path) + + # calculate PSNR + sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) + avg_psnr += util.calculate_psnr(sr_img, gt_img) + pbar.update('Test {}'.format(img_name)) + + avg_psnr = avg_psnr / idx + + # log + logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger.add_scalar('psnr', avg_psnr, current_step) + else: # video restoration validation + if opt['dist']: + # multi-GPU testing + psnr_rlt = {} # with border and center frames + if rank == 0: + pbar = util.ProgressBar(len(val_set)) + for idx in range(rank, len(val_set), world_size): + val_data = val_set[idx] + val_data['LQs'].unsqueeze_(0) + val_data['GT'].unsqueeze_(0) + folder = val_data['folder'] + idx_d, max_idx = val_data['idx'].split('/') + idx_d, max_idx = int(idx_d), int(max_idx) + if psnr_rlt.get(folder, None) is None: + psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, + device='cuda') + # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') + model.feed_data(val_data) + model.test() + visuals = model.get_current_visuals() + rlt_img = util.tensor2img(visuals['rlt']) # uint8 + gt_img = util.tensor2img(visuals['GT']) # uint8 + # calculate PSNR + psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img) + + if rank == 0: + for _ in range(world_size): + pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx)) + # # collect data + for _, v in psnr_rlt.items(): + dist.reduce(v, 0) + dist.barrier() + + if rank == 0: + psnr_rlt_avg = {} + psnr_total_avg = 0. + for k, v in psnr_rlt.items(): + psnr_rlt_avg[k] = torch.mean(v).cpu().item() + psnr_total_avg += psnr_rlt_avg[k] + psnr_total_avg /= len(psnr_rlt) + log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg) + for k, v in psnr_rlt_avg.items(): + log_s += ' {}: {:.4e}'.format(k, v) + logger.info(log_s) + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) + for k, v in psnr_rlt_avg.items(): + tb_logger.add_scalar(k, v, current_step) + else: + pbar = util.ProgressBar(len(val_loader)) + psnr_rlt = {} # with border and center frames + psnr_rlt_avg = {} + psnr_total_avg = 0. + for val_data in val_loader: + folder = val_data['folder'][0] + idx_d = val_data['idx'].item() + # border = val_data['border'].item() + if psnr_rlt.get(folder, None) is None: + psnr_rlt[folder] = [] + + model.feed_data(val_data) + model.test() + visuals = model.get_current_visuals() + rlt_img = util.tensor2img(visuals['rlt']) # uint8 + gt_img = util.tensor2img(visuals['GT']) # uint8 + + # calculate PSNR + psnr = util.calculate_psnr(rlt_img, gt_img) + psnr_rlt[folder].append(psnr) + pbar.update('Test {} - {}'.format(folder, idx_d)) + for k, v in psnr_rlt.items(): + psnr_rlt_avg[k] = sum(v) / len(v) + psnr_total_avg += psnr_rlt_avg[k] + psnr_total_avg /= len(psnr_rlt) + log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg) + for k, v in psnr_rlt_avg.items(): + log_s += ' {}: {:.4e}'.format(k, v) + logger.info(log_s) + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) + for k, v in psnr_rlt_avg.items(): + tb_logger.add_scalar(k, v, current_step) + + #### save models and training states + if current_step % opt['logger']['save_checkpoint_freq'] == 0: + if rank <= 0: + logger.info('Saving models and training states.') + model.save(current_step) + model.save_training_state(epoch, current_step) + + if rank <= 0: + logger.info('Saving the final model.') + model.save('latest') + logger.info('End of training.') + tb_logger.close() + + +if __name__ == '__main__': + main() diff --git a/codes/utils/__init__.py b/codes/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/utils/util.py b/codes/utils/util.py new file mode 100644 index 00000000..a8924fa8 --- /dev/null +++ b/codes/utils/util.py @@ -0,0 +1,327 @@ +import os +import sys +import time +import math +import torch.nn.functional as F +from datetime import datetime +import random +import logging +from collections import OrderedDict +import numpy as np +import cv2 +import torch +from torchvision.utils import make_grid +from shutil import get_terminal_size + +import yaml +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + + +def OrderedYaml(): + '''yaml orderedDict support''' + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +#################### +# miscellaneous +#################### + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + logger = logging.getLogger('base') + logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): + '''set up logger''' + lg = logging.getLogger(logger_name) + formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', + datefmt='%y-%m-%d %H:%M:%S') + lg.setLevel(level) + if tofile: + log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) + fh = logging.FileHandler(log_file, mode='w') + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) + + +#################### +# image convert +#################### +def crop_border(img_list, crop_border): + """Crop borders of images + Args: + img_list (list [Numpy]): HWC + crop_border (int): crop border for each end of height and weight + + Returns: + (list [Numpy]): cropped image list + """ + if crop_border == 0: + return img_list + else: + return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list] + + +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +def save_img(img, img_path, mode='RGB'): + cv2.imwrite(img_path, img) + + +def DUF_downsample(x, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code + + Args: + x (Tensor, [B, T, C, H, W]): frames to be downsampled. + scale (int): downsampling factor: 2 | 3 | 4. + """ + + assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale) + + def gkern(kernlen=13, nsig=1.6): + import scipy.ndimage.filters as fi + inp = np.zeros((kernlen, kernlen)) + # set element at the middle to one, a dirac delta + inp[kernlen // 2, kernlen // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter mask + return fi.gaussian_filter(inp, nsig) + + B, T, C, H, W = x.size() + x = x.view(-1, 1, H, W) + pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter + r_h, r_w = 0, 0 + if scale == 3: + r_h = 3 - (H % 3) + r_w = 3 - (W % 3) + x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect') + + gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(B, T, C, x.size(2), x.size(3)) + return x + + +def single_forward(model, inp): + """PyTorch model forward (single test), it is just a simple warpper + Args: + model (PyTorch model) + inp (Tensor): inputs defined by the model + + Returns: + output (Tensor): outputs of the model. float, in CPU + """ + with torch.no_grad(): + model_output = model(inp) + if isinstance(model_output, list) or isinstance(model_output, tuple): + output = model_output[0] + else: + output = model_output + output = output.data.float().cpu() + return output + + +def flipx4_forward(model, inp): + """Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W + Args: + model (PyTorch model) + inp (Tensor): inputs defined by the model + + Returns: + output (Tensor): outputs of the model. float, in CPU + """ + # normal + output_f = single_forward(model, inp) + + # flip W + output = single_forward(model, torch.flip(inp, (-1, ))) + output_f = output_f + torch.flip(output, (-1, )) + # flip H + output = single_forward(model, torch.flip(inp, (-2, ))) + output_f = output_f + torch.flip(output, (-2, )) + # flip both H and W + output = single_forward(model, torch.flip(inp, (-2, -1))) + output_f = output_f + torch.flip(output, (-2, -1)) + + return output_f / 4 + + +#################### +# metric +#################### + + +def calculate_psnr(img1, img2): + # img1 and img2 have range [0, 255] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, img2): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1, img2)) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +class ProgressBar(object): + '''A progress bar which can print the progress + modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py + ''' + + def __init__(self, task_num=0, bar_width=50, start=True): + self.task_num = task_num + max_bar_width = self._get_max_bar_width() + self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) + self.completed = 0 + if start: + self.start() + + def _get_max_bar_width(self): + terminal_width, _ = get_terminal_size() + max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) + if max_bar_width < 10: + print('terminal width is too small ({}), please consider widen the terminal for better ' + 'progressbar visualization'.format(terminal_width)) + max_bar_width = 10 + return max_bar_width + + def start(self): + if self.task_num > 0: + sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( + ' ' * self.bar_width, self.task_num, 'Start...')) + else: + sys.stdout.write('completed: 0, elapsed: 0s') + sys.stdout.flush() + self.start_time = time.time() + + def update(self, msg='In progress...'): + self.completed += 1 + elapsed = time.time() - self.start_time + fps = self.completed / elapsed + if self.task_num > 0: + percentage = self.completed / float(self.task_num) + eta = int(elapsed * (1 - percentage) / percentage + 0.5) + mark_width = int(self.bar_width * percentage) + bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) + sys.stdout.write('\033[2F') # cursor up 2 lines + sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) + sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( + bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) + else: + sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( + self.completed, int(elapsed + 0.5), fps)) + sys.stdout.flush() diff --git a/experiments/pretrained_models/Put pretrained models here. b/experiments/pretrained_models/Put pretrained models here. new file mode 100644 index 00000000..e69de29b