This commit is contained in:
XintaoWang 2019-08-23 21:42:47 +08:00
parent 58b175161c
commit 037933ba66
66 changed files with 8121 additions and 0 deletions

6
.flake8 Normal file
View File

@ -0,0 +1,6 @@
[flake8]
ignore =
# Too many leading '#' for block comment (E266)
E266
max-line-length=100

121
.gitignore vendored Normal file
View File

@ -0,0 +1,121 @@
experiments/*
results/*
tb_logger/*
datasets/*
.vscode
*.html
*.png
*.jpg
*.gif
*.pth
*.pytorch
*.zip
*.cu
# template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/

4
.style.yapf Normal file
View File

@ -0,0 +1,4 @@
[style]
BASED_ON_STYLE = pep8
COLUMN_LIMIT = 100
SPLIT_BEFORE_NAMED_ASSIGNS = false

47
README.md Normal file
View File

@ -0,0 +1,47 @@
# MMSR
MMSR is an open source image and video super-resolution toolbox based on PyTorch. It is a part of the [open-mmlab](https://github.com/open-mmlab) project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/). MMSR is based on our previous projects: [BasicSR](https://github.com/xinntao/BasicSR), [ESRGAN](https://github.com/xinntao/ESRGAN)and [EDVR](https://github.com/xinntao/EDVR).
### Highlights
- **A unified framework** suitable for image and video super-resolution tasks. It is also easy to adapt to other restoration tasks, e.g., deblurring, denoising, etc.
- **State of the art**: It includes several winning methods in competitions: such as ESRGAN (PIRM18), EDVR (NTIRE19).
- **Easy to extend**: It is easy to try new research ideas based on the code base.
### Updates
[2019-07-25] MMSR v0.1 is released.
## Dependencies and Installation
- Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
- [PyTorch >= 1.0](https://pytorch.org/)
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
- [Deformable Convolution](https://arxiv.org/abs/1703.06211). We use [mmdetection](https://github.com/open-mmlab/mmdetection)'s dcn implementation. Please first compile it.
```
cd ./codes/models/archs/dcn
python setup.py develop
```
- Python packages: `pip install numpy opencv-python lmdb pyyaml`
- TensorBoard:
- PyTorch >= 1.1: `pip install tb-nightly future`
- PyTorch == 1.0: `pip install tensorboardX`
## Dataset Preparation
We use datasets in LDMB format for faster IO speed. Please refer to [DATASETS.md](datasets/DATASETS.md) for more details.
## Training and Testing
Please see [wiki- Training and Testing](https://github.com/open-mmlab/mmsr/wiki/Training-and-Testing) for the basic usage, *i.e.,* training and testing.
## Model Zoo and Baselines
Results and pre-trained models are available in the [wiki-Model Zoo](https://github.com/open-mmlab/mmsr/wiki/Model-Zoo).
## Contributing
We appreciate all contributions. Please refer to [mmdetection](https://github.com/open-mmlab/mmdetection/blob/master/CONTRIBUTING.md) for contributing guideline.
**Python code style**<br/>
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style. We use [flake8](http://flake8.pycqa.org/en/latest/) as the linter and [yapf](https://github.com/google/yapf) as the formatter. Please upgrade to the latest yapf (>=0.27.0) and refer to the [yapf configuration](.style.yapf) and [flake8 configuration](.flake8).
> Before you create a PR, make sure that your code lints and is formatted by yapf.
## License
This project is released under the Apache 2.0 license.

127
codes/data/LQGT_dataset.py Normal file
View File

@ -0,0 +1,127 @@
import random
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
class LQGTDataset(data.Dataset):
"""
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
If only GT images are provided, generate LQ images on-the-fly.
"""
def __init__(self, opt):
super(LQGTDataset, self).__init__()
self.opt = opt
self.data_type = self.opt['data_type']
self.paths_LQ, self.paths_GT = None, None
self.sizes_LQ, self.sizes_GT = None, None
self.LQ_env, self.GT_env = None, None # environments for lmdb
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT:
assert len(self.paths_LQ) == len(
self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
len(self.paths_LQ), len(self.paths_GT))
self.random_scale_list = [1]
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def __getitem__(self, index):
if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
GT_path, LQ_path = None, None
scale = self.opt['scale']
GT_size = self.opt['GT_size']
# get GT image
GT_path = self.paths_GT[index]
resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution)
if self.opt['phase'] != 'train': # modcrop in the validation / test phase
img_GT = util.modcrop(img_GT, scale)
if self.opt['color']: # change color space if necessary
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get LQ image
if self.paths_LQ:
LQ_path = self.paths_LQ[index]
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_GT.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
H, W, _ = img_GT.shape
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
if self.opt['phase'] == 'train':
# if the image size is too small
H, W, _ = img_GT.shape
if H < GT_size or W < GT_size:
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
H, W, C = img_LQ.shape
LQ_size = GT_size // scale
# randomly crop
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
# augmentation - flip, rotate
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
self.opt['use_rot'])
if self.opt['color']: # change color space if necessary
img_LQ = util.channel_convert(C, self.opt['color'],
[img_LQ])[0] # TODO during val no definition
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQ = img_LQ[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
if LQ_path is None:
LQ_path = GT_path
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
def __len__(self):
return len(self.paths_GT)

47
codes/data/LQ_dataset.py Normal file
View File

@ -0,0 +1,47 @@
import numpy as np
import lmdb
import torch
import torch.utils.data as data
import data.util as util
class LQDataset(data.Dataset):
'''Read LQ images only in the test phase.'''
def __init__(self, opt):
super(LQDataset, self).__init__()
self.opt = opt
self.paths_LQ, self.paths_GT = None, None
self.LQ_env = None # environment for lmdb
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
assert self.paths_LQ, 'Error: LQ paths are empty.'
def _init_lmdb(self):
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def __getitem__(self, index):
if self.data_type == 'lmdb' and self.LQ_env is None:
self._init_lmdb()
LQ_path = None
# get LQ image
LQ_path = self.LQ_path[index]
resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
H, W, C = img_LQ.shape
if self.opt['color']: # change color space if necessary
img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0]
# BGR to RGB, HWC to CHW, numpy to tensor
if img_LQ.shape[2] == 3:
img_LQ = img_LQ[:, :, [2, 1, 0]]
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
return {'LQ': img_LQ, 'LQ_path': LQ_path}
def __len__(self):
return len(self.paths_LQ)

210
codes/data/REDS_dataset.py Normal file
View File

@ -0,0 +1,210 @@
'''
REDS dataset
support reading images from lmdb, image folder and memcached
'''
import os.path as osp
import random
import pickle
import logging
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
try:
import mc # import memcached
except ImportError:
pass
logger = logging.getLogger('base')
class REDSDataset(data.Dataset):
'''
Reading the training REDS dataset
key example: 000_00000000
GT: Ground-Truth;
LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames
support reading N LQ frames, N = 1, 3, 5, 7
'''
def __init__(self, opt):
super(REDSDataset, self).__init__()
self.opt = opt
# temporal augmentation
self.interval_list = opt['interval_list']
self.random_reverse = opt['random_reverse']
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
self.half_N_frames = opt['N_frames'] // 2
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
self.data_type = self.opt['data_type']
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
#### directly load image keys
if self.data_type == 'lmdb':
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
logger.info('Using lmdb meta info for cache keys.')
elif opt['cache_keys']:
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
else:
raise ValueError(
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
# remove the REDS4 for testing
self.paths_GT = [
v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020']
]
assert self.paths_GT, 'Error: GT path is empty.'
if self.data_type == 'lmdb':
self.GT_env, self.LQ_env = None, None
elif self.data_type == 'mc': # memcached
self.mclient = None
elif self.data_type == 'img':
pass
else:
raise ValueError('Wrong data type: {}'.format(self.data_type))
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def _ensure_memcached(self):
if self.mclient is None:
# specify the config files
server_list_config_file = None
client_config_file = None
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
client_config_file)
def _read_img_mc(self, path):
''' Return BGR, HWC, [0, 255], uint8'''
value = mc.pyvector()
self.mclient.Get(path, value)
value_buf = mc.ConvertBuffer(value)
img_array = np.frombuffer(value_buf, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
return img
def _read_img_mc_BGR(self, path, name_a, name_b):
''' Read BGR channels separately and then combine for 1M limits in cluster'''
img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png'))
img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png'))
img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png'))
img = cv2.merge((img_B, img_G, img_R))
return img
def __getitem__(self, index):
if self.data_type == 'mc':
self._ensure_memcached()
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
scale = self.opt['scale']
GT_size = self.opt['GT_size']
key = self.paths_GT[index]
name_a, name_b = key.split('_')
center_frame_idx = int(name_b)
#### determine the neighbor frames
interval = random.choice(self.interval_list)
if self.opt['border_mode']:
direction = 1 # 1: forward; 0: backward
N_frames = self.opt['N_frames']
if self.random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (N_frames - 1) > 99:
direction = 0
elif center_frame_idx - interval * (N_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * N_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * N_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + self.half_N_frames * interval >
99) or (center_frame_idx - self.half_N_frames * interval < 0):
center_frame_idx = random.randint(0, 99)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - self.half_N_frames * interval,
center_frame_idx + self.half_N_frames * interval + 1, interval))
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[self.half_N_frames])
assert len(
neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format(
len(neighbor_list))
#### get the GT image (as the center frame)
if self.data_type == 'mc':
img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b)
img_GT = img_GT.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_GT = util.read_img(self.GT_env, key, (3, 720, 1280))
else:
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png'))
#### get LQ images
LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280)
img_LQ_l = []
for v in neighbor_list:
img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v))
if self.data_type == 'mc':
if self.LR_input:
img_LQ = self._read_img_mc(img_LQ_path)
else:
img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v))
img_LQ = img_LQ.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple)
else:
img_LQ = util.read_img(None, img_LQ_path)
img_LQ_l.append(img_LQ)
if self.opt['phase'] == 'train':
C, H, W = LQ_size_tuple # LQ size
# randomly crop
if self.LR_input:
LQ_size = GT_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
else:
rnd_h = random.randint(0, max(0, H - GT_size))
rnd_w = random.randint(0, max(0, W - GT_size))
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
# augmentation - flip, rotate
img_LQ_l.append(img_GT)
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
img_LQ_l = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(img_LQ_l, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
(0, 3, 1, 2)))).float()
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
def __len__(self):
return len(self.paths_GT)

View File

@ -0,0 +1,167 @@
'''
Vimeo90K dataset
support reading images from lmdb, image folder and memcached
'''
import os.path as osp
import random
import pickle
import logging
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
try:
import mc # import memcached
except ImportError:
pass
logger = logging.getLogger('base')
class Vimeo90KDataset(data.Dataset):
'''
Reading the training Vimeo90K dataset
key example: 00001_0001 (_1, ..., _7)
GT (Ground-Truth): 4th frame;
LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame
'''
def __init__(self, opt):
super(Vimeo90KDataset, self).__init__()
self.opt = opt
# temporal augmentation
self.interval_list = opt['interval_list']
self.random_reverse = opt['random_reverse']
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
self.data_type = self.opt['data_type']
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
#### determine the LQ frame list
'''
N | frames
1 | 4
3 | 3,4,5
5 | 2,3,4,5,6
7 | 1,2,3,4,5,6,7
'''
self.LQ_frames_list = []
for i in range(opt['N_frames']):
self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2)
#### directly load image keys
if self.data_type == 'lmdb':
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
logger.info('Using lmdb meta info for cache keys.')
elif opt['cache_keys']:
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
else:
raise ValueError(
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
assert self.paths_GT, 'Error: GT path is empty.'
if self.data_type == 'lmdb':
self.GT_env, self.LQ_env = None, None
elif self.data_type == 'mc': # memcached
self.mclient = None
elif self.data_type == 'img':
pass
else:
raise ValueError('Wrong data type: {}'.format(self.data_type))
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def _ensure_memcached(self):
if self.mclient is None:
# specify the config files
server_list_config_file = None
client_config_file = None
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
client_config_file)
def _read_img_mc(self, path):
''' Return BGR, HWC, [0, 255], uint8'''
value = mc.pyvector()
self.mclient.Get(path, value)
value_buf = mc.ConvertBuffer(value)
img_array = np.frombuffer(value_buf, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
return img
def __getitem__(self, index):
if self.data_type == 'mc':
self._ensure_memcached()
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
scale = self.opt['scale']
GT_size = self.opt['GT_size']
key = self.paths_GT[index]
name_a, name_b = key.split('_')
#### get the GT image (as the center frame)
if self.data_type == 'mc':
img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png'))
img_GT = img_GT.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448))
else:
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png'))
#### get LQ images
LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
img_LQ_l = []
for v in self.LQ_frames_list:
if self.data_type == 'mc':
img_LQ = self._read_img_mc(
osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v)))
img_LQ = img_LQ.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
else:
img_LQ = util.read_img(None,
osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
img_LQ_l.append(img_LQ)
if self.opt['phase'] == 'train':
C, H, W = LQ_size_tuple # LQ size
# randomly crop
if self.LR_input:
LQ_size = GT_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
else:
rnd_h = random.randint(0, max(0, H - GT_size))
rnd_w = random.randint(0, max(0, W - GT_size))
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
# augmentation - flip, rotate
img_LQ_l.append(img_GT)
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
img_LQ_l = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(img_LQ_l, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
(0, 3, 1, 2)))).float()
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
def __len__(self):
return len(self.paths_GT)

49
codes/data/__init__.py Normal file
View File

@ -0,0 +1,49 @@
"""create dataset and dataloader"""
import logging
import torch
import torch.utils.data
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
phase = dataset_opt['phase']
if phase == 'train':
if opt['dist']:
world_size = torch.distributed.get_world_size()
num_workers = dataset_opt['n_workers']
assert dataset_opt['batch_size'] % world_size == 0
batch_size = dataset_opt['batch_size'] // world_size
shuffle = False
else:
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
batch_size = dataset_opt['batch_size']
shuffle = True
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, sampler=sampler, drop_last=True,
pin_memory=False)
else:
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
pin_memory=False)
def create_dataset(dataset_opt):
mode = dataset_opt['mode']
# datasets for image restoration
if mode == 'LQ':
from data.LQ_dataset import LQDataset as D
elif mode == 'LQGT':
from data.LQGT_dataset import LQGTDataset as D
# datasets for video restoration
elif mode == 'REDS':
from data.REDS_dataset import REDSDataset as D
elif mode == 'Vimeo90K':
from data.Vimeo90K_dataset import Vimeo90KDataset as D
elif mode == 'video_test':
from data.video_test_dataset import VideoTestDataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)
logger = logging.getLogger('base')
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
dataset_opt['name']))
return dataset

View File

@ -0,0 +1,65 @@
"""
Modified from torch.utils.data.distributed.DistributedSampler
Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the
dataloader after each epoch
"""
import math
import torch
from torch.utils.data.sampler import Sampler
import torch.distributed as dist
class DistIterSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(self.total_size, generator=g).tolist()
dsize = len(self.dataset)
indices = [v % dsize for v in indices]
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch

543
codes/data/util.py Normal file
View File

@ -0,0 +1,543 @@
import os
import math
import pickle
import random
import numpy as np
import glob
import torch
import cv2
####################
# Files & IO
####################
###################### get image path list ######################
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def _get_paths_from_images(path):
"""get image path list from image folder"""
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return images
def _get_paths_from_lmdb(dataroot):
"""get image path list from lmdb meta info"""
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
paths = meta_info['keys']
sizes = meta_info['resolution']
if len(sizes) == 1:
sizes = sizes * len(paths)
return paths, sizes
def get_image_paths(data_type, dataroot):
"""get image path list
support lmdb or image files"""
paths, sizes = None, None
if dataroot is not None:
if data_type == 'lmdb':
paths, sizes = _get_paths_from_lmdb(dataroot)
elif data_type == 'img':
paths = sorted(_get_paths_from_images(dataroot))
else:
raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
return paths, sizes
def glob_file_list(root):
return sorted(glob.glob(os.path.join(root, '*')))
###################### read images ######################
def _read_img_lmdb(env, key, size):
"""read image from lmdb with key (w/ and w/o fixed size)
size: (C, H, W) tuple"""
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
img_flat = np.frombuffer(buf, dtype=np.uint8)
C, H, W = size
img = img_flat.reshape(H, W, C)
return img
def read_img(env, path, size=None):
"""read image by cv2 or from lmdb
return: Numpy float32, HWC, BGR, [0,1]"""
if env is None: # img
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
else:
img = _read_img_lmdb(env, path, size)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def read_img_seq(path):
"""Read a sequence of images from a given folder path
Args:
path (list/str): list of image paths/image folder path
Returns:
imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
"""
if type(path) is list:
img_path_l = path
else:
img_path_l = sorted(glob.glob(os.path.join(path, '*')))
img_l = [read_img(None, v) for v in img_path_l]
# stack to Torch tensor
imgs = np.stack(img_l, axis=0)
imgs = imgs[:, :, :, [2, 1, 0]]
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
return imgs
def index_generation(crt_i, max_n, N, padding='reflection'):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
return return_l
####################
# image processing
# process on numpy image
####################
def augment(img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def augment_flow(img_list, flow_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip:
flow = flow[:, ::-1, :]
flow[:, :, 0] *= -1
if vflip:
flow = flow[::-1, :, :]
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
rlt_img_list = [_augment(img) for img in img_list]
rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
return rlt_img_list, rlt_flow_list
def channel_convert(in_c, tar_type, img_list):
"""conversion among BGR, gray and y"""
if in_c == 3 and tar_type == 'gray': # BGR to gray
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
return [np.expand_dims(img, axis=2) for img in gray_list]
elif in_c == 3 and tar_type == 'y': # BGR to y
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
return [np.expand_dims(img, axis=2) for img in y_list]
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
else:
return img_list
def rgb2ycbcr(img, only_y=True):
"""same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True):
"""bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
"""same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def modcrop(img_in, scale):
"""img_in: Numpy, HWC or HW"""
img = np.copy(img_in)
if img.ndim == 2:
H, W = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r]
elif img.ndim == 3:
H, W, C = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r, :]
else:
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
return img
####################
# Functions
####################
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
(absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
# apply cubic kernel
if (scale < 1) and (antialiasing):
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, P)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, P - 2)
weights = weights.narrow(1, 1, P - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, P - 2)
weights = weights.narrow(1, 0, P - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
def imresize(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: CHW RGB [0,1]
# output: CHW RGB [0,1] w/o round
in_C, in_H, in_W = img.size()
_, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:, :sym_len_Hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_He:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_C, out_H, in_W)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_Ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_We:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_C, out_H, out_W)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
return out_2
def imresize_np(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC BGR [0,1]
# output: HWC BGR [0,1] w/o round
img = torch.from_numpy(img)
in_H, in_W, in_C = img.size()
_, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:sym_len_Hs, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[-sym_len_He:, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(out_H, in_W, in_C)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :sym_len_Ws, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, -sym_len_We:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(out_H, out_W, in_C)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
return out_2.numpy()
if __name__ == '__main__':
# test imresize function
# read images
img = cv2.imread('test.png')
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
# imresize
scale = 1 / 4
import time
total_time = 0
for i in range(10):
start_time = time.time()
rlt = imresize(img, scale, antialiasing=True)
use_time = time.time() - start_time
total_time += use_time
print('average time: {}'.format(total_time / 10))
import torchvision.utils
torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
normalize=False)

View File

@ -0,0 +1,84 @@
import os.path as osp
import torch
import torch.utils.data as data
import data.util as util
class VideoTestDataset(data.Dataset):
"""
A video test dataset. Support:
Vid4
REDS4
Vimeo90K-Test
no need to prepare LMDB files
"""
def __init__(self, opt):
super(VideoTestDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
self.half_N_frames = opt['N_frames'] // 2
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
self.data_type = self.opt['data_type']
self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []}
if self.data_type == 'lmdb':
raise ValueError('No need to use LMDB during validation/test.')
#### Generate data info and cache data
self.imgs_LQ, self.imgs_GT = {}, {}
if opt['name'].lower() in ['vid4', 'reds4']:
subfolders_LQ = util.glob_file_list(self.LQ_root)
subfolders_GT = util.glob_file_list(self.GT_root)
for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT):
subfolder_name = osp.basename(subfolder_GT)
img_paths_LQ = util.glob_file_list(subfolder_LQ)
img_paths_GT = util.glob_file_list(subfolder_GT)
max_idx = len(img_paths_LQ)
assert max_idx == len(
img_paths_GT), 'Different number of images in LQ and GT folders'
self.data_info['path_LQ'].extend(img_paths_LQ)
self.data_info['path_GT'].extend(img_paths_GT)
self.data_info['folder'].extend([subfolder_name] * max_idx)
for i in range(max_idx):
self.data_info['idx'].append('{}/{}'.format(i, max_idx))
border_l = [0] * max_idx
for i in range(self.half_N_frames):
border_l[i] = 1
border_l[max_idx - i - 1] = 1
self.data_info['border'].extend(border_l)
if self.cache_data:
self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ)
self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT)
elif opt['name'].lower() in ['vimeo90k-test']:
pass # TODO
else:
raise ValueError(
'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.')
def __getitem__(self, index):
# path_LQ = self.data_info['path_LQ'][index]
# path_GT = self.data_info['path_GT'][index]
folder = self.data_info['folder'][index]
idx, max_idx = self.data_info['idx'][index].split('/')
idx, max_idx = int(idx), int(max_idx)
border = self.data_info['border'][index]
if self.cache_data:
select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
padding=self.opt['padding'])
imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx))
img_GT = self.imgs_GT[folder][idx]
else:
pass # TODO
return {
'LQs': imgs_LQ,
'GT': img_GT,
'folder': folder,
'idx': self.data_info['idx'][index],
'border': border
}
def __len__(self):
return len(self.data_info['path_GT'])

View File

@ -0,0 +1,411 @@
"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets"""
import sys
import os.path as osp
import glob
import pickle
from multiprocessing import Pool
import numpy as np
import lmdb
import cv2
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
import data.util as data_util # noqa: E402
import utils.util as util # noqa: E402
def main():
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
mode = 'GT' # used for vimeo90k and REDS datasets
# vimeo90k: GT | LR | flow
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
# train_sharp_flowx4
if dataset == 'vimeo90k':
vimeo90k(mode)
elif dataset == 'REDS':
REDS(mode)
elif dataset == 'general':
opt = {}
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
opt['name'] = 'DIV2K800_sub_GT'
general_image_folder(opt)
elif dataset == 'DIV2K_demo':
opt = {}
## GT
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
opt['name'] = 'DIV2K800_sub_GT'
general_image_folder(opt)
## LR
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
opt['name'] = 'DIV2K800_sub_bicLRx4'
general_image_folder(opt)
elif dataset == 'test':
test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS')
def read_image_worker(path, key):
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
return (key, img)
def general_image_folder(opt):
"""Create lmdb for general image folders
Users should define the keys, such as: '0321_s035' for DIV2K sub-images
If all the images have the same resolution, it will only store one copy of resolution info.
Otherwise, it will store every resolution info.
"""
#### configurations
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
n_thread = 40
########################################################
img_folder = opt['img_folder']
lmdb_save_path = opt['lmdb_save_path']
meta_info = {'name': opt['name']}
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
if osp.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
#### read all the image paths to a list
print('Reading image path list ...')
all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
keys = []
for img_path in all_img_list:
keys.append(osp.splitext(osp.basename(img_path))[0])
if read_all_imgs:
#### read all images to memory (multiprocessing)
dataset = {} # store all image data. list cannot keep the order, use dict
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
pbar = util.ProgressBar(len(all_img_list))
def mycallback(arg):
'''get the image data and update pbar'''
key = arg[0]
dataset[key] = arg[1]
pbar.update('Reading {}'.format(key))
pool = Pool(n_thread)
for path, key in zip(all_img_list, keys):
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
pool.close()
pool.join()
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
#### create lmdb environment
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
#### write data to lmdb
pbar = util.ProgressBar(len(all_img_list))
txn = env.begin(write=True)
resolutions = []
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
pbar.update('Write {}'.format(key))
key_byte = key.encode('ascii')
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
if data.ndim == 2:
H, W = data.shape
C = 1
else:
H, W, C = data.shape
txn.put(key_byte, data)
resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
if not read_all_imgs and idx % BATCH == 0:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
print('Finish writing lmdb.')
#### create meta information
# check whether all the images are the same size
assert len(keys) == len(resolutions)
if len(set(resolutions)) <= 1:
meta_info['resolution'] = [resolutions[0]]
meta_info['keys'] = keys
print('All images have the same resolution. Simplify the meta info.')
else:
meta_info['resolution'] = resolutions
meta_info['keys'] = keys
print('Not all images have the same resolution. Save meta info for each image.')
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def vimeo90k(mode):
"""Create lmdb for the Vimeo90K dataset, each image with a fixed size
GT: [3, 256, 448]
Now only need the 4th frame, e.g., 00001_0001_4
LR: [3, 64, 112]
1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
key:
Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3]
Each flow is calculated with GT images by PWCNet and then downsampled by 1/4
Flow map is quantized by mmcv and saved in png format
"""
#### configurations
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
if mode == 'GT':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
H_dst, W_dst = 256, 448
elif mode == 'LR':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
H_dst, W_dst = 64, 112
elif mode == 'flow':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
H_dst, W_dst = 128, 112
else:
raise ValueError('Wrong dataset mode: {}'.format(mode))
n_thread = 40
########################################################
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
if osp.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
#### read all the image paths to a list
print('Reading image path list ...')
with open(txt_file) as f:
train_l = f.readlines()
train_l = [v.strip() for v in train_l]
all_img_list = []
keys = []
for line in train_l:
folder = line.split('/')[0]
sub_folder = line.split('/')[1]
all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
if mode == 'flow':
for j in range(1, 4):
keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j))
keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j))
else:
for j in range(7):
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
all_img_list = sorted(all_img_list)
keys = sorted(keys)
if mode == 'GT': # only read the 4th frame for the GT mode
print('Only keep the 4th frame.')
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
keys = [v for v in keys if v.endswith('_4')]
if read_all_imgs:
#### read all images to memory (multiprocessing)
dataset = {} # store all image data. list cannot keep the order, use dict
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
pbar = util.ProgressBar(len(all_img_list))
def mycallback(arg):
"""get the image data and update pbar"""
key = arg[0]
dataset[key] = arg[1]
pbar.update('Reading {}'.format(key))
pool = Pool(n_thread)
for path, key in zip(all_img_list, keys):
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
pool.close()
pool.join()
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
#### write data to lmdb
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
txn = env.begin(write=True)
pbar = util.ProgressBar(len(all_img_list))
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
pbar.update('Write {}'.format(key))
key_byte = key.encode('ascii')
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
if 'flow' in mode:
H, W = data.shape
assert H == H_dst and W == W_dst, 'different shape.'
else:
H, W, C = data.shape
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
txn.put(key_byte, data)
if not read_all_imgs and idx % BATCH == 0:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
print('Finish writing lmdb.')
#### create meta information
meta_info = {}
if mode == 'GT':
meta_info['name'] = 'Vimeo90K_train_GT'
elif mode == 'LR':
meta_info['name'] = 'Vimeo90K_train_LR'
elif mode == 'flow':
meta_info['name'] = 'Vimeo90K_train_flowx4'
channel = 1 if 'flow' in mode else 3
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
key_set = set()
for key in keys:
if mode == 'flow':
a, b, _, _ = key.split('_')
else:
a, b, _ = key.split('_')
key_set.add('{}_{}'.format(a, b))
meta_info['keys'] = list(key_set)
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def REDS(mode):
"""Create lmdb for the REDS dataset, each image with a fixed size
GT: [3, 720, 1280], key: 000_00000000
LR: [3, 180, 320], key: 000_00000000
key: 000_00000000
flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
Flow map is quantized by mmcv and saved in png format
"""
#### configurations
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
if mode == 'train_sharp':
img_folder = '../../datasets/REDS/train_sharp'
lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
H_dst, W_dst = 720, 1280
elif mode == 'train_sharp_bicubic':
img_folder = '../../datasets/REDS/train_sharp_bicubic'
lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
H_dst, W_dst = 180, 320
elif mode == 'train_blur_bicubic':
img_folder = '../../datasets/REDS/train_blur_bicubic'
lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
H_dst, W_dst = 180, 320
elif mode == 'train_blur':
img_folder = '../../datasets/REDS/train_blur'
lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
H_dst, W_dst = 720, 1280
elif mode == 'train_blur_comp':
img_folder = '../../datasets/REDS/train_blur_comp'
lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
H_dst, W_dst = 720, 1280
elif mode == 'train_sharp_flowx4':
img_folder = '../../datasets/REDS/train_sharp_flowx4'
lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
H_dst, W_dst = 360, 320
n_thread = 40
########################################################
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
if osp.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
#### read all the image paths to a list
print('Reading image path list ...')
all_img_list = data_util._get_paths_from_images(img_folder)
keys = []
for img_path in all_img_list:
split_rlt = img_path.split('/')
folder = split_rlt[-2]
img_name = split_rlt[-1].split('.png')[0]
keys.append(folder + '_' + img_name)
if read_all_imgs:
#### read all images to memory (multiprocessing)
dataset = {} # store all image data. list cannot keep the order, use dict
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
pbar = util.ProgressBar(len(all_img_list))
def mycallback(arg):
'''get the image data and update pbar'''
key = arg[0]
dataset[key] = arg[1]
pbar.update('Reading {}'.format(key))
pool = Pool(n_thread)
for path, key in zip(all_img_list, keys):
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
pool.close()
pool.join()
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
#### create lmdb environment
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
#### write data to lmdb
pbar = util.ProgressBar(len(all_img_list))
txn = env.begin(write=True)
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
pbar.update('Write {}'.format(key))
key_byte = key.encode('ascii')
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
if 'flow' in mode:
H, W = data.shape
assert H == H_dst and W == W_dst, 'different shape.'
else:
H, W, C = data.shape
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
txn.put(key_byte, data)
if not read_all_imgs and idx % BATCH == 0:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
print('Finish writing lmdb.')
#### create meta information
meta_info = {}
meta_info['name'] = 'REDS_{}_wval'.format(mode)
channel = 1 if 'flow' in mode else 3
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
meta_info['keys'] = keys
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def test_lmdb(dataroot, dataset='REDS'):
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
print('Name: ', meta_info['name'])
print('Resolution: ', meta_info['resolution'])
print('# keys: ', len(meta_info['keys']))
# read one image
if dataset == 'vimeo90k':
key = '00001_0001_4'
else:
key = '000_00000000'
print('Reading {} for test.'.format(key))
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
img_flat = np.frombuffer(buf, dtype=np.uint8)
C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
img = img_flat.reshape(H, W, C)
cv2.imwrite('test.png', img)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,141 @@
"""A multi-thread tool to crop large images to sub-images for faster IO."""
import os
import os.path as osp
import sys
from multiprocessing import Pool
import numpy as np
import cv2
from PIL import Image
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from utils.util import ProgressBar # noqa: E402
import data.util as data_util # noqa: E402
def main():
mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3 # 3 is the default value in cv2
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR'
opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
opt['crop_sz'] = 480 # the size of each sub-image
opt['step'] = 240 # step of the sliding crop window
opt['thres_sz'] = 48 # size threshold
extract_signle(opt)
elif mode == 'pair':
GT_folder = '../../datasets/DIV2K/DIV2K_train_HR'
LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub'
save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
scale_ratio = 4
crop_sz = 480 # the size of each sub-image (GT)
step = 240 # step of the sliding crop window (GT)
thres_sz = 48 # size threshold
########################################################################
# check that all the GT and LR images have correct scale ratio
img_GT_list = data_util._get_paths_from_images(GT_folder)
img_LR_list = data_util._get_paths_from_images(LR_folder)
assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
for path_GT, path_LR in zip(img_GT_list, img_LR_list):
img_GT = Image.open(path_GT)
img_LR = Image.open(path_LR)
w_GT, h_GT = img_GT.size
w_LR, h_LR = img_LR.size
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
w_GT, scale_ratio, w_LR, path_GT)
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
w_GT, scale_ratio, w_LR, path_GT)
# check crop size, step and threshold size
assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
scale_ratio)
assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
scale_ratio)
print('process GT...')
opt['input_folder'] = GT_folder
opt['save_folder'] = save_GT_folder
opt['crop_sz'] = crop_sz
opt['step'] = step
opt['thres_sz'] = thres_sz
extract_signle(opt)
print('process LR...')
opt['input_folder'] = LR_folder
opt['save_folder'] = save_LR_folder
opt['crop_sz'] = crop_sz // scale_ratio
opt['step'] = step // scale_ratio
opt['thres_sz'] = thres_sz // scale_ratio
extract_signle(opt)
assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
data_util._get_paths_from_images(
save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
else:
raise ValueError('Wrong mode.')
def extract_signle(opt):
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print('mkdir [{:s}] ...'.format(save_folder))
else:
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
sys.exit(1)
img_list = data_util._get_paths_from_images(input_folder)
def update(arg):
pbar.update(arg)
pbar = ProgressBar(len(img_list))
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=update)
pool.close()
pool.join()
print('All subprocesses done.')
def worker(path, opt):
crop_sz = opt['crop_sz']
step = opt['step']
thres_sz = opt['thres_sz']
img_name = osp.basename(path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
n_channels = len(img.shape)
if n_channels == 2:
h, w = img.shape
elif n_channels == 3:
h, w, c = img.shape
else:
raise ValueError('Wrong image shape - {}'.format(n_channels))
h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)
w_space = np.arange(0, w - crop_sz + 1, step)
if w - (w_space[-1] + crop_sz) > thres_sz:
w_space = np.append(w_space, w - crop_sz)
index = 0
for x in h_space:
for y in w_space:
index += 1
if n_channels == 2:
crop_img = img[x:x + crop_sz, y:y + crop_sz]
else:
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
crop_img = np.ascontiguousarray(crop_img)
cv2.imwrite(
osp.join(opt['save_folder'],
img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
return 'Processing {:s} ...'.format(img_name)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,49 @@
function generate_LR_Vimeo90K()
%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
up_scale = 4;
mod_scale = 4;
idx = 0;
filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
for i = 1 : length(filepaths)
[~,imname,ext] = fileparts(filepaths(i).name);
folder_path = filepaths(i).folder;
save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
if ~exist(save_LR_folder, 'dir')
mkdir(save_LR_folder);
end
if isempty(imname)
disp('Ignore . folder.');
elseif strcmp(imname, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_rlt = sprintf('%d\t%s.\n', idx, imname);
fprintf(str_rlt);
% read image
img = imread(fullfile(folder_path, [imname, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
% LR
im_LR = imresize(img, 1/up_scale, 'bicubic');
if exist('save_LR_folder', 'var')
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end

View File

@ -0,0 +1,82 @@
function generate_mod_LR_bic()
%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
%% set parameters
% comment the unnecessary line
input_folder = '../../datasets/DIV2K/DIV2K800';
% save_mod_folder = '';
save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4';
% save_bic_folder = '';
up_scale = 4;
mod_scale = 4;
if exist('save_mod_folder', 'var')
if exist(save_mod_folder, 'dir')
disp(['It will cover ', save_mod_folder]);
else
mkdir(save_mod_folder);
end
end
if exist('save_LR_folder', 'var')
if exist(save_LR_folder, 'dir')
disp(['It will cover ', save_LR_folder]);
else
mkdir(save_LR_folder);
end
end
if exist('save_bic_folder', 'var')
if exist(save_bic_folder, 'dir')
disp(['It will cover ', save_bic_folder]);
else
mkdir(save_bic_folder);
end
end
idx = 0;
filepaths = dir(fullfile(input_folder,'*.*'));
for i = 1 : length(filepaths)
[paths,imname,ext] = fileparts(filepaths(i).name);
if isempty(imname)
disp('Ignore . folder.');
elseif strcmp(imname, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_rlt = sprintf('%d\t%s.\n', idx, imname);
fprintf(str_rlt);
% read image
img = imread(fullfile(input_folder, [imname, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
if exist('save_mod_folder', 'var')
imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
end
% LR
im_LR = imresize(img, 1/up_scale, 'bicubic');
if exist('save_LR_folder', 'var')
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
end
% Bicubic
if exist('save_bic_folder', 'var')
im_B = imresize(im_LR, up_scale, 'bicubic');
imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end

View File

@ -0,0 +1,81 @@
import os
import sys
import cv2
import numpy as np
try:
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.util import imresize_np
except ImportError:
pass
def generate_mod_LR_bic():
# set parameters
up_scale = 4
mod_scale = 4
# set data dir
sourcedir = '/data/datasets/img'
savedir = '/data/datasets/mod'
saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale))
saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
if not os.path.isdir(sourcedir):
print('Error: No source data found')
exit(0)
if not os.path.isdir(savedir):
os.mkdir(savedir)
if not os.path.isdir(os.path.join(savedir, 'HR')):
os.mkdir(os.path.join(savedir, 'HR'))
if not os.path.isdir(os.path.join(savedir, 'LR')):
os.mkdir(os.path.join(savedir, 'LR'))
if not os.path.isdir(os.path.join(savedir, 'Bic')):
os.mkdir(os.path.join(savedir, 'Bic'))
if not os.path.isdir(saveHRpath):
os.mkdir(saveHRpath)
else:
print('It will cover ' + str(saveHRpath))
if not os.path.isdir(saveLRpath):
os.mkdir(saveLRpath)
else:
print('It will cover ' + str(saveLRpath))
if not os.path.isdir(saveBicpath):
os.mkdir(saveBicpath)
else:
print('It will cover ' + str(saveBicpath))
filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
num_files = len(filepaths)
# prepare data with augementation
for i in range(num_files):
filename = filepaths[i]
print('No.{} -- Processing {}'.format(i, filename))
# read image
image = cv2.imread(os.path.join(sourcedir, filename))
width = int(np.floor(image.shape[1] / mod_scale))
height = int(np.floor(image.shape[0] / mod_scale))
# modcrop
if len(image.shape) == 3:
image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
else:
image_HR = image[0:mod_scale * height, 0:mod_scale * width]
# LR
image_LR = imresize_np(image_HR, 1 / up_scale, True)
# bic
image_Bic = imresize_np(image_LR, up_scale, True)
cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
if __name__ == "__main__":
generate_mod_LR_bic()

View File

@ -0,0 +1,42 @@
echo "Prepare DIV2K X4 datasets..."
cd ../../datasets
mkdir DIV2K
cd DIV2K
#### Step 1
echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..."
# GT
FOLDER=DIV2K_train_HR
FILE=DIV2K_train_HR.zip
if [ ! -d "$FOLDER" ]; then
if [ ! -f "$FILE" ]; then
echo "Downloading $FILE..."
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
fi
unzip $FILE
fi
# LR
FOLDER=DIV2K_train_LR_bicubic
FILE=DIV2K_train_LR_bicubic_X4.zip
if [ ! -d "$FOLDER" ]; then
if [ ! -f "$FILE" ]; then
echo "Downloading $FILE..."
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
fi
unzip $FILE
fi
#### Step 2
echo "Step 2: Rename the LR images..."
cd ../../codes/data_scripts
python rename.py
#### Step 4
echo "Step 4: Crop to sub-images..."
python extract_subimages.py
#### Step 5
echo "Step5: Create LMDB files..."
python create_lmdb.py

View File

@ -0,0 +1,11 @@
import os
import glob
train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4'
val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4'
# mv the val set
val_folders = glob.glob(os.path.join(val_path, '*'))
for folder in val_folders:
new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240)
os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx)))

View File

@ -0,0 +1,19 @@
import os
import glob
def main():
folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
DIV2K(folder)
print('Finished.')
def DIV2K(path):
img_path_l = glob.glob(os.path.join(path, '*'))
for img_path in img_path_l:
new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
os.rename(img_path, new_path)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,104 @@
import sys
import os.path as osp
import math
import torchvision.utils
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from data import create_dataloader, create_dataset # noqa: E402
from utils import util # noqa: E402
def main():
dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
opt = {}
opt['dist'] = False
opt['gpu_ids'] = [0]
if dataset == 'REDS':
opt['name'] = 'test_REDS'
opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
opt['mode'] = 'REDS'
opt['N_frames'] = 5
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['GT_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = None
opt['data_type'] = 'lmdb' # img | lmdb | mc
elif dataset == 'Vimeo90K':
opt['name'] = 'test_Vimeo90K'
opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
opt['mode'] = 'Vimeo90K'
opt['N_frames'] = 7
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['GT_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = None
opt['data_type'] = 'lmdb' # img | lmdb | mc
elif dataset == 'DIV2K800_sub':
opt['name'] = 'DIV2K800'
opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
opt['mode'] = 'LQGT'
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['GT_size'] = 128
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['color'] = 'RGB'
opt['data_type'] = 'lmdb' # img | lmdb
else:
raise ValueError('Please implement by yourself.')
util.mkdir('tmp')
train_set = create_dataset(opt)
train_loader = create_dataloader(train_set, opt, opt, None)
nrow = int(math.sqrt(opt['batch_size']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(train_loader):
if i > 5:
break
print(i)
if dataset == 'REDS' or dataset == 'Vimeo90K':
LQs = data['LQs']
else:
LQ = data['LQ']
GT = data['GT']
if dataset == 'REDS' or dataset == 'Vimeo90K':
for j in range(LQs.size(1)):
torchvision.utils.save_image(LQs[:, j, :, :, :],
'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
padding=padding, normalize=False)
else:
torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
padding=padding, normalize=False)
torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
normalize=False)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,261 @@
function calculate_PSNR_SSIM()
% GT and SR folder
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
scale = 4;
suffix = ''; % suffix for SR images
test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
if test_Y
fprintf('Tesing Y channel.\n');
else
fprintf('Tesing RGB channels.\n');
end
filepaths = dir(fullfile(folder_GT, '*.png'));
PSNR_all = zeros(1, length(filepaths));
SSIM_all = zeros(1, length(filepaths));
for idx_im = 1:length(filepaths)
im_name = filepaths(idx_im).name;
im_GT = imread(fullfile(folder_GT, im_name));
im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
if test_Y % evaluate on Y channel in YCbCr color space
if size(im_GT, 3) == 3
im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
im_GT_in = im_GT_YCbCr(:,:,1);
im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
im_SR_in = im_SR_YCbCr(:,:,1);
else
im_GT_in = im2double(im_GT);
im_SR_in = im2double(im_SR);
end
else % evaluate on RGB channels
im_GT_in = im2double(im_GT);
im_SR_in = im2double(im_SR);
end
% calculate PSNR and SSIM
PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
end
fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
end
function res = calculate_PSNR(GT, SR, border)
% remove border
GT = GT(border+1:end-border, border+1:end-border, :);
SR = SR(border+1:end-border, border+1:end-border, :);
% calculate PNSR (assume in [0,255])
error = GT(:) - SR(:);
mse = mean(error.^2);
res = 10 * log10(255^2/mse);
end
function res = calculate_SSIM(GT, SR, border)
GT = GT(border+1:end-border, border+1:end-border, :);
SR = SR(border+1:end-border, border+1:end-border, :);
% calculate SSIM
mssim = zeros(1, size(SR, 3));
for i = 1:size(SR,3)
[mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
end
res = mean(mssim);
end
function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
%========================================================================
%SSIM Index, Version 1.0
%Copyright(c) 2003 Zhou Wang
%All Rights Reserved.
%
%The author is with Howard Hughes Medical Institute, and Laboratory
%for Computational Vision at Center for Neural Science and Courant
%Institute of Mathematical Sciences, New York University.
%
%----------------------------------------------------------------------
%Permission to use, copy, or modify this software and its documentation
%for educational and research purposes only and without fee is hereby
%granted, provided that this copyright notice and the original authors'
%names appear on all copies and supporting documentation. This program
%shall not be used, rewritten, or adapted as the basis of a commercial
%software or hardware product without first obtaining permission of the
%authors. The authors make no representations about the suitability of
%this software for any purpose. It is provided "as is" without express
%or implied warranty.
%----------------------------------------------------------------------
%
%This is an implementation of the algorithm for calculating the
%Structural SIMilarity (SSIM) index between two images. Please refer
%to the following paper:
%
%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
%quality assessment: From error measurement to structural similarity"
%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
%
%Kindly report any suggestions or corrections to zhouwang@ieee.org
%
%----------------------------------------------------------------------
%
%Input : (1) img1: the first image being compared
% (2) img2: the second image being compared
% (3) K: constants in the SSIM index formula (see the above
% reference). defualt value: K = [0.01 0.03]
% (4) window: local window for statistics (see the above
% reference). default widnow is Gaussian given by
% window = fspecial('gaussian', 11, 1.5);
% (5) L: dynamic range of the images. default: L = 255
%
%Output: (1) mssim: the mean SSIM index value between 2 images.
% If one of the images being compared is regarded as
% perfect quality, then mssim can be considered as the
% quality measure of the other image.
% If img1 = img2, then mssim = 1.
% (2) ssim_map: the SSIM index map of the test image. The map
% has a smaller size than the input images. The actual size:
% size(img1) - size(window) + 1.
%
%Default Usage:
% Given 2 test images img1 and img2, whose dynamic range is 0-255
%
% [mssim ssim_map] = ssim_index(img1, img2);
%
%Advanced Usage:
% User defined parameters. For example
%
% K = [0.05 0.05];
% window = ones(8);
% L = 100;
% [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
%
%See the results:
%
% mssim %Gives the mssim value
% imshow(max(0, ssim_map).^4) %Shows the SSIM index map
%
%========================================================================
if (nargin < 2 || nargin > 5)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
if (size(img1) ~= size(img2))
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
[M, N] = size(img1);
if (nargin == 2)
if ((M < 11) || (N < 11))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
window = fspecial('gaussian', 11, 1.5); %
K(1) = 0.01; % default settings
K(2) = 0.03; %
L = 255; %
end
if (nargin == 3)
if ((M < 11) || (N < 11))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
window = fspecial('gaussian', 11, 1.5);
L = 255;
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
if (nargin == 4)
[H, W] = size(window);
if ((H*W) < 4 || (H > M) || (W > N))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
L = 255;
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
if (nargin == 5)
[H, W] = size(window);
if ((H*W) < 4 || (H > M) || (W > N))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
C1 = (K(1)*L)^2;
C2 = (K(2)*L)^2;
window = window/sum(sum(window));
img1 = double(img1);
img2 = double(img2);
mu1 = filter2(window, img1, 'valid');
mu2 = filter2(window, img2, 'valid');
mu1_sq = mu1.*mu1;
mu2_sq = mu2.*mu2;
mu1_mu2 = mu1.*mu2;
sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
if (C1 > 0 && C2 > 0)
ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
else
numerator1 = 2*mu1_mu2 + C1;
numerator2 = 2*sigma12 + C2;
denominator1 = mu1_sq + mu2_sq + C1;
denominator2 = sigma1_sq + sigma2_sq + C2;
ssim_map = ones(size(mu1));
index = (denominator1.*denominator2 > 0);
ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
index = (denominator1 ~= 0) & (denominator2 == 0);
ssim_map(index) = numerator1(index)./denominator1(index);
end
mssim = mean2(ssim_map);
end

View File

@ -0,0 +1,147 @@
'''
calculate the PSNR and SSIM.
same as MATLAB's results
'''
import os
import math
import numpy as np
import cv2
import glob
def main():
# Configurations
# GT - Ground-truth;
# Gen: Generated / Restored / Recovered images
folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
crop_border = 4
suffix = '' # suffix for Gen images
test_Y = False # True: test Y channel only; False: test RGB channels
PSNR_all = []
SSIM_all = []
img_list = sorted(glob.glob(folder_GT + '/*'))
if test_Y:
print('Testing Y channel.')
else:
print('Testing RGB channels.')
for i, img_path in enumerate(img_list):
base_name = os.path.splitext(os.path.basename(img_path))[0]
im_GT = cv2.imread(img_path) / 255.
im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
im_GT_in = bgr2ycbcr(im_GT)
im_Gen_in = bgr2ycbcr(im_Gen)
else:
im_GT_in = im_GT
im_Gen_in = im_Gen
# crop borders
if im_GT_in.ndim == 3:
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
elif im_GT_in.ndim == 2:
cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
else:
raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
# calculate PSNR and SSIM
PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
i + 1, base_name, PSNR, SSIM))
PSNR_all.append(PSNR)
SSIM_all.append(SSIM)
print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
sum(PSNR_all) / len(PSNR_all),
sum(SSIM_all) / len(SSIM_all)))
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(img1, img2):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def bgr2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
if __name__ == '__main__':
main()

267
codes/models/SRGAN_model.py Normal file
View File

@ -0,0 +1,267 @@
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
from models.loss import GANLoss
logger = logging.getLogger('base')
class SRGANModel(BaseModel):
def __init__(self, opt):
super(SRGANModel, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
# define networks and load pretrained models
self.netG = networks.define_G(opt).to(self.device)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
self.netG = DataParallel(self.netG)
if self.is_train:
self.netD = networks.define_D(opt).to(self.device)
if opt['dist']:
self.netD = DistributedDataParallel(self.netD,
device_ids=[torch.cuda.current_device()])
else:
self.netD = DataParallel(self.netD)
self.netG.train()
self.netD.train()
# define losses, optimizer and scheduler
if self.is_train:
# G pixel loss
if train_opt['pixel_weight'] > 0:
l_pix_type = train_opt['pixel_criterion']
if l_pix_type == 'l1':
self.cri_pix = nn.L1Loss().to(self.device)
elif l_pix_type == 'l2':
self.cri_pix = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
self.l_pix_w = train_opt['pixel_weight']
else:
logger.info('Remove pixel loss.')
self.cri_pix = None
# G feature loss
if train_opt['feature_weight'] > 0:
l_fea_type = train_opt['feature_criterion']
if l_fea_type == 'l1':
self.cri_fea = nn.L1Loss().to(self.device)
elif l_fea_type == 'l2':
self.cri_fea = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
self.l_fea_w = train_opt['feature_weight']
else:
logger.info('Remove feature loss.')
self.cri_fea = None
if self.cri_fea: # load VGG perceptual loss
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
if opt['dist']:
self.netF = DistributedDataParallel(self.netF,
device_ids=[torch.cuda.current_device()])
else:
self.netF = DataParallel(self.netF)
# GD gan loss
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
self.l_gan_w = train_opt['gan_weight']
# D_update_ratio and D_init_iters
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
# optimizers
# G
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = []
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
self.optimizers.append(self.optimizer_G)
# D
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
weight_decay=wd_D,
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
self.optimizers.append(self.optimizer_D)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.CosineAnnealingLR_Restart(
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
else:
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
self.log_dict = OrderedDict()
self.print_network() # print network
self.load() # load G and D if needed
def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ
if need_GT:
self.var_H = data['GT'].to(self.device) # GT
input_ref = data['ref'] if 'ref' in data else data['GT']
self.var_ref = input_ref.to(self.device)
def optimize_parameters(self, step):
# G
for p in self.netD.parameters():
p.requires_grad = False
self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
l_g_total += l_g_pix
if self.cri_fea: # feature loss
real_fea = self.netF(self.var_H).detach()
fake_fea = self.netF(self.fake_H)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea
pred_g_fake = self.netD(self.fake_H)
if self.opt['train']['gan_type'] == 'gan':
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(self.var_ref).detach()
l_g_gan = self.l_gan_w * (
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
l_g_total += l_g_gan
l_g_total.backward()
self.optimizer_G.step()
# D
for p in self.netD.parameters():
p.requires_grad = True
self.optimizer_D.zero_grad()
l_d_total = 0
pred_d_real = self.netD(self.var_ref)
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
if self.opt['train']['gan_type'] == 'gan':
l_d_real = self.cri_gan(pred_d_real, True)
l_d_fake = self.cri_gan(pred_d_fake, False)
l_d_total = l_d_real + l_d_fake
elif self.opt['train']['gan_type'] == 'ragan':
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
l_d_total = (l_d_real + l_d_fake) / 2
l_d_total.backward()
self.optimizer_D.step()
# set log
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix:
self.log_dict['l_g_pix'] = l_g_pix.item()
if self.cri_fea:
self.log_dict['l_g_fea'] = l_g_fea.item()
self.log_dict['l_g_gan'] = l_g_gan.item()
self.log_dict['l_d_real'] = l_d_real.item()
self.log_dict['l_d_fake'] = l_d_fake.item()
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H = self.netG(self.var_L)
self.netG.train()
def get_current_log(self):
return self.log_dict
def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict()
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
if need_GT:
out_dict['GT'] = self.var_H.detach()[0].float().cpu()
return out_dict
def print_network(self):
# Generator
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
if self.is_train:
# Discriminator
s, n = self.get_network_description(self.netD)
if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
self.netD.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netD.__class__.__name__)
if self.rank <= 0:
logger.info('Network D structure: {}, with parameters: {:,d}'.format(
net_struc_str, n))
logger.info(s)
if self.cri_fea: # F, Perceptual Network
s, n = self.get_network_description(self.netF)
if isinstance(self.netF, nn.DataParallel) or isinstance(
self.netF, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
self.netF.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netF.__class__.__name__)
if self.rank <= 0:
logger.info('Network F structure: {}, with parameters: {:,d}'.format(
net_struc_str, n))
logger.info(s)
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
load_path_D = self.opt['path']['pretrain_model_D']
if self.opt['is_train'] and load_path_D is not None:
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
def save(self, iter_step):
self.save_network(self.netG, 'G', iter_step)
self.save_network(self.netD, 'D', iter_step)

170
codes/models/SR_model.py Normal file
View File

@ -0,0 +1,170 @@
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
from models.loss import CharbonnierLoss
logger = logging.getLogger('base')
class SRModel(BaseModel):
def __init__(self, opt):
super(SRModel, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
# define network and load pretrained models
self.netG = networks.define_G(opt).to(self.device)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
self.netG = DataParallel(self.netG)
# print network
self.print_network()
self.load()
if self.is_train:
self.netG.train()
# loss
loss_type = train_opt['pixel_criterion']
if loss_type == 'l1':
self.cri_pix = nn.L1Loss().to(self.device)
elif loss_type == 'l2':
self.cri_pix = nn.MSELoss().to(self.device)
elif loss_type == 'cb':
self.cri_pix = CharbonnierLoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
self.l_pix_w = train_opt['pixel_weight']
# optimizers
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = []
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1'], train_opt['beta2']))
self.optimizers.append(self.optimizer_G)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.CosineAnnealingLR_Restart(
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
else:
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ
if need_GT:
self.real_H = data['GT'].to(self.device) # GT
def optimize_parameters(self, step):
self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
l_pix.backward()
self.optimizer_G.step()
# set log
self.log_dict['l_pix'] = l_pix.item()
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H = self.netG(self.var_L)
self.netG.train()
def test_x8(self):
# from https://github.com/thstkdgus35/EDSR-PyTorch
self.netG.eval()
def _transform(v, op):
# if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
# if self.precision == 'half': ret = ret.half()
return ret
lr_list = [self.var_L]
for tf in 'v', 'h', 't':
lr_list.extend([_transform(t, tf) for t in lr_list])
with torch.no_grad():
sr_list = [self.netG(aug) for aug in lr_list]
for i in range(len(sr_list)):
if i > 3:
sr_list[i] = _transform(sr_list[i], 't')
if i % 4 > 1:
sr_list[i] = _transform(sr_list[i], 'h')
if (i % 4) % 2 == 1:
sr_list[i] = _transform(sr_list[i], 'v')
output_cat = torch.cat(sr_list, dim=0)
self.fake_H = output_cat.mean(dim=0, keepdim=True)
self.netG.train()
def get_current_log(self):
return self.log_dict
def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict()
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
if need_GT:
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
return out_dict
def print_network(self):
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
def save(self, iter_label):
self.save_network(self.netG, 'G', iter_label)

View File

@ -0,0 +1,166 @@
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
from models.loss import CharbonnierLoss
logger = logging.getLogger('base')
class VideoBaseModel(BaseModel):
def __init__(self, opt):
super(VideoBaseModel, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
# define network and load pretrained models
self.netG = networks.define_G(opt).to(self.device)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
self.netG = DataParallel(self.netG)
# print network
self.print_network()
self.load()
if self.is_train:
self.netG.train()
#### loss
loss_type = train_opt['pixel_criterion']
if loss_type == 'l1':
self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
elif loss_type == 'l2':
self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
elif loss_type == 'cb':
self.cri_pix = CharbonnierLoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
self.l_pix_w = train_opt['pixel_weight']
#### optimizers
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
if train_opt['ft_tsa_only']:
normal_params = []
tsa_fusion_params = []
for k, v in self.netG.named_parameters():
if v.requires_grad:
if 'tsa_fusion' in k:
tsa_fusion_params.append(v)
else:
normal_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['lr_G']
},
{
'params': tsa_fusion_params,
'lr': train_opt['lr_G']
},
]
else:
optim_params = []
for k, v in self.netG.named_parameters():
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1'], train_opt['beta2']))
self.optimizers.append(self.optimizer_G)
#### schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.CosineAnnealingLR_Restart(
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
else:
raise NotImplementedError()
self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True):
self.var_L = data['LQs'].to(self.device)
if need_GT:
self.real_H = data['GT'].to(self.device)
def set_params_lr_zero(self):
# fix normal module
self.optimizers[0].param_groups[0]['lr'] = 0
def optimize_parameters(self, step):
if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
self.set_params_lr_zero()
self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
l_pix.backward()
self.optimizer_G.step()
# set log
self.log_dict['l_pix'] = l_pix.item()
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H = self.netG(self.var_L)
self.netG.train()
def get_current_log(self):
return self.log_dict
def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict()
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
if need_GT:
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
return out_dict
def print_network(self):
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
def save(self, iter_label):
self.save_network(self.netG, 'G', iter_label)

19
codes/models/__init__.py Normal file
View File

@ -0,0 +1,19 @@
import logging
logger = logging.getLogger('base')
def create_model(opt):
model = opt['model']
# image restoration
if model == 'sr': # PSNR-oriented super resolution
from .SR_model import SRModel as M
elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN
from .SRGAN_model import SRGANModel as M
# video restoration
elif model == 'video_base':
from .Video_base_model import VideoBaseModel as M
else:
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
m = M(opt)
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
return m

View File

@ -0,0 +1,368 @@
'''Network architecture for DUF:
Deep Video Super-Resolution Network Using Dynamic Upsampling Filters
Without Explicit Motion Compensation (CVPR18)
https://github.com/yhjo09/VSR-DUF
For all the models below, [adapt_official] is only necessary when
loading the weights converted from the official TensorFlow weights.
Please set it to [False] if you are training the model from scratch.
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def adapt_official(Rx, scale=4):
'''Adapt the weights translated from the official tensorflow weights
Not necessary if you are training from scratch'''
x = Rx.clone()
x1 = x[:, ::3, :, :]
x2 = x[:, 1::3, :, :]
x3 = x[:, 2::3, :, :]
Rx[:, :scale**2, :, :] = x1
Rx[:, scale**2:2 * (scale**2), :, :] = x2
Rx[:, 2 * (scale**2):, :, :] = x3
return Rx
class DenseBlock(nn.Module):
'''Dense block
for the second denseblock, t_reduced = True'''
def __init__(self, nf=64, ng=32, t_reduce=False):
super(DenseBlock, self).__init__()
self.t_reduce = t_reduce
if self.t_reduce:
pad = (0, 1, 1)
else:
pad = (1, 1, 1)
self.bn3d_1 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
self.conv3d_1 = nn.Conv3d(nf, nf, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.bn3d_2 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
self.conv3d_2 = nn.Conv3d(nf, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
self.bn3d_3 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
self.conv3d_3 = nn.Conv3d(nf + ng, nf + ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.bn3d_4 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
self.conv3d_4 = nn.Conv3d(nf + ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
self.bn3d_5 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
self.conv3d_5 = nn.Conv3d(nf + 2 * ng, nf + 2 * ng, (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.bn3d_6 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
self.conv3d_6 = nn.Conv3d(nf + 2 * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad,
bias=True)
def forward(self, x):
'''x: [B, C, T, H, W]
C: nf -> nf + 3 * ng
T: 1) 7 -> 7 (t_reduce=False);
2) 7 -> 7 - 2 * 3 = 1 (t_reduce=True)'''
x1 = self.conv3d_1(F.relu(self.bn3d_1(x), inplace=True))
x1 = self.conv3d_2(F.relu(self.bn3d_2(x1), inplace=True))
if self.t_reduce:
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
else:
x1 = torch.cat((x, x1), 1)
x2 = self.conv3d_3(F.relu(self.bn3d_3(x1), inplace=True))
x2 = self.conv3d_4(F.relu(self.bn3d_4(x2), inplace=True))
if self.t_reduce:
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
else:
x2 = torch.cat((x1, x2), 1)
x3 = self.conv3d_5(F.relu(self.bn3d_5(x2), inplace=True))
x3 = self.conv3d_6(F.relu(self.bn3d_6(x3), inplace=True))
if self.t_reduce:
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
else:
x3 = torch.cat((x2, x3), 1)
return x3
class DynamicUpsamplingFilter_3C(nn.Module):
'''dynamic upsampling filter with 3 channels applying the same filters
filter_size: filter size of the generated filters, shape (C, kH, kW)'''
def __init__(self, filter_size=(1, 5, 5)):
super(DynamicUpsamplingFilter_3C, self).__init__()
# generate a local expansion filter, used similar to im2col
nF = np.prod(filter_size)
expand_filter_np = np.reshape(np.eye(nF, nF),
(nF, filter_size[0], filter_size[1], filter_size[2]))
expand_filter = torch.from_numpy(expand_filter_np).float()
self.expand_filter = torch.cat((expand_filter, expand_filter, expand_filter),
0) # [75, 1, 5, 5]
def forward(self, x, filters):
'''x: input image, [B, 3, H, W]
filters: generate dynamic filters, [B, F, R, H, W], e.g., [B, 25, 16, H, W]
F: prod of filter kernel size, e.g., 5*5 = 25
R: used for upsampling, similar to pixel shuffle, e.g., 4*4 = 16 for x4
Return: filtered image, [B, 3*R, H, W]
'''
B, nF, R, H, W = filters.size()
# using group convolution
input_expand = F.conv2d(x, self.expand_filter.type_as(x), padding=2,
groups=3) # [B, 75, H, W] similar to im2col
input_expand = input_expand.view(B, 3, nF, H, W).permute(0, 3, 4, 1, 2) # [B, H, W, 3, 25]
filters = filters.permute(0, 3, 4, 1, 2) # [B, H, W, 25, 16]
out = torch.matmul(input_expand, filters) # [B, H, W, 3, 16]
return out.permute(0, 3, 4, 1, 2).view(B, 3 * R, H, W) # [B, 3*16, H, W]
class DUF_16L(nn.Module):
'''Official DUF structure with 16 layers'''
def __init__(self, scale=4, adapt_official=False):
super(DUF_16L, self).__init__()
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.dense_block_1 = DenseBlock(64, 64 // 2, t_reduce=False) # 64 + 32 * 3 = 160, T = 7
self.dense_block_2 = DenseBlock(160, 64 // 2, t_reduce=True) # 160 + 32 * 3 = 256, T = 1
self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
bias=True)
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
self.scale = scale
self.adapt_official = adapt_official
def forward(self, x):
'''
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
Generate filters and image residual:
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
Rx: [B, 3*16, 1, H, W]
'''
B, T, C, H, W = x.size()
x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] for Conv3D
x_center = x[:, :, T // 2, :, :]
x = self.conv3d_1(x)
x = self.dense_block_1(x)
x = self.dense_block_2(x) # reduce T to 1
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
# image residual
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
# filter
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
# Adapt to official model weights
if self.adapt_official:
adapt_official(Rx, scale=self.scale)
# dynamic filter
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
out += Rx.squeeze_(2)
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
return out
class DenseBlock_28L(nn.Module):
'''The first part of the dense blocks used in DUF_28L
Temporal dimension remains the same here'''
def __init__(self, nf=64, ng=16):
super(DenseBlock_28L, self).__init__()
pad = (1, 1, 1)
dense_block_l = []
for i in range(0, 9):
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
dense_block_l.append(nn.ReLU())
dense_block_l.append(
nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True))
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
dense_block_l.append(nn.ReLU())
dense_block_l.append(
nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
self.dense_blocks = nn.ModuleList(dense_block_l)
def forward(self, x):
'''x: [B, C, T, H, W]
C: 1) 64 -> 208;
T: 1) 7 -> 7; (t_reduce=True)'''
for i in range(0, len(self.dense_blocks), 6):
y = x
for j in range(6):
y = self.dense_blocks[i + j](y)
x = torch.cat((x, y), 1)
return x
class DUF_28L(nn.Module):
'''Official DUF structure with 28 layers'''
def __init__(self, scale=4, adapt_official=False):
super(DUF_28L, self).__init__()
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.dense_block_1 = DenseBlock_28L(64, 16) # 64 + 16 * 9 = 208, T = 7
self.dense_block_2 = DenseBlock(208, 16, t_reduce=True) # 208 + 16 * 3 = 256, T = 1
self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
bias=True)
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
self.scale = scale
self.adapt_official = adapt_official
def forward(self, x):
'''
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
Generate filters and image residual:
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
Rx: [B, 3*16, 1, H, W]
'''
B, T, C, H, W = x.size()
x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
x_center = x[:, :, T // 2, :, :]
x = self.conv3d_1(x)
x = self.dense_block_1(x)
x = self.dense_block_2(x) # reduce T to 1
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
# image residual
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
# filter
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
# Adapt to official model weights
if self.adapt_official:
adapt_official(Rx, scale=self.scale)
# dynamic filter
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
out += Rx.squeeze_(2)
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
return out
class DenseBlock_52L(nn.Module):
'''The first part of the dense blocks used in DUF_52L
Temporal dimension remains the same here'''
def __init__(self, nf=64, ng=16):
super(DenseBlock_52L, self).__init__()
pad = (1, 1, 1)
dense_block_l = []
for i in range(0, 21):
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
dense_block_l.append(nn.ReLU())
dense_block_l.append(
nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True))
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
dense_block_l.append(nn.ReLU())
dense_block_l.append(
nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
self.dense_blocks = nn.ModuleList(dense_block_l)
def forward(self, x):
'''x: [B, C, T, H, W]
C: 1) 64 -> 400;
T: 1) 7 -> 7; (t_reduce=True)'''
for i in range(0, len(self.dense_blocks), 6):
y = x
for j in range(6):
y = self.dense_blocks[i + j](y)
x = torch.cat((x, y), 1)
return x
class DUF_52L(nn.Module):
'''Official DUF structure with 52 layers'''
def __init__(self, scale=4, adapt_official=False):
super(DUF_52L, self).__init__()
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.dense_block_1 = DenseBlock_52L(64, 16) # 64 + 21 * 9 = 400, T = 7
self.dense_block_2 = DenseBlock(400, 16, t_reduce=True) # 400 + 16 * 3 = 448, T = 1
self.bn3d_2 = nn.BatchNorm3d(448, eps=1e-3, momentum=1e-3)
self.conv3d_2 = nn.Conv3d(448, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
bias=True)
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
self.scale = scale
self.adapt_official = adapt_official
def forward(self, x):
'''
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
Generate filters and image residual:
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
Rx: [B, 3*16, 1, H, W]
'''
B, T, C, H, W = x.size()
x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
x_center = x[:, :, T // 2, :, :]
x = self.conv3d_1(x)
x = self.dense_block_1(x)
x = self.dense_block_2(x)
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
# image residual
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
# filter
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
# Adapt to official model weights
if self.adapt_official:
adapt_official(Rx, scale=self.scale)
# dynamic filter
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
out += Rx.squeeze_(2)
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
return out

View File

@ -0,0 +1,312 @@
''' network architecture for EDVR '''
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
try:
from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN
except ImportError:
raise ImportError('Failed to import DCNv2 module.')
class Predeblur_ResNet_Pyramid(nn.Module):
def __init__(self, nf=128, HR_in=False):
'''
HR_in: True if the inputs are high spatial size
'''
super(Predeblur_ResNet_Pyramid, self).__init__()
self.HR_in = True if HR_in else False
if self.HR_in:
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
else:
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
self.RB_L1_1 = basic_block()
self.RB_L1_2 = basic_block()
self.RB_L1_3 = basic_block()
self.RB_L1_4 = basic_block()
self.RB_L1_5 = basic_block()
self.RB_L2_1 = basic_block()
self.RB_L2_2 = basic_block()
self.RB_L3_1 = basic_block()
self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
if self.HR_in:
L1_fea = self.lrelu(self.conv_first_1(x))
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
else:
L1_fea = self.lrelu(self.conv_first(x))
L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea))
L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea))
L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear',
align_corners=False)
L2_fea = self.RB_L2_1(L2_fea) + L3_fea
L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear',
align_corners=False)
L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea
out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea)))
return out
class PCD_Align(nn.Module):
''' Alignment module using Pyramid, Cascading and Deformable convolution
with 3 pyramid levels.
'''
def __init__(self, nf=64, groups=8):
super(PCD_Align, self).__init__()
# L3: level 3, 1/4 spatial size
self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
# L2: level 2, 1/2 spatial size
self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
# L1: level 1, original spatial size
self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
# Cascading DCN
self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, nbr_fea_l, ref_fea_l):
'''align other neighboring frames to the reference frame in the feature level
nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
'''
# L3
L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)
L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
# L2
L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)
L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
# L1
L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)
L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
# Cascading
offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
offset = self.lrelu(self.cas_offset_conv1(offset))
offset = self.lrelu(self.cas_offset_conv2(offset))
L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
return L1_fea
class TSA_Fusion(nn.Module):
''' Temporal Spatial Attention fusion module
Temporal: correlation;
Spatial: 3 pyramid levels.
'''
def __init__(self, nf=64, nframes=5, center=2):
super(TSA_Fusion, self).__init__()
self.center = center
# temporal attention (before fusion conv)
self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# fusion conv: using 1x1 to save parameters and computation
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
# spatial attention (after fusion conv)
self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True)
self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, aligned_fea):
B, N, C, H, W = aligned_fea.size() # N video frames
#### temporal attention
emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())
emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W]
cor_l = []
for i in range(N):
emb_nbr = emb[:, i, :, :, :]
cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W
cor_l.append(cor_tmp)
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W)
aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob
#### fusion
fea = self.lrelu(self.fea_fusion(aligned_fea))
#### spatial attention
att = self.lrelu(self.sAtt_1(aligned_fea))
att_max = self.maxpool(att)
att_avg = self.avgpool(att)
att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
# pyramid levels
att_L = self.lrelu(self.sAtt_L1(att))
att_max = self.maxpool(att_L)
att_avg = self.avgpool(att_L)
att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
att_L = self.lrelu(self.sAtt_L3(att_L))
att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
att = self.lrelu(self.sAtt_3(att))
att = att + att_L
att = self.lrelu(self.sAtt_4(att))
att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
att = self.sAtt_5(att)
att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
att = torch.sigmoid(att)
fea = fea * att * 2 + att_add
return fea
class EDVR(nn.Module):
def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None,
predeblur=False, HR_in=False, w_TSA=True):
super(EDVR, self).__init__()
self.nf = nf
self.center = nframes // 2 if center is None else center
self.is_predeblur = True if predeblur else False
self.HR_in = True if HR_in else False
self.w_TSA = w_TSA
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
#### extract features (for each frame)
if self.is_predeblur:
self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
else:
if self.HR_in:
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
else:
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.pcd_align = PCD_Align(nf=nf, groups=groups)
if self.w_TSA:
self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
else:
self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
#### reconstruction
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
#### activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
B, N, C, H, W = x.size() # N video frames
x_center = x[:, self.center, :, :, :].contiguous()
#### extract LR features
# L1
if self.is_predeblur:
L1_fea = self.pre_deblur(x.view(-1, C, H, W))
L1_fea = self.conv_1x1(L1_fea)
if self.HR_in:
H, W = H // 4, W // 4
else:
if self.HR_in:
L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W)))
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
H, W = H // 4, W // 4
else:
L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W)))
L1_fea = self.feature_extraction(L1_fea)
# L2
L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
# L3
L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
L1_fea = L1_fea.view(B, N, -1, H, W)
L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
#### pcd align
# ref feature list
ref_fea_l = [
L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(),
L3_fea[:, self.center, :, :, :].clone()
]
aligned_fea = []
for i in range(N):
nbr_fea_l = [
L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(),
L3_fea[:, i, :, :, :].clone()
]
aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l))
aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W]
if not self.w_TSA:
aligned_fea = aligned_fea.view(B, -1, H, W)
fea = self.tsa_fusion(aligned_fea)
out = self.recon_trunk(fea)
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.HRconv(out))
out = self.conv_last(out)
if self.HR_in:
base = x_center
else:
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
out += base
return out

View File

@ -0,0 +1,73 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],
0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out

View File

@ -0,0 +1,55 @@
import functools
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
class MSRResNet(nn.Module):
''' modified SRResNet'''
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
super(MSRResNet, self).__init__()
self.upscale = upscale
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
self.recon_trunk = arch_util.make_layer(basic_block, nb)
# upsampling
if self.upscale == 2:
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
elif self.upscale == 3:
self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(3)
elif self.upscale == 4:
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# initialization
arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
0.1)
if self.upscale == 4:
arch_util.initialize_weights(self.upconv2, 0.1)
def forward(self, x):
fea = self.lrelu(self.conv_first(x))
out = self.recon_trunk(fea)
if self.upscale == 4:
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
elif self.upscale == 3 or self.upscale == 2:
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.conv_last(self.lrelu(self.HRconv(out)))
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
out += base
return out

137
codes/models/archs/TOF_arch.py Executable file
View File

@ -0,0 +1,137 @@
'''PyTorch implementation of TOFlow
Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
Code reference:
1. https://github.com/anchen1011/toflow
2. https://github.com/Coldog2333/pytoflow
'''
import torch
import torch.nn as nn
from .arch_util import flow_warp
def normalize(x):
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
return (x - mean) / std
def denormalize(x):
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
return x * std + mean
class SpyNet_Block(nn.Module):
'''A submodule of SpyNet.'''
def __init__(self):
super(SpyNet_Block, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
def forward(self, x):
'''
input: x: [ref im, nbr im, initial flow] - (B, 8, H, W)
output: estimated flow - (B, 2, H, W)
'''
return self.block(x)
class SpyNet(nn.Module):
'''SpyNet for estimating optical flow
Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016'''
def __init__(self):
super(SpyNet, self).__init__()
self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)])
def forward(self, ref, nbr):
'''Estimating optical flow in coarse level, upsample, and estimate in fine level
input: ref: reference image - [B, 3, H, W]
nbr: the neighboring image to be warped - [B, 3, H, W]
output: estimated optical flow - [B, 2, H, W]
'''
B, C, H, W = ref.size()
ref = [ref]
nbr = [nbr]
for _ in range(3):
ref.insert(
0,
nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2,
count_include_pad=False))
nbr.insert(
0,
nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2,
count_include_pad=False))
flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0])
for i in range(4):
flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear',
align_corners=True) * 2.0
flow = flow_up + self.blocks[i](torch.cat(
[ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
return flow
class TOFlow(nn.Module):
def __init__(self, adapt_official=False):
super(TOFlow, self).__init__()
self.SpyNet = SpyNet()
self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4)
self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1)
self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1)
self.relu = nn.ReLU(inplace=True)
self.adapt_official = adapt_official # True if using translated official weights else False
def forward(self, x):
"""
input: x: input frames - [B, 7, 3, H, W]
output: SR reference frame - [B, 3, H, W]
"""
B, T, C, H, W = x.size()
x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W)
ref_idx = 3
x_ref = x[:, ref_idx, :, :, :]
# In the official torch code, the 0-th frame is the reference frame
if self.adapt_official:
x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
ref_idx = 0
x_warped = []
for i in range(7):
if i == ref_idx:
x_warped.append(x_ref)
else:
x_nbr = x[:, i, :, :, :]
flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1)
x_warped.append(flow_warp(x_nbr, flow))
x_warped = torch.stack(x_warped, dim=1)
x = x_warped.view(B, -1, H, W)
x = self.relu(self.conv_3x7_64_9x9(x))
x = self.relu(self.conv_64_64_9x9(x))
x = self.relu(self.conv_64_64_1x1(x))
x = self.conv_64_3_1x1(x) + x_ref
return denormalize(x)

View File

View File

@ -0,0 +1,79 @@
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualBlock_noBN(nn.Module):
'''Residual block w/o BN
---Conv-ReLU-Conv-+-
|________________|
'''
def __init__(self, nf=64):
super(ResidualBlock_noBN, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# initialization
initialize_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = F.relu(self.conv1(x), inplace=True)
out = self.conv2(out)
return identity + out
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
"""Warp an image or feature map with optical flow
Args:
x (Tensor): size (N, C, H, W)
flow (Tensor): size (N, H, W, 2), normal value
interp_mode (str): 'nearest' or 'bilinear'
padding_mode (str): 'zeros' or 'border' or 'reflection'
Returns:
Tensor: warped image or feature map
"""
assert x.size()[-2:] == flow.size()[1:3]
B, C, H, W = x.size()
# mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
grid = grid.type_as(x)
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
return output

View File

@ -0,0 +1,7 @@
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,
deform_conv, modulated_deform_conv)
__all__ = [
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv'
]

View File

@ -0,0 +1,291 @@
import math
import logging
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from . import deform_conv_cuda
logger = logging.getLogger('base')
class DeformConvFunction(Function):
@staticmethod
def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
deformable_groups=1, im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0],
ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_cuda.deform_conv_backward_input_cuda(
input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_cuda.deform_conv_backward_parameters_cuda(
input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, 1, cur_im2col_step)
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError("convolution input is too small (output would be {})".format('x'.join(
map(str, output_size))))
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
groups=1, deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
deform_conv_cuda.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding - (ctx.dilation *
(kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding - (ctx.dilation *
(kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
class DeformConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
groups=1, deformable_groups=1, bias=False):
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class DeformConvPack(DeformConv):
def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
groups=1, deformable_groups=1, bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self, *args, extra_offset_mask=False, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.extra_offset_mask = extra_offset_mask
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, x):
if self.extra_offset_mask:
# x = [input, features]
out = self.conv_offset_mask(x[1])
x = x[0]
else:
out = self.conv_offset_mask(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
offset_mean = torch.mean(torch.abs(offset))
if offset_mean > 100:
logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)

View File

@ -0,0 +1,22 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
def make_cuda_ext(name, sources):
return CUDAExtension(
name='{}'.format(name), sources=[p for p in sources], extra_compile_args={
'cxx': [],
'nvcc': [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
})
setup(
name='deform_conv', ext_modules=[
make_cuda_ext(name='deform_conv_cuda',
sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])
], cmdclass={'build_ext': BuildExtension}, zip_safe=False)

View File

@ -0,0 +1,695 @@
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h>
#include <cmath>
#include <vector>
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor data_col);
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im);
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const int channels, const int height,
const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor data_col);
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
int padW, int dilationH, int dilationW, int group,
int deformable_group) {
AT_CHECK(weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
AT_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
kW);
AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
kW, weight.size(2), weight.size(3));
AT_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
AT_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
AT_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
AT_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");
AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight, outputWidth, offset.size(2), offset.size(3));
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");
if (gradOutput != NULL) {
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));
AT_CHECK((gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
outputHeight, outputWidth, gradOutput->size(dimh),
gradOutput->size(dimw));
}
}
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input.unsqueeze_(0);
offset.unsqueeze_(0);
}
// todo: assert batchsize dividable by im2col_step
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputHeight, outputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
ones = at::ones({outputHeight, outputWidth}, input.options());
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
at::Tensor output_buffer =
at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth},
output.options());
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view(
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step) {
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view({1, input.size(0), input.size(1), input.size(2)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
// change order of grad output
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight,
outputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
gradOutput = gradOutput.view(
{gradOutput.size(0), group, gradOutput.size(1) / group,
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++) {
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, im2col_step, deformable_group,
gradOffset[elt]);
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step) {
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view(
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight =
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++) {
gradWeight[g] = gradWeight[g]
.flatten(1)
.addmm_(gradOutputBuffer[elt][g].flatten(1),
columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3),
gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
return 1;
}
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
"deform forward (CUDA)");
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
"deform_conv_backward_input (CUDA)");
m.def("deform_conv_backward_parameters_cuda",
&deform_conv_backward_parameters_cuda,
"deform_conv_backward_parameters (CUDA)");
m.def("modulated_deform_conv_cuda_forward",
&modulated_deform_conv_cuda_forward,
"modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward",
&modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
}

View File

@ -0,0 +1,88 @@
import torch
import torch.nn as nn
import torchvision
class Discriminator_VGG_128(nn.Module):
def __init__(self, in_nc, nf):
super(Discriminator_VGG_128, self).__init__()
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
# [64, 64, 64]
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
# [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
self.linear1 = nn.Linear(512 * 4 * 4, 100)
self.linear2 = nn.Linear(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
fea = fea.view(fea.size(0), -1)
fea = self.lrelu(self.linear1(fea))
out = self.linear2(fea)
return out
class VGGFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
else:
model = torchvision.models.vgg19(pretrained=True)
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
# Assume input range is [0, 1]
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output

116
codes/models/base_model.py Normal file
View File

@ -0,0 +1,116 @@
import os
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
class BaseModel():
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []
def feed_data(self, data):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
pass
def get_current_losses(self):
pass
def print_network(self):
pass
def save(self, label):
pass
def load(self):
pass
def _set_lr(self, lr_groups_l):
"""Set learning rate for warmup
lr_groups_l: list for lr_groups. each for a optimizer"""
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
for param_group, lr in zip(optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def _get_init_lr(self):
"""Get the initial lr, which is set by the scheduler"""
init_lr_groups_l = []
for optimizer in self.optimizers:
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
return init_lr_groups_l
def update_learning_rate(self, cur_iter, warmup_iter=-1):
for scheduler in self.schedulers:
scheduler.step()
# set up warm-up learning rate
if cur_iter < warmup_iter:
# get initial lr for each group
init_lr_g_l = self._get_init_lr()
# modify warming-up learning rates
warm_up_lr_l = []
for init_lr_g in init_lr_g_l:
warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
# set learning rate
self._set_lr(warm_up_lr_l)
def get_current_learning_rate(self):
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
def get_network_description(self, network):
"""Get the string and total parameters of the network"""
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
return str(network), sum(map(lambda x: x.numel(), network.parameters()))
def save_network(self, network, network_label, iter_label):
save_filename = '{}_{}.pth'.format(iter_label, network_label)
save_path = os.path.join(self.opt['path']['models'], save_filename)
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
state_dict = network.state_dict()
for key, param in state_dict.items():
state_dict[key] = param.cpu()
torch.save(state_dict, save_path)
def load_network(self, load_path, network, strict=True):
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
load_net = torch.load(load_path)
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
network.load_state_dict(load_net_clean, strict=strict)
def save_training_state(self, epoch, iter_step):
"""Save training state during training, which will be used for resuming"""
state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
for s in self.schedulers:
state['schedulers'].append(s.state_dict())
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
save_filename = '{}.state'.format(iter_step)
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
torch.save(state, save_path)
def resume_training(self, resume_state):
"""Resume the optimizers and schedulers for training"""
resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
for i, o in enumerate(resume_optimizers):
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)

74
codes/models/loss.py Normal file
View File

@ -0,0 +1,74 @@
import torch
import torch.nn as nn
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-6):
super(CharbonnierLoss, self).__init__()
self.eps = eps
def forward(self, x, y):
diff = x - y
loss = torch.sum(torch.sqrt(diff * diff + self.eps))
return loss
# Define GAN loss: [vanilla | lsgan | wgan-gp]
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type.lower()
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'gan' or self.gan_type == 'ragan':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan-gp':
def wgan_loss(input, target):
# target is boolean
return -1 * input.mean() if target else input.mean()
self.loss = wgan_loss
else:
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
def get_target_label(self, input, target_is_real):
if self.gan_type == 'wgan-gp':
return target_is_real
if target_is_real:
return torch.empty_like(input).fill_(self.real_label_val)
else:
return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real):
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
return loss
class GradientPenaltyLoss(nn.Module):
def __init__(self, device=torch.device('cpu')):
super(GradientPenaltyLoss, self).__init__()
self.register_buffer('grad_outputs', torch.Tensor())
self.grad_outputs = self.grad_outputs.to(device)
def get_grad_outputs(self, input):
if self.grad_outputs.size() != input.size():
self.grad_outputs.resize_(input.size()).fill_(1.0)
return self.grad_outputs
def forward(self, interp, interp_crit):
grad_outputs = self.get_grad_outputs(interp_crit)
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
grad_outputs=grad_outputs, create_graph=True,
retain_graph=True, only_inputs=True)[0]
grad_interp = grad_interp.view(grad_interp.size(0), -1)
grad_interp_norm = grad_interp.norm(2, dim=1)
loss = ((grad_interp_norm - 1)**2).mean()
return loss

View File

@ -0,0 +1,144 @@
import math
from collections import Counter
from collections import defaultdict
import torch
from torch.optim.lr_scheduler import _LRScheduler
class MultiStepLR_Restart(_LRScheduler):
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
clear_state=False, last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.clear_state = clear_state
self.restarts = restarts if restarts else [0]
self.restarts = [v + 1 for v in self.restarts]
self.restart_weights = weights if weights else [1]
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch in self.restarts:
if self.clear_state:
self.optimizer.state = defaultdict(dict)
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [
group['lr'] * self.gamma**self.milestones[self.last_epoch]
for group in self.optimizer.param_groups
]
class CosineAnnealingLR_Restart(_LRScheduler):
def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
self.T_period = T_period
self.T_max = self.T_period[0] # current T period
self.eta_min = eta_min
self.restarts = restarts if restarts else [0]
self.restarts = [v + 1 for v in self.restarts]
self.restart_weights = weights if weights else [1]
self.last_restart = 0
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch == 0:
return self.base_lrs
elif self.last_epoch in self.restarts:
self.last_restart = self.last_epoch
self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
return [
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
(1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]
if __name__ == "__main__":
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
betas=(0.9, 0.99))
##############################
# MultiStepLR_Restart
##############################
## Original
lr_steps = [200000, 400000, 600000, 800000]
restarts = None
restart_weights = None
## two
lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
restarts = [500000]
restart_weights = [1]
## four
lr_steps = [
50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
]
restarts = [250000, 500000, 750000]
restart_weights = [1, 1, 1]
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
clear_state=False)
##############################
# Cosine Annealing Restart
##############################
## two
T_period = [500000, 500000]
restarts = [500000]
restart_weights = [1]
## four
T_period = [250000, 250000, 250000, 250000]
restarts = [250000, 500000, 750000]
restart_weights = [1, 1, 1]
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
weights=restart_weights)
##############################
# Draw figure
##############################
N_iter = 1000000
lr_l = list(range(N_iter))
for i in range(N_iter):
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
lr_l[i] = current_lr
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.ticker as mtick
mpl.style.use('default')
import seaborn
seaborn.set(style='whitegrid')
seaborn.set_context('paper')
plt.figure(1)
plt.subplot(111)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.title('Title', fontsize=16, color='k')
plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
legend = plt.legend(loc='upper right', shadow=False)
ax = plt.gca()
labels = ax.get_xticks().tolist()
for k, v in enumerate(labels):
labels[k] = str(int(v / 1000)) + 'K'
ax.set_xticklabels(labels)
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
ax.set_ylabel('Learning rate')
ax.set_xlabel('Iteration')
fig = plt.gcf()
plt.show()

57
codes/models/networks.py Normal file
View File

@ -0,0 +1,57 @@
import torch
import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.EDVR_arch as EDVR_arch
# Generator
def define_G(opt):
opt_net = opt['network_G']
which_model = opt_net['which_model_G']
# image restoration
if which_model == 'MSRResNet':
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'RRDBNet':
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'])
# video restoration
elif which_model == 'EDVR':
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
w_TSA=opt_net['w_TSA'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG
# Discriminator
def define_D(opt):
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Define network used for perceptual loss
def define_F(opt, use_bn=False):
gpu_ids = opt['gpu_ids']
device = torch.device('cuda' if gpu_ids else 'cpu')
# PyTorch pretrained VGG19-54, before ReLU.
if use_bn:
feature_layer = 49
else:
feature_layer = 34
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
use_input_norm=True, device=device)
netF.eval() # No need to train
return netF

View File

116
codes/options/options.py Normal file
View File

@ -0,0 +1,116 @@
import os
import os.path as osp
import logging
import yaml
from utils.util import OrderedYaml
Loader, Dumper = OrderedYaml()
def parse(opt_path, is_train=True):
with open(opt_path, mode='r') as f:
opt = yaml.load(f, Loader=Loader)
# export CUDA_VISIBLE_DEVICES
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
opt['is_train'] = is_train
if opt['distortion'] == 'sr':
scale = opt['scale']
# datasets
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0]
dataset['phase'] = phase
if opt['distortion'] == 'sr':
dataset['scale'] = scale
is_lmdb = False
if dataset.get('dataroot_GT', None) is not None:
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
if dataset['dataroot_GT'].endswith('lmdb'):
is_lmdb = True
if dataset.get('dataroot_LQ', None) is not None:
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
if dataset['dataroot_LQ'].endswith('lmdb'):
is_lmdb = True
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
if dataset['mode'].endswith('mc'): # for memcached
dataset['data_type'] = 'mc'
dataset['mode'] = dataset['mode'].replace('_mc', '')
# path
for key, path in opt['path'].items():
if path and key in opt['path'] and key != 'strict_load':
opt['path'][key] = osp.expanduser(path)
opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
if is_train:
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
opt['path']['log'] = experiments_root
opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
# change some options for debug mode
if 'debug' in opt['name']:
opt['train']['val_freq'] = 8
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
else: # test
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root
# network
if opt['distortion'] == 'sr':
opt['network_G']['scale'] = scale
return opt
def dict2str(opt, indent_l=1):
'''dict to string for logger'''
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg
class NoneDict(dict):
def __missing__(self, key):
return None
# convert to NoneDict, which return None for missing key.
def dict_to_nonedict(opt):
if isinstance(opt, dict):
new_opt = dict()
for key, sub_opt in opt.items():
new_opt[key] = dict_to_nonedict(sub_opt)
return NoneDict(**new_opt)
elif isinstance(opt, list):
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
else:
return opt
def check_resume(opt, resume_iter):
'''Check resume states and pretrain_model paths'''
logger = logging.getLogger('base')
if opt['path']['resume_state']:
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
'pretrain_model_D', None) is not None:
logger.warning('pretrain_model path will be ignored when resuming training.')
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
'{}_G.pth'.format(resume_iter))
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
if 'gan' in opt['model']:
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
'{}_D.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])

View File

@ -0,0 +1,32 @@
name: RRDB_ESRGAN_x4
suffix: ~ # add suffix to saved images
model: sr
distortion: sr
scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0]
datasets:
test_1: # the 1st test dataset
name: set5
mode: LQGT
dataroot_GT: ../datasets/val_set5/Set5
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
test_2: # the 2st test dataset
name: set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
#### network structures
network_G:
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
upscale: 4
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth

View File

@ -0,0 +1,32 @@
name: MSRGANx4
suffix: ~ # add suffix to saved images
model: sr
distortion: sr
scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0]
datasets:
test_1: # the 1st test dataset
name: set5
mode: LQGT
dataroot_GT: ../datasets/val_set5/Set5
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
test_2: # the 2st test dataset
name: set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
#### network structures
network_G:
which_model_G: MSRResNet
in_nc: 3
out_nc: 3
nf: 64
nb: 16
upscale: 4
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth

View File

@ -0,0 +1,48 @@
name: MSRResNetx4
suffix: ~ # add suffix to saved images
model: sr
distortion: sr
scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0]
datasets:
test_1: # the 1st test dataset
name: set5
mode: LQGT
dataroot_GT: ../datasets/val_set5/Set5
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
test_2: # the 2st test dataset
name: set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
test_3:
name: bsd100
mode: LQGT
dataroot_GT: ../datasets/BSD/BSDS100
dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4
test_4:
name: urban100
mode: LQGT
dataroot_GT: ../datasets/urban100
dataroot_LQ: ../datasets/urban100_bicLRx4
test_5:
name: div2k100
mode: LQGT
dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR
dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4
#### network structures
network_G:
which_model_G: MSRResNet
in_nc: 3
out_nc: 3
nf: 64
nb: 16
upscale: 4
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth

View File

@ -0,0 +1,80 @@
#### general settings
name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new
use_tb_logger: true
model: video_base
distortion: sr
scale: 4
gpu_ids: [0,1,2,3,4,5,6,7]
#### datasets
datasets:
train:
name: REDS
mode: REDS
interval_list: [1]
random_reverse: false
border_mode: false
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
cache_keys: ~
N_frames: 5
use_shuffle: true
n_workers: 3 # per GPU
batch_size: 32
GT_size: 256
LQ_size: 64
use_flip: true
use_rot: true
color: RGB
val:
name: REDS4
mode: video_test
dataroot_GT: ../datasets/REDS4/GT
dataroot_LQ: ../datasets/REDS4/sharp_bicubic
cache_data: True
N_frames: 5
padding: new_info
#### network structures
network_G:
which_model_G: EDVR
nf: 64
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 10
predeblur: false
HR_in: false
w_TSA: true
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth
strict_load: false
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 4e-4
lr_scheme: CosineAnnealingLR_Restart
beta1: 0.9
beta2: 0.99
niter: 600000
ft_tsa_only: 50000
warmup_iter: -1 # -1: no warm up
T_period: [50000, 100000, 150000, 150000, 150000]
restarts: [50000, 150000, 300000, 450000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-7
pixel_criterion: cb
pixel_weight: 1.0
val_freq: !!float 5e3
manual_seed: 0
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

View File

@ -0,0 +1,71 @@
#### general settings
name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S
use_tb_logger: true
model: video_base
distortion: sr
scale: 4
gpu_ids: [0,1,2,3,4,5,6,7]
#### datasets
datasets:
train:
name: REDS
mode: REDS
interval_list: [1]
random_reverse: false
border_mode: false
dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
cache_keys: ~
N_frames: 5
use_shuffle: true
n_workers: 3 # per GPU
batch_size: 32
GT_size: 256
LQ_size: 64
use_flip: true
use_rot: true
color: RGB
#### network structures
network_G:
which_model_G: EDVR
nf: 64
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 10
predeblur: false
HR_in: false
w_TSA: false
#### path
path:
pretrain_model_G: ~
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 4e-4
lr_scheme: CosineAnnealingLR_Restart
beta1: 0.9
beta2: 0.99
niter: 600000
warmup_iter: -1 # -1: no warm up
T_period: [150000, 150000, 150000, 150000]
restarts: [150000, 300000, 450000]
restart_weights: [1, 1, 1]
eta_min: !!float 1e-7
pixel_criterion: cb
pixel_weight: 1.0
val_freq: !!float 5e3
manual_seed: 0
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

View File

@ -0,0 +1,81 @@
#### general settings
name: 003_RRDB_ESRGANx4_DIV2K
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [2]
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
use_shuffle: true
n_workers: 6 # per GPU
batch_size: 16
GT_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: val_set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
#### network structures
network_G:
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
network_D:
which_model_D: discriminator_vgg_128
in_nc: 3
nf: 64
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/RRDB_PSNR_x4.pth
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [50000, 100000, 200000, 300000]
lr_gamma: 0.5
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
gan_type: ragan # gan | ragan
gan_weight: !!float 5e-3
D_update_ratio: 1
D_init_iters: 0
manual_seed: 10
val_freq: !!float 5e3
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

View File

@ -0,0 +1,85 @@
# Not exactly the same as SRGAN in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
# With 16 Residual blocks w/o BN
#### general settings
name: 002_SRGANx4_MSRResNetx4Ini_DIV2K
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [1]
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
use_shuffle: true
n_workers: 6 # per GPU
batch_size: 16
GT_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: val_set14
mode: LQGT
dataroot_GT: ../datasets/val_set14/Set14
dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
#### network structures
network_G:
which_model_G: MSRResNet
in_nc: 3
out_nc: 3
nf: 64
nb: 16
upscale: 4
network_D:
which_model_D: discriminator_vgg_128
in_nc: 3
nf: 64
#### path
path:
pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [50000, 100000, 200000, 300000]
lr_gamma: 0.5
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
gan_type: gan # gan | ragan
gan_weight: !!float 5e-3
D_update_ratio: 1
D_init_iters: 0
manual_seed: 10
val_freq: !!float 5e3
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

View File

@ -0,0 +1,70 @@
# Not exactly the same as SRResNet in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
# With 16 Residual blocks w/o BN
#### general settings
name: 001_MSRResNetx4_scratch_DIV2K
use_tb_logger: true
model: sr
distortion: sr
scale: 4
gpu_ids: [0]
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
use_shuffle: true
n_workers: 6 # per GPU
batch_size: 16
GT_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: val_set5
mode: LQGT
dataroot_GT: ../datasets/val_set5/Set5
dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
#### network structures
network_G:
which_model_G: MSRResNet
in_nc: 3
out_nc: 3
nf: 64
nb: 16
upscale: 4
#### path
path:
pretrain_model_G: ~
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 2e-4
lr_scheme: CosineAnnealingLR_Restart
beta1: 0.9
beta2: 0.99
niter: 1000000
warmup_iter: -1 # no warm up
T_period: [250000, 250000, 250000, 250000]
restarts: [250000, 500000, 750000]
restart_weights: [1, 1, 1]
eta_min: !!float 1e-7
pixel_criterion: l1
pixel_weight: 1.0
manual_seed: 10
val_freq: !!float 5e3
#### logger
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3

10
codes/run_scripts.sh Normal file
View File

@ -0,0 +1,10 @@
# single GPU training (image SR)
python train.py -opt options/train/train_SRResNet.yml
python train.py -opt options/train/train_SRGAN.yml
python train.py -opt options/train/train_ESRGAN.yml
# distributed training (video SR)
# 8 GPUs
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch

View File

@ -0,0 +1,20 @@
function [im_h] = backprojection(im_h, im_l, maxIter)
[row_l, col_l,~] = size(im_l);
[row_h, col_h,~] = size(im_h);
p = fspecial('gaussian', 5, 1);
p = p.^2;
p = p./sum(p(:));
im_l = double(im_l);
im_h = double(im_h);
for ii = 1:maxIter
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
im_diff = im_l - im_l_s;
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
end

View File

@ -0,0 +1,22 @@
clear; close all; clc;
LR_folder = './LR'; % LR
preout_folder = './results'; % pre output
save_folder = './results_20bp';
filepaths = dir(fullfile(preout_folder, '*.png'));
max_iter = 20;
if ~ exist(save_folder, 'dir')
mkdir(save_folder);
end
for idx_im = 1:length(filepaths)
fprintf([num2str(idx_im) '\n']);
im_name = filepaths(idx_im).name;
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
im_out = im2double(imread(fullfile(preout_folder, im_name)));
%tic
im_out = backprojection(im_out, im_LR, max_iter);
%toc
imwrite(im_out, fullfile(save_folder, im_name));
end

View File

@ -0,0 +1,25 @@
clear; close all; clc;
LR_folder = './LR'; % LR
preout_folder = './results'; % pre output
save_folder = './results_20if';
filepaths = dir(fullfile(preout_folder, '*.png'));
max_iter = 20;
if ~ exist(save_folder, 'dir')
mkdir(save_folder);
end
for idx_im = 1:length(filepaths)
fprintf([num2str(idx_im) '\n']);
im_name = filepaths(idx_im).name;
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
im_out = im2double(imread(fullfile(preout_folder, im_name)));
J = imresize(im_LR,4,'bicubic');
%tic
for m = 1:max_iter
im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
end
%toc
imwrite(im_out, fullfile(save_folder, im_name));
end

View File

@ -0,0 +1,27 @@
import os.path as osp
import sys
import torch
try:
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
import models.archs.SRResNet_arch as SRResNet_arch
except ImportError:
pass
pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
crt_net = crt_model.state_dict()
for k, v in crt_net.items():
if k in pretrained_net and 'upconv1' not in k:
crt_net[k] = pretrained_net[k]
print('replace ... ', k)
# x4 -> x3
crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')

105
codes/test.py Normal file
View File

@ -0,0 +1,105 @@
import os.path as osp
import logging
import time
import argparse
from collections import OrderedDict
import options.options as option
import utils.util as util
from data.util import bgr2ycbcr
from data import create_dataset, create_dataloader
from models import create_model
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
test_set = create_dataset(dataset_opt)
test_loader = create_dataloader(test_set, dataset_opt)
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
model = create_model(opt)
for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name']
logger.info('\nTesting [{:s}]...'.format(test_set_name))
test_start_time = time.time()
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
util.mkdir(dataset_dir)
test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnr_y'] = []
test_results['ssim_y'] = []
for data in test_loader:
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
model.feed_data(data, need_GT=need_GT)
img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
img_name = osp.splitext(osp.basename(img_path))[0]
model.test()
visuals = model.get_current_visuals(need_GT=need_GT)
sr_img = util.tensor2img(visuals['rlt']) # uint8
# save images
suffix = opt['suffix']
if suffix:
save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
else:
save_img_path = osp.join(dataset_dir, img_name + '.png')
util.save_img(sr_img, save_img_path)
# calculate PSNR and SSIM
if need_GT:
gt_img = util.tensor2img(visuals['GT'])
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
psnr = util.calculate_psnr(sr_img, gt_img)
ssim = util.calculate_ssim(sr_img, gt_img)
test_results['psnr'].append(psnr)
test_results['ssim'].append(ssim)
if gt_img.shape[2] == 3: # RGB image
sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)
psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
test_results['psnr_y'].append(psnr_y)
test_results['ssim_y'].append(ssim_y)
logger.info(
'{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
format(img_name, psnr, ssim, psnr_y, ssim_y))
else:
logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
else:
logger.info(img_name)
if need_GT: # metrics
# Average PSNR/SSIM results
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
logger.info(
'----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format(
test_set_name, ave_psnr, ave_ssim))
if test_results['psnr_y'] and test_results['ssim_y']:
ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
logger.info(
'----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'.
format(ave_psnr_y, ave_ssim_y))

View File

@ -0,0 +1,208 @@
'''
Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets
'''
import os
import os.path as osp
import glob
import logging
import numpy as np
import cv2
import torch
import utils.util as util
import data.util as data_util
import models.archs.EDVR_arch as EDVR_arch
def main():
#################
# configurations
#################
device = torch.device('cuda')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
# Vid4: SR
# REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
# blur (deblur-clean), blur_comp (deblur-compression).
stage = 1 # 1 or 2, use two stage strategy for REDS dataset.
flip_test = False
############################################################################
#### model
if data_mode == 'Vid4':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
else:
raise ValueError('Vid4 does not support stage 2.')
elif data_mode == 'sharp_bicubic':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
else:
model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
elif data_mode == 'blur_bicubic':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
else:
model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
elif data_mode == 'blur':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
else:
model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
elif data_mode == 'blur_comp':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
else:
model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
else:
raise NotImplementedError
if data_mode == 'Vid4':
N_in = 7 # use N_in images to restore one HR image
else:
N_in = 5
predeblur, HR_in = False, False
back_RBs = 40
if data_mode == 'blur_bicubic':
predeblur = True
if data_mode == 'blur' or data_mode == 'blur_comp':
predeblur, HR_in = True, True
if stage == 2:
HR_in = True
back_RBs = 20
model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
#### dataset
if data_mode == 'Vid4':
test_dataset_folder = '../datasets/Vid4/BIx4'
GT_dataset_folder = '../datasets/Vid4/GT'
else:
if stage == 1:
test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
else:
test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
print('You should modify the test_dataset_folder path for stage 2')
GT_dataset_folder = '../datasets/REDS4/GT'
#### evaluation
crop_border = 0
border_frame = N_in // 2 # border frames when evaluate
# temporal padding mode
if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
padding = 'new_info'
else:
padding = 'replicate'
save_imgs = True
save_folder = '../results/{}'.format(data_mode)
util.mkdirs(save_folder)
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
logger = logging.getLogger('base')
#### log info
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
logger.info('Flip test: {}'.format(flip_test))
#### set up the models
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
subfolder_name_l = []
subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
# for each subfolder
for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
subfolder_name = osp.basename(subfolder)
subfolder_name_l.append(subfolder_name)
save_subfolder = osp.join(save_folder, subfolder_name)
img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
max_idx = len(img_path_l)
if save_imgs:
util.mkdirs(save_subfolder)
#### read LQ and GT images
imgs_LQ = data_util.read_img_seq(subfolder)
img_GT_l = []
for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
img_GT_l.append(data_util.read_img(None, img_GT_path))
avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
# process each image
for img_idx, img_path in enumerate(img_path_l):
img_name = osp.splitext(osp.basename(img_path))[0]
select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
if flip_test:
output = util.flipx4_forward(model, imgs_in)
else:
output = util.single_forward(model, imgs_in)
output = util.tensor2img(output.squeeze(0))
if save_imgs:
cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
# calculate PSNR
output = output / 255.
GT = np.copy(img_GT_l[img_idx])
# For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
if data_mode == 'Vid4': # bgr2y, [0, 1]
GT = data_util.bgr2ycbcr(GT, only_y=True)
output = data_util.bgr2ycbcr(output, only_y=True)
output, GT = util.crop_border([output, GT], crop_border)
crt_psnr = util.calculate_psnr(output * 255, GT * 255)
logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
avg_psnr_center += crt_psnr
N_center += 1
else: # border frames
avg_psnr_border += crt_psnr
N_border += 1
avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
avg_psnr_center = avg_psnr_center / N_center
avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
avg_psnr_l.append(avg_psnr)
avg_psnr_center_l.append(avg_psnr_center)
avg_psnr_border_l.append(avg_psnr_border)
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
'Center PSNR: {:.6f} dB for {} frames; '
'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
(N_center + N_border),
avg_psnr_center, N_center,
avg_psnr_border, N_border))
logger.info('################ Tidy Outputs ################')
for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
avg_psnr_center_l, avg_psnr_border_l):
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
'Center PSNR: {:.6f} dB. '
'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
psnr_border))
logger.info('################ Final Results ################')
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
logger.info('Flip test: {}'.format(flip_test))
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,264 @@
"""
DUF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
write to txt log file
"""
import os
import os.path as osp
import glob
import logging
import numpy as np
import cv2
import torch
import utils.util as util
import data.util as data_util
import models.archs.DUF_arch as DUF_arch
def main():
#################
# configurations
#################
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
# Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
scale = 4
layer = 52
assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28),
(4, 52)], 'Unrecognized (scale, layer) combination'
# model
N_in = 7
model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format(scale, layer)
adapt_official = True if 'official' in model_path else False
DUF_downsampling = True # True | False
if layer == 16:
model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official)
elif layer == 28:
model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official)
elif layer == 52:
model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official)
#### dataset
if data_mode == 'Vid4':
test_dataset_folder = '../datasets/Vid4/BIx4/*'
else: # sharp_bicubic (REDS)
test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
#### evaluation
crop_border = 8
border_frame = N_in // 2 # border frames when evaluate
# temporal padding mode
padding = 'new_info' # different from the official testing codes, which pads zeros.
save_imgs = True
############################################################################
device = torch.device('cuda')
save_folder = '../results/{}'.format(data_mode)
util.mkdirs(save_folder)
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
logger = logging.getLogger('base')
#### log info
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
def read_image(img_path):
'''read one image from img_path
Return img: HWC, BGR, [0,1], numpy
'''
img_GT = cv2.imread(img_path)
img = img_GT.astype(np.float32) / 255.
return img
def read_seq_imgs(img_seq_path):
'''read a sequence of images'''
img_path_l = sorted(glob.glob(img_seq_path + '/*'))
img_l = [read_image(v) for v in img_path_l]
# stack to TCHW, RGB, [0,1], torch
imgs = np.stack(img_l, axis=0)
imgs = imgs[:, :, :, [2, 1, 0]]
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
return imgs
def index_generation(crt_i, max_n, N, padding='reflection'):
'''
padding: replicate | reflection | new_info | circle
'''
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
return return_l
def single_forward(model, imgs_in):
with torch.no_grad():
model_output = model(imgs_in)
if isinstance(model_output, list) or isinstance(model_output, tuple):
output = model_output[0]
else:
output = model_output
return output
sub_folder_l = sorted(glob.glob(test_dataset_folder))
#### set up the models
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
sub_folder_name_l = []
# for each sub-folder
for sub_folder in sub_folder_l:
sub_folder_name = sub_folder.split('/')[-1]
sub_folder_name_l.append(sub_folder_name)
save_sub_folder = osp.join(save_folder, sub_folder_name)
img_path_l = sorted(glob.glob(sub_folder + '/*'))
max_idx = len(img_path_l)
if save_imgs:
util.mkdirs(save_sub_folder)
#### read LR images
imgs = read_seq_imgs(sub_folder)
#### read GT images
img_GT_l = []
if data_mode == 'Vid4':
sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*')
else:
sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
for img_GT_path in sorted(glob.glob(sub_folder_GT)):
img_GT_l.append(read_image(img_GT_path))
# When using the downsampling in DUF official code, we downsample the HR images
if DUF_downsampling:
sub_folder = sub_folder_GT
img_path_l = sorted(glob.glob(sub_folder))
max_idx = len(img_path_l)
imgs = read_seq_imgs(sub_folder[:-2])
avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
cal_n_border, cal_n_center = 0, 0
# process each image
for img_idx, img_path in enumerate(img_path_l):
c_idx = int(osp.splitext(osp.basename(img_path))[0])
select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
# get input images
imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
# Downsample the HR images
H, W = imgs_in.size(3), imgs_in.size(4)
if DUF_downsampling:
imgs_in = util.DUF_downsample(imgs_in, scale=scale)
output = single_forward(model, imgs_in)
# Crop to the original shape
if scale == 3:
pad_h = 3 - (H % 3)
pad_w = 3 - (W % 3)
if pad_h > 0:
output = output[:, :, :-pad_h, :]
if pad_w > 0:
output = output[:, :, :, :-pad_w]
output_f = output.data.float().cpu().squeeze(0)
output = util.tensor2img(output_f)
# save imgs
if save_imgs:
cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
#### calculate PSNR
output = output / 255.
GT = np.copy(img_GT_l[img_idx])
# For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
if data_mode == 'Vid4': # bgr2y, [0, 1]
GT = data_util.bgr2ycbcr(GT)
output = data_util.bgr2ycbcr(output)
if crop_border == 0:
cropped_output = output
cropped_GT = GT
else:
cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
avg_psnr_center += crt_psnr
cal_n_center += 1
else: # border frames
avg_psnr_border += crt_psnr
cal_n_border += 1
avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
avg_psnr_center = avg_psnr_center / cal_n_center
if cal_n_border == 0:
avg_psnr_border = 0
else:
avg_psnr_border = avg_psnr_border / cal_n_border
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
'Center PSNR: {:.6f} dB for {} frames; '
'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
(cal_n_center + cal_n_border),
avg_psnr_center, cal_n_center,
avg_psnr_border, cal_n_border))
avg_psnr_l.append(avg_psnr)
avg_psnr_center_l.append(avg_psnr_center)
avg_psnr_border_l.append(avg_psnr_border)
logger.info('################ Tidy Outputs ################')
for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
avg_psnr_center_l, avg_psnr_border_l):
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
'Center PSNR: {:.6f} dB. '
'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
logger.info('################ Final Results ################')
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,230 @@
"""
TOF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
write to txt log file
"""
import os
import os.path as osp
import glob
import logging
import numpy as np
import cv2
import torch
import utils.util as util
import data.util as data_util
import models.archs.TOF_arch as TOF_arch
def main():
#################
# configurations
#################
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
# model
N_in = 7
model_path = '../experiments/pretrained_models/TOF_official.pth'
adapt_official = True if 'official' in model_path else False
model = TOF_arch.TOFlow(adapt_official=adapt_official)
#### dataset
if data_mode == 'Vid4':
test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*'
else:
test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
#### evaluation
crop_border = 0
border_frame = N_in // 2 # border frames when evaluate
# temporal padding mode
padding = 'new_info' # different from the official setting
save_imgs = True
############################################################################
device = torch.device('cuda')
save_folder = '../results/{}'.format(data_mode)
util.mkdirs(save_folder)
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
logger = logging.getLogger('base')
#### log info
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
def read_image(img_path):
'''read one image from img_path
Return img: HWC, BGR, [0,1], numpy
'''
img_GT = cv2.imread(img_path)
img = img_GT.astype(np.float32) / 255.
return img
def read_seq_imgs(img_seq_path):
'''read a sequence of images'''
img_path_l = sorted(glob.glob(img_seq_path + '/*'))
img_l = [read_image(v) for v in img_path_l]
# stack to TCHW, RGB, [0,1], torch
imgs = np.stack(img_l, axis=0)
imgs = imgs[:, :, :, [2, 1, 0]]
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
return imgs
def index_generation(crt_i, max_n, N, padding='reflection'):
'''
padding: replicate | reflection | new_info | circle
'''
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
return return_l
def single_forward(model, imgs_in):
with torch.no_grad():
model_output = model(imgs_in)
if isinstance(model_output, list) or isinstance(model_output, tuple):
output = model_output[0]
else:
output = model_output
return output
sub_folder_l = sorted(glob.glob(test_dataset_folder))
#### set up the models
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
sub_folder_name_l = []
# for each sub-folder
for sub_folder in sub_folder_l:
sub_folder_name = sub_folder.split('/')[-1]
sub_folder_name_l.append(sub_folder_name)
save_sub_folder = osp.join(save_folder, sub_folder_name)
img_path_l = sorted(glob.glob(sub_folder + '/*'))
max_idx = len(img_path_l)
if save_imgs:
util.mkdirs(save_sub_folder)
#### read LR images
imgs = read_seq_imgs(sub_folder)
#### read GT images
img_GT_l = []
if data_mode == 'Vid4':
sub_folder_GT = osp.join(sub_folder.replace('/BIx4up_direct/', '/GT/'), '*')
else:
sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
for img_GT_path in sorted(glob.glob(sub_folder_GT)):
img_GT_l.append(read_image(img_GT_path))
avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
cal_n_border, cal_n_center = 0, 0
# process each image
for img_idx, img_path in enumerate(img_path_l):
c_idx = int(osp.splitext(osp.basename(img_path))[0])
select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
# get input images
imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
output = single_forward(model, imgs_in)
output_f = output.data.float().cpu().squeeze(0)
output = util.tensor2img(output_f)
# save imgs
if save_imgs:
cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
#### calculate PSNR
output = output / 255.
GT = np.copy(img_GT_l[img_idx])
# For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
if data_mode == 'Vid4': # bgr2y, [0, 1]
GT = data_util.bgr2ycbcr(GT)
output = data_util.bgr2ycbcr(output)
if crop_border == 0:
cropped_output = output
cropped_GT = GT
else:
cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
avg_psnr_center += crt_psnr
cal_n_center += 1
else: # border frames
avg_psnr_border += crt_psnr
cal_n_border += 1
avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
avg_psnr_center = avg_psnr_center / cal_n_center
if cal_n_border == 0:
avg_psnr_border = 0
else:
avg_psnr_border = avg_psnr_border / cal_n_border
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
'Center PSNR: {:.6f} dB for {} frames; '
'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
(cal_n_center + cal_n_border),
avg_psnr_center, cal_n_center,
avg_psnr_border, cal_n_border))
avg_psnr_l.append(avg_psnr)
avg_psnr_center_l.append(avg_psnr_center)
avg_psnr_border_l.append(avg_psnr_border)
logger.info('################ Tidy Outputs ################')
for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
avg_psnr_center_l, avg_psnr_border_l):
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
'Center PSNR: {:.6f} dB. '
'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
logger.info('################ Final Results ################')
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
if __name__ == '__main__':
main()

310
codes/train.py Normal file
View File

@ -0,0 +1,310 @@
import os
import math
import argparse
import random
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from data.data_sampler import DistIterSampler
import options.options as option
from utils import util
from data import create_dataloader, create_dataset
from models import create_model
def init_dist(backend='nccl', **kwargs):
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
#### loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
device_id = torch.cuda.current_device()
resume_state = torch.load(opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id))
option.check_resume(opt, resume_state['iter']) # check resume options
else:
resume_state = None
#### mkdir and loggers
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
if resume_state is None:
util.mkdir_and_rename(
opt['path']['experiments_root']) # rename experiment folder if exists
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
and 'pretrain_model' not in key and 'resume' not in key))
# config loggers. Before it, the log will not work
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
version = float(torch.__version__[0:3])
if version >= 1.1: # PyTorch 1.1
from torch.utils.tensorboard import SummaryWriter
else:
logger.info(
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
from tensorboardX import SummaryWriter
tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
else:
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
logger = logging.getLogger('base')
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
if rank <= 0:
logger.info('Random seed: {}'.format(seed))
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
#### create train and val dataloader
dataset_ratio = 200 # enlarge the size of each epoch
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']:
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
else:
train_sampler = None
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
if rank <= 0:
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
len(train_set), train_size))
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
total_epochs, total_iters))
elif phase == 'val':
val_set = create_dataset(dataset_opt)
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
if rank <= 0:
logger.info('Number of val images in [{:s}]: {:d}'.format(
dataset_opt['name'], len(val_set)))
else:
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
assert train_loader is not None
#### create model
model = create_model(opt)
#### resume training
if resume_state:
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
resume_state['epoch'], resume_state['iter']))
start_epoch = resume_state['epoch']
current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers
else:
current_step = 0
start_epoch = 0
#### training
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
for epoch in range(start_epoch, total_epochs + 1):
if opt['dist']:
train_sampler.set_epoch(epoch)
for _, train_data in enumerate(train_loader):
current_step += 1
if current_step > total_iters:
break
#### update learning rate
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
#### training
model.feed_data(train_data)
model.optimize_parameters(current_step)
#### log
if current_step % opt['logger']['print_freq'] == 0:
logs = model.get_current_log()
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
for v in model.get_current_learning_rate():
message += '{:.3e},'.format(v)
message += ')] '
for k, v in logs.items():
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
if rank <= 0:
tb_logger.add_scalar(k, v, current_step)
if rank <= 0:
logger.info(message)
#### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation
# does not support multi-GPU validation
pbar = util.ProgressBar(len(val_loader))
avg_psnr = 0.
idx = 0
for val_data in val_loader:
idx += 1
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
model.feed_data(val_data)
model.test()
visuals = model.get_current_visuals()
sr_img = util.tensor2img(visuals['rlt']) # uint8
gt_img = util.tensor2img(visuals['GT']) # uint8
# Save SR images for reference
save_img_path = os.path.join(img_dir,
'{:s}_{:d}.png'.format(img_name, current_step))
util.save_img(sr_img, save_img_path)
# calculate PSNR
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
pbar.update('Test {}'.format(img_name))
avg_psnr = avg_psnr / idx
# log
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger.add_scalar('psnr', avg_psnr, current_step)
else: # video restoration validation
if opt['dist']:
# multi-GPU testing
psnr_rlt = {} # with border and center frames
if rank == 0:
pbar = util.ProgressBar(len(val_set))
for idx in range(rank, len(val_set), world_size):
val_data = val_set[idx]
val_data['LQs'].unsqueeze_(0)
val_data['GT'].unsqueeze_(0)
folder = val_data['folder']
idx_d, max_idx = val_data['idx'].split('/')
idx_d, max_idx = int(idx_d), int(max_idx)
if psnr_rlt.get(folder, None) is None:
psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32,
device='cuda')
# tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
model.feed_data(val_data)
model.test()
visuals = model.get_current_visuals()
rlt_img = util.tensor2img(visuals['rlt']) # uint8
gt_img = util.tensor2img(visuals['GT']) # uint8
# calculate PSNR
psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img)
if rank == 0:
for _ in range(world_size):
pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx))
# # collect data
for _, v in psnr_rlt.items():
dist.reduce(v, 0)
dist.barrier()
if rank == 0:
psnr_rlt_avg = {}
psnr_total_avg = 0.
for k, v in psnr_rlt.items():
psnr_rlt_avg[k] = torch.mean(v).cpu().item()
psnr_total_avg += psnr_rlt_avg[k]
psnr_total_avg /= len(psnr_rlt)
log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
for k, v in psnr_rlt_avg.items():
log_s += ' {}: {:.4e}'.format(k, v)
logger.info(log_s)
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
for k, v in psnr_rlt_avg.items():
tb_logger.add_scalar(k, v, current_step)
else:
pbar = util.ProgressBar(len(val_loader))
psnr_rlt = {} # with border and center frames
psnr_rlt_avg = {}
psnr_total_avg = 0.
for val_data in val_loader:
folder = val_data['folder'][0]
idx_d = val_data['idx'].item()
# border = val_data['border'].item()
if psnr_rlt.get(folder, None) is None:
psnr_rlt[folder] = []
model.feed_data(val_data)
model.test()
visuals = model.get_current_visuals()
rlt_img = util.tensor2img(visuals['rlt']) # uint8
gt_img = util.tensor2img(visuals['GT']) # uint8
# calculate PSNR
psnr = util.calculate_psnr(rlt_img, gt_img)
psnr_rlt[folder].append(psnr)
pbar.update('Test {} - {}'.format(folder, idx_d))
for k, v in psnr_rlt.items():
psnr_rlt_avg[k] = sum(v) / len(v)
psnr_total_avg += psnr_rlt_avg[k]
psnr_total_avg /= len(psnr_rlt)
log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
for k, v in psnr_rlt_avg.items():
log_s += ' {}: {:.4e}'.format(k, v)
logger.info(log_s)
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
for k, v in psnr_rlt_avg.items():
tb_logger.add_scalar(k, v, current_step)
#### save models and training states
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
if rank <= 0:
logger.info('Saving models and training states.')
model.save(current_step)
model.save_training_state(epoch, current_step)
if rank <= 0:
logger.info('Saving the final model.')
model.save('latest')
logger.info('End of training.')
tb_logger.close()
if __name__ == '__main__':
main()

0
codes/utils/__init__.py Normal file
View File

327
codes/utils/util.py Normal file
View File

@ -0,0 +1,327 @@
import os
import sys
import time
import math
import torch.nn.functional as F
from datetime import datetime
import random
import logging
from collections import OrderedDict
import numpy as np
import cv2
import torch
from torchvision.utils import make_grid
from shutil import get_terminal_size
import yaml
try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper
def OrderedYaml():
'''yaml orderedDict support'''
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
####################
# miscellaneous
####################
def get_timestamp():
return datetime.now().strftime('%y%m%d-%H%M%S')
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def mkdirs(paths):
if isinstance(paths, str):
mkdir(paths)
else:
for path in paths:
mkdir(path)
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + '_archived_' + get_timestamp()
print('Path already exists. Rename it to [{:s}]'.format(new_name))
logger = logging.getLogger('base')
logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
os.rename(path, new_name)
os.makedirs(path)
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
'''set up logger'''
lg = logging.getLogger(logger_name)
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
datefmt='%y-%m-%d %H:%M:%S')
lg.setLevel(level)
if tofile:
log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
fh = logging.FileHandler(log_file, mode='w')
fh.setFormatter(formatter)
lg.addHandler(fh)
if screen:
sh = logging.StreamHandler()
sh.setFormatter(formatter)
lg.addHandler(sh)
####################
# image convert
####################
def crop_border(img_list, crop_border):
"""Crop borders of images
Args:
img_list (list [Numpy]): HWC
crop_border (int): crop border for each end of height and weight
Returns:
(list [Numpy]): cropped image list
"""
if crop_border == 0:
return img_list
else:
return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list]
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
'''
Converts a torch Tensor into an image Numpy array
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3:
img_np = tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 2:
img_np = tensor.numpy()
else:
raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type)
def save_img(img, img_path, mode='RGB'):
cv2.imwrite(img_path, img)
def DUF_downsample(x, scale=4):
"""Downsamping with Gaussian kernel used in the DUF official code
Args:
x (Tensor, [B, T, C, H, W]): frames to be downsampled.
scale (int): downsampling factor: 2 | 3 | 4.
"""
assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale)
def gkern(kernlen=13, nsig=1.6):
import scipy.ndimage.filters as fi
inp = np.zeros((kernlen, kernlen))
# set element at the middle to one, a dirac delta
inp[kernlen // 2, kernlen // 2] = 1
# gaussian-smooth the dirac, resulting in a gaussian filter mask
return fi.gaussian_filter(inp, nsig)
B, T, C, H, W = x.size()
x = x.view(-1, 1, H, W)
pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter
r_h, r_w = 0, 0
if scale == 3:
r_h = 3 - (H % 3)
r_w = 3 - (W % 3)
x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect')
gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0)
x = F.conv2d(x, gaussian_filter, stride=scale)
x = x[:, :, 2:-2, 2:-2]
x = x.view(B, T, C, x.size(2), x.size(3))
return x
def single_forward(model, inp):
"""PyTorch model forward (single test), it is just a simple warpper
Args:
model (PyTorch model)
inp (Tensor): inputs defined by the model
Returns:
output (Tensor): outputs of the model. float, in CPU
"""
with torch.no_grad():
model_output = model(inp)
if isinstance(model_output, list) or isinstance(model_output, tuple):
output = model_output[0]
else:
output = model_output
output = output.data.float().cpu()
return output
def flipx4_forward(model, inp):
"""Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W
Args:
model (PyTorch model)
inp (Tensor): inputs defined by the model
Returns:
output (Tensor): outputs of the model. float, in CPU
"""
# normal
output_f = single_forward(model, inp)
# flip W
output = single_forward(model, torch.flip(inp, (-1, )))
output_f = output_f + torch.flip(output, (-1, ))
# flip H
output = single_forward(model, torch.flip(inp, (-2, )))
output_f = output_f + torch.flip(output, (-2, ))
# flip both H and W
output = single_forward(model, torch.flip(inp, (-2, -1)))
output_f = output_f + torch.flip(output, (-2, -1))
return output_f / 4
####################
# metric
####################
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(img1, img2):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
class ProgressBar(object):
'''A progress bar which can print the progress
modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
'''
def __init__(self, task_num=0, bar_width=50, start=True):
self.task_num = task_num
max_bar_width = self._get_max_bar_width()
self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
self.completed = 0
if start:
self.start()
def _get_max_bar_width(self):
terminal_width, _ = get_terminal_size()
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
if max_bar_width < 10:
print('terminal width is too small ({}), please consider widen the terminal for better '
'progressbar visualization'.format(terminal_width))
max_bar_width = 10
return max_bar_width
def start(self):
if self.task_num > 0:
sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
' ' * self.bar_width, self.task_num, 'Start...'))
else:
sys.stdout.write('completed: 0, elapsed: 0s')
sys.stdout.flush()
self.start_time = time.time()
def update(self, msg='In progress...'):
self.completed += 1
elapsed = time.time() - self.start_time
fps = self.completed / elapsed
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
mark_width = int(self.bar_width * percentage)
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
sys.stdout.write('\033[2F') # cursor up 2 lines
sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
else:
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
self.completed, int(elapsed + 0.5), fps))
sys.stdout.flush()