From 4e44b8a1aad60b9a6c098869cc3115c90b39a425 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 25 May 2020 19:20:49 -0600 Subject: [PATCH] Clean up video stuff --- codes/data/Vimeo90K_dataset.py | 167 ---- codes/data/video_test_dataset.py | 84 -- codes/models/Video_base_model.py | 166 ---- codes/models/archs/DUF_arch.py | 368 -------- codes/models/archs/EDVR_arch.py | 312 ------- codes/models/archs/TOF_arch.py | 137 --- codes/models/archs/dcn/__init__.py | 7 - codes/models/archs/dcn/deform_conv.py | 291 ------ codes/models/archs/dcn/setup.py | 22 - .../models/archs/dcn/src/deform_conv_cuda.cpp | 695 -------------- .../archs/dcn/src/deform_conv_cuda_kernel.cu | 866 ------------------ codes/test_Vid4_REDS4_with_GT.py | 208 ----- codes/test_Vid4_REDS4_with_GT_DUF.py | 264 ------ codes/test_Vid4_REDS4_with_GT_TOF.py | 230 ----- 14 files changed, 3817 deletions(-) delete mode 100644 codes/data/Vimeo90K_dataset.py delete mode 100644 codes/data/video_test_dataset.py delete mode 100644 codes/models/Video_base_model.py delete mode 100644 codes/models/archs/DUF_arch.py delete mode 100644 codes/models/archs/EDVR_arch.py delete mode 100755 codes/models/archs/TOF_arch.py delete mode 100644 codes/models/archs/dcn/__init__.py delete mode 100644 codes/models/archs/dcn/deform_conv.py delete mode 100644 codes/models/archs/dcn/setup.py delete mode 100644 codes/models/archs/dcn/src/deform_conv_cuda.cpp delete mode 100644 codes/models/archs/dcn/src/deform_conv_cuda_kernel.cu delete mode 100644 codes/test_Vid4_REDS4_with_GT.py delete mode 100644 codes/test_Vid4_REDS4_with_GT_DUF.py delete mode 100644 codes/test_Vid4_REDS4_with_GT_TOF.py diff --git a/codes/data/Vimeo90K_dataset.py b/codes/data/Vimeo90K_dataset.py deleted file mode 100644 index 324e3a10..00000000 --- a/codes/data/Vimeo90K_dataset.py +++ /dev/null @@ -1,167 +0,0 @@ -''' -Vimeo90K dataset -support reading images from lmdb, image folder and memcached -''' -import os.path as osp -import random -import pickle -import logging -import numpy as np -import cv2 -import lmdb -import torch -import torch.utils.data as data -import data.util as util -try: - import mc # import memcached -except ImportError: - pass -logger = logging.getLogger('base') - - -class Vimeo90KDataset(data.Dataset): - ''' - Reading the training Vimeo90K dataset - key example: 00001_0001 (_1, ..., _7) - GT (Ground-Truth): 4th frame; - LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame - ''' - - def __init__(self, opt): - super(Vimeo90KDataset, self).__init__() - self.opt = opt - # temporal augmentation - self.interval_list = opt['interval_list'] - self.random_reverse = opt['random_reverse'] - logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( - ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) - - self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] - self.data_type = self.opt['data_type'] - self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs - - #### determine the LQ frame list - ''' - N | frames - 1 | 4 - 3 | 3,4,5 - 5 | 2,3,4,5,6 - 7 | 1,2,3,4,5,6,7 - ''' - self.LQ_frames_list = [] - for i in range(opt['N_frames']): - self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2) - - #### directly load image keys - if self.data_type == 'lmdb': - self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT']) - logger.info('Using lmdb meta info for cache keys.') - elif opt['cache_keys']: - logger.info('Using cache keys: {}'.format(opt['cache_keys'])) - self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys'] - else: - raise ValueError( - 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') - assert self.paths_GT, 'Error: GT path is empty.' - - if self.data_type == 'lmdb': - self.GT_env, self.LQ_env = None, None - elif self.data_type == 'mc': # memcached - self.mclient = None - elif self.data_type == 'img': - pass - else: - raise ValueError('Wrong data type: {}'.format(self.data_type)) - - def _init_lmdb(self): - # https://github.com/chainer/chainermn/issues/129 - self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, - meminit=False) - self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, - meminit=False) - - def _ensure_memcached(self): - if self.mclient is None: - # specify the config files - server_list_config_file = None - client_config_file = None - self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, - client_config_file) - - def _read_img_mc(self, path): - ''' Return BGR, HWC, [0, 255], uint8''' - value = mc.pyvector() - self.mclient.Get(path, value) - value_buf = mc.ConvertBuffer(value) - img_array = np.frombuffer(value_buf, np.uint8) - img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) - return img - - def __getitem__(self, index): - if self.data_type == 'mc': - self._ensure_memcached() - elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): - self._init_lmdb() - - scale = self.opt['scale'] - GT_size = self.opt['target_size'] - key = self.paths_GT[index] - name_a, name_b = key.split('_') - #### get the GT image (as the center frame) - if self.data_type == 'mc': - img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png')) - img_GT = img_GT.astype(np.float32) / 255. - elif self.data_type == 'lmdb': - img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448)) - else: - img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png')) - - #### get LQ images - LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448) - img_LQ_l = [] - for v in self.LQ_frames_list: - if self.data_type == 'mc': - img_LQ = self._read_img_mc( - osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v))) - img_LQ = img_LQ.astype(np.float32) / 255. - elif self.data_type == 'lmdb': - img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple) - else: - img_LQ = util.read_img(None, - osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v))) - img_LQ_l.append(img_LQ) - - if self.opt['phase'] == 'train': - C, H, W = LQ_size_tuple # LQ size - # randomly crop - if self.LR_input: - LQ_size = GT_size // scale - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l] - rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) - img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] - else: - rnd_h = random.randint(0, max(0, H - GT_size)) - rnd_w = random.randint(0, max(0, W - GT_size)) - img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l] - img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] - - # augmentation - flip, rotate - img_LQ_l.append(img_GT) - rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) - img_LQ_l = rlt[0:-1] - img_GT = rlt[-1] - - # stack LQ images to NHWC, N is the frame number - img_LQs = np.stack(img_LQ_l, axis=0) - # BGR to RGB, HWC to CHW, numpy to tensor - img_GT = img_GT[:, :, [2, 1, 0]] - img_LQs = img_LQs[:, :, :, [2, 1, 0]] - img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() - img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, - (0, 3, 1, 2)))).float() - return {'LQs': img_LQs, 'GT': img_GT, 'key': key} - - def __len__(self): - return len(self.paths_GT) diff --git a/codes/data/video_test_dataset.py b/codes/data/video_test_dataset.py deleted file mode 100644 index ce891958..00000000 --- a/codes/data/video_test_dataset.py +++ /dev/null @@ -1,84 +0,0 @@ -import os.path as osp -import torch -import torch.utils.data as data -import data.util as util - - -class VideoTestDataset(data.Dataset): - """ - A video test dataset. Support: - Vid4 - REDS4 - Vimeo90K-Test - - no need to prepare LMDB files - """ - - def __init__(self, opt): - super(VideoTestDataset, self).__init__() - self.opt = opt - self.cache_data = opt['cache_data'] - self.half_N_frames = opt['N_frames'] // 2 - self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] - self.data_type = self.opt['data_type'] - self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []} - if self.data_type == 'lmdb': - raise ValueError('No need to use LMDB during validation/test.') - #### Generate data info and cache data - self.imgs_LQ, self.imgs_GT = {}, {} - if opt['name'].lower() in ['vid4', 'reds4']: - subfolders_LQ = util.glob_file_list(self.LQ_root) - subfolders_GT = util.glob_file_list(self.GT_root) - for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT): - subfolder_name = osp.basename(subfolder_GT) - img_paths_LQ = util.glob_file_list(subfolder_LQ) - img_paths_GT = util.glob_file_list(subfolder_GT) - max_idx = len(img_paths_LQ) - assert max_idx == len( - img_paths_GT), 'Different number of images in LQ and GT folders' - self.data_info['path_LQ'].extend(img_paths_LQ) - self.data_info['path_GT'].extend(img_paths_GT) - self.data_info['folder'].extend([subfolder_name] * max_idx) - for i in range(max_idx): - self.data_info['idx'].append('{}/{}'.format(i, max_idx)) - border_l = [0] * max_idx - for i in range(self.half_N_frames): - border_l[i] = 1 - border_l[max_idx - i - 1] = 1 - self.data_info['border'].extend(border_l) - - if self.cache_data: - self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ) - self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT) - elif opt['name'].lower() in ['vimeo90k-test']: - pass # TODO - else: - raise ValueError( - 'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.') - - def __getitem__(self, index): - # path_LQ = self.data_info['path_LQ'][index] - # path_GT = self.data_info['path_GT'][index] - folder = self.data_info['folder'][index] - idx, max_idx = self.data_info['idx'][index].split('/') - idx, max_idx = int(idx), int(max_idx) - border = self.data_info['border'][index] - - if self.cache_data: - select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'], - padding=self.opt['padding']) - imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx)) - img_GT = self.imgs_GT[folder][idx] - else: - pass # TODO - - return { - 'LQs': imgs_LQ, - 'GT': img_GT, - 'folder': folder, - 'idx': self.data_info['idx'][index], - 'border': border - } - - def __len__(self): - return len(self.data_info['path_GT']) diff --git a/codes/models/Video_base_model.py b/codes/models/Video_base_model.py deleted file mode 100644 index eb85fc5c..00000000 --- a/codes/models/Video_base_model.py +++ /dev/null @@ -1,166 +0,0 @@ -import logging -from collections import OrderedDict - -import torch -import torch.nn as nn -from torch.nn.parallel import DataParallel, DistributedDataParallel -import models.networks as networks -import models.lr_scheduler as lr_scheduler -from .base_model import BaseModel -from models.loss import CharbonnierLoss - -logger = logging.getLogger('base') - - -class VideoBaseModel(BaseModel): - def __init__(self, opt): - super(VideoBaseModel, self).__init__(opt) - - if opt['dist']: - self.rank = torch.distributed.get_rank() - else: - self.rank = -1 # non dist training - train_opt = opt['train'] - - # define network and load pretrained models - self.netG = networks.define_G(opt).to(self.device) - if opt['dist']: - self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) - else: - self.netG = DataParallel(self.netG) - # print network - self.print_network() - self.load() - - if self.is_train: - self.netG.train() - - #### loss - loss_type = train_opt['pixel_criterion'] - if loss_type == 'l1': - self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) - elif loss_type == 'l2': - self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) - elif loss_type == 'cb': - self.cri_pix = CharbonnierLoss().to(self.device) - else: - raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) - self.l_pix_w = train_opt['pixel_weight'] - - #### optimizers - wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 - if train_opt['ft_tsa_only']: - normal_params = [] - tsa_fusion_params = [] - for k, v in self.netG.named_parameters(): - if v.requires_grad: - if 'tsa_fusion' in k: - tsa_fusion_params.append(v) - else: - normal_params.append(v) - else: - if self.rank <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) - optim_params = [ - { # add normal params first - 'params': normal_params, - 'lr': train_opt['lr_G'] - }, - { - 'params': tsa_fusion_params, - 'lr': train_opt['lr_G'] - }, - ] - else: - optim_params = [] - for k, v in self.netG.named_parameters(): - if v.requires_grad: - optim_params.append(v) - else: - if self.rank <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) - - self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], - weight_decay=wd_G, - betas=(train_opt['beta1'], train_opt['beta2'])) - self.optimizers.append(self.optimizer_G) - - #### schedulers - if train_opt['lr_scheme'] == 'MultiStepLR': - for optimizer in self.optimizers: - self.schedulers.append( - lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], - restarts=train_opt['restarts'], - weights=train_opt['restart_weights'], - gamma=train_opt['lr_gamma'], - clear_state=train_opt['clear_state'])) - elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': - for optimizer in self.optimizers: - self.schedulers.append( - lr_scheduler.CosineAnnealingLR_Restart( - optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], - restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) - else: - raise NotImplementedError() - - self.log_dict = OrderedDict() - - def feed_data(self, data, need_GT=True): - self.var_L = data['LQs'].to(self.device) - if need_GT: - self.real_H = data['GT'].to(self.device) - - def set_params_lr_zero(self): - # fix normal module - self.optimizers[0].param_groups[0]['lr'] = 0 - - def optimize_parameters(self, step): - if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: - self.set_params_lr_zero() - - self.optimizer_G.zero_grad() - self.fake_H = self.netG(self.var_L) - - l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) - l_pix.backward() - self.optimizer_G.step() - - # set log - self.log_dict['l_pix'] = l_pix.item() - - def test(self): - self.netG.eval() - with torch.no_grad(): - self.fake_H = self.netG(self.var_L) - self.netG.train() - - def get_current_log(self): - return self.log_dict - - def get_current_visuals(self, need_GT=True): - out_dict = OrderedDict() - out_dict['LQ'] = self.var_L.detach()[0].float().cpu() - out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() - if need_GT: - out_dict['GT'] = self.real_H.detach()[0].float().cpu() - return out_dict - - def print_network(self): - s, n = self.get_network_description(self.netG) - if isinstance(self.netG, nn.DataParallel): - net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, - self.netG.module.__class__.__name__) - else: - net_struc_str = '{}'.format(self.netG.__class__.__name__) - if self.rank <= 0: - logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) - logger.info(s) - - def load(self): - load_path_G = self.opt['path']['pretrain_model_G'] - if load_path_G is not None: - logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) - self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) - - def save(self, iter_label): - self.save_network(self.netG, 'G', iter_label) diff --git a/codes/models/archs/DUF_arch.py b/codes/models/archs/DUF_arch.py deleted file mode 100644 index 319cea87..00000000 --- a/codes/models/archs/DUF_arch.py +++ /dev/null @@ -1,368 +0,0 @@ -'''Network architecture for DUF: -Deep Video Super-Resolution Network Using Dynamic Upsampling Filters -Without Explicit Motion Compensation (CVPR18) -https://github.com/yhjo09/VSR-DUF - -For all the models below, [adapt_official] is only necessary when -loading the weights converted from the official TensorFlow weights. -Please set it to [False] if you are training the model from scratch. -''' - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def adapt_official(Rx, scale=4): - '''Adapt the weights translated from the official tensorflow weights - Not necessary if you are training from scratch''' - x = Rx.clone() - x1 = x[:, ::3, :, :] - x2 = x[:, 1::3, :, :] - x3 = x[:, 2::3, :, :] - - Rx[:, :scale**2, :, :] = x1 - Rx[:, scale**2:2 * (scale**2), :, :] = x2 - Rx[:, 2 * (scale**2):, :, :] = x3 - - return Rx - - -class DenseBlock(nn.Module): - '''Dense block - for the second denseblock, t_reduced = True''' - - def __init__(self, nf=64, ng=32, t_reduce=False): - super(DenseBlock, self).__init__() - self.t_reduce = t_reduce - if self.t_reduce: - pad = (0, 1, 1) - else: - pad = (1, 1, 1) - self.bn3d_1 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3) - self.conv3d_1 = nn.Conv3d(nf, nf, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) - self.bn3d_2 = nn.BatchNorm3d(nf, eps=1e-3, momentum=1e-3) - self.conv3d_2 = nn.Conv3d(nf, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True) - self.bn3d_3 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3) - self.conv3d_3 = nn.Conv3d(nf + ng, nf + ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True) - self.bn3d_4 = nn.BatchNorm3d(nf + ng, eps=1e-3, momentum=1e-3) - self.conv3d_4 = nn.Conv3d(nf + ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True) - self.bn3d_5 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3) - self.conv3d_5 = nn.Conv3d(nf + 2 * ng, nf + 2 * ng, (1, 1, 1), stride=(1, 1, 1), - padding=(0, 0, 0), bias=True) - self.bn3d_6 = nn.BatchNorm3d(nf + 2 * ng, eps=1e-3, momentum=1e-3) - self.conv3d_6 = nn.Conv3d(nf + 2 * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, - bias=True) - - def forward(self, x): - '''x: [B, C, T, H, W] - C: nf -> nf + 3 * ng - T: 1) 7 -> 7 (t_reduce=False); - 2) 7 -> 7 - 2 * 3 = 1 (t_reduce=True)''' - x1 = self.conv3d_1(F.relu(self.bn3d_1(x), inplace=True)) - x1 = self.conv3d_2(F.relu(self.bn3d_2(x1), inplace=True)) - if self.t_reduce: - x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1) - else: - x1 = torch.cat((x, x1), 1) - - x2 = self.conv3d_3(F.relu(self.bn3d_3(x1), inplace=True)) - x2 = self.conv3d_4(F.relu(self.bn3d_4(x2), inplace=True)) - if self.t_reduce: - x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1) - else: - x2 = torch.cat((x1, x2), 1) - - x3 = self.conv3d_5(F.relu(self.bn3d_5(x2), inplace=True)) - x3 = self.conv3d_6(F.relu(self.bn3d_6(x3), inplace=True)) - if self.t_reduce: - x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1) - else: - x3 = torch.cat((x2, x3), 1) - return x3 - - -class DynamicUpsamplingFilter_3C(nn.Module): - '''dynamic upsampling filter with 3 channels applying the same filters - filter_size: filter size of the generated filters, shape (C, kH, kW)''' - - def __init__(self, filter_size=(1, 5, 5)): - super(DynamicUpsamplingFilter_3C, self).__init__() - # generate a local expansion filter, used similar to im2col - nF = np.prod(filter_size) - expand_filter_np = np.reshape(np.eye(nF, nF), - (nF, filter_size[0], filter_size[1], filter_size[2])) - expand_filter = torch.from_numpy(expand_filter_np).float() - self.expand_filter = torch.cat((expand_filter, expand_filter, expand_filter), - 0) # [75, 1, 5, 5] - - def forward(self, x, filters): - '''x: input image, [B, 3, H, W] - filters: generate dynamic filters, [B, F, R, H, W], e.g., [B, 25, 16, H, W] - F: prod of filter kernel size, e.g., 5*5 = 25 - R: used for upsampling, similar to pixel shuffle, e.g., 4*4 = 16 for x4 - Return: filtered image, [B, 3*R, H, W] - ''' - B, nF, R, H, W = filters.size() - # using group convolution - input_expand = F.conv2d(x, self.expand_filter.type_as(x), padding=2, - groups=3) # [B, 75, H, W] similar to im2col - input_expand = input_expand.view(B, 3, nF, H, W).permute(0, 3, 4, 1, 2) # [B, H, W, 3, 25] - filters = filters.permute(0, 3, 4, 1, 2) # [B, H, W, 25, 16] - out = torch.matmul(input_expand, filters) # [B, H, W, 3, 16] - return out.permute(0, 3, 4, 1, 2).view(B, 3 * R, H, W) # [B, 3*16, H, W] - - -class DUF_16L(nn.Module): - '''Official DUF structure with 16 layers''' - - def __init__(self, scale=4, adapt_official=False): - super(DUF_16L, self).__init__() - self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) - self.dense_block_1 = DenseBlock(64, 64 // 2, t_reduce=False) # 64 + 32 * 3 = 160, T = 7 - self.dense_block_2 = DenseBlock(160, 64 // 2, t_reduce=True) # 160 + 32 * 3 = 256, T = 1 - self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3) - self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), - bias=True) - - self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True) - self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), - padding=(0, 0, 0), bias=True) - - self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True) - self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), - padding=(0, 0, 0), bias=True) - - self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5)) - - self.scale = scale - self.adapt_official = adapt_official - - def forward(self, x): - ''' - x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D - Generate filters and image residual: - Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C - Rx: [B, 3*16, 1, H, W] - ''' - B, T, C, H, W = x.size() - x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] for Conv3D - x_center = x[:, :, T // 2, :, :] - - x = self.conv3d_1(x) - x = self.dense_block_1(x) - x = self.dense_block_2(x) # reduce T to 1 - x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True) - - # image residual - Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W] - - # filter - Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W] - Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1) - - # Adapt to official model weights - if self.adapt_official: - adapt_official(Rx, scale=self.scale) - - # dynamic filter - out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W] - out += Rx.squeeze_(2) - out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W] - - return out - - -class DenseBlock_28L(nn.Module): - '''The first part of the dense blocks used in DUF_28L - Temporal dimension remains the same here''' - - def __init__(self, nf=64, ng=16): - super(DenseBlock_28L, self).__init__() - pad = (1, 1, 1) - - dense_block_l = [] - for i in range(0, 9): - dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3)) - dense_block_l.append(nn.ReLU()) - dense_block_l.append( - nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True)) - - dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3)) - dense_block_l.append(nn.ReLU()) - dense_block_l.append( - nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)) - - self.dense_blocks = nn.ModuleList(dense_block_l) - - def forward(self, x): - '''x: [B, C, T, H, W] - C: 1) 64 -> 208; - T: 1) 7 -> 7; (t_reduce=True)''' - for i in range(0, len(self.dense_blocks), 6): - y = x - for j in range(6): - y = self.dense_blocks[i + j](y) - x = torch.cat((x, y), 1) - return x - - -class DUF_28L(nn.Module): - '''Official DUF structure with 28 layers''' - - def __init__(self, scale=4, adapt_official=False): - super(DUF_28L, self).__init__() - self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) - self.dense_block_1 = DenseBlock_28L(64, 16) # 64 + 16 * 9 = 208, T = 7 - self.dense_block_2 = DenseBlock(208, 16, t_reduce=True) # 208 + 16 * 3 = 256, T = 1 - self.bn3d_2 = nn.BatchNorm3d(256, eps=1e-3, momentum=1e-3) - self.conv3d_2 = nn.Conv3d(256, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), - bias=True) - - self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True) - self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), - padding=(0, 0, 0), bias=True) - - self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True) - self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), - padding=(0, 0, 0), bias=True) - - self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5)) - - self.scale = scale - self.adapt_official = adapt_official - - def forward(self, x): - ''' - x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D - Generate filters and image residual: - Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C - Rx: [B, 3*16, 1, H, W] - ''' - B, T, C, H, W = x.size() - x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D - x_center = x[:, :, T // 2, :, :] - x = self.conv3d_1(x) - x = self.dense_block_1(x) - x = self.dense_block_2(x) # reduce T to 1 - x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True) - - # image residual - Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W] - - # filter - Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W] - Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1) - - # Adapt to official model weights - if self.adapt_official: - adapt_official(Rx, scale=self.scale) - - # dynamic filter - out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W] - out += Rx.squeeze_(2) - out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W] - return out - - -class DenseBlock_52L(nn.Module): - '''The first part of the dense blocks used in DUF_52L - Temporal dimension remains the same here''' - - def __init__(self, nf=64, ng=16): - super(DenseBlock_52L, self).__init__() - pad = (1, 1, 1) - - dense_block_l = [] - for i in range(0, 21): - dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3)) - dense_block_l.append(nn.ReLU()) - dense_block_l.append( - nn.Conv3d(nf + i * ng, nf + i * ng, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True)) - - dense_block_l.append(nn.BatchNorm3d(nf + i * ng, eps=1e-3, momentum=1e-3)) - dense_block_l.append(nn.ReLU()) - dense_block_l.append( - nn.Conv3d(nf + i * ng, ng, (3, 3, 3), stride=(1, 1, 1), padding=pad, bias=True)) - - self.dense_blocks = nn.ModuleList(dense_block_l) - - def forward(self, x): - '''x: [B, C, T, H, W] - C: 1) 64 -> 400; - T: 1) 7 -> 7; (t_reduce=True)''' - for i in range(0, len(self.dense_blocks), 6): - y = x - for j in range(6): - y = self.dense_blocks[i + j](y) - x = torch.cat((x, y), 1) - return x - - -class DUF_52L(nn.Module): - '''Official DUF structure with 52 layers''' - - def __init__(self, scale=4, adapt_official=False): - super(DUF_52L, self).__init__() - self.conv3d_1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) - self.dense_block_1 = DenseBlock_52L(64, 16) # 64 + 21 * 9 = 400, T = 7 - self.dense_block_2 = DenseBlock(400, 16, t_reduce=True) # 400 + 16 * 3 = 448, T = 1 - - self.bn3d_2 = nn.BatchNorm3d(448, eps=1e-3, momentum=1e-3) - self.conv3d_2 = nn.Conv3d(448, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), - bias=True) - - self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True) - self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), - padding=(0, 0, 0), bias=True) - - self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), - bias=True) - self.conv3d_f2 = nn.Conv3d(512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), - padding=(0, 0, 0), bias=True) - - self.dynamic_filter = DynamicUpsamplingFilter_3C((1, 5, 5)) - - self.scale = scale - self.adapt_official = adapt_official - - def forward(self, x): - ''' - x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D - Generate filters and image residual: - Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C - Rx: [B, 3*16, 1, H, W] - ''' - B, T, C, H, W = x.size() - x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D - x_center = x[:, :, T // 2, :, :] - x = self.conv3d_1(x) - x = self.dense_block_1(x) - x = self.dense_block_2(x) - x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True) - - # image residual - Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W] - - # filter - Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W] - Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1) - - # Adapt to official model weights - if self.adapt_official: - adapt_official(Rx, scale=self.scale) - - # dynamic filter - out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W] - out += Rx.squeeze_(2) - out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W] - return out diff --git a/codes/models/archs/EDVR_arch.py b/codes/models/archs/EDVR_arch.py deleted file mode 100644 index df9c0325..00000000 --- a/codes/models/archs/EDVR_arch.py +++ /dev/null @@ -1,312 +0,0 @@ -''' network architecture for EDVR ''' -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F -import models.archs.arch_util as arch_util -try: - from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN -except ImportError: - raise ImportError('Failed to import DCNv2 module.') - - -class Predeblur_ResNet_Pyramid(nn.Module): - def __init__(self, nf=128, HR_in=False): - ''' - HR_in: True if the inputs are high spatial size - ''' - - super(Predeblur_ResNet_Pyramid, self).__init__() - self.HR_in = True if HR_in else False - if self.HR_in: - self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) - self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) - self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) - else: - self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) - basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) - self.RB_L1_1 = basic_block() - self.RB_L1_2 = basic_block() - self.RB_L1_3 = basic_block() - self.RB_L1_4 = basic_block() - self.RB_L1_5 = basic_block() - self.RB_L2_1 = basic_block() - self.RB_L2_2 = basic_block() - self.RB_L3_1 = basic_block() - self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) - self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - - def forward(self, x): - if self.HR_in: - L1_fea = self.lrelu(self.conv_first_1(x)) - L1_fea = self.lrelu(self.conv_first_2(L1_fea)) - L1_fea = self.lrelu(self.conv_first_3(L1_fea)) - else: - L1_fea = self.lrelu(self.conv_first(x)) - L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea)) - L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea)) - L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear', - align_corners=False) - L2_fea = self.RB_L2_1(L2_fea) + L3_fea - L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear', - align_corners=False) - L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea - out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea))) - return out - - -class PCD_Align(nn.Module): - ''' Alignment module using Pyramid, Cascading and Deformable convolution - with 3 pyramid levels. - ''' - - def __init__(self, nf=64, groups=8): - super(PCD_Align, self).__init__() - # L3: level 3, 1/4 spatial size - self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff - self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, - extra_offset_mask=True) - # L2: level 2, 1/2 spatial size - self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff - self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset - self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, - extra_offset_mask=True) - self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea - # L1: level 1, original spatial size - self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff - self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset - self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, - extra_offset_mask=True) - self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea - # Cascading DCN - self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff - self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - - self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, - extra_offset_mask=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - - def forward(self, nbr_fea_l, ref_fea_l): - '''align other neighboring frames to the reference frame in the feature level - nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features - ''' - # L3 - L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1) - L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) - L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) - L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset])) - # L2 - L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1) - L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) - L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) - L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1))) - L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) - L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset]) - L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) - L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) - # L1 - L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1) - L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) - L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) - L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) - L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) - L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset]) - L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) - L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) - # Cascading - offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) - offset = self.lrelu(self.cas_offset_conv1(offset)) - offset = self.lrelu(self.cas_offset_conv2(offset)) - L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset])) - - return L1_fea - - -class TSA_Fusion(nn.Module): - ''' Temporal Spatial Attention fusion module - Temporal: correlation; - Spatial: 3 pyramid levels. - ''' - - def __init__(self, nf=64, nframes=5, center=2): - super(TSA_Fusion, self).__init__() - self.center = center - # temporal attention (before fusion conv) - self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - - # fusion conv: using 1x1 to save parameters and computation - self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) - - # spatial attention (after fusion conv) - self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) - self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) - self.avgpool = nn.AvgPool2d(3, stride=2, padding=1) - self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True) - self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True) - self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True) - self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) - self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True) - self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - - def forward(self, aligned_fea): - B, N, C, H, W = aligned_fea.size() # N video frames - #### temporal attention - emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone()) - emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W] - - cor_l = [] - for i in range(N): - emb_nbr = emb[:, i, :, :, :] - cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W - cor_l.append(cor_tmp) - cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W - cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W) - aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob - - #### fusion - fea = self.lrelu(self.fea_fusion(aligned_fea)) - - #### spatial attention - att = self.lrelu(self.sAtt_1(aligned_fea)) - att_max = self.maxpool(att) - att_avg = self.avgpool(att) - att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1))) - # pyramid levels - att_L = self.lrelu(self.sAtt_L1(att)) - att_max = self.maxpool(att_L) - att_avg = self.avgpool(att_L) - att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1))) - att_L = self.lrelu(self.sAtt_L3(att_L)) - att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False) - - att = self.lrelu(self.sAtt_3(att)) - att = att + att_L - att = self.lrelu(self.sAtt_4(att)) - att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False) - att = self.sAtt_5(att) - att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att))) - att = torch.sigmoid(att) - - fea = fea * att * 2 + att_add - return fea - - -class EDVR(nn.Module): - def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None, - predeblur=False, HR_in=False, w_TSA=True): - super(EDVR, self).__init__() - self.nf = nf - self.center = nframes // 2 if center is None else center - self.is_predeblur = True if predeblur else False - self.HR_in = True if HR_in else False - self.w_TSA = w_TSA - ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) - - #### extract features (for each frame) - if self.is_predeblur: - self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in) - self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True) - else: - if self.HR_in: - self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) - self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) - self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) - else: - self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) - self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs) - self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) - self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) - self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - - self.pcd_align = PCD_Align(nf=nf, groups=groups) - if self.w_TSA: - self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center) - else: - self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) - - #### reconstruction - self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True) - self.pixel_shuffle = nn.PixelShuffle(2) - self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True) - - #### activation function - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - - def forward(self, x): - B, N, C, H, W = x.size() # N video frames - x_center = x[:, self.center, :, :, :].contiguous() - - #### extract LR features - # L1 - if self.is_predeblur: - L1_fea = self.pre_deblur(x.view(-1, C, H, W)) - L1_fea = self.conv_1x1(L1_fea) - if self.HR_in: - H, W = H // 4, W // 4 - else: - if self.HR_in: - L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W))) - L1_fea = self.lrelu(self.conv_first_2(L1_fea)) - L1_fea = self.lrelu(self.conv_first_3(L1_fea)) - H, W = H // 4, W // 4 - else: - L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W))) - L1_fea = self.feature_extraction(L1_fea) - # L2 - L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea)) - L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea)) - # L3 - L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea)) - L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea)) - - L1_fea = L1_fea.view(B, N, -1, H, W) - L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2) - L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4) - - #### pcd align - # ref feature list - ref_fea_l = [ - L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(), - L3_fea[:, self.center, :, :, :].clone() - ] - aligned_fea = [] - for i in range(N): - nbr_fea_l = [ - L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(), - L3_fea[:, i, :, :, :].clone() - ] - aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l)) - aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W] - - if not self.w_TSA: - aligned_fea = aligned_fea.view(B, -1, H, W) - fea = self.tsa_fusion(aligned_fea) - - out = self.recon_trunk(fea) - out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) - out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) - out = self.lrelu(self.HRconv(out)) - out = self.conv_last(out) - if self.HR_in: - base = x_center - else: - base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False) - out += base - return out diff --git a/codes/models/archs/TOF_arch.py b/codes/models/archs/TOF_arch.py deleted file mode 100755 index 02d7a914..00000000 --- a/codes/models/archs/TOF_arch.py +++ /dev/null @@ -1,137 +0,0 @@ -'''PyTorch implementation of TOFlow -Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 -Code reference: -1. https://github.com/anchen1011/toflow -2. https://github.com/Coldog2333/pytoflow -''' - -import torch -import torch.nn as nn -from .arch_util import flow_warp - - -def normalize(x): - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x) - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x) - return (x - mean) / std - - -def denormalize(x): - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x) - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x) - return x * std + mean - - -class SpyNet_Block(nn.Module): - '''A submodule of SpyNet.''' - - def __init__(self): - super(SpyNet_Block, self).__init__() - - self.block = nn.Sequential( - nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), - nn.BatchNorm2d(32), nn.ReLU(inplace=True), - nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), - nn.BatchNorm2d(64), nn.ReLU(inplace=True), - nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), - nn.BatchNorm2d(32), nn.ReLU(inplace=True), - nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), - nn.BatchNorm2d(16), nn.ReLU(inplace=True), - nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) - - def forward(self, x): - ''' - input: x: [ref im, nbr im, initial flow] - (B, 8, H, W) - output: estimated flow - (B, 2, H, W) - ''' - return self.block(x) - - -class SpyNet(nn.Module): - '''SpyNet for estimating optical flow - Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016''' - - def __init__(self): - super(SpyNet, self).__init__() - - self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)]) - - def forward(self, ref, nbr): - '''Estimating optical flow in coarse level, upsample, and estimate in fine level - input: ref: reference image - [B, 3, H, W] - nbr: the neighboring image to be warped - [B, 3, H, W] - output: estimated optical flow - [B, 2, H, W] - ''' - B, C, H, W = ref.size() - ref = [ref] - nbr = [nbr] - - for _ in range(3): - ref.insert( - 0, - nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2, - count_include_pad=False)) - nbr.insert( - 0, - nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2, - count_include_pad=False)) - - flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0]) - - for i in range(4): - flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear', - align_corners=True) * 2.0 - flow = flow_up + self.blocks[i](torch.cat( - [ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) - return flow - - -class TOFlow(nn.Module): - def __init__(self, adapt_official=False): - super(TOFlow, self).__init__() - - self.SpyNet = SpyNet() - - self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4) - self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4) - self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1) - self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1) - - self.relu = nn.ReLU(inplace=True) - - self.adapt_official = adapt_official # True if using translated official weights else False - - def forward(self, x): - """ - input: x: input frames - [B, 7, 3, H, W] - output: SR reference frame - [B, 3, H, W] - """ - - B, T, C, H, W = x.size() - x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W) - - ref_idx = 3 - x_ref = x[:, ref_idx, :, :, :] - - # In the official torch code, the 0-th frame is the reference frame - if self.adapt_official: - x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] - ref_idx = 0 - - x_warped = [] - for i in range(7): - if i == ref_idx: - x_warped.append(x_ref) - else: - x_nbr = x[:, i, :, :, :] - flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1) - x_warped.append(flow_warp(x_nbr, flow)) - x_warped = torch.stack(x_warped, dim=1) - - x = x_warped.view(B, -1, H, W) - x = self.relu(self.conv_3x7_64_9x9(x)) - x = self.relu(self.conv_64_64_9x9(x)) - x = self.relu(self.conv_64_64_1x1(x)) - x = self.conv_64_3_1x1(x) + x_ref - - return denormalize(x) diff --git a/codes/models/archs/dcn/__init__.py b/codes/models/archs/dcn/__init__.py deleted file mode 100644 index 1c85e1f0..00000000 --- a/codes/models/archs/dcn/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, - deform_conv, modulated_deform_conv) - -__all__ = [ - 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', - 'modulated_deform_conv' -] diff --git a/codes/models/archs/dcn/deform_conv.py b/codes/models/archs/dcn/deform_conv.py deleted file mode 100644 index f97cb1c8..00000000 --- a/codes/models/archs/dcn/deform_conv.py +++ /dev/null @@ -1,291 +0,0 @@ -import math -import logging - -import torch -import torch.nn as nn -from torch.autograd import Function -from torch.autograd.function import once_differentiable -from torch.nn.modules.utils import _pair - -from . import deform_conv_cuda - -logger = logging.getLogger('base') - - -class DeformConvFunction(Function): - @staticmethod - def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1, - deformable_groups=1, im2col_step=64): - if input is not None and input.dim() != 4: - raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format( - input.dim())) - ctx.stride = _pair(stride) - ctx.padding = _pair(padding) - ctx.dilation = _pair(dilation) - ctx.groups = groups - ctx.deformable_groups = deformable_groups - ctx.im2col_step = im2col_step - - ctx.save_for_backward(input, offset, weight) - - output = input.new_empty( - DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) - - ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones - - if not input.is_cuda: - raise NotImplementedError - else: - cur_im2col_step = min(ctx.im2col_step, input.shape[0]) - assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' - deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output, - ctx.bufs_[0], ctx.bufs_[1], weight.size(3), - weight.size(2), ctx.stride[1], ctx.stride[0], - ctx.padding[1], ctx.padding[0], - ctx.dilation[1], ctx.dilation[0], ctx.groups, - ctx.deformable_groups, cur_im2col_step) - return output - - @staticmethod - @once_differentiable - def backward(ctx, grad_output): - input, offset, weight = ctx.saved_tensors - - grad_input = grad_offset = grad_weight = None - - if not grad_output.is_cuda: - raise NotImplementedError - else: - cur_im2col_step = min(ctx.im2col_step, input.shape[0]) - assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' - - if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: - grad_input = torch.zeros_like(input) - grad_offset = torch.zeros_like(offset) - deform_conv_cuda.deform_conv_backward_input_cuda( - input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0], - weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], - ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, - ctx.deformable_groups, cur_im2col_step) - - if ctx.needs_input_grad[2]: - grad_weight = torch.zeros_like(weight) - deform_conv_cuda.deform_conv_backward_parameters_cuda( - input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1], - weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], - ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, - ctx.deformable_groups, 1, cur_im2col_step) - - return (grad_input, grad_offset, grad_weight, None, None, None, None, None) - - @staticmethod - def _output_size(input, weight, padding, dilation, stride): - channels = weight.size(0) - output_size = (input.size(0), channels) - for d in range(input.dim() - 2): - in_size = input.size(d + 2) - pad = padding[d] - kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 - stride_ = stride[d] - output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) - if not all(map(lambda s: s > 0, output_size)): - raise ValueError("convolution input is too small (output would be {})".format('x'.join( - map(str, output_size)))) - return output_size - - -class ModulatedDeformConvFunction(Function): - @staticmethod - def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1, - groups=1, deformable_groups=1): - ctx.stride = stride - ctx.padding = padding - ctx.dilation = dilation - ctx.groups = groups - ctx.deformable_groups = deformable_groups - ctx.with_bias = bias is not None - if not ctx.with_bias: - bias = input.new_empty(1) # fake tensor - if not input.is_cuda: - raise NotImplementedError - if weight.requires_grad or mask.requires_grad or offset.requires_grad \ - or input.requires_grad: - ctx.save_for_backward(input, offset, mask, weight, bias) - output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) - ctx._bufs = [input.new_empty(0), input.new_empty(0)] - deform_conv_cuda.modulated_deform_conv_cuda_forward( - input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2], - weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, - ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias) - return output - - @staticmethod - @once_differentiable - def backward(ctx, grad_output): - if not grad_output.is_cuda: - raise NotImplementedError - input, offset, mask, weight, bias = ctx.saved_tensors - grad_input = torch.zeros_like(input) - grad_offset = torch.zeros_like(offset) - grad_mask = torch.zeros_like(mask) - grad_weight = torch.zeros_like(weight) - grad_bias = torch.zeros_like(bias) - deform_conv_cuda.modulated_deform_conv_cuda_backward( - input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight, - grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3], - ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, - ctx.groups, ctx.deformable_groups, ctx.with_bias) - if not ctx.with_bias: - grad_bias = None - - return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, - None) - - @staticmethod - def _infer_shape(ctx, input, weight): - n = input.size(0) - channels_out = weight.size(0) - height, width = input.shape[2:4] - kernel_h, kernel_w = weight.shape[2:4] - height_out = (height + 2 * ctx.padding - (ctx.dilation * - (kernel_h - 1) + 1)) // ctx.stride + 1 - width_out = (width + 2 * ctx.padding - (ctx.dilation * - (kernel_w - 1) + 1)) // ctx.stride + 1 - return n, channels_out, height_out, width_out - - -deform_conv = DeformConvFunction.apply -modulated_deform_conv = ModulatedDeformConvFunction.apply - - -class DeformConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, - groups=1, deformable_groups=1, bias=False): - super(DeformConv, self).__init__() - - assert not bias - assert in_channels % groups == 0, \ - 'in_channels {} cannot be divisible by groups {}'.format( - in_channels, groups) - assert out_channels % groups == 0, \ - 'out_channels {} cannot be divisible by groups {}'.format( - out_channels, groups) - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride) - self.padding = _pair(padding) - self.dilation = _pair(dilation) - self.groups = groups - self.deformable_groups = deformable_groups - - self.weight = nn.Parameter( - torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) - - self.reset_parameters() - - def reset_parameters(self): - n = self.in_channels - for k in self.kernel_size: - n *= k - stdv = 1. / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - - def forward(self, x, offset): - return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, - self.groups, self.deformable_groups) - - -class DeformConvPack(DeformConv): - def __init__(self, *args, **kwargs): - super(DeformConvPack, self).__init__(*args, **kwargs) - - self.conv_offset = nn.Conv2d( - self.in_channels, - self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], - kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), - bias=True) - self.init_offset() - - def init_offset(self): - self.conv_offset.weight.data.zero_() - self.conv_offset.bias.data.zero_() - - def forward(self, x): - offset = self.conv_offset(x) - return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, - self.groups, self.deformable_groups) - - -class ModulatedDeformConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, - groups=1, deformable_groups=1, bias=True): - super(ModulatedDeformConv, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - self.deformable_groups = deformable_groups - self.with_bias = bias - - self.weight = nn.Parameter( - torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter('bias', None) - self.reset_parameters() - - def reset_parameters(self): - n = self.in_channels - for k in self.kernel_size: - n *= k - stdv = 1. / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.zero_() - - def forward(self, x, offset, mask): - return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups, - self.deformable_groups) - - -class ModulatedDeformConvPack(ModulatedDeformConv): - def __init__(self, *args, extra_offset_mask=False, **kwargs): - super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) - - self.extra_offset_mask = extra_offset_mask - self.conv_offset_mask = nn.Conv2d( - self.in_channels, - self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], - kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), - bias=True) - self.init_offset() - - def init_offset(self): - self.conv_offset_mask.weight.data.zero_() - self.conv_offset_mask.bias.data.zero_() - - def forward(self, x): - if self.extra_offset_mask: - # x = [input, features] - out = self.conv_offset_mask(x[1]) - x = x[0] - else: - out = self.conv_offset_mask(x) - o1, o2, mask = torch.chunk(out, 3, dim=1) - offset = torch.cat((o1, o2), dim=1) - mask = torch.sigmoid(mask) - - offset_mean = torch.mean(torch.abs(offset)) - if offset_mean > 100: - logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean)) - - return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups, - self.deformable_groups) diff --git a/codes/models/archs/dcn/setup.py b/codes/models/archs/dcn/setup.py deleted file mode 100644 index 094d961f..00000000 --- a/codes/models/archs/dcn/setup.py +++ /dev/null @@ -1,22 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - - -def make_cuda_ext(name, sources): - - return CUDAExtension( - name='{}'.format(name), sources=[p for p in sources], extra_compile_args={ - 'cxx': [], - 'nvcc': [ - '-D__CUDA_NO_HALF_OPERATORS__', - '-D__CUDA_NO_HALF_CONVERSIONS__', - '-D__CUDA_NO_HALF2_OPERATORS__', - ] - }) - - -setup( - name='deform_conv', ext_modules=[ - make_cuda_ext(name='deform_conv_cuda', - sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']) - ], cmdclass={'build_ext': BuildExtension}, zip_safe=False) diff --git a/codes/models/archs/dcn/src/deform_conv_cuda.cpp b/codes/models/archs/dcn/src/deform_conv_cuda.cpp deleted file mode 100644 index c4563ed8..00000000 --- a/codes/models/archs/dcn/src/deform_conv_cuda.cpp +++ /dev/null @@ -1,695 +0,0 @@ -// modify from -// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c - -#include - -#include -#include - -void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, - const int channels, const int height, const int width, - const int ksize_h, const int ksize_w, const int pad_h, - const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int parallel_imgs, const int deformable_group, - at::Tensor data_col); - -void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, - const int channels, const int height, const int width, - const int ksize_h, const int ksize_w, const int pad_h, - const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int parallel_imgs, const int deformable_group, - at::Tensor grad_im); - -void deformable_col2im_coord( - const at::Tensor data_col, const at::Tensor data_im, - const at::Tensor data_offset, const int channels, const int height, - const int width, const int ksize_h, const int ksize_w, const int pad_h, - const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, const int parallel_imgs, - const int deformable_group, at::Tensor grad_offset); - -void modulated_deformable_im2col_cuda( - const at::Tensor data_im, const at::Tensor data_offset, - const at::Tensor data_mask, const int batch_size, const int channels, - const int height_im, const int width_im, const int height_col, - const int width_col, const int kernel_h, const int kenerl_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, const int deformable_group, - at::Tensor data_col); - -void modulated_deformable_col2im_cuda( - const at::Tensor data_col, const at::Tensor data_offset, - const at::Tensor data_mask, const int batch_size, const int channels, - const int height_im, const int width_im, const int height_col, - const int width_col, const int kernel_h, const int kenerl_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, const int deformable_group, - at::Tensor grad_im); - -void modulated_deformable_col2im_coord_cuda( - const at::Tensor data_col, const at::Tensor data_im, - const at::Tensor data_offset, const at::Tensor data_mask, - const int batch_size, const int channels, const int height_im, - const int width_im, const int height_col, const int width_col, - const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, const int dilation_h, - const int dilation_w, const int deformable_group, at::Tensor grad_offset, - at::Tensor grad_mask); - -void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, - at::Tensor weight, int kH, int kW, int dH, int dW, int padH, - int padW, int dilationH, int dilationW, int group, - int deformable_group) { - AT_CHECK(weight.ndimension() == 4, - "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " - "but got: %s", - weight.ndimension()); - - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); - - AT_CHECK(kW > 0 && kH > 0, - "kernel size should be greater than zero, but got kH: %d kW: %d", kH, - kW); - - AT_CHECK((weight.size(2) == kH && weight.size(3) == kW), - "kernel size should be consistent with weight, ", - "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, - kW, weight.size(2), weight.size(3)); - - AT_CHECK(dW > 0 && dH > 0, - "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); - - AT_CHECK( - dilationW > 0 && dilationH > 0, - "dilation should be greater than 0, but got dilationH: %d dilationW: %d", - dilationH, dilationW); - - int ndim = input.ndimension(); - int dimf = 0; - int dimh = 1; - int dimw = 2; - - if (ndim == 4) { - dimf++; - dimh++; - dimw++; - } - - AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", - ndim); - - long nInputPlane = weight.size(1) * group; - long inputHeight = input.size(dimh); - long inputWidth = input.size(dimw); - long nOutputPlane = weight.size(0); - long outputHeight = - (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - long outputWidth = - (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - - AT_CHECK(nInputPlane % deformable_group == 0, - "input channels must divide deformable group size"); - - if (outputWidth < 1 || outputHeight < 1) - AT_ERROR( - "Given input size: (%ld x %ld x %ld). " - "Calculated output size: (%ld x %ld x %ld). Output size is too small", - nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, - outputWidth); - - AT_CHECK(input.size(1) == nInputPlane, - "invalid number of input planes, expected: %d, but got: %d", - nInputPlane, input.size(1)); - - AT_CHECK((inputHeight >= kH && inputWidth >= kW), - "input image is smaller than kernel"); - - AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), - "invalid spatial size of offset, expected height: %d width: %d, but " - "got height: %d width: %d", - outputHeight, outputWidth, offset.size(2), offset.size(3)); - - AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), - "invalid number of channels of offset"); - - if (gradOutput != NULL) { - AT_CHECK(gradOutput->size(dimf) == nOutputPlane, - "invalid number of gradOutput planes, expected: %d, but got: %d", - nOutputPlane, gradOutput->size(dimf)); - - AT_CHECK((gradOutput->size(dimh) == outputHeight && - gradOutput->size(dimw) == outputWidth), - "invalid size of gradOutput, expected height: %d width: %d , but " - "got height: %d width: %d", - outputHeight, outputWidth, gradOutput->size(dimh), - gradOutput->size(dimw)); - } -} - -int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, - at::Tensor offset, at::Tensor output, - at::Tensor columns, at::Tensor ones, int kW, - int kH, int dW, int dH, int padW, int padH, - int dilationW, int dilationH, int group, - int deformable_group, int im2col_step) { - // todo: resize columns to include im2col: done - // todo: add im2col_step as input - // todo: add new output buffer and transpose it to output (or directly - // transpose output) todo: possibly change data indexing because of - // parallel_imgs - - shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, - dilationH, dilationW, group, deformable_group); - - input = input.contiguous(); - offset = offset.contiguous(); - weight = weight.contiguous(); - - int batch = 1; - if (input.ndimension() == 3) { - // Force batch - batch = 0; - input.unsqueeze_(0); - offset.unsqueeze_(0); - } - - // todo: assert batchsize dividable by im2col_step - - long batchSize = input.size(0); - long nInputPlane = input.size(1); - long inputHeight = input.size(2); - long inputWidth = input.size(3); - - long nOutputPlane = weight.size(0); - - long outputWidth = - (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - long outputHeight = - (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - - AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); - - output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, - outputHeight, outputWidth}); - columns = at::zeros( - {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, - input.options()); - - if (ones.ndimension() != 2 || - ones.size(0) * ones.size(1) < outputHeight * outputWidth) { - ones = at::ones({outputHeight, outputWidth}, input.options()); - } - - input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, - inputHeight, inputWidth}); - offset = - offset.view({batchSize / im2col_step, im2col_step, - deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - at::Tensor output_buffer = - at::zeros({batchSize / im2col_step, nOutputPlane, - im2col_step * outputHeight, outputWidth}, - output.options()); - - output_buffer = output_buffer.view( - {output_buffer.size(0), group, output_buffer.size(1) / group, - output_buffer.size(2), output_buffer.size(3)}); - - for (int elt = 0; elt < batchSize / im2col_step; elt++) { - deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, - inputWidth, kH, kW, padH, padW, dH, dW, dilationH, - dilationW, im2col_step, deformable_group, columns); - - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - weight = weight.view({group, weight.size(0) / group, weight.size(1), - weight.size(2), weight.size(3)}); - - for (int g = 0; g < group; g++) { - output_buffer[elt][g] = output_buffer[elt][g] - .flatten(1) - .addmm_(weight[g].flatten(1), columns[g]) - .view_as(output_buffer[elt][g]); - } - } - - output_buffer = output_buffer.view( - {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), - output_buffer.size(3), output_buffer.size(4)}); - - output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, - im2col_step, outputHeight, outputWidth}); - output_buffer.transpose_(1, 2); - output.copy_(output_buffer); - output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); - - input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); - offset = offset.view( - {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - if (batch == 0) { - output = output.view({nOutputPlane, outputHeight, outputWidth}); - input = input.view({nInputPlane, inputHeight, inputWidth}); - offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); - } - - return 1; -} - -int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, - at::Tensor gradOutput, at::Tensor gradInput, - at::Tensor gradOffset, at::Tensor weight, - at::Tensor columns, int kW, int kH, int dW, - int dH, int padW, int padH, int dilationW, - int dilationH, int group, - int deformable_group, int im2col_step) { - shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, - dilationH, dilationW, group, deformable_group); - - input = input.contiguous(); - offset = offset.contiguous(); - gradOutput = gradOutput.contiguous(); - weight = weight.contiguous(); - - int batch = 1; - - if (input.ndimension() == 3) { - // Force batch - batch = 0; - input = input.view({1, input.size(0), input.size(1), input.size(2)}); - offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); - gradOutput = gradOutput.view( - {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); - } - - long batchSize = input.size(0); - long nInputPlane = input.size(1); - long inputHeight = input.size(2); - long inputWidth = input.size(3); - - long nOutputPlane = weight.size(0); - - long outputWidth = - (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - long outputHeight = - (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - - AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); - gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); - columns = at::zeros( - {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, - input.options()); - - // change order of grad output - gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, - nOutputPlane, outputHeight, outputWidth}); - gradOutput.transpose_(1, 2); - - gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, - inputHeight, inputWidth}); - input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, - inputHeight, inputWidth}); - gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, - deformable_group * 2 * kH * kW, outputHeight, - outputWidth}); - offset = - offset.view({batchSize / im2col_step, im2col_step, - deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - for (int elt = 0; elt < batchSize / im2col_step; elt++) { - // divide into groups - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - weight = weight.view({group, weight.size(0) / group, weight.size(1), - weight.size(2), weight.size(3)}); - gradOutput = gradOutput.view( - {gradOutput.size(0), group, gradOutput.size(1) / group, - gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); - - for (int g = 0; g < group; g++) { - columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), - gradOutput[elt][g].flatten(1), 0.0f, 1.0f); - } - - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - gradOutput = gradOutput.view( - {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), - gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); - - deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, - inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, - dilationH, dilationW, im2col_step, deformable_group, - gradOffset[elt]); - - deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, - inputWidth, kH, kW, padH, padW, dH, dW, dilationH, - dilationW, im2col_step, deformable_group, gradInput[elt]); - } - - gradOutput.transpose_(1, 2); - gradOutput = - gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); - - gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); - input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); - gradOffset = gradOffset.view( - {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - offset = offset.view( - {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - if (batch == 0) { - gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); - input = input.view({nInputPlane, inputHeight, inputWidth}); - gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); - offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); - gradOffset = - gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); - } - - return 1; -} - -int deform_conv_backward_parameters_cuda( - at::Tensor input, at::Tensor offset, at::Tensor gradOutput, - at::Tensor gradWeight, // at::Tensor gradBias, - at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, - int padW, int padH, int dilationW, int dilationH, int group, - int deformable_group, float scale, int im2col_step) { - // todo: transpose and reshape outGrad - // todo: reshape columns - // todo: add im2col_step as input - - shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, - padW, dilationH, dilationW, group, deformable_group); - - input = input.contiguous(); - offset = offset.contiguous(); - gradOutput = gradOutput.contiguous(); - - int batch = 1; - - if (input.ndimension() == 3) { - // Force batch - batch = 0; - input = input.view( - at::IntList({1, input.size(0), input.size(1), input.size(2)})); - gradOutput = gradOutput.view( - {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); - } - - long batchSize = input.size(0); - long nInputPlane = input.size(1); - long inputHeight = input.size(2); - long inputWidth = input.size(3); - - long nOutputPlane = gradWeight.size(0); - - long outputWidth = - (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - long outputHeight = - (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - - AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); - - columns = at::zeros( - {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, - input.options()); - - gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, - nOutputPlane, outputHeight, outputWidth}); - gradOutput.transpose_(1, 2); - - at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); - gradOutputBuffer = - gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, - outputHeight, outputWidth}); - gradOutputBuffer.copy_(gradOutput); - gradOutputBuffer = - gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, - im2col_step * outputHeight, outputWidth}); - - gradOutput.transpose_(1, 2); - gradOutput = - gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); - - input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, - inputHeight, inputWidth}); - offset = - offset.view({batchSize / im2col_step, im2col_step, - deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - for (int elt = 0; elt < batchSize / im2col_step; elt++) { - deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, - inputWidth, kH, kW, padH, padW, dH, dW, dilationH, - dilationW, im2col_step, deformable_group, columns); - - // divide into group - gradOutputBuffer = gradOutputBuffer.view( - {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, - gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - gradWeight = - gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), - gradWeight.size(2), gradWeight.size(3)}); - - for (int g = 0; g < group; g++) { - gradWeight[g] = gradWeight[g] - .flatten(1) - .addmm_(gradOutputBuffer[elt][g].flatten(1), - columns[g].transpose(1, 0), 1.0, scale) - .view_as(gradWeight[g]); - } - gradOutputBuffer = gradOutputBuffer.view( - {gradOutputBuffer.size(0), - gradOutputBuffer.size(1) * gradOutputBuffer.size(2), - gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), - gradWeight.size(2), gradWeight.size(3), - gradWeight.size(4)}); - } - - input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); - offset = offset.view( - {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - if (batch == 0) { - gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); - input = input.view({nInputPlane, inputHeight, inputWidth}); - } - - return 1; -} - -void modulated_deform_conv_cuda_forward( - at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, - at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, - int kernel_h, int kernel_w, const int stride_h, const int stride_w, - const int pad_h, const int pad_w, const int dilation_h, - const int dilation_w, const int group, const int deformable_group, - const bool with_bias) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); - - const int batch = input.size(0); - const int channels = input.size(1); - const int height = input.size(2); - const int width = input.size(3); - - const int channels_out = weight.size(0); - const int channels_kernel = weight.size(1); - const int kernel_h_ = weight.size(2); - const int kernel_w_ = weight.size(3); - - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", - kernel_h_, kernel_w, kernel_h_, kernel_w_); - if (channels != channels_kernel * group) - AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", - channels, channels_kernel * group); - - const int height_out = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - if (ones.ndimension() != 2 || - ones.size(0) * ones.size(1) < height_out * width_out) { - // Resize plane and fill with ones... - ones = at::ones({height_out, width_out}, input.options()); - } - - // resize output - output = output.view({batch, channels_out, height_out, width_out}).zero_(); - // resize temporary columns - columns = - at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, - input.options()); - - output = output.view({output.size(0), group, output.size(1) / group, - output.size(2), output.size(3)}); - - for (int b = 0; b < batch; b++) { - modulated_deformable_im2col_cuda( - input[b], offset[b], mask[b], 1, channels, height, width, height_out, - width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, columns); - - // divide into group - weight = weight.view({group, weight.size(0) / group, weight.size(1), - weight.size(2), weight.size(3)}); - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - - for (int g = 0; g < group; g++) { - output[b][g] = output[b][g] - .flatten(1) - .addmm_(weight[g].flatten(1), columns[g]) - .view_as(output[b][g]); - } - - weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), - weight.size(3), weight.size(4)}); - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - } - - output = output.view({output.size(0), output.size(1) * output.size(2), - output.size(3), output.size(4)}); - - if (with_bias) { - output += bias.view({1, bias.size(0), 1, 1}); - } -} - -void modulated_deform_conv_cuda_backward( - at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, - at::Tensor offset, at::Tensor mask, at::Tensor columns, - at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, - at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, - int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, - int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, - const bool with_bias) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); - - const int batch = input.size(0); - const int channels = input.size(1); - const int height = input.size(2); - const int width = input.size(3); - - const int channels_kernel = weight.size(1); - const int kernel_h_ = weight.size(2); - const int kernel_w_ = weight.size(3); - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", - kernel_h_, kernel_w, kernel_h_, kernel_w_); - if (channels != channels_kernel * group) - AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", - channels, channels_kernel * group); - - const int height_out = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - if (ones.ndimension() != 2 || - ones.size(0) * ones.size(1) < height_out * width_out) { - // Resize plane and fill with ones... - ones = at::ones({height_out, width_out}, input.options()); - } - - grad_input = grad_input.view({batch, channels, height, width}); - columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, - input.options()); - - grad_output = - grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, - grad_output.size(2), grad_output.size(3)}); - - for (int b = 0; b < batch; b++) { - // divide int group - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - weight = weight.view({group, weight.size(0) / group, weight.size(1), - weight.size(2), weight.size(3)}); - - for (int g = 0; g < group; g++) { - columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), - grad_output[b][g].flatten(1), 0.0f, 1.0f); - } - - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), - weight.size(3), weight.size(4)}); - - // gradient w.r.t. input coordinate data - modulated_deformable_col2im_coord_cuda( - columns, input[b], offset[b], mask[b], 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, - stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], - grad_mask[b]); - // gradient w.r.t. input data - modulated_deformable_col2im_cuda( - columns, offset[b], mask[b], 1, channels, height, width, height_out, - width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, grad_input[b]); - - // gradient w.r.t. weight, dWeight should accumulate across the batch and - // group - modulated_deformable_im2col_cuda( - input[b], offset[b], mask[b], 1, channels, height, width, height_out, - width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, columns); - - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - grad_weight = grad_weight.view({group, grad_weight.size(0) / group, - grad_weight.size(1), grad_weight.size(2), - grad_weight.size(3)}); - if (with_bias) - grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); - - for (int g = 0; g < group; g++) { - grad_weight[g] = - grad_weight[g] - .flatten(1) - .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) - .view_as(grad_weight[g]); - if (with_bias) { - grad_bias[g] = - grad_bias[g] - .view({-1, 1}) - .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) - .view(-1); - } - } - - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), - grad_weight.size(2), grad_weight.size(3), - grad_weight.size(4)}); - if (with_bias) - grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); - } - grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), - grad_output.size(2), grad_output.size(3), - grad_output.size(4)}); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, - "deform forward (CUDA)"); - m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda, - "deform_conv_backward_input (CUDA)"); - m.def("deform_conv_backward_parameters_cuda", - &deform_conv_backward_parameters_cuda, - "deform_conv_backward_parameters (CUDA)"); - m.def("modulated_deform_conv_cuda_forward", - &modulated_deform_conv_cuda_forward, - "modulated deform conv forward (CUDA)"); - m.def("modulated_deform_conv_cuda_backward", - &modulated_deform_conv_cuda_backward, - "modulated deform conv backward (CUDA)"); -} diff --git a/codes/models/archs/dcn/src/deform_conv_cuda_kernel.cu b/codes/models/archs/dcn/src/deform_conv_cuda_kernel.cu deleted file mode 100644 index a2b94286..00000000 --- a/codes/models/archs/dcn/src/deform_conv_cuda_kernel.cu +++ /dev/null @@ -1,866 +0,0 @@ -/*! - ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** - * - * COPYRIGHT - * - * All contributions by the University of California: - * Copyright (c) 2014-2017 The Regents of the University of California (Regents) - * All rights reserved. - * - * All other contributions: - * Copyright (c) 2014-2017, the respective contributors - * All rights reserved. - * - * Caffe uses a shared copyright model: each contributor holds copyright over - * their contributions to Caffe. The project versioning records all such - * contribution and copyright details. If a contributor wants to further mark - * their specific copyright on a particular contribution, they should indicate - * their copyright solely in the commit message of the change when it is - * committed. - * - * LICENSE - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * CONTRIBUTION AGREEMENT - * - * By contributing to the BVLC/caffe repository through pull-request, comment, - * or otherwise, the contributor releases their content to the - * license and copyright terms herein. - * - ***************** END Caffe Copyright Notice and Disclaimer ******************** - * - * Copyright (c) 2018 Microsoft - * Licensed under The MIT License [see LICENSE for details] - * \file modulated_deformable_im2col.cuh - * \brief Function definitions of converting an image to - * column matrix based on kernel, padding, dilation, and offset. - * These functions are mainly used in deformable convolution operators. - * \ref: https://arxiv.org/abs/1703.06211 - * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng - */ - -// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu - -#include -#include -#include -#include -#include - -using namespace at; - -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - -const int CUDA_NUM_THREADS = 1024; -const int kMaxGridNum = 65535; - -inline int GET_BLOCKS(const int N) -{ - return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); -} - -template -__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, - const int height, const int width, scalar_t h, scalar_t w) -{ - - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - scalar_t lh = h - h_low; - scalar_t lw = w - w_low; - scalar_t hh = 1 - lh, hw = 1 - lw; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - v1 = bottom_data[h_low * data_width + w_low]; - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - v2 = bottom_data[h_low * data_width + w_high]; - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - v3 = bottom_data[h_high * data_width + w_low]; - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = bottom_data[h_high * data_width + w_high]; - - scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - -template -__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, - const int h, const int w, const int height, const int width) -{ - - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) - { - //empty - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - scalar_t weight = 0; - if (h == argmax_h_low && w == argmax_w_low) - weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); - if (h == argmax_h_low && w == argmax_w_high) - weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); - if (h == argmax_h_high && w == argmax_w_low) - weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); - if (h == argmax_h_high && w == argmax_w_high) - weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); - return weight; -} - -template -__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, - const int height, const int width, const scalar_t *im_data, - const int data_width, const int bp_dir) -{ - - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) - { - //empty - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - scalar_t weight = 0; - - if (bp_dir == 0) - { - if (argmax_h_low >= 0 && argmax_w_low >= 0) - weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; - if (argmax_h_low >= 0 && argmax_w_high <= width - 1) - weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; - if (argmax_h_high <= height - 1 && argmax_w_low >= 0) - weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; - if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; - } - else if (bp_dir == 1) - { - if (argmax_h_low >= 0 && argmax_w_low >= 0) - weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; - if (argmax_h_low >= 0 && argmax_w_high <= width - 1) - weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; - if (argmax_h_high <= height - 1 && argmax_w_low >= 0) - weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; - if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; - } - - return weight; -} - -template -__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, const int channel_per_deformable_group, - const int batch_size, const int num_channels, const int deformable_group, - const int height_col, const int width_col, - scalar_t *data_col) -{ - CUDA_KERNEL_LOOP(index, n) - { - // index index of output matrix - const int w_col = index % width_col; - const int h_col = (index / width_col) % height_col; - const int b_col = (index / width_col / height_col) % batch_size; - const int c_im = (index / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - // compute deformable group index - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; - const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; - const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) - { - for (int j = 0; j < kernel_w; ++j) - { - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; - const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; - scalar_t val = static_cast(0); - const scalar_t h_im = h_in + i * dilation_h + offset_h; - const scalar_t w_im = w_in + j * dilation_w + offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) - { - //const scalar_t map_h = i * dilation_h + offset_h; - //const scalar_t map_w = j * dilation_w + offset_w; - //const int cur_height = height - h_in; - //const int cur_width = width - w_in; - //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); - val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val; - data_col_ptr += batch_size * height_col * width_col; - } - } - } -} - -void deformable_im2col( - const at::Tensor data_im, const at::Tensor data_offset, const int channels, - const int height, const int width, const int ksize_h, const int ksize_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, const int parallel_imgs, - const int deformable_group, at::Tensor data_col) -{ - // num_axes should be smaller than block size - // todo: check parallel_imgs is correctly passed in - int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; - int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; - int num_kernels = channels * height_col * width_col * parallel_imgs; - int channel_per_deformable_group = channels / deformable_group; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_im.scalar_type(), "deformable_im2col_gpu", ([&] { - const scalar_t *data_im_ = data_im.data(); - const scalar_t *data_offset_ = data_offset.data(); - scalar_t *data_col_ = data_col.data(); - - deformable_im2col_gpu_kernel<<>>( - num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - channel_per_deformable_group, parallel_imgs, channels, deformable_group, - height_col, width_col, data_col_); - })); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); - } -} - -template -__global__ void deformable_col2im_gpu_kernel( - const int n, const scalar_t *data_col, const scalar_t *data_offset, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, const int deformable_group, - const int height_col, const int width_col, - scalar_t *grad_im) -{ - CUDA_KERNEL_LOOP(index, n) - { - const int j = (index / width_col / height_col / batch_size) % kernel_w; - const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; - // compute the start and end of the output - - const int deformable_group_index = c / channel_per_deformable_group; - - int w_out = index % width_col; - int h_out = (index / width_col) % height_col; - int b = (index / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - - const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; - const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; - const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; - const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; - - const scalar_t cur_top_grad = data_col[index]; - const int cur_h = (int)cur_inv_h_data; - const int cur_w = (int)cur_inv_w_data; - for (int dy = -2; dy <= 2; dy++) - { - for (int dx = -2; dx <= 2; dx++) - { - if (cur_h + dy >= 0 && cur_h + dy < height && - cur_w + dx >= 0 && cur_w + dx < width && - abs(cur_inv_h_data - (cur_h + dy)) < 1 && - abs(cur_inv_w_data - (cur_w + dx)) < 1) - { - int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; - scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); - atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); - } - } - } - } -} - -void deformable_col2im( - const at::Tensor data_col, const at::Tensor data_offset, const int channels, - const int height, const int width, const int ksize_h, - const int ksize_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int parallel_imgs, const int deformable_group, - at::Tensor grad_im) -{ - - // todo: make sure parallel_imgs is passed in correctly - int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; - int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; - int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; - int channel_per_deformable_group = channels / deformable_group; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_col.scalar_type(), "deformable_col2im_gpu", ([&] { - const scalar_t *data_col_ = data_col.data(); - const scalar_t *data_offset_ = data_offset.data(); - scalar_t *grad_im_ = grad_im.data(); - - deformable_col2im_gpu_kernel<<>>( - num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, - ksize_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, channel_per_deformable_group, - parallel_imgs, deformable_group, height_col, width_col, grad_im_); - })); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); - } -} - -template -__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, - const scalar_t *data_im, const scalar_t *data_offset, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, const int offset_channels, const int deformable_group, - const int height_col, const int width_col, scalar_t *grad_offset) -{ - CUDA_KERNEL_LOOP(index, n) - { - scalar_t val = 0; - int w = index % width_col; - int h = (index / width_col) % height_col; - int c = (index / width_col / height_col) % offset_channels; - int b = (index / width_col / height_col) / offset_channels; - // compute the start and end of the output - - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; - const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * - batch_size * width_col * height_col; - const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * - channel_per_deformable_group / kernel_h / kernel_w * height * width; - const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * - kernel_h * kernel_w * height_col * width_col; - - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; - - for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) - { - const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; - - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); - const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; - const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; - scalar_t inv_h = h_in + i * dilation_h + offset_h; - scalar_t inv_w = w_in + j * dilation_w + offset_w; - if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) - { - inv_h = inv_w = -2; - } - const scalar_t weight = get_coordinate_weight( - inv_h, inv_w, - height, width, data_im_ptr + cnt * height * width, width, bp_dir); - val += weight * data_col_ptr[col_pos]; - cnt += 1; - } - - grad_offset[index] = val; - } -} - -void deformable_col2im_coord( - const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, - const int channels, const int height, const int width, const int ksize_h, - const int ksize_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) -{ - - int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; - int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; - int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; - int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { - const scalar_t *data_col_ = data_col.data(); - const scalar_t *data_im_ = data_im.data(); - const scalar_t *data_offset_ = data_offset.data(); - scalar_t *grad_offset_ = grad_offset.data(); - - deformable_col2im_coord_gpu_kernel<<>>( - num_kernels, data_col_, data_im_, data_offset_, channels, height, width, - ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, channel_per_deformable_group, - parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, - height_col, width_col, grad_offset_); - })); -} - -template -__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, - const int height, const int width, scalar_t h, scalar_t w) -{ - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - scalar_t lh = h - h_low; - scalar_t lw = w - w_low; - scalar_t hh = 1 - lh, hw = 1 - lw; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - v1 = bottom_data[h_low * data_width + w_low]; - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - v2 = bottom_data[h_low * data_width + w_high]; - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - v3 = bottom_data[h_high * data_width + w_low]; - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = bottom_data[h_high * data_width + w_high]; - - scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - -template -__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, - const int h, const int w, const int height, const int width) -{ - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) - { - //empty - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - scalar_t weight = 0; - if (h == argmax_h_low && w == argmax_w_low) - weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); - if (h == argmax_h_low && w == argmax_w_high) - weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); - if (h == argmax_h_high && w == argmax_w_low) - weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); - if (h == argmax_h_high && w == argmax_w_high) - weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); - return weight; -} - -template -__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, - const int height, const int width, const scalar_t *im_data, - const int data_width, const int bp_dir) -{ - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) - { - //empty - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - scalar_t weight = 0; - - if (bp_dir == 0) - { - if (argmax_h_low >= 0 && argmax_w_low >= 0) - weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; - if (argmax_h_low >= 0 && argmax_w_high <= width - 1) - weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; - if (argmax_h_high <= height - 1 && argmax_w_low >= 0) - weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; - if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; - } - else if (bp_dir == 1) - { - if (argmax_h_low >= 0 && argmax_w_low >= 0) - weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; - if (argmax_h_low >= 0 && argmax_w_high <= width - 1) - weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; - if (argmax_h_high <= height - 1 && argmax_w_low >= 0) - weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; - if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; - } - - return weight; -} - -template -__global__ void modulated_deformable_im2col_gpu_kernel(const int n, - const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, const int num_channels, const int deformable_group, - const int height_col, const int width_col, - scalar_t *data_col) -{ - CUDA_KERNEL_LOOP(index, n) - { - // index index of output matrix - const int w_col = index % width_col; - const int h_col = (index / width_col) % height_col; - const int b_col = (index / width_col / height_col) % batch_size; - const int c_im = (index / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - // compute deformable group index - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; - const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; - const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - - const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) - { - for (int j = 0; j < kernel_w; ++j) - { - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; - const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; - const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; - scalar_t val = static_cast(0); - const scalar_t h_im = h_in + i * dilation_h + offset_h; - const scalar_t w_im = w_in + j * dilation_w + offset_w; - //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) - { - //const float map_h = i * dilation_h + offset_h; - //const float map_w = j * dilation_w + offset_w; - //const int cur_height = height - h_in; - //const int cur_width = width - w_in; - //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); - val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val * mask; - data_col_ptr += batch_size * height_col * width_col; - //data_col_ptr += height_col * width_col; - } - } - } -} - -template -__global__ void modulated_deformable_col2im_gpu_kernel(const int n, - const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, const int deformable_group, - const int height_col, const int width_col, - scalar_t *grad_im) -{ - CUDA_KERNEL_LOOP(index, n) - { - const int j = (index / width_col / height_col / batch_size) % kernel_w; - const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; - // compute the start and end of the output - - const int deformable_group_index = c / channel_per_deformable_group; - - int w_out = index % width_col; - int h_out = (index / width_col) % height_col; - int b = (index / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - - const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; - const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; - const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; - const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; - const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; - const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; - - const scalar_t cur_top_grad = data_col[index] * mask; - const int cur_h = (int)cur_inv_h_data; - const int cur_w = (int)cur_inv_w_data; - for (int dy = -2; dy <= 2; dy++) - { - for (int dx = -2; dx <= 2; dx++) - { - if (cur_h + dy >= 0 && cur_h + dy < height && - cur_w + dx >= 0 && cur_w + dx < width && - abs(cur_inv_h_data - (cur_h + dy)) < 1 && - abs(cur_inv_w_data - (cur_w + dx)) < 1) - { - int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; - scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); - atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); - } - } - } - } -} - -template -__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, - const scalar_t *data_col, const scalar_t *data_im, - const scalar_t *data_offset, const scalar_t *data_mask, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, const int offset_channels, const int deformable_group, - const int height_col, const int width_col, - scalar_t *grad_offset, scalar_t *grad_mask) -{ - CUDA_KERNEL_LOOP(index, n) - { - scalar_t val = 0, mval = 0; - int w = index % width_col; - int h = (index / width_col) % height_col; - int c = (index / width_col / height_col) % offset_channels; - int b = (index / width_col / height_col) / offset_channels; - // compute the start and end of the output - - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; - const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; - const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; - const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; - - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; - - for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) - { - const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; - - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); - const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); - const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; - const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; - const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; - scalar_t inv_h = h_in + i * dilation_h + offset_h; - scalar_t inv_w = w_in + j * dilation_w + offset_w; - if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) - { - inv_h = inv_w = -2; - } - else - { - mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); - } - const scalar_t weight = dmcn_get_coordinate_weight( - inv_h, inv_w, - height, width, data_im_ptr + cnt * height * width, width, bp_dir); - val += weight * data_col_ptr[col_pos] * mask; - cnt += 1; - } - // KERNEL_ASSIGN(grad_offset[index], offset_req, val); - grad_offset[index] = val; - if (offset_c % 2 == 0) - // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); - grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; - } -} - -void modulated_deformable_im2col_cuda( - const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kenerl_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, at::Tensor data_col) -{ - // num_axes should be smaller than block size - const int channel_per_deformable_group = channels / deformable_group; - const int num_kernels = channels * batch_size * height_col * width_col; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { - const scalar_t *data_im_ = data_im.data(); - const scalar_t *data_offset_ = data_offset.data(); - const scalar_t *data_mask_ = data_mask.data(); - scalar_t *data_col_ = data_col.data(); - - modulated_deformable_im2col_gpu_kernel<<>>( - num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, - batch_size, channels, deformable_group, height_col, width_col, data_col_); - })); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); - } -} - -void modulated_deformable_col2im_cuda( - const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, at::Tensor grad_im) -{ - - const int channel_per_deformable_group = channels / deformable_group; - const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { - const scalar_t *data_col_ = data_col.data(); - const scalar_t *data_offset_ = data_offset.data(); - const scalar_t *data_mask_ = data_mask.data(); - scalar_t *grad_im_ = grad_im.data(); - - modulated_deformable_col2im_gpu_kernel<<>>( - num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, - kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, - dilation_h, dilation_w, channel_per_deformable_group, - batch_size, deformable_group, height_col, width_col, grad_im_); - })); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); - } -} - -void modulated_deformable_col2im_coord_cuda( - const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, - at::Tensor grad_offset, at::Tensor grad_mask) -{ - const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; - const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { - const scalar_t *data_col_ = data_col.data(); - const scalar_t *data_im_ = data_im.data(); - const scalar_t *data_offset_ = data_offset.data(); - const scalar_t *data_mask_ = data_mask.data(); - scalar_t *grad_offset_ = grad_offset.data(); - scalar_t *grad_mask_ = grad_mask.data(); - - modulated_deformable_col2im_coord_gpu_kernel<<>>( - num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, - kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, channel_per_deformable_group, - batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, - grad_offset_, grad_mask_); - })); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); - } -} diff --git a/codes/test_Vid4_REDS4_with_GT.py b/codes/test_Vid4_REDS4_with_GT.py deleted file mode 100644 index 7c90c493..00000000 --- a/codes/test_Vid4_REDS4_with_GT.py +++ /dev/null @@ -1,208 +0,0 @@ -''' -Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets -''' - -import os -import os.path as osp -import glob -import logging -import numpy as np -import cv2 -import torch - -import utils.util as util -import data.util as data_util -import models.archs.EDVR_arch as EDVR_arch - - -def main(): - ################# - # configurations - ################# - device = torch.device('cuda') - os.environ['CUDA_VISIBLE_DEVICES'] = '0' - data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp - # Vid4: SR - # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); - # blur (deblur-clean), blur_comp (deblur-compression). - stage = 1 # 1 or 2, use two stage strategy for REDS dataset. - flip_test = False - ############################################################################ - #### model - if data_mode == 'Vid4': - if stage == 1: - model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' - else: - raise ValueError('Vid4 does not support stage 2.') - elif data_mode == 'sharp_bicubic': - if stage == 1: - model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' - else: - model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' - elif data_mode == 'blur_bicubic': - if stage == 1: - model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' - else: - model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' - elif data_mode == 'blur': - if stage == 1: - model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' - else: - model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' - elif data_mode == 'blur_comp': - if stage == 1: - model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' - else: - model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' - else: - raise NotImplementedError - - if data_mode == 'Vid4': - N_in = 7 # use N_in images to restore one HR image - else: - N_in = 5 - - predeblur, HR_in = False, False - back_RBs = 40 - if data_mode == 'blur_bicubic': - predeblur = True - if data_mode == 'blur' or data_mode == 'blur_comp': - predeblur, HR_in = True, True - if stage == 2: - HR_in = True - back_RBs = 20 - model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) - - #### dataset - if data_mode == 'Vid4': - test_dataset_folder = '../datasets/Vid4/BIx4' - GT_dataset_folder = '../datasets/Vid4/GT' - else: - if stage == 1: - test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) - else: - test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' - print('You should modify the test_dataset_folder path for stage 2') - GT_dataset_folder = '../datasets/REDS4/GT' - - #### evaluation - crop_border = 0 - border_frame = N_in // 2 # border frames when evaluate - # temporal padding mode - if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': - padding = 'new_info' - else: - padding = 'replicate' - save_imgs = True - - save_folder = '../results/{}'.format(data_mode) - util.mkdirs(save_folder) - util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) - logger = logging.getLogger('base') - - #### log info - logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) - logger.info('Padding mode: {}'.format(padding)) - logger.info('Model path: {}'.format(model_path)) - logger.info('Save images: {}'.format(save_imgs)) - logger.info('Flip test: {}'.format(flip_test)) - - #### set up the models - model.load_state_dict(torch.load(model_path), strict=True) - model.eval() - model = model.to(device) - - avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] - subfolder_name_l = [] - - subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) - subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) - # for each subfolder - for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): - subfolder_name = osp.basename(subfolder) - subfolder_name_l.append(subfolder_name) - save_subfolder = osp.join(save_folder, subfolder_name) - - img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) - max_idx = len(img_path_l) - if save_imgs: - util.mkdirs(save_subfolder) - - #### read LQ and GT images - imgs_LQ = data_util.read_img_seq(subfolder) - img_GT_l = [] - for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): - img_GT_l.append(data_util.read_img(None, img_GT_path)) - - avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 - - # process each image - for img_idx, img_path in enumerate(img_path_l): - img_name = osp.splitext(osp.basename(img_path))[0] - select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) - imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) - - if flip_test: - output = util.flipx4_forward(model, imgs_in) - else: - output = util.single_forward(model, imgs_in) - output = util.tensor2img(output.squeeze(0)) - - if save_imgs: - cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) - - # calculate PSNR - output = output / 255. - GT = np.copy(img_GT_l[img_idx]) - # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel - if data_mode == 'Vid4': # bgr2y, [0, 1] - GT = data_util.bgr2ycbcr(GT, only_y=True) - output = data_util.bgr2ycbcr(output, only_y=True) - - output, GT = util.crop_border([output, GT], crop_border) - crt_psnr = util.calculate_psnr(output * 255, GT * 255) - logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) - - if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames - avg_psnr_center += crt_psnr - N_center += 1 - else: # border frames - avg_psnr_border += crt_psnr - N_border += 1 - - avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) - avg_psnr_center = avg_psnr_center / N_center - avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border - avg_psnr_l.append(avg_psnr) - avg_psnr_center_l.append(avg_psnr_center) - avg_psnr_border_l.append(avg_psnr_border) - - logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' - 'Center PSNR: {:.6f} dB for {} frames; ' - 'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr, - (N_center + N_border), - avg_psnr_center, N_center, - avg_psnr_border, N_border)) - - logger.info('################ Tidy Outputs ################') - for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, - avg_psnr_center_l, avg_psnr_border_l): - logger.info('Folder {} - Average PSNR: {:.6f} dB. ' - 'Center PSNR: {:.6f} dB. ' - 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, - psnr_border)) - logger.info('################ Final Results ################') - logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) - logger.info('Padding mode: {}'.format(padding)) - logger.info('Model path: {}'.format(model_path)) - logger.info('Save images: {}'.format(save_imgs)) - logger.info('Flip test: {}'.format(flip_test)) - logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' - 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( - sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), - sum(avg_psnr_center_l) / len(avg_psnr_center_l), - sum(avg_psnr_border_l) / len(avg_psnr_border_l))) - - -if __name__ == '__main__': - main() diff --git a/codes/test_Vid4_REDS4_with_GT_DUF.py b/codes/test_Vid4_REDS4_with_GT_DUF.py deleted file mode 100644 index fcec690d..00000000 --- a/codes/test_Vid4_REDS4_with_GT_DUF.py +++ /dev/null @@ -1,264 +0,0 @@ -""" -DUF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets -write to txt log file -""" - -import os -import os.path as osp -import glob -import logging -import numpy as np -import cv2 -import torch - -import utils.util as util -import data.util as data_util -import models.archs.DUF_arch as DUF_arch - - -def main(): - ################# - # configurations - ################# - os.environ['CUDA_VISIBLE_DEVICES'] = '0' - data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS) - - # Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52) - scale = 4 - layer = 52 - assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28), - (4, 52)], 'Unrecognized (scale, layer) combination' - - # model - N_in = 7 - model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format(scale, layer) - adapt_official = True if 'official' in model_path else False - DUF_downsampling = True # True | False - if layer == 16: - model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official) - elif layer == 28: - model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official) - elif layer == 52: - model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official) - - #### dataset - if data_mode == 'Vid4': - test_dataset_folder = '../datasets/Vid4/BIx4/*' - else: # sharp_bicubic (REDS) - test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode) - - #### evaluation - crop_border = 8 - border_frame = N_in // 2 # border frames when evaluate - # temporal padding mode - padding = 'new_info' # different from the official testing codes, which pads zeros. - save_imgs = True - ############################################################################ - device = torch.device('cuda') - save_folder = '../results/{}'.format(data_mode) - util.mkdirs(save_folder) - util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) - logger = logging.getLogger('base') - - #### log info - logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) - logger.info('Padding mode: {}'.format(padding)) - logger.info('Model path: {}'.format(model_path)) - logger.info('Save images: {}'.format(save_imgs)) - - def read_image(img_path): - '''read one image from img_path - Return img: HWC, BGR, [0,1], numpy - ''' - img_GT = cv2.imread(img_path) - img = img_GT.astype(np.float32) / 255. - return img - - def read_seq_imgs(img_seq_path): - '''read a sequence of images''' - img_path_l = sorted(glob.glob(img_seq_path + '/*')) - img_l = [read_image(v) for v in img_path_l] - # stack to TCHW, RGB, [0,1], torch - imgs = np.stack(img_l, axis=0) - imgs = imgs[:, :, :, [2, 1, 0]] - imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() - return imgs - - def index_generation(crt_i, max_n, N, padding='reflection'): - ''' - padding: replicate | reflection | new_info | circle - ''' - max_n = max_n - 1 - n_pad = N // 2 - return_l = [] - - for i in range(crt_i - n_pad, crt_i + n_pad + 1): - if i < 0: - if padding == 'replicate': - add_idx = 0 - elif padding == 'reflection': - add_idx = -i - elif padding == 'new_info': - add_idx = (crt_i + n_pad) + (-i) - elif padding == 'circle': - add_idx = N + i - else: - raise ValueError('Wrong padding mode') - elif i > max_n: - if padding == 'replicate': - add_idx = max_n - elif padding == 'reflection': - add_idx = max_n * 2 - i - elif padding == 'new_info': - add_idx = (crt_i - n_pad) - (i - max_n) - elif padding == 'circle': - add_idx = i - N - else: - raise ValueError('Wrong padding mode') - else: - add_idx = i - return_l.append(add_idx) - return return_l - - def single_forward(model, imgs_in): - with torch.no_grad(): - model_output = model(imgs_in) - if isinstance(model_output, list) or isinstance(model_output, tuple): - output = model_output[0] - else: - output = model_output - return output - - sub_folder_l = sorted(glob.glob(test_dataset_folder)) - #### set up the models - model.load_state_dict(torch.load(model_path), strict=True) - model.eval() - model = model.to(device) - - avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] - sub_folder_name_l = [] - - # for each sub-folder - for sub_folder in sub_folder_l: - sub_folder_name = sub_folder.split('/')[-1] - sub_folder_name_l.append(sub_folder_name) - save_sub_folder = osp.join(save_folder, sub_folder_name) - - img_path_l = sorted(glob.glob(sub_folder + '/*')) - max_idx = len(img_path_l) - - if save_imgs: - util.mkdirs(save_sub_folder) - - #### read LR images - imgs = read_seq_imgs(sub_folder) - #### read GT images - img_GT_l = [] - if data_mode == 'Vid4': - sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*') - else: - sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*') - for img_GT_path in sorted(glob.glob(sub_folder_GT)): - img_GT_l.append(read_image(img_GT_path)) - - # When using the downsampling in DUF official code, we downsample the HR images - if DUF_downsampling: - sub_folder = sub_folder_GT - img_path_l = sorted(glob.glob(sub_folder)) - max_idx = len(img_path_l) - imgs = read_seq_imgs(sub_folder[:-2]) - - avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0 - cal_n_border, cal_n_center = 0, 0 - - # process each image - for img_idx, img_path in enumerate(img_path_l): - c_idx = int(osp.splitext(osp.basename(img_path))[0]) - select_idx = index_generation(c_idx, max_idx, N_in, padding=padding) - # get input images - imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) - - # Downsample the HR images - H, W = imgs_in.size(3), imgs_in.size(4) - if DUF_downsampling: - imgs_in = util.DUF_downsample(imgs_in, scale=scale) - - output = single_forward(model, imgs_in) - - # Crop to the original shape - if scale == 3: - pad_h = 3 - (H % 3) - pad_w = 3 - (W % 3) - if pad_h > 0: - output = output[:, :, :-pad_h, :] - if pad_w > 0: - output = output[:, :, :, :-pad_w] - output_f = output.data.float().cpu().squeeze(0) - - output = util.tensor2img(output_f) - - # save imgs - if save_imgs: - cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output) - - #### calculate PSNR - output = output / 255. - GT = np.copy(img_GT_l[img_idx]) - # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels - if data_mode == 'Vid4': # bgr2y, [0, 1] - GT = data_util.bgr2ycbcr(GT) - output = data_util.bgr2ycbcr(output) - if crop_border == 0: - cropped_output = output - cropped_GT = GT - else: - cropped_output = output[crop_border:-crop_border, crop_border:-crop_border] - cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border] - crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255) - logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr)) - - if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames - avg_psnr_center += crt_psnr - cal_n_center += 1 - else: # border frames - avg_psnr_border += crt_psnr - cal_n_border += 1 - - avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border) - avg_psnr_center = avg_psnr_center / cal_n_center - if cal_n_border == 0: - avg_psnr_border = 0 - else: - avg_psnr_border = avg_psnr_border / cal_n_border - - logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' - 'Center PSNR: {:.6f} dB for {} frames; ' - 'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr, - (cal_n_center + cal_n_border), - avg_psnr_center, cal_n_center, - avg_psnr_border, cal_n_border)) - - avg_psnr_l.append(avg_psnr) - avg_psnr_center_l.append(avg_psnr_center) - avg_psnr_border_l.append(avg_psnr_border) - - logger.info('################ Tidy Outputs ################') - for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l, - avg_psnr_center_l, avg_psnr_border_l): - logger.info('Folder {} - Average PSNR: {:.6f} dB. ' - 'Center PSNR: {:.6f} dB. ' - 'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border)) - logger.info('################ Final Results ################') - logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) - logger.info('Padding mode: {}'.format(padding)) - logger.info('Model path: {}'.format(model_path)) - logger.info('Save images: {}'.format(save_imgs)) - logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' - 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( - sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l), - sum(avg_psnr_center_l) / len(avg_psnr_center_l), - sum(avg_psnr_border_l) / len(avg_psnr_border_l))) - - -if __name__ == '__main__': - main() diff --git a/codes/test_Vid4_REDS4_with_GT_TOF.py b/codes/test_Vid4_REDS4_with_GT_TOF.py deleted file mode 100644 index da8fc300..00000000 --- a/codes/test_Vid4_REDS4_with_GT_TOF.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -TOF testing script, test Vid4 (SR) and REDS4 (SR-clean) datasets -write to txt log file -""" - -import os -import os.path as osp -import glob -import logging -import numpy as np -import cv2 -import torch - -import utils.util as util -import data.util as data_util -import models.archs.TOF_arch as TOF_arch - - -def main(): - ################# - # configurations - ################# - os.environ['CUDA_VISIBLE_DEVICES'] = '0' - data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS) - - # model - N_in = 7 - model_path = '../experiments/pretrained_models/TOF_official.pth' - adapt_official = True if 'official' in model_path else False - model = TOF_arch.TOFlow(adapt_official=adapt_official) - - #### dataset - if data_mode == 'Vid4': - test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*' - else: - test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode) - - #### evaluation - crop_border = 0 - border_frame = N_in // 2 # border frames when evaluate - # temporal padding mode - padding = 'new_info' # different from the official setting - save_imgs = True - ############################################################################ - device = torch.device('cuda') - save_folder = '../results/{}'.format(data_mode) - util.mkdirs(save_folder) - util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) - logger = logging.getLogger('base') - - #### log info - logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) - logger.info('Padding mode: {}'.format(padding)) - logger.info('Model path: {}'.format(model_path)) - logger.info('Save images: {}'.format(save_imgs)) - - def read_image(img_path): - '''read one image from img_path - Return img: HWC, BGR, [0,1], numpy - ''' - img_GT = cv2.imread(img_path) - img = img_GT.astype(np.float32) / 255. - return img - - def read_seq_imgs(img_seq_path): - '''read a sequence of images''' - img_path_l = sorted(glob.glob(img_seq_path + '/*')) - img_l = [read_image(v) for v in img_path_l] - # stack to TCHW, RGB, [0,1], torch - imgs = np.stack(img_l, axis=0) - imgs = imgs[:, :, :, [2, 1, 0]] - imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() - return imgs - - def index_generation(crt_i, max_n, N, padding='reflection'): - ''' - padding: replicate | reflection | new_info | circle - ''' - max_n = max_n - 1 - n_pad = N // 2 - return_l = [] - - for i in range(crt_i - n_pad, crt_i + n_pad + 1): - if i < 0: - if padding == 'replicate': - add_idx = 0 - elif padding == 'reflection': - add_idx = -i - elif padding == 'new_info': - add_idx = (crt_i + n_pad) + (-i) - elif padding == 'circle': - add_idx = N + i - else: - raise ValueError('Wrong padding mode') - elif i > max_n: - if padding == 'replicate': - add_idx = max_n - elif padding == 'reflection': - add_idx = max_n * 2 - i - elif padding == 'new_info': - add_idx = (crt_i - n_pad) - (i - max_n) - elif padding == 'circle': - add_idx = i - N - else: - raise ValueError('Wrong padding mode') - else: - add_idx = i - return_l.append(add_idx) - return return_l - - def single_forward(model, imgs_in): - with torch.no_grad(): - model_output = model(imgs_in) - if isinstance(model_output, list) or isinstance(model_output, tuple): - output = model_output[0] - else: - output = model_output - return output - - sub_folder_l = sorted(glob.glob(test_dataset_folder)) - #### set up the models - model.load_state_dict(torch.load(model_path), strict=True) - model.eval() - model = model.to(device) - - avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] - sub_folder_name_l = [] - - # for each sub-folder - for sub_folder in sub_folder_l: - sub_folder_name = sub_folder.split('/')[-1] - sub_folder_name_l.append(sub_folder_name) - save_sub_folder = osp.join(save_folder, sub_folder_name) - - img_path_l = sorted(glob.glob(sub_folder + '/*')) - max_idx = len(img_path_l) - - if save_imgs: - util.mkdirs(save_sub_folder) - - #### read LR images - imgs = read_seq_imgs(sub_folder) - #### read GT images - img_GT_l = [] - if data_mode == 'Vid4': - sub_folder_GT = osp.join(sub_folder.replace('/BIx4up_direct/', '/GT/'), '*') - else: - sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*') - for img_GT_path in sorted(glob.glob(sub_folder_GT)): - img_GT_l.append(read_image(img_GT_path)) - - avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0 - cal_n_border, cal_n_center = 0, 0 - - # process each image - for img_idx, img_path in enumerate(img_path_l): - c_idx = int(osp.splitext(osp.basename(img_path))[0]) - select_idx = index_generation(c_idx, max_idx, N_in, padding=padding) - # get input images - imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) - output = single_forward(model, imgs_in) - output_f = output.data.float().cpu().squeeze(0) - - output = util.tensor2img(output_f) - - # save imgs - if save_imgs: - cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output) - - #### calculate PSNR - output = output / 255. - GT = np.copy(img_GT_l[img_idx]) - # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels - if data_mode == 'Vid4': # bgr2y, [0, 1] - GT = data_util.bgr2ycbcr(GT) - output = data_util.bgr2ycbcr(output) - if crop_border == 0: - cropped_output = output - cropped_GT = GT - else: - cropped_output = output[crop_border:-crop_border, crop_border:-crop_border] - cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border] - crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255) - logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr)) - - if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames - avg_psnr_center += crt_psnr - cal_n_center += 1 - else: # border frames - avg_psnr_border += crt_psnr - cal_n_border += 1 - - avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border) - avg_psnr_center = avg_psnr_center / cal_n_center - if cal_n_border == 0: - avg_psnr_border = 0 - else: - avg_psnr_border = avg_psnr_border / cal_n_border - - logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' - 'Center PSNR: {:.6f} dB for {} frames; ' - 'Border PSNR: {:.6f} dB for {} frames.'.format(sub_folder_name, avg_psnr, - (cal_n_center + cal_n_border), - avg_psnr_center, cal_n_center, - avg_psnr_border, cal_n_border)) - - avg_psnr_l.append(avg_psnr) - avg_psnr_center_l.append(avg_psnr_center) - avg_psnr_border_l.append(avg_psnr_border) - - logger.info('################ Tidy Outputs ################') - for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l, - avg_psnr_center_l, avg_psnr_border_l): - logger.info('Folder {} - Average PSNR: {:.6f} dB. ' - 'Center PSNR: {:.6f} dB. ' - 'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border)) - logger.info('################ Final Results ################') - logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) - logger.info('Padding mode: {}'.format(padding)) - logger.info('Model path: {}'.format(model_path)) - logger.info('Save images: {}'.format(save_imgs)) - logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' - 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( - sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l), - sum(avg_psnr_center_l) / len(avg_psnr_center_l), - sum(avg_psnr_border_l) / len(avg_psnr_border_l))) - - -if __name__ == '__main__': - main()