Clean up video stuff

This commit is contained in:
James Betker 2020-05-25 19:20:49 -06:00
parent 8464cae168
commit 4e44b8a1aa
14 changed files with 0 additions and 3817 deletions

View File

@ -1,167 +0,0 @@
'''
Vimeo90K dataset
support reading images from lmdb, image folder and memcached
'''
import os.path as osp
import random
import pickle
import logging
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
try:
import mc # import memcached
except ImportError:
pass
logger = logging.getLogger('base')
class Vimeo90KDataset(data.Dataset):
'''
Reading the training Vimeo90K dataset
key example: 00001_0001 (_1, ..., _7)
GT (Ground-Truth): 4th frame;
LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame
'''
def __init__(self, opt):
super(Vimeo90KDataset, self).__init__()
self.opt = opt
# temporal augmentation
self.interval_list = opt['interval_list']
self.random_reverse = opt['random_reverse']
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
self.data_type = self.opt['data_type']
self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs
#### determine the LQ frame list
'''
N | frames
1 | 4
3 | 3,4,5
5 | 2,3,4,5,6
7 | 1,2,3,4,5,6,7
'''
self.LQ_frames_list = []
for i in range(opt['N_frames']):
self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2)
#### directly load image keys
if self.data_type == 'lmdb':
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
logger.info('Using lmdb meta info for cache keys.')
elif opt['cache_keys']:
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
else:
raise ValueError(
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
assert self.paths_GT, 'Error: GT path is empty.'
if self.data_type == 'lmdb':
self.GT_env, self.LQ_env = None, None
elif self.data_type == 'mc': # memcached
self.mclient = None
elif self.data_type == 'img':
pass
else:
raise ValueError('Wrong data type: {}'.format(self.data_type))
def _init_lmdb(self):
# https://github.com/chainer/chainermn/issues/129
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
meminit=False)
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
meminit=False)
def _ensure_memcached(self):
if self.mclient is None:
# specify the config files
server_list_config_file = None
client_config_file = None
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
client_config_file)
def _read_img_mc(self, path):
''' Return BGR, HWC, [0, 255], uint8'''
value = mc.pyvector()
self.mclient.Get(path, value)
value_buf = mc.ConvertBuffer(value)
img_array = np.frombuffer(value_buf, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
return img
def __getitem__(self, index):
if self.data_type == 'mc':
self._ensure_memcached()
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
self._init_lmdb()
scale = self.opt['scale']
GT_size = self.opt['target_size']
key = self.paths_GT[index]
name_a, name_b = key.split('_')
#### get the GT image (as the center frame)
if self.data_type == 'mc':
img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png'))
img_GT = img_GT.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448))
else:
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png'))
#### get LQ images
LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
img_LQ_l = []
for v in self.LQ_frames_list:
if self.data_type == 'mc':
img_LQ = self._read_img_mc(
osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v)))
img_LQ = img_LQ.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
else:
img_LQ = util.read_img(None,
osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
img_LQ_l.append(img_LQ)
if self.opt['phase'] == 'train':
C, H, W = LQ_size_tuple # LQ size
# randomly crop
if self.LR_input:
LQ_size = GT_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
else:
rnd_h = random.randint(0, max(0, H - GT_size))
rnd_w = random.randint(0, max(0, W - GT_size))
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
# augmentation - flip, rotate
img_LQ_l.append(img_GT)
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
img_LQ_l = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(img_LQ_l, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
(0, 3, 1, 2)))).float()
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
def __len__(self):
return len(self.paths_GT)

View File

@ -1,84 +0,0 @@
import os.path as osp
import torch
import torch.utils.data as data
import data.util as util
class VideoTestDataset(data.Dataset):
"""
A video test dataset. Support:
Vid4
REDS4
Vimeo90K-Test
no need to prepare LMDB files
"""
def __init__(self, opt):
super(VideoTestDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
self.half_N_frames = opt['N_frames'] // 2
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
self.data_type = self.opt['data_type']
self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []}
if self.data_type == 'lmdb':
raise ValueError('No need to use LMDB during validation/test.')
#### Generate data info and cache data
self.imgs_LQ, self.imgs_GT = {}, {}
if opt['name'].lower() in ['vid4', 'reds4']:
subfolders_LQ = util.glob_file_list(self.LQ_root)
subfolders_GT = util.glob_file_list(self.GT_root)
for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT):
subfolder_name = osp.basename(subfolder_GT)
img_paths_LQ = util.glob_file_list(subfolder_LQ)
img_paths_GT = util.glob_file_list(subfolder_GT)
max_idx = len(img_paths_LQ)
assert max_idx == len(
img_paths_GT), 'Different number of images in LQ and GT folders'
self.data_info['path_LQ'].extend(img_paths_LQ)
self.data_info['path_GT'].extend(img_paths_GT)
self.data_info['folder'].extend([subfolder_name] * max_idx)
for i in range(max_idx):
self.data_info['idx'].append('{}/{}'.format(i, max_idx))
border_l = [0] * max_idx
for i in range(self.half_N_frames):
border_l[i] = 1
border_l[max_idx - i - 1] = 1
self.data_info['border'].extend(border_l)
if self.cache_data:
self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ)
self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT)
elif opt['name'].lower() in ['vimeo90k-test']:
pass # TODO
else:
raise ValueError(
'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.')
def __getitem__(self, index):
# path_LQ = self.data_info['path_LQ'][index]
# path_GT = self.data_info['path_GT'][index]
folder = self.data_info['folder'][index]
idx, max_idx = self.data_info['idx'][index].split('/')
idx, max_idx = int(idx), int(max_idx)
border = self.data_info['border'][index]
if self.cache_data:
select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
padding=self.opt['padding'])
imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx))
img_GT = self.imgs_GT[folder][idx]
else:
pass # TODO
return {
'LQs': imgs_LQ,
'GT': img_GT,
'folder': folder,
'idx': self.data_info['idx'][index],
'border': border
}
def __len__(self):
return len(self.data_info['path_GT'])

