Clean up video stuff
This commit is contained in:
parent
8464cae168
commit
4e44b8a1aa
|
@ -1,167 +0,0 @@
|
|||
'''
|
||||
Vimeo90K dataset
|
||||
support reading images from lmdb, image folder and memcached
|
||||
'''
|
||||
import os.path as osp
|
||||
import random
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
import cv2
|
||||
import lmdb
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
try:
|
||||
import mc # import memcached
|
||||
except ImportError:
|
||||
pass
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
class Vimeo90KDataset(data.Dataset):
|
||||
'''
|
||||
Reading the training Vimeo90K dataset
|
||||
key example: 00001_0001 (_1, ..., _7)
|
||||
GT (Ground-Truth): 4th frame;
|
||||
LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame
|
||||
'''
|
||||
|
||||
def __init__(self, opt):
|
||||
super(Vimeo90KDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# temporal augmentation
|
||||
self.interval_list = opt['interval_list']
|
||||
self.random_reverse = opt['random_reverse']
|
||||
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
|
||||
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
|
||||
|
||||
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||||
self.data_type = self.opt['data_type']
|
||||
self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs
|
||||
|
||||
#### determine the LQ frame list
|
||||
'''
|
||||
N | frames
|
||||
1 | 4
|
||||
3 | 3,4,5
|
||||
5 | 2,3,4,5,6
|
||||
7 | 1,2,3,4,5,6,7
|
||||
'''
|
||||
self.LQ_frames_list = []
|
||||
for i in range(opt['N_frames']):
|
||||
self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2)
|
||||
|
||||
#### directly load image keys
|
||||
if self.data_type == 'lmdb':
|
||||
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
logger.info('Using lmdb meta info for cache keys.')
|
||||
elif opt['cache_keys']:
|
||||
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
|
||||
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
|
||||
else:
|
||||
raise ValueError(
|
||||
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
|
||||
if self.data_type == 'lmdb':
|
||||
self.GT_env, self.LQ_env = None, None
|
||||
elif self.data_type == 'mc': # memcached
|
||||
self.mclient = None
|
||||
elif self.data_type == 'img':
|
||||
pass
|
||||
else:
|
||||
raise ValueError('Wrong data type: {}'.format(self.data_type))
|
||||
|
||||
def _init_lmdb(self):
|
||||
# https://github.com/chainer/chainermn/issues/129
|
||||
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
|
||||
def _ensure_memcached(self):
|
||||
if self.mclient is None:
|
||||
# specify the config files
|
||||
server_list_config_file = None
|
||||
client_config_file = None
|
||||
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
|
||||
client_config_file)
|
||||
|
||||
def _read_img_mc(self, path):
|
||||
''' Return BGR, HWC, [0, 255], uint8'''
|
||||
value = mc.pyvector()
|
||||
self.mclient.Get(path, value)
|
||||
value_buf = mc.ConvertBuffer(value)
|
||||
img_array = np.frombuffer(value_buf, np.uint8)
|
||||
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
|
||||
return img
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.data_type == 'mc':
|
||||
self._ensure_memcached()
|
||||
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||||
self._init_lmdb()
|
||||
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['target_size']
|
||||
key = self.paths_GT[index]
|
||||
name_a, name_b = key.split('_')
|
||||
#### get the GT image (as the center frame)
|
||||
if self.data_type == 'mc':
|
||||
img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png'))
|
||||
img_GT = img_GT.astype(np.float32) / 255.
|
||||
elif self.data_type == 'lmdb':
|
||||
img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448))
|
||||
else:
|
||||
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png'))
|
||||
|
||||
#### get LQ images
|
||||
LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
|
||||
img_LQ_l = []
|
||||
for v in self.LQ_frames_list:
|
||||
if self.data_type == 'mc':
|
||||
img_LQ = self._read_img_mc(
|
||||
osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v)))
|
||||
img_LQ = img_LQ.astype(np.float32) / 255.
|
||||
elif self.data_type == 'lmdb':
|
||||
img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
|
||||
else:
|
||||
img_LQ = util.read_img(None,
|
||||
osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
|
||||
img_LQ_l.append(img_LQ)
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
C, H, W = LQ_size_tuple # LQ size
|
||||
# randomly crop
|
||||
if self.LR_input:
|
||||
LQ_size = GT_size // scale
|
||||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
|
||||
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
|
||||
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
|
||||
else:
|
||||
rnd_h = random.randint(0, max(0, H - GT_size))
|
||||
rnd_w = random.randint(0, max(0, W - GT_size))
|
||||
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
|
||||
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_LQ_l.append(img_GT)
|
||||
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
|
||||
img_LQ_l = rlt[0:-1]
|
||||
img_GT = rlt[-1]
|
||||
|
||||
# stack LQ images to NHWC, N is the frame number
|
||||
img_LQs = np.stack(img_LQ_l, axis=0)
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
|
||||
(0, 3, 1, 2)))).float()
|
||||
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths_GT)
|
|
@ -1,84 +0,0 @@
|
|||
import os.path as osp
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
|
||||
|
||||
class VideoTestDataset(data.Dataset):
|
||||
"""
|
||||
A video test dataset. Support:
|
||||
Vid4
|
||||
REDS4
|
||||
Vimeo90K-Test
|
||||
|
||||
no need to prepare LMDB files
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(VideoTestDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.cache_data = opt['cache_data']
|
||||
self.half_N_frames = opt['N_frames'] // 2
|
||||
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||||
self.data_type = self.opt['data_type']
|
||||
self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []}
|
||||
if self.data_type == 'lmdb':
|
||||
raise ValueError('No need to use LMDB during validation/test.')
|
||||
#### Generate data info and cache data
|
||||
self.imgs_LQ, self.imgs_GT = {}, {}
|
||||
if opt['name'].lower() in ['vid4', 'reds4']:
|
||||
subfolders_LQ = util.glob_file_list(self.LQ_root)
|
||||
subfolders_GT = util.glob_file_list(self.GT_root)
|
||||
for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT):
|
||||
subfolder_name = osp.basename(subfolder_GT)
|
||||
img_paths_LQ = util.glob_file_list(subfolder_LQ)
|
||||
img_paths_GT = util.glob_file_list(subfolder_GT)
|
||||
max_idx = len(img_paths_LQ)
|
||||
assert max_idx == len(
|
||||
img_paths_GT), 'Different number of images in LQ and GT folders'
|
||||
self.data_info['path_LQ'].extend(img_paths_LQ)
|
||||
self.data_info['path_GT'].extend(img_paths_GT)
|
||||
self.data_info['folder'].extend([subfolder_name] * max_idx)
|
||||
for i in range(max_idx):
|
||||
self.data_info['idx'].append('{}/{}'.format(i, max_idx))
|
||||
border_l = [0] * max_idx
|
||||
for i in range(self.half_N_frames):
|
||||
border_l[i] = 1
|
||||
border_l[max_idx - i - 1] = 1
|
||||
self.data_info['border'].extend(border_l)
|
||||
|
||||
if self.cache_data:
|
||||
self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ)
|
||||
self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT)
|
||||
elif opt['name'].lower() in ['vimeo90k-test']:
|
||||
pass # TODO
|
||||
else:
|
||||
raise ValueError(
|
||||
'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.')
|
||||
|
||||
def __getitem__(self, index):
|
||||
# path_LQ = self.data_info['path_LQ'][index]
|
||||
# path_GT = self.data_info['path_GT'][index]
|
||||
folder = self.data_info['folder'][index]
|
||||
idx, max_idx = self.data_info['idx'][index].split('/')
|
||||
idx, max_idx = int(idx), int(max_idx)
|
||||
border = self.data_info['border'][index]
|
||||
|
||||
if self.cache_data:
|
||||
select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
|
||||
padding=self.opt['padding'])
|
||||
imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx))
|
||||
img_GT = self.imgs_GT[folder][idx]
|
||||
else:
|
||||
pass # TODO
|
||||
|
||||
return {
|
||||
'LQs': imgs_LQ,
|
||||
'GT': img_GT,
|
||||
'folder': folder,
|
||||
'idx': self.data_info['idx'][index],
|
||||
'border': border
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_info['path_GT'])
|
|
@ -1,166 +0,0 @@
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
import models.networks as networks
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
from .base_model import BaseModel
|
||||
from models.loss import CharbonnierLoss
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
class VideoBaseModel(BaseModel):
|
||||
def __init__(self, opt):
|
||||
super(VideoBaseModel, self).__init__(opt)
|
||||
|
||||
if opt['dist']:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self.rank = -1 # non dist training
|
||||
train_opt = opt['train']
|
||||
|
||||
# define network and load pretrained models
|
||||
self.netG = networks.define_G(opt).to(self.device)
|
||||
if opt['dist']:
|
||||
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||
else:
|
||||
self.netG = DataParallel(self.netG)
|
||||
# print network
|
||||
self.print_network()
|
||||
self.load()
|
||||
|
||||
if self.is_train:
|
||||
self.netG.train()
|
||||
|
||||
#### loss
|
||||
loss_type = train_opt['pixel_criterion']
|
||||
if loss_type == 'l1':
|
||||
self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
|
||||
elif loss_type == 'l2':
|
||||
self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
|
||||
elif loss_type == 'cb':
|
||||
self.cri_pix = CharbonnierLoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
|
||||
self.l_pix_w = train_opt['pixel_weight']
|
||||
|
||||
#### optimizers
|
||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||
if train_opt['ft_tsa_only']:
|
||||
normal_params = []
|
||||
tsa_fusion_params = []
|
||||
for k, v in self.netG.named_parameters():
|
||||
if v.requires_grad:
|
||||
if 'tsa_fusion' in k:
|
||||
tsa_fusion_params.append(v)
|
||||
else:
|
||||
normal_params.append(v)
|
||||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
optim_params = [
|
||||
{ # add normal params first
|
||||
'params': normal_params,
|
||||
'lr': train_opt['lr_G']
|
||||
},
|
||||
{
|
||||
'params': tsa_fusion_params,
|
||||
'lr': train_opt['lr_G']
|
||||
},
|
||||
]
|
||||
else:
|
||||
optim_params = []
|
||||
for k, v in self.netG.named_parameters():
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
|
||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||
weight_decay=wd_G,
|
||||
betas=(train_opt['beta1'], train_opt['beta2']))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
#### schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
||||
restarts=train_opt['restarts'],
|
||||
weights=train_opt['restart_weights'],
|
||||
gamma=train_opt['lr_gamma'],
|
||||
clear_state=train_opt['clear_state']))
|
||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
lr_scheduler.CosineAnnealingLR_Restart(
|
||||
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
||||
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.log_dict = OrderedDict()
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
self.var_L = data['LQs'].to(self.device)
|
||||
if need_GT:
|
||||
self.real_H = data['GT'].to(self.device)
|
||||
|
||||
def set_params_lr_zero(self):
|
||||
# fix normal module
|
||||
self.optimizers[0].param_groups[0]['lr'] = 0
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
|
||||
self.set_params_lr_zero()
|
||||
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
|
||||
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
|
||||
l_pix.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# set log
|
||||
self.log_dict['l_pix'] = l_pix.item()
|
||||
|
||||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
self.netG.train()
|
||||
|
||||
def get_current_log(self):
|
||||
return self.log_dict
|
||||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
|
||||
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
|
||||
if need_GT:
|
||||
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
|
||||
return out_dict
|
||||
|
||||
def print_network(self):
|
||||
s, n = self.get_network_description(self.netG)
|
||||
if isinstance(self.netG, nn.DataParallel):
|
||||
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
||||
self.netG.module.__class__.__name__)
|
||||
else:
|
||||
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
||||
if self.rank <= 0:
|
||||
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info(s)
|
||||
|
||||
def load(self):
|
||||
load_path_G = self.opt['path']['pretrain_model_G']
|
||||
if load_path_G is not None:
|
||||
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
||||
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
||||
|
||||
def save(self, iter_label):
|
||||
self.save_network(self.netG, 'G', iter_label)
|
|
@ -1,368 +0,0 @@
|
|||
'''Network architecture for DUF:
|
||||
Deep Video Super-Resolution Network Using Dynamic Upsampling Filters
|
||||
Without Explicit Motion Compensation (CVPR18)
|
||||
https://github.com/yhjo09/VSR-DUF
|
||||
|
||||
For all the models below, [adapt_official] is only necessary when
|
||||
loading the weights converted from the official TensorFlow weights.
|
||||
Please set it to [False] if you are training the model from scratch.
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def adapt_official(Rx, scale=4):
|
||||
'''Adapt the weights translated from the official tensorflow weights
|
||||
Not necessary if you are training from scratch'''
|
||||
x = Rx.clone()
|
||||
x1 = x[:, ::3, :, :]
|
||||
x2 = x[:, 1::3, :, :]
|
||||
x3 = x[:, 2::3, :, :]
|
||||
|
||||
Rx[:, :scale**2, :, :] = x1
|
||||
Rx[:, scale**2:2 * (scale**2), :, :] = x2
|
||||
Rx[:, 2 * (scale**2):, :, :] = x3
|
||||
|
||||
return Rx
|
||||
|
||||
|
||||
class DenseBlock(nn.Module):
|
||||
'''Dense block
|
||||
for the second denseblock, t_reduced = True'''
|
||||
|
||||
def __init__(self, nf=64, ng=32, t_reduce=False):
|
||||
super(DenseBlock, self).__init__()
|
||||
self.t_reduce = t_reduce
|
||||
if self.t_reduce:
|
||||
pad = (0, 1, 1)
|
||||
else:
|
||||
pad = (1, 1, 1)
|
||||
self.bn3d_1 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_1 = nn.Conv3d(nf, nf, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
||||
self.bn3d_2 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_2 = nn.Conv3d(nf, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
|
||||
self.bn3d_3 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_3 = nn.Conv3d(nf + ng, nf + ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True)
|
||||
self.bn3d_4 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_4 = nn.Conv3d(nf + ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)
|
||||
self.bn3d_5 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_5 = nn.Conv3d(nf + 2 * ng, nf + 2 * ng, (1, 1, 1), stride=(1, 1, 1),
|
||||
padding=(0, 0, 0), bias=True)
|
||||
self.bn3d_6 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_6 = nn.Conv3d(nf + 2 * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad,
|
||||
bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
'''x: [B, C, T, H, W]
|
||||
C: nf -> nf + 3 * ng
|
||||
T: 1) 7 -> 7 (t_reduce=False);
|
||||
2) 7 -> 7 - 2 * 3 = 1 (t_reduce=True)'''
|
||||
x1 = self.conv3d_1(F.relu(self.bn3d_1(x), inplace=True))
|
||||
x1 = self.conv3d_2(F.relu(self.bn3d_2(x1), inplace=True))
|
||||
if self.t_reduce:
|
||||
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
|
||||
else:
|
||||
x1 = torch.cat((x, x1), 1)
|
||||
|
||||
x2 = self.conv3d_3(F.relu(self.bn3d_3(x1), inplace=True))
|
||||
x2 = self.conv3d_4(F.relu(self.bn3d_4(x2), inplace=True))
|
||||
if self.t_reduce:
|
||||
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
|
||||
else:
|
||||
x2 = torch.cat((x1, x2), 1)
|
||||
|
||||
x3 = self.conv3d_5(F.relu(self.bn3d_5(x2), inplace=True))
|
||||
x3 = self.conv3d_6(F.relu(self.bn3d_6(x3), inplace=True))
|
||||
if self.t_reduce:
|
||||
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
|
||||
else:
|
||||
x3 = torch.cat((x2, x3), 1)
|
||||
return x3
|
||||
|
||||
|
||||
class DynamicUpsamplingFilter_3C(nn.Module):
|
||||
'''dynamic upsampling filter with 3 channels applying the same filters
|
||||
filter_size: filter size of the generated filters, shape (C, kH, kW)'''
|
||||
|
||||
def __init__(self, filter_size=(1, 5, 5)):
|
||||
super(DynamicUpsamplingFilter_3C, self).__init__()
|
||||
# generate a local expansion filter, used similar to im2col
|
||||
nF = np.prod(filter_size)
|
||||
expand_filter_np = np.reshape(np.eye(nF, nF),
|
||||
(nF, filter_size[0], filter_size[1], filter_size[2]))
|
||||
expand_filter = torch.from_numpy(expand_filter_np).float()
|
||||
self.expand_filter = torch.cat((expand_filter, expand_filter, expand_filter),
|
||||
0) # [75, 1, 5, 5]
|
||||
|
||||
def forward(self, x, filters):
|
||||
'''x: input image, [B, 3, H, W]
|
||||
filters: generate dynamic filters, [B, F, R, H, W], e.g., [B, 25, 16, H, W]
|
||||
F: prod of filter kernel size, e.g., 5*5 = 25
|
||||
R: used for upsampling, similar to pixel shuffle, e.g., 4*4 = 16 for x4
|
||||
Return: filtered image, [B, 3*R, H, W]
|
||||
'''
|
||||
B, nF, R, H, W = filters.size()
|
||||
# using group convolution
|
||||
input_expand = F.conv2d(x, self.expand_filter.type_as(x), padding=2,
|
||||
groups=3) # [B, 75, H, W] similar to im2col
|
||||
input_expand = input_expand.view(B, 3, nF, H, W).permute(0, 3, 4, 1, 2) # [B, H, W, 3, 25]
|
||||
filters = filters.permute(0, 3, 4, 1, 2) # [B, H, W, 25, 16]
|
||||
out = torch.matmul(input_expand, filters) # [B, H, W, 3, 16]
|
||||
return out.permute(0, 3, 4, 1, 2).view(B, 3 * R, H, W) # [B, 3*16, H, W]
|
||||
|
||||
|
||||
class DUF_16L(nn.Module):
|
||||
'''Official DUF structure with 16 layers'''
|
||||
|
||||
def __init__(self, scale=4, adapt_official=False):
|
||||
super(DUF_16L, self).__init__()
|
||||
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
||||
self.dense_block_1 = DenseBlock(64, 64 // 2, t_reduce=False) # 64 + 32 * 3 = 160, T = 7
|
||||
self.dense_block_2 = DenseBlock(160, 64 // 2, t_reduce=True) # 160 + 32 * 3 = 256, T = 1
|
||||
self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
|
||||
bias=True)
|
||||
|
||||
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True)
|
||||
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||
padding=(0, 0, 0), bias=True)
|
||||
|
||||
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True)
|
||||
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||
padding=(0, 0, 0), bias=True)
|
||||
|
||||
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
|
||||
|
||||
self.scale = scale
|
||||
self.adapt_official = adapt_official
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
|
||||
Generate filters and image residual:
|
||||
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
|
||||
Rx: [B, 3*16, 1, H, W]
|
||||
'''
|
||||
B, T, C, H, W = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] for Conv3D
|
||||
x_center = x[:, :, T // 2, :, :]
|
||||
|
||||
x = self.conv3d_1(x)
|
||||
x = self.dense_block_1(x)
|
||||
x = self.dense_block_2(x) # reduce T to 1
|
||||
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
|
||||
|
||||
# image residual
|
||||
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
|
||||
|
||||
# filter
|
||||
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
|
||||
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
|
||||
|
||||
# Adapt to official model weights
|
||||
if self.adapt_official:
|
||||
adapt_official(Rx, scale=self.scale)
|
||||
|
||||
# dynamic filter
|
||||
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
|
||||
out += Rx.squeeze_(2)
|
||||
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DenseBlock_28L(nn.Module):
|
||||
'''The first part of the dense blocks used in DUF_28L
|
||||
Temporal dimension remains the same here'''
|
||||
|
||||
def __init__(self, nf=64, ng=16):
|
||||
super(DenseBlock_28L, self).__init__()
|
||||
pad = (1, 1, 1)
|
||||
|
||||
dense_block_l = []
|
||||
for i in range(0, 9):
|
||||
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
|
||||
dense_block_l.append(nn.ReLU())
|
||||
dense_block_l.append(
|
||||
nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True))
|
||||
|
||||
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
|
||||
dense_block_l.append(nn.ReLU())
|
||||
dense_block_l.append(
|
||||
nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
|
||||
|
||||
self.dense_blocks = nn.ModuleList(dense_block_l)
|
||||
|
||||
def forward(self, x):
|
||||
'''x: [B, C, T, H, W]
|
||||
C: 1) 64 -> 208;
|
||||
T: 1) 7 -> 7; (t_reduce=True)'''
|
||||
for i in range(0, len(self.dense_blocks), 6):
|
||||
y = x
|
||||
for j in range(6):
|
||||
y = self.dense_blocks[i + j](y)
|
||||
x = torch.cat((x, y), 1)
|
||||
return x
|
||||
|
||||
|
||||
class DUF_28L(nn.Module):
|
||||
'''Official DUF structure with 28 layers'''
|
||||
|
||||
def __init__(self, scale=4, adapt_official=False):
|
||||
super(DUF_28L, self).__init__()
|
||||
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
||||
self.dense_block_1 = DenseBlock_28L(64, 16) # 64 + 16 * 9 = 208, T = 7
|
||||
self.dense_block_2 = DenseBlock(208, 16, t_reduce=True) # 208 + 16 * 3 = 256, T = 1
|
||||
self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
|
||||
bias=True)
|
||||
|
||||
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True)
|
||||
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||
padding=(0, 0, 0), bias=True)
|
||||
|
||||
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True)
|
||||
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||
padding=(0, 0, 0), bias=True)
|
||||
|
||||
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
|
||||
|
||||
self.scale = scale
|
||||
self.adapt_official = adapt_official
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
|
||||
Generate filters and image residual:
|
||||
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
|
||||
Rx: [B, 3*16, 1, H, W]
|
||||
'''
|
||||
B, T, C, H, W = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
|
||||
x_center = x[:, :, T // 2, :, :]
|
||||
x = self.conv3d_1(x)
|
||||
x = self.dense_block_1(x)
|
||||
x = self.dense_block_2(x) # reduce T to 1
|
||||
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
|
||||
|
||||
# image residual
|
||||
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
|
||||
|
||||
# filter
|
||||
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
|
||||
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
|
||||
|
||||
# Adapt to official model weights
|
||||
if self.adapt_official:
|
||||
adapt_official(Rx, scale=self.scale)
|
||||
|
||||
# dynamic filter
|
||||
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
|
||||
out += Rx.squeeze_(2)
|
||||
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
|
||||
return out
|
||||
|
||||
|
||||
class DenseBlock_52L(nn.Module):
|
||||
'''The first part of the dense blocks used in DUF_52L
|
||||
Temporal dimension remains the same here'''
|
||||
|
||||
def __init__(self, nf=64, ng=16):
|
||||
super(DenseBlock_52L, self).__init__()
|
||||
pad = (1, 1, 1)
|
||||
|
||||
dense_block_l = []
|
||||
for i in range(0, 21):
|
||||
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
|
||||
dense_block_l.append(nn.ReLU())
|
||||
dense_block_l.append(
|
||||
nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True))
|
||||
|
||||
dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3))
|
||||
dense_block_l.append(nn.ReLU())
|
||||
dense_block_l.append(
|
||||
nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True))
|
||||
|
||||
self.dense_blocks = nn.ModuleList(dense_block_l)
|
||||
|
||||
def forward(self, x):
|
||||
'''x: [B, C, T, H, W]
|
||||
C: 1) 64 -> 400;
|
||||
T: 1) 7 -> 7; (t_reduce=True)'''
|
||||
for i in range(0, len(self.dense_blocks), 6):
|
||||
y = x
|
||||
for j in range(6):
|
||||
y = self.dense_blocks[i + j](y)
|
||||
x = torch.cat((x, y), 1)
|
||||
return x
|
||||
|
||||
|
||||
class DUF_52L(nn.Module):
|
||||
'''Official DUF structure with 52 layers'''
|
||||
|
||||
def __init__(self, scale=4, adapt_official=False):
|
||||
super(DUF_52L, self).__init__()
|
||||
self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
||||
self.dense_block_1 = DenseBlock_52L(64, 16) # 64 + 21 * 9 = 400, T = 7
|
||||
self.dense_block_2 = DenseBlock(400, 16, t_reduce=True) # 400 + 16 * 3 = 448, T = 1
|
||||
|
||||
self.bn3d_2 = nn.BatchNorm3d(448, eps=1e-3, momentum=1e-3)
|
||||
self.conv3d_2 = nn.Conv3d(448, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1),
|
||||
bias=True)
|
||||
|
||||
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True)
|
||||
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||
padding=(0, 0, 0), bias=True)
|
||||
|
||||
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0),
|
||||
bias=True)
|
||||
self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1),
|
||||
padding=(0, 0, 0), bias=True)
|
||||
|
||||
self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5))
|
||||
|
||||
self.scale = scale
|
||||
self.adapt_official = adapt_official
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D
|
||||
Generate filters and image residual:
|
||||
Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C
|
||||
Rx: [B, 3*16, 1, H, W]
|
||||
'''
|
||||
B, T, C, H, W = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D
|
||||
x_center = x[:, :, T // 2, :, :]
|
||||
x = self.conv3d_1(x)
|
||||
x = self.dense_block_1(x)
|
||||
x = self.dense_block_2(x)
|
||||
x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True)
|
||||
|
||||
# image residual
|
||||
Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W]
|
||||
|
||||
# filter
|
||||
Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W]
|
||||
Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1)
|
||||
|
||||
# Adapt to official model weights
|
||||
if self.adapt_official:
|
||||
adapt_official(Rx, scale=self.scale)
|
||||
|
||||
# dynamic filter
|
||||
out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W]
|
||||
out += Rx.squeeze_(2)
|
||||
out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W]
|
||||
return out
|
|
@ -1,312 +0,0 @@
|
|||
''' network architecture for EDVR '''
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.arch_util as arch_util
|
||||
try:
|
||||
from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN
|
||||
except ImportError:
|
||||
raise ImportError('Failed to import DCNv2 module.')
|
||||
|
||||
|
||||
class Predeblur_ResNet_Pyramid(nn.Module):
|
||||
def __init__(self, nf=128, HR_in=False):
|
||||
'''
|
||||
HR_in: True if the inputs are high spatial size
|
||||
'''
|
||||
|
||||
super(Predeblur_ResNet_Pyramid, self).__init__()
|
||||
self.HR_in = True if HR_in else False
|
||||
if self.HR_in:
|
||||
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||
else:
|
||||
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||||
self.RB_L1_1 = basic_block()
|
||||
self.RB_L1_2 = basic_block()
|
||||
self.RB_L1_3 = basic_block()
|
||||
self.RB_L1_4 = basic_block()
|
||||
self.RB_L1_5 = basic_block()
|
||||
self.RB_L2_1 = basic_block()
|
||||
self.RB_L2_2 = basic_block()
|
||||
self.RB_L3_1 = basic_block()
|
||||
self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||
self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.HR_in:
|
||||
L1_fea = self.lrelu(self.conv_first_1(x))
|
||||
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
|
||||
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
|
||||
else:
|
||||
L1_fea = self.lrelu(self.conv_first(x))
|
||||
L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea))
|
||||
L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea))
|
||||
L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear',
|
||||
align_corners=False)
|
||||
L2_fea = self.RB_L2_1(L2_fea) + L3_fea
|
||||
L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear',
|
||||
align_corners=False)
|
||||
L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea
|
||||
out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea)))
|
||||
return out
|
||||
|
||||
|
||||
class PCD_Align(nn.Module):
|
||||
''' Alignment module using Pyramid, Cascading and Deformable convolution
|
||||
with 3 pyramid levels.
|
||||
'''
|
||||
|
||||
def __init__(self, nf=64, groups=8):
|
||||
super(PCD_Align, self).__init__()
|
||||
# L3: level 3, 1/4 spatial size
|
||||
self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||||
self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||||
extra_offset_mask=True)
|
||||
# L2: level 2, 1/2 spatial size
|
||||
self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||||
self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
|
||||
self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||||
extra_offset_mask=True)
|
||||
self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
|
||||
# L1: level 1, original spatial size
|
||||
self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||||
self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
|
||||
self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||||
extra_offset_mask=True)
|
||||
self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
|
||||
# Cascading DCN
|
||||
self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||||
self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
|
||||
self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||||
extra_offset_mask=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, nbr_fea_l, ref_fea_l):
|
||||
'''align other neighboring frames to the reference frame in the feature level
|
||||
nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
|
||||
'''
|
||||
# L3
|
||||
L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)
|
||||
L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
|
||||
L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
|
||||
L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
|
||||
# L2
|
||||
L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)
|
||||
L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
|
||||
L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
|
||||
L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
|
||||
L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
|
||||
L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
|
||||
# L1
|
||||
L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)
|
||||
L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
|
||||
L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
|
||||
L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
|
||||
L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
|
||||
L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
|
||||
# Cascading
|
||||
offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
|
||||
offset = self.lrelu(self.cas_offset_conv1(offset))
|
||||
offset = self.lrelu(self.cas_offset_conv2(offset))
|
||||
L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
|
||||
|
||||
return L1_fea
|
||||
|
||||
|
||||
class TSA_Fusion(nn.Module):
|
||||
''' Temporal Spatial Attention fusion module
|
||||
Temporal: correlation;
|
||||
Spatial: 3 pyramid levels.
|
||||
'''
|
||||
|
||||
def __init__(self, nf=64, nframes=5, center=2):
|
||||
super(TSA_Fusion, self).__init__()
|
||||
self.center = center
|
||||
# temporal attention (before fusion conv)
|
||||
self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
|
||||
# fusion conv: using 1x1 to save parameters and computation
|
||||
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||||
|
||||
# spatial attention (after fusion conv)
|
||||
self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
|
||||
self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True)
|
||||
self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||
self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||
self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
|
||||
self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||
self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, aligned_fea):
|
||||
B, N, C, H, W = aligned_fea.size() # N video frames
|
||||
#### temporal attention
|
||||
emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())
|
||||
emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W]
|
||||
|
||||
cor_l = []
|
||||
for i in range(N):
|
||||
emb_nbr = emb[:, i, :, :, :]
|
||||
cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W
|
||||
cor_l.append(cor_tmp)
|
||||
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W
|
||||
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W)
|
||||
aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob
|
||||
|
||||
#### fusion
|
||||
fea = self.lrelu(self.fea_fusion(aligned_fea))
|
||||
|
||||
#### spatial attention
|
||||
att = self.lrelu(self.sAtt_1(aligned_fea))
|
||||
att_max = self.maxpool(att)
|
||||
att_avg = self.avgpool(att)
|
||||
att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
|
||||
# pyramid levels
|
||||
att_L = self.lrelu(self.sAtt_L1(att))
|
||||
att_max = self.maxpool(att_L)
|
||||
att_avg = self.avgpool(att_L)
|
||||
att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
|
||||
att_L = self.lrelu(self.sAtt_L3(att_L))
|
||||
att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
att = self.lrelu(self.sAtt_3(att))
|
||||
att = att + att_L
|
||||
att = self.lrelu(self.sAtt_4(att))
|
||||
att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
att = self.sAtt_5(att)
|
||||
att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
|
||||
att = torch.sigmoid(att)
|
||||
|
||||
fea = fea * att * 2 + att_add
|
||||
return fea
|
||||
|
||||
|
||||
class EDVR(nn.Module):
|
||||
def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None,
|
||||
predeblur=False, HR_in=False, w_TSA=True):
|
||||
super(EDVR, self).__init__()
|
||||
self.nf = nf
|
||||
self.center = nframes // 2 if center is None else center
|
||||
self.is_predeblur = True if predeblur else False
|
||||
self.HR_in = True if HR_in else False
|
||||
self.w_TSA = w_TSA
|
||||
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||||
|
||||
#### extract features (for each frame)
|
||||
if self.is_predeblur:
|
||||
self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
|
||||
self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||||
else:
|
||||
if self.HR_in:
|
||||
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||
else:
|
||||
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||||
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
|
||||
self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||
self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||||
self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
|
||||
self.pcd_align = PCD_Align(nf=nf, groups=groups)
|
||||
if self.w_TSA:
|
||||
self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
|
||||
else:
|
||||
self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||||
|
||||
#### reconstruction
|
||||
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
|
||||
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
|
||||
|
||||
#### activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C, H, W = x.size() # N video frames
|
||||
x_center = x[:, self.center, :, :, :].contiguous()
|
||||
|
||||
#### extract LR features
|
||||
# L1
|
||||
if self.is_predeblur:
|
||||
L1_fea = self.pre_deblur(x.view(-1, C, H, W))
|
||||
L1_fea = self.conv_1x1(L1_fea)
|
||||
if self.HR_in:
|
||||
H, W = H // 4, W // 4
|
||||
else:
|
||||
if self.HR_in:
|
||||
L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W)))
|
||||
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
|
||||
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
|
||||
H, W = H // 4, W // 4
|
||||
else:
|
||||
L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W)))
|
||||
L1_fea = self.feature_extraction(L1_fea)
|
||||
# L2
|
||||
L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
|
||||
L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
|
||||
# L3
|
||||
L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
|
||||
L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
|
||||
|
||||
L1_fea = L1_fea.view(B, N, -1, H, W)
|
||||
L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
|
||||
L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
|
||||
|
||||
#### pcd align
|
||||
# ref feature list
|
||||
ref_fea_l = [
|
||||
L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(),
|
||||
L3_fea[:, self.center, :, :, :].clone()
|
||||
]
|
||||
aligned_fea = []
|
||||
for i in range(N):
|
||||
nbr_fea_l = [
|
||||
L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(),
|
||||
L3_fea[:, i, :, :, :].clone()
|
||||
]
|
||||
aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l))
|
||||
aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W]
|
||||
|
||||
if not self.w_TSA:
|
||||
aligned_fea = aligned_fea.view(B, -1, H, W)
|
||||
fea = self.tsa_fusion(aligned_fea)
|
||||
|
||||
out = self.recon_trunk(fea)
|
||||
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
||||
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
||||
out = self.lrelu(self.HRconv(out))
|
||||
out = self.conv_last(out)
|
||||
if self.HR_in:
|
||||
base = x_center
|
||||
else:
|
||||
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
|
||||
out += base
|
||||
return out
|
|
@ -1,137 +0,0 @@
|
|||
'''PyTorch implementation of TOFlow
|
||||
Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
|
||||
Code reference:
|
||||
1. https://github.com/anchen1011/toflow
|
||||
2. https://github.com/Coldog2333/pytoflow
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .arch_util import flow_warp
|
||||
|
||||
|
||||
def normalize(x):
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
|
||||
return (x - mean) / std
|
||||
|
||||
|
||||
def denormalize(x):
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
|
||||
return x * std + mean
|
||||
|
||||
|
||||
class SpyNet_Block(nn.Module):
|
||||
'''A submodule of SpyNet.'''
|
||||
|
||||
def __init__(self):
|
||||
super(SpyNet_Block, self).__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
|
||||
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
|
||||
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
|
||||
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
|
||||
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
input: x: [ref im, nbr im, initial flow] - (B, 8, H, W)
|
||||
output: estimated flow - (B, 2, H, W)
|
||||
'''
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class SpyNet(nn.Module):
|
||||
'''SpyNet for estimating optical flow
|
||||
Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016'''
|
||||
|
||||
def __init__(self):
|
||||
super(SpyNet, self).__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)])
|
||||
|
||||
def forward(self, ref, nbr):
|
||||
'''Estimating optical flow in coarse level, upsample, and estimate in fine level
|
||||
input: ref: reference image - [B, 3, H, W]
|
||||
nbr: the neighboring image to be warped - [B, 3, H, W]
|
||||
output: estimated optical flow - [B, 2, H, W]
|
||||
'''
|
||||
B, C, H, W = ref.size()
|
||||
ref = [ref]
|
||||
nbr = [nbr]
|
||||
|
||||
for _ in range(3):
|
||||
ref.insert(
|
||||
0,
|
||||
nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2,
|
||||
count_include_pad=False))
|
||||
nbr.insert(
|
||||
0,
|
||||
nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2,
|
||||
count_include_pad=False))
|
||||
|
||||
flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0])
|
||||
|
||||
for i in range(4):
|
||||
flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear',
|
||||
align_corners=True) * 2.0
|
||||
flow = flow_up + self.blocks[i](torch.cat(
|
||||
[ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
|
||||
return flow
|
||||
|
||||
|
||||
class TOFlow(nn.Module):
|
||||
def __init__(self, adapt_official=False):
|
||||
super(TOFlow, self).__init__()
|
||||
|
||||
self.SpyNet = SpyNet()
|
||||
|
||||
self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
|
||||
self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4)
|
||||
self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1)
|
||||
self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.adapt_official = adapt_official # True if using translated official weights else False
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
input: x: input frames - [B, 7, 3, H, W]
|
||||
output: SR reference frame - [B, 3, H, W]
|
||||
"""
|
||||
|
||||
B, T, C, H, W = x.size()
|
||||
x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W)
|
||||
|
||||
ref_idx = 3
|
||||
x_ref = x[:, ref_idx, :, :, :]
|
||||
|
||||
# In the official torch code, the 0-th frame is the reference frame
|
||||
if self.adapt_official:
|
||||
x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
|
||||
ref_idx = 0
|
||||
|
||||
x_warped = []
|
||||
for i in range(7):
|
||||
if i == ref_idx:
|
||||
x_warped.append(x_ref)
|
||||
else:
|
||||
x_nbr = x[:, i, :, :, :]
|
||||
flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1)
|
||||
x_warped.append(flow_warp(x_nbr, flow))
|
||||
x_warped = torch.stack(x_warped, dim=1)
|
||||
|
||||
x = x_warped.view(B, -1, H, W)
|
||||
x = self.relu(self.conv_3x7_64_9x9(x))
|
||||
x = self.relu(self.conv_64_64_9x9(x))
|
||||
x = self.relu(self.conv_64_64_1x1(x))
|
||||
x = self.conv_64_3_1x1(x) + x_ref
|
||||
|
||||
return denormalize(x)
|
|
@ -1,7 +0,0 @@
|
|||
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,
|
||||
deform_conv, modulated_deform_conv)
|
||||
|
||||
__all__ = [
|
||||
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
|
||||
'modulated_deform_conv'
|
||||
]
|
|
@ -1,291 +0,0 @@
|
|||
import math
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from . import deform_conv_cuda
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
class DeformConvFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
|
||||
deformable_groups=1, im2col_step=64):
|
||||
if input is not None and input.dim() != 4:
|
||||
raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
|
||||
input.dim()))
|
||||
ctx.stride = _pair(stride)
|
||||
ctx.padding = _pair(padding)
|
||||
ctx.dilation = _pair(dilation)
|
||||
ctx.groups = groups
|
||||
ctx.deformable_groups = deformable_groups
|
||||
ctx.im2col_step = im2col_step
|
||||
|
||||
ctx.save_for_backward(input, offset, weight)
|
||||
|
||||
output = input.new_empty(
|
||||
DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
|
||||
|
||||
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
|
||||
|
||||
if not input.is_cuda:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
||||
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
||||
deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
|
||||
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
||||
weight.size(2), ctx.stride[1], ctx.stride[0],
|
||||
ctx.padding[1], ctx.padding[0],
|
||||
ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
||||
ctx.deformable_groups, cur_im2col_step)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input, offset, weight = ctx.saved_tensors
|
||||
|
||||
grad_input = grad_offset = grad_weight = None
|
||||
|
||||
if not grad_output.is_cuda:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
||||
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||
grad_input = torch.zeros_like(input)
|
||||
grad_offset = torch.zeros_like(offset)
|
||||
deform_conv_cuda.deform_conv_backward_input_cuda(
|
||||
input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
|
||||
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
||||
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
||||
ctx.deformable_groups, cur_im2col_step)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_weight = torch.zeros_like(weight)
|
||||
deform_conv_cuda.deform_conv_backward_parameters_cuda(
|
||||
input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
|
||||
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
||||
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
||||
ctx.deformable_groups, 1, cur_im2col_step)
|
||||
|
||||
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
|
||||
|
||||
@staticmethod
|
||||
def _output_size(input, weight, padding, dilation, stride):
|
||||
channels = weight.size(0)
|
||||
output_size = (input.size(0), channels)
|
||||
for d in range(input.dim() - 2):
|
||||
in_size = input.size(d + 2)
|
||||
pad = padding[d]
|
||||
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
|
||||
stride_ = stride[d]
|
||||
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
||||
if not all(map(lambda s: s > 0, output_size)):
|
||||
raise ValueError("convolution input is too small (output would be {})".format('x'.join(
|
||||
map(str, output_size))))
|
||||
return output_size
|
||||
|
||||
|
||||
class ModulatedDeformConvFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
|
||||
groups=1, deformable_groups=1):
|
||||
ctx.stride = stride
|
||||
ctx.padding = padding
|
||||
ctx.dilation = dilation
|
||||
ctx.groups = groups
|
||||
ctx.deformable_groups = deformable_groups
|
||||
ctx.with_bias = bias is not None
|
||||
if not ctx.with_bias:
|
||||
bias = input.new_empty(1) # fake tensor
|
||||
if not input.is_cuda:
|
||||
raise NotImplementedError
|
||||
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
|
||||
or input.requires_grad:
|
||||
ctx.save_for_backward(input, offset, mask, weight, bias)
|
||||
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
|
||||
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
|
||||
deform_conv_cuda.modulated_deform_conv_cuda_forward(
|
||||
input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
|
||||
weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
|
||||
ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
if not grad_output.is_cuda:
|
||||
raise NotImplementedError
|
||||
input, offset, mask, weight, bias = ctx.saved_tensors
|
||||
grad_input = torch.zeros_like(input)
|
||||
grad_offset = torch.zeros_like(offset)
|
||||
grad_mask = torch.zeros_like(mask)
|
||||
grad_weight = torch.zeros_like(weight)
|
||||
grad_bias = torch.zeros_like(bias)
|
||||
deform_conv_cuda.modulated_deform_conv_cuda_backward(
|
||||
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
|
||||
grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
|
||||
ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
||||
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
||||
if not ctx.with_bias:
|
||||
grad_bias = None
|
||||
|
||||
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
|
||||
None)
|
||||
|
||||
@staticmethod
|
||||
def _infer_shape(ctx, input, weight):
|
||||
n = input.size(0)
|
||||
channels_out = weight.size(0)
|
||||
height, width = input.shape[2:4]
|
||||
kernel_h, kernel_w = weight.shape[2:4]
|
||||
height_out = (height + 2 * ctx.padding - (ctx.dilation *
|
||||
(kernel_h - 1) + 1)) // ctx.stride + 1
|
||||
width_out = (width + 2 * ctx.padding - (ctx.dilation *
|
||||
(kernel_w - 1) + 1)) // ctx.stride + 1
|
||||
return n, channels_out, height_out, width_out
|
||||
|
||||
|
||||
deform_conv = DeformConvFunction.apply
|
||||
modulated_deform_conv = ModulatedDeformConvFunction.apply
|
||||
|
||||
|
||||
class DeformConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
|
||||
groups=1, deformable_groups=1, bias=False):
|
||||
super(DeformConv, self).__init__()
|
||||
|
||||
assert not bias
|
||||
assert in_channels % groups == 0, \
|
||||
'in_channels {} cannot be divisible by groups {}'.format(
|
||||
in_channels, groups)
|
||||
assert out_channels % groups == 0, \
|
||||
'out_channels {} cannot be divisible by groups {}'.format(
|
||||
out_channels, groups)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = _pair(stride)
|
||||
self.padding = _pair(padding)
|
||||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.deformable_groups = deformable_groups
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
n = self.in_channels
|
||||
for k in self.kernel_size:
|
||||
n *= k
|
||||
stdv = 1. / math.sqrt(n)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, x, offset):
|
||||
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
|
||||
self.groups, self.deformable_groups)
|
||||
|
||||
|
||||
class DeformConvPack(DeformConv):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeformConvPack, self).__init__(*args, **kwargs)
|
||||
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
||||
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
|
||||
bias=True)
|
||||
self.init_offset()
|
||||
|
||||
def init_offset(self):
|
||||
self.conv_offset.weight.data.zero_()
|
||||
self.conv_offset.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
offset = self.conv_offset(x)
|
||||
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
|
||||
self.groups, self.deformable_groups)
|
||||
|
||||
|
||||
class ModulatedDeformConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
|
||||
groups=1, deformable_groups=1, bias=True):
|
||||
super(ModulatedDeformConv, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.deformable_groups = deformable_groups
|
||||
self.with_bias = bias
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
n = self.in_channels
|
||||
for k in self.kernel_size:
|
||||
n *= k
|
||||
stdv = 1. / math.sqrt(n)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.zero_()
|
||||
|
||||
def forward(self, x, offset, mask):
|
||||
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups,
|
||||
self.deformable_groups)
|
||||
|
||||
|
||||
class ModulatedDeformConvPack(ModulatedDeformConv):
|
||||
def __init__(self, *args, extra_offset_mask=False, **kwargs):
|
||||
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
|
||||
|
||||
self.extra_offset_mask = extra_offset_mask
|
||||
self.conv_offset_mask = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
||||
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
|
||||
bias=True)
|
||||
self.init_offset()
|
||||
|
||||
def init_offset(self):
|
||||
self.conv_offset_mask.weight.data.zero_()
|
||||
self.conv_offset_mask.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
if self.extra_offset_mask:
|
||||
# x = [input, features]
|
||||
out = self.conv_offset_mask(x[1])
|
||||
x = x[0]
|
||||
else:
|
||||
out = self.conv_offset_mask(x)
|
||||
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||
offset = torch.cat((o1, o2), dim=1)
|
||||
mask = torch.sigmoid(mask)
|
||||
|
||||
offset_mean = torch.mean(torch.abs(offset))
|
||||
if offset_mean > 100:
|
||||
logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
|
||||
|
||||
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups,
|
||||
self.deformable_groups)
|
|
@ -1,22 +0,0 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
def make_cuda_ext(name, sources):
|
||||
|
||||
return CUDAExtension(
|
||||
name='{}'.format(name), sources=[p for p in sources], extra_compile_args={
|
||||
'cxx': [],
|
||||
'nvcc': [
|
||||
'-D__CUDA_NO_HALF_OPERATORS__',
|
||||
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-D__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
setup(
|
||||
name='deform_conv', ext_modules=[
|
||||
make_cuda_ext(name='deform_conv_cuda',
|
||||
sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])
|
||||
], cmdclass={'build_ext': BuildExtension}, zip_safe=False)
|
|
@ -1,695 +0,0 @@
|
|||
// modify from
|
||||
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
||||
const int channels, const int height, const int width,
|
||||
const int ksize_h, const int ksize_w, const int pad_h,
|
||||
const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
at::Tensor data_col);
|
||||
|
||||
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
||||
const int channels, const int height, const int width,
|
||||
const int ksize_h, const int ksize_w, const int pad_h,
|
||||
const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
at::Tensor grad_im);
|
||||
|
||||
void deformable_col2im_coord(
|
||||
const at::Tensor data_col, const at::Tensor data_im,
|
||||
const at::Tensor data_offset, const int channels, const int height,
|
||||
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
||||
const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||||
const int deformable_group, at::Tensor grad_offset);
|
||||
|
||||
void modulated_deformable_im2col_cuda(
|
||||
const at::Tensor data_im, const at::Tensor data_offset,
|
||||
const at::Tensor data_mask, const int batch_size, const int channels,
|
||||
const int height_im, const int width_im, const int height_col,
|
||||
const int width_col, const int kernel_h, const int kenerl_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||
at::Tensor data_col);
|
||||
|
||||
void modulated_deformable_col2im_cuda(
|
||||
const at::Tensor data_col, const at::Tensor data_offset,
|
||||
const at::Tensor data_mask, const int batch_size, const int channels,
|
||||
const int height_im, const int width_im, const int height_col,
|
||||
const int width_col, const int kernel_h, const int kenerl_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||
at::Tensor grad_im);
|
||||
|
||||
void modulated_deformable_col2im_coord_cuda(
|
||||
const at::Tensor data_col, const at::Tensor data_im,
|
||||
const at::Tensor data_offset, const at::Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
||||
at::Tensor grad_mask);
|
||||
|
||||
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
||||
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
||||
int padW, int dilationH, int dilationW, int group,
|
||||
int deformable_group) {
|
||||
AT_CHECK(weight.ndimension() == 4,
|
||||
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
||||
"but got: %s",
|
||||
weight.ndimension());
|
||||
|
||||
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||
|
||||
AT_CHECK(kW > 0 && kH > 0,
|
||||
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
||||
kW);
|
||||
|
||||
AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
||||
"kernel size should be consistent with weight, ",
|
||||
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
||||
kW, weight.size(2), weight.size(3));
|
||||
|
||||
AT_CHECK(dW > 0 && dH > 0,
|
||||
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
||||
|
||||
AT_CHECK(
|
||||
dilationW > 0 && dilationH > 0,
|
||||
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
||||
dilationH, dilationW);
|
||||
|
||||
int ndim = input.ndimension();
|
||||
int dimf = 0;
|
||||
int dimh = 1;
|
||||
int dimw = 2;
|
||||
|
||||
if (ndim == 4) {
|
||||
dimf++;
|
||||
dimh++;
|
||||
dimw++;
|
||||
}
|
||||
|
||||
AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
||||
ndim);
|
||||
|
||||
long nInputPlane = weight.size(1) * group;
|
||||
long inputHeight = input.size(dimh);
|
||||
long inputWidth = input.size(dimw);
|
||||
long nOutputPlane = weight.size(0);
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
|
||||
AT_CHECK(nInputPlane % deformable_group == 0,
|
||||
"input channels must divide deformable group size");
|
||||
|
||||
if (outputWidth < 1 || outputHeight < 1)
|
||||
AT_ERROR(
|
||||
"Given input size: (%ld x %ld x %ld). "
|
||||
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
||||
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
||||
outputWidth);
|
||||
|
||||
AT_CHECK(input.size(1) == nInputPlane,
|
||||
"invalid number of input planes, expected: %d, but got: %d",
|
||||
nInputPlane, input.size(1));
|
||||
|
||||
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
|
||||
"input image is smaller than kernel");
|
||||
|
||||
AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
||||
"invalid spatial size of offset, expected height: %d width: %d, but "
|
||||
"got height: %d width: %d",
|
||||
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
||||
|
||||
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
||||
"invalid number of channels of offset");
|
||||
|
||||
if (gradOutput != NULL) {
|
||||
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
||||
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
||||
nOutputPlane, gradOutput->size(dimf));
|
||||
|
||||
AT_CHECK((gradOutput->size(dimh) == outputHeight &&
|
||||
gradOutput->size(dimw) == outputWidth),
|
||||
"invalid size of gradOutput, expected height: %d width: %d , but "
|
||||
"got height: %d width: %d",
|
||||
outputHeight, outputWidth, gradOutput->size(dimh),
|
||||
gradOutput->size(dimw));
|
||||
}
|
||||
}
|
||||
|
||||
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
||||
at::Tensor offset, at::Tensor output,
|
||||
at::Tensor columns, at::Tensor ones, int kW,
|
||||
int kH, int dW, int dH, int padW, int padH,
|
||||
int dilationW, int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
// todo: resize columns to include im2col: done
|
||||
// todo: add im2col_step as input
|
||||
// todo: add new output buffer and transpose it to output (or directly
|
||||
// transpose output) todo: possibly change data indexing because of
|
||||
// parallel_imgs
|
||||
|
||||
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
||||
dilationH, dilationW, group, deformable_group);
|
||||
|
||||
input = input.contiguous();
|
||||
offset = offset.contiguous();
|
||||
weight = weight.contiguous();
|
||||
|
||||
int batch = 1;
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input.unsqueeze_(0);
|
||||
offset.unsqueeze_(0);
|
||||
}
|
||||
|
||||
// todo: assert batchsize dividable by im2col_step
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = weight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||
|
||||
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
||||
outputHeight, outputWidth});
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
||||
ones = at::ones({outputHeight, outputWidth}, input.options());
|
||||
}
|
||||
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
at::Tensor output_buffer =
|
||||
at::zeros({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step * outputHeight, outputWidth},
|
||||
output.options());
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
||||
output_buffer.size(2), output_buffer.size(3)});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output_buffer[elt][g] = output_buffer[elt][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output_buffer[elt][g]);
|
||||
}
|
||||
}
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
||||
output_buffer.size(3), output_buffer.size(4)});
|
||||
|
||||
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step, outputHeight, outputWidth});
|
||||
output_buffer.transpose_(1, 2);
|
||||
output.copy_(output_buffer);
|
||||
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
||||
at::Tensor gradOutput, at::Tensor gradInput,
|
||||
at::Tensor gradOffset, at::Tensor weight,
|
||||
at::Tensor columns, int kW, int kH, int dW,
|
||||
int dH, int padW, int padH, int dilationW,
|
||||
int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
||||
dilationH, dilationW, group, deformable_group);
|
||||
|
||||
input = input.contiguous();
|
||||
offset = offset.contiguous();
|
||||
gradOutput = gradOutput.contiguous();
|
||||
weight = weight.contiguous();
|
||||
|
||||
int batch = 1;
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
||||
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
||||
gradOutput = gradOutput.view(
|
||||
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||
}
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = weight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
||||
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
// change order of grad output
|
||||
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||
nOutputPlane, outputHeight, outputWidth});
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
||||
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight,
|
||||
outputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
// divide into groups
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
gradOutput = gradOutput.view(
|
||||
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
||||
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
gradOutput = gradOutput.view(
|
||||
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
||||
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
||||
|
||||
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
||||
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
||||
dilationH, dilationW, im2col_step, deformable_group,
|
||||
gradOffset[elt]);
|
||||
|
||||
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
||||
}
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
gradOutput =
|
||||
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
gradOffset = gradOffset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
gradOffset =
|
||||
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int deform_conv_backward_parameters_cuda(
|
||||
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
||||
at::Tensor gradWeight, // at::Tensor gradBias,
|
||||
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
||||
int padW, int padH, int dilationW, int dilationH, int group,
|
||||
int deformable_group, float scale, int im2col_step) {
|
||||
// todo: transpose and reshape outGrad
|
||||
// todo: reshape columns
|
||||
// todo: add im2col_step as input
|
||||
|
||||
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
||||
padW, dilationH, dilationW, group, deformable_group);
|
||||
|
||||
input = input.contiguous();
|
||||
offset = offset.contiguous();
|
||||
gradOutput = gradOutput.contiguous();
|
||||
|
||||
int batch = 1;
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input = input.view(
|
||||
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
||||
gradOutput = gradOutput.view(
|
||||
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||
}
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = gradWeight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||
nOutputPlane, outputHeight, outputWidth});
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
||||
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
||||
gradOutputBuffer =
|
||||
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
||||
outputHeight, outputWidth});
|
||||
gradOutputBuffer.copy_(gradOutput);
|
||||
gradOutputBuffer =
|
||||
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step * outputHeight, outputWidth});
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
gradOutput =
|
||||
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
|
||||
// divide into group
|
||||
gradOutputBuffer = gradOutputBuffer.view(
|
||||
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
||||
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
gradWeight =
|
||||
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
||||
gradWeight.size(2), gradWeight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
gradWeight[g] = gradWeight[g]
|
||||
.flatten(1)
|
||||
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
||||
columns[g].transpose(1, 0), 1.0, scale)
|
||||
.view_as(gradWeight[g]);
|
||||
}
|
||||
gradOutputBuffer = gradOutputBuffer.view(
|
||||
{gradOutputBuffer.size(0),
|
||||
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
||||
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
||||
gradWeight.size(2), gradWeight.size(3),
|
||||
gradWeight.size(4)});
|
||||
}
|
||||
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
void modulated_deform_conv_cuda_forward(
|
||||
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
||||
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group, const int deformable_group,
|
||||
const bool with_bias) {
|
||||
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
||||
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_out = weight.size(0);
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
// resize output
|
||||
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
||||
// resize temporary columns
|
||||
columns =
|
||||
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
||||
input.options());
|
||||
|
||||
output = output.view({output.size(0), group, output.size(1) / group,
|
||||
output.size(2), output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
modulated_deformable_im2col_cuda(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
|
||||
// divide into group
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output[b][g] = output[b][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output[b][g]);
|
||||
}
|
||||
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
}
|
||||
|
||||
output = output.view({output.size(0), output.size(1) * output.size(2),
|
||||
output.size(3), output.size(4)});
|
||||
|
||||
if (with_bias) {
|
||||
output += bias.view({1, bias.size(0), 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deform_conv_cuda_backward(
|
||||
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
||||
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
||||
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
||||
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias) {
|
||||
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
||||
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
grad_input = grad_input.view({batch, channels, height, width});
|
||||
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
||||
input.options());
|
||||
|
||||
grad_output =
|
||||
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
||||
grad_output.size(2), grad_output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
// divide int group
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
|
||||
// gradient w.r.t. input coordinate data
|
||||
modulated_deformable_col2im_coord_cuda(
|
||||
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
||||
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
||||
grad_mask[b]);
|
||||
// gradient w.r.t. input data
|
||||
modulated_deformable_col2im_cuda(
|
||||
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
||||
|
||||
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
||||
// group
|
||||
modulated_deformable_im2col_cuda(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
||||
grad_weight.size(1), grad_weight.size(2),
|
||||
grad_weight.size(3)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
grad_weight[g] =
|
||||
grad_weight[g]
|
||||
.flatten(1)
|
||||
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
||||
.view_as(grad_weight[g]);
|
||||
if (with_bias) {
|
||||
grad_bias[g] =
|
||||
grad_bias[g]
|
||||
.view({-1, 1})
|
||||
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
||||
.view(-1);
|
||||
}
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
||||
grad_weight.size(2), grad_weight.size(3),
|
||||
grad_weight.size(4)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
||||
}
|
||||
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
||||
grad_output.size(2), grad_output.size(3),
|
||||
grad_output.size(4)});
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
|
||||
"deform forward (CUDA)");
|
||||
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
|
||||
"deform_conv_backward_input (CUDA)");
|
||||
m.def("deform_conv_backward_parameters_cuda",
|
||||
&deform_conv_backward_parameters_cuda,
|
||||
"deform_conv_backward_parameters (CUDA)");
|
||||
m.def("modulated_deform_conv_cuda_forward",
|
||||
&modulated_deform_conv_cuda_forward,
|
||||
"modulated deform conv forward (CUDA)");
|
||||
m.def("modulated_deform_conv_cuda_backward",
|
||||
&modulated_deform_conv_cuda_backward,
|
||||
"modulated deform conv backward (CUDA)");
|
||||
}
|
|
@ -1,866 +0,0 @@
|
|||
/*!
|
||||
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
|
||||
*
|
||||
* COPYRIGHT
|
||||
*
|
||||
* All contributions by the University of California:
|
||||
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
||||
* All rights reserved.
|
||||
*
|
||||
* All other contributions:
|
||||
* Copyright (c) 2014-2017, the respective contributors
|
||||
* All rights reserved.
|
||||
*
|
||||
* Caffe uses a shared copyright model: each contributor holds copyright over
|
||||
* their contributions to Caffe. The project versioning records all such
|
||||
* contribution and copyright details. If a contributor wants to further mark
|
||||
* their specific copyright on a particular contribution, they should indicate
|
||||
* their copyright solely in the commit message of the change when it is
|
||||
* committed.
|
||||
*
|
||||
* LICENSE
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
* CONTRIBUTION AGREEMENT
|
||||
*
|
||||
* By contributing to the BVLC/caffe repository through pull-request, comment,
|
||||
* or otherwise, the contributor releases their content to the
|
||||
* license and copyright terms herein.
|
||||
*
|
||||
***************** END Caffe Copyright Notice and Disclaimer ********************
|
||||
*
|
||||
* Copyright (c) 2018 Microsoft
|
||||
* Licensed under The MIT License [see LICENSE for details]
|
||||
* \file modulated_deformable_im2col.cuh
|
||||
* \brief Function definitions of converting an image to
|
||||
* column matrix based on kernel, padding, dilation, and offset.
|
||||
* These functions are mainly used in deformable convolution operators.
|
||||
* \ref: https://arxiv.org/abs/1703.06211
|
||||
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
|
||||
*/
|
||||
|
||||
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
const int kMaxGridNum = 65535;
|
||||
|
||||
inline int GET_BLOCKS(const int N)
|
||||
{
|
||||
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
||||
const int height, const int width, scalar_t h, scalar_t w)
|
||||
{
|
||||
|
||||
int h_low = floor(h);
|
||||
int w_low = floor(w);
|
||||
int h_high = h_low + 1;
|
||||
int w_high = w_low + 1;
|
||||
|
||||
scalar_t lh = h - h_low;
|
||||
scalar_t lw = w - w_low;
|
||||
scalar_t hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
scalar_t v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0)
|
||||
v1 = bottom_data[h_low * data_width + w_low];
|
||||
scalar_t v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1)
|
||||
v2 = bottom_data[h_low * data_width + w_high];
|
||||
scalar_t v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0)
|
||||
v3 = bottom_data[h_high * data_width + w_low];
|
||||
scalar_t v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1)
|
||||
v4 = bottom_data[h_high * data_width + w_high];
|
||||
|
||||
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
|
||||
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
||||
const int h, const int w, const int height, const int width)
|
||||
{
|
||||
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
||||
{
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
scalar_t weight = 0;
|
||||
if (h == argmax_h_low && w == argmax_w_low)
|
||||
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_low && w == argmax_w_high)
|
||||
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
||||
if (h == argmax_h_high && w == argmax_w_low)
|
||||
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_high && w == argmax_w_high)
|
||||
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
||||
const int height, const int width, const scalar_t *im_data,
|
||||
const int data_width, const int bp_dir)
|
||||
{
|
||||
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
||||
{
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
scalar_t weight = 0;
|
||||
|
||||
if (bp_dir == 0)
|
||||
{
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
else if (bp_dir == 1)
|
||||
{
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
|
||||
const int batch_size, const int num_channels, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *data_col)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
// index index of output matrix
|
||||
const int w_col = index % width_col;
|
||||
const int h_col = (index / width_col) % height_col;
|
||||
const int b_col = (index / width_col / height_col) % batch_size;
|
||||
const int c_im = (index / width_col / height_col) / batch_size;
|
||||
const int c_col = c_im * kernel_h * kernel_w;
|
||||
|
||||
// compute deformable group index
|
||||
const int deformable_group_index = c_im / channel_per_deformable_group;
|
||||
|
||||
const int h_in = h_col * stride_h - pad_h;
|
||||
const int w_in = w_col * stride_w - pad_w;
|
||||
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
||||
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
||||
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
||||
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
for (int i = 0; i < kernel_h; ++i)
|
||||
{
|
||||
for (int j = 0; j < kernel_w; ++j)
|
||||
{
|
||||
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
||||
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
scalar_t val = static_cast<scalar_t>(0);
|
||||
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
||||
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
||||
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
||||
{
|
||||
//const scalar_t map_h = i * dilation_h + offset_h;
|
||||
//const scalar_t map_w = j * dilation_w + offset_w;
|
||||
//const int cur_height = height - h_in;
|
||||
//const int cur_width = width - w_in;
|
||||
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
||||
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
||||
}
|
||||
*data_col_ptr = val;
|
||||
data_col_ptr += batch_size * height_col * width_col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deformable_im2col(
|
||||
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h, const int ksize_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||||
const int deformable_group, at::Tensor data_col)
|
||||
{
|
||||
// num_axes should be smaller than block size
|
||||
// todo: check parallel_imgs is correctly passed in
|
||||
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels = channels * height_col * width_col * parallel_imgs;
|
||||
int channel_per_deformable_group = channels / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
|
||||
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
||||
scalar_t *data_col_ = data_col.data<scalar_t>();
|
||||
|
||||
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
||||
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
|
||||
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
||||
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
|
||||
height_col, width_col, data_col_);
|
||||
}));
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void deformable_col2im_gpu_kernel(
|
||||
const int n, const scalar_t *data_col, const scalar_t *data_offset,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *grad_im)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
||||
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / channel_per_deformable_group;
|
||||
|
||||
int w_out = index % width_col;
|
||||
int h_out = (index / width_col) % height_col;
|
||||
int b = (index / width_col / height_col) % batch_size;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
|
||||
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
|
||||
2 * kernel_h * kernel_w * height_col * width_col;
|
||||
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
||||
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
||||
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
||||
|
||||
const scalar_t cur_top_grad = data_col[index];
|
||||
const int cur_h = (int)cur_inv_h_data;
|
||||
const int cur_w = (int)cur_inv_w_data;
|
||||
for (int dy = -2; dy <= 2; dy++)
|
||||
{
|
||||
for (int dx = -2; dx <= 2; dx++)
|
||||
{
|
||||
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
||||
cur_w + dx >= 0 && cur_w + dx < width &&
|
||||
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
||||
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
||||
{
|
||||
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
||||
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
||||
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deformable_col2im(
|
||||
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
at::Tensor grad_im)
|
||||
{
|
||||
|
||||
// todo: make sure parallel_imgs is passed in correctly
|
||||
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
|
||||
int channel_per_deformable_group = channels / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
||||
scalar_t *grad_im_ = grad_im.data<scalar_t>();
|
||||
|
||||
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
||||
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
|
||||
ksize_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group,
|
||||
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
|
||||
}));
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
|
||||
const scalar_t *data_im, const scalar_t *data_offset,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int offset_channels, const int deformable_group,
|
||||
const int height_col, const int width_col, scalar_t *grad_offset)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
scalar_t val = 0;
|
||||
int w = index % width_col;
|
||||
int h = (index / width_col) % height_col;
|
||||
int c = (index / width_col / height_col) % offset_channels;
|
||||
int b = (index / width_col / height_col) / offset_channels;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
||||
const int col_step = kernel_h * kernel_w;
|
||||
int cnt = 0;
|
||||
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
|
||||
batch_size * width_col * height_col;
|
||||
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
|
||||
channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
||||
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
||||
|
||||
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
||||
{
|
||||
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
||||
const int bp_dir = offset_c % 2;
|
||||
|
||||
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
||||
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
int w_out = col_pos % width_col;
|
||||
int h_out = (col_pos / width_col) % height_col;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
||||
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
||||
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
||||
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
||||
{
|
||||
inv_h = inv_w = -2;
|
||||
}
|
||||
const scalar_t weight = get_coordinate_weight(
|
||||
inv_h, inv_w,
|
||||
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
||||
val += weight * data_col_ptr[col_pos];
|
||||
cnt += 1;
|
||||
}
|
||||
|
||||
grad_offset[index] = val;
|
||||
}
|
||||
}
|
||||
|
||||
void deformable_col2im_coord(
|
||||
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
|
||||
const int channels, const int height, const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
|
||||
const int stride_w, const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
|
||||
{
|
||||
|
||||
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
|
||||
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
||||
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
||||
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
|
||||
|
||||
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
||||
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
|
||||
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group,
|
||||
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
|
||||
height_col, width_col, grad_offset_);
|
||||
}));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
||||
const int height, const int width, scalar_t h, scalar_t w)
|
||||
{
|
||||
int h_low = floor(h);
|
||||
int w_low = floor(w);
|
||||
int h_high = h_low + 1;
|
||||
int w_high = w_low + 1;
|
||||
|
||||
scalar_t lh = h - h_low;
|
||||
scalar_t lw = w - w_low;
|
||||
scalar_t hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
scalar_t v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0)
|
||||
v1 = bottom_data[h_low * data_width + w_low];
|
||||
scalar_t v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1)
|
||||
v2 = bottom_data[h_low * data_width + w_high];
|
||||
scalar_t v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0)
|
||||
v3 = bottom_data[h_high * data_width + w_low];
|
||||
scalar_t v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1)
|
||||
v4 = bottom_data[h_high * data_width + w_high];
|
||||
|
||||
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
|
||||
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
||||
const int h, const int w, const int height, const int width)
|
||||
{
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
||||
{
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
scalar_t weight = 0;
|
||||
if (h == argmax_h_low && w == argmax_w_low)
|
||||
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_low && w == argmax_w_high)
|
||||
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
||||
if (h == argmax_h_high && w == argmax_w_low)
|
||||
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_high && w == argmax_w_high)
|
||||
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
||||
const int height, const int width, const scalar_t *im_data,
|
||||
const int data_width, const int bp_dir)
|
||||
{
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
||||
{
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
scalar_t weight = 0;
|
||||
|
||||
if (bp_dir == 0)
|
||||
{
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
else if (bp_dir == 1)
|
||||
{
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
|
||||
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int num_channels, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *data_col)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
// index index of output matrix
|
||||
const int w_col = index % width_col;
|
||||
const int h_col = (index / width_col) % height_col;
|
||||
const int b_col = (index / width_col / height_col) % batch_size;
|
||||
const int c_im = (index / width_col / height_col) / batch_size;
|
||||
const int c_col = c_im * kernel_h * kernel_w;
|
||||
|
||||
// compute deformable group index
|
||||
const int deformable_group_index = c_im / channel_per_deformable_group;
|
||||
|
||||
const int h_in = h_col * stride_h - pad_h;
|
||||
const int w_in = w_col * stride_w - pad_w;
|
||||
|
||||
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
||||
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
||||
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
||||
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
for (int i = 0; i < kernel_h; ++i)
|
||||
{
|
||||
for (int j = 0; j < kernel_w; ++j)
|
||||
{
|
||||
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
||||
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
||||
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
scalar_t val = static_cast<scalar_t>(0);
|
||||
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
||||
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
||||
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
|
||||
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
||||
{
|
||||
//const float map_h = i * dilation_h + offset_h;
|
||||
//const float map_w = j * dilation_w + offset_w;
|
||||
//const int cur_height = height - h_in;
|
||||
//const int cur_width = width - w_in;
|
||||
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
||||
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
||||
}
|
||||
*data_col_ptr = val * mask;
|
||||
data_col_ptr += batch_size * height_col * width_col;
|
||||
//data_col_ptr += height_col * width_col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
|
||||
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *grad_im)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
||||
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / channel_per_deformable_group;
|
||||
|
||||
int w_out = index % width_col;
|
||||
int h_out = (index / width_col) % height_col;
|
||||
int b = (index / width_col / height_col) % batch_size;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
|
||||
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
||||
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
||||
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
||||
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
||||
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
||||
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
||||
|
||||
const scalar_t cur_top_grad = data_col[index] * mask;
|
||||
const int cur_h = (int)cur_inv_h_data;
|
||||
const int cur_w = (int)cur_inv_w_data;
|
||||
for (int dy = -2; dy <= 2; dy++)
|
||||
{
|
||||
for (int dx = -2; dx <= 2; dx++)
|
||||
{
|
||||
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
||||
cur_w + dx >= 0 && cur_w + dx < width &&
|
||||
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
||||
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
||||
{
|
||||
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
||||
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
||||
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
|
||||
const scalar_t *data_col, const scalar_t *data_im,
|
||||
const scalar_t *data_offset, const scalar_t *data_mask,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int offset_channels, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *grad_offset, scalar_t *grad_mask)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
scalar_t val = 0, mval = 0;
|
||||
int w = index % width_col;
|
||||
int h = (index / width_col) % height_col;
|
||||
int c = (index / width_col / height_col) % offset_channels;
|
||||
int b = (index / width_col / height_col) / offset_channels;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
||||
const int col_step = kernel_h * kernel_w;
|
||||
int cnt = 0;
|
||||
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
|
||||
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
||||
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
||||
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
||||
|
||||
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
||||
{
|
||||
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
||||
const int bp_dir = offset_c % 2;
|
||||
|
||||
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
||||
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
int w_out = col_pos % width_col;
|
||||
int h_out = (col_pos / width_col) % height_col;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
||||
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
||||
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
||||
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
||||
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
||||
{
|
||||
inv_h = inv_w = -2;
|
||||
}
|
||||
else
|
||||
{
|
||||
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
|
||||
}
|
||||
const scalar_t weight = dmcn_get_coordinate_weight(
|
||||
inv_h, inv_w,
|
||||
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
||||
val += weight * data_col_ptr[col_pos] * mask;
|
||||
cnt += 1;
|
||||
}
|
||||
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
|
||||
grad_offset[index] = val;
|
||||
if (offset_c % 2 == 0)
|
||||
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
|
||||
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deformable_im2col_cuda(
|
||||
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im, const int width_im,
|
||||
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int deformable_group, at::Tensor data_col)
|
||||
{
|
||||
// num_axes should be smaller than block size
|
||||
const int channel_per_deformable_group = channels / deformable_group;
|
||||
const int num_kernels = channels * batch_size * height_col * width_col;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
|
||||
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
||||
scalar_t *data_col_ = data_col.data<scalar_t>();
|
||||
|
||||
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
||||
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
|
||||
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
|
||||
batch_size, channels, deformable_group, height_col, width_col, data_col_);
|
||||
}));
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deformable_col2im_cuda(
|
||||
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im, const int width_im,
|
||||
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int deformable_group, at::Tensor grad_im)
|
||||
{
|
||||
|
||||
const int channel_per_deformable_group = channels / deformable_group;
|
||||
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
||||
scalar_t *grad_im_ = grad_im.data<scalar_t>();
|
||||
|
||||
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
||||
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
|
||||
kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group,
|
||||
batch_size, deformable_group, height_col, width_col, grad_im_);
|
||||
}));
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deformable_col2im_coord_cuda(
|
||||
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im, const int width_im,
|
||||
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int deformable_group,
|
||||
at::Tensor grad_offset, at::Tensor grad_mask)
|
||||
{
|
||||
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
|
||||
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
||||
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
||||
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
|
||||
scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
|
||||
|
||||
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
||||
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
|
||||
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group,
|
||||
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
|
||||
grad_offset_, grad_mask_);
|
||||
}));
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
|
@ -1,208 +0,0 @@
|
|||
'''
|
||||
Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets
|
||||
'''
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import logging
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import utils.util as util
|
||||
import data.util as data_util
|
||||
import models.archs.EDVR_arch as EDVR_arch
|
||||
|
||||
|
||||
def main():
|
||||
#################
|
||||
# configurations
|
||||
#################
|
||||
device = torch.device('cuda')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
|
||||
# Vid4: SR
|
||||
# REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
|
||||
# blur (deblur-clean), blur_comp (deblur-compression).
|
||||
stage = 1 # 1 or 2, use two stage strategy for REDS dataset.
|
||||
flip_test = False
|
||||
############################################################################
|
||||
#### model
|
||||
if data_mode == 'Vid4':
|
||||
if stage == 1:
|
||||
model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
|
||||
else:
|
||||
raise ValueError('Vid4 does not support stage 2.')
|
||||
elif data_mode == 'sharp_bicubic':
|
||||
if stage == 1:
|
||||
model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
|
||||
else:
|
||||
model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
|
||||
elif data_mode == 'blur_bicubic':
|
||||
if stage == 1:
|
||||
model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
|
||||
else:
|
||||
model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
|
||||
elif data_mode == 'blur':
|
||||
if stage == 1:
|
||||
model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
|
||||
else:
|
||||
model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
|
||||
elif data_mode == 'blur_comp':
|
||||
if stage == 1:
|
||||
model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
|
||||
else:
|
||||
model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if data_mode == 'Vid4':
|
||||
N_in = 7 # use N_in images to restore one HR image
|
||||
else:
|
||||
N_in = 5
|
||||
|
||||
predeblur, HR_in = False, False
|
||||
back_RBs = 40
|
||||
if data_mode == 'blur_bicubic':
|
||||
predeblur = True
|
||||
if data_mode == 'blur' or data_mode == 'blur_comp':
|
||||
predeblur, HR_in = True, True
|
||||
if stage == 2:
|
||||
HR_in = True
|
||||
back_RBs = 20
|
||||
model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
|
||||
|
||||
#### dataset
|
||||
if data_mode == 'Vid4':
|
||||
test_dataset_folder = '../datasets/Vid4/BIx4'
|
||||
GT_dataset_folder = '../datasets/Vid4/GT'
|
||||
else:
|
||||
if stage == 1:
|
||||
test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
|
||||
else:
|
||||
test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
|
||||
print('You should modify the test_dataset_folder path for stage 2')
|
||||
GT_dataset_folder = '../datasets/REDS4/GT'
|
||||
|
||||
#### evaluation
|
||||
crop_border = 0
|
||||
border_frame = N_in // 2 # border frames when evaluate
|
||||
# temporal padding mode
|
||||
if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
|
||||
padding = 'new_info'
|
||||
else:
|
||||
padding = 'replicate'
|
||||
save_imgs = True
|
||||
|
||||
save_folder = '../results/{}'.format(data_mode)
|
||||
util.mkdirs(save_folder)
|
||||
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
#### log info
|
||||
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||
logger.info('Padding mode: {}'.format(padding))
|
||||
logger.info('Model path: {}'.format(model_path))
|
||||
logger.info('Save images: {}'.format(save_imgs))
|
||||
logger.info('Flip test: {}'.format(flip_test))
|
||||
|
||||
#### set up the models
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
|
||||
subfolder_name_l = []
|
||||
|
||||
subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
|
||||
subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
|
||||
# for each subfolder
|
||||
for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
|
||||
subfolder_name = osp.basename(subfolder)
|
||||
subfolder_name_l.append(subfolder_name)
|
||||
save_subfolder = osp.join(save_folder, subfolder_name)
|
||||
|
||||
img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
|
||||
max_idx = len(img_path_l)
|
||||
if save_imgs:
|
||||
util.mkdirs(save_subfolder)
|
||||
|
||||
#### read LQ and GT images
|
||||
imgs_LQ = data_util.read_img_seq(subfolder)
|
||||
img_GT_l = []
|
||||
for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
|
||||
img_GT_l.append(data_util.read_img(None, img_GT_path))
|
||||
|
||||
avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
|
||||
|
||||
# process each image
|
||||
for img_idx, img_path in enumerate(img_path_l):
|
||||
img_name = osp.splitext(osp.basename(img_path))[0]
|
||||
select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
|
||||
imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
|
||||
|
||||
if flip_test:
|
||||
output = util.flipx4_forward(model, imgs_in)
|
||||
else:
|
||||
output = util.single_forward(model, imgs_in)
|
||||
output = util.tensor2img(output.squeeze(0))
|
||||
|
||||
if save_imgs:
|
||||
cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
|
||||
|
||||
# calculate PSNR
|
||||
output = output / 255.
|
||||
GT = np.copy(img_GT_l[img_idx])
|
||||
# For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
|
||||
if data_mode == 'Vid4': # bgr2y, [0, 1]
|
||||
GT = data_util.bgr2ycbcr(GT, only_y=True)
|
||||
output = data_util.bgr2ycbcr(output, only_y=True)
|
||||
|
||||
output, GT = util.crop_border([output, GT], crop_border)
|
||||
crt_psnr = util.calculate_psnr(output * 255, GT * 255)
|
||||
logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))
|
||||
|
||||
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
|
||||
avg_psnr_center += crt_psnr
|
||||
N_center += 1
|
||||
else: # border frames
|
||||
avg_psnr_border += crt_psnr
|
||||
N_border += 1
|
||||
|
||||
avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
|
||||
avg_psnr_center = avg_psnr_center / N_center
|
||||
avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
|
||||
avg_psnr_l.append(avg_psnr)
|
||||
avg_psnr_center_l.append(avg_psnr_center)
|
||||
avg_psnr_border_l.append(avg_psnr_border)
|
||||
|
||||
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
|
||||
'Center PSNR: {:.6f} dB for {} frames; '
|
||||
'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
|
||||
(N_center + N_border),
|
||||
avg_psnr_center, N_center,
|
||||
avg_psnr_border, N_border))
|
||||
|
||||
logger.info('################ Tidy Outputs ################')
|
||||
for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
|
||||
avg_psnr_center_l, avg_psnr_border_l):
|
||||
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
|
||||
'Center PSNR: {:.6f} dB. '
|
||||
'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
|
||||
psnr_border))
|
||||
logger.info('################ Final Results ################')
|
||||
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||
logger.info('Padding mode: {}'.format(padding))
|
||||
logger.info('Model path: {}'.format(model_path))
|
||||
logger.info('Save images: {}'.format(save_imgs))
|
||||
logger.info('Flip test: {}'.format(flip_test))
|
||||
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
|
||||
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
|
||||
sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
|
||||
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
|
||||
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,264 +0,0 @@
|
|||
"""
|
||||
DUF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
|
||||
write to txt log file
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import logging
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import utils.util as util
|
||||
import data.util as data_util
|
||||
import models.archs.DUF_arch as DUF_arch
|
||||
|
||||
|
||||
def main():
|
||||
#################
|
||||
# configurations
|
||||
#################
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
|
||||
|
||||
# Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
|
||||
scale = 4
|
||||
layer = 52
|
||||
assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28),
|
||||
(4, 52)], 'Unrecognized (scale, layer) combination'
|
||||
|
||||
# model
|
||||
N_in = 7
|
||||
model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format(scale, layer)
|
||||
adapt_official = True if 'official' in model_path else False
|
||||
DUF_downsampling = True # True | False
|
||||
if layer == 16:
|
||||
model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official)
|
||||
elif layer == 28:
|
||||
model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official)
|
||||
elif layer == 52:
|
||||
model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official)
|
||||
|
||||
#### dataset
|
||||
if data_mode == 'Vid4':
|
||||
test_dataset_folder = '../datasets/Vid4/BIx4/*'
|
||||
else: # sharp_bicubic (REDS)
|
||||
test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
|
||||
|
||||
#### evaluation
|
||||
crop_border = 8
|
||||
border_frame = N_in // 2 # border frames when evaluate
|
||||
# temporal padding mode
|
||||
padding = 'new_info' # different from the official testing codes, which pads zeros.
|
||||
save_imgs = True
|
||||
############################################################################
|
||||
device = torch.device('cuda')
|
||||
save_folder = '../results/{}'.format(data_mode)
|
||||
util.mkdirs(save_folder)
|
||||
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
#### log info
|
||||
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||
logger.info('Padding mode: {}'.format(padding))
|
||||
logger.info('Model path: {}'.format(model_path))
|
||||
logger.info('Save images: {}'.format(save_imgs))
|
||||
|
||||
def read_image(img_path):
|
||||
'''read one image from img_path
|
||||
Return img: HWC, BGR, [0,1], numpy
|
||||
'''
|
||||
img_GT = cv2.imread(img_path)
|
||||
img = img_GT.astype(np.float32) / 255.
|
||||
return img
|
||||
|
||||
def read_seq_imgs(img_seq_path):
|
||||
'''read a sequence of images'''
|
||||
img_path_l = sorted(glob.glob(img_seq_path + '/*'))
|
||||
img_l = [read_image(v) for v in img_path_l]
|
||||
# stack to TCHW, RGB, [0,1], torch
|
||||
imgs = np.stack(img_l, axis=0)
|
||||
imgs = imgs[:, :, :, [2, 1, 0]]
|
||||
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
|
||||
return imgs
|
||||
|
||||
def index_generation(crt_i, max_n, N, padding='reflection'):
|
||||
'''
|
||||
padding: replicate | reflection | new_info | circle
|
||||
'''
|
||||
max_n = max_n - 1
|
||||
n_pad = N // 2
|
||||
return_l = []
|
||||
|
||||
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
|
||||
if i < 0:
|
||||
if padding == 'replicate':
|
||||
add_idx = 0
|
||||
elif padding == 'reflection':
|
||||
add_idx = -i
|
||||
elif padding == 'new_info':
|
||||
add_idx = (crt_i + n_pad) + (-i)
|
||||
elif padding == 'circle':
|
||||
add_idx = N + i
|
||||
else:
|
||||
raise ValueError('Wrong padding mode')
|
||||
elif i > max_n:
|
||||
if padding == 'replicate':
|
||||
add_idx = max_n
|
||||
elif padding == 'reflection':
|
||||
add_idx = max_n * 2 - i
|
||||
elif padding == 'new_info':
|
||||
add_idx = (crt_i - n_pad) - (i - max_n)
|
||||
elif padding == 'circle':
|
||||
add_idx = i - N
|
||||
else:
|
||||
raise ValueError('Wrong padding mode')
|
||||
else:
|
||||
add_idx = i
|
||||
return_l.append(add_idx)
|
||||
return return_l
|
||||
|
||||
def single_forward(model, imgs_in):
|
||||
with torch.no_grad():
|
||||
model_output = model(imgs_in)
|
||||
if isinstance(model_output, list) or isinstance(model_output, tuple):
|
||||
output = model_output[0]
|
||||
else:
|
||||
output = model_output
|
||||
return output
|
||||
|
||||
sub_folder_l = sorted(glob.glob(test_dataset_folder))
|
||||
#### set up the models
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
|
||||
sub_folder_name_l = []
|
||||
|
||||
# for each sub-folder
|
||||
for sub_folder in sub_folder_l:
|
||||
sub_folder_name = sub_folder.split('/')[-1]
|
||||
sub_folder_name_l.append(sub_folder_name)
|
||||
save_sub_folder = osp.join(save_folder, sub_folder_name)
|
||||
|
||||
img_path_l = sorted(glob.glob(sub_folder + '/*'))
|
||||
max_idx = len(img_path_l)
|
||||
|
||||
if save_imgs:
|
||||
util.mkdirs(save_sub_folder)
|
||||
|
||||
#### read LR images
|
||||
imgs = read_seq_imgs(sub_folder)
|
||||
#### read GT images
|
||||
img_GT_l = []
|
||||
if data_mode == 'Vid4':
|
||||
sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*')
|
||||
else:
|
||||
sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
|
||||
for img_GT_path in sorted(glob.glob(sub_folder_GT)):
|
||||
img_GT_l.append(read_image(img_GT_path))
|
||||
|
||||
# When using the downsampling in DUF official code, we downsample the HR images
|
||||
if DUF_downsampling:
|
||||
sub_folder = sub_folder_GT
|
||||
img_path_l = sorted(glob.glob(sub_folder))
|
||||
max_idx = len(img_path_l)
|
||||
imgs = read_seq_imgs(sub_folder[:-2])
|
||||
|
||||
avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
|
||||
cal_n_border, cal_n_center = 0, 0
|
||||
|
||||
# process each image
|
||||
for img_idx, img_path in enumerate(img_path_l):
|
||||
c_idx = int(osp.splitext(osp.basename(img_path))[0])
|
||||
select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
|
||||
# get input images
|
||||
imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
|
||||
|
||||
# Downsample the HR images
|
||||
H, W = imgs_in.size(3), imgs_in.size(4)
|
||||
if DUF_downsampling:
|
||||
imgs_in = util.DUF_downsample(imgs_in, scale=scale)
|
||||
|
||||
output = single_forward(model, imgs_in)
|
||||
|
||||
# Crop to the original shape
|
||||
if scale == 3:
|
||||
pad_h = 3 - (H % 3)
|
||||
pad_w = 3 - (W % 3)
|
||||
if pad_h > 0:
|
||||
output = output[:, :, :-pad_h, :]
|
||||
if pad_w > 0:
|
||||
output = output[:, :, :, :-pad_w]
|
||||
output_f = output.data.float().cpu().squeeze(0)
|
||||
|
||||
output = util.tensor2img(output_f)
|
||||
|
||||
# save imgs
|
||||
if save_imgs:
|
||||
cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
|
||||
|
||||
#### calculate PSNR
|
||||
output = output / 255.
|
||||
GT = np.copy(img_GT_l[img_idx])
|
||||
# For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
|
||||
if data_mode == 'Vid4': # bgr2y, [0, 1]
|
||||
GT = data_util.bgr2ycbcr(GT)
|
||||
output = data_util.bgr2ycbcr(output)
|
||||
if crop_border == 0:
|
||||
cropped_output = output
|
||||
cropped_GT = GT
|
||||
else:
|
||||
cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
|
||||
cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
|
||||
crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
|
||||
logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
|
||||
|
||||
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
|
||||
avg_psnr_center += crt_psnr
|
||||
cal_n_center += 1
|
||||
else: # border frames
|
||||
avg_psnr_border += crt_psnr
|
||||
cal_n_border += 1
|
||||
|
||||
avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
|
||||
avg_psnr_center = avg_psnr_center / cal_n_center
|
||||
if cal_n_border == 0:
|
||||
avg_psnr_border = 0
|
||||
else:
|
||||
avg_psnr_border = avg_psnr_border / cal_n_border
|
||||
|
||||
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
|
||||
'Center PSNR: {:.6f} dB for {} frames; '
|
||||
'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
|
||||
(cal_n_center + cal_n_border),
|
||||
avg_psnr_center, cal_n_center,
|
||||
avg_psnr_border, cal_n_border))
|
||||
|
||||
avg_psnr_l.append(avg_psnr)
|
||||
avg_psnr_center_l.append(avg_psnr_center)
|
||||
avg_psnr_border_l.append(avg_psnr_border)
|
||||
|
||||
logger.info('################ Tidy Outputs ################')
|
||||
for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
|
||||
avg_psnr_center_l, avg_psnr_border_l):
|
||||
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
|
||||
'Center PSNR: {:.6f} dB. '
|
||||
'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
|
||||
logger.info('################ Final Results ################')
|
||||
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||
logger.info('Padding mode: {}'.format(padding))
|
||||
logger.info('Model path: {}'.format(model_path))
|
||||
logger.info('Save images: {}'.format(save_imgs))
|
||||
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
|
||||
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
|
||||
sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
|
||||
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
|
||||
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,230 +0,0 @@
|
|||
"""
|
||||
TOF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets
|
||||
write to txt log file
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import logging
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import utils.util as util
|
||||
import data.util as data_util
|
||||
import models.archs.TOF_arch as TOF_arch
|
||||
|
||||
|
||||
def main():
|
||||
#################
|
||||
# configurations
|
||||
#################
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS)
|
||||
|
||||
# model
|
||||
N_in = 7
|
||||
model_path = '../experiments/pretrained_models/TOF_official.pth'
|
||||
adapt_official = True if 'official' in model_path else False
|
||||
model = TOF_arch.TOFlow(adapt_official=adapt_official)
|
||||
|
||||
#### dataset
|
||||
if data_mode == 'Vid4':
|
||||
test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*'
|
||||
else:
|
||||
test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)
|
||||
|
||||
#### evaluation
|
||||
crop_border = 0
|
||||
border_frame = N_in // 2 # border frames when evaluate
|
||||
# temporal padding mode
|
||||
padding = 'new_info' # different from the official setting
|
||||
save_imgs = True
|
||||
############################################################################
|
||||
device = torch.device('cuda')
|
||||
save_folder = '../results/{}'.format(data_mode)
|
||||
util.mkdirs(save_folder)
|
||||
util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
#### log info
|
||||
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||
logger.info('Padding mode: {}'.format(padding))
|
||||
logger.info('Model path: {}'.format(model_path))
|
||||
logger.info('Save images: {}'.format(save_imgs))
|
||||
|
||||
def read_image(img_path):
|
||||
'''read one image from img_path
|
||||
Return img: HWC, BGR, [0,1], numpy
|
||||
'''
|
||||
img_GT = cv2.imread(img_path)
|
||||
img = img_GT.astype(np.float32) / 255.
|
||||
return img
|
||||
|
||||
def read_seq_imgs(img_seq_path):
|
||||
'''read a sequence of images'''
|
||||
img_path_l = sorted(glob.glob(img_seq_path + '/*'))
|
||||
img_l = [read_image(v) for v in img_path_l]
|
||||
# stack to TCHW, RGB, [0,1], torch
|
||||
imgs = np.stack(img_l, axis=0)
|
||||
imgs = imgs[:, :, :, [2, 1, 0]]
|
||||
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
|
||||
return imgs
|
||||
|
||||
def index_generation(crt_i, max_n, N, padding='reflection'):
|
||||
'''
|
||||
padding: replicate | reflection | new_info | circle
|
||||
'''
|
||||
max_n = max_n - 1
|
||||
n_pad = N // 2
|
||||
return_l = []
|
||||
|
||||
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
|
||||
if i < 0:
|
||||
if padding == 'replicate':
|
||||
add_idx = 0
|
||||
elif padding == 'reflection':
|
||||
add_idx = -i
|
||||
elif padding == 'new_info':
|
||||
add_idx = (crt_i + n_pad) + (-i)
|
||||
elif padding == 'circle':
|
||||
add_idx = N + i
|
||||
else:
|
||||
raise ValueError('Wrong padding mode')
|
||||
elif i > max_n:
|
||||
if padding == 'replicate':
|
||||
add_idx = max_n
|
||||
elif padding == 'reflection':
|
||||
add_idx = max_n * 2 - i
|
||||
elif padding == 'new_info':
|
||||
add_idx = (crt_i - n_pad) - (i - max_n)
|
||||
elif padding == 'circle':
|
||||
add_idx = i - N
|
||||
else:
|
||||
raise ValueError('Wrong padding mode')
|
||||
else:
|
||||
add_idx = i
|
||||
return_l.append(add_idx)
|
||||
return return_l
|
||||
|
||||
def single_forward(model, imgs_in):
|
||||
with torch.no_grad():
|
||||
model_output = model(imgs_in)
|
||||
if isinstance(model_output, list) or isinstance(model_output, tuple):
|
||||
output = model_output[0]
|
||||
else:
|
||||
output = model_output
|
||||
return output
|
||||
|
||||
sub_folder_l = sorted(glob.glob(test_dataset_folder))
|
||||
#### set up the models
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
|
||||
sub_folder_name_l = []
|
||||
|
||||
# for each sub-folder
|
||||
for sub_folder in sub_folder_l:
|
||||
sub_folder_name = sub_folder.split('/')[-1]
|
||||
sub_folder_name_l.append(sub_folder_name)
|
||||
save_sub_folder = osp.join(save_folder, sub_folder_name)
|
||||
|
||||
img_path_l = sorted(glob.glob(sub_folder + '/*'))
|
||||
max_idx = len(img_path_l)
|
||||
|
||||
if save_imgs:
|
||||
util.mkdirs(save_sub_folder)
|
||||
|
||||
#### read LR images
|
||||
imgs = read_seq_imgs(sub_folder)
|
||||
#### read GT images
|
||||
img_GT_l = []
|
||||
if data_mode == 'Vid4':
|
||||
sub_folder_GT = osp.join(sub_folder.replace('/BIx4up_direct/', '/GT/'), '*')
|
||||
else:
|
||||
sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
|
||||
for img_GT_path in sorted(glob.glob(sub_folder_GT)):
|
||||
img_GT_l.append(read_image(img_GT_path))
|
||||
|
||||
avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
|
||||
cal_n_border, cal_n_center = 0, 0
|
||||
|
||||
# process each image
|
||||
for img_idx, img_path in enumerate(img_path_l):
|
||||
c_idx = int(osp.splitext(osp.basename(img_path))[0])
|
||||
select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
|
||||
# get input images
|
||||
imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
|
||||
output = single_forward(model, imgs_in)
|
||||
output_f = output.data.float().cpu().squeeze(0)
|
||||
|
||||
output = util.tensor2img(output_f)
|
||||
|
||||
# save imgs
|
||||
if save_imgs:
|
||||
cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output)
|
||||
|
||||
#### calculate PSNR
|
||||
output = output / 255.
|
||||
GT = np.copy(img_GT_l[img_idx])
|
||||
# For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
|
||||
if data_mode == 'Vid4': # bgr2y, [0, 1]
|
||||
GT = data_util.bgr2ycbcr(GT)
|
||||
output = data_util.bgr2ycbcr(output)
|
||||
if crop_border == 0:
|
||||
cropped_output = output
|
||||
cropped_GT = GT
|
||||
else:
|
||||
cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
|
||||
cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
|
||||
crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
|
||||
logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))
|
||||
|
||||
if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames
|
||||
avg_psnr_center += crt_psnr
|
||||
cal_n_center += 1
|
||||
else: # border frames
|
||||
avg_psnr_border += crt_psnr
|
||||
cal_n_border += 1
|
||||
|
||||
avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
|
||||
avg_psnr_center = avg_psnr_center / cal_n_center
|
||||
if cal_n_border == 0:
|
||||
avg_psnr_border = 0
|
||||
else:
|
||||
avg_psnr_border = avg_psnr_border / cal_n_border
|
||||
|
||||
logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
|
||||
'Center PSNR: {:.6f} dB for {} frames; '
|
||||
'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr,
|
||||
(cal_n_center + cal_n_border),
|
||||
avg_psnr_center, cal_n_center,
|
||||
avg_psnr_border, cal_n_border))
|
||||
|
||||
avg_psnr_l.append(avg_psnr)
|
||||
avg_psnr_center_l.append(avg_psnr_center)
|
||||
avg_psnr_border_l.append(avg_psnr_border)
|
||||
|
||||
logger.info('################ Tidy Outputs ################')
|
||||
for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
|
||||
avg_psnr_center_l, avg_psnr_border_l):
|
||||
logger.info('Folder {} - Average PSNR: {:.6f} dB. '
|
||||
'Center PSNR: {:.6f} dB. '
|
||||
'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
|
||||
logger.info('################ Final Results ################')
|
||||
logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
|
||||
logger.info('Padding mode: {}'.format(padding))
|
||||
logger.info('Model path: {}'.format(model_path))
|
||||
logger.info('Save images: {}'.format(save_imgs))
|
||||
logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
|
||||
'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
|
||||
sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
|
||||
sum(avg_psnr_center_l) / len(avg_psnr_center_l),
|
||||
sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user