diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000..009995b3
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,6 @@
+[flake8]
+ignore =
+ # Too many leading '#' for block comment (E266)
+ E266
+
+max-line-length=100
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..b0abdf54
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,121 @@
+experiments/*
+results/*
+tb_logger/*
+datasets/*
+.vscode
+
+*.html
+*.png
+*.jpg
+*.gif
+*.pth
+*.pytorch
+*.zip
+*.cu
+
+# template
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
diff --git a/.style.yapf b/.style.yapf
new file mode 100644
index 00000000..9f0e0622
--- /dev/null
+++ b/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+BASED_ON_STYLE = pep8
+COLUMN_LIMIT = 100
+SPLIT_BEFORE_NAMED_ASSIGNS = false
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 00000000..4b8f4ffc
--- /dev/null
+++ b/README.md
@@ -0,0 +1,47 @@
+# MMSR
+
+MMSR is an open source image and video super-resolution toolbox based on PyTorch. It is a part of the [open-mmlab](https://github.com/open-mmlab) project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/). MMSR is based on our previous projects: [BasicSR](https://github.com/xinntao/BasicSR), [ESRGAN](https://github.com/xinntao/ESRGAN)and [EDVR](https://github.com/xinntao/EDVR).
+
+### Highlights
+- **A unified framework** suitable for image and video super-resolution tasks. It is also easy to adapt to other restoration tasks, e.g., deblurring, denoising, etc.
+- **State of the art**: It includes several winning methods in competitions: such as ESRGAN (PIRM18), EDVR (NTIRE19).
+- **Easy to extend**: It is easy to try new research ideas based on the code base.
+
+
+### Updates
+[2019-07-25] MMSR v0.1 is released.
+
+## Dependencies and Installation
+
+- Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
+- [PyTorch >= 1.0](https://pytorch.org/)
+- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
+- [Deformable Convolution](https://arxiv.org/abs/1703.06211). We use [mmdetection](https://github.com/open-mmlab/mmdetection)'s dcn implementation. Please first compile it.
+ ```
+ cd ./codes/models/archs/dcn
+ python setup.py develop
+ ```
+- Python packages: `pip install numpy opencv-python lmdb pyyaml`
+- TensorBoard:
+ - PyTorch >= 1.1: `pip install tb-nightly future`
+ - PyTorch == 1.0: `pip install tensorboardX`
+
+## Dataset Preparation
+We use datasets in LDMB format for faster IO speed. Please refer to [DATASETS.md](datasets/DATASETS.md) for more details.
+
+## Training and Testing
+Please see [wiki- Training and Testing](https://github.com/open-mmlab/mmsr/wiki/Training-and-Testing) for the basic usage, *i.e.,* training and testing.
+
+## Model Zoo and Baselines
+Results and pre-trained models are available in the [wiki-Model Zoo](https://github.com/open-mmlab/mmsr/wiki/Model-Zoo).
+
+## Contributing
+We appreciate all contributions. Please refer to [mmdetection](https://github.com/open-mmlab/mmdetection/blob/master/CONTRIBUTING.md) for contributing guideline.
+
+**Python code style**
+We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style. We use [flake8](http://flake8.pycqa.org/en/latest/) as the linter and [yapf](https://github.com/google/yapf) as the formatter. Please upgrade to the latest yapf (>=0.27.0) and refer to the [yapf configuration](.style.yapf) and [flake8 configuration](.flake8).
+
+> Before you create a PR, make sure that your code lints and is formatted by yapf.
+
+## License
+This project is released under the Apache 2.0 license.
diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py
new file mode 100644
index 00000000..a52817eb
--- /dev/null
+++ b/codes/data/LQGT_dataset.py
@@ -0,0 +1,127 @@
+import random
+import numpy as np
+import cv2
+import lmdb
+import torch
+import torch.utils.data as data
+import data.util as util
+
+
+class LQGTDataset(data.Dataset):
+ """
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
+ If only GT images are provided, generate LQ images on-the-fly.
+ """
+
+ def __init__(self, opt):
+ super(LQGTDataset, self).__init__()
+ self.opt = opt
+ self.data_type = self.opt['data_type']
+ self.paths_LQ, self.paths_GT = None, None
+ self.sizes_LQ, self.sizes_GT = None, None
+ self.LQ_env, self.GT_env = None, None # environments for lmdb
+
+ self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
+ self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
+ assert self.paths_GT, 'Error: GT path is empty.'
+ if self.paths_LQ and self.paths_GT:
+ assert len(self.paths_LQ) == len(
+ self.paths_GT
+ ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
+ len(self.paths_LQ), len(self.paths_GT))
+ self.random_scale_list = [1]
+
+ def _init_lmdb(self):
+ # https://github.com/chainer/chainermn/issues/129
+ self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
+ meminit=False)
+ self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
+ meminit=False)
+
+ def __getitem__(self, index):
+ if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
+ self._init_lmdb()
+ GT_path, LQ_path = None, None
+ scale = self.opt['scale']
+ GT_size = self.opt['GT_size']
+
+ # get GT image
+ GT_path = self.paths_GT[index]
+ resolution = [int(s) for s in self.sizes_GT[index].split('_')
+ ] if self.data_type == 'lmdb' else None
+ img_GT = util.read_img(self.GT_env, GT_path, resolution)
+ if self.opt['phase'] != 'train': # modcrop in the validation / test phase
+ img_GT = util.modcrop(img_GT, scale)
+ if self.opt['color']: # change color space if necessary
+ img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
+
+ # get LQ image
+ if self.paths_LQ:
+ LQ_path = self.paths_LQ[index]
+ resolution = [int(s) for s in self.sizes_LQ[index].split('_')
+ ] if self.data_type == 'lmdb' else None
+ img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
+ else: # down-sampling on-the-fly
+ # randomly scale during training
+ if self.opt['phase'] == 'train':
+ random_scale = random.choice(self.random_scale_list)
+ H_s, W_s, _ = img_GT.shape
+
+ def _mod(n, random_scale, scale, thres):
+ rlt = int(n * random_scale)
+ rlt = (rlt // scale) * scale
+ return thres if rlt < thres else rlt
+
+ H_s = _mod(H_s, random_scale, scale, GT_size)
+ W_s = _mod(W_s, random_scale, scale, GT_size)
+ img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
+ if img_GT.ndim == 2:
+ img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
+
+ H, W, _ = img_GT.shape
+ # using matlab imresize
+ img_LQ = util.imresize_np(img_GT, 1 / scale, True)
+ if img_LQ.ndim == 2:
+ img_LQ = np.expand_dims(img_LQ, axis=2)
+
+ if self.opt['phase'] == 'train':
+ # if the image size is too small
+ H, W, _ = img_GT.shape
+ if H < GT_size or W < GT_size:
+ img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
+ # using matlab imresize
+ img_LQ = util.imresize_np(img_GT, 1 / scale, True)
+ if img_LQ.ndim == 2:
+ img_LQ = np.expand_dims(img_LQ, axis=2)
+
+ H, W, C = img_LQ.shape
+ LQ_size = GT_size // scale
+
+ # randomly crop
+ rnd_h = random.randint(0, max(0, H - LQ_size))
+ rnd_w = random.randint(0, max(0, W - LQ_size))
+ img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
+ rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
+ img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
+
+ # augmentation - flip, rotate
+ img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
+ self.opt['use_rot'])
+
+ if self.opt['color']: # change color space if necessary
+ img_LQ = util.channel_convert(C, self.opt['color'],
+ [img_LQ])[0] # TODO during val no definition
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ if img_GT.shape[2] == 3:
+ img_GT = img_GT[:, :, [2, 1, 0]]
+ img_LQ = img_LQ[:, :, [2, 1, 0]]
+ img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
+ img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
+
+ if LQ_path is None:
+ LQ_path = GT_path
+ return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
+
+ def __len__(self):
+ return len(self.paths_GT)
diff --git a/codes/data/LQ_dataset.py b/codes/data/LQ_dataset.py
new file mode 100644
index 00000000..e7ad7dd3
--- /dev/null
+++ b/codes/data/LQ_dataset.py
@@ -0,0 +1,47 @@
+import numpy as np
+import lmdb
+import torch
+import torch.utils.data as data
+import data.util as util
+
+
+class LQDataset(data.Dataset):
+ '''Read LQ images only in the test phase.'''
+
+ def __init__(self, opt):
+ super(LQDataset, self).__init__()
+ self.opt = opt
+ self.paths_LQ, self.paths_GT = None, None
+ self.LQ_env = None # environment for lmdb
+
+ self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
+ assert self.paths_LQ, 'Error: LQ paths are empty.'
+
+ def _init_lmdb(self):
+ self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
+ meminit=False)
+
+ def __getitem__(self, index):
+ if self.data_type == 'lmdb' and self.LQ_env is None:
+ self._init_lmdb()
+ LQ_path = None
+
+ # get LQ image
+ LQ_path = self.LQ_path[index]
+ resolution = [int(s) for s in self.sizes_LQ[index].split('_')
+ ] if self.data_type == 'lmdb' else None
+ img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
+ H, W, C = img_LQ.shape
+
+ if self.opt['color']: # change color space if necessary
+ img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0]
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ if img_LQ.shape[2] == 3:
+ img_LQ = img_LQ[:, :, [2, 1, 0]]
+ img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
+
+ return {'LQ': img_LQ, 'LQ_path': LQ_path}
+
+ def __len__(self):
+ return len(self.paths_LQ)
diff --git a/codes/data/REDS_dataset.py b/codes/data/REDS_dataset.py
new file mode 100644
index 00000000..5643dbb2
--- /dev/null
+++ b/codes/data/REDS_dataset.py
@@ -0,0 +1,210 @@
+'''
+REDS dataset
+support reading images from lmdb, image folder and memcached
+'''
+import os.path as osp
+import random
+import pickle
+import logging
+import numpy as np
+import cv2
+import lmdb
+import torch
+import torch.utils.data as data
+import data.util as util
+try:
+ import mc # import memcached
+except ImportError:
+ pass
+
+logger = logging.getLogger('base')
+
+
+class REDSDataset(data.Dataset):
+ '''
+ Reading the training REDS dataset
+ key example: 000_00000000
+ GT: Ground-Truth;
+ LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames
+ support reading N LQ frames, N = 1, 3, 5, 7
+ '''
+
+ def __init__(self, opt):
+ super(REDSDataset, self).__init__()
+ self.opt = opt
+ # temporal augmentation
+ self.interval_list = opt['interval_list']
+ self.random_reverse = opt['random_reverse']
+ logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
+ ','.join(str(x) for x in opt['interval_list']), self.random_reverse))
+
+ self.half_N_frames = opt['N_frames'] // 2
+ self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
+ self.data_type = self.opt['data_type']
+ self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
+ #### directly load image keys
+ if self.data_type == 'lmdb':
+ self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
+ logger.info('Using lmdb meta info for cache keys.')
+ elif opt['cache_keys']:
+ logger.info('Using cache keys: {}'.format(opt['cache_keys']))
+ self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
+ else:
+ raise ValueError(
+ 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
+
+ # remove the REDS4 for testing
+ self.paths_GT = [
+ v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020']
+ ]
+ assert self.paths_GT, 'Error: GT path is empty.'
+
+ if self.data_type == 'lmdb':
+ self.GT_env, self.LQ_env = None, None
+ elif self.data_type == 'mc': # memcached
+ self.mclient = None
+ elif self.data_type == 'img':
+ pass
+ else:
+ raise ValueError('Wrong data type: {}'.format(self.data_type))
+
+ def _init_lmdb(self):
+ # https://github.com/chainer/chainermn/issues/129
+ self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
+ meminit=False)
+ self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
+ meminit=False)
+
+ def _ensure_memcached(self):
+ if self.mclient is None:
+ # specify the config files
+ server_list_config_file = None
+ client_config_file = None
+ self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
+ client_config_file)
+
+ def _read_img_mc(self, path):
+ ''' Return BGR, HWC, [0, 255], uint8'''
+ value = mc.pyvector()
+ self.mclient.Get(path, value)
+ value_buf = mc.ConvertBuffer(value)
+ img_array = np.frombuffer(value_buf, np.uint8)
+ img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
+ return img
+
+ def _read_img_mc_BGR(self, path, name_a, name_b):
+ ''' Read BGR channels separately and then combine for 1M limits in cluster'''
+ img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png'))
+ img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png'))
+ img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png'))
+ img = cv2.merge((img_B, img_G, img_R))
+ return img
+
+ def __getitem__(self, index):
+ if self.data_type == 'mc':
+ self._ensure_memcached()
+ elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
+ self._init_lmdb()
+
+ scale = self.opt['scale']
+ GT_size = self.opt['GT_size']
+ key = self.paths_GT[index]
+ name_a, name_b = key.split('_')
+ center_frame_idx = int(name_b)
+
+ #### determine the neighbor frames
+ interval = random.choice(self.interval_list)
+ if self.opt['border_mode']:
+ direction = 1 # 1: forward; 0: backward
+ N_frames = self.opt['N_frames']
+ if self.random_reverse and random.random() < 0.5:
+ direction = random.choice([0, 1])
+ if center_frame_idx + interval * (N_frames - 1) > 99:
+ direction = 0
+ elif center_frame_idx - interval * (N_frames - 1) < 0:
+ direction = 1
+ # get the neighbor list
+ if direction == 1:
+ neighbor_list = list(
+ range(center_frame_idx, center_frame_idx + interval * N_frames, interval))
+ else:
+ neighbor_list = list(
+ range(center_frame_idx, center_frame_idx - interval * N_frames, -interval))
+ name_b = '{:08d}'.format(neighbor_list[0])
+ else:
+ # ensure not exceeding the borders
+ while (center_frame_idx + self.half_N_frames * interval >
+ 99) or (center_frame_idx - self.half_N_frames * interval < 0):
+ center_frame_idx = random.randint(0, 99)
+ # get the neighbor list
+ neighbor_list = list(
+ range(center_frame_idx - self.half_N_frames * interval,
+ center_frame_idx + self.half_N_frames * interval + 1, interval))
+ if self.random_reverse and random.random() < 0.5:
+ neighbor_list.reverse()
+ name_b = '{:08d}'.format(neighbor_list[self.half_N_frames])
+
+ assert len(
+ neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format(
+ len(neighbor_list))
+
+ #### get the GT image (as the center frame)
+ if self.data_type == 'mc':
+ img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b)
+ img_GT = img_GT.astype(np.float32) / 255.
+ elif self.data_type == 'lmdb':
+ img_GT = util.read_img(self.GT_env, key, (3, 720, 1280))
+ else:
+ img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png'))
+
+ #### get LQ images
+ LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280)
+ img_LQ_l = []
+ for v in neighbor_list:
+ img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v))
+ if self.data_type == 'mc':
+ if self.LR_input:
+ img_LQ = self._read_img_mc(img_LQ_path)
+ else:
+ img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v))
+ img_LQ = img_LQ.astype(np.float32) / 255.
+ elif self.data_type == 'lmdb':
+ img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple)
+ else:
+ img_LQ = util.read_img(None, img_LQ_path)
+ img_LQ_l.append(img_LQ)
+
+ if self.opt['phase'] == 'train':
+ C, H, W = LQ_size_tuple # LQ size
+ # randomly crop
+ if self.LR_input:
+ LQ_size = GT_size // scale
+ rnd_h = random.randint(0, max(0, H - LQ_size))
+ rnd_w = random.randint(0, max(0, W - LQ_size))
+ img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
+ rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
+ img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
+ else:
+ rnd_h = random.randint(0, max(0, H - GT_size))
+ rnd_w = random.randint(0, max(0, W - GT_size))
+ img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
+ img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
+
+ # augmentation - flip, rotate
+ img_LQ_l.append(img_GT)
+ rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
+ img_LQ_l = rlt[0:-1]
+ img_GT = rlt[-1]
+
+ # stack LQ images to NHWC, N is the frame number
+ img_LQs = np.stack(img_LQ_l, axis=0)
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_GT = img_GT[:, :, [2, 1, 0]]
+ img_LQs = img_LQs[:, :, :, [2, 1, 0]]
+ img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
+ img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
+ (0, 3, 1, 2)))).float()
+ return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
+
+ def __len__(self):
+ return len(self.paths_GT)
diff --git a/codes/data/Vimeo90K_dataset.py b/codes/data/Vimeo90K_dataset.py
new file mode 100644
index 00000000..914e1dba
--- /dev/null
+++ b/codes/data/Vimeo90K_dataset.py
@@ -0,0 +1,167 @@
+'''
+Vimeo90K dataset
+support reading images from lmdb, image folder and memcached
+'''
+import os.path as osp
+import random
+import pickle
+import logging
+import numpy as np
+import cv2
+import lmdb
+import torch
+import torch.utils.data as data
+import data.util as util
+try:
+ import mc # import memcached
+except ImportError:
+ pass
+logger = logging.getLogger('base')
+
+
+class Vimeo90KDataset(data.Dataset):
+ '''
+ Reading the training Vimeo90K dataset
+ key example: 00001_0001 (_1, ..., _7)
+ GT (Ground-Truth): 4th frame;
+ LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame
+ '''
+
+ def __init__(self, opt):
+ super(Vimeo90KDataset, self).__init__()
+ self.opt = opt
+ # temporal augmentation
+ self.interval_list = opt['interval_list']
+ self.random_reverse = opt['random_reverse']
+ logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
+ ','.join(str(x) for x in opt['interval_list']), self.random_reverse))
+
+ self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
+ self.data_type = self.opt['data_type']
+ self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
+
+ #### determine the LQ frame list
+ '''
+ N | frames
+ 1 | 4
+ 3 | 3,4,5
+ 5 | 2,3,4,5,6
+ 7 | 1,2,3,4,5,6,7
+ '''
+ self.LQ_frames_list = []
+ for i in range(opt['N_frames']):
+ self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2)
+
+ #### directly load image keys
+ if self.data_type == 'lmdb':
+ self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
+ logger.info('Using lmdb meta info for cache keys.')
+ elif opt['cache_keys']:
+ logger.info('Using cache keys: {}'.format(opt['cache_keys']))
+ self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
+ else:
+ raise ValueError(
+ 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
+ assert self.paths_GT, 'Error: GT path is empty.'
+
+ if self.data_type == 'lmdb':
+ self.GT_env, self.LQ_env = None, None
+ elif self.data_type == 'mc': # memcached
+ self.mclient = None
+ elif self.data_type == 'img':
+ pass
+ else:
+ raise ValueError('Wrong data type: {}'.format(self.data_type))
+
+ def _init_lmdb(self):
+ # https://github.com/chainer/chainermn/issues/129
+ self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
+ meminit=False)
+ self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
+ meminit=False)
+
+ def _ensure_memcached(self):
+ if self.mclient is None:
+ # specify the config files
+ server_list_config_file = None
+ client_config_file = None
+ self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
+ client_config_file)
+
+ def _read_img_mc(self, path):
+ ''' Return BGR, HWC, [0, 255], uint8'''
+ value = mc.pyvector()
+ self.mclient.Get(path, value)
+ value_buf = mc.ConvertBuffer(value)
+ img_array = np.frombuffer(value_buf, np.uint8)
+ img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
+ return img
+
+ def __getitem__(self, index):
+ if self.data_type == 'mc':
+ self._ensure_memcached()
+ elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
+ self._init_lmdb()
+
+ scale = self.opt['scale']
+ GT_size = self.opt['GT_size']
+ key = self.paths_GT[index]
+ name_a, name_b = key.split('_')
+ #### get the GT image (as the center frame)
+ if self.data_type == 'mc':
+ img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png'))
+ img_GT = img_GT.astype(np.float32) / 255.
+ elif self.data_type == 'lmdb':
+ img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448))
+ else:
+ img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png'))
+
+ #### get LQ images
+ LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
+ img_LQ_l = []
+ for v in self.LQ_frames_list:
+ if self.data_type == 'mc':
+ img_LQ = self._read_img_mc(
+ osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v)))
+ img_LQ = img_LQ.astype(np.float32) / 255.
+ elif self.data_type == 'lmdb':
+ img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
+ else:
+ img_LQ = util.read_img(None,
+ osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
+ img_LQ_l.append(img_LQ)
+
+ if self.opt['phase'] == 'train':
+ C, H, W = LQ_size_tuple # LQ size
+ # randomly crop
+ if self.LR_input:
+ LQ_size = GT_size // scale
+ rnd_h = random.randint(0, max(0, H - LQ_size))
+ rnd_w = random.randint(0, max(0, W - LQ_size))
+ img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
+ rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
+ img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
+ else:
+ rnd_h = random.randint(0, max(0, H - GT_size))
+ rnd_w = random.randint(0, max(0, W - GT_size))
+ img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
+ img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
+
+ # augmentation - flip, rotate
+ img_LQ_l.append(img_GT)
+ rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
+ img_LQ_l = rlt[0:-1]
+ img_GT = rlt[-1]
+
+ # stack LQ images to NHWC, N is the frame number
+ img_LQs = np.stack(img_LQ_l, axis=0)
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_GT = img_GT[:, :, [2, 1, 0]]
+ img_LQs = img_LQs[:, :, :, [2, 1, 0]]
+ img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
+ img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
+ (0, 3, 1, 2)))).float()
+ return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
+
+ def __len__(self):
+ return len(self.paths_GT)
diff --git a/codes/data/__init__.py b/codes/data/__init__.py
new file mode 100644
index 00000000..ce10a834
--- /dev/null
+++ b/codes/data/__init__.py
@@ -0,0 +1,49 @@
+"""create dataset and dataloader"""
+import logging
+import torch
+import torch.utils.data
+
+
+def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
+ phase = dataset_opt['phase']
+ if phase == 'train':
+ if opt['dist']:
+ world_size = torch.distributed.get_world_size()
+ num_workers = dataset_opt['n_workers']
+ assert dataset_opt['batch_size'] % world_size == 0
+ batch_size = dataset_opt['batch_size'] // world_size
+ shuffle = False
+ else:
+ num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
+ batch_size = dataset_opt['batch_size']
+ shuffle = True
+ return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
+ num_workers=num_workers, sampler=sampler, drop_last=True,
+ pin_memory=False)
+ else:
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
+ pin_memory=False)
+
+
+def create_dataset(dataset_opt):
+ mode = dataset_opt['mode']
+ # datasets for image restoration
+ if mode == 'LQ':
+ from data.LQ_dataset import LQDataset as D
+ elif mode == 'LQGT':
+ from data.LQGT_dataset import LQGTDataset as D
+ # datasets for video restoration
+ elif mode == 'REDS':
+ from data.REDS_dataset import REDSDataset as D
+ elif mode == 'Vimeo90K':
+ from data.Vimeo90K_dataset import Vimeo90KDataset as D
+ elif mode == 'video_test':
+ from data.video_test_dataset import VideoTestDataset as D
+ else:
+ raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
+ dataset = D(dataset_opt)
+
+ logger = logging.getLogger('base')
+ logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
+ dataset_opt['name']))
+ return dataset
diff --git a/codes/data/data_sampler.py b/codes/data/data_sampler.py
new file mode 100644
index 00000000..9c409418
--- /dev/null
+++ b/codes/data/data_sampler.py
@@ -0,0 +1,65 @@
+"""
+Modified from torch.utils.data.distributed.DistributedSampler
+Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the
+dataloader after each epoch
+"""
+import math
+import torch
+from torch.utils.data.sampler import Sampler
+import torch.distributed as dist
+
+
+class DistIterSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+
+ .. note::
+ Dataset is assumed to be of constant size.
+
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.total_size, generator=g).tolist()
+
+ dsize = len(self.dataset)
+ indices = [v % dsize for v in indices]
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/codes/data/util.py b/codes/data/util.py
new file mode 100644
index 00000000..516c6ae9
--- /dev/null
+++ b/codes/data/util.py
@@ -0,0 +1,543 @@
+import os
+import math
+import pickle
+import random
+import numpy as np
+import glob
+import torch
+import cv2
+
+####################
+# Files & IO
+####################
+
+###################### get image path list ######################
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def _get_paths_from_images(path):
+ """get image path list from image folder"""
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+def _get_paths_from_lmdb(dataroot):
+ """get image path list from lmdb meta info"""
+ meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
+ paths = meta_info['keys']
+ sizes = meta_info['resolution']
+ if len(sizes) == 1:
+ sizes = sizes * len(paths)
+ return paths, sizes
+
+
+def get_image_paths(data_type, dataroot):
+ """get image path list
+ support lmdb or image files"""
+ paths, sizes = None, None
+ if dataroot is not None:
+ if data_type == 'lmdb':
+ paths, sizes = _get_paths_from_lmdb(dataroot)
+ elif data_type == 'img':
+ paths = sorted(_get_paths_from_images(dataroot))
+ else:
+ raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
+ return paths, sizes
+
+
+def glob_file_list(root):
+ return sorted(glob.glob(os.path.join(root, '*')))
+
+
+###################### read images ######################
+def _read_img_lmdb(env, key, size):
+ """read image from lmdb with key (w/ and w/o fixed size)
+ size: (C, H, W) tuple"""
+ with env.begin(write=False) as txn:
+ buf = txn.get(key.encode('ascii'))
+ img_flat = np.frombuffer(buf, dtype=np.uint8)
+ C, H, W = size
+ img = img_flat.reshape(H, W, C)
+ return img
+
+
+def read_img(env, path, size=None):
+ """read image by cv2 or from lmdb
+ return: Numpy float32, HWC, BGR, [0,1]"""
+ if env is None: # img
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ else:
+ img = _read_img_lmdb(env, path, size)
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+def read_img_seq(path):
+ """Read a sequence of images from a given folder path
+ Args:
+ path (list/str): list of image paths/image folder path
+
+ Returns:
+ imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
+ """
+ if type(path) is list:
+ img_path_l = path
+ else:
+ img_path_l = sorted(glob.glob(os.path.join(path, '*')))
+ img_l = [read_img(None, v) for v in img_path_l]
+ # stack to Torch tensor
+ imgs = np.stack(img_l, axis=0)
+ imgs = imgs[:, :, :, [2, 1, 0]]
+ imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
+ return imgs
+
+
+def index_generation(crt_i, max_n, N, padding='reflection'):
+ """Generate an index list for reading N frames from a sequence of images
+ Args:
+ crt_i (int): current center index
+ max_n (int): max number of the sequence of images (calculated from 1)
+ N (int): reading N frames
+ padding (str): padding mode, one of replicate | reflection | new_info | circle
+ Example: crt_i = 0, N = 5
+ replicate: [0, 0, 0, 1, 2]
+ reflection: [2, 1, 0, 1, 2]
+ new_info: [4, 3, 0, 1, 2]
+ circle: [3, 4, 0, 1, 2]
+
+ Returns:
+ return_l (list [int]): a list of indexes
+ """
+ max_n = max_n - 1
+ n_pad = N // 2
+ return_l = []
+
+ for i in range(crt_i - n_pad, crt_i + n_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ add_idx = 0
+ elif padding == 'reflection':
+ add_idx = -i
+ elif padding == 'new_info':
+ add_idx = (crt_i + n_pad) + (-i)
+ elif padding == 'circle':
+ add_idx = N + i
+ else:
+ raise ValueError('Wrong padding mode')
+ elif i > max_n:
+ if padding == 'replicate':
+ add_idx = max_n
+ elif padding == 'reflection':
+ add_idx = max_n * 2 - i
+ elif padding == 'new_info':
+ add_idx = (crt_i - n_pad) - (i - max_n)
+ elif padding == 'circle':
+ add_idx = i - N
+ else:
+ raise ValueError('Wrong padding mode')
+ else:
+ add_idx = i
+ return_l.append(add_idx)
+ return return_l
+
+
+####################
+# image processing
+# process on numpy image
+####################
+
+
+def augment(img_list, hflip=True, rot=True):
+ """horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+def augment_flow(img_list, flow_list, hflip=True, rot=True):
+ """horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows"""
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip:
+ flow = flow[:, ::-1, :]
+ flow[:, :, 0] *= -1
+ if vflip:
+ flow = flow[::-1, :, :]
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ rlt_img_list = [_augment(img) for img in img_list]
+ rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
+
+ return rlt_img_list, rlt_flow_list
+
+
+def channel_convert(in_c, tar_type, img_list):
+ """conversion among BGR, gray and y"""
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+def rgb2ycbcr(img, only_y=True):
+ """same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ """
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ """bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ """
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ """same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ """
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def modcrop(img_in, scale):
+ """img_in: Numpy, HWC or HW"""
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+####################
+# Functions
+####################
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
+ (absx > 1) * (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: CHW RGB [0,1]
+ # output: CHW RGB [0,1] w/o round
+
+ in_C, in_H, in_W = img.size()
+ _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+ out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+ out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
+ out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
+ out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
+
+ return out_2
+
+
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC BGR [0,1]
+ # output: HWC BGR [0,1] w/o round
+ img = torch.from_numpy(img)
+
+ in_H, in_W, in_C = img.size()
+ _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
+ out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
+ out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
+ out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
+ out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ # test imresize function
+ # read images
+ img = cv2.imread('test.png')
+ img = img * 1.0 / 255
+ img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
+ # imresize
+ scale = 1 / 4
+ import time
+ total_time = 0
+ for i in range(10):
+ start_time = time.time()
+ rlt = imresize(img, scale, antialiasing=True)
+ use_time = time.time() - start_time
+ total_time += use_time
+ print('average time: {}'.format(total_time / 10))
+
+ import torchvision.utils
+ torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
+ normalize=False)
diff --git a/codes/data/video_test_dataset.py b/codes/data/video_test_dataset.py
new file mode 100644
index 00000000..ce891958
--- /dev/null
+++ b/codes/data/video_test_dataset.py
@@ -0,0 +1,84 @@
+import os.path as osp
+import torch
+import torch.utils.data as data
+import data.util as util
+
+
+class VideoTestDataset(data.Dataset):
+ """
+ A video test dataset. Support:
+ Vid4
+ REDS4
+ Vimeo90K-Test
+
+ no need to prepare LMDB files
+ """
+
+ def __init__(self, opt):
+ super(VideoTestDataset, self).__init__()
+ self.opt = opt
+ self.cache_data = opt['cache_data']
+ self.half_N_frames = opt['N_frames'] // 2
+ self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
+ self.data_type = self.opt['data_type']
+ self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []}
+ if self.data_type == 'lmdb':
+ raise ValueError('No need to use LMDB during validation/test.')
+ #### Generate data info and cache data
+ self.imgs_LQ, self.imgs_GT = {}, {}
+ if opt['name'].lower() in ['vid4', 'reds4']:
+ subfolders_LQ = util.glob_file_list(self.LQ_root)
+ subfolders_GT = util.glob_file_list(self.GT_root)
+ for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT):
+ subfolder_name = osp.basename(subfolder_GT)
+ img_paths_LQ = util.glob_file_list(subfolder_LQ)
+ img_paths_GT = util.glob_file_list(subfolder_GT)
+ max_idx = len(img_paths_LQ)
+ assert max_idx == len(
+ img_paths_GT), 'Different number of images in LQ and GT folders'
+ self.data_info['path_LQ'].extend(img_paths_LQ)
+ self.data_info['path_GT'].extend(img_paths_GT)
+ self.data_info['folder'].extend([subfolder_name] * max_idx)
+ for i in range(max_idx):
+ self.data_info['idx'].append('{}/{}'.format(i, max_idx))
+ border_l = [0] * max_idx
+ for i in range(self.half_N_frames):
+ border_l[i] = 1
+ border_l[max_idx - i - 1] = 1
+ self.data_info['border'].extend(border_l)
+
+ if self.cache_data:
+ self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ)
+ self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT)
+ elif opt['name'].lower() in ['vimeo90k-test']:
+ pass # TODO
+ else:
+ raise ValueError(
+ 'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.')
+
+ def __getitem__(self, index):
+ # path_LQ = self.data_info['path_LQ'][index]
+ # path_GT = self.data_info['path_GT'][index]
+ folder = self.data_info['folder'][index]
+ idx, max_idx = self.data_info['idx'][index].split('/')
+ idx, max_idx = int(idx), int(max_idx)
+ border = self.data_info['border'][index]
+
+ if self.cache_data:
+ select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
+ padding=self.opt['padding'])
+ imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx))
+ img_GT = self.imgs_GT[folder][idx]
+ else:
+ pass # TODO
+
+ return {
+ 'LQs': imgs_LQ,
+ 'GT': img_GT,
+ 'folder': folder,
+ 'idx': self.data_info['idx'][index],
+ 'border': border
+ }
+
+ def __len__(self):
+ return len(self.data_info['path_GT'])
diff --git a/codes/data_scripts/create_lmdb.py b/codes/data_scripts/create_lmdb.py
new file mode 100644
index 00000000..c44b5c78
--- /dev/null
+++ b/codes/data_scripts/create_lmdb.py
@@ -0,0 +1,411 @@
+"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets"""
+
+import sys
+import os.path as osp
+import glob
+import pickle
+from multiprocessing import Pool
+import numpy as np
+import lmdb
+import cv2
+
+sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
+import data.util as data_util # noqa: E402
+import utils.util as util # noqa: E402
+
+
+def main():
+ dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
+ mode = 'GT' # used for vimeo90k and REDS datasets
+ # vimeo90k: GT | LR | flow
+ # REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
+ # train_sharp_flowx4
+ if dataset == 'vimeo90k':
+ vimeo90k(mode)
+ elif dataset == 'REDS':
+ REDS(mode)
+ elif dataset == 'general':
+ opt = {}
+ opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
+ opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
+ opt['name'] = 'DIV2K800_sub_GT'
+ general_image_folder(opt)
+ elif dataset == 'DIV2K_demo':
+ opt = {}
+ ## GT
+ opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
+ opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
+ opt['name'] = 'DIV2K800_sub_GT'
+ general_image_folder(opt)
+ ## LR
+ opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
+ opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
+ opt['name'] = 'DIV2K800_sub_bicLRx4'
+ general_image_folder(opt)
+ elif dataset == 'test':
+ test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS')
+
+
+def read_image_worker(path, key):
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ return (key, img)
+
+
+def general_image_folder(opt):
+ """Create lmdb for general image folders
+ Users should define the keys, such as: '0321_s035' for DIV2K sub-images
+ If all the images have the same resolution, it will only store one copy of resolution info.
+ Otherwise, it will store every resolution info.
+ """
+ #### configurations
+ read_all_imgs = False # whether real all images to memory with multiprocessing
+ # Set False for use limited memory
+ BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
+ n_thread = 40
+ ########################################################
+ img_folder = opt['img_folder']
+ lmdb_save_path = opt['lmdb_save_path']
+ meta_info = {'name': opt['name']}
+ if not lmdb_save_path.endswith('.lmdb'):
+ raise ValueError("lmdb_save_path must end with \'lmdb\'.")
+ if osp.exists(lmdb_save_path):
+ print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
+ sys.exit(1)
+
+ #### read all the image paths to a list
+ print('Reading image path list ...')
+ all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
+ keys = []
+ for img_path in all_img_list:
+ keys.append(osp.splitext(osp.basename(img_path))[0])
+
+ if read_all_imgs:
+ #### read all images to memory (multiprocessing)
+ dataset = {} # store all image data. list cannot keep the order, use dict
+ print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
+ pbar = util.ProgressBar(len(all_img_list))
+
+ def mycallback(arg):
+ '''get the image data and update pbar'''
+ key = arg[0]
+ dataset[key] = arg[1]
+ pbar.update('Reading {}'.format(key))
+
+ pool = Pool(n_thread)
+ for path, key in zip(all_img_list, keys):
+ pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
+ pool.close()
+ pool.join()
+ print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
+
+ #### create lmdb environment
+ data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
+ print('data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(all_img_list)
+ env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
+
+ #### write data to lmdb
+ pbar = util.ProgressBar(len(all_img_list))
+ txn = env.begin(write=True)
+ resolutions = []
+ for idx, (path, key) in enumerate(zip(all_img_list, keys)):
+ pbar.update('Write {}'.format(key))
+ key_byte = key.encode('ascii')
+ data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if data.ndim == 2:
+ H, W = data.shape
+ C = 1
+ else:
+ H, W, C = data.shape
+ txn.put(key_byte, data)
+ resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
+ if not read_all_imgs and idx % BATCH == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ txn.commit()
+ env.close()
+ print('Finish writing lmdb.')
+
+ #### create meta information
+ # check whether all the images are the same size
+ assert len(keys) == len(resolutions)
+ if len(set(resolutions)) <= 1:
+ meta_info['resolution'] = [resolutions[0]]
+ meta_info['keys'] = keys
+ print('All images have the same resolution. Simplify the meta info.')
+ else:
+ meta_info['resolution'] = resolutions
+ meta_info['keys'] = keys
+ print('Not all images have the same resolution. Save meta info for each image.')
+
+ pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
+ print('Finish creating lmdb meta info.')
+
+
+def vimeo90k(mode):
+ """Create lmdb for the Vimeo90K dataset, each image with a fixed size
+ GT: [3, 256, 448]
+ Now only need the 4th frame, e.g., 00001_0001_4
+ LR: [3, 64, 112]
+ 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
+ key:
+ Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
+
+ flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3]
+ Each flow is calculated with GT images by PWCNet and then downsampled by 1/4
+ Flow map is quantized by mmcv and saved in png format
+ """
+ #### configurations
+ read_all_imgs = False # whether real all images to memory with multiprocessing
+ # Set False for use limited memory
+ BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
+ if mode == 'GT':
+ img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
+ lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
+ txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
+ H_dst, W_dst = 256, 448
+ elif mode == 'LR':
+ img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
+ lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
+ txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
+ H_dst, W_dst = 64, 112
+ elif mode == 'flow':
+ img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4'
+ lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb'
+ txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
+ H_dst, W_dst = 128, 112
+ else:
+ raise ValueError('Wrong dataset mode: {}'.format(mode))
+ n_thread = 40
+ ########################################################
+ if not lmdb_save_path.endswith('.lmdb'):
+ raise ValueError("lmdb_save_path must end with \'lmdb\'.")
+ if osp.exists(lmdb_save_path):
+ print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
+ sys.exit(1)
+
+ #### read all the image paths to a list
+ print('Reading image path list ...')
+ with open(txt_file) as f:
+ train_l = f.readlines()
+ train_l = [v.strip() for v in train_l]
+ all_img_list = []
+ keys = []
+ for line in train_l:
+ folder = line.split('/')[0]
+ sub_folder = line.split('/')[1]
+ all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
+ if mode == 'flow':
+ for j in range(1, 4):
+ keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j))
+ keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j))
+ else:
+ for j in range(7):
+ keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
+ all_img_list = sorted(all_img_list)
+ keys = sorted(keys)
+ if mode == 'GT': # only read the 4th frame for the GT mode
+ print('Only keep the 4th frame.')
+ all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
+ keys = [v for v in keys if v.endswith('_4')]
+
+ if read_all_imgs:
+ #### read all images to memory (multiprocessing)
+ dataset = {} # store all image data. list cannot keep the order, use dict
+ print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
+ pbar = util.ProgressBar(len(all_img_list))
+
+ def mycallback(arg):
+ """get the image data and update pbar"""
+ key = arg[0]
+ dataset[key] = arg[1]
+ pbar.update('Reading {}'.format(key))
+
+ pool = Pool(n_thread)
+ for path, key in zip(all_img_list, keys):
+ pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
+ pool.close()
+ pool.join()
+ print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
+
+ #### write data to lmdb
+ data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
+ print('data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(all_img_list)
+ env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
+ txn = env.begin(write=True)
+ pbar = util.ProgressBar(len(all_img_list))
+ for idx, (path, key) in enumerate(zip(all_img_list, keys)):
+ pbar.update('Write {}'.format(key))
+ key_byte = key.encode('ascii')
+ data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if 'flow' in mode:
+ H, W = data.shape
+ assert H == H_dst and W == W_dst, 'different shape.'
+ else:
+ H, W, C = data.shape
+ assert H == H_dst and W == W_dst and C == 3, 'different shape.'
+ txn.put(key_byte, data)
+ if not read_all_imgs and idx % BATCH == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ txn.commit()
+ env.close()
+ print('Finish writing lmdb.')
+
+ #### create meta information
+ meta_info = {}
+ if mode == 'GT':
+ meta_info['name'] = 'Vimeo90K_train_GT'
+ elif mode == 'LR':
+ meta_info['name'] = 'Vimeo90K_train_LR'
+ elif mode == 'flow':
+ meta_info['name'] = 'Vimeo90K_train_flowx4'
+ channel = 1 if 'flow' in mode else 3
+ meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
+ key_set = set()
+ for key in keys:
+ if mode == 'flow':
+ a, b, _, _ = key.split('_')
+ else:
+ a, b, _ = key.split('_')
+ key_set.add('{}_{}'.format(a, b))
+ meta_info['keys'] = list(key_set)
+ pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
+ print('Finish creating lmdb meta info.')
+
+
+def REDS(mode):
+ """Create lmdb for the REDS dataset, each image with a fixed size
+ GT: [3, 720, 1280], key: 000_00000000
+ LR: [3, 180, 320], key: 000_00000000
+ key: 000_00000000
+
+ flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
+ Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
+ Flow map is quantized by mmcv and saved in png format
+ """
+ #### configurations
+ read_all_imgs = False # whether real all images to memory with multiprocessing
+ # Set False for use limited memory
+ BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
+ if mode == 'train_sharp':
+ img_folder = '../../datasets/REDS/train_sharp'
+ lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
+ H_dst, W_dst = 720, 1280
+ elif mode == 'train_sharp_bicubic':
+ img_folder = '../../datasets/REDS/train_sharp_bicubic'
+ lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
+ H_dst, W_dst = 180, 320
+ elif mode == 'train_blur_bicubic':
+ img_folder = '../../datasets/REDS/train_blur_bicubic'
+ lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
+ H_dst, W_dst = 180, 320
+ elif mode == 'train_blur':
+ img_folder = '../../datasets/REDS/train_blur'
+ lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
+ H_dst, W_dst = 720, 1280
+ elif mode == 'train_blur_comp':
+ img_folder = '../../datasets/REDS/train_blur_comp'
+ lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
+ H_dst, W_dst = 720, 1280
+ elif mode == 'train_sharp_flowx4':
+ img_folder = '../../datasets/REDS/train_sharp_flowx4'
+ lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
+ H_dst, W_dst = 360, 320
+ n_thread = 40
+ ########################################################
+ if not lmdb_save_path.endswith('.lmdb'):
+ raise ValueError("lmdb_save_path must end with \'lmdb\'.")
+ if osp.exists(lmdb_save_path):
+ print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
+ sys.exit(1)
+
+ #### read all the image paths to a list
+ print('Reading image path list ...')
+ all_img_list = data_util._get_paths_from_images(img_folder)
+ keys = []
+ for img_path in all_img_list:
+ split_rlt = img_path.split('/')
+ folder = split_rlt[-2]
+ img_name = split_rlt[-1].split('.png')[0]
+ keys.append(folder + '_' + img_name)
+
+ if read_all_imgs:
+ #### read all images to memory (multiprocessing)
+ dataset = {} # store all image data. list cannot keep the order, use dict
+ print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
+ pbar = util.ProgressBar(len(all_img_list))
+
+ def mycallback(arg):
+ '''get the image data and update pbar'''
+ key = arg[0]
+ dataset[key] = arg[1]
+ pbar.update('Reading {}'.format(key))
+
+ pool = Pool(n_thread)
+ for path, key in zip(all_img_list, keys):
+ pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
+ pool.close()
+ pool.join()
+ print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
+
+ #### create lmdb environment
+ data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
+ print('data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(all_img_list)
+ env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
+
+ #### write data to lmdb
+ pbar = util.ProgressBar(len(all_img_list))
+ txn = env.begin(write=True)
+ for idx, (path, key) in enumerate(zip(all_img_list, keys)):
+ pbar.update('Write {}'.format(key))
+ key_byte = key.encode('ascii')
+ data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if 'flow' in mode:
+ H, W = data.shape
+ assert H == H_dst and W == W_dst, 'different shape.'
+ else:
+ H, W, C = data.shape
+ assert H == H_dst and W == W_dst and C == 3, 'different shape.'
+ txn.put(key_byte, data)
+ if not read_all_imgs and idx % BATCH == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ txn.commit()
+ env.close()
+ print('Finish writing lmdb.')
+
+ #### create meta information
+ meta_info = {}
+ meta_info['name'] = 'REDS_{}_wval'.format(mode)
+ channel = 1 if 'flow' in mode else 3
+ meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
+ meta_info['keys'] = keys
+ pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
+ print('Finish creating lmdb meta info.')
+
+
+def test_lmdb(dataroot, dataset='REDS'):
+ env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
+ meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
+ print('Name: ', meta_info['name'])
+ print('Resolution: ', meta_info['resolution'])
+ print('# keys: ', len(meta_info['keys']))
+ # read one image
+ if dataset == 'vimeo90k':
+ key = '00001_0001_4'
+ else:
+ key = '000_00000000'
+ print('Reading {} for test.'.format(key))
+ with env.begin(write=False) as txn:
+ buf = txn.get(key.encode('ascii'))
+ img_flat = np.frombuffer(buf, dtype=np.uint8)
+ C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
+ img = img_flat.reshape(H, W, C)
+ cv2.imwrite('test.png', img)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codes/data_scripts/extract_subimages.py b/codes/data_scripts/extract_subimages.py
new file mode 100644
index 00000000..a8185df2
--- /dev/null
+++ b/codes/data_scripts/extract_subimages.py
@@ -0,0 +1,141 @@
+"""A multi-thread tool to crop large images to sub-images for faster IO."""
+import os
+import os.path as osp
+import sys
+from multiprocessing import Pool
+import numpy as np
+import cv2
+from PIL import Image
+sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
+from utils.util import ProgressBar # noqa: E402
+import data.util as data_util # noqa: E402
+
+
+def main():
+ mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
+ opt = {}
+ opt['n_thread'] = 20
+ opt['compression_level'] = 3 # 3 is the default value in cv2
+ # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
+ # compression time. If read raw images during training, use 0 for faster IO speed.
+ if mode == 'single':
+ opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR'
+ opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
+ opt['crop_sz'] = 480 # the size of each sub-image
+ opt['step'] = 240 # step of the sliding crop window
+ opt['thres_sz'] = 48 # size threshold
+ extract_signle(opt)
+ elif mode == 'pair':
+ GT_folder = '../../datasets/DIV2K/DIV2K_train_HR'
+ LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
+ save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub'
+ save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
+ scale_ratio = 4
+ crop_sz = 480 # the size of each sub-image (GT)
+ step = 240 # step of the sliding crop window (GT)
+ thres_sz = 48 # size threshold
+ ########################################################################
+ # check that all the GT and LR images have correct scale ratio
+ img_GT_list = data_util._get_paths_from_images(GT_folder)
+ img_LR_list = data_util._get_paths_from_images(LR_folder)
+ assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
+ for path_GT, path_LR in zip(img_GT_list, img_LR_list):
+ img_GT = Image.open(path_GT)
+ img_LR = Image.open(path_LR)
+ w_GT, h_GT = img_GT.size
+ w_LR, h_LR = img_LR.size
+ assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
+ w_GT, scale_ratio, w_LR, path_GT)
+ assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
+ w_GT, scale_ratio, w_LR, path_GT)
+ # check crop size, step and threshold size
+ assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
+ scale_ratio)
+ assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
+ assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
+ scale_ratio)
+ print('process GT...')
+ opt['input_folder'] = GT_folder
+ opt['save_folder'] = save_GT_folder
+ opt['crop_sz'] = crop_sz
+ opt['step'] = step
+ opt['thres_sz'] = thres_sz
+ extract_signle(opt)
+ print('process LR...')
+ opt['input_folder'] = LR_folder
+ opt['save_folder'] = save_LR_folder
+ opt['crop_sz'] = crop_sz // scale_ratio
+ opt['step'] = step // scale_ratio
+ opt['thres_sz'] = thres_sz // scale_ratio
+ extract_signle(opt)
+ assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
+ data_util._get_paths_from_images(
+ save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
+ else:
+ raise ValueError('Wrong mode.')
+
+
+def extract_signle(opt):
+ input_folder = opt['input_folder']
+ save_folder = opt['save_folder']
+ if not osp.exists(save_folder):
+ os.makedirs(save_folder)
+ print('mkdir [{:s}] ...'.format(save_folder))
+ else:
+ print('Folder [{:s}] already exists. Exit...'.format(save_folder))
+ sys.exit(1)
+ img_list = data_util._get_paths_from_images(input_folder)
+
+ def update(arg):
+ pbar.update(arg)
+
+ pbar = ProgressBar(len(img_list))
+
+ pool = Pool(opt['n_thread'])
+ for path in img_list:
+ pool.apply_async(worker, args=(path, opt), callback=update)
+ pool.close()
+ pool.join()
+ print('All subprocesses done.')
+
+
+def worker(path, opt):
+ crop_sz = opt['crop_sz']
+ step = opt['step']
+ thres_sz = opt['thres_sz']
+ img_name = osp.basename(path)
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+
+ n_channels = len(img.shape)
+ if n_channels == 2:
+ h, w = img.shape
+ elif n_channels == 3:
+ h, w, c = img.shape
+ else:
+ raise ValueError('Wrong image shape - {}'.format(n_channels))
+
+ h_space = np.arange(0, h - crop_sz + 1, step)
+ if h - (h_space[-1] + crop_sz) > thres_sz:
+ h_space = np.append(h_space, h - crop_sz)
+ w_space = np.arange(0, w - crop_sz + 1, step)
+ if w - (w_space[-1] + crop_sz) > thres_sz:
+ w_space = np.append(w_space, w - crop_sz)
+
+ index = 0
+ for x in h_space:
+ for y in w_space:
+ index += 1
+ if n_channels == 2:
+ crop_img = img[x:x + crop_sz, y:y + crop_sz]
+ else:
+ crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
+ crop_img = np.ascontiguousarray(crop_img)
+ cv2.imwrite(
+ osp.join(opt['save_folder'],
+ img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
+ [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
+ return 'Processing {:s} ...'.format(img_name)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/codes/data_scripts/generate_LR_Vimeo90K.m b/codes/data_scripts/generate_LR_Vimeo90K.m
new file mode 100644
index 00000000..acce7898
--- /dev/null
+++ b/codes/data_scripts/generate_LR_Vimeo90K.m
@@ -0,0 +1,49 @@
+function generate_LR_Vimeo90K()
+%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
+
+up_scale = 4;
+mod_scale = 4;
+idx = 0;
+filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
+for i = 1 : length(filepaths)
+ [~,imname,ext] = fileparts(filepaths(i).name);
+ folder_path = filepaths(i).folder;
+ save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
+ if ~exist(save_LR_folder, 'dir')
+ mkdir(save_LR_folder);
+ end
+ if isempty(imname)
+ disp('Ignore . folder.');
+ elseif strcmp(imname, '.')
+ disp('Ignore .. folder.');
+ else
+ idx = idx + 1;
+ str_rlt = sprintf('%d\t%s.\n', idx, imname);
+ fprintf(str_rlt);
+ % read image
+ img = imread(fullfile(folder_path, [imname, ext]));
+ img = im2double(img);
+ % modcrop
+ img = modcrop(img, mod_scale);
+ % LR
+ im_LR = imresize(img, 1/up_scale, 'bicubic');
+ if exist('save_LR_folder', 'var')
+ imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
+ end
+ end
+end
+end
+
+%% modcrop
+function img = modcrop(img, modulo)
+if size(img,3) == 1
+ sz = size(img);
+ sz = sz - mod(sz, modulo);
+ img = img(1:sz(1), 1:sz(2));
+else
+ tmpsz = size(img);
+ sz = tmpsz(1:2);
+ sz = sz - mod(sz, modulo);
+ img = img(1:sz(1), 1:sz(2),:);
+end
+end
diff --git a/codes/data_scripts/generate_mod_LR_bic.m b/codes/data_scripts/generate_mod_LR_bic.m
new file mode 100644
index 00000000..05a9c61a
--- /dev/null
+++ b/codes/data_scripts/generate_mod_LR_bic.m
@@ -0,0 +1,82 @@
+function generate_mod_LR_bic()
+%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
+
+%% set parameters
+% comment the unnecessary line
+input_folder = '../../datasets/DIV2K/DIV2K800';
+% save_mod_folder = '';
+save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4';
+% save_bic_folder = '';
+
+up_scale = 4;
+mod_scale = 4;
+
+if exist('save_mod_folder', 'var')
+ if exist(save_mod_folder, 'dir')
+ disp(['It will cover ', save_mod_folder]);
+ else
+ mkdir(save_mod_folder);
+ end
+end
+if exist('save_LR_folder', 'var')
+ if exist(save_LR_folder, 'dir')
+ disp(['It will cover ', save_LR_folder]);
+ else
+ mkdir(save_LR_folder);
+ end
+end
+if exist('save_bic_folder', 'var')
+ if exist(save_bic_folder, 'dir')
+ disp(['It will cover ', save_bic_folder]);
+ else
+ mkdir(save_bic_folder);
+ end
+end
+
+idx = 0;
+filepaths = dir(fullfile(input_folder,'*.*'));
+for i = 1 : length(filepaths)
+ [paths,imname,ext] = fileparts(filepaths(i).name);
+ if isempty(imname)
+ disp('Ignore . folder.');
+ elseif strcmp(imname, '.')
+ disp('Ignore .. folder.');
+ else
+ idx = idx + 1;
+ str_rlt = sprintf('%d\t%s.\n', idx, imname);
+ fprintf(str_rlt);
+ % read image
+ img = imread(fullfile(input_folder, [imname, ext]));
+ img = im2double(img);
+ % modcrop
+ img = modcrop(img, mod_scale);
+ if exist('save_mod_folder', 'var')
+ imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
+ end
+ % LR
+ im_LR = imresize(img, 1/up_scale, 'bicubic');
+ if exist('save_LR_folder', 'var')
+ imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
+ end
+ % Bicubic
+ if exist('save_bic_folder', 'var')
+ im_B = imresize(im_LR, up_scale, 'bicubic');
+ imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
+ end
+ end
+end
+end
+
+%% modcrop
+function img = modcrop(img, modulo)
+if size(img,3) == 1
+ sz = size(img);
+ sz = sz - mod(sz, modulo);
+ img = img(1:sz(1), 1:sz(2));
+else
+ tmpsz = size(img);
+ sz = tmpsz(1:2);
+ sz = sz - mod(sz, modulo);
+ img = img(1:sz(1), 1:sz(2),:);
+end
+end
diff --git a/codes/data_scripts/generate_mod_LR_bic.py b/codes/data_scripts/generate_mod_LR_bic.py
new file mode 100644
index 00000000..59b313a8
--- /dev/null
+++ b/codes/data_scripts/generate_mod_LR_bic.py
@@ -0,0 +1,81 @@
+import os
+import sys
+import cv2
+import numpy as np
+
+try:
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ from data.util import imresize_np
+except ImportError:
+ pass
+
+
+def generate_mod_LR_bic():
+ # set parameters
+ up_scale = 4
+ mod_scale = 4
+ # set data dir
+ sourcedir = '/data/datasets/img'
+ savedir = '/data/datasets/mod'
+
+ saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale))
+ saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
+ saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
+
+ if not os.path.isdir(sourcedir):
+ print('Error: No source data found')
+ exit(0)
+ if not os.path.isdir(savedir):
+ os.mkdir(savedir)
+
+ if not os.path.isdir(os.path.join(savedir, 'HR')):
+ os.mkdir(os.path.join(savedir, 'HR'))
+ if not os.path.isdir(os.path.join(savedir, 'LR')):
+ os.mkdir(os.path.join(savedir, 'LR'))
+ if not os.path.isdir(os.path.join(savedir, 'Bic')):
+ os.mkdir(os.path.join(savedir, 'Bic'))
+
+ if not os.path.isdir(saveHRpath):
+ os.mkdir(saveHRpath)
+ else:
+ print('It will cover ' + str(saveHRpath))
+
+ if not os.path.isdir(saveLRpath):
+ os.mkdir(saveLRpath)
+ else:
+ print('It will cover ' + str(saveLRpath))
+
+ if not os.path.isdir(saveBicpath):
+ os.mkdir(saveBicpath)
+ else:
+ print('It will cover ' + str(saveBicpath))
+
+ filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
+ num_files = len(filepaths)
+
+ # prepare data with augementation
+ for i in range(num_files):
+ filename = filepaths[i]
+ print('No.{} -- Processing {}'.format(i, filename))
+ # read image
+ image = cv2.imread(os.path.join(sourcedir, filename))
+
+ width = int(np.floor(image.shape[1] / mod_scale))
+ height = int(np.floor(image.shape[0] / mod_scale))
+ # modcrop
+ if len(image.shape) == 3:
+ image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
+ else:
+ image_HR = image[0:mod_scale * height, 0:mod_scale * width]
+ # LR
+ image_LR = imresize_np(image_HR, 1 / up_scale, True)
+ # bic
+ image_Bic = imresize_np(image_LR, up_scale, True)
+
+ cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
+ cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
+ cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
+
+
+if __name__ == "__main__":
+ generate_mod_LR_bic()
diff --git a/codes/data_scripts/prepare_DIV2K_x4_dataset.sh b/codes/data_scripts/prepare_DIV2K_x4_dataset.sh
new file mode 100644
index 00000000..a53bd1f0
--- /dev/null
+++ b/codes/data_scripts/prepare_DIV2K_x4_dataset.sh
@@ -0,0 +1,42 @@
+
+
+echo "Prepare DIV2K X4 datasets..."
+cd ../../datasets
+mkdir DIV2K
+cd DIV2K
+
+#### Step 1
+echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..."
+# GT
+FOLDER=DIV2K_train_HR
+FILE=DIV2K_train_HR.zip
+if [ ! -d "$FOLDER" ]; then
+ if [ ! -f "$FILE" ]; then
+ echo "Downloading $FILE..."
+ wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
+ fi
+ unzip $FILE
+fi
+# LR
+FOLDER=DIV2K_train_LR_bicubic
+FILE=DIV2K_train_LR_bicubic_X4.zip
+if [ ! -d "$FOLDER" ]; then
+ if [ ! -f "$FILE" ]; then
+ echo "Downloading $FILE..."
+ wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
+ fi
+ unzip $FILE
+fi
+
+#### Step 2
+echo "Step 2: Rename the LR images..."
+cd ../../codes/data_scripts
+python rename.py
+
+#### Step 4
+echo "Step 4: Crop to sub-images..."
+python extract_subimages.py
+
+#### Step 5
+echo "Step5: Create LMDB files..."
+python create_lmdb.py
diff --git a/codes/data_scripts/regroup_REDS.py b/codes/data_scripts/regroup_REDS.py
new file mode 100644
index 00000000..7c8fa928
--- /dev/null
+++ b/codes/data_scripts/regroup_REDS.py
@@ -0,0 +1,11 @@
+import os
+import glob
+
+train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4'
+val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4'
+
+# mv the val set
+val_folders = glob.glob(os.path.join(val_path, '*'))
+for folder in val_folders:
+ new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240)
+ os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx)))
diff --git a/codes/data_scripts/rename.py b/codes/data_scripts/rename.py
new file mode 100644
index 00000000..f8a19552
--- /dev/null
+++ b/codes/data_scripts/rename.py
@@ -0,0 +1,19 @@
+import os
+import glob
+
+
+def main():
+ folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
+ DIV2K(folder)
+ print('Finished.')
+
+
+def DIV2K(path):
+ img_path_l = glob.glob(os.path.join(path, '*'))
+ for img_path in img_path_l:
+ new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
+ os.rename(img_path, new_path)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/codes/data_scripts/test_dataloader.py b/codes/data_scripts/test_dataloader.py
new file mode 100644
index 00000000..5f580793
--- /dev/null
+++ b/codes/data_scripts/test_dataloader.py
@@ -0,0 +1,104 @@
+import sys
+import os.path as osp
+import math
+import torchvision.utils
+
+sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
+from data import create_dataloader, create_dataset # noqa: E402
+from utils import util # noqa: E402
+
+
+def main():
+ dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
+ opt = {}
+ opt['dist'] = False
+ opt['gpu_ids'] = [0]
+ if dataset == 'REDS':
+ opt['name'] = 'test_REDS'
+ opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
+ opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
+ opt['mode'] = 'REDS'
+ opt['N_frames'] = 5
+ opt['phase'] = 'train'
+ opt['use_shuffle'] = True
+ opt['n_workers'] = 8
+ opt['batch_size'] = 16
+ opt['GT_size'] = 256
+ opt['LQ_size'] = 64
+ opt['scale'] = 4
+ opt['use_flip'] = True
+ opt['use_rot'] = True
+ opt['interval_list'] = [1]
+ opt['random_reverse'] = False
+ opt['border_mode'] = False
+ opt['cache_keys'] = None
+ opt['data_type'] = 'lmdb' # img | lmdb | mc
+ elif dataset == 'Vimeo90K':
+ opt['name'] = 'test_Vimeo90K'
+ opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
+ opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
+ opt['mode'] = 'Vimeo90K'
+ opt['N_frames'] = 7
+ opt['phase'] = 'train'
+ opt['use_shuffle'] = True
+ opt['n_workers'] = 8
+ opt['batch_size'] = 16
+ opt['GT_size'] = 256
+ opt['LQ_size'] = 64
+ opt['scale'] = 4
+ opt['use_flip'] = True
+ opt['use_rot'] = True
+ opt['interval_list'] = [1]
+ opt['random_reverse'] = False
+ opt['border_mode'] = False
+ opt['cache_keys'] = None
+ opt['data_type'] = 'lmdb' # img | lmdb | mc
+ elif dataset == 'DIV2K800_sub':
+ opt['name'] = 'DIV2K800'
+ opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
+ opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
+ opt['mode'] = 'LQGT'
+ opt['phase'] = 'train'
+ opt['use_shuffle'] = True
+ opt['n_workers'] = 8
+ opt['batch_size'] = 16
+ opt['GT_size'] = 128
+ opt['scale'] = 4
+ opt['use_flip'] = True
+ opt['use_rot'] = True
+ opt['color'] = 'RGB'
+ opt['data_type'] = 'lmdb' # img | lmdb
+ else:
+ raise ValueError('Please implement by yourself.')
+
+ util.mkdir('tmp')
+ train_set = create_dataset(opt)
+ train_loader = create_dataloader(train_set, opt, opt, None)
+ nrow = int(math.sqrt(opt['batch_size']))
+ padding = 2 if opt['phase'] == 'train' else 0
+
+ print('start...')
+ for i, data in enumerate(train_loader):
+ if i > 5:
+ break
+ print(i)
+ if dataset == 'REDS' or dataset == 'Vimeo90K':
+ LQs = data['LQs']
+ else:
+ LQ = data['LQ']
+ GT = data['GT']
+
+ if dataset == 'REDS' or dataset == 'Vimeo90K':
+ for j in range(LQs.size(1)):
+ torchvision.utils.save_image(LQs[:, j, :, :, :],
+ 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
+ padding=padding, normalize=False)
+ else:
+ torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
+ padding=padding, normalize=False)
+ torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
+ normalize=False)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codes/metrics/calculate_PSNR_SSIM.m b/codes/metrics/calculate_PSNR_SSIM.m
new file mode 100644
index 00000000..471de83f
--- /dev/null
+++ b/codes/metrics/calculate_PSNR_SSIM.m
@@ -0,0 +1,261 @@
+function calculate_PSNR_SSIM()
+
+% GT and SR folder
+folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
+folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
+scale = 4;
+suffix = ''; % suffix for SR images
+test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
+if test_Y
+ fprintf('Tesing Y channel.\n');
+else
+ fprintf('Tesing RGB channels.\n');
+end
+filepaths = dir(fullfile(folder_GT, '*.png'));
+PSNR_all = zeros(1, length(filepaths));
+SSIM_all = zeros(1, length(filepaths));
+
+for idx_im = 1:length(filepaths)
+ im_name = filepaths(idx_im).name;
+ im_GT = imread(fullfile(folder_GT, im_name));
+ im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
+
+ if test_Y % evaluate on Y channel in YCbCr color space
+ if size(im_GT, 3) == 3
+ im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
+ im_GT_in = im_GT_YCbCr(:,:,1);
+ im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
+ im_SR_in = im_SR_YCbCr(:,:,1);
+ else
+ im_GT_in = im2double(im_GT);
+ im_SR_in = im2double(im_SR);
+ end
+ else % evaluate on RGB channels
+ im_GT_in = im2double(im_GT);
+ im_SR_in = im2double(im_SR);
+ end
+
+ % calculate PSNR and SSIM
+ PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
+ SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
+ fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
+end
+
+fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
+end
+
+function res = calculate_PSNR(GT, SR, border)
+% remove border
+GT = GT(border+1:end-border, border+1:end-border, :);
+SR = SR(border+1:end-border, border+1:end-border, :);
+% calculate PNSR (assume in [0,255])
+error = GT(:) - SR(:);
+mse = mean(error.^2);
+res = 10 * log10(255^2/mse);
+end
+
+function res = calculate_SSIM(GT, SR, border)
+GT = GT(border+1:end-border, border+1:end-border, :);
+SR = SR(border+1:end-border, border+1:end-border, :);
+% calculate SSIM
+mssim = zeros(1, size(SR, 3));
+for i = 1:size(SR,3)
+ [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
+end
+res = mean(mssim);
+end
+
+function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
+
+%========================================================================
+%SSIM Index, Version 1.0
+%Copyright(c) 2003 Zhou Wang
+%All Rights Reserved.
+%
+%The author is with Howard Hughes Medical Institute, and Laboratory
+%for Computational Vision at Center for Neural Science and Courant
+%Institute of Mathematical Sciences, New York University.
+%
+%----------------------------------------------------------------------
+%Permission to use, copy, or modify this software and its documentation
+%for educational and research purposes only and without fee is hereby
+%granted, provided that this copyright notice and the original authors'
+%names appear on all copies and supporting documentation. This program
+%shall not be used, rewritten, or adapted as the basis of a commercial
+%software or hardware product without first obtaining permission of the
+%authors. The authors make no representations about the suitability of
+%this software for any purpose. It is provided "as is" without express
+%or implied warranty.
+%----------------------------------------------------------------------
+%
+%This is an implementation of the algorithm for calculating the
+%Structural SIMilarity (SSIM) index between two images. Please refer
+%to the following paper:
+%
+%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
+%quality assessment: From error measurement to structural similarity"
+%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
+%
+%Kindly report any suggestions or corrections to zhouwang@ieee.org
+%
+%----------------------------------------------------------------------
+%
+%Input : (1) img1: the first image being compared
+% (2) img2: the second image being compared
+% (3) K: constants in the SSIM index formula (see the above
+% reference). defualt value: K = [0.01 0.03]
+% (4) window: local window for statistics (see the above
+% reference). default widnow is Gaussian given by
+% window = fspecial('gaussian', 11, 1.5);
+% (5) L: dynamic range of the images. default: L = 255
+%
+%Output: (1) mssim: the mean SSIM index value between 2 images.
+% If one of the images being compared is regarded as
+% perfect quality, then mssim can be considered as the
+% quality measure of the other image.
+% If img1 = img2, then mssim = 1.
+% (2) ssim_map: the SSIM index map of the test image. The map
+% has a smaller size than the input images. The actual size:
+% size(img1) - size(window) + 1.
+%
+%Default Usage:
+% Given 2 test images img1 and img2, whose dynamic range is 0-255
+%
+% [mssim ssim_map] = ssim_index(img1, img2);
+%
+%Advanced Usage:
+% User defined parameters. For example
+%
+% K = [0.05 0.05];
+% window = ones(8);
+% L = 100;
+% [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
+%
+%See the results:
+%
+% mssim %Gives the mssim value
+% imshow(max(0, ssim_map).^4) %Shows the SSIM index map
+%
+%========================================================================
+
+
+if (nargin < 2 || nargin > 5)
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return;
+end
+
+if (size(img1) ~= size(img2))
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return;
+end
+
+[M, N] = size(img1);
+
+if (nargin == 2)
+ if ((M < 11) || (N < 11))
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return
+ end
+ window = fspecial('gaussian', 11, 1.5); %
+ K(1) = 0.01; % default settings
+ K(2) = 0.03; %
+ L = 255; %
+end
+
+if (nargin == 3)
+ if ((M < 11) || (N < 11))
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return
+ end
+ window = fspecial('gaussian', 11, 1.5);
+ L = 255;
+ if (length(K) == 2)
+ if (K(1) < 0 || K(2) < 0)
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return;
+ end
+ else
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return;
+ end
+end
+
+if (nargin == 4)
+ [H, W] = size(window);
+ if ((H*W) < 4 || (H > M) || (W > N))
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return
+ end
+ L = 255;
+ if (length(K) == 2)
+ if (K(1) < 0 || K(2) < 0)
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return;
+ end
+ else
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return;
+ end
+end
+
+if (nargin == 5)
+ [H, W] = size(window);
+ if ((H*W) < 4 || (H > M) || (W > N))
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return
+ end
+ if (length(K) == 2)
+ if (K(1) < 0 || K(2) < 0)
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return;
+ end
+ else
+ ssim_index = -Inf;
+ ssim_map = -Inf;
+ return;
+ end
+end
+
+C1 = (K(1)*L)^2;
+C2 = (K(2)*L)^2;
+window = window/sum(sum(window));
+img1 = double(img1);
+img2 = double(img2);
+
+mu1 = filter2(window, img1, 'valid');
+mu2 = filter2(window, img2, 'valid');
+mu1_sq = mu1.*mu1;
+mu2_sq = mu2.*mu2;
+mu1_mu2 = mu1.*mu2;
+sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
+sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
+sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
+
+if (C1 > 0 && C2 > 0)
+ ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
+else
+ numerator1 = 2*mu1_mu2 + C1;
+ numerator2 = 2*sigma12 + C2;
+ denominator1 = mu1_sq + mu2_sq + C1;
+ denominator2 = sigma1_sq + sigma2_sq + C2;
+ ssim_map = ones(size(mu1));
+ index = (denominator1.*denominator2 > 0);
+ ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
+ index = (denominator1 ~= 0) & (denominator2 == 0);
+ ssim_map(index) = numerator1(index)./denominator1(index);
+end
+
+mssim = mean2(ssim_map);
+
+end
diff --git a/codes/metrics/calculate_PSNR_SSIM.py b/codes/metrics/calculate_PSNR_SSIM.py
new file mode 100644
index 00000000..fecc0310
--- /dev/null
+++ b/codes/metrics/calculate_PSNR_SSIM.py
@@ -0,0 +1,147 @@
+'''
+calculate the PSNR and SSIM.
+same as MATLAB's results
+'''
+import os
+import math
+import numpy as np
+import cv2
+import glob
+
+
+def main():
+ # Configurations
+
+ # GT - Ground-truth;
+ # Gen: Generated / Restored / Recovered images
+ folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
+ folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
+
+ crop_border = 4
+ suffix = '' # suffix for Gen images
+ test_Y = False # True: test Y channel only; False: test RGB channels
+
+ PSNR_all = []
+ SSIM_all = []
+ img_list = sorted(glob.glob(folder_GT + '/*'))
+
+ if test_Y:
+ print('Testing Y channel.')
+ else:
+ print('Testing RGB channels.')
+
+ for i, img_path in enumerate(img_list):
+ base_name = os.path.splitext(os.path.basename(img_path))[0]
+ im_GT = cv2.imread(img_path) / 255.
+ im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
+
+ if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
+ im_GT_in = bgr2ycbcr(im_GT)
+ im_Gen_in = bgr2ycbcr(im_Gen)
+ else:
+ im_GT_in = im_GT
+ im_Gen_in = im_Gen
+
+ # crop borders
+ if im_GT_in.ndim == 3:
+ cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
+ cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
+ elif im_GT_in.ndim == 2:
+ cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
+ cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
+ else:
+ raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
+
+ # calculate PSNR and SSIM
+ PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
+
+ SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
+ print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
+ i + 1, base_name, PSNR, SSIM))
+ PSNR_all.append(PSNR)
+ SSIM_all.append(SSIM)
+ print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
+ sum(PSNR_all) / len(PSNR_all),
+ sum(SSIM_all) / len(SSIM_all)))
+
+
+def calculate_psnr(img1, img2):
+ # img1 and img2 have range [0, 255]
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+def calculate_ssim(img1, img2):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1, img2))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
new file mode 100644
index 00000000..77bb0fca
--- /dev/null
+++ b/codes/models/SRGAN_model.py
@@ -0,0 +1,267 @@
+import logging
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+import models.networks as networks
+import models.lr_scheduler as lr_scheduler
+from .base_model import BaseModel
+from models.loss import GANLoss
+
+logger = logging.getLogger('base')
+
+
+class SRGANModel(BaseModel):
+ def __init__(self, opt):
+ super(SRGANModel, self).__init__(opt)
+ if opt['dist']:
+ self.rank = torch.distributed.get_rank()
+ else:
+ self.rank = -1 # non dist training
+ train_opt = opt['train']
+
+ # define networks and load pretrained models
+ self.netG = networks.define_G(opt).to(self.device)
+ if opt['dist']:
+ self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
+ else:
+ self.netG = DataParallel(self.netG)
+ if self.is_train:
+ self.netD = networks.define_D(opt).to(self.device)
+ if opt['dist']:
+ self.netD = DistributedDataParallel(self.netD,
+ device_ids=[torch.cuda.current_device()])
+ else:
+ self.netD = DataParallel(self.netD)
+
+ self.netG.train()
+ self.netD.train()
+
+ # define losses, optimizer and scheduler
+ if self.is_train:
+ # G pixel loss
+ if train_opt['pixel_weight'] > 0:
+ l_pix_type = train_opt['pixel_criterion']
+ if l_pix_type == 'l1':
+ self.cri_pix = nn.L1Loss().to(self.device)
+ elif l_pix_type == 'l2':
+ self.cri_pix = nn.MSELoss().to(self.device)
+ else:
+ raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
+ self.l_pix_w = train_opt['pixel_weight']
+ else:
+ logger.info('Remove pixel loss.')
+ self.cri_pix = None
+
+ # G feature loss
+ if train_opt['feature_weight'] > 0:
+ l_fea_type = train_opt['feature_criterion']
+ if l_fea_type == 'l1':
+ self.cri_fea = nn.L1Loss().to(self.device)
+ elif l_fea_type == 'l2':
+ self.cri_fea = nn.MSELoss().to(self.device)
+ else:
+ raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
+ self.l_fea_w = train_opt['feature_weight']
+ else:
+ logger.info('Remove feature loss.')
+ self.cri_fea = None
+ if self.cri_fea: # load VGG perceptual loss
+ self.netF = networks.define_F(opt, use_bn=False).to(self.device)
+ if opt['dist']:
+ self.netF = DistributedDataParallel(self.netF,
+ device_ids=[torch.cuda.current_device()])
+ else:
+ self.netF = DataParallel(self.netF)
+
+ # GD gan loss
+ self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
+ self.l_gan_w = train_opt['gan_weight']
+ # D_update_ratio and D_init_iters
+ self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
+ self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
+
+ # optimizers
+ # G
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
+ optim_params = []
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
+ if v.requires_grad:
+ optim_params.append(v)
+ else:
+ if self.rank <= 0:
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
+ self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
+ weight_decay=wd_G,
+ betas=(train_opt['beta1_G'], train_opt['beta2_G']))
+ self.optimizers.append(self.optimizer_G)
+ # D
+ wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
+ weight_decay=wd_D,
+ betas=(train_opt['beta1_D'], train_opt['beta2_D']))
+ self.optimizers.append(self.optimizer_D)
+
+ # schedulers
+ if train_opt['lr_scheme'] == 'MultiStepLR':
+ for optimizer in self.optimizers:
+ self.schedulers.append(
+ lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
+ restarts=train_opt['restarts'],
+ weights=train_opt['restart_weights'],
+ gamma=train_opt['lr_gamma'],
+ clear_state=train_opt['clear_state']))
+ elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
+ for optimizer in self.optimizers:
+ self.schedulers.append(
+ lr_scheduler.CosineAnnealingLR_Restart(
+ optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
+ restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
+ else:
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
+
+ self.log_dict = OrderedDict()
+
+ self.print_network() # print network
+ self.load() # load G and D if needed
+
+ def feed_data(self, data, need_GT=True):
+ self.var_L = data['LQ'].to(self.device) # LQ
+ if need_GT:
+ self.var_H = data['GT'].to(self.device) # GT
+ input_ref = data['ref'] if 'ref' in data else data['GT']
+ self.var_ref = input_ref.to(self.device)
+
+ def optimize_parameters(self, step):
+ # G
+ for p in self.netD.parameters():
+ p.requires_grad = False
+
+ self.optimizer_G.zero_grad()
+ self.fake_H = self.netG(self.var_L)
+
+ l_g_total = 0
+ if step % self.D_update_ratio == 0 and step > self.D_init_iters:
+ if self.cri_pix: # pixel loss
+ l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
+ l_g_total += l_g_pix
+ if self.cri_fea: # feature loss
+ real_fea = self.netF(self.var_H).detach()
+ fake_fea = self.netF(self.fake_H)
+ l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
+ l_g_total += l_g_fea
+
+ pred_g_fake = self.netD(self.fake_H)
+ if self.opt['train']['gan_type'] == 'gan':
+ l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
+ elif self.opt['train']['gan_type'] == 'ragan':
+ pred_d_real = self.netD(self.var_ref).detach()
+ l_g_gan = self.l_gan_w * (
+ self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
+ self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
+ l_g_total += l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_G.step()
+
+ # D
+ for p in self.netD.parameters():
+ p.requires_grad = True
+
+ self.optimizer_D.zero_grad()
+ l_d_total = 0
+ pred_d_real = self.netD(self.var_ref)
+ pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
+ if self.opt['train']['gan_type'] == 'gan':
+ l_d_real = self.cri_gan(pred_d_real, True)
+ l_d_fake = self.cri_gan(pred_d_fake, False)
+ l_d_total = l_d_real + l_d_fake
+ elif self.opt['train']['gan_type'] == 'ragan':
+ l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
+ l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
+ l_d_total = (l_d_real + l_d_fake) / 2
+
+ l_d_total.backward()
+ self.optimizer_D.step()
+
+ # set log
+ if step % self.D_update_ratio == 0 and step > self.D_init_iters:
+ if self.cri_pix:
+ self.log_dict['l_g_pix'] = l_g_pix.item()
+ if self.cri_fea:
+ self.log_dict['l_g_fea'] = l_g_fea.item()
+ self.log_dict['l_g_gan'] = l_g_gan.item()
+
+ self.log_dict['l_d_real'] = l_d_real.item()
+ self.log_dict['l_d_fake'] = l_d_fake.item()
+ self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
+ self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
+
+ def test(self):
+ self.netG.eval()
+ with torch.no_grad():
+ self.fake_H = self.netG(self.var_L)
+ self.netG.train()
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def get_current_visuals(self, need_GT=True):
+ out_dict = OrderedDict()
+ out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
+ out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
+ if need_GT:
+ out_dict['GT'] = self.var_H.detach()[0].float().cpu()
+ return out_dict
+
+ def print_network(self):
+ # Generator
+ s, n = self.get_network_description(self.netG)
+ if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
+ self.netG.module.__class__.__name__)
+ else:
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
+ if self.rank <= 0:
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
+ logger.info(s)
+ if self.is_train:
+ # Discriminator
+ s, n = self.get_network_description(self.netD)
+ if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
+ DistributedDataParallel):
+ net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
+ self.netD.module.__class__.__name__)
+ else:
+ net_struc_str = '{}'.format(self.netD.__class__.__name__)
+ if self.rank <= 0:
+ logger.info('Network D structure: {}, with parameters: {:,d}'.format(
+ net_struc_str, n))
+ logger.info(s)
+
+ if self.cri_fea: # F, Perceptual Network
+ s, n = self.get_network_description(self.netF)
+ if isinstance(self.netF, nn.DataParallel) or isinstance(
+ self.netF, DistributedDataParallel):
+ net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
+ self.netF.module.__class__.__name__)
+ else:
+ net_struc_str = '{}'.format(self.netF.__class__.__name__)
+ if self.rank <= 0:
+ logger.info('Network F structure: {}, with parameters: {:,d}'.format(
+ net_struc_str, n))
+ logger.info(s)
+
+ def load(self):
+ load_path_G = self.opt['path']['pretrain_model_G']
+ if load_path_G is not None:
+ logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
+ self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
+ load_path_D = self.opt['path']['pretrain_model_D']
+ if self.opt['is_train'] and load_path_D is not None:
+ logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
+ self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
+
+ def save(self, iter_step):
+ self.save_network(self.netG, 'G', iter_step)
+ self.save_network(self.netD, 'D', iter_step)
diff --git a/codes/models/SR_model.py b/codes/models/SR_model.py
new file mode 100644
index 00000000..6782762a
--- /dev/null
+++ b/codes/models/SR_model.py
@@ -0,0 +1,170 @@
+import logging
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+import models.networks as networks
+import models.lr_scheduler as lr_scheduler
+from .base_model import BaseModel
+from models.loss import CharbonnierLoss
+
+logger = logging.getLogger('base')
+
+
+class SRModel(BaseModel):
+ def __init__(self, opt):
+ super(SRModel, self).__init__(opt)
+
+ if opt['dist']:
+ self.rank = torch.distributed.get_rank()
+ else:
+ self.rank = -1 # non dist training
+ train_opt = opt['train']
+
+ # define network and load pretrained models
+ self.netG = networks.define_G(opt).to(self.device)
+ if opt['dist']:
+ self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
+ else:
+ self.netG = DataParallel(self.netG)
+ # print network
+ self.print_network()
+ self.load()
+
+ if self.is_train:
+ self.netG.train()
+
+ # loss
+ loss_type = train_opt['pixel_criterion']
+ if loss_type == 'l1':
+ self.cri_pix = nn.L1Loss().to(self.device)
+ elif loss_type == 'l2':
+ self.cri_pix = nn.MSELoss().to(self.device)
+ elif loss_type == 'cb':
+ self.cri_pix = CharbonnierLoss().to(self.device)
+ else:
+ raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
+ self.l_pix_w = train_opt['pixel_weight']
+
+ # optimizers
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
+ optim_params = []
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
+ if v.requires_grad:
+ optim_params.append(v)
+ else:
+ if self.rank <= 0:
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
+ self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
+ weight_decay=wd_G,
+ betas=(train_opt['beta1'], train_opt['beta2']))
+ self.optimizers.append(self.optimizer_G)
+
+ # schedulers
+ if train_opt['lr_scheme'] == 'MultiStepLR':
+ for optimizer in self.optimizers:
+ self.schedulers.append(
+ lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
+ restarts=train_opt['restarts'],
+ weights=train_opt['restart_weights'],
+ gamma=train_opt['lr_gamma'],
+ clear_state=train_opt['clear_state']))
+ elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
+ for optimizer in self.optimizers:
+ self.schedulers.append(
+ lr_scheduler.CosineAnnealingLR_Restart(
+ optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
+ restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
+ else:
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
+
+ self.log_dict = OrderedDict()
+
+ def feed_data(self, data, need_GT=True):
+ self.var_L = data['LQ'].to(self.device) # LQ
+ if need_GT:
+ self.real_H = data['GT'].to(self.device) # GT
+
+ def optimize_parameters(self, step):
+ self.optimizer_G.zero_grad()
+ self.fake_H = self.netG(self.var_L)
+ l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
+ l_pix.backward()
+ self.optimizer_G.step()
+
+ # set log
+ self.log_dict['l_pix'] = l_pix.item()
+
+ def test(self):
+ self.netG.eval()
+ with torch.no_grad():
+ self.fake_H = self.netG(self.var_L)
+ self.netG.train()
+
+ def test_x8(self):
+ # from https://github.com/thstkdgus35/EDSR-PyTorch
+ self.netG.eval()
+
+ def _transform(v, op):
+ # if self.precision != 'single': v = v.float()
+ v2np = v.data.cpu().numpy()
+ if op == 'v':
+ tfnp = v2np[:, :, :, ::-1].copy()
+ elif op == 'h':
+ tfnp = v2np[:, :, ::-1, :].copy()
+ elif op == 't':
+ tfnp = v2np.transpose((0, 1, 3, 2)).copy()
+
+ ret = torch.Tensor(tfnp).to(self.device)
+ # if self.precision == 'half': ret = ret.half()
+
+ return ret
+
+ lr_list = [self.var_L]
+ for tf in 'v', 'h', 't':
+ lr_list.extend([_transform(t, tf) for t in lr_list])
+ with torch.no_grad():
+ sr_list = [self.netG(aug) for aug in lr_list]
+ for i in range(len(sr_list)):
+ if i > 3:
+ sr_list[i] = _transform(sr_list[i], 't')
+ if i % 4 > 1:
+ sr_list[i] = _transform(sr_list[i], 'h')
+ if (i % 4) % 2 == 1:
+ sr_list[i] = _transform(sr_list[i], 'v')
+
+ output_cat = torch.cat(sr_list, dim=0)
+ self.fake_H = output_cat.mean(dim=0, keepdim=True)
+ self.netG.train()
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def get_current_visuals(self, need_GT=True):
+ out_dict = OrderedDict()
+ out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
+ out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
+ if need_GT:
+ out_dict['GT'] = self.real_H.detach()[0].float().cpu()
+ return out_dict
+
+ def print_network(self):
+ s, n = self.get_network_description(self.netG)
+ if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
+ self.netG.module.__class__.__name__)
+ else:
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
+ if self.rank <= 0:
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
+ logger.info(s)
+
+ def load(self):
+ load_path_G = self.opt['path']['pretrain_model_G']
+ if load_path_G is not None:
+ logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
+ self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
+
+ def save(self, iter_label):
+ self.save_network(self.netG, 'G', iter_label)
diff --git a/codes/models/Video_base_model.py b/codes/models/Video_base_model.py
new file mode 100644
index 00000000..eb85fc5c
--- /dev/null
+++ b/codes/models/Video_base_model.py
@@ -0,0 +1,166 @@
+import logging
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+import models.networks as networks
+import models.lr_scheduler as lr_scheduler
+from .base_model import BaseModel
+from models.loss import CharbonnierLoss
+
+logger = logging.getLogger('base')
+
+
+class VideoBaseModel(BaseModel):
+ def __init__(self, opt):
+ super(VideoBaseModel, self).__init__(opt)
+
+ if opt['dist']:
+ self.rank = torch.distributed.get_rank()
+ else:
+ self.rank = -1 # non dist training
+ train_opt = opt['train']
+
+ # define network and load pretrained models
+ self.netG = networks.define_G(opt).to(self.device)
+ if opt['dist']:
+ self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
+ else:
+ self.netG = DataParallel(self.netG)
+ # print network
+ self.print_network()
+ self.load()
+
+ if self.is_train:
+ self.netG.train()
+
+ #### loss
+ loss_type = train_opt['pixel_criterion']
+ if loss_type == 'l1':
+ self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
+ elif loss_type == 'l2':
+ self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
+ elif loss_type == 'cb':
+ self.cri_pix = CharbonnierLoss().to(self.device)
+ else:
+ raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
+ self.l_pix_w = train_opt['pixel_weight']
+
+ #### optimizers
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
+ if train_opt['ft_tsa_only']:
+ normal_params = []
+ tsa_fusion_params = []
+ for k, v in self.netG.named_parameters():
+ if v.requires_grad:
+ if 'tsa_fusion' in k:
+ tsa_fusion_params.append(v)
+ else:
+ normal_params.append(v)
+ else:
+ if self.rank <= 0:
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
+ optim_params = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['lr_G']
+ },
+ {
+ 'params': tsa_fusion_params,
+ 'lr': train_opt['lr_G']
+ },
+ ]
+ else:
+ optim_params = []
+ for k, v in self.netG.named_parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ else:
+ if self.rank <= 0:
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
+
+ self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
+ weight_decay=wd_G,
+ betas=(train_opt['beta1'], train_opt['beta2']))
+ self.optimizers.append(self.optimizer_G)
+
+ #### schedulers
+ if train_opt['lr_scheme'] == 'MultiStepLR':
+ for optimizer in self.optimizers:
+ self.schedulers.append(
+ lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
+ restarts=train_opt['restarts'],
+ weights=train_opt['restart_weights'],
+ gamma=train_opt['lr_gamma'],
+ clear_state=train_opt['clear_state']))
+ elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
+ for optimizer in self.optimizers:
+ self.schedulers.append(
+ lr_scheduler.CosineAnnealingLR_Restart(
+ optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
+ restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
+ else:
+ raise NotImplementedError()
+
+ self.log_dict = OrderedDict()
+
+ def feed_data(self, data, need_GT=True):
+ self.var_L = data['LQs'].to(self.device)
+ if need_GT:
+ self.real_H = data['GT'].to(self.device)
+
+ def set_params_lr_zero(self):
+ # fix normal module
+ self.optimizers[0].param_groups[0]['lr'] = 0
+
+ def optimize_parameters(self, step):
+ if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
+ self.set_params_lr_zero()
+
+ self.optimizer_G.zero_grad()
+ self.fake_H = self.netG(self.var_L)
+
+ l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
+ l_pix.backward()
+ self.optimizer_G.step()
+
+ # set log
+ self.log_dict['l_pix'] = l_pix.item()
+
+ def test(self):
+ self.netG.eval()
+ with torch.no_grad():
+ self.fake_H = self.netG(self.var_L)
+ self.netG.train()
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def get_current_visuals(self, need_GT=True):
+ out_dict = OrderedDict()
+ out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
+ out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
+ if need_GT:
+ out_dict['GT'] = self.real_H.detach()[0].float().cpu()
+ return out_dict
+
+ def print_network(self):
+ s, n = self.get_network_description(self.netG)
+ if isinstance(self.netG, nn.DataParallel):
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
+ self.netG.module.__class__.__name__)
+ else:
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
+ if self.rank <= 0:
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
+ logger.info(s)
+
+ def load(self):
+ load_path_G = self.opt['path']['pretrain_model_G']
+ if load_path_G is not None:
+ logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
+ self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
+
+ def save(self, iter_label):
+ self.save_network(self.netG, 'G', iter_label)
diff --git a/codes/models/__init__.py b/codes/models/__init__.py
new file mode 100644
index 00000000..c95004c9
--- /dev/null
+++ b/codes/models/__init__.py
@@ -0,0 +1,19 @@
+import logging
+logger = logging.getLogger('base')
+
+
+def create_model(opt):
+ model = opt['model']
+ # image restoration
+ if model == 'sr': # PSNR-oriented super resolution
+ from .SR_model import SRModel as M
+ elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN
+ from .SRGAN_model import SRGANModel as M
+ # video restoration
+ elif model == 'video_base':
+ from .Video_base_model import VideoBaseModel as M
+ else:
+ raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
+ m = M(opt)
+ logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
+ return m
diff --git a/codes/models/archs/DUF_arch.py b/codes/models/archs/DUF_arch.py
new file mode 100644
index 00000000..319cea87
--- /dev/null
+++ b/codes/models/archs/DUF_arch.py
@@ -0,0 +1,368 @@
+'''Network architecture for DUF:
+Deep Video Super-Resolution Network Using Dynamic Upsampling Filters
+Without Explicit Motion Compensation (CVPR18)
+https://github.com/yhjo09/VSR-DUF
+
+For all the models below, [adapt_official] is only necessary when
+loading the weights converted from the official TensorFlow weights.
+Please set it to [False] if you are training the model from scratch.
+'''
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def adapt_official(Rx, scale=4):
+ '''Adapt the weights translated from the official tensorflow weights
+ Not necessary if you are training from scratch'''
+ x = Rx.clone()
+ x1 = x[:, ::3, :, :]
+ x2 = x[:, 1::3, :, :]
+ x3 = x[:, 2::3, :, :]
+
+ Rx[:, :scale**2, :, :] = x1
+ Rx[:, scale**2:2 * (scale**2), :, :] = x2
+ Rx[:, 2 * (scale**2):, :, :] = x3
+
+ return Rx
+
+
+class DenseBlock(nn.Module):
+ '''Dense block
+ for the second denseblock, t_reduced = True'''
+
+ def __init__(self, nf=64, ng=32, t_reduce=False):
+ super(DenseBlock, self).__init__()
+ self.t_reduce = t_reduce
+ if self.t_reduce:
+ pad = (0, 1, 1)
+ else:
+ pad = (1, 1, 1)
+ self.bn3d_1 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
+ self.conv3d_1 = nn.Conv3d(nf, nf, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+ self.bn3d_2 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
+ self.conv3d_2 = nn.Conv3d(nf, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
+ self.bn3d_3 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
+ self.conv3d_3 = nn.Conv3d(nf + ng, nf + ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True)
+ self.bn3d_4 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
+ self.conv3d_4 = nn.Conv3d(nf + ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
+ self.bn3d_5 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
+ self.conv3d_5 = nn.Conv3d(nf + 2 * ng, nf + 2 * ng, (1, 1, 1), stride=(1, 1, 1),
+ padding=(0, 0, 0), bias=True)
+ self.bn3d_6 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
+ self.conv3d_6 = nn.Conv3d(nf + 2 * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad,
+ bias=True)
+
+ def forward(self, x):
+ '''x: [B, C, T, H, W]
+ C: nf -> nf + 3 * ng
+ T: 1) 7 -> 7 (t_reduce=False);
+ 2) 7 -> 7 - 2 * 3 = 1 (t_reduce=True)'''
+ x1 = self.conv3d_1(F.relu(self.bn3d_1(x), inplace=True))
+ x1 = self.conv3d_2(F.relu(self.bn3d_2(x1), inplace=True))
+ if self.t_reduce:
+ x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
+ else:
+ x1 = torch.cat((x, x1), 1)
+
+ x2 = self.conv3d_3(F.relu(self.bn3d_3(x1), inplace=True))
+ x2 = self.conv3d_4(F.relu(self.bn3d_4(x2), inplace=True))
+ if self.t_reduce:
+ x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
+ else:
+ x2 = torch.cat((x1, x2), 1)
+
+ x3 = self.conv3d_5(F.relu(self.bn3d_5(x2), inplace=True))
+ x3 = self.conv3d_6(F.relu(self.bn3d_6(x3), inplace=True))
+ if self.t_reduce:
+ x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
+ else:
+ x3 = torch.cat((x2, x3), 1)
+ return x3
+
+
+class DynamicUpsamplingFilter_3C(nn.Module):
+ '''dynamic upsampling filter with 3 channels applying the same filters
+ filter_size: filter size of the generated filters, shape (C, kH, kW)'''
+
+ def __init__(self, filter_size=(1, 5, 5)):
+ super(DynamicUpsamplingFilter_3C, self).__init__()
+ # generate a local expansion filter, used similar to im2col
+ nF = np.prod(filter_size)
+ expand_filter_np = np.reshape(np.eye(nF, nF),
+ (nF, filter_size[0], filter_size[1], filter_size[2]))
+ expand_filter = torch.from_numpy(expand_filter_np).float()
+ self.expand_filter = torch.cat((expand_filter, expand_filter, expand_filter),
+ 0) # [75, 1, 5, 5]
+
+ def forward(self, x, filters):
+ '''x: input image, [B, 3, H, W]
+ filters: generate dynamic filters, [B, F, R, H, W], e.g., [B, 25, 16, H, W]
+ F: prod of filter kernel size, e.g., 5*5 = 25
+ R: used for upsampling, similar to pixel shuffle, e.g., 4*4 = 16 for x4
+ Return: filtered image, [B, 3*R, H, W]
+ '''
+ B, nF, R, H, W = filters.size()
+ # using group convolution
+ input_expand = F.conv2d(x, self.expand_filter.type_as(x), padding=2,
+ groups=3) # [B, 75, H, W] similar to im2col
+ input_expand = input_expand.view(B, 3, nF, H, W).permute(0, 3, 4, 1, 2) # [B, H, W, 3, 25]
+ filters = filters.permute(0, 3, 4, 1, 2) # [B, H, W, 25, 16]
+ out = torch.matmul(input_expand, filters) # [B, H, W, 3, 16]
+ return out.permute(0, 3, 4, 1, 2).view(B, 3 * R, H, W) # [B, 3*16, H, W]
+
+
+class DUF_16L(nn.Module):
+ '''Official DUF structure with 16 layers'''
+
+ def __init__(self, scale=4, adapt_official=False):
+ super(DUF_16L, self).__init__()
+ self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+ self.dense_block_1 = DenseBlock(64, 64 // 2, t_reduce=False) # 64 + 32 * 3 = 160, T = 7
+ self.dense_block_2 = DenseBlock(160, 64 // 2, t_reduce=True) # 160 + 32 * 3 = 256, T = 1
+ self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
+ self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
+ bias=True)
+
+ self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True)
+ self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
+ padding=(0, 0, 0), bias=True)
+
+ self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True)
+ self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
+ padding=(0, 0, 0), bias=True)
+
+ self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
+
+ self.scale = scale
+ self.adapt_official = adapt_official
+
+ def forward(self, x):
+ '''
+ x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
+ Generate filters and image residual:
+ Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
+ Rx: [B, 3*16, 1, H, W]
+ '''
+ B, T, C, H, W = x.size()
+ x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] for Conv3D
+ x_center = x[:, :, T // 2, :, :]
+
+ x = self.conv3d_1(x)
+ x = self.dense_block_1(x)
+ x = self.dense_block_2(x) # reduce T to 1
+ x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
+
+ # image residual
+ Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
+
+ # filter
+ Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
+ Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
+
+ # Adapt to official model weights
+ if self.adapt_official:
+ adapt_official(Rx, scale=self.scale)
+
+ # dynamic filter
+ out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
+ out += Rx.squeeze_(2)
+ out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
+
+ return out
+
+
+class DenseBlock_28L(nn.Module):
+ '''The first part of the dense blocks used in DUF_28L
+ Temporal dimension remains the same here'''
+
+ def __init__(self, nf=64, ng=16):
+ super(DenseBlock_28L, self).__init__()
+ pad = (1, 1, 1)
+
+ dense_block_l = []
+ for i in range(0, 9):
+ dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
+ dense_block_l.append(nn.ReLU())
+ dense_block_l.append(
+ nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True))
+
+ dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
+ dense_block_l.append(nn.ReLU())
+ dense_block_l.append(
+ nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
+
+ self.dense_blocks = nn.ModuleList(dense_block_l)
+
+ def forward(self, x):
+ '''x: [B, C, T, H, W]
+ C: 1) 64 -> 208;
+ T: 1) 7 -> 7; (t_reduce=True)'''
+ for i in range(0, len(self.dense_blocks), 6):
+ y = x
+ for j in range(6):
+ y = self.dense_blocks[i + j](y)
+ x = torch.cat((x, y), 1)
+ return x
+
+
+class DUF_28L(nn.Module):
+ '''Official DUF structure with 28 layers'''
+
+ def __init__(self, scale=4, adapt_official=False):
+ super(DUF_28L, self).__init__()
+ self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+ self.dense_block_1 = DenseBlock_28L(64, 16) # 64 + 16 * 9 = 208, T = 7
+ self.dense_block_2 = DenseBlock(208, 16, t_reduce=True) # 208 + 16 * 3 = 256, T = 1
+ self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
+ self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
+ bias=True)
+
+ self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True)
+ self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
+ padding=(0, 0, 0), bias=True)
+
+ self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True)
+ self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
+ padding=(0, 0, 0), bias=True)
+
+ self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
+
+ self.scale = scale
+ self.adapt_official = adapt_official
+
+ def forward(self, x):
+ '''
+ x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
+ Generate filters and image residual:
+ Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
+ Rx: [B, 3*16, 1, H, W]
+ '''
+ B, T, C, H, W = x.size()
+ x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
+ x_center = x[:, :, T // 2, :, :]
+ x = self.conv3d_1(x)
+ x = self.dense_block_1(x)
+ x = self.dense_block_2(x) # reduce T to 1
+ x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
+
+ # image residual
+ Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
+
+ # filter
+ Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
+ Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
+
+ # Adapt to official model weights
+ if self.adapt_official:
+ adapt_official(Rx, scale=self.scale)
+
+ # dynamic filter
+ out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
+ out += Rx.squeeze_(2)
+ out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
+ return out
+
+
+class DenseBlock_52L(nn.Module):
+ '''The first part of the dense blocks used in DUF_52L
+ Temporal dimension remains the same here'''
+
+ def __init__(self, nf=64, ng=16):
+ super(DenseBlock_52L, self).__init__()
+ pad = (1, 1, 1)
+
+ dense_block_l = []
+ for i in range(0, 21):
+ dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
+ dense_block_l.append(nn.ReLU())
+ dense_block_l.append(
+ nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True))
+
+ dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
+ dense_block_l.append(nn.ReLU())
+ dense_block_l.append(
+ nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
+
+ self.dense_blocks = nn.ModuleList(dense_block_l)
+
+ def forward(self, x):
+ '''x: [B, C, T, H, W]
+ C: 1) 64 -> 400;
+ T: 1) 7 -> 7; (t_reduce=True)'''
+ for i in range(0, len(self.dense_blocks), 6):
+ y = x
+ for j in range(6):
+ y = self.dense_blocks[i + j](y)
+ x = torch.cat((x, y), 1)
+ return x
+
+
+class DUF_52L(nn.Module):
+ '''Official DUF structure with 52 layers'''
+
+ def __init__(self, scale=4, adapt_official=False):
+ super(DUF_52L, self).__init__()
+ self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+ self.dense_block_1 = DenseBlock_52L(64, 16) # 64 + 21 * 9 = 400, T = 7
+ self.dense_block_2 = DenseBlock(400, 16, t_reduce=True) # 400 + 16 * 3 = 448, T = 1
+
+ self.bn3d_2 = nn.BatchNorm3d(448, eps=1e-3, momentum=1e-3)
+ self.conv3d_2 = nn.Conv3d(448, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
+ bias=True)
+
+ self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True)
+ self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
+ padding=(0, 0, 0), bias=True)
+
+ self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
+ bias=True)
+ self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
+ padding=(0, 0, 0), bias=True)
+
+ self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
+
+ self.scale = scale
+ self.adapt_official = adapt_official
+
+ def forward(self, x):
+ '''
+ x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
+ Generate filters and image residual:
+ Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
+ Rx: [B, 3*16, 1, H, W]
+ '''
+ B, T, C, H, W = x.size()
+ x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
+ x_center = x[:, :, T // 2, :, :]
+ x = self.conv3d_1(x)
+ x = self.dense_block_1(x)
+ x = self.dense_block_2(x)
+ x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
+
+ # image residual
+ Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
+
+ # filter
+ Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
+ Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
+
+ # Adapt to official model weights
+ if self.adapt_official:
+ adapt_official(Rx, scale=self.scale)
+
+ # dynamic filter
+ out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
+ out += Rx.squeeze_(2)
+ out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
+ return out
diff --git a/codes/models/archs/EDVR_arch.py b/codes/models/archs/EDVR_arch.py
new file mode 100644
index 00000000..df9c0325
--- /dev/null
+++ b/codes/models/archs/EDVR_arch.py
@@ -0,0 +1,312 @@
+''' network architecture for EDVR '''
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import models.archs.arch_util as arch_util
+try:
+ from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN
+except ImportError:
+ raise ImportError('Failed to import DCNv2 module.')
+
+
+class Predeblur_ResNet_Pyramid(nn.Module):
+ def __init__(self, nf=128, HR_in=False):
+ '''
+ HR_in: True if the inputs are high spatial size
+ '''
+
+ super(Predeblur_ResNet_Pyramid, self).__init__()
+ self.HR_in = True if HR_in else False
+ if self.HR_in:
+ self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
+ self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
+ self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
+ else:
+ self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
+ basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
+ self.RB_L1_1 = basic_block()
+ self.RB_L1_2 = basic_block()
+ self.RB_L1_3 = basic_block()
+ self.RB_L1_4 = basic_block()
+ self.RB_L1_5 = basic_block()
+ self.RB_L2_1 = basic_block()
+ self.RB_L2_2 = basic_block()
+ self.RB_L3_1 = basic_block()
+ self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
+ self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, x):
+ if self.HR_in:
+ L1_fea = self.lrelu(self.conv_first_1(x))
+ L1_fea = self.lrelu(self.conv_first_2(L1_fea))
+ L1_fea = self.lrelu(self.conv_first_3(L1_fea))
+ else:
+ L1_fea = self.lrelu(self.conv_first(x))
+ L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea))
+ L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea))
+ L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear',
+ align_corners=False)
+ L2_fea = self.RB_L2_1(L2_fea) + L3_fea
+ L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear',
+ align_corners=False)
+ L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea
+ out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea)))
+ return out
+
+
+class PCD_Align(nn.Module):
+ ''' Alignment module using Pyramid, Cascading and Deformable convolution
+ with 3 pyramid levels.
+ '''
+
+ def __init__(self, nf=64, groups=8):
+ super(PCD_Align, self).__init__()
+ # L3: level 3, 1/4 spatial size
+ self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
+ self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
+ extra_offset_mask=True)
+ # L2: level 2, 1/2 spatial size
+ self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
+ self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
+ self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
+ extra_offset_mask=True)
+ self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
+ # L1: level 1, original spatial size
+ self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
+ self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
+ self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
+ extra_offset_mask=True)
+ self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
+ # Cascading DCN
+ self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
+ self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+
+ self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
+ extra_offset_mask=True)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, nbr_fea_l, ref_fea_l):
+ '''align other neighboring frames to the reference frame in the feature level
+ nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
+ '''
+ # L3
+ L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)
+ L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
+ L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
+ L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
+ # L2
+ L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)
+ L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
+ L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
+ L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
+ L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
+ L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
+ L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
+ L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
+ # L1
+ L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)
+ L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
+ L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
+ L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
+ L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
+ L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
+ L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
+ L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
+ # Cascading
+ offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
+ offset = self.lrelu(self.cas_offset_conv1(offset))
+ offset = self.lrelu(self.cas_offset_conv2(offset))
+ L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
+
+ return L1_fea
+
+
+class TSA_Fusion(nn.Module):
+ ''' Temporal Spatial Attention fusion module
+ Temporal: correlation;
+ Spatial: 3 pyramid levels.
+ '''
+
+ def __init__(self, nf=64, nframes=5, center=2):
+ super(TSA_Fusion, self).__init__()
+ self.center = center
+ # temporal attention (before fusion conv)
+ self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+
+ # fusion conv: using 1x1 to save parameters and computation
+ self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
+
+ # spatial attention (after fusion conv)
+ self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
+ self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
+ self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True)
+ self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True)
+ self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
+ self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
+ self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
+ self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, aligned_fea):
+ B, N, C, H, W = aligned_fea.size() # N video frames
+ #### temporal attention
+ emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())
+ emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W]
+
+ cor_l = []
+ for i in range(N):
+ emb_nbr = emb[:, i, :, :, :]
+ cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W
+ cor_l.append(cor_tmp)
+ cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W
+ cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W)
+ aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob
+
+ #### fusion
+ fea = self.lrelu(self.fea_fusion(aligned_fea))
+
+ #### spatial attention
+ att = self.lrelu(self.sAtt_1(aligned_fea))
+ att_max = self.maxpool(att)
+ att_avg = self.avgpool(att)
+ att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
+ # pyramid levels
+ att_L = self.lrelu(self.sAtt_L1(att))
+ att_max = self.maxpool(att_L)
+ att_avg = self.avgpool(att_L)
+ att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
+ att_L = self.lrelu(self.sAtt_L3(att_L))
+ att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
+
+ att = self.lrelu(self.sAtt_3(att))
+ att = att + att_L
+ att = self.lrelu(self.sAtt_4(att))
+ att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
+ att = self.sAtt_5(att)
+ att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
+ att = torch.sigmoid(att)
+
+ fea = fea * att * 2 + att_add
+ return fea
+
+
+class EDVR(nn.Module):
+ def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None,
+ predeblur=False, HR_in=False, w_TSA=True):
+ super(EDVR, self).__init__()
+ self.nf = nf
+ self.center = nframes // 2 if center is None else center
+ self.is_predeblur = True if predeblur else False
+ self.HR_in = True if HR_in else False
+ self.w_TSA = w_TSA
+ ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
+
+ #### extract features (for each frame)
+ if self.is_predeblur:
+ self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
+ self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
+ else:
+ if self.HR_in:
+ self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
+ self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
+ self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
+ else:
+ self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
+ self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
+ self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
+ self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
+ self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+
+ self.pcd_align = PCD_Align(nf=nf, groups=groups)
+ if self.w_TSA:
+ self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
+ else:
+ self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
+
+ #### reconstruction
+ self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
+ #### upsampling
+ self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
+ self.pixel_shuffle = nn.PixelShuffle(2)
+ self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
+
+ #### activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, x):
+ B, N, C, H, W = x.size() # N video frames
+ x_center = x[:, self.center, :, :, :].contiguous()
+
+ #### extract LR features
+ # L1
+ if self.is_predeblur:
+ L1_fea = self.pre_deblur(x.view(-1, C, H, W))
+ L1_fea = self.conv_1x1(L1_fea)
+ if self.HR_in:
+ H, W = H // 4, W // 4
+ else:
+ if self.HR_in:
+ L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W)))
+ L1_fea = self.lrelu(self.conv_first_2(L1_fea))
+ L1_fea = self.lrelu(self.conv_first_3(L1_fea))
+ H, W = H // 4, W // 4
+ else:
+ L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W)))
+ L1_fea = self.feature_extraction(L1_fea)
+ # L2
+ L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
+ L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
+ # L3
+ L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
+ L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
+
+ L1_fea = L1_fea.view(B, N, -1, H, W)
+ L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
+ L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
+
+ #### pcd align
+ # ref feature list
+ ref_fea_l = [
+ L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(),
+ L3_fea[:, self.center, :, :, :].clone()
+ ]
+ aligned_fea = []
+ for i in range(N):
+ nbr_fea_l = [
+ L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(),
+ L3_fea[:, i, :, :, :].clone()
+ ]
+ aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l))
+ aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W]
+
+ if not self.w_TSA:
+ aligned_fea = aligned_fea.view(B, -1, H, W)
+ fea = self.tsa_fusion(aligned_fea)
+
+ out = self.recon_trunk(fea)
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ out = self.lrelu(self.HRconv(out))
+ out = self.conv_last(out)
+ if self.HR_in:
+ base = x_center
+ else:
+ base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
+ out += base
+ return out
diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py
new file mode 100644
index 00000000..9d61256c
--- /dev/null
+++ b/codes/models/archs/RRDBNet_arch.py
@@ -0,0 +1,73 @@
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import models.archs.arch_util as arch_util
+
+
+class ResidualDenseBlock_5C(nn.Module):
+ def __init__(self, nf=64, gc=32, bias=True):
+ super(ResidualDenseBlock_5C, self).__init__()
+ # gc: growth channel, i.e. intermediate channels
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],
+ 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ '''Residual in Residual Dense Block'''
+
+ def __init__(self, nf, gc=32):
+ super(RRDB, self).__init__()
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
+
+ def forward(self, x):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ return out * 0.2 + x
+
+
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
+ super(RRDBNet, self).__init__()
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
+
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
+ self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ #### upsampling
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ fea = self.conv_first(x)
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
+ fea = fea + trunk
+
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
+
+ return out
diff --git a/codes/models/archs/SRResNet_arch.py b/codes/models/archs/SRResNet_arch.py
new file mode 100644
index 00000000..6e622ac3
--- /dev/null
+++ b/codes/models/archs/SRResNet_arch.py
@@ -0,0 +1,55 @@
+import functools
+import torch.nn as nn
+import torch.nn.functional as F
+import models.archs.arch_util as arch_util
+
+
+class MSRResNet(nn.Module):
+ ''' modified SRResNet'''
+
+ def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
+ super(MSRResNet, self).__init__()
+ self.upscale = upscale
+
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
+ basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
+ self.recon_trunk = arch_util.make_layer(basic_block, nb)
+
+ # upsampling
+ if self.upscale == 2:
+ self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
+ self.pixel_shuffle = nn.PixelShuffle(2)
+ elif self.upscale == 3:
+ self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
+ self.pixel_shuffle = nn.PixelShuffle(3)
+ elif self.upscale == 4:
+ self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ # initialization
+ arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
+ 0.1)
+ if self.upscale == 4:
+ arch_util.initialize_weights(self.upconv2, 0.1)
+
+ def forward(self, x):
+ fea = self.lrelu(self.conv_first(x))
+ out = self.recon_trunk(fea)
+
+ if self.upscale == 4:
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ elif self.upscale == 3 or self.upscale == 2:
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+
+ out = self.conv_last(self.lrelu(self.HRconv(out)))
+ base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
+ out += base
+ return out
diff --git a/codes/models/archs/TOF_arch.py b/codes/models/archs/TOF_arch.py
new file mode 100755
index 00000000..02d7a914
--- /dev/null
+++ b/codes/models/archs/TOF_arch.py
@@ -0,0 +1,137 @@
+'''PyTorch implementation of TOFlow
+Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
+Code reference:
+1. https://github.com/anchen1011/toflow
+2. https://github.com/Coldog2333/pytoflow
+'''
+
+import torch
+import torch.nn as nn
+from .arch_util import flow_warp
+
+
+def normalize(x):
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
+ return (x - mean) / std
+
+
+def denormalize(x):
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
+ return x * std + mean
+
+
+class SpyNet_Block(nn.Module):
+ '''A submodule of SpyNet.'''
+
+ def __init__(self):
+ super(SpyNet_Block, self).__init__()
+
+ self.block = nn.Sequential(
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
+ nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
+ nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
+ nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
+ nn.BatchNorm2d(16), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+
+ def forward(self, x):
+ '''
+ input: x: [ref im, nbr im, initial flow] - (B, 8, H, W)
+ output: estimated flow - (B, 2, H, W)
+ '''
+ return self.block(x)
+
+
+class SpyNet(nn.Module):
+ '''SpyNet for estimating optical flow
+ Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016'''
+
+ def __init__(self):
+ super(SpyNet, self).__init__()
+
+ self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)])
+
+ def forward(self, ref, nbr):
+ '''Estimating optical flow in coarse level, upsample, and estimate in fine level
+ input: ref: reference image - [B, 3, H, W]
+ nbr: the neighboring image to be warped - [B, 3, H, W]
+ output: estimated optical flow - [B, 2, H, W]
+ '''
+ B, C, H, W = ref.size()
+ ref = [ref]
+ nbr = [nbr]
+
+ for _ in range(3):
+ ref.insert(
+ 0,
+ nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2,
+ count_include_pad=False))
+ nbr.insert(
+ 0,
+ nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2,
+ count_include_pad=False))
+
+ flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0])
+
+ for i in range(4):
+ flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear',
+ align_corners=True) * 2.0
+ flow = flow_up + self.blocks[i](torch.cat(
+ [ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
+ return flow
+
+
+class TOFlow(nn.Module):
+ def __init__(self, adapt_official=False):
+ super(TOFlow, self).__init__()
+
+ self.SpyNet = SpyNet()
+
+ self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
+ self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4)
+ self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1)
+ self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ self.adapt_official = adapt_official # True if using translated official weights else False
+
+ def forward(self, x):
+ """
+ input: x: input frames - [B, 7, 3, H, W]
+ output: SR reference frame - [B, 3, H, W]
+ """
+
+ B, T, C, H, W = x.size()
+ x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W)
+
+ ref_idx = 3
+ x_ref = x[:, ref_idx, :, :, :]
+
+ # In the official torch code, the 0-th frame is the reference frame
+ if self.adapt_official:
+ x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
+ ref_idx = 0
+
+ x_warped = []
+ for i in range(7):
+ if i == ref_idx:
+ x_warped.append(x_ref)
+ else:
+ x_nbr = x[:, i, :, :, :]
+ flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1)
+ x_warped.append(flow_warp(x_nbr, flow))
+ x_warped = torch.stack(x_warped, dim=1)
+
+ x = x_warped.view(B, -1, H, W)
+ x = self.relu(self.conv_3x7_64_9x9(x))
+ x = self.relu(self.conv_64_64_9x9(x))
+ x = self.relu(self.conv_64_64_1x1(x))
+ x = self.conv_64_3_1x1(x) + x_ref
+
+ return denormalize(x)
diff --git a/codes/models/archs/__init__.py b/codes/models/archs/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py
new file mode 100644
index 00000000..ca5d7fa9
--- /dev/null
+++ b/codes/models/archs/arch_util.py
@@ -0,0 +1,79 @@
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+import torch.nn.functional as F
+
+
+def initialize_weights(net_l, scale=1):
+ if not isinstance(net_l, list):
+ net_l = [net_l]
+ for net in net_l:
+ for m in net.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+ m.weight.data *= scale # for residual block
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias.data, 0.0)
+
+
+def make_layer(block, n_layers):
+ layers = []
+ for _ in range(n_layers):
+ layers.append(block())
+ return nn.Sequential(*layers)
+
+
+class ResidualBlock_noBN(nn.Module):
+ '''Residual block w/o BN
+ ---Conv-ReLU-Conv-+-
+ |________________|
+ '''
+
+ def __init__(self, nf=64):
+ super(ResidualBlock_noBN, self).__init__()
+ self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+
+ # initialization
+ initialize_weights([self.conv1, self.conv2], 0.1)
+
+ def forward(self, x):
+ identity = x
+ out = F.relu(self.conv1(x), inplace=True)
+ out = self.conv2(out)
+ return identity + out
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
+ """Warp an image or feature map with optical flow
+ Args:
+ x (Tensor): size (N, C, H, W)
+ flow (Tensor): size (N, H, W, 2), normal value
+ interp_mode (str): 'nearest' or 'bilinear'
+ padding_mode (str): 'zeros' or 'border' or 'reflection'
+
+ Returns:
+ Tensor: warped image or feature map
+ """
+ assert x.size()[-2:] == flow.size()[1:3]
+ B, C, H, W = x.size()
+ # mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
+ grid.requires_grad = False
+ grid = grid.type_as(x)
+ vgrid = grid + flow
+ # scale grid to [-1,1]
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
+ return output
diff --git a/codes/models/archs/dcn/__init__.py b/codes/models/archs/dcn/__init__.py
new file mode 100644
index 00000000..1c85e1f0
--- /dev/null
+++ b/codes/models/archs/dcn/__init__.py
@@ -0,0 +1,7 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,
+ deform_conv, modulated_deform_conv)
+
+__all__ = [
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+ 'modulated_deform_conv'
+]
diff --git a/codes/models/archs/dcn/deform_conv.py b/codes/models/archs/dcn/deform_conv.py
new file mode 100644
index 00000000..f97cb1c8
--- /dev/null
+++ b/codes/models/archs/dcn/deform_conv.py
@@ -0,0 +1,291 @@
+import math
+import logging
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from . import deform_conv_cuda
+
+logger = logging.getLogger('base')
+
+
+class DeformConvFunction(Function):
+ @staticmethod
+ def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
+ deformable_groups=1, im2col_step=64):
+ if input is not None and input.dim() != 4:
+ raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
+ input.dim()))
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(
+ DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0],
+ ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_cuda.deform_conv_backward_input_cuda(
+ input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
+ weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_cuda.deform_conv_backward_parameters_cuda(
+ input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
+ weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, 1, cur_im2col_step)
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError("convolution input is too small (output would be {})".format('x'.join(
+ map(str, output_size))))
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+ @staticmethod
+ def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
+ groups=1, deformable_groups=1):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+ or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_cuda.modulated_deform_conv_cuda_forward(
+ input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
+ weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
+ ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_cuda.modulated_deform_conv_cuda_backward(
+ input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
+ grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
+ ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
+ None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding - (ctx.dilation *
+ (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding - (ctx.dilation *
+ (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
+ groups=1, deformable_groups=1, bias=False):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, \
+ 'in_channels {} cannot be divisible by groups {}'.format(
+ in_channels, groups)
+ assert out_channels % groups == 0, \
+ 'out_channels {} cannot be divisible by groups {}'.format(
+ out_channels, groups)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class DeformConvPack(DeformConv):
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
+ groups=1, deformable_groups=1, bias=True):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+ def __init__(self, *args, extra_offset_mask=False, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.extra_offset_mask = extra_offset_mask
+ self.conv_offset_mask = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset_mask.weight.data.zero_()
+ self.conv_offset_mask.bias.data.zero_()
+
+ def forward(self, x):
+ if self.extra_offset_mask:
+ # x = [input, features]
+ out = self.conv_offset_mask(x[1])
+ x = x[0]
+ else:
+ out = self.conv_offset_mask(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ offset_mean = torch.mean(torch.abs(offset))
+ if offset_mean > 100:
+ logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
+
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups,
+ self.deformable_groups)
diff --git a/codes/models/archs/dcn/setup.py b/codes/models/archs/dcn/setup.py
new file mode 100644
index 00000000..094d961f
--- /dev/null
+++ b/codes/models/archs/dcn/setup.py
@@ -0,0 +1,22 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+
+def make_cuda_ext(name, sources):
+
+ return CUDAExtension(
+ name='{}'.format(name), sources=[p for p in sources], extra_compile_args={
+ 'cxx': [],
+ 'nvcc': [
+ '-D__CUDA_NO_HALF_OPERATORS__',
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
+ '-D__CUDA_NO_HALF2_OPERATORS__',
+ ]
+ })
+
+
+setup(
+ name='deform_conv', ext_modules=[
+ make_cuda_ext(name='deform_conv_cuda',
+ sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])
+ ], cmdclass={'build_ext': BuildExtension}, zip_safe=False)
diff --git a/codes/models/archs/dcn/src/deform_conv_cuda.cpp b/codes/models/archs/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 00000000..c4563ed8
--- /dev/null
+++ b/codes/models/archs/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,695 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+
+#include
+#include
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im);
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const int channels, const int height,
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+ at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+ int padW, int dilationH, int dilationW, int group,
+ int deformable_group) {
+ AT_CHECK(weight.ndimension() == 4,
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+ "but got: %s",
+ weight.ndimension());
+
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ AT_CHECK(kW > 0 && kH > 0,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+ kW);
+
+ AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+ "kernel size should be consistent with weight, ",
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+ kW, weight.size(2), weight.size(3));
+
+ AT_CHECK(dW > 0 && dH > 0,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+ AT_CHECK(
+ dilationW > 0 && dilationH > 0,
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+ dilationH, dilationW);
+
+ int ndim = input.ndimension();
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+ ndim);
+
+ long nInputPlane = weight.size(1) * group;
+ long inputHeight = input.size(dimh);
+ long inputWidth = input.size(dimw);
+ long nOutputPlane = weight.size(0);
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+ AT_CHECK(nInputPlane % deformable_group == 0,
+ "input channels must divide deformable group size");
+
+ if (outputWidth < 1 || outputHeight < 1)
+ AT_ERROR(
+ "Given input size: (%ld x %ld x %ld). "
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+ outputWidth);
+
+ AT_CHECK(input.size(1) == nInputPlane,
+ "invalid number of input planes, expected: %d, but got: %d",
+ nInputPlane, input.size(1));
+
+ AT_CHECK((inputHeight >= kH && inputWidth >= kW),
+ "input image is smaller than kernel");
+
+ AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+ "invalid spatial size of offset, expected height: %d width: %d, but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+ AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+ "invalid number of channels of offset");
+
+ if (gradOutput != NULL) {
+ AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
+ nOutputPlane, gradOutput->size(dimf));
+
+ AT_CHECK((gradOutput->size(dimh) == outputHeight &&
+ gradOutput->size(dimw) == outputWidth),
+ "invalid size of gradOutput, expected height: %d width: %d , but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, gradOutput->size(dimh),
+ gradOutput->size(dimw));
+ }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ // todo: resize columns to include im2col: done
+ // todo: add im2col_step as input
+ // todo: add new output buffer and transpose it to output (or directly
+ // transpose output) todo: possibly change data indexing because of
+ // parallel_imgs
+
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input.unsqueeze_(0);
+ offset.unsqueeze_(0);
+ }
+
+ // todo: assert batchsize dividable by im2col_step
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+ outputHeight, outputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+ ones = at::ones({outputHeight, outputWidth}, input.options());
+ }
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ at::Tensor output_buffer =
+ at::zeros({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth},
+ output.options());
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
+ output_buffer.size(2), output_buffer.size(3)});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ output_buffer[elt][g] = output_buffer[elt][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output_buffer[elt][g]);
+ }
+ }
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+ output_buffer.size(3), output_buffer.size(4)});
+
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step, outputHeight, outputWidth});
+ output_buffer.transpose_(1, 2);
+ output.copy_(output_buffer);
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ // change order of grad output
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight,
+ outputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ // divide into groups
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, im2col_step, deformable_group,
+ gradOffset[elt]);
+
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
+ }
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ gradOffset = gradOffset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ gradOffset =
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ // todo: transpose and reshape outGrad
+ // todo: reshape columns
+ // todo: add im2col_step as input
+
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+ padW, dilationH, dilationW, group, deformable_group);
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view(
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = gradWeight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+ outputHeight, outputWidth});
+ gradOutputBuffer.copy_(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth});
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ // divide into group
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ gradWeight =
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ gradWeight[g] = gradWeight[g]
+ .flatten(1)
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
+ columns[g].transpose(1, 0), 1.0, scale)
+ .view_as(gradWeight[g]);
+ }
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0),
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3),
+ gradWeight.size(4)});
+ }
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ }
+
+ return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ // resize output
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
+ // resize temporary columns
+ columns =
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+ input.options());
+
+ output = output.view({output.size(0), group, output.size(1) / group,
+ output.size(2), output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ // divide into group
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+ for (int g = 0; g < group; g++) {
+ output[b][g] = output[b][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output[b][g]);
+ }
+
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ }
+
+ output = output.view({output.size(0), output.size(1) * output.size(2),
+ output.size(3), output.size(4)});
+
+ if (with_bias) {
+ output += bias.view({1, bias.size(0), 1, 1});
+ }
+}
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ grad_input = grad_input.view({batch, channels, height, width});
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+ input.options());
+
+ grad_output =
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+ grad_output.size(2), grad_output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ // divide int group
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+ grad_mask[b]);
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
+ // group
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+ grad_weight.size(1), grad_weight.size(2),
+ grad_weight.size(3)});
+ if (with_bias)
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+ for (int g = 0; g < group; g++) {
+ grad_weight[g] =
+ grad_weight[g]
+ .flatten(1)
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+ .view_as(grad_weight[g]);
+ if (with_bias) {
+ grad_bias[g] =
+ grad_bias[g]
+ .view({-1, 1})
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+ .view(-1);
+ }
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+ grad_weight.size(2), grad_weight.size(3),
+ grad_weight.size(4)});
+ if (with_bias)
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+ }
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+ grad_output.size(2), grad_output.size(3),
+ grad_output.size(4)});
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
+ "deform forward (CUDA)");
+ m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
+ "deform_conv_backward_input (CUDA)");
+ m.def("deform_conv_backward_parameters_cuda",
+ &deform_conv_backward_parameters_cuda,
+ "deform_conv_backward_parameters (CUDA)");
+ m.def("modulated_deform_conv_cuda_forward",
+ &modulated_deform_conv_cuda_forward,
+ "modulated deform conv forward (CUDA)");
+ m.def("modulated_deform_conv_cuda_backward",
+ &modulated_deform_conv_cuda_backward,
+ "modulated deform conv backward (CUDA)");
+}
diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py
new file mode 100644
index 00000000..27dd6a1f
--- /dev/null
+++ b/codes/models/archs/discriminator_vgg_arch.py
@@ -0,0 +1,88 @@
+import torch
+import torch.nn as nn
+import torchvision
+
+
+class Discriminator_VGG_128(nn.Module):
+ def __init__(self, in_nc, nf):
+ super(Discriminator_VGG_128, self).__init__()
+ # [64, 128, 128]
+ self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
+ self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
+ self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
+ # [64, 64, 64]
+ self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
+ self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
+ self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
+ self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
+ # [128, 32, 32]
+ self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
+ self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
+ self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
+ self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
+ # [256, 16, 16]
+ self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
+ self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
+ self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
+ self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
+ # [512, 8, 8]
+ self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
+ self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
+ self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
+ self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
+
+ self.linear1 = nn.Linear(512 * 4 * 4, 100)
+ self.linear2 = nn.Linear(100, 1)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ fea = self.lrelu(self.conv0_0(x))
+ fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
+
+ fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
+ fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
+
+ fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
+ fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
+
+ fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
+ fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
+
+ fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
+ fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
+
+ fea = fea.view(fea.size(0), -1)
+ fea = self.lrelu(self.linear1(fea))
+ out = self.linear2(fea)
+ return out
+
+
+class VGGFeatureExtractor(nn.Module):
+ def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
+ device=torch.device('cpu')):
+ super(VGGFeatureExtractor, self).__init__()
+ self.use_input_norm = use_input_norm
+ if use_bn:
+ model = torchvision.models.vgg19_bn(pretrained=True)
+ else:
+ model = torchvision.models.vgg19(pretrained=True)
+ if self.use_input_norm:
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
+ # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
+ # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
+ self.register_buffer('mean', mean)
+ self.register_buffer('std', std)
+ self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
+ # No need to BP to variable
+ for k, v in self.features.named_parameters():
+ v.requires_grad = False
+
+ def forward(self, x):
+ # Assume input range is [0, 1]
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+ output = self.features(x)
+ return output
diff --git a/codes/models/base_model.py b/codes/models/base_model.py
new file mode 100644
index 00000000..8a5d2225
--- /dev/null
+++ b/codes/models/base_model.py
@@ -0,0 +1,116 @@
+import os
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+from torch.nn.parallel import DistributedDataParallel
+
+
+class BaseModel():
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
+ self.is_train = opt['is_train']
+ self.schedulers = []
+ self.optimizers = []
+
+ def feed_data(self, data):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ pass
+
+ def get_current_losses(self):
+ pass
+
+ def print_network(self):
+ pass
+
+ def save(self, label):
+ pass
+
+ def load(self):
+ pass
+
+ def _set_lr(self, lr_groups_l):
+ """Set learning rate for warmup
+ lr_groups_l: list for lr_groups. each for a optimizer"""
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
+ param_group['lr'] = lr
+
+ def _get_init_lr(self):
+ """Get the initial lr, which is set by the scheduler"""
+ init_lr_groups_l = []
+ for optimizer in self.optimizers:
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
+ return init_lr_groups_l
+
+ def update_learning_rate(self, cur_iter, warmup_iter=-1):
+ for scheduler in self.schedulers:
+ scheduler.step()
+ # set up warm-up learning rate
+ if cur_iter < warmup_iter:
+ # get initial lr for each group
+ init_lr_g_l = self._get_init_lr()
+ # modify warming-up learning rates
+ warm_up_lr_l = []
+ for init_lr_g in init_lr_g_l:
+ warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
+ # set learning rate
+ self._set_lr(warm_up_lr_l)
+
+ def get_current_learning_rate(self):
+ return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
+
+ def get_network_description(self, network):
+ """Get the string and total parameters of the network"""
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
+ network = network.module
+ return str(network), sum(map(lambda x: x.numel(), network.parameters()))
+
+ def save_network(self, network, network_label, iter_label):
+ save_filename = '{}_{}.pth'.format(iter_label, network_label)
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
+ network = network.module
+ state_dict = network.state_dict()
+ for key, param in state_dict.items():
+ state_dict[key] = param.cpu()
+ torch.save(state_dict, save_path)
+
+ def load_network(self, load_path, network, strict=True):
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
+ network = network.module
+ load_net = torch.load(load_path)
+ load_net_clean = OrderedDict() # remove unnecessary 'module.'
+ for k, v in load_net.items():
+ if k.startswith('module.'):
+ load_net_clean[k[7:]] = v
+ else:
+ load_net_clean[k] = v
+ network.load_state_dict(load_net_clean, strict=strict)
+
+ def save_training_state(self, epoch, iter_step):
+ """Save training state during training, which will be used for resuming"""
+ state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
+ for s in self.schedulers:
+ state['schedulers'].append(s.state_dict())
+ for o in self.optimizers:
+ state['optimizers'].append(o.state_dict())
+ save_filename = '{}.state'.format(iter_step)
+ save_path = os.path.join(self.opt['path']['training_state'], save_filename)
+ torch.save(state, save_path)
+
+ def resume_training(self, resume_state):
+ """Resume the optimizers and schedulers for training"""
+ resume_optimizers = resume_state['optimizers']
+ resume_schedulers = resume_state['schedulers']
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
+ for i, o in enumerate(resume_optimizers):
+ self.optimizers[i].load_state_dict(o)
+ for i, s in enumerate(resume_schedulers):
+ self.schedulers[i].load_state_dict(s)
diff --git a/codes/models/loss.py b/codes/models/loss.py
new file mode 100644
index 00000000..c9f75f89
--- /dev/null
+++ b/codes/models/loss.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn as nn
+
+
+class CharbonnierLoss(nn.Module):
+ """Charbonnier Loss (L1)"""
+
+ def __init__(self, eps=1e-6):
+ super(CharbonnierLoss, self).__init__()
+ self.eps = eps
+
+ def forward(self, x, y):
+ diff = x - y
+ loss = torch.sum(torch.sqrt(diff * diff + self.eps))
+ return loss
+
+
+# Define GAN loss: [vanilla | lsgan | wgan-gp]
+class GANLoss(nn.Module):
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type.lower()
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'gan' or self.gan_type == 'ragan':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan-gp':
+
+ def wgan_loss(input, target):
+ # target is boolean
+ return -1 * input.mean() if target else input.mean()
+
+ self.loss = wgan_loss
+ else:
+ raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
+
+ def get_target_label(self, input, target_is_real):
+ if self.gan_type == 'wgan-gp':
+ return target_is_real
+ if target_is_real:
+ return torch.empty_like(input).fill_(self.real_label_val)
+ else:
+ return torch.empty_like(input).fill_(self.fake_label_val)
+
+ def forward(self, input, target_is_real):
+ target_label = self.get_target_label(input, target_is_real)
+ loss = self.loss(input, target_label)
+ return loss
+
+
+class GradientPenaltyLoss(nn.Module):
+ def __init__(self, device=torch.device('cpu')):
+ super(GradientPenaltyLoss, self).__init__()
+ self.register_buffer('grad_outputs', torch.Tensor())
+ self.grad_outputs = self.grad_outputs.to(device)
+
+ def get_grad_outputs(self, input):
+ if self.grad_outputs.size() != input.size():
+ self.grad_outputs.resize_(input.size()).fill_(1.0)
+ return self.grad_outputs
+
+ def forward(self, interp, interp_crit):
+ grad_outputs = self.get_grad_outputs(interp_crit)
+ grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
+ grad_outputs=grad_outputs, create_graph=True,
+ retain_graph=True, only_inputs=True)[0]
+ grad_interp = grad_interp.view(grad_interp.size(0), -1)
+ grad_interp_norm = grad_interp.norm(2, dim=1)
+
+ loss = ((grad_interp_norm - 1)**2).mean()
+ return loss
diff --git a/codes/models/lr_scheduler.py b/codes/models/lr_scheduler.py
new file mode 100644
index 00000000..be7a92f0
--- /dev/null
+++ b/codes/models/lr_scheduler.py
@@ -0,0 +1,144 @@
+import math
+from collections import Counter
+from collections import defaultdict
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepLR_Restart(_LRScheduler):
+ def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
+ clear_state=False, last_epoch=-1):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ self.clear_state = clear_state
+ self.restarts = restarts if restarts else [0]
+ self.restarts = [v + 1 for v in self.restarts]
+ self.restart_weights = weights if weights else [1]
+ assert len(self.restarts) == len(
+ self.restart_weights), 'restarts and their weights do not match.'
+ super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch in self.restarts:
+ if self.clear_state:
+ self.optimizer.state = defaultdict(dict)
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [
+ group['lr'] * self.gamma**self.milestones[self.last_epoch]
+ for group in self.optimizer.param_groups
+ ]
+
+
+class CosineAnnealingLR_Restart(_LRScheduler):
+ def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
+ self.T_period = T_period
+ self.T_max = self.T_period[0] # current T period
+ self.eta_min = eta_min
+ self.restarts = restarts if restarts else [0]
+ self.restarts = [v + 1 for v in self.restarts]
+ self.restart_weights = weights if weights else [1]
+ self.last_restart = 0
+ assert len(self.restarts) == len(
+ self.restart_weights), 'restarts and their weights do not match.'
+ super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch == 0:
+ return self.base_lrs
+ elif self.last_epoch in self.restarts:
+ self.last_restart = self.last_epoch
+ self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
+ elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
+ return [
+ group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
+ ]
+ return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
+ (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
+ (group['lr'] - self.eta_min) + self.eta_min
+ for group in self.optimizer.param_groups]
+
+
+if __name__ == "__main__":
+ optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
+ betas=(0.9, 0.99))
+ ##############################
+ # MultiStepLR_Restart
+ ##############################
+ ## Original
+ lr_steps = [200000, 400000, 600000, 800000]
+ restarts = None
+ restart_weights = None
+
+ ## two
+ lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
+ restarts = [500000]
+ restart_weights = [1]
+
+ ## four
+ lr_steps = [
+ 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
+ 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
+ ]
+ restarts = [250000, 500000, 750000]
+ restart_weights = [1, 1, 1]
+
+ scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
+ clear_state=False)
+
+ ##############################
+ # Cosine Annealing Restart
+ ##############################
+ ## two
+ T_period = [500000, 500000]
+ restarts = [500000]
+ restart_weights = [1]
+
+ ## four
+ T_period = [250000, 250000, 250000, 250000]
+ restarts = [250000, 500000, 750000]
+ restart_weights = [1, 1, 1]
+
+ scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
+ weights=restart_weights)
+
+ ##############################
+ # Draw figure
+ ##############################
+ N_iter = 1000000
+ lr_l = list(range(N_iter))
+ for i in range(N_iter):
+ scheduler.step()
+ current_lr = optimizer.param_groups[0]['lr']
+ lr_l[i] = current_lr
+
+ import matplotlib as mpl
+ from matplotlib import pyplot as plt
+ import matplotlib.ticker as mtick
+ mpl.style.use('default')
+ import seaborn
+ seaborn.set(style='whitegrid')
+ seaborn.set_context('paper')
+
+ plt.figure(1)
+ plt.subplot(111)
+ plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
+ plt.title('Title', fontsize=16, color='k')
+ plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
+ legend = plt.legend(loc='upper right', shadow=False)
+ ax = plt.gca()
+ labels = ax.get_xticks().tolist()
+ for k, v in enumerate(labels):
+ labels[k] = str(int(v / 1000)) + 'K'
+ ax.set_xticklabels(labels)
+ ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
+
+ ax.set_ylabel('Learning rate')
+ ax.set_xlabel('Iteration')
+ fig = plt.gcf()
+ plt.show()
diff --git a/codes/models/networks.py b/codes/models/networks.py
new file mode 100644
index 00000000..2a679134
--- /dev/null
+++ b/codes/models/networks.py
@@ -0,0 +1,57 @@
+import torch
+import models.archs.SRResNet_arch as SRResNet_arch
+import models.archs.discriminator_vgg_arch as SRGAN_arch
+import models.archs.RRDBNet_arch as RRDBNet_arch
+import models.archs.EDVR_arch as EDVR_arch
+
+
+# Generator
+def define_G(opt):
+ opt_net = opt['network_G']
+ which_model = opt_net['which_model_G']
+
+ # image restoration
+ if which_model == 'MSRResNet':
+ netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
+ nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
+ elif which_model == 'RRDBNet':
+ netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
+ nf=opt_net['nf'], nb=opt_net['nb'])
+ # video restoration
+ elif which_model == 'EDVR':
+ netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
+ groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
+ back_RBs=opt_net['back_RBs'], center=opt_net['center'],
+ predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
+ w_TSA=opt_net['w_TSA'])
+ else:
+ raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
+
+ return netG
+
+
+# Discriminator
+def define_D(opt):
+ opt_net = opt['network_D']
+ which_model = opt_net['which_model_D']
+
+ if which_model == 'discriminator_vgg_128':
+ netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
+ else:
+ raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
+ return netD
+
+
+# Define network used for perceptual loss
+def define_F(opt, use_bn=False):
+ gpu_ids = opt['gpu_ids']
+ device = torch.device('cuda' if gpu_ids else 'cpu')
+ # PyTorch pretrained VGG19-54, before ReLU.
+ if use_bn:
+ feature_layer = 49
+ else:
+ feature_layer = 34
+ netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
+ use_input_norm=True, device=device)
+ netF.eval() # No need to train
+ return netF
diff --git a/codes/options/__init__.py b/codes/options/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/codes/options/options.py b/codes/options/options.py
new file mode 100644
index 00000000..99181b34
--- /dev/null
+++ b/codes/options/options.py
@@ -0,0 +1,116 @@
+import os
+import os.path as osp
+import logging
+import yaml
+from utils.util import OrderedYaml
+Loader, Dumper = OrderedYaml()
+
+
+def parse(opt_path, is_train=True):
+ with open(opt_path, mode='r') as f:
+ opt = yaml.load(f, Loader=Loader)
+ # export CUDA_VISIBLE_DEVICES
+ gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
+
+ opt['is_train'] = is_train
+ if opt['distortion'] == 'sr':
+ scale = opt['scale']
+
+ # datasets
+ for phase, dataset in opt['datasets'].items():
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ if opt['distortion'] == 'sr':
+ dataset['scale'] = scale
+ is_lmdb = False
+ if dataset.get('dataroot_GT', None) is not None:
+ dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
+ if dataset['dataroot_GT'].endswith('lmdb'):
+ is_lmdb = True
+ if dataset.get('dataroot_LQ', None) is not None:
+ dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
+ if dataset['dataroot_LQ'].endswith('lmdb'):
+ is_lmdb = True
+ dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
+ if dataset['mode'].endswith('mc'): # for memcached
+ dataset['data_type'] = 'mc'
+ dataset['mode'] = dataset['mode'].replace('_mc', '')
+
+ # path
+ for key, path in opt['path'].items():
+ if path and key in opt['path'] and key != 'strict_load':
+ opt['path'][key] = osp.expanduser(path)
+ opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
+ if is_train:
+ experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
+ opt['path']['log'] = experiments_root
+ opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
+
+ # change some options for debug mode
+ if 'debug' in opt['name']:
+ opt['train']['val_freq'] = 8
+ opt['logger']['print_freq'] = 1
+ opt['logger']['save_checkpoint_freq'] = 8
+ else: # test
+ results_root = osp.join(opt['path']['root'], 'results', opt['name'])
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+
+ # network
+ if opt['distortion'] == 'sr':
+ opt['network_G']['scale'] = scale
+
+ return opt
+
+
+def dict2str(opt, indent_l=1):
+ '''dict to string for logger'''
+ msg = ''
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_l * 2) + k + ':[\n'
+ msg += dict2str(v, indent_l + 1)
+ msg += ' ' * (indent_l * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
+ return msg
+
+
+class NoneDict(dict):
+ def __missing__(self, key):
+ return None
+
+
+# convert to NoneDict, which return None for missing key.
+def dict_to_nonedict(opt):
+ if isinstance(opt, dict):
+ new_opt = dict()
+ for key, sub_opt in opt.items():
+ new_opt[key] = dict_to_nonedict(sub_opt)
+ return NoneDict(**new_opt)
+ elif isinstance(opt, list):
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
+ else:
+ return opt
+
+
+def check_resume(opt, resume_iter):
+ '''Check resume states and pretrain_model paths'''
+ logger = logging.getLogger('base')
+ if opt['path']['resume_state']:
+ if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
+ 'pretrain_model_D', None) is not None:
+ logger.warning('pretrain_model path will be ignored when resuming training.')
+
+ opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
+ '{}_G.pth'.format(resume_iter))
+ logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
+ if 'gan' in opt['model']:
+ opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
+ '{}_D.pth'.format(resume_iter))
+ logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
diff --git a/codes/options/test/test_ESRGAN.yml b/codes/options/test/test_ESRGAN.yml
new file mode 100644
index 00000000..3522f217
--- /dev/null
+++ b/codes/options/test/test_ESRGAN.yml
@@ -0,0 +1,32 @@
+name: RRDB_ESRGAN_x4
+suffix: ~ # add suffix to saved images
+model: sr
+distortion: sr
+scale: 4
+crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
+gpu_ids: [0]
+
+datasets:
+ test_1: # the 1st test dataset
+ name: set5
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set5/Set5
+ dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
+ test_2: # the 2st test dataset
+ name: set14
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set14/Set14
+ dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
+
+#### network structures
+network_G:
+ which_model_G: RRDBNet
+ in_nc: 3
+ out_nc: 3
+ nf: 64
+ nb: 23
+ upscale: 4
+
+#### path
+path:
+ pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth
diff --git a/codes/options/test/test_SRGAN.yml b/codes/options/test/test_SRGAN.yml
new file mode 100644
index 00000000..21eea625
--- /dev/null
+++ b/codes/options/test/test_SRGAN.yml
@@ -0,0 +1,32 @@
+name: MSRGANx4
+suffix: ~ # add suffix to saved images
+model: sr
+distortion: sr
+scale: 4
+crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
+gpu_ids: [0]
+
+datasets:
+ test_1: # the 1st test dataset
+ name: set5
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set5/Set5
+ dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
+ test_2: # the 2st test dataset
+ name: set14
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set14/Set14
+ dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
+
+#### network structures
+network_G:
+ which_model_G: MSRResNet
+ in_nc: 3
+ out_nc: 3
+ nf: 64
+ nb: 16
+ upscale: 4
+
+#### path
+path:
+ pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth
diff --git a/codes/options/test/test_SRResNet.yml b/codes/options/test/test_SRResNet.yml
new file mode 100644
index 00000000..b30b3b44
--- /dev/null
+++ b/codes/options/test/test_SRResNet.yml
@@ -0,0 +1,48 @@
+name: MSRResNetx4
+suffix: ~ # add suffix to saved images
+model: sr
+distortion: sr
+scale: 4
+crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
+gpu_ids: [0]
+
+datasets:
+ test_1: # the 1st test dataset
+ name: set5
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set5/Set5
+ dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
+ test_2: # the 2st test dataset
+ name: set14
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set14/Set14
+ dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
+ test_3:
+ name: bsd100
+ mode: LQGT
+ dataroot_GT: ../datasets/BSD/BSDS100
+ dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4
+ test_4:
+ name: urban100
+ mode: LQGT
+ dataroot_GT: ../datasets/urban100
+ dataroot_LQ: ../datasets/urban100_bicLRx4
+ test_5:
+ name: div2k100
+ mode: LQGT
+ dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR
+ dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4
+
+
+#### network structures
+network_G:
+ which_model_G: MSRResNet
+ in_nc: 3
+ out_nc: 3
+ nf: 64
+ nb: 16
+ upscale: 4
+
+#### path
+path:
+ pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
diff --git a/codes/options/train/train_EDVR_M.yml b/codes/options/train/train_EDVR_M.yml
new file mode 100644
index 00000000..ed0916c0
--- /dev/null
+++ b/codes/options/train/train_EDVR_M.yml
@@ -0,0 +1,80 @@
+#### general settings
+name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new
+use_tb_logger: true
+model: video_base
+distortion: sr
+scale: 4
+gpu_ids: [0,1,2,3,4,5,6,7]
+
+#### datasets
+datasets:
+ train:
+ name: REDS
+ mode: REDS
+ interval_list: [1]
+ random_reverse: false
+ border_mode: false
+ dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
+ dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
+ cache_keys: ~
+
+ N_frames: 5
+ use_shuffle: true
+ n_workers: 3 # per GPU
+ batch_size: 32
+ GT_size: 256
+ LQ_size: 64
+ use_flip: true
+ use_rot: true
+ color: RGB
+ val:
+ name: REDS4
+ mode: video_test
+ dataroot_GT: ../datasets/REDS4/GT
+ dataroot_LQ: ../datasets/REDS4/sharp_bicubic
+ cache_data: True
+ N_frames: 5
+ padding: new_info
+
+#### network structures
+network_G:
+ which_model_G: EDVR
+ nf: 64
+ nframes: 5
+ groups: 8
+ front_RBs: 5
+ back_RBs: 10
+ predeblur: false
+ HR_in: false
+ w_TSA: true
+
+#### path
+path:
+ pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth
+ strict_load: false
+ resume_state: ~
+
+#### training settings: learning rate scheme, loss
+train:
+ lr_G: !!float 4e-4
+ lr_scheme: CosineAnnealingLR_Restart
+ beta1: 0.9
+ beta2: 0.99
+ niter: 600000
+ ft_tsa_only: 50000
+ warmup_iter: -1 # -1: no warm up
+ T_period: [50000, 100000, 150000, 150000, 150000]
+ restarts: [50000, 150000, 300000, 450000]
+ restart_weights: [1, 1, 1, 1]
+ eta_min: !!float 1e-7
+
+ pixel_criterion: cb
+ pixel_weight: 1.0
+ val_freq: !!float 5e3
+
+ manual_seed: 0
+
+#### logger
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
diff --git a/codes/options/train/train_EDVR_woTSA_M.yml b/codes/options/train/train_EDVR_woTSA_M.yml
new file mode 100644
index 00000000..9f48573c
--- /dev/null
+++ b/codes/options/train/train_EDVR_woTSA_M.yml
@@ -0,0 +1,71 @@
+#### general settings
+name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S
+use_tb_logger: true
+model: video_base
+distortion: sr
+scale: 4
+gpu_ids: [0,1,2,3,4,5,6,7]
+
+#### datasets
+datasets:
+ train:
+ name: REDS
+ mode: REDS
+ interval_list: [1]
+ random_reverse: false
+ border_mode: false
+ dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
+ dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
+ cache_keys: ~
+
+ N_frames: 5
+ use_shuffle: true
+ n_workers: 3 # per GPU
+ batch_size: 32
+ GT_size: 256
+ LQ_size: 64
+ use_flip: true
+ use_rot: true
+ color: RGB
+
+#### network structures
+network_G:
+ which_model_G: EDVR
+ nf: 64
+ nframes: 5
+ groups: 8
+ front_RBs: 5
+ back_RBs: 10
+ predeblur: false
+ HR_in: false
+ w_TSA: false
+
+#### path
+path:
+ pretrain_model_G: ~
+ strict_load: true
+ resume_state: ~
+
+#### training settings: learning rate scheme, loss
+train:
+ lr_G: !!float 4e-4
+ lr_scheme: CosineAnnealingLR_Restart
+ beta1: 0.9
+ beta2: 0.99
+ niter: 600000
+ warmup_iter: -1 # -1: no warm up
+ T_period: [150000, 150000, 150000, 150000]
+ restarts: [150000, 300000, 450000]
+ restart_weights: [1, 1, 1]
+ eta_min: !!float 1e-7
+
+ pixel_criterion: cb
+ pixel_weight: 1.0
+ val_freq: !!float 5e3
+
+ manual_seed: 0
+
+#### logger
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
diff --git a/codes/options/train/train_ESRGAN.yml b/codes/options/train/train_ESRGAN.yml
new file mode 100644
index 00000000..720f8652
--- /dev/null
+++ b/codes/options/train/train_ESRGAN.yml
@@ -0,0 +1,81 @@
+#### general settings
+name: 003_RRDB_ESRGANx4_DIV2K
+use_tb_logger: true
+model: srgan
+distortion: sr
+scale: 4
+gpu_ids: [2]
+
+#### datasets
+datasets:
+ train:
+ name: DIV2K
+ mode: LQGT
+ dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
+ dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
+
+ use_shuffle: true
+ n_workers: 6 # per GPU
+ batch_size: 16
+ GT_size: 128
+ use_flip: true
+ use_rot: true
+ color: RGB
+ val:
+ name: val_set14
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set14/Set14
+ dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
+
+#### network structures
+network_G:
+ which_model_G: RRDBNet
+ in_nc: 3
+ out_nc: 3
+ nf: 64
+ nb: 23
+network_D:
+ which_model_D: discriminator_vgg_128
+ in_nc: 3
+ nf: 64
+
+#### path
+path:
+ pretrain_model_G: ../experiments/pretrained_models/RRDB_PSNR_x4.pth
+ strict_load: true
+ resume_state: ~
+
+#### training settings: learning rate scheme, loss
+train:
+ lr_G: !!float 1e-4
+ weight_decay_G: 0
+ beta1_G: 0.9
+ beta2_G: 0.99
+ lr_D: !!float 1e-4
+ weight_decay_D: 0
+ beta1_D: 0.9
+ beta2_D: 0.99
+ lr_scheme: MultiStepLR
+
+ niter: 400000
+ warmup_iter: -1 # no warm up
+ lr_steps: [50000, 100000, 200000, 300000]
+ lr_gamma: 0.5
+
+ pixel_criterion: l1
+ pixel_weight: !!float 1e-2
+ feature_criterion: l1
+ feature_weight: 1
+ gan_type: ragan # gan | ragan
+ gan_weight: !!float 5e-3
+
+ D_update_ratio: 1
+ D_init_iters: 0
+
+ manual_seed: 10
+ val_freq: !!float 5e3
+
+#### logger
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
diff --git a/codes/options/train/train_SRGAN.yml b/codes/options/train/train_SRGAN.yml
new file mode 100644
index 00000000..8aa48727
--- /dev/null
+++ b/codes/options/train/train_SRGAN.yml
@@ -0,0 +1,85 @@
+# Not exactly the same as SRGAN in
+# With 16 Residual blocks w/o BN
+
+#### general settings
+name: 002_SRGANx4_MSRResNetx4Ini_DIV2K
+use_tb_logger: true
+model: srgan
+distortion: sr
+scale: 4
+gpu_ids: [1]
+
+#### datasets
+datasets:
+ train:
+ name: DIV2K
+ mode: LQGT
+ dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
+ dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
+
+ use_shuffle: true
+ n_workers: 6 # per GPU
+ batch_size: 16
+ GT_size: 128
+ use_flip: true
+ use_rot: true
+ color: RGB
+ val:
+ name: val_set14
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set14/Set14
+ dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
+
+#### network structures
+network_G:
+ which_model_G: MSRResNet
+ in_nc: 3
+ out_nc: 3
+ nf: 64
+ nb: 16
+ upscale: 4
+network_D:
+ which_model_D: discriminator_vgg_128
+ in_nc: 3
+ nf: 64
+
+#### path
+path:
+ pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
+ strict_load: true
+ resume_state: ~
+
+#### training settings: learning rate scheme, loss
+train:
+ lr_G: !!float 1e-4
+ weight_decay_G: 0
+ beta1_G: 0.9
+ beta2_G: 0.99
+ lr_D: !!float 1e-4
+ weight_decay_D: 0
+ beta1_D: 0.9
+ beta2_D: 0.99
+ lr_scheme: MultiStepLR
+
+ niter: 400000
+ warmup_iter: -1 # no warm up
+ lr_steps: [50000, 100000, 200000, 300000]
+ lr_gamma: 0.5
+
+ pixel_criterion: l1
+ pixel_weight: !!float 1e-2
+ feature_criterion: l1
+ feature_weight: 1
+ gan_type: gan # gan | ragan
+ gan_weight: !!float 5e-3
+
+ D_update_ratio: 1
+ D_init_iters: 0
+
+ manual_seed: 10
+ val_freq: !!float 5e3
+
+#### logger
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
diff --git a/codes/options/train/train_SRResNet.yml b/codes/options/train/train_SRResNet.yml
new file mode 100644
index 00000000..be6fd4da
--- /dev/null
+++ b/codes/options/train/train_SRResNet.yml
@@ -0,0 +1,70 @@
+# Not exactly the same as SRResNet in
+# With 16 Residual blocks w/o BN
+
+#### general settings
+name: 001_MSRResNetx4_scratch_DIV2K
+use_tb_logger: true
+model: sr
+distortion: sr
+scale: 4
+gpu_ids: [0]
+
+#### datasets
+datasets:
+ train:
+ name: DIV2K
+ mode: LQGT
+ dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
+ dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
+
+ use_shuffle: true
+ n_workers: 6 # per GPU
+ batch_size: 16
+ GT_size: 128
+ use_flip: true
+ use_rot: true
+ color: RGB
+ val:
+ name: val_set5
+ mode: LQGT
+ dataroot_GT: ../datasets/val_set5/Set5
+ dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
+
+#### network structures
+network_G:
+ which_model_G: MSRResNet
+ in_nc: 3
+ out_nc: 3
+ nf: 64
+ nb: 16
+ upscale: 4
+
+#### path
+path:
+ pretrain_model_G: ~
+ strict_load: true
+ resume_state: ~
+
+#### training settings: learning rate scheme, loss
+train:
+ lr_G: !!float 2e-4
+ lr_scheme: CosineAnnealingLR_Restart
+ beta1: 0.9
+ beta2: 0.99
+ niter: 1000000
+ warmup_iter: -1 # no warm up
+ T_period: [250000, 250000, 250000, 250000]
+ restarts: [250000, 500000, 750000]
+ restart_weights: [1, 1, 1]
+ eta_min: !!float 1e-7
+
+ pixel_criterion: l1
+ pixel_weight: 1.0
+
+ manual_seed: 10
+ val_freq: !!float 5e3
+
+#### logger
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
diff --git a/codes/run_scripts.sh b/codes/run_scripts.sh
new file mode 100644
index 00000000..3e7c4945
--- /dev/null
+++ b/codes/run_scripts.sh
@@ -0,0 +1,10 @@
+# single GPU training (image SR)
+python train.py -opt options/train/train_SRResNet.yml
+python train.py -opt options/train/train_SRGAN.yml
+python train.py -opt options/train/train_ESRGAN.yml
+
+
+# distributed training (video SR)
+# 8 GPUs
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch
\ No newline at end of file
diff --git a/codes/scripts/back_projection/backprojection.m b/codes/scripts/back_projection/backprojection.m
new file mode 100644
index 00000000..496d93f4
--- /dev/null
+++ b/codes/scripts/back_projection/backprojection.m
@@ -0,0 +1,20 @@
+function [im_h] = backprojection(im_h, im_l, maxIter)
+
+[row_l, col_l,~] = size(im_l);
+[row_h, col_h,~] = size(im_h);
+
+p = fspecial('gaussian', 5, 1);
+p = p.^2;
+p = p./sum(p(:));
+
+im_l = double(im_l);
+im_h = double(im_h);
+
+for ii = 1:maxIter
+ im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
+ im_diff = im_l - im_l_s;
+ im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
+ im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
+ im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
+ im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
+end
diff --git a/codes/scripts/back_projection/main_bp.m b/codes/scripts/back_projection/main_bp.m
new file mode 100644
index 00000000..40c137ed
--- /dev/null
+++ b/codes/scripts/back_projection/main_bp.m
@@ -0,0 +1,22 @@
+clear; close all; clc;
+
+LR_folder = './LR'; % LR
+preout_folder = './results'; % pre output
+save_folder = './results_20bp';
+filepaths = dir(fullfile(preout_folder, '*.png'));
+max_iter = 20;
+
+if ~ exist(save_folder, 'dir')
+ mkdir(save_folder);
+end
+
+for idx_im = 1:length(filepaths)
+ fprintf([num2str(idx_im) '\n']);
+ im_name = filepaths(idx_im).name;
+ im_LR = im2double(imread(fullfile(LR_folder, im_name)));
+ im_out = im2double(imread(fullfile(preout_folder, im_name)));
+ %tic
+ im_out = backprojection(im_out, im_LR, max_iter);
+ %toc
+ imwrite(im_out, fullfile(save_folder, im_name));
+end
diff --git a/codes/scripts/back_projection/main_reverse_filter.m b/codes/scripts/back_projection/main_reverse_filter.m
new file mode 100644
index 00000000..63f2edcf
--- /dev/null
+++ b/codes/scripts/back_projection/main_reverse_filter.m
@@ -0,0 +1,25 @@
+clear; close all; clc;
+
+LR_folder = './LR'; % LR
+preout_folder = './results'; % pre output
+save_folder = './results_20if';
+filepaths = dir(fullfile(preout_folder, '*.png'));
+max_iter = 20;
+
+if ~ exist(save_folder, 'dir')
+ mkdir(save_folder);
+end
+
+for idx_im = 1:length(filepaths)
+ fprintf([num2str(idx_im) '\n']);
+ im_name = filepaths(idx_im).name;
+ im_LR = im2double(imread(fullfile(LR_folder, im_name)));
+ im_out = im2double(imread(fullfile(preout_folder, im_name)));
+ J = imresize(im_LR,4,'bicubic');
+ %tic
+ for m = 1:max_iter
+ im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
+ end
+ %toc
+ imwrite(im_out, fullfile(save_folder, im_name));
+end
diff --git a/codes/scripts/transfer_params_MSRResNet.py b/codes/scripts/transfer_params_MSRResNet.py
new file mode 100644
index 00000000..70dafa4d
--- /dev/null
+++ b/codes/scripts/transfer_params_MSRResNet.py
@@ -0,0 +1,27 @@
+import os.path as osp
+import sys
+import torch
+try:
+ sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
+ import models.archs.SRResNet_arch as SRResNet_arch
+except ImportError:
+ pass
+
+pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
+crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
+crt_net = crt_model.state_dict()
+
+for k, v in crt_net.items():
+ if k in pretrained_net and 'upconv1' not in k:
+ crt_net[k] = pretrained_net[k]
+ print('replace ... ', k)
+
+# x4 -> x3
+crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
+crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
+crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
+crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
+crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
+crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
+
+torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
diff --git a/codes/test.py b/codes/test.py
new file mode 100644
index 00000000..b07a44b7
--- /dev/null
+++ b/codes/test.py
@@ -0,0 +1,105 @@
+import os.path as osp
+import logging
+import time
+import argparse
+from collections import OrderedDict
+
+import options.options as option
+import utils.util as util
+from data.util import bgr2ycbcr
+from data import create_dataset, create_dataloader
+from models import create_model
+
+#### options
+parser = argparse.ArgumentParser()
+parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
+opt = option.parse(parser.parse_args().opt, is_train=False)
+opt = option.dict_to_nonedict(opt)
+
+util.mkdirs(
+ (path for key, path in opt['path'].items()
+ if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
+util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
+ screen=True, tofile=True)
+logger = logging.getLogger('base')
+logger.info(option.dict2str(opt))
+
+#### Create test dataset and dataloader
+test_loaders = []
+for phase, dataset_opt in sorted(opt['datasets'].items()):
+ test_set = create_dataset(dataset_opt)
+ test_loader = create_dataloader(test_set, dataset_opt)
+ logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
+ test_loaders.append(test_loader)
+
+model = create_model(opt)
+for test_loader in test_loaders:
+ test_set_name = test_loader.dataset.opt['name']
+ logger.info('\nTesting [{:s}]...'.format(test_set_name))
+ test_start_time = time.time()
+ dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
+ util.mkdir(dataset_dir)
+
+ test_results = OrderedDict()
+ test_results['psnr'] = []
+ test_results['ssim'] = []
+ test_results['psnr_y'] = []
+ test_results['ssim_y'] = []
+
+ for data in test_loader:
+ need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
+ model.feed_data(data, need_GT=need_GT)
+ img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
+ img_name = osp.splitext(osp.basename(img_path))[0]
+
+ model.test()
+ visuals = model.get_current_visuals(need_GT=need_GT)
+
+ sr_img = util.tensor2img(visuals['rlt']) # uint8
+
+ # save images
+ suffix = opt['suffix']
+ if suffix:
+ save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
+ else:
+ save_img_path = osp.join(dataset_dir, img_name + '.png')
+ util.save_img(sr_img, save_img_path)
+
+ # calculate PSNR and SSIM
+ if need_GT:
+ gt_img = util.tensor2img(visuals['GT'])
+ sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
+ psnr = util.calculate_psnr(sr_img, gt_img)
+ ssim = util.calculate_ssim(sr_img, gt_img)
+ test_results['psnr'].append(psnr)
+ test_results['ssim'].append(ssim)
+
+ if gt_img.shape[2] == 3: # RGB image
+ sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
+ gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)
+
+ psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
+ ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
+ test_results['psnr_y'].append(psnr_y)
+ test_results['ssim_y'].append(ssim_y)
+ logger.info(
+ '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
+ format(img_name, psnr, ssim, psnr_y, ssim_y))
+ else:
+ logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
+ else:
+ logger.info(img_name)
+
+ if need_GT: # metrics
+ # Average PSNR/SSIM results
+ ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
+ ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
+ logger.info(
+ '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format(
+ test_set_name, ave_psnr, ave_ssim))
+ if test_results['psnr_y'] and test_results['ssim_y']:
+ ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
+ ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
+ logger.info(
+ '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'.
+ format(ave_psnr_y, ave_ssim_y))
diff --git a/codes/test_Vid4_REDS4_with_GT.py b/codes/test_Vid4_REDS4_with_GT.py
new file mode 100644
index 00000000..7c90c493
--- /dev/null
+++ b/codes/test_Vid4_REDS4_with_GT.py
@@ -0,0 +1,208 @@
+'''
+Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets
+'''
+
+import os
+import os.path as osp
+import glob
+import logging
+import numpy as np
+import cv2
+import torch
+
+import utils.util as util
+import data.util as data_util
+import models.archs.EDVR_arch as EDVR_arch
+
+
+def main():
+ #################
+ # configurations
+ #################
+ device = torch.device('cuda')
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+ data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
+ # Vid4: SR
+ # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
+ # blur (deblur-clean), blur_comp (deblur-compression).
+ stage = 1 # 1 or 2, use two stage strategy for REDS dataset.
+ flip_test = False
+ ############################################################################
+ #### model
+ if data_mode == 'Vid4':
+ if stage == 1:
+ model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
+ else:
+ raise ValueError('Vid4 does not support stage 2.')
+ elif data_mode == 'sharp_bicubic':
+ if stage == 1:
+ model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
+ else:
+ model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
+ elif data_mode == 'blur_bicubic':
+ if stage == 1:
+ model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
+ else:
+ model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
+ elif data_mode == 'blur':
+ if stage == 1:
+ model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
+ else:
+ model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
+ elif data_mode == 'blur_comp':
+ if stage == 1:
+ model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
+ else:
+ model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
+ else:
+ raise NotImplementedError
+
+ if data_mode == 'Vid4':
+ N_in = 7 # use N_in images to restore one HR image
+ else:
+ N_in = 5
+
+ predeblur, HR_in = False, False
+ back_RBs = 40
+ if data_mode == 'blur_bicubic':
+ predeblur = True
+ if data_mode == 'blur' or data_mode == 'blur_comp':
+ predeblur, HR_in = True, True
+ if stage == 2:
+ HR_in = True
+ back_RBs = 20
+ model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
+
+ #### dataset
+ if data_mode == 'Vid4':
+ test_dataset_folder = '../datasets/Vid4/BIx4'
+ GT_dataset_folder = '../datasets/Vid4/GT'
+ else:
+ if stage == 1:
+ test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
+ else:
+ test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
+ print('You should modify the test_dataset_folder path for stage 2')
+ GT_dataset_folder = '../datasets/REDS4/GT'
+
+ #### evaluation
+ crop_border = 0
+ border_frame = N_in // 2 # border frames when evaluate
+ # temporal padding mode
+ if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
+ padding = 'new_info'
+ else:
+ padding = 'replicate'
+ save_imgs = True
+
+ save_folder = '../results/{}'.format(data_mode)
+ util.mkdirs(save_folder)
+ util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
+ logger = logging.getLogger('base')
+
+ #### log info
+ logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
+ logger.info('Padding mode: {}'.format(padding))
+ logger.info('Model path: {}'.format(model_path))
+ logger.info('Save images: {}'.format(save_imgs))
+ logger.info('Flip test: {}'.format(flip_test))
+
+ #### set up the models
+ model.load_state_dict(torch.load(model_path), strict=True)
+ model.eval()
+ model = model.to(device)
+
+ avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
+ subfolder_name_l = []
+
+ subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
+ subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
+ # for each subfolder
+ for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
+ subfolder_name = osp.basename(subfolder)
+ subfolder_name_l.append(subfolder_name)
+ save_subfolder = osp.join(save_folder, subfolder_name)
+
+ img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
+ max_idx = len(img_path_l)
+ if save_imgs:
+ util.mkdirs(save_subfolder)
+
+ #### read LQ and GT images
+ imgs_LQ = data_util.read_img_seq(subfolder)
+ img_GT_l = []
+ for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
+ img_GT_l.append(data_util.read_img(None, img_GT_path))
+
+ avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
+
+ # process each image
+ for img_idx, img_path in enumerate(img_path_l):
+ img_name = osp.splitext(osp.basename(img_path))[0]
+ select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
+ imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
+
+ if flip_test:
+ output = util.flipx4_forward(model, imgs_in)
+ else:
+ output = util.single_forward(model, imgs_in)
+ output = util.tensor2img(output.squeeze(0))
+
+ if save_imgs:
+ cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
+
+ # calculate PSNR
+ output = output / 255.
+ GT = np.copy(img_GT_l[img_idx])
+ # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
+ if data_mode == 'Vid4': # bgr2y, [0, 1]
+ GT = data_util.bgr2ycbcr(GT, only_y=True)
+ output = data_util.bgr2ycbcr(output, only_y=True)
+
+ output, GT = util.crop_border([output, GT], crop_border)
+ crt_psnr = util.calculate_psnr(output * 255, GT * 255)
+ logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))
+
+ if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
+ avg_psnr_center += crt_psnr
+ N_center += 1
+ else: # border frames
+ avg_psnr_border += crt_psnr
+ N_border += 1
+
+ avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
+ avg_psnr_center = avg_psnr_center / N_center
+ avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
+ avg_psnr_l.append(avg_psnr)
+ avg_psnr_center_l.append(avg_psnr_center)
+ avg_psnr_border_l.append(avg_psnr_border)
+
+ logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
+ 'Center PSNR: {:.6f} dB for {} frames; '
+ 'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
+ (N_center + N_border),
+ avg_psnr_center, N_center,
+ avg_psnr_border, N_border))
+
+ logger.info('################ Tidy Outputs ################')
+ for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
+ avg_psnr_center_l, avg_psnr_border_l):
+ logger.info('Folder {} - Average PSNR: {:.6f} dB. '
+ 'Center PSNR: {:.6f} dB. '
+ 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
+ psnr_border))
+ logger.info('################ Final Results ################')
+ logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
+ logger.info('Padding mode: {}'.format(padding))
+ logger.info('Model path: {}'.format(model_path))
+ logger.info('Save images: {}'.format(save_imgs))
+ logger.info('Flip test: {}'.format(flip_test))
+ logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
+ 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
+ sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
+ sum(avg_psnr_center_l) / len(avg_psnr_center_l),
+ sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/codes/test_Vid4_REDS4_with_GT_DUF.py b/codes/test_Vid4_REDS4_with_GT_DUF.py
new file mode 100644
index 00000000..fcec690d
--- /dev/null
+++ b/codes/test_Vid4_REDS4_with_GT_DUF.py
@@ -0,0 +1,264 @@
+"""
+DUF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
+write to txt log file
+"""
+
+import os
+import os.path as osp
+import glob
+import logging
+import numpy as np
+import cv2
+import torch
+
+import utils.util as util
+import data.util as data_util
+import models.archs.DUF_arch as DUF_arch
+
+
+def main():
+ #################
+ # configurations
+ #################
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+ data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
+
+ # Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
+ scale = 4
+ layer = 52
+ assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28),
+ (4, 52)], 'Unrecognized (scale, layer) combination'
+
+ # model
+ N_in = 7
+ model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format(scale, layer)
+ adapt_official = True if 'official' in model_path else False
+ DUF_downsampling = True # True | False
+ if layer == 16:
+ model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official)
+ elif layer == 28:
+ model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official)
+ elif layer == 52:
+ model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official)
+
+ #### dataset
+ if data_mode == 'Vid4':
+ test_dataset_folder = '../datasets/Vid4/BIx4/*'
+ else: # sharp_bicubic (REDS)
+ test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
+
+ #### evaluation
+ crop_border = 8
+ border_frame = N_in // 2 # border frames when evaluate
+ # temporal padding mode
+ padding = 'new_info' # different from the official testing codes, which pads zeros.
+ save_imgs = True
+ ############################################################################
+ device = torch.device('cuda')
+ save_folder = '../results/{}'.format(data_mode)
+ util.mkdirs(save_folder)
+ util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
+ logger = logging.getLogger('base')
+
+ #### log info
+ logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
+ logger.info('Padding mode: {}'.format(padding))
+ logger.info('Model path: {}'.format(model_path))
+ logger.info('Save images: {}'.format(save_imgs))
+
+ def read_image(img_path):
+ '''read one image from img_path
+ Return img: HWC, BGR, [0,1], numpy
+ '''
+ img_GT = cv2.imread(img_path)
+ img = img_GT.astype(np.float32) / 255.
+ return img
+
+ def read_seq_imgs(img_seq_path):
+ '''read a sequence of images'''
+ img_path_l = sorted(glob.glob(img_seq_path + '/*'))
+ img_l = [read_image(v) for v in img_path_l]
+ # stack to TCHW, RGB, [0,1], torch
+ imgs = np.stack(img_l, axis=0)
+ imgs = imgs[:, :, :, [2, 1, 0]]
+ imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
+ return imgs
+
+ def index_generation(crt_i, max_n, N, padding='reflection'):
+ '''
+ padding: replicate | reflection | new_info | circle
+ '''
+ max_n = max_n - 1
+ n_pad = N // 2
+ return_l = []
+
+ for i in range(crt_i - n_pad, crt_i + n_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ add_idx = 0
+ elif padding == 'reflection':
+ add_idx = -i
+ elif padding == 'new_info':
+ add_idx = (crt_i + n_pad) + (-i)
+ elif padding == 'circle':
+ add_idx = N + i
+ else:
+ raise ValueError('Wrong padding mode')
+ elif i > max_n:
+ if padding == 'replicate':
+ add_idx = max_n
+ elif padding == 'reflection':
+ add_idx = max_n * 2 - i
+ elif padding == 'new_info':
+ add_idx = (crt_i - n_pad) - (i - max_n)
+ elif padding == 'circle':
+ add_idx = i - N
+ else:
+ raise ValueError('Wrong padding mode')
+ else:
+ add_idx = i
+ return_l.append(add_idx)
+ return return_l
+
+ def single_forward(model, imgs_in):
+ with torch.no_grad():
+ model_output = model(imgs_in)
+ if isinstance(model_output, list) or isinstance(model_output, tuple):
+ output = model_output[0]
+ else:
+ output = model_output
+ return output
+
+ sub_folder_l = sorted(glob.glob(test_dataset_folder))
+ #### set up the models
+ model.load_state_dict(torch.load(model_path), strict=True)
+ model.eval()
+ model = model.to(device)
+
+ avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
+ sub_folder_name_l = []
+
+ # for each sub-folder
+ for sub_folder in sub_folder_l:
+ sub_folder_name = sub_folder.split('/')[-1]
+ sub_folder_name_l.append(sub_folder_name)
+ save_sub_folder = osp.join(save_folder, sub_folder_name)
+
+ img_path_l = sorted(glob.glob(sub_folder + '/*'))
+ max_idx = len(img_path_l)
+
+ if save_imgs:
+ util.mkdirs(save_sub_folder)
+
+ #### read LR images
+ imgs = read_seq_imgs(sub_folder)
+ #### read GT images
+ img_GT_l = []
+ if data_mode == 'Vid4':
+ sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*')
+ else:
+ sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
+ for img_GT_path in sorted(glob.glob(sub_folder_GT)):
+ img_GT_l.append(read_image(img_GT_path))
+
+ # When using the downsampling in DUF official code, we downsample the HR images
+ if DUF_downsampling:
+ sub_folder = sub_folder_GT
+ img_path_l = sorted(glob.glob(sub_folder))
+ max_idx = len(img_path_l)
+ imgs = read_seq_imgs(sub_folder[:-2])
+
+ avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
+ cal_n_border, cal_n_center = 0, 0
+
+ # process each image
+ for img_idx, img_path in enumerate(img_path_l):
+ c_idx = int(osp.splitext(osp.basename(img_path))[0])
+ select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
+ # get input images
+ imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
+
+ # Downsample the HR images
+ H, W = imgs_in.size(3), imgs_in.size(4)
+ if DUF_downsampling:
+ imgs_in = util.DUF_downsample(imgs_in, scale=scale)
+
+ output = single_forward(model, imgs_in)
+
+ # Crop to the original shape
+ if scale == 3:
+ pad_h = 3 - (H % 3)
+ pad_w = 3 - (W % 3)
+ if pad_h > 0:
+ output = output[:, :, :-pad_h, :]
+ if pad_w > 0:
+ output = output[:, :, :, :-pad_w]
+ output_f = output.data.float().cpu().squeeze(0)
+
+ output = util.tensor2img(output_f)
+
+ # save imgs
+ if save_imgs:
+ cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
+
+ #### calculate PSNR
+ output = output / 255.
+ GT = np.copy(img_GT_l[img_idx])
+ # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
+ if data_mode == 'Vid4': # bgr2y, [0, 1]
+ GT = data_util.bgr2ycbcr(GT)
+ output = data_util.bgr2ycbcr(output)
+ if crop_border == 0:
+ cropped_output = output
+ cropped_GT = GT
+ else:
+ cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
+ cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
+ crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
+ logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
+
+ if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
+ avg_psnr_center += crt_psnr
+ cal_n_center += 1
+ else: # border frames
+ avg_psnr_border += crt_psnr
+ cal_n_border += 1
+
+ avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
+ avg_psnr_center = avg_psnr_center / cal_n_center
+ if cal_n_border == 0:
+ avg_psnr_border = 0
+ else:
+ avg_psnr_border = avg_psnr_border / cal_n_border
+
+ logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
+ 'Center PSNR: {:.6f} dB for {} frames; '
+ 'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
+ (cal_n_center + cal_n_border),
+ avg_psnr_center, cal_n_center,
+ avg_psnr_border, cal_n_border))
+
+ avg_psnr_l.append(avg_psnr)
+ avg_psnr_center_l.append(avg_psnr_center)
+ avg_psnr_border_l.append(avg_psnr_border)
+
+ logger.info('################ Tidy Outputs ################')
+ for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
+ avg_psnr_center_l, avg_psnr_border_l):
+ logger.info('Folder {} - Average PSNR: {:.6f} dB. '
+ 'Center PSNR: {:.6f} dB. '
+ 'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
+ logger.info('################ Final Results ################')
+ logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
+ logger.info('Padding mode: {}'.format(padding))
+ logger.info('Model path: {}'.format(model_path))
+ logger.info('Save images: {}'.format(save_imgs))
+ logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
+ 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
+ sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
+ sum(avg_psnr_center_l) / len(avg_psnr_center_l),
+ sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/codes/test_Vid4_REDS4_with_GT_TOF.py b/codes/test_Vid4_REDS4_with_GT_TOF.py
new file mode 100644
index 00000000..da8fc300
--- /dev/null
+++ b/codes/test_Vid4_REDS4_with_GT_TOF.py
@@ -0,0 +1,230 @@
+"""
+TOF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
+write to txt log file
+"""
+
+import os
+import os.path as osp
+import glob
+import logging
+import numpy as np
+import cv2
+import torch
+
+import utils.util as util
+import data.util as data_util
+import models.archs.TOF_arch as TOF_arch
+
+
+def main():
+ #################
+ # configurations
+ #################
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+ data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
+
+ # model
+ N_in = 7
+ model_path = '../experiments/pretrained_models/TOF_official.pth'
+ adapt_official = True if 'official' in model_path else False
+ model = TOF_arch.TOFlow(adapt_official=adapt_official)
+
+ #### dataset
+ if data_mode == 'Vid4':
+ test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*'
+ else:
+ test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
+
+ #### evaluation
+ crop_border = 0
+ border_frame = N_in // 2 # border frames when evaluate
+ # temporal padding mode
+ padding = 'new_info' # different from the official setting
+ save_imgs = True
+ ############################################################################
+ device = torch.device('cuda')
+ save_folder = '../results/{}'.format(data_mode)
+ util.mkdirs(save_folder)
+ util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
+ logger = logging.getLogger('base')
+
+ #### log info
+ logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
+ logger.info('Padding mode: {}'.format(padding))
+ logger.info('Model path: {}'.format(model_path))
+ logger.info('Save images: {}'.format(save_imgs))
+
+ def read_image(img_path):
+ '''read one image from img_path
+ Return img: HWC, BGR, [0,1], numpy
+ '''
+ img_GT = cv2.imread(img_path)
+ img = img_GT.astype(np.float32) / 255.
+ return img
+
+ def read_seq_imgs(img_seq_path):
+ '''read a sequence of images'''
+ img_path_l = sorted(glob.glob(img_seq_path + '/*'))
+ img_l = [read_image(v) for v in img_path_l]
+ # stack to TCHW, RGB, [0,1], torch
+ imgs = np.stack(img_l, axis=0)
+ imgs = imgs[:, :, :, [2, 1, 0]]
+ imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
+ return imgs
+
+ def index_generation(crt_i, max_n, N, padding='reflection'):
+ '''
+ padding: replicate | reflection | new_info | circle
+ '''
+ max_n = max_n - 1
+ n_pad = N // 2
+ return_l = []
+
+ for i in range(crt_i - n_pad, crt_i + n_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ add_idx = 0
+ elif padding == 'reflection':
+ add_idx = -i
+ elif padding == 'new_info':
+ add_idx = (crt_i + n_pad) + (-i)
+ elif padding == 'circle':
+ add_idx = N + i
+ else:
+ raise ValueError('Wrong padding mode')
+ elif i > max_n:
+ if padding == 'replicate':
+ add_idx = max_n
+ elif padding == 'reflection':
+ add_idx = max_n * 2 - i
+ elif padding == 'new_info':
+ add_idx = (crt_i - n_pad) - (i - max_n)
+ elif padding == 'circle':
+ add_idx = i - N
+ else:
+ raise ValueError('Wrong padding mode')
+ else:
+ add_idx = i
+ return_l.append(add_idx)
+ return return_l
+
+ def single_forward(model, imgs_in):
+ with torch.no_grad():
+ model_output = model(imgs_in)
+ if isinstance(model_output, list) or isinstance(model_output, tuple):
+ output = model_output[0]
+ else:
+ output = model_output
+ return output
+
+ sub_folder_l = sorted(glob.glob(test_dataset_folder))
+ #### set up the models
+ model.load_state_dict(torch.load(model_path), strict=True)
+ model.eval()
+ model = model.to(device)
+
+ avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
+ sub_folder_name_l = []
+
+ # for each sub-folder
+ for sub_folder in sub_folder_l:
+ sub_folder_name = sub_folder.split('/')[-1]
+ sub_folder_name_l.append(sub_folder_name)
+ save_sub_folder = osp.join(save_folder, sub_folder_name)
+
+ img_path_l = sorted(glob.glob(sub_folder + '/*'))
+ max_idx = len(img_path_l)
+
+ if save_imgs:
+ util.mkdirs(save_sub_folder)
+
+ #### read LR images
+ imgs = read_seq_imgs(sub_folder)
+ #### read GT images
+ img_GT_l = []
+ if data_mode == 'Vid4':
+ sub_folder_GT = osp.join(sub_folder.replace('/BIx4up_direct/', '/GT/'), '*')
+ else:
+ sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
+ for img_GT_path in sorted(glob.glob(sub_folder_GT)):
+ img_GT_l.append(read_image(img_GT_path))
+
+ avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
+ cal_n_border, cal_n_center = 0, 0
+
+ # process each image
+ for img_idx, img_path in enumerate(img_path_l):
+ c_idx = int(osp.splitext(osp.basename(img_path))[0])
+ select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
+ # get input images
+ imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
+ output = single_forward(model, imgs_in)
+ output_f = output.data.float().cpu().squeeze(0)
+
+ output = util.tensor2img(output_f)
+
+ # save imgs
+ if save_imgs:
+ cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
+
+ #### calculate PSNR
+ output = output / 255.
+ GT = np.copy(img_GT_l[img_idx])
+ # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
+ if data_mode == 'Vid4': # bgr2y, [0, 1]
+ GT = data_util.bgr2ycbcr(GT)
+ output = data_util.bgr2ycbcr(output)
+ if crop_border == 0:
+ cropped_output = output
+ cropped_GT = GT
+ else:
+ cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
+ cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
+ crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
+ logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
+
+ if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
+ avg_psnr_center += crt_psnr
+ cal_n_center += 1
+ else: # border frames
+ avg_psnr_border += crt_psnr
+ cal_n_border += 1
+
+ avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
+ avg_psnr_center = avg_psnr_center / cal_n_center
+ if cal_n_border == 0:
+ avg_psnr_border = 0
+ else:
+ avg_psnr_border = avg_psnr_border / cal_n_border
+
+ logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
+ 'Center PSNR: {:.6f} dB for {} frames; '
+ 'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
+ (cal_n_center + cal_n_border),
+ avg_psnr_center, cal_n_center,
+ avg_psnr_border, cal_n_border))
+
+ avg_psnr_l.append(avg_psnr)
+ avg_psnr_center_l.append(avg_psnr_center)
+ avg_psnr_border_l.append(avg_psnr_border)
+
+ logger.info('################ Tidy Outputs ################')
+ for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
+ avg_psnr_center_l, avg_psnr_border_l):
+ logger.info('Folder {} - Average PSNR: {:.6f} dB. '
+ 'Center PSNR: {:.6f} dB. '
+ 'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
+ logger.info('################ Final Results ################')
+ logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
+ logger.info('Padding mode: {}'.format(padding))
+ logger.info('Model path: {}'.format(model_path))
+ logger.info('Save images: {}'.format(save_imgs))
+ logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
+ 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
+ sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
+ sum(avg_psnr_center_l) / len(avg_psnr_center_l),
+ sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/codes/train.py b/codes/train.py
new file mode 100644
index 00000000..c8c29bde
--- /dev/null
+++ b/codes/train.py
@@ -0,0 +1,310 @@
+import os
+import math
+import argparse
+import random
+import logging
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from data.data_sampler import DistIterSampler
+
+import options.options as option
+from utils import util
+from data import create_dataloader, create_dataset
+from models import create_model
+
+
+def init_dist(backend='nccl', **kwargs):
+ """initialization for distributed training"""
+ if mp.get_start_method(allow_none=True) != 'spawn':
+ mp.set_start_method('spawn')
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def main():
+ #### options
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
+ parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ opt = option.parse(args.opt, is_train=True)
+
+ #### distributed training settings
+ if args.launcher == 'none': # disabled distributed training
+ opt['dist'] = False
+ rank = -1
+ print('Disabled distributed training.')
+ else:
+ opt['dist'] = True
+ init_dist()
+ world_size = torch.distributed.get_world_size()
+ rank = torch.distributed.get_rank()
+
+ #### loading resume state if exists
+ if opt['path'].get('resume_state', None):
+ # distributed resuming: all load into default GPU
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(opt['path']['resume_state'],
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ option.check_resume(opt, resume_state['iter']) # check resume options
+ else:
+ resume_state = None
+
+ #### mkdir and loggers
+ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
+ if resume_state is None:
+ util.mkdir_and_rename(
+ opt['path']['experiments_root']) # rename experiment folder if exists
+ util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
+ and 'pretrain_model' not in key and 'resume' not in key))
+
+ # config loggers. Before it, the log will not work
+ util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
+ screen=True, tofile=True)
+ logger = logging.getLogger('base')
+ logger.info(option.dict2str(opt))
+ # tensorboard logger
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ version = float(torch.__version__[0:3])
+ if version >= 1.1: # PyTorch 1.1
+ from torch.utils.tensorboard import SummaryWriter
+ else:
+ logger.info(
+ 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
+ from tensorboardX import SummaryWriter
+ tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
+ else:
+ util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
+ logger = logging.getLogger('base')
+
+ # convert to NoneDict, which returns None for missing keys
+ opt = option.dict_to_nonedict(opt)
+
+ #### random seed
+ seed = opt['train']['manual_seed']
+ if seed is None:
+ seed = random.randint(1, 10000)
+ if rank <= 0:
+ logger.info('Random seed: {}'.format(seed))
+ util.set_random_seed(seed)
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ #### create train and val dataloader
+ dataset_ratio = 200 # enlarge the size of each epoch
+ for phase, dataset_opt in opt['datasets'].items():
+ if phase == 'train':
+ train_set = create_dataset(dataset_opt)
+ train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
+ total_iters = int(opt['train']['niter'])
+ total_epochs = int(math.ceil(total_iters / train_size))
+ if opt['dist']:
+ train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
+ total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
+ else:
+ train_sampler = None
+ train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
+ if rank <= 0:
+ logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
+ len(train_set), train_size))
+ logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
+ total_epochs, total_iters))
+ elif phase == 'val':
+ val_set = create_dataset(dataset_opt)
+ val_loader = create_dataloader(val_set, dataset_opt, opt, None)
+ if rank <= 0:
+ logger.info('Number of val images in [{:s}]: {:d}'.format(
+ dataset_opt['name'], len(val_set)))
+ else:
+ raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
+ assert train_loader is not None
+
+ #### create model
+ model = create_model(opt)
+
+ #### resume training
+ if resume_state:
+ logger.info('Resuming training from epoch: {}, iter: {}.'.format(
+ resume_state['epoch'], resume_state['iter']))
+
+ start_epoch = resume_state['epoch']
+ current_step = resume_state['iter']
+ model.resume_training(resume_state) # handle optimizers and schedulers
+ else:
+ current_step = 0
+ start_epoch = 0
+
+ #### training
+ logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
+ for epoch in range(start_epoch, total_epochs + 1):
+ if opt['dist']:
+ train_sampler.set_epoch(epoch)
+ for _, train_data in enumerate(train_loader):
+ current_step += 1
+ if current_step > total_iters:
+ break
+ #### update learning rate
+ model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
+
+ #### training
+ model.feed_data(train_data)
+ model.optimize_parameters(current_step)
+
+ #### log
+ if current_step % opt['logger']['print_freq'] == 0:
+ logs = model.get_current_log()
+ message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
+ for v in model.get_current_learning_rate():
+ message += '{:.3e},'.format(v)
+ message += ')] '
+ for k, v in logs.items():
+ message += '{:s}: {:.4e} '.format(k, v)
+ # tensorboard logger
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ if rank <= 0:
+ tb_logger.add_scalar(k, v, current_step)
+ if rank <= 0:
+ logger.info(message)
+ #### validation
+ if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
+ if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation
+ # does not support multi-GPU validation
+ pbar = util.ProgressBar(len(val_loader))
+ avg_psnr = 0.
+ idx = 0
+ for val_data in val_loader:
+ idx += 1
+ img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
+ img_dir = os.path.join(opt['path']['val_images'], img_name)
+ util.mkdir(img_dir)
+
+ model.feed_data(val_data)
+ model.test()
+
+ visuals = model.get_current_visuals()
+ sr_img = util.tensor2img(visuals['rlt']) # uint8
+ gt_img = util.tensor2img(visuals['GT']) # uint8
+
+ # Save SR images for reference
+ save_img_path = os.path.join(img_dir,
+ '{:s}_{:d}.png'.format(img_name, current_step))
+ util.save_img(sr_img, save_img_path)
+
+ # calculate PSNR
+ sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
+ avg_psnr += util.calculate_psnr(sr_img, gt_img)
+ pbar.update('Test {}'.format(img_name))
+
+ avg_psnr = avg_psnr / idx
+
+ # log
+ logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
+ # tensorboard logger
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ tb_logger.add_scalar('psnr', avg_psnr, current_step)
+ else: # video restoration validation
+ if opt['dist']:
+ # multi-GPU testing
+ psnr_rlt = {} # with border and center frames
+ if rank == 0:
+ pbar = util.ProgressBar(len(val_set))
+ for idx in range(rank, len(val_set), world_size):
+ val_data = val_set[idx]
+ val_data['LQs'].unsqueeze_(0)
+ val_data['GT'].unsqueeze_(0)
+ folder = val_data['folder']
+ idx_d, max_idx = val_data['idx'].split('/')
+ idx_d, max_idx = int(idx_d), int(max_idx)
+ if psnr_rlt.get(folder, None) is None:
+ psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32,
+ device='cuda')
+ # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
+ model.feed_data(val_data)
+ model.test()
+ visuals = model.get_current_visuals()
+ rlt_img = util.tensor2img(visuals['rlt']) # uint8
+ gt_img = util.tensor2img(visuals['GT']) # uint8
+ # calculate PSNR
+ psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img)
+
+ if rank == 0:
+ for _ in range(world_size):
+ pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx))
+ # # collect data
+ for _, v in psnr_rlt.items():
+ dist.reduce(v, 0)
+ dist.barrier()
+
+ if rank == 0:
+ psnr_rlt_avg = {}
+ psnr_total_avg = 0.
+ for k, v in psnr_rlt.items():
+ psnr_rlt_avg[k] = torch.mean(v).cpu().item()
+ psnr_total_avg += psnr_rlt_avg[k]
+ psnr_total_avg /= len(psnr_rlt)
+ log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
+ for k, v in psnr_rlt_avg.items():
+ log_s += ' {}: {:.4e}'.format(k, v)
+ logger.info(log_s)
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
+ for k, v in psnr_rlt_avg.items():
+ tb_logger.add_scalar(k, v, current_step)
+ else:
+ pbar = util.ProgressBar(len(val_loader))
+ psnr_rlt = {} # with border and center frames
+ psnr_rlt_avg = {}
+ psnr_total_avg = 0.
+ for val_data in val_loader:
+ folder = val_data['folder'][0]
+ idx_d = val_data['idx'].item()
+ # border = val_data['border'].item()
+ if psnr_rlt.get(folder, None) is None:
+ psnr_rlt[folder] = []
+
+ model.feed_data(val_data)
+ model.test()
+ visuals = model.get_current_visuals()
+ rlt_img = util.tensor2img(visuals['rlt']) # uint8
+ gt_img = util.tensor2img(visuals['GT']) # uint8
+
+ # calculate PSNR
+ psnr = util.calculate_psnr(rlt_img, gt_img)
+ psnr_rlt[folder].append(psnr)
+ pbar.update('Test {} - {}'.format(folder, idx_d))
+ for k, v in psnr_rlt.items():
+ psnr_rlt_avg[k] = sum(v) / len(v)
+ psnr_total_avg += psnr_rlt_avg[k]
+ psnr_total_avg /= len(psnr_rlt)
+ log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
+ for k, v in psnr_rlt_avg.items():
+ log_s += ' {}: {:.4e}'.format(k, v)
+ logger.info(log_s)
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
+ for k, v in psnr_rlt_avg.items():
+ tb_logger.add_scalar(k, v, current_step)
+
+ #### save models and training states
+ if current_step % opt['logger']['save_checkpoint_freq'] == 0:
+ if rank <= 0:
+ logger.info('Saving models and training states.')
+ model.save(current_step)
+ model.save_training_state(epoch, current_step)
+
+ if rank <= 0:
+ logger.info('Saving the final model.')
+ model.save('latest')
+ logger.info('End of training.')
+ tb_logger.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/codes/utils/__init__.py b/codes/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/codes/utils/util.py b/codes/utils/util.py
new file mode 100644
index 00000000..a8924fa8
--- /dev/null
+++ b/codes/utils/util.py
@@ -0,0 +1,327 @@
+import os
+import sys
+import time
+import math
+import torch.nn.functional as F
+from datetime import datetime
+import random
+import logging
+from collections import OrderedDict
+import numpy as np
+import cv2
+import torch
+from torchvision.utils import make_grid
+from shutil import get_terminal_size
+
+import yaml
+try:
+ from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+ from yaml import Loader, Dumper
+
+
+def OrderedYaml():
+ '''yaml orderedDict support'''
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+####################
+# miscellaneous
+####################
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ logger = logging.getLogger('base')
+ logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+def set_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
+ '''set up logger'''
+ lg = logging.getLogger(logger_name)
+ formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
+ datefmt='%y-%m-%d %H:%M:%S')
+ lg.setLevel(level)
+ if tofile:
+ log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
+ fh = logging.FileHandler(log_file, mode='w')
+ fh.setFormatter(formatter)
+ lg.addHandler(fh)
+ if screen:
+ sh = logging.StreamHandler()
+ sh.setFormatter(formatter)
+ lg.addHandler(sh)
+
+
+####################
+# image convert
+####################
+def crop_border(img_list, crop_border):
+ """Crop borders of images
+ Args:
+ img_list (list [Numpy]): HWC
+ crop_border (int): crop border for each end of height and weight
+
+ Returns:
+ (list [Numpy]): cropped image list
+ """
+ if crop_border == 0:
+ return img_list
+ else:
+ return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list]
+
+
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+def save_img(img, img_path, mode='RGB'):
+ cv2.imwrite(img_path, img)
+
+
+def DUF_downsample(x, scale=4):
+ """Downsamping with Gaussian kernel used in the DUF official code
+
+ Args:
+ x (Tensor, [B, T, C, H, W]): frames to be downsampled.
+ scale (int): downsampling factor: 2 | 3 | 4.
+ """
+
+ assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale)
+
+ def gkern(kernlen=13, nsig=1.6):
+ import scipy.ndimage.filters as fi
+ inp = np.zeros((kernlen, kernlen))
+ # set element at the middle to one, a dirac delta
+ inp[kernlen // 2, kernlen // 2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter mask
+ return fi.gaussian_filter(inp, nsig)
+
+ B, T, C, H, W = x.size()
+ x = x.view(-1, 1, H, W)
+ pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter
+ r_h, r_w = 0, 0
+ if scale == 3:
+ r_h = 3 - (H % 3)
+ r_w = 3 - (W % 3)
+ x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect')
+
+ gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0)
+ x = F.conv2d(x, gaussian_filter, stride=scale)
+ x = x[:, :, 2:-2, 2:-2]
+ x = x.view(B, T, C, x.size(2), x.size(3))
+ return x
+
+
+def single_forward(model, inp):
+ """PyTorch model forward (single test), it is just a simple warpper
+ Args:
+ model (PyTorch model)
+ inp (Tensor): inputs defined by the model
+
+ Returns:
+ output (Tensor): outputs of the model. float, in CPU
+ """
+ with torch.no_grad():
+ model_output = model(inp)
+ if isinstance(model_output, list) or isinstance(model_output, tuple):
+ output = model_output[0]
+ else:
+ output = model_output
+ output = output.data.float().cpu()
+ return output
+
+
+def flipx4_forward(model, inp):
+ """Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W
+ Args:
+ model (PyTorch model)
+ inp (Tensor): inputs defined by the model
+
+ Returns:
+ output (Tensor): outputs of the model. float, in CPU
+ """
+ # normal
+ output_f = single_forward(model, inp)
+
+ # flip W
+ output = single_forward(model, torch.flip(inp, (-1, )))
+ output_f = output_f + torch.flip(output, (-1, ))
+ # flip H
+ output = single_forward(model, torch.flip(inp, (-2, )))
+ output_f = output_f + torch.flip(output, (-2, ))
+ # flip both H and W
+ output = single_forward(model, torch.flip(inp, (-2, -1)))
+ output_f = output_f + torch.flip(output, (-2, -1))
+
+ return output_f / 4
+
+
+####################
+# metric
+####################
+
+
+def calculate_psnr(img1, img2):
+ # img1 and img2 have range [0, 255]
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+def calculate_ssim(img1, img2):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1, img2))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+class ProgressBar(object):
+ '''A progress bar which can print the progress
+ modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
+ '''
+
+ def __init__(self, task_num=0, bar_width=50, start=True):
+ self.task_num = task_num
+ max_bar_width = self._get_max_bar_width()
+ self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
+ self.completed = 0
+ if start:
+ self.start()
+
+ def _get_max_bar_width(self):
+ terminal_width, _ = get_terminal_size()
+ max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
+ if max_bar_width < 10:
+ print('terminal width is too small ({}), please consider widen the terminal for better '
+ 'progressbar visualization'.format(terminal_width))
+ max_bar_width = 10
+ return max_bar_width
+
+ def start(self):
+ if self.task_num > 0:
+ sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
+ ' ' * self.bar_width, self.task_num, 'Start...'))
+ else:
+ sys.stdout.write('completed: 0, elapsed: 0s')
+ sys.stdout.flush()
+ self.start_time = time.time()
+
+ def update(self, msg='In progress...'):
+ self.completed += 1
+ elapsed = time.time() - self.start_time
+ fps = self.completed / elapsed
+ if self.task_num > 0:
+ percentage = self.completed / float(self.task_num)
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+ mark_width = int(self.bar_width * percentage)
+ bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
+ sys.stdout.write('\033[2F') # cursor up 2 lines
+ sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
+ sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
+ bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
+ else:
+ sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
+ self.completed, int(elapsed + 0.5), fps))
+ sys.stdout.flush()
diff --git a/experiments/pretrained_models/Put pretrained models here. b/experiments/pretrained_models/Put pretrained models here.
new file mode 100644
index 00000000..e69de29b