View File

@ -1,166 +0,0 @@
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
from models.loss import CharbonnierLoss
logger = logging.getLogger('base')
class VideoBaseModel(BaseModel):
def __init__(self, opt):
super(VideoBaseModel, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
# define network and load pretrained models
self.netG = networks.define_G(opt).to(self.device)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
self.netG = DataParallel(self.netG)
# print network
self.print_network()
self.load()
if self.is_train:
self.netG.train()
#### loss
loss_type = train_opt['pixel_criterion']
if loss_type == 'l1':
self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
elif loss_type == 'l2':
self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
elif loss_type == 'cb':
self.cri_pix = CharbonnierLoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
self.l_pix_w = train_opt['pixel_weight']
#### optimizers
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
if train_opt['ft_tsa_only']:
normal_params = []
tsa_fusion_params = []
for k, v in self.netG.named_parameters():
if v.requires_grad:
if 'tsa_fusion' in k:
tsa_fusion_params.append(v)
else:
normal_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['lr_G']
},
{
'params': tsa_fusion_params,
'lr': train_opt['lr_G']
},
]
else:
optim_params = []
for k, v in self.netG.named_parameters():
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1'], train_opt['beta2']))
self.optimizers.append(self.optimizer_G)
#### schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.CosineAnnealingLR_Restart(
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
else:
raise NotImplementedError()
self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True):
self.var_L = data['LQs'].to(self.device)
if need_GT:
self.real_H = data['GT'].to(self.device)
def set_params_lr_zero(self):
# fix normal module
self.optimizers[0].param_groups[0]['lr'] = 0
def optimize_parameters(self, step):
if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
self.set_params_lr_zero()
self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
l_pix.backward()
self.optimizer_G.step()
# set log
self.log_dict['l_pix'] = l_pix.item()
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H = self.netG(self.var_L)
self.netG.train()
def get_current_log(self):
return self.log_dict
def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict()
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
if need_GT:
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
return out_dict
def print_network(self):
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
def save(self, iter_label):
self.save_network(self.netG, 'G', iter_label)

View File

