forked from mrq/DL-Art-School
mmsr
This commit is contained in:
parent
58b175161c
commit
037933ba66
6
.flake8
Normal file
6
.flake8
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
[flake8]
|
||||||
|
ignore =
|
||||||
|
# Too many leading '#' for block comment (E266)
|
||||||
|
E266
|
||||||
|
|
||||||
|
max-line-length=100
|
121
.gitignore
vendored
Normal file
121
.gitignore
vendored
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
experiments/*
|
||||||
|
results/*
|
||||||
|
tb_logger/*
|
||||||
|
datasets/*
|
||||||
|
.vscode
|
||||||
|
|
||||||
|
*.html
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.gif
|
||||||
|
*.pth
|
||||||
|
*.pytorch
|
||||||
|
*.zip
|
||||||
|
*.cu
|
||||||
|
|
||||||
|
# template
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
4
.style.yapf
Normal file
4
.style.yapf
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
[style]
|
||||||
|
BASED_ON_STYLE = pep8
|
||||||
|
COLUMN_LIMIT = 100
|
||||||
|
SPLIT_BEFORE_NAMED_ASSIGNS = false
|
47
README.md
Normal file
47
README.md
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# MMSR
|
||||||
|
|
||||||
|
MMSR is an open source image and video super-resolution toolbox based on PyTorch. It is a part of the [open-mmlab](https://github.com/open-mmlab) project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/). MMSR is based on our previous projects: [BasicSR](https://github.com/xinntao/BasicSR), [ESRGAN](https://github.com/xinntao/ESRGAN)and [EDVR](https://github.com/xinntao/EDVR).
|
||||||
|
|
||||||
|
### Highlights
|
||||||
|
- **A unified framework** suitable for image and video super-resolution tasks. It is also easy to adapt to other restoration tasks, e.g., deblurring, denoising, etc.
|
||||||
|
- **State of the art**: It includes several winning methods in competitions: such as ESRGAN (PIRM18), EDVR (NTIRE19).
|
||||||
|
- **Easy to extend**: It is easy to try new research ideas based on the code base.
|
||||||
|
|
||||||
|
|
||||||
|
### Updates
|
||||||
|
[2019-07-25] MMSR v0.1 is released.
|
||||||
|
|
||||||
|
## Dependencies and Installation
|
||||||
|
|
||||||
|
- Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
|
||||||
|
- [PyTorch >= 1.0](https://pytorch.org/)
|
||||||
|
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- [Deformable Convolution](https://arxiv.org/abs/1703.06211). We use [mmdetection](https://github.com/open-mmlab/mmdetection)'s dcn implementation. Please first compile it.
|
||||||
|
```
|
||||||
|
cd ./codes/models/archs/dcn
|
||||||
|
python setup.py develop
|
||||||
|
```
|
||||||
|
- Python packages: `pip install numpy opencv-python lmdb pyyaml`
|
||||||
|
- TensorBoard:
|
||||||
|
- PyTorch >= 1.1: `pip install tb-nightly future`
|
||||||
|
- PyTorch == 1.0: `pip install tensorboardX`
|
||||||
|
|
||||||
|
## Dataset Preparation
|
||||||
|
We use datasets in LDMB format for faster IO speed. Please refer to [DATASETS.md](datasets/DATASETS.md) for more details.
|
||||||
|
|
||||||
|
## Training and Testing
|
||||||
|
Please see [wiki- Training and Testing](https://github.com/open-mmlab/mmsr/wiki/Training-and-Testing) for the basic usage, *i.e.,* training and testing.
|
||||||
|
|
||||||
|
## Model Zoo and Baselines
|
||||||
|
Results and pre-trained models are available in the [wiki-Model Zoo](https://github.com/open-mmlab/mmsr/wiki/Model-Zoo).
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
We appreciate all contributions. Please refer to [mmdetection](https://github.com/open-mmlab/mmdetection/blob/master/CONTRIBUTING.md) for contributing guideline.
|
||||||
|
|
||||||
|
**Python code style**<br/>
|
||||||
|
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style. We use [flake8](http://flake8.pycqa.org/en/latest/) as the linter and [yapf](https://github.com/google/yapf) as the formatter. Please upgrade to the latest yapf (>=0.27.0) and refer to the [yapf configuration](.style.yapf) and [flake8 configuration](.flake8).
|
||||||
|
|
||||||
|
> Before you create a PR, make sure that your code lints and is formatted by yapf.
|
||||||
|
|
||||||
|
## License
|
||||||
|
This project is released under the Apache 2.0 license.
|
127
codes/data/LQGT_dataset.py
Normal file
127
codes/data/LQGT_dataset.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import lmdb
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import data.util as util
|
||||||
|
|
||||||
|
|
||||||
|
class LQGTDataset(data.Dataset):
|
||||||
|
"""
|
||||||
|
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
|
||||||
|
If only GT images are provided, generate LQ images on-the-fly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(LQGTDataset, self).__init__()
|
||||||
|
self.opt = opt
|
||||||
|
self.data_type = self.opt['data_type']
|
||||||
|
self.paths_LQ, self.paths_GT = None, None
|
||||||
|
self.sizes_LQ, self.sizes_GT = None, None
|
||||||
|
self.LQ_env, self.GT_env = None, None # environments for lmdb
|
||||||
|
|
||||||
|
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||||
|
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||||
|
assert self.paths_GT, 'Error: GT path is empty.'
|
||||||
|
if self.paths_LQ and self.paths_GT:
|
||||||
|
assert len(self.paths_LQ) == len(
|
||||||
|
self.paths_GT
|
||||||
|
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
|
||||||
|
len(self.paths_LQ), len(self.paths_GT))
|
||||||
|
self.random_scale_list = [1]
|
||||||
|
|
||||||
|
def _init_lmdb(self):
|
||||||
|
# https://github.com/chainer/chainermn/issues/129
|
||||||
|
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||||
|
self._init_lmdb()
|
||||||
|
GT_path, LQ_path = None, None
|
||||||
|
scale = self.opt['scale']
|
||||||
|
GT_size = self.opt['GT_size']
|
||||||
|
|
||||||
|
# get GT image
|
||||||
|
GT_path = self.paths_GT[index]
|
||||||
|
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
||||||
|
] if self.data_type == 'lmdb' else None
|
||||||
|
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
||||||
|
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
|
||||||
|
img_GT = util.modcrop(img_GT, scale)
|
||||||
|
if self.opt['color']: # change color space if necessary
|
||||||
|
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
||||||
|
|
||||||
|
# get LQ image
|
||||||
|
if self.paths_LQ:
|
||||||
|
LQ_path = self.paths_LQ[index]
|
||||||
|
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
||||||
|
] if self.data_type == 'lmdb' else None
|
||||||
|
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
||||||
|
else: # down-sampling on-the-fly
|
||||||
|
# randomly scale during training
|
||||||
|
if self.opt['phase'] == 'train':
|
||||||
|
random_scale = random.choice(self.random_scale_list)
|
||||||
|
H_s, W_s, _ = img_GT.shape
|
||||||
|
|
||||||
|
def _mod(n, random_scale, scale, thres):
|
||||||
|
rlt = int(n * random_scale)
|
||||||
|
rlt = (rlt // scale) * scale
|
||||||
|
return thres if rlt < thres else rlt
|
||||||
|
|
||||||
|
H_s = _mod(H_s, random_scale, scale, GT_size)
|
||||||
|
W_s = _mod(W_s, random_scale, scale, GT_size)
|
||||||
|
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
||||||
|
if img_GT.ndim == 2:
|
||||||
|
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
H, W, _ = img_GT.shape
|
||||||
|
# using matlab imresize
|
||||||
|
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||||
|
if img_LQ.ndim == 2:
|
||||||
|
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||||
|
|
||||||
|
if self.opt['phase'] == 'train':
|
||||||
|
# if the image size is too small
|
||||||
|
H, W, _ = img_GT.shape
|
||||||
|
if H < GT_size or W < GT_size:
|
||||||
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||||
|
# using matlab imresize
|
||||||
|
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||||
|
if img_LQ.ndim == 2:
|
||||||
|
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||||
|
|
||||||
|
H, W, C = img_LQ.shape
|
||||||
|
LQ_size = GT_size // scale
|
||||||
|
|
||||||
|
# randomly crop
|
||||||
|
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||||
|
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||||
|
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||||
|
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
||||||
|
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||||
|
|
||||||
|
# augmentation - flip, rotate
|
||||||
|
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
|
||||||
|
self.opt['use_rot'])
|
||||||
|
|
||||||
|
if self.opt['color']: # change color space if necessary
|
||||||
|
img_LQ = util.channel_convert(C, self.opt['color'],
|
||||||
|
[img_LQ])[0] # TODO during val no definition
|
||||||
|
|
||||||
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||||
|
if img_GT.shape[2] == 3:
|
||||||
|
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||||
|
img_LQ = img_LQ[:, :, [2, 1, 0]]
|
||||||
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||||
|
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
||||||
|
|
||||||
|
if LQ_path is None:
|
||||||
|
LQ_path = GT_path
|
||||||
|
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths_GT)
|
47
codes/data/LQ_dataset.py
Normal file
47
codes/data/LQ_dataset.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
import numpy as np
|
||||||
|
import lmdb
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import data.util as util
|
||||||
|
|
||||||
|
|
||||||
|
class LQDataset(data.Dataset):
|
||||||
|
'''Read LQ images only in the test phase.'''
|
||||||
|
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(LQDataset, self).__init__()
|
||||||
|
self.opt = opt
|
||||||
|
self.paths_LQ, self.paths_GT = None, None
|
||||||
|
self.LQ_env = None # environment for lmdb
|
||||||
|
|
||||||
|
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||||
|
assert self.paths_LQ, 'Error: LQ paths are empty.'
|
||||||
|
|
||||||
|
def _init_lmdb(self):
|
||||||
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.data_type == 'lmdb' and self.LQ_env is None:
|
||||||
|
self._init_lmdb()
|
||||||
|
LQ_path = None
|
||||||
|
|
||||||
|
# get LQ image
|
||||||
|
LQ_path = self.LQ_path[index]
|
||||||
|
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
|
||||||
|
] if self.data_type == 'lmdb' else None
|
||||||
|
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
||||||
|
H, W, C = img_LQ.shape
|
||||||
|
|
||||||
|
if self.opt['color']: # change color space if necessary
|
||||||
|
img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0]
|
||||||
|
|
||||||
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||||
|
if img_LQ.shape[2] == 3:
|
||||||
|
img_LQ = img_LQ[:, :, [2, 1, 0]]
|
||||||
|
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
||||||
|
|
||||||
|
return {'LQ': img_LQ, 'LQ_path': LQ_path}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths_LQ)
|
210
codes/data/REDS_dataset.py
Normal file
210
codes/data/REDS_dataset.py
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
'''
|
||||||
|
REDS dataset
|
||||||
|
support reading images from lmdb, image folder and memcached
|
||||||
|
'''
|
||||||
|
import os.path as osp
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import lmdb
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import data.util as util
|
||||||
|
try:
|
||||||
|
import mc # import memcached
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
class REDSDataset(data.Dataset):
|
||||||
|
'''
|
||||||
|
Reading the training REDS dataset
|
||||||
|
key example: 000_00000000
|
||||||
|
GT: Ground-Truth;
|
||||||
|
LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames
|
||||||
|
support reading N LQ frames, N = 1, 3, 5, 7
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(REDSDataset, self).__init__()
|
||||||
|
self.opt = opt
|
||||||
|
# temporal augmentation
|
||||||
|
self.interval_list = opt['interval_list']
|
||||||
|
self.random_reverse = opt['random_reverse']
|
||||||
|
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
|
||||||
|
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
|
||||||
|
|
||||||
|
self.half_N_frames = opt['N_frames'] // 2
|
||||||
|
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||||||
|
self.data_type = self.opt['data_type']
|
||||||
|
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
|
||||||
|
#### directly load image keys
|
||||||
|
if self.data_type == 'lmdb':
|
||||||
|
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||||
|
logger.info('Using lmdb meta info for cache keys.')
|
||||||
|
elif opt['cache_keys']:
|
||||||
|
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
|
||||||
|
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
|
||||||
|
|
||||||
|
# remove the REDS4 for testing
|
||||||
|
self.paths_GT = [
|
||||||
|
v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020']
|
||||||
|
]
|
||||||
|
assert self.paths_GT, 'Error: GT path is empty.'
|
||||||
|
|
||||||
|
if self.data_type == 'lmdb':
|
||||||
|
self.GT_env, self.LQ_env = None, None
|
||||||
|
elif self.data_type == 'mc': # memcached
|
||||||
|
self.mclient = None
|
||||||
|
elif self.data_type == 'img':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong data type: {}'.format(self.data_type))
|
||||||
|
|
||||||
|
def _init_lmdb(self):
|
||||||
|
# https://github.com/chainer/chainermn/issues/129
|
||||||
|
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
|
||||||
|
def _ensure_memcached(self):
|
||||||
|
if self.mclient is None:
|
||||||
|
# specify the config files
|
||||||
|
server_list_config_file = None
|
||||||
|
client_config_file = None
|
||||||
|
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
|
||||||
|
client_config_file)
|
||||||
|
|
||||||
|
def _read_img_mc(self, path):
|
||||||
|
''' Return BGR, HWC, [0, 255], uint8'''
|
||||||
|
value = mc.pyvector()
|
||||||
|
self.mclient.Get(path, value)
|
||||||
|
value_buf = mc.ConvertBuffer(value)
|
||||||
|
img_array = np.frombuffer(value_buf, np.uint8)
|
||||||
|
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def _read_img_mc_BGR(self, path, name_a, name_b):
|
||||||
|
''' Read BGR channels separately and then combine for 1M limits in cluster'''
|
||||||
|
img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png'))
|
||||||
|
img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png'))
|
||||||
|
img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png'))
|
||||||
|
img = cv2.merge((img_B, img_G, img_R))
|
||||||
|
return img
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.data_type == 'mc':
|
||||||
|
self._ensure_memcached()
|
||||||
|
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||||
|
self._init_lmdb()
|
||||||
|
|
||||||
|
scale = self.opt['scale']
|
||||||
|
GT_size = self.opt['GT_size']
|
||||||
|
key = self.paths_GT[index]
|
||||||
|
name_a, name_b = key.split('_')
|
||||||
|
center_frame_idx = int(name_b)
|
||||||
|
|
||||||
|
#### determine the neighbor frames
|
||||||
|
interval = random.choice(self.interval_list)
|
||||||
|
if self.opt['border_mode']:
|
||||||
|
direction = 1 # 1: forward; 0: backward
|
||||||
|
N_frames = self.opt['N_frames']
|
||||||
|
if self.random_reverse and random.random() < 0.5:
|
||||||
|
direction = random.choice([0, 1])
|
||||||
|
if center_frame_idx + interval * (N_frames - 1) > 99:
|
||||||
|
direction = 0
|
||||||
|
elif center_frame_idx - interval * (N_frames - 1) < 0:
|
||||||
|
direction = 1
|
||||||
|
# get the neighbor list
|
||||||
|
if direction == 1:
|
||||||
|
neighbor_list = list(
|
||||||
|
range(center_frame_idx, center_frame_idx + interval * N_frames, interval))
|
||||||
|
else:
|
||||||
|
neighbor_list = list(
|
||||||
|
range(center_frame_idx, center_frame_idx - interval * N_frames, -interval))
|
||||||
|
name_b = '{:08d}'.format(neighbor_list[0])
|
||||||
|
else:
|
||||||
|
# ensure not exceeding the borders
|
||||||
|
while (center_frame_idx + self.half_N_frames * interval >
|
||||||
|
99) or (center_frame_idx - self.half_N_frames * interval < 0):
|
||||||
|
center_frame_idx = random.randint(0, 99)
|
||||||
|
# get the neighbor list
|
||||||
|
neighbor_list = list(
|
||||||
|
range(center_frame_idx - self.half_N_frames * interval,
|
||||||
|
center_frame_idx + self.half_N_frames * interval + 1, interval))
|
||||||
|
if self.random_reverse and random.random() < 0.5:
|
||||||
|
neighbor_list.reverse()
|
||||||
|
name_b = '{:08d}'.format(neighbor_list[self.half_N_frames])
|
||||||
|
|
||||||
|
assert len(
|
||||||
|
neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format(
|
||||||
|
len(neighbor_list))
|
||||||
|
|
||||||
|
#### get the GT image (as the center frame)
|
||||||
|
if self.data_type == 'mc':
|
||||||
|
img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b)
|
||||||
|
img_GT = img_GT.astype(np.float32) / 255.
|
||||||
|
elif self.data_type == 'lmdb':
|
||||||
|
img_GT = util.read_img(self.GT_env, key, (3, 720, 1280))
|
||||||
|
else:
|
||||||
|
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png'))
|
||||||
|
|
||||||
|
#### get LQ images
|
||||||
|
LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280)
|
||||||
|
img_LQ_l = []
|
||||||
|
for v in neighbor_list:
|
||||||
|
img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v))
|
||||||
|
if self.data_type == 'mc':
|
||||||
|
if self.LR_input:
|
||||||
|
img_LQ = self._read_img_mc(img_LQ_path)
|
||||||
|
else:
|
||||||
|
img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v))
|
||||||
|
img_LQ = img_LQ.astype(np.float32) / 255.
|
||||||
|
elif self.data_type == 'lmdb':
|
||||||
|
img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple)
|
||||||
|
else:
|
||||||
|
img_LQ = util.read_img(None, img_LQ_path)
|
||||||
|
img_LQ_l.append(img_LQ)
|
||||||
|
|
||||||
|
if self.opt['phase'] == 'train':
|
||||||
|
C, H, W = LQ_size_tuple # LQ size
|
||||||
|
# randomly crop
|
||||||
|
if self.LR_input:
|
||||||
|
LQ_size = GT_size // scale
|
||||||
|
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||||
|
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||||
|
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
|
||||||
|
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
|
||||||
|
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
|
||||||
|
else:
|
||||||
|
rnd_h = random.randint(0, max(0, H - GT_size))
|
||||||
|
rnd_w = random.randint(0, max(0, W - GT_size))
|
||||||
|
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
|
||||||
|
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
|
||||||
|
|
||||||
|
# augmentation - flip, rotate
|
||||||
|
img_LQ_l.append(img_GT)
|
||||||
|
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
|
||||||
|
img_LQ_l = rlt[0:-1]
|
||||||
|
img_GT = rlt[-1]
|
||||||
|
|
||||||
|
# stack LQ images to NHWC, N is the frame number
|
||||||
|
img_LQs = np.stack(img_LQ_l, axis=0)
|
||||||
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||||
|
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||||
|
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
|
||||||
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||||
|
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
|
||||||
|
(0, 3, 1, 2)))).float()
|
||||||
|
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths_GT)
|
167
codes/data/Vimeo90K_dataset.py
Normal file
167
codes/data/Vimeo90K_dataset.py
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
'''
|
||||||
|
Vimeo90K dataset
|
||||||
|
support reading images from lmdb, image folder and memcached
|
||||||
|
'''
|
||||||
|
import os.path as osp
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import lmdb
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import data.util as util
|
||||||
|
try:
|
||||||
|
import mc # import memcached
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
class Vimeo90KDataset(data.Dataset):
|
||||||
|
'''
|
||||||
|
Reading the training Vimeo90K dataset
|
||||||
|
key example: 00001_0001 (_1, ..., _7)
|
||||||
|
GT (Ground-Truth): 4th frame;
|
||||||
|
LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(Vimeo90KDataset, self).__init__()
|
||||||
|
self.opt = opt
|
||||||
|
# temporal augmentation
|
||||||
|
self.interval_list = opt['interval_list']
|
||||||
|
self.random_reverse = opt['random_reverse']
|
||||||
|
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
|
||||||
|
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
|
||||||
|
|
||||||
|
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||||||
|
self.data_type = self.opt['data_type']
|
||||||
|
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
|
||||||
|
|
||||||
|
#### determine the LQ frame list
|
||||||
|
'''
|
||||||
|
N | frames
|
||||||
|
1 | 4
|
||||||
|
3 | 3,4,5
|
||||||
|
5 | 2,3,4,5,6
|
||||||
|
7 | 1,2,3,4,5,6,7
|
||||||
|
'''
|
||||||
|
self.LQ_frames_list = []
|
||||||
|
for i in range(opt['N_frames']):
|
||||||
|
self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2)
|
||||||
|
|
||||||
|
#### directly load image keys
|
||||||
|
if self.data_type == 'lmdb':
|
||||||
|
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||||
|
logger.info('Using lmdb meta info for cache keys.')
|
||||||
|
elif opt['cache_keys']:
|
||||||
|
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
|
||||||
|
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
|
||||||
|
assert self.paths_GT, 'Error: GT path is empty.'
|
||||||
|
|
||||||
|
if self.data_type == 'lmdb':
|
||||||
|
self.GT_env, self.LQ_env = None, None
|
||||||
|
elif self.data_type == 'mc': # memcached
|
||||||
|
self.mclient = None
|
||||||
|
elif self.data_type == 'img':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong data type: {}'.format(self.data_type))
|
||||||
|
|
||||||
|
def _init_lmdb(self):
|
||||||
|
# https://github.com/chainer/chainermn/issues/129
|
||||||
|
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
|
||||||
|
def _ensure_memcached(self):
|
||||||
|
if self.mclient is None:
|
||||||
|
# specify the config files
|
||||||
|
server_list_config_file = None
|
||||||
|
client_config_file = None
|
||||||
|
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
|
||||||
|
client_config_file)
|
||||||
|
|
||||||
|
def _read_img_mc(self, path):
|
||||||
|
''' Return BGR, HWC, [0, 255], uint8'''
|
||||||
|
value = mc.pyvector()
|
||||||
|
self.mclient.Get(path, value)
|
||||||
|
value_buf = mc.ConvertBuffer(value)
|
||||||
|
img_array = np.frombuffer(value_buf, np.uint8)
|
||||||
|
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.data_type == 'mc':
|
||||||
|
self._ensure_memcached()
|
||||||
|
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||||
|
self._init_lmdb()
|
||||||
|
|
||||||
|
scale = self.opt['scale']
|
||||||
|
GT_size = self.opt['GT_size']
|
||||||
|
key = self.paths_GT[index]
|
||||||
|
name_a, name_b = key.split('_')
|
||||||
|
#### get the GT image (as the center frame)
|
||||||
|
if self.data_type == 'mc':
|
||||||
|
img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png'))
|
||||||
|
img_GT = img_GT.astype(np.float32) / 255.
|
||||||
|
elif self.data_type == 'lmdb':
|
||||||
|
img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448))
|
||||||
|
else:
|
||||||
|
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png'))
|
||||||
|
|
||||||
|
#### get LQ images
|
||||||
|
LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
|
||||||
|
img_LQ_l = []
|
||||||
|
for v in self.LQ_frames_list:
|
||||||
|
if self.data_type == 'mc':
|
||||||
|
img_LQ = self._read_img_mc(
|
||||||
|
osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v)))
|
||||||
|
img_LQ = img_LQ.astype(np.float32) / 255.
|
||||||
|
elif self.data_type == 'lmdb':
|
||||||
|
img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
|
||||||
|
else:
|
||||||
|
img_LQ = util.read_img(None,
|
||||||
|
osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
|
||||||
|
img_LQ_l.append(img_LQ)
|
||||||
|
|
||||||
|
if self.opt['phase'] == 'train':
|
||||||
|
C, H, W = LQ_size_tuple # LQ size
|
||||||
|
# randomly crop
|
||||||
|
if self.LR_input:
|
||||||
|
LQ_size = GT_size // scale
|
||||||
|
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||||
|
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||||
|
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
|
||||||
|
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
|
||||||
|
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
|
||||||
|
else:
|
||||||
|
rnd_h = random.randint(0, max(0, H - GT_size))
|
||||||
|
rnd_w = random.randint(0, max(0, W - GT_size))
|
||||||
|
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
|
||||||
|
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
|
||||||
|
|
||||||
|
# augmentation - flip, rotate
|
||||||
|
img_LQ_l.append(img_GT)
|
||||||
|
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
|
||||||
|
img_LQ_l = rlt[0:-1]
|
||||||
|
img_GT = rlt[-1]
|
||||||
|
|
||||||
|
# stack LQ images to NHWC, N is the frame number
|
||||||
|
img_LQs = np.stack(img_LQ_l, axis=0)
|
||||||
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||||
|
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||||
|
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
|
||||||
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||||
|
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
|
||||||
|
(0, 3, 1, 2)))).float()
|
||||||
|
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths_GT)
|
49
codes/data/__init__.py
Normal file
49
codes/data/__init__.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
"""create dataset and dataloader"""
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||||
|
phase = dataset_opt['phase']
|
||||||
|
if phase == 'train':
|
||||||
|
if opt['dist']:
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
num_workers = dataset_opt['n_workers']
|
||||||
|
assert dataset_opt['batch_size'] % world_size == 0
|
||||||
|
batch_size = dataset_opt['batch_size'] // world_size
|
||||||
|
shuffle = False
|
||||||
|
else:
|
||||||
|
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
|
||||||
|
batch_size = dataset_opt['batch_size']
|
||||||
|
shuffle = True
|
||||||
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||||
|
num_workers=num_workers, sampler=sampler, drop_last=True,
|
||||||
|
pin_memory=False)
|
||||||
|
else:
|
||||||
|
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
|
||||||
|
pin_memory=False)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(dataset_opt):
|
||||||
|
mode = dataset_opt['mode']
|
||||||
|
# datasets for image restoration
|
||||||
|
if mode == 'LQ':
|
||||||
|
from data.LQ_dataset import LQDataset as D
|
||||||
|
elif mode == 'LQGT':
|
||||||
|
from data.LQGT_dataset import LQGTDataset as D
|
||||||
|
# datasets for video restoration
|
||||||
|
elif mode == 'REDS':
|
||||||
|
from data.REDS_dataset import REDSDataset as D
|
||||||
|
elif mode == 'Vimeo90K':
|
||||||
|
from data.Vimeo90K_dataset import Vimeo90KDataset as D
|
||||||
|
elif mode == 'video_test':
|
||||||
|
from data.video_test_dataset import VideoTestDataset as D
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||||
|
dataset = D(dataset_opt)
|
||||||
|
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
|
||||||
|
dataset_opt['name']))
|
||||||
|
return dataset
|
65
codes/data/data_sampler.py
Normal file
65
codes/data/data_sampler.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
"""
|
||||||
|
Modified from torch.utils.data.distributed.DistributedSampler
|
||||||
|
Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the
|
||||||
|
dataloader after each epoch
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.sampler import Sampler
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
class DistIterSampler(Sampler):
|
||||||
|
"""Sampler that restricts data loading to a subset of the dataset.
|
||||||
|
|
||||||
|
It is especially useful in conjunction with
|
||||||
|
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
||||||
|
process can pass a DistributedSampler instance as a DataLoader sampler,
|
||||||
|
and load a subset of the original dataset that is exclusive to it.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Dataset is assumed to be of constant size.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
dataset: Dataset used for sampling.
|
||||||
|
num_replicas (optional): Number of processes participating in
|
||||||
|
distributed training.
|
||||||
|
rank (optional): Rank of the current process within num_replicas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
|
||||||
|
if num_replicas is None:
|
||||||
|
if not dist.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
num_replicas = dist.get_world_size()
|
||||||
|
if rank is None:
|
||||||
|
if not dist.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
rank = dist.get_rank()
|
||||||
|
self.dataset = dataset
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.rank = rank
|
||||||
|
self.epoch = 0
|
||||||
|
self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
|
||||||
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# deterministically shuffle based on epoch
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.epoch)
|
||||||
|
indices = torch.randperm(self.total_size, generator=g).tolist()
|
||||||
|
|
||||||
|
dsize = len(self.dataset)
|
||||||
|
indices = [v % dsize for v in indices]
|
||||||
|
|
||||||
|
# subsample
|
||||||
|
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||||
|
assert len(indices) == self.num_samples
|
||||||
|
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def set_epoch(self, epoch):
|
||||||
|
self.epoch = epoch
|
543
codes/data/util.py
Normal file
543
codes/data/util.py
Normal file
|
@ -0,0 +1,543 @@
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Files & IO
|
||||||
|
####################
|
||||||
|
|
||||||
|
###################### get image path list ######################
|
||||||
|
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
|
||||||
|
|
||||||
|
|
||||||
|
def is_image_file(filename):
|
||||||
|
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_paths_from_images(path):
|
||||||
|
"""get image path list from image folder"""
|
||||||
|
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
||||||
|
images = []
|
||||||
|
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||||
|
for fname in sorted(fnames):
|
||||||
|
if is_image_file(fname):
|
||||||
|
img_path = os.path.join(dirpath, fname)
|
||||||
|
images.append(img_path)
|
||||||
|
assert images, '{:s} has no valid image file'.format(path)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def _get_paths_from_lmdb(dataroot):
|
||||||
|
"""get image path list from lmdb meta info"""
|
||||||
|
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
|
||||||
|
paths = meta_info['keys']
|
||||||
|
sizes = meta_info['resolution']
|
||||||
|
if len(sizes) == 1:
|
||||||
|
sizes = sizes * len(paths)
|
||||||
|
return paths, sizes
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_paths(data_type, dataroot):
|
||||||
|
"""get image path list
|
||||||
|
support lmdb or image files"""
|
||||||
|
paths, sizes = None, None
|
||||||
|
if dataroot is not None:
|
||||||
|
if data_type == 'lmdb':
|
||||||
|
paths, sizes = _get_paths_from_lmdb(dataroot)
|
||||||
|
elif data_type == 'img':
|
||||||
|
paths = sorted(_get_paths_from_images(dataroot))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
|
||||||
|
return paths, sizes
|
||||||
|
|
||||||
|
|
||||||
|
def glob_file_list(root):
|
||||||
|
return sorted(glob.glob(os.path.join(root, '*')))
|
||||||
|
|
||||||
|
|
||||||
|
###################### read images ######################
|
||||||
|
def _read_img_lmdb(env, key, size):
|
||||||
|
"""read image from lmdb with key (w/ and w/o fixed size)
|
||||||
|
size: (C, H, W) tuple"""
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
buf = txn.get(key.encode('ascii'))
|
||||||
|
img_flat = np.frombuffer(buf, dtype=np.uint8)
|
||||||
|
C, H, W = size
|
||||||
|
img = img_flat.reshape(H, W, C)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def read_img(env, path, size=None):
|
||||||
|
"""read image by cv2 or from lmdb
|
||||||
|
return: Numpy float32, HWC, BGR, [0,1]"""
|
||||||
|
if env is None: # img
|
||||||
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
else:
|
||||||
|
img = _read_img_lmdb(env, path, size)
|
||||||
|
img = img.astype(np.float32) / 255.
|
||||||
|
if img.ndim == 2:
|
||||||
|
img = np.expand_dims(img, axis=2)
|
||||||
|
# some images have 4 channels
|
||||||
|
if img.shape[2] > 3:
|
||||||
|
img = img[:, :, :3]
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def read_img_seq(path):
|
||||||
|
"""Read a sequence of images from a given folder path
|
||||||
|
Args:
|
||||||
|
path (list/str): list of image paths/image folder path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
|
||||||
|
"""
|
||||||
|
if type(path) is list:
|
||||||
|
img_path_l = path
|
||||||
|
else:
|
||||||
|
img_path_l = sorted(glob.glob(os.path.join(path, '*')))
|
||||||
|
img_l = [read_img(None, v) for v in img_path_l]
|
||||||
|
# stack to Torch tensor
|
||||||
|
imgs = np.stack(img_l, axis=0)
|
||||||
|
imgs = imgs[:, :, :, [2, 1, 0]]
|
||||||
|
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
|
||||||
|
def index_generation(crt_i, max_n, N, padding='reflection'):
|
||||||
|
"""Generate an index list for reading N frames from a sequence of images
|
||||||
|
Args:
|
||||||
|
crt_i (int): current center index
|
||||||
|
max_n (int): max number of the sequence of images (calculated from 1)
|
||||||
|
N (int): reading N frames
|
||||||
|
padding (str): padding mode, one of replicate | reflection | new_info | circle
|
||||||
|
Example: crt_i = 0, N = 5
|
||||||
|
replicate: [0, 0, 0, 1, 2]
|
||||||
|
reflection: [2, 1, 0, 1, 2]
|
||||||
|
new_info: [4, 3, 0, 1, 2]
|
||||||
|
circle: [3, 4, 0, 1, 2]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return_l (list [int]): a list of indexes
|
||||||
|
"""
|
||||||
|
max_n = max_n - 1
|
||||||
|
n_pad = N // 2
|
||||||
|
return_l = []
|
||||||
|
|
||||||
|
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
|
||||||
|
if i < 0:
|
||||||
|
if padding == 'replicate':
|
||||||
|
add_idx = 0
|
||||||
|
elif padding == 'reflection':
|
||||||
|
add_idx = -i
|
||||||
|
elif padding == 'new_info':
|
||||||
|
add_idx = (crt_i + n_pad) + (-i)
|
||||||
|
elif padding == 'circle':
|
||||||
|
add_idx = N + i
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong padding mode')
|
||||||
|
elif i > max_n:
|
||||||
|
if padding == 'replicate':
|
||||||
|
add_idx = max_n
|
||||||
|
elif padding == 'reflection':
|
||||||
|
add_idx = max_n * 2 - i
|
||||||
|
elif padding == 'new_info':
|
||||||
|
add_idx = (crt_i - n_pad) - (i - max_n)
|
||||||
|
elif padding == 'circle':
|
||||||
|
add_idx = i - N
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong padding mode')
|
||||||
|
else:
|
||||||
|
add_idx = i
|
||||||
|
return_l.append(add_idx)
|
||||||
|
return return_l
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# image processing
|
||||||
|
# process on numpy image
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def augment(img_list, hflip=True, rot=True):
|
||||||
|
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
|
||||||
|
hflip = hflip and random.random() < 0.5
|
||||||
|
vflip = rot and random.random() < 0.5
|
||||||
|
rot90 = rot and random.random() < 0.5
|
||||||
|
|
||||||
|
def _augment(img):
|
||||||
|
if hflip:
|
||||||
|
img = img[:, ::-1, :]
|
||||||
|
if vflip:
|
||||||
|
img = img[::-1, :, :]
|
||||||
|
if rot90:
|
||||||
|
img = img.transpose(1, 0, 2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
return [_augment(img) for img in img_list]
|
||||||
|
|
||||||
|
|
||||||
|
def augment_flow(img_list, flow_list, hflip=True, rot=True):
|
||||||
|
"""horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows"""
|
||||||
|
hflip = hflip and random.random() < 0.5
|
||||||
|
vflip = rot and random.random() < 0.5
|
||||||
|
rot90 = rot and random.random() < 0.5
|
||||||
|
|
||||||
|
def _augment(img):
|
||||||
|
if hflip:
|
||||||
|
img = img[:, ::-1, :]
|
||||||
|
if vflip:
|
||||||
|
img = img[::-1, :, :]
|
||||||
|
if rot90:
|
||||||
|
img = img.transpose(1, 0, 2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def _augment_flow(flow):
|
||||||
|
if hflip:
|
||||||
|
flow = flow[:, ::-1, :]
|
||||||
|
flow[:, :, 0] *= -1
|
||||||
|
if vflip:
|
||||||
|
flow = flow[::-1, :, :]
|
||||||
|
flow[:, :, 1] *= -1
|
||||||
|
if rot90:
|
||||||
|
flow = flow.transpose(1, 0, 2)
|
||||||
|
flow = flow[:, :, [1, 0]]
|
||||||
|
return flow
|
||||||
|
|
||||||
|
rlt_img_list = [_augment(img) for img in img_list]
|
||||||
|
rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
|
||||||
|
|
||||||
|
return rlt_img_list, rlt_flow_list
|
||||||
|
|
||||||
|
|
||||||
|
def channel_convert(in_c, tar_type, img_list):
|
||||||
|
"""conversion among BGR, gray and y"""
|
||||||
|
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
||||||
|
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
||||||
|
return [np.expand_dims(img, axis=2) for img in gray_list]
|
||||||
|
elif in_c == 3 and tar_type == 'y': # BGR to y
|
||||||
|
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
||||||
|
return [np.expand_dims(img, axis=2) for img in y_list]
|
||||||
|
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
||||||
|
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
||||||
|
else:
|
||||||
|
return img_list
|
||||||
|
|
||||||
|
|
||||||
|
def rgb2ycbcr(img, only_y=True):
|
||||||
|
"""same as matlab rgb2ycbcr
|
||||||
|
only_y: only return Y channel
|
||||||
|
Input:
|
||||||
|
uint8, [0, 255]
|
||||||
|
float, [0, 1]
|
||||||
|
"""
|
||||||
|
in_img_type = img.dtype
|
||||||
|
img.astype(np.float32)
|
||||||
|
if in_img_type != np.uint8:
|
||||||
|
img *= 255.
|
||||||
|
# convert
|
||||||
|
if only_y:
|
||||||
|
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||||
|
else:
|
||||||
|
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
||||||
|
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
||||||
|
if in_img_type == np.uint8:
|
||||||
|
rlt = rlt.round()
|
||||||
|
else:
|
||||||
|
rlt /= 255.
|
||||||
|
return rlt.astype(in_img_type)
|
||||||
|
|
||||||
|
|
||||||
|
def bgr2ycbcr(img, only_y=True):
|
||||||
|
"""bgr version of rgb2ycbcr
|
||||||
|
only_y: only return Y channel
|
||||||
|
Input:
|
||||||
|
uint8, [0, 255]
|
||||||
|
float, [0, 1]
|
||||||
|
"""
|
||||||
|
in_img_type = img.dtype
|
||||||
|
img.astype(np.float32)
|
||||||
|
if in_img_type != np.uint8:
|
||||||
|
img *= 255.
|
||||||
|
# convert
|
||||||
|
if only_y:
|
||||||
|
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||||
|
else:
|
||||||
|
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||||
|
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
||||||
|
if in_img_type == np.uint8:
|
||||||
|
rlt = rlt.round()
|
||||||
|
else:
|
||||||
|
rlt /= 255.
|
||||||
|
return rlt.astype(in_img_type)
|
||||||
|
|
||||||
|
|
||||||
|
def ycbcr2rgb(img):
|
||||||
|
"""same as matlab ycbcr2rgb
|
||||||
|
Input:
|
||||||
|
uint8, [0, 255]
|
||||||
|
float, [0, 1]
|
||||||
|
"""
|
||||||
|
in_img_type = img.dtype
|
||||||
|
img.astype(np.float32)
|
||||||
|
if in_img_type != np.uint8:
|
||||||
|
img *= 255.
|
||||||
|
# convert
|
||||||
|
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||||
|
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
||||||
|
if in_img_type == np.uint8:
|
||||||
|
rlt = rlt.round()
|
||||||
|
else:
|
||||||
|
rlt /= 255.
|
||||||
|
return rlt.astype(in_img_type)
|
||||||
|
|
||||||
|
|
||||||
|
def modcrop(img_in, scale):
|
||||||
|
"""img_in: Numpy, HWC or HW"""
|
||||||
|
img = np.copy(img_in)
|
||||||
|
if img.ndim == 2:
|
||||||
|
H, W = img.shape
|
||||||
|
H_r, W_r = H % scale, W % scale
|
||||||
|
img = img[:H - H_r, :W - W_r]
|
||||||
|
elif img.ndim == 3:
|
||||||
|
H, W, C = img.shape
|
||||||
|
H_r, W_r = H % scale, W % scale
|
||||||
|
img = img[:H - H_r, :W - W_r, :]
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Functions
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
# matlab 'imresize' function, now only support 'bicubic'
|
||||||
|
def cubic(x):
|
||||||
|
absx = torch.abs(x)
|
||||||
|
absx2 = absx**2
|
||||||
|
absx3 = absx**3
|
||||||
|
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
||||||
|
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
|
||||||
|
(absx > 1) * (absx <= 2)).type_as(absx))
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||||
|
if (scale < 1) and (antialiasing):
|
||||||
|
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
||||||
|
kernel_width = kernel_width / scale
|
||||||
|
|
||||||
|
# Output-space coordinates
|
||||||
|
x = torch.linspace(1, out_length, out_length)
|
||||||
|
|
||||||
|
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||||
|
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
||||||
|
# space maps to 1.5 in input space.
|
||||||
|
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||||
|
|
||||||
|
# What is the left-most pixel that can be involved in the computation?
|
||||||
|
left = torch.floor(u - kernel_width / 2)
|
||||||
|
|
||||||
|
# What is the maximum number of pixels that can be involved in the
|
||||||
|
# computation? Note: it's OK to use an extra pixel here; if the
|
||||||
|
# corresponding weights are all zero, it will be eliminated at the end
|
||||||
|
# of this function.
|
||||||
|
P = math.ceil(kernel_width) + 2
|
||||||
|
|
||||||
|
# The indices of the input pixels involved in computing the k-th output
|
||||||
|
# pixel are in row k of the indices matrix.
|
||||||
|
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
||||||
|
1, P).expand(out_length, P)
|
||||||
|
|
||||||
|
# The weights used to compute the k-th output pixel are in row k of the
|
||||||
|
# weights matrix.
|
||||||
|
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
||||||
|
# apply cubic kernel
|
||||||
|
if (scale < 1) and (antialiasing):
|
||||||
|
weights = scale * cubic(distance_to_center * scale)
|
||||||
|
else:
|
||||||
|
weights = cubic(distance_to_center)
|
||||||
|
# Normalize the weights matrix so that each row sums to 1.
|
||||||
|
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||||
|
weights = weights / weights_sum.expand(out_length, P)
|
||||||
|
|
||||||
|
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
||||||
|
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||||
|
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||||
|
indices = indices.narrow(1, 1, P - 2)
|
||||||
|
weights = weights.narrow(1, 1, P - 2)
|
||||||
|
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||||
|
indices = indices.narrow(1, 0, P - 2)
|
||||||
|
weights = weights.narrow(1, 0, P - 2)
|
||||||
|
weights = weights.contiguous()
|
||||||
|
indices = indices.contiguous()
|
||||||
|
sym_len_s = -indices.min() + 1
|
||||||
|
sym_len_e = indices.max() - in_length
|
||||||
|
indices = indices + sym_len_s - 1
|
||||||
|
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||||
|
|
||||||
|
|
||||||
|
def imresize(img, scale, antialiasing=True):
|
||||||
|
# Now the scale should be the same for H and W
|
||||||
|
# input: img: CHW RGB [0,1]
|
||||||
|
# output: CHW RGB [0,1] w/o round
|
||||||
|
|
||||||
|
in_C, in_H, in_W = img.size()
|
||||||
|
_, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||||
|
kernel_width = 4
|
||||||
|
kernel = 'cubic'
|
||||||
|
|
||||||
|
# Return the desired dimension order for performing the resize. The
|
||||||
|
# strategy is to perform the resize first along the dimension with the
|
||||||
|
# smallest scale factor.
|
||||||
|
# Now we do not support this.
|
||||||
|
|
||||||
|
# get weights and indices
|
||||||
|
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||||
|
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||||
|
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||||
|
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||||
|
# process H dimension
|
||||||
|
# symmetric copying
|
||||||
|
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
||||||
|
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
||||||
|
|
||||||
|
sym_patch = img[:, :sym_len_Hs, :]
|
||||||
|
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||||
|
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||||
|
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||||
|
|
||||||
|
sym_patch = img[:, -sym_len_He:, :]
|
||||||
|
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||||
|
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||||
|
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||||
|
|
||||||
|
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
||||||
|
kernel_width = weights_H.size(1)
|
||||||
|
for i in range(out_H):
|
||||||
|
idx = int(indices_H[i][0])
|
||||||
|
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||||
|
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||||
|
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||||
|
|
||||||
|
# process W dimension
|
||||||
|
# symmetric copying
|
||||||
|
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
||||||
|
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
||||||
|
|
||||||
|
sym_patch = out_1[:, :, :sym_len_Ws]
|
||||||
|
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||||
|
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||||
|
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||||
|
|
||||||
|
sym_patch = out_1[:, :, -sym_len_We:]
|
||||||
|
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||||
|
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||||
|
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||||
|
|
||||||
|
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
||||||
|
kernel_width = weights_W.size(1)
|
||||||
|
for i in range(out_W):
|
||||||
|
idx = int(indices_W[i][0])
|
||||||
|
out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||||
|
out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||||
|
out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||||
|
|
||||||
|
return out_2
|
||||||
|
|
||||||
|
|
||||||
|
def imresize_np(img, scale, antialiasing=True):
|
||||||
|
# Now the scale should be the same for H and W
|
||||||
|
# input: img: Numpy, HWC BGR [0,1]
|
||||||
|
# output: HWC BGR [0,1] w/o round
|
||||||
|
img = torch.from_numpy(img)
|
||||||
|
|
||||||
|
in_H, in_W, in_C = img.size()
|
||||||
|
_, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||||
|
kernel_width = 4
|
||||||
|
kernel = 'cubic'
|
||||||
|
|
||||||
|
# Return the desired dimension order for performing the resize. The
|
||||||
|
# strategy is to perform the resize first along the dimension with the
|
||||||
|
# smallest scale factor.
|
||||||
|
# Now we do not support this.
|
||||||
|
|
||||||
|
# get weights and indices
|
||||||
|
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||||
|
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||||
|
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||||
|
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||||
|
# process H dimension
|
||||||
|
# symmetric copying
|
||||||
|
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
||||||
|
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
||||||
|
|
||||||
|
sym_patch = img[:sym_len_Hs, :, :]
|
||||||
|
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||||
|
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||||
|
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||||
|
|
||||||
|
sym_patch = img[-sym_len_He:, :, :]
|
||||||
|
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||||
|
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||||
|
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||||
|
|
||||||
|
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
||||||
|
kernel_width = weights_H.size(1)
|
||||||
|
for i in range(out_H):
|
||||||
|
idx = int(indices_H[i][0])
|
||||||
|
out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
|
||||||
|
out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
|
||||||
|
out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
|
||||||
|
|
||||||
|
# process W dimension
|
||||||
|
# symmetric copying
|
||||||
|
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
||||||
|
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
||||||
|
|
||||||
|
sym_patch = out_1[:, :sym_len_Ws, :]
|
||||||
|
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||||
|
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||||
|
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||||
|
|
||||||
|
sym_patch = out_1[:, -sym_len_We:, :]
|
||||||
|
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||||
|
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||||
|
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||||
|
|
||||||
|
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
||||||
|
kernel_width = weights_W.size(1)
|
||||||
|
for i in range(out_W):
|
||||||
|
idx = int(indices_W[i][0])
|
||||||
|
out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
|
||||||
|
out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
|
||||||
|
out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
|
||||||
|
|
||||||
|
return out_2.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# test imresize function
|
||||||
|
# read images
|
||||||
|
img = cv2.imread('test.png')
|
||||||
|
img = img * 1.0 / 255
|
||||||
|
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
||||||
|
# imresize
|
||||||
|
scale = 1 / 4
|
||||||
|
import time
|
||||||
|
total_time = 0
|
||||||
|
for i in range(10):
|
||||||
|
start_time = time.time()
|
||||||
|
rlt = imresize(img, scale, antialiasing=True)
|
||||||
|
use_time = time.time() - start_time
|
||||||
|
total_time += use_time
|
||||||
|
print('average time: {}'.format(total_time / 10))
|
||||||
|
|
||||||
|
import torchvision.utils
|
||||||
|
torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
|
||||||
|
normalize=False)
|
84
codes/data/video_test_dataset.py
Normal file
84
codes/data/video_test_dataset.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import os.path as osp
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import data.util as util
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTestDataset(data.Dataset):
|
||||||
|
"""
|
||||||
|
A video test dataset. Support:
|
||||||
|
Vid4
|
||||||
|
REDS4
|
||||||
|
Vimeo90K-Test
|
||||||
|
|
||||||
|
no need to prepare LMDB files
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(VideoTestDataset, self).__init__()
|
||||||
|
self.opt = opt
|
||||||
|
self.cache_data = opt['cache_data']
|
||||||
|
self.half_N_frames = opt['N_frames'] // 2
|
||||||
|
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||||||
|
self.data_type = self.opt['data_type']
|
||||||
|
self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []}
|
||||||
|
if self.data_type == 'lmdb':
|
||||||
|
raise ValueError('No need to use LMDB during validation/test.')
|
||||||
|
#### Generate data info and cache data
|
||||||
|
self.imgs_LQ, self.imgs_GT = {}, {}
|
||||||
|
if opt['name'].lower() in ['vid4', 'reds4']:
|
||||||
|
subfolders_LQ = util.glob_file_list(self.LQ_root)
|
||||||
|
subfolders_GT = util.glob_file_list(self.GT_root)
|
||||||
|
for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT):
|
||||||
|
subfolder_name = osp.basename(subfolder_GT)
|
||||||
|
img_paths_LQ = util.glob_file_list(subfolder_LQ)
|
||||||
|
img_paths_GT = util.glob_file_list(subfolder_GT)
|
||||||
|
max_idx = len(img_paths_LQ)
|
||||||
|
assert max_idx == len(
|
||||||
|
img_paths_GT), 'Different number of images in LQ and GT folders'
|
||||||
|
self.data_info['path_LQ'].extend(img_paths_LQ)
|
||||||
|
self.data_info['path_GT'].extend(img_paths_GT)
|
||||||
|
self.data_info['folder'].extend([subfolder_name] * max_idx)
|
||||||
|
for i in range(max_idx):
|
||||||
|
self.data_info['idx'].append('{}/{}'.format(i, max_idx))
|
||||||
|
border_l = [0] * max_idx
|
||||||
|
for i in range(self.half_N_frames):
|
||||||
|
border_l[i] = 1
|
||||||
|
border_l[max_idx - i - 1] = 1
|
||||||
|
self.data_info['border'].extend(border_l)
|
||||||
|
|
||||||
|
if self.cache_data:
|
||||||
|
self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ)
|
||||||
|
self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT)
|
||||||
|
elif opt['name'].lower() in ['vimeo90k-test']:
|
||||||
|
pass # TODO
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.')
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# path_LQ = self.data_info['path_LQ'][index]
|
||||||
|
# path_GT = self.data_info['path_GT'][index]
|
||||||
|
folder = self.data_info['folder'][index]
|
||||||
|
idx, max_idx = self.data_info['idx'][index].split('/')
|
||||||
|
idx, max_idx = int(idx), int(max_idx)
|
||||||
|
border = self.data_info['border'][index]
|
||||||
|
|
||||||
|
if self.cache_data:
|
||||||
|
select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
|
||||||
|
padding=self.opt['padding'])
|
||||||
|
imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx))
|
||||||
|
img_GT = self.imgs_GT[folder][idx]
|
||||||
|
else:
|
||||||
|
pass # TODO
|
||||||
|
|
||||||
|
return {
|
||||||
|
'LQs': imgs_LQ,
|
||||||
|
'GT': img_GT,
|
||||||
|
'folder': folder,
|
||||||
|
'idx': self.data_info['idx'][index],
|
||||||
|
'border': border
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_info['path_GT'])
|
411
codes/data_scripts/create_lmdb.py
Normal file
411
codes/data_scripts/create_lmdb.py
Normal file
|
@ -0,0 +1,411 @@
|
||||||
|
"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os.path as osp
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
from multiprocessing import Pool
|
||||||
|
import numpy as np
|
||||||
|
import lmdb
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||||
|
import data.util as data_util # noqa: E402
|
||||||
|
import utils.util as util # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
|
||||||
|
mode = 'GT' # used for vimeo90k and REDS datasets
|
||||||
|
# vimeo90k: GT | LR | flow
|
||||||
|
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
|
||||||
|
# train_sharp_flowx4
|
||||||
|
if dataset == 'vimeo90k':
|
||||||
|
vimeo90k(mode)
|
||||||
|
elif dataset == 'REDS':
|
||||||
|
REDS(mode)
|
||||||
|
elif dataset == 'general':
|
||||||
|
opt = {}
|
||||||
|
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
||||||
|
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
||||||
|
opt['name'] = 'DIV2K800_sub_GT'
|
||||||
|
general_image_folder(opt)
|
||||||
|
elif dataset == 'DIV2K_demo':
|
||||||
|
opt = {}
|
||||||
|
## GT
|
||||||
|
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
||||||
|
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
||||||
|
opt['name'] = 'DIV2K800_sub_GT'
|
||||||
|
general_image_folder(opt)
|
||||||
|
## LR
|
||||||
|
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
|
||||||
|
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
|
||||||
|
opt['name'] = 'DIV2K800_sub_bicLRx4'
|
||||||
|
general_image_folder(opt)
|
||||||
|
elif dataset == 'test':
|
||||||
|
test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS')
|
||||||
|
|
||||||
|
|
||||||
|
def read_image_worker(path, key):
|
||||||
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
return (key, img)
|
||||||
|
|
||||||
|
|
||||||
|
def general_image_folder(opt):
|
||||||
|
"""Create lmdb for general image folders
|
||||||
|
Users should define the keys, such as: '0321_s035' for DIV2K sub-images
|
||||||
|
If all the images have the same resolution, it will only store one copy of resolution info.
|
||||||
|
Otherwise, it will store every resolution info.
|
||||||
|
"""
|
||||||
|
#### configurations
|
||||||
|
read_all_imgs = False # whether real all images to memory with multiprocessing
|
||||||
|
# Set False for use limited memory
|
||||||
|
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
||||||
|
n_thread = 40
|
||||||
|
########################################################
|
||||||
|
img_folder = opt['img_folder']
|
||||||
|
lmdb_save_path = opt['lmdb_save_path']
|
||||||
|
meta_info = {'name': opt['name']}
|
||||||
|
if not lmdb_save_path.endswith('.lmdb'):
|
||||||
|
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
||||||
|
if osp.exists(lmdb_save_path):
|
||||||
|
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
#### read all the image paths to a list
|
||||||
|
print('Reading image path list ...')
|
||||||
|
all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
|
||||||
|
keys = []
|
||||||
|
for img_path in all_img_list:
|
||||||
|
keys.append(osp.splitext(osp.basename(img_path))[0])
|
||||||
|
|
||||||
|
if read_all_imgs:
|
||||||
|
#### read all images to memory (multiprocessing)
|
||||||
|
dataset = {} # store all image data. list cannot keep the order, use dict
|
||||||
|
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
||||||
|
pbar = util.ProgressBar(len(all_img_list))
|
||||||
|
|
||||||
|
def mycallback(arg):
|
||||||
|
'''get the image data and update pbar'''
|
||||||
|
key = arg[0]
|
||||||
|
dataset[key] = arg[1]
|
||||||
|
pbar.update('Reading {}'.format(key))
|
||||||
|
|
||||||
|
pool = Pool(n_thread)
|
||||||
|
for path, key in zip(all_img_list, keys):
|
||||||
|
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
||||||
|
|
||||||
|
#### create lmdb environment
|
||||||
|
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
||||||
|
print('data size per image is: ', data_size_per_img)
|
||||||
|
data_size = data_size_per_img * len(all_img_list)
|
||||||
|
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
||||||
|
|
||||||
|
#### write data to lmdb
|
||||||
|
pbar = util.ProgressBar(len(all_img_list))
|
||||||
|
txn = env.begin(write=True)
|
||||||
|
resolutions = []
|
||||||
|
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
||||||
|
pbar.update('Write {}'.format(key))
|
||||||
|
key_byte = key.encode('ascii')
|
||||||
|
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if data.ndim == 2:
|
||||||
|
H, W = data.shape
|
||||||
|
C = 1
|
||||||
|
else:
|
||||||
|
H, W, C = data.shape
|
||||||
|
txn.put(key_byte, data)
|
||||||
|
resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
|
||||||
|
if not read_all_imgs and idx % BATCH == 0:
|
||||||
|
txn.commit()
|
||||||
|
txn = env.begin(write=True)
|
||||||
|
txn.commit()
|
||||||
|
env.close()
|
||||||
|
print('Finish writing lmdb.')
|
||||||
|
|
||||||
|
#### create meta information
|
||||||
|
# check whether all the images are the same size
|
||||||
|
assert len(keys) == len(resolutions)
|
||||||
|
if len(set(resolutions)) <= 1:
|
||||||
|
meta_info['resolution'] = [resolutions[0]]
|
||||||
|
meta_info['keys'] = keys
|
||||||
|
print('All images have the same resolution. Simplify the meta info.')
|
||||||
|
else:
|
||||||
|
meta_info['resolution'] = resolutions
|
||||||
|
meta_info['keys'] = keys
|
||||||
|
print('Not all images have the same resolution. Save meta info for each image.')
|
||||||
|
|
||||||
|
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
||||||
|
print('Finish creating lmdb meta info.')
|
||||||
|
|
||||||
|
|
||||||
|
def vimeo90k(mode):
|
||||||
|
"""Create lmdb for the Vimeo90K dataset, each image with a fixed size
|
||||||
|
GT: [3, 256, 448]
|
||||||
|
Now only need the 4th frame, e.g., 00001_0001_4
|
||||||
|
LR: [3, 64, 112]
|
||||||
|
1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
|
||||||
|
key:
|
||||||
|
Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
|
||||||
|
|
||||||
|
flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3]
|
||||||
|
Each flow is calculated with GT images by PWCNet and then downsampled by 1/4
|
||||||
|
Flow map is quantized by mmcv and saved in png format
|
||||||
|
"""
|
||||||
|
#### configurations
|
||||||
|
read_all_imgs = False # whether real all images to memory with multiprocessing
|
||||||
|
# Set False for use limited memory
|
||||||
|
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
||||||
|
if mode == 'GT':
|
||||||
|
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
|
||||||
|
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
||||||
|
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
||||||
|
H_dst, W_dst = 256, 448
|
||||||
|
elif mode == 'LR':
|
||||||
|
img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
|
||||||
|
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
|
||||||
|
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
||||||
|
H_dst, W_dst = 64, 112
|
||||||
|
elif mode == 'flow':
|
||||||
|
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4'
|
||||||
|
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb'
|
||||||
|
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
||||||
|
H_dst, W_dst = 128, 112
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong dataset mode: {}'.format(mode))
|
||||||
|
n_thread = 40
|
||||||
|
########################################################
|
||||||
|
if not lmdb_save_path.endswith('.lmdb'):
|
||||||
|
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
||||||
|
if osp.exists(lmdb_save_path):
|
||||||
|
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
#### read all the image paths to a list
|
||||||
|
print('Reading image path list ...')
|
||||||
|
with open(txt_file) as f:
|
||||||
|
train_l = f.readlines()
|
||||||
|
train_l = [v.strip() for v in train_l]
|
||||||
|
all_img_list = []
|
||||||
|
keys = []
|
||||||
|
for line in train_l:
|
||||||
|
folder = line.split('/')[0]
|
||||||
|
sub_folder = line.split('/')[1]
|
||||||
|
all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
|
||||||
|
if mode == 'flow':
|
||||||
|
for j in range(1, 4):
|
||||||
|
keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j))
|
||||||
|
keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j))
|
||||||
|
else:
|
||||||
|
for j in range(7):
|
||||||
|
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
|
||||||
|
all_img_list = sorted(all_img_list)
|
||||||
|
keys = sorted(keys)
|
||||||
|
if mode == 'GT': # only read the 4th frame for the GT mode
|
||||||
|
print('Only keep the 4th frame.')
|
||||||
|
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
|
||||||
|
keys = [v for v in keys if v.endswith('_4')]
|
||||||
|
|
||||||
|
if read_all_imgs:
|
||||||
|
#### read all images to memory (multiprocessing)
|
||||||
|
dataset = {} # store all image data. list cannot keep the order, use dict
|
||||||
|
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
||||||
|
pbar = util.ProgressBar(len(all_img_list))
|
||||||
|
|
||||||
|
def mycallback(arg):
|
||||||
|
"""get the image data and update pbar"""
|
||||||
|
key = arg[0]
|
||||||
|
dataset[key] = arg[1]
|
||||||
|
pbar.update('Reading {}'.format(key))
|
||||||
|
|
||||||
|
pool = Pool(n_thread)
|
||||||
|
for path, key in zip(all_img_list, keys):
|
||||||
|
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
||||||
|
|
||||||
|
#### write data to lmdb
|
||||||
|
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
||||||
|
print('data size per image is: ', data_size_per_img)
|
||||||
|
data_size = data_size_per_img * len(all_img_list)
|
||||||
|
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
||||||
|
txn = env.begin(write=True)
|
||||||
|
pbar = util.ProgressBar(len(all_img_list))
|
||||||
|
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
||||||
|
pbar.update('Write {}'.format(key))
|
||||||
|
key_byte = key.encode('ascii')
|
||||||
|
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if 'flow' in mode:
|
||||||
|
H, W = data.shape
|
||||||
|
assert H == H_dst and W == W_dst, 'different shape.'
|
||||||
|
else:
|
||||||
|
H, W, C = data.shape
|
||||||
|
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
|
||||||
|
txn.put(key_byte, data)
|
||||||
|
if not read_all_imgs and idx % BATCH == 0:
|
||||||
|
txn.commit()
|
||||||
|
txn = env.begin(write=True)
|
||||||
|
txn.commit()
|
||||||
|
env.close()
|
||||||
|
print('Finish writing lmdb.')
|
||||||
|
|
||||||
|
#### create meta information
|
||||||
|
meta_info = {}
|
||||||
|
if mode == 'GT':
|
||||||
|
meta_info['name'] = 'Vimeo90K_train_GT'
|
||||||
|
elif mode == 'LR':
|
||||||
|
meta_info['name'] = 'Vimeo90K_train_LR'
|
||||||
|
elif mode == 'flow':
|
||||||
|
meta_info['name'] = 'Vimeo90K_train_flowx4'
|
||||||
|
channel = 1 if 'flow' in mode else 3
|
||||||
|
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
|
||||||
|
key_set = set()
|
||||||
|
for key in keys:
|
||||||
|
if mode == 'flow':
|
||||||
|
a, b, _, _ = key.split('_')
|
||||||
|
else:
|
||||||
|
a, b, _ = key.split('_')
|
||||||
|
key_set.add('{}_{}'.format(a, b))
|
||||||
|
meta_info['keys'] = list(key_set)
|
||||||
|
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
||||||
|
print('Finish creating lmdb meta info.')
|
||||||
|
|
||||||
|
|
||||||
|
def REDS(mode):
|
||||||
|
"""Create lmdb for the REDS dataset, each image with a fixed size
|
||||||
|
GT: [3, 720, 1280], key: 000_00000000
|
||||||
|
LR: [3, 180, 320], key: 000_00000000
|
||||||
|
key: 000_00000000
|
||||||
|
|
||||||
|
flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
|
||||||
|
Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
|
||||||
|
Flow map is quantized by mmcv and saved in png format
|
||||||
|
"""
|
||||||
|
#### configurations
|
||||||
|
read_all_imgs = False # whether real all images to memory with multiprocessing
|
||||||
|
# Set False for use limited memory
|
||||||
|
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
||||||
|
if mode == 'train_sharp':
|
||||||
|
img_folder = '../../datasets/REDS/train_sharp'
|
||||||
|
lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
|
||||||
|
H_dst, W_dst = 720, 1280
|
||||||
|
elif mode == 'train_sharp_bicubic':
|
||||||
|
img_folder = '../../datasets/REDS/train_sharp_bicubic'
|
||||||
|
lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
|
||||||
|
H_dst, W_dst = 180, 320
|
||||||
|
elif mode == 'train_blur_bicubic':
|
||||||
|
img_folder = '../../datasets/REDS/train_blur_bicubic'
|
||||||
|
lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
|
||||||
|
H_dst, W_dst = 180, 320
|
||||||
|
elif mode == 'train_blur':
|
||||||
|
img_folder = '../../datasets/REDS/train_blur'
|
||||||
|
lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
|
||||||
|
H_dst, W_dst = 720, 1280
|
||||||
|
elif mode == 'train_blur_comp':
|
||||||
|
img_folder = '../../datasets/REDS/train_blur_comp'
|
||||||
|
lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
|
||||||
|
H_dst, W_dst = 720, 1280
|
||||||
|
elif mode == 'train_sharp_flowx4':
|
||||||
|
img_folder = '../../datasets/REDS/train_sharp_flowx4'
|
||||||
|
lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
|
||||||
|
H_dst, W_dst = 360, 320
|
||||||
|
n_thread = 40
|
||||||
|
########################################################
|
||||||
|
if not lmdb_save_path.endswith('.lmdb'):
|
||||||
|
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
||||||
|
if osp.exists(lmdb_save_path):
|
||||||
|
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
#### read all the image paths to a list
|
||||||
|
print('Reading image path list ...')
|
||||||
|
all_img_list = data_util._get_paths_from_images(img_folder)
|
||||||
|
keys = []
|
||||||
|
for img_path in all_img_list:
|
||||||
|
split_rlt = img_path.split('/')
|
||||||
|
folder = split_rlt[-2]
|
||||||
|
img_name = split_rlt[-1].split('.png')[0]
|
||||||
|
keys.append(folder + '_' + img_name)
|
||||||
|
|
||||||
|
if read_all_imgs:
|
||||||
|
#### read all images to memory (multiprocessing)
|
||||||
|
dataset = {} # store all image data. list cannot keep the order, use dict
|
||||||
|
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
||||||
|
pbar = util.ProgressBar(len(all_img_list))
|
||||||
|
|
||||||
|
def mycallback(arg):
|
||||||
|
'''get the image data and update pbar'''
|
||||||
|
key = arg[0]
|
||||||
|
dataset[key] = arg[1]
|
||||||
|
pbar.update('Reading {}'.format(key))
|
||||||
|
|
||||||
|
pool = Pool(n_thread)
|
||||||
|
for path, key in zip(all_img_list, keys):
|
||||||
|
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
||||||
|
|
||||||
|
#### create lmdb environment
|
||||||
|
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
||||||
|
print('data size per image is: ', data_size_per_img)
|
||||||
|
data_size = data_size_per_img * len(all_img_list)
|
||||||
|
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
||||||
|
|
||||||
|
#### write data to lmdb
|
||||||
|
pbar = util.ProgressBar(len(all_img_list))
|
||||||
|
txn = env.begin(write=True)
|
||||||
|
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
||||||
|
pbar.update('Write {}'.format(key))
|
||||||
|
key_byte = key.encode('ascii')
|
||||||
|
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if 'flow' in mode:
|
||||||
|
H, W = data.shape
|
||||||
|
assert H == H_dst and W == W_dst, 'different shape.'
|
||||||
|
else:
|
||||||
|
H, W, C = data.shape
|
||||||
|
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
|
||||||
|
txn.put(key_byte, data)
|
||||||
|
if not read_all_imgs and idx % BATCH == 0:
|
||||||
|
txn.commit()
|
||||||
|
txn = env.begin(write=True)
|
||||||
|
txn.commit()
|
||||||
|
env.close()
|
||||||
|
print('Finish writing lmdb.')
|
||||||
|
|
||||||
|
#### create meta information
|
||||||
|
meta_info = {}
|
||||||
|
meta_info['name'] = 'REDS_{}_wval'.format(mode)
|
||||||
|
channel = 1 if 'flow' in mode else 3
|
||||||
|
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
|
||||||
|
meta_info['keys'] = keys
|
||||||
|
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
||||||
|
print('Finish creating lmdb meta info.')
|
||||||
|
|
||||||
|
|
||||||
|
def test_lmdb(dataroot, dataset='REDS'):
|
||||||
|
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
|
||||||
|
print('Name: ', meta_info['name'])
|
||||||
|
print('Resolution: ', meta_info['resolution'])
|
||||||
|
print('# keys: ', len(meta_info['keys']))
|
||||||
|
# read one image
|
||||||
|
if dataset == 'vimeo90k':
|
||||||
|
key = '00001_0001_4'
|
||||||
|
else:
|
||||||
|
key = '000_00000000'
|
||||||
|
print('Reading {} for test.'.format(key))
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
buf = txn.get(key.encode('ascii'))
|
||||||
|
img_flat = np.frombuffer(buf, dtype=np.uint8)
|
||||||
|
C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
|
||||||
|
img = img_flat.reshape(H, W, C)
|
||||||
|
cv2.imwrite('test.png', img)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
141
codes/data_scripts/extract_subimages.py
Normal file
141
codes/data_scripts/extract_subimages.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
"""A multi-thread tool to crop large images to sub-images for faster IO."""
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import sys
|
||||||
|
from multiprocessing import Pool
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||||
|
from utils.util import ProgressBar # noqa: E402
|
||||||
|
import data.util as data_util # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
|
||||||
|
opt = {}
|
||||||
|
opt['n_thread'] = 20
|
||||||
|
opt['compression_level'] = 3 # 3 is the default value in cv2
|
||||||
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
if mode == 'single':
|
||||||
|
opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR'
|
||||||
|
opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
||||||
|
opt['crop_sz'] = 480 # the size of each sub-image
|
||||||
|
opt['step'] = 240 # step of the sliding crop window
|
||||||
|
opt['thres_sz'] = 48 # size threshold
|
||||||
|
extract_signle(opt)
|
||||||
|
elif mode == 'pair':
|
||||||
|
GT_folder = '../../datasets/DIV2K/DIV2K_train_HR'
|
||||||
|
LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
|
||||||
|
save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub'
|
||||||
|
save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
|
||||||
|
scale_ratio = 4
|
||||||
|
crop_sz = 480 # the size of each sub-image (GT)
|
||||||
|
step = 240 # step of the sliding crop window (GT)
|
||||||
|
thres_sz = 48 # size threshold
|
||||||
|
########################################################################
|
||||||
|
# check that all the GT and LR images have correct scale ratio
|
||||||
|
img_GT_list = data_util._get_paths_from_images(GT_folder)
|
||||||
|
img_LR_list = data_util._get_paths_from_images(LR_folder)
|
||||||
|
assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
|
||||||
|
for path_GT, path_LR in zip(img_GT_list, img_LR_list):
|
||||||
|
img_GT = Image.open(path_GT)
|
||||||
|
img_LR = Image.open(path_LR)
|
||||||
|
w_GT, h_GT = img_GT.size
|
||||||
|
w_LR, h_LR = img_LR.size
|
||||||
|
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
|
||||||
|
w_GT, scale_ratio, w_LR, path_GT)
|
||||||
|
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
|
||||||
|
w_GT, scale_ratio, w_LR, path_GT)
|
||||||
|
# check crop size, step and threshold size
|
||||||
|
assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
|
||||||
|
scale_ratio)
|
||||||
|
assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
|
||||||
|
assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
|
||||||
|
scale_ratio)
|
||||||
|
print('process GT...')
|
||||||
|
opt['input_folder'] = GT_folder
|
||||||
|
opt['save_folder'] = save_GT_folder
|
||||||
|
opt['crop_sz'] = crop_sz
|
||||||
|
opt['step'] = step
|
||||||
|
opt['thres_sz'] = thres_sz
|
||||||
|
extract_signle(opt)
|
||||||
|
print('process LR...')
|
||||||
|
opt['input_folder'] = LR_folder
|
||||||
|
opt['save_folder'] = save_LR_folder
|
||||||
|
opt['crop_sz'] = crop_sz // scale_ratio
|
||||||
|
opt['step'] = step // scale_ratio
|
||||||
|
opt['thres_sz'] = thres_sz // scale_ratio
|
||||||
|
extract_signle(opt)
|
||||||
|
assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
|
||||||
|
data_util._get_paths_from_images(
|
||||||
|
save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong mode.')
|
||||||
|
|
||||||
|
|
||||||
|
def extract_signle(opt):
|
||||||
|
input_folder = opt['input_folder']
|
||||||
|
save_folder = opt['save_folder']
|
||||||
|
if not osp.exists(save_folder):
|
||||||
|
os.makedirs(save_folder)
|
||||||
|
print('mkdir [{:s}] ...'.format(save_folder))
|
||||||
|
else:
|
||||||
|
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
|
||||||
|
sys.exit(1)
|
||||||
|
img_list = data_util._get_paths_from_images(input_folder)
|
||||||
|
|
||||||
|
def update(arg):
|
||||||
|
pbar.update(arg)
|
||||||
|
|
||||||
|
pbar = ProgressBar(len(img_list))
|
||||||
|
|
||||||
|
pool = Pool(opt['n_thread'])
|
||||||
|
for path in img_list:
|
||||||
|
pool.apply_async(worker, args=(path, opt), callback=update)
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
print('All subprocesses done.')
|
||||||
|
|
||||||
|
|
||||||
|
def worker(path, opt):
|
||||||
|
crop_sz = opt['crop_sz']
|
||||||
|
step = opt['step']
|
||||||
|
thres_sz = opt['thres_sz']
|
||||||
|
img_name = osp.basename(path)
|
||||||
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
n_channels = len(img.shape)
|
||||||
|
if n_channels == 2:
|
||||||
|
h, w = img.shape
|
||||||
|
elif n_channels == 3:
|
||||||
|
h, w, c = img.shape
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong image shape - {}'.format(n_channels))
|
||||||
|
|
||||||
|
h_space = np.arange(0, h - crop_sz + 1, step)
|
||||||
|
if h - (h_space[-1] + crop_sz) > thres_sz:
|
||||||
|
h_space = np.append(h_space, h - crop_sz)
|
||||||
|
w_space = np.arange(0, w - crop_sz + 1, step)
|
||||||
|
if w - (w_space[-1] + crop_sz) > thres_sz:
|
||||||
|
w_space = np.append(w_space, w - crop_sz)
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
for x in h_space:
|
||||||
|
for y in w_space:
|
||||||
|
index += 1
|
||||||
|
if n_channels == 2:
|
||||||
|
crop_img = img[x:x + crop_sz, y:y + crop_sz]
|
||||||
|
else:
|
||||||
|
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
||||||
|
crop_img = np.ascontiguousarray(crop_img)
|
||||||
|
cv2.imwrite(
|
||||||
|
osp.join(opt['save_folder'],
|
||||||
|
img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
|
||||||
|
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
||||||
|
return 'Processing {:s} ...'.format(img_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
49
codes/data_scripts/generate_LR_Vimeo90K.m
Normal file
49
codes/data_scripts/generate_LR_Vimeo90K.m
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
function generate_LR_Vimeo90K()
|
||||||
|
%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
|
||||||
|
|
||||||
|
up_scale = 4;
|
||||||
|
mod_scale = 4;
|
||||||
|
idx = 0;
|
||||||
|
filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
|
||||||
|
for i = 1 : length(filepaths)
|
||||||
|
[~,imname,ext] = fileparts(filepaths(i).name);
|
||||||
|
folder_path = filepaths(i).folder;
|
||||||
|
save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
|
||||||
|
if ~exist(save_LR_folder, 'dir')
|
||||||
|
mkdir(save_LR_folder);
|
||||||
|
end
|
||||||
|
if isempty(imname)
|
||||||
|
disp('Ignore . folder.');
|
||||||
|
elseif strcmp(imname, '.')
|
||||||
|
disp('Ignore .. folder.');
|
||||||
|
else
|
||||||
|
idx = idx + 1;
|
||||||
|
str_rlt = sprintf('%d\t%s.\n', idx, imname);
|
||||||
|
fprintf(str_rlt);
|
||||||
|
% read image
|
||||||
|
img = imread(fullfile(folder_path, [imname, ext]));
|
||||||
|
img = im2double(img);
|
||||||
|
% modcrop
|
||||||
|
img = modcrop(img, mod_scale);
|
||||||
|
% LR
|
||||||
|
im_LR = imresize(img, 1/up_scale, 'bicubic');
|
||||||
|
if exist('save_LR_folder', 'var')
|
||||||
|
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
%% modcrop
|
||||||
|
function img = modcrop(img, modulo)
|
||||||
|
if size(img,3) == 1
|
||||||
|
sz = size(img);
|
||||||
|
sz = sz - mod(sz, modulo);
|
||||||
|
img = img(1:sz(1), 1:sz(2));
|
||||||
|
else
|
||||||
|
tmpsz = size(img);
|
||||||
|
sz = tmpsz(1:2);
|
||||||
|
sz = sz - mod(sz, modulo);
|
||||||
|
img = img(1:sz(1), 1:sz(2),:);
|
||||||
|
end
|
||||||
|
end
|
82
codes/data_scripts/generate_mod_LR_bic.m
Normal file
82
codes/data_scripts/generate_mod_LR_bic.m
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
function generate_mod_LR_bic()
|
||||||
|
%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
|
||||||
|
|
||||||
|
%% set parameters
|
||||||
|
% comment the unnecessary line
|
||||||
|
input_folder = '../../datasets/DIV2K/DIV2K800';
|
||||||
|
% save_mod_folder = '';
|
||||||
|
save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4';
|
||||||
|
% save_bic_folder = '';
|
||||||
|
|
||||||
|
up_scale = 4;
|
||||||
|
mod_scale = 4;
|
||||||
|
|
||||||
|
if exist('save_mod_folder', 'var')
|
||||||
|
if exist(save_mod_folder, 'dir')
|
||||||
|
disp(['It will cover ', save_mod_folder]);
|
||||||
|
else
|
||||||
|
mkdir(save_mod_folder);
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if exist('save_LR_folder', 'var')
|
||||||
|
if exist(save_LR_folder, 'dir')
|
||||||
|
disp(['It will cover ', save_LR_folder]);
|
||||||
|
else
|
||||||
|
mkdir(save_LR_folder);
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if exist('save_bic_folder', 'var')
|
||||||
|
if exist(save_bic_folder, 'dir')
|
||||||
|
disp(['It will cover ', save_bic_folder]);
|
||||||
|
else
|
||||||
|
mkdir(save_bic_folder);
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
idx = 0;
|
||||||
|
filepaths = dir(fullfile(input_folder,'*.*'));
|
||||||
|
for i = 1 : length(filepaths)
|
||||||
|
[paths,imname,ext] = fileparts(filepaths(i).name);
|
||||||
|
if isempty(imname)
|
||||||
|
disp('Ignore . folder.');
|
||||||
|
elseif strcmp(imname, '.')
|
||||||
|
disp('Ignore .. folder.');
|
||||||
|
else
|
||||||
|
idx = idx + 1;
|
||||||
|
str_rlt = sprintf('%d\t%s.\n', idx, imname);
|
||||||
|
fprintf(str_rlt);
|
||||||
|
% read image
|
||||||
|
img = imread(fullfile(input_folder, [imname, ext]));
|
||||||
|
img = im2double(img);
|
||||||
|
% modcrop
|
||||||
|
img = modcrop(img, mod_scale);
|
||||||
|
if exist('save_mod_folder', 'var')
|
||||||
|
imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
|
||||||
|
end
|
||||||
|
% LR
|
||||||
|
im_LR = imresize(img, 1/up_scale, 'bicubic');
|
||||||
|
if exist('save_LR_folder', 'var')
|
||||||
|
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
|
||||||
|
end
|
||||||
|
% Bicubic
|
||||||
|
if exist('save_bic_folder', 'var')
|
||||||
|
im_B = imresize(im_LR, up_scale, 'bicubic');
|
||||||
|
imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
%% modcrop
|
||||||
|
function img = modcrop(img, modulo)
|
||||||
|
if size(img,3) == 1
|
||||||
|
sz = size(img);
|
||||||
|
sz = sz - mod(sz, modulo);
|
||||||
|
img = img(1:sz(1), 1:sz(2));
|
||||||
|
else
|
||||||
|
tmpsz = size(img);
|
||||||
|
sz = tmpsz(1:2);
|
||||||
|
sz = sz - mod(sz, modulo);
|
||||||
|
img = img(1:sz(1), 1:sz(2),:);
|
||||||
|
end
|
||||||
|
end
|
81
codes/data_scripts/generate_mod_LR_bic.py
Normal file
81
codes/data_scripts/generate_mod_LR_bic.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from data.util import imresize_np
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def generate_mod_LR_bic():
|
||||||
|
# set parameters
|
||||||
|
up_scale = 4
|
||||||
|
mod_scale = 4
|
||||||
|
# set data dir
|
||||||
|
sourcedir = '/data/datasets/img'
|
||||||
|
savedir = '/data/datasets/mod'
|
||||||
|
|
||||||
|
saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale))
|
||||||
|
saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
|
||||||
|
saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
|
||||||
|
|
||||||
|
if not os.path.isdir(sourcedir):
|
||||||
|
print('Error: No source data found')
|
||||||
|
exit(0)
|
||||||
|
if not os.path.isdir(savedir):
|
||||||
|
os.mkdir(savedir)
|
||||||
|
|
||||||
|
if not os.path.isdir(os.path.join(savedir, 'HR')):
|
||||||
|
os.mkdir(os.path.join(savedir, 'HR'))
|
||||||
|
if not os.path.isdir(os.path.join(savedir, 'LR')):
|
||||||
|
os.mkdir(os.path.join(savedir, 'LR'))
|
||||||
|
if not os.path.isdir(os.path.join(savedir, 'Bic')):
|
||||||
|
os.mkdir(os.path.join(savedir, 'Bic'))
|
||||||
|
|
||||||
|
if not os.path.isdir(saveHRpath):
|
||||||
|
os.mkdir(saveHRpath)
|
||||||
|
else:
|
||||||
|
print('It will cover ' + str(saveHRpath))
|
||||||
|
|
||||||
|
if not os.path.isdir(saveLRpath):
|
||||||
|
os.mkdir(saveLRpath)
|
||||||
|
else:
|
||||||
|
print('It will cover ' + str(saveLRpath))
|
||||||
|
|
||||||
|
if not os.path.isdir(saveBicpath):
|
||||||
|
os.mkdir(saveBicpath)
|
||||||
|
else:
|
||||||
|
print('It will cover ' + str(saveBicpath))
|
||||||
|
|
||||||
|
filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
|
||||||
|
num_files = len(filepaths)
|
||||||
|
|
||||||
|
# prepare data with augementation
|
||||||
|
for i in range(num_files):
|
||||||
|
filename = filepaths[i]
|
||||||
|
print('No.{} -- Processing {}'.format(i, filename))
|
||||||
|
# read image
|
||||||
|
image = cv2.imread(os.path.join(sourcedir, filename))
|
||||||
|
|
||||||
|
width = int(np.floor(image.shape[1] / mod_scale))
|
||||||
|
height = int(np.floor(image.shape[0] / mod_scale))
|
||||||
|
# modcrop
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
|
||||||
|
else:
|
||||||
|
image_HR = image[0:mod_scale * height, 0:mod_scale * width]
|
||||||
|
# LR
|
||||||
|
image_LR = imresize_np(image_HR, 1 / up_scale, True)
|
||||||
|
# bic
|
||||||
|
image_Bic = imresize_np(image_LR, up_scale, True)
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
|
||||||
|
cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
|
||||||
|
cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
generate_mod_LR_bic()
|
42
codes/data_scripts/prepare_DIV2K_x4_dataset.sh
Normal file
42
codes/data_scripts/prepare_DIV2K_x4_dataset.sh
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
|
||||||
|
|
||||||
|
echo "Prepare DIV2K X4 datasets..."
|
||||||
|
cd ../../datasets
|
||||||
|
mkdir DIV2K
|
||||||
|
cd DIV2K
|
||||||
|
|
||||||
|
#### Step 1
|
||||||
|
echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..."
|
||||||
|
# GT
|
||||||
|
FOLDER=DIV2K_train_HR
|
||||||
|
FILE=DIV2K_train_HR.zip
|
||||||
|
if [ ! -d "$FOLDER" ]; then
|
||||||
|
if [ ! -f "$FILE" ]; then
|
||||||
|
echo "Downloading $FILE..."
|
||||||
|
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
|
||||||
|
fi
|
||||||
|
unzip $FILE
|
||||||
|
fi
|
||||||
|
# LR
|
||||||
|
FOLDER=DIV2K_train_LR_bicubic
|
||||||
|
FILE=DIV2K_train_LR_bicubic_X4.zip
|
||||||
|
if [ ! -d "$FOLDER" ]; then
|
||||||
|
if [ ! -f "$FILE" ]; then
|
||||||
|
echo "Downloading $FILE..."
|
||||||
|
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
|
||||||
|
fi
|
||||||
|
unzip $FILE
|
||||||
|
fi
|
||||||
|
|
||||||
|
#### Step 2
|
||||||
|
echo "Step 2: Rename the LR images..."
|
||||||
|
cd ../../codes/data_scripts
|
||||||
|
python rename.py
|
||||||
|
|
||||||
|
#### Step 4
|
||||||
|
echo "Step 4: Crop to sub-images..."
|
||||||
|
python extract_subimages.py
|
||||||
|
|
||||||
|
#### Step 5
|
||||||
|
echo "Step5: Create LMDB files..."
|
||||||
|
python create_lmdb.py
|
11
codes/data_scripts/regroup_REDS.py
Normal file
11
codes/data_scripts/regroup_REDS.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4'
|
||||||
|
val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4'
|
||||||
|
|
||||||
|
# mv the val set
|
||||||
|
val_folders = glob.glob(os.path.join(val_path, '*'))
|
||||||
|
for folder in val_folders:
|
||||||
|
new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240)
|
||||||
|
os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx)))
|
19
codes/data_scripts/rename.py
Normal file
19
codes/data_scripts/rename.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
|
||||||
|
DIV2K(folder)
|
||||||
|
print('Finished.')
|
||||||
|
|
||||||
|
|
||||||
|
def DIV2K(path):
|
||||||
|
img_path_l = glob.glob(os.path.join(path, '*'))
|
||||||
|
for img_path in img_path_l:
|
||||||
|
new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
|
||||||
|
os.rename(img_path, new_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
104
codes/data_scripts/test_dataloader.py
Normal file
104
codes/data_scripts/test_dataloader.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
import sys
|
||||||
|
import os.path as osp
|
||||||
|
import math
|
||||||
|
import torchvision.utils
|
||||||
|
|
||||||
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||||
|
from data import create_dataloader, create_dataset # noqa: E402
|
||||||
|
from utils import util # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
|
||||||
|
opt = {}
|
||||||
|
opt['dist'] = False
|
||||||
|
opt['gpu_ids'] = [0]
|
||||||
|
if dataset == 'REDS':
|
||||||
|
opt['name'] = 'test_REDS'
|
||||||
|
opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
|
||||||
|
opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
|
||||||
|
opt['mode'] = 'REDS'
|
||||||
|
opt['N_frames'] = 5
|
||||||
|
opt['phase'] = 'train'
|
||||||
|
opt['use_shuffle'] = True
|
||||||
|
opt['n_workers'] = 8
|
||||||
|
opt['batch_size'] = 16
|
||||||
|
opt['GT_size'] = 256
|
||||||
|
opt['LQ_size'] = 64
|
||||||
|
opt['scale'] = 4
|
||||||
|
opt['use_flip'] = True
|
||||||
|
opt['use_rot'] = True
|
||||||
|
opt['interval_list'] = [1]
|
||||||
|
opt['random_reverse'] = False
|
||||||
|
opt['border_mode'] = False
|
||||||
|
opt['cache_keys'] = None
|
||||||
|
opt['data_type'] = 'lmdb' # img | lmdb | mc
|
||||||
|
elif dataset == 'Vimeo90K':
|
||||||
|
opt['name'] = 'test_Vimeo90K'
|
||||||
|
opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
||||||
|
opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
|
||||||
|
opt['mode'] = 'Vimeo90K'
|
||||||
|
opt['N_frames'] = 7
|
||||||
|
opt['phase'] = 'train'
|
||||||
|
opt['use_shuffle'] = True
|
||||||
|
opt['n_workers'] = 8
|
||||||
|
opt['batch_size'] = 16
|
||||||
|
opt['GT_size'] = 256
|
||||||
|
opt['LQ_size'] = 64
|
||||||
|
opt['scale'] = 4
|
||||||
|
opt['use_flip'] = True
|
||||||
|
opt['use_rot'] = True
|
||||||
|
opt['interval_list'] = [1]
|
||||||
|
opt['random_reverse'] = False
|
||||||
|
opt['border_mode'] = False
|
||||||
|
opt['cache_keys'] = None
|
||||||
|
opt['data_type'] = 'lmdb' # img | lmdb | mc
|
||||||
|
elif dataset == 'DIV2K800_sub':
|
||||||
|
opt['name'] = 'DIV2K800'
|
||||||
|
opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
||||||
|
opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
|
||||||
|
opt['mode'] = 'LQGT'
|
||||||
|
opt['phase'] = 'train'
|
||||||
|
opt['use_shuffle'] = True
|
||||||
|
opt['n_workers'] = 8
|
||||||
|
opt['batch_size'] = 16
|
||||||
|
opt['GT_size'] = 128
|
||||||
|
opt['scale'] = 4
|
||||||
|
opt['use_flip'] = True
|
||||||
|
opt['use_rot'] = True
|
||||||
|
opt['color'] = 'RGB'
|
||||||
|
opt['data_type'] = 'lmdb' # img | lmdb
|
||||||
|
else:
|
||||||
|
raise ValueError('Please implement by yourself.')
|
||||||
|
|
||||||
|
util.mkdir('tmp')
|
||||||
|
train_set = create_dataset(opt)
|
||||||
|
train_loader = create_dataloader(train_set, opt, opt, None)
|
||||||
|
nrow = int(math.sqrt(opt['batch_size']))
|
||||||
|
padding = 2 if opt['phase'] == 'train' else 0
|
||||||
|
|
||||||
|
print('start...')
|
||||||
|
for i, data in enumerate(train_loader):
|
||||||
|
if i > 5:
|
||||||
|
break
|
||||||
|
print(i)
|
||||||
|
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
||||||
|
LQs = data['LQs']
|
||||||
|
else:
|
||||||
|
LQ = data['LQ']
|
||||||
|
GT = data['GT']
|
||||||
|
|
||||||
|
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
||||||
|
for j in range(LQs.size(1)):
|
||||||
|
torchvision.utils.save_image(LQs[:, j, :, :, :],
|
||||||
|
'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
|
||||||
|
padding=padding, normalize=False)
|
||||||
|
else:
|
||||||
|
torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
|
||||||
|
padding=padding, normalize=False)
|
||||||
|
torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
|
||||||
|
normalize=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
261
codes/metrics/calculate_PSNR_SSIM.m
Normal file
261
codes/metrics/calculate_PSNR_SSIM.m
Normal file
|
@ -0,0 +1,261 @@
|
||||||
|
function calculate_PSNR_SSIM()
|
||||||
|
|
||||||
|
% GT and SR folder
|
||||||
|
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
|
||||||
|
folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
|
||||||
|
scale = 4;
|
||||||
|
suffix = ''; % suffix for SR images
|
||||||
|
test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
|
||||||
|
if test_Y
|
||||||
|
fprintf('Tesing Y channel.\n');
|
||||||
|
else
|
||||||
|
fprintf('Tesing RGB channels.\n');
|
||||||
|
end
|
||||||
|
filepaths = dir(fullfile(folder_GT, '*.png'));
|
||||||
|
PSNR_all = zeros(1, length(filepaths));
|
||||||
|
SSIM_all = zeros(1, length(filepaths));
|
||||||
|
|
||||||
|
for idx_im = 1:length(filepaths)
|
||||||
|
im_name = filepaths(idx_im).name;
|
||||||
|
im_GT = imread(fullfile(folder_GT, im_name));
|
||||||
|
im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
|
||||||
|
|
||||||
|
if test_Y % evaluate on Y channel in YCbCr color space
|
||||||
|
if size(im_GT, 3) == 3
|
||||||
|
im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
|
||||||
|
im_GT_in = im_GT_YCbCr(:,:,1);
|
||||||
|
im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
|
||||||
|
im_SR_in = im_SR_YCbCr(:,:,1);
|
||||||
|
else
|
||||||
|
im_GT_in = im2double(im_GT);
|
||||||
|
im_SR_in = im2double(im_SR);
|
||||||
|
end
|
||||||
|
else % evaluate on RGB channels
|
||||||
|
im_GT_in = im2double(im_GT);
|
||||||
|
im_SR_in = im2double(im_SR);
|
||||||
|
end
|
||||||
|
|
||||||
|
% calculate PSNR and SSIM
|
||||||
|
PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
|
||||||
|
SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
|
||||||
|
fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
|
||||||
|
end
|
||||||
|
|
||||||
|
fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
|
||||||
|
end
|
||||||
|
|
||||||
|
function res = calculate_PSNR(GT, SR, border)
|
||||||
|
% remove border
|
||||||
|
GT = GT(border+1:end-border, border+1:end-border, :);
|
||||||
|
SR = SR(border+1:end-border, border+1:end-border, :);
|
||||||
|
% calculate PNSR (assume in [0,255])
|
||||||
|
error = GT(:) - SR(:);
|
||||||
|
mse = mean(error.^2);
|
||||||
|
res = 10 * log10(255^2/mse);
|
||||||
|
end
|
||||||
|
|
||||||
|
function res = calculate_SSIM(GT, SR, border)
|
||||||
|
GT = GT(border+1:end-border, border+1:end-border, :);
|
||||||
|
SR = SR(border+1:end-border, border+1:end-border, :);
|
||||||
|
% calculate SSIM
|
||||||
|
mssim = zeros(1, size(SR, 3));
|
||||||
|
for i = 1:size(SR,3)
|
||||||
|
[mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
|
||||||
|
end
|
||||||
|
res = mean(mssim);
|
||||||
|
end
|
||||||
|
|
||||||
|
function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
|
||||||
|
|
||||||
|
%========================================================================
|
||||||
|
%SSIM Index, Version 1.0
|
||||||
|
%Copyright(c) 2003 Zhou Wang
|
||||||
|
%All Rights Reserved.
|
||||||
|
%
|
||||||
|
%The author is with Howard Hughes Medical Institute, and Laboratory
|
||||||
|
%for Computational Vision at Center for Neural Science and Courant
|
||||||
|
%Institute of Mathematical Sciences, New York University.
|
||||||
|
%
|
||||||
|
%----------------------------------------------------------------------
|
||||||
|
%Permission to use, copy, or modify this software and its documentation
|
||||||
|
%for educational and research purposes only and without fee is hereby
|
||||||
|
%granted, provided that this copyright notice and the original authors'
|
||||||
|
%names appear on all copies and supporting documentation. This program
|
||||||
|
%shall not be used, rewritten, or adapted as the basis of a commercial
|
||||||
|
%software or hardware product without first obtaining permission of the
|
||||||
|
%authors. The authors make no representations about the suitability of
|
||||||
|
%this software for any purpose. It is provided "as is" without express
|
||||||
|
%or implied warranty.
|
||||||
|
%----------------------------------------------------------------------
|
||||||
|
%
|
||||||
|
%This is an implementation of the algorithm for calculating the
|
||||||
|
%Structural SIMilarity (SSIM) index between two images. Please refer
|
||||||
|
%to the following paper:
|
||||||
|
%
|
||||||
|
%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
|
||||||
|
%quality assessment: From error measurement to structural similarity"
|
||||||
|
%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
|
||||||
|
%
|
||||||
|
%Kindly report any suggestions or corrections to zhouwang@ieee.org
|
||||||
|
%
|
||||||
|
%----------------------------------------------------------------------
|
||||||
|
%
|
||||||
|
%Input : (1) img1: the first image being compared
|
||||||
|
% (2) img2: the second image being compared
|
||||||
|
% (3) K: constants in the SSIM index formula (see the above
|
||||||
|
% reference). defualt value: K = [0.01 0.03]
|
||||||
|
% (4) window: local window for statistics (see the above
|
||||||
|
% reference). default widnow is Gaussian given by
|
||||||
|
% window = fspecial('gaussian', 11, 1.5);
|
||||||
|
% (5) L: dynamic range of the images. default: L = 255
|
||||||
|
%
|
||||||
|
%Output: (1) mssim: the mean SSIM index value between 2 images.
|
||||||
|
% If one of the images being compared is regarded as
|
||||||
|
% perfect quality, then mssim can be considered as the
|
||||||
|
% quality measure of the other image.
|
||||||
|
% If img1 = img2, then mssim = 1.
|
||||||
|
% (2) ssim_map: the SSIM index map of the test image. The map
|
||||||
|
% has a smaller size than the input images. The actual size:
|
||||||
|
% size(img1) - size(window) + 1.
|
||||||
|
%
|
||||||
|
%Default Usage:
|
||||||
|
% Given 2 test images img1 and img2, whose dynamic range is 0-255
|
||||||
|
%
|
||||||
|
% [mssim ssim_map] = ssim_index(img1, img2);
|
||||||
|
%
|
||||||
|
%Advanced Usage:
|
||||||
|
% User defined parameters. For example
|
||||||
|
%
|
||||||
|
% K = [0.05 0.05];
|
||||||
|
% window = ones(8);
|
||||||
|
% L = 100;
|
||||||
|
% [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
|
||||||
|
%
|
||||||
|
%See the results:
|
||||||
|
%
|
||||||
|
% mssim %Gives the mssim value
|
||||||
|
% imshow(max(0, ssim_map).^4) %Shows the SSIM index map
|
||||||
|
%
|
||||||
|
%========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
if (nargin < 2 || nargin > 5)
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return;
|
||||||
|
end
|
||||||
|
|
||||||
|
if (size(img1) ~= size(img2))
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return;
|
||||||
|
end
|
||||||
|
|
||||||
|
[M, N] = size(img1);
|
||||||
|
|
||||||
|
if (nargin == 2)
|
||||||
|
if ((M < 11) || (N < 11))
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return
|
||||||
|
end
|
||||||
|
window = fspecial('gaussian', 11, 1.5); %
|
||||||
|
K(1) = 0.01; % default settings
|
||||||
|
K(2) = 0.03; %
|
||||||
|
L = 255; %
|
||||||
|
end
|
||||||
|
|
||||||
|
if (nargin == 3)
|
||||||
|
if ((M < 11) || (N < 11))
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return
|
||||||
|
end
|
||||||
|
window = fspecial('gaussian', 11, 1.5);
|
||||||
|
L = 255;
|
||||||
|
if (length(K) == 2)
|
||||||
|
if (K(1) < 0 || K(2) < 0)
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return;
|
||||||
|
end
|
||||||
|
else
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return;
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if (nargin == 4)
|
||||||
|
[H, W] = size(window);
|
||||||
|
if ((H*W) < 4 || (H > M) || (W > N))
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return
|
||||||
|
end
|
||||||
|
L = 255;
|
||||||
|
if (length(K) == 2)
|
||||||
|
if (K(1) < 0 || K(2) < 0)
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return;
|
||||||
|
end
|
||||||
|
else
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return;
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if (nargin == 5)
|
||||||
|
[H, W] = size(window);
|
||||||
|
if ((H*W) < 4 || (H > M) || (W > N))
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return
|
||||||
|
end
|
||||||
|
if (length(K) == 2)
|
||||||
|
if (K(1) < 0 || K(2) < 0)
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return;
|
||||||
|
end
|
||||||
|
else
|
||||||
|
ssim_index = -Inf;
|
||||||
|
ssim_map = -Inf;
|
||||||
|
return;
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
C1 = (K(1)*L)^2;
|
||||||
|
C2 = (K(2)*L)^2;
|
||||||
|
window = window/sum(sum(window));
|
||||||
|
img1 = double(img1);
|
||||||
|
img2 = double(img2);
|
||||||
|
|
||||||
|
mu1 = filter2(window, img1, 'valid');
|
||||||
|
mu2 = filter2(window, img2, 'valid');
|
||||||
|
mu1_sq = mu1.*mu1;
|
||||||
|
mu2_sq = mu2.*mu2;
|
||||||
|
mu1_mu2 = mu1.*mu2;
|
||||||
|
sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
|
||||||
|
sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
|
||||||
|
sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
|
||||||
|
|
||||||
|
if (C1 > 0 && C2 > 0)
|
||||||
|
ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
|
||||||
|
else
|
||||||
|
numerator1 = 2*mu1_mu2 + C1;
|
||||||
|
numerator2 = 2*sigma12 + C2;
|
||||||
|
denominator1 = mu1_sq + mu2_sq + C1;
|
||||||
|
denominator2 = sigma1_sq + sigma2_sq + C2;
|
||||||
|
ssim_map = ones(size(mu1));
|
||||||
|
index = (denominator1.*denominator2 > 0);
|
||||||
|
ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
|
||||||
|
index = (denominator1 ~= 0) & (denominator2 == 0);
|
||||||
|
ssim_map(index) = numerator1(index)./denominator1(index);
|
||||||
|
end
|
||||||
|
|
||||||
|
mssim = mean2(ssim_map);
|
||||||
|
|
||||||
|
end
|
147
codes/metrics/calculate_PSNR_SSIM.py
Normal file
147
codes/metrics/calculate_PSNR_SSIM.py
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
'''
|
||||||
|
calculate the PSNR and SSIM.
|
||||||
|
same as MATLAB's results
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Configurations
|
||||||
|
|
||||||
|
# GT - Ground-truth;
|
||||||
|
# Gen: Generated / Restored / Recovered images
|
||||||
|
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
|
||||||
|
folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
|
||||||
|
|
||||||
|
crop_border = 4
|
||||||
|
suffix = '' # suffix for Gen images
|
||||||
|
test_Y = False # True: test Y channel only; False: test RGB channels
|
||||||
|
|
||||||
|
PSNR_all = []
|
||||||
|
SSIM_all = []
|
||||||
|
img_list = sorted(glob.glob(folder_GT + '/*'))
|
||||||
|
|
||||||
|
if test_Y:
|
||||||
|
print('Testing Y channel.')
|
||||||
|
else:
|
||||||
|
print('Testing RGB channels.')
|
||||||
|
|
||||||
|
for i, img_path in enumerate(img_list):
|
||||||
|
base_name = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
|
im_GT = cv2.imread(img_path) / 255.
|
||||||
|
im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
|
||||||
|
|
||||||
|
if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
|
||||||
|
im_GT_in = bgr2ycbcr(im_GT)
|
||||||
|
im_Gen_in = bgr2ycbcr(im_Gen)
|
||||||
|
else:
|
||||||
|
im_GT_in = im_GT
|
||||||
|
im_Gen_in = im_Gen
|
||||||
|
|
||||||
|
# crop borders
|
||||||
|
if im_GT_in.ndim == 3:
|
||||||
|
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
|
||||||
|
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
|
||||||
|
elif im_GT_in.ndim == 2:
|
||||||
|
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
|
||||||
|
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
|
||||||
|
|
||||||
|
# calculate PSNR and SSIM
|
||||||
|
PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
|
||||||
|
|
||||||
|
SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
|
||||||
|
print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
|
||||||
|
i + 1, base_name, PSNR, SSIM))
|
||||||
|
PSNR_all.append(PSNR)
|
||||||
|
SSIM_all.append(SSIM)
|
||||||
|
print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
|
||||||
|
sum(PSNR_all) / len(PSNR_all),
|
||||||
|
sum(SSIM_all) / len(SSIM_all)))
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_psnr(img1, img2):
|
||||||
|
# img1 and img2 have range [0, 255]
|
||||||
|
img1 = img1.astype(np.float64)
|
||||||
|
img2 = img2.astype(np.float64)
|
||||||
|
mse = np.mean((img1 - img2)**2)
|
||||||
|
if mse == 0:
|
||||||
|
return float('inf')
|
||||||
|
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||||
|
|
||||||
|
|
||||||
|
def ssim(img1, img2):
|
||||||
|
C1 = (0.01 * 255)**2
|
||||||
|
C2 = (0.03 * 255)**2
|
||||||
|
|
||||||
|
img1 = img1.astype(np.float64)
|
||||||
|
img2 = img2.astype(np.float64)
|
||||||
|
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||||
|
window = np.outer(kernel, kernel.transpose())
|
||||||
|
|
||||||
|
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||||
|
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||||
|
mu1_sq = mu1**2
|
||||||
|
mu2_sq = mu2**2
|
||||||
|
mu1_mu2 = mu1 * mu2
|
||||||
|
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||||
|
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||||
|
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||||
|
|
||||||
|
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||||
|
(sigma1_sq + sigma2_sq + C2))
|
||||||
|
return ssim_map.mean()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_ssim(img1, img2):
|
||||||
|
'''calculate SSIM
|
||||||
|
the same outputs as MATLAB's
|
||||||
|
img1, img2: [0, 255]
|
||||||
|
'''
|
||||||
|
if not img1.shape == img2.shape:
|
||||||
|
raise ValueError('Input images must have the same dimensions.')
|
||||||
|
if img1.ndim == 2:
|
||||||
|
return ssim(img1, img2)
|
||||||
|
elif img1.ndim == 3:
|
||||||
|
if img1.shape[2] == 3:
|
||||||
|
ssims = []
|
||||||
|
for i in range(3):
|
||||||
|
ssims.append(ssim(img1, img2))
|
||||||
|
return np.array(ssims).mean()
|
||||||
|
elif img1.shape[2] == 1:
|
||||||
|
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong input image dimensions.')
|
||||||
|
|
||||||
|
|
||||||
|
def bgr2ycbcr(img, only_y=True):
|
||||||
|
'''same as matlab rgb2ycbcr
|
||||||
|
only_y: only return Y channel
|
||||||
|
Input:
|
||||||
|
uint8, [0, 255]
|
||||||
|
float, [0, 1]
|
||||||
|
'''
|
||||||
|
in_img_type = img.dtype
|
||||||
|
img.astype(np.float32)
|
||||||
|
if in_img_type != np.uint8:
|
||||||
|
img *= 255.
|
||||||
|
# convert
|
||||||
|
if only_y:
|
||||||
|
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||||
|
else:
|
||||||
|
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||||
|
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
||||||
|
if in_img_type == np.uint8:
|
||||||
|
rlt = rlt.round()
|
||||||
|
else:
|
||||||
|
rlt /= 255.
|
||||||
|
return rlt.astype(in_img_type)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
267
codes/models/SRGAN_model.py
Normal file
267
codes/models/SRGAN_model.py
Normal file
|
@ -0,0 +1,267 @@
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||||
|
import models.networks as networks
|
||||||
|
import models.lr_scheduler as lr_scheduler
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from models.loss import GANLoss
|
||||||
|
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
class SRGANModel(BaseModel):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(SRGANModel, self).__init__(opt)
|
||||||
|
if opt['dist']:
|
||||||
|
self.rank = torch.distributed.get_rank()
|
||||||
|
else:
|
||||||
|
self.rank = -1 # non dist training
|
||||||
|
train_opt = opt['train']
|
||||||
|
|
||||||
|
# define networks and load pretrained models
|
||||||
|
self.netG = networks.define_G(opt).to(self.device)
|
||||||
|
if opt['dist']:
|
||||||
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||||
|
else:
|
||||||
|
self.netG = DataParallel(self.netG)
|
||||||
|
if self.is_train:
|
||||||
|
self.netD = networks.define_D(opt).to(self.device)
|
||||||
|
if opt['dist']:
|
||||||
|
self.netD = DistributedDataParallel(self.netD,
|
||||||
|
device_ids=[torch.cuda.current_device()])
|
||||||
|
else:
|
||||||
|
self.netD = DataParallel(self.netD)
|
||||||
|
|
||||||
|
self.netG.train()
|
||||||
|
self.netD.train()
|
||||||
|
|
||||||
|
# define losses, optimizer and scheduler
|
||||||
|
if self.is_train:
|
||||||
|
# G pixel loss
|
||||||
|
if train_opt['pixel_weight'] > 0:
|
||||||
|
l_pix_type = train_opt['pixel_criterion']
|
||||||
|
if l_pix_type == 'l1':
|
||||||
|
self.cri_pix = nn.L1Loss().to(self.device)
|
||||||
|
elif l_pix_type == 'l2':
|
||||||
|
self.cri_pix = nn.MSELoss().to(self.device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||||||
|
self.l_pix_w = train_opt['pixel_weight']
|
||||||
|
else:
|
||||||
|
logger.info('Remove pixel loss.')
|
||||||
|
self.cri_pix = None
|
||||||
|
|
||||||
|
# G feature loss
|
||||||
|
if train_opt['feature_weight'] > 0:
|
||||||
|
l_fea_type = train_opt['feature_criterion']
|
||||||
|
if l_fea_type == 'l1':
|
||||||
|
self.cri_fea = nn.L1Loss().to(self.device)
|
||||||
|
elif l_fea_type == 'l2':
|
||||||
|
self.cri_fea = nn.MSELoss().to(self.device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
||||||
|
self.l_fea_w = train_opt['feature_weight']
|
||||||
|
else:
|
||||||
|
logger.info('Remove feature loss.')
|
||||||
|
self.cri_fea = None
|
||||||
|
if self.cri_fea: # load VGG perceptual loss
|
||||||
|
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
||||||
|
if opt['dist']:
|
||||||
|
self.netF = DistributedDataParallel(self.netF,
|
||||||
|
device_ids=[torch.cuda.current_device()])
|
||||||
|
else:
|
||||||
|
self.netF = DataParallel(self.netF)
|
||||||
|
|
||||||
|
# GD gan loss
|
||||||
|
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||||
|
self.l_gan_w = train_opt['gan_weight']
|
||||||
|
# D_update_ratio and D_init_iters
|
||||||
|
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||||
|
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||||
|
|
||||||
|
# optimizers
|
||||||
|
# G
|
||||||
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||||
|
optim_params = []
|
||||||
|
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
||||||
|
if v.requires_grad:
|
||||||
|
optim_params.append(v)
|
||||||
|
else:
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||||
|
weight_decay=wd_G,
|
||||||
|
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||||
|
self.optimizers.append(self.optimizer_G)
|
||||||
|
# D
|
||||||
|
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||||
|
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
|
||||||
|
weight_decay=wd_D,
|
||||||
|
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
||||||
|
self.optimizers.append(self.optimizer_D)
|
||||||
|
|
||||||
|
# schedulers
|
||||||
|
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(
|
||||||
|
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
||||||
|
restarts=train_opt['restarts'],
|
||||||
|
weights=train_opt['restart_weights'],
|
||||||
|
gamma=train_opt['lr_gamma'],
|
||||||
|
clear_state=train_opt['clear_state']))
|
||||||
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(
|
||||||
|
lr_scheduler.CosineAnnealingLR_Restart(
|
||||||
|
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
||||||
|
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||||
|
|
||||||
|
self.log_dict = OrderedDict()
|
||||||
|
|
||||||
|
self.print_network() # print network
|
||||||
|
self.load() # load G and D if needed
|
||||||
|
|
||||||
|
def feed_data(self, data, need_GT=True):
|
||||||
|
self.var_L = data['LQ'].to(self.device) # LQ
|
||||||
|
if need_GT:
|
||||||
|
self.var_H = data['GT'].to(self.device) # GT
|
||||||
|
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||||
|
self.var_ref = input_ref.to(self.device)
|
||||||
|
|
||||||
|
def optimize_parameters(self, step):
|
||||||
|
# G
|
||||||
|
for p in self.netD.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
self.fake_H = self.netG(self.var_L)
|
||||||
|
|
||||||
|
l_g_total = 0
|
||||||
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
|
if self.cri_pix: # pixel loss
|
||||||
|
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
|
||||||
|
l_g_total += l_g_pix
|
||||||
|
if self.cri_fea: # feature loss
|
||||||
|
real_fea = self.netF(self.var_H).detach()
|
||||||
|
fake_fea = self.netF(self.fake_H)
|
||||||
|
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||||
|
l_g_total += l_g_fea
|
||||||
|
|
||||||
|
pred_g_fake = self.netD(self.fake_H)
|
||||||
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||||
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
|
pred_d_real = self.netD(self.var_ref).detach()
|
||||||
|
l_g_gan = self.l_gan_w * (
|
||||||
|
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
|
l_g_total += l_g_gan
|
||||||
|
|
||||||
|
l_g_total.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
|
||||||
|
# D
|
||||||
|
for p in self.netD.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
self.optimizer_D.zero_grad()
|
||||||
|
l_d_total = 0
|
||||||
|
pred_d_real = self.netD(self.var_ref)
|
||||||
|
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||||
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
|
l_d_real = self.cri_gan(pred_d_real, True)
|
||||||
|
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||||
|
l_d_total = l_d_real + l_d_fake
|
||||||
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
|
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||||
|
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||||
|
l_d_total = (l_d_real + l_d_fake) / 2
|
||||||
|
|
||||||
|
l_d_total.backward()
|
||||||
|
self.optimizer_D.step()
|
||||||
|
|
||||||
|
# set log
|
||||||
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
|
if self.cri_pix:
|
||||||
|
self.log_dict['l_g_pix'] = l_g_pix.item()
|
||||||
|
if self.cri_fea:
|
||||||
|
self.log_dict['l_g_fea'] = l_g_fea.item()
|
||||||
|
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||||
|
|
||||||
|
self.log_dict['l_d_real'] = l_d_real.item()
|
||||||
|
self.log_dict['l_d_fake'] = l_d_fake.item()
|
||||||
|
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
||||||
|
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
self.netG.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
self.fake_H = self.netG(self.var_L)
|
||||||
|
self.netG.train()
|
||||||
|
|
||||||
|
def get_current_log(self):
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
out_dict = OrderedDict()
|
||||||
|
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
|
||||||
|
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
|
||||||
|
if need_GT:
|
||||||
|
out_dict['GT'] = self.var_H.detach()[0].float().cpu()
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
def print_network(self):
|
||||||
|
# Generator
|
||||||
|
s, n = self.get_network_description(self.netG)
|
||||||
|
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
||||||
|
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
||||||
|
self.netG.module.__class__.__name__)
|
||||||
|
else:
|
||||||
|
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||||
|
logger.info(s)
|
||||||
|
if self.is_train:
|
||||||
|
# Discriminator
|
||||||
|
s, n = self.get_network_description(self.netD)
|
||||||
|
if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
|
||||||
|
DistributedDataParallel):
|
||||||
|
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
||||||
|
self.netD.module.__class__.__name__)
|
||||||
|
else:
|
||||||
|
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.info('Network D structure: {}, with parameters: {:,d}'.format(
|
||||||
|
net_struc_str, n))
|
||||||
|
logger.info(s)
|
||||||
|
|
||||||
|
if self.cri_fea: # F, Perceptual Network
|
||||||
|
s, n = self.get_network_description(self.netF)
|
||||||
|
if isinstance(self.netF, nn.DataParallel) or isinstance(
|
||||||
|
self.netF, DistributedDataParallel):
|
||||||
|
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
||||||
|
self.netF.module.__class__.__name__)
|
||||||
|
else:
|
||||||
|
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.info('Network F structure: {}, with parameters: {:,d}'.format(
|
||||||
|
net_struc_str, n))
|
||||||
|
logger.info(s)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
load_path_G = self.opt['path']['pretrain_model_G']
|
||||||
|
if load_path_G is not None:
|
||||||
|
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
||||||
|
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
||||||
|
load_path_D = self.opt['path']['pretrain_model_D']
|
||||||
|
if self.opt['is_train'] and load_path_D is not None:
|
||||||
|
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
|
||||||
|
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
|
||||||
|
|
||||||
|
def save(self, iter_step):
|
||||||
|
self.save_network(self.netG, 'G', iter_step)
|
||||||
|
self.save_network(self.netD, 'D', iter_step)
|
170
codes/models/SR_model.py
Normal file
170
codes/models/SR_model.py
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||||
|
import models.networks as networks
|
||||||
|
import models.lr_scheduler as lr_scheduler
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from models.loss import CharbonnierLoss
|
||||||
|
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
class SRModel(BaseModel):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(SRModel, self).__init__(opt)
|
||||||
|
|
||||||
|
if opt['dist']:
|
||||||
|
self.rank = torch.distributed.get_rank()
|
||||||
|
else:
|
||||||
|
self.rank = -1 # non dist training
|
||||||
|
train_opt = opt['train']
|
||||||
|
|
||||||
|
# define network and load pretrained models
|
||||||
|
self.netG = networks.define_G(opt).to(self.device)
|
||||||
|
if opt['dist']:
|
||||||
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||||
|
else:
|
||||||
|
self.netG = DataParallel(self.netG)
|
||||||
|
# print network
|
||||||
|
self.print_network()
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
if self.is_train:
|
||||||
|
self.netG.train()
|
||||||
|
|
||||||
|
# loss
|
||||||
|
loss_type = train_opt['pixel_criterion']
|
||||||
|
if loss_type == 'l1':
|
||||||
|
self.cri_pix = nn.L1Loss().to(self.device)
|
||||||
|
elif loss_type == 'l2':
|
||||||
|
self.cri_pix = nn.MSELoss().to(self.device)
|
||||||
|
elif loss_type == 'cb':
|
||||||
|
self.cri_pix = CharbonnierLoss().to(self.device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
|
||||||
|
self.l_pix_w = train_opt['pixel_weight']
|
||||||
|
|
||||||
|
# optimizers
|
||||||
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||||
|
optim_params = []
|
||||||
|
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
||||||
|
if v.requires_grad:
|
||||||
|
optim_params.append(v)
|
||||||
|
else:
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||||
|
weight_decay=wd_G,
|
||||||
|
betas=(train_opt['beta1'], train_opt['beta2']))
|
||||||
|
self.optimizers.append(self.optimizer_G)
|
||||||
|
|
||||||
|
# schedulers
|
||||||
|
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(
|
||||||
|
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
||||||
|
restarts=train_opt['restarts'],
|
||||||
|
weights=train_opt['restart_weights'],
|
||||||
|
gamma=train_opt['lr_gamma'],
|
||||||
|
clear_state=train_opt['clear_state']))
|
||||||
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(
|
||||||
|
lr_scheduler.CosineAnnealingLR_Restart(
|
||||||
|
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
||||||
|
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||||
|
|
||||||
|
self.log_dict = OrderedDict()
|
||||||
|
|
||||||
|
def feed_data(self, data, need_GT=True):
|
||||||
|
self.var_L = data['LQ'].to(self.device) # LQ
|
||||||
|
if need_GT:
|
||||||
|
self.real_H = data['GT'].to(self.device) # GT
|
||||||
|
|
||||||
|
def optimize_parameters(self, step):
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
self.fake_H = self.netG(self.var_L)
|
||||||
|
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
|
||||||
|
l_pix.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
|
||||||
|
# set log
|
||||||
|
self.log_dict['l_pix'] = l_pix.item()
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
self.netG.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
self.fake_H = self.netG(self.var_L)
|
||||||
|
self.netG.train()
|
||||||
|
|
||||||
|
def test_x8(self):
|
||||||
|
# from https://github.com/thstkdgus35/EDSR-PyTorch
|
||||||
|
self.netG.eval()
|
||||||
|
|
||||||
|
def _transform(v, op):
|
||||||
|
# if self.precision != 'single': v = v.float()
|
||||||
|
v2np = v.data.cpu().numpy()
|
||||||
|
if op == 'v':
|
||||||
|
tfnp = v2np[:, :, :, ::-1].copy()
|
||||||
|
elif op == 'h':
|
||||||
|
tfnp = v2np[:, :, ::-1, :].copy()
|
||||||
|
elif op == 't':
|
||||||
|
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
|
||||||
|
|
||||||
|
ret = torch.Tensor(tfnp).to(self.device)
|
||||||
|
# if self.precision == 'half': ret = ret.half()
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
lr_list = [self.var_L]
|
||||||
|
for tf in 'v', 'h', 't':
|
||||||
|
lr_list.extend([_transform(t, tf) for t in lr_list])
|
||||||
|
with torch.no_grad():
|
||||||
|
sr_list = [self.netG(aug) for aug in lr_list]
|
||||||
|
for i in range(len(sr_list)):
|
||||||
|
if i > 3:
|
||||||
|
sr_list[i] = _transform(sr_list[i], 't')
|
||||||
|
if i % 4 > 1:
|
||||||
|
sr_list[i] = _transform(sr_list[i], 'h')
|
||||||
|
if (i % 4) % 2 == 1:
|
||||||
|
sr_list[i] = _transform(sr_list[i], 'v')
|
||||||
|
|
||||||
|
output_cat = torch.cat(sr_list, dim=0)
|
||||||
|
self.fake_H = output_cat.mean(dim=0, keepdim=True)
|
||||||
|
self.netG.train()
|
||||||
|
|
||||||
|
def get_current_log(self):
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
out_dict = OrderedDict()
|
||||||
|
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
|
||||||
|
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
|
||||||
|
if need_GT:
|
||||||
|
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
def print_network(self):
|
||||||
|
s, n = self.get_network_description(self.netG)
|
||||||
|
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
||||||
|
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
||||||
|
self.netG.module.__class__.__name__)
|
||||||
|
else:
|
||||||
|
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||||
|
logger.info(s)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
load_path_G = self.opt['path']['pretrain_model_G']
|
||||||
|
if load_path_G is not None:
|
||||||
|
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
||||||
|
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
||||||
|
|
||||||
|
def save(self, iter_label):
|
||||||
|
self.save_network(self.netG, 'G', iter_label)
|
166
codes/models/Video_base_model.py
Normal file
166
codes/models/Video_base_model.py
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||||
|
import models.networks as networks
|
||||||
|
import models.lr_scheduler as lr_scheduler
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from models.loss import CharbonnierLoss
|
||||||
|
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
class VideoBaseModel(BaseModel):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(VideoBaseModel, self).__init__(opt)
|
||||||
|
|
||||||
|
if opt['dist']:
|
||||||
|
self.rank = torch.distributed.get_rank()
|
||||||
|
else:
|
||||||
|
self.rank = -1 # non dist training
|
||||||
|
train_opt = opt['train']
|
||||||
|
|
||||||
|
# define network and load pretrained models
|
||||||
|
self.netG = networks.define_G(opt).to(self.device)
|
||||||
|
if opt['dist']:
|
||||||
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||||
|
else:
|
||||||
|
self.netG = DataParallel(self.netG)
|
||||||
|
# print network
|
||||||
|
self.print_network()
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
if self.is_train:
|
||||||
|
self.netG.train()
|
||||||
|
|
||||||
|
#### loss
|
||||||
|
loss_type = train_opt['pixel_criterion']
|
||||||
|
if loss_type == 'l1':
|
||||||
|
self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
|
||||||
|
elif loss_type == 'l2':
|
||||||
|
self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
|
||||||
|
elif loss_type == 'cb':
|
||||||
|
self.cri_pix = CharbonnierLoss().to(self.device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
|
||||||
|
self.l_pix_w = train_opt['pixel_weight']
|
||||||
|
|
||||||
|
#### optimizers
|
||||||
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||||
|
if train_opt['ft_tsa_only']:
|
||||||
|
normal_params = []
|
||||||
|
tsa_fusion_params = []
|
||||||
|
for k, v in self.netG.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
if 'tsa_fusion' in k:
|
||||||
|
tsa_fusion_params.append(v)
|
||||||
|
else:
|
||||||
|
normal_params.append(v)
|
||||||
|
else:
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
|
optim_params = [
|
||||||
|
{ # add normal params first
|
||||||
|
'params': normal_params,
|
||||||
|
'lr': train_opt['lr_G']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': tsa_fusion_params,
|
||||||
|
'lr': train_opt['lr_G']
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
optim_params = []
|
||||||
|
for k, v in self.netG.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
optim_params.append(v)
|
||||||
|
else:
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
|
|
||||||
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||||
|
weight_decay=wd_G,
|
||||||
|
betas=(train_opt['beta1'], train_opt['beta2']))
|
||||||
|
self.optimizers.append(self.optimizer_G)
|
||||||
|
|
||||||
|
#### schedulers
|
||||||
|
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(
|
||||||
|
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
||||||
|
restarts=train_opt['restarts'],
|
||||||
|
weights=train_opt['restart_weights'],
|
||||||
|
gamma=train_opt['lr_gamma'],
|
||||||
|
clear_state=train_opt['clear_state']))
|
||||||
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(
|
||||||
|
lr_scheduler.CosineAnnealingLR_Restart(
|
||||||
|
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
||||||
|
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.log_dict = OrderedDict()
|
||||||
|
|
||||||
|
def feed_data(self, data, need_GT=True):
|
||||||
|
self.var_L = data['LQs'].to(self.device)
|
||||||
|
if need_GT:
|
||||||
|
self.real_H = data['GT'].to(self.device)
|
||||||
|
|
||||||
|
def set_params_lr_zero(self):
|
||||||
|
# fix normal module
|
||||||
|
self.optimizers[0].param_groups[0]['lr'] = 0
|
||||||
|
|
||||||
|
def optimize_parameters(self, step):
|
||||||
|
if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
|
||||||
|
self.set_params_lr_zero()
|
||||||
|
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
self.fake_H = self.netG(self.var_L)
|
||||||
|
|
||||||
|
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
|
||||||
|
l_pix.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
|
||||||
|
# set log
|
||||||
|
self.log_dict['l_pix'] = l_pix.item()
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
self.netG.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
self.fake_H = self.netG(self.var_L)
|
||||||
|
self.netG.train()
|
||||||
|
|
||||||
|
def get_current_log(self):
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
out_dict = OrderedDict()
|
||||||
|
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
|
||||||
|
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
|
||||||
|
if need_GT:
|
||||||
|
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
def print_network(self):
|
||||||
|
s, n = self.get_network_description(self.netG)
|
||||||
|
if isinstance(self.netG, nn.DataParallel):
|
||||||
|
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
||||||
|
self.netG.module.__class__.__name__)
|
||||||
|
else:
|
||||||
|
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||||
|
logger.info(s)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
load_path_G = self.opt['path']['pretrain_model_G']
|
||||||
|
if load_path_G is not None:
|
||||||
|
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
||||||
|
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
||||||
|
|
||||||
|
def save(self, iter_label):
|
||||||
|
self.save_network(self.netG, 'G', iter_label)
|
19
codes/models/__init__.py
Normal file
19
codes/models/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(opt):
|
||||||
|
model = opt['model']
|
||||||
|
# image restoration
|
||||||
|
if model == 'sr': # PSNR-oriented super resolution
|
||||||
|
from .SR_model import SRModel as M
|
||||||
|
elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN
|
||||||
|
from .SRGAN_model import SRGANModel as M
|
||||||
|
# video restoration
|
||||||
|
elif model == 'video_base':
|
||||||
|
from .Video_base_model import VideoBaseModel as M
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
||||||
|
m = M(opt)
|
||||||
|
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
|
||||||
|
return m
|
368
codes/models/archs/DUF_arch.py
Normal file
368
codes/models/archs/DUF_arch.py
Normal file
|
@ -0,0 +1,368 @@
|
||||||
|
'''Network architecture for DUF:
|
||||||
|
Deep Video Super-Resolution Network Using Dynamic Upsampling Filters
|
||||||
|
Without Explicit Motion Compensation (CVPR18)
|
||||||
|
https://github.com/yhjo09/VSR-DUF
|
||||||
|
|
||||||
|
For all the models below, [adapt_official] is only necessary when
|
||||||
|
loading the weights converted from the official TensorFlow weights.
|
||||||
|
Please set it to [False] if you are training the model from scratch.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_official(Rx, scale=4):
|
||||||
|
'''Adapt the weights translated from the official tensorflow weights
|
||||||
|
Not necessary if you are training from scratch'''
|
||||||
|
x = Rx.clone()
|
||||||
|
x1 = x[:, ::3, :, :]
|
||||||
|
x2 = x[:, 1::3, :, :]
|
||||||
|
x3 = x[:, 2::3, :, :]
|
||||||
|
|
||||||
|
Rx[:, :scale**2, :, :] = x1
|
||||||
|
Rx[:, scale**2:2 * (scale**2), :, :] = x2
|
||||||
|
Rx[:, 2 * (scale**2):, :, :] = x3
|
||||||
|
|
||||||
|
return Rx
|
||||||
|
|
||||||
|
|
||||||
|
class DenseBlock(nn.Module):
|
||||||
|
'''Dense block
|
||||||
|
for the second denseblock, t_reduced = True'''
|
||||||
|
|
||||||
|
def __init__(self, nf=64, ng=32, t_reduce=False):
|
||||||
|
super(DenseBlock, self).__init__()
|
||||||
|
self.t_reduce = t_reduce
|
||||||
|
if self.t_reduce:
|
||||||
|
pad = (0, 1, 1)
|
||||||
|
else:
|
||||||
|
pad = (1, 1, 1)
|
||||||
|
self.bn3d_1 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_1 = nn.Conv3d(nf, nf, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
||||||
|
self.bn3d_2 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_2 = nn.Conv3d(nf, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
|
||||||
|
self.bn3d_3 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_3 = nn.Conv3d(nf + ng, nf + ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True)
|
||||||
|
self.bn3d_4 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_4 = nn.Conv3d(nf + ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
|
||||||
|
self.bn3d_5 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_5 = nn.Conv3d(nf + 2 * ng, nf + 2 * ng, (1, 1, 1), stride=(1, 1, 1),
|
||||||
|
padding=(0, 0, 0), bias=True)
|
||||||
|
self.bn3d_6 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_6 = nn.Conv3d(nf + 2 * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''x: [B, C, T, H, W]
|
||||||
|
C: nf -> nf + 3 * ng
|
||||||
|
T: 1) 7 -> 7 (t_reduce=False);
|
||||||
|
2) 7 -> 7 - 2 * 3 = 1 (t_reduce=True)'''
|
||||||
|
x1 = self.conv3d_1(F.relu(self.bn3d_1(x), inplace=True))
|
||||||
|
x1 = self.conv3d_2(F.relu(self.bn3d_2(x1), inplace=True))
|
||||||
|
if self.t_reduce:
|
||||||
|
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
|
||||||
|
else:
|
||||||
|
x1 = torch.cat((x, x1), 1)
|
||||||
|
|
||||||
|
x2 = self.conv3d_3(F.relu(self.bn3d_3(x1), inplace=True))
|
||||||
|
x2 = self.conv3d_4(F.relu(self.bn3d_4(x2), inplace=True))
|
||||||
|
if self.t_reduce:
|
||||||
|
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
|
||||||
|
else:
|
||||||
|
x2 = torch.cat((x1, x2), 1)
|
||||||
|
|
||||||
|
x3 = self.conv3d_5(F.relu(self.bn3d_5(x2), inplace=True))
|
||||||
|
x3 = self.conv3d_6(F.relu(self.bn3d_6(x3), inplace=True))
|
||||||
|
if self.t_reduce:
|
||||||
|
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
|
||||||
|
else:
|
||||||
|
x3 = torch.cat((x2, x3), 1)
|
||||||
|
return x3
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicUpsamplingFilter_3C(nn.Module):
|
||||||
|
'''dynamic upsampling filter with 3 channels applying the same filters
|
||||||
|
filter_size: filter size of the generated filters, shape (C, kH, kW)'''
|
||||||
|
|
||||||
|
def __init__(self, filter_size=(1, 5, 5)):
|
||||||
|
super(DynamicUpsamplingFilter_3C, self).__init__()
|
||||||
|
# generate a local expansion filter, used similar to im2col
|
||||||
|
nF = np.prod(filter_size)
|
||||||
|
expand_filter_np = np.reshape(np.eye(nF, nF),
|
||||||
|
(nF, filter_size[0], filter_size[1], filter_size[2]))
|
||||||
|
expand_filter = torch.from_numpy(expand_filter_np).float()
|
||||||
|
self.expand_filter = torch.cat((expand_filter, expand_filter, expand_filter),
|
||||||
|
0) # [75, 1, 5, 5]
|
||||||
|
|
||||||
|
def forward(self, x, filters):
|
||||||
|
'''x: input image, [B, 3, H, W]
|
||||||
|
filters: generate dynamic filters, [B, F, R, H, W], e.g., [B, 25, 16, H, W]
|
||||||
|
F: prod of filter kernel size, e.g., 5*5 = 25
|
||||||
|
R: used for upsampling, similar to pixel shuffle, e.g., 4*4 = 16 for x4
|
||||||
|
Return: filtered image, [B, 3*R, H, W]
|
||||||
|
'''
|
||||||
|
B, nF, R, H, W = filters.size()
|
||||||
|
# using group convolution
|
||||||
|
input_expand = F.conv2d(x, self.expand_filter.type_as(x), padding=2,
|
||||||
|
groups=3) # [B, 75, H, W] similar to im2col
|
||||||
|
input_expand = input_expand.view(B, 3, nF, H, W).permute(0, 3, 4, 1, 2) # [B, H, W, 3, 25]
|
||||||
|
filters = filters.permute(0, 3, 4, 1, 2) # [B, H, W, 25, 16]
|
||||||
|
out = torch.matmul(input_expand, filters) # [B, H, W, 3, 16]
|
||||||
|
return out.permute(0, 3, 4, 1, 2).view(B, 3 * R, H, W) # [B, 3*16, H, W]
|
||||||
|
|
||||||
|
|
||||||
|
class DUF_16L(nn.Module):
|
||||||
|
'''Official DUF structure with 16 layers'''
|
||||||
|
|
||||||
|
def __init__(self, scale=4, adapt_official=False):
|
||||||
|
super(DUF_16L, self).__init__()
|
||||||
|
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
||||||
|
self.dense_block_1 = DenseBlock(64, 64 // 2, t_reduce=False) # 64 + 32 * 3 = 160, T = 7
|
||||||
|
self.dense_block_2 = DenseBlock(160, 64 // 2, t_reduce=True) # 160 + 32 * 3 = 256, T = 1
|
||||||
|
self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True)
|
||||||
|
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||||
|
padding=(0, 0, 0), bias=True)
|
||||||
|
|
||||||
|
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True)
|
||||||
|
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||||
|
padding=(0, 0, 0), bias=True)
|
||||||
|
|
||||||
|
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
self.adapt_official = adapt_official
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
|
||||||
|
Generate filters and image residual:
|
||||||
|
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
|
||||||
|
Rx: [B, 3*16, 1, H, W]
|
||||||
|
'''
|
||||||
|
B, T, C, H, W = x.size()
|
||||||
|
x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] for Conv3D
|
||||||
|
x_center = x[:, :, T // 2, :, :]
|
||||||
|
|
||||||
|
x = self.conv3d_1(x)
|
||||||
|
x = self.dense_block_1(x)
|
||||||
|
x = self.dense_block_2(x) # reduce T to 1
|
||||||
|
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
|
||||||
|
|
||||||
|
# image residual
|
||||||
|
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
|
||||||
|
|
||||||
|
# filter
|
||||||
|
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
|
||||||
|
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
|
||||||
|
|
||||||
|
# Adapt to official model weights
|
||||||
|
if self.adapt_official:
|
||||||
|
adapt_official(Rx, scale=self.scale)
|
||||||
|
|
||||||
|
# dynamic filter
|
||||||
|
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
|
||||||
|
out += Rx.squeeze_(2)
|
||||||
|
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DenseBlock_28L(nn.Module):
|
||||||
|
'''The first part of the dense blocks used in DUF_28L
|
||||||
|
Temporal dimension remains the same here'''
|
||||||
|
|
||||||
|
def __init__(self, nf=64, ng=16):
|
||||||
|
super(DenseBlock_28L, self).__init__()
|
||||||
|
pad = (1, 1, 1)
|
||||||
|
|
||||||
|
dense_block_l = []
|
||||||
|
for i in range(0, 9):
|
||||||
|
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
|
||||||
|
dense_block_l.append(nn.ReLU())
|
||||||
|
dense_block_l.append(
|
||||||
|
nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True))
|
||||||
|
|
||||||
|
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
|
||||||
|
dense_block_l.append(nn.ReLU())
|
||||||
|
dense_block_l.append(
|
||||||
|
nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
|
||||||
|
|
||||||
|
self.dense_blocks = nn.ModuleList(dense_block_l)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''x: [B, C, T, H, W]
|
||||||
|
C: 1) 64 -> 208;
|
||||||
|
T: 1) 7 -> 7; (t_reduce=True)'''
|
||||||
|
for i in range(0, len(self.dense_blocks), 6):
|
||||||
|
y = x
|
||||||
|
for j in range(6):
|
||||||
|
y = self.dense_blocks[i + j](y)
|
||||||
|
x = torch.cat((x, y), 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DUF_28L(nn.Module):
|
||||||
|
'''Official DUF structure with 28 layers'''
|
||||||
|
|
||||||
|
def __init__(self, scale=4, adapt_official=False):
|
||||||
|
super(DUF_28L, self).__init__()
|
||||||
|
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
||||||
|
self.dense_block_1 = DenseBlock_28L(64, 16) # 64 + 16 * 9 = 208, T = 7
|
||||||
|
self.dense_block_2 = DenseBlock(208, 16, t_reduce=True) # 208 + 16 * 3 = 256, T = 1
|
||||||
|
self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True)
|
||||||
|
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||||
|
padding=(0, 0, 0), bias=True)
|
||||||
|
|
||||||
|
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True)
|
||||||
|
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||||
|
padding=(0, 0, 0), bias=True)
|
||||||
|
|
||||||
|
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
self.adapt_official = adapt_official
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
|
||||||
|
Generate filters and image residual:
|
||||||
|
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
|
||||||
|
Rx: [B, 3*16, 1, H, W]
|
||||||
|
'''
|
||||||
|
B, T, C, H, W = x.size()
|
||||||
|
x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
|
||||||
|
x_center = x[:, :, T // 2, :, :]
|
||||||
|
x = self.conv3d_1(x)
|
||||||
|
x = self.dense_block_1(x)
|
||||||
|
x = self.dense_block_2(x) # reduce T to 1
|
||||||
|
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
|
||||||
|
|
||||||
|
# image residual
|
||||||
|
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
|
||||||
|
|
||||||
|
# filter
|
||||||
|
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
|
||||||
|
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
|
||||||
|
|
||||||
|
# Adapt to official model weights
|
||||||
|
if self.adapt_official:
|
||||||
|
adapt_official(Rx, scale=self.scale)
|
||||||
|
|
||||||
|
# dynamic filter
|
||||||
|
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
|
||||||
|
out += Rx.squeeze_(2)
|
||||||
|
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DenseBlock_52L(nn.Module):
|
||||||
|
'''The first part of the dense blocks used in DUF_52L
|
||||||
|
Temporal dimension remains the same here'''
|
||||||
|
|
||||||
|
def __init__(self, nf=64, ng=16):
|
||||||
|
super(DenseBlock_52L, self).__init__()
|
||||||
|
pad = (1, 1, 1)
|
||||||
|
|
||||||
|
dense_block_l = []
|
||||||
|
for i in range(0, 21):
|
||||||
|
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
|
||||||
|
dense_block_l.append(nn.ReLU())
|
||||||
|
dense_block_l.append(
|
||||||
|
nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True))
|
||||||
|
|
||||||
|
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
|
||||||
|
dense_block_l.append(nn.ReLU())
|
||||||
|
dense_block_l.append(
|
||||||
|
nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
|
||||||
|
|
||||||
|
self.dense_blocks = nn.ModuleList(dense_block_l)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''x: [B, C, T, H, W]
|
||||||
|
C: 1) 64 -> 400;
|
||||||
|
T: 1) 7 -> 7; (t_reduce=True)'''
|
||||||
|
for i in range(0, len(self.dense_blocks), 6):
|
||||||
|
y = x
|
||||||
|
for j in range(6):
|
||||||
|
y = self.dense_blocks[i + j](y)
|
||||||
|
x = torch.cat((x, y), 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DUF_52L(nn.Module):
|
||||||
|
'''Official DUF structure with 52 layers'''
|
||||||
|
|
||||||
|
def __init__(self, scale=4, adapt_official=False):
|
||||||
|
super(DUF_52L, self).__init__()
|
||||||
|
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
||||||
|
self.dense_block_1 = DenseBlock_52L(64, 16) # 64 + 21 * 9 = 400, T = 7
|
||||||
|
self.dense_block_2 = DenseBlock(400, 16, t_reduce=True) # 400 + 16 * 3 = 448, T = 1
|
||||||
|
|
||||||
|
self.bn3d_2 = nn.BatchNorm3d(448, eps=1e-3, momentum=1e-3)
|
||||||
|
self.conv3d_2 = nn.Conv3d(448, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True)
|
||||||
|
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||||
|
padding=(0, 0, 0), bias=True)
|
||||||
|
|
||||||
|
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||||
|
bias=True)
|
||||||
|
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||||
|
padding=(0, 0, 0), bias=True)
|
||||||
|
|
||||||
|
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
self.adapt_official = adapt_official
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
|
||||||
|
Generate filters and image residual:
|
||||||
|
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
|
||||||
|
Rx: [B, 3*16, 1, H, W]
|
||||||
|
'''
|
||||||
|
B, T, C, H, W = x.size()
|
||||||
|
x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
|
||||||
|
x_center = x[:, :, T // 2, :, :]
|
||||||
|
x = self.conv3d_1(x)
|
||||||
|
x = self.dense_block_1(x)
|
||||||
|
x = self.dense_block_2(x)
|
||||||
|
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
|
||||||
|
|
||||||
|
# image residual
|
||||||
|
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
|
||||||
|
|
||||||
|
# filter
|
||||||
|
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
|
||||||
|
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
|
||||||
|
|
||||||
|
# Adapt to official model weights
|
||||||
|
if self.adapt_official:
|
||||||
|
adapt_official(Rx, scale=self.scale)
|
||||||
|
|
||||||
|
# dynamic filter
|
||||||
|
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
|
||||||
|
out += Rx.squeeze_(2)
|
||||||
|
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
|
||||||
|
return out
|
312
codes/models/archs/EDVR_arch.py
Normal file
312
codes/models/archs/EDVR_arch.py
Normal file
|
@ -0,0 +1,312 @@
|
||||||
|
''' network architecture for EDVR '''
|
||||||
|
import functools
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import models.archs.arch_util as arch_util
|
||||||
|
try:
|
||||||
|
from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Failed to import DCNv2 module.')
|
||||||
|
|
||||||
|
|
||||||
|
class Predeblur_ResNet_Pyramid(nn.Module):
|
||||||
|
def __init__(self, nf=128, HR_in=False):
|
||||||
|
'''
|
||||||
|
HR_in: True if the inputs are high spatial size
|
||||||
|
'''
|
||||||
|
|
||||||
|
super(Predeblur_ResNet_Pyramid, self).__init__()
|
||||||
|
self.HR_in = True if HR_in else False
|
||||||
|
if self.HR_in:
|
||||||
|
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||||
|
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||||
|
else:
|
||||||
|
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||||
|
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||||||
|
self.RB_L1_1 = basic_block()
|
||||||
|
self.RB_L1_2 = basic_block()
|
||||||
|
self.RB_L1_3 = basic_block()
|
||||||
|
self.RB_L1_4 = basic_block()
|
||||||
|
self.RB_L1_5 = basic_block()
|
||||||
|
self.RB_L2_1 = basic_block()
|
||||||
|
self.RB_L2_2 = basic_block()
|
||||||
|
self.RB_L3_1 = basic_block()
|
||||||
|
self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||||
|
self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.HR_in:
|
||||||
|
L1_fea = self.lrelu(self.conv_first_1(x))
|
||||||
|
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
|
||||||
|
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
|
||||||
|
else:
|
||||||
|
L1_fea = self.lrelu(self.conv_first(x))
|
||||||
|
L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea))
|
||||||
|
L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea))
|
||||||
|
L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear',
|
||||||
|
align_corners=False)
|
||||||
|
L2_fea = self.RB_L2_1(L2_fea) + L3_fea
|
||||||
|
L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear',
|
||||||
|
align_corners=False)
|
||||||
|
L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea
|
||||||
|
out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PCD_Align(nn.Module):
|
||||||
|
''' Alignment module using Pyramid, Cascading and Deformable convolution
|
||||||
|
with 3 pyramid levels.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, nf=64, groups=8):
|
||||||
|
super(PCD_Align, self).__init__()
|
||||||
|
# L3: level 3, 1/4 spatial size
|
||||||
|
self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||||||
|
self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||||||
|
extra_offset_mask=True)
|
||||||
|
# L2: level 2, 1/2 spatial size
|
||||||
|
self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||||||
|
self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
|
||||||
|
self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||||||
|
extra_offset_mask=True)
|
||||||
|
self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
|
||||||
|
# L1: level 1, original spatial size
|
||||||
|
self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||||||
|
self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
|
||||||
|
self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||||||
|
extra_offset_mask=True)
|
||||||
|
self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
|
||||||
|
# Cascading DCN
|
||||||
|
self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||||||
|
self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||||||
|
extra_offset_mask=True)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, nbr_fea_l, ref_fea_l):
|
||||||
|
'''align other neighboring frames to the reference frame in the feature level
|
||||||
|
nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
|
||||||
|
'''
|
||||||
|
# L3
|
||||||
|
L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)
|
||||||
|
L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
|
||||||
|
L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
|
||||||
|
L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
|
||||||
|
# L2
|
||||||
|
L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)
|
||||||
|
L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
|
||||||
|
L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
|
||||||
|
L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
|
||||||
|
L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
|
||||||
|
L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
|
||||||
|
# L1
|
||||||
|
L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)
|
||||||
|
L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
|
||||||
|
L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
|
||||||
|
L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
|
||||||
|
L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
|
||||||
|
L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
|
||||||
|
# Cascading
|
||||||
|
offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
|
||||||
|
offset = self.lrelu(self.cas_offset_conv1(offset))
|
||||||
|
offset = self.lrelu(self.cas_offset_conv2(offset))
|
||||||
|
L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
|
||||||
|
|
||||||
|
return L1_fea
|
||||||
|
|
||||||
|
|
||||||
|
class TSA_Fusion(nn.Module):
|
||||||
|
''' Temporal Spatial Attention fusion module
|
||||||
|
Temporal: correlation;
|
||||||
|
Spatial: 3 pyramid levels.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, nf=64, nframes=5, center=2):
|
||||||
|
super(TSA_Fusion, self).__init__()
|
||||||
|
self.center = center
|
||||||
|
# temporal attention (before fusion conv)
|
||||||
|
self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
# fusion conv: using 1x1 to save parameters and computation
|
||||||
|
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||||||
|
|
||||||
|
# spatial attention (after fusion conv)
|
||||||
|
self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||||
|
self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
|
||||||
|
self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True)
|
||||||
|
self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||||
|
self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||||
|
self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
|
||||||
|
self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||||
|
self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, aligned_fea):
|
||||||
|
B, N, C, H, W = aligned_fea.size() # N video frames
|
||||||
|
#### temporal attention
|
||||||
|
emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())
|
||||||
|
emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W]
|
||||||
|
|
||||||
|
cor_l = []
|
||||||
|
for i in range(N):
|
||||||
|
emb_nbr = emb[:, i, :, :, :]
|
||||||
|
cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W
|
||||||
|
cor_l.append(cor_tmp)
|
||||||
|
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W
|
||||||
|
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W)
|
||||||
|
aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob
|
||||||
|
|
||||||
|
#### fusion
|
||||||
|
fea = self.lrelu(self.fea_fusion(aligned_fea))
|
||||||
|
|
||||||
|
#### spatial attention
|
||||||
|
att = self.lrelu(self.sAtt_1(aligned_fea))
|
||||||
|
att_max = self.maxpool(att)
|
||||||
|
att_avg = self.avgpool(att)
|
||||||
|
att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
|
||||||
|
# pyramid levels
|
||||||
|
att_L = self.lrelu(self.sAtt_L1(att))
|
||||||
|
att_max = self.maxpool(att_L)
|
||||||
|
att_avg = self.avgpool(att_L)
|
||||||
|
att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
|
||||||
|
att_L = self.lrelu(self.sAtt_L3(att_L))
|
||||||
|
att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
att = self.lrelu(self.sAtt_3(att))
|
||||||
|
att = att + att_L
|
||||||
|
att = self.lrelu(self.sAtt_4(att))
|
||||||
|
att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
att = self.sAtt_5(att)
|
||||||
|
att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
|
||||||
|
att = torch.sigmoid(att)
|
||||||
|
|
||||||
|
fea = fea * att * 2 + att_add
|
||||||
|
return fea
|
||||||
|
|
||||||
|
|
||||||
|
class EDVR(nn.Module):
|
||||||
|
def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None,
|
||||||
|
predeblur=False, HR_in=False, w_TSA=True):
|
||||||
|
super(EDVR, self).__init__()
|
||||||
|
self.nf = nf
|
||||||
|
self.center = nframes // 2 if center is None else center
|
||||||
|
self.is_predeblur = True if predeblur else False
|
||||||
|
self.HR_in = True if HR_in else False
|
||||||
|
self.w_TSA = w_TSA
|
||||||
|
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||||||
|
|
||||||
|
#### extract features (for each frame)
|
||||||
|
if self.is_predeblur:
|
||||||
|
self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
|
||||||
|
self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||||
|
else:
|
||||||
|
if self.HR_in:
|
||||||
|
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||||
|
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||||
|
else:
|
||||||
|
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||||
|
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
|
||||||
|
self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||||
|
self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||||
|
self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.pcd_align = PCD_Align(nf=nf, groups=groups)
|
||||||
|
if self.w_TSA:
|
||||||
|
self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
|
||||||
|
else:
|
||||||
|
self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||||||
|
|
||||||
|
#### reconstruction
|
||||||
|
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
|
||||||
|
#### upsampling
|
||||||
|
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||||||
|
self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
|
||||||
|
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||||
|
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
|
||||||
|
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
#### activation function
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C, H, W = x.size() # N video frames
|
||||||
|
x_center = x[:, self.center, :, :, :].contiguous()
|
||||||
|
|
||||||
|
#### extract LR features
|
||||||
|
# L1
|
||||||
|
if self.is_predeblur:
|
||||||
|
L1_fea = self.pre_deblur(x.view(-1, C, H, W))
|
||||||
|
L1_fea = self.conv_1x1(L1_fea)
|
||||||
|
if self.HR_in:
|
||||||
|
H, W = H // 4, W // 4
|
||||||
|
else:
|
||||||
|
if self.HR_in:
|
||||||
|
L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W)))
|
||||||
|
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
|
||||||
|
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
|
||||||
|
H, W = H // 4, W // 4
|
||||||
|
else:
|
||||||
|
L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W)))
|
||||||
|
L1_fea = self.feature_extraction(L1_fea)
|
||||||
|
# L2
|
||||||
|
L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
|
||||||
|
L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
|
||||||
|
# L3
|
||||||
|
L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
|
||||||
|
L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
|
||||||
|
|
||||||
|
L1_fea = L1_fea.view(B, N, -1, H, W)
|
||||||
|
L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
|
||||||
|
L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
|
||||||
|
|
||||||
|
#### pcd align
|
||||||
|
# ref feature list
|
||||||
|
ref_fea_l = [
|
||||||
|
L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(),
|
||||||
|
L3_fea[:, self.center, :, :, :].clone()
|
||||||
|
]
|
||||||
|
aligned_fea = []
|
||||||
|
for i in range(N):
|
||||||
|
nbr_fea_l = [
|
||||||
|
L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(),
|
||||||
|
L3_fea[:, i, :, :, :].clone()
|
||||||
|
]
|
||||||
|
aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l))
|
||||||
|
aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W]
|
||||||
|
|
||||||
|
if not self.w_TSA:
|
||||||
|
aligned_fea = aligned_fea.view(B, -1, H, W)
|
||||||
|
fea = self.tsa_fusion(aligned_fea)
|
||||||
|
|
||||||
|
out = self.recon_trunk(fea)
|
||||||
|
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
||||||
|
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
||||||
|
out = self.lrelu(self.HRconv(out))
|
||||||
|
out = self.conv_last(out)
|
||||||
|
if self.HR_in:
|
||||||
|
base = x_center
|
||||||
|
else:
|
||||||
|
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
|
||||||
|
out += base
|
||||||
|
return out
|
73
codes/models/archs/RRDBNet_arch.py
Normal file
73
codes/models/archs/RRDBNet_arch.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
import functools
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import models.archs.arch_util as arch_util
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
|
def __init__(self, nf=64, gc=32, bias=True):
|
||||||
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
|
# gc: growth channel, i.e. intermediate channels
|
||||||
|
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],
|
||||||
|
0.1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.lrelu(self.conv1(x))
|
||||||
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||||
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||||
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
return x5 * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDB(nn.Module):
|
||||||
|
'''Residual in Residual Dense Block'''
|
||||||
|
|
||||||
|
def __init__(self, nf, gc=32):
|
||||||
|
super(RRDB, self).__init__()
|
||||||
|
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.RDB1(x)
|
||||||
|
out = self.RDB2(out)
|
||||||
|
out = self.RDB3(out)
|
||||||
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDBNet(nn.Module):
|
||||||
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||||
|
super(RRDBNet, self).__init__()
|
||||||
|
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||||
|
|
||||||
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
|
||||||
|
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
#### upsampling
|
||||||
|
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.conv_first(x)
|
||||||
|
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||||
|
fea = fea + trunk
|
||||||
|
|
||||||
|
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||||
|
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||||
|
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||||
|
|
||||||
|
return out
|
55
codes/models/archs/SRResNet_arch.py
Normal file
55
codes/models/archs/SRResNet_arch.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
import functools
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import models.archs.arch_util as arch_util
|
||||||
|
|
||||||
|
|
||||||
|
class MSRResNet(nn.Module):
|
||||||
|
''' modified SRResNet'''
|
||||||
|
|
||||||
|
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
|
||||||
|
super(MSRResNet, self).__init__()
|
||||||
|
self.upscale = upscale
|
||||||
|
|
||||||
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||||||
|
self.recon_trunk = arch_util.make_layer(basic_block, nb)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
if self.upscale == 2:
|
||||||
|
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||||||
|
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||||
|
elif self.upscale == 3:
|
||||||
|
self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
|
||||||
|
self.pixel_shuffle = nn.PixelShuffle(3)
|
||||||
|
elif self.upscale == 4:
|
||||||
|
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||||||
|
self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||||||
|
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||||
|
|
||||||
|
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
# activation function
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
|
||||||
|
0.1)
|
||||||
|
if self.upscale == 4:
|
||||||
|
arch_util.initialize_weights(self.upconv2, 0.1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.lrelu(self.conv_first(x))
|
||||||
|
out = self.recon_trunk(fea)
|
||||||
|
|
||||||
|
if self.upscale == 4:
|
||||||
|
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
||||||
|
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
||||||
|
elif self.upscale == 3 or self.upscale == 2:
|
||||||
|
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
||||||
|
|
||||||
|
out = self.conv_last(self.lrelu(self.HRconv(out)))
|
||||||
|
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
|
||||||
|
out += base
|
||||||
|
return out
|
137
codes/models/archs/TOF_arch.py
Executable file
137
codes/models/archs/TOF_arch.py
Executable file
|
@ -0,0 +1,137 @@
|
||||||
|
'''PyTorch implementation of TOFlow
|
||||||
|
Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
|
||||||
|
Code reference:
|
||||||
|
1. https://github.com/anchen1011/toflow
|
||||||
|
2. https://github.com/Coldog2333/pytoflow
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .arch_util import flow_warp
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x):
|
||||||
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
|
||||||
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
|
||||||
|
return (x - mean) / std
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize(x):
|
||||||
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
|
||||||
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
|
||||||
|
return x * std + mean
|
||||||
|
|
||||||
|
|
||||||
|
class SpyNet_Block(nn.Module):
|
||||||
|
'''A submodule of SpyNet.'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SpyNet_Block, self).__init__()
|
||||||
|
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
|
||||||
|
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
|
||||||
|
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
|
||||||
|
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
|
||||||
|
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
input: x: [ref im, nbr im, initial flow] - (B, 8, H, W)
|
||||||
|
output: estimated flow - (B, 2, H, W)
|
||||||
|
'''
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SpyNet(nn.Module):
|
||||||
|
'''SpyNet for estimating optical flow
|
||||||
|
Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SpyNet, self).__init__()
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)])
|
||||||
|
|
||||||
|
def forward(self, ref, nbr):
|
||||||
|
'''Estimating optical flow in coarse level, upsample, and estimate in fine level
|
||||||
|
input: ref: reference image - [B, 3, H, W]
|
||||||
|
nbr: the neighboring image to be warped - [B, 3, H, W]
|
||||||
|
output: estimated optical flow - [B, 2, H, W]
|
||||||
|
'''
|
||||||
|
B, C, H, W = ref.size()
|
||||||
|
ref = [ref]
|
||||||
|
nbr = [nbr]
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
ref.insert(
|
||||||
|
0,
|
||||||
|
nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2,
|
||||||
|
count_include_pad=False))
|
||||||
|
nbr.insert(
|
||||||
|
0,
|
||||||
|
nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2,
|
||||||
|
count_include_pad=False))
|
||||||
|
|
||||||
|
flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0])
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear',
|
||||||
|
align_corners=True) * 2.0
|
||||||
|
flow = flow_up + self.blocks[i](torch.cat(
|
||||||
|
[ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
|
||||||
|
return flow
|
||||||
|
|
||||||
|
|
||||||
|
class TOFlow(nn.Module):
|
||||||
|
def __init__(self, adapt_official=False):
|
||||||
|
super(TOFlow, self).__init__()
|
||||||
|
|
||||||
|
self.SpyNet = SpyNet()
|
||||||
|
|
||||||
|
self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
|
||||||
|
self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4)
|
||||||
|
self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1)
|
||||||
|
self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.adapt_official = adapt_official # True if using translated official weights else False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
input: x: input frames - [B, 7, 3, H, W]
|
||||||
|
output: SR reference frame - [B, 3, H, W]
|
||||||
|
"""
|
||||||
|
|
||||||
|
B, T, C, H, W = x.size()
|
||||||
|
x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W)
|
||||||
|
|
||||||
|
ref_idx = 3
|
||||||
|
x_ref = x[:, ref_idx, :, :, :]
|
||||||
|
|
||||||
|
# In the official torch code, the 0-th frame is the reference frame
|
||||||
|
if self.adapt_official:
|
||||||
|
x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
|
||||||
|
ref_idx = 0
|
||||||
|
|
||||||
|
x_warped = []
|
||||||
|
for i in range(7):
|
||||||
|
if i == ref_idx:
|
||||||
|
x_warped.append(x_ref)
|
||||||
|
else:
|
||||||
|
x_nbr = x[:, i, :, :, :]
|
||||||
|
flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1)
|
||||||
|
x_warped.append(flow_warp(x_nbr, flow))
|
||||||
|
x_warped = torch.stack(x_warped, dim=1)
|
||||||
|
|
||||||
|
x = x_warped.view(B, -1, H, W)
|
||||||
|
x = self.relu(self.conv_3x7_64_9x9(x))
|
||||||
|
x = self.relu(self.conv_64_64_9x9(x))
|
||||||
|
x = self.relu(self.conv_64_64_1x1(x))
|
||||||
|
x = self.conv_64_3_1x1(x) + x_ref
|
||||||
|
|
||||||
|
return denormalize(x)
|
0
codes/models/archs/__init__.py
Normal file
0
codes/models/archs/__init__.py
Normal file
79
codes/models/archs/arch_util.py
Normal file
79
codes/models/archs/arch_util.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.init as init
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_weights(net_l, scale=1):
|
||||||
|
if not isinstance(net_l, list):
|
||||||
|
net_l = [net_l]
|
||||||
|
for net in net_l:
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
|
m.weight.data *= scale # for residual block
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
|
m.weight.data *= scale
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias.data, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def make_layer(block, n_layers):
|
||||||
|
layers = []
|
||||||
|
for _ in range(n_layers):
|
||||||
|
layers.append(block())
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock_noBN(nn.Module):
|
||||||
|
'''Residual block w/o BN
|
||||||
|
---Conv-ReLU-Conv-+-
|
||||||
|
|________________|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, nf=64):
|
||||||
|
super(ResidualBlock_noBN, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
initialize_weights([self.conv1, self.conv2], 0.1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
out = F.relu(self.conv1(x), inplace=True)
|
||||||
|
out = self.conv2(out)
|
||||||
|
return identity + out
|
||||||
|
|
||||||
|
|
||||||
|
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
||||||
|
"""Warp an image or feature map with optical flow
|
||||||
|
Args:
|
||||||
|
x (Tensor): size (N, C, H, W)
|
||||||
|
flow (Tensor): size (N, H, W, 2), normal value
|
||||||
|
interp_mode (str): 'nearest' or 'bilinear'
|
||||||
|
padding_mode (str): 'zeros' or 'border' or 'reflection'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: warped image or feature map
|
||||||
|
"""
|
||||||
|
assert x.size()[-2:] == flow.size()[1:3]
|
||||||
|
B, C, H, W = x.size()
|
||||||
|
# mesh grid
|
||||||
|
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
||||||
|
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
||||||
|
grid.requires_grad = False
|
||||||
|
grid = grid.type_as(x)
|
||||||
|
vgrid = grid + flow
|
||||||
|
# scale grid to [-1,1]
|
||||||
|
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
||||||
|
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
||||||
|
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
||||||
|
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
||||||
|
return output
|
7
codes/models/archs/dcn/__init__.py
Normal file
7
codes/models/archs/dcn/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,
|
||||||
|
deform_conv, modulated_deform_conv)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
|
||||||
|
'modulated_deform_conv'
|
||||||
|
]
|
291
codes/models/archs/dcn/deform_conv.py
Normal file
291
codes/models/archs/dcn/deform_conv.py
Normal file
|
@ -0,0 +1,291 @@
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.autograd import Function
|
||||||
|
from torch.autograd.function import once_differentiable
|
||||||
|
from torch.nn.modules.utils import _pair
|
||||||
|
|
||||||
|
from . import deform_conv_cuda
|
||||||
|
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
class DeformConvFunction(Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
|
||||||
|
deformable_groups=1, im2col_step=64):
|
||||||
|
if input is not None and input.dim() != 4:
|
||||||
|
raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
|
||||||
|
input.dim()))
|
||||||
|
ctx.stride = _pair(stride)
|
||||||
|
ctx.padding = _pair(padding)
|
||||||
|
ctx.dilation = _pair(dilation)
|
||||||
|
ctx.groups = groups
|
||||||
|
ctx.deformable_groups = deformable_groups
|
||||||
|
ctx.im2col_step = im2col_step
|
||||||
|
|
||||||
|
ctx.save_for_backward(input, offset, weight)
|
||||||
|
|
||||||
|
output = input.new_empty(
|
||||||
|
DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
|
||||||
|
|
||||||
|
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
|
||||||
|
|
||||||
|
if not input.is_cuda:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
||||||
|
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
||||||
|
deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
|
||||||
|
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
||||||
|
weight.size(2), ctx.stride[1], ctx.stride[0],
|
||||||
|
ctx.padding[1], ctx.padding[0],
|
||||||
|
ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
||||||
|
ctx.deformable_groups, cur_im2col_step)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@once_differentiable
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input, offset, weight = ctx.saved_tensors
|
||||||
|
|
||||||
|
grad_input = grad_offset = grad_weight = None
|
||||||
|
|
||||||
|
if not grad_output.is_cuda:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
||||||
|
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
||||||
|
|
||||||
|
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||||
|
grad_input = torch.zeros_like(input)
|
||||||
|
grad_offset = torch.zeros_like(offset)
|
||||||
|
deform_conv_cuda.deform_conv_backward_input_cuda(
|
||||||
|
input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
|
||||||
|
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
||||||
|
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
||||||
|
ctx.deformable_groups, cur_im2col_step)
|
||||||
|
|
||||||
|
if ctx.needs_input_grad[2]:
|
||||||
|
grad_weight = torch.zeros_like(weight)
|
||||||
|
deform_conv_cuda.deform_conv_backward_parameters_cuda(
|
||||||
|
input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
|
||||||
|
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
||||||
|
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
||||||
|
ctx.deformable_groups, 1, cur_im2col_step)
|
||||||
|
|
||||||
|
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _output_size(input, weight, padding, dilation, stride):
|
||||||
|
channels = weight.size(0)
|
||||||
|
output_size = (input.size(0), channels)
|
||||||
|
for d in range(input.dim() - 2):
|
||||||
|
in_size = input.size(d + 2)
|
||||||
|
pad = padding[d]
|
||||||
|
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
|
||||||
|
stride_ = stride[d]
|
||||||
|
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
||||||
|
if not all(map(lambda s: s > 0, output_size)):
|
||||||
|
raise ValueError("convolution input is too small (output would be {})".format('x'.join(
|
||||||
|
map(str, output_size))))
|
||||||
|
return output_size
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedDeformConvFunction(Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
|
||||||
|
groups=1, deformable_groups=1):
|
||||||
|
ctx.stride = stride
|
||||||
|
ctx.padding = padding
|
||||||
|
ctx.dilation = dilation
|
||||||
|
ctx.groups = groups
|
||||||
|
ctx.deformable_groups = deformable_groups
|
||||||
|
ctx.with_bias = bias is not None
|
||||||
|
if not ctx.with_bias:
|
||||||
|
bias = input.new_empty(1) # fake tensor
|
||||||
|
if not input.is_cuda:
|
||||||
|
raise NotImplementedError
|
||||||
|
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
|
||||||
|
or input.requires_grad:
|
||||||
|
ctx.save_for_backward(input, offset, mask, weight, bias)
|
||||||
|
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
|
||||||
|
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
|
||||||
|
deform_conv_cuda.modulated_deform_conv_cuda_forward(
|
||||||
|
input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
|
||||||
|
weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
|
||||||
|
ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@once_differentiable
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
if not grad_output.is_cuda:
|
||||||
|
raise NotImplementedError
|
||||||
|
input, offset, mask, weight, bias = ctx.saved_tensors
|
||||||
|
grad_input = torch.zeros_like(input)
|
||||||
|
grad_offset = torch.zeros_like(offset)
|
||||||
|
grad_mask = torch.zeros_like(mask)
|
||||||
|
grad_weight = torch.zeros_like(weight)
|
||||||
|
grad_bias = torch.zeros_like(bias)
|
||||||
|
deform_conv_cuda.modulated_deform_conv_cuda_backward(
|
||||||
|
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
|
||||||
|
grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
|
||||||
|
ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
||||||
|
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
||||||
|
if not ctx.with_bias:
|
||||||
|
grad_bias = None
|
||||||
|
|
||||||
|
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
|
||||||
|
None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _infer_shape(ctx, input, weight):
|
||||||
|
n = input.size(0)
|
||||||
|
channels_out = weight.size(0)
|
||||||
|
height, width = input.shape[2:4]
|
||||||
|
kernel_h, kernel_w = weight.shape[2:4]
|
||||||
|
height_out = (height + 2 * ctx.padding - (ctx.dilation *
|
||||||
|
(kernel_h - 1) + 1)) // ctx.stride + 1
|
||||||
|
width_out = (width + 2 * ctx.padding - (ctx.dilation *
|
||||||
|
(kernel_w - 1) + 1)) // ctx.stride + 1
|
||||||
|
return n, channels_out, height_out, width_out
|
||||||
|
|
||||||
|
|
||||||
|
deform_conv = DeformConvFunction.apply
|
||||||
|
modulated_deform_conv = ModulatedDeformConvFunction.apply
|
||||||
|
|
||||||
|
|
||||||
|
class DeformConv(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
|
||||||
|
groups=1, deformable_groups=1, bias=False):
|
||||||
|
super(DeformConv, self).__init__()
|
||||||
|
|
||||||
|
assert not bias
|
||||||
|
assert in_channels % groups == 0, \
|
||||||
|
'in_channels {} cannot be divisible by groups {}'.format(
|
||||||
|
in_channels, groups)
|
||||||
|
assert out_channels % groups == 0, \
|
||||||
|
'out_channels {} cannot be divisible by groups {}'.format(
|
||||||
|
out_channels, groups)
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = _pair(kernel_size)
|
||||||
|
self.stride = _pair(stride)
|
||||||
|
self.padding = _pair(padding)
|
||||||
|
self.dilation = _pair(dilation)
|
||||||
|
self.groups = groups
|
||||||
|
self.deformable_groups = deformable_groups
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
n = self.in_channels
|
||||||
|
for k in self.kernel_size:
|
||||||
|
n *= k
|
||||||
|
stdv = 1. / math.sqrt(n)
|
||||||
|
self.weight.data.uniform_(-stdv, stdv)
|
||||||
|
|
||||||
|
def forward(self, x, offset):
|
||||||
|
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
|
||||||
|
self.groups, self.deformable_groups)
|
||||||
|
|
||||||
|
|
||||||
|
class DeformConvPack(DeformConv):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(DeformConvPack, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.conv_offset = nn.Conv2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
||||||
|
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
|
||||||
|
bias=True)
|
||||||
|
self.init_offset()
|
||||||
|
|
||||||
|
def init_offset(self):
|
||||||
|
self.conv_offset.weight.data.zero_()
|
||||||
|
self.conv_offset.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
offset = self.conv_offset(x)
|
||||||
|
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
|
||||||
|
self.groups, self.deformable_groups)
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedDeformConv(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
|
||||||
|
groups=1, deformable_groups=1, bias=True):
|
||||||
|
super(ModulatedDeformConv, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = _pair(kernel_size)
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
self.groups = groups
|
||||||
|
self.deformable_groups = deformable_groups
|
||||||
|
self.with_bias = bias
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
n = self.in_channels
|
||||||
|
for k in self.kernel_size:
|
||||||
|
n *= k
|
||||||
|
stdv = 1. / math.sqrt(n)
|
||||||
|
self.weight.data.uniform_(-stdv, stdv)
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, offset, mask):
|
||||||
|
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
|
||||||
|
self.padding, self.dilation, self.groups,
|
||||||
|
self.deformable_groups)
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedDeformConvPack(ModulatedDeformConv):
|
||||||
|
def __init__(self, *args, extra_offset_mask=False, **kwargs):
|
||||||
|
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.extra_offset_mask = extra_offset_mask
|
||||||
|
self.conv_offset_mask = nn.Conv2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
||||||
|
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
|
||||||
|
bias=True)
|
||||||
|
self.init_offset()
|
||||||
|
|
||||||
|
def init_offset(self):
|
||||||
|
self.conv_offset_mask.weight.data.zero_()
|
||||||
|
self.conv_offset_mask.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.extra_offset_mask:
|
||||||
|
# x = [input, features]
|
||||||
|
out = self.conv_offset_mask(x[1])
|
||||||
|
x = x[0]
|
||||||
|
else:
|
||||||
|
out = self.conv_offset_mask(x)
|
||||||
|
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||||
|
offset = torch.cat((o1, o2), dim=1)
|
||||||
|
mask = torch.sigmoid(mask)
|
||||||
|
|
||||||
|
offset_mean = torch.mean(torch.abs(offset))
|
||||||
|
if offset_mean > 100:
|
||||||
|
logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
|
||||||
|
|
||||||
|
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
|
||||||
|
self.padding, self.dilation, self.groups,
|
||||||
|
self.deformable_groups)
|
22
codes/models/archs/dcn/setup.py
Normal file
22
codes/models/archs/dcn/setup.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
from setuptools import setup
|
||||||
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
|
||||||
|
def make_cuda_ext(name, sources):
|
||||||
|
|
||||||
|
return CUDAExtension(
|
||||||
|
name='{}'.format(name), sources=[p for p in sources], extra_compile_args={
|
||||||
|
'cxx': [],
|
||||||
|
'nvcc': [
|
||||||
|
'-D__CUDA_NO_HALF_OPERATORS__',
|
||||||
|
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
||||||
|
'-D__CUDA_NO_HALF2_OPERATORS__',
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='deform_conv', ext_modules=[
|
||||||
|
make_cuda_ext(name='deform_conv_cuda',
|
||||||
|
sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])
|
||||||
|
], cmdclass={'build_ext': BuildExtension}, zip_safe=False)
|
695
codes/models/archs/dcn/src/deform_conv_cuda.cpp
Normal file
695
codes/models/archs/dcn/src/deform_conv_cuda.cpp
Normal file
|
@ -0,0 +1,695 @@
|
||||||
|
// modify from
|
||||||
|
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
||||||
|
const int channels, const int height, const int width,
|
||||||
|
const int ksize_h, const int ksize_w, const int pad_h,
|
||||||
|
const int pad_w, const int stride_h, const int stride_w,
|
||||||
|
const int dilation_h, const int dilation_w,
|
||||||
|
const int parallel_imgs, const int deformable_group,
|
||||||
|
at::Tensor data_col);
|
||||||
|
|
||||||
|
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
||||||
|
const int channels, const int height, const int width,
|
||||||
|
const int ksize_h, const int ksize_w, const int pad_h,
|
||||||
|
const int pad_w, const int stride_h, const int stride_w,
|
||||||
|
const int dilation_h, const int dilation_w,
|
||||||
|
const int parallel_imgs, const int deformable_group,
|
||||||
|
at::Tensor grad_im);
|
||||||
|
|
||||||
|
void deformable_col2im_coord(
|
||||||
|
const at::Tensor data_col, const at::Tensor data_im,
|
||||||
|
const at::Tensor data_offset, const int channels, const int height,
|
||||||
|
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
||||||
|
const int pad_w, const int stride_h, const int stride_w,
|
||||||
|
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||||||
|
const int deformable_group, at::Tensor grad_offset);
|
||||||
|
|
||||||
|
void modulated_deformable_im2col_cuda(
|
||||||
|
const at::Tensor data_im, const at::Tensor data_offset,
|
||||||
|
const at::Tensor data_mask, const int batch_size, const int channels,
|
||||||
|
const int height_im, const int width_im, const int height_col,
|
||||||
|
const int width_col, const int kernel_h, const int kenerl_w,
|
||||||
|
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||||
|
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||||
|
at::Tensor data_col);
|
||||||
|
|
||||||
|
void modulated_deformable_col2im_cuda(
|
||||||
|
const at::Tensor data_col, const at::Tensor data_offset,
|
||||||
|
const at::Tensor data_mask, const int batch_size, const int channels,
|
||||||
|
const int height_im, const int width_im, const int height_col,
|
||||||
|
const int width_col, const int kernel_h, const int kenerl_w,
|
||||||
|
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||||
|
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||||
|
at::Tensor grad_im);
|
||||||
|
|
||||||
|
void modulated_deformable_col2im_coord_cuda(
|
||||||
|
const at::Tensor data_col, const at::Tensor data_im,
|
||||||
|
const at::Tensor data_offset, const at::Tensor data_mask,
|
||||||
|
const int batch_size, const int channels, const int height_im,
|
||||||
|
const int width_im, const int height_col, const int width_col,
|
||||||
|
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
||||||
|
const int stride_h, const int stride_w, const int dilation_h,
|
||||||
|
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
||||||
|
at::Tensor grad_mask);
|
||||||
|
|
||||||
|
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
||||||
|
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
||||||
|
int padW, int dilationH, int dilationW, int group,
|
||||||
|
int deformable_group) {
|
||||||
|
AT_CHECK(weight.ndimension() == 4,
|
||||||
|
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
||||||
|
"but got: %s",
|
||||||
|
weight.ndimension());
|
||||||
|
|
||||||
|
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||||
|
|
||||||
|
AT_CHECK(kW > 0 && kH > 0,
|
||||||
|
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
||||||
|
kW);
|
||||||
|
|
||||||
|
AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
||||||
|
"kernel size should be consistent with weight, ",
|
||||||
|
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
||||||
|
kW, weight.size(2), weight.size(3));
|
||||||
|
|
||||||
|
AT_CHECK(dW > 0 && dH > 0,
|
||||||
|
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
||||||
|
|
||||||
|
AT_CHECK(
|
||||||
|
dilationW > 0 && dilationH > 0,
|
||||||
|
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
||||||
|
dilationH, dilationW);
|
||||||
|
|
||||||
|
int ndim = input.ndimension();
|
||||||
|
int dimf = 0;
|
||||||
|
int dimh = 1;
|
||||||
|
int dimw = 2;
|
||||||
|
|
||||||
|
if (ndim == 4) {
|
||||||
|
dimf++;
|
||||||
|
dimh++;
|
||||||
|
dimw++;
|
||||||
|
}
|
||||||
|
|
||||||
|
AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
||||||
|
ndim);
|
||||||
|
|
||||||
|
long nInputPlane = weight.size(1) * group;
|
||||||
|
long inputHeight = input.size(dimh);
|
||||||
|
long inputWidth = input.size(dimw);
|
||||||
|
long nOutputPlane = weight.size(0);
|
||||||
|
long outputHeight =
|
||||||
|
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||||
|
long outputWidth =
|
||||||
|
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||||
|
|
||||||
|
AT_CHECK(nInputPlane % deformable_group == 0,
|
||||||
|
"input channels must divide deformable group size");
|
||||||
|
|
||||||
|
if (outputWidth < 1 || outputHeight < 1)
|
||||||
|
AT_ERROR(
|
||||||
|
"Given input size: (%ld x %ld x %ld). "
|
||||||
|
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
||||||
|
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
||||||
|
outputWidth);
|
||||||
|
|
||||||
|
AT_CHECK(input.size(1) == nInputPlane,
|
||||||
|
"invalid number of input planes, expected: %d, but got: %d",
|
||||||
|
nInputPlane, input.size(1));
|
||||||
|
|
||||||
|
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
|
||||||
|
"input image is smaller than kernel");
|
||||||
|
|
||||||
|
AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
||||||
|
"invalid spatial size of offset, expected height: %d width: %d, but "
|
||||||
|
"got height: %d width: %d",
|
||||||
|
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
||||||
|
|
||||||
|
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
||||||
|
"invalid number of channels of offset");
|
||||||
|
|
||||||
|
if (gradOutput != NULL) {
|
||||||
|
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
||||||
|
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
||||||
|
nOutputPlane, gradOutput->size(dimf));
|
||||||
|
|
||||||
|
AT_CHECK((gradOutput->size(dimh) == outputHeight &&
|
||||||
|
gradOutput->size(dimw) == outputWidth),
|
||||||
|
"invalid size of gradOutput, expected height: %d width: %d , but "
|
||||||
|
"got height: %d width: %d",
|
||||||
|
outputHeight, outputWidth, gradOutput->size(dimh),
|
||||||
|
gradOutput->size(dimw));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
||||||
|
at::Tensor offset, at::Tensor output,
|
||||||
|
at::Tensor columns, at::Tensor ones, int kW,
|
||||||
|
int kH, int dW, int dH, int padW, int padH,
|
||||||
|
int dilationW, int dilationH, int group,
|
||||||
|
int deformable_group, int im2col_step) {
|
||||||
|
// todo: resize columns to include im2col: done
|
||||||
|
// todo: add im2col_step as input
|
||||||
|
// todo: add new output buffer and transpose it to output (or directly
|
||||||
|
// transpose output) todo: possibly change data indexing because of
|
||||||
|
// parallel_imgs
|
||||||
|
|
||||||
|
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
||||||
|
dilationH, dilationW, group, deformable_group);
|
||||||
|
|
||||||
|
input = input.contiguous();
|
||||||
|
offset = offset.contiguous();
|
||||||
|
weight = weight.contiguous();
|
||||||
|
|
||||||
|
int batch = 1;
|
||||||
|
if (input.ndimension() == 3) {
|
||||||
|
// Force batch
|
||||||
|
batch = 0;
|
||||||
|
input.unsqueeze_(0);
|
||||||
|
offset.unsqueeze_(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo: assert batchsize dividable by im2col_step
|
||||||
|
|
||||||
|
long batchSize = input.size(0);
|
||||||
|
long nInputPlane = input.size(1);
|
||||||
|
long inputHeight = input.size(2);
|
||||||
|
long inputWidth = input.size(3);
|
||||||
|
|
||||||
|
long nOutputPlane = weight.size(0);
|
||||||
|
|
||||||
|
long outputWidth =
|
||||||
|
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||||
|
long outputHeight =
|
||||||
|
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||||
|
|
||||||
|
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||||
|
|
||||||
|
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
||||||
|
outputHeight, outputWidth});
|
||||||
|
columns = at::zeros(
|
||||||
|
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||||
|
input.options());
|
||||||
|
|
||||||
|
if (ones.ndimension() != 2 ||
|
||||||
|
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
||||||
|
ones = at::ones({outputHeight, outputWidth}, input.options());
|
||||||
|
}
|
||||||
|
|
||||||
|
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||||
|
inputHeight, inputWidth});
|
||||||
|
offset =
|
||||||
|
offset.view({batchSize / im2col_step, im2col_step,
|
||||||
|
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
at::Tensor output_buffer =
|
||||||
|
at::zeros({batchSize / im2col_step, nOutputPlane,
|
||||||
|
im2col_step * outputHeight, outputWidth},
|
||||||
|
output.options());
|
||||||
|
|
||||||
|
output_buffer = output_buffer.view(
|
||||||
|
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
||||||
|
output_buffer.size(2), output_buffer.size(3)});
|
||||||
|
|
||||||
|
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||||
|
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||||
|
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||||
|
dilationW, im2col_step, deformable_group, columns);
|
||||||
|
|
||||||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||||
|
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||||
|
weight.size(2), weight.size(3)});
|
||||||
|
|
||||||
|
for (int g = 0; g < group; g++) {
|
||||||
|
output_buffer[elt][g] = output_buffer[elt][g]
|
||||||
|
.flatten(1)
|
||||||
|
.addmm_(weight[g].flatten(1), columns[g])
|
||||||
|
.view_as(output_buffer[elt][g]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output_buffer = output_buffer.view(
|
||||||
|
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
||||||
|
output_buffer.size(3), output_buffer.size(4)});
|
||||||
|
|
||||||
|
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
||||||
|
im2col_step, outputHeight, outputWidth});
|
||||||
|
output_buffer.transpose_(1, 2);
|
||||||
|
output.copy_(output_buffer);
|
||||||
|
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||||
|
offset = offset.view(
|
||||||
|
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
if (batch == 0) {
|
||||||
|
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
||||||
|
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||||
|
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
||||||
|
at::Tensor gradOutput, at::Tensor gradInput,
|
||||||
|
at::Tensor gradOffset, at::Tensor weight,
|
||||||
|
at::Tensor columns, int kW, int kH, int dW,
|
||||||
|
int dH, int padW, int padH, int dilationW,
|
||||||
|
int dilationH, int group,
|
||||||
|
int deformable_group, int im2col_step) {
|
||||||
|
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
||||||
|
dilationH, dilationW, group, deformable_group);
|
||||||
|
|
||||||
|
input = input.contiguous();
|
||||||
|
offset = offset.contiguous();
|
||||||
|
gradOutput = gradOutput.contiguous();
|
||||||
|
weight = weight.contiguous();
|
||||||
|
|
||||||
|
int batch = 1;
|
||||||
|
|
||||||
|
if (input.ndimension() == 3) {
|
||||||
|
// Force batch
|
||||||
|
batch = 0;
|
||||||
|
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
||||||
|
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
||||||
|
gradOutput = gradOutput.view(
|
||||||
|
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||||
|
}
|
||||||
|
|
||||||
|
long batchSize = input.size(0);
|
||||||
|
long nInputPlane = input.size(1);
|
||||||
|
long inputHeight = input.size(2);
|
||||||
|
long inputWidth = input.size(3);
|
||||||
|
|
||||||
|
long nOutputPlane = weight.size(0);
|
||||||
|
|
||||||
|
long outputWidth =
|
||||||
|
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||||
|
long outputHeight =
|
||||||
|
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||||
|
|
||||||
|
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
||||||
|
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||||
|
columns = at::zeros(
|
||||||
|
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||||
|
input.options());
|
||||||
|
|
||||||
|
// change order of grad output
|
||||||
|
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||||
|
nOutputPlane, outputHeight, outputWidth});
|
||||||
|
gradOutput.transpose_(1, 2);
|
||||||
|
|
||||||
|
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||||
|
inputHeight, inputWidth});
|
||||||
|
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||||
|
inputHeight, inputWidth});
|
||||||
|
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
||||||
|
deformable_group * 2 * kH * kW, outputHeight,
|
||||||
|
outputWidth});
|
||||||
|
offset =
|
||||||
|
offset.view({batchSize / im2col_step, im2col_step,
|
||||||
|
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||||
|
// divide into groups
|
||||||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||||
|
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||||
|
weight.size(2), weight.size(3)});
|
||||||
|
gradOutput = gradOutput.view(
|
||||||
|
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
||||||
|
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
||||||
|
|
||||||
|
for (int g = 0; g < group; g++) {
|
||||||
|
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||||
|
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
columns =
|
||||||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||||
|
gradOutput = gradOutput.view(
|
||||||
|
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
||||||
|
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
||||||
|
|
||||||
|
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
||||||
|
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
||||||
|
dilationH, dilationW, im2col_step, deformable_group,
|
||||||
|
gradOffset[elt]);
|
||||||
|
|
||||||
|
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
||||||
|
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||||
|
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
||||||
|
}
|
||||||
|
|
||||||
|
gradOutput.transpose_(1, 2);
|
||||||
|
gradOutput =
|
||||||
|
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||||
|
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||||
|
gradOffset = gradOffset.view(
|
||||||
|
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||||
|
offset = offset.view(
|
||||||
|
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
if (batch == 0) {
|
||||||
|
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||||
|
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||||
|
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
||||||
|
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||||
|
gradOffset =
|
||||||
|
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int deform_conv_backward_parameters_cuda(
|
||||||
|
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
||||||
|
at::Tensor gradWeight, // at::Tensor gradBias,
|
||||||
|
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
||||||
|
int padW, int padH, int dilationW, int dilationH, int group,
|
||||||
|
int deformable_group, float scale, int im2col_step) {
|
||||||
|
// todo: transpose and reshape outGrad
|
||||||
|
// todo: reshape columns
|
||||||
|
// todo: add im2col_step as input
|
||||||
|
|
||||||
|
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
||||||
|
padW, dilationH, dilationW, group, deformable_group);
|
||||||
|
|
||||||
|
input = input.contiguous();
|
||||||
|
offset = offset.contiguous();
|
||||||
|
gradOutput = gradOutput.contiguous();
|
||||||
|
|
||||||
|
int batch = 1;
|
||||||
|
|
||||||
|
if (input.ndimension() == 3) {
|
||||||
|
// Force batch
|
||||||
|
batch = 0;
|
||||||
|
input = input.view(
|
||||||
|
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
||||||
|
gradOutput = gradOutput.view(
|
||||||
|
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||||
|
}
|
||||||
|
|
||||||
|
long batchSize = input.size(0);
|
||||||
|
long nInputPlane = input.size(1);
|
||||||
|
long inputHeight = input.size(2);
|
||||||
|
long inputWidth = input.size(3);
|
||||||
|
|
||||||
|
long nOutputPlane = gradWeight.size(0);
|
||||||
|
|
||||||
|
long outputWidth =
|
||||||
|
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||||
|
long outputHeight =
|
||||||
|
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||||
|
|
||||||
|
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||||
|
|
||||||
|
columns = at::zeros(
|
||||||
|
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||||
|
input.options());
|
||||||
|
|
||||||
|
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||||
|
nOutputPlane, outputHeight, outputWidth});
|
||||||
|
gradOutput.transpose_(1, 2);
|
||||||
|
|
||||||
|
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
||||||
|
gradOutputBuffer =
|
||||||
|
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
||||||
|
outputHeight, outputWidth});
|
||||||
|
gradOutputBuffer.copy_(gradOutput);
|
||||||
|
gradOutputBuffer =
|
||||||
|
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
||||||
|
im2col_step * outputHeight, outputWidth});
|
||||||
|
|
||||||
|
gradOutput.transpose_(1, 2);
|
||||||
|
gradOutput =
|
||||||
|
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||||
|
inputHeight, inputWidth});
|
||||||
|
offset =
|
||||||
|
offset.view({batchSize / im2col_step, im2col_step,
|
||||||
|
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||||
|
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||||
|
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||||
|
dilationW, im2col_step, deformable_group, columns);
|
||||||
|
|
||||||
|
// divide into group
|
||||||
|
gradOutputBuffer = gradOutputBuffer.view(
|
||||||
|
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
||||||
|
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
||||||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||||
|
gradWeight =
|
||||||
|
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
||||||
|
gradWeight.size(2), gradWeight.size(3)});
|
||||||
|
|
||||||
|
for (int g = 0; g < group; g++) {
|
||||||
|
gradWeight[g] = gradWeight[g]
|
||||||
|
.flatten(1)
|
||||||
|
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
||||||
|
columns[g].transpose(1, 0), 1.0, scale)
|
||||||
|
.view_as(gradWeight[g]);
|
||||||
|
}
|
||||||
|
gradOutputBuffer = gradOutputBuffer.view(
|
||||||
|
{gradOutputBuffer.size(0),
|
||||||
|
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
||||||
|
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
||||||
|
columns =
|
||||||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||||
|
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
||||||
|
gradWeight.size(2), gradWeight.size(3),
|
||||||
|
gradWeight.size(4)});
|
||||||
|
}
|
||||||
|
|
||||||
|
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||||
|
offset = offset.view(
|
||||||
|
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||||
|
|
||||||
|
if (batch == 0) {
|
||||||
|
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||||
|
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void modulated_deform_conv_cuda_forward(
|
||||||
|
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||||
|
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
||||||
|
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
||||||
|
const int pad_h, const int pad_w, const int dilation_h,
|
||||||
|
const int dilation_w, const int group, const int deformable_group,
|
||||||
|
const bool with_bias) {
|
||||||
|
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
||||||
|
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||||
|
|
||||||
|
const int batch = input.size(0);
|
||||||
|
const int channels = input.size(1);
|
||||||
|
const int height = input.size(2);
|
||||||
|
const int width = input.size(3);
|
||||||
|
|
||||||
|
const int channels_out = weight.size(0);
|
||||||
|
const int channels_kernel = weight.size(1);
|
||||||
|
const int kernel_h_ = weight.size(2);
|
||||||
|
const int kernel_w_ = weight.size(3);
|
||||||
|
|
||||||
|
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||||
|
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||||||
|
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||||
|
if (channels != channels_kernel * group)
|
||||||
|
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||||||
|
channels, channels_kernel * group);
|
||||||
|
|
||||||
|
const int height_out =
|
||||||
|
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||||
|
const int width_out =
|
||||||
|
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||||
|
|
||||||
|
if (ones.ndimension() != 2 ||
|
||||||
|
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||||
|
// Resize plane and fill with ones...
|
||||||
|
ones = at::ones({height_out, width_out}, input.options());
|
||||||
|
}
|
||||||
|
|
||||||
|
// resize output
|
||||||
|
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
||||||
|
// resize temporary columns
|
||||||
|
columns =
|
||||||
|
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
||||||
|
input.options());
|
||||||
|
|
||||||
|
output = output.view({output.size(0), group, output.size(1) / group,
|
||||||
|
output.size(2), output.size(3)});
|
||||||
|
|
||||||
|
for (int b = 0; b < batch; b++) {
|
||||||
|
modulated_deformable_im2col_cuda(
|
||||||
|
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||||
|
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||||
|
dilation_h, dilation_w, deformable_group, columns);
|
||||||
|
|
||||||
|
// divide into group
|
||||||
|
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||||
|
weight.size(2), weight.size(3)});
|
||||||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||||
|
|
||||||
|
for (int g = 0; g < group; g++) {
|
||||||
|
output[b][g] = output[b][g]
|
||||||
|
.flatten(1)
|
||||||
|
.addmm_(weight[g].flatten(1), columns[g])
|
||||||
|
.view_as(output[b][g]);
|
||||||
|
}
|
||||||
|
|
||||||
|
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||||
|
weight.size(3), weight.size(4)});
|
||||||
|
columns =
|
||||||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||||
|
}
|
||||||
|
|
||||||
|
output = output.view({output.size(0), output.size(1) * output.size(2),
|
||||||
|
output.size(3), output.size(4)});
|
||||||
|
|
||||||
|
if (with_bias) {
|
||||||
|
output += bias.view({1, bias.size(0), 1, 1});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void modulated_deform_conv_cuda_backward(
|
||||||
|
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||||
|
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
||||||
|
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
||||||
|
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
||||||
|
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||||
|
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||||
|
const bool with_bias) {
|
||||||
|
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
||||||
|
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||||
|
|
||||||
|
const int batch = input.size(0);
|
||||||
|
const int channels = input.size(1);
|
||||||
|
const int height = input.size(2);
|
||||||
|
const int width = input.size(3);
|
||||||
|
|
||||||
|
const int channels_kernel = weight.size(1);
|
||||||
|
const int kernel_h_ = weight.size(2);
|
||||||
|
const int kernel_w_ = weight.size(3);
|
||||||
|
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||||
|
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||||||
|
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||||
|
if (channels != channels_kernel * group)
|
||||||
|
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||||||
|
channels, channels_kernel * group);
|
||||||
|
|
||||||
|
const int height_out =
|
||||||
|
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||||
|
const int width_out =
|
||||||
|
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||||
|
|
||||||
|
if (ones.ndimension() != 2 ||
|
||||||
|
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||||
|
// Resize plane and fill with ones...
|
||||||
|
ones = at::ones({height_out, width_out}, input.options());
|
||||||
|
}
|
||||||
|
|
||||||
|
grad_input = grad_input.view({batch, channels, height, width});
|
||||||
|
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
||||||
|
input.options());
|
||||||
|
|
||||||
|
grad_output =
|
||||||
|
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
||||||
|
grad_output.size(2), grad_output.size(3)});
|
||||||
|
|
||||||
|
for (int b = 0; b < batch; b++) {
|
||||||
|
// divide int group
|
||||||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||||
|
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||||
|
weight.size(2), weight.size(3)});
|
||||||
|
|
||||||
|
for (int g = 0; g < group; g++) {
|
||||||
|
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||||
|
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
columns =
|
||||||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||||
|
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||||
|
weight.size(3), weight.size(4)});
|
||||||
|
|
||||||
|
// gradient w.r.t. input coordinate data
|
||||||
|
modulated_deformable_col2im_coord_cuda(
|
||||||
|
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
||||||
|
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||||
|
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
||||||
|
grad_mask[b]);
|
||||||
|
// gradient w.r.t. input data
|
||||||
|
modulated_deformable_col2im_cuda(
|
||||||
|
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
||||||
|
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||||
|
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
||||||
|
|
||||||
|
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
||||||
|
// group
|
||||||
|
modulated_deformable_im2col_cuda(
|
||||||
|
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||||
|
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||||
|
dilation_h, dilation_w, deformable_group, columns);
|
||||||
|
|
||||||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||||
|
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
||||||
|
grad_weight.size(1), grad_weight.size(2),
|
||||||
|
grad_weight.size(3)});
|
||||||
|
if (with_bias)
|
||||||
|
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
||||||
|
|
||||||
|
for (int g = 0; g < group; g++) {
|
||||||
|
grad_weight[g] =
|
||||||
|
grad_weight[g]
|
||||||
|
.flatten(1)
|
||||||
|
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
||||||
|
.view_as(grad_weight[g]);
|
||||||
|
if (with_bias) {
|
||||||
|
grad_bias[g] =
|
||||||
|
grad_bias[g]
|
||||||
|
.view({-1, 1})
|
||||||
|
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
||||||
|
.view(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
columns =
|
||||||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||||
|
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
||||||
|
grad_weight.size(2), grad_weight.size(3),
|
||||||
|
grad_weight.size(4)});
|
||||||
|
if (with_bias)
|
||||||
|
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
||||||
|
}
|
||||||
|
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
||||||
|
grad_output.size(2), grad_output.size(3),
|
||||||
|
grad_output.size(4)});
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
|
||||||
|
"deform forward (CUDA)");
|
||||||
|
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
|
||||||
|
"deform_conv_backward_input (CUDA)");
|
||||||
|
m.def("deform_conv_backward_parameters_cuda",
|
||||||
|
&deform_conv_backward_parameters_cuda,
|
||||||
|
"deform_conv_backward_parameters (CUDA)");
|
||||||
|
m.def("modulated_deform_conv_cuda_forward",
|
||||||
|
&modulated_deform_conv_cuda_forward,
|
||||||
|
"modulated deform conv forward (CUDA)");
|
||||||
|
m.def("modulated_deform_conv_cuda_backward",
|
||||||
|
&modulated_deform_conv_cuda_backward,
|
||||||
|
"modulated deform conv backward (CUDA)");
|
||||||
|
}
|
88
codes/models/archs/discriminator_vgg_arch.py
Normal file
88
codes/models/archs/discriminator_vgg_arch.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator_VGG_128(nn.Module):
|
||||||
|
def __init__(self, in_nc, nf):
|
||||||
|
super(Discriminator_VGG_128, self).__init__()
|
||||||
|
# [64, 128, 128]
|
||||||
|
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||||
|
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
|
||||||
|
# [64, 64, 64]
|
||||||
|
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||||
|
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||||
|
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||||
|
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
|
||||||
|
# [128, 32, 32]
|
||||||
|
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||||
|
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||||
|
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||||
|
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
|
||||||
|
# [256, 16, 16]
|
||||||
|
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
# [512, 8, 8]
|
||||||
|
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(512 * 4 * 4, 100)
|
||||||
|
self.linear2 = nn.Linear(100, 1)
|
||||||
|
|
||||||
|
# activation function
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
|
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||||
|
|
||||||
|
fea = fea.view(fea.size(0), -1)
|
||||||
|
fea = self.lrelu(self.linear1(fea))
|
||||||
|
out = self.linear2(fea)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class VGGFeatureExtractor(nn.Module):
|
||||||
|
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
||||||
|
device=torch.device('cpu')):
|
||||||
|
super(VGGFeatureExtractor, self).__init__()
|
||||||
|
self.use_input_norm = use_input_norm
|
||||||
|
if use_bn:
|
||||||
|
model = torchvision.models.vgg19_bn(pretrained=True)
|
||||||
|
else:
|
||||||
|
model = torchvision.models.vgg19(pretrained=True)
|
||||||
|
if self.use_input_norm:
|
||||||
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||||
|
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
|
||||||
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||||
|
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||||
|
self.register_buffer('mean', mean)
|
||||||
|
self.register_buffer('std', std)
|
||||||
|
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||||
|
# No need to BP to variable
|
||||||
|
for k, v in self.features.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Assume input range is [0, 1]
|
||||||
|
if self.use_input_norm:
|
||||||
|
x = (x - self.mean) / self.std
|
||||||
|
output = self.features(x)
|
||||||
|
return output
|
116
codes/models/base_model.py
Normal file
116
codes/models/base_model.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel():
|
||||||
|
def __init__(self, opt):
|
||||||
|
self.opt = opt
|
||||||
|
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
||||||
|
self.is_train = opt['is_train']
|
||||||
|
self.schedulers = []
|
||||||
|
self.optimizers = []
|
||||||
|
|
||||||
|
def feed_data(self, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def optimize_parameters(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_current_visuals(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_current_losses(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_network(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self, label):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _set_lr(self, lr_groups_l):
|
||||||
|
"""Set learning rate for warmup
|
||||||
|
lr_groups_l: list for lr_groups. each for a optimizer"""
|
||||||
|
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
||||||
|
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
def _get_init_lr(self):
|
||||||
|
"""Get the initial lr, which is set by the scheduler"""
|
||||||
|
init_lr_groups_l = []
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
|
||||||
|
return init_lr_groups_l
|
||||||
|
|
||||||
|
def update_learning_rate(self, cur_iter, warmup_iter=-1):
|
||||||
|
for scheduler in self.schedulers:
|
||||||
|
scheduler.step()
|
||||||
|
# set up warm-up learning rate
|
||||||
|
if cur_iter < warmup_iter:
|
||||||
|
# get initial lr for each group
|
||||||
|
init_lr_g_l = self._get_init_lr()
|
||||||
|
# modify warming-up learning rates
|
||||||
|
warm_up_lr_l = []
|
||||||
|
for init_lr_g in init_lr_g_l:
|
||||||
|
warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
|
||||||
|
# set learning rate
|
||||||
|
self._set_lr(warm_up_lr_l)
|
||||||
|
|
||||||
|
def get_current_learning_rate(self):
|
||||||
|
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
|
||||||
|
|
||||||
|
def get_network_description(self, network):
|
||||||
|
"""Get the string and total parameters of the network"""
|
||||||
|
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||||
|
network = network.module
|
||||||
|
return str(network), sum(map(lambda x: x.numel(), network.parameters()))
|
||||||
|
|
||||||
|
def save_network(self, network, network_label, iter_label):
|
||||||
|
save_filename = '{}_{}.pth'.format(iter_label, network_label)
|
||||||
|
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
||||||
|
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||||
|
network = network.module
|
||||||
|
state_dict = network.state_dict()
|
||||||
|
for key, param in state_dict.items():
|
||||||
|
state_dict[key] = param.cpu()
|
||||||
|
torch.save(state_dict, save_path)
|
||||||
|
|
||||||
|
def load_network(self, load_path, network, strict=True):
|
||||||
|
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||||
|
network = network.module
|
||||||
|
load_net = torch.load(load_path)
|
||||||
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||||
|
for k, v in load_net.items():
|
||||||
|
if k.startswith('module.'):
|
||||||
|
load_net_clean[k[7:]] = v
|
||||||
|
else:
|
||||||
|
load_net_clean[k] = v
|
||||||
|
network.load_state_dict(load_net_clean, strict=strict)
|
||||||
|
|
||||||
|
def save_training_state(self, epoch, iter_step):
|
||||||
|
"""Save training state during training, which will be used for resuming"""
|
||||||
|
state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
|
||||||
|
for s in self.schedulers:
|
||||||
|
state['schedulers'].append(s.state_dict())
|
||||||
|
for o in self.optimizers:
|
||||||
|
state['optimizers'].append(o.state_dict())
|
||||||
|
save_filename = '{}.state'.format(iter_step)
|
||||||
|
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||||
|
torch.save(state, save_path)
|
||||||
|
|
||||||
|
def resume_training(self, resume_state):
|
||||||
|
"""Resume the optimizers and schedulers for training"""
|
||||||
|
resume_optimizers = resume_state['optimizers']
|
||||||
|
resume_schedulers = resume_state['schedulers']
|
||||||
|
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
|
||||||
|
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
|
||||||
|
for i, o in enumerate(resume_optimizers):
|
||||||
|
self.optimizers[i].load_state_dict(o)
|
||||||
|
for i, s in enumerate(resume_schedulers):
|
||||||
|
self.schedulers[i].load_state_dict(s)
|
74
codes/models/loss.py
Normal file
74
codes/models/loss.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class CharbonnierLoss(nn.Module):
|
||||||
|
"""Charbonnier Loss (L1)"""
|
||||||
|
|
||||||
|
def __init__(self, eps=1e-6):
|
||||||
|
super(CharbonnierLoss, self).__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
diff = x - y
|
||||||
|
loss = torch.sum(torch.sqrt(diff * diff + self.eps))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# Define GAN loss: [vanilla | lsgan | wgan-gp]
|
||||||
|
class GANLoss(nn.Module):
|
||||||
|
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
||||||
|
super(GANLoss, self).__init__()
|
||||||
|
self.gan_type = gan_type.lower()
|
||||||
|
self.real_label_val = real_label_val
|
||||||
|
self.fake_label_val = fake_label_val
|
||||||
|
|
||||||
|
if self.gan_type == 'gan' or self.gan_type == 'ragan':
|
||||||
|
self.loss = nn.BCEWithLogitsLoss()
|
||||||
|
elif self.gan_type == 'lsgan':
|
||||||
|
self.loss = nn.MSELoss()
|
||||||
|
elif self.gan_type == 'wgan-gp':
|
||||||
|
|
||||||
|
def wgan_loss(input, target):
|
||||||
|
# target is boolean
|
||||||
|
return -1 * input.mean() if target else input.mean()
|
||||||
|
|
||||||
|
self.loss = wgan_loss
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
||||||
|
|
||||||
|
def get_target_label(self, input, target_is_real):
|
||||||
|
if self.gan_type == 'wgan-gp':
|
||||||
|
return target_is_real
|
||||||
|
if target_is_real:
|
||||||
|
return torch.empty_like(input).fill_(self.real_label_val)
|
||||||
|
else:
|
||||||
|
return torch.empty_like(input).fill_(self.fake_label_val)
|
||||||
|
|
||||||
|
def forward(self, input, target_is_real):
|
||||||
|
target_label = self.get_target_label(input, target_is_real)
|
||||||
|
loss = self.loss(input, target_label)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class GradientPenaltyLoss(nn.Module):
|
||||||
|
def __init__(self, device=torch.device('cpu')):
|
||||||
|
super(GradientPenaltyLoss, self).__init__()
|
||||||
|
self.register_buffer('grad_outputs', torch.Tensor())
|
||||||
|
self.grad_outputs = self.grad_outputs.to(device)
|
||||||
|
|
||||||
|
def get_grad_outputs(self, input):
|
||||||
|
if self.grad_outputs.size() != input.size():
|
||||||
|
self.grad_outputs.resize_(input.size()).fill_(1.0)
|
||||||
|
return self.grad_outputs
|
||||||
|
|
||||||
|
def forward(self, interp, interp_crit):
|
||||||
|
grad_outputs = self.get_grad_outputs(interp_crit)
|
||||||
|
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
|
||||||
|
grad_outputs=grad_outputs, create_graph=True,
|
||||||
|
retain_graph=True, only_inputs=True)[0]
|
||||||
|
grad_interp = grad_interp.view(grad_interp.size(0), -1)
|
||||||
|
grad_interp_norm = grad_interp.norm(2, dim=1)
|
||||||
|
|
||||||
|
loss = ((grad_interp_norm - 1)**2).mean()
|
||||||
|
return loss
|
144
codes/models/lr_scheduler.py
Normal file
144
codes/models/lr_scheduler.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
import math
|
||||||
|
from collections import Counter
|
||||||
|
from collections import defaultdict
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStepLR_Restart(_LRScheduler):
|
||||||
|
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
||||||
|
clear_state=False, last_epoch=-1):
|
||||||
|
self.milestones = Counter(milestones)
|
||||||
|
self.gamma = gamma
|
||||||
|
self.clear_state = clear_state
|
||||||
|
self.restarts = restarts if restarts else [0]
|
||||||
|
self.restarts = [v + 1 for v in self.restarts]
|
||||||
|
self.restart_weights = weights if weights else [1]
|
||||||
|
assert len(self.restarts) == len(
|
||||||
|
self.restart_weights), 'restarts and their weights do not match.'
|
||||||
|
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if self.last_epoch in self.restarts:
|
||||||
|
if self.clear_state:
|
||||||
|
self.optimizer.state = defaultdict(dict)
|
||||||
|
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
||||||
|
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
||||||
|
if self.last_epoch not in self.milestones:
|
||||||
|
return [group['lr'] for group in self.optimizer.param_groups]
|
||||||
|
return [
|
||||||
|
group['lr'] * self.gamma**self.milestones[self.last_epoch]
|
||||||
|
for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CosineAnnealingLR_Restart(_LRScheduler):
|
||||||
|
def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
|
||||||
|
self.T_period = T_period
|
||||||
|
self.T_max = self.T_period[0] # current T period
|
||||||
|
self.eta_min = eta_min
|
||||||
|
self.restarts = restarts if restarts else [0]
|
||||||
|
self.restarts = [v + 1 for v in self.restarts]
|
||||||
|
self.restart_weights = weights if weights else [1]
|
||||||
|
self.last_restart = 0
|
||||||
|
assert len(self.restarts) == len(
|
||||||
|
self.restart_weights), 'restarts and their weights do not match.'
|
||||||
|
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if self.last_epoch == 0:
|
||||||
|
return self.base_lrs
|
||||||
|
elif self.last_epoch in self.restarts:
|
||||||
|
self.last_restart = self.last_epoch
|
||||||
|
self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
|
||||||
|
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
||||||
|
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
||||||
|
elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
|
||||||
|
return [
|
||||||
|
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
|
||||||
|
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
||||||
|
]
|
||||||
|
return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
|
||||||
|
(1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
|
||||||
|
(group['lr'] - self.eta_min) + self.eta_min
|
||||||
|
for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
|
||||||
|
betas=(0.9, 0.99))
|
||||||
|
##############################
|
||||||
|
# MultiStepLR_Restart
|
||||||
|
##############################
|
||||||
|
## Original
|
||||||
|
lr_steps = [200000, 400000, 600000, 800000]
|
||||||
|
restarts = None
|
||||||
|
restart_weights = None
|
||||||
|
|
||||||
|
## two
|
||||||
|
lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
|
||||||
|
restarts = [500000]
|
||||||
|
restart_weights = [1]
|
||||||
|
|
||||||
|
## four
|
||||||
|
lr_steps = [
|
||||||
|
50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
|
||||||
|
600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
|
||||||
|
]
|
||||||
|
restarts = [250000, 500000, 750000]
|
||||||
|
restart_weights = [1, 1, 1]
|
||||||
|
|
||||||
|
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
|
||||||
|
clear_state=False)
|
||||||
|
|
||||||
|
##############################
|
||||||
|
# Cosine Annealing Restart
|
||||||
|
##############################
|
||||||
|
## two
|
||||||
|
T_period = [500000, 500000]
|
||||||
|
restarts = [500000]
|
||||||
|
restart_weights = [1]
|
||||||
|
|
||||||
|
## four
|
||||||
|
T_period = [250000, 250000, 250000, 250000]
|
||||||
|
restarts = [250000, 500000, 750000]
|
||||||
|
restart_weights = [1, 1, 1]
|
||||||
|
|
||||||
|
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
|
||||||
|
weights=restart_weights)
|
||||||
|
|
||||||
|
##############################
|
||||||
|
# Draw figure
|
||||||
|
##############################
|
||||||
|
N_iter = 1000000
|
||||||
|
lr_l = list(range(N_iter))
|
||||||
|
for i in range(N_iter):
|
||||||
|
scheduler.step()
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
lr_l[i] = current_lr
|
||||||
|
|
||||||
|
import matplotlib as mpl
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import matplotlib.ticker as mtick
|
||||||
|
mpl.style.use('default')
|
||||||
|
import seaborn
|
||||||
|
seaborn.set(style='whitegrid')
|
||||||
|
seaborn.set_context('paper')
|
||||||
|
|
||||||
|
plt.figure(1)
|
||||||
|
plt.subplot(111)
|
||||||
|
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
|
||||||
|
plt.title('Title', fontsize=16, color='k')
|
||||||
|
plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
|
||||||
|
legend = plt.legend(loc='upper right', shadow=False)
|
||||||
|
ax = plt.gca()
|
||||||
|
labels = ax.get_xticks().tolist()
|
||||||
|
for k, v in enumerate(labels):
|
||||||
|
labels[k] = str(int(v / 1000)) + 'K'
|
||||||
|
ax.set_xticklabels(labels)
|
||||||
|
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
|
||||||
|
|
||||||
|
ax.set_ylabel('Learning rate')
|
||||||
|
ax.set_xlabel('Iteration')
|
||||||
|
fig = plt.gcf()
|
||||||
|
plt.show()
|
57
codes/models/networks.py
Normal file
57
codes/models/networks.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
import torch
|
||||||
|
import models.archs.SRResNet_arch as SRResNet_arch
|
||||||
|
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||||
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||||
|
import models.archs.EDVR_arch as EDVR_arch
|
||||||
|
|
||||||
|
|
||||||
|
# Generator
|
||||||
|
def define_G(opt):
|
||||||
|
opt_net = opt['network_G']
|
||||||
|
which_model = opt_net['which_model_G']
|
||||||
|
|
||||||
|
# image restoration
|
||||||
|
if which_model == 'MSRResNet':
|
||||||
|
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
|
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||||
|
elif which_model == 'RRDBNet':
|
||||||
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
|
nf=opt_net['nf'], nb=opt_net['nb'])
|
||||||
|
# video restoration
|
||||||
|
elif which_model == 'EDVR':
|
||||||
|
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
||||||
|
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
|
||||||
|
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
|
||||||
|
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
|
||||||
|
w_TSA=opt_net['w_TSA'])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
||||||
|
return netG
|
||||||
|
|
||||||
|
|
||||||
|
# Discriminator
|
||||||
|
def define_D(opt):
|
||||||
|
opt_net = opt['network_D']
|
||||||
|
which_model = opt_net['which_model_D']
|
||||||
|
|
||||||
|
if which_model == 'discriminator_vgg_128':
|
||||||
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
|
return netD
|
||||||
|
|
||||||
|
|
||||||
|
# Define network used for perceptual loss
|
||||||
|
def define_F(opt, use_bn=False):
|
||||||
|
gpu_ids = opt['gpu_ids']
|
||||||
|
device = torch.device('cuda' if gpu_ids else 'cpu')
|
||||||
|
# PyTorch pretrained VGG19-54, before ReLU.
|
||||||
|
if use_bn:
|
||||||
|
feature_layer = 49
|
||||||
|
else:
|
||||||
|
feature_layer = 34
|
||||||
|
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||||
|
use_input_norm=True, device=device)
|
||||||
|
netF.eval() # No need to train
|
||||||
|
return netF
|
0
codes/options/__init__.py
Normal file
0
codes/options/__init__.py
Normal file
116
codes/options/options.py
Normal file
116
codes/options/options.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
from utils.util import OrderedYaml
|
||||||
|
Loader, Dumper = OrderedYaml()
|
||||||
|
|
||||||
|
|
||||||
|
def parse(opt_path, is_train=True):
|
||||||
|
with open(opt_path, mode='r') as f:
|
||||||
|
opt = yaml.load(f, Loader=Loader)
|
||||||
|
# export CUDA_VISIBLE_DEVICES
|
||||||
|
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||||
|
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||||
|
|
||||||
|
opt['is_train'] = is_train
|
||||||
|
if opt['distortion'] == 'sr':
|
||||||
|
scale = opt['scale']
|
||||||
|
|
||||||
|
# datasets
|
||||||
|
for phase, dataset in opt['datasets'].items():
|
||||||
|
phase = phase.split('_')[0]
|
||||||
|
dataset['phase'] = phase
|
||||||
|
if opt['distortion'] == 'sr':
|
||||||
|
dataset['scale'] = scale
|
||||||
|
is_lmdb = False
|
||||||
|
if dataset.get('dataroot_GT', None) is not None:
|
||||||
|
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
|
||||||
|
if dataset['dataroot_GT'].endswith('lmdb'):
|
||||||
|
is_lmdb = True
|
||||||
|
if dataset.get('dataroot_LQ', None) is not None:
|
||||||
|
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
|
||||||
|
if dataset['dataroot_LQ'].endswith('lmdb'):
|
||||||
|
is_lmdb = True
|
||||||
|
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
|
||||||
|
if dataset['mode'].endswith('mc'): # for memcached
|
||||||
|
dataset['data_type'] = 'mc'
|
||||||
|
dataset['mode'] = dataset['mode'].replace('_mc', '')
|
||||||
|
|
||||||
|
# path
|
||||||
|
for key, path in opt['path'].items():
|
||||||
|
if path and key in opt['path'] and key != 'strict_load':
|
||||||
|
opt['path'][key] = osp.expanduser(path)
|
||||||
|
opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
|
||||||
|
if is_train:
|
||||||
|
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
|
||||||
|
opt['path']['experiments_root'] = experiments_root
|
||||||
|
opt['path']['models'] = osp.join(experiments_root, 'models')
|
||||||
|
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
|
||||||
|
opt['path']['log'] = experiments_root
|
||||||
|
opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
|
||||||
|
|
||||||
|
# change some options for debug mode
|
||||||
|
if 'debug' in opt['name']:
|
||||||
|
opt['train']['val_freq'] = 8
|
||||||
|
opt['logger']['print_freq'] = 1
|
||||||
|
opt['logger']['save_checkpoint_freq'] = 8
|
||||||
|
else: # test
|
||||||
|
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
|
||||||
|
opt['path']['results_root'] = results_root
|
||||||
|
opt['path']['log'] = results_root
|
||||||
|
|
||||||
|
# network
|
||||||
|
if opt['distortion'] == 'sr':
|
||||||
|
opt['network_G']['scale'] = scale
|
||||||
|
|
||||||
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
def dict2str(opt, indent_l=1):
|
||||||
|
'''dict to string for logger'''
|
||||||
|
msg = ''
|
||||||
|
for k, v in opt.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
msg += ' ' * (indent_l * 2) + k + ':[\n'
|
||||||
|
msg += dict2str(v, indent_l + 1)
|
||||||
|
msg += ' ' * (indent_l * 2) + ']\n'
|
||||||
|
else:
|
||||||
|
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
class NoneDict(dict):
|
||||||
|
def __missing__(self, key):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# convert to NoneDict, which return None for missing key.
|
||||||
|
def dict_to_nonedict(opt):
|
||||||
|
if isinstance(opt, dict):
|
||||||
|
new_opt = dict()
|
||||||
|
for key, sub_opt in opt.items():
|
||||||
|
new_opt[key] = dict_to_nonedict(sub_opt)
|
||||||
|
return NoneDict(**new_opt)
|
||||||
|
elif isinstance(opt, list):
|
||||||
|
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
||||||
|
else:
|
||||||
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
def check_resume(opt, resume_iter):
|
||||||
|
'''Check resume states and pretrain_model paths'''
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
if opt['path']['resume_state']:
|
||||||
|
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
|
||||||
|
'pretrain_model_D', None) is not None:
|
||||||
|
logger.warning('pretrain_model path will be ignored when resuming training.')
|
||||||
|
|
||||||
|
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
||||||
|
'{}_G.pth'.format(resume_iter))
|
||||||
|
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
||||||
|
if 'gan' in opt['model']:
|
||||||
|
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
||||||
|
'{}_D.pth'.format(resume_iter))
|
||||||
|
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
32
codes/options/test/test_ESRGAN.yml
Normal file
32
codes/options/test/test_ESRGAN.yml
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
name: RRDB_ESRGAN_x4
|
||||||
|
suffix: ~ # add suffix to saved images
|
||||||
|
model: sr
|
||||||
|
distortion: sr
|
||||||
|
scale: 4
|
||||||
|
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||||
|
gpu_ids: [0]
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
test_1: # the 1st test dataset
|
||||||
|
name: set5
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set5/Set5
|
||||||
|
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
||||||
|
test_2: # the 2st test dataset
|
||||||
|
name: set14
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set14/Set14
|
||||||
|
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
network_G:
|
||||||
|
which_model_G: RRDBNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
nf: 64
|
||||||
|
nb: 23
|
||||||
|
upscale: 4
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth
|
32
codes/options/test/test_SRGAN.yml
Normal file
32
codes/options/test/test_SRGAN.yml
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
name: MSRGANx4
|
||||||
|
suffix: ~ # add suffix to saved images
|
||||||
|
model: sr
|
||||||
|
distortion: sr
|
||||||
|
scale: 4
|
||||||
|
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||||
|
gpu_ids: [0]
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
test_1: # the 1st test dataset
|
||||||
|
name: set5
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set5/Set5
|
||||||
|
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
||||||
|
test_2: # the 2st test dataset
|
||||||
|
name: set14
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set14/Set14
|
||||||
|
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
network_G:
|
||||||
|
which_model_G: MSRResNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
nf: 64
|
||||||
|
nb: 16
|
||||||
|
upscale: 4
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth
|
48
codes/options/test/test_SRResNet.yml
Normal file
48
codes/options/test/test_SRResNet.yml
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
name: MSRResNetx4
|
||||||
|
suffix: ~ # add suffix to saved images
|
||||||
|
model: sr
|
||||||
|
distortion: sr
|
||||||
|
scale: 4
|
||||||
|
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||||
|
gpu_ids: [0]
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
test_1: # the 1st test dataset
|
||||||
|
name: set5
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set5/Set5
|
||||||
|
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
||||||
|
test_2: # the 2st test dataset
|
||||||
|
name: set14
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set14/Set14
|
||||||
|
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||||
|
test_3:
|
||||||
|
name: bsd100
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/BSD/BSDS100
|
||||||
|
dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4
|
||||||
|
test_4:
|
||||||
|
name: urban100
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/urban100
|
||||||
|
dataroot_LQ: ../datasets/urban100_bicLRx4
|
||||||
|
test_5:
|
||||||
|
name: div2k100
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR
|
||||||
|
dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4
|
||||||
|
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
network_G:
|
||||||
|
which_model_G: MSRResNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
nf: 64
|
||||||
|
nb: 16
|
||||||
|
upscale: 4
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
|
80
codes/options/train/train_EDVR_M.yml
Normal file
80
codes/options/train/train_EDVR_M.yml
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
#### general settings
|
||||||
|
name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new
|
||||||
|
use_tb_logger: true
|
||||||
|
model: video_base
|
||||||
|
distortion: sr
|
||||||
|
scale: 4
|
||||||
|
gpu_ids: [0,1,2,3,4,5,6,7]
|
||||||
|
|
||||||
|
#### datasets
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: REDS
|
||||||
|
mode: REDS
|
||||||
|
interval_list: [1]
|
||||||
|
random_reverse: false
|
||||||
|
border_mode: false
|
||||||
|
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
|
||||||
|
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
|
||||||
|
cache_keys: ~
|
||||||
|
|
||||||
|
N_frames: 5
|
||||||
|
use_shuffle: true
|
||||||
|
n_workers: 3 # per GPU
|
||||||
|
batch_size: 32
|
||||||
|
GT_size: 256
|
||||||
|
LQ_size: 64
|
||||||
|
use_flip: true
|
||||||
|
use_rot: true
|
||||||
|
color: RGB
|
||||||
|
val:
|
||||||
|
name: REDS4
|
||||||
|
mode: video_test
|
||||||
|
dataroot_GT: ../datasets/REDS4/GT
|
||||||
|
dataroot_LQ: ../datasets/REDS4/sharp_bicubic
|
||||||
|
cache_data: True
|
||||||
|
N_frames: 5
|
||||||
|
padding: new_info
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
network_G:
|
||||||
|
which_model_G: EDVR
|
||||||
|
nf: 64
|
||||||
|
nframes: 5
|
||||||
|
groups: 8
|
||||||
|
front_RBs: 5
|
||||||
|
back_RBs: 10
|
||||||
|
predeblur: false
|
||||||
|
HR_in: false
|
||||||
|
w_TSA: true
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth
|
||||||
|
strict_load: false
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
#### training settings: learning rate scheme, loss
|
||||||
|
train:
|
||||||
|
lr_G: !!float 4e-4
|
||||||
|
lr_scheme: CosineAnnealingLR_Restart
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
niter: 600000
|
||||||
|
ft_tsa_only: 50000
|
||||||
|
warmup_iter: -1 # -1: no warm up
|
||||||
|
T_period: [50000, 100000, 150000, 150000, 150000]
|
||||||
|
restarts: [50000, 150000, 300000, 450000]
|
||||||
|
restart_weights: [1, 1, 1, 1]
|
||||||
|
eta_min: !!float 1e-7
|
||||||
|
|
||||||
|
pixel_criterion: cb
|
||||||
|
pixel_weight: 1.0
|
||||||
|
val_freq: !!float 5e3
|
||||||
|
|
||||||
|
manual_seed: 0
|
||||||
|
|
||||||
|
#### logger
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
71
codes/options/train/train_EDVR_woTSA_M.yml
Normal file
71
codes/options/train/train_EDVR_woTSA_M.yml
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
#### general settings
|
||||||
|
name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S
|
||||||
|
use_tb_logger: true
|
||||||
|
model: video_base
|
||||||
|
distortion: sr
|
||||||
|
scale: 4
|
||||||
|
gpu_ids: [0,1,2,3,4,5,6,7]
|
||||||
|
|
||||||
|
#### datasets
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: REDS
|
||||||
|
mode: REDS
|
||||||
|
interval_list: [1]
|
||||||
|
random_reverse: false
|
||||||
|
border_mode: false
|
||||||
|
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
|
||||||
|
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
|
||||||
|
cache_keys: ~
|
||||||
|
|
||||||
|
N_frames: 5
|
||||||
|
use_shuffle: true
|
||||||
|
n_workers: 3 # per GPU
|
||||||
|
batch_size: 32
|
||||||
|
GT_size: 256
|
||||||
|
LQ_size: 64
|
||||||
|
use_flip: true
|
||||||
|
use_rot: true
|
||||||
|
color: RGB
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
network_G:
|
||||||
|
which_model_G: EDVR
|
||||||
|
nf: 64
|
||||||
|
nframes: 5
|
||||||
|
groups: 8
|
||||||
|
front_RBs: 5
|
||||||
|
back_RBs: 10
|
||||||
|
predeblur: false
|
||||||
|
HR_in: false
|
||||||
|
w_TSA: false
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_G: ~
|
||||||
|
strict_load: true
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
#### training settings: learning rate scheme, loss
|
||||||
|
train:
|
||||||
|
lr_G: !!float 4e-4
|
||||||
|
lr_scheme: CosineAnnealingLR_Restart
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
niter: 600000
|
||||||
|
warmup_iter: -1 # -1: no warm up
|
||||||
|
T_period: [150000, 150000, 150000, 150000]
|
||||||
|
restarts: [150000, 300000, 450000]
|
||||||
|
restart_weights: [1, 1, 1]
|
||||||
|
eta_min: !!float 1e-7
|
||||||
|
|
||||||
|
pixel_criterion: cb
|
||||||
|
pixel_weight: 1.0
|
||||||
|
val_freq: !!float 5e3
|
||||||
|
|
||||||
|
manual_seed: 0
|
||||||
|
|
||||||
|
#### logger
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
81
codes/options/train/train_ESRGAN.yml
Normal file
81
codes/options/train/train_ESRGAN.yml
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
#### general settings
|
||||||
|
name: 003_RRDB_ESRGANx4_DIV2K
|
||||||
|
use_tb_logger: true
|
||||||
|
model: srgan
|
||||||
|
distortion: sr
|
||||||
|
scale: 4
|
||||||
|
gpu_ids: [2]
|
||||||
|
|
||||||
|
#### datasets
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: DIV2K
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
|
||||||
|
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
|
||||||
|
|
||||||
|
use_shuffle: true
|
||||||
|
n_workers: 6 # per GPU
|
||||||
|
batch_size: 16
|
||||||
|
GT_size: 128
|
||||||
|
use_flip: true
|
||||||
|
use_rot: true
|
||||||
|
color: RGB
|
||||||
|
val:
|
||||||
|
name: val_set14
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set14/Set14
|
||||||
|
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
network_G:
|
||||||
|
which_model_G: RRDBNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
nf: 64
|
||||||
|
nb: 23
|
||||||
|
network_D:
|
||||||
|
which_model_D: discriminator_vgg_128
|
||||||
|
in_nc: 3
|
||||||
|
nf: 64
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_G: ../experiments/pretrained_models/RRDB_PSNR_x4.pth
|
||||||
|
strict_load: true
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
#### training settings: learning rate scheme, loss
|
||||||
|
train:
|
||||||
|
lr_G: !!float 1e-4
|
||||||
|
weight_decay_G: 0
|
||||||
|
beta1_G: 0.9
|
||||||
|
beta2_G: 0.99
|
||||||
|
lr_D: !!float 1e-4
|
||||||
|
weight_decay_D: 0
|
||||||
|
beta1_D: 0.9
|
||||||
|
beta2_D: 0.99
|
||||||
|
lr_scheme: MultiStepLR
|
||||||
|
|
||||||
|
niter: 400000
|
||||||
|
warmup_iter: -1 # no warm up
|
||||||
|
lr_steps: [50000, 100000, 200000, 300000]
|
||||||
|
lr_gamma: 0.5
|
||||||
|
|
||||||
|
pixel_criterion: l1
|
||||||
|
pixel_weight: !!float 1e-2
|
||||||
|
feature_criterion: l1
|
||||||
|
feature_weight: 1
|
||||||
|
gan_type: ragan # gan | ragan
|
||||||
|
gan_weight: !!float 5e-3
|
||||||
|
|
||||||
|
D_update_ratio: 1
|
||||||
|
D_init_iters: 0
|
||||||
|
|
||||||
|
manual_seed: 10
|
||||||
|
val_freq: !!float 5e3
|
||||||
|
|
||||||
|
#### logger
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
85
codes/options/train/train_SRGAN.yml
Normal file
85
codes/options/train/train_SRGAN.yml
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# Not exactly the same as SRGAN in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
|
||||||
|
# With 16 Residual blocks w/o BN
|
||||||
|
|
||||||
|
#### general settings
|
||||||
|
name: 002_SRGANx4_MSRResNetx4Ini_DIV2K
|
||||||
|
use_tb_logger: true
|
||||||
|
model: srgan
|
||||||
|
distortion: sr
|
||||||
|
scale: 4
|
||||||
|
gpu_ids: [1]
|
||||||
|
|
||||||
|
#### datasets
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: DIV2K
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
|
||||||
|
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
|
||||||
|
|
||||||
|
use_shuffle: true
|
||||||
|
n_workers: 6 # per GPU
|
||||||
|
batch_size: 16
|
||||||
|
GT_size: 128
|
||||||
|
use_flip: true
|
||||||
|
use_rot: true
|
||||||
|
color: RGB
|
||||||
|
val:
|
||||||
|
name: val_set14
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set14/Set14
|
||||||
|
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
network_G:
|
||||||
|
which_model_G: MSRResNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
nf: 64
|
||||||
|
nb: 16
|
||||||
|
upscale: 4
|
||||||
|
network_D:
|
||||||
|
which_model_D: discriminator_vgg_128
|
||||||
|
in_nc: 3
|
||||||
|
nf: 64
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
|
||||||
|
strict_load: true
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
#### training settings: learning rate scheme, loss
|
||||||
|
train:
|
||||||
|
lr_G: !!float 1e-4
|
||||||
|
weight_decay_G: 0
|
||||||
|
beta1_G: 0.9
|
||||||
|
beta2_G: 0.99
|
||||||
|
lr_D: !!float 1e-4
|
||||||
|
weight_decay_D: 0
|
||||||
|
beta1_D: 0.9
|
||||||
|
beta2_D: 0.99
|
||||||
|
lr_scheme: MultiStepLR
|
||||||
|
|
||||||
|
niter: 400000
|
||||||
|
warmup_iter: -1 # no warm up
|
||||||
|
lr_steps: [50000, 100000, 200000, 300000]
|
||||||
|
lr_gamma: 0.5
|
||||||
|
|
||||||
|
pixel_criterion: l1
|
||||||
|
pixel_weight: !!float 1e-2
|
||||||
|
feature_criterion: l1
|
||||||
|
feature_weight: 1
|
||||||
|
gan_type: gan # gan | ragan
|
||||||
|
gan_weight: !!float 5e-3
|
||||||
|
|
||||||
|
D_update_ratio: 1
|
||||||
|
D_init_iters: 0
|
||||||
|
|
||||||
|
manual_seed: 10
|
||||||
|
val_freq: !!float 5e3
|
||||||
|
|
||||||
|
#### logger
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
70
codes/options/train/train_SRResNet.yml
Normal file
70
codes/options/train/train_SRResNet.yml
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# Not exactly the same as SRResNet in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
|
||||||
|
# With 16 Residual blocks w/o BN
|
||||||
|
|
||||||
|
#### general settings
|
||||||
|
name: 001_MSRResNetx4_scratch_DIV2K
|
||||||
|
use_tb_logger: true
|
||||||
|
model: sr
|
||||||
|
distortion: sr
|
||||||
|
scale: 4
|
||||||
|
gpu_ids: [0]
|
||||||
|
|
||||||
|
#### datasets
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: DIV2K
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
|
||||||
|
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
|
||||||
|
|
||||||
|
use_shuffle: true
|
||||||
|
n_workers: 6 # per GPU
|
||||||
|
batch_size: 16
|
||||||
|
GT_size: 128
|
||||||
|
use_flip: true
|
||||||
|
use_rot: true
|
||||||
|
color: RGB
|
||||||
|
val:
|
||||||
|
name: val_set5
|
||||||
|
mode: LQGT
|
||||||
|
dataroot_GT: ../datasets/val_set5/Set5
|
||||||
|
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
network_G:
|
||||||
|
which_model_G: MSRResNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
nf: 64
|
||||||
|
nb: 16
|
||||||
|
upscale: 4
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_G: ~
|
||||||
|
strict_load: true
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
#### training settings: learning rate scheme, loss
|
||||||
|
train:
|
||||||
|
lr_G: !!float 2e-4
|
||||||
|
lr_scheme: CosineAnnealingLR_Restart
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
niter: 1000000
|
||||||
|
warmup_iter: -1 # no warm up
|
||||||
|
T_period: [250000, 250000, 250000, 250000]
|
||||||
|
restarts: [250000, 500000, 750000]
|
||||||
|
restart_weights: [1, 1, 1]
|
||||||
|
eta_min: !!float 1e-7
|
||||||
|
|
||||||
|
pixel_criterion: l1
|
||||||
|
pixel_weight: 1.0
|
||||||
|
|
||||||
|
manual_seed: 10
|
||||||
|
val_freq: !!float 5e3
|
||||||
|
|
||||||
|
#### logger
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
10
codes/run_scripts.sh
Normal file
10
codes/run_scripts.sh
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# single GPU training (image SR)
|
||||||
|
python train.py -opt options/train/train_SRResNet.yml
|
||||||
|
python train.py -opt options/train/train_SRGAN.yml
|
||||||
|
python train.py -opt options/train/train_ESRGAN.yml
|
||||||
|
|
||||||
|
|
||||||
|
# distributed training (video SR)
|
||||||
|
# 8 GPUs
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch
|
20
codes/scripts/back_projection/backprojection.m
Normal file
20
codes/scripts/back_projection/backprojection.m
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
function [im_h] = backprojection(im_h, im_l, maxIter)
|
||||||
|
|
||||||
|
[row_l, col_l,~] = size(im_l);
|
||||||
|
[row_h, col_h,~] = size(im_h);
|
||||||
|
|
||||||
|
p = fspecial('gaussian', 5, 1);
|
||||||
|
p = p.^2;
|
||||||
|
p = p./sum(p(:));
|
||||||
|
|
||||||
|
im_l = double(im_l);
|
||||||
|
im_h = double(im_h);
|
||||||
|
|
||||||
|
for ii = 1:maxIter
|
||||||
|
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
|
||||||
|
im_diff = im_l - im_l_s;
|
||||||
|
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
|
||||||
|
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
|
||||||
|
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
|
||||||
|
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
|
||||||
|
end
|
22
codes/scripts/back_projection/main_bp.m
Normal file
22
codes/scripts/back_projection/main_bp.m
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
clear; close all; clc;
|
||||||
|
|
||||||
|
LR_folder = './LR'; % LR
|
||||||
|
preout_folder = './results'; % pre output
|
||||||
|
save_folder = './results_20bp';
|
||||||
|
filepaths = dir(fullfile(preout_folder, '*.png'));
|
||||||
|
max_iter = 20;
|
||||||
|
|
||||||
|
if ~ exist(save_folder, 'dir')
|
||||||
|
mkdir(save_folder);
|
||||||
|
end
|
||||||
|
|
||||||
|
for idx_im = 1:length(filepaths)
|
||||||
|
fprintf([num2str(idx_im) '\n']);
|
||||||
|
im_name = filepaths(idx_im).name;
|
||||||
|
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
|
||||||
|
im_out = im2double(imread(fullfile(preout_folder, im_name)));
|
||||||
|
%tic
|
||||||
|
im_out = backprojection(im_out, im_LR, max_iter);
|
||||||
|
%toc
|
||||||
|
imwrite(im_out, fullfile(save_folder, im_name));
|
||||||
|
end
|
25
codes/scripts/back_projection/main_reverse_filter.m
Normal file
25
codes/scripts/back_projection/main_reverse_filter.m
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
clear; close all; clc;
|
||||||
|
|
||||||
|
LR_folder = './LR'; % LR
|
||||||
|
preout_folder = './results'; % pre output
|
||||||
|
save_folder = './results_20if';
|
||||||
|
filepaths = dir(fullfile(preout_folder, '*.png'));
|
||||||
|
max_iter = 20;
|
||||||
|
|
||||||
|
if ~ exist(save_folder, 'dir')
|
||||||
|
mkdir(save_folder);
|
||||||
|
end
|
||||||
|
|
||||||
|
for idx_im = 1:length(filepaths)
|
||||||
|
fprintf([num2str(idx_im) '\n']);
|
||||||
|
im_name = filepaths(idx_im).name;
|
||||||
|
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
|
||||||
|
im_out = im2double(imread(fullfile(preout_folder, im_name)));
|
||||||
|
J = imresize(im_LR,4,'bicubic');
|
||||||
|
%tic
|
||||||
|
for m = 1:max_iter
|
||||||
|
im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
|
||||||
|
end
|
||||||
|
%toc
|
||||||
|
imwrite(im_out, fullfile(save_folder, im_name));
|
||||||
|
end
|
27
codes/scripts/transfer_params_MSRResNet.py
Normal file
27
codes/scripts/transfer_params_MSRResNet.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
import os.path as osp
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
try:
|
||||||
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||||
|
import models.archs.SRResNet_arch as SRResNet_arch
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
|
||||||
|
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
|
||||||
|
crt_net = crt_model.state_dict()
|
||||||
|
|
||||||
|
for k, v in crt_net.items():
|
||||||
|
if k in pretrained_net and 'upconv1' not in k:
|
||||||
|
crt_net[k] = pretrained_net[k]
|
||||||
|
print('replace ... ', k)
|
||||||
|
|
||||||
|
# x4 -> x3
|
||||||
|
crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
|
||||||
|
crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
|
||||||
|
crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
|
||||||
|
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
|
||||||
|
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
|
||||||
|
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
|
||||||
|
|
||||||
|
torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
|
105
codes/test.py
Normal file
105
codes/test.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
import os.path as osp
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import options.options as option
|
||||||
|
import utils.util as util
|
||||||
|
from data.util import bgr2ycbcr
|
||||||
|
from data import create_dataset, create_dataloader
|
||||||
|
from models import create_model
|
||||||
|
|
||||||
|
#### options
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
|
||||||
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
|
util.mkdirs(
|
||||||
|
(path for key, path in opt['path'].items()
|
||||||
|
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||||
|
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||||
|
screen=True, tofile=True)
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
logger.info(option.dict2str(opt))
|
||||||
|
|
||||||
|
#### Create test dataset and dataloader
|
||||||
|
test_loaders = []
|
||||||
|
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
||||||
|
test_set = create_dataset(dataset_opt)
|
||||||
|
test_loader = create_dataloader(test_set, dataset_opt)
|
||||||
|
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||||
|
test_loaders.append(test_loader)
|
||||||
|
|
||||||
|
model = create_model(opt)
|
||||||
|
for test_loader in test_loaders:
|
||||||
|
test_set_name = test_loader.dataset.opt['name']
|
||||||
|
logger.info('\nTesting [{:s}]...'.format(test_set_name))
|
||||||
|
test_start_time = time.time()
|
||||||
|
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
|
||||||
|
util.mkdir(dataset_dir)
|
||||||
|
|
||||||
|
test_results = OrderedDict()
|
||||||
|
test_results['psnr'] = []
|
||||||
|
test_results['ssim'] = []
|
||||||
|
test_results['psnr_y'] = []
|
||||||
|
test_results['ssim_y'] = []
|
||||||
|
|
||||||
|
for data in test_loader:
|
||||||
|
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||||
|
model.feed_data(data, need_GT=need_GT)
|
||||||
|
img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
|
||||||
|
img_name = osp.splitext(osp.basename(img_path))[0]
|
||||||
|
|
||||||
|
model.test()
|
||||||
|
visuals = model.get_current_visuals(need_GT=need_GT)
|
||||||
|
|
||||||
|
sr_img = util.tensor2img(visuals['rlt']) # uint8
|
||||||
|
|
||||||
|
# save images
|
||||||
|
suffix = opt['suffix']
|
||||||
|
if suffix:
|
||||||
|
save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
|
||||||
|
else:
|
||||||
|
save_img_path = osp.join(dataset_dir, img_name + '.png')
|
||||||
|
util.save_img(sr_img, save_img_path)
|
||||||
|
|
||||||
|
# calculate PSNR and SSIM
|
||||||
|
if need_GT:
|
||||||
|
gt_img = util.tensor2img(visuals['GT'])
|
||||||
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
|
psnr = util.calculate_psnr(sr_img, gt_img)
|
||||||
|
ssim = util.calculate_ssim(sr_img, gt_img)
|
||||||
|
test_results['psnr'].append(psnr)
|
||||||
|
test_results['ssim'].append(ssim)
|
||||||
|
|
||||||
|
if gt_img.shape[2] == 3: # RGB image
|
||||||
|
sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
|
||||||
|
gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)
|
||||||
|
|
||||||
|
psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
|
||||||
|
ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
|
||||||
|
test_results['psnr_y'].append(psnr_y)
|
||||||
|
test_results['ssim_y'].append(ssim_y)
|
||||||
|
logger.info(
|
||||||
|
'{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
|
||||||
|
format(img_name, psnr, ssim, psnr_y, ssim_y))
|
||||||
|
else:
|
||||||
|
logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
|
||||||
|
else:
|
||||||
|
logger.info(img_name)
|
||||||
|
|
||||||
|
if need_GT: # metrics
|
||||||
|
# Average PSNR/SSIM results
|
||||||
|
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
|
||||||
|
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
|
||||||
|
logger.info(
|
||||||
|
'----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format(
|
||||||
|
test_set_name, ave_psnr, ave_ssim))
|
||||||
|
if test_results['psnr_y'] and test_results['ssim_y']:
|
||||||
|
ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
|
||||||
|
ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
|
||||||
|
logger.info(
|
||||||
|
'----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'.
|
||||||
|
format(ave_psnr_y, ave_ssim_y))
|
208
codes/test_Vid4_REDS4_with_GT.py
Normal file
208
codes/test_Vid4_REDS4_with_GT.py
Normal file
|
@ -0,0 +1,208 @@
|
||||||
|
'''
|
||||||
|
Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets
|
||||||
|
'''
|
||||||
|
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import utils.util as util
|
||||||
|
import data.util as data_util
|
||||||
|
import models.archs.EDVR_arch as EDVR_arch
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
#################
|
||||||
|
# configurations
|
||||||
|
#################
|
||||||
|
device = torch.device('cuda')
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
|
||||||
|
# Vid4: SR
|
||||||
|
# REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
|
||||||
|
# blur (deblur-clean), blur_comp (deblur-compression).
|
||||||
|
stage = 1 # 1 or 2, use two stage strategy for REDS dataset.
|
||||||
|
flip_test = False
|
||||||
|
############################################################################
|
||||||
|
#### model
|
||||||
|
if data_mode == 'Vid4':
|
||||||
|
if stage == 1:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
|
||||||
|
else:
|
||||||
|
raise ValueError('Vid4 does not support stage 2.')
|
||||||
|
elif data_mode == 'sharp_bicubic':
|
||||||
|
if stage == 1:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
|
||||||
|
else:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
|
||||||
|
elif data_mode == 'blur_bicubic':
|
||||||
|
if stage == 1:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
|
||||||
|
else:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
|
||||||
|
elif data_mode == 'blur':
|
||||||
|
if stage == 1:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
|
||||||
|
else:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
|
||||||
|
elif data_mode == 'blur_comp':
|
||||||
|
if stage == 1:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
|
||||||
|
else:
|
||||||
|
model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if data_mode == 'Vid4':
|
||||||
|
N_in = 7 # use N_in images to restore one HR image
|
||||||
|
else:
|
||||||
|
N_in = 5
|
||||||
|
|
||||||
|
predeblur, HR_in = False, False
|
||||||
|
back_RBs = 40
|
||||||
|
if data_mode == 'blur_bicubic':
|
||||||
|
predeblur = True
|
||||||
|
if data_mode == 'blur' or data_mode == 'blur_comp':
|
||||||
|
predeblur, HR_in = True, True
|
||||||
|
if stage == 2:
|
||||||
|
HR_in = True
|
||||||
|
back_RBs = 20
|
||||||
|
model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
|
||||||
|
|
||||||
|
#### dataset
|
||||||
|
if data_mode == 'Vid4':
|
||||||
|
test_dataset_folder = '../datasets/Vid4/BIx4'
|
||||||
|
GT_dataset_folder = '../datasets/Vid4/GT'
|
||||||
|
else:
|
||||||
|
if stage == 1:
|
||||||
|
test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
|
||||||
|
else:
|
||||||
|
test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
|
||||||
|
print('You should modify the test_dataset_folder path for stage 2')
|
||||||
|
GT_dataset_folder = '../datasets/REDS4/GT'
|
||||||
|
|
||||||
|
#### evaluation
|
||||||
|
crop_border = 0
|
||||||
|
border_frame = N_in // 2 # border frames when evaluate
|
||||||
|
# temporal padding mode
|
||||||
|
if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
|
||||||
|
padding = 'new_info'
|
||||||
|
else:
|
||||||
|
padding = 'replicate'
|
||||||
|
save_imgs = True
|
||||||
|
|
||||||
|
save_folder = '../results/{}'.format(data_mode)
|
||||||
|
util.mkdirs(save_folder)
|
||||||
|
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
#### log info
|
||||||
|
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||||
|
logger.info('Padding mode: {}'.format(padding))
|
||||||
|
logger.info('Model path: {}'.format(model_path))
|
||||||
|
logger.info('Save images: {}'.format(save_imgs))
|
||||||
|
logger.info('Flip test: {}'.format(flip_test))
|
||||||
|
|
||||||
|
#### set up the models
|
||||||
|
model.load_state_dict(torch.load(model_path), strict=True)
|
||||||
|
model.eval()
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
|
||||||
|
subfolder_name_l = []
|
||||||
|
|
||||||
|
subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
|
||||||
|
subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
|
||||||
|
# for each subfolder
|
||||||
|
for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
|
||||||
|
subfolder_name = osp.basename(subfolder)
|
||||||
|
subfolder_name_l.append(subfolder_name)
|
||||||
|
save_subfolder = osp.join(save_folder, subfolder_name)
|
||||||
|
|
||||||
|
img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
|
||||||
|
max_idx = len(img_path_l)
|
||||||
|
if save_imgs:
|
||||||
|
util.mkdirs(save_subfolder)
|
||||||
|
|
||||||
|
#### read LQ and GT images
|
||||||
|
imgs_LQ = data_util.read_img_seq(subfolder)
|
||||||
|
img_GT_l = []
|
||||||
|
for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
|
||||||
|
img_GT_l.append(data_util.read_img(None, img_GT_path))
|
||||||
|
|
||||||
|
avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
|
||||||
|
|
||||||
|
# process each image
|
||||||
|
for img_idx, img_path in enumerate(img_path_l):
|
||||||
|
img_name = osp.splitext(osp.basename(img_path))[0]
|
||||||
|
select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
|
||||||
|
imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
if flip_test:
|
||||||
|
output = util.flipx4_forward(model, imgs_in)
|
||||||
|
else:
|
||||||
|
output = util.single_forward(model, imgs_in)
|
||||||
|
output = util.tensor2img(output.squeeze(0))
|
||||||
|
|
||||||
|
if save_imgs:
|
||||||
|
cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
|
||||||
|
|
||||||
|
# calculate PSNR
|
||||||
|
output = output / 255.
|
||||||
|
GT = np.copy(img_GT_l[img_idx])
|
||||||
|
# For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
|
||||||
|
if data_mode == 'Vid4': # bgr2y, [0, 1]
|
||||||
|
GT = data_util.bgr2ycbcr(GT, only_y=True)
|
||||||
|
output = data_util.bgr2ycbcr(output, only_y=True)
|
||||||
|
|
||||||
|
output, GT = util.crop_border([output, GT], crop_border)
|
||||||
|
crt_psnr = util.calculate_psnr(output * 255, GT * 255)
|
||||||
|
logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))
|
||||||
|
|
||||||
|
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
|
||||||
|
avg_psnr_center += crt_psnr
|
||||||
|
N_center += 1
|
||||||
|
else: # border frames
|
||||||
|
avg_psnr_border += crt_psnr
|
||||||
|
N_border += 1
|
||||||
|
|
||||||
|
avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
|
||||||
|
avg_psnr_center = avg_psnr_center / N_center
|
||||||
|
avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
|
||||||
|
avg_psnr_l.append(avg_psnr)
|
||||||
|
avg_psnr_center_l.append(avg_psnr_center)
|
||||||
|
avg_psnr_border_l.append(avg_psnr_border)
|
||||||
|
|
||||||
|
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
|
||||||
|
'Center PSNR: {:.6f} dB for {} frames; '
|
||||||
|
'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
|
||||||
|
(N_center + N_border),
|
||||||
|
avg_psnr_center, N_center,
|
||||||
|
avg_psnr_border, N_border))
|
||||||
|
|
||||||
|
logger.info('################ Tidy Outputs ################')
|
||||||
|
for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
|
||||||
|
avg_psnr_center_l, avg_psnr_border_l):
|
||||||
|
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
|
||||||
|
'Center PSNR: {:.6f} dB. '
|
||||||
|
'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
|
||||||
|
psnr_border))
|
||||||
|
logger.info('################ Final Results ################')
|
||||||
|
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||||
|
logger.info('Padding mode: {}'.format(padding))
|
||||||
|
logger.info('Model path: {}'.format(model_path))
|
||||||
|
logger.info('Save images: {}'.format(save_imgs))
|
||||||
|
logger.info('Flip test: {}'.format(flip_test))
|
||||||
|
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
|
||||||
|
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
|
||||||
|
sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
|
||||||
|
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
|
||||||
|
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
264
codes/test_Vid4_REDS4_with_GT_DUF.py
Normal file
264
codes/test_Vid4_REDS4_with_GT_DUF.py
Normal file
|
@ -0,0 +1,264 @@
|
||||||
|
"""
|
||||||
|
DUF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
|
||||||
|
write to txt log file
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import utils.util as util
|
||||||
|
import data.util as data_util
|
||||||
|
import models.archs.DUF_arch as DUF_arch
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
#################
|
||||||
|
# configurations
|
||||||
|
#################
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
|
||||||
|
|
||||||
|
# Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
|
||||||
|
scale = 4
|
||||||
|
layer = 52
|
||||||
|
assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28),
|
||||||
|
(4, 52)], 'Unrecognized (scale, layer) combination'
|
||||||
|
|
||||||
|
# model
|
||||||
|
N_in = 7
|
||||||
|
model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format(scale, layer)
|
||||||
|
adapt_official = True if 'official' in model_path else False
|
||||||
|
DUF_downsampling = True # True | False
|
||||||
|
if layer == 16:
|
||||||
|
model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official)
|
||||||
|
elif layer == 28:
|
||||||
|
model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official)
|
||||||
|
elif layer == 52:
|
||||||
|
model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official)
|
||||||
|
|
||||||
|
#### dataset
|
||||||
|
if data_mode == 'Vid4':
|
||||||
|
test_dataset_folder = '../datasets/Vid4/BIx4/*'
|
||||||
|
else: # sharp_bicubic (REDS)
|
||||||
|
test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
|
||||||
|
|
||||||
|
#### evaluation
|
||||||
|
crop_border = 8
|
||||||
|
border_frame = N_in // 2 # border frames when evaluate
|
||||||
|
# temporal padding mode
|
||||||
|
padding = 'new_info' # different from the official testing codes, which pads zeros.
|
||||||
|
save_imgs = True
|
||||||
|
############################################################################
|
||||||
|
device = torch.device('cuda')
|
||||||
|
save_folder = '../results/{}'.format(data_mode)
|
||||||
|
util.mkdirs(save_folder)
|
||||||
|
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
#### log info
|
||||||
|
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||||
|
logger.info('Padding mode: {}'.format(padding))
|
||||||
|
logger.info('Model path: {}'.format(model_path))
|
||||||
|
logger.info('Save images: {}'.format(save_imgs))
|
||||||
|
|
||||||
|
def read_image(img_path):
|
||||||
|
'''read one image from img_path
|
||||||
|
Return img: HWC, BGR, [0,1], numpy
|
||||||
|
'''
|
||||||
|
img_GT = cv2.imread(img_path)
|
||||||
|
img = img_GT.astype(np.float32) / 255.
|
||||||
|
return img
|
||||||
|
|
||||||
|
def read_seq_imgs(img_seq_path):
|
||||||
|
'''read a sequence of images'''
|
||||||
|
img_path_l = sorted(glob.glob(img_seq_path + '/*'))
|
||||||
|
img_l = [read_image(v) for v in img_path_l]
|
||||||
|
# stack to TCHW, RGB, [0,1], torch
|
||||||
|
imgs = np.stack(img_l, axis=0)
|
||||||
|
imgs = imgs[:, :, :, [2, 1, 0]]
|
||||||
|
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
def index_generation(crt_i, max_n, N, padding='reflection'):
|
||||||
|
'''
|
||||||
|
padding: replicate | reflection | new_info | circle
|
||||||
|
'''
|
||||||
|
max_n = max_n - 1
|
||||||
|
n_pad = N // 2
|
||||||
|
return_l = []
|
||||||
|
|
||||||
|
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
|
||||||
|
if i < 0:
|
||||||
|
if padding == 'replicate':
|
||||||
|
add_idx = 0
|
||||||
|
elif padding == 'reflection':
|
||||||
|
add_idx = -i
|
||||||
|
elif padding == 'new_info':
|
||||||
|
add_idx = (crt_i + n_pad) + (-i)
|
||||||
|
elif padding == 'circle':
|
||||||
|
add_idx = N + i
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong padding mode')
|
||||||
|
elif i > max_n:
|
||||||
|
if padding == 'replicate':
|
||||||
|
add_idx = max_n
|
||||||
|
elif padding == 'reflection':
|
||||||
|
add_idx = max_n * 2 - i
|
||||||
|
elif padding == 'new_info':
|
||||||
|
add_idx = (crt_i - n_pad) - (i - max_n)
|
||||||
|
elif padding == 'circle':
|
||||||
|
add_idx = i - N
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong padding mode')
|
||||||
|
else:
|
||||||
|
add_idx = i
|
||||||
|
return_l.append(add_idx)
|
||||||
|
return return_l
|
||||||
|
|
||||||
|
def single_forward(model, imgs_in):
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = model(imgs_in)
|
||||||
|
if isinstance(model_output, list) or isinstance(model_output, tuple):
|
||||||
|
output = model_output[0]
|
||||||
|
else:
|
||||||
|
output = model_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
sub_folder_l = sorted(glob.glob(test_dataset_folder))
|
||||||
|
#### set up the models
|
||||||
|
model.load_state_dict(torch.load(model_path), strict=True)
|
||||||
|
model.eval()
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
|
||||||
|
sub_folder_name_l = []
|
||||||
|
|
||||||
|
# for each sub-folder
|
||||||
|
for sub_folder in sub_folder_l:
|
||||||
|
sub_folder_name = sub_folder.split('/')[-1]
|
||||||
|
sub_folder_name_l.append(sub_folder_name)
|
||||||
|
save_sub_folder = osp.join(save_folder, sub_folder_name)
|
||||||
|
|
||||||
|
img_path_l = sorted(glob.glob(sub_folder + '/*'))
|
||||||
|
max_idx = len(img_path_l)
|
||||||
|
|
||||||
|
if save_imgs:
|
||||||
|
util.mkdirs(save_sub_folder)
|
||||||
|
|
||||||
|
#### read LR images
|
||||||
|
imgs = read_seq_imgs(sub_folder)
|
||||||
|
#### read GT images
|
||||||
|
img_GT_l = []
|
||||||
|
if data_mode == 'Vid4':
|
||||||
|
sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*')
|
||||||
|
else:
|
||||||
|
sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
|
||||||
|
for img_GT_path in sorted(glob.glob(sub_folder_GT)):
|
||||||
|
img_GT_l.append(read_image(img_GT_path))
|
||||||
|
|
||||||
|
# When using the downsampling in DUF official code, we downsample the HR images
|
||||||
|
if DUF_downsampling:
|
||||||
|
sub_folder = sub_folder_GT
|
||||||
|
img_path_l = sorted(glob.glob(sub_folder))
|
||||||
|
max_idx = len(img_path_l)
|
||||||
|
imgs = read_seq_imgs(sub_folder[:-2])
|
||||||
|
|
||||||
|
avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
|
||||||
|
cal_n_border, cal_n_center = 0, 0
|
||||||
|
|
||||||
|
# process each image
|
||||||
|
for img_idx, img_path in enumerate(img_path_l):
|
||||||
|
c_idx = int(osp.splitext(osp.basename(img_path))[0])
|
||||||
|
select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
|
||||||
|
# get input images
|
||||||
|
imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
# Downsample the HR images
|
||||||
|
H, W = imgs_in.size(3), imgs_in.size(4)
|
||||||
|
if DUF_downsampling:
|
||||||
|
imgs_in = util.DUF_downsample(imgs_in, scale=scale)
|
||||||
|
|
||||||
|
output = single_forward(model, imgs_in)
|
||||||
|
|
||||||
|
# Crop to the original shape
|
||||||
|
if scale == 3:
|
||||||
|
pad_h = 3 - (H % 3)
|
||||||
|
pad_w = 3 - (W % 3)
|
||||||
|
if pad_h > 0:
|
||||||
|
output = output[:, :, :-pad_h, :]
|
||||||
|
if pad_w > 0:
|
||||||
|
output = output[:, :, :, :-pad_w]
|
||||||
|
output_f = output.data.float().cpu().squeeze(0)
|
||||||
|
|
||||||
|
output = util.tensor2img(output_f)
|
||||||
|
|
||||||
|
# save imgs
|
||||||
|
if save_imgs:
|
||||||
|
cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
|
||||||
|
|
||||||
|
#### calculate PSNR
|
||||||
|
output = output / 255.
|
||||||
|
GT = np.copy(img_GT_l[img_idx])
|
||||||
|
# For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
|
||||||
|
if data_mode == 'Vid4': # bgr2y, [0, 1]
|
||||||
|
GT = data_util.bgr2ycbcr(GT)
|
||||||
|
output = data_util.bgr2ycbcr(output)
|
||||||
|
if crop_border == 0:
|
||||||
|
cropped_output = output
|
||||||
|
cropped_GT = GT
|
||||||
|
else:
|
||||||
|
cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
|
||||||
|
cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
|
||||||
|
crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
|
||||||
|
logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
|
||||||
|
|
||||||
|
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
|
||||||
|
avg_psnr_center += crt_psnr
|
||||||
|
cal_n_center += 1
|
||||||
|
else: # border frames
|
||||||
|
avg_psnr_border += crt_psnr
|
||||||
|
cal_n_border += 1
|
||||||
|
|
||||||
|
avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
|
||||||
|
avg_psnr_center = avg_psnr_center / cal_n_center
|
||||||
|
if cal_n_border == 0:
|
||||||
|
avg_psnr_border = 0
|
||||||
|
else:
|
||||||
|
avg_psnr_border = avg_psnr_border / cal_n_border
|
||||||
|
|
||||||
|
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
|
||||||
|
'Center PSNR: {:.6f} dB for {} frames; '
|
||||||
|
'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
|
||||||
|
(cal_n_center + cal_n_border),
|
||||||
|
avg_psnr_center, cal_n_center,
|
||||||
|
avg_psnr_border, cal_n_border))
|
||||||
|
|
||||||
|
avg_psnr_l.append(avg_psnr)
|
||||||
|
avg_psnr_center_l.append(avg_psnr_center)
|
||||||
|
avg_psnr_border_l.append(avg_psnr_border)
|
||||||
|
|
||||||
|
logger.info('################ Tidy Outputs ################')
|
||||||
|
for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
|
||||||
|
avg_psnr_center_l, avg_psnr_border_l):
|
||||||
|
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
|
||||||
|
'Center PSNR: {:.6f} dB. '
|
||||||
|
'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
|
||||||
|
logger.info('################ Final Results ################')
|
||||||
|
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||||
|
logger.info('Padding mode: {}'.format(padding))
|
||||||
|
logger.info('Model path: {}'.format(model_path))
|
||||||
|
logger.info('Save images: {}'.format(save_imgs))
|
||||||
|
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
|
||||||
|
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
|
||||||
|
sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
|
||||||
|
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
|
||||||
|
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
230
codes/test_Vid4_REDS4_with_GT_TOF.py
Normal file
230
codes/test_Vid4_REDS4_with_GT_TOF.py
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
"""
|
||||||
|
TOF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
|
||||||
|
write to txt log file
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import utils.util as util
|
||||||
|
import data.util as data_util
|
||||||
|
import models.archs.TOF_arch as TOF_arch
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
#################
|
||||||
|
# configurations
|
||||||
|
#################
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
|
||||||
|
|
||||||
|
# model
|
||||||
|
N_in = 7
|
||||||
|
model_path = '../experiments/pretrained_models/TOF_official.pth'
|
||||||
|
adapt_official = True if 'official' in model_path else False
|
||||||
|
model = TOF_arch.TOFlow(adapt_official=adapt_official)
|
||||||
|
|
||||||
|
#### dataset
|
||||||
|
if data_mode == 'Vid4':
|
||||||
|
test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*'
|
||||||
|
else:
|
||||||
|
test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
|
||||||
|
|
||||||
|
#### evaluation
|
||||||
|
crop_border = 0
|
||||||
|
border_frame = N_in // 2 # border frames when evaluate
|
||||||
|
# temporal padding mode
|
||||||
|
padding = 'new_info' # different from the official setting
|
||||||
|
save_imgs = True
|
||||||
|
############################################################################
|
||||||
|
device = torch.device('cuda')
|
||||||
|
save_folder = '../results/{}'.format(data_mode)
|
||||||
|
util.mkdirs(save_folder)
|
||||||
|
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
#### log info
|
||||||
|
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||||
|
logger.info('Padding mode: {}'.format(padding))
|
||||||
|
logger.info('Model path: {}'.format(model_path))
|
||||||
|
logger.info('Save images: {}'.format(save_imgs))
|
||||||
|
|
||||||
|
def read_image(img_path):
|
||||||
|
'''read one image from img_path
|
||||||
|
Return img: HWC, BGR, [0,1], numpy
|
||||||
|
'''
|
||||||
|
img_GT = cv2.imread(img_path)
|
||||||
|
img = img_GT.astype(np.float32) / 255.
|
||||||
|
return img
|
||||||
|
|
||||||
|
def read_seq_imgs(img_seq_path):
|
||||||
|
'''read a sequence of images'''
|
||||||
|
img_path_l = sorted(glob.glob(img_seq_path + '/*'))
|
||||||
|
img_l = [read_image(v) for v in img_path_l]
|
||||||
|
# stack to TCHW, RGB, [0,1], torch
|
||||||
|
imgs = np.stack(img_l, axis=0)
|
||||||
|
imgs = imgs[:, :, :, [2, 1, 0]]
|
||||||
|
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
def index_generation(crt_i, max_n, N, padding='reflection'):
|
||||||
|
'''
|
||||||
|
padding: replicate | reflection | new_info | circle
|
||||||
|
'''
|
||||||
|
max_n = max_n - 1
|
||||||
|
n_pad = N // 2
|
||||||
|
return_l = []
|
||||||
|
|
||||||
|
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
|
||||||
|
if i < 0:
|
||||||
|
if padding == 'replicate':
|
||||||
|
add_idx = 0
|
||||||
|
elif padding == 'reflection':
|
||||||
|
add_idx = -i
|
||||||
|
elif padding == 'new_info':
|
||||||
|
add_idx = (crt_i + n_pad) + (-i)
|
||||||
|
elif padding == 'circle':
|
||||||
|
add_idx = N + i
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong padding mode')
|
||||||
|
elif i > max_n:
|
||||||
|
if padding == 'replicate':
|
||||||
|
add_idx = max_n
|
||||||
|
elif padding == 'reflection':
|
||||||
|
add_idx = max_n * 2 - i
|
||||||
|
elif padding == 'new_info':
|
||||||
|
add_idx = (crt_i - n_pad) - (i - max_n)
|
||||||
|
elif padding == 'circle':
|
||||||
|
add_idx = i - N
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong padding mode')
|
||||||
|
else:
|
||||||
|
add_idx = i
|
||||||
|
return_l.append(add_idx)
|
||||||
|
return return_l
|
||||||
|
|
||||||
|
def single_forward(model, imgs_in):
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = model(imgs_in)
|
||||||
|
if isinstance(model_output, list) or isinstance(model_output, tuple):
|
||||||
|
output = model_output[0]
|
||||||
|
else:
|
||||||
|
output = model_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
sub_folder_l = sorted(glob.glob(test_dataset_folder))
|
||||||
|
#### set up the models
|
||||||
|
model.load_state_dict(torch.load(model_path), strict=True)
|
||||||
|
model.eval()
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
|
||||||
|
sub_folder_name_l = []
|
||||||
|
|
||||||
|
# for each sub-folder
|
||||||
|
for sub_folder in sub_folder_l:
|
||||||
|
sub_folder_name = sub_folder.split('/')[-1]
|
||||||
|
sub_folder_name_l.append(sub_folder_name)
|
||||||
|
save_sub_folder = osp.join(save_folder, sub_folder_name)
|
||||||
|
|
||||||
|
img_path_l = sorted(glob.glob(sub_folder + '/*'))
|
||||||
|
max_idx = len(img_path_l)
|
||||||
|
|
||||||
|
if save_imgs:
|
||||||
|
util.mkdirs(save_sub_folder)
|
||||||
|
|
||||||
|
#### read LR images
|
||||||
|
imgs = read_seq_imgs(sub_folder)
|
||||||
|
#### read GT images
|
||||||
|
img_GT_l = []
|
||||||
|
if data_mode == 'Vid4':
|
||||||
|
sub_folder_GT = osp.join(sub_folder.replace('/BIx4up_direct/', '/GT/'), '*')
|
||||||
|
else:
|
||||||
|
sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
|
||||||
|
for img_GT_path in sorted(glob.glob(sub_folder_GT)):
|
||||||
|
img_GT_l.append(read_image(img_GT_path))
|
||||||
|
|
||||||
|
avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
|
||||||
|
cal_n_border, cal_n_center = 0, 0
|
||||||
|
|
||||||
|
# process each image
|
||||||
|
for img_idx, img_path in enumerate(img_path_l):
|
||||||
|
c_idx = int(osp.splitext(osp.basename(img_path))[0])
|
||||||
|
select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
|
||||||
|
# get input images
|
||||||
|
imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
|
||||||
|
output = single_forward(model, imgs_in)
|
||||||
|
output_f = output.data.float().cpu().squeeze(0)
|
||||||
|
|
||||||
|
output = util.tensor2img(output_f)
|
||||||
|
|
||||||
|
# save imgs
|
||||||
|
if save_imgs:
|
||||||
|
cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
|
||||||
|
|
||||||
|
#### calculate PSNR
|
||||||
|
output = output / 255.
|
||||||
|
GT = np.copy(img_GT_l[img_idx])
|
||||||
|
# For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
|
||||||
|
if data_mode == 'Vid4': # bgr2y, [0, 1]
|
||||||
|
GT = data_util.bgr2ycbcr(GT)
|
||||||
|
output = data_util.bgr2ycbcr(output)
|
||||||
|
if crop_border == 0:
|
||||||
|
cropped_output = output
|
||||||
|
cropped_GT = GT
|
||||||
|
else:
|
||||||
|
cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
|
||||||
|
cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
|
||||||
|
crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
|
||||||
|
logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
|
||||||
|
|
||||||
|
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
|
||||||
|
avg_psnr_center += crt_psnr
|
||||||
|
cal_n_center += 1
|
||||||
|
else: # border frames
|
||||||
|
avg_psnr_border += crt_psnr
|
||||||
|
cal_n_border += 1
|
||||||
|
|
||||||
|
avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
|
||||||
|
avg_psnr_center = avg_psnr_center / cal_n_center
|
||||||
|
if cal_n_border == 0:
|
||||||
|
avg_psnr_border = 0
|
||||||
|
else:
|
||||||
|
avg_psnr_border = avg_psnr_border / cal_n_border
|
||||||
|
|
||||||
|
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
|
||||||
|
'Center PSNR: {:.6f} dB for {} frames; '
|
||||||
|
'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
|
||||||
|
(cal_n_center + cal_n_border),
|
||||||
|
avg_psnr_center, cal_n_center,
|
||||||
|
avg_psnr_border, cal_n_border))
|
||||||
|
|
||||||
|
avg_psnr_l.append(avg_psnr)
|
||||||
|
avg_psnr_center_l.append(avg_psnr_center)
|
||||||
|
avg_psnr_border_l.append(avg_psnr_border)
|
||||||
|
|
||||||
|
logger.info('################ Tidy Outputs ################')
|
||||||
|
for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
|
||||||
|
avg_psnr_center_l, avg_psnr_border_l):
|
||||||
|
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
|
||||||
|
'Center PSNR: {:.6f} dB. '
|
||||||
|
'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
|
||||||
|
logger.info('################ Final Results ################')
|
||||||
|
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||||
|
logger.info('Padding mode: {}'.format(padding))
|
||||||
|
logger.info('Model path: {}'.format(model_path))
|
||||||
|
logger.info('Save images: {}'.format(save_imgs))
|
||||||
|
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
|
||||||
|
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
|
||||||
|
sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
|
||||||
|
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
|
||||||
|
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
310
codes/train.py
Normal file
310
codes/train.py
Normal file
|
@ -0,0 +1,310 @@
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from data.data_sampler import DistIterSampler
|
||||||
|
|
||||||
|
import options.options as option
|
||||||
|
from utils import util
|
||||||
|
from data import create_dataloader, create_dataset
|
||||||
|
from models import create_model
|
||||||
|
|
||||||
|
|
||||||
|
def init_dist(backend='nccl', **kwargs):
|
||||||
|
"""initialization for distributed training"""
|
||||||
|
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||||
|
mp.set_start_method('spawn')
|
||||||
|
rank = int(os.environ['RANK'])
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
dist.init_process_group(backend=backend, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
#### options
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
||||||
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
|
help='job launcher')
|
||||||
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
||||||
|
#### distributed training settings
|
||||||
|
if args.launcher == 'none': # disabled distributed training
|
||||||
|
opt['dist'] = False
|
||||||
|
rank = -1
|
||||||
|
print('Disabled distributed training.')
|
||||||
|
else:
|
||||||
|
opt['dist'] = True
|
||||||
|
init_dist()
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
|
#### loading resume state if exists
|
||||||
|
if opt['path'].get('resume_state', None):
|
||||||
|
# distributed resuming: all load into default GPU
|
||||||
|
device_id = torch.cuda.current_device()
|
||||||
|
resume_state = torch.load(opt['path']['resume_state'],
|
||||||
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||||
|
option.check_resume(opt, resume_state['iter']) # check resume options
|
||||||
|
else:
|
||||||
|
resume_state = None
|
||||||
|
|
||||||
|
#### mkdir and loggers
|
||||||
|
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
||||||
|
if resume_state is None:
|
||||||
|
util.mkdir_and_rename(
|
||||||
|
opt['path']['experiments_root']) # rename experiment folder if exists
|
||||||
|
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
|
||||||
|
and 'pretrain_model' not in key and 'resume' not in key))
|
||||||
|
|
||||||
|
# config loggers. Before it, the log will not work
|
||||||
|
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
||||||
|
screen=True, tofile=True)
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
logger.info(option.dict2str(opt))
|
||||||
|
# tensorboard logger
|
||||||
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
|
version = float(torch.__version__[0:3])
|
||||||
|
if version >= 1.1: # PyTorch 1.1
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
|
||||||
|
else:
|
||||||
|
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
# convert to NoneDict, which returns None for missing keys
|
||||||
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
|
#### random seed
|
||||||
|
seed = opt['train']['manual_seed']
|
||||||
|
if seed is None:
|
||||||
|
seed = random.randint(1, 10000)
|
||||||
|
if rank <= 0:
|
||||||
|
logger.info('Random seed: {}'.format(seed))
|
||||||
|
util.set_random_seed(seed)
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
# torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
#### create train and val dataloader
|
||||||
|
dataset_ratio = 200 # enlarge the size of each epoch
|
||||||
|
for phase, dataset_opt in opt['datasets'].items():
|
||||||
|
if phase == 'train':
|
||||||
|
train_set = create_dataset(dataset_opt)
|
||||||
|
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
||||||
|
total_iters = int(opt['train']['niter'])
|
||||||
|
total_epochs = int(math.ceil(total_iters / train_size))
|
||||||
|
if opt['dist']:
|
||||||
|
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
|
||||||
|
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||||
|
else:
|
||||||
|
train_sampler = None
|
||||||
|
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
||||||
|
if rank <= 0:
|
||||||
|
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||||
|
len(train_set), train_size))
|
||||||
|
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
||||||
|
total_epochs, total_iters))
|
||||||
|
elif phase == 'val':
|
||||||
|
val_set = create_dataset(dataset_opt)
|
||||||
|
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
||||||
|
if rank <= 0:
|
||||||
|
logger.info('Number of val images in [{:s}]: {:d}'.format(
|
||||||
|
dataset_opt['name'], len(val_set)))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
||||||
|
assert train_loader is not None
|
||||||
|
|
||||||
|
#### create model
|
||||||
|
model = create_model(opt)
|
||||||
|
|
||||||
|
#### resume training
|
||||||
|
if resume_state:
|
||||||
|
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
||||||
|
resume_state['epoch'], resume_state['iter']))
|
||||||
|
|
||||||
|
start_epoch = resume_state['epoch']
|
||||||
|
current_step = resume_state['iter']
|
||||||
|
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||||
|
else:
|
||||||
|
current_step = 0
|
||||||
|
start_epoch = 0
|
||||||
|
|
||||||
|
#### training
|
||||||
|
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
||||||
|
for epoch in range(start_epoch, total_epochs + 1):
|
||||||
|
if opt['dist']:
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
for _, train_data in enumerate(train_loader):
|
||||||
|
current_step += 1
|
||||||
|
if current_step > total_iters:
|
||||||
|
break
|
||||||
|
#### update learning rate
|
||||||
|
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||||
|
|
||||||
|
#### training
|
||||||
|
model.feed_data(train_data)
|
||||||
|
model.optimize_parameters(current_step)
|
||||||
|
|
||||||
|
#### log
|
||||||
|
if current_step % opt['logger']['print_freq'] == 0:
|
||||||
|
logs = model.get_current_log()
|
||||||
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||||
|
for v in model.get_current_learning_rate():
|
||||||
|
message += '{:.3e},'.format(v)
|
||||||
|
message += ')] '
|
||||||
|
for k, v in logs.items():
|
||||||
|
message += '{:s}: {:.4e} '.format(k, v)
|
||||||
|
# tensorboard logger
|
||||||
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
|
if rank <= 0:
|
||||||
|
tb_logger.add_scalar(k, v, current_step)
|
||||||
|
if rank <= 0:
|
||||||
|
logger.info(message)
|
||||||
|
#### validation
|
||||||
|
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||||
|
if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation
|
||||||
|
# does not support multi-GPU validation
|
||||||
|
pbar = util.ProgressBar(len(val_loader))
|
||||||
|
avg_psnr = 0.
|
||||||
|
idx = 0
|
||||||
|
for val_data in val_loader:
|
||||||
|
idx += 1
|
||||||
|
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
|
||||||
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||||
|
util.mkdir(img_dir)
|
||||||
|
|
||||||
|
model.feed_data(val_data)
|
||||||
|
model.test()
|
||||||
|
|
||||||
|
visuals = model.get_current_visuals()
|
||||||
|
sr_img = util.tensor2img(visuals['rlt']) # uint8
|
||||||
|
gt_img = util.tensor2img(visuals['GT']) # uint8
|
||||||
|
|
||||||
|
# Save SR images for reference
|
||||||
|
save_img_path = os.path.join(img_dir,
|
||||||
|
'{:s}_{:d}.png'.format(img_name, current_step))
|
||||||
|
util.save_img(sr_img, save_img_path)
|
||||||
|
|
||||||
|
# calculate PSNR
|
||||||
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
|
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||||
|
pbar.update('Test {}'.format(img_name))
|
||||||
|
|
||||||
|
avg_psnr = avg_psnr / idx
|
||||||
|
|
||||||
|
# log
|
||||||
|
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
|
||||||
|
# tensorboard logger
|
||||||
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
|
tb_logger.add_scalar('psnr', avg_psnr, current_step)
|
||||||
|
else: # video restoration validation
|
||||||
|
if opt['dist']:
|
||||||
|
# multi-GPU testing
|
||||||
|
psnr_rlt = {} # with border and center frames
|
||||||
|
if rank == 0:
|
||||||
|
pbar = util.ProgressBar(len(val_set))
|
||||||
|
for idx in range(rank, len(val_set), world_size):
|
||||||
|
val_data = val_set[idx]
|
||||||
|
val_data['LQs'].unsqueeze_(0)
|
||||||
|
val_data['GT'].unsqueeze_(0)
|
||||||
|
folder = val_data['folder']
|
||||||
|
idx_d, max_idx = val_data['idx'].split('/')
|
||||||
|
idx_d, max_idx = int(idx_d), int(max_idx)
|
||||||
|
if psnr_rlt.get(folder, None) is None:
|
||||||
|
psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32,
|
||||||
|
device='cuda')
|
||||||
|
# tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
|
||||||
|
model.feed_data(val_data)
|
||||||
|
model.test()
|
||||||
|
visuals = model.get_current_visuals()
|
||||||
|
rlt_img = util.tensor2img(visuals['rlt']) # uint8
|
||||||
|
gt_img = util.tensor2img(visuals['GT']) # uint8
|
||||||
|
# calculate PSNR
|
||||||
|
psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
for _ in range(world_size):
|
||||||
|
pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx))
|
||||||
|
# # collect data
|
||||||
|
for _, v in psnr_rlt.items():
|
||||||
|
dist.reduce(v, 0)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
psnr_rlt_avg = {}
|
||||||
|
psnr_total_avg = 0.
|
||||||
|
for k, v in psnr_rlt.items():
|
||||||
|
psnr_rlt_avg[k] = torch.mean(v).cpu().item()
|
||||||
|
psnr_total_avg += psnr_rlt_avg[k]
|
||||||
|
psnr_total_avg /= len(psnr_rlt)
|
||||||
|
log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
|
||||||
|
for k, v in psnr_rlt_avg.items():
|
||||||
|
log_s += ' {}: {:.4e}'.format(k, v)
|
||||||
|
logger.info(log_s)
|
||||||
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
|
tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
|
||||||
|
for k, v in psnr_rlt_avg.items():
|
||||||
|
tb_logger.add_scalar(k, v, current_step)
|
||||||
|
else:
|
||||||
|
pbar = util.ProgressBar(len(val_loader))
|
||||||
|
psnr_rlt = {} # with border and center frames
|
||||||
|
psnr_rlt_avg = {}
|
||||||
|
psnr_total_avg = 0.
|
||||||
|
for val_data in val_loader:
|
||||||
|
folder = val_data['folder'][0]
|
||||||
|
idx_d = val_data['idx'].item()
|
||||||
|
# border = val_data['border'].item()
|
||||||
|
if psnr_rlt.get(folder, None) is None:
|
||||||
|
psnr_rlt[folder] = []
|
||||||
|
|
||||||
|
model.feed_data(val_data)
|
||||||
|
model.test()
|
||||||
|
visuals = model.get_current_visuals()
|
||||||
|
rlt_img = util.tensor2img(visuals['rlt']) # uint8
|
||||||
|
gt_img = util.tensor2img(visuals['GT']) # uint8
|
||||||
|
|
||||||
|
# calculate PSNR
|
||||||
|
psnr = util.calculate_psnr(rlt_img, gt_img)
|
||||||
|
psnr_rlt[folder].append(psnr)
|
||||||
|
pbar.update('Test {} - {}'.format(folder, idx_d))
|
||||||
|
for k, v in psnr_rlt.items():
|
||||||
|
psnr_rlt_avg[k] = sum(v) / len(v)
|
||||||
|
psnr_total_avg += psnr_rlt_avg[k]
|
||||||
|
psnr_total_avg /= len(psnr_rlt)
|
||||||
|
log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
|
||||||
|
for k, v in psnr_rlt_avg.items():
|
||||||
|
log_s += ' {}: {:.4e}'.format(k, v)
|
||||||
|
logger.info(log_s)
|
||||||
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
|
tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
|
||||||
|
for k, v in psnr_rlt_avg.items():
|
||||||
|
tb_logger.add_scalar(k, v, current_step)
|
||||||
|
|
||||||
|
#### save models and training states
|
||||||
|
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||||
|
if rank <= 0:
|
||||||
|
logger.info('Saving models and training states.')
|
||||||
|
model.save(current_step)
|
||||||
|
model.save_training_state(epoch, current_step)
|
||||||
|
|
||||||
|
if rank <= 0:
|
||||||
|
logger.info('Saving the final model.')
|
||||||
|
model.save('latest')
|
||||||
|
logger.info('End of training.')
|
||||||
|
tb_logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
0
codes/utils/__init__.py
Normal file
0
codes/utils/__init__.py
Normal file
327
codes/utils/util.py
Normal file
327
codes/utils/util.py
Normal file
|
@ -0,0 +1,327 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from datetime import datetime
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
from shutil import get_terminal_size
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
try:
|
||||||
|
from yaml import CLoader as Loader, CDumper as Dumper
|
||||||
|
except ImportError:
|
||||||
|
from yaml import Loader, Dumper
|
||||||
|
|
||||||
|
|
||||||
|
def OrderedYaml():
|
||||||
|
'''yaml orderedDict support'''
|
||||||
|
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
||||||
|
|
||||||
|
def dict_representer(dumper, data):
|
||||||
|
return dumper.represent_dict(data.items())
|
||||||
|
|
||||||
|
def dict_constructor(loader, node):
|
||||||
|
return OrderedDict(loader.construct_pairs(node))
|
||||||
|
|
||||||
|
Dumper.add_representer(OrderedDict, dict_representer)
|
||||||
|
Loader.add_constructor(_mapping_tag, dict_constructor)
|
||||||
|
return Loader, Dumper
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# miscellaneous
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestamp():
|
||||||
|
return datetime.now().strftime('%y%m%d-%H%M%S')
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
|
||||||
|
def mkdirs(paths):
|
||||||
|
if isinstance(paths, str):
|
||||||
|
mkdir(paths)
|
||||||
|
else:
|
||||||
|
for path in paths:
|
||||||
|
mkdir(path)
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir_and_rename(path):
|
||||||
|
if os.path.exists(path):
|
||||||
|
new_name = path + '_archived_' + get_timestamp()
|
||||||
|
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
|
||||||
|
os.rename(path, new_name)
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
|
||||||
|
'''set up logger'''
|
||||||
|
lg = logging.getLogger(logger_name)
|
||||||
|
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
|
||||||
|
datefmt='%y-%m-%d %H:%M:%S')
|
||||||
|
lg.setLevel(level)
|
||||||
|
if tofile:
|
||||||
|
log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
|
||||||
|
fh = logging.FileHandler(log_file, mode='w')
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
lg.addHandler(fh)
|
||||||
|
if screen:
|
||||||
|
sh = logging.StreamHandler()
|
||||||
|
sh.setFormatter(formatter)
|
||||||
|
lg.addHandler(sh)
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# image convert
|
||||||
|
####################
|
||||||
|
def crop_border(img_list, crop_border):
|
||||||
|
"""Crop borders of images
|
||||||
|
Args:
|
||||||
|
img_list (list [Numpy]): HWC
|
||||||
|
crop_border (int): crop border for each end of height and weight
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(list [Numpy]): cropped image list
|
||||||
|
"""
|
||||||
|
if crop_border == 0:
|
||||||
|
return img_list
|
||||||
|
else:
|
||||||
|
return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list]
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||||
|
'''
|
||||||
|
Converts a torch Tensor into an image Numpy array
|
||||||
|
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
||||||
|
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
||||||
|
'''
|
||||||
|
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
|
||||||
|
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
||||||
|
n_dim = tensor.dim()
|
||||||
|
if n_dim == 4:
|
||||||
|
n_img = len(tensor)
|
||||||
|
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
||||||
|
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||||
|
elif n_dim == 3:
|
||||||
|
img_np = tensor.numpy()
|
||||||
|
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||||
|
elif n_dim == 2:
|
||||||
|
img_np = tensor.numpy()
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
||||||
|
if out_type == np.uint8:
|
||||||
|
img_np = (img_np * 255.0).round()
|
||||||
|
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
||||||
|
return img_np.astype(out_type)
|
||||||
|
|
||||||
|
|
||||||
|
def save_img(img, img_path, mode='RGB'):
|
||||||
|
cv2.imwrite(img_path, img)
|
||||||
|
|
||||||
|
|
||||||
|
def DUF_downsample(x, scale=4):
|
||||||
|
"""Downsamping with Gaussian kernel used in the DUF official code
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor, [B, T, C, H, W]): frames to be downsampled.
|
||||||
|
scale (int): downsampling factor: 2 | 3 | 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale)
|
||||||
|
|
||||||
|
def gkern(kernlen=13, nsig=1.6):
|
||||||
|
import scipy.ndimage.filters as fi
|
||||||
|
inp = np.zeros((kernlen, kernlen))
|
||||||
|
# set element at the middle to one, a dirac delta
|
||||||
|
inp[kernlen // 2, kernlen // 2] = 1
|
||||||
|
# gaussian-smooth the dirac, resulting in a gaussian filter mask
|
||||||
|
return fi.gaussian_filter(inp, nsig)
|
||||||
|
|
||||||
|
B, T, C, H, W = x.size()
|
||||||
|
x = x.view(-1, 1, H, W)
|
||||||
|
pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter
|
||||||
|
r_h, r_w = 0, 0
|
||||||
|
if scale == 3:
|
||||||
|
r_h = 3 - (H % 3)
|
||||||
|
r_w = 3 - (W % 3)
|
||||||
|
x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect')
|
||||||
|
|
||||||
|
gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0)
|
||||||
|
x = F.conv2d(x, gaussian_filter, stride=scale)
|
||||||
|
x = x[:, :, 2:-2, 2:-2]
|
||||||
|
x = x.view(B, T, C, x.size(2), x.size(3))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def single_forward(model, inp):
|
||||||
|
"""PyTorch model forward (single test), it is just a simple warpper
|
||||||
|
Args:
|
||||||
|
model (PyTorch model)
|
||||||
|
inp (Tensor): inputs defined by the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output (Tensor): outputs of the model. float, in CPU
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = model(inp)
|
||||||
|
if isinstance(model_output, list) or isinstance(model_output, tuple):
|
||||||
|
output = model_output[0]
|
||||||
|
else:
|
||||||
|
output = model_output
|
||||||
|
output = output.data.float().cpu()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def flipx4_forward(model, inp):
|
||||||
|
"""Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W
|
||||||
|
Args:
|
||||||
|
model (PyTorch model)
|
||||||
|
inp (Tensor): inputs defined by the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output (Tensor): outputs of the model. float, in CPU
|
||||||
|
"""
|
||||||
|
# normal
|
||||||
|
output_f = single_forward(model, inp)
|
||||||
|
|
||||||
|
# flip W
|
||||||
|
output = single_forward(model, torch.flip(inp, (-1, )))
|
||||||
|
output_f = output_f + torch.flip(output, (-1, ))
|
||||||
|
# flip H
|
||||||
|
output = single_forward(model, torch.flip(inp, (-2, )))
|
||||||
|
output_f = output_f + torch.flip(output, (-2, ))
|
||||||
|
# flip both H and W
|
||||||
|
output = single_forward(model, torch.flip(inp, (-2, -1)))
|
||||||
|
output_f = output_f + torch.flip(output, (-2, -1))
|
||||||
|
|
||||||
|
return output_f / 4
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# metric
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_psnr(img1, img2):
|
||||||
|
# img1 and img2 have range [0, 255]
|
||||||
|
img1 = img1.astype(np.float64)
|
||||||
|
img2 = img2.astype(np.float64)
|
||||||
|
mse = np.mean((img1 - img2)**2)
|
||||||
|
if mse == 0:
|
||||||
|
return float('inf')
|
||||||
|
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||||
|
|
||||||
|
|
||||||
|
def ssim(img1, img2):
|
||||||
|
C1 = (0.01 * 255)**2
|
||||||
|
C2 = (0.03 * 255)**2
|
||||||
|
|
||||||
|
img1 = img1.astype(np.float64)
|
||||||
|
img2 = img2.astype(np.float64)
|
||||||
|
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||||
|
window = np.outer(kernel, kernel.transpose())
|
||||||
|
|
||||||
|
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||||
|
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||||
|
mu1_sq = mu1**2
|
||||||
|
mu2_sq = mu2**2
|
||||||
|
mu1_mu2 = mu1 * mu2
|
||||||
|
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||||
|
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||||
|
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||||
|
|
||||||
|
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||||
|
(sigma1_sq + sigma2_sq + C2))
|
||||||
|
return ssim_map.mean()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_ssim(img1, img2):
|
||||||
|
'''calculate SSIM
|
||||||
|
the same outputs as MATLAB's
|
||||||
|
img1, img2: [0, 255]
|
||||||
|
'''
|
||||||
|
if not img1.shape == img2.shape:
|
||||||
|
raise ValueError('Input images must have the same dimensions.')
|
||||||
|
if img1.ndim == 2:
|
||||||
|
return ssim(img1, img2)
|
||||||
|
elif img1.ndim == 3:
|
||||||
|
if img1.shape[2] == 3:
|
||||||
|
ssims = []
|
||||||
|
for i in range(3):
|
||||||
|
ssims.append(ssim(img1, img2))
|
||||||
|
return np.array(ssims).mean()
|
||||||
|
elif img1.shape[2] == 1:
|
||||||
|
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong input image dimensions.')
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressBar(object):
|
||||||
|
'''A progress bar which can print the progress
|
||||||
|
modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, task_num=0, bar_width=50, start=True):
|
||||||
|
self.task_num = task_num
|
||||||
|
max_bar_width = self._get_max_bar_width()
|
||||||
|
self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
|
||||||
|
self.completed = 0
|
||||||
|
if start:
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def _get_max_bar_width(self):
|
||||||
|
terminal_width, _ = get_terminal_size()
|
||||||
|
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
|
||||||
|
if max_bar_width < 10:
|
||||||
|
print('terminal width is too small ({}), please consider widen the terminal for better '
|
||||||
|
'progressbar visualization'.format(terminal_width))
|
||||||
|
max_bar_width = 10
|
||||||
|
return max_bar_width
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if self.task_num > 0:
|
||||||
|
sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
|
||||||
|
' ' * self.bar_width, self.task_num, 'Start...'))
|
||||||
|
else:
|
||||||
|
sys.stdout.write('completed: 0, elapsed: 0s')
|
||||||
|
sys.stdout.flush()
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def update(self, msg='In progress...'):
|
||||||
|
self.completed += 1
|
||||||
|
elapsed = time.time() - self.start_time
|
||||||
|
fps = self.completed / elapsed
|
||||||
|
if self.task_num > 0:
|
||||||
|
percentage = self.completed / float(self.task_num)
|
||||||
|
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
|
||||||
|
mark_width = int(self.bar_width * percentage)
|
||||||
|
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
|
||||||
|
sys.stdout.write('\033[2F') # cursor up 2 lines
|
||||||
|
sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
|
||||||
|
sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
|
||||||
|
bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
|
||||||
|
else:
|
||||||
|
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
|
||||||
|
self.completed, int(elapsed + 0.5), fps))
|
||||||
|
sys.stdout.flush()
|
Loading…
Reference in New Issue
Block a user