@ -1,368 +0,0 @@
'''Network architecture for DUF:
Deep Video Super-Resolution Network Using Dynamic Upsampling Filters
Without Explicit Motion Compensation (CVPR18)
https://github.com/yhjo09/VSR-DUF
For all the models below, [adapt_official] is only necessary when
loading the weights converted from the official TensorFlow weights.
Please set it to [False] if you are training the model from scratch.
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def adapt_official(Rx, scale=4):
'''Adapt the weights translated from the official tensorflow weights
Not necessary if you are training from scratch'''
x = Rx.clone()
x1 = x[:, ::3, :, :]
x2 = x[:, 1::3, :, :]
x3 = x[:, 2::3, :, :]
Rx[:, :scale**2, :, :] = x1
Rx[:, scale**2:2 * (scale**2), :, :] = x2
Rx[:, 2 * (scale**2):, :, :] = x3
return Rx
class DenseBlock(nn.Module):
'''Dense block
for the second denseblock, t_reduced = True'''
def __init__(self, nf=64, ng=32, t_reduce=False):
super(DenseBlock, self).__init__()
self.t_reduce = t_reduce
if self.t_reduce:
pad = (0, 1, 1)
else:
pad = (1, 1, 1)
self.bn3d_1 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
self.conv3d_1 = nn.Conv3d(nf, nf, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.bn3d_2 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
self.conv3d_2 = nn.Conv3d(nf, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
self.bn3d_3 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
self.conv3d_3 = nn.Conv3d(nf + ng, nf + ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.bn3d_4 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
self.conv3d_4 = nn.Conv3d(nf + ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
self.bn3d_5 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
self.conv3d_5 = nn.Conv3d(nf + 2 * ng, nf + 2 * ng, (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.bn3d_6 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
self.conv3d_6 = nn.Conv3d(nf + 2 * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad,
bias=True)
def forward(self, x):
'''x: [B, C, T, H, W]
C: nf -> nf + 3 * ng
T: 1) 7 -> 7 (t_reduce=False);
2) 7 -> 7 - 2 * 3 = 1 (t_reduce=True)'''
x1 = self.conv3d_1(F.relu(self.bn3d_1(x), inplace=True))
x1 = self.conv3d_2(F.relu(self.bn3d_2(x1), inplace=True))
if self.t_reduce:
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
else:
x1 = torch.cat((x, x1), 1)
x2 = self.conv3d_3(F.relu(self.bn3d_3(x1), inplace=True))
x2 = self.conv3d_4(F.relu(self.bn3d_4(x2), inplace=True))
if self.t_reduce:
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
else:
x2 = torch.cat((x1, x2), 1)
x3 = self.conv3d_5(F.relu(self.bn3d_5(x2), inplace=True))
x3 = self.conv3d_6(F.relu(self.bn3d_6(x3), inplace=True))
if self.t_reduce:
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
else:
x3 = torch.cat((x2, x3), 1)
return x3
class DynamicUpsamplingFilter_3C(nn.Module):
'''dynamic upsampling filter with 3 channels applying the same filters
filter_size: filter size of the generated filters, shape (C, kH, kW)'''
def __init__(self, filter_size=(1, 5, 5)):
super(DynamicUpsamplingFilter_3C, self).__init__()
# generate a local expansion filter, used similar to im2col
nF = np.prod(filter_size)
expand_filter_np = np.reshape(np.eye(nF, nF),
(nF, filter_size[0], filter_size[1], filter_size[2]))
expand_filter = torch.from_numpy(expand_filter_np).float()
self.expand_filter = torch.cat((expand_filter, expand_filter, expand_filter),
0) # [75, 1, 5, 5]
def forward(self, x, filters):
'''x: input image, [B, 3, H, W]
filters: generate dynamic filters, [B, F, R, H, W], e.g., [B, 25, 16, H, W]
F: prod of filter kernel size, e.g., 5*5 = 25
R: used for upsampling, similar to pixel shuffle, e.g., 4*4 = 16 for x4
Return: filtered image, [B, 3*R, H, W]
'''
B, nF, R, H, W = filters.size()
# using group convolution
input_expand = F.conv2d(x, self.expand_filter.type_as(x), padding=2,
groups=3) # [B, 75, H, W] similar to im2col
input_expand = input_expand.view(B, 3, nF, H, W).permute(0, 3, 4, 1, 2) # [B, H, W, 3, 25]
filters = filters.permute(0, 3, 4, 1, 2) # [B, H, W, 25, 16]
out = torch.matmul(input_expand, filters) # [B, H, W, 3, 16]
return out.permute(0, 3, 4, 1, 2).view(B, 3 * R, H, W) # [B, 3*16, H, W]
class DUF_16L(nn.Module):
'''Official DUF structure with 16 layers'''
def __init__(self, scale=4, adapt_official=False):
super(DUF_16L, self).__init__()
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.dense_block_1 = DenseBlock(64, 64 // 2, t_reduce=False) # 64 + 32 * 3 = 160, T = 7
self.dense_block_2 = DenseBlock(160, 64 // 2, t_reduce=True) # 160 + 32 * 3 = 256, T = 1
self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
bias=True)
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
self.scale = scale
self.adapt_official = adapt_official
def forward(self, x):
'''
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
Generate filters and image residual:
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
Rx: [B, 3*16, 1, H, W]
'''
B, T, C, H, W = x.size()
x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] for Conv3D
x_center = x[:, :, T // 2, :, :]
x = self.conv3d_1(x)
x = self.dense_block_1(x)
x = self.dense_block_2(x) # reduce T to 1
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
# image residual
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
# filter
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
# Adapt to official model weights
if self.adapt_official:
adapt_official(Rx, scale=self.scale)
# dynamic filter
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
out += Rx.squeeze_(2)
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
return out
class DenseBlock_28L(nn.Module):
'''The first part of the dense blocks used in DUF_28L
Temporal dimension remains the same here'''
def __init__(self, nf=64, ng=16):
super(DenseBlock_28L, self).__init__()
pad = (1, 1, 1)
dense_block_l = []
for i in range(0, 9):
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
dense_block_l.append(nn.ReLU())
dense_block_l.append(
nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True))
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
dense_block_l.append(nn.ReLU())
dense_block_l.append(
nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
self.dense_blocks = nn.ModuleList(dense_block_l)
def forward(self, x):
'''x: [B, C, T, H, W]
C: 1) 64 -> 208;
T: 1) 7 -> 7; (t_reduce=True)'''
for i in range(0, len(self.dense_blocks), 6):
y = x
for j in range(6):
y = self.dense_blocks[i + j](y)
x = torch.cat((x, y), 1)
return x
class DUF_28L(nn.Module):
'''Official DUF structure with 28 layers'''
def __init__(self, scale=4, adapt_official=False):
super(DUF_28L, self).__init__()
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.dense_block_1 = DenseBlock_28L(64, 16) # 64 + 16 * 9 = 208, T = 7
self.dense_block_2 = DenseBlock(208, 16, t_reduce=True) # 208 + 16 * 3 = 256, T = 1
self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
bias=True)
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
self.scale = scale
self.adapt_official = adapt_official
def forward(self, x):
'''
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
Generate filters and image residual:
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
Rx: [B, 3*16, 1, H, W]
'''
B, T, C, H, W = x.size()
x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
x_center = x[:, :, T // 2, :, :]
x = self.conv3d_1(x)
x = self.dense_block_1(x)
x = self.dense_block_2(x) # reduce T to 1
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
# image residual
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
# filter
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
# Adapt to official model weights
if self.adapt_official:
adapt_official(Rx, scale=self.scale)
# dynamic filter
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
out += Rx.squeeze_(2)
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
return out
class DenseBlock_52L(nn.Module):
'''The first part of the dense blocks used in DUF_52L
Temporal dimension remains the same here'''
def __init__(self, nf=64, ng=16):
super(DenseBlock_52L, self).__init__()
pad = (1, 1, 1)
dense_block_l = []
for i in range(0, 21):
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
dense_block_l.append(nn.ReLU())
dense_block_l.append(
nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True))
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
dense_block_l.append(nn.ReLU())
dense_block_l.append(
nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
self.dense_blocks = nn.ModuleList(dense_block_l)
def forward(self, x):
'''x: [B, C, T, H, W]
C: 1) 64 -> 400;
T: 1) 7 -> 7; (t_reduce=True)'''
for i in range(0, len(self.dense_blocks), 6):
y = x
for j in range(6):
y = self.dense_blocks[i + j](y)
x = torch.cat((x, y), 1)
return x
class DUF_52L(nn.Module):
'''Official DUF structure with 52 layers'''
def __init__(self, scale=4, adapt_official=False):
super(DUF_52L, self).__init__()
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.dense_block_1 = DenseBlock_52L(64, 16) # 64 + 21 * 9 = 400, T = 7
self.dense_block_2 = DenseBlock(400, 16, t_reduce=True) # 400 + 16 * 3 = 448, T = 1
self.bn3d_2 = nn.BatchNorm3d(448, eps=1e-3, momentum=1e-3)
self.conv3d_2 = nn.Conv3d(448, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
bias=True)
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
bias=True)
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
padding=(0, 0, 0), bias=True)
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
self.scale = scale
self.adapt_official = adapt_official
def forward(self, x):
'''
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
Generate filters and image residual:
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
Rx: [B, 3*16, 1, H, W]
'''
B, T, C, H, W = x.size()
x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
x_center = x[:, :, T // 2, :, :]
x = self.conv3d_1(x)
x = self.dense_block_1(x)
x = self.dense_block_2(x)
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
# image residual
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
# filter
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
# Adapt to official model weights
if self.adapt_official:
adapt_official(Rx, scale=self.scale)
# dynamic filter
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
out += Rx.squeeze_(2)
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
return out

View File

@ -1,312 +0,0 @@
''' network architecture for EDVR '''
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
try:
from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN
except ImportError:
raise ImportError('Failed to import DCNv2 module.')
class Predeblur_ResNet_Pyramid(nn.Module):
def __init__(self, nf=128, HR_in=False):
'''
HR_in: True if the inputs are high spatial size
'''
super(Predeblur_ResNet_Pyramid, self).__init__()
self.HR_in = True if HR_in else False
if self.HR_in:
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
else:
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
self.RB_L1_1 = basic_block()
self.RB_L1_2 = basic_block()
self.RB_L1_3 = basic_block()
self.RB_L1_4 = basic_block()
self.RB_L1_5 = basic_block()
self.RB_L2_1 = basic_block()
self.RB_L2_2 = basic_block()
self.RB_L3_1 = basic_block()
self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
if self.HR_in:
L1_fea = self.lrelu(self.conv_first_1(x))
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
else:
L1_fea = self.lrelu(self.conv_first(x))
L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea))
L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea))
L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear',
align_corners=False)
L2_fea = self.RB_L2_1(L2_fea) + L3_fea
L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear',
align_corners=False)
L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea
out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea)))
return out
class PCD_Align(nn.Module):
''' Alignment module using Pyramid, Cascading and Deformable convolution
with 3 pyramid levels.
'''
def __init__(self, nf=64, groups=8):
super(PCD_Align, self).__init__()
# L3: level 3, 1/4 spatial size
self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
# L2: level 2, 1/2 spatial size
self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
# L1: level 1, original spatial size
self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
# Cascading DCN
self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, nbr_fea_l, ref_fea_l):
'''align other neighboring frames to the reference frame in the feature level
nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
'''
# L3
L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)
L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
# L2
L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)
L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
# L1
L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)
L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
# Cascading
offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
offset = self.lrelu(self.cas_offset_conv1(offset))
offset = self.lrelu(self.cas_offset_conv2(offset))
L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
return L1_fea
class TSA_Fusion(nn.Module):
''' Temporal Spatial Attention fusion module
Temporal: correlation;
Spatial: 3 pyramid levels.
'''
def __init__(self, nf=64, nframes=5, center=2):
super(TSA_Fusion, self).__init__()
self.center = center
# temporal attention (before fusion conv)
self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# fusion conv: using 1x1 to save parameters and computation
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
# spatial attention (after fusion conv)
self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True)
self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, aligned_fea):
B, N, C, H, W = aligned_fea.size() # N video frames
#### temporal attention
emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())
emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W]
cor_l = []
for i in range(N):
emb_nbr = emb[:, i, :, :, :]
cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W
cor_l.append(cor_tmp)
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W)
aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob
#### fusion
fea = self.lrelu(self.fea_fusion(aligned_fea))
#### spatial attention
att = self.lrelu(self.sAtt_1(aligned_fea))
att_max = self.maxpool(att)
att_avg = self.avgpool(att)
att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
# pyramid levels
att_L = self.lrelu(self.sAtt_L1(att))
att_max = self.maxpool(att_L)
att_avg = self.avgpool(att_L)
att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
att_L = self.lrelu(self.sAtt_L3(att_L))
att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
att = self.lrelu(self.sAtt_3(att))
att = att + att_L
att = self.lrelu(self.sAtt_4(att))
att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
att = self.sAtt_5(att)
att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
att = torch.sigmoid(att)
fea = fea * att * 2 + att_add
return fea
class EDVR(nn.Module):
def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None,
predeblur=False, HR_in=False, w_TSA=True):
super(EDVR, self).__init__()
self.nf = nf
self.center = nframes // 2 if center is None else center
self.is_predeblur = True if predeblur else False
self.HR_in = True if HR_in else False
self.w_TSA = w_TSA
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
#### extract features (for each frame)
if self.is_predeblur:
self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
else:
if self.HR_in:
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
else:
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.pcd_align = PCD_Align(nf=nf, groups=groups)
if self.w_TSA:
self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
else:
self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
#### reconstruction
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
#### activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
B, N, C, H, W = x.size() # N video frames
x_center = x[:, self.center, :, :, :].contiguous()
#### extract LR features
# L1
if self.is_predeblur:
L1_fea = self.pre_deblur(x.view(-1, C, H, W))
L1_fea = self.conv_1x1(L1_fea)
if self.HR_in:
H, W = H // 4, W // 4
else:
if self.HR_in:
L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W)))
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
H, W = H // 4, W // 4
else:
L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W)))
L1_fea = self.feature_extraction(L1_fea)
# L2
L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
# L3
L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
L1_fea = L1_fea.view(B, N, -1, H, W)
L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
#### pcd align
# ref feature list
ref_fea_l = [
L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(),
L3_fea[:, self.center, :, :, :].clone()
]
aligned_fea = []
for i in range(N):
nbr_fea_l = [
L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(),
L3_fea[:, i, :, :, :].clone()
]
aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l))
aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W]
if not self.w_TSA:
aligned_fea = aligned_fea.view(B, -1, H, W)
fea = self.tsa_fusion(aligned_fea)
out = self.recon_trunk(fea)
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.HRconv(out))
out = self.conv_last(out)
if self.HR_in:
base = x_center
else:
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
out += base
return out

View File

@ -1,137 +0,0 @@
'''PyTorch implementation of TOFlow
Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
Code reference:
1. https://github.com/anchen1011/toflow
2. https://github.com/Coldog2333/pytoflow
'''
import torch
import torch.nn as nn
from .arch_util import flow_warp
def normalize(x):
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
return (x - mean) / std
def denormalize(x):
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
return x * std + mean
class SpyNet_Block(nn.Module):
'''A submodule of SpyNet.'''
def __init__(self):
super(SpyNet_Block, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
def forward(self, x):
'''
input: x: [ref im, nbr im, initial flow] - (B, 8, H, W)
output: estimated flow - (B, 2, H, W)
'''
return self.block(x)
class SpyNet(nn.Module):
'''SpyNet for estimating optical flow
Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016'''
def __init__(self):
super(SpyNet, self).__init__()
self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)])
def forward(self, ref, nbr):
'''Estimating optical flow in coarse level, upsample, and estimate in fine level
input: ref: reference image - [B, 3, H, W]
nbr: the neighboring image to be warped - [B, 3, H, W]
output: estimated optical flow - [B, 2, H, W]
'''
B, C, H, W = ref.size()
ref = [ref]
nbr = [nbr]
for _ in range(3):
ref.insert(
0,
nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2,
count_include_pad=False))
nbr.insert(
0,
nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2,
count_include_pad=False))
flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0])
for i in range(4):
flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear',
align_corners=True) * 2.0
flow = flow_up + self.blocks[i](torch.cat(
[ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
return flow
class TOFlow(nn.Module):
def __init__(self, adapt_official=False):
super(TOFlow, self).__init__()
self.SpyNet = SpyNet()
self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4)
self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1)
self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1)
self.relu = nn.ReLU(inplace=True)
self.adapt_official = adapt_official # True if using translated official weights else False
def forward(self, x):
"""
input: x: input frames - [B, 7, 3, H, W]
output: SR reference frame - [B, 3, H, W]
"""
B, T, C, H, W = x.size()
x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W)
ref_idx = 3
x_ref = x[:, ref_idx, :, :, :]
# In the official torch code, the 0-th frame is the reference frame
if self.adapt_official:
x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
ref_idx = 0
x_warped = []
for i in range(7):
if i == ref_idx:
x_warped.append(x_ref)
else:
x_nbr = x[:, i, :, :, :]
flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1)
x_warped.append(flow_warp(x_nbr, flow))
x_warped = torch.stack(x_warped, dim=1)
x = x_warped.view(B, -1, H, W)
x = self.relu(self.conv_3x7_64_9x9(x))
x = self.relu(self.conv_64_64_9x9(x))
x = self.relu(self.conv_64_64_1x1(x))
x = self.conv_64_3_1x1(x) + x_ref
return denormalize(x)

View File

@ -1,7 +0,0 @@
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,
deform_conv, modulated_deform_conv)
__all__ = [
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv'
]

View File

@ -1,291 +0,0 @@
import math
import logging
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from . import deform_conv_cuda
logger = logging.getLogger('base')
class DeformConvFunction(Function):
@staticmethod
def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
deformable_groups=1, im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0],
ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_cuda.deform_conv_backward_input_cuda(
input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_cuda.deform_conv_backward_parameters_cuda(
input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, 1, cur_im2col_step)
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError("convolution input is too small (output would be {})".format('x'.join(
map(str, output_size))))
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
groups=1, deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
deform_conv_cuda.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding - (ctx.dilation *
(kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding - (ctx.dilation *
(kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
class DeformConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
groups=1, deformable_groups=1, bias=False):
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class DeformConvPack(DeformConv):
def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
groups=1, deformable_groups=1, bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self, *args, extra_offset_mask=False, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.extra_offset_mask = extra_offset_mask
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, x):
if self.extra_offset_mask:
# x = [input, features]
out = self.conv_offset_mask(x[1])
x = x[0]
else:
out = self.conv_offset_mask(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
offset_mean = torch.mean(torch.abs(offset))
if offset_mean > 100:
logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)

View File

@ -1,22 +0,0 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
def make_cuda_ext(name, sources):
return CUDAExtension(
name='{}'.format(name), sources=[p for p in sources], extra_compile_args={
'cxx': [],
'nvcc': [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
})
setup(
name='deform_conv', ext_modules=[
make_cuda_ext(name='deform_conv_cuda',
sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])
], cmdclass={'build_ext': BuildExtension}, zip_safe=False)

View File

@ -1,695 +0,0 @@
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h>
#include <cmath>
#include <vector>
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor data_col);
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im);
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const int channels, const int height,
const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor data_col);
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
int padW, int dilationH, int dilationW, int group,
int deformable_group) {
AT_CHECK(weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
AT_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
kW);
AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
kW, weight.size(2), weight.size(3));
AT_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
AT_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
AT_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
AT_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");
AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight, outputWidth, offset.size(2), offset.size(3));
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");
if (gradOutput != NULL) {
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));
AT_CHECK((gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
outputHeight, outputWidth, gradOutput->size(dimh),
gradOutput->size(dimw));
}
}
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input.unsqueeze_(0);
offset.unsqueeze_(0);
}
// todo: assert batchsize dividable by im2col_step
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputHeight, outputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
ones = at::ones({outputHeight, outputWidth}, input.options());
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
at::Tensor output_buffer =
at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth},
output.options());
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view(
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step) {
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view({1, input.size(0), input.size(1), input.size(2)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
// change order of grad output
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight,
outputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
gradOutput = gradOutput.view(
{gradOutput.size(0), group, gradOutput.size(1) / group,
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++) {
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, im2col_step, deformable_group,
gradOffset[elt]);
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step) {
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view(
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight =
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++) {
gradWeight[g] = gradWeight[g]
.flatten(1)
.addmm_(gradOutputBuffer[elt][g].flatten(1),
columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3),
gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
return 1;
}
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
"deform forward (CUDA)");
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
"deform_conv_backward_input (CUDA)");
m.def("deform_conv_backward_parameters_cuda",
&deform_conv_backward_parameters_cuda,
"deform_conv_backward_parameters (CUDA)");
m.def("modulated_deform_conv_cuda_forward",
&modulated_deform_conv_cuda_forward,
"modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward",
&modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
}

View File

@ -1,866 +0,0 @@
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <float.h>
using namespace at;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;
inline int GET_BLOCKS(const int N)
{
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}
template <typename scalar_t>
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const scalar_t map_h = i * dilation_h + offset_h;
//const scalar_t map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
void deformable_im2col(
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *data_col_ = data_col.data<scalar_t>();
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel(
const int n, const scalar_t *data_col, const scalar_t *data_offset,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index];
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
void deformable_col2im(
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im)
{
// todo: make sure parallel_imgs is passed in correctly
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *grad_im_ = grad_im.data<scalar_t>();
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
const scalar_t *data_im, const scalar_t *data_offset,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col, scalar_t *grad_offset)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
const scalar_t weight = get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos];
cnt += 1;
}
grad_offset[index] = val;
}
}
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
{
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_);
}));
}
template <typename scalar_t>
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
//data_col_ptr += height_col * width_col;
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_offset, scalar_t *grad_mask)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
else
{
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
}
const scalar_t weight = dmcn_get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
}
}
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *data_col_ = data_col.data<scalar_t>();
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im)
{
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *grad_im_ = grad_im.data<scalar_t>();
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
at::Tensor grad_offset, at::Tensor grad_mask)
{
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset_, grad_mask_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
}
}

View File

@ -1,208 +0,0 @@
'''
Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets
'''
import os
import os.path as osp
import glob
import logging
import numpy as np
import cv2
import torch
import utils.util as util
import data.util as data_util
import models.archs.EDVR_arch as EDVR_arch
def main():
#################
# configurations
#################
device = torch.device('cuda')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
# Vid4: SR
# REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
# blur (deblur-clean), blur_comp (deblur-compression).
stage = 1 # 1 or 2, use two stage strategy for REDS dataset.
flip_test = False
############################################################################
#### model
if data_mode == 'Vid4':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
else:
raise ValueError('Vid4 does not support stage 2.')
elif data_mode == 'sharp_bicubic':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
else:
model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
elif data_mode == 'blur_bicubic':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
else:
model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
elif data_mode == 'blur':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
else:
model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
elif data_mode == 'blur_comp':
if stage == 1:
model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
else:
model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
else:
raise NotImplementedError
if data_mode == 'Vid4':
N_in = 7 # use N_in images to restore one HR image
else:
N_in = 5
predeblur, HR_in = False, False
back_RBs = 40
if data_mode == 'blur_bicubic':
predeblur = True
if data_mode == 'blur' or data_mode == 'blur_comp':
predeblur, HR_in = True, True
if stage == 2:
HR_in = True
back_RBs = 20
model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
#### dataset
if data_mode == 'Vid4':
test_dataset_folder = '../datasets/Vid4/BIx4'
GT_dataset_folder = '../datasets/Vid4/GT'
else:
if stage == 1:
test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
else:
test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
print('You should modify the test_dataset_folder path for stage 2')
GT_dataset_folder = '../datasets/REDS4/GT'
#### evaluation
crop_border = 0
border_frame = N_in // 2 # border frames when evaluate
# temporal padding mode
if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
padding = 'new_info'
else:
padding = 'replicate'
save_imgs = True
save_folder = '../results/{}'.format(data_mode)
util.mkdirs(save_folder)
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
logger = logging.getLogger('base')
#### log info
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
logger.info('Flip test: {}'.format(flip_test))
#### set up the models
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
subfolder_name_l = []
subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
# for each subfolder
for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
subfolder_name = osp.basename(subfolder)
subfolder_name_l.append(subfolder_name)
save_subfolder = osp.join(save_folder, subfolder_name)
img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
max_idx = len(img_path_l)
if save_imgs:
util.mkdirs(save_subfolder)
#### read LQ and GT images
imgs_LQ = data_util.read_img_seq(subfolder)
img_GT_l = []
for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
img_GT_l.append(data_util.read_img(None, img_GT_path))
avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
# process each image
for img_idx, img_path in enumerate(img_path_l):
img_name = osp.splitext(osp.basename(img_path))[0]
select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
if flip_test:
output = util.flipx4_forward(model, imgs_in)
else:
output = util.single_forward(model, imgs_in)
output = util.tensor2img(output.squeeze(0))
if save_imgs:
cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
# calculate PSNR
output = output / 255.
GT = np.copy(img_GT_l[img_idx])
# For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
if data_mode == 'Vid4': # bgr2y, [0, 1]
GT = data_util.bgr2ycbcr(GT, only_y=True)
output = data_util.bgr2ycbcr(output, only_y=True)
output, GT = util.crop_border([output, GT], crop_border)
crt_psnr = util.calculate_psnr(output * 255, GT * 255)
logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
avg_psnr_center += crt_psnr
N_center += 1
else: # border frames
avg_psnr_border += crt_psnr
N_border += 1
avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
avg_psnr_center = avg_psnr_center / N_center
avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
avg_psnr_l.append(avg_psnr)
avg_psnr_center_l.append(avg_psnr_center)
avg_psnr_border_l.append(avg_psnr_border)
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
'Center PSNR: {:.6f} dB for {} frames; '
'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
(N_center + N_border),
avg_psnr_center, N_center,
avg_psnr_border, N_border))
logger.info('################ Tidy Outputs ################')
for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
avg_psnr_center_l, avg_psnr_border_l):
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
'Center PSNR: {:.6f} dB. '
'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
psnr_border))
logger.info('################ Final Results ################')
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
logger.info('Flip test: {}'.format(flip_test))
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
if __name__ == '__main__':
main()

View File

@ -1,264 +0,0 @@
"""
DUF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
write to txt log file
"""
import os
import os.path as osp
import glob
import logging
import numpy as np
import cv2
import torch
import utils.util as util
import data.util as data_util
import models.archs.DUF_arch as DUF_arch
def main():
#################
# configurations
#################
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
# Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
scale = 4
layer = 52
assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28),
(4, 52)], 'Unrecognized (scale, layer) combination'
# model
N_in = 7
model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format(scale, layer)
adapt_official = True if 'official' in model_path else False
DUF_downsampling = True # True | False
if layer == 16:
model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official)
elif layer == 28:
model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official)
elif layer == 52:
model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official)
#### dataset
if data_mode == 'Vid4':
test_dataset_folder = '../datasets/Vid4/BIx4/*'
else: # sharp_bicubic (REDS)
test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
#### evaluation
crop_border = 8
border_frame = N_in // 2 # border frames when evaluate
# temporal padding mode
padding = 'new_info' # different from the official testing codes, which pads zeros.
save_imgs = True
############################################################################
device = torch.device('cuda')
save_folder = '../results/{}'.format(data_mode)
util.mkdirs(save_folder)
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
logger = logging.getLogger('base')
#### log info
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
def read_image(img_path):
'''read one image from img_path
Return img: HWC, BGR, [0,1], numpy
'''
img_GT = cv2.imread(img_path)
img = img_GT.astype(np.float32) / 255.
return img
def read_seq_imgs(img_seq_path):
'''read a sequence of images'''
img_path_l = sorted(glob.glob(img_seq_path + '/*'))
img_l = [read_image(v) for v in img_path_l]
# stack to TCHW, RGB, [0,1], torch
imgs = np.stack(img_l, axis=0)
imgs = imgs[:, :, :, [2, 1, 0]]
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
return imgs
def index_generation(crt_i, max_n, N, padding='reflection'):
'''
padding: replicate | reflection | new_info | circle
'''
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
return return_l
def single_forward(model, imgs_in):
with torch.no_grad():
model_output = model(imgs_in)
if isinstance(model_output, list) or isinstance(model_output, tuple):
output = model_output[0]
else:
output = model_output
return output
sub_folder_l = sorted(glob.glob(test_dataset_folder))
#### set up the models
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
sub_folder_name_l = []
# for each sub-folder
for sub_folder in sub_folder_l:
sub_folder_name = sub_folder.split('/')[-1]
sub_folder_name_l.append(sub_folder_name)
save_sub_folder = osp.join(save_folder, sub_folder_name)
img_path_l = sorted(glob.glob(sub_folder + '/*'))
max_idx = len(img_path_l)
if save_imgs:
util.mkdirs(save_sub_folder)
#### read LR images
imgs = read_seq_imgs(sub_folder)
#### read GT images
img_GT_l = []
if data_mode == 'Vid4':
sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*')
else:
sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
for img_GT_path in sorted(glob.glob(sub_folder_GT)):
img_GT_l.append(read_image(img_GT_path))
# When using the downsampling in DUF official code, we downsample the HR images
if DUF_downsampling:
sub_folder = sub_folder_GT
img_path_l = sorted(glob.glob(sub_folder))
max_idx = len(img_path_l)
imgs = read_seq_imgs(sub_folder[:-2])
avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
cal_n_border, cal_n_center = 0, 0
# process each image
for img_idx, img_path in enumerate(img_path_l):
c_idx = int(osp.splitext(osp.basename(img_path))[0])
select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
# get input images
imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
# Downsample the HR images
H, W = imgs_in.size(3), imgs_in.size(4)
if DUF_downsampling:
imgs_in = util.DUF_downsample(imgs_in, scale=scale)
output = single_forward(model, imgs_in)
# Crop to the original shape
if scale == 3:
pad_h = 3 - (H % 3)
pad_w = 3 - (W % 3)
if pad_h > 0:
output = output[:, :, :-pad_h, :]
if pad_w > 0:
output = output[:, :, :, :-pad_w]
output_f = output.data.float().cpu().squeeze(0)
output = util.tensor2img(output_f)
# save imgs
if save_imgs:
cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
#### calculate PSNR
output = output / 255.
GT = np.copy(img_GT_l[img_idx])
# For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
if data_mode == 'Vid4': # bgr2y, [0, 1]
GT = data_util.bgr2ycbcr(GT)
output = data_util.bgr2ycbcr(output)
if crop_border == 0:
cropped_output = output
cropped_GT = GT
else:
cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
avg_psnr_center += crt_psnr
cal_n_center += 1
else: # border frames
avg_psnr_border += crt_psnr
cal_n_border += 1
avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
avg_psnr_center = avg_psnr_center / cal_n_center
if cal_n_border == 0:
avg_psnr_border = 0
else:
avg_psnr_border = avg_psnr_border / cal_n_border
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
'Center PSNR: {:.6f} dB for {} frames; '
'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
(cal_n_center + cal_n_border),
avg_psnr_center, cal_n_center,
avg_psnr_border, cal_n_border))
avg_psnr_l.append(avg_psnr)
avg_psnr_center_l.append(avg_psnr_center)
avg_psnr_border_l.append(avg_psnr_border)
logger.info('################ Tidy Outputs ################')
for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
avg_psnr_center_l, avg_psnr_border_l):
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
'Center PSNR: {:.6f} dB. '
'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
logger.info('################ Final Results ################')
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
if __name__ == '__main__':
main()

View File

@ -1,230 +0,0 @@
"""
TOF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
write to txt log file
"""
import os
import os.path as osp
import glob
import logging
import numpy as np
import cv2
import torch
import utils.util as util
import data.util as data_util
import models.archs.TOF_arch as TOF_arch
def main():
#################
# configurations
#################
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
# model
N_in = 7
model_path = '../experiments/pretrained_models/TOF_official.pth'
adapt_official = True if 'official' in model_path else False
model = TOF_arch.TOFlow(adapt_official=adapt_official)
#### dataset
if data_mode == 'Vid4':
test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*'
else:
test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
#### evaluation
crop_border = 0
border_frame = N_in // 2 # border frames when evaluate
# temporal padding mode
padding = 'new_info' # different from the official setting
save_imgs = True
############################################################################
device = torch.device('cuda')
save_folder = '../results/{}'.format(data_mode)
util.mkdirs(save_folder)
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
logger = logging.getLogger('base')
#### log info
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
def read_image(img_path):
'''read one image from img_path
Return img: HWC, BGR, [0,1], numpy
'''
img_GT = cv2.imread(img_path)
img = img_GT.astype(np.float32) / 255.
return img
def read_seq_imgs(img_seq_path):
'''read a sequence of images'''
img_path_l = sorted(glob.glob(img_seq_path + '/*'))
img_l = [read_image(v) for v in img_path_l]
# stack to TCHW, RGB, [0,1], torch
imgs = np.stack(img_l, axis=0)
imgs = imgs[:, :, :, [2, 1, 0]]
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
return imgs
def index_generation(crt_i, max_n, N, padding='reflection'):
'''
padding: replicate | reflection | new_info | circle
'''
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
return return_l
def single_forward(model, imgs_in):
with torch.no_grad():
model_output = model(imgs_in)
if isinstance(model_output, list) or isinstance(model_output, tuple):
output = model_output[0]
else:
output = model_output
return output
sub_folder_l = sorted(glob.glob(test_dataset_folder))
#### set up the models
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
sub_folder_name_l = []
# for each sub-folder
for sub_folder in sub_folder_l:
sub_folder_name = sub_folder.split('/')[-1]
sub_folder_name_l.append(sub_folder_name)
save_sub_folder = osp.join(save_folder, sub_folder_name)
img_path_l = sorted(glob.glob(sub_folder + '/*'))
max_idx = len(img_path_l)
if save_imgs:
util.mkdirs(save_sub_folder)
#### read LR images
imgs = read_seq_imgs(sub_folder)
#### read GT images
img_GT_l = []
if data_mode == 'Vid4':
sub_folder_GT = osp.join(sub_folder.replace('/BIx4up_direct/', '/GT/'), '*')
else:
sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
for img_GT_path in sorted(glob.glob(sub_folder_GT)):
img_GT_l.append(read_image(img_GT_path))
avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
cal_n_border, cal_n_center = 0, 0
# process each image
for img_idx, img_path in enumerate(img_path_l):
c_idx = int(osp.splitext(osp.basename(img_path))[0])
select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
# get input images
imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
output = single_forward(model, imgs_in)
output_f = output.data.float().cpu().squeeze(0)
output = util.tensor2img(output_f)
# save imgs
if save_imgs:
cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
#### calculate PSNR
output = output / 255.
GT = np.copy(img_GT_l[img_idx])
# For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
if data_mode == 'Vid4': # bgr2y, [0, 1]
GT = data_util.bgr2ycbcr(GT)
output = data_util.bgr2ycbcr(output)
if crop_border == 0:
cropped_output = output
cropped_GT = GT
else:
cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
avg_psnr_center += crt_psnr
cal_n_center += 1
else: # border frames
avg_psnr_border += crt_psnr
cal_n_border += 1
avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
avg_psnr_center = avg_psnr_center / cal_n_center
if cal_n_border == 0:
avg_psnr_border = 0
else:
avg_psnr_border = avg_psnr_border / cal_n_border
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
'Center PSNR: {:.6f} dB for {} frames; '
'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
(cal_n_center + cal_n_border),
avg_psnr_center, cal_n_center,
avg_psnr_border, cal_n_border))
avg_psnr_l.append(avg_psnr)
avg_psnr_center_l.append(avg_psnr_center)
avg_psnr_border_l.append(avg_psnr_border)
logger.info('################ Tidy Outputs ################')
for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
avg_psnr_center_l, avg_psnr_border_l):
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
'Center PSNR: {:.6f} dB. '
'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
logger.info('################ Final Results ################')
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
logger.info('Padding mode: {}'.format(padding))
logger.info('Model path: {}'.format(model_path))
logger.info('Save images: {}'.format(save_imgs))
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
if __name__ == '__main__':
